From 2a3bf2bdd49dbb84c048da0ee8f263c0610f6504 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 24 Jul 2019 10:59:30 -0700 Subject: [PATCH 1/5] Miscellaneous fixes and improvements --- torchvision/io/video.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 726cce8cdd2..ec964ac7f7f 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -1,3 +1,4 @@ +import re import gc import torch import numpy as np @@ -68,18 +69,34 @@ def _read_from_stream(container, start_offset, end_offset, stream, stream_name): should_buffer = False max_buffer_size = 5 if stream.type == "video": - # TODO consider also using stream.codec_context.codec.reorder - # videos with b frames can have out-of-order pts + # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt) # so need to buffer some extra frames to sort everything # properly - should_buffer = stream.codec_context.has_b_frames + extradata = stream.codec_context.extradata + # overly complicated way of finding if `divx_packed` is set, following + # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263 + if extradata and b"DivX" in extradata: + # can't use regex directly because of some weird characters sometimes... + pos = extradata.find(b"DivX") + d = extradata[pos:] + o = re.search(b"DivX(\d+)Build(\d+)(\w)", d) + if o is None: + o = re.search(b"DivX(\d+)b(\d+)(\w)", d) + if o is not None: + should_buffer = o.group(3) == b"p" seek_offset = start_offset + # some files don't seek to the right location, so better be safe here + seek_offset = max(seek_offset - 1, 0) if should_buffer: # FIXME this is kind of a hack, but we will jump to the previous keyframe # so this will be safe seek_offset = max(seek_offset - max_buffer_size, 0) - # TODO check if stream needs to always be the video stream here or not - container.seek(seek_offset, any_frame=False, backward=True, stream=stream) + try: + # TODO check if stream needs to always be the video stream here or not + container.seek(seek_offset, any_frame=False, backward=True, stream=stream) + except av.AVError: + print("Corrupted file?", container.name) + return [] buffer_count = 0 for idx, frame in enumerate(container.decode(**stream_name)): frames[frame.pts] = frame From f3c064df7a32e0791308f3fa7b01ac398442cb73 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 25 Jul 2019 02:46:24 -0700 Subject: [PATCH 2/5] Guard against videos without video stream --- torchvision/datasets/video_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchvision/datasets/video_utils.py b/torchvision/datasets/video_utils.py index 17227f38ecc..cb426b35e69 100644 --- a/torchvision/datasets/video_utils.py +++ b/torchvision/datasets/video_utils.py @@ -84,6 +84,10 @@ def subset(self, indices): @staticmethod def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate): + if fps is None: + # if for some reason the video doesn't have fps (because doesn't have a video stream) + # set the fps to 1. The value doesn't matter, because video_pts is empty anyway + fps = 1 if frame_rate is None: frame_rate = fps total_frames = len(video_pts) * (float(frame_rate) / fps) From 27142e002b6621c8fdd345fe2b6f00ef9c47e82f Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 25 Jul 2019 03:01:56 -0700 Subject: [PATCH 3/5] Fix lint --- torchvision/io/video.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index ec964ac7f7f..e738f00421d 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -79,9 +79,9 @@ def _read_from_stream(container, start_offset, end_offset, stream, stream_name): # can't use regex directly because of some weird characters sometimes... pos = extradata.find(b"DivX") d = extradata[pos:] - o = re.search(b"DivX(\d+)Build(\d+)(\w)", d) + o = re.search(br"DivX(\d+)Build(\d+)(\w)", d) if o is None: - o = re.search(b"DivX(\d+)b(\d+)(\w)", d) + o = re.search(br"DivX(\d+)b(\d+)(\w)", d) if o is not None: should_buffer = o.group(3) == b"p" seek_offset = start_offset From 58a2cce0dc2579b7424a7ff7fa59c8d7ef245749 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 25 Jul 2019 06:18:35 -0700 Subject: [PATCH 4/5] Add test for packed b-frames videos --- test/test_io.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/test_io.py b/test/test_io.py index f3008ce4a67..332c80b73e8 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -2,9 +2,17 @@ import contextlib import tempfile import torch +import torchvision.datasets.utils as utils import torchvision.io as io import unittest +import sys +from common_utils import get_tmp_dir + +if sys.version_info < (3,): + from urllib2 import URLError +else: + from urllib.error import URLError try: import av @@ -104,6 +112,22 @@ def test_read_partial_video_bframes(self): self.assertEqual(len(lv), 4) self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE) + @unittest.skipIf(av is None, "PyAV unavailable") + def test_read_packed_b_frames_divx_file(self): + with get_tmp_dir() as temp_dir: + name = "hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi" + f_name = os.path.join(temp_dir, name) + url = "https://download.pytorch.org/vision_tests/io/" + name + try: + utils.download_url(url, temp_dir) + pts, fps = io.read_video_timestamps(f_name) + self.assertEqual(pts, sorted(pts)) + self.assertEqual(fps, 30) + except URLError: + msg = "could not download test file '{}'".format(url) + warnings.warn(msg, RuntimeWarning) + raise unittest.SkipTest(msg) + # TODO add tests for audio From f85379d1fce9790f1b403b34cfe990313ae0264c Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 25 Jul 2019 07:37:27 -0700 Subject: [PATCH 5/5] Fix missing import --- test/test_io.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_io.py b/test/test_io.py index 332c80b73e8..6bd5703b7e6 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -6,6 +6,7 @@ import torchvision.io as io import unittest import sys +import warnings from common_utils import get_tmp_dir