diff --git a/test/test_prototype_transforms_kernels.py b/test/test_prototype_transforms_kernels.py index 249ee76e6bc..b83febd8915 100644 --- a/test/test_prototype_transforms_kernels.py +++ b/test/test_prototype_transforms_kernels.py @@ -170,11 +170,11 @@ def resize_image(): def resize_bounding_box(): for bounding_box in make_bounding_boxes(): height, width = bounding_box.image_size - for new_image_size in [ + for size in [ (height, width), (int(height * 0.75), int(width * 1.25)), ]: - yield SampleInput(bounding_box, old_image_size=bounding_box.image_size, new_image_size=new_image_size) + yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size) class TestKernelsCommon: diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 1ffd1fb84dc..4c5b82f5952 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -39,6 +39,9 @@ def __new__( return bounding_box def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox: + # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we + # promote this out of the prototype state + # import at runtime to avoid cyclic imports from torchvision.prototype.transforms.kernels import convert_bounding_box_format diff --git a/torchvision/prototype/features/_encoded.py b/torchvision/prototype/features/_encoded.py index ea8bdeae32e..0117a041b02 100644 --- a/torchvision/prototype/features/_encoded.py +++ b/torchvision/prototype/features/_encoded.py @@ -39,6 +39,9 @@ def image_size(self) -> Tuple[int, int]: return self._image_size def decode(self) -> Image: + # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we + # promote this out of the prototype state + # import at runtime to avoid cyclic imports from torchvision.prototype.transforms.kernels import decode_image_with_pil diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index d6d4df8486e..e5dadbe6af4 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -12,20 +12,20 @@ class Feature(torch.Tensor): _metadata: Dict[str, Any] def __init_subclass__(cls) -> None: - # In order to help static type checkers, we require subclasses of `Feature` to add the metadata attributes - # as static class annotations: - # - # >>> class Foo(Feature): - # ... bar: str - # ... baz: Optional[str] - # - # Internally, this information is used twofold: - # - # 1. A class annotation is contained in `cls.__annotations__` but not in `cls.__dict__`. We use this difference - # to automatically detect the meta data attributes and expose them as `@property`'s for convenient runtime - # access. This happens in this method. - # 2. The information extracted in 1. is also used at creation (`__new__`) to perform an input parsing for - # unknown arguments. + """ + For convenient copying of metadata, we store it inside a dictionary rather than multiple individual attributes. + By adding the metadata attributes as class annotations on subclasses of :class:`Feature`, this method adds + properties to have the same convenient access as regular attributes. + + >>> class Foo(Feature): + ... bar: str + ... baz: Optional[str] + >>> foo = Foo() + >>> foo.bar + >>> foo.baz + + This has the additional benefit that autocomplete engines and static type checkers are aware of the metadata. + """ meta_attrs = {attr for attr in cls.__annotations__.keys() - cls.__dict__.keys() if not attr.startswith("_")} for super_cls in cls.__mro__[1:]: if super_cls is Feature: diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 67a25c9836c..1fd1ea6158c 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -78,7 +78,11 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace: return ColorSpace.OTHER def show(self) -> None: + # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we + # promote this out of the prototype state to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show() def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image: + # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we + # promote this out of the prototype state return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs)) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 2f9f0f76e39..53faea0f087 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -41,7 +41,7 @@ def resize(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" if isinstance(input, features.BoundingBox): size = kwargs.pop("size") - output = K.resize_bounding_box(input, old_image_size=list(input.image_size), new_image_size=size) + output = K.resize_bounding_box(input, size=size, image_size=input.image_size) return cast(T, features.BoundingBox.new_like(input, output, image_size=size)) raise RuntimeError diff --git a/torchvision/prototype/transforms/kernels/_geometry.py b/torchvision/prototype/transforms/kernels/_geometry.py index c3cbbb34b02..34bdc6703b4 100644 --- a/torchvision/prototype/transforms/kernels/_geometry.py +++ b/torchvision/prototype/transforms/kernels/_geometry.py @@ -51,21 +51,15 @@ def resize_image( def resize_segmentation_mask( segmentation_mask: torch.Tensor, size: List[int], - interpolation: InterpolationMode = InterpolationMode.NEAREST, max_size: Optional[int] = None, - antialias: Optional[bool] = None, ) -> torch.Tensor: - return resize_image( - segmentation_mask, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias - ) + return resize_image(segmentation_mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size) # TODO: handle max_size -def resize_bounding_box( - bounding_box: torch.Tensor, *, old_image_size: List[int], new_image_size: List[int] -) -> torch.Tensor: - old_height, old_width = old_image_size - new_height, new_width = new_image_size +def resize_bounding_box(bounding_box: torch.Tensor, *, size: List[int], image_size: Tuple[int, int]) -> torch.Tensor: + old_height, old_width = image_size + new_height, new_width = size ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device) return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape)