Skip to content

Commit

Permalink
Better handle corrupted videos (#1463)
Browse files Browse the repository at this point in the history
* Handle corrupted video headers in io

* Catch exceptions while decoding partly-corrupted files

* Add more tests
  • Loading branch information
fmassa committed Oct 15, 2019
1 parent 1d6145d commit da89dad
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 34 deletions.
35 changes: 35 additions & 0 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,41 @@ def test_read_partial_video_pts_unit_sec(self):
self.assertEqual(len(lv), 4)
self.assertTrue(data[4:8].equal(lv))

def test_read_video_corrupted_file(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
f.write(b'This is not an mpg4 file')
video, audio, info = io.read_video(f.name)
self.assertIsInstance(video, torch.Tensor)
self.assertIsInstance(audio, torch.Tensor)
self.assertEqual(video.numel(), 0)
self.assertEqual(audio.numel(), 0)
self.assertEqual(info, {})

def test_read_video_timestamps_corrupted_file(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
f.write(b'This is not an mpg4 file')
video_pts, video_fps = io.read_video_timestamps(f.name)
self.assertEqual(video_pts, [])
self.assertIs(video_fps, None)

def test_read_video_partially_corrupted_file(self):
with temp_video(5, 4, 4, 5, lossless=True) as (f_name, data):
with open(f_name, 'r+b') as f:
size = os.path.getsize(f_name)
bytes_to_overwrite = size // 10
# seek to the middle of the file
f.seek(5 * bytes_to_overwrite)
# corrupt 10% of the file from the middle
f.write(b'\xff' * bytes_to_overwrite)
# this exercises the container.decode assertion check
video, audio, info = io.read_video(f.name, pts_unit='sec')
# check that size is not equal to 5, but 3
self.assertEqual(len(video), 3)
# but the valid decoded content is still correct
self.assertTrue(video[:3].equal(data[:3]))
# and the last few frames are wrong
self.assertFalse(video.equal(data))

# TODO add tests for audio


Expand Down
95 changes: 61 additions & 34 deletions torchvision/io/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,17 @@ def _read_from_stream(container, start_offset, end_offset, pts_unit, stream, str
# print("Corrupted file?", container.name)
return []
buffer_count = 0
for idx, frame in enumerate(container.decode(**stream_name)):
frames[frame.pts] = frame
if frame.pts >= end_offset:
if should_buffer and buffer_count < max_buffer_size:
buffer_count += 1
continue
break
try:
for idx, frame in enumerate(container.decode(**stream_name)):
frames[frame.pts] = frame
if frame.pts >= end_offset:
if should_buffer and buffer_count < max_buffer_size:
buffer_count += 1
continue
break
except av.AVError:
# TODO add a warning
pass
# ensure that the results are sorted wrt the pts
result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset]
if start_offset > 0 and start_offset not in frames:
Expand Down Expand Up @@ -193,25 +197,39 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'):
raise ValueError("end_pts should be larger than start_pts, got "
"start_pts={} and end_pts={}".format(start_pts, end_pts))

container = av.open(filename, metadata_errors='ignore')
info = {}

video_frames = []
if container.streams.video:
video_frames = _read_from_stream(container, start_pts, end_pts, pts_unit,
container.streams.video[0], {'video': 0})
info["video_fps"] = float(container.streams.video[0].average_rate)
audio_frames = []
if container.streams.audio:
audio_frames = _read_from_stream(container, start_pts, end_pts, pts_unit,
container.streams.audio[0], {'audio': 0})
info["audio_fps"] = container.streams.audio[0].rate

container.close()
try:
container = av.open(filename, metadata_errors='ignore')
except av.AVError:
# TODO raise a warning?
pass
else:
if container.streams.video:
video_frames = _read_from_stream(container, start_pts, end_pts, pts_unit,
container.streams.video[0], {'video': 0})
video_fps = container.streams.video[0].average_rate
# guard against potentially corrupted files
if video_fps is not None:
info["video_fps"] = float(video_fps)

if container.streams.audio:
audio_frames = _read_from_stream(container, start_pts, end_pts, pts_unit,
container.streams.audio[0], {'audio': 0})
info["audio_fps"] = container.streams.audio[0].rate

container.close()

vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
aframes = [frame.to_ndarray() for frame in audio_frames]
vframes = torch.as_tensor(np.stack(vframes))

if vframes:
vframes = torch.as_tensor(np.stack(vframes))
else:
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)

if aframes:
aframes = np.concatenate(aframes, 1)
aframes = torch.as_tensor(aframes)
Expand Down Expand Up @@ -255,21 +273,30 @@ def read_video_timestamps(filename, pts_unit='pts'):
"""
_check_av_available()

container = av.open(filename, metadata_errors='ignore')

video_frames = []
video_fps = None
if container.streams.video:
video_stream = container.streams.video[0]
video_time_base = video_stream.time_base
if _can_read_timestamps_from_packets(container):
# fast path
video_frames = [x for x in container.demux(video=0) if x.pts is not None]
else:
video_frames = _read_from_stream(container, 0, float("inf"), pts_unit,
video_stream, {'video': 0})
video_fps = float(video_stream.average_rate)
container.close()

try:
container = av.open(filename, metadata_errors='ignore')
except av.AVError:
# TODO add a warning
pass
else:
if container.streams.video:
video_stream = container.streams.video[0]
video_time_base = video_stream.time_base
if _can_read_timestamps_from_packets(container):
# fast path
video_frames = [x for x in container.demux(video=0) if x.pts is not None]
else:
video_frames = _read_from_stream(container, 0, float("inf"), pts_unit,
video_stream, {'video': 0})
video_fps = float(video_stream.average_rate)
container.close()

pts = [x.pts for x in video_frames]

if pts_unit == 'sec':
return [x.pts * video_time_base for x in video_frames], video_fps
return [x.pts for x in video_frames], video_fps
pts = [x * video_time_base for x in pts]

return pts, video_fps

0 comments on commit da89dad

Please sign in to comment.