diff --git a/test/test_io.py b/test/test_io.py index 6bd5703b7e6..a27418bf5d5 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -52,12 +52,12 @@ def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, yield f.name, data +@unittest.skipIf(av is None, "PyAV unavailable") class Tester(unittest.TestCase): # compression adds artifacts, thus we add a tolerance of # 6 in 0-255 range TOLERANCE = 6 - @unittest.skipIf(av is None, "PyAV unavailable") def test_write_read_video(self): with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): lv, _, info = io.read_video(f_name) @@ -65,7 +65,6 @@ def test_write_read_video(self): self.assertTrue(data.equal(lv)) self.assertEqual(info["video_fps"], 5) - @unittest.skipIf(av is None, "PyAV unavailable") def test_read_timestamps(self): with temp_video(10, 300, 300, 5) as (f_name, data): pts, _ = io.read_video_timestamps(f_name) @@ -81,7 +80,6 @@ def test_read_timestamps(self): self.assertEqual(pts, expected_pts) - @unittest.skipIf(av is None, "PyAV unavailable") def test_read_partial_video(self): with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): pts, _ = io.read_video_timestamps(f_name) @@ -96,7 +94,6 @@ def test_read_partial_video(self): self.assertEqual(len(lv), 4) self.assertTrue(data[4:8].equal(lv)) - @unittest.skipIf(av is None, "PyAV unavailable") def test_read_partial_video_bframes(self): # do not use lossless encoding, to test the presence of B-frames options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'} @@ -113,7 +110,6 @@ def test_read_partial_video_bframes(self): self.assertEqual(len(lv), 4) self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE) - @unittest.skipIf(av is None, "PyAV unavailable") def test_read_packed_b_frames_divx_file(self): with get_tmp_dir() as temp_dir: name = "hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi" @@ -129,6 +125,23 @@ def test_read_packed_b_frames_divx_file(self): warnings.warn(msg, RuntimeWarning) raise unittest.SkipTest(msg) + def test_read_timestamps_from_packet(self): + with temp_video(10, 300, 300, 5, video_codec='mpeg4') as (f_name, data): + pts, _ = io.read_video_timestamps(f_name) + + # note: not all formats/codecs provide accurate information for computing the + # timestamps. For the format that we use here, this information is available, + # so we use it as a baseline + container = av.open(f_name) + stream = container.streams[0] + # make sure we went through the optimized codepath + self.assertIn(b'Lavc', stream.codec_context.extradata) + pts_step = int(round(float(1 / (stream.average_rate * stream.time_base)))) + num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration))) + expected_pts = [i * pts_step for i in range(num_frames)] + + self.assertEqual(pts, expected_pts) + # TODO add tests for audio diff --git a/torchvision/io/video.py b/torchvision/io/video.py index e738f00421d..383f539e9f6 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -185,6 +185,15 @@ def read_video(filename, start_pts=0, end_pts=None): return vframes, aframes, info +def _can_read_timestamps_from_packets(container): + extradata = container.streams[0].codec_context.extradata + if extradata is None: + return False + if b"Lavc" in extradata: + return True + return False + + def read_video_timestamps(filename): """ List the video frames timestamps. @@ -205,8 +214,12 @@ def read_video_timestamps(filename): video_frames = [] video_fps = None if container.streams.video: - video_frames = _read_from_stream(container, 0, float("inf"), - container.streams.video[0], {'video': 0}) + if _can_read_timestamps_from_packets(container): + # fast path + video_frames = [x for x in container.demux(video=0) if x.pts is not None] + else: + video_frames = _read_from_stream(container, 0, float("inf"), + container.streams.video[0], {'video': 0}) video_fps = float(container.streams.video[0].average_rate) container.close() return [x.pts for x in video_frames], video_fps