diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index b21a3c62878..33dd94925b6 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -467,6 +467,20 @@ def test__transform(self, degrees, expand, fill, center, mocker): fn.assert_called_once_with(inpt, **params, interpolation=interpolation, expand=expand, fill=fill, center=center) + @pytest.mark.parametrize("angle", [34, -87]) + @pytest.mark.parametrize("expand", [False, True]) + def test_boundingbox_image_size(self, angle, expand): + # Specific test for BoundingBox.rotate + bbox = features.BoundingBox( + torch.tensor([1, 2, 3, 4]), format=features.BoundingBoxFormat.XYXY, image_size=(32, 32) + ) + img = features.Image(torch.rand(1, 3, 32, 32)) + + out_img = img.rotate(angle, expand=expand) + out_bbox = bbox.rotate(angle, expand=expand) + + assert out_img.image_size == out_bbox.image_size + class TestRandomAffine: def test_assertions(self): diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index fb5f10459fe..d3353a0932d 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -693,13 +693,11 @@ def test_scriptable(kernel): "InterpolationMode", "decode_video_with_av", "crop", - "rotate", "perspective", "elastic_transform", "elastic", } # We skip 'crop' due to missing 'height' and 'width' - # We skip 'rotate' due to non implemented yet expand=True case for bboxes # We skip 'perspective' as it requires different input args than perspective_image_tensor etc # Skip 'elastic', TODO: inspect why test is failing ], @@ -999,6 +997,9 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): out_bbox[2] -= tr_x out_bbox[3] -= tr_y + # image_size should be updated, but it is OK here to skip its computation + # as we do not compute it in F.rotate_bounding_box + out_bbox = features.BoundingBox( out_bbox, format=features.BoundingBoxFormat.XYXY, diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 59b88d2931f..54e1315c9ab 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -5,6 +5,8 @@ import torch from torchvision._utils import StrEnum from torchvision.transforms import InterpolationMode +from torchvision.transforms.functional import _get_inverse_affine_matrix +from torchvision.transforms.functional_tensor import _compute_output_size from ._feature import _Feature @@ -168,10 +170,18 @@ def rotate( output = _F.rotate_bounding_box( self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center ) - # TODO: update output image size if expand is True + image_size = self.image_size if expand: - raise RuntimeError("Not yet implemented") - return BoundingBox.new_like(self, output, dtype=output.dtype) + # The way we recompute image_size is not optimal due to redundant computations of + # - rotation matrix (_get_inverse_affine_matrix) + # - points dot matrix (_compute_output_size) + # Alternatively, we could return new image size by _F.rotate_bounding_box + height, width = image_size + rotation_matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, [0.0, 0.0], 1.0, [0.0, 0.0]) + new_width, new_height = _compute_output_size(rotation_matrix, width, height) + image_size = (new_height, new_width) + + return BoundingBox.new_like(self, output, dtype=output.dtype, image_size=image_size) def affine( self, diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 1a55e5c5acb..303486f98ba 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -74,7 +74,7 @@ def new_like( @property def image_size(self) -> Tuple[int, int]: - return cast(Tuple[int, int], self.shape[-2:]) + return cast(Tuple[int, int], tuple(self.shape[-2:])) @property def num_channels(self) -> int: diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 8f37005298b..df5396a063c 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -634,7 +634,7 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] cmax = torch.ceil((max_vals / tol).trunc_() * tol) cmin = torch.floor((min_vals / tol).trunc_() * tol) size = cmax - cmin - return int(size[0]), int(size[1]) + return int(size[0]), int(size[1]) # w, h def rotate(