From 8bf75f0506d9361103daf572ae6e3a467b499ee6 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 24 Jul 2019 06:00:10 -0700 Subject: [PATCH] Test videos with B-Frames Also extend video saving to support different codecs and options. Notably, we can now save with lossless compression --- test/test_io.py | 88 +++++++++++++++++++++++++++-------------- torchvision/io/video.py | 8 ++-- 2 files changed, 63 insertions(+), 33 deletions(-) diff --git a/test/test_io.py b/test/test_io.py index a77b6d22cca..f3008ce4a67 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -1,4 +1,5 @@ import os +import contextlib import tempfile import torch import torchvision.io as io @@ -11,45 +12,59 @@ av = None +def _create_video_frames(num_frames, height, width): + y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width)) + data = [] + for i in range(num_frames): + xc = float(i) / num_frames + yc = 1 - float(i) / (2 * num_frames) + d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255 + data.append(d.unsqueeze(2).repeat(1, 1, 3).byte()) + + return torch.stack(data, 0) + + +@contextlib.contextmanager +def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, options=None): + if lossless: + assert video_codec is None, "video_codec can't be specified together with lossless" + assert options is None, "options can't be specified together with lossless" + video_codec = 'libx264rgb' + options = {'crf': '0'} + + if video_codec is None: + video_codec = 'libx264' + if options is None: + options = {} + + data = _create_video_frames(num_frames, height, width) + with tempfile.NamedTemporaryFile(suffix='.mp4') as f: + io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options) + yield f.name, data + + class Tester(unittest.TestCase): # compression adds artifacts, thus we add a tolerance of # 6 in 0-255 range TOLERANCE = 6 - def _create_video_frames(self, num_frames, height, width): - y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width)) - data = [] - for i in range(num_frames): - xc = float(i) / num_frames - yc = 1 - float(i) / (2 * num_frames) - d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255 - data.append(d.unsqueeze(2).repeat(1, 1, 3).byte()) - - return torch.stack(data, 0) - @unittest.skipIf(av is None, "PyAV unavailable") def test_write_read_video(self): - with tempfile.NamedTemporaryFile(suffix='.mp4') as f: - data = self._create_video_frames(10, 300, 300) - io.write_video(f.name, data, fps=5) + with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): + lv, _, info = io.read_video(f_name) - lv, _, info = io.read_video(f.name) - - self.assertTrue((data.float() - lv.float()).abs().max() < self.TOLERANCE) + self.assertTrue(data.equal(lv)) self.assertEqual(info["video_fps"], 5) @unittest.skipIf(av is None, "PyAV unavailable") def test_read_timestamps(self): - with tempfile.NamedTemporaryFile(suffix='.mp4') as f: - data = self._create_video_frames(10, 300, 300) - io.write_video(f.name, data, fps=5) - - pts, _ = io.read_video_timestamps(f.name) + with temp_video(10, 300, 300, 5) as (f_name, data): + pts, _ = io.read_video_timestamps(f_name) # note: not all formats/codecs provide accurate information for computing the # timestamps. For the format that we use here, this information is available, # so we use it as a baseline - container = av.open(f.name) + container = av.open(f_name) stream = container.streams[0] pts_step = int(round(float(1 / (stream.average_rate * stream.time_base)))) num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration))) @@ -59,20 +74,33 @@ def test_read_timestamps(self): @unittest.skipIf(av is None, "PyAV unavailable") def test_read_partial_video(self): - with tempfile.NamedTemporaryFile(suffix='.mp4') as f: - data = self._create_video_frames(10, 300, 300) - io.write_video(f.name, data, fps=5) + with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): + pts, _ = io.read_video_timestamps(f_name) + for start in range(5): + for l in range(1, 4): + lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1]) + s_data = data[start:(start + l)] + self.assertEqual(len(lv), l) + self.assertTrue(s_data.equal(lv)) - pts, _ = io.read_video_timestamps(f.name) + lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7]) + self.assertEqual(len(lv), 4) + self.assertTrue(data[4:8].equal(lv)) - for start in range(5): + @unittest.skipIf(av is None, "PyAV unavailable") + def test_read_partial_video_bframes(self): + # do not use lossless encoding, to test the presence of B-frames + options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'} + with temp_video(100, 300, 300, 5, options=options) as (f_name, data): + pts, _ = io.read_video_timestamps(f_name) + for start in range(0, 80, 20): for l in range(1, 4): - lv, _, _ = io.read_video(f.name, pts[start], pts[start + l - 1]) + lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1]) s_data = data[start:(start + l)] self.assertEqual(len(lv), l) self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE) - lv, _, _ = io.read_video(f.name, pts[4] + 1, pts[7]) + lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7]) self.assertEqual(len(lv), 4) self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 6f20142d0a5..e72a0e924d1 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -23,7 +23,7 @@ def _check_av_available(): _GC_COLLECTION_INTERVAL = 20 -def write_video(filename, video_array, fps): +def write_video(filename, video_array, fps, video_codec='libx264', options=None): """ Writes a 4d tensor in [T, H, W, C] format in a video file @@ -38,13 +38,15 @@ def write_video(filename, video_array, fps): container = av.open(filename, mode='w') - stream = container.add_stream('mpeg4', rate=fps) + stream = container.add_stream(video_codec, rate=fps) stream.width = video_array.shape[2] stream.height = video_array.shape[1] - stream.pix_fmt = 'yuv420p' + stream.pix_fmt = 'yuv420p' if video_codec != 'libx264rgb' else 'rgb24' + stream.options = options or {} for img in video_array: frame = av.VideoFrame.from_ndarray(img, format='rgb24') + frame.pict_type = 'NONE' for packet in stream.encode(frame): container.mux(packet)