Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring to use contexts managers, list comprehensions when more idiomatic, and minor renaming to help reader clarity #2335

Merged
merged 3 commits into from
Jun 22, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 2 additions & 16 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,9 @@ def list_dir(root, prefix=False):
only returns the name of the directories found
"""
root = os.path.expanduser(root)
directories = list(
filter(
lambda p: os.path.isdir(os.path.join(root, p)),
os.listdir(root)
)
)

directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
if prefix is True:
directories = [os.path.join(root, d) for d in directories]

return directories


Expand All @@ -119,16 +112,9 @@ def list_files(root, suffix, prefix=False):
only returns the name of the files found
"""
root = os.path.expanduser(root)
files = list(
filter(
lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
os.listdir(root)
)
)

files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
if prefix is True:
files = [os.path.join(root, d) for d in files]

return files


Expand Down
20 changes: 12 additions & 8 deletions torchvision/datasets/video_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import bisect
import math
from fractions import Fraction
from typing import List

import torch
from torchvision.io import (
Expand Down Expand Up @@ -45,20 +46,23 @@ def unfold(tensor, size, step, dilation=1):
return torch.as_strided(tensor, new_size, new_stride)


class _DummyDataset(object):
class _VideoTimestampsDataset(object):
"""
Dummy dataset used for DataLoader in VideoClips.
Defined at top level so it can be pickled when forking.
Dataset used to parallelize the reading of the timestamps
of a list of videos, given their paths in the filesystem.

Used in VideoClips and defined at top level so it can be
pickled when forking.
"""

def __init__(self, x):
self.x = x
def __init__(self, video_paths: List[str]):
self.video_paths = video_paths
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is fine because this is a private class, but in general renaming methods is a BC-breaking change


def __len__(self):
return len(self.x)
return len(self.video_paths)

def __getitem__(self, idx):
return read_video_timestamps(self.x[idx])
return read_video_timestamps(self.video_paths[idx])


class VideoClips(object):
Expand Down Expand Up @@ -132,7 +136,7 @@ def _compute_frame_pts(self):
import torch.utils.data

dl = torch.utils.data.DataLoader(
_DummyDataset(self.video_paths),
_VideoTimestampsDataset(self.video_paths),
batch_size=16,
num_workers=self.num_workers,
collate_fn=self._collate_fn,
Expand Down
122 changes: 58 additions & 64 deletions torchvision/io/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,27 +65,23 @@ def write_video(filename, video_array, fps, video_codec="libx264", options=None)
_check_av_available()
video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()

container = av.open(filename, mode="w")

stream = container.add_stream(video_codec, rate=fps)
stream.width = video_array.shape[2]
stream.height = video_array.shape[1]
stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
stream.options = options or {}

for img in video_array:
frame = av.VideoFrame.from_ndarray(img, format="rgb24")
frame.pict_type = "NONE"
for packet in stream.encode(frame):
with av.open(filename, mode="w") as container:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! PyAV only added support for using open as a context-manager in 6.2.0, which was after we released this first version.
This clean-up is much appreciated!

stream = container.add_stream(video_codec, rate=fps)
stream.width = video_array.shape[2]
stream.height = video_array.shape[1]
stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
stream.options = options or {}

for img in video_array:
frame = av.VideoFrame.from_ndarray(img, format="rgb24")
frame.pict_type = "NONE"
for packet in stream.encode(frame):
container.mux(packet)

# Flush stream
for packet in stream.encode():
container.mux(packet)

# Flush stream
for packet in stream.encode():
container.mux(packet)

# Close the file
container.close()


def _read_from_stream(
container, start_offset, end_offset, pts_unit, stream, stream_name
Expand Down Expand Up @@ -229,37 +225,35 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
audio_frames = []

try:
container = av.open(filename, metadata_errors="ignore")
with av.open(filename, metadata_errors="ignore") as container:
if container.streams.video:
video_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
video_fps = container.streams.video[0].average_rate
# guard against potentially corrupted files
if video_fps is not None:
info["video_fps"] = float(video_fps)

if container.streams.audio:
audio_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.audio[0],
{"audio": 0},
)
info["audio_fps"] = container.streams.audio[0].rate

except av.AVError:
# TODO raise a warning?
pass
else:
if container.streams.video:
video_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
video_fps = container.streams.video[0].average_rate
# guard against potentially corrupted files
if video_fps is not None:
info["video_fps"] = float(video_fps)

if container.streams.audio:
audio_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.audio[0],
{"audio": 0},
)
info["audio_fps"] = container.streams.audio[0].rate

container.close()

vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
aframes = [frame.to_ndarray() for frame in audio_frames]
Expand Down Expand Up @@ -288,6 +282,14 @@ def _can_read_timestamps_from_packets(container):
return False


def _decode_video_timestamps(container):
if _can_read_timestamps_from_packets(container):
# fast path
return [x.pts for x in container.demux(video=0) if x.pts is not None]
else:
return [x.pts for x in container.decode(video=0) if x.pts is not None]


def read_video_timestamps(filename, pts_unit="pts"):
"""
List the video frames timestamps.
Expand Down Expand Up @@ -321,26 +323,18 @@ def read_video_timestamps(filename, pts_unit="pts"):
pts = []

try:
container = av.open(filename, metadata_errors="ignore")
with av.open(filename, metadata_errors="ignore") as container:
if container.streams.video:
video_stream = container.streams.video[0]
video_time_base = video_stream.time_base
try:
pts = _decode_video_timestamps(container)
except av.AVError:
warnings.warn(f"Failed decoding frames for file {filename}")
video_fps = float(video_stream.average_rate)
except av.AVError:
# TODO add a warning
pass
else:
if container.streams.video:
video_stream = container.streams.video[0]
video_time_base = video_stream.time_base
try:
if _can_read_timestamps_from_packets(container):
# fast path
pts = [x.pts for x in container.demux(video=0) if x.pts is not None]
else:
pts = [
x.pts for x in container.decode(video=0) if x.pts is not None
]
except av.AVError:
warnings.warn(f"Failed decoding frames for file {filename}")
video_fps = float(video_stream.average_rate)
container.close()

pts.sort()

Expand Down