Skip to content
5 changes: 3 additions & 2 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,9 @@ def make_image_pil(*args, **kwargs):
return to_pil_image(make_image(*args, **kwargs))


def make_image_cvcuda(*args, **kwargs):
return to_cvcuda_tensor(make_image(*args, **kwargs))
def make_image_cvcuda(*args, batch_dims=(1,), **kwargs):
# explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4)
return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs))


def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"):
Expand Down
105 changes: 94 additions & 11 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3453,6 +3453,9 @@ def test_kernel_video(self):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_functional(self, make_input):
Expand All @@ -3468,16 +3471,39 @@ def test_functional(self, make_input):
(F.crop_mask, tv_tensors.Mask),
(F.crop_video, tv_tensors.Video),
(F.crop_keypoints, tv_tensors.KeyPoints),
pytest.param(
F._geometry._crop_cvcuda,
"cvcuda.Tensor",
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
),
],
)
def test_functional_signature(self, kernel, input_type):
if input_type == "cvcuda.Tensor":
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.crop, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS)
def test_functional_image_correctness(self, kwargs):
image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_functional_image_correctness(self, kwargs, make_input):
image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")

actual = F.crop(image, **kwargs)

if make_input == make_image_cvcuda:
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
actual = actual.squeeze(0)
image = F.cvcuda_to_tensor(image).to(device="cpu")
image = image.squeeze(0)

expected = F.to_image(F.crop(F.to_pil_image(image), **kwargs))

assert_equal(actual, expected)
Expand All @@ -3496,15 +3522,18 @@ def test_functional_image_correctness(self, kwargs):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_transform(self, param, value, make_input):
input = make_input(self.INPUT_SIZE)
input_data = make_input(self.INPUT_SIZE)

check_sample_input = True
if param == "fill":
if isinstance(value, (tuple, list)):
if isinstance(input, tv_tensors.Mask):
if isinstance(input_data, tv_tensors.Mask):
pytest.skip("F.pad_mask doesn't support non-scalar fill.")
else:
check_sample_input = False
Expand All @@ -3513,14 +3542,14 @@ def test_transform(self, param, value, make_input):
# 1. size is required
# 2. the fill parameter only has an affect if we need padding
size=[s + 4 for s in self.INPUT_SIZE],
fill=adapt_fill(value, dtype=input.dtype if isinstance(input, torch.Tensor) else torch.uint8),
fill=adapt_fill(value, dtype=input_data.dtype if isinstance(input_data, torch.Tensor) else torch.uint8),
)
else:
kwargs = {param: value}

check_transform(
transforms.RandomCrop(**kwargs, pad_if_needed=True),
input,
input_data,
check_v1_compatibility=param != "fill" or isinstance(value, (int, float)),
check_sample_input=check_sample_input,
)
Expand Down Expand Up @@ -3562,7 +3591,16 @@ def test_transform_pad_if_needed(self):
padding_mode=["constant", "edge", "reflect", "symmetric"],
)
@pytest.mark.parametrize("seed", list(range(5)))
def test_transform_image_correctness(self, param, value, seed):
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_transform_image_correctness(self, param, value, seed, make_input):
kwargs = {param: value}
if param != "size":
# 1. size is required
Expand All @@ -3573,16 +3611,32 @@ def test_transform_image_correctness(self, param, value, seed):

transform = transforms.RandomCrop(pad_if_needed=True, **kwargs)

image = make_image(self.INPUT_SIZE)
will_pad = False
if kwargs["size"][0] > self.INPUT_SIZE[0] or kwargs["size"][1] > self.INPUT_SIZE[1]:
will_pad = True

image = make_input(self.INPUT_SIZE)

with freeze_rng_state():
torch.manual_seed(seed)
actual = transform(image)

torch.manual_seed(seed)

if make_input == make_image_cvcuda:
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
actual = actual.squeeze(0)
image = F.cvcuda_to_tensor(image).to(device="cpu")
image = image.squeeze(0)

expected = F.to_image(transform(F.to_pil_image(image)))

assert_equal(actual, expected)
if make_input == make_image_cvcuda and will_pad:
# when padding is applied, CV-CUDA will always fill with zeros
# cannot use assert_equal since it will fail unless random is all zeros
torch.testing.assert_close(actual, expected, rtol=0, atol=get_max_value(image.dtype))
else:
assert_equal(actual, expected)

def _reference_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, width):
affine_matrix = np.array(
Expand Down Expand Up @@ -4876,6 +4930,9 @@ def test_kernel_video(self):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_functional(self, make_input):
Expand All @@ -4891,9 +4948,16 @@ def test_functional(self, make_input):
(F.center_crop_mask, tv_tensors.Mask),
(F.center_crop_video, tv_tensors.Video),
(F.center_crop_keypoints, tv_tensors.KeyPoints),
pytest.param(
F._geometry._center_crop_cvcuda,
"cvcuda.Tensor",
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
),
],
)
def test_functional_signature(self, kernel, input_type):
if input_type == "cvcuda.Tensor":
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.center_crop, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize(
Expand All @@ -4906,17 +4970,36 @@ def test_functional_signature(self, kernel, input_type):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_transform(self, make_input):
check_transform(transforms.CenterCrop(self.OUTPUT_SIZES[0]), make_input(self.INPUT_SIZE))

@pytest.mark.parametrize("output_size", OUTPUT_SIZES)
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
@pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)])
def test_image_correctness(self, output_size, fn):
image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
def test_image_correctness(self, output_size, make_input, fn):
image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")

actual = fn(image, output_size)

if make_input == make_image_cvcuda:
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
actual = actual.squeeze(0)
image = F.cvcuda_to_tensor(image).to(device="cpu")
image = image.squeeze(0)

expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size))

assert_equal(actual, expected)
Expand Down
8 changes: 5 additions & 3 deletions torchvision/transforms/v2/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
from torchvision.utils import _log_api_usage_once

from .functional._utils import _get_kernel
from .functional._utils import _get_kernel, is_cvcuda_tensor


class Transform(nn.Module):
Expand All @@ -23,7 +23,7 @@ class Transform(nn.Module):

# Class attribute defining transformed types. Other types are passed-through without any transformation
# We support both Types and callables that are able to do further checks on the type of the input.
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image)
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor)

def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -90,7 +90,9 @@ def _needs_transform_list(self, flat_inputs: list[Any]) -> list[bool]:
# However, this case wasn't supported by transforms v1 either, so there is no BC concern.

needs_transform_list = []
transform_pure_tensor = not has_any(flat_inputs, tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)
transform_pure_tensor = not has_any(
flat_inputs, tv_tensors.Image, tv_tensors.Video, PIL.Image.Image, is_cvcuda_tensor
)
for inpt in flat_inputs:
needs_transform = True

Expand Down
5 changes: 3 additions & 2 deletions torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchvision._utils import sequence_to_str

from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT


Expand Down Expand Up @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor))
}
if not chws:
raise TypeError("No image or video was found in the sample")
Expand All @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]:
tv_tensors.Mask,
tv_tensors.BoundingBoxes,
tv_tensors.KeyPoints,
is_cvcuda_tensor,
),
)
}
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torchvision.transforms import InterpolationMode # usort: skip

from ._utils import is_pure_tensor, register_kernel # usort: skip
from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip

from ._meta import (
clamp_bounding_boxes,
Expand Down
Loading