From 133d7c1df2251332e3d56da6bacb05d2642ae76b Mon Sep 17 00:00:00 2001 From: Dominik Kallusky Date: Fri, 12 Apr 2024 16:10:11 -0700 Subject: [PATCH 01/16] Adding GPU acceleration to encode_jpeg Summary: I'm adding GPU support to the existing torchvision.io.encode_jpeg function. If the input tensors are on the GPU, the CUDA version will be used and the CPU version otherwise. Additionally, I'm adding a new function torchvision.io.encode_jpegs (plural) with uses a fused kernel and may be faster than successive calls to the singular version which incurs kernel launch overhead for each call. If it's alright, I'll be happy to refactor decode_jpeg to follow this convention in a follow up PR. Test Plan: 1. pytest test -vvv 2. ufmt format torchvision 3. flake8 torchvision Reviewers: Subscribers: Tasks: Tags: --- test/test_image.py | 105 ++++++++ .../csrc/io/image/cuda/decode_jpeg_cuda.cpp | 24 +- .../csrc/io/image/cuda/decode_jpeg_cuda.h | 15 -- .../io/image/cuda/encode_decode_jpeg_cuda.h | 29 +++ .../csrc/io/image/cuda/encode_jpeg_cuda.cpp | 230 ++++++++++++++++++ torchvision/csrc/io/image/image.cpp | 1 + torchvision/csrc/io/image/image.h | 2 +- torchvision/io/image.py | 42 +++- 8 files changed, 398 insertions(+), 50 deletions(-) delete mode 100644 torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h create mode 100644 torchvision/csrc/io/image/cuda/encode_decode_jpeg_cuda.h create mode 100644 torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp diff --git a/test/test_image.py b/test/test_image.py index 3d9d612b5f3..5e968e5c1a2 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -24,6 +24,7 @@ write_jpeg, write_png, ) +import re IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata") @@ -67,6 +68,110 @@ def normalize_dimensions(img_pil): return img_pil +@needs_cuda +@pytest.mark.parametrize( + "img_path", + [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")], +) +@pytest.mark.parametrize("scripted", (False, True)) +@pytest.mark.parametrize("contiguous", (False, True)) +def test_single_encode_jpeg_cuda(img_path, scripted, contiguous): + decoded_image_tv = read_image(img_path) + encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg + + if "cmyk" in img_path: + pytest.xfail("Encoding a CMYK jpeg isn't supported") + if decoded_image_tv.shape[0] == 1: + pytest.xfail("Decoding a grayscale jpeg isn't supported") + # For more detail as to why check out: https://github.com/NVIDIA/cuda-samples/issues/23#issuecomment-559283013 + if not contiguous: + decoded_image_tv = decoded_image_tv.permute(1, 2, 0).contiguous().permute(2,0,1) + encoded_jpeg_cuda_tv = encode_fn(decoded_image_tv.cuda(), quality=75) + decoded_jpeg_cuda_tv = decode_jpeg(encoded_jpeg_cuda_tv.cpu()) + + # the actual encoded bytestreams from libnvjpeg and libjpeg-turbo differ for the same quality + # instead, we re-decode the encoded image and compare to the original + abs_mean_diff = (decoded_jpeg_cuda_tv.type(torch.float32) - decoded_image_tv.type(torch.float32)).abs().mean().item() + assert abs_mean_diff < 3 + + +@needs_cuda +@pytest.mark.parametrize("scripted", (False, True)) +@pytest.mark.parametrize("contiguous", (False, True)) +def test_batch_encode_jpegs_cuda(scripted, contiguous): + decoded_images_tv = [] + for jpeg_path in get_images(IMAGE_ROOT, ".jpg"): + if "cmyk" in jpeg_path: + continue + decoded_image = read_image(jpeg_path) + if decoded_image.shape[0] == 1: + continue + if not contiguous: + decoded_image = decoded_image.permute(1, 2, 0).contiguous().permute(2,0,1) + decoded_images_tv.append(decoded_image) + + encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg + + decoded_images_tv_cuda = [img.cuda() for img in decoded_images_tv] + encoded_jpegs_tv_cuda = encode_fn(decoded_images_tv_cuda, quality=75) + decoded_jpegs_tv_cuda = [decode_jpeg(img.cpu()) for img in encoded_jpegs_tv_cuda] + + + for original, encoded_decoded in zip(decoded_images_tv, decoded_jpegs_tv_cuda): + c,h,w = original.shape + abs_mean_diff = (original.type(torch.float32) - encoded_decoded.type(torch.float32)).abs().mean().item() + assert abs_mean_diff < 3 + + +@needs_cuda +def test_single_encode_jpeg_cuda_errors(): + with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): + encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32, device="cuda")) + + with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"): + encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda")) + + with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"): + encode_jpeg(torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda")) + + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): + encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda")) + + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): + encode_jpeg(torch.empty((100, 100), dtype=torch.uint8, device="cuda")) + + +@needs_cuda +def test_batch_encode_jpegs_cuda_errors(): + with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): + encode_jpeg([torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), torch.empty((3, 100, 100), dtype=torch.float32, device="cuda")]) + + with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"): + encode_jpeg([torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda")]) + + with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"): + encode_jpeg([torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda")]) + + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): + encode_jpeg([torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda")]) + + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): + encode_jpeg([torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), torch.empty((100, 100), dtype=torch.uint8, device="cuda")]) + + with pytest.raises(RuntimeError, match="Input tensor should be on CPU"): + encode_jpeg([torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"), torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda")]) + + with pytest.raises(RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg"): + encode_jpeg([torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu")]) + + if torch.cuda.device_count() >= 2: + with pytest.raises(RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg"): + encode_jpeg([torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:0"), torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:1")]) + + with pytest.raises(AssertionError, match="encode_jpeg requires at least one input tensor when a list is passed"): + encode_jpeg([]) + + @pytest.mark.parametrize( "img_path", [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")], diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index ee7d432f30d..59f9255f130 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -1,4 +1,4 @@ -#include "decode_jpeg_cuda.h" +#include "encode_decode_jpeg_cuda.h" #include @@ -25,10 +25,6 @@ torch::Tensor decode_jpeg_cuda( #else -namespace { -static nvjpegHandle_t nvjpeg_handle = nullptr; -} - torch::Tensor decode_jpeg_cuda( const torch::Tensor& data, ImageReadMode mode, @@ -71,23 +67,7 @@ torch::Tensor decode_jpeg_cuda( at::cuda::CUDAGuard device_guard(device); // Create global nvJPEG handle - static std::once_flag nvjpeg_handle_creation_flag; - std::call_once(nvjpeg_handle_creation_flag, []() { - if (nvjpeg_handle == nullptr) { - nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); - - if (create_status != NVJPEG_STATUS_SUCCESS) { - // Reset handle so that one can still call the function again in the - // same process if there was a failure - free(nvjpeg_handle); - nvjpeg_handle = nullptr; - } - TORCH_CHECK( - create_status == NVJPEG_STATUS_SUCCESS, - "nvjpegCreateSimple failed: ", - create_status); - } - }); + std::call_once(::nvjpeg_handle_creation_flag, nvjpeg_init); // Create the jpeg state nvjpegJpegState_t jpeg_state; diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h deleted file mode 100644 index 496b355e9b7..00000000000 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include -#include "../image_read_mode.h" - -namespace vision { -namespace image { - -C10_EXPORT torch::Tensor decode_jpeg_cuda( - const torch::Tensor& data, - ImageReadMode mode, - torch::Device device); - -} // namespace image -} // namespace vision diff --git a/torchvision/csrc/io/image/cuda/encode_decode_jpeg_cuda.h b/torchvision/csrc/io/image/cuda/encode_decode_jpeg_cuda.h new file mode 100644 index 00000000000..fa96adc1bca --- /dev/null +++ b/torchvision/csrc/io/image/cuda/encode_decode_jpeg_cuda.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include "../image_read_mode.h" + +#if NVJPEG_FOUND +#include + +extern nvjpegHandle_t nvjpeg_handle; +extern std::once_flag nvjpeg_handle_creation_flag; +#endif + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor decode_jpeg_cuda( + const torch::Tensor& data, + ImageReadMode mode, + torch::Device device); + + +C10_EXPORT std::vector encode_jpeg_cuda( + const std::vector& images, + const int64_t quality); + +void nvjpeg_init(); + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp new file mode 100644 index 00000000000..68be3b8eb55 --- /dev/null +++ b/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp @@ -0,0 +1,230 @@ +#include +#include "c10/core/ScalarType.h" +#include "encode_decode_jpeg_cuda.h" +#include "torch/types.h" + +#include +#include +#include + +#if NVJPEG_FOUND +#include +#include +#include + +nvjpegHandle_t nvjpeg_handle = nullptr; +std::once_flag nvjpeg_handle_creation_flag; + +#endif +#include +#include + +namespace vision { +namespace image { + +#if !NVJPEG_FOUND + +std::vector encode_jpeg_cuda( + const std::vector& images, + const int64_t quality) { + TORCH_CHECK( + false, "decode_jpeg_cuda: torchvision not compiled with nvJPEG support"); +} + +#else + +void nvjpeg_init() { + if (nvjpeg_handle == nullptr) { + nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); + + if (create_status != NVJPEG_STATUS_SUCCESS) { + // Reset handle so that one can still call the function again in the + // same process if there was a failure + free(nvjpeg_handle); + nvjpeg_handle = nullptr; + } + TORCH_CHECK( + create_status == NVJPEG_STATUS_SUCCESS, + "nvjpegCreateSimple failed: ", + create_status); + } +} + +torch::Tensor encode_single_jpeg( + const torch::Tensor& data, + const int64_t quality, + const cudaStream_t stream, + const torch::Device& device, + const nvjpegEncoderState_t& nv_enc_state, + const nvjpegEncoderParams_t& nv_enc_params); + +std::vector encode_jpeg_cuda( + const std::vector& images, + const int64_t quality) { + C10_LOG_API_USAGE_ONCE( + "torchvision.csrc.io.image.cuda.encode_jpeg_cuda.encode_jpeg_cuda"); + + TORCH_CHECK(images.size() > 0, "Empty input tensor list"); + + torch::Device device = images[0].device(); + at::cuda::CUDAGuard device_guard(device); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index()); + + std::vector contig_images; + contig_images.reserve(images.size()); + for (const auto& image : images) { + TORCH_CHECK(image.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + + TORCH_CHECK( + image.device() == device, + "All input tensors must be on the same CUDA device when encoding with nvjpeg") + + TORCH_CHECK( + image.dim() == 3 && image.numel() > 0, + "Input data should be a 3-dimensional tensor"); + + TORCH_CHECK( + image.size(0) == 3, + "The number of channels should be 3, got: ", image.size(0)); + + // nvjpeg requires images to be contiguous + contig_images.push_back(image.contiguous()); + } + + // Create global nvJPEG handle + std::call_once(::nvjpeg_handle_creation_flag, nvjpeg_init); + + nvjpegEncoderState_t nv_enc_state; + nvjpegEncoderParams_t nv_enc_params; + + // initialize nvjpeg structures + // these are rather expensive to create and thus will be reused across multiple calls to encode_single_jpeg + try { + nvjpegStatus_t stateCreateResult = + nvjpegEncoderStateCreate(nvjpeg_handle, &nv_enc_state, stream); + TORCH_CHECK( + stateCreateResult == NVJPEG_STATUS_SUCCESS, + "Failed to create nvjpeg encoder state: ", + stateCreateResult); + + nvjpegStatus_t paramsCreateResult = + nvjpegEncoderParamsCreate(nvjpeg_handle, &nv_enc_params, stream); + TORCH_CHECK( + paramsCreateResult == NVJPEG_STATUS_SUCCESS, + "Failed to create nvjpeg encoder params: ", + paramsCreateResult); + + nvjpegStatus_t paramsQualityStatus = + nvjpegEncoderParamsSetQuality(nv_enc_params, quality, stream); + TORCH_CHECK( + paramsQualityStatus == NVJPEG_STATUS_SUCCESS, + "Failed to set nvjpeg encoder params quality: ", + paramsQualityStatus); + + std::vector encoded_images; + for (const auto& image : contig_images) { + auto encoded_image = encode_single_jpeg( + image, quality, stream, device, nv_enc_state, nv_enc_params); + encoded_images.push_back(encoded_image); + } + // Clean up + nvjpegEncoderStateDestroy(nv_enc_state); + nvjpegEncoderParamsDestroy(nv_enc_params); + return encoded_images; + } catch (const std::exception& e) { + nvjpegEncoderStateDestroy(nv_enc_state); + nvjpegEncoderParamsDestroy(nv_enc_params); + throw; + } +} + +torch::Tensor encode_single_jpeg( + const torch::Tensor& src_image, + const int64_t quality, + const cudaStream_t stream, + const torch::Device& device, + const nvjpegEncoderState_t& nv_enc_state, + const nvjpegEncoderParams_t& nv_enc_params) { + int channels = src_image.size(0); + int height = src_image.size(1); + int width = src_image.size(2); + + nvjpegStatus_t samplingSetResult = nvjpegEncoderParamsSetSamplingFactors( + nv_enc_params, NVJPEG_CSS_444, stream); + TORCH_CHECK( + samplingSetResult == NVJPEG_STATUS_SUCCESS, + "Failed to set nvjpeg encoder params sampling factors: ", + samplingSetResult); + + // Create nvjpeg image + nvjpegImage_t target_image; + + for (int c = 0; c < channels; c++) { + target_image.channel[c] = src_image[c].data_ptr(); + // this is why we need contiguous tensors + target_image.pitch[c] = width; + } + for (int c = channels; c < NVJPEG_MAX_COMPONENT; c++) { + target_image.channel[c] = nullptr; + target_image.pitch[c] = 0; + } + nvjpegStatus_t encodingState; + + // Encode the image + encodingState = nvjpegEncodeImage( + nvjpeg_handle, + nv_enc_state, + nv_enc_params, + &target_image, + NVJPEG_INPUT_RGB, + width, + height, + stream); + + TORCH_CHECK( + encodingState == NVJPEG_STATUS_SUCCESS, + "image encoding failed: ", + encodingState); + + // Retrieve length of the encoded image + size_t length; + nvjpegStatus_t getStreamState = nvjpegEncodeRetrieveBitstreamDevice( + nvjpeg_handle, nv_enc_state, NULL, &length, stream); + TORCH_CHECK( + getStreamState == NVJPEG_STATUS_SUCCESS, + "Failed to retrieve encoded image stream state: ", + getStreamState); + + // Synchronize the stream to ensure that the encoded image is ready + cudaError_t syncState = cudaStreamSynchronize(stream); + TORCH_CHECK(syncState == cudaSuccess, "CUDA ERROR: ", syncState); + + // Reserve buffer for the encoded image + torch::Tensor encoded_image = torch::empty( + {static_cast(length)}, + torch::TensorOptions() + .dtype(torch::kByte) + .layout(torch::kStrided) + .device(device) + .requires_grad(false)); + syncState = cudaStreamSynchronize(stream); + TORCH_CHECK(syncState == cudaSuccess, "CUDA ERROR: ", syncState); + + // Retrieve the encoded image + getStreamState = nvjpegEncodeRetrieveBitstreamDevice( + nvjpeg_handle, + nv_enc_state, + encoded_image.data_ptr(), + &length, + 0); + TORCH_CHECK( + getStreamState == NVJPEG_STATUS_SUCCESS, + "Failed to retrieve encoded image: ", + getStreamState); + return encoded_image; +} + +#endif // NVJPEG_FOUND + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index fb5ee874acb..209a119a809 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -32,6 +32,7 @@ static auto registry = .op("image::decode_image(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", &decode_image) .op("image::decode_jpeg_cuda", &decode_jpeg_cuda) + .op("image::encode_jpeg_cuda", &encode_jpeg_cuda) .op("image::_jpeg_version", &_jpeg_version) .op("image::_is_compiled_against_turbo", &_is_compiled_against_turbo); diff --git a/torchvision/csrc/io/image/image.h b/torchvision/csrc/io/image/image.h index 05bac44c77d..c4d5952af60 100644 --- a/torchvision/csrc/io/image/image.h +++ b/torchvision/csrc/io/image/image.h @@ -6,4 +6,4 @@ #include "cpu/encode_jpeg.h" #include "cpu/encode_png.h" #include "cpu/read_write_file.h" -#include "cuda/decode_jpeg_cuda.h" +#include "cuda/encode_decode_jpeg_cuda.h" diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 8d3b294b32e..ec2830e91ec 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -1,4 +1,5 @@ from enum import Enum +from typing import List, Union from warnings import warn import torch @@ -68,7 +69,9 @@ def write_file(filename: str, data: torch.Tensor) -> None: def decode_png( - input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False + input: torch.Tensor, + mode: ImageReadMode = ImageReadMode.UNCHANGED, + apply_exif_orientation: bool = False, ) -> torch.Tensor: """ Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor. @@ -180,27 +183,38 @@ def decode_jpeg( return output -def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor: +def encode_jpeg( + input: Union[torch.Tensor, List[torch.Tensor]], quality: int = 75 +) -> Union[torch.Tensor, List[torch.Tensor]]: """ - Takes an input tensor in CHW layout and returns a buffer with the contents - of its corresponding JPEG file. + Takes a (list of) input tensor(s) in CHW layout and returns a (list of) buffer(s) with the contents + of the corresponding JPEG file(s). Args: - input (Tensor[channels, image_height, image_width])): int8 image tensor of - ``c`` channels, where ``c`` must be 1 or 3. - quality (int): Quality of the resulting JPEG file, it must be a number between + input (Tensor[channels, image_height, image_width] or List[Tensor[channels, image_height, image_width]]): + (list of) uint8 image tensor(s) of ``c`` channels, where ``c`` must be 1 or 3 + quality (int): Quality of the resulting JPEG file(s). Must be a number between 1 and 100. Default: 75 Returns: - output (Tensor[1]): A one dimensional int8 tensor that contains the raw bytes of the - JPEG file. + output (Tensor[1] or list[Tensor[1]]): A (list of) one dimensional uint8 tensor(s) that contain the raw bytes of the JPEG file. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(encode_jpeg) if quality < 1 or quality > 100: raise ValueError("Image quality should be a positive number between 1 and 100") + if isinstance(input, list): + assert len(input) > 0, "encode_jpeg requires at least one input tensor when a list is passed" + if input[0].device.type == "cuda": + return torch.ops.image.encode_jpeg_cuda(input, quality) + else: + return [torch.ops.image.encode_jpeg(image, quality) for image in input] + else: # single input tensor + if input.device.type == "cuda": + return torch.ops.image.encode_jpeg_cuda([input], quality)[0] + else: + return torch.ops.image.encode_jpeg(input, quality) - output = torch.ops.image.encode_jpeg(input, quality) return output @@ -222,7 +236,9 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): def decode_image( - input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False + input: torch.Tensor, + mode: ImageReadMode = ImageReadMode.UNCHANGED, + apply_exif_orientation: bool = False, ) -> torch.Tensor: """ Detects whether an image is a JPEG or PNG and performs the appropriate @@ -251,7 +267,9 @@ def decode_image( def read_image( - path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False + path: str, + mode: ImageReadMode = ImageReadMode.UNCHANGED, + apply_exif_orientation: bool = False, ) -> torch.Tensor: """ Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor. From 4cc30cbae0def378392d3b73c66846fd68347a62 Mon Sep 17 00:00:00 2001 From: Dominik Kallusky Date: Tue, 23 Apr 2024 12:06:36 -0700 Subject: [PATCH 02/16] fix test cases --- torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp | 1 - torchvision/io/image.py | 1 + torchvision/transforms/v2/functional/_augment.py | 11 ++++++++--- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp index 68be3b8eb55..c55a0d6e55b 100644 --- a/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp @@ -16,7 +16,6 @@ nvjpegHandle_t nvjpeg_handle = nullptr; std::once_flag nvjpeg_handle_creation_flag; #endif -#include #include namespace vision { diff --git a/torchvision/io/image.py b/torchvision/io/image.py index ec2830e91ec..7f5e2d28067 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -232,6 +232,7 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(write_jpeg) output = encode_jpeg(input, quality) + assert isinstance(output, torch.Tensor) write_file(filename, output) diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index eac27f37022..4a806109eae 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -78,9 +78,14 @@ def jpeg_image(image: torch.Tensor, quality: int) -> torch.Tensor: if image.shape[0] == 0: # degenerate return image.reshape(original_shape).clone() - image = [decode_jpeg(encode_jpeg(image[i], quality=quality)) for i in range(image.shape[0])] - image = torch.stack(image, dim=0).view(original_shape) - return image + images = [] + for i in range(image.shape[0]): + encoded_image = encode_jpeg(image[i], quality=quality) + assert isinstance(encoded_image, torch.Tensor) + images.append(decode_jpeg(encoded_image)) + + images = torch.stack(images, dim=0).view(original_shape) + return images @_register_kernel_internal(jpeg, tv_tensors.Video) From 2db02f04afb4622225c29ded301178a18f51e7d4 Mon Sep 17 00:00:00 2001 From: Dominik Kallusky Date: Tue, 23 Apr 2024 14:10:15 -0700 Subject: [PATCH 03/16] fix lints --- test/test_image.py | 79 +++++++++++++++---- .../io/image/cuda/encode_decode_jpeg_cuda.h | 1 - .../csrc/io/image/cuda/encode_jpeg_cuda.cpp | 9 ++- 3 files changed, 68 insertions(+), 21 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 5e968e5c1a2..ae5f67a8fbd 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -1,6 +1,7 @@ import glob import io import os +import re import sys from pathlib import Path @@ -24,7 +25,6 @@ write_jpeg, write_png, ) -import re IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata") @@ -85,13 +85,15 @@ def test_single_encode_jpeg_cuda(img_path, scripted, contiguous): pytest.xfail("Decoding a grayscale jpeg isn't supported") # For more detail as to why check out: https://github.com/NVIDIA/cuda-samples/issues/23#issuecomment-559283013 if not contiguous: - decoded_image_tv = decoded_image_tv.permute(1, 2, 0).contiguous().permute(2,0,1) + decoded_image_tv = decoded_image_tv.permute(1, 2, 0).contiguous().permute(2, 0, 1) encoded_jpeg_cuda_tv = encode_fn(decoded_image_tv.cuda(), quality=75) decoded_jpeg_cuda_tv = decode_jpeg(encoded_jpeg_cuda_tv.cpu()) # the actual encoded bytestreams from libnvjpeg and libjpeg-turbo differ for the same quality # instead, we re-decode the encoded image and compare to the original - abs_mean_diff = (decoded_jpeg_cuda_tv.type(torch.float32) - decoded_image_tv.type(torch.float32)).abs().mean().item() + abs_mean_diff = ( + (decoded_jpeg_cuda_tv.type(torch.float32) - decoded_image_tv.type(torch.float32)).abs().mean().item() + ) assert abs_mean_diff < 3 @@ -107,7 +109,7 @@ def test_batch_encode_jpegs_cuda(scripted, contiguous): if decoded_image.shape[0] == 1: continue if not contiguous: - decoded_image = decoded_image.permute(1, 2, 0).contiguous().permute(2,0,1) + decoded_image = decoded_image.permute(1, 2, 0).contiguous().permute(2, 0, 1) decoded_images_tv.append(decoded_image) encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg @@ -116,9 +118,8 @@ def test_batch_encode_jpegs_cuda(scripted, contiguous): encoded_jpegs_tv_cuda = encode_fn(decoded_images_tv_cuda, quality=75) decoded_jpegs_tv_cuda = [decode_jpeg(img.cpu()) for img in encoded_jpegs_tv_cuda] - for original, encoded_decoded in zip(decoded_images_tv, decoded_jpegs_tv_cuda): - c,h,w = original.shape + c, h, w = original.shape abs_mean_diff = (original.type(torch.float32) - encoded_decoded.type(torch.float32)).abs().mean().item() assert abs_mean_diff < 3 @@ -144,29 +145,73 @@ def test_single_encode_jpeg_cuda_errors(): @needs_cuda def test_batch_encode_jpegs_cuda_errors(): with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): - encode_jpeg([torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), torch.empty((3, 100, 100), dtype=torch.float32, device="cuda")]) + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"), + ] + ) with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"): - encode_jpeg([torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda")]) + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"), + ] + ) with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"): - encode_jpeg([torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda")]) + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"), + ] + ) with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): - encode_jpeg([torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda")]) + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"), + ] + ) with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): - encode_jpeg([torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), torch.empty((100, 100), dtype=torch.uint8, device="cuda")]) + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + torch.empty((100, 100), dtype=torch.uint8, device="cuda"), + ] + ) with pytest.raises(RuntimeError, match="Input tensor should be on CPU"): - encode_jpeg([torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"), torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda")]) - - with pytest.raises(RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg"): - encode_jpeg([torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu")]) + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"), + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + ] + ) + + with pytest.raises( + RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg" + ): + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"), + ] + ) if torch.cuda.device_count() >= 2: - with pytest.raises(RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg"): - encode_jpeg([torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:0"), torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:1")]) + with pytest.raises( + RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg" + ): + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:0"), + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:1"), + ] + ) with pytest.raises(AssertionError, match="encode_jpeg requires at least one input tensor when a list is passed"): encode_jpeg([]) diff --git a/torchvision/csrc/io/image/cuda/encode_decode_jpeg_cuda.h b/torchvision/csrc/io/image/cuda/encode_decode_jpeg_cuda.h index fa96adc1bca..436da00ba8f 100644 --- a/torchvision/csrc/io/image/cuda/encode_decode_jpeg_cuda.h +++ b/torchvision/csrc/io/image/cuda/encode_decode_jpeg_cuda.h @@ -18,7 +18,6 @@ C10_EXPORT torch::Tensor decode_jpeg_cuda( ImageReadMode mode, torch::Device device); - C10_EXPORT std::vector encode_jpeg_cuda( const std::vector& images, const int64_t quality); diff --git a/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp index c55a0d6e55b..e5f3d3d0161 100644 --- a/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp @@ -72,7 +72,8 @@ std::vector encode_jpeg_cuda( std::vector contig_images; contig_images.reserve(images.size()); for (const auto& image : images) { - TORCH_CHECK(image.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + TORCH_CHECK( + image.dtype() == torch::kU8, "Input tensor dtype should be uint8"); TORCH_CHECK( image.device() == device, @@ -84,7 +85,8 @@ std::vector encode_jpeg_cuda( TORCH_CHECK( image.size(0) == 3, - "The number of channels should be 3, got: ", image.size(0)); + "The number of channels should be 3, got: ", + image.size(0)); // nvjpeg requires images to be contiguous contig_images.push_back(image.contiguous()); @@ -97,7 +99,8 @@ std::vector encode_jpeg_cuda( nvjpegEncoderParams_t nv_enc_params; // initialize nvjpeg structures - // these are rather expensive to create and thus will be reused across multiple calls to encode_single_jpeg + // these are rather expensive to create and thus will be reused across + // multiple calls to encode_single_jpeg try { nvjpegStatus_t stateCreateResult = nvjpegEncoderStateCreate(nvjpeg_handle, &nv_enc_state, stream); From 6acef83c8eec180d25aed7ee164961e7f3840ebd Mon Sep 17 00:00:00 2001 From: Dominik Kallusky Date: Tue, 23 Apr 2024 14:26:22 -0700 Subject: [PATCH 04/16] fix lints2 --- test/test_image.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_image.py b/test/test_image.py index ae5f67a8fbd..e67dfcdc73b 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -1,7 +1,6 @@ import glob import io import os -import re import sys from pathlib import Path From ae0450d86e73cfb32499a4d57e89565e23dfb820 Mon Sep 17 00:00:00 2001 From: Dominik Kallusky Date: Mon, 29 Apr 2024 15:59:22 -0700 Subject: [PATCH 05/16] latest round of updates --- test/test_image.py | 311 +++++++++--------- .../csrc/io/image/cuda/encode_jpeg_cuda.cpp | 7 +- torchvision/io/image.py | 9 +- 3 files changed, 174 insertions(+), 153 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index e67dfcdc73b..068e2670f80 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -67,155 +67,6 @@ def normalize_dimensions(img_pil): return img_pil -@needs_cuda -@pytest.mark.parametrize( - "img_path", - [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")], -) -@pytest.mark.parametrize("scripted", (False, True)) -@pytest.mark.parametrize("contiguous", (False, True)) -def test_single_encode_jpeg_cuda(img_path, scripted, contiguous): - decoded_image_tv = read_image(img_path) - encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg - - if "cmyk" in img_path: - pytest.xfail("Encoding a CMYK jpeg isn't supported") - if decoded_image_tv.shape[0] == 1: - pytest.xfail("Decoding a grayscale jpeg isn't supported") - # For more detail as to why check out: https://github.com/NVIDIA/cuda-samples/issues/23#issuecomment-559283013 - if not contiguous: - decoded_image_tv = decoded_image_tv.permute(1, 2, 0).contiguous().permute(2, 0, 1) - encoded_jpeg_cuda_tv = encode_fn(decoded_image_tv.cuda(), quality=75) - decoded_jpeg_cuda_tv = decode_jpeg(encoded_jpeg_cuda_tv.cpu()) - - # the actual encoded bytestreams from libnvjpeg and libjpeg-turbo differ for the same quality - # instead, we re-decode the encoded image and compare to the original - abs_mean_diff = ( - (decoded_jpeg_cuda_tv.type(torch.float32) - decoded_image_tv.type(torch.float32)).abs().mean().item() - ) - assert abs_mean_diff < 3 - - -@needs_cuda -@pytest.mark.parametrize("scripted", (False, True)) -@pytest.mark.parametrize("contiguous", (False, True)) -def test_batch_encode_jpegs_cuda(scripted, contiguous): - decoded_images_tv = [] - for jpeg_path in get_images(IMAGE_ROOT, ".jpg"): - if "cmyk" in jpeg_path: - continue - decoded_image = read_image(jpeg_path) - if decoded_image.shape[0] == 1: - continue - if not contiguous: - decoded_image = decoded_image.permute(1, 2, 0).contiguous().permute(2, 0, 1) - decoded_images_tv.append(decoded_image) - - encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg - - decoded_images_tv_cuda = [img.cuda() for img in decoded_images_tv] - encoded_jpegs_tv_cuda = encode_fn(decoded_images_tv_cuda, quality=75) - decoded_jpegs_tv_cuda = [decode_jpeg(img.cpu()) for img in encoded_jpegs_tv_cuda] - - for original, encoded_decoded in zip(decoded_images_tv, decoded_jpegs_tv_cuda): - c, h, w = original.shape - abs_mean_diff = (original.type(torch.float32) - encoded_decoded.type(torch.float32)).abs().mean().item() - assert abs_mean_diff < 3 - - -@needs_cuda -def test_single_encode_jpeg_cuda_errors(): - with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): - encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32, device="cuda")) - - with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"): - encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda")) - - with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"): - encode_jpeg(torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda")) - - with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): - encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda")) - - with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): - encode_jpeg(torch.empty((100, 100), dtype=torch.uint8, device="cuda")) - - -@needs_cuda -def test_batch_encode_jpegs_cuda_errors(): - with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): - encode_jpeg( - [ - torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), - torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"), - ] - ) - - with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"): - encode_jpeg( - [ - torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), - torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"), - ] - ) - - with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"): - encode_jpeg( - [ - torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), - torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"), - ] - ) - - with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): - encode_jpeg( - [ - torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), - torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"), - ] - ) - - with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): - encode_jpeg( - [ - torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), - torch.empty((100, 100), dtype=torch.uint8, device="cuda"), - ] - ) - - with pytest.raises(RuntimeError, match="Input tensor should be on CPU"): - encode_jpeg( - [ - torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"), - torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), - ] - ) - - with pytest.raises( - RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg" - ): - encode_jpeg( - [ - torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), - torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"), - ] - ) - - if torch.cuda.device_count() >= 2: - with pytest.raises( - RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg" - ): - encode_jpeg( - [ - torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:0"), - torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:1"), - ] - ) - - with pytest.raises(AssertionError, match="encode_jpeg requires at least one input tensor when a list is passed"): - encode_jpeg([]) - - @pytest.mark.parametrize( "img_path", [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")], @@ -654,6 +505,168 @@ def test_encode_jpeg(img_path, scripted): assert_equal(encoded_jpeg_torch, encoded_jpeg_pil) +@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031") +@pytest.mark.parametrize("scripted", (True, False)) +@pytest.mark.parametrize("contiguous", (True, False)) +def test_batch_encode_jpegs(scripted, contiguous): + _test_batch_encode_jpegs_helper(scripted, contiguous, "cpu") + + +@needs_cuda +@pytest.mark.parametrize( + "img_path", + [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")], +) +@pytest.mark.parametrize("scripted", (False, True)) +@pytest.mark.parametrize("contiguous", (False, True)) +def test_single_encode_jpeg_cuda(img_path, scripted, contiguous): + decoded_image_tv = read_image(img_path) + encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg + + if "cmyk" in img_path: + pytest.xfail("Encoding a CMYK jpeg isn't supported") + if decoded_image_tv.shape[0] == 1: + pytest.xfail("Decoding a grayscale jpeg isn't supported") + # For more detail as to why check out: https://github.com/NVIDIA/cuda-samples/issues/23#issuecomment-559283013 + if contiguous: + decoded_image_tv = decoded_image_tv[None].contiguous(memory_format=torch.contiguous_format)[0] + else: + decoded_image_tv = decoded_image_tv[None].contiguous(memory_format=torch.channels_last)[0] + encoded_jpeg_cuda_tv = encode_fn(decoded_image_tv.cuda(), quality=75) + decoded_jpeg_cuda_tv = decode_jpeg(encoded_jpeg_cuda_tv.cpu()) + + # the actual encoded bytestreams from libnvjpeg and libjpeg-turbo differ for the same quality + # instead, we re-decode the encoded image and compare to the original + abs_mean_diff = (decoded_jpeg_cuda_tv.float() - decoded_image_tv.float()).abs().mean().item() + assert abs_mean_diff < 3 + + +def _test_batch_encode_jpegs_helper(scripted, contiguous, device): + decoded_images_tv = [] + for jpeg_path in get_images(IMAGE_ROOT, ".jpg"): + if "cmyk" in jpeg_path: + continue + decoded_image = read_image(jpeg_path) + if decoded_image.shape[0] == 1: + continue + if contiguous: + decoded_image = decoded_image[None].contiguous(memory_format=torch.contiguous_format)[0] + else: + decoded_image = decoded_image[None].contiguous(memory_format=torch.channels_last)[0] + decoded_images_tv.append(decoded_image) + + encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg + + decoded_images_tv_device = [img.to(device=device) for img in decoded_images_tv] + encoded_jpegs_tv_device = encode_fn(decoded_images_tv_device, quality=75) + encoded_jpegs_tv_device = [decode_jpeg(img.cpu()) for img in encoded_jpegs_tv_device] + + for original, encoded_decoded in zip(decoded_images_tv, encoded_jpegs_tv_device): + c, h, w = original.shape + abs_mean_diff = (original.float() - encoded_decoded.float()).abs().mean().item() + assert abs_mean_diff < 3 + + +@needs_cuda +@pytest.mark.parametrize("scripted", (False, True)) +@pytest.mark.parametrize("contiguous", (False, True)) +def test_batch_encode_jpegs_cuda(scripted, contiguous): + _test_batch_encode_jpegs_helper(scripted, contiguous, "cuda") + + +@needs_cuda +def test_single_encode_jpeg_cuda_errors(): + with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): + encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32, device="cuda")) + + with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"): + encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda")) + + with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"): + encode_jpeg(torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda")) + + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): + encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda")) + + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): + encode_jpeg(torch.empty((100, 100), dtype=torch.uint8, device="cuda")) + + +@needs_cuda +def test_batch_encode_jpegs_cuda_errors(): + with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"), + ] + ) + + with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"): + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"), + ] + ) + + with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"): + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"), + ] + ) + + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"), + ] + ) + + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + torch.empty((100, 100), dtype=torch.uint8, device="cuda"), + ] + ) + + with pytest.raises(RuntimeError, match="Input tensor should be on CPU"): + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"), + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + ] + ) + + with pytest.raises( + RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg" + ): + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"), + ] + ) + + if torch.cuda.device_count() >= 2: + with pytest.raises( + RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg" + ): + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:0"), + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:1"), + ] + ) + + with pytest.raises(ValueError, match="encode_jpeg requires at least one input tensor when a list is passed"): + encode_jpeg([]) + + @pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031") @pytest.mark.parametrize( "img_path", diff --git a/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp index e5f3d3d0161..b30c09440ba 100644 --- a/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp @@ -89,7 +89,12 @@ std::vector encode_jpeg_cuda( image.size(0)); // nvjpeg requires images to be contiguous - contig_images.push_back(image.contiguous()); + if (image.is_contiguous()) { + contig_images.push_back(image); + } + else { + contig_images.push_back(image.contiguous()); + } } // Create global nvJPEG handle diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 7f5e2d28067..ae8300eb7a2 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -190,6 +190,10 @@ def encode_jpeg( Takes a (list of) input tensor(s) in CHW layout and returns a (list of) buffer(s) with the contents of the corresponding JPEG file(s). + .. note:: + Passing a list of CUDA tensors is more efficient than repeated individual calls to ``encode_jpeg``. + For CPU tensors the performance is equivalent. + Args: input (Tensor[channels, image_height, image_width] or List[Tensor[channels, image_height, image_width]]): (list of) uint8 image tensor(s) of ``c`` channels, where ``c`` must be 1 or 3 @@ -204,7 +208,8 @@ def encode_jpeg( if quality < 1 or quality > 100: raise ValueError("Image quality should be a positive number between 1 and 100") if isinstance(input, list): - assert len(input) > 0, "encode_jpeg requires at least one input tensor when a list is passed" + if not input: + raise ValueError("encode_jpeg requires at least one input tensor when a list is passed") if input[0].device.type == "cuda": return torch.ops.image.encode_jpeg_cuda(input, quality) else: @@ -215,8 +220,6 @@ def encode_jpeg( else: return torch.ops.image.encode_jpeg(input, quality) - return output - def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): """ From a799c53826c5c637da52dde200ee40ff7aaf0a1f Mon Sep 17 00:00:00 2001 From: Dominik Kallusky Date: Tue, 30 Apr 2024 14:12:55 -0700 Subject: [PATCH 06/16] fix lints --- torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp index b30c09440ba..35672c25696 100644 --- a/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp @@ -90,9 +90,8 @@ std::vector encode_jpeg_cuda( // nvjpeg requires images to be contiguous if (image.is_contiguous()) { - contig_images.push_back(image); - } - else { + contig_images.push_back(image); + } else { contig_images.push_back(image.contiguous()); } } From c5810ffab8c6e54bae9a5879da0049736b5356db Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 31 May 2024 05:26:21 -0700 Subject: [PATCH 07/16] Ignore mypy --- torchvision/transforms/v2/functional/_augment.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index 4a806109eae..e0db4f388ad 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -78,14 +78,11 @@ def jpeg_image(image: torch.Tensor, quality: int) -> torch.Tensor: if image.shape[0] == 0: # degenerate return image.reshape(original_shape).clone() - images = [] - for i in range(image.shape[0]): - encoded_image = encode_jpeg(image[i], quality=quality) - assert isinstance(encoded_image, torch.Tensor) - images.append(decode_jpeg(encoded_image)) - - images = torch.stack(images, dim=0).view(original_shape) - return images + image = [ + decode_jpeg(encode_jpeg(image[i], quality=quality)) for i in range(image.shape[0]) # type: ignore[arg-type] + ] + image = torch.stack(image, dim=0).view(original_shape) + return image @_register_kernel_internal(jpeg, tv_tensors.Video) From ff40253788a0e052a2364f0f741e4eb38bf5b791 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 31 May 2024 05:27:09 -0700 Subject: [PATCH 08/16] Add comment --- torchvision/io/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/io/image.py b/torchvision/io/image.py index ae8300eb7a2..d2beaebe34a 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -235,7 +235,7 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(write_jpeg) output = encode_jpeg(input, quality) - assert isinstance(output, torch.Tensor) + assert isinstance(output, torch.Tensor) # Needed for torchscript write_file(filename, output) From 0972863f83f30b3adb7031225560fa9327988413 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 31 May 2024 06:25:44 -0700 Subject: [PATCH 09/16] minor test refactor --- test/test_image.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 068e2670f80..86f5f1fa2c0 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -8,7 +8,7 @@ import pytest import torch import torchvision.transforms.functional as F -from common_utils import assert_equal, needs_cuda +from common_utils import assert_equal, cpu_and_cuda, needs_cuda from PIL import __version__ as PILLOW_VERSION, Image, ImageOps from torchvision.io.image import ( _read_png_16, @@ -505,13 +505,6 @@ def test_encode_jpeg(img_path, scripted): assert_equal(encoded_jpeg_torch, encoded_jpeg_pil) -@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031") -@pytest.mark.parametrize("scripted", (True, False)) -@pytest.mark.parametrize("contiguous", (True, False)) -def test_batch_encode_jpegs(scripted, contiguous): - _test_batch_encode_jpegs_helper(scripted, contiguous, "cpu") - - @needs_cuda @pytest.mark.parametrize( "img_path", @@ -541,7 +534,12 @@ def test_single_encode_jpeg_cuda(img_path, scripted, contiguous): assert abs_mean_diff < 3 -def _test_batch_encode_jpegs_helper(scripted, contiguous, device): +@pytest.mark.parametrize("device", cpu_and_cuda()) +@pytest.mark.parametrize("scripted", (True, False)) +@pytest.mark.parametrize("contiguous", (True, False)) +def test_encode_jpegs_batch(scripted, contiguous, device): + if device == "cpu" and IS_MACOS: + pytest.skip("https://github.com/pytorch/vision/issues/8031") decoded_images_tv = [] for jpeg_path in get_images(IMAGE_ROOT, ".jpg"): if "cmyk" in jpeg_path: @@ -567,13 +565,6 @@ def _test_batch_encode_jpegs_helper(scripted, contiguous, device): assert abs_mean_diff < 3 -@needs_cuda -@pytest.mark.parametrize("scripted", (False, True)) -@pytest.mark.parametrize("contiguous", (False, True)) -def test_batch_encode_jpegs_cuda(scripted, contiguous): - _test_batch_encode_jpegs_helper(scripted, contiguous, "cuda") - - @needs_cuda def test_single_encode_jpeg_cuda_errors(): with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): From 62e072a567d008aca4de0364ea95991f3f96d4d4 Mon Sep 17 00:00:00 2001 From: Dominik Kallusky Date: Wed, 5 Jun 2024 09:23:52 -0700 Subject: [PATCH 10/16] Caching nvjpeg vars across calls --- benchmarks/encoding.py | 70 +++++ test/test_image.py | 44 +++- .../csrc/io/image/cuda/decode_jpeg_cuda.cpp | 24 +- ...jpeg_cuda.h => encode_decode_jpegs_cuda.h} | 14 +- .../csrc/io/image/cuda/encode_jpeg_cuda.cpp | 236 ----------------- .../csrc/io/image/cuda/encode_jpegs_cuda.cpp | 245 ++++++++++++++++++ .../csrc/io/image/cuda/encode_jpegs_cuda.h | 33 +++ torchvision/csrc/io/image/image.cpp | 2 +- torchvision/csrc/io/image/image.h | 2 +- torchvision/io/image.py | 4 +- 10 files changed, 420 insertions(+), 254 deletions(-) create mode 100644 benchmarks/encoding.py rename torchvision/csrc/io/image/cuda/{encode_decode_jpeg_cuda.h => encode_decode_jpegs_cuda.h} (54%) delete mode 100644 torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp create mode 100644 torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp create mode 100644 torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h diff --git a/benchmarks/encoding.py b/benchmarks/encoding.py new file mode 100644 index 00000000000..52aa37c352c --- /dev/null +++ b/benchmarks/encoding.py @@ -0,0 +1,70 @@ +import os +import platform +import statistics +import tarfile +import tempfile +import urllib.request + +import torch +import torch.utils.benchmark as benchmark +import torchvision + + +def print_machine_specs(): + print("Processor:", platform.processor()) + print("Platform:", platform.platform()) + print("Logical CPUs:", os.cpu_count()) + print(f"\nCUDA device: {torch.cuda.get_device_name()}") + print(f"Total Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") + + +def get_data(): + transform = torchvision.transforms.Compose( + [ + torchvision.transforms.PILToTensor(), + ] + ) + path = os.path.join(os.getcwd(), "data") + testset = torchvision.datasets.Places365( + root="./data", download=not os.path.exists(path), transform=transform, split="val" + ) + testloader = torch.utils.data.DataLoader( + testset, batch_size=1000, shuffle=False, num_workers=1, collate_fn=lambda batch: [r[0] for r in batch] + ) + return next(iter(testloader)) + + +def run_benchmark(batch): + results = [] + for device in ["cpu", "cuda"]: + batch_device = [t.to(device=device) for t in batch] + for size in [1, 100, 1000]: + for num_threads in [1, 12, 24]: + for stmt, strat in zip( + [ + "[torchvision.io.encode_jpeg(img) for img in batch_input]", + "torchvision.io.encode_jpeg(batch_input)", + ], + ["unfused", "fused"], + ): + batch_input = batch_device[:size] + t = benchmark.Timer( + stmt=stmt, + setup="import torchvision", + globals={"batch_input": batch_input}, + label=f"Image Encoding", + sub_label=f"{device.upper()} ({strat}): {stmt}", + description=f"{size} images", + num_threads=num_threads, + ) + results.append(t.blocked_autorange()) + compare = benchmark.Compare(results) + compare.print() + + +if __name__ == "__main__": + print_machine_specs() + batch = get_data() + mean_h, mean_w = statistics.mean(t.shape[-2] for t in batch), statistics.mean(t.shape[-1] for t in batch) + print(f"\nMean image size: {int(mean_h)}x{int(mean_w)}") + run_benchmark(batch) diff --git a/test/test_image.py b/test/test_image.py index 6043425ff03..30baadabdc0 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -1,3 +1,4 @@ +import concurrent.futures import glob import io import os @@ -508,6 +509,28 @@ def test_encode_jpeg(img_path, scripted): assert_equal(encoded_jpeg_torch, encoded_jpeg_pil) +@needs_cuda +def test_encode_jpeg_cuda_device_param(): + path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path) + + data = read_image(path) + + current_device = torch.cuda.current_device() + current_stream = torch.cuda.current_stream() + num_devices = torch.cuda.device_count() + devices = ["cuda", torch.device("cuda")] + [torch.device(f"cuda:{i}") for i in range(num_devices)] + results = [] + for device in devices: + print(f"python: device: {device}") + results.append(encode_jpeg(data.to(device=device))) + assert len(results) == len(devices) + for result in results: + assert torch.all(result.cpu() == results[0].cpu()) + + assert current_device == torch.cuda.current_device() + assert current_stream == torch.cuda.current_stream() + + @needs_cuda @pytest.mark.parametrize( "img_path", @@ -515,7 +538,7 @@ def test_encode_jpeg(img_path, scripted): ) @pytest.mark.parametrize("scripted", (False, True)) @pytest.mark.parametrize("contiguous", (False, True)) -def test_single_encode_jpeg_cuda(img_path, scripted, contiguous): +def test_encode_jpeg_cuda(img_path, scripted, contiguous): decoded_image_tv = read_image(img_path) encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg @@ -567,6 +590,25 @@ def test_encode_jpegs_batch(scripted, contiguous, device): abs_mean_diff = (original.float() - encoded_decoded.float()).abs().mean().item() assert abs_mean_diff < 3 + # test multithreaded decoding + # in the current version we prevent this by using a lock but we still want to test it + num_workers = 10 + with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [executor.submit(encode_fn, decoded_images_tv_device) for _ in range(num_workers)] + encoded_images_threaded = [future.result() for future in futures] + assert len(encoded_images_threaded) == num_workers + for encoded_images in encoded_images_threaded: + assert len(decoded_images_tv_device) == len(encoded_images) + for i, (encoded_image_cuda, decoded_image_tv) in enumerate(zip(encoded_images, decoded_images_tv_device)): + # make sure all the threads produce identical outputs + assert torch.all(encoded_image_cuda == encoded_images_threaded[0][i]) + + # make sure the outputs are identical or close enough to baseline + decoded_cuda_encoded_image = decode_jpeg(encoded_image_cuda.cpu()) + assert decoded_cuda_encoded_image.shape == decoded_image_tv.shape + assert decoded_cuda_encoded_image.dtype == decoded_image_tv.dtype + assert (decoded_cuda_encoded_image.cpu().float() - decoded_image_tv.cpu().float()).abs().mean() < 3 + @needs_cuda def test_single_encode_jpeg_cuda_errors(): diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index 59f9255f130..26fecc3e1f3 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -1,4 +1,4 @@ -#include "encode_decode_jpeg_cuda.h" +#include "encode_decode_jpegs_cuda.h" #include @@ -25,6 +25,10 @@ torch::Tensor decode_jpeg_cuda( #else +namespace { +static nvjpegHandle_t nvjpeg_handle = nullptr; +} + torch::Tensor decode_jpeg_cuda( const torch::Tensor& data, ImageReadMode mode, @@ -67,7 +71,23 @@ torch::Tensor decode_jpeg_cuda( at::cuda::CUDAGuard device_guard(device); // Create global nvJPEG handle - std::call_once(::nvjpeg_handle_creation_flag, nvjpeg_init); + static std::once_flag nvjpeg_handle_creation_flag; + std::call_once(nvjpeg_handle_creation_flag, []() { + if (nvjpeg_handle == nullptr) { + nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); + + if (create_status != NVJPEG_STATUS_SUCCESS) { + // Reset handle so that one can still call the function again in the + // same process if there was a failure + free(nvjpeg_handle); + nvjpeg_handle = nullptr; + } + TORCH_CHECK( + create_status == NVJPEG_STATUS_SUCCESS, + "nvjpegCreateSimple failed: ", + create_status); + } + }); // Create the jpeg state nvjpegJpegState_t jpeg_state; diff --git a/torchvision/csrc/io/image/cuda/encode_decode_jpeg_cuda.h b/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h similarity index 54% rename from torchvision/csrc/io/image/cuda/encode_decode_jpeg_cuda.h rename to torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h index 436da00ba8f..7723d11d621 100644 --- a/torchvision/csrc/io/image/cuda/encode_decode_jpeg_cuda.h +++ b/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h @@ -2,13 +2,7 @@ #include #include "../image_read_mode.h" - -#if NVJPEG_FOUND -#include - -extern nvjpegHandle_t nvjpeg_handle; -extern std::once_flag nvjpeg_handle_creation_flag; -#endif +#include "encode_jpegs_cuda.h" namespace vision { namespace image { @@ -18,11 +12,9 @@ C10_EXPORT torch::Tensor decode_jpeg_cuda( ImageReadMode mode, torch::Device device); -C10_EXPORT std::vector encode_jpeg_cuda( - const std::vector& images, +C10_EXPORT std::vector encode_jpegs_cuda( + const std::vector& decoded_images, const int64_t quality); -void nvjpeg_init(); - } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp deleted file mode 100644 index 35672c25696..00000000000 --- a/torchvision/csrc/io/image/cuda/encode_jpeg_cuda.cpp +++ /dev/null @@ -1,236 +0,0 @@ -#include -#include "c10/core/ScalarType.h" -#include "encode_decode_jpeg_cuda.h" -#include "torch/types.h" - -#include -#include -#include - -#if NVJPEG_FOUND -#include -#include -#include - -nvjpegHandle_t nvjpeg_handle = nullptr; -std::once_flag nvjpeg_handle_creation_flag; - -#endif -#include - -namespace vision { -namespace image { - -#if !NVJPEG_FOUND - -std::vector encode_jpeg_cuda( - const std::vector& images, - const int64_t quality) { - TORCH_CHECK( - false, "decode_jpeg_cuda: torchvision not compiled with nvJPEG support"); -} - -#else - -void nvjpeg_init() { - if (nvjpeg_handle == nullptr) { - nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); - - if (create_status != NVJPEG_STATUS_SUCCESS) { - // Reset handle so that one can still call the function again in the - // same process if there was a failure - free(nvjpeg_handle); - nvjpeg_handle = nullptr; - } - TORCH_CHECK( - create_status == NVJPEG_STATUS_SUCCESS, - "nvjpegCreateSimple failed: ", - create_status); - } -} - -torch::Tensor encode_single_jpeg( - const torch::Tensor& data, - const int64_t quality, - const cudaStream_t stream, - const torch::Device& device, - const nvjpegEncoderState_t& nv_enc_state, - const nvjpegEncoderParams_t& nv_enc_params); - -std::vector encode_jpeg_cuda( - const std::vector& images, - const int64_t quality) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cuda.encode_jpeg_cuda.encode_jpeg_cuda"); - - TORCH_CHECK(images.size() > 0, "Empty input tensor list"); - - torch::Device device = images[0].device(); - at::cuda::CUDAGuard device_guard(device); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index()); - - std::vector contig_images; - contig_images.reserve(images.size()); - for (const auto& image : images) { - TORCH_CHECK( - image.dtype() == torch::kU8, "Input tensor dtype should be uint8"); - - TORCH_CHECK( - image.device() == device, - "All input tensors must be on the same CUDA device when encoding with nvjpeg") - - TORCH_CHECK( - image.dim() == 3 && image.numel() > 0, - "Input data should be a 3-dimensional tensor"); - - TORCH_CHECK( - image.size(0) == 3, - "The number of channels should be 3, got: ", - image.size(0)); - - // nvjpeg requires images to be contiguous - if (image.is_contiguous()) { - contig_images.push_back(image); - } else { - contig_images.push_back(image.contiguous()); - } - } - - // Create global nvJPEG handle - std::call_once(::nvjpeg_handle_creation_flag, nvjpeg_init); - - nvjpegEncoderState_t nv_enc_state; - nvjpegEncoderParams_t nv_enc_params; - - // initialize nvjpeg structures - // these are rather expensive to create and thus will be reused across - // multiple calls to encode_single_jpeg - try { - nvjpegStatus_t stateCreateResult = - nvjpegEncoderStateCreate(nvjpeg_handle, &nv_enc_state, stream); - TORCH_CHECK( - stateCreateResult == NVJPEG_STATUS_SUCCESS, - "Failed to create nvjpeg encoder state: ", - stateCreateResult); - - nvjpegStatus_t paramsCreateResult = - nvjpegEncoderParamsCreate(nvjpeg_handle, &nv_enc_params, stream); - TORCH_CHECK( - paramsCreateResult == NVJPEG_STATUS_SUCCESS, - "Failed to create nvjpeg encoder params: ", - paramsCreateResult); - - nvjpegStatus_t paramsQualityStatus = - nvjpegEncoderParamsSetQuality(nv_enc_params, quality, stream); - TORCH_CHECK( - paramsQualityStatus == NVJPEG_STATUS_SUCCESS, - "Failed to set nvjpeg encoder params quality: ", - paramsQualityStatus); - - std::vector encoded_images; - for (const auto& image : contig_images) { - auto encoded_image = encode_single_jpeg( - image, quality, stream, device, nv_enc_state, nv_enc_params); - encoded_images.push_back(encoded_image); - } - // Clean up - nvjpegEncoderStateDestroy(nv_enc_state); - nvjpegEncoderParamsDestroy(nv_enc_params); - return encoded_images; - } catch (const std::exception& e) { - nvjpegEncoderStateDestroy(nv_enc_state); - nvjpegEncoderParamsDestroy(nv_enc_params); - throw; - } -} - -torch::Tensor encode_single_jpeg( - const torch::Tensor& src_image, - const int64_t quality, - const cudaStream_t stream, - const torch::Device& device, - const nvjpegEncoderState_t& nv_enc_state, - const nvjpegEncoderParams_t& nv_enc_params) { - int channels = src_image.size(0); - int height = src_image.size(1); - int width = src_image.size(2); - - nvjpegStatus_t samplingSetResult = nvjpegEncoderParamsSetSamplingFactors( - nv_enc_params, NVJPEG_CSS_444, stream); - TORCH_CHECK( - samplingSetResult == NVJPEG_STATUS_SUCCESS, - "Failed to set nvjpeg encoder params sampling factors: ", - samplingSetResult); - - // Create nvjpeg image - nvjpegImage_t target_image; - - for (int c = 0; c < channels; c++) { - target_image.channel[c] = src_image[c].data_ptr(); - // this is why we need contiguous tensors - target_image.pitch[c] = width; - } - for (int c = channels; c < NVJPEG_MAX_COMPONENT; c++) { - target_image.channel[c] = nullptr; - target_image.pitch[c] = 0; - } - nvjpegStatus_t encodingState; - - // Encode the image - encodingState = nvjpegEncodeImage( - nvjpeg_handle, - nv_enc_state, - nv_enc_params, - &target_image, - NVJPEG_INPUT_RGB, - width, - height, - stream); - - TORCH_CHECK( - encodingState == NVJPEG_STATUS_SUCCESS, - "image encoding failed: ", - encodingState); - - // Retrieve length of the encoded image - size_t length; - nvjpegStatus_t getStreamState = nvjpegEncodeRetrieveBitstreamDevice( - nvjpeg_handle, nv_enc_state, NULL, &length, stream); - TORCH_CHECK( - getStreamState == NVJPEG_STATUS_SUCCESS, - "Failed to retrieve encoded image stream state: ", - getStreamState); - - // Synchronize the stream to ensure that the encoded image is ready - cudaError_t syncState = cudaStreamSynchronize(stream); - TORCH_CHECK(syncState == cudaSuccess, "CUDA ERROR: ", syncState); - - // Reserve buffer for the encoded image - torch::Tensor encoded_image = torch::empty( - {static_cast(length)}, - torch::TensorOptions() - .dtype(torch::kByte) - .layout(torch::kStrided) - .device(device) - .requires_grad(false)); - syncState = cudaStreamSynchronize(stream); - TORCH_CHECK(syncState == cudaSuccess, "CUDA ERROR: ", syncState); - - // Retrieve the encoded image - getStreamState = nvjpegEncodeRetrieveBitstreamDevice( - nvjpeg_handle, - nv_enc_state, - encoded_image.data_ptr(), - &length, - 0); - TORCH_CHECK( - getStreamState == NVJPEG_STATUS_SUCCESS, - "Failed to retrieve encoded image: ", - getStreamState); - return encoded_image; -} - -#endif // NVJPEG_FOUND - -} // namespace image -} // namespace vision diff --git a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp new file mode 100644 index 00000000000..3f2d7104bdc --- /dev/null +++ b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp @@ -0,0 +1,245 @@ +#include "encode_jpegs_cuda.h" +#if !NVJPEG_FOUND +std::vector encode_jpegs_cuda( + const std::vector& images, + const int64_t quality) { + TORCH_CHECK( + false, "encode_jpegs_cuda: torchvision not compiled with nvJPEG support"); +} +#else + +#include +#include +#include +#include +#include +#include +#include +#include +#include "c10/core/ScalarType.h" + +namespace vision { +namespace image { + +// We use global variables to cache the encoder and decoder instances and +// reuse them across calls to the corresponding pytorch functions +std::mutex encoderMutex; +std::unique_ptr cudaJpegEncoder; + +std::vector encode_jpegs_cuda( + const std::vector& decoded_images, + const int64_t quality) { + C10_LOG_API_USAGE_ONCE( + "torchvision.csrc.io.image.cuda.encode_jpegs_cuda.encode_jpegs_cuda"); + + // Some nvjpeg structures are not thread safe so we're keeping it single + // threaded for now. In the future this may be an opportunity to unlock + // further speedups + std::lock_guard lock(encoderMutex); + TORCH_CHECK(decoded_images.size() > 0, "Empty input tensor list"); + torch::Device device = decoded_images[0].device(); + at::cuda::CUDAGuard device_guard(device); + + // lazy init of the encoder class + // the encoder object holds on to a lot of state and is expensive to create, + // so we reuse it across calls. NB: the cached structures are device specific + // and cannot be reused across devices + if (cudaJpegEncoder == nullptr || device != cudaJpegEncoder->target_device) { + if (cudaJpegEncoder != nullptr) + delete cudaJpegEncoder.release(); + + cudaJpegEncoder = std::make_unique(device); + + // Unfortunately, we cannot rely on the smart pointer releasing the encoder + // object correctly upon program exit. This is because, when cudaJpegEncoder + // gets destroyed, the CUDA runtime may already be shut down, rendering all + // destroy* calls in the encoder destructor invalid. Instead, we use an + // atexit hook which executes after main() finishes, but before CUDA shuts + // down when the program exits. + std::atexit([]() { delete cudaJpegEncoder.release(); }); + } + + std::vector contig_images; + contig_images.reserve(decoded_images.size()); + for (const auto& image : decoded_images) { + TORCH_CHECK( + image.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + + TORCH_CHECK( + image.device() == device, + "All input tensors must be on the same CUDA device when encoding with nvjpeg") + + TORCH_CHECK( + image.dim() == 3 && image.numel() > 0, + "Input data should be a 3-dimensional tensor"); + + TORCH_CHECK( + image.size(0) == 3, + "The number of channels should be 3, got: ", + image.size(0)); + + // nvjpeg requires images to be contiguous + if (image.is_contiguous()) { + contig_images.push_back(image); + } else { + contig_images.push_back(image.contiguous()); + } + } + + cudaJpegEncoder->setQuality(quality); + std::vector encoded_images; + at::cuda::CUDAEvent event; + event.record(cudaJpegEncoder->stream); + for (const auto& image : contig_images) { + auto encoded_image = cudaJpegEncoder->encode_jpeg(image); + encoded_images.push_back(encoded_image); + } + + // We use a dedicated stream to do the encoding and even though the results + // may be ready on that stream we cannot assume that they are also available + // on the current stream of the calling context when this function returns. We + // use a blocking event to ensure that this is indeed the case. Crucially, we + // do not want to block the host (which is what cudaStreamSynchronize would + // do) Events allow us to synchronize the streams without blocking the host + event.block(at::cuda::getCurrentCUDAStream( + cudaJpegEncoder->original_device.has_index() + ? cudaJpegEncoder->original_device.index() + : 0)); + return encoded_images; +} + +CUDAJpegEncoder::CUDAJpegEncoder(const torch::Device& target_device) + : original_device{torch::kCUDA, torch::cuda::current_device()}, + target_device{target_device}, + stream{at::cuda::getStreamFromPool( + false, + target_device.has_index() ? target_device.index() : 0)} { + nvjpegStatus_t status; + status = nvjpegCreateSimple(&nvjpeg_handle); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create nvjpeg handle: ", + status); + + status = nvjpegEncoderStateCreate(nvjpeg_handle, &nv_enc_state, stream); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create nvjpeg encoder state: ", + status); + + status = nvjpegEncoderParamsCreate(nvjpeg_handle, &nv_enc_params, stream); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create nvjpeg encoder params: ", + status); +} + +CUDAJpegEncoder::~CUDAJpegEncoder() { + nvjpegStatus_t status; + + status = nvjpegEncoderParamsDestroy(nv_enc_params); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to destroy nvjpeg encoder params: ", + status); + + status = nvjpegEncoderStateDestroy(nv_enc_state); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to destroy nvjpeg encoder state: ", + status); + + cudaStreamSynchronize(stream); + + status = nvjpegDestroy(nvjpeg_handle); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, "nvjpegDestroy failed: ", status); +} + +torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) { + int channels = src_image.size(0); + int height = src_image.size(1); + int width = src_image.size(2); + + nvjpegStatus_t status; + cudaError_t cudaStatus; + status = nvjpegEncoderParamsSetSamplingFactors( + nv_enc_params, NVJPEG_CSS_444, stream); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to set nvjpeg encoder params sampling factors: ", + status); + + nvjpegImage_t target_image; + for (int c = 0; c < channels; c++) { + target_image.channel[c] = src_image[c].data_ptr(); + // this is why we need contiguous tensors + target_image.pitch[c] = width; + } + for (int c = channels; c < NVJPEG_MAX_COMPONENT; c++) { + target_image.channel[c] = nullptr; + target_image.pitch[c] = 0; + } + // Encode the image + status = nvjpegEncodeImage( + nvjpeg_handle, + nv_enc_state, + nv_enc_params, + &target_image, + NVJPEG_INPUT_RGB, + width, + height, + stream); + + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, "image encoding failed: ", status); + // Retrieve length of the encoded image + size_t length; + status = nvjpegEncodeRetrieveBitstreamDevice( + nvjpeg_handle, nv_enc_state, NULL, &length, stream); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to retrieve encoded image stream state: ", + status); + + // Synchronize the stream to ensure that the encoded image is ready + cudaStatus = cudaStreamSynchronize(stream); + TORCH_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus); + + // Reserve buffer for the encoded image + torch::Tensor encoded_image = torch::empty( + {static_cast(length)}, + torch::TensorOptions() + .dtype(torch::kByte) + .layout(torch::kStrided) + .device(target_device) + .requires_grad(false)); + cudaStatus = cudaStreamSynchronize(stream); + TORCH_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus); + // Retrieve the encoded image + status = nvjpegEncodeRetrieveBitstreamDevice( + nvjpeg_handle, + nv_enc_state, + encoded_image.data_ptr(), + &length, + 0); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to retrieve encoded image: ", + status); + return encoded_image; +} + +void CUDAJpegEncoder::setQuality(const int64_t quality) { + nvjpegStatus_t paramsQualityStatus = + nvjpegEncoderParamsSetQuality(nv_enc_params, quality, stream); + TORCH_CHECK( + paramsQualityStatus == NVJPEG_STATUS_SUCCESS, + "Failed to set nvjpeg encoder params quality: ", + paramsQualityStatus); +} + +} // namespace image +} // namespace vision + +#endif // NVJPEG_FOUND diff --git a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h new file mode 100644 index 00000000000..ce49ca9eed0 --- /dev/null +++ b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h @@ -0,0 +1,33 @@ +#pragma once +#include +#include +#if NVJPEG_FOUND + +#include +#include +#include + +namespace vision { +namespace image { + +class CUDAJpegEncoder { + public: + CUDAJpegEncoder(const torch::Device& device); + ~CUDAJpegEncoder(); + + torch::Tensor encode_jpeg(const torch::Tensor& src_image); + + void setQuality(const int64_t); + + const torch::Device original_device; + const torch::Device target_device; + const c10::cuda::CUDAStream stream; + + protected: + nvjpegEncoderState_t nv_enc_state; + nvjpegEncoderParams_t nv_enc_params; + nvjpegHandle_t nvjpeg_handle; +}; +} // namespace image +} // namespace vision +#endif diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index 68b4c7813b0..68267b72604 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -27,7 +27,7 @@ static auto registry = .op("image::decode_image(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", &decode_image) .op("image::decode_jpeg_cuda", &decode_jpeg_cuda) - .op("image::encode_jpeg_cuda", &encode_jpeg_cuda) + .op("image::encode_jpegs_cuda", &encode_jpegs_cuda) .op("image::_jpeg_version", &_jpeg_version) .op("image::_is_compiled_against_turbo", &_is_compiled_against_turbo); diff --git a/torchvision/csrc/io/image/image.h b/torchvision/csrc/io/image/image.h index b5612150359..f7e9b63801c 100644 --- a/torchvision/csrc/io/image/image.h +++ b/torchvision/csrc/io/image/image.h @@ -7,4 +7,4 @@ #include "cpu/encode_jpeg.h" #include "cpu/encode_png.h" #include "cpu/read_write_file.h" -#include "cuda/encode_decode_jpeg_cuda.h" +#include "cuda/encode_decode_jpegs_cuda.h" diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 691886898b0..9656fdd163e 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -211,12 +211,12 @@ def encode_jpeg( if not input: raise ValueError("encode_jpeg requires at least one input tensor when a list is passed") if input[0].device.type == "cuda": - return torch.ops.image.encode_jpeg_cuda(input, quality) + return torch.ops.image.encode_jpegs_cuda(input, quality) else: return [torch.ops.image.encode_jpeg(image, quality) for image in input] else: # single input tensor if input.device.type == "cuda": - return torch.ops.image.encode_jpeg_cuda([input], quality)[0] + return torch.ops.image.encode_jpegs_cuda([input], quality)[0] else: return torch.ops.image.encode_jpeg(input, quality) From f190d996db717c6c527e0eed12394e1380702e3b Mon Sep 17 00:00:00 2001 From: Dominik Kallusky Date: Wed, 5 Jun 2024 09:55:44 -0700 Subject: [PATCH 11/16] Update if nvjpeg not found --- benchmarks/encoding.py | 5 +---- .../csrc/io/image/cuda/encode_jpegs_cuda.cpp | 13 +++++++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/benchmarks/encoding.py b/benchmarks/encoding.py index 52aa37c352c..f994b03c783 100644 --- a/benchmarks/encoding.py +++ b/benchmarks/encoding.py @@ -1,9 +1,6 @@ import os import platform import statistics -import tarfile -import tempfile -import urllib.request import torch import torch.utils.benchmark as benchmark @@ -52,7 +49,7 @@ def run_benchmark(batch): stmt=stmt, setup="import torchvision", globals={"batch_input": batch_input}, - label=f"Image Encoding", + label="Image Encoding", sub_label=f"{device.upper()} ({strat}): {stmt}", description=f"{size} images", num_threads=num_threads, diff --git a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp index 3f2d7104bdc..6dbbbd1126d 100644 --- a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp @@ -1,11 +1,15 @@ #include "encode_jpegs_cuda.h" #if !NVJPEG_FOUND +namespace vision { +namespace image { std::vector encode_jpegs_cuda( - const std::vector& images, + const std::vector& decoded_images, const int64_t quality) { TORCH_CHECK( false, "encode_jpegs_cuda: torchvision not compiled with nvJPEG support"); } +} // namespace image +} // namespace vision #else #include @@ -111,9 +115,10 @@ std::vector encode_jpegs_cuda( CUDAJpegEncoder::CUDAJpegEncoder(const torch::Device& target_device) : original_device{torch::kCUDA, torch::cuda::current_device()}, target_device{target_device}, - stream{at::cuda::getStreamFromPool( - false, - target_device.has_index() ? target_device.index() : 0)} { + stream{ + target_device.has_index() + ? at::cuda::getStreamFromPool(false, target_device.index()) + : at::cuda::getStreamFromPool(false)} { nvjpegStatus_t status; status = nvjpegCreateSimple(&nvjpeg_handle); TORCH_CHECK( From 50510502c8e85734795235f6a4fa284fb0e2a87a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 10 Jun 2024 02:54:16 -0700 Subject: [PATCH 12/16] Revert "Ignore mypy" This reverts commit c5810ffab8c6e54bae9a5879da0049736b5356db. --- torchvision/transforms/v2/functional/_augment.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index e0db4f388ad..4a806109eae 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -78,11 +78,14 @@ def jpeg_image(image: torch.Tensor, quality: int) -> torch.Tensor: if image.shape[0] == 0: # degenerate return image.reshape(original_shape).clone() - image = [ - decode_jpeg(encode_jpeg(image[i], quality=quality)) for i in range(image.shape[0]) # type: ignore[arg-type] - ] - image = torch.stack(image, dim=0).view(original_shape) - return image + images = [] + for i in range(image.shape[0]): + encoded_image = encode_jpeg(image[i], quality=quality) + assert isinstance(encoded_image, torch.Tensor) + images.append(decode_jpeg(encoded_image)) + + images = torch.stack(images, dim=0).view(original_shape) + return images @_register_kernel_internal(jpeg, tv_tensors.Video) From 136f790d0aeede2308ccd35dea7d1b0e4e140f8e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 10 Jun 2024 02:56:11 -0700 Subject: [PATCH 13/16] Add comment --- torchvision/transforms/v2/functional/_augment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index 4a806109eae..60b49099fc5 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -81,7 +81,7 @@ def jpeg_image(image: torch.Tensor, quality: int) -> torch.Tensor: images = [] for i in range(image.shape[0]): encoded_image = encode_jpeg(image[i], quality=quality) - assert isinstance(encoded_image, torch.Tensor) + assert isinstance(encoded_image, torch.Tensor) # For torchscript images.append(decode_jpeg(encoded_image)) images = torch.stack(images, dim=0).view(original_shape) From 0a88d27b486d5a5fa9da89b6107401a50271761c Mon Sep 17 00:00:00 2001 From: Dominik Kallusky Date: Tue, 11 Jun 2024 04:44:33 -0700 Subject: [PATCH 14/16] minor changes to address ahmad's comments --- .../csrc/io/image/cuda/encode_jpegs_cuda.cpp | 24 ++++++++++++++----- .../csrc/io/image/cuda/encode_jpegs_cuda.h | 2 +- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp index 6dbbbd1126d..960669554ac 100644 --- a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp @@ -58,8 +58,10 @@ std::vector encode_jpegs_cuda( // object correctly upon program exit. This is because, when cudaJpegEncoder // gets destroyed, the CUDA runtime may already be shut down, rendering all // destroy* calls in the encoder destructor invalid. Instead, we use an - // atexit hook which executes after main() finishes, but before CUDA shuts - // down when the program exits. + // atexit hook which executes after main() finishes, but hopefully before + // CUDA shuts down when the program exits. If CUDA is already shut down the + // destructor will detect this and will not attempt to destroy any encoder + // structures. std::atexit([]() { delete cudaJpegEncoder.release(); }); } @@ -90,7 +92,7 @@ std::vector encode_jpegs_cuda( } } - cudaJpegEncoder->setQuality(quality); + cudaJpegEncoder->set_quality(quality); std::vector encoded_images; at::cuda::CUDAEvent event; event.record(cudaJpegEncoder->stream); @@ -103,8 +105,9 @@ std::vector encode_jpegs_cuda( // may be ready on that stream we cannot assume that they are also available // on the current stream of the calling context when this function returns. We // use a blocking event to ensure that this is indeed the case. Crucially, we - // do not want to block the host (which is what cudaStreamSynchronize would - // do) Events allow us to synchronize the streams without blocking the host + // do not want to block the host at this particular point + // (which is what cudaStreamSynchronize would do.) Events allow us to + // synchronize the streams without blocking the host. event.block(at::cuda::getCurrentCUDAStream( cudaJpegEncoder->original_device.has_index() ? cudaJpegEncoder->original_device.index() @@ -140,6 +143,15 @@ CUDAJpegEncoder::CUDAJpegEncoder(const torch::Device& target_device) } CUDAJpegEncoder::~CUDAJpegEncoder() { + // We run cudaGetDeviceCount as a dummy to test if the CUDA runtime is still + // initialized. If it is not, we can skip the rest of this function as it is + // unsafe to execute. + int deviceCount = 0; + cudaError_t error = cudaGetDeviceCount(&deviceCount); + if (error != cudaSuccess) + return; // CUDA runtime has already shut down. There's nothing we can do + // now. + nvjpegStatus_t status; status = nvjpegEncoderParamsDestroy(nv_enc_params); @@ -235,7 +247,7 @@ torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) { return encoded_image; } -void CUDAJpegEncoder::setQuality(const int64_t quality) { +void CUDAJpegEncoder::set_quality(const int64_t quality) { nvjpegStatus_t paramsQualityStatus = nvjpegEncoderParamsSetQuality(nv_enc_params, quality, stream); TORCH_CHECK( diff --git a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h index ce49ca9eed0..543940f1585 100644 --- a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h @@ -17,7 +17,7 @@ class CUDAJpegEncoder { torch::Tensor encode_jpeg(const torch::Tensor& src_image); - void setQuality(const int64_t); + void set_quality(const int64_t quality); const torch::Device original_device; const torch::Device target_device; From f3c8a72001e362f816de0191435f1d7e793f2eaa Mon Sep 17 00:00:00 2001 From: Dominik Kallusky Date: Wed, 12 Jun 2024 13:06:57 -0700 Subject: [PATCH 15/16] add dtor log messages --- .../csrc/io/image/cuda/encode_jpegs_cuda.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp index 960669554ac..b1826e43900 100644 --- a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp @@ -146,20 +146,28 @@ CUDAJpegEncoder::~CUDAJpegEncoder() { // We run cudaGetDeviceCount as a dummy to test if the CUDA runtime is still // initialized. If it is not, we can skip the rest of this function as it is // unsafe to execute. + + std::cout << "CUDAJpegEncoder dtor: checking if CUDA runtime is still alive" + << std::endl; int deviceCount = 0; cudaError_t error = cudaGetDeviceCount(&deviceCount); - if (error != cudaSuccess) + if (error != cudaSuccess) { + std::cout << "CUDAJpegEncoder dtor: CUDA already shut down" << std::endl; return; // CUDA runtime has already shut down. There's nothing we can do // now. + } nvjpegStatus_t status; + std::cout << "CUDAJpegEncoder dtor: 1" << std::endl; + status = nvjpegEncoderParamsDestroy(nv_enc_params); + std::cout << "status: " << status << std::endl; TORCH_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to destroy nvjpeg encoder params: ", status); - + std::cout << "CUDAJpegEncoder dtor: 2" << std::endl; status = nvjpegEncoderStateDestroy(nv_enc_state); TORCH_CHECK( status == NVJPEG_STATUS_SUCCESS, @@ -167,10 +175,11 @@ CUDAJpegEncoder::~CUDAJpegEncoder() { status); cudaStreamSynchronize(stream); - + std::cout << "CUDAJpegEncoder dtor: 3" << std::endl; status = nvjpegDestroy(nvjpeg_handle); TORCH_CHECK( status == NVJPEG_STATUS_SUCCESS, "nvjpegDestroy failed: ", status); + std::cout << "CUDAJpegEncoder dtor: 4" << std::endl; } torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) { From 117d1f1280673927b4dc89a6600dd17e1ac7e86b Mon Sep 17 00:00:00 2001 From: Dominik Kallusky Date: Wed, 12 Jun 2024 14:00:10 -0700 Subject: [PATCH 16/16] Skip CUDA cleanup altogether --- .../csrc/io/image/cuda/encode_jpegs_cuda.cpp | 77 ++++++++++--------- 1 file changed, 40 insertions(+), 37 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp index b1826e43900..1f10327ddbf 100644 --- a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp @@ -143,43 +143,46 @@ CUDAJpegEncoder::CUDAJpegEncoder(const torch::Device& target_device) } CUDAJpegEncoder::~CUDAJpegEncoder() { - // We run cudaGetDeviceCount as a dummy to test if the CUDA runtime is still - // initialized. If it is not, we can skip the rest of this function as it is - // unsafe to execute. - - std::cout << "CUDAJpegEncoder dtor: checking if CUDA runtime is still alive" - << std::endl; - int deviceCount = 0; - cudaError_t error = cudaGetDeviceCount(&deviceCount); - if (error != cudaSuccess) { - std::cout << "CUDAJpegEncoder dtor: CUDA already shut down" << std::endl; - return; // CUDA runtime has already shut down. There's nothing we can do - // now. - } - - nvjpegStatus_t status; - - std::cout << "CUDAJpegEncoder dtor: 1" << std::endl; - - status = nvjpegEncoderParamsDestroy(nv_enc_params); - std::cout << "status: " << status << std::endl; - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to destroy nvjpeg encoder params: ", - status); - std::cout << "CUDAJpegEncoder dtor: 2" << std::endl; - status = nvjpegEncoderStateDestroy(nv_enc_state); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to destroy nvjpeg encoder state: ", - status); - - cudaStreamSynchronize(stream); - std::cout << "CUDAJpegEncoder dtor: 3" << std::endl; - status = nvjpegDestroy(nvjpeg_handle); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, "nvjpegDestroy failed: ", status); - std::cout << "CUDAJpegEncoder dtor: 4" << std::endl; + /* + The below code works on Mac and Linux, but fails on Windows. + This is because on Windows, the atexit hook which calls this + destructor executes after cuda is already shut down causing SIGSEGV. + We do not have a solution to this problem at the moment, so we'll + just leak the libnvjpeg & cuda variables for the time being and hope + that the CUDA runtime handles cleanup for us. + Please send a PR if you have a solution for this problem. + */ + + // // We run cudaGetDeviceCount as a dummy to test if the CUDA runtime is + // still + // // initialized. If it is not, we can skip the rest of this function as it + // is + // // unsafe to execute. + // int deviceCount = 0; + // cudaError_t error = cudaGetDeviceCount(&deviceCount); + // if (error != cudaSuccess) + // return; // CUDA runtime has already shut down. There's nothing we can do + // // now. + + // nvjpegStatus_t status; + + // status = nvjpegEncoderParamsDestroy(nv_enc_params); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy nvjpeg encoder params: ", + // status); + + // status = nvjpegEncoderStateDestroy(nv_enc_state); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy nvjpeg encoder state: ", + // status); + + // cudaStreamSynchronize(stream); + + // status = nvjpegDestroy(nvjpeg_handle); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, "nvjpegDestroy failed: ", status); } torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) {