diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 670a9d00ffb..550e978d2a5 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1249,6 +1249,10 @@ def test_kernel_video(self): def test_functional(self, make_input): check_functional(F.horizontal_flip, make_input()) + @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available") + def test_functional_cvcuda(self): + check_functional(F.horizontal_flip, make_image_cvcuda(batch_dims=(1,))) + @pytest.mark.parametrize( ("kernel", "input_type"), [ @@ -1291,6 +1295,15 @@ def test_image_correctness(self, fn): torch.testing.assert_close(actual, expected) + @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available") + def test_image_correctness_cvcuda(self): + image = make_image_cvcuda(batch_dims=(1,)) + + actual = F.horizontal_flip(image) + expected_torch = F.horizontal_flip(F.cvcuda_to_tensor(image)) + + assert torch.equal(F.cvcuda_to_tensor(actual), expected_torch) + def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes): affine_matrix = np.array( [ @@ -1865,6 +1878,10 @@ def test_kernel_video(self): def test_functional(self, make_input): check_functional(F.vertical_flip, make_input()) + @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available") + def test_functional_cvcuda(self): + check_functional(F.vertical_flip, make_image_cvcuda()) + @pytest.mark.parametrize( ("kernel", "input_type"), [ @@ -1905,6 +1922,15 @@ def test_image_correctness(self, fn): torch.testing.assert_close(actual, expected) + @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available") + def test_image_correctness_cvcuda(self): + image = make_image_cvcuda(batch_dims=(1,)) + + actual = F.vertical_flip(image) + expected_torch = F.vertical_flip(F.cvcuda_to_tensor(image)) + + assert torch.equal(F.cvcuda_to_tensor(actual), expected_torch) + def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes): affine_matrix = np.array( [ diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 4fcb7fabe0d..1b7163b6d9b 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -2,7 +2,7 @@ import numbers import warnings from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union import PIL.Image import torch @@ -26,7 +26,13 @@ from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format -from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal +from ._utils import _FillTypeJIT, _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_five_ten_crop_kernel_internal, _register_kernel_internal + +CVCUDA_AVAILABLE = _is_cvcuda_available() +if TYPE_CHECKING: + import cvcuda +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: @@ -61,6 +67,12 @@ def horizontal_flip_image(image: torch.Tensor) -> torch.Tensor: def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: return _FP.hflip(image) +def _horizontal_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": + return cvcuda.flip(image, flipCode=1) + + +if CVCUDA_AVAILABLE: + _horizontal_flip_image_cvcuda_registered = _register_kernel_internal(horizontal_flip, cvcuda.Tensor)(_horizontal_flip_image_cvcuda) @_register_kernel_internal(horizontal_flip, tv_tensors.Mask) def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: @@ -150,6 +162,14 @@ def _vertical_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: return _FP.vflip(image) +def _vertical_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": + return cvcuda.flip(image, flipCode=0) + + +if CVCUDA_AVAILABLE: + _vertical_flip_image_cvcuda_registered = _register_kernel_internal(vertical_flip, cvcuda.Tensor)(_vertical_flip_image_cvcuda) + + @_register_kernel_internal(vertical_flip, tv_tensors.Mask) def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: return vertical_flip_image(mask)