Skip to content

Commit

Permalink
Unify video metadata in VideoClips (#1527)
Browse files Browse the repository at this point in the history
* Unify video metadata in VideoClips

* Bugfix

* Make tests a bit more robust
  • Loading branch information
fmassa committed Oct 29, 2019
1 parent c226bb9 commit 7d509c5
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 63 deletions.
5 changes: 3 additions & 2 deletions test/test_datasets_video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_unfold(self):
self.assertTrue(r.equal(expected))

@unittest.skipIf(not io.video._av_available(), "this test requires av")
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
@unittest.skipIf(sys.platform == 'win32', 'temporarily disabled on Windows')
def test_video_clips(self):
with get_list_of_videos(num_videos=3) as video_list:
video_clips = VideoClips(video_list, 5, 5)
Expand All @@ -84,7 +84,7 @@ def test_video_clips(self):
self.assertEqual(clip_idx, c_idx)

@unittest.skipIf(not io.video._av_available(), "this test requires av")
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
@unittest.skipIf(sys.platform == 'win32', 'temporarily disabled on Windows')
def test_video_clips_custom_fps(self):
with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list:
num_frames = 4
Expand All @@ -94,6 +94,7 @@ def test_video_clips_custom_fps(self):
video, audio, info, video_idx = video_clips.get_clip(i)
self.assertEqual(video.shape[0], num_frames)
self.assertEqual(info["video_fps"], fps)
self.assertEqual(info, {"video_fps": fps})
# TODO add tests checking that the content is right

def test_compute_clips_for_video(self):
Expand Down
11 changes: 11 additions & 0 deletions test/test_datasets_video_utils_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import unittest
from torchvision import set_video_backend
import test_datasets_video_utils


set_video_backend('video_reader')


if __name__ == '__main__':
suite = unittest.TestLoader().loadTestsFromModule(test_datasets_video_utils)
unittest.TextTestRunner(verbosity=1).run(suite)
1 change: 1 addition & 0 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def test_read_video_pts_unit_sec(self):

self.assertTrue(data.equal(lv))
self.assertEqual(info["video_fps"], 5)
self.assertEqual(info, {"video_fps": 5})

def test_read_timestamps_pts_unit_sec(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
Expand Down
104 changes: 44 additions & 60 deletions torchvision/datasets/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torchvision.io import (
_read_video_timestamps_from_file,
_read_video_from_file,
_probe_video_from_file
)
from torchvision.io import read_video_timestamps, read_video

Expand Down Expand Up @@ -71,11 +72,11 @@ def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1
frame_rate=None, _precomputed_metadata=None, num_workers=0,
_video_width=0, _video_height=0, _video_min_dimension=0,
_audio_samples=0):
from torchvision import get_video_backend

self.video_paths = video_paths
self.num_workers = num_workers
self._backend = get_video_backend()

# these options are not valid for pyav backend
self._video_width = _video_width
self._video_height = _video_height
self._video_min_dimension = _video_min_dimension
Expand All @@ -89,87 +90,60 @@ def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1

def _compute_frame_pts(self):
self.video_pts = []
if self._backend == "pyav":
self.video_fps = []
else:
self.info = []
self.video_fps = []

# strategy: use a DataLoader to parallelize read_video_timestamps
# so need to create a dummy dataset first
class DS(object):
def __init__(self, x, _backend):
def __init__(self, x):
self.x = x
self._backend = _backend

def __len__(self):
return len(self.x)

def __getitem__(self, idx):
if self._backend == "pyav":
return read_video_timestamps(self.x[idx])
else:
return _read_video_timestamps_from_file(self.x[idx])
return read_video_timestamps(self.x[idx])

import torch.utils.data
dl = torch.utils.data.DataLoader(
DS(self.video_paths, self._backend),
DS(self.video_paths),
batch_size=16,
num_workers=self.num_workers,
collate_fn=lambda x: x)

with tqdm(total=len(dl)) as pbar:
for batch in dl:
pbar.update(1)
if self._backend == "pyav":
clips, fps = list(zip(*batch))
clips = [torch.as_tensor(c) for c in clips]
self.video_pts.extend(clips)
self.video_fps.extend(fps)
else:
video_pts, _audio_pts, info = list(zip(*batch))
video_pts = [torch.as_tensor(c) for c in video_pts]
self.video_pts.extend(video_pts)
self.info.extend(info)
clips, fps = list(zip(*batch))
clips = [torch.as_tensor(c) for c in clips]
self.video_pts.extend(clips)
self.video_fps.extend(fps)

def _init_from_metadata(self, metadata):
self.video_paths = metadata["video_paths"]
assert len(self.video_paths) == len(metadata["video_pts"])
self.video_pts = metadata["video_pts"]

if self._backend == "pyav":
assert len(self.video_paths) == len(metadata["video_fps"])
self.video_fps = metadata["video_fps"]
else:
assert len(self.video_paths) == len(metadata["info"])
self.info = metadata["info"]
assert len(self.video_paths) == len(metadata["video_fps"])
self.video_fps = metadata["video_fps"]

@property
def metadata(self):
_metadata = {
"video_paths": self.video_paths,
"video_pts": self.video_pts,
"video_fps": self.video_fps
}
if self._backend == "pyav":
_metadata.update({"video_fps": self.video_fps})
else:
_metadata.update({"info": self.info})
return _metadata

def subset(self, indices):
video_paths = [self.video_paths[i] for i in indices]
video_pts = [self.video_pts[i] for i in indices]
if self._backend == "pyav":
video_fps = [self.video_fps[i] for i in indices]
else:
info = [self.info[i] for i in indices]
video_fps = [self.video_fps[i] for i in indices]
metadata = {
"video_paths": video_paths,
"video_pts": video_pts,
"video_fps": video_fps
}
if self._backend == "pyav":
metadata.update({"video_fps": video_fps})
else:
metadata.update({"info": info})
return type(self)(video_paths, self.num_frames, self.step, self.frame_rate,
_precomputed_metadata=metadata, num_workers=self.num_workers,
_video_width=self._video_width,
Expand Down Expand Up @@ -212,22 +186,10 @@ def compute_clips(self, num_frames, step, frame_rate=None):
self.frame_rate = frame_rate
self.clips = []
self.resampling_idxs = []
if self._backend == "pyav":
for video_pts, fps in zip(self.video_pts, self.video_fps):
clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate)
self.clips.append(clips)
self.resampling_idxs.append(idxs)
else:
for video_pts, info in zip(self.video_pts, self.info):
if "video_fps" in info:
clips, idxs = self.compute_clips_for_video(
video_pts, num_frames, step, info["video_fps"], frame_rate)
self.clips.append(clips)
self.resampling_idxs.append(idxs)
else:
# properly handle the cases where video decoding fails
self.clips.append(torch.zeros(0, num_frames, dtype=torch.int64))
self.resampling_idxs.append(torch.zeros(0, dtype=torch.int64))
for video_pts, fps in zip(self.video_pts, self.video_fps):
clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate)
self.clips.append(clips)
self.resampling_idxs.append(idxs)
clip_lengths = torch.as_tensor([len(v) for v in self.clips])
self.cumulative_sizes = clip_lengths.cumsum(0).tolist()

Expand Down Expand Up @@ -287,12 +249,28 @@ def get_clip(self, idx):
video_path = self.video_paths[video_idx]
clip_pts = self.clips[video_idx][clip_idx]

if self._backend == "pyav":
from torchvision import get_video_backend
backend = get_video_backend()

if backend == "pyav":
# check for invalid options
if self._video_width != 0:
raise ValueError("pyav backend doesn't support _video_width != 0")
if self._video_height != 0:
raise ValueError("pyav backend doesn't support _video_height != 0")
if self._video_min_dimension != 0:
raise ValueError("pyav backend doesn't support _video_min_dimension != 0")
if self._audio_samples != 0:
raise ValueError("pyav backend doesn't support _audio_samples != 0")

if backend == "pyav":
start_pts = clip_pts[0].item()
end_pts = clip_pts[-1].item()
video, audio, info = read_video(video_path, start_pts, end_pts)
else:
info = self.info[video_idx]
info = _probe_video_from_file(video_path)
video_fps = info["video_fps"]
audio_fps = None

video_start_pts = clip_pts[0].item()
video_end_pts = clip_pts[-1].item()
Expand All @@ -313,6 +291,7 @@ def get_clip(self, idx):
info["audio_timebase"],
math.ceil,
)
audio_fps = info["audio_sample_rate"]
video, audio, info = _read_video_from_file(
video_path,
video_width=self._video_width,
Expand All @@ -324,6 +303,11 @@ def get_clip(self, idx):
audio_pts_range=(audio_start_pts, audio_end_pts),
audio_timebase=audio_timebase,
)

info = {"video_fps": video_fps}
if audio_fps is not None:
info["audio_fps"] = audio_fps

if self.frame_rate is not None:
resampling_idx = self.resampling_idxs[video_idx][clip_idx]
if isinstance(resampling_idx, torch.Tensor):
Expand Down
9 changes: 8 additions & 1 deletion torchvision/io/_video_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def get_pts(time_base):
audio_timebase = info['audio_timebase']
audio_pts_range = get_pts(audio_timebase)

return _read_video_from_file(
vframes, aframes, info = _read_video_from_file(
filename,
read_video_stream=True,
video_pts_range=video_pts_range,
Expand All @@ -392,6 +392,13 @@ def get_pts(time_base):
audio_pts_range=audio_pts_range,
audio_timebase=audio_timebase,
)
_info = {}
if has_video:
_info['video_fps'] = info['video_fps']
if has_audio:
_info['audio_fps'] = info['audio_sample_rate']

return vframes, aframes, _info


def _read_video_timestamps(filename, pts_unit='pts'):
Expand Down

0 comments on commit 7d509c5

Please sign in to comment.