Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding GPU acceleration to encode_jpeg #8391

Merged
merged 21 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
149 changes: 149 additions & 0 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,155 @@ def normalize_dimensions(img_pil):
return img_pil


@needs_cuda
@pytest.mark.parametrize(
"img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
)
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize("contiguous", (False, True))
def test_single_encode_jpeg_cuda(img_path, scripted, contiguous):
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved
decoded_image_tv = read_image(img_path)
encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg

if "cmyk" in img_path:
pytest.xfail("Encoding a CMYK jpeg isn't supported")
if decoded_image_tv.shape[0] == 1:
pytest.xfail("Decoding a grayscale jpeg isn't supported")
# For more detail as to why check out: https://github.com/NVIDIA/cuda-samples/issues/23#issuecomment-559283013
if not contiguous:
decoded_image_tv = decoded_image_tv.permute(1, 2, 0).contiguous().permute(2, 0, 1)
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved
encoded_jpeg_cuda_tv = encode_fn(decoded_image_tv.cuda(), quality=75)
decoded_jpeg_cuda_tv = decode_jpeg(encoded_jpeg_cuda_tv.cpu())

# the actual encoded bytestreams from libnvjpeg and libjpeg-turbo differ for the same quality
# instead, we re-decode the encoded image and compare to the original
abs_mean_diff = (
(decoded_jpeg_cuda_tv.type(torch.float32) - decoded_image_tv.type(torch.float32)).abs().mean().item()
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved
)
assert abs_mean_diff < 3


@needs_cuda
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize("contiguous", (False, True))
def test_batch_encode_jpegs_cuda(scripted, contiguous):
decoded_images_tv = []
for jpeg_path in get_images(IMAGE_ROOT, ".jpg"):
if "cmyk" in jpeg_path:
continue
decoded_image = read_image(jpeg_path)
if decoded_image.shape[0] == 1:
continue
if not contiguous:
decoded_image = decoded_image.permute(1, 2, 0).contiguous().permute(2, 0, 1)
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved
decoded_images_tv.append(decoded_image)

encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg

decoded_images_tv_cuda = [img.cuda() for img in decoded_images_tv]
encoded_jpegs_tv_cuda = encode_fn(decoded_images_tv_cuda, quality=75)
decoded_jpegs_tv_cuda = [decode_jpeg(img.cpu()) for img in encoded_jpegs_tv_cuda]

for original, encoded_decoded in zip(decoded_images_tv, decoded_jpegs_tv_cuda):
c, h, w = original.shape
abs_mean_diff = (original.type(torch.float32) - encoded_decoded.type(torch.float32)).abs().mean().item()
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved
assert abs_mean_diff < 3


@needs_cuda
def test_single_encode_jpeg_cuda_errors():
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"))

with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"):
encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"))

with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"):
encode_jpeg(torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"))

with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"))

with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8, device="cuda"))


@needs_cuda
def test_batch_encode_jpegs_cuda_errors():
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"),
]
)

with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"),
]
)

with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"),
]
)

with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"),
]
)

with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((100, 100), dtype=torch.uint8, device="cuda"),
]
)

with pytest.raises(RuntimeError, match="Input tensor should be on CPU"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"),
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
]
)

with pytest.raises(
RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg"
):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"),
]
)

if torch.cuda.device_count() >= 2:
with pytest.raises(
RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg"
):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:0"),
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:1"),
]
)

with pytest.raises(AssertionError, match="encode_jpeg requires at least one input tensor when a list is passed"):
encode_jpeg([])


@pytest.mark.parametrize(
"img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
Expand Down
24 changes: 2 additions & 22 deletions torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "decode_jpeg_cuda.h"
#include "encode_decode_jpeg_cuda.h"

#include <ATen/ATen.h>

Expand All @@ -25,10 +25,6 @@ torch::Tensor decode_jpeg_cuda(

#else

namespace {
static nvjpegHandle_t nvjpeg_handle = nullptr;
}

torch::Tensor decode_jpeg_cuda(
const torch::Tensor& data,
ImageReadMode mode,
Expand Down Expand Up @@ -71,23 +67,7 @@ torch::Tensor decode_jpeg_cuda(
at::cuda::CUDAGuard device_guard(device);

// Create global nvJPEG handle
static std::once_flag nvjpeg_handle_creation_flag;
std::call_once(nvjpeg_handle_creation_flag, []() {
if (nvjpeg_handle == nullptr) {
nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle);

if (create_status != NVJPEG_STATUS_SUCCESS) {
// Reset handle so that one can still call the function again in the
// same process if there was a failure
free(nvjpeg_handle);
nvjpeg_handle = nullptr;
}
TORCH_CHECK(
create_status == NVJPEG_STATUS_SUCCESS,
"nvjpegCreateSimple failed: ",
create_status);
}
});
std::call_once(::nvjpeg_handle_creation_flag, nvjpeg_init);

// Create the jpeg state
nvjpegJpegState_t jpeg_state;
Expand Down
15 changes: 0 additions & 15 deletions torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h

This file was deleted.

28 changes: 28 additions & 0 deletions torchvision/csrc/io/image/cuda/encode_decode_jpeg_cuda.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once

#include <torch/types.h>
#include "../image_read_mode.h"

#if NVJPEG_FOUND
#include <nvjpeg.h>

extern nvjpegHandle_t nvjpeg_handle;
extern std::once_flag nvjpeg_handle_creation_flag;
#endif

namespace vision {
namespace image {

C10_EXPORT torch::Tensor decode_jpeg_cuda(
const torch::Tensor& data,
ImageReadMode mode,
torch::Device device);

C10_EXPORT std::vector<torch::Tensor> encode_jpeg_cuda(
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: perhaps the name itself should indicate this is a plurality of images, like maybe encode_jpegs_cuda?

const std::vector<torch::Tensor>& images,
const int64_t quality);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: add a comment about quality. Is higher better or lower? What is the range/min/max here?


void nvjpeg_init();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're not exposing this one, should we put it in a different namespace than in vision::image?


} // namespace image
} // namespace vision