Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_mixup_cutmix(self, transform, input):
transform(input_copy)

# Check if we raise an error if sample contains bbox or mask or label
err_msg = "does not support bounding boxes, masks and plain labels"
err_msg = "does not support PIL images, bounding boxes, masks and plain labels"
input_copy = dict(input)
for unsup_data in [
make_label(),
Expand Down
13 changes: 11 additions & 2 deletions torchvision/prototype/features/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._encoded import EncodedData, EncodedImage, EncodedVideo
from ._feature import _Feature, DType, is_simple_tensor
from ._image import ColorSpace, Image, ImageType
from ._feature import _Feature, FillType, FillTypeJIT, InputType, InputTypeJIT, is_simple_tensor
from ._image import (
ColorSpace,
Image,
ImageType,
ImageTypeJIT,
LegacyImageType,
LegacyImageTypeJIT,
TensorImageType,
TensorImageTypeJIT,
)
from ._label import Label, OneHotLabel
from ._mask import Mask
12 changes: 6 additions & 6 deletions torchvision/prototype/features/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torchvision._utils import StrEnum
from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms

from ._feature import _Feature
from ._feature import _Feature, FillTypeJIT


class BoundingBoxFormat(StrEnum):
Expand Down Expand Up @@ -115,7 +115,7 @@ def resized_crop(
def pad(
self,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> BoundingBox:
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
Expand All @@ -137,7 +137,7 @@ def rotate(
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> BoundingBox:
output = self._F.rotate_bounding_box(
Expand Down Expand Up @@ -165,7 +165,7 @@ def affine(
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> BoundingBox:
output = self._F.affine_bounding_box(
Expand All @@ -184,7 +184,7 @@ def perspective(
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
) -> BoundingBox:
output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs)
return BoundingBox.new_like(self, output, dtype=output.dtype)
Expand All @@ -193,7 +193,7 @@ def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
) -> BoundingBox:
output = self._F.elastic_bounding_box(self, self.format, displacement)
return BoundingBox.new_like(self, output, dtype=output.dtype)
22 changes: 12 additions & 10 deletions torchvision/prototype/features/_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
from types import ModuleType
from typing import Any, Callable, cast, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union

import PIL.Image
import torch
from torch._C import _TensorBase, DisableTorchFunction
from torchvision.transforms import InterpolationMode

F = TypeVar("F", bound="_Feature")


# Due to torch.jit.script limitation we keep DType as torch.Tensor
# instead of Union[torch.Tensor, PIL.Image.Image, features._Feature]
DType = torch.Tensor
FillType = Union[int, float, Sequence[int], Sequence[float], None]
FillTypeJIT = Union[int, float, List[float], None]
Comment on lines +12 to +13
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could move to another file if we don't want it here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is ok there. I would prefer to have everything in a _typing.py file, but as you explained this will give us circular imports due to the method signatures. Thus, let's take the way of the least resistance.



def is_simple_tensor(inpt: Any) -> bool:
Expand Down Expand Up @@ -154,7 +152,7 @@ def resized_crop(
def pad(
self,
padding: Union[int, List[int]],
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> _Feature:
return self
Expand All @@ -164,7 +162,7 @@ def rotate(
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> _Feature:
return self
Expand All @@ -176,7 +174,7 @@ def affine(
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> _Feature:
return self
Expand All @@ -185,15 +183,15 @@ def perspective(
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
) -> _Feature:
return self

def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
) -> _Feature:
return self

Expand Down Expand Up @@ -232,3 +230,7 @@ def invert(self) -> _Feature:

def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> _Feature:
return self


InputType = Union[torch.Tensor, PIL.Image.Image, _Feature]
InputTypeJIT = torch.Tensor
Comment on lines +235 to +236
Copy link
Contributor Author

@datumbox datumbox Sep 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously DType.

26 changes: 15 additions & 11 deletions torchvision/prototype/features/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,14 @@
import warnings
from typing import Any, cast, List, Optional, Tuple, Union

import PIL.Image
import torch
from torchvision._utils import StrEnum
from torchvision.transforms.functional import InterpolationMode, to_pil_image
from torchvision.utils import draw_bounding_boxes, make_grid

from ._bounding_box import BoundingBox
from ._feature import _Feature


# Due to torch.jit.script limitation we keep ImageType as torch.Tensor
# instead of Union[torch.Tensor, PIL.Image.Image, features.Image]
ImageType = torch.Tensor
from ._feature import _Feature, FillTypeJIT


class ColorSpace(StrEnum):
Expand Down Expand Up @@ -181,7 +177,7 @@ def resized_crop(
def pad(
self,
padding: Union[int, List[int]],
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> Image:
output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
Expand All @@ -192,7 +188,7 @@ def rotate(
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Image:
output = self._F._geometry.rotate_image_tensor(
Expand All @@ -207,7 +203,7 @@ def affine(
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Image:
output = self._F._geometry.affine_image_tensor(
Expand All @@ -226,7 +222,7 @@ def perspective(
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
) -> Image:
output = self._F._geometry.perspective_image_tensor(
self, perspective_coeffs, interpolation=interpolation, fill=fill
Expand All @@ -237,7 +233,7 @@ def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
) -> Image:
output = self._F._geometry.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill)
return Image.new_like(self, output)
Expand Down Expand Up @@ -289,3 +285,11 @@ def invert(self) -> Image:
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image:
output = self._F.gaussian_blur_image_tensor(self, kernel_size=kernel_size, sigma=sigma)
return Image.new_like(self, output)


ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
ImageTypeJIT = torch.Tensor
LegacyImageType = Union[torch.Tensor, PIL.Image.Image]
LegacyImageTypeJIT = torch.Tensor
TensorImageType = Union[torch.Tensor, Image]
TensorImageTypeJIT = torch.Tensor
12 changes: 6 additions & 6 deletions torchvision/prototype/features/_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torchvision.transforms import InterpolationMode

from ._feature import _Feature
from ._feature import _Feature, FillTypeJIT


class Mask(_Feature):
Expand Down Expand Up @@ -51,7 +51,7 @@ def resized_crop(
def pad(
self,
padding: Union[int, List[int]],
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> Mask:
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
Expand All @@ -62,7 +62,7 @@ def rotate(
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Mask:
output = self._F.rotate_mask(self, angle, expand=expand, center=center, fill=fill)
Expand All @@ -75,7 +75,7 @@ def affine(
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Mask:
output = self._F.affine_mask(
Expand All @@ -93,7 +93,7 @@ def perspective(
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
) -> Mask:
output = self._F.perspective_mask(self, perspective_coeffs, fill=fill)
return Mask.new_like(self, output)
Expand All @@ -102,7 +102,7 @@ def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
) -> Mask:
output = self._F.elastic_mask(self, displacement, fill=fill)
return Mask.new_like(self, output, dtype=output.dtype)
29 changes: 17 additions & 12 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import numbers
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, cast, Dict, List, Optional, Tuple

import PIL.Image
import torch
Expand Down Expand Up @@ -92,9 +92,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:

return dict(i=i, j=j, h=h, w=w, v=v)

def _transform(
self, inpt: Union[torch.Tensor, features.Image, PIL.Image.Image], params: Dict[str, Any]
) -> Union[torch.Tensor, features.Image, PIL.Image.Image]:
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
if params["v"] is not None:
inpt = F.erase(inpt, **params, inplace=self.inplace)

Expand All @@ -110,8 +108,10 @@ def __init__(self, alpha: float, p: float = 0.5) -> None:
def forward(self, *inputs: Any) -> Any:
if not (has_any(inputs, features.Image, features.is_simple_tensor) and has_any(inputs, features.OneHotLabel)):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.")
if has_any(inputs, features.BoundingBox, features.Mask, features.Label):
raise TypeError(f"{type(self).__name__}() does not support bounding boxes, masks and plain labels.")
if has_any(inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label):
raise TypeError(
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
)
Comment on lines +111 to +114
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we previously missed PIL images which are not supported by this class.

return super().forward(*inputs)

def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel:
Expand Down Expand Up @@ -203,15 +203,15 @@ def __init__(

def _copy_paste(
self,
image: Any,
image: features.TensorImageType,
target: Dict[str, Any],
paste_image: Any,
paste_image: features.TensorImageType,
paste_target: Dict[str, Any],
random_selection: torch.Tensor,
blending: bool,
resize_interpolation: F.InterpolationMode,
antialias: Optional[bool],
) -> Tuple[Any, Dict[str, Any]]:
) -> Tuple[features.TensorImageType, Dict[str, Any]]:

paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection])
paste_boxes = paste_target["boxes"].new_like(paste_target["boxes"], paste_target["boxes"][random_selection])
Expand All @@ -223,7 +223,7 @@ def _copy_paste(
# This is something different to TF implementation we introduced here as
# originally the algorithm works on equal-sized data
# (for example, coming from LSJ data augmentations)
size1 = image.shape[-2:]
size1 = cast(List[int], image.shape[-2:])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy started to randomly complain once the image type was defined, so I choose the cast it. Happy to ignore for performance reasons.

size2 = paste_image.shape[-2:]
if size1 != size2:
paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation, antialias=antialias)
Expand Down Expand Up @@ -278,7 +278,9 @@ def _copy_paste(

return image, out_target

def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], List[Dict[str, Any]]]:
def _extract_image_targets(
self, flat_sample: List[Any]
) -> Tuple[List[features.TensorImageType], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBox], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], []
Expand Down Expand Up @@ -307,7 +309,10 @@ def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], Lis
return images, targets

def _insert_outputs(
self, flat_sample: List[Any], output_images: List[Any], output_targets: List[Dict[str, Any]]
self,
flat_sample: List[Any],
output_images: List[features.TensorImageType],
output_targets: List[Dict[str, Any]],
) -> None:
c0, c1, c2, c3 = 0, 0, 0, 0
for i, obj in enumerate(flat_sample):
Expand Down
Loading