Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,19 @@ def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None,
yield f.name, data


@unittest.skipIf(av is None, "PyAV unavailable")
class Tester(unittest.TestCase):
# compression adds artifacts, thus we add a tolerance of
# 6 in 0-255 range
TOLERANCE = 6

@unittest.skipIf(av is None, "PyAV unavailable")
def test_write_read_video(self):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
lv, _, info = io.read_video(f_name)

self.assertTrue(data.equal(lv))
self.assertEqual(info["video_fps"], 5)

@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_timestamps(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)
Expand All @@ -81,7 +80,6 @@ def test_read_timestamps(self):

self.assertEqual(pts, expected_pts)

@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_partial_video(self):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)
Expand All @@ -96,7 +94,6 @@ def test_read_partial_video(self):
self.assertEqual(len(lv), 4)
self.assertTrue(data[4:8].equal(lv))

@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_partial_video_bframes(self):
# do not use lossless encoding, to test the presence of B-frames
options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'}
Expand All @@ -113,7 +110,6 @@ def test_read_partial_video_bframes(self):
self.assertEqual(len(lv), 4)
self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)

@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_packed_b_frames_divx_file(self):
with get_tmp_dir() as temp_dir:
name = "hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi"
Expand All @@ -129,6 +125,23 @@ def test_read_packed_b_frames_divx_file(self):
warnings.warn(msg, RuntimeWarning)
raise unittest.SkipTest(msg)

def test_read_timestamps_from_packet(self):
with temp_video(10, 300, 300, 5, video_codec='mpeg4') as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)

# note: not all formats/codecs provide accurate information for computing the
# timestamps. For the format that we use here, this information is available,
# so we use it as a baseline
container = av.open(f_name)
stream = container.streams[0]
# make sure we went through the optimized codepath
self.assertIn(b'Lavc', stream.codec_context.extradata)
pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
expected_pts = [i * pts_step for i in range(num_frames)]

self.assertEqual(pts, expected_pts)

# TODO add tests for audio


Expand Down
17 changes: 15 additions & 2 deletions torchvision/io/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,15 @@ def read_video(filename, start_pts=0, end_pts=None):
return vframes, aframes, info


def _can_read_timestamps_from_packets(container):
extradata = container.streams[0].codec_context.extradata
if extradata is None:
return False
if b"Lavc" in extradata:
return True
return False


def read_video_timestamps(filename):
"""
List the video frames timestamps.
Expand All @@ -205,8 +214,12 @@ def read_video_timestamps(filename):
video_frames = []
video_fps = None
if container.streams.video:
video_frames = _read_from_stream(container, 0, float("inf"),
container.streams.video[0], {'video': 0})
if _can_read_timestamps_from_packets(container):
# fast path
video_frames = [x for x in container.demux(video=0) if x.pts is not None]
else:
video_frames = _read_from_stream(container, 0, float("inf"),
container.streams.video[0], {'video': 0})
video_fps = float(container.streams.video[0].average_rate)
container.close()
return [x.pts for x in video_frames], video_fps