From 3b178c49efc0f5a287a550154e48b926945a55ba Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 18 Nov 2025 12:14:45 -0800 Subject: [PATCH 01/11] initial cvcuda crop implementation, only minimal tests so far --- test/test_transforms_v2.py | 32 ++++++++++++ .../transforms/v2/functional/__init__.py | 2 + .../transforms/v2/functional/_geometry.py | 50 ++++++++++++++++++- 3 files changed, 82 insertions(+), 2 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 670a9d00ffb..7937e155c7a 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3712,6 +3712,18 @@ def test_errors(self): transforms.RandomCrop([10, 12], padding=1, padding_mode="abc") +@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="cvcuda not available") +@needs_cuda +class TestCropCVCUDA: + def test_functional(self): + check_functional( + F.crop, make_image_cvcuda(TestCrop.INPUT_SIZE, batch_dims=(1,)), **TestCrop.MINIMAL_CROP_KWARGS + ) + + def test_functional_signature(self): + check_functional_kernel_signature_match(F.crop, kernel=F.crop_cvcuda, input_type=cvcuda.Tensor) + + class TestErase: INPUT_SIZE = (17, 11) FUNCTIONAL_KWARGS = dict( @@ -4992,6 +5004,26 @@ def test_keypoints_correctness(self, output_size, dtype, device, fn): assert_equal(actual, expected) +@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="cvcuda not available") +@needs_cuda +class TestCenterCropCVCUDA: + def test_functional(self): + check_functional( + F.center_crop, + make_image_cvcuda(TestCenterCrop.INPUT_SIZE, batch_dims=(1,)), + output_size=TestCenterCrop.OUTPUT_SIZES[0], + ) + + def test_functional_signature(self): + check_functional_kernel_signature_match(F.center_crop, kernel=F.center_crop_cvcuda, input_type=cvcuda.Tensor) + + def test_transform(self): + check_transform( + transforms.CenterCrop(TestCenterCrop.OUTPUT_SIZES[0]), + make_image_cvcuda(TestCenterCrop.INPUT_SIZE, batch_dims=(1,)), + ) + + class TestPerspective: COEFFICIENTS = [ [1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018], diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 032a993b1f0..9b437dfd8a8 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -76,12 +76,14 @@ affine_video, center_crop, center_crop_bounding_boxes, + center_crop_cvcuda, center_crop_image, center_crop_keypoints, center_crop_mask, center_crop_video, crop, crop_bounding_boxes, + crop_cvcuda, crop_image, crop_keypoints, crop_mask, diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 4fcb7fabe0d..b66dc0cabbb 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -2,7 +2,7 @@ import numbers import warnings from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union import PIL.Image import torch @@ -26,7 +26,21 @@ from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format -from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal +from ._utils import ( + _FillTypeJIT, + _get_kernel, + _import_cvcuda, + _is_cvcuda_available, + _register_five_ten_crop_kernel_internal, + _register_kernel_internal, +) + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: @@ -1897,6 +1911,23 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int return crop_image(video, top, left, height, width) +def crop_cvcuda( + image: "cvcuda.Tensor", + top: int, + left: int, + height: int, + width: int, +) -> "cvcuda.Tensor": + return cvcuda.customcrop( + image, + cvcuda.RectI(x=left, y=top, width=width, height=height), + ) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(crop, cvcuda.Tensor)(crop_cvcuda) + + def perspective( inpt: torch.Tensor, startpoints: Optional[list[list[int]]], @@ -2647,6 +2678,21 @@ def center_crop_video(video: torch.Tensor, output_size: list[int]) -> torch.Tens return center_crop_image(video, output_size) +def center_crop_cvcuda( + image: "cvcuda.Tensor", + output_size: list[int], +) -> "cvcuda.Tensor": + crop_height, crop_width = _center_crop_parse_output_size(output_size) + return cvcuda.center_crop( + image, + crop_size=(crop_width, crop_height), + ) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(center_crop, cvcuda.Tensor)(center_crop_cvcuda) + + def resized_crop( inpt: torch.Tensor, top: int, From a24a3b996fde48a5e9f07970004a9bb6bc06042a Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 18 Nov 2025 13:10:40 -0800 Subject: [PATCH 02/11] add padding to centercrop and if needed to crop --- .../transforms/v2/functional/_geometry.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index b66dc0cabbb..26ad4be38fd 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1918,6 +1918,33 @@ def crop_cvcuda( height: int, width: int, ) -> "cvcuda.Tensor": + image_height, image_width, channels = image.shape[1:] + top_diff = 0 + left_diff = 0 + height_diff = 0 + width_diff = 0 + if top < 0: + top_diff = -1 * top + if left < 0: + left_diff = -1 * left + if top + height > image_height: + height_diff = top + height - image_height + if left + width > image_width: + width_diff = left + width - image_width + if top_diff or left_diff or height_diff or width_diff: + image = cvcuda.copymakeborder( + image, + top=top_diff, + left=left_diff, + bottom=height_diff, + right=width_diff, + border_mode=cvcuda.Border.CONSTANT, + value=[0.0] * channels, + ) + top = 0 + left = 0 + height = image_height + width = image_width return cvcuda.customcrop( image, cvcuda.RectI(x=left, y=top, width=width, height=height), @@ -2683,6 +2710,21 @@ def center_crop_cvcuda( output_size: list[int], ) -> "cvcuda.Tensor": crop_height, crop_width = _center_crop_parse_output_size(output_size) + # we only allow cvcuda conversion for 4 ndim, and always use nhwc layout + image_height = image.shape[1] + image_width = image.shape[2] + channels = image.shape[3] + if crop_height > image_height or crop_width > image_width: + padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) + image = cvcuda.copymakeborder( + image, + top=padding_ltrb[1], + left=padding_ltrb[0], + bottom=padding_ltrb[3], + right=padding_ltrb[2], + border_mode=cvcuda.Border.CONSTANT, + value=[0.0] * channels, + ) return cvcuda.center_crop( image, crop_size=(crop_width, crop_height), From f8e435689831dda758497d38c2977843f9cb4893 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 18 Nov 2025 13:47:40 -0800 Subject: [PATCH 03/11] test padding for crop_cvcuda, add functional test --- test/test_transforms_v2.py | 7 +++++++ .../transforms/v2/functional/_geometry.py | 18 ++++++++---------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 7937e155c7a..0b68f408ab4 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3723,6 +3723,13 @@ def test_functional(self): def test_functional_signature(self): check_functional_kernel_signature_match(F.crop, kernel=F.crop_cvcuda, input_type=cvcuda.Tensor) + @pytest.mark.parametrize("size", [(10, 5), (25, 15), (25, 5), (10, 15)]) + def test_functional_correctness(self, size): + image = make_image_cvcuda(TestCrop.INPUT_SIZE, batch_dims=(1,)) + actual = F.crop(image, 0, 0, *size) + expected = F.crop(F.cvcuda_to_tensor(image), 0, 0, *size) + assert_equal(F.cvcuda_to_tensor(actual), expected) + class TestErase: INPUT_SIZE = (17, 11) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 26ad4be38fd..f1a297a48c0 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1924,27 +1924,25 @@ def crop_cvcuda( height_diff = 0 width_diff = 0 if top < 0: - top_diff = -1 * top + top_diff = int(-1 * top) if left < 0: - left_diff = -1 * left + left_diff = int(-1 * left) if top + height > image_height: - height_diff = top + height - image_height + height_diff = int(top + height - image_height) if left + width > image_width: - width_diff = left + width - image_width + width_diff = int(left + width - image_width) if top_diff or left_diff or height_diff or width_diff: image = cvcuda.copymakeborder( image, + border_mode=cvcuda.Border.CONSTANT, + border_value=[0.0] * channels, top=top_diff, left=left_diff, bottom=height_diff, right=width_diff, - border_mode=cvcuda.Border.CONSTANT, - value=[0.0] * channels, ) - top = 0 - left = 0 - height = image_height - width = image_width + top = top + top_diff + left = left + left_diff return cvcuda.customcrop( image, cvcuda.RectI(x=left, y=top, width=width, height=height), From 8750ce1e12cabd44ba3dc66bb98307911d4777cd Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 18 Nov 2025 15:27:21 -0800 Subject: [PATCH 04/11] center_crop passes functional equiv --- test/test_transforms_v2.py | 7 +++++++ .../transforms/v2/functional/_geometry.py | 17 +++++++++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 0b68f408ab4..4cc22a8991d 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5024,6 +5024,13 @@ def test_functional(self): def test_functional_signature(self): check_functional_kernel_signature_match(F.center_crop, kernel=F.center_crop_cvcuda, input_type=cvcuda.Tensor) + @pytest.mark.parametrize("output_size", TestCenterCrop.OUTPUT_SIZES) + def test_functional_correctness(self, output_size): + image = make_image_cvcuda(TestCenterCrop.INPUT_SIZE, batch_dims=(1,)) + actual = F.center_crop(image, output_size) + expected = F.center_crop(F.cvcuda_to_tensor(image), output_size) + assert_equal(F.cvcuda_to_tensor(actual), expected) + def test_transform(self): check_transform( transforms.CenterCrop(TestCenterCrop.OUTPUT_SIZES[0]), diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index f1a297a48c0..83b1aa256ce 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -2716,16 +2716,25 @@ def center_crop_cvcuda( padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) image = cvcuda.copymakeborder( image, + border_mode=cvcuda.Border.CONSTANT, + border_value=[0.0] * channels, top=padding_ltrb[1], left=padding_ltrb[0], bottom=padding_ltrb[3], right=padding_ltrb[2], - border_mode=cvcuda.Border.CONSTANT, - value=[0.0] * channels, ) - return cvcuda.center_crop( + + image_height = image.shape[1] + image_width = image.shape[2] + + if crop_width == image_width and crop_height == image_height: + return image + + # use customcrop to match crop_image behavior + crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width) + return cvcuda.customcrop( image, - crop_size=(crop_width, crop_height), + cvcuda.RectI(x=crop_left, y=crop_top, width=crop_width, height=crop_height), ) From 7696ff460d4ce56a8744f97367f2fab4a3fcb973 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 20 Nov 2025 13:18:45 -0800 Subject: [PATCH 05/11] fix: crop testing, adhere to conventions --- test/common_utils.py | 5 +- test/test_transforms_v2.py | 63 ++++++++++++------- torchvision/transforms/v2/_transform.py | 8 ++- torchvision/transforms/v2/_utils.py | 5 +- .../transforms/v2/functional/__init__.py | 4 +- .../transforms/v2/functional/_geometry.py | 6 +- .../transforms/v2/functional/_utils.py | 9 +++ 7 files changed, 65 insertions(+), 35 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 8c3c9dd58a8..e7bae60c41b 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -400,8 +400,9 @@ def make_image_pil(*args, **kwargs): return to_pil_image(make_image(*args, **kwargs)) -def make_image_cvcuda(*args, **kwargs): - return to_cvcuda_tensor(make_image(*args, **kwargs)) +def make_image_cvcuda(*args, batch_dims=(1,), **kwargs): + # explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4) + return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs)) def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"): diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 4cc22a8991d..757b7e991c4 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3453,6 +3453,9 @@ def test_kernel_video(self): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), ], ) def test_functional(self, make_input): @@ -3468,6 +3471,11 @@ def test_functional(self, make_input): (F.crop_mask, tv_tensors.Mask), (F.crop_video, tv_tensors.Video), (F.crop_keypoints, tv_tensors.KeyPoints), + pytest.param( + F._geometry._crop_cvcuda, + _import_cvcuda().Tensor, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), + ), ], ) def test_functional_signature(self, kernel, input_type): @@ -3496,15 +3504,18 @@ def test_functional_image_correctness(self, kwargs): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), ], ) def test_transform(self, param, value, make_input): - input = make_input(self.INPUT_SIZE) + input_data = make_input(self.INPUT_SIZE) check_sample_input = True if param == "fill": if isinstance(value, (tuple, list)): - if isinstance(input, tv_tensors.Mask): + if isinstance(input_data, tv_tensors.Mask): pytest.skip("F.pad_mask doesn't support non-scalar fill.") else: check_sample_input = False @@ -3513,14 +3524,14 @@ def test_transform(self, param, value, make_input): # 1. size is required # 2. the fill parameter only has an affect if we need padding size=[s + 4 for s in self.INPUT_SIZE], - fill=adapt_fill(value, dtype=input.dtype if isinstance(input, torch.Tensor) else torch.uint8), + fill=adapt_fill(value, dtype=input_data.dtype if isinstance(input_data, torch.Tensor) else torch.uint8), ) else: kwargs = {param: value} check_transform( transforms.RandomCrop(**kwargs, pad_if_needed=True), - input, + input_data, check_v1_compatibility=param != "fill" or isinstance(value, (int, float)), check_sample_input=check_sample_input, ) @@ -3584,6 +3595,31 @@ def test_transform_image_correctness(self, param, value, seed): assert_equal(actual, expected) + @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + @pytest.mark.parametrize("size", [(10, 5), (25, 15), (25, 5), (10, 15), (10, 10)]) + @pytest.mark.parametrize("seed", list(range(5))) + def test_transform_cvcuda_correctness(self, size, seed): + pad_if_needed = False + if size[0] > self.INPUT_SIZE[0] or size[1] > self.INPUT_SIZE[1]: + pad_if_needed = True + transform = transforms.RandomCrop(size, pad_if_needed=pad_if_needed) + + image = make_image(size=self.INPUT_SIZE, batch_dims=(1,), device="cuda") + cv_image = F.to_cvcuda_tensor(image) + + with freeze_rng_state(): + torch.manual_seed(seed) + actual = transform(cv_image) + + torch.manual_seed(seed) + expected = transform(image) + + if not pad_if_needed: + torch.testing.assert_close(F.cvcuda_to_tensor(actual), expected, rtol=0, atol=0) + else: + # if padding is requied, CV-CUDA will always fill with zeros + torch.testing.assert_close(F.cvcuda_to_tensor(actual), expected, rtol=0, atol=get_max_value(image.dtype)) + def _reference_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, width): affine_matrix = np.array( [ @@ -3712,25 +3748,6 @@ def test_errors(self): transforms.RandomCrop([10, 12], padding=1, padding_mode="abc") -@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="cvcuda not available") -@needs_cuda -class TestCropCVCUDA: - def test_functional(self): - check_functional( - F.crop, make_image_cvcuda(TestCrop.INPUT_SIZE, batch_dims=(1,)), **TestCrop.MINIMAL_CROP_KWARGS - ) - - def test_functional_signature(self): - check_functional_kernel_signature_match(F.crop, kernel=F.crop_cvcuda, input_type=cvcuda.Tensor) - - @pytest.mark.parametrize("size", [(10, 5), (25, 15), (25, 5), (10, 15)]) - def test_functional_correctness(self, size): - image = make_image_cvcuda(TestCrop.INPUT_SIZE, batch_dims=(1,)) - actual = F.crop(image, 0, 0, *size) - expected = F.crop(F.cvcuda_to_tensor(image), 0, 0, *size) - assert_equal(F.cvcuda_to_tensor(actual), expected) - - class TestErase: INPUT_SIZE = (17, 11) FUNCTIONAL_KWARGS = dict( diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index ac84fcb6c82..c7b32223b8b 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -11,7 +11,7 @@ from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor from torchvision.utils import _log_api_usage_once -from .functional._utils import _get_kernel +from .functional._utils import _get_kernel, is_cvcuda_tensor class Transform(nn.Module): @@ -23,7 +23,7 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. - _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image) + _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor) def __init__(self) -> None: super().__init__() @@ -90,7 +90,9 @@ def _needs_transform_list(self, flat_inputs: list[Any]) -> list[bool]: # However, this case wasn't supported by transforms v1 either, so there is no BC concern. needs_transform_list = [] - transform_pure_tensor = not has_any(flat_inputs, tv_tensors.Image, tv_tensors.Video, PIL.Image.Image) + transform_pure_tensor = not has_any( + flat_inputs, tv_tensors.Image, tv_tensors.Video, PIL.Image.Image, is_cvcuda_tensor + ) for inpt in flat_inputs: needs_transform = True diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index bb6051b4e61..3fc33ce5964 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -15,7 +15,7 @@ from torchvision._utils import sequence_to_str from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 -from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor +from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor)) } if not chws: raise TypeError("No image or video was found in the sample") @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, + is_cvcuda_tensor, ), ) } diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 9b437dfd8a8..52181e4624b 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import is_pure_tensor, register_kernel # usort: skip +from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip from ._meta import ( clamp_bounding_boxes, @@ -76,14 +76,12 @@ affine_video, center_crop, center_crop_bounding_boxes, - center_crop_cvcuda, center_crop_image, center_crop_keypoints, center_crop_mask, center_crop_video, crop, crop_bounding_boxes, - crop_cvcuda, crop_image, crop_keypoints, crop_mask, diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 83b1aa256ce..42c70f430a7 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1911,13 +1911,15 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int return crop_image(video, top, left, height, width) -def crop_cvcuda( +def _crop_cvcuda( image: "cvcuda.Tensor", top: int, left: int, height: int, width: int, ) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + image_height, image_width, channels = image.shape[1:] top_diff = 0 left_diff = 0 @@ -1950,7 +1952,7 @@ def crop_cvcuda( if CVCUDA_AVAILABLE: - _register_kernel_internal(crop, cvcuda.Tensor)(crop_cvcuda) + _crop_cvcuda_registered = _register_kernel_internal(crop, _import_cvcuda().Tensor)(_crop_cvcuda) def perspective( diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index ad1eddd258b..5485d5364a0 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -5,6 +5,7 @@ import torch from torchvision import tv_tensors + _FillType = Union[int, float, Sequence[int], Sequence[float], None] _FillTypeJIT = Optional[list[float]] @@ -169,3 +170,11 @@ def _is_cvcuda_available(): return True except ImportError: return False + + +def is_cvcuda_tensor(inpt: Any) -> bool: + try: + cvcuda = _import_cvcuda() + return isinstance(inpt, cvcuda.Tensor) + except ImportError: + return False From bda657086c1522051b7072951557516fc076e2aa Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 20 Nov 2025 13:24:30 -0800 Subject: [PATCH 06/11] Fix: update center crop --- test/test_transforms_v2.py | 49 +++++++++---------- .../transforms/v2/functional/_geometry.py | 6 ++- 2 files changed, 26 insertions(+), 29 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 757b7e991c4..ad19cbe974e 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -4912,6 +4912,9 @@ def test_kernel_video(self): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), ], ) def test_functional(self, make_input): @@ -4927,6 +4930,11 @@ def test_functional(self, make_input): (F.center_crop_mask, tv_tensors.Mask), (F.center_crop_video, tv_tensors.Video), (F.center_crop_keypoints, tv_tensors.KeyPoints), + pytest.param( + F._geometry._center_crop_cvcuda, + _import_cvcuda().Tensor, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), + ), ], ) def test_functional_signature(self, kernel, input_type): @@ -4942,6 +4950,9 @@ def test_functional_signature(self, kernel, input_type): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), ], ) def test_transform(self, make_input): @@ -4957,6 +4968,17 @@ def test_image_correctness(self, output_size, fn): assert_equal(actual, expected) + @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + @pytest.mark.parametrize("output_size", OUTPUT_SIZES) + @pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)]) + def test_cvcuda_correctness(self, output_size, fn): + image = make_image_cvcuda(self.INPUT_SIZE, dtype=torch.uint8, device="cuda") + + actual = fn(image, output_size) + expected = F.center_crop(F.cvcuda_to_tensor(image), output_size) + + assert_equal(F.cvcuda_to_tensor(actual), expected) + def _reference_center_crop_bounding_boxes(self, bounding_boxes, output_size): image_height, image_width = bounding_boxes.canvas_size if isinstance(output_size, int): @@ -5028,33 +5050,6 @@ def test_keypoints_correctness(self, output_size, dtype, device, fn): assert_equal(actual, expected) -@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="cvcuda not available") -@needs_cuda -class TestCenterCropCVCUDA: - def test_functional(self): - check_functional( - F.center_crop, - make_image_cvcuda(TestCenterCrop.INPUT_SIZE, batch_dims=(1,)), - output_size=TestCenterCrop.OUTPUT_SIZES[0], - ) - - def test_functional_signature(self): - check_functional_kernel_signature_match(F.center_crop, kernel=F.center_crop_cvcuda, input_type=cvcuda.Tensor) - - @pytest.mark.parametrize("output_size", TestCenterCrop.OUTPUT_SIZES) - def test_functional_correctness(self, output_size): - image = make_image_cvcuda(TestCenterCrop.INPUT_SIZE, batch_dims=(1,)) - actual = F.center_crop(image, output_size) - expected = F.center_crop(F.cvcuda_to_tensor(image), output_size) - assert_equal(F.cvcuda_to_tensor(actual), expected) - - def test_transform(self): - check_transform( - transforms.CenterCrop(TestCenterCrop.OUTPUT_SIZES[0]), - make_image_cvcuda(TestCenterCrop.INPUT_SIZE, batch_dims=(1,)), - ) - - class TestPerspective: COEFFICIENTS = [ [1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018], diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 42c70f430a7..0061c84c9ee 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -2705,7 +2705,7 @@ def center_crop_video(video: torch.Tensor, output_size: list[int]) -> torch.Tens return center_crop_image(video, output_size) -def center_crop_cvcuda( +def _center_crop_cvcuda( image: "cvcuda.Tensor", output_size: list[int], ) -> "cvcuda.Tensor": @@ -2741,7 +2741,9 @@ def center_crop_cvcuda( if CVCUDA_AVAILABLE: - _register_kernel_internal(center_crop, cvcuda.Tensor)(center_crop_cvcuda) + _center_crop_cvcuda_registered = _register_kernel_internal(center_crop, _import_cvcuda().Tensor)( + _center_crop_cvcuda + ) def resized_crop( From 27663fa73c27287e6c83310787b120d74d63d059 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 24 Nov 2025 13:42:54 -0800 Subject: [PATCH 07/11] handle some comments from other prs review --- test/test_transforms_v2.py | 8 ++++++-- torchvision/transforms/v2/functional/_geometry.py | 6 ++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index ad19cbe974e..0a267e26704 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3473,12 +3473,14 @@ def test_functional(self, make_input): (F.crop_keypoints, tv_tensors.KeyPoints), pytest.param( F._geometry._crop_cvcuda, - _import_cvcuda().Tensor, + "cvcuda.Tensor", marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), ), ], ) def test_functional_signature(self, kernel, input_type): + if input_type == "cvcuda.Tensor": + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.crop, kernel=kernel, input_type=input_type) @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS) @@ -4932,12 +4934,14 @@ def test_functional(self, make_input): (F.center_crop_keypoints, tv_tensors.KeyPoints), pytest.param( F._geometry._center_crop_cvcuda, - _import_cvcuda().Tensor, + "cvcuda.Tensor", marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), ), ], ) def test_functional_signature(self, kernel, input_type): + if input_type == "cvcuda.Tensor": + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.center_crop, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 0061c84c9ee..0515daa44bd 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1952,7 +1952,7 @@ def _crop_cvcuda( if CVCUDA_AVAILABLE: - _crop_cvcuda_registered = _register_kernel_internal(crop, _import_cvcuda().Tensor)(_crop_cvcuda) + _register_kernel_internal(crop, _import_cvcuda().Tensor)(_crop_cvcuda) def perspective( @@ -2741,9 +2741,7 @@ def _center_crop_cvcuda( if CVCUDA_AVAILABLE: - _center_crop_cvcuda_registered = _register_kernel_internal(center_crop, _import_cvcuda().Tensor)( - _center_crop_cvcuda - ) + _register_kernel_internal(center_crop, _import_cvcuda().Tensor)(_center_crop_cvcuda) def resized_crop( From 5d5c436071ac10f510f7024e04aa8866a63f9dff Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Wed, 26 Nov 2025 10:42:37 -0800 Subject: [PATCH 08/11] simplify and improve crop testing for cvcuda --- test/test_transforms_v2.py | 72 +++++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 28 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 0a267e26704..59bfb93117d 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3484,10 +3484,26 @@ def test_functional_signature(self, kernel, input_type): check_functional_kernel_signature_match(F.crop, kernel=kernel, input_type=input_type) @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS) - def test_functional_image_correctness(self, kwargs): - image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) + def test_functional_image_correctness(self, kwargs, make_input): + image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") actual = F.crop(image, **kwargs) + + if make_input == make_image_cvcuda: + actual = F.cvcuda_to_tensor(actual).to(device="cpu") + actual = actual.squeeze(0) + image = F.cvcuda_to_tensor(image).to(device="cpu") + image = image.squeeze(0) + expected = F.to_image(F.crop(F.to_pil_image(image), **kwargs)) assert_equal(actual, expected) @@ -3575,7 +3591,16 @@ def test_transform_pad_if_needed(self): padding_mode=["constant", "edge", "reflect", "symmetric"], ) @pytest.mark.parametrize("seed", list(range(5))) - def test_transform_image_correctness(self, param, value, seed): + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) + def test_transform_image_correctness(self, param, value, seed, make_input): kwargs = {param: value} if param != "size": # 1. size is required @@ -3586,41 +3611,32 @@ def test_transform_image_correctness(self, param, value, seed): transform = transforms.RandomCrop(pad_if_needed=True, **kwargs) - image = make_image(self.INPUT_SIZE) + will_pad = False + if kwargs["size"][0] > self.INPUT_SIZE[0] or kwargs["size"][1] > self.INPUT_SIZE[1]: + will_pad = True + + image = make_input(self.INPUT_SIZE) with freeze_rng_state(): torch.manual_seed(seed) actual = transform(image) torch.manual_seed(seed) - expected = F.to_image(transform(F.to_pil_image(image))) - assert_equal(actual, expected) + if make_input == make_image_cvcuda: + actual = F.cvcuda_to_tensor(actual).to(device="cpu") + actual = actual.squeeze(0) + image = F.cvcuda_to_tensor(image).to(device="cpu") + image = image.squeeze(0) - @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") - @pytest.mark.parametrize("size", [(10, 5), (25, 15), (25, 5), (10, 15), (10, 10)]) - @pytest.mark.parametrize("seed", list(range(5))) - def test_transform_cvcuda_correctness(self, size, seed): - pad_if_needed = False - if size[0] > self.INPUT_SIZE[0] or size[1] > self.INPUT_SIZE[1]: - pad_if_needed = True - transform = transforms.RandomCrop(size, pad_if_needed=pad_if_needed) - - image = make_image(size=self.INPUT_SIZE, batch_dims=(1,), device="cuda") - cv_image = F.to_cvcuda_tensor(image) - - with freeze_rng_state(): - torch.manual_seed(seed) - actual = transform(cv_image) - - torch.manual_seed(seed) - expected = transform(image) + expected = F.to_image(transform(F.to_pil_image(image))) - if not pad_if_needed: - torch.testing.assert_close(F.cvcuda_to_tensor(actual), expected, rtol=0, atol=0) + if make_input == make_image_cvcuda and will_pad: + # when padding is applied, CV-CUDA will always fill with zeros + # cannot use assert_equal since it will fail unless random is all zeros + torch.testing.assert_close(actual, expected, rtol=0, atol=get_max_value(image.dtype)) else: - # if padding is requied, CV-CUDA will always fill with zeros - torch.testing.assert_close(F.cvcuda_to_tensor(actual), expected, rtol=0, atol=get_max_value(image.dtype)) + assert_equal(actual, expected) def _reference_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, width): affine_matrix = np.array( From ee626ae78925f0ce6c14bc2ed54dea0f937ca186 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Wed, 26 Nov 2025 10:45:24 -0800 Subject: [PATCH 09/11] simplify test for center crop cvcuda --- test/test_transforms_v2.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 59bfb93117d..1e558bac2a8 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -4979,25 +4979,30 @@ def test_transform(self, make_input): check_transform(transforms.CenterCrop(self.OUTPUT_SIZES[0]), make_input(self.INPUT_SIZE)) @pytest.mark.parametrize("output_size", OUTPUT_SIZES) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) @pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)]) - def test_image_correctness(self, output_size, fn): - image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") + def test_image_correctness(self, output_size, make_input, fn): + image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") actual = fn(image, output_size) - expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size)) - assert_equal(actual, expected) - - @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") - @pytest.mark.parametrize("output_size", OUTPUT_SIZES) - @pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)]) - def test_cvcuda_correctness(self, output_size, fn): - image = make_image_cvcuda(self.INPUT_SIZE, dtype=torch.uint8, device="cuda") + if make_input == make_image_cvcuda: + actual = F.cvcuda_to_tensor(actual).to(device="cpu") + actual = actual.squeeze(0) + image = F.cvcuda_to_tensor(image).to(device="cpu") + image = image.squeeze(0) - actual = fn(image, output_size) - expected = F.center_crop(F.cvcuda_to_tensor(image), output_size) + expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size)) - assert_equal(F.cvcuda_to_tensor(actual), expected) + assert_equal(actual, expected) def _reference_center_crop_bounding_boxes(self, bounding_boxes, output_size): image_height, image_width = bounding_boxes.canvas_size From 31e08e471c57d6d3f2340f0a42fff7fe8c7be8be Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 1 Dec 2025 15:51:41 -0800 Subject: [PATCH 10/11] begin work on finalizing the crop PR to include five and ten crop, adhere to new PR reviews for flip --- test/common_utils.py | 33 +++++++- test/test_transforms_v2.py | 79 +++++++++++++++---- torchvision/transforms/v2/_geometry.py | 11 +++ torchvision/transforms/v2/_transform.py | 6 +- .../transforms/v2/functional/_geometry.py | 71 +++++++++++++++++ 5 files changed, 179 insertions(+), 21 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index e7bae60c41b..b382a764e95 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -20,7 +20,13 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import io, tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional import ( + cvcuda_to_tensor, + is_cvcuda_tensor, + to_cvcuda_tensor, + to_image, + to_pil_image, +) from torchvision.utils import _Image_fromarray @@ -275,6 +281,17 @@ def combinations_grid(**kwargs): return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())] +def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor: + tensor = cvcuda_to_tensor(tensor) + if tensor.ndim != 4: + raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.") + if tensor.shape[0] != 1: + raise ValueError( + f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}." + ) + return tensor.squeeze(0).cpu() + + class ImagePair(TensorLikePair): def __init__( self, @@ -284,8 +301,17 @@ def __init__( mae=False, **other_parameters, ): - if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]): - actual, expected = (to_image(input) for input in [actual, expected]) + # Convert PIL images to tv_tensors.Image (regardless of what the other is) + if isinstance(actual, PIL.Image.Image): + actual = to_image(actual) + if isinstance(expected, PIL.Image.Image): + expected = to_image(expected) + + # attempt to convert CV-CUDA tensors to torch tensors + if is_cvcuda_tensor(actual): + actual = cvcuda_to_pil_compatible_tensor(actual) + if is_cvcuda_tensor(expected): + expected = cvcuda_to_pil_compatible_tensor(expected) super().__init__(actual, expected, **other_parameters) self.mae = mae @@ -401,7 +427,6 @@ def make_image_pil(*args, **kwargs): def make_image_cvcuda(*args, batch_dims=(1,), **kwargs): - # explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4) return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs)) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 1e558bac2a8..f7168dbf7a3 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -24,6 +24,7 @@ assert_equal, cache, cpu_and_cuda, + cvcuda_to_pil_compatible_tensor, freeze_rng_state, ignore_jit_no_profile_information_warning, make_bounding_boxes, @@ -3624,10 +3625,7 @@ def test_transform_image_correctness(self, param, value, seed, make_input): torch.manual_seed(seed) if make_input == make_image_cvcuda: - actual = F.cvcuda_to_tensor(actual).to(device="cpu") - actual = actual.squeeze(0) - image = F.cvcuda_to_tensor(image).to(device="cpu") - image = image.squeeze(0) + image = cvcuda_to_pil_compatible_tensor(image) expected = F.to_image(transform(F.to_pil_image(image))) @@ -4995,10 +4993,7 @@ def test_image_correctness(self, output_size, make_input, fn): actual = fn(image, output_size) if make_input == make_image_cvcuda: - actual = F.cvcuda_to_tensor(actual).to(device="cpu") - actual = actual.squeeze(0) - image = F.cvcuda_to_tensor(image).to(device="cpu") - image = image.squeeze(0) + image = cvcuda_to_pil_compatible_tensor(image) expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size)) @@ -6274,7 +6269,15 @@ def wrapper(*args, **kwargs): @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image_pil, make_image, make_video], + [ + make_image_tensor, + make_image_pil, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], ) @pytest.mark.parametrize("functional", [F.five_crop, F.ten_crop]) def test_functional(self, make_input, functional): @@ -6292,13 +6295,27 @@ def test_functional(self, make_input, functional): (F.five_crop, F._geometry._five_crop_image_pil, PIL.Image.Image), (F.five_crop, F.five_crop_image, tv_tensors.Image), (F.five_crop, F.five_crop_video, tv_tensors.Video), + pytest.param( + F.five_crop, + F._geometry._five_crop_cvcuda, + "cvcuda.Tensor", + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), + ), (F.ten_crop, F.ten_crop_image, torch.Tensor), (F.ten_crop, F._geometry._ten_crop_image_pil, PIL.Image.Image), (F.ten_crop, F.ten_crop_image, tv_tensors.Image), (F.ten_crop, F.ten_crop_video, tv_tensors.Video), + pytest.param( + F.ten_crop, + F._geometry._ten_crop_cvcuda, + "cvcuda.Tensor", + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), + ), ], ) def test_functional_signature(self, functional, kernel, input_type): + if input_type == "cvcuda.Tensor": + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(functional, kernel=kernel, input_type=input_type) class _TransformWrapper(nn.Module): @@ -6320,7 +6337,15 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image_pil, make_image, make_video], + [ + make_image_tensor, + make_image_pil, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], ) @pytest.mark.parametrize("transform_cls", [transforms.FiveCrop, transforms.TenCrop]) def test_transform(self, make_input, transform_cls): @@ -6338,19 +6363,41 @@ def test_transform_error(self, make_input, transform_cls): with pytest.raises(TypeError, match="not supported"): transform(make_input(self.INPUT_SIZE)) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) @pytest.mark.parametrize("fn", [F.five_crop, transform_cls_to_functional(transforms.FiveCrop)]) - def test_correctness_image_five_crop(self, fn): - image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") + def test_correctness_image_five_crop(self, make_input, fn): + image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") actual = fn(image, size=self.OUTPUT_SIZE) + + if make_input is make_image_cvcuda: + image = cvcuda_to_pil_compatible_tensor(image) + expected = F.five_crop(F.to_pil_image(image), size=self.OUTPUT_SIZE) assert isinstance(actual, tuple) assert_equal(actual, [F.to_image(e) for e in expected]) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) @pytest.mark.parametrize("fn_or_class", [F.ten_crop, transforms.TenCrop]) @pytest.mark.parametrize("vertical_flip", [False, True]) - def test_correctness_image_ten_crop(self, fn_or_class, vertical_flip): + def test_correctness_image_ten_crop(self, make_input, fn_or_class, vertical_flip): if fn_or_class is transforms.TenCrop: fn = transform_cls_to_functional(fn_or_class, size=self.OUTPUT_SIZE, vertical_flip=vertical_flip) kwargs = dict() @@ -6358,9 +6405,13 @@ def test_correctness_image_ten_crop(self, fn_or_class, vertical_flip): fn = fn_or_class kwargs = dict(size=self.OUTPUT_SIZE, vertical_flip=vertical_flip) - image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") + image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") actual = fn(image, **kwargs) + + if make_input is make_image_cvcuda: + image = cvcuda_to_pil_compatible_tensor(image) + expected = F.ten_crop(F.to_pil_image(image), size=self.OUTPUT_SIZE, vertical_flip=vertical_flip) assert isinstance(actual, tuple) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 1418a6b4953..a89be60481e 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -26,6 +26,7 @@ get_bounding_boxes, has_all, has_any, + is_cvcuda_tensor, is_pure_tensor, query_size, ) @@ -186,6 +187,8 @@ class CenterCrop(Transform): _v1_transform_cls = _transforms.CenterCrop + _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + def __init__(self, size: Union[int, Sequence[int]]): super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") @@ -352,6 +355,8 @@ class FiveCrop(Transform): _v1_transform_cls = _transforms.FiveCrop + _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + def __init__(self, size: Union[int, Sequence[int]]) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") @@ -396,6 +401,8 @@ class TenCrop(Transform): _v1_transform_cls = _transforms.TenCrop + _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") @@ -803,6 +810,8 @@ class RandomCrop(Transform): _v1_transform_cls = _transforms.RandomCrop + _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + def _extract_params_for_v1_transform(self) -> dict[str, Any]: params = super()._extract_params_for_v1_transform() @@ -1113,6 +1122,8 @@ class RandomIoUCrop(Transform): Default, 40. """ + _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + def __init__( self, min_scale: float = 0.3, diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index c7b32223b8b..b0985bb0aec 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -8,10 +8,10 @@ from torch import nn from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import tv_tensors -from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor +from torchvision.transforms.v2._utils import check_type, has_any, is_cvcuda_tensor, is_pure_tensor from torchvision.utils import _log_api_usage_once -from .functional._utils import _get_kernel, is_cvcuda_tensor +from .functional._utils import _get_kernel class Transform(nn.Module): @@ -23,7 +23,7 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. - _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor) + _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image) def __init__(self) -> None: super().__init__() diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 0515daa44bd..23cbb3257de 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -142,6 +142,14 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: return horizontal_flip_image(video) +def _horizontal_flip_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": + return _import_cvcuda().flip(image, flipCode=1) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)(_horizontal_flip_cvcuda) + + def vertical_flip(inpt: torch.Tensor) -> torch.Tensor: """See :class:`~torchvision.transforms.v2.RandomVerticalFlip` for details.""" if torch.jit.is_scripting(): @@ -230,6 +238,14 @@ def vertical_flip_video(video: torch.Tensor) -> torch.Tensor: return vertical_flip_image(video) +def _vertical_flip_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": + return _import_cvcuda().flip(image, flipCode=0) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)(_vertical_flip_cvcuda) + + # We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are # prevalent and well understood. Thus, we just alias them without deprecating the old names. hflip = horizontal_flip @@ -3003,6 +3019,29 @@ def five_crop_video( return five_crop_image(video, size) +def _five_crop_cvcuda( + image: "cvcuda.Tensor", + size: list[int], +) -> tuple["cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor"]: + crop_height, crop_width = _parse_five_crop_size(size) + image_height, image_width = image.shape[-2:] + + if crop_width > image_width or crop_height > image_height: + raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}") + + tl = _crop_cvcuda(image, 0, 0, crop_height, crop_width) + tr = _crop_cvcuda(image, 0, image_width - crop_height, crop_width, crop_height) + bl = _crop_cvcuda(image, image_height - crop_height, 0, crop_width, crop_height) + br = _crop_cvcuda(image, image_height - crop_height, image_width - crop_width, crop_width, crop_height) + center = _center_crop_cvcuda(image, [crop_height, crop_width]) + + return tl, tr, bl, br, center + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(five_crop, _import_cvcuda().Tensor)(_five_crop_cvcuda) + + def ten_crop( inpt: torch.Tensor, size: list[int], vertical_flip: bool = False ) -> tuple[ @@ -3098,3 +3137,35 @@ def ten_crop_video( torch.Tensor, ]: return ten_crop_image(video, size, vertical_flip=vertical_flip) + + +def _ten_crop_cvcuda( + image: "cvcuda.Tensor", + size: list[int], + vertical_flip: bool = False, +) -> tuple[ + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", +]: + non_flipped = _five_crop_cvcuda(image, size) + + if vertical_flip: + image = _vertical_flip_cvcuda(image) + else: + image = _horizontal_flip_cvcuda(image) + + flipped = _five_crop_cvcuda(image, size) + + return non_flipped + flipped + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(ten_crop, _import_cvcuda().Tensor)(_ten_crop_cvcuda) From a00a7afc28c487355e3d8066abd97b58e962e8e9 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 1 Dec 2025 16:54:47 -0800 Subject: [PATCH 11/11] update to include five ten crop and resized crop, use placeholder transforms for flip and resize for now --- test/common_utils.py | 17 +++--- test/test_transforms_v2.py | 51 +++++++++++++----- torchvision/transforms/v2/_geometry.py | 6 +++ .../transforms/v2/functional/_geometry.py | 52 +++++++++++++++++-- 4 files changed, 99 insertions(+), 27 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index b382a764e95..069077569a0 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -20,19 +20,15 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import io, tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import ( - cvcuda_to_tensor, - is_cvcuda_tensor, - to_cvcuda_tensor, - to_image, - to_pil_image, -) +from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available from torchvision.utils import _Image_fromarray IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" +CVCUDA_AVAILABLE = _is_cvcuda_available() CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" MPS_NOT_AVAILABLE_MSG = "MPS device not available" OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda." @@ -307,11 +303,10 @@ def __init__( if isinstance(expected, PIL.Image.Image): expected = to_image(expected) - # attempt to convert CV-CUDA tensors to torch tensors - if is_cvcuda_tensor(actual): + # handle check for CV-CUDA Tensors + if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor): + # Use the PIL compatible tensor, so we can always compare with PIL.Image.Image actual = cvcuda_to_pil_compatible_tensor(actual) - if is_cvcuda_tensor(expected): - expected = cvcuda_to_pil_compatible_tensor(expected) super().__init__(actual, expected, **other_parameters) self.mae = mae diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index f7168dbf7a3..086d468995b 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -21,6 +21,7 @@ import torchvision.transforms.v2 as transforms from common_utils import ( + assert_close, assert_equal, cache, cpu_and_cuda, @@ -42,7 +43,6 @@ ) from torch import nn -from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors @@ -3499,11 +3499,8 @@ def test_functional_image_correctness(self, kwargs, make_input): actual = F.crop(image, **kwargs) - if make_input == make_image_cvcuda: - actual = F.cvcuda_to_tensor(actual).to(device="cpu") - actual = actual.squeeze(0) - image = F.cvcuda_to_tensor(image).to(device="cpu") - image = image.squeeze(0) + if make_input is make_image_cvcuda: + image = cvcuda_to_pil_compatible_tensor(image) expected = F.to_image(F.crop(F.to_pil_image(image), **kwargs)) @@ -3624,7 +3621,7 @@ def test_transform_image_correctness(self, param, value, seed, make_input): torch.manual_seed(seed) - if make_input == make_image_cvcuda: + if make_input is make_image_cvcuda: image = cvcuda_to_pil_compatible_tensor(image) expected = F.to_image(transform(F.to_pil_image(image))) @@ -3632,7 +3629,7 @@ def test_transform_image_correctness(self, param, value, seed, make_input): if make_input == make_image_cvcuda and will_pad: # when padding is applied, CV-CUDA will always fill with zeros # cannot use assert_equal since it will fail unless random is all zeros - torch.testing.assert_close(actual, expected, rtol=0, atol=get_max_value(image.dtype)) + assert_close(actual, expected, rtol=0, atol=get_max_value(image.dtype)) else: assert_equal(actual, expected) @@ -4458,6 +4455,9 @@ def test_kernel(self, kernel, make_input): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), ], ) def test_functional(self, make_input): @@ -4474,9 +4474,16 @@ def test_functional(self, make_input): (F.resized_crop_mask, tv_tensors.Mask), (F.resized_crop_video, tv_tensors.Video), (F.resized_crop_keypoints, tv_tensors.KeyPoints), + pytest.param( + F.resized_crop_image, + "cvcuda.Tensor", + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if input_type == "cvcuda.Tensor": + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.resized_crop, kernel=kernel, input_type=input_type) @param_value_parametrization( @@ -4493,6 +4500,9 @@ def test_functional_signature(self, kernel, input_type): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), ], ) def test_transform(self, param, value, make_input): @@ -4504,20 +4514,37 @@ def test_transform(self, param, value, make_input): # `InterpolationMode.NEAREST` is modeled after the buggy `INTER_NEAREST` interpolation of CV2. # The PIL equivalent of `InterpolationMode.NEAREST` is `InterpolationMode.NEAREST_EXACT` + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) @pytest.mark.parametrize("interpolation", set(INTERPOLATION_MODES) - {transforms.InterpolationMode.NEAREST}) - def test_functional_image_correctness(self, interpolation): - image = make_image(self.INPUT_SIZE, dtype=torch.uint8) + def test_functional_image_correctness(self, make_input, interpolation): + image = make_input(self.INPUT_SIZE, dtype=torch.uint8) actual = F.resized_crop( image, **self.CROP_KWARGS, size=self.OUTPUT_SIZE, interpolation=interpolation, antialias=True ) + + if make_input is make_image_cvcuda: + image = cvcuda_to_pil_compatible_tensor(image) + expected = F.to_image( F.resized_crop( F.to_pil_image(image), **self.CROP_KWARGS, size=self.OUTPUT_SIZE, interpolation=interpolation ) ) - torch.testing.assert_close(actual, expected, atol=1, rtol=0) + atol = 1 + if make_input is make_image_cvcuda and interpolation == transforms.InterpolationMode.BICUBIC: + # CV-CUDA BICUBIC differs from PIL ground truth BICUBIC + atol = 10 + assert_close(actual, expected, atol=atol, rtol=0) def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, width, size): new_height, new_width = size @@ -4992,7 +5019,7 @@ def test_image_correctness(self, output_size, make_input, fn): actual = fn(image, output_size) - if make_input == make_image_cvcuda: + if make_input is make_image_cvcuda: image = cvcuda_to_pil_compatible_tensor(image) expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size)) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index a89be60481e..d7cc9ea72f3 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -46,6 +46,8 @@ class RandomHorizontalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomHorizontalFlip + _transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,) + def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel(F.horizontal_flip, inpt) @@ -64,6 +66,8 @@ class RandomVerticalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomVerticalFlip + _transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,) + def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel(F.vertical_flip, inpt) @@ -247,6 +251,8 @@ class RandomResizedCrop(Transform): _v1_transform_cls = _transforms.RandomResizedCrop + _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + def __init__( self, size: Union[int, Sequence[int]], diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 23cbb3257de..efb9a594cca 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -605,6 +605,32 @@ def resize_video( return resize_image(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) +def _resize_cvcuda( + image: "cvcuda.Tensor", + size: Optional[list[int]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = True, +) -> "cvcuda.Tensor": + # placeholder func for now, will be handled in PR for resize alone + # since placeholder convert to from torch tensor and use resize_image + from ._type_conversion import cvcuda_to_tensor, to_cvcuda_tensor + + return to_cvcuda_tensor( + resize_image( + cvcuda_to_tensor(image), + size=size, + interpolation=interpolation, + max_size=max_size, + antialias=antialias, + ) + ) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(resize, _import_cvcuda().Tensor)(_resize_cvcuda) + + def affine( inpt: torch.Tensor, angle: Union[int, float], @@ -2946,6 +2972,24 @@ def resized_crop_video( ) +def _resized_crop_cvcuda( + image: "cvcuda.Tensor", + top: int, + left: int, + height: int, + width: int, + size: list[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, +) -> "cvcuda.Tensor": + image = _crop_cvcuda(image, top, left, height, width) + return _resize_cvcuda(image, size, interpolation=interpolation, antialias=antialias) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(resized_crop, _import_cvcuda().Tensor)(_resized_crop_cvcuda) + + def five_crop( inpt: torch.Tensor, size: list[int] ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -3024,15 +3068,15 @@ def _five_crop_cvcuda( size: list[int], ) -> tuple["cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor"]: crop_height, crop_width = _parse_five_crop_size(size) - image_height, image_width = image.shape[-2:] + image_height, image_width = image.shape[1], image.shape[2] if crop_width > image_width or crop_height > image_height: raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}") tl = _crop_cvcuda(image, 0, 0, crop_height, crop_width) - tr = _crop_cvcuda(image, 0, image_width - crop_height, crop_width, crop_height) - bl = _crop_cvcuda(image, image_height - crop_height, 0, crop_width, crop_height) - br = _crop_cvcuda(image, image_height - crop_height, image_width - crop_width, crop_width, crop_height) + tr = _crop_cvcuda(image, 0, image_width - crop_width, crop_height, crop_width) + bl = _crop_cvcuda(image, image_height - crop_height, 0, crop_height, crop_width) + br = _crop_cvcuda(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width) center = _center_crop_cvcuda(image, [crop_height, crop_width]) return tl, tr, bl, br, center