Skip to content
Merged
18 changes: 13 additions & 5 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,16 +1352,24 @@ def test_ten_crop(device):
assert_equal(transformed_batch, s_transformed_batch)


def test_elastic_transform_asserts():
with pytest.raises(TypeError, match="Argument displacement should be a Tensor"):
_ = F.elastic_transform("abc", displacement=None)

with pytest.raises(TypeError, match="img should be PIL Image or Tensor"):
_ = F.elastic_transform("abc", displacement=torch.rand(1))

img_tensor = torch.rand(1, 3, 32, 24)
with pytest.raises(ValueError, match="Argument displacement shape should"):
_ = F.elastic_transform(img_tensor, displacement=torch.rand(1, 2))


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC])
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize(
"fill",
[
None,
[255, 255, 255],
(2.0,),
],
[None, [255, 255, 255], (2.0,)],
)
def test_elastic_transform_consistency(device, interpolation, dt, fill):
script_elastic_transform = torch.jit.script(F.elastic_transform)
Expand Down
75 changes: 75 additions & 0 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ def test__transform(self, fill, side_range, mocker):
fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
Expand Down Expand Up @@ -456,6 +457,7 @@ def test__transform(self, degrees, expand, fill, center, mocker):
inpt = mocker.MagicMock(spec=features.Image)
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
Expand Down Expand Up @@ -576,6 +578,7 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker

# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
Expand Down Expand Up @@ -645,6 +648,7 @@ def test_forward(self, padding, pad_if_needed, fill, padding_mode, mocker):

# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
Expand Down Expand Up @@ -716,6 +720,7 @@ def test__transform(self, kernel_size, sigma, mocker):

# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
Expand Down Expand Up @@ -795,10 +800,80 @@ def test__transform(self, distortion_scale, mocker):
inpt.image_size = (24, 32)
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
torch.rand(1) # random apply changes random state
params = transform._get_params(inpt)

fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)


class TestElasticTransform:
def test_assertions(self):

with pytest.raises(TypeError, match="alpha should be float or a sequence of floats"):
transforms.ElasticTransform({})

with pytest.raises(ValueError, match="alpha is a sequence its length should be one of 2"):
transforms.ElasticTransform([1.0, 2.0, 3.0])

with pytest.raises(ValueError, match="alpha should be a sequence of floats"):
transforms.ElasticTransform([1, 2])

with pytest.raises(TypeError, match="sigma should be float or a sequence of floats"):
transforms.ElasticTransform(1.0, {})

with pytest.raises(ValueError, match="sigma is a sequence its length should be one of 2"):
transforms.ElasticTransform(1.0, [1.0, 2.0, 3.0])

with pytest.raises(ValueError, match="sigma should be a sequence of floats"):
transforms.ElasticTransform(1.0, [1, 2])

with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.ElasticTransform(1.0, 2.0, fill="abc")

def test__get_params(self, mocker):
alpha = 2.0
sigma = 3.0
transform = transforms.ElasticTransform(alpha, sigma)
image = mocker.MagicMock(spec=features.Image)
image.num_channels = 3
image.image_size = (24, 32)

params = transform._get_params(image)

h, w = image.image_size
displacement = params["displacement"]
assert displacement.shape == (1, h, w, 2)
assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all()
assert (-alpha / h <= displacement[0, ..., 1]).all() and (displacement[0, ..., 1] <= alpha / h).all()

@pytest.mark.parametrize("alpha", [5.0, [5.0, 10.0]])
@pytest.mark.parametrize("sigma", [2.0, [2.0, 5.0]])
def test__transform(self, alpha, sigma, mocker):
interpolation = InterpolationMode.BILINEAR
fill = 12
transform = transforms.ElasticTransform(alpha, sigma=sigma, fill=fill, interpolation=interpolation)

if isinstance(alpha, float):
assert transform.alpha == [alpha, alpha]
else:
assert transform.alpha == alpha

if isinstance(sigma, float):
assert transform.sigma == [sigma, sigma]
else:
assert transform.sigma == sigma

fn = mocker.patch("torchvision.prototype.transforms.functional.elastic")
inpt = mocker.MagicMock(spec=features.Image)
inpt.num_channels = 3
inpt.image_size = (24, 32)

# Let's mock transform._get_params to control the output:
transform._get_params = mocker.MagicMock()
_ = transform(inpt)
params = transform._get_params(inpt)
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
106 changes: 98 additions & 8 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def make_images(
yield make_image(size, color_space=color_space, dtype=dtype)

for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims):
yield make_image(color_space=color_space, extra_dims=extra_dims_, dtype=dtype)
yield make_image(size=sizes[0], color_space=color_space, extra_dims=extra_dims_, dtype=dtype)


def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
Expand Down Expand Up @@ -149,12 +149,12 @@ def make_segmentation_mask(size=None, *, num_categories=80, extra_dims=(), dtype


def make_segmentation_masks(
image_sizes=((16, 16), (7, 33), (31, 9)),
sizes=((16, 16), (7, 33), (31, 9)),
dtypes=(torch.long,),
extra_dims=((), (4,), (2, 3)),
):
for image_size, dtype, extra_dims_ in itertools.product(image_sizes, dtypes, extra_dims):
yield make_segmentation_mask(size=image_size, dtype=dtype, extra_dims=extra_dims_)
for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims):
yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_)


class SampleInput:
Expand Down Expand Up @@ -533,6 +533,40 @@ def perspective_segmentation_mask():
)


@register_kernel_info_from_sample_inputs_fn
def elastic_image_tensor():
for image, fill in itertools.product(
make_images(extra_dims=((), (4,))),
[None, [128], [12.0]], # fill
):
h, w = image.shape[-2:]
displacement = torch.rand(1, h, w, 2)
yield SampleInput(image, displacement=displacement, fill=fill)


@register_kernel_info_from_sample_inputs_fn
def elastic_bounding_box():
for bounding_box in make_bounding_boxes():
h, w = bounding_box.image_size
displacement = torch.rand(1, h, w, 2)
yield SampleInput(
bounding_box,
format=bounding_box.format,
displacement=displacement,
)


@register_kernel_info_from_sample_inputs_fn
def elastic_segmentation_mask():
for mask in make_segmentation_masks(extra_dims=((), (4,))):
h, w = mask.shape[-2:]
displacement = torch.rand(1, h, w, 2)
yield SampleInput(
mask,
displacement=displacement,
)


@register_kernel_info_from_sample_inputs_fn
def center_crop_image_tensor():
for mask, output_size in itertools.product(
Expand All @@ -553,7 +587,7 @@ def center_crop_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def center_crop_segmentation_mask():
for mask, output_size in itertools.product(
make_segmentation_masks(image_sizes=((16, 16), (7, 33), (31, 9))),
make_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))),
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
):
yield SampleInput(mask, output_size)
Expand Down Expand Up @@ -654,10 +688,20 @@ def test_scriptable(kernel):
feature_type not in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label", "pil"}
)
and name
not in {"to_image_tensor", "InterpolationMode", "decode_video_with_av", "crop", "rotate", "perspective"}
not in {
"to_image_tensor",
"InterpolationMode",
"decode_video_with_av",
"crop",
"rotate",
"perspective",
"elastic_transform",
"elastic",
}
# We skip 'crop' due to missing 'height' and 'width'
# We skip 'rotate' due to non implemented yet expand=True case for bboxes
# We skip 'perspective' as it requires different input args than perspective_image_tensor etc
# Skip 'elastic', TODO: inspect why test is failing
],
)
def test_functional_mid_level(func):
Expand All @@ -670,7 +714,9 @@ def test_functional_mid_level(func):
if key in kwargs:
del kwargs[key]
output = func(*sample_input.args, **kwargs)
torch.testing.assert_close(output, expected, msg=f"finfo={finfo}, output={output}, expected={expected}")
torch.testing.assert_close(
output, expected, msg=f"finfo={finfo.name}, output={output}, expected={expected}"
)
break


Expand Down Expand Up @@ -1739,5 +1785,49 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s
torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
)

out = fn(tensor, kernel_size=ksize, sigma=sigma)
image = features.Image(tensor)

out = fn(image, kernel_size=ksize, sigma=sigma)
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"fn, make_samples", [(F.elastic_image_tensor, make_images), (F.elastic_segmentation_mask, make_segmentation_masks)]
)
def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples):
in_box = [10, 15, 25, 35]
for sample in make_samples(sizes=((64, 76),), extra_dims=((), (4,))):
c, h, w = sample.shape[-3:]
# Setup a dummy image with 4 points
sample[..., in_box[1], in_box[0]] = torch.tensor([12, 34, 96, 112])[:c]
sample[..., in_box[3] - 1, in_box[0]] = torch.tensor([12, 34, 96, 112])[:c]
sample[..., in_box[3] - 1, in_box[2] - 1] = torch.tensor([12, 34, 96, 112])[:c]
sample[..., in_box[1], in_box[2] - 1] = torch.tensor([12, 34, 96, 112])[:c]
sample = sample.to(device)

if fn == F.elastic_image_tensor:
sample = features.Image(sample)
kwargs = {"interpolation": F.InterpolationMode.NEAREST}
else:
sample = features.SegmentationMask(sample)
kwargs = {}

# Create a displacement grid using sin
n, m = 5.0, 0.1
d1 = m * torch.sin(torch.arange(h, dtype=torch.float) * torch.pi * n / h)
d2 = m * torch.sin(torch.arange(w, dtype=torch.float) * torch.pi * n / w)

d1 = d1[:, None].expand((h, w))
d2 = d2[None, :].expand((h, w))

displacement = torch.cat([d1[..., None], d2[..., None]], dim=-1)
displacement = displacement.reshape(1, h, w, 2)

output = fn(sample, displacement=displacement, **kwargs)

# Check places where transformed points should be
torch.testing.assert_close(output[..., 12, 9], sample[..., in_box[1], in_box[0]])
torch.testing.assert_close(output[..., 17, 27], sample[..., in_box[1], in_box[2] - 1])
torch.testing.assert_close(output[..., 31, 6], sample[..., in_box[3] - 1, in_box[0]])
torch.testing.assert_close(output[..., 37, 23], sample[..., in_box[3] - 1, in_box[2] - 1])
11 changes: 11 additions & 0 deletions torchvision/prototype/features/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,14 @@ def perspective(

output = _F.perspective_bounding_box(self, self.format, perspective_coeffs)
return BoundingBox.new_like(self, output, dtype=output.dtype)

def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F

output = _F.elastic_bounding_box(self, self.format, displacement)
return BoundingBox.new_like(self, output, dtype=output.dtype)
8 changes: 8 additions & 0 deletions torchvision/prototype/features/_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ def perspective(
) -> Any:
return self

def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Any:
return self

def adjust_brightness(self, brightness_factor: float) -> Any:
return self

Expand Down
13 changes: 13 additions & 0 deletions torchvision/prototype/features/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,19 @@ def perspective(
output = _F.perspective_image_tensor(self, perspective_coeffs, interpolation=interpolation, fill=fill)
return Image.new_like(self, output)

def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image:
from torchvision.prototype.transforms.functional import _geometry as _F

fill = _F._convert_fill_arg(fill)

output = _F.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill)
return Image.new_like(self, output)

def adjust_brightness(self, brightness_factor: float) -> Image:
from torchvision.prototype.transforms import functional as _F

Expand Down
12 changes: 12 additions & 0 deletions torchvision/prototype/features/_segmentation_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import List, Optional, Union, Sequence

import torch
from torchvision.transforms import InterpolationMode

from ._feature import _Feature
Expand Down Expand Up @@ -119,3 +120,14 @@ def perspective(

output = _F.perspective_segmentation_mask(self, perspective_coeffs)
return SegmentationMask.new_like(self, output)

def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F

output = _F.elastic_segmentation_mask(self, displacement)
return SegmentationMask.new_like(self, output, dtype=output.dtype)
3 changes: 1 addition & 2 deletions torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@
RandomRotation,
RandomAffine,
RandomPerspective,
ElasticTransform,
)
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, GaussianBlur, Normalize, ToDtype, Lambda
from ._type_conversion import DecodeImage, LabelToOneHot

from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip

# TODO: add RandomPerspective, ElasticTransform
Loading