diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py new file mode 100644 index 00000000000..e19cfed0cd7 --- /dev/null +++ b/test/prototype_common_utils.py @@ -0,0 +1,199 @@ +import functools +import itertools + +import PIL.Image +import pytest + +import torch +import torch.testing +from torch.nn.functional import one_hot +from torch.testing._comparison import assert_equal as _assert_equal, TensorLikePair +from torchvision.prototype import features +from torchvision.prototype.transforms.functional import to_image_tensor +from torchvision.transforms.functional_tensor import _max_value as get_max_value + + +class ImagePair(TensorLikePair): + def _process_inputs(self, actual, expected, *, id, allow_subclasses): + return super()._process_inputs( + *[to_image_tensor(input) if isinstance(input, PIL.Image.Image) else input for input in [actual, expected]], + id=id, + allow_subclasses=allow_subclasses, + ) + + +assert_equal = functools.partial(_assert_equal, pair_types=[ImagePair], rtol=0, atol=0) + + +class ArgsKwargs: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + def __iter__(self): + yield self.args + yield self.kwargs + + def __str__(self): + def short_repr(obj, max=20): + repr_ = repr(obj) + if len(repr_) <= max: + return repr_ + + return f"{repr_[:max//2]}...{repr_[-(max//2-3):]}" + + return ", ".join( + itertools.chain( + [short_repr(arg) for arg in self.args], + [f"{param}={short_repr(kwarg)}" for param, kwarg in self.kwargs.items()], + ) + ) + + +make_tensor = functools.partial(torch.testing.make_tensor, device="cpu") + + +def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32, constant_alpha=True): + size = size or torch.randint(16, 33, (2,)).tolist() + + try: + num_channels = { + features.ColorSpace.GRAY: 1, + features.ColorSpace.GRAY_ALPHA: 2, + features.ColorSpace.RGB: 3, + features.ColorSpace.RGB_ALPHA: 4, + }[color_space] + except KeyError as error: + raise pytest.UsageError() from error + + shape = (*extra_dims, num_channels, *size) + max_value = get_max_value(dtype) + data = make_tensor(shape, low=0, high=max_value, dtype=dtype) + if color_space in {features.ColorSpace.GRAY_ALPHA, features.ColorSpace.RGB_ALPHA} and constant_alpha: + data[..., -1, :, :] = max_value + return features.Image(data, color_space=color_space) + + +make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAY) +make_rgb_image = functools.partial(make_image, color_space=features.ColorSpace.RGB) + + +def make_images( + sizes=((16, 16), (7, 33), (31, 9)), + color_spaces=( + features.ColorSpace.GRAY, + features.ColorSpace.GRAY_ALPHA, + features.ColorSpace.RGB, + features.ColorSpace.RGB_ALPHA, + ), + dtypes=(torch.float32, torch.uint8), + extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)), +): + for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes): + yield make_image(size, color_space=color_space, dtype=dtype) + + for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims): + yield make_image(size=sizes[0], color_space=color_space, extra_dims=extra_dims_, dtype=dtype) + + +def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): + low, high = torch.broadcast_tensors( + *[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))] + ) + return torch.stack( + [ + torch.randint(low_scalar, high_scalar, (), **kwargs) + for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist()) + ] + ).reshape(low.shape) + + +def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch.int64): + if isinstance(format, str): + format = features.BoundingBoxFormat[format] + + if any(dim == 0 for dim in extra_dims): + return features.BoundingBox(torch.empty(*extra_dims, 4), format=format, image_size=image_size) + + height, width = image_size + + if format == features.BoundingBoxFormat.XYXY: + x1 = torch.randint(0, width // 2, extra_dims) + y1 = torch.randint(0, height // 2, extra_dims) + x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1 + y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1 + parts = (x1, y1, x2, y2) + elif format == features.BoundingBoxFormat.XYWH: + x = torch.randint(0, width // 2, extra_dims) + y = torch.randint(0, height // 2, extra_dims) + w = randint_with_tensor_bounds(1, width - x) + h = randint_with_tensor_bounds(1, height - y) + parts = (x, y, w, h) + elif format == features.BoundingBoxFormat.CXCYWH: + cx = torch.randint(1, width - 1, ()) + cy = torch.randint(1, height - 1, ()) + w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1) + h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1) + parts = (cx, cy, w, h) + else: + raise pytest.UsageError() + + return features.BoundingBox(torch.stack(parts, dim=-1).to(dtype), format=format, image_size=image_size) + + +make_xyxy_bounding_box = functools.partial(make_bounding_box, format=features.BoundingBoxFormat.XYXY) + + +def make_bounding_boxes( + formats=(features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH), + image_sizes=((32, 32),), + dtypes=(torch.int64, torch.float32), + extra_dims=((0,), (), (4,), (2, 3), (5, 0), (0, 5)), +): + for format, image_size, dtype in itertools.product(formats, image_sizes, dtypes): + yield make_bounding_box(format=format, image_size=image_size, dtype=dtype) + + for format, extra_dims_ in itertools.product(formats, extra_dims): + yield make_bounding_box(format=format, extra_dims=extra_dims_) + + +def make_label(size=(), *, categories=("category0", "category1")): + return features.Label(torch.randint(0, len(categories) if categories else 10, size), categories=categories) + + +def make_one_hot_label(*args, **kwargs): + label = make_label(*args, **kwargs) + return features.OneHotLabel(one_hot(label, num_classes=len(label.categories)), categories=label.categories) + + +def make_one_hot_labels( + *, + num_categories=(1, 2, 10), + extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)), +): + for num_categories_ in num_categories: + yield make_one_hot_label(categories=[f"category{idx}" for idx in range(num_categories_)]) + + for extra_dims_ in extra_dims: + yield make_one_hot_label(extra_dims_) + + +def make_segmentation_mask(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8): + size = size if size is not None else torch.randint(16, 33, (2,)).tolist() + num_objects = num_objects if num_objects is not None else int(torch.randint(1, 11, ())) + shape = (*extra_dims, num_objects, *size) + data = make_tensor(shape, low=0, high=2, dtype=dtype) + return features.SegmentationMask(data) + + +def make_segmentation_masks( + sizes=((16, 16), (7, 33), (31, 9)), + dtypes=(torch.uint8,), + extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)), + num_objects=(1, 0, 10), +): + for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims): + yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_) + + for dtype, extra_dims_, num_objects_ in itertools.product(dtypes, extra_dims, num_objects): + yield make_segmentation_mask(size=sizes[0], num_objects=num_objects_, dtype=dtype, extra_dims=extra_dims_) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index bf62f966750..a5e96134c5a 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -7,7 +7,7 @@ import pytest import torch from common_utils import assert_equal, cpu_and_gpu -from test_prototype_transforms_functional import ( +from prototype_common_utils import ( make_bounding_box, make_bounding_boxes, make_image, diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index bb681f02d1e..2bb98002e12 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -1,61 +1,21 @@ import enum -import functools import inspect -import itertools import numpy as np import PIL.Image import pytest import torch -from test_prototype_transforms_functional import make_images -from torch.testing._comparison import assert_equal as _assert_equal, TensorLikePair +from prototype_common_utils import ArgsKwargs, assert_equal, make_images from torchvision import transforms as legacy_transforms from torchvision._utils import sequence_to_str from torchvision.prototype import features, transforms as prototype_transforms -from torchvision.prototype.transforms.functional import to_image_pil, to_image_tensor - - -class ImagePair(TensorLikePair): - def _process_inputs(self, actual, expected, *, id, allow_subclasses): - return super()._process_inputs( - *[to_image_tensor(input) if isinstance(input, PIL.Image.Image) else input for input in [actual, expected]], - id=id, - allow_subclasses=allow_subclasses, - ) - - -assert_equal = functools.partial(_assert_equal, pair_types=[ImagePair], rtol=0, atol=0) +from torchvision.prototype.transforms.functional import to_image_pil DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)]) -class ArgsKwargs: - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs - - def __iter__(self): - yield self.args - yield self.kwargs - - def __str__(self): - def short_repr(obj, max=20): - repr_ = repr(obj) - if len(repr_) <= max: - return repr_ - - return f"{repr_[:max//2]}...{repr_[-(max//2-3):]}" - - return ", ".join( - itertools.chain( - [short_repr(arg) for arg in self.args], - [f"{param}={short_repr(kwarg)}" for param, kwarg in self.kwargs.items()], - ) - ) - - class ConsistencyConfig: def __init__( self, diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index e56e9b3da77..499dee008bd 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -1,4 +1,3 @@ -import functools import itertools import math import os @@ -9,167 +8,12 @@ import torch.testing import torchvision.prototype.transforms.functional as F from common_utils import cpu_and_gpu +from prototype_common_utils import ArgsKwargs, make_bounding_boxes, make_image, make_images, make_segmentation_masks from torch import jit -from torch.nn.functional import one_hot from torchvision.prototype import features from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format from torchvision.transforms.functional import _get_perspective_coeffs -from torchvision.transforms.functional_tensor import _max_value as get_max_value - -make_tensor = functools.partial(torch.testing.make_tensor, device="cpu") - - -def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32, constant_alpha=True): - size = size or torch.randint(16, 33, (2,)).tolist() - - try: - num_channels = { - features.ColorSpace.GRAY: 1, - features.ColorSpace.GRAY_ALPHA: 2, - features.ColorSpace.RGB: 3, - features.ColorSpace.RGB_ALPHA: 4, - }[color_space] - except KeyError as error: - raise pytest.UsageError() from error - - shape = (*extra_dims, num_channels, *size) - max_value = get_max_value(dtype) - data = make_tensor(shape, low=0, high=max_value, dtype=dtype) - if color_space in {features.ColorSpace.GRAY_ALPHA, features.ColorSpace.RGB_ALPHA} and constant_alpha: - data[..., -1, :, :] = max_value - return features.Image(data, color_space=color_space) - - -make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAY) -make_rgb_image = functools.partial(make_image, color_space=features.ColorSpace.RGB) - - -def make_images( - sizes=((16, 16), (7, 33), (31, 9)), - color_spaces=( - features.ColorSpace.GRAY, - features.ColorSpace.GRAY_ALPHA, - features.ColorSpace.RGB, - features.ColorSpace.RGB_ALPHA, - ), - dtypes=(torch.float32, torch.uint8), - extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)), -): - for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes): - yield make_image(size, color_space=color_space, dtype=dtype) - - for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims): - yield make_image(size=sizes[0], color_space=color_space, extra_dims=extra_dims_, dtype=dtype) - - -def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): - low, high = torch.broadcast_tensors( - *[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))] - ) - return torch.stack( - [ - torch.randint(low_scalar, high_scalar, (), **kwargs) - for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist()) - ] - ).reshape(low.shape) - - -def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch.int64): - if isinstance(format, str): - format = features.BoundingBoxFormat[format] - - if any(dim == 0 for dim in extra_dims): - return features.BoundingBox(torch.empty(*extra_dims, 4), format=format, image_size=image_size) - - height, width = image_size - - if format == features.BoundingBoxFormat.XYXY: - x1 = torch.randint(0, width // 2, extra_dims) - y1 = torch.randint(0, height // 2, extra_dims) - x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1 - y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1 - parts = (x1, y1, x2, y2) - elif format == features.BoundingBoxFormat.XYWH: - x = torch.randint(0, width // 2, extra_dims) - y = torch.randint(0, height // 2, extra_dims) - w = randint_with_tensor_bounds(1, width - x) - h = randint_with_tensor_bounds(1, height - y) - parts = (x, y, w, h) - elif format == features.BoundingBoxFormat.CXCYWH: - cx = torch.randint(1, width - 1, ()) - cy = torch.randint(1, height - 1, ()) - w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1) - h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1) - parts = (cx, cy, w, h) - else: - raise pytest.UsageError() - - return features.BoundingBox(torch.stack(parts, dim=-1).to(dtype), format=format, image_size=image_size) - - -make_xyxy_bounding_box = functools.partial(make_bounding_box, format=features.BoundingBoxFormat.XYXY) - - -def make_bounding_boxes( - formats=(features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH), - image_sizes=((32, 32),), - dtypes=(torch.int64, torch.float32), - extra_dims=((0,), (), (4,), (2, 3), (5, 0), (0, 5)), -): - for format, image_size, dtype in itertools.product(formats, image_sizes, dtypes): - yield make_bounding_box(format=format, image_size=image_size, dtype=dtype) - - for format, extra_dims_ in itertools.product(formats, extra_dims): - yield make_bounding_box(format=format, extra_dims=extra_dims_) - - -def make_label(size=(), *, categories=("category0", "category1")): - return features.Label(torch.randint(0, len(categories) if categories else 10, size), categories=categories) - - -def make_one_hot_label(*args, **kwargs): - label = make_label(*args, **kwargs) - return features.OneHotLabel(one_hot(label, num_classes=len(label.categories)), categories=label.categories) - - -def make_one_hot_labels( - *, - num_categories=(1, 2, 10), - extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)), -): - for num_categories_ in num_categories: - yield make_one_hot_label(categories=[f"category{idx}" for idx in range(num_categories_)]) - - for extra_dims_ in extra_dims: - yield make_one_hot_label(extra_dims_) - - -def make_segmentation_mask(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8): - size = size if size is not None else torch.randint(16, 33, (2,)).tolist() - num_objects = num_objects if num_objects is not None else int(torch.randint(1, 11, ())) - shape = (*extra_dims, num_objects, *size) - data = make_tensor(shape, low=0, high=2, dtype=dtype) - return features.SegmentationMask(data) - - -def make_segmentation_masks( - sizes=((16, 16), (7, 33), (31, 9)), - dtypes=(torch.uint8,), - extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)), - num_objects=(1, 0, 10), -): - for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims): - yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_) - - for dtype, extra_dims_, num_objects_ in itertools.product(dtypes, extra_dims, num_objects): - yield make_segmentation_mask(size=sizes[0], num_objects=num_objects_, dtype=dtype, extra_dims=extra_dims_) - - -class SampleInput: - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs class FunctionalInfo: @@ -182,7 +26,7 @@ def sample_inputs(self): yield from self._sample_inputs_fn() def __call__(self, *args, **kwargs): - if len(args) == 1 and not kwargs and isinstance(args[0], SampleInput): + if len(args) == 1 and not kwargs and isinstance(args[0], ArgsKwargs): sample_input = args[0] return self.functional(*sample_input.args, **sample_input.kwargs) @@ -200,37 +44,37 @@ def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn): @register_kernel_info_from_sample_inputs_fn def horizontal_flip_image_tensor(): for image in make_images(): - yield SampleInput(image) + yield ArgsKwargs(image) @register_kernel_info_from_sample_inputs_fn def horizontal_flip_bounding_box(): for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]): - yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size) + yield ArgsKwargs(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size) @register_kernel_info_from_sample_inputs_fn def horizontal_flip_segmentation_mask(): for mask in make_segmentation_masks(): - yield SampleInput(mask) + yield ArgsKwargs(mask) @register_kernel_info_from_sample_inputs_fn def vertical_flip_image_tensor(): for image in make_images(): - yield SampleInput(image) + yield ArgsKwargs(image) @register_kernel_info_from_sample_inputs_fn def vertical_flip_bounding_box(): for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]): - yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size) + yield ArgsKwargs(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size) @register_kernel_info_from_sample_inputs_fn def vertical_flip_segmentation_mask(): for mask in make_segmentation_masks(): - yield SampleInput(mask) + yield ArgsKwargs(mask) @register_kernel_info_from_sample_inputs_fn @@ -252,7 +96,7 @@ def resize_image_tensor(): ]: if max_size is not None: size = [size[0]] - yield SampleInput(image, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) + yield ArgsKwargs(image, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) @register_kernel_info_from_sample_inputs_fn @@ -268,7 +112,7 @@ def resize_bounding_box(): ]: if max_size is not None: size = [size[0]] - yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size) + yield ArgsKwargs(bounding_box, size=size, image_size=bounding_box.image_size) @register_kernel_info_from_sample_inputs_fn @@ -284,7 +128,7 @@ def resize_segmentation_mask(): ]: if max_size is not None: size = [size[0]] - yield SampleInput(mask, size=size, max_size=max_size) + yield ArgsKwargs(mask, size=size, max_size=max_size) @register_kernel_info_from_sample_inputs_fn @@ -296,7 +140,7 @@ def affine_image_tensor(): [0.77, 1.27], # scale [0, 12], # shear ): - yield SampleInput( + yield ArgsKwargs( image, angle=angle, translate=(translate, translate), @@ -315,7 +159,7 @@ def affine_bounding_box(): [0.77, 1.27], # scale [0, 12], # shear ): - yield SampleInput( + yield ArgsKwargs( bounding_box, format=bounding_box.format, image_size=bounding_box.image_size, @@ -335,7 +179,7 @@ def affine_segmentation_mask(): [0.77, 1.27], # scale [0, 12], # shear ): - yield SampleInput( + yield ArgsKwargs( mask, angle=angle, translate=(translate, translate), @@ -357,7 +201,7 @@ def rotate_image_tensor(): # Skip warning: The provided center argument is ignored if expand is True continue - yield SampleInput(image, angle=angle, expand=expand, center=center, fill=fill) + yield ArgsKwargs(image, angle=angle, expand=expand, center=center, fill=fill) @register_kernel_info_from_sample_inputs_fn @@ -369,7 +213,7 @@ def rotate_bounding_box(): # Skip warning: The provided center argument is ignored if expand is True continue - yield SampleInput( + yield ArgsKwargs( bounding_box, format=bounding_box.format, image_size=bounding_box.image_size, @@ -391,7 +235,7 @@ def rotate_segmentation_mask(): # Skip warning: The provided center argument is ignored if expand is True continue - yield SampleInput( + yield ArgsKwargs( mask, angle=angle, expand=expand, @@ -402,7 +246,7 @@ def rotate_segmentation_mask(): @register_kernel_info_from_sample_inputs_fn def crop_image_tensor(): for image, top, left, height, width in itertools.product(make_images(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]): - yield SampleInput( + yield ArgsKwargs( image, top=top, left=left, @@ -414,7 +258,7 @@ def crop_image_tensor(): @register_kernel_info_from_sample_inputs_fn def crop_bounding_box(): for bounding_box, top, left in itertools.product(make_bounding_boxes(), [-8, 0, 9], [-8, 0, 9]): - yield SampleInput( + yield ArgsKwargs( bounding_box, format=bounding_box.format, top=top, @@ -427,7 +271,7 @@ def crop_segmentation_mask(): for mask, top, left, height, width in itertools.product( make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20] ): - yield SampleInput( + yield ArgsKwargs( mask, top=top, left=left, @@ -447,7 +291,7 @@ def resized_crop_image_tensor(): [(16, 18)], [True, False], ): - yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size, antialias=antialias) + yield ArgsKwargs(mask, top=top, left=left, height=height, width=width, size=size, antialias=antialias) @register_kernel_info_from_sample_inputs_fn @@ -455,7 +299,7 @@ def resized_crop_bounding_box(): for bounding_box, top, left, height, width, size in itertools.product( make_bounding_boxes(), [-8, 9], [-8, 9], [32, 22], [34, 20], [(32, 32), (16, 18)] ): - yield SampleInput( + yield ArgsKwargs( bounding_box, format=bounding_box.format, top=top, left=left, height=height, width=width, size=size ) @@ -465,7 +309,7 @@ def resized_crop_segmentation_mask(): for mask, top, left, height, width, size in itertools.product( make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)] ): - yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size) + yield ArgsKwargs(mask, top=top, left=left, height=height, width=width, size=size) @register_kernel_info_from_sample_inputs_fn @@ -476,7 +320,7 @@ def pad_image_tensor(): [None, 12, 12.0], # fill ["constant", "symmetric", "edge", "reflect"], # padding mode, ): - yield SampleInput(image, padding=padding, fill=fill, padding_mode=padding_mode) + yield ArgsKwargs(image, padding=padding, fill=fill, padding_mode=padding_mode) @register_kernel_info_from_sample_inputs_fn @@ -486,7 +330,7 @@ def pad_segmentation_mask(): [[1], [1, 1], [1, 1, 2, 2]], # padding ["constant", "symmetric", "edge", "reflect"], # padding mode, ): - yield SampleInput(mask, padding=padding, padding_mode=padding_mode) + yield ArgsKwargs(mask, padding=padding, padding_mode=padding_mode) @register_kernel_info_from_sample_inputs_fn @@ -495,7 +339,7 @@ def pad_bounding_box(): make_bounding_boxes(), [[1], [1, 1], [1, 1, 2, 2]], ): - yield SampleInput(bounding_box, padding=padding, format=bounding_box.format) + yield ArgsKwargs(bounding_box, padding=padding, format=bounding_box.format) @register_kernel_info_from_sample_inputs_fn @@ -508,7 +352,7 @@ def perspective_image_tensor(): ], [None, [128], [12.0]], # fill ): - yield SampleInput(image, perspective_coeffs=perspective_coeffs, fill=fill) + yield ArgsKwargs(image, perspective_coeffs=perspective_coeffs, fill=fill) @register_kernel_info_from_sample_inputs_fn @@ -520,7 +364,7 @@ def perspective_bounding_box(): [0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063], ], ): - yield SampleInput( + yield ArgsKwargs( bounding_box, format=bounding_box.format, perspective_coeffs=perspective_coeffs, @@ -536,7 +380,7 @@ def perspective_segmentation_mask(): [0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063], ], ): - yield SampleInput( + yield ArgsKwargs( mask, perspective_coeffs=perspective_coeffs, ) @@ -550,7 +394,7 @@ def elastic_image_tensor(): ): h, w = image.shape[-2:] displacement = torch.rand(1, h, w, 2) - yield SampleInput(image, displacement=displacement, fill=fill) + yield ArgsKwargs(image, displacement=displacement, fill=fill) @register_kernel_info_from_sample_inputs_fn @@ -558,7 +402,7 @@ def elastic_bounding_box(): for bounding_box in make_bounding_boxes(): h, w = bounding_box.image_size displacement = torch.rand(1, h, w, 2) - yield SampleInput( + yield ArgsKwargs( bounding_box, format=bounding_box.format, displacement=displacement, @@ -570,7 +414,7 @@ def elastic_segmentation_mask(): for mask in make_segmentation_masks(extra_dims=((), (4,))): h, w = mask.shape[-2:] displacement = torch.rand(1, h, w, 2) - yield SampleInput( + yield ArgsKwargs( mask, displacement=displacement, ) @@ -582,13 +426,13 @@ def center_crop_image_tensor(): make_images(sizes=((16, 16), (7, 33), (31, 9))), [[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size ): - yield SampleInput(mask, output_size) + yield ArgsKwargs(mask, output_size) @register_kernel_info_from_sample_inputs_fn def center_crop_bounding_box(): for bounding_box, output_size in itertools.product(make_bounding_boxes(), [(24, 12), [16, 18], [46, 48], [12]]): - yield SampleInput( + yield ArgsKwargs( bounding_box, format=bounding_box.format, output_size=output_size, image_size=bounding_box.image_size ) @@ -599,7 +443,7 @@ def center_crop_segmentation_mask(): make_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))), [[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size ): - yield SampleInput(mask, output_size) + yield ArgsKwargs(mask, output_size) @register_kernel_info_from_sample_inputs_fn @@ -609,7 +453,7 @@ def gaussian_blur_image_tensor(): [[3, 3]], [None, [3.0, 3.0]], ): - yield SampleInput(image, kernel_size=kernel_size, sigma=sigma) + yield ArgsKwargs(image, kernel_size=kernel_size, sigma=sigma) @register_kernel_info_from_sample_inputs_fn @@ -617,13 +461,13 @@ def equalize_image_tensor(): for image in make_images(extra_dims=(), color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)): if image.dtype != torch.uint8: continue - yield SampleInput(image) + yield ArgsKwargs(image) @register_kernel_info_from_sample_inputs_fn def invert_image_tensor(): for image in make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)): - yield SampleInput(image) + yield ArgsKwargs(image) @register_kernel_info_from_sample_inputs_fn @@ -634,7 +478,7 @@ def posterize_image_tensor(): ): if image.dtype != torch.uint8: continue - yield SampleInput(image, bits=bits) + yield ArgsKwargs(image, bits=bits) @register_kernel_info_from_sample_inputs_fn @@ -645,13 +489,13 @@ def solarize_image_tensor(): ): if image.is_floating_point() and threshold > 1.0: continue - yield SampleInput(image, threshold=threshold) + yield ArgsKwargs(image, threshold=threshold) @register_kernel_info_from_sample_inputs_fn def autocontrast_image_tensor(): for image in make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)): - yield SampleInput(image) + yield ArgsKwargs(image) @register_kernel_info_from_sample_inputs_fn @@ -660,14 +504,14 @@ def adjust_sharpness_image_tensor(): make_images(extra_dims=((4,),), color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)), [0.1, 0.5], ): - yield SampleInput(image, sharpness_factor=sharpness_factor) + yield ArgsKwargs(image, sharpness_factor=sharpness_factor) @register_kernel_info_from_sample_inputs_fn def erase_image_tensor(): for image in make_images(): c = image.shape[-3] - yield SampleInput(image, i=1, j=2, h=6, w=7, v=torch.rand(c, 6, 7)) + yield ArgsKwargs(image, i=1, j=2, h=6, w=7, v=torch.rand(c, 6, 7)) @pytest.mark.parametrize( diff --git a/test/test_prototype_transforms_utils.py b/test/test_prototype_transforms_utils.py index a656743db26..ed6f7ed6bc7 100644 --- a/test/test_prototype_transforms_utils.py +++ b/test/test_prototype_transforms_utils.py @@ -3,7 +3,7 @@ import torch -from test_prototype_transforms_functional import make_bounding_box, make_image, make_segmentation_mask +from prototype_common_utils import make_bounding_box, make_image, make_segmentation_mask from torchvision.prototype import features from torchvision.prototype.transforms._utils import has_all, has_any