Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test videos with B-Frames #1157

Merged
merged 1 commit into from
Jul 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
88 changes: 58 additions & 30 deletions test/test_io.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import contextlib
import tempfile
import torch
import torchvision.io as io
Expand All @@ -11,45 +12,59 @@
av = None


def _create_video_frames(num_frames, height, width):
y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width))
data = []
for i in range(num_frames):
xc = float(i) / num_frames
yc = 1 - float(i) / (2 * num_frames)
d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255
data.append(d.unsqueeze(2).repeat(1, 1, 3).byte())

return torch.stack(data, 0)


@contextlib.contextmanager
def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, options=None):
if lossless:
assert video_codec is None, "video_codec can't be specified together with lossless"
assert options is None, "options can't be specified together with lossless"
video_codec = 'libx264rgb'
options = {'crf': '0'}

if video_codec is None:
video_codec = 'libx264'
if options is None:
options = {}

data = _create_video_frames(num_frames, height, width)
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options)
yield f.name, data


class Tester(unittest.TestCase):
# compression adds artifacts, thus we add a tolerance of
# 6 in 0-255 range
TOLERANCE = 6

def _create_video_frames(self, num_frames, height, width):
y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width))
data = []
for i in range(num_frames):
xc = float(i) / num_frames
yc = 1 - float(i) / (2 * num_frames)
d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255
data.append(d.unsqueeze(2).repeat(1, 1, 3).byte())

return torch.stack(data, 0)

@unittest.skipIf(av is None, "PyAV unavailable")
def test_write_read_video(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
data = self._create_video_frames(10, 300, 300)
io.write_video(f.name, data, fps=5)
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
lv, _, info = io.read_video(f_name)

lv, _, info = io.read_video(f.name)

self.assertTrue((data.float() - lv.float()).abs().max() < self.TOLERANCE)
self.assertTrue(data.equal(lv))
self.assertEqual(info["video_fps"], 5)

@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_timestamps(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
data = self._create_video_frames(10, 300, 300)
io.write_video(f.name, data, fps=5)

pts, _ = io.read_video_timestamps(f.name)
with temp_video(10, 300, 300, 5) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)

# note: not all formats/codecs provide accurate information for computing the
# timestamps. For the format that we use here, this information is available,
# so we use it as a baseline
container = av.open(f.name)
container = av.open(f_name)
stream = container.streams[0]
pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
Expand All @@ -59,20 +74,33 @@ def test_read_timestamps(self):

@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_partial_video(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
data = self._create_video_frames(10, 300, 300)
io.write_video(f.name, data, fps=5)
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)
for start in range(5):
for l in range(1, 4):
lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
s_data = data[start:(start + l)]
self.assertEqual(len(lv), l)
self.assertTrue(s_data.equal(lv))

pts, _ = io.read_video_timestamps(f.name)
lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
self.assertEqual(len(lv), 4)
self.assertTrue(data[4:8].equal(lv))

for start in range(5):
@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_partial_video_bframes(self):
# do not use lossless encoding, to test the presence of B-frames
options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'}
with temp_video(100, 300, 300, 5, options=options) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)
for start in range(0, 80, 20):
for l in range(1, 4):
lv, _, _ = io.read_video(f.name, pts[start], pts[start + l - 1])
lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
s_data = data[start:(start + l)]
self.assertEqual(len(lv), l)
self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE)

lv, _, _ = io.read_video(f.name, pts[4] + 1, pts[7])
lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
self.assertEqual(len(lv), 4)
self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)

Expand Down
8 changes: 5 additions & 3 deletions torchvision/io/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _check_av_available():
_GC_COLLECTION_INTERVAL = 20


def write_video(filename, video_array, fps):
def write_video(filename, video_array, fps, video_codec='libx264', options=None):
"""
Writes a 4d tensor in [T, H, W, C] format in a video file

Expand All @@ -38,13 +38,15 @@ def write_video(filename, video_array, fps):

container = av.open(filename, mode='w')

stream = container.add_stream('mpeg4', rate=fps)
stream = container.add_stream(video_codec, rate=fps)
stream.width = video_array.shape[2]
stream.height = video_array.shape[1]
stream.pix_fmt = 'yuv420p'
stream.pix_fmt = 'yuv420p' if video_codec != 'libx264rgb' else 'rgb24'
stream.options = options or {}

for img in video_array:
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame.pict_type = 'NONE'
for packet in stream.encode(frame):
container.mux(packet)

Expand Down