Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 169 additions & 9 deletions src/torchcodec/_core/CudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ static bool g_cuda =
return new CudaDeviceInterface(device);
});

// BT.709 full range color conversion matrix for YUV to RGB conversion.
// See Note [YUV -> RGB Color Conversion, color space and color range] below.
constexpr Npp32f bt709FullRangeColorTwist[3][4] = {
{1.0f, 0.0f, 1.5748f, 0.0f},
{1.0f, -0.187324273f, -0.468124273f, -128.0f},
{1.0f, 1.8556f, 0.0f, -128.0f}};

// We reuse cuda contexts across VideoDeoder instances. This is because
// creating a cuda context is expensive. The cache mechanism is as follows:
// 1. There is a cache of size MAX_CONTEXTS_PER_GPU_IN_CACHE cuda contexts for
Expand Down Expand Up @@ -312,21 +319,54 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
static_cast<int>(getFFMPEGCompatibleDeviceIndex(device_)));

NppiSize oSizeROI = {width, height};
Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]};
Npp8u* yuvData[2] = {avFrame->data[0], avFrame->data[1]};

NppStatus status;

// For background, see
// Note [YUV -> RGB Color Conversion, color space and color range]
if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) {
status = nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx(
input,
avFrame->linesize[0],
static_cast<Npp8u*>(dst.data_ptr()),
dst.stride(0),
oSizeROI,
nppCtx);
if (avFrame->color_range == AVColorRange::AVCOL_RANGE_JPEG) {
// NPP provides a pre-defined color conversion function for BT.709 full
// range: nppiNV12ToRGB_709HDTV_8u_P2C3R_Ctx. But it's not closely
// matching the results we have on CPU. So we're using a custom color
// conversion matrix, which provides more accurate results. See the note
// mentioned above for details, and headaches.

int srcStep[2] = {avFrame->linesize[0], avFrame->linesize[1]};

status = nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx(
yuvData,
srcStep,
static_cast<Npp8u*>(dst.data_ptr()),
dst.stride(0),
oSizeROI,
bt709FullRangeColorTwist,
nppCtx);
} else {
// If not full range, we assume studio limited range.
// The color conversion matrix for BT.709 limited range should be:
// static const Npp32f bt709LimitedRangeColorTwist[3][4] = {
// {1.16438356f, 0.0f, 1.79274107f, -16.0f},
// {1.16438356f, -0.213248614f, -0.5329093290f, -128.0f},
// {1.16438356f, 2.11240179f, 0.0f, -128.0f}
// };
// We get very close results to CPU with that, but using the pre-defined
// nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx seems to be even more accurate.
status = nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx(
yuvData,
avFrame->linesize[0],
static_cast<Npp8u*>(dst.data_ptr()),
dst.stride(0),
oSizeROI,
nppCtx);
}
} else {
// TODO we're assuming BT.601 color space (and probably limited range) by
// calling nppiNV12ToRGB_8u_P2C3R_Ctx. We should handle BT.601 full range,
// and other color-spaces like 2020.
status = nppiNV12ToRGB_8u_P2C3R_Ctx(
input,
yuvData,
avFrame->linesize[0],
static_cast<Npp8u*>(dst.data_ptr()),
dst.stride(0),
Expand Down Expand Up @@ -362,3 +402,123 @@ std::optional<const AVCodec*> CudaDeviceInterface::findCodec(
}

} // namespace facebook::torchcodec

/* clang-format off */
// Note: [YUV -> RGB Color Conversion, color space and color range]
//
// The frames we get from the decoder (FFmpeg decoder, or NVCUVID) are in YUV
// format. We need to convert them to RGB. This note attempts to describe this
// process. There may be some inaccuracies and approximations that experts will
// notice, but our goal is only to provide a good enough understanding of the
// process for torchcodec developers to implement and maintain it.
// On CPU, filtergraph and swscale handle everything for us. With CUDA, we have
// to do a lot of the heavy lifting ourselves.
//
// Color space and color range
// ---------------------------
// Two main characteristics of a frame will affect the conversion process:
// 1. Color space: This basically defines what YUV values correspond to which
// physical wavelength. No need to go into details here,the point is that
// videos can come in different color spaces, the most common ones being
// BT.601 and BT.709, but there are others.
// In FFmpeg this is represented with AVColorSpace:
// https://ffmpeg.org/doxygen/4.0/pixfmt_8h.html#aff71a069509a1ad3ff54d53a1c894c85
// 2. Color range: This defines the range of YUV values. There is:
// - full range, also called PC range: AVCOL_RANGE_JPEG
// - and the "limited" range, also called studio or TV range: AVCOL_RANGE_MPEG
// https://ffmpeg.org/doxygen/4.0/pixfmt_8h.html#a3da0bf691418bc22c4bcbe6583ad589a
//
// Color space and color range are independent concepts, so we can have a BT.709
// with full range, and another one with limited range. Same for BT.601.
//
// In the first version of this note we'll focus on the full color range. It
// will later be updated to account for the limited range.
//
// Color conversion matrix
// -----------------------
// YUV -> RGB conversion is defined as the reverse process of the RGB -> YUV,
// So this is where we'll start.
// At the core of a RGB -> YUV conversion are the "luma coefficients", which are
// specific to a given color space and defined by the color space standard. In
// FFmpeg they can be found here:
// https://github.com/FFmpeg/FFmpeg/blob/7d606ef0ccf2946a4a21ab1ec23486cadc21864b/libavutil/csp.c#L46-L56
//
// For example, the BT.709 coefficients are: kr=0.2126, kg=0.7152, kb=0.0722
// Coefficients must sum to 1.
//
// Conventionally Y is in [0, 1] range, and U and V are in [-0.5, 0.5] range
// (that's mathematically, in practice they are represented in integer range).
// The conversion is defined as:
// https://en.wikipedia.org/wiki/YCbCr#R'G'B'_to_Y%E2%80%B2PbPr
// Y = kr*R + kg*G + kb*B
// U = (B - Y) * 0.5 / (1 - kb) = (B - Y) / u_scale where u_scale = 2 * (1 - kb)
// V = (R - Y) * 0.5 / (1 - kr) = (R - Y) / v_scale where v_scale = 2 * (1 - kr)
//
// Putting all this into matrix form, we get:
// [Y] = [kr kg kb ] [R]
// [U] [-kr/u_scale -kg/u_scale (1-kb)/u_scale] [G]
// [V] [(1-kr)/v_scale -kg/v_scale -kb)/v_scale ] [B]
//
//
// Now, to convert YUV to RGB, we just need to invert this matrix:
// ```py
// import torch
// kr, kg, kb = 0.2126, 0.7152, 0.0722 # BT.709 luma coefficients
// u_scale = 2 * (1 - kb)
// v_scale = 2 * (1 - kr)
//
// rgb_to_yuv = torch.tensor([
// [kr, kg, kb],
// [-kr/u_scale, -kg/u_scale, (1-kb)/u_scale],
// [(1-kr)/v_scale, -kg/v_scale, -kb/v_scale]
// ])
//
// yuv_to_rgb_full = torch.linalg.inv(rgb_to_yuv)
// print("YUV->RGB matrix (Full Range):")
// print(yuv_to_rgb_full)
// ```
// And we get:
// tensor([[ 1.0000e+00, -3.3142e-09, 1.5748e+00],
// [ 1.0000e+00, -1.8732e-01, -4.6812e-01],
// [ 1.0000e+00, 1.8556e+00, 4.6231e-09]])
//
// Which matches https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.709_conversion
//
// Color conversion in NPP
// -----------------------
// https://docs.nvidia.com/cuda/npp/image_color_conversion.html.
//
// NPP provides different ways to convert YUV to RGB:
// - pre-defined color conversion functions like
// nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx and nppiNV12ToRGB_709HDTV_8u_P2C3R_Ctx
// which are for BT.709 limited and full range, respectively.
// - generic color conversion functions that accept a custom color conversion
// matrix, called ColorTwist, like nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx
//
// We use the pre-defined functions or the color twist functions depending on
// which one we find to be closer to the CPU results.
//
// The color twist functionality is *partially* described in a section named
// "YUVToRGBColorTwist". Importantly:
//
// - The `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` function takes the YUV data
// and the color-conversion matrix as input. The function itself and the
// matrix assume different ranges for YUV values:
// - The **matrix coefficient** must assume that Y is in [0, 1] and U,V are in
// [-0.5, 0.5]. That's how we defined our matrix above.
// - The function `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` however expects all
// of the input Y, U, V to be in [0, 255]. That's how the data comes out of
// the decoder.
// - But *internally*, `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` needs U and V to
// be centered around 0, i.e. in [-128, 127]. So we need to apply a -128
// offset to U and V. Y doesn't need to be offset. The offset can be applied
// by adding a 4th column to the matrix.
//
//
// So our conversion matrix becomes the following, with new offset column:
// tensor([[ 1.0000e+00, -3.3142e-09, 1.5748e+00, 0]
// [ 1.0000e+00, -1.8732e-01, -4.6812e-01, -128]
// [ 1.0000e+00, 1.8556e+00, 4.6231e-09 , -128]])
//
// And that's what we need to pass for BT701, full range.
/* clang-format on */
Binary file added test/resources/bt709_full_range.mp4
Binary file not shown.
27 changes: 27 additions & 0 deletions test/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
all_supported_devices,
assert_frames_equal,
AV1_VIDEO,
BT709_FULL_RANGE,
cuda_version_used_for_building_torch,
get_ffmpeg_major_version,
H264_10BITS,
H265_10BITS,
Expand All @@ -35,6 +37,7 @@
NASA_AUDIO_MP3_44100,
NASA_VIDEO,
needs_cuda,
psnr,
SINE_MONO_S16,
SINE_MONO_S32,
SINE_MONO_S32_44100,
Expand Down Expand Up @@ -1197,6 +1200,30 @@ def test_pts_to_dts_fallback(self, seek_mode):
with pytest.raises(AssertionError, match="not equal"):
torch.testing.assert_close(decoder[0], decoder[10])

@needs_cuda
@pytest.mark.parametrize("asset", (BT709_FULL_RANGE, NASA_VIDEO))
def test_full_and_studio_range_bt709_video(self, asset):
# Test ensuring result consistency between CPU and GPU decoder on BT709
# videos, one with full color range, one with studio range.
# This is a non-regression test for times when we used to not support
# full range on GPU.
#
# NASA_VIDEO is a BT709 studio range video, as can be confirmed with
# ffprobe -v quiet -select_streams v:0 -show_entries
# stream=color_space,color_transfer,color_primaries,color_range -of
# default=noprint_wrappers=1 test/resources/nasa_13013.mp4
decoder_gpu = VideoDecoder(asset.path, device="cuda")
decoder_cpu = VideoDecoder(asset.path, device="cpu")

for frame_index in (0, 10, 20, 5):
gpu_frame = decoder_gpu.get_frame_at(frame_index).data.cpu()
cpu_frame = decoder_cpu.get_frame_at(frame_index).data

if cuda_version_used_for_building_torch() >= (12, 9):
torch.testing.assert_close(gpu_frame, cpu_frame, rtol=0, atol=2)
elif cuda_version_used_for_building_torch() == (12, 8):
assert psnr(gpu_frame, cpu_frame) > 20

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I developed this PR on CUDA 12.9, and I was unconditionally using torch.testing.assert_close(gpu_frame, cpu_frame, rtol=0, atol=2) which was passing. And it's passing on the 12.9 CI job.

When I submitted the PR and the CI tested on CUDA 12.6, I realized the test wasn't passing. I'm unable to tell by how much, and I'm unable to reproduce locally because I don't have 12.6, and I can't ssh into the runner either.

12.8 is producing OK results, with a psnr of ~24, but it's not as good as with 12.9.

I honestly think we should treat this as bugs in NPP that were eventually fixed in 12.9. I can't imagine us having to use different code-paths depending on the current runtime CUDA version. That sounds too complicated, and I'm not even sure that is doable. I.e. this isn't about compile-time #define checks, that wouldn't be enough, because we can compile on 12.9 and run on 12.8.

Note that 12.6 is considered to be legacy support from now on with torch: pytorch/pytorch#159980

@needs_cuda
def test_10bit_videos_cuda(self):
# Assert that we raise proper error on different kinds of 10bit videos.
Expand Down
46 changes: 46 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,31 @@ def get_ffmpeg_major_version():
return int(ffmpeg_version.split(".")[0])


def cuda_version_used_for_building_torch() -> Optional[tuple[int, int]]:
# Return the CUDA version that was used to build PyTorch. That's not always
# the same as the CUDA version that is currently installed on the running
# machine, which is what we actually want. On the CI though, these are the
# same.
if torch.version.cuda is None:
return None
else:
return tuple(int(x) for x in torch.version.cuda.split("."))


def psnr(a, b, max_val=255) -> float:
# Return Peak Signal-to-Noise Ratio (PSNR) between two tensors a and b. The
# higher, the better.
# According to https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio,
# typical values for the PSNR in lossy image and video compression are
# between 30 and 50 dB.
# Acceptable values for wireless transmission quality loss are considered to
# be about 20 dB to 25 dB
mse = torch.mean((a.float() - b.float()) ** 2)
if mse == 0:
return float("inf")
return 20 * torch.log10(max_val / torch.sqrt(mse)).item()


# For use with decoded data frames. On CPU Linux, we expect exact, bit-for-bit
# equality. On CUDA Linux, we expect a small tolerance.
# On other platforms (e.g. MacOS), we also allow a small tolerance. FFmpeg does
Expand Down Expand Up @@ -637,3 +662,24 @@ def sample_format(self) -> str:
},
},
)


# This is a BT.709 full range video, generated with:
# ffmpeg -f lavfi -i testsrc2=duration=1:size=1920x720:rate=30 \
# -c:v libx264 -pix_fmt yuv420p -color_primaries bt709 -color_trc bt709 \
# -colorspace bt709 -color_range pc bt709_full_range.mp4
#
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up until this PR, we've maintained the rule that all generated references can be generated from the generate_reference_resources.sh script. Is that something we want to continue to maintain? I think there is a lot of value in it, but that script is also not the cleanest artifact.

Copy link
Contributor Author

@NicolasHug NicolasHug Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I can tell this script generates the bmp / pt reference frames, but not the source video themselves. I see similar comments there indicating how the videos were generated:

https://github.com/pytorch/torchcodec/blob/ffcb7ab2e98c204dfb103f46b2db154cbf1aa713/test/generate_reference_resources.sh#L65-L66

Here we're not generating or using the frames, we're just comparing the CPU output with the GPU output.

I agree we should also try to check against a ground truth reference, but I'll leave that out as a follow-up if that's OK

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're actually doing a bit of both, which is messy. I'm going to create an issue about it.

# We can confirm the color space and color range with:
# ffprobe -v quiet -select_streams v:0 -show_entries stream=color_space,color_transfer,color_primaries,color_range -of default=noprint_wrappers=1 test/resources/bt709_full_range.mp4
# color_range=pc
# color_space=bt709
# color_transfer=bt709
# color_primaries=bt709
BT709_FULL_RANGE = TestVideo(
filename="bt709_full_range.mp4",
default_stream_index=0,
stream_infos={
0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3),
},
frames={0: {}}, # Not needed for now
)
Loading