From b8d2895408416565f1b9faa32fef65c355561e10 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Sat, 26 Feb 2022 14:40:32 -0800 Subject: [PATCH 1/2] Fix shape mismatch error --- torchvision/io/_video_opt.py | 40 +++++++++--------------------------- torchvision/io/video.py | 7 ------- 2 files changed, 10 insertions(+), 37 deletions(-) diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index de4b25bb7b5..87b9f229aca 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -427,16 +427,6 @@ def _probe_video_from_memory( return info -def _convert_to_sec( - start_pts: Union[float, Fraction], end_pts: Union[float, Fraction], pts_unit: str, time_base: Fraction -) -> Tuple[Union[float, Fraction], Union[float, Fraction], str]: - if pts_unit == "pts": - start_pts = float(start_pts * time_base) - end_pts = float(end_pts * time_base) - pts_unit = "sec" - return start_pts, end_pts, pts_unit - - def _read_video( filename: str, start_pts: Union[float, Fraction] = 0, @@ -456,38 +446,28 @@ def _read_video( has_video = info.has_video has_audio = info.has_audio - video_pts_range = (0, -1) - video_timebase = default_timebase - audio_pts_range = (0, -1) - audio_timebase = default_timebase - time_base = default_timebase - - if has_video: - video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator) - time_base = video_timebase - - if has_audio: - audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator) - time_base = time_base if time_base else audio_timebase - - # video_timebase is the default time_base - start_pts_sec, end_pts_sec, pts_unit = _convert_to_sec(start_pts, end_pts, pts_unit, time_base) def get_pts(time_base): - start_offset = start_pts_sec - end_offset = end_pts_sec + start_offset = start_pts + end_offset = end_pts if pts_unit == "sec": - start_offset = int(math.floor(start_pts_sec * (1 / time_base))) + start_offset = int(math.floor(start_pts * (1 / time_base))) if end_offset != float("inf"): - end_offset = int(math.ceil(end_pts_sec * (1 / time_base))) + end_offset = int(math.ceil(end_pts * (1 / time_base))) if end_offset == float("inf"): end_offset = -1 return start_offset, end_offset + video_pts_range = (0, -1) + video_timebase = default_timebase if has_video: + video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator) video_pts_range = get_pts(video_timebase) + audio_pts_range = (0, -1) + audio_timebase = default_timebase if has_audio: + audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator) audio_pts_range = get_pts(audio_timebase) vframes, aframes, info = _read_video_from_file( diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 479fdfc1ddf..d026e754546 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -287,13 +287,6 @@ def read_video( with av.open(filename, metadata_errors="ignore") as container: if container.streams.audio: audio_timebase = container.streams.audio[0].time_base - time_base = _video_opt.default_timebase - if container.streams.video: - time_base = container.streams.video[0].time_base - elif container.streams.audio: - time_base = container.streams.audio[0].time_base - # video_timebase is the default time_base - start_pts, end_pts, pts_unit = _video_opt._convert_to_sec(start_pts, end_pts, pts_unit, time_base) if container.streams.video: video_frames = _read_from_stream( container, From ac2ced213b0fe665e4764825fdec074244e5f2f2 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Mon, 28 Feb 2022 04:33:06 -0800 Subject: [PATCH 2/2] Update offset --- test/test_video_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_video_reader.py b/test/test_video_reader.py index 73c4d8a1b85..075a3902a1b 100644 --- a/test/test_video_reader.py +++ b/test/test_video_reader.py @@ -1226,7 +1226,7 @@ def test_invalid_file(self): def test_audio_present_pts(self): """Test if audio frames are returned with pts unit.""" backends = ["video_reader", "pyav"] - start_offsets = [0, 1000] + start_offsets = [0, 500] end_offsets = [3000, None] for test_video, _ in test_videos.items(): full_path = os.path.join(VIDEO_DIR, test_video)