diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index 263dadfd51f..0ca8fc46678 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -47,6 +47,27 @@ torch::Tensor decode_jpeg_cuda( TORCH_CHECK(device.is_cuda(), "Expected a cuda device") + int major_version; + int minor_version; + nvjpegStatus_t get_major_property_status = + nvjpegGetProperty(MAJOR_VERSION, &major_version); + nvjpegStatus_t get_minor_property_status = + nvjpegGetProperty(MINOR_VERSION, &minor_version); + + TORCH_CHECK( + get_major_property_status == NVJPEG_STATUS_SUCCESS, + "nvjpegGetProperty failed: ", + get_major_property_status); + TORCH_CHECK( + get_minor_property_status == NVJPEG_STATUS_SUCCESS, + "nvjpegGetProperty failed: ", + get_minor_property_status); + if ((major_version < 11) || ((major_version == 11) && (minor_version < 6))) { + TORCH_WARN_ONCE( + "There is a memory leak issue in the nvjpeg library for CUDA versions < 11.6. " + "Make sure to rely on CUDA 11.6 or above before using decode_jpeg(..., device='cuda')."); + } + at::cuda::CUDAGuard device_guard(device); // Create global nvJPEG handle diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 7f5aa78880d..339fe4318aa 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -145,6 +145,10 @@ def decode_jpeg( with `nvjpeg `_. This is only supported for CUDA version >= 10.1 + .. warning:: + There is a memory leak in the nvjpeg library for CUDA versions < 11.6. + Make sure to rely on CUDA 11.6 or above before using ``device="cuda"``. + Returns: output (Tensor[image_channels, image_height, image_width]) """