Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding GPU acceleration to encode_jpeg #8391

Merged
merged 21 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions benchmarks/encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import platform
import statistics

import torch
import torch.utils.benchmark as benchmark
import torchvision


def print_machine_specs():
print("Processor:", platform.processor())
print("Platform:", platform.platform())
print("Logical CPUs:", os.cpu_count())
print(f"\nCUDA device: {torch.cuda.get_device_name()}")
print(f"Total Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


def get_data():
transform = torchvision.transforms.Compose(
[
torchvision.transforms.PILToTensor(),
]
)
path = os.path.join(os.getcwd(), "data")
testset = torchvision.datasets.Places365(
root="./data", download=not os.path.exists(path), transform=transform, split="val"
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=1000, shuffle=False, num_workers=1, collate_fn=lambda batch: [r[0] for r in batch]
)
return next(iter(testloader))


def run_benchmark(batch):
results = []
for device in ["cpu", "cuda"]:
batch_device = [t.to(device=device) for t in batch]
for size in [1, 100, 1000]:
for num_threads in [1, 12, 24]:
for stmt, strat in zip(
[
"[torchvision.io.encode_jpeg(img) for img in batch_input]",
"torchvision.io.encode_jpeg(batch_input)",
],
["unfused", "fused"],
):
batch_input = batch_device[:size]
t = benchmark.Timer(
stmt=stmt,
setup="import torchvision",
globals={"batch_input": batch_input},
label="Image Encoding",
sub_label=f"{device.upper()} ({strat}): {stmt}",
description=f"{size} images",
num_threads=num_threads,
)
results.append(t.blocked_autorange())
compare = benchmark.Compare(results)
compare.print()


if __name__ == "__main__":
print_machine_specs()
batch = get_data()
mean_h, mean_w = statistics.mean(t.shape[-2] for t in batch), statistics.mean(t.shape[-1] for t in batch)
print(f"\nMean image size: {int(mean_h)}x{int(mean_w)}")
run_benchmark(batch)
197 changes: 196 additions & 1 deletion test/test_image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import concurrent.futures
import glob
import io
import os
Expand All @@ -10,7 +11,7 @@
import requests
import torch
import torchvision.transforms.functional as F
from common_utils import assert_equal, IN_OSS_CI, needs_cuda
from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
from torchvision.io.image import (
_read_png_16,
Expand Down Expand Up @@ -508,6 +509,200 @@ def test_encode_jpeg(img_path, scripted):
assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)


@needs_cuda
def test_encode_jpeg_cuda_device_param():
path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)

data = read_image(path)

current_device = torch.cuda.current_device()
current_stream = torch.cuda.current_stream()
num_devices = torch.cuda.device_count()
devices = ["cuda", torch.device("cuda")] + [torch.device(f"cuda:{i}") for i in range(num_devices)]
results = []
for device in devices:
print(f"python: device: {device}")
results.append(encode_jpeg(data.to(device=device)))
assert len(results) == len(devices)
for result in results:
assert torch.all(result.cpu() == results[0].cpu())

assert current_device == torch.cuda.current_device()
assert current_stream == torch.cuda.current_stream()


@needs_cuda
@pytest.mark.parametrize(
"img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
)
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize("contiguous", (False, True))
def test_encode_jpeg_cuda(img_path, scripted, contiguous):
decoded_image_tv = read_image(img_path)
encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg

if "cmyk" in img_path:
pytest.xfail("Encoding a CMYK jpeg isn't supported")
if decoded_image_tv.shape[0] == 1:
pytest.xfail("Decoding a grayscale jpeg isn't supported")
# For more detail as to why check out: https://github.com/NVIDIA/cuda-samples/issues/23#issuecomment-559283013
if contiguous:
decoded_image_tv = decoded_image_tv[None].contiguous(memory_format=torch.contiguous_format)[0]
else:
decoded_image_tv = decoded_image_tv[None].contiguous(memory_format=torch.channels_last)[0]
encoded_jpeg_cuda_tv = encode_fn(decoded_image_tv.cuda(), quality=75)
decoded_jpeg_cuda_tv = decode_jpeg(encoded_jpeg_cuda_tv.cpu())

# the actual encoded bytestreams from libnvjpeg and libjpeg-turbo differ for the same quality
# instead, we re-decode the encoded image and compare to the original
abs_mean_diff = (decoded_jpeg_cuda_tv.float() - decoded_image_tv.float()).abs().mean().item()
assert abs_mean_diff < 3


@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scripted", (True, False))
@pytest.mark.parametrize("contiguous", (True, False))
def test_encode_jpegs_batch(scripted, contiguous, device):
if device == "cpu" and IS_MACOS:
pytest.skip("https://github.com/pytorch/vision/issues/8031")
decoded_images_tv = []
for jpeg_path in get_images(IMAGE_ROOT, ".jpg"):
if "cmyk" in jpeg_path:
continue
decoded_image = read_image(jpeg_path)
if decoded_image.shape[0] == 1:
continue
if contiguous:
decoded_image = decoded_image[None].contiguous(memory_format=torch.contiguous_format)[0]
else:
decoded_image = decoded_image[None].contiguous(memory_format=torch.channels_last)[0]
decoded_images_tv.append(decoded_image)

encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg

decoded_images_tv_device = [img.to(device=device) for img in decoded_images_tv]
encoded_jpegs_tv_device = encode_fn(decoded_images_tv_device, quality=75)
encoded_jpegs_tv_device = [decode_jpeg(img.cpu()) for img in encoded_jpegs_tv_device]

for original, encoded_decoded in zip(decoded_images_tv, encoded_jpegs_tv_device):
c, h, w = original.shape
abs_mean_diff = (original.float() - encoded_decoded.float()).abs().mean().item()
assert abs_mean_diff < 3

# test multithreaded decoding
# in the current version we prevent this by using a lock but we still want to test it
num_workers = 10
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(encode_fn, decoded_images_tv_device) for _ in range(num_workers)]
encoded_images_threaded = [future.result() for future in futures]
assert len(encoded_images_threaded) == num_workers
for encoded_images in encoded_images_threaded:
assert len(decoded_images_tv_device) == len(encoded_images)
for i, (encoded_image_cuda, decoded_image_tv) in enumerate(zip(encoded_images, decoded_images_tv_device)):
# make sure all the threads produce identical outputs
assert torch.all(encoded_image_cuda == encoded_images_threaded[0][i])

# make sure the outputs are identical or close enough to baseline
decoded_cuda_encoded_image = decode_jpeg(encoded_image_cuda.cpu())
assert decoded_cuda_encoded_image.shape == decoded_image_tv.shape
assert decoded_cuda_encoded_image.dtype == decoded_image_tv.dtype
assert (decoded_cuda_encoded_image.cpu().float() - decoded_image_tv.cpu().float()).abs().mean() < 3


@needs_cuda
def test_single_encode_jpeg_cuda_errors():
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"))

with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"):
encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"))

with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"):
encode_jpeg(torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"))

with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"))

with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8, device="cuda"))


@needs_cuda
def test_batch_encode_jpegs_cuda_errors():
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"),
]
)

with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"),
]
)

with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"),
]
)

with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"),
]
)

with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((100, 100), dtype=torch.uint8, device="cuda"),
]
)

with pytest.raises(RuntimeError, match="Input tensor should be on CPU"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"),
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
]
)

with pytest.raises(
RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg"
):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"),
]
)

if torch.cuda.device_count() >= 2:
with pytest.raises(
RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg"
):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:0"),
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:1"),
]
)

with pytest.raises(ValueError, match="encode_jpeg requires at least one input tensor when a list is passed"):
encode_jpeg([])


@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031")
@pytest.mark.parametrize(
"img_path",
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "decode_jpeg_cuda.h"
#include "encode_decode_jpegs_cuda.h"

#include <ATen/ATen.h>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <torch/types.h>
#include "../image_read_mode.h"
#include "encode_jpegs_cuda.h"

namespace vision {
namespace image {
Expand All @@ -11,5 +12,9 @@ C10_EXPORT torch::Tensor decode_jpeg_cuda(
ImageReadMode mode,
torch::Device device);

C10_EXPORT std::vector<torch::Tensor> encode_jpegs_cuda(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment here or somewhere for the user to say that it only supports contiguous tensors?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 87 in encode_jpegs_cuda.cpp should takes care of handling non-contiguous images.

const std::vector<torch::Tensor>& decoded_images,
const int64_t quality);

} // namespace image
} // namespace vision
Loading
Loading