Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,18 @@
import contextlib
import tempfile
import torch
import torchvision.datasets.utils as utils
import torchvision.io as io
import unittest
import sys
import warnings

from common_utils import get_tmp_dir

if sys.version_info < (3,):
from urllib2 import URLError
else:
from urllib.error import URLError

try:
import av
Expand Down Expand Up @@ -104,6 +113,22 @@ def test_read_partial_video_bframes(self):
self.assertEqual(len(lv), 4)
self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)

@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_packed_b_frames_divx_file(self):
with get_tmp_dir() as temp_dir:
name = "hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi"
f_name = os.path.join(temp_dir, name)
url = "https://download.pytorch.org/vision_tests/io/" + name
try:
utils.download_url(url, temp_dir)
pts, fps = io.read_video_timestamps(f_name)
self.assertEqual(pts, sorted(pts))
self.assertEqual(fps, 30)
except URLError:
msg = "could not download test file '{}'".format(url)
warnings.warn(msg, RuntimeWarning)
raise unittest.SkipTest(msg)

# TODO add tests for audio


Expand Down
4 changes: 4 additions & 0 deletions torchvision/datasets/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def subset(self, indices):

@staticmethod
def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate):
if fps is None:
# if for some reason the video doesn't have fps (because doesn't have a video stream)
# set the fps to 1. The value doesn't matter, because video_pts is empty anyway
fps = 1
if frame_rate is None:
frame_rate = fps
total_frames = len(video_pts) * (float(frame_rate) / fps)
Expand Down
27 changes: 22 additions & 5 deletions torchvision/io/video.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import gc
import torch
import numpy as np
Expand Down Expand Up @@ -68,18 +69,34 @@ def _read_from_stream(container, start_offset, end_offset, stream, stream_name):
should_buffer = False
max_buffer_size = 5
if stream.type == "video":
# TODO consider also using stream.codec_context.codec.reorder
# videos with b frames can have out-of-order pts
# DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
# so need to buffer some extra frames to sort everything
# properly
should_buffer = stream.codec_context.has_b_frames
extradata = stream.codec_context.extradata
# overly complicated way of finding if `divx_packed` is set, following
# https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
if extradata and b"DivX" in extradata:
# can't use regex directly because of some weird characters sometimes...
pos = extradata.find(b"DivX")
d = extradata[pos:]
o = re.search(br"DivX(\d+)Build(\d+)(\w)", d)
if o is None:
o = re.search(br"DivX(\d+)b(\d+)(\w)", d)
if o is not None:
should_buffer = o.group(3) == b"p"
seek_offset = start_offset
# some files don't seek to the right location, so better be safe here
seek_offset = max(seek_offset - 1, 0)
if should_buffer:
# FIXME this is kind of a hack, but we will jump to the previous keyframe
# so this will be safe
seek_offset = max(seek_offset - max_buffer_size, 0)
# TODO check if stream needs to always be the video stream here or not
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
try:
# TODO check if stream needs to always be the video stream here or not
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
except av.AVError:
print("Corrupted file?", container.name)
return []
buffer_count = 0
for idx, frame in enumerate(container.decode(**stream_name)):
frames[frame.pts] = frame
Expand Down