Skip to content

Commit

Permalink
Refactoring to use contexts managers, list comprehensions when more i…
Browse files Browse the repository at this point in the history
…diomatic, and minor renaming to help reader clarity (#2335)

* Refactoring to use contexts managers, list comprehensions when more idiomatic, and minor renaming to help reader clarity.

* Fix flake8 warning in video_utils.py
  • Loading branch information
QuentinDuval committed Jun 22, 2020
1 parent 32f21da commit 42aa9b2
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 88 deletions.
18 changes: 2 additions & 16 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,9 @@ def list_dir(root, prefix=False):
only returns the name of the directories found
"""
root = os.path.expanduser(root)
directories = list(
filter(
lambda p: os.path.isdir(os.path.join(root, p)),
os.listdir(root)
)
)

directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
if prefix is True:
directories = [os.path.join(root, d) for d in directories]

return directories


Expand All @@ -119,16 +112,9 @@ def list_files(root, suffix, prefix=False):
only returns the name of the files found
"""
root = os.path.expanduser(root)
files = list(
filter(
lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
os.listdir(root)
)
)

files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
if prefix is True:
files = [os.path.join(root, d) for d in files]

return files


Expand Down
20 changes: 12 additions & 8 deletions torchvision/datasets/video_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import bisect
import math
from fractions import Fraction
from typing import List

import torch
from torchvision.io import (
Expand Down Expand Up @@ -45,20 +46,23 @@ def unfold(tensor, size, step, dilation=1):
return torch.as_strided(tensor, new_size, new_stride)


class _DummyDataset(object):
class _VideoTimestampsDataset(object):
"""
Dummy dataset used for DataLoader in VideoClips.
Defined at top level so it can be pickled when forking.
Dataset used to parallelize the reading of the timestamps
of a list of videos, given their paths in the filesystem.
Used in VideoClips and defined at top level so it can be
pickled when forking.
"""

def __init__(self, x):
self.x = x
def __init__(self, video_paths: List[str]):
self.video_paths = video_paths

def __len__(self):
return len(self.x)
return len(self.video_paths)

def __getitem__(self, idx):
return read_video_timestamps(self.x[idx])
return read_video_timestamps(self.video_paths[idx])


class VideoClips(object):
Expand Down Expand Up @@ -132,7 +136,7 @@ def _compute_frame_pts(self):
import torch.utils.data

dl = torch.utils.data.DataLoader(
_DummyDataset(self.video_paths),
_VideoTimestampsDataset(self.video_paths),
batch_size=16,
num_workers=self.num_workers,
collate_fn=self._collate_fn,
Expand Down
122 changes: 58 additions & 64 deletions torchvision/io/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,27 +70,23 @@ def write_video(filename, video_array, fps: Union[int, float], video_codec="libx
if isinstance(fps, float):
fps = np.round(fps)

container = av.open(filename, mode="w")

stream = container.add_stream(video_codec, rate=fps)
stream.width = video_array.shape[2]
stream.height = video_array.shape[1]
stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
stream.options = options or {}

for img in video_array:
frame = av.VideoFrame.from_ndarray(img, format="rgb24")
frame.pict_type = "NONE"
for packet in stream.encode(frame):
with av.open(filename, mode="w") as container:
stream = container.add_stream(video_codec, rate=fps)
stream.width = video_array.shape[2]
stream.height = video_array.shape[1]
stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
stream.options = options or {}

for img in video_array:
frame = av.VideoFrame.from_ndarray(img, format="rgb24")
frame.pict_type = "NONE"
for packet in stream.encode(frame):
container.mux(packet)

# Flush stream
for packet in stream.encode():
container.mux(packet)

# Flush stream
for packet in stream.encode():
container.mux(packet)

# Close the file
container.close()


def _read_from_stream(
container, start_offset, end_offset, pts_unit, stream, stream_name
Expand Down Expand Up @@ -234,37 +230,35 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
audio_frames = []

try:
container = av.open(filename, metadata_errors="ignore")
with av.open(filename, metadata_errors="ignore") as container:
if container.streams.video:
video_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
video_fps = container.streams.video[0].average_rate
# guard against potentially corrupted files
if video_fps is not None:
info["video_fps"] = float(video_fps)

if container.streams.audio:
audio_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.audio[0],
{"audio": 0},
)
info["audio_fps"] = container.streams.audio[0].rate

except av.AVError:
# TODO raise a warning?
pass
else:
if container.streams.video:
video_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
video_fps = container.streams.video[0].average_rate
# guard against potentially corrupted files
if video_fps is not None:
info["video_fps"] = float(video_fps)

if container.streams.audio:
audio_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.audio[0],
{"audio": 0},
)
info["audio_fps"] = container.streams.audio[0].rate

container.close()

vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
aframes = [frame.to_ndarray() for frame in audio_frames]
Expand Down Expand Up @@ -293,6 +287,14 @@ def _can_read_timestamps_from_packets(container):
return False


def _decode_video_timestamps(container):
if _can_read_timestamps_from_packets(container):
# fast path
return [x.pts for x in container.demux(video=0) if x.pts is not None]
else:
return [x.pts for x in container.decode(video=0) if x.pts is not None]


def read_video_timestamps(filename, pts_unit="pts"):
"""
List the video frames timestamps.
Expand Down Expand Up @@ -326,26 +328,18 @@ def read_video_timestamps(filename, pts_unit="pts"):
pts = []

try:
container = av.open(filename, metadata_errors="ignore")
with av.open(filename, metadata_errors="ignore") as container:
if container.streams.video:
video_stream = container.streams.video[0]
video_time_base = video_stream.time_base
try:
pts = _decode_video_timestamps(container)
except av.AVError:
warnings.warn(f"Failed decoding frames for file {filename}")
video_fps = float(video_stream.average_rate)
except av.AVError:
# TODO add a warning
pass
else:
if container.streams.video:
video_stream = container.streams.video[0]
video_time_base = video_stream.time_base
try:
if _can_read_timestamps_from_packets(container):
# fast path
pts = [x.pts for x in container.demux(video=0) if x.pts is not None]
else:
pts = [
x.pts for x in container.decode(video=0) if x.pts is not None
]
except av.AVError:
warnings.warn(f"Failed decoding frames for file {filename}")
video_fps = float(video_stream.average_rate)
container.close()

pts.sort()

Expand Down

0 comments on commit 42aa9b2

Please sign in to comment.