diff --git a/test/common_utils.py b/test/common_utils.py index 8c3c9dd58a8..069077569a0 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -20,13 +20,15 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import io, tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available from torchvision.utils import _Image_fromarray IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" +CVCUDA_AVAILABLE = _is_cvcuda_available() CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" MPS_NOT_AVAILABLE_MSG = "MPS device not available" OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda." @@ -275,6 +277,17 @@ def combinations_grid(**kwargs): return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())] +def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor: + tensor = cvcuda_to_tensor(tensor) + if tensor.ndim != 4: + raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.") + if tensor.shape[0] != 1: + raise ValueError( + f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}." + ) + return tensor.squeeze(0).cpu() + + class ImagePair(TensorLikePair): def __init__( self, @@ -284,8 +297,16 @@ def __init__( mae=False, **other_parameters, ): - if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]): - actual, expected = (to_image(input) for input in [actual, expected]) + # Convert PIL images to tv_tensors.Image (regardless of what the other is) + if isinstance(actual, PIL.Image.Image): + actual = to_image(actual) + if isinstance(expected, PIL.Image.Image): + expected = to_image(expected) + + # handle check for CV-CUDA Tensors + if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor): + # Use the PIL compatible tensor, so we can always compare with PIL.Image.Image + actual = cvcuda_to_pil_compatible_tensor(actual) super().__init__(actual, expected, **other_parameters) self.mae = mae @@ -400,8 +421,8 @@ def make_image_pil(*args, **kwargs): return to_pil_image(make_image(*args, **kwargs)) -def make_image_cvcuda(*args, **kwargs): - return to_cvcuda_tensor(make_image(*args, **kwargs)) +def make_image_cvcuda(*args, batch_dims=(1,), **kwargs): + return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs)) def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"): diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 670a9d00ffb..086d468995b 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -21,9 +21,11 @@ import torchvision.transforms.v2 as transforms from common_utils import ( + assert_close, assert_equal, cache, cpu_and_cuda, + cvcuda_to_pil_compatible_tensor, freeze_rng_state, ignore_jit_no_profile_information_warning, make_bounding_boxes, @@ -41,7 +43,6 @@ ) from torch import nn -from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors @@ -3453,6 +3454,9 @@ def test_kernel_video(self): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), ], ) def test_functional(self, make_input): @@ -3468,16 +3472,36 @@ def test_functional(self, make_input): (F.crop_mask, tv_tensors.Mask), (F.crop_video, tv_tensors.Video), (F.crop_keypoints, tv_tensors.KeyPoints), + pytest.param( + F._geometry._crop_cvcuda, + "cvcuda.Tensor", + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if input_type == "cvcuda.Tensor": + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.crop, kernel=kernel, input_type=input_type) @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS) - def test_functional_image_correctness(self, kwargs): - image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) + def test_functional_image_correctness(self, kwargs, make_input): + image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") actual = F.crop(image, **kwargs) + + if make_input is make_image_cvcuda: + image = cvcuda_to_pil_compatible_tensor(image) + expected = F.to_image(F.crop(F.to_pil_image(image), **kwargs)) assert_equal(actual, expected) @@ -3496,15 +3520,18 @@ def test_functional_image_correctness(self, kwargs): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), ], ) def test_transform(self, param, value, make_input): - input = make_input(self.INPUT_SIZE) + input_data = make_input(self.INPUT_SIZE) check_sample_input = True if param == "fill": if isinstance(value, (tuple, list)): - if isinstance(input, tv_tensors.Mask): + if isinstance(input_data, tv_tensors.Mask): pytest.skip("F.pad_mask doesn't support non-scalar fill.") else: check_sample_input = False @@ -3513,14 +3540,14 @@ def test_transform(self, param, value, make_input): # 1. size is required # 2. the fill parameter only has an affect if we need padding size=[s + 4 for s in self.INPUT_SIZE], - fill=adapt_fill(value, dtype=input.dtype if isinstance(input, torch.Tensor) else torch.uint8), + fill=adapt_fill(value, dtype=input_data.dtype if isinstance(input_data, torch.Tensor) else torch.uint8), ) else: kwargs = {param: value} check_transform( transforms.RandomCrop(**kwargs, pad_if_needed=True), - input, + input_data, check_v1_compatibility=param != "fill" or isinstance(value, (int, float)), check_sample_input=check_sample_input, ) @@ -3562,7 +3589,16 @@ def test_transform_pad_if_needed(self): padding_mode=["constant", "edge", "reflect", "symmetric"], ) @pytest.mark.parametrize("seed", list(range(5))) - def test_transform_image_correctness(self, param, value, seed): + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) + def test_transform_image_correctness(self, param, value, seed, make_input): kwargs = {param: value} if param != "size": # 1. size is required @@ -3573,16 +3609,29 @@ def test_transform_image_correctness(self, param, value, seed): transform = transforms.RandomCrop(pad_if_needed=True, **kwargs) - image = make_image(self.INPUT_SIZE) + will_pad = False + if kwargs["size"][0] > self.INPUT_SIZE[0] or kwargs["size"][1] > self.INPUT_SIZE[1]: + will_pad = True + + image = make_input(self.INPUT_SIZE) with freeze_rng_state(): torch.manual_seed(seed) actual = transform(image) torch.manual_seed(seed) + + if make_input is make_image_cvcuda: + image = cvcuda_to_pil_compatible_tensor(image) + expected = F.to_image(transform(F.to_pil_image(image))) - assert_equal(actual, expected) + if make_input == make_image_cvcuda and will_pad: + # when padding is applied, CV-CUDA will always fill with zeros + # cannot use assert_equal since it will fail unless random is all zeros + assert_close(actual, expected, rtol=0, atol=get_max_value(image.dtype)) + else: + assert_equal(actual, expected) def _reference_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, width): affine_matrix = np.array( @@ -4406,6 +4455,9 @@ def test_kernel(self, kernel, make_input): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), ], ) def test_functional(self, make_input): @@ -4422,9 +4474,16 @@ def test_functional(self, make_input): (F.resized_crop_mask, tv_tensors.Mask), (F.resized_crop_video, tv_tensors.Video), (F.resized_crop_keypoints, tv_tensors.KeyPoints), + pytest.param( + F.resized_crop_image, + "cvcuda.Tensor", + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if input_type == "cvcuda.Tensor": + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.resized_crop, kernel=kernel, input_type=input_type) @param_value_parametrization( @@ -4441,6 +4500,9 @@ def test_functional_signature(self, kernel, input_type): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), ], ) def test_transform(self, param, value, make_input): @@ -4452,20 +4514,37 @@ def test_transform(self, param, value, make_input): # `InterpolationMode.NEAREST` is modeled after the buggy `INTER_NEAREST` interpolation of CV2. # The PIL equivalent of `InterpolationMode.NEAREST` is `InterpolationMode.NEAREST_EXACT` + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) @pytest.mark.parametrize("interpolation", set(INTERPOLATION_MODES) - {transforms.InterpolationMode.NEAREST}) - def test_functional_image_correctness(self, interpolation): - image = make_image(self.INPUT_SIZE, dtype=torch.uint8) + def test_functional_image_correctness(self, make_input, interpolation): + image = make_input(self.INPUT_SIZE, dtype=torch.uint8) actual = F.resized_crop( image, **self.CROP_KWARGS, size=self.OUTPUT_SIZE, interpolation=interpolation, antialias=True ) + + if make_input is make_image_cvcuda: + image = cvcuda_to_pil_compatible_tensor(image) + expected = F.to_image( F.resized_crop( F.to_pil_image(image), **self.CROP_KWARGS, size=self.OUTPUT_SIZE, interpolation=interpolation ) ) - torch.testing.assert_close(actual, expected, atol=1, rtol=0) + atol = 1 + if make_input is make_image_cvcuda and interpolation == transforms.InterpolationMode.BICUBIC: + # CV-CUDA BICUBIC differs from PIL ground truth BICUBIC + atol = 10 + assert_close(actual, expected, atol=atol, rtol=0) def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, width, size): new_height, new_width = size @@ -4876,6 +4955,9 @@ def test_kernel_video(self): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), ], ) def test_functional(self, make_input): @@ -4891,9 +4973,16 @@ def test_functional(self, make_input): (F.center_crop_mask, tv_tensors.Mask), (F.center_crop_video, tv_tensors.Video), (F.center_crop_keypoints, tv_tensors.KeyPoints), + pytest.param( + F._geometry._center_crop_cvcuda, + "cvcuda.Tensor", + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if input_type == "cvcuda.Tensor": + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.center_crop, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( @@ -4906,17 +4995,33 @@ def test_functional_signature(self, kernel, input_type): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), ], ) def test_transform(self, make_input): check_transform(transforms.CenterCrop(self.OUTPUT_SIZES[0]), make_input(self.INPUT_SIZE)) @pytest.mark.parametrize("output_size", OUTPUT_SIZES) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) @pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)]) - def test_image_correctness(self, output_size, fn): - image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") + def test_image_correctness(self, output_size, make_input, fn): + image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") actual = fn(image, output_size) + + if make_input is make_image_cvcuda: + image = cvcuda_to_pil_compatible_tensor(image) + expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size)) assert_equal(actual, expected) @@ -6191,7 +6296,15 @@ def wrapper(*args, **kwargs): @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image_pil, make_image, make_video], + [ + make_image_tensor, + make_image_pil, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], ) @pytest.mark.parametrize("functional", [F.five_crop, F.ten_crop]) def test_functional(self, make_input, functional): @@ -6209,13 +6322,27 @@ def test_functional(self, make_input, functional): (F.five_crop, F._geometry._five_crop_image_pil, PIL.Image.Image), (F.five_crop, F.five_crop_image, tv_tensors.Image), (F.five_crop, F.five_crop_video, tv_tensors.Video), + pytest.param( + F.five_crop, + F._geometry._five_crop_cvcuda, + "cvcuda.Tensor", + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), + ), (F.ten_crop, F.ten_crop_image, torch.Tensor), (F.ten_crop, F._geometry._ten_crop_image_pil, PIL.Image.Image), (F.ten_crop, F.ten_crop_image, tv_tensors.Image), (F.ten_crop, F.ten_crop_video, tv_tensors.Video), + pytest.param( + F.ten_crop, + F._geometry._ten_crop_cvcuda, + "cvcuda.Tensor", + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), + ), ], ) def test_functional_signature(self, functional, kernel, input_type): + if input_type == "cvcuda.Tensor": + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(functional, kernel=kernel, input_type=input_type) class _TransformWrapper(nn.Module): @@ -6237,7 +6364,15 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image_pil, make_image, make_video], + [ + make_image_tensor, + make_image_pil, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], ) @pytest.mark.parametrize("transform_cls", [transforms.FiveCrop, transforms.TenCrop]) def test_transform(self, make_input, transform_cls): @@ -6255,19 +6390,41 @@ def test_transform_error(self, make_input, transform_cls): with pytest.raises(TypeError, match="not supported"): transform(make_input(self.INPUT_SIZE)) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) @pytest.mark.parametrize("fn", [F.five_crop, transform_cls_to_functional(transforms.FiveCrop)]) - def test_correctness_image_five_crop(self, fn): - image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") + def test_correctness_image_five_crop(self, make_input, fn): + image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") actual = fn(image, size=self.OUTPUT_SIZE) + + if make_input is make_image_cvcuda: + image = cvcuda_to_pil_compatible_tensor(image) + expected = F.five_crop(F.to_pil_image(image), size=self.OUTPUT_SIZE) assert isinstance(actual, tuple) assert_equal(actual, [F.to_image(e) for e in expected]) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) @pytest.mark.parametrize("fn_or_class", [F.ten_crop, transforms.TenCrop]) @pytest.mark.parametrize("vertical_flip", [False, True]) - def test_correctness_image_ten_crop(self, fn_or_class, vertical_flip): + def test_correctness_image_ten_crop(self, make_input, fn_or_class, vertical_flip): if fn_or_class is transforms.TenCrop: fn = transform_cls_to_functional(fn_or_class, size=self.OUTPUT_SIZE, vertical_flip=vertical_flip) kwargs = dict() @@ -6275,9 +6432,13 @@ def test_correctness_image_ten_crop(self, fn_or_class, vertical_flip): fn = fn_or_class kwargs = dict(size=self.OUTPUT_SIZE, vertical_flip=vertical_flip) - image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") + image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") actual = fn(image, **kwargs) + + if make_input is make_image_cvcuda: + image = cvcuda_to_pil_compatible_tensor(image) + expected = F.ten_crop(F.to_pil_image(image), size=self.OUTPUT_SIZE, vertical_flip=vertical_flip) assert isinstance(actual, tuple) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 1418a6b4953..d7cc9ea72f3 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -26,6 +26,7 @@ get_bounding_boxes, has_all, has_any, + is_cvcuda_tensor, is_pure_tensor, query_size, ) @@ -45,6 +46,8 @@ class RandomHorizontalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomHorizontalFlip + _transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,) + def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel(F.horizontal_flip, inpt) @@ -63,6 +66,8 @@ class RandomVerticalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomVerticalFlip + _transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,) + def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel(F.vertical_flip, inpt) @@ -186,6 +191,8 @@ class CenterCrop(Transform): _v1_transform_cls = _transforms.CenterCrop + _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + def __init__(self, size: Union[int, Sequence[int]]): super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") @@ -244,6 +251,8 @@ class RandomResizedCrop(Transform): _v1_transform_cls = _transforms.RandomResizedCrop + _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + def __init__( self, size: Union[int, Sequence[int]], @@ -352,6 +361,8 @@ class FiveCrop(Transform): _v1_transform_cls = _transforms.FiveCrop + _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + def __init__(self, size: Union[int, Sequence[int]]) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") @@ -396,6 +407,8 @@ class TenCrop(Transform): _v1_transform_cls = _transforms.TenCrop + _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") @@ -803,6 +816,8 @@ class RandomCrop(Transform): _v1_transform_cls = _transforms.RandomCrop + _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + def _extract_params_for_v1_transform(self) -> dict[str, Any]: params = super()._extract_params_for_v1_transform() @@ -1113,6 +1128,8 @@ class RandomIoUCrop(Transform): Default, 40. """ + _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + def __init__( self, min_scale: float = 0.3, diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index ac84fcb6c82..b0985bb0aec 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -8,7 +8,7 @@ from torch import nn from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import tv_tensors -from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor +from torchvision.transforms.v2._utils import check_type, has_any, is_cvcuda_tensor, is_pure_tensor from torchvision.utils import _log_api_usage_once from .functional._utils import _get_kernel @@ -90,7 +90,9 @@ def _needs_transform_list(self, flat_inputs: list[Any]) -> list[bool]: # However, this case wasn't supported by transforms v1 either, so there is no BC concern. needs_transform_list = [] - transform_pure_tensor = not has_any(flat_inputs, tv_tensors.Image, tv_tensors.Video, PIL.Image.Image) + transform_pure_tensor = not has_any( + flat_inputs, tv_tensors.Image, tv_tensors.Video, PIL.Image.Image, is_cvcuda_tensor + ) for inpt in flat_inputs: needs_transform = True diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index bb6051b4e61..3fc33ce5964 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -15,7 +15,7 @@ from torchvision._utils import sequence_to_str from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 -from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor +from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor)) } if not chws: raise TypeError("No image or video was found in the sample") @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, + is_cvcuda_tensor, ), ) } diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 032a993b1f0..52181e4624b 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import is_pure_tensor, register_kernel # usort: skip +from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip from ._meta import ( clamp_bounding_boxes, diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 4fcb7fabe0d..efb9a594cca 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -2,7 +2,7 @@ import numbers import warnings from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union import PIL.Image import torch @@ -26,7 +26,21 @@ from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format -from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal +from ._utils import ( + _FillTypeJIT, + _get_kernel, + _import_cvcuda, + _is_cvcuda_available, + _register_five_ten_crop_kernel_internal, + _register_kernel_internal, +) + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: @@ -128,6 +142,14 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: return horizontal_flip_image(video) +def _horizontal_flip_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": + return _import_cvcuda().flip(image, flipCode=1) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)(_horizontal_flip_cvcuda) + + def vertical_flip(inpt: torch.Tensor) -> torch.Tensor: """See :class:`~torchvision.transforms.v2.RandomVerticalFlip` for details.""" if torch.jit.is_scripting(): @@ -216,6 +238,14 @@ def vertical_flip_video(video: torch.Tensor) -> torch.Tensor: return vertical_flip_image(video) +def _vertical_flip_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": + return _import_cvcuda().flip(image, flipCode=0) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)(_vertical_flip_cvcuda) + + # We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are # prevalent and well understood. Thus, we just alias them without deprecating the old names. hflip = horizontal_flip @@ -575,6 +605,32 @@ def resize_video( return resize_image(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) +def _resize_cvcuda( + image: "cvcuda.Tensor", + size: Optional[list[int]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = True, +) -> "cvcuda.Tensor": + # placeholder func for now, will be handled in PR for resize alone + # since placeholder convert to from torch tensor and use resize_image + from ._type_conversion import cvcuda_to_tensor, to_cvcuda_tensor + + return to_cvcuda_tensor( + resize_image( + cvcuda_to_tensor(image), + size=size, + interpolation=interpolation, + max_size=max_size, + antialias=antialias, + ) + ) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(resize, _import_cvcuda().Tensor)(_resize_cvcuda) + + def affine( inpt: torch.Tensor, angle: Union[int, float], @@ -1897,6 +1953,50 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int return crop_image(video, top, left, height, width) +def _crop_cvcuda( + image: "cvcuda.Tensor", + top: int, + left: int, + height: int, + width: int, +) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + + image_height, image_width, channels = image.shape[1:] + top_diff = 0 + left_diff = 0 + height_diff = 0 + width_diff = 0 + if top < 0: + top_diff = int(-1 * top) + if left < 0: + left_diff = int(-1 * left) + if top + height > image_height: + height_diff = int(top + height - image_height) + if left + width > image_width: + width_diff = int(left + width - image_width) + if top_diff or left_diff or height_diff or width_diff: + image = cvcuda.copymakeborder( + image, + border_mode=cvcuda.Border.CONSTANT, + border_value=[0.0] * channels, + top=top_diff, + left=left_diff, + bottom=height_diff, + right=width_diff, + ) + top = top + top_diff + left = left + left_diff + return cvcuda.customcrop( + image, + cvcuda.RectI(x=left, y=top, width=width, height=height), + ) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(crop, _import_cvcuda().Tensor)(_crop_cvcuda) + + def perspective( inpt: torch.Tensor, startpoints: Optional[list[list[int]]], @@ -2647,6 +2747,45 @@ def center_crop_video(video: torch.Tensor, output_size: list[int]) -> torch.Tens return center_crop_image(video, output_size) +def _center_crop_cvcuda( + image: "cvcuda.Tensor", + output_size: list[int], +) -> "cvcuda.Tensor": + crop_height, crop_width = _center_crop_parse_output_size(output_size) + # we only allow cvcuda conversion for 4 ndim, and always use nhwc layout + image_height = image.shape[1] + image_width = image.shape[2] + channels = image.shape[3] + if crop_height > image_height or crop_width > image_width: + padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) + image = cvcuda.copymakeborder( + image, + border_mode=cvcuda.Border.CONSTANT, + border_value=[0.0] * channels, + top=padding_ltrb[1], + left=padding_ltrb[0], + bottom=padding_ltrb[3], + right=padding_ltrb[2], + ) + + image_height = image.shape[1] + image_width = image.shape[2] + + if crop_width == image_width and crop_height == image_height: + return image + + # use customcrop to match crop_image behavior + crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width) + return cvcuda.customcrop( + image, + cvcuda.RectI(x=crop_left, y=crop_top, width=crop_width, height=crop_height), + ) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(center_crop, _import_cvcuda().Tensor)(_center_crop_cvcuda) + + def resized_crop( inpt: torch.Tensor, top: int, @@ -2833,6 +2972,24 @@ def resized_crop_video( ) +def _resized_crop_cvcuda( + image: "cvcuda.Tensor", + top: int, + left: int, + height: int, + width: int, + size: list[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, +) -> "cvcuda.Tensor": + image = _crop_cvcuda(image, top, left, height, width) + return _resize_cvcuda(image, size, interpolation=interpolation, antialias=antialias) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(resized_crop, _import_cvcuda().Tensor)(_resized_crop_cvcuda) + + def five_crop( inpt: torch.Tensor, size: list[int] ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -2906,6 +3063,29 @@ def five_crop_video( return five_crop_image(video, size) +def _five_crop_cvcuda( + image: "cvcuda.Tensor", + size: list[int], +) -> tuple["cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor"]: + crop_height, crop_width = _parse_five_crop_size(size) + image_height, image_width = image.shape[1], image.shape[2] + + if crop_width > image_width or crop_height > image_height: + raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}") + + tl = _crop_cvcuda(image, 0, 0, crop_height, crop_width) + tr = _crop_cvcuda(image, 0, image_width - crop_width, crop_height, crop_width) + bl = _crop_cvcuda(image, image_height - crop_height, 0, crop_height, crop_width) + br = _crop_cvcuda(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width) + center = _center_crop_cvcuda(image, [crop_height, crop_width]) + + return tl, tr, bl, br, center + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(five_crop, _import_cvcuda().Tensor)(_five_crop_cvcuda) + + def ten_crop( inpt: torch.Tensor, size: list[int], vertical_flip: bool = False ) -> tuple[ @@ -3001,3 +3181,35 @@ def ten_crop_video( torch.Tensor, ]: return ten_crop_image(video, size, vertical_flip=vertical_flip) + + +def _ten_crop_cvcuda( + image: "cvcuda.Tensor", + size: list[int], + vertical_flip: bool = False, +) -> tuple[ + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", +]: + non_flipped = _five_crop_cvcuda(image, size) + + if vertical_flip: + image = _vertical_flip_cvcuda(image) + else: + image = _horizontal_flip_cvcuda(image) + + flipped = _five_crop_cvcuda(image, size) + + return non_flipped + flipped + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(ten_crop, _import_cvcuda().Tensor)(_ten_crop_cvcuda) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index ad1eddd258b..5485d5364a0 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -5,6 +5,7 @@ import torch from torchvision import tv_tensors + _FillType = Union[int, float, Sequence[int], Sequence[float], None] _FillTypeJIT = Optional[list[float]] @@ -169,3 +170,11 @@ def _is_cvcuda_available(): return True except ImportError: return False + + +def is_cvcuda_tensor(inpt: Any) -> bool: + try: + cvcuda = _import_cvcuda() + return isinstance(inpt, cvcuda.Tensor) + except ImportError: + return False