Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 87 additions & 89 deletions test/test_io.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest
import os
import contextlib
import sys
import tempfile
import torch
import torchvision.io as io
from torchvision import get_video_backend
import unittest
import warnings
from urllib.error import URLError

Expand Down Expand Up @@ -64,10 +64,10 @@ def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None,
os.unlink(f.name)


@unittest.skipIf(get_video_backend() != "pyav" and not io._HAS_VIDEO_OPT,
"video_reader backend not available")
@unittest.skipIf(av is None, "PyAV unavailable")
class TestIO(unittest.TestCase):
@pytest.mark.skipif(get_video_backend() != "pyav" and not io._HAS_VIDEO_OPT,
reason="video_reader backend not available")
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
class TestVideo:
# compression adds artifacts, thus we add a tolerance of
# 6 in 0-255 range
TOLERANCE = 6
Expand All @@ -76,164 +76,162 @@ def test_write_read_video(self):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
lv, _, info = io.read_video(f_name)
assert_equal(data, lv)
self.assertEqual(info["video_fps"], 5)
assert info["video_fps"] == 5

@unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
@pytest.mark.skipif(not io._HAS_VIDEO_OPT, reason="video_reader backend is not chosen")
def test_probe_video_from_file(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
video_info = io._probe_video_from_file(f_name)
self.assertAlmostEqual(video_info.video_duration, 2, delta=0.1)
self.assertAlmostEqual(video_info.video_fps, 5, delta=0.1)
assert pytest.approx(2, rel=0.0, abs=0.1) == video_info.video_duration
assert pytest.approx(5, rel=0.0, abs=0.1) == video_info.video_fps

@unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
@pytest.mark.skipif(not io._HAS_VIDEO_OPT, reason="video_reader backend is not chosen")
def test_probe_video_from_memory(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
with open(f_name, "rb") as fp:
filebuffer = fp.read()
video_info = io._probe_video_from_memory(filebuffer)
self.assertAlmostEqual(video_info.video_duration, 2, delta=0.1)
self.assertAlmostEqual(video_info.video_fps, 5, delta=0.1)
assert pytest.approx(2, rel=0.0, abs=0.1) == video_info.video_duration
assert pytest.approx(5, rel=0.0, abs=0.1) == video_info.video_fps

def test_read_timestamps(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)
# note: not all formats/codecs provide accurate information for computing the
# timestamps. For the format that we use here, this information is available,
# so we use it as a baseline
container = av.open(f_name)
stream = container.streams[0]
pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
expected_pts = [i * pts_step for i in range(num_frames)]
with av.open(f_name) as container:
stream = container.streams[0]
pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
expected_pts = [i * pts_step for i in range(num_frames)]

self.assertEqual(pts, expected_pts)
container.close()
assert pts == expected_pts

def test_read_partial_video(self):
@pytest.mark.parametrize('start', range(5))
@pytest.mark.parametrize('offset', range(1, 4))
def test_read_partial_video(self, start, offset):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)
for start in range(5):
for offset in range(1, 4):
lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1])
s_data = data[start:(start + offset)]
self.assertEqual(len(lv), offset)
assert_equal(s_data, lv)

lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1])
s_data = data[start:(start + offset)]
assert len(lv) == offset
assert_equal(s_data, lv)

if get_video_backend() == "pyav":
# for "video_reader" backend, we don't decode the closest early frame
# when the given start pts is not matching any frame pts
lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
self.assertEqual(len(lv), 4)
assert len(lv) == 4
assert_equal(data[4:8], lv)

def test_read_partial_video_bframes(self):
@pytest.mark.parametrize('start', range(0, 80, 20))
@pytest.mark.parametrize('offset', range(1, 4))
def test_read_partial_video_bframes(self, start, offset):
# do not use lossless encoding, to test the presence of B-frames
options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'}
with temp_video(100, 300, 300, 5, options=options) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)
for start in range(0, 80, 20):
for offset in range(1, 4):
lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1])
s_data = data[start:(start + offset)]
self.assertEqual(len(lv), offset)
assert_equal(s_data, lv, rtol=0.0, atol=self.TOLERANCE)

lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1])
s_data = data[start:(start + offset)]
assert len(lv) == offset
assert_equal(s_data, lv, rtol=0.0, atol=self.TOLERANCE)

lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
# TODO fix this
if get_video_backend() == 'pyav':
self.assertEqual(len(lv), 4)
assert len(lv) == 4
assert_equal(data[4:8], lv, rtol=0.0, atol=self.TOLERANCE)
else:
self.assertEqual(len(lv), 3)
assert len(lv) == 3
assert_equal(data[5:8], lv, rtol=0.0, atol=self.TOLERANCE)

def test_read_packed_b_frames_divx_file(self):
name = "hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi"
f_name = os.path.join(VIDEO_DIR, name)
pts, fps = io.read_video_timestamps(f_name)

self.assertEqual(pts, sorted(pts))
self.assertEqual(fps, 30)
assert pts == sorted(pts)
assert fps == 30

def test_read_timestamps_from_packet(self):
with temp_video(10, 300, 300, 5, video_codec='mpeg4') as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)
# note: not all formats/codecs provide accurate information for computing the
# timestamps. For the format that we use here, this information is available,
# so we use it as a baseline
container = av.open(f_name)
stream = container.streams[0]
# make sure we went through the optimized codepath
self.assertIn(b'Lavc', stream.codec_context.extradata)
pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
expected_pts = [i * pts_step for i in range(num_frames)]
with av.open(f_name) as container:
stream = container.streams[0]
# make sure we went through the optimized codepath
assert b'Lavc' in stream.codec_context.extradata
pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
expected_pts = [i * pts_step for i in range(num_frames)]

self.assertEqual(pts, expected_pts)
container.close()
assert pts == expected_pts

def test_read_video_pts_unit_sec(self):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
lv, _, info = io.read_video(f_name, pts_unit='sec')

assert_equal(data, lv)
self.assertEqual(info["video_fps"], 5)
self.assertEqual(info, {"video_fps": 5})
assert info["video_fps"] == 5
assert info == {"video_fps": 5}

def test_read_timestamps_pts_unit_sec(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name, pts_unit='sec')

container = av.open(f_name)
stream = container.streams[0]
pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
expected_pts = [i * pts_step * stream.time_base for i in range(num_frames)]
with av.open(f_name) as container:
stream = container.streams[0]
pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
expected_pts = [i * pts_step * stream.time_base for i in range(num_frames)]

self.assertEqual(pts, expected_pts)
container.close()
assert pts == expected_pts

def test_read_partial_video_pts_unit_sec(self):
@pytest.mark.parametrize('start', range(5))
@pytest.mark.parametrize('offset', range(1, 4))
def test_read_partial_video_pts_unit_sec(self, start, offset):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name, pts_unit='sec')

for start in range(5):
for offset in range(1, 4):
lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1], pts_unit='sec')
s_data = data[start:(start + offset)]
self.assertEqual(len(lv), offset)
assert_equal(s_data, lv)

container = av.open(f_name)
stream = container.streams[0]
lv, _, _ = io.read_video(f_name,
int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base, pts[7],
pts_unit='sec')
lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1], pts_unit='sec')
s_data = data[start:(start + offset)]
assert len(lv) == offset
assert_equal(s_data, lv)

with av.open(f_name) as container:
stream = container.streams[0]
lv, _, _ = io.read_video(f_name,
int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base, pts[7],
pts_unit='sec')
if get_video_backend() == "pyav":
# for "video_reader" backend, we don't decode the closest early frame
# when the given start pts is not matching any frame pts
self.assertEqual(len(lv), 4)
assert len(lv) == 4
assert_equal(data[4:8], lv)
container.close()

def test_read_video_corrupted_file(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
f.write(b'This is not an mpg4 file')
video, audio, info = io.read_video(f.name)
self.assertIsInstance(video, torch.Tensor)
self.assertIsInstance(audio, torch.Tensor)
self.assertEqual(video.numel(), 0)
self.assertEqual(audio.numel(), 0)
self.assertEqual(info, {})
assert isinstance(video, torch.Tensor)
assert isinstance(audio, torch.Tensor)
assert video.numel() == 0
assert audio.numel() == 0
assert info == {}

def test_read_video_timestamps_corrupted_file(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
f.write(b'This is not an mpg4 file')
video_pts, video_fps = io.read_video_timestamps(f.name)
self.assertEqual(video_pts, [])
self.assertIs(video_fps, None)
assert video_pts == []
assert video_fps is None

@unittest.skip("Temporarily disabled due to new pyav")
@pytest.mark.skip(reason="Temporarily disabled due to new pyav")
def test_read_video_partially_corrupted_file(self):
with temp_video(5, 4, 4, 5, lossless=True) as (f_name, data):
with open(f_name, 'r+b') as f:
Expand All @@ -248,16 +246,16 @@ def test_read_video_partially_corrupted_file(self):
# check that size is not equal to 5, but 3
# TODO fix this
if get_video_backend() == 'pyav':
self.assertEqual(len(video), 3)
assert len(video) == 3
else:
self.assertEqual(len(video), 4)
assert len(video) == 4
# but the valid decoded content is still correct
assert_equal(video[:3], data[:3])
# and the last few frames are wrong
with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
assert_equal(video, data)

@unittest.skipIf(sys.platform == 'win32', 'temporarily disabled on Windows')
@pytest.mark.skipif(sys.platform == 'win32', reason='temporarily disabled on Windows')
def test_write_video_with_audio(self):
f_name = os.path.join(VIDEO_DIR, "R6llTwEh07w.mp4")
video_tensor, audio_tensor, info = io.read_video(f_name, pts_unit="sec")
Expand All @@ -279,19 +277,19 @@ def test_write_video_with_audio(self):
out_f_name, pts_unit="sec"
)

self.assertEqual(info["video_fps"], out_info["video_fps"])
assert info["video_fps"] == out_info["video_fps"]
assert_equal(video_tensor, out_video_tensor)

audio_stream = av.open(f_name).streams.audio[0]
out_audio_stream = av.open(out_f_name).streams.audio[0]

self.assertEqual(info["audio_fps"], out_info["audio_fps"])
self.assertEqual(audio_stream.rate, out_audio_stream.rate)
self.assertAlmostEqual(audio_stream.frames, out_audio_stream.frames, delta=1)
self.assertEqual(audio_stream.frame_size, out_audio_stream.frame_size)
assert info["audio_fps"] == out_info["audio_fps"]
assert audio_stream.rate == out_audio_stream.rate
assert pytest.approx(out_audio_stream.frames, rel=0.0, abs=1) == audio_stream.frames
assert audio_stream.frame_size == out_audio_stream.frame_size

# TODO add tests for audio


if __name__ == '__main__':
unittest.main()
pytest.main(__file__)