-
Notifications
You must be signed in to change notification settings - Fork 7.2k
[prototype] Align and Clean up transform types #6627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d8d765a
06aad95
b877c20
e5b36b6
7e61d75
1edbc06
4eb82f4
ad60b50
211f294
f3cefdf
1955b89
4bf566a
458a60f
4013924
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,15 @@ | ||
from ._bounding_box import BoundingBox, BoundingBoxFormat | ||
from ._encoded import EncodedData, EncodedImage, EncodedVideo | ||
from ._feature import _Feature, DType, is_simple_tensor | ||
from ._image import ColorSpace, Image, ImageType | ||
from ._feature import _Feature, FillType, FillTypeJIT, InputType, InputTypeJIT, is_simple_tensor | ||
from ._image import ( | ||
ColorSpace, | ||
Image, | ||
ImageType, | ||
ImageTypeJIT, | ||
LegacyImageType, | ||
LegacyImageTypeJIT, | ||
TensorImageType, | ||
TensorImageTypeJIT, | ||
) | ||
from ._label import Label, OneHotLabel | ||
from ._mask import Mask |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,16 +3,14 @@ | |
from types import ModuleType | ||
from typing import Any, Callable, cast, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union | ||
|
||
import PIL.Image | ||
import torch | ||
from torch._C import _TensorBase, DisableTorchFunction | ||
from torchvision.transforms import InterpolationMode | ||
|
||
F = TypeVar("F", bound="_Feature") | ||
|
||
|
||
# Due to torch.jit.script limitation we keep DType as torch.Tensor | ||
# instead of Union[torch.Tensor, PIL.Image.Image, features._Feature] | ||
DType = torch.Tensor | ||
FillType = Union[int, float, Sequence[int], Sequence[float], None] | ||
FillTypeJIT = Union[int, float, List[float], None] | ||
|
||
|
||
def is_simple_tensor(inpt: Any) -> bool: | ||
|
@@ -154,7 +152,7 @@ def resized_crop( | |
def pad( | ||
self, | ||
padding: Union[int, List[int]], | ||
fill: Optional[Union[int, float, List[float]]] = None, | ||
fill: FillTypeJIT = None, | ||
padding_mode: str = "constant", | ||
) -> _Feature: | ||
return self | ||
|
@@ -164,7 +162,7 @@ def rotate( | |
angle: float, | ||
interpolation: InterpolationMode = InterpolationMode.NEAREST, | ||
expand: bool = False, | ||
fill: Optional[Union[int, float, List[float]]] = None, | ||
fill: FillTypeJIT = None, | ||
center: Optional[List[float]] = None, | ||
) -> _Feature: | ||
return self | ||
|
@@ -176,7 +174,7 @@ def affine( | |
scale: float, | ||
shear: List[float], | ||
interpolation: InterpolationMode = InterpolationMode.NEAREST, | ||
fill: Optional[Union[int, float, List[float]]] = None, | ||
fill: FillTypeJIT = None, | ||
center: Optional[List[float]] = None, | ||
) -> _Feature: | ||
return self | ||
|
@@ -185,15 +183,15 @@ def perspective( | |
self, | ||
perspective_coeffs: List[float], | ||
interpolation: InterpolationMode = InterpolationMode.BILINEAR, | ||
fill: Optional[Union[int, float, List[float]]] = None, | ||
fill: FillTypeJIT = None, | ||
) -> _Feature: | ||
return self | ||
|
||
def elastic( | ||
self, | ||
displacement: torch.Tensor, | ||
interpolation: InterpolationMode = InterpolationMode.BILINEAR, | ||
fill: Optional[Union[int, float, List[float]]] = None, | ||
fill: FillTypeJIT = None, | ||
) -> _Feature: | ||
return self | ||
|
||
|
@@ -232,3 +230,7 @@ def invert(self) -> _Feature: | |
|
||
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> _Feature: | ||
return self | ||
|
||
|
||
InputType = Union[torch.Tensor, PIL.Image.Image, _Feature] | ||
InputTypeJIT = torch.Tensor | ||
Comment on lines
+235
to
+236
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previously |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
import math | ||
import numbers | ||
import warnings | ||
from typing import Any, Dict, List, Optional, Tuple, Union | ||
from typing import Any, cast, Dict, List, Optional, Tuple | ||
|
||
import PIL.Image | ||
import torch | ||
|
@@ -92,9 +92,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: | |
|
||
return dict(i=i, j=j, h=h, w=w, v=v) | ||
|
||
def _transform( | ||
self, inpt: Union[torch.Tensor, features.Image, PIL.Image.Image], params: Dict[str, Any] | ||
) -> Union[torch.Tensor, features.Image, PIL.Image.Image]: | ||
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: | ||
if params["v"] is not None: | ||
inpt = F.erase(inpt, **params, inplace=self.inplace) | ||
|
||
|
@@ -110,8 +108,10 @@ def __init__(self, alpha: float, p: float = 0.5) -> None: | |
def forward(self, *inputs: Any) -> Any: | ||
if not (has_any(inputs, features.Image, features.is_simple_tensor) and has_any(inputs, features.OneHotLabel)): | ||
raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.") | ||
if has_any(inputs, features.BoundingBox, features.Mask, features.Label): | ||
raise TypeError(f"{type(self).__name__}() does not support bounding boxes, masks and plain labels.") | ||
if has_any(inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label): | ||
raise TypeError( | ||
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels." | ||
) | ||
Comment on lines
+111
to
+114
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe we previously missed PIL images which are not supported by this class. |
||
return super().forward(*inputs) | ||
|
||
def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel: | ||
|
@@ -203,15 +203,15 @@ def __init__( | |
|
||
def _copy_paste( | ||
self, | ||
image: Any, | ||
image: features.TensorImageType, | ||
target: Dict[str, Any], | ||
paste_image: Any, | ||
paste_image: features.TensorImageType, | ||
paste_target: Dict[str, Any], | ||
random_selection: torch.Tensor, | ||
blending: bool, | ||
resize_interpolation: F.InterpolationMode, | ||
antialias: Optional[bool], | ||
) -> Tuple[Any, Dict[str, Any]]: | ||
) -> Tuple[features.TensorImageType, Dict[str, Any]]: | ||
|
||
paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection]) | ||
paste_boxes = paste_target["boxes"].new_like(paste_target["boxes"], paste_target["boxes"][random_selection]) | ||
|
@@ -223,7 +223,7 @@ def _copy_paste( | |
# This is something different to TF implementation we introduced here as | ||
# originally the algorithm works on equal-sized data | ||
# (for example, coming from LSJ data augmentations) | ||
size1 = image.shape[-2:] | ||
size1 = cast(List[int], image.shape[-2:]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mypy started to randomly complain once the image type was defined, so I choose the cast it. Happy to ignore for performance reasons. |
||
size2 = paste_image.shape[-2:] | ||
if size1 != size2: | ||
paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation, antialias=antialias) | ||
|
@@ -278,7 +278,9 @@ def _copy_paste( | |
|
||
return image, out_target | ||
|
||
def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], List[Dict[str, Any]]]: | ||
def _extract_image_targets( | ||
self, flat_sample: List[Any] | ||
) -> Tuple[List[features.TensorImageType], List[Dict[str, Any]]]: | ||
# fetch all images, bboxes, masks and labels from unstructured input | ||
# with List[image], List[BoundingBox], List[Mask], List[Label] | ||
images, bboxes, masks, labels = [], [], [], [] | ||
|
@@ -307,7 +309,10 @@ def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], Lis | |
return images, targets | ||
|
||
def _insert_outputs( | ||
self, flat_sample: List[Any], output_images: List[Any], output_targets: List[Dict[str, Any]] | ||
self, | ||
flat_sample: List[Any], | ||
output_images: List[features.TensorImageType], | ||
output_targets: List[Dict[str, Any]], | ||
) -> None: | ||
c0, c1, c2, c3 = 0, 0, 0, 0 | ||
for i, obj in enumerate(flat_sample): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could move to another file if we don't want it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is ok there. I would prefer to have everything in a
_typing.py
file, but as you explained this will give us circular imports due to the method signatures. Thus, let's take the way of the least resistance.