Skip to content
197 changes: 197 additions & 0 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import functools
import itertools

import pytest
import torch.testing
import torchvision.prototype.transforms.functional as F
from torch import jit
from torchvision.prototype import features

make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")


def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32):
size = size or torch.randint(16, 33, (2,)).tolist()

if isinstance(color_space, str):
color_space = features.ColorSpace[color_space]
num_channels = {
features.ColorSpace.GRAYSCALE: 1,
features.ColorSpace.RGB: 3,
}[color_space]

shape = (*extra_dims, num_channels, *size)
if dtype.is_floating_point:
data = torch.rand(shape, dtype=dtype)
else:
data = torch.randint(0, torch.iinfo(dtype).max, shape, dtype=dtype)
return features.Image(data, color_space=color_space)


make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAYSCALE)
make_rgb_image = functools.partial(make_image, color_space=features.ColorSpace.RGB)


def make_images(
sizes=((16, 16), (7, 33), (31, 9)),
color_spaces=(features.ColorSpace.GRAYSCALE, features.ColorSpace.RGB),
dtypes=(torch.float32, torch.uint8),
extra_dims=((4,), (2, 3)),
):
for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes):
yield make_image(size, color_space=color_space)

for color_space, extra_dims_ in itertools.product(color_spaces, extra_dims):
yield make_image(color_space=color_space, extra_dims=extra_dims_)


def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
low, high = torch.broadcast_tensors(
*[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))]
)
try:
return torch.stack(
[
torch.randint(low_scalar, high_scalar, (), **kwargs)
for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist())
]
).reshape(low.shape)
except RuntimeError as error:
raise error


def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch.int64):
if isinstance(format, str):
format = features.BoundingBoxFormat[format]

height, width = image_size

if format == features.BoundingBoxFormat.XYXY:
x1 = torch.randint(0, width // 2, extra_dims)
y1 = torch.randint(0, height // 2, extra_dims)
x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1
y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1
parts = (x1, y1, x2, y2)
elif format == features.BoundingBoxFormat.XYWH:
x = torch.randint(0, width // 2, extra_dims)
y = torch.randint(0, height // 2, extra_dims)
w = randint_with_tensor_bounds(1, width - x)
h = randint_with_tensor_bounds(1, height - y)
parts = (x, y, w, h)
elif format == features.BoundingBoxFormat.CXCYWH:
cx = torch.randint(1, width - 1, ())
cy = torch.randint(1, height - 1, ())
w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1)
h = randint_with_tensor_bounds(1, torch.minimum(cy, width - cy) + 1)
parts = (cx, cy, w, h)
else: # format == features.BoundingBoxFormat._SENTINEL:
raise ValueError()

return features.BoundingBox(torch.stack(parts, dim=-1).to(dtype), format=format, image_size=image_size)


make_xyxy_bounding_box = functools.partial(make_bounding_box, format=features.BoundingBoxFormat.XYXY)


def make_bounding_boxes(
formats=(features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH),
image_sizes=((32, 32),),
dtypes=(torch.int64, torch.float32),
extra_dims=((4,), (2, 3)),
):
for format, image_size, dtype in itertools.product(formats, image_sizes, dtypes):
yield make_bounding_box(format=format, image_size=image_size, dtype=dtype)

for format, extra_dims_ in itertools.product(formats, extra_dims):
yield make_bounding_box(format=format, extra_dims=extra_dims_)


class SampleInput:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why we need this class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other option would be for a simple inputs function to return ((...), dict(...)) to bundle args and kwargs. IMO this is not as convenient as having a single structure to hold everything. To me

yield SampleInput(bounding_box, old_image_size=bounding_box.image_size, new_image_size=new_image_size)

is more readable than

yield (bounding_box,), dict(old_image_size=bounding_box.image_size, new_image_size=new_image_size)

given that it resembles the actual call signature.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plus, I was looking into lazily loading the samples. This is not implemented yet, so we are generating all samples at test collection time. This can become an issue real quick if we go along this automated tests direction.

def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs


class KernelInfo:
def __init__(self, name, *, sample_inputs_fn):
self.name = name
self.kernel = getattr(F, name)
self._sample_inputs_fn = sample_inputs_fn

def sample_inputs(self):
yield from self._sample_inputs_fn()

def __call__(self, *args, **kwargs):
if len(args) == 1 and not kwargs and isinstance(args[0], SampleInput):
sample_input = args[0]
return self.kernel(*sample_input.args, **sample_input.kwargs)

return self.kernel(*args, **kwargs)


KERNEL_INFOS = []


def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn):
KERNEL_INFOS.append(KernelInfo(sample_inputs_fn.__name__, sample_inputs_fn=sample_inputs_fn))
return sample_inputs_fn


@register_kernel_info_from_sample_inputs_fn
def horizontal_flip_image():
for image in make_images():
yield SampleInput(image)


@register_kernel_info_from_sample_inputs_fn
def horizontal_flip_bounding_box():
for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]):
yield SampleInput(bounding_box, image_size=bounding_box.image_size)


@register_kernel_info_from_sample_inputs_fn
def resize_image():
for image, interpolation in itertools.product(
make_images(),
[
F.InterpolationMode.BILINEAR,
F.InterpolationMode.NEAREST,
],
):
height, width = image.shape[-2:]
for size in [
(height, width),
(int(height * 0.75), int(width * 1.25)),
]:
yield SampleInput(image, size=size, interpolation=interpolation)


@register_kernel_info_from_sample_inputs_fn
def resize_bounding_box():
for bounding_box in make_bounding_boxes():
height, width = bounding_box.image_size
for new_image_size in [
(height, width),
(int(height * 0.75), int(width * 1.25)),
]:
yield SampleInput(bounding_box, old_image_size=bounding_box.image_size, new_image_size=new_image_size)


class TestKernelsCommon:
@pytest.mark.parametrize("kernel_info", KERNEL_INFOS, ids=lambda kernel_info: kernel_info.name)
def test_scriptable(self, kernel_info):
jit.script(kernel_info.kernel)

@pytest.mark.parametrize(
("kernel_info", "sample_input"),
[
pytest.param(kernel_info, sample_input, id=f"{kernel_info.name}-{idx}")
for kernel_info in KERNEL_INFOS
for idx, sample_input in enumerate(kernel_info.sample_inputs())
Comment on lines +189 to +190
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of the 2 manual for loops could we simply use 2 @parametrize statements?

Also, is there stark difference between passing idx instead of relying on pytest's default ids?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of the 2 manual for loops could we simply use 2 @parametrize statements?

We can't use two separate parametrizations, since the inner loop depends on the outer.

Also, is there stark difference between passing idx instead of relying on pytest's default ids?

Let's look into this more after #5295 (comment).

],
)
def test_eager_vs_scripted(self, kernel_info, sample_input):
eager = kernel_info(sample_input)
scripted = jit.script(kernel_info.kernel)(*sample_input.args, **sample_input.kwargs)

torch.testing.assert_close(eager, scripted)
11 changes: 11 additions & 0 deletions torchvision/prototype/features/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,14 @@ def __new__(
bounding_box._metadata.update(dict(format=format, image_size=image_size))

return bounding_box

def to_format(self, format: Union[str, BoundingBoxFormat]) -> "BoundingBox":
# import at runtime to avoid cyclic imports
from torchvision.prototype.transforms.functional import convert_bounding_box_format
Comment on lines +41 to +42
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, this is a sign to redesign the structure and put this method here or in utils module ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was a design choice to have all transforming functions in torchvision.prototype.transforms.functional. This is not implemented yet, but for the automatic dispatch to work, we need to depend on torchvision.prototype.features. Thus, if we want to honor the design choice, we can't get around the cyclic imports if we want to provide these convenience conversion methods.

If we relax the design choice to all transforming functions need to be present in torchvision.prototype.transforms.functional, but can be imported from somewhere else, I fully agree, and this should be refactored.


if isinstance(format, str):
format = BoundingBoxFormat[format]

return BoundingBox.new_like(
self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format
)
7 changes: 7 additions & 0 deletions torchvision/prototype/features/_encoded.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer

from ._feature import Feature
from ._image import Image

D = TypeVar("D", bound="EncodedData")

Expand Down Expand Up @@ -37,6 +38,12 @@ def image_size(self) -> Tuple[int, int]:

return self._image_size

def decode(self) -> Image:
# import at runtime to avoid cyclic imports
from torchvision.prototype.transforms.functional import decode_image_with_pil

return Image(decode_image_with_pil(self))


class EncodedVideo(EncodedData):
pass
5 changes: 5 additions & 0 deletions torchvision/prototype/features/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import torch
from torchvision.prototype.utils._internal import StrEnum
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import draw_bounding_boxes
from torchvision.utils import make_grid

from ._bounding_box import BoundingBox
from ._feature import Feature


Expand Down Expand Up @@ -76,3 +78,6 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace:

def show(self) -> None:
to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show()

def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> "Image":
return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs))
4 changes: 3 additions & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from . import functional
from .functional import InterpolationMode # usort: skip

from ._transform import Transform
from ._container import Compose, RandomApply, RandomChoice, RandomOrder # usort: skip

from ._geometry import Resize, RandomResize, HorizontalFlip, Crop, CenterCrop, RandomCrop
from ._misc import Identity, Normalize
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
27 changes: 27 additions & 0 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from ._augment import erase_image, mixup_image, mixup_one_hot_label, cutmix_image, cutmix_one_hot_label
from ._color import (
adjust_brightness_image,
adjust_contrast_image,
adjust_saturation_image,
adjust_sharpness_image,
posterize_image,
solarize_image,
autocontrast_image,
equalize_image,
invert_image,
)
from ._geometry import (
horizontal_flip_bounding_box,
horizontal_flip_image,
resize_bounding_box,
resize_image,
resize_segmentation_mask,
center_crop_image,
resized_crop_image,
InterpolationMode,
affine_image,
rotate_image,
)
from ._meta_conversion import convert_color_space, convert_bounding_box_format
from ._misc import normalize_image
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot
40 changes: 40 additions & 0 deletions torchvision/prototype/transforms/functional/_augment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Tuple

import torch
from torchvision.transforms import functional as _F


erase_image = _F.erase


def _mixup(input: torch.Tensor, batch_dim: int, lam: float, inplace: bool) -> torch.Tensor:
if not inplace:
input = input.clone()

input_rolled = input.roll(1, batch_dim)
return input.mul_(lam).add_(input_rolled.mul_(1 - lam))


def mixup_image(image_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor:
return _mixup(image_batch, -4, lam, inplace)


def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor:
return _mixup(one_hot_label_batch, -2, lam, inplace)


def cutmix_image(image: torch.Tensor, *, box: Tuple[int, int, int, int], inplace: bool = False) -> torch.Tensor:
if not inplace:
image = image.clone()

x1, y1, x2, y2 = box
image_rolled = image.roll(1, -4)

image[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
return image


def cutmix_one_hot_label(
one_hot_label_batch: torch.Tensor, *, lam_adjusted: float, inplace: bool = False
) -> torch.Tensor:
return mixup_one_hot_label(one_hot_label_batch, lam=lam_adjusted, inplace=inplace)
20 changes: 20 additions & 0 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from torchvision.transforms import functional as _F


adjust_brightness_image = _F.adjust_brightness

adjust_saturation_image = _F.adjust_saturation

adjust_contrast_image = _F.adjust_contrast

adjust_sharpness_image = _F.adjust_sharpness

posterize_image = _F.posterize

solarize_image = _F.solarize

autocontrast_image = _F.autocontrast

equalize_image = _F.equalize

invert_image = _F.invert
Loading