From 1375e90b4329bbccf09f0fdc032aee9da82d5de5 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 7 May 2021 13:37:26 +0100 Subject: [PATCH 1/3] Fixed audio-video synchronisation problem in read_video() when using as unit --- torchvision/io/_video_opt.py | 38 ++++++++++++++++++++++++++---------- torchvision/io/video.py | 8 ++++++++ 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index 4cc2b60c706..0783ec44dc4 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -471,6 +471,14 @@ def _probe_video_from_memory(video_data): return info +def _convert_to_sec(start_pts, end_pts, pts_unit, time_base): + if pts_unit == 'pts': + start_pts = float(start_pts * time_base) + end_pts = float(end_pts * time_base) + pts_unit = 'sec' + return start_pts, end_pts, pts_unit + + def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): if end_pts is None: end_pts = float("inf") @@ -485,6 +493,26 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): has_video = info.has_video has_audio = info.has_audio + video_pts_range = (0, -1) + video_timebase = default_timebase + audio_pts_range = (0, -1) + audio_timebase = default_timebase + time_base = None + + if has_video: + video_timebase = Fraction( + info.video_timebase.numerator, info.video_timebase.denominator + ) + time_base = video_timebase + + if has_audio: + audio_timebase = Fraction( + info.audio_timebase.numerator, info.audio_timebase.denominator + ) + time_base = time_base if time_base else audio_timebase + + start_pts, end_pts, pts_unit = _convert_to_sec( + start_pts, end_pts, pts_unit, time_base) def get_pts(time_base): start_offset = start_pts @@ -497,20 +525,10 @@ def get_pts(time_base): end_offset = -1 return start_offset, end_offset - video_pts_range = (0, -1) - video_timebase = default_timebase if has_video: - video_timebase = Fraction( - info.video_timebase.numerator, info.video_timebase.denominator - ) video_pts_range = get_pts(video_timebase) - audio_pts_range = (0, -1) - audio_timebase = default_timebase if has_audio: - audio_timebase = Fraction( - info.audio_timebase.numerator, info.audio_timebase.denominator - ) audio_pts_range = get_pts(audio_timebase) vframes, aframes, info = _read_video_from_file( diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 22cad38d10b..510c4fe60e2 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -278,6 +278,14 @@ def read_video( try: with av.open(filename, metadata_errors="ignore") as container: + time_base = None + if container.streams.video: + time_base = container.streams.video[0].time_base + elif container.streams.audio: + time_base = container.streams.audio[0].time_base + if pts_unit == 'pts': + start_pts, end_pts, pts_unit = _video_opt._convert_to_sec( + start_pts, end_pts, pts_unit, time_base) if container.streams.video: video_frames = _read_from_stream( container, From 1a41df3ad803a465efcb14f924de4cc36b216179 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 21 May 2021 13:11:49 +0100 Subject: [PATCH 2/3] Addressed review comments --- torchvision/io/_video_opt.py | 13 +++++++------ torchvision/io/video.py | 16 ++++++++-------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index 0783ec44dc4..a34b023bc6c 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -497,7 +497,7 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): video_timebase = default_timebase audio_pts_range = (0, -1) audio_timebase = default_timebase - time_base = None + time_base = default_timebase if has_video: video_timebase = Fraction( @@ -511,16 +511,17 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): ) time_base = time_base if time_base else audio_timebase - start_pts, end_pts, pts_unit = _convert_to_sec( + # video_timebase is the default time_base + start_pts_sec, end_pts_sec, pts_unit = _convert_to_sec( start_pts, end_pts, pts_unit, time_base) def get_pts(time_base): - start_offset = start_pts - end_offset = end_pts + start_offset = start_pts_sec + end_offset = end_pts_sec if pts_unit == "sec": - start_offset = int(math.floor(start_pts * (1 / time_base))) + start_offset = int(math.floor(start_pts_sec * (1 / time_base))) if end_offset != float("inf"): - end_offset = int(math.ceil(end_pts * (1 / time_base))) + end_offset = int(math.ceil(end_pts_sec * (1 / time_base))) if end_offset == float("inf"): end_offset = -1 return start_offset, end_offset diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 510c4fe60e2..e16e8906d97 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -278,19 +278,19 @@ def read_video( try: with av.open(filename, metadata_errors="ignore") as container: - time_base = None + time_base = _video_opt.default_timebase if container.streams.video: time_base = container.streams.video[0].time_base elif container.streams.audio: time_base = container.streams.audio[0].time_base - if pts_unit == 'pts': - start_pts, end_pts, pts_unit = _video_opt._convert_to_sec( - start_pts, end_pts, pts_unit, time_base) + # video_timebase is the default time_base + start_pts_sec, end_pts_sec, pts_unit = _video_opt._convert_to_sec( + start_pts, end_pts, pts_unit, time_base) if container.streams.video: video_frames = _read_from_stream( container, - start_pts, - end_pts, + start_pts_sec, + end_pts_sec, pts_unit, container.streams.video[0], {"video": 0}, @@ -303,8 +303,8 @@ def read_video( if container.streams.audio: audio_frames = _read_from_stream( container, - start_pts, - end_pts, + start_pts_sec, + end_pts_sec, pts_unit, container.streams.audio[0], {"audio": 0}, From d8b131712e01ac4c696efdf387d0510a0939fb5c Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 21 May 2021 14:36:24 +0100 Subject: [PATCH 3/3] Added unit test --- test/test_video_reader.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/test/test_video_reader.py b/test/test_video_reader.py index 5b9b2184daf..b1db3cfb0ed 100644 --- a/test/test_video_reader.py +++ b/test/test_video_reader.py @@ -1241,6 +1241,45 @@ def test_read_video_from_memory_scripted(self): ) # FUTURE: check value of video / audio frames + def test_audio_video_sync(self): + """Test if audio/video are synchronised with pyav output.""" + for test_video, config in test_videos.items(): + full_path = os.path.join(VIDEO_DIR, test_video) + container = av.open(full_path) + if not container.streams.audio: + # Skip if no audio stream + continue + start_pts_val, cutoff = 0, 1 + if container.streams.video: + video = container.streams.video[0] + arr = [] + for index, frame in enumerate(container.decode(video)): + if index == cutoff: + start_pts_val = frame.pts + if index >= cutoff: + arr.append(frame.to_rgb().to_ndarray()) + visual, _, info = io.read_video(full_path, start_pts=start_pts_val, pts_unit='pts') + self.assertAlmostEqual( + config.video_fps, info['video_fps'], delta=0.0001 + ) + arr = torch.Tensor(arr) + if arr.shape == visual.shape: + self.assertGreaterEqual( + torch.mean(torch.isclose(visual.float(), arr, atol=1e-5).float()), 0.99) + + container = av.open(full_path) + if container.streams.audio: + audio = container.streams.audio[0] + arr = [] + for index, frame in enumerate(container.decode(audio)): + if index >= cutoff: + arr.append(frame.to_ndarray()) + _, audio, _ = io.read_video(full_path, start_pts=start_pts_val, pts_unit='pts') + arr = torch.as_tensor(np.concatenate(arr, axis=1)) + if arr.shape == audio.shape: + self.assertGreaterEqual( + torch.mean(torch.isclose(audio.float(), arr).float()), 0.99) + if __name__ == "__main__": unittest.main()