Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import io, tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image
from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image
from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available
from torchvision.utils import _Image_fromarray


IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
CVCUDA_AVAILABLE = _is_cvcuda_available()
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."
Expand Down Expand Up @@ -275,6 +277,17 @@ def combinations_grid(**kwargs):
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]


def cvcuda_to_pil_compatible_tensor(tensor):
tensor = cvcuda_to_tensor(tensor)
if tensor.ndim != 4:
raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.")
if tensor.shape[0] != 1:
raise ValueError(
f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}."
)
return tensor.squeeze(0).cpu()


class ImagePair(TensorLikePair):
def __init__(
self,
Expand All @@ -286,6 +299,13 @@ def __init__(
):
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
actual, expected = (to_image(input) for input in [actual, expected])
elif CVCUDA_AVAILABLE and all(isinstance(input, _import_cvcuda().Tensor) for input in [actual, expected]):
actual, expected = (cvcuda_to_tensor(input) for input in [actual, expected])
elif CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor) and isinstance(expected, PIL.Image.Image):
actual = cvcuda_to_pil_compatible_tensor(actual)
expected = to_image(expected)
elif CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor):
actual = cvcuda_to_pil_compatible_tensor(actual)

super().__init__(actual, expected, **other_parameters)
self.mae = mae
Expand Down Expand Up @@ -400,8 +420,9 @@ def make_image_pil(*args, **kwargs):
return to_pil_image(make_image(*args, **kwargs))


def make_image_cvcuda(*args, **kwargs):
return to_cvcuda_tensor(make_image(*args, **kwargs))
def make_image_cvcuda(*args, batch_dims=(1,), **kwargs):
# explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4)
return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs))


def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"):
Expand Down
102 changes: 94 additions & 8 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
assert_equal,
cache,
cpu_and_cuda,
cvcuda_to_pil_compatible_tensor,
freeze_rng_state,
ignore_jit_no_profile_information_warning,
make_bounding_boxes,
Expand Down Expand Up @@ -6357,7 +6358,17 @@ class TestRgbToGrayscale:
def test_kernel_image(self, dtype, device):
check_kernel(F.rgb_to_grayscale_image, make_image(dtype=dtype, device=device))

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
def test_functional(self, make_input):
check_functional(F.rgb_to_grayscale, make_input())

Expand All @@ -6367,23 +6378,58 @@ def test_functional(self, make_input):
(F.rgb_to_grayscale_image, torch.Tensor),
(F._color._rgb_to_grayscale_image_pil, PIL.Image.Image),
(F.rgb_to_grayscale_image, tv_tensors.Image),
pytest.param(
F._color._rgb_to_grayscale_cvcuda,
"cvcuda.Tensor",
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
),
],
)
def test_functional_signature(self, kernel, input_type):
if input_type == "cvcuda.Tensor":
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.rgb_to_grayscale, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize("transform", [transforms.Grayscale(), transforms.RandomGrayscale(p=1)])
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
def test_transform(self, transform, make_input):
if make_input is make_image_cvcuda and isinstance(transform, transforms.RandomGrayscale):
pytest.skip("CV-CUDA does not support RandomGrayscale, will have num_output_channels == 3")
check_transform(transform, make_input())

@pytest.mark.parametrize("num_output_channels", [1, 3])
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
@pytest.mark.parametrize("fn", [F.rgb_to_grayscale, transform_cls_to_functional(transforms.Grayscale)])
def test_image_correctness(self, num_output_channels, color_space, fn):
image = make_image(dtype=torch.uint8, device="cpu", color_space=color_space)
def test_image_correctness(self, num_output_channels, color_space, make_input, fn):
if make_input is make_image_cvcuda and num_output_channels == 3:
pytest.skip("CV-CUDA does not support num_output_channels == 3")

image = make_input(dtype=torch.uint8, device="cpu", color_space=color_space)

actual = fn(image, num_output_channels=num_output_channels)

if make_input is make_image_cvcuda:
image = cvcuda_to_pil_compatible_tensor(image)

expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_output_channels))

assert_equal(actual, expected, rtol=0, atol=1)
Expand Down Expand Up @@ -6421,7 +6467,17 @@ class TestGrayscaleToRgb:
def test_kernel_image(self, dtype, device):
check_kernel(F.grayscale_to_rgb_image, make_image(dtype=dtype, device=device))

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
def test_functional(self, make_input):
check_functional(F.grayscale_to_rgb, make_input())

Expand All @@ -6431,20 +6487,50 @@ def test_functional(self, make_input):
(F.rgb_to_grayscale_image, torch.Tensor),
(F._color._rgb_to_grayscale_image_pil, PIL.Image.Image),
(F.rgb_to_grayscale_image, tv_tensors.Image),
pytest.param(
F._color._rgb_to_grayscale_cvcuda,
"cvcuda.Tensor",
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
),
],
)
def test_functional_signature(self, kernel, input_type):
if input_type == "cvcuda.Tensor":
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.grayscale_to_rgb, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
def test_transform(self, make_input):
check_transform(transforms.RGB(), make_input(color_space="GRAY"))

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
@pytest.mark.parametrize("fn", [F.grayscale_to_rgb, transform_cls_to_functional(transforms.RGB)])
def test_image_correctness(self, fn):
image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY")
def test_image_correctness(self, make_input, fn):
image = make_input(dtype=torch.uint8, device="cpu", color_space="GRAY")

actual = fn(image)

if make_input is make_image_cvcuda:
image = cvcuda_to_pil_compatible_tensor(image)

expected = F.to_image(F.grayscale_to_rgb(F.to_pil_image(image)))

assert_equal(actual, expected, rtol=0, atol=1)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/v2/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
from torchvision.utils import _log_api_usage_once

from .functional._utils import _get_kernel
from .functional._utils import _get_kernel, is_cvcuda_tensor


class Transform(nn.Module):
Expand All @@ -23,7 +23,7 @@ class Transform(nn.Module):

# Class attribute defining transformed types. Other types are passed-through without any transformation
# We support both Types and callables that are able to do further checks on the type of the input.
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image)
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor)

def __init__(self) -> None:
super().__init__()
Expand Down
5 changes: 3 additions & 2 deletions torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchvision._utils import sequence_to_str

from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT


Expand Down Expand Up @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor))
}
if not chws:
raise TypeError("No image or video was found in the sample")
Expand All @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]:
tv_tensors.Mask,
tv_tensors.BoundingBoxes,
tv_tensors.KeyPoints,
is_cvcuda_tensor,
),
)
}
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torchvision.transforms import InterpolationMode # usort: skip

from ._utils import is_pure_tensor, register_kernel # usort: skip
from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip

from ._meta import (
clamp_bounding_boxes,
Expand Down
11 changes: 10 additions & 1 deletion torchvision/transforms/v2/functional/_augment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
from typing import TYPE_CHECKING

import PIL.Image

Expand All @@ -8,7 +9,15 @@
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once

from ._utils import _get_kernel, _register_kernel_internal
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal


CVCUDA_AVAILABLE = _is_cvcuda_available()

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]
if CVCUDA_AVAILABLE:
cvcuda = _import_cvcuda() # noqa: F811


def erase(
Expand Down
69 changes: 68 additions & 1 deletion torchvision/transforms/v2/functional/_color.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import TYPE_CHECKING

import PIL.Image
import torch
from torch.nn.functional import conv2d
Expand All @@ -9,7 +11,15 @@

from ._misc import _num_value_bits, to_dtype_image
from ._type_conversion import pil_to_tensor, to_pil_image
from ._utils import _get_kernel, _register_kernel_internal
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal


CVCUDA_AVAILABLE = _is_cvcuda_available()

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]
if CVCUDA_AVAILABLE:
cvcuda = _import_cvcuda() # noqa: F811


def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
Expand Down Expand Up @@ -63,6 +73,38 @@ def _rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int
return _FP.to_grayscale(image, num_output_channels=num_output_channels)


def _rgb_to_grayscale_cvcuda(
image: "cvcuda.Tensor",
num_output_channels: int = 1,
) -> "cvcuda.Tensor":
cvcuda = _import_cvcuda()

if num_output_channels not in (1, 3):
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")

if num_output_channels == 3:
raise ValueError("num_output_channels must be 1 for CV-CUDA, got 3.")

if image.shape[3] == 1:
# if we already have a single channel, just clone the tensor
# we will use copymakeborder since CV-CUDA has no native clone
return cvcuda.copymakeborder(
image,
border_mode=cvcuda.Border.CONSTANT,
border_value=[0],
top=0,
left=0,
bottom=0,
right=0,
)

return cvcuda.cvtcolor(image, cvcuda.ColorConversion.RGB2GRAY)


if CVCUDA_AVAILABLE:
_register_kernel_internal(rgb_to_grayscale, _import_cvcuda().Tensor)(_rgb_to_grayscale_cvcuda)


def grayscale_to_rgb(inpt: torch.Tensor) -> torch.Tensor:
"""See :class:`~torchvision.transforms.v2.RGB` for details."""
if torch.jit.is_scripting():
Expand All @@ -89,6 +131,31 @@ def grayscale_to_rgb_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return image.convert(mode="RGB")


def _grayscale_to_rgb_cvcuda(
image: "cvcuda.Tensor",
) -> "cvcuda.Tensor":
cvcuda = _import_cvcuda()

if image.shape[3] == 3:
# if we already have RGB channels, just clone the tensor
# we will use copymakeborder since CV-CUDA has no native clone
return cvcuda.copymakeborder(
image,
border_mode=cvcuda.Border.CONSTANT,
border_value=[0],
top=0,
left=0,
bottom=0,
right=0,
)

return cvcuda.cvtcolor(image, cvcuda.ColorConversion.GRAY2RGB)


if CVCUDA_AVAILABLE:
_register_kernel_internal(grayscale_to_rgb, _import_cvcuda().Tensor)(_grayscale_to_rgb_cvcuda)


def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
ratio = float(ratio)
fp = image1.is_floating_point()
Expand Down
Loading