Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 67 additions & 41 deletions benchmarks/samplers/benchmark_samplers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
from pathlib import Path
from time import perf_counter_ns

Expand Down Expand Up @@ -45,51 +46,76 @@ def report_stats(times, num_frames, unit="ms"):
return med, fps


def sample(sampler, **kwargs):
decoder = VideoDecoder(VIDEO_PATH)
def sample(decoder, sampler, **kwargs):
return sampler(
decoder,
num_frames_per_clip=10,
**kwargs,
)


VIDEO_PATH = Path(__file__).parent / "../../test/resources/nasa_13013.mp4"
NUM_EXP = 30

for num_clips in (1, 50):
print("-" * 10)
print(f"{num_clips = }")

print("clips_at_random_indices ", end="")
times, num_frames = bench(
sample, clips_at_random_indices, num_clips=num_clips, num_exp=NUM_EXP, warmup=2
)
report_stats(times, num_frames, unit="ms")

print("clips_at_regular_indices ", end="")
times, num_frames = bench(
sample, clips_at_regular_indices, num_clips=num_clips, num_exp=NUM_EXP, warmup=2
)
report_stats(times, num_frames, unit="ms")

print("clips_at_random_timestamps ", end="")
times, num_frames = bench(
sample,
clips_at_random_timestamps,
num_clips=num_clips,
num_exp=NUM_EXP,
warmup=2,
)
report_stats(times, num_frames, unit="ms")

print("clips_at_regular_timestamps ", end="")
seconds_between_clip_starts = 13 / num_clips # approximate. video is 13s long
times, num_frames = bench(
sample,
clips_at_regular_timestamps,
seconds_between_clip_starts=seconds_between_clip_starts,
num_exp=NUM_EXP,
warmup=2,
)
report_stats(times, num_frames, unit="ms")
def run_sampler_benchmarks(device, video):
NUM_EXP = 30

for num_clips in (1, 50):
print("-" * 10)
print(f"{num_clips = }")

print("clips_at_random_indices ", end="")
decoder = VideoDecoder(video, device=device)
times, num_frames = bench(
sample,
decoder,
clips_at_random_indices,
num_clips=num_clips,
num_exp=NUM_EXP,
warmup=2,
)
report_stats(times, num_frames, unit="ms")

print("clips_at_regular_indices ", end="")
times, num_frames = bench(
sample,
decoder,
clips_at_regular_indices,
num_clips=num_clips,
num_exp=NUM_EXP,
warmup=2,
)
report_stats(times, num_frames, unit="ms")

print("clips_at_random_timestamps ", end="")
times, num_frames = bench(
sample,
decoder,
clips_at_random_timestamps,
num_clips=num_clips,
num_exp=NUM_EXP,
warmup=2,
)
report_stats(times, num_frames, unit="ms")

print("clips_at_regular_timestamps ", end="")
seconds_between_clip_starts = 13 / num_clips # approximate. video is 13s long
times, num_frames = bench(
sample,
decoder,
clips_at_regular_timestamps,
seconds_between_clip_starts=seconds_between_clip_starts,
num_exp=NUM_EXP,
warmup=2,
)
report_stats(times, num_frames, unit="ms")


def main():
DEFAULT_VIDEO_PATH = Path(__file__).parent / "../../test/resources/nasa_13013.mp4"
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--video", type=str, default=str(DEFAULT_VIDEO_PATH))
args = parser.parse_args()
run_sampler_benchmarks(args.device, args.video)


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion src/torchcodec/decoders/_core/CPUOnlyDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ void convertAVFrameToDecodedOutputOnCuda(
const VideoDecoder::VideoStreamDecoderOptions& options,
AVCodecContext* codecContext,
VideoDecoder::RawDecodedOutput& rawOutput,
VideoDecoder::DecodedOutput& output) {
VideoDecoder::DecodedOutput& output,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
throwUnsupportedDeviceError(device);
}

Expand Down
19 changes: 17 additions & 2 deletions src/torchcodec/decoders/_core/CudaDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ void convertAVFrameToDecodedOutputOnCuda(
const VideoDecoder::VideoStreamDecoderOptions& options,
AVCodecContext* codecContext,
VideoDecoder::RawDecodedOutput& rawOutput,
VideoDecoder::DecodedOutput& output) {
VideoDecoder::DecodedOutput& output,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
AVFrame* src = rawOutput.frame.get();

TORCH_CHECK(
Expand All @@ -213,7 +214,21 @@ void convertAVFrameToDecodedOutputOnCuda(
NppiSize oSizeROI = {width, height};
Npp8u* input[2] = {src->data[0], src->data[1]};
torch::Tensor& dst = output.frame;
dst = allocateDeviceTensor({height, width, 3}, options.device);
if (preAllocatedOutputTensor.has_value()) {
dst = preAllocatedOutputTensor.value();
auto shape = dst.sizes();
TORCH_CHECK(
(shape.size() == 3) && (shape[0] == height) && (shape[1] == width) &&
(shape[2] == 3),
"Expected tensor of shape ",
height,
"x",
width,
"x3, got ",
shape);
} else {
dst = allocateDeviceTensor({height, width, 3}, options.device);
}

// Use the user-requested GPU for running the NPP kernel.
c10::cuda::CUDAGuard deviceGuard(device);
Expand Down
3 changes: 2 additions & 1 deletion src/torchcodec/decoders/_core/DeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ void convertAVFrameToDecodedOutputOnCuda(
const VideoDecoder::VideoStreamDecoderOptions& options,
AVCodecContext* codecContext,
VideoDecoder::RawDecodedOutput& rawOutput,
VideoDecoder::DecodedOutput& output);
VideoDecoder::DecodedOutput& output,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);

void releaseContextOnCuda(
const torch::Device& device,
Expand Down
7 changes: 4 additions & 3 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
options.height.value_or(*metadata.height),
options.width.value_or(*metadata.width),
3},
{torch::kUInt8})),
at::TensorOptions(options.device).dtype(torch::kUInt8))),
ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {}

Expand Down Expand Up @@ -855,17 +855,18 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
output.duration = getDuration(frame);
output.durationSeconds = ptsToSeconds(
getDuration(frame), formatContext_->streams[streamIndex]->time_base);
// TODO: we should fold preAllocatedOutputTensor into RawDecodedOutput.
if (streamInfo.options.device.type() == torch::kCPU) {
convertAVFrameToDecodedOutputOnCPU(
rawOutput, output, preAllocatedOutputTensor);
} else if (streamInfo.options.device.type() == torch::kCUDA) {
// TODO: handle pre-allocated output tensor
convertAVFrameToDecodedOutputOnCuda(
streamInfo.options.device,
streamInfo.options,
streamInfo.codecContext.get(),
rawOutput,
output);
output,
preAllocatedOutputTensor);
} else {
TORCH_CHECK(
false, "Invalid device type: " + streamInfo.options.device.str());
Expand Down
17 changes: 10 additions & 7 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pathlib import Path
from typing import Literal, Optional, Tuple, Union

from torch import Tensor
from torch import device, Tensor

from torchcodec import Frame, FrameBatch
from torchcodec.decoders import _core as core
Expand Down Expand Up @@ -36,19 +36,20 @@ class VideoDecoder:
This can be either "NCHW" (default) or "NHWC", where N is the batch
size, C is the number of channels, H is the height, and W is the
width of the frames.
num_ffmpeg_threads (int, optional): The number of threads to use for decoding.
Use 1 for single-threaded decoding which may be best if you are running multiple
instances of ``VideoDecoder`` in parallel. Use a higher number for multi-threaded
decoding which is best if you are running a single instance of ``VideoDecoder``.
Default: 1.

.. note::

Frames are natively decoded in NHWC format by the underlying
FFmpeg implementation. Converting those into NCHW format is a
cheap no-copy operation that allows these frames to be
transformed using the `torchvision transforms
<https://pytorch.org/vision/stable/transforms.html>`_.
num_ffmpeg_threads (int, optional): The number of threads to use for decoding.
Use 1 for single-threaded decoding which may be best if you are running multiple
instances of ``VideoDecoder`` in parallel. Use a higher number for multi-threaded
decoding which is best if you are running a single instance of ``VideoDecoder``.
Default: 1.
device (str or torch.device, optional): The device to use for decoding. Default: "cpu".


Attributes:
metadata (VideoStreamMetadata): Metadata of the video stream.
Expand All @@ -64,6 +65,7 @@ def __init__(
stream_index: Optional[int] = None,
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
num_ffmpeg_threads: int = 1,
device: Optional[Union[str, device]] = "cpu",
):
if isinstance(source, str):
self._decoder = core.create_from_file(source)
Expand Down Expand Up @@ -92,6 +94,7 @@ def __init__(
stream_index=stream_index,
dimension_order=dimension_order,
num_threads=num_ffmpeg_threads,
device=device,
)

self.metadata, self.stream_index = _get_and_validate_stream_metadata(
Expand Down
64 changes: 61 additions & 3 deletions test/decoders/test_video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@
seek_to_pts,
)

from ..utils import assert_tensor_equal, NASA_AUDIO, NASA_VIDEO, needs_cuda
from ..utils import (
assert_tensor_close_on_at_least,
assert_tensor_equal,
NASA_AUDIO,
NASA_VIDEO,
needs_cuda,
)

torch._dynamo.config.capture_dynamic_output_shape_ops = True

Expand Down Expand Up @@ -137,6 +143,24 @@ def test_get_frames_at_indices(self):
assert_tensor_equal(frames0and180[0], reference_frame0)
assert_tensor_equal(frames0and180[1], reference_frame180)

@needs_cuda
def test_get_frames_at_indices_with_cuda(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll also want to test get_frames_in_range, and all the batch-APIs?
I feel like we should be parametrizing a fair amount of our tests. But this can be done as a follow-up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do that as a follow-up

decoder = create_from_file(str(NASA_VIDEO.path))
scan_all_streams_to_update_metadata(decoder)
add_video_stream(decoder, device="cuda")
frames0and180, *_ = get_frames_at_indices(
decoder, stream_index=3, frame_indices=[0, 180]
)
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
reference_frame180 = NASA_VIDEO.get_frame_data_by_index(
INDEX_OF_FRAME_AT_6_SECONDS
)
assert frames0and180.device.type == "cuda"
assert_tensor_close_on_at_least(frames0and180[0].to("cpu"), reference_frame0)
assert_tensor_close_on_at_least(
frames0and180[1].to("cpu"), reference_frame180, 0.3, 30
)

def test_get_frames_at_indices_unsorted_indices(self):
decoder = create_from_file(str(NASA_VIDEO.path))
_add_video_stream(decoder)
Expand Down Expand Up @@ -198,6 +222,40 @@ def test_get_frames_by_pts(self):
with pytest.raises(AssertionError):
assert_tensor_equal(frames[0], frames[-1])

# TODO: Figure out how to parameterize this test to run on both CPU and CUDA.abs
# The question is how to have the @needs_cuda decorator with the pytest.mark.parametrize
# decorator on the same test.
Comment on lines +225 to +227
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's simple!

We just need to define this new util

https://github.com/pytorch/vision/blob/e9a3213524a0abd609ac7330cf170b9e19917d39/test/common_utils.py#L122-L125

and it can be used like this

https://github.com/pytorch/vision/blob/e9a3213524a0abd609ac7330cf170b9e19917d39/test/test_utils.py#L221

If you want, we can merge this PR as-is and follow-up with that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do that as a follow-up

@needs_cuda
def test_get_frames_by_pts_with_cuda(self):
decoder = create_from_file(str(NASA_VIDEO.path))
_add_video_stream(decoder, device="cuda")
scan_all_streams_to_update_metadata(decoder)
stream_index = 3

# Note: 13.01 should give the last video frame for the NASA video
timestamps = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3]

expected_frames = [
get_frame_at_pts(decoder, seconds=pts)[0] for pts in timestamps
]

frames, *_ = get_frames_by_pts(
decoder,
stream_index=stream_index,
timestamps=timestamps,
)
for frame, expected_frame in zip(frames, expected_frames):
assert_tensor_equal(frame, expected_frame)

# first and last frame should be equal, at pts=2 [+ eps]. We then modify
# the first frame and assert that it's now different from the last
# frame. This ensures a copy was properly made during the de-duplication
# logic.
assert_tensor_equal(frames[0], frames[-1])
frames[0] += 20
with pytest.raises(AssertionError):
assert_tensor_equal(frames[0], frames[-1])

def test_pts_apis_against_index_ref(self):
# Non-regression test for https://github.com/pytorch/torchcodec/pull/287
# Get all frames in the video, then query all frames with all time-based
Expand Down Expand Up @@ -657,8 +715,8 @@ def test_cuda_decoder(self):
assert frame0.device.type == "cuda"
frame0_cpu = frame0.to("cpu")
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
# GPU decode is not bit-accurate. In the following assertion we ensure
# not more than 0.3% of values have a difference greater than 20.
# GPU decode is not bit-accurate. So we allow some tolerance.
assert_tensor_close_on_at_least(frame0_cpu, reference_frame0)
diff = (reference_frame0.float() - frame0_cpu.float()).abs()
assert (diff > 20).float().mean() <= 0.003
assert pts == torch.tensor([0])
Expand Down
7 changes: 7 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ def assert_tensor_equal(*args, **kwargs):
torch.testing.assert_close(*args, **kwargs, atol=absolute_tolerance, rtol=0)


# Asserts that at least `percentage`% of the values are within the absolute tolerance.
def assert_tensor_close_on_at_least(frame1, frame2, percentage=99.7, abs_tolerance=20):
diff = (frame2.float() - frame1.float()).abs()
diff_percentage = 100.0 - percentage
assert (diff > abs_tolerance).float().mean() <= diff_percentage / 100.0


# For use with floating point metadata, or in other instances where we are not confident
# that reference and test tensors can be exactly equal. This is true for pts and duration
# in seconds, as the reference values are from ffprobe's JSON output. In that case, it is
Expand Down
Loading