Skip to content

Commit

Permalink
extend video reader to support fast video probing (#1437)
Browse files Browse the repository at this point in the history
* extend video reader to support fast video probing

* fix c++ lint

* small fix

* allow to accept input video of type torch.Tensor
  • Loading branch information
stephenyan1231 authored and fmassa committed Oct 12, 2019
1 parent 7ae1b8c commit ed5b2dc
Show file tree
Hide file tree
Showing 10 changed files with 332 additions and 57 deletions.
16 changes: 16 additions & 0 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,22 @@ def test_write_read_video(self):
self.assertTrue(data.equal(lv))
self.assertEqual(info["video_fps"], 5)

@unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
def test_probe_video_from_file(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
video_info = io._probe_video_from_file(f_name)
self.assertAlmostEqual(video_info["video_duration"], 2, delta=0.1)
self.assertAlmostEqual(video_info["video_fps"], 5, delta=0.1)

@unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
def test_probe_video_from_memory(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
with open(f_name, "rb") as fp:
filebuffer = fp.read()
video_info = io._probe_video_from_memory(filebuffer)
self.assertAlmostEqual(video_info["video_duration"], 2, delta=0.1)
self.assertAlmostEqual(video_info["video_fps"], 5, delta=0.1)

def test_read_timestamps(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
if _video_backend == "pyav":
Expand Down
105 changes: 76 additions & 29 deletions test/test_video_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")

CheckerConfig = [
"duration",
"video_fps",
"audio_sample_rate",
# We find for some videos (e.g. HMDB51 videos), the decoded audio frames and pts are
Expand All @@ -44,6 +45,7 @@
)

all_check_config = GroundTruth(
duration=0,
video_fps=0,
audio_sample_rate=0,
check_aframes=True,
Expand All @@ -52,50 +54,58 @@

test_videos = {
"RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(
duration=2.0,
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth(
duration=2.0,
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(
duration=2.0,
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"v_SoccerJuggling_g23_c01.avi": GroundTruth(
duration=8.0,
video_fps=29.97,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"v_SoccerJuggling_g24_c01.avi": GroundTruth(
duration=8.0,
video_fps=29.97,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"R6llTwEh07w.mp4": GroundTruth(
duration=10.0,
video_fps=30.0,
audio_sample_rate=44100,
# PyAv miss one audio frame at the beginning (pts=0)
check_aframes=False,
check_aframe_pts=False,
),
"SOX5yA1l24A.mp4": GroundTruth(
duration=11.0,
video_fps=29.97,
audio_sample_rate=48000,
# PyAv miss one audio frame at the beginning (pts=0)
check_aframes=False,
check_aframe_pts=False,
),
"WUzgd7C1pWA.mp4": GroundTruth(
duration=11.0,
video_fps=29.97,
audio_sample_rate=48000,
# PyAv miss one audio frame at the beginning (pts=0)
Expand Down Expand Up @@ -272,13 +282,22 @@ class TestVideoReader(unittest.TestCase):
def check_separate_decoding_result(self, tv_result, config):
"""check the decoding results from TorchVision decoder
"""
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = (
tv_result
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
atimebase, asample_rate, aduration = tv_result

video_duration = vduration.item() * Fraction(
vtimebase[0].item(), vtimebase[1].item()
)
self.assertAlmostEqual(video_duration, config.duration, delta=0.5)

self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5)
if asample_rate.numel() > 0:
self.assertEqual(asample_rate.item(), config.audio_sample_rate)
audio_duration = aduration.item() * Fraction(
atimebase[0].item(), atimebase[1].item()
)
self.assertAlmostEqual(audio_duration, config.duration, delta=0.5)

# check if pts of video frames are sorted in ascending order
for i in range(len(vframe_pts) - 1):
self.assertEqual(vframe_pts[i] < vframe_pts[i + 1], True)
Expand All @@ -288,6 +307,20 @@ def check_separate_decoding_result(self, tv_result, config):
for i in range(len(aframe_pts) - 1):
self.assertEqual(aframe_pts[i] < aframe_pts[i + 1], True)

def check_probe_result(self, result, config):
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
video_duration = vduration.item() * Fraction(
vtimebase[0].item(), vtimebase[1].item()
)
self.assertAlmostEqual(video_duration, config.duration, delta=0.5)
self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5)
if asample_rate.numel() > 0:
self.assertEqual(asample_rate.item(), config.audio_sample_rate)
audio_duration = aduration.item() * Fraction(
atimebase[0].item(), atimebase[1].item()
)
self.assertAlmostEqual(audio_duration, config.duration, delta=0.5)

def compare_decoding_result(self, tv_result, ref_result, config=all_check_config):
"""
Compare decoding results from two sources.
Expand All @@ -297,18 +330,17 @@ def compare_decoding_result(self, tv_result, ref_result, config=all_check_config
decoder or TorchVision decoder with getPtsOnly = 1
config: config of decoding results checker
"""
vframes, vframe_pts, vtimebase, _vfps, aframes, aframe_pts, atimebase, _asample_rate = (
tv_result
)
vframes, vframe_pts, vtimebase, _vfps, _vduration, aframes, aframe_pts, \
atimebase, _asample_rate, _aduration = tv_result
if isinstance(ref_result, list):
# the ref_result is from new video_reader decoder
ref_result = DecoderResult(
vframes=ref_result[0],
vframe_pts=ref_result[1],
vtimebase=ref_result[2],
aframes=ref_result[4],
aframe_pts=ref_result[5],
atimebase=ref_result[6],
aframes=ref_result[5],
aframe_pts=ref_result[6],
atimebase=ref_result[7],
)

if vframes.numel() > 0 and ref_result.vframes.numel() > 0:
Expand Down Expand Up @@ -351,12 +383,12 @@ def test_stress_test_read_video_from_file(self):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1

for i in range(num_iter):
for test_video, config in test_videos.items():
for _i in range(num_iter):
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)

# pass 1: decode all frames using new decoder
_ = torch.ops.video_reader.read_video_from_file(
torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
Expand Down Expand Up @@ -460,9 +492,8 @@ def test_read_video_from_file_read_single_stream_only(self):
audio_timebase_den,
)

vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = (
tv_result
)
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
atimebase, asample_rate, aduration = tv_result

self.assertEqual(vframes.numel() > 0, readVideoStream)
self.assertEqual(vframe_pts.numel() > 0, readVideoStream)
Expand All @@ -489,7 +520,7 @@ def test_read_video_from_file_rescale_min_dimension(self):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1

for test_video, config in test_videos.items():
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)

tv_result = torch.ops.video_reader.read_video_from_file(
Expand Down Expand Up @@ -528,7 +559,7 @@ def test_read_video_from_file_rescale_width(self):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1

for test_video, config in test_videos.items():
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)

tv_result = torch.ops.video_reader.read_video_from_file(
Expand Down Expand Up @@ -567,7 +598,7 @@ def test_read_video_from_file_rescale_height(self):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1

for test_video, config in test_videos.items():
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)

tv_result = torch.ops.video_reader.read_video_from_file(
Expand Down Expand Up @@ -606,7 +637,7 @@ def test_read_video_from_file_rescale_width_and_height(self):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1

for test_video, config in test_videos.items():
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)

tv_result = torch.ops.video_reader.read_video_from_file(
Expand Down Expand Up @@ -651,7 +682,7 @@ def test_read_video_from_file_audio_resampling(self):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1

for test_video, config in test_videos.items():
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)

tv_result = torch.ops.video_reader.read_video_from_file(
Expand All @@ -674,18 +705,17 @@ def test_read_video_from_file_audio_resampling(self):
audio_timebase_num,
audio_timebase_den,
)
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, a_sample_rate = (
tv_result
)
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
atimebase, asample_rate, aduration = tv_result
if aframes.numel() > 0:
self.assertEqual(samples, a_sample_rate.item())
self.assertEqual(samples, asample_rate.item())
self.assertEqual(1, aframes.size(1))
# when audio stream is found
duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1])
self.assertAlmostEqual(
aframes.size(0),
int(duration * a_sample_rate.item()),
delta=0.1 * a_sample_rate.item(),
int(duration * asample_rate.item()),
delta=0.1 * asample_rate.item(),
)

def test_compare_read_video_from_memory_and_file(self):
Expand Down Expand Up @@ -859,7 +889,7 @@ def test_read_video_from_memory_get_pts_only(self):
)

self.assertEqual(tv_result_pts_only[0].numel(), 0)
self.assertEqual(tv_result_pts_only[4].numel(), 0)
self.assertEqual(tv_result_pts_only[5].numel(), 0)
self.compare_decoding_result(tv_result, tv_result_pts_only)

def test_read_video_in_range_from_memory(self):
Expand Down Expand Up @@ -899,9 +929,8 @@ def test_read_video_in_range_from_memory(self):
audio_timebase_num,
audio_timebase_den,
)
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = (
tv_result
)
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
atimebase, asample_rate, aduration = tv_result
self.assertAlmostEqual(config.video_fps, vfps.item(), delta=0.01)

for num_frames in [4, 8, 16, 32, 64, 128]:
Expand Down Expand Up @@ -997,6 +1026,24 @@ def test_read_video_in_range_from_memory(self):
# and PyAv
self.compare_decoding_result(tv_result, pyav_result, config)

def test_probe_video_from_file(self):
"""
Test the case when decoder probes a video file
"""
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
probe_result = torch.ops.video_reader.probe_video_from_file(full_path)
self.check_probe_result(probe_result, config)

def test_probe_video_from_memory(self):
"""
Test the case when decoder probes a video in memory
"""
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor)
self.check_probe_result(probe_result, config)


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions torchvision/csrc/cpu/video_reader/FfmpegAudioStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ void FfmpegAudioStream::updateStreamDecodeParams() {
mediaFormat_.format.audio.timeBaseDen =
inputCtx_->streams[index_]->time_base.den;
}
mediaFormat_.format.audio.duration = inputCtx_->streams[index_]->duration;
}

int FfmpegAudioStream::initFormat() {
Expand Down
34 changes: 34 additions & 0 deletions torchvision/csrc/cpu/video_reader/FfmpegDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,30 @@ int FfmpegDecoder::decodeMemory(
return ret;
}

int FfmpegDecoder::probeFile(
unique_ptr<DecoderParameters> params,
const string& fileName,
DecoderOutput& decoderOutput) {
VLOG(1) << "probe file: " << fileName;
FfmpegAvioContext ioctx;
return probeVideo(std::move(params), fileName, true, ioctx, decoderOutput);
}

int FfmpegDecoder::probeMemory(
unique_ptr<DecoderParameters> params,
const uint8_t* buffer,
int64_t size,
DecoderOutput& decoderOutput) {
VLOG(1) << "probe video data in memory";
FfmpegAvioContext ioctx;
int ret = ioctx.initAVIOContext(buffer, size);
if (ret == 0) {
ret =
probeVideo(std::move(params), string(""), false, ioctx, decoderOutput);
}
return ret;
}

void FfmpegDecoder::cleanUp() {
if (formatCtx_) {
for (auto& stream : streams_) {
Expand Down Expand Up @@ -320,6 +344,16 @@ int FfmpegDecoder::decodeLoop(
return ret;
}

int FfmpegDecoder::probeVideo(
unique_ptr<DecoderParameters> params,
const std::string& filename,
bool isDecodeFile,
FfmpegAvioContext& ioctx,
DecoderOutput& decoderOutput) {
params_ = std::move(params);
return init(filename, isDecodeFile, ioctx, decoderOutput);
}

bool FfmpegDecoder::initStreams() {
for (auto it = params_->formats.begin(); it != params_->formats.end(); ++it) {
AVMediaType mediaType;
Expand Down

0 comments on commit ed5b2dc

Please sign in to comment.