From 566b6b1fe0f79b191baf73c08fe02bd8662aca7a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 26 Mar 2025 13:22:03 +0000 Subject: [PATCH 1/2] Revert "Fix non-rotated format to rotated format conversion logic (#8926)" This reverts commit 77e95fce9d2567f9da7da434ef0e65af9582b757. --- torchvision/transforms/v2/functional/_meta.py | 11 ----------- torchvision/tv_tensors/_bounding_boxes.py | 2 +- 2 files changed, 1 insertion(+), 12 deletions(-) diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 4f1c192455e..aeb2d6aed52 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -263,14 +263,6 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor: return xyxyxyxy[..., :5].to(dtype) -def is_rotated_bounding_box_format(format: BoundingBoxFormat) -> bool: - return format.value in [ - BoundingBoxFormat.XYWHR.value, - BoundingBoxFormat.CXCYWHR.value, - BoundingBoxFormat.XYXYXYXY.value, - ] - - def _convert_bounding_box_format( bounding_boxes: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False ) -> torch.Tensor: @@ -278,9 +270,6 @@ def _convert_bounding_box_format( if new_format == old_format: return bounding_boxes - if is_rotated_bounding_box_format(old_format) ^ is_rotated_bounding_box_format(new_format): - raise ValueError("Cannot convert between rotated and unrotated bounding boxes.") - # TODO: Add _xywh_to_cxcywh and _cxcywh_to_xywh to improve performance if old_format == BoundingBoxFormat.XYWH: bounding_boxes = _xywh_to_xyxy(bounding_boxes, inplace) diff --git a/torchvision/tv_tensors/_bounding_boxes.py b/torchvision/tv_tensors/_bounding_boxes.py index df53550332e..b0238bb694e 100644 --- a/torchvision/tv_tensors/_bounding_boxes.py +++ b/torchvision/tv_tensors/_bounding_boxes.py @@ -20,7 +20,7 @@ class BoundingBoxFormat(Enum): * ``XYWHR``: rotated boxes represented via corner, width and height, x1, y1 being top left, w, h being width and height. r is rotation angle in degrees. - * ``CXCYWHR``: rotated boxes represented via centre, width and height, cx, + * ``CXCYWHR``: jrotated boxes represented via centre, width and height, cx, cy being center of box, w, h being width and height. r is rotation angle in degrees. * ``XYXYXYXY``: rotated boxes represented via corners, x1, y1 being top From a62435c2137b98750e0972c693c566965b4b4fd6 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 26 Mar 2025 13:22:05 +0000 Subject: [PATCH 2/2] Revert "Add rotated bounding box formats (#8841)" This reverts commit 501a2c9a56ebb4dfaf7ff72dccd437647bba6873. --- test/common_utils.py | 18 --- test/test_ops.py | 57 +--------- test/test_transforms_v2.py | 76 +++++-------- torchvision/ops/_box_convert.py | 107 ------------------ torchvision/ops/boxes.py | 89 ++++----------- torchvision/transforms/v2/functional/_meta.py | 95 ---------------- torchvision/tv_tensors/_bounding_boxes.py | 19 +--- 7 files changed, 60 insertions(+), 401 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 9a4b41e606f..99c7931587d 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -423,7 +423,6 @@ def sample_position(values, max_value): h, w = [torch.randint(1, s, (num_boxes,)) for s in canvas_size] y = sample_position(h, canvas_size[0]) x = sample_position(w, canvas_size[1]) - r = -360 * torch.rand((num_boxes,)) + 180 if format is tv_tensors.BoundingBoxFormat.XYWH: parts = (x, y, w, h) @@ -436,23 +435,6 @@ def sample_position(values, max_value): cx = x + w / 2 cy = y + h / 2 parts = (cx, cy, w, h) - elif format is tv_tensors.BoundingBoxFormat.XYWHR: - parts = (x, y, w, h, r) - elif format is tv_tensors.BoundingBoxFormat.CXCYWHR: - cx = x + w / 2 - cy = y + h / 2 - parts = (cx, cy, w, h, r) - elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY: - r_rad = r * torch.pi / 180.0 - cos, sin = torch.cos(r_rad), torch.sin(r_rad) - x1, y1 = x, y - x3 = x1 + w * cos - y3 = y1 - w * sin - x2 = x3 + h * sin - y2 = y3 + h * cos - x4 = x1 + h * sin - y4 = y1 + h * cos - parts = (x1, y1, x3, y3, x2, y2, x4, y4) else: raise ValueError(f"Format {format} is not supported") diff --git a/test/test_ops.py b/test/test_ops.py index 88124f7ba17..1ba7a2c9efa 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1339,61 +1339,8 @@ def test_bbox_xywh_cxcywh(self): box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh") assert_equal(box_xywh, box_tensor) - def test_bbox_xywhr_cxcywhr(self): - box_tensor = torch.tensor( - [ - [0, 0, 100, 100, 0], - [0, 0, 0, 0, 0], - [10, 15, 20, 20, 0], - [23, 35, 70, 60, 0], - [4, 2, 4, 2, 0], - [5, 5, 4, 2, 90], - [8, 4, 4, 2, 180], - [7, 1, 4, 2, -90], - ], - dtype=torch.float, - ) - - exp_cxcywhr = torch.tensor( - [ - [50, 50, 100, 100, 0], - [0, 0, 0, 0, 0], - [20, 25, 20, 20, 0], - [58, 65, 70, 60, 0], - [6, 3, 4, 2, 0], - [6, 3, 4, 2, 90], - [6, 3, 4, 2, 180], - [6, 3, 4, 2, -90], - ], - dtype=torch.float, - ) - - assert exp_cxcywhr.size() == torch.Size([8, 5]) - box_cxcywhr = ops.box_convert(box_tensor, in_fmt="xywhr", out_fmt="cxcywhr") - torch.testing.assert_close(box_cxcywhr, exp_cxcywhr) - - # Reverse conversion - box_xywhr = ops.box_convert(box_cxcywhr, in_fmt="cxcywhr", out_fmt="xywhr") - torch.testing.assert_close(box_xywhr, box_tensor) - - def test_bbox_cxcywhr_to_xyxyxyxy(self): - box_tensor = torch.tensor([[5, 3, 4, 2, 90]], dtype=torch.float) - exp_xyxyxyxy = torch.tensor([[4, 5, 4, 1, 6, 1, 6, 5]], dtype=torch.float) - - assert exp_xyxyxyxy.size() == torch.Size([1, 8]) - box_xyxyxyxy = ops.box_convert(box_tensor, in_fmt="cxcywhr", out_fmt="xyxyxyxy") - torch.testing.assert_close(box_xyxyxyxy, exp_xyxyxyxy) - - def test_bbox_xywhr_to_xyxyxyxy(self): - box_tensor = torch.tensor([[4, 5, 4, 2, 90]], dtype=torch.float) - exp_xyxyxyxy = torch.tensor([[4, 5, 4, 1, 6, 1, 6, 5]], dtype=torch.float) - - assert exp_xyxyxyxy.size() == torch.Size([1, 8]) - box_xyxyxyxy = ops.box_convert(box_tensor, in_fmt="xywhr", out_fmt="xyxyxyxy") - torch.testing.assert_close(box_xyxyxyxy, exp_xyxyxyxy) - - @pytest.mark.parametrize("inv_infmt", ["xwyh", "cxwyh", "xwyhr", "cxwyhr", "xxxxyyyy"]) - @pytest.mark.parametrize("inv_outfmt", ["xwcx", "xhwcy", "xwcxr", "xhwcyr", "xyxyxxyy"]) + @pytest.mark.parametrize("inv_infmt", ["xwyh", "cxwyh"]) + @pytest.mark.parametrize("inv_outfmt", ["xwcx", "xhwcy"]) def test_bbox_invalid(self, inv_infmt, inv_outfmt): box_tensor = torch.tensor( [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index a9fd3bc5ec9..ddd7ebf4e6f 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -53,15 +53,6 @@ from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal -# While we are working on adjusting transform functions -# for rotated and oriented bounding boxes formats, -# we limit the perimeter of tests to formats -# for which transform functions are already implemented. -# In the future, this global variable will be replaced with `list(tv_tensors.BoundingBoxFormat)` -# to support all available formats. -SUPPORTED_BOX_FORMATS = [tv_tensors.BoundingBoxFormat[x] for x in ["XYXY", "XYWH", "CXCYWH"]] -NEW_BOX_FORMATS = [tv_tensors.BoundingBoxFormat[x] for x in ["XYWHR", "CXCYWHR", "XYXYXYXY"]] - # turns all warnings into errors for this module pytestmark = [pytest.mark.filterwarnings("error")] @@ -635,7 +626,7 @@ def test_kernel_image(self, size, interpolation, use_max_size, antialias, dtype, check_scripted_vs_eager=not isinstance(size, int), ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("size", OUTPUT_SIZES) @pytest.mark.parametrize("use_max_size", [True, False]) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @@ -766,7 +757,7 @@ def _reference_resize_bounding_boxes(self, bounding_boxes, *, size, max_size=Non new_canvas_size=(new_height, new_width), ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("size", OUTPUT_SIZES) @pytest.mark.parametrize("use_max_size", [True, False]) @pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)]) @@ -1012,7 +1003,7 @@ class TestHorizontalFlip: def test_kernel_image(self, dtype, device): check_kernel(F.horizontal_flip_image, make_image(dtype=dtype, device=device)) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_bounding_boxes(self, format, dtype, device): @@ -1081,7 +1072,7 @@ def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes): return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize( "fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)] ) @@ -1178,7 +1169,7 @@ def test_kernel_image(self, param, value, dtype, device): shear=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["shear"], center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"], ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_bounding_boxes(self, param, value, format, dtype, device): @@ -1327,7 +1318,7 @@ def _reference_affine_bounding_boxes(self, bounding_boxes, *, angle, translate, ), ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"]) @pytest.mark.parametrize("translate", _CORRECTNESS_AFFINE_KWARGS["translate"]) @pytest.mark.parametrize("scale", _CORRECTNESS_AFFINE_KWARGS["scale"]) @@ -1355,7 +1346,7 @@ def test_functional_bounding_boxes_correctness(self, format, angle, translate, s torch.testing.assert_close(actual, expected) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) @pytest.mark.parametrize("seed", list(range(5))) def test_transform_bounding_boxes_correctness(self, format, center, seed): @@ -1462,7 +1453,7 @@ class TestVerticalFlip: def test_kernel_image(self, dtype, device): check_kernel(F.vertical_flip_image, make_image(dtype=dtype, device=device)) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_bounding_boxes(self, format, dtype, device): @@ -1529,7 +1520,7 @@ def _reference_vertical_flip_bounding_boxes(self, bounding_boxes): return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)]) def test_bounding_boxes_correctness(self, format, fn): bounding_boxes = make_bounding_boxes(format=format) @@ -1598,7 +1589,7 @@ def test_kernel_image(self, param, value, dtype, device): expand=[False, True], center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"], ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_bounding_boxes(self, param, value, format, dtype, device): @@ -1769,7 +1760,7 @@ def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, cen bounding_boxes ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"]) @pytest.mark.parametrize("expand", [False, True]) @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) @@ -1782,7 +1773,7 @@ def test_functional_bounding_boxes_correctness(self, format, angle, expand, cent torch.testing.assert_close(actual, expected) torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("expand", [False, True]) @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) @pytest.mark.parametrize("seed", list(range(5))) @@ -2703,7 +2694,7 @@ def test_kernel_image(self, param, value, dtype, device): check_cuda_vs_cpu=dtype is not torch.float16, ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_bounding_boxes(self, format, dtype, device): @@ -2830,7 +2821,7 @@ def test_kernel_image(self, kwargs, dtype, device): check_kernel(F.crop_image, make_image(self.INPUT_SIZE, dtype=dtype, device=device), **kwargs) @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_bounding_box(self, kwargs, format, dtype, device): @@ -2980,7 +2971,7 @@ def _reference_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, w ) @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_functional_bounding_box_correctness(self, kwargs, format, dtype, device): @@ -2993,7 +2984,7 @@ def test_functional_bounding_box_correctness(self, kwargs, format, dtype, device assert_equal(F.get_size(actual), F.get_size(expected)) @pytest.mark.parametrize("output_size", [(17, 11), (11, 17), (11, 11)]) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("seed", list(range(5))) @@ -3516,8 +3507,7 @@ def test_aug_mix_severity_error(self, severity): class TestConvertBoundingBoxFormat: - old_new_formats = list(itertools.permutations(SUPPORTED_BOX_FORMATS, 2)) - old_new_formats += list(itertools.permutations(NEW_BOX_FORMATS, 2)) + old_new_formats = list(itertools.permutations(iter(tv_tensors.BoundingBoxFormat), 2)) @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats) def test_kernel(self, old_format, new_format): @@ -3528,7 +3518,7 @@ def test_kernel(self, old_format, new_format): old_format=old_format, ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("inplace", [False, True]) def test_kernel_noop(self, format, inplace): input = make_bounding_boxes(format=format).as_subclass(torch.Tensor) @@ -3552,13 +3542,9 @@ def test_kernel_inplace(self, old_format, new_format): output_inplace = F.convert_bounding_box_format( input, old_format=old_format, new_format=new_format, inplace=True ) - if old_format != tv_tensors.BoundingBoxFormat.XYXYXYXY and new_format != tv_tensors.BoundingBoxFormat.XYXYXYXY: - # NOTE: BoundingBox format conversion from and to XYXYXYXY format - # cannot modify the input tensor inplace as it requires a dimension - # change. - assert output_inplace.data_ptr() == input.data_ptr() - assert output_inplace._version > input_version - assert output_inplace is input + assert output_inplace.data_ptr() == input.data_ptr() + assert output_inplace._version > input_version + assert output_inplace is input assert_equal(output_inplace, output_out_of_place) @@ -3577,7 +3563,7 @@ def test_transform(self, old_format, new_format, format_type): @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats) def test_strings(self, old_format, new_format): # Non-regression test for https://github.com/pytorch/vision/issues/8258 - input = make_bounding_boxes(format=old_format, canvas_size=(50, 50)) + input = tv_tensors.BoundingBoxes(torch.tensor([[10, 10, 20, 20]]), format=old_format, canvas_size=(50, 50)) expected = self._reference_convert_bounding_box_format(input, new_format) old_format = old_format.name @@ -3742,7 +3728,7 @@ def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, h new_canvas_size=size, ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) def test_functional_bounding_boxes_correctness(self, format): bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format) @@ -3816,7 +3802,7 @@ def test_kernel_image(self, param, value, dtype, device): ), ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) def test_kernel_bounding_boxes(self, format): bounding_boxes = make_bounding_boxes(format=format) check_kernel( @@ -3935,7 +3921,7 @@ def _reference_pad_bounding_boxes(self, bounding_boxes, *, padding): ) @pytest.mark.parametrize("padding", CORRECTNESS_PADDINGS) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.int64, torch.float32]) @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("fn", [F.pad, transform_cls_to_functional(transforms.Pad)]) @@ -3964,7 +3950,7 @@ def test_kernel_image(self, output_size, dtype, device): ) @pytest.mark.parametrize("output_size", OUTPUT_SIZES) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) def test_kernel_bounding_boxes(self, output_size, format): bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format) check_kernel( @@ -4043,7 +4029,7 @@ def _reference_center_crop_bounding_boxes(self, bounding_boxes, output_size): ) @pytest.mark.parametrize("output_size", OUTPUT_SIZES) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.int64, torch.float32]) @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)]) @@ -4110,7 +4096,7 @@ def test_kernel_image_error(self): coefficients=COEFFICIENTS, start_end_points=START_END_POINTS, ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) def test_kernel_bounding_boxes(self, param, value, format): if param == "start_end_points": kwargs = dict(zip(["startpoints", "endpoints"], value)) @@ -4286,7 +4272,7 @@ def perspective_bounding_boxes(bounding_boxes): ) @pytest.mark.parametrize(("startpoints", "endpoints"), START_END_POINTS) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.int64, torch.float32]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_correctness_perspective_bounding_boxes(self, startpoints, endpoints, format, dtype, device): @@ -4493,7 +4479,7 @@ def test_correctness_image(self, mean, std, dtype, fn): class TestClampBoundingBoxes: - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.int64, torch.float32]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel(self, format, dtype, device): @@ -4505,7 +4491,7 @@ def test_kernel(self, format, dtype, device): canvas_size=bounding_boxes.canvas_size, ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) def test_functional(self, format): check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format)) diff --git a/torchvision/ops/_box_convert.py b/torchvision/ops/_box_convert.py index 62744fee060..124bdd0bcc6 100644 --- a/torchvision/ops/_box_convert.py +++ b/torchvision/ops/_box_convert.py @@ -79,110 +79,3 @@ def _box_xyxy_to_xywh(boxes: Tensor) -> Tensor: h = y2 - y1 # y2 - y1 boxes = torch.stack((x1, y1, w, h), dim=-1) return boxes - - -def _box_cxcywhr_to_xywhr(boxes: Tensor) -> Tensor: - """ - Converts rotated bounding boxes from (cx, cy, w, h, r) format to (x1, y1, w, h, r) format. - (cx, cy) refers to center of bounding box - (w, h) refers to width and height of rotated bounding box - (x1, y1) refers to top left of rotated bounding box - r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan - Args: - boxes (Tensor[N, 5]): boxes in (cx, cy, w, h, r) format which will be converted. - - Returns: - boxes (Tensor(N, 5)): rotated boxes in (x1, y1, w, h, r) format. - """ - cx, cy, w, h, r = boxes.unbind(-1) - r_rad = r * torch.pi / 180.0 - cos, sin = torch.cos(r_rad), torch.sin(r_rad) - - x1 = cx - w / 2 * cos - h / 2 * sin - y1 = cy - h / 2 * cos + w / 2 * sin - boxes = torch.stack((x1, y1, w, h, r), dim=-1) - - return boxes - - -def _box_xywhr_to_cxcywhr(boxes: Tensor) -> Tensor: - """ - Converts rotated bounding boxes from (x1, y1, w, h, r) format to (cx, cy, w, h, r) format. - (x1, y1) refers to top left of rotated bounding box - (w, h) refers to width and height of rotated bounding box - r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan - Args: - boxes (Tensor[N, 5]): rotated boxes in (x1, y1, w, h, r) format which will be converted. - - Returns: - boxes (Tensor[N, 5]): rotated boxes in (cx, cy, w, h, r) format. - """ - x1, y1, w, h, r = boxes.unbind(-1) - r_rad = r * torch.pi / 180.0 - cos, sin = torch.cos(r_rad), torch.sin(r_rad) - - cx = x1 + w / 2 * cos + h / 2 * sin - cy = y1 - w / 2 * sin + h / 2 * cos - - boxes = torch.stack([cx, cy, w, h, r], dim=-1) - return boxes - - -def _box_xywhr_to_xyxyxyxy(boxes: Tensor) -> Tensor: - """ - Converts rotated bounding boxes from (x1, y1, w, h, r) format to (x1, y1, x3, y3, x2, y2, x4, y4) format. - (x1, y1) refer to top left of bounding box - (w, h) are width and height of the rotated bounding box - r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan - - (x1, y1) refer to top left of rotated bounding box - (x3, y3) refer to top right of rotated bounding box - (x2, y2) refer to bottom right of rotated bounding box - (x4, y4) refer to bottom left ofrotated bounding box - Args: - boxes (Tensor[N, 5]): rotated boxes in (cx, cy, w, h, r) format which will be converted. - - Returns: - boxes (Tensor(N, 8)): rotated boxes in (x1, y1, x3, y3, x2, y2, x4, y4) format. - """ - x1, y1, w, h, r = boxes.unbind(-1) - r_rad = r * torch.pi / 180.0 - cos, sin = torch.cos(r_rad), torch.sin(r_rad) - - x3 = x1 + w * cos - y3 = y1 - w * sin - x2 = x3 + h * sin - y2 = y3 + h * cos - x4 = x1 + h * sin - y4 = y1 + h * cos - - return torch.stack((x1, y1, x3, y3, x2, y2, x4, y4), dim=-1) - - -def _box_xyxyxyxy_to_xywhr(boxes: Tensor) -> Tensor: - """ - Converts rotated bounding boxes from (x1, y1, x3, y3, x2, y2, x4, y4) format to (x1, y1, w, h, r) format. - (x1, y1) refer to top left of the rotated bounding box - (x3, y3) refer to bottom left of the rotated bounding box - (x2, y2) refer to bottom right of the rotated bounding box - (x4, y4) refer to top right of the rotated bounding box - (w, h) refers to width and height of rotated bounding box - r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan - - Args: - boxes (Tensor(N, 8)): rotated boxes in (x1, y1, x3, y3, x2, y2, x4, y4) format. - - Returns: - boxes (Tensor[N, 5]): rotated boxes in (x1, y1, w, h, r) format. - """ - x1, y1, x3, y3, x2, y2, x4, y4 = boxes.unbind(-1) - r_rad = torch.atan2(y1 - y3, x3 - x1) - r = r_rad * 180 / torch.pi - cos, sin = torch.cos(r_rad), torch.sin(r_rad) - - w = (x2 - x1) * cos + (y1 - y2) * sin - h = (x2 - x1) * sin + (y2 - y1) * cos - - boxes = torch.stack((x1, y1, w, h, r), dim=-1) - - return boxes diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 48df4d85cc7..089e2d11504 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -6,16 +6,7 @@ from torchvision.extension import _assert_has_ops from ..utils import _log_api_usage_once -from ._box_convert import ( - _box_cxcywh_to_xyxy, - _box_cxcywhr_to_xywhr, - _box_xywh_to_xyxy, - _box_xywhr_to_cxcywhr, - _box_xywhr_to_xyxyxyxy, - _box_xyxy_to_cxcywh, - _box_xyxy_to_xywh, - _box_xyxyxyxy_to_xywhr, -) +from ._box_convert import _box_cxcywh_to_xyxy, _box_xywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xyxy_to_xywh from ._utils import _upcast @@ -204,71 +195,41 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor: ``'cxcywh'``: boxes are represented via centre, width and height, cx, cy being center of box, w, h being width and height. - ``'xywhr'``: boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height. - r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan - - ``'cxcywhr'``: boxes are represented via centre, width and height, cx, cy being center of box, w, h - being width and height. - r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan - - ``'xyxyxyxy'``: boxes are represented via corners, x1, y1 being top left, x2, y2 bottom right, - x3, y3 bottom left, and x4, y4 top right. - Args: - boxes (Tensor[N, K]): boxes which will be converted. K is the number of coordinates (4 for unrotated bounding boxes, 5 or 8 for rotated bounding boxes) - in_fmt (str): Input format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh', 'xywhr', 'cxcywhr', 'xyxyxyxy']. - out_fmt (str): Output format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh', 'xywhr', 'cxcywhr', 'xyxyxyxy'] + boxes (Tensor[N, 4]): boxes which will be converted. + in_fmt (str): Input format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh']. + out_fmt (str): Output format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh'] Returns: - Tensor[N, K]: Boxes into converted format. + Tensor[N, 4]: Boxes into converted format. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(box_convert) - allowed_fmts = ( - "xyxy", - "xywh", - "cxcywh", - "xywhr", - "cxcywhr", - "xyxyxyxy", - ) + allowed_fmts = ("xyxy", "xywh", "cxcywh") if in_fmt not in allowed_fmts or out_fmt not in allowed_fmts: - raise ValueError(f"Unsupported Bounding Box Conversions for given in_fmt {in_fmt} and out_fmt {out_fmt}") + raise ValueError("Unsupported Bounding Box Conversions for given in_fmt and out_fmt") if in_fmt == out_fmt: return boxes.clone() - e = (in_fmt, out_fmt) - if e == ("xywh", "xyxy"): - boxes = _box_xywh_to_xyxy(boxes) - elif e == ("cxcywh", "xyxy"): - boxes = _box_cxcywh_to_xyxy(boxes) - elif e == ("xyxy", "xywh"): - boxes = _box_xyxy_to_xywh(boxes) - elif e == ("xyxy", "cxcywh"): - boxes = _box_xyxy_to_cxcywh(boxes) - elif e == ("xywh", "cxcywh"): - boxes = _box_xywh_to_xyxy(boxes) - boxes = _box_xyxy_to_cxcywh(boxes) - elif e == ("cxcywh", "xywh"): - boxes = _box_cxcywh_to_xyxy(boxes) - boxes = _box_xyxy_to_xywh(boxes) - elif e == ("cxcywhr", "xywhr"): - boxes = _box_cxcywhr_to_xywhr(boxes) - elif e == ("xywhr", "cxcywhr"): - boxes = _box_xywhr_to_cxcywhr(boxes) - elif e == ("cxcywhr", "xyxyxyxy"): - boxes = _box_cxcywhr_to_xywhr(boxes).to(boxes.dtype) - boxes = _box_xywhr_to_xyxyxyxy(boxes) - elif e == ("xyxyxyxy", "cxcywhr"): - boxes = _box_xyxyxyxy_to_xywhr(boxes).to(boxes.dtype) - boxes = _box_xywhr_to_cxcywhr(boxes) - elif e == ("xywhr", "xyxyxyxy"): - boxes = _box_xywhr_to_xyxyxyxy(boxes) - elif e == ("xyxyxyxy", "xywhr"): - boxes = _box_xyxyxyxy_to_xywhr(boxes) - else: - raise NotImplementedError(f"Unsupported Bounding Box Conversions for given in_fmt {e[0]} and out_fmt {e[1]}") + if in_fmt != "xyxy" and out_fmt != "xyxy": + # convert to xyxy and change in_fmt xyxy + if in_fmt == "xywh": + boxes = _box_xywh_to_xyxy(boxes) + elif in_fmt == "cxcywh": + boxes = _box_cxcywh_to_xyxy(boxes) + in_fmt = "xyxy" + + if in_fmt == "xyxy": + if out_fmt == "xywh": + boxes = _box_xyxy_to_xywh(boxes) + elif out_fmt == "cxcywh": + boxes = _box_xyxy_to_cxcywh(boxes) + elif out_fmt == "xyxy": + if in_fmt == "xywh": + boxes = _box_xywh_to_xyxy(boxes) + elif in_fmt == "cxcywh": + boxes = _box_cxcywh_to_xyxy(boxes) return boxes diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index aeb2d6aed52..b90e5fb7b5b 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -176,93 +176,6 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor: return xyxy -def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor: - if not inplace: - cxcywhr = cxcywhr.clone() - - dtype = cxcywhr.dtype - if not cxcywhr.is_floating_point(): - cxcywhr = cxcywhr.float() - - half_wh = cxcywhr[..., 2:-1].div(-2, rounding_mode=None if cxcywhr.is_floating_point() else "floor").abs_() - r_rad = cxcywhr[..., 4].mul(torch.pi).div(180.0) - cos, sin = r_rad.cos(), r_rad.sin() - # (cx - width / 2 * cos - height / 2 * sin) = x1 - cxcywhr[..., 0].sub_(half_wh[..., 0].mul(cos)).sub_(half_wh[..., 1].mul(sin)) - # (cy + width / 2 * sin - height / 2 * cos) = y1 - cxcywhr[..., 1].add_(half_wh[..., 0].mul(sin)).sub_(half_wh[..., 1].mul(cos)) - - return cxcywhr.to(dtype) - - -def _xywhr_to_cxcywhr(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor: - if not inplace: - xywhr = xywhr.clone() - - dtype = xywhr.dtype - if not xywhr.is_floating_point(): - xywhr = xywhr.float() - - half_wh = xywhr[..., 2:-1].div(-2, rounding_mode=None if xywhr.is_floating_point() else "floor").abs_() - r_rad = xywhr[..., 4].mul(torch.pi).div(180.0) - cos, sin = r_rad.cos(), r_rad.sin() - # (x1 + width / 2 * cos + height / 2 * sin) = cx - xywhr[..., 0].add_(half_wh[..., 0].mul(cos)).add_(half_wh[..., 1].mul(sin)) - # (y1 - width / 2 * sin + height / 2 * cos) = cy - xywhr[..., 1].sub_(half_wh[..., 0].mul(sin)).add_(half_wh[..., 1].mul(cos)) - - return xywhr.to(dtype) - - -def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor: - # NOTE: This function cannot modify the input tensor inplace as it requires a dimension change. - if not inplace: - xywhr = xywhr.clone() - - dtype = xywhr.dtype - if not xywhr.is_floating_point(): - xywhr = xywhr.float() - - wh = xywhr[..., 2:-1] - r_rad = xywhr[..., 4].mul(torch.pi).div(180.0) - cos, sin = r_rad.cos(), r_rad.sin() - xywhr = xywhr[..., :2].tile((1, 4)) - # x1 + w * cos = x3 - xywhr[..., 2].add_(wh[..., 0].mul(cos)) - # y1 - w * sin = y3 - xywhr[..., 3].sub_(wh[..., 0].mul(sin)) - # x1 + w * cos + h * sin = x2 - xywhr[..., 4].add_(wh[..., 0].mul(cos).add(wh[..., 1].mul(sin))) - # y1 - w * sin + h * cos = y2 - xywhr[..., 5].sub_(wh[..., 0].mul(sin).sub(wh[..., 1].mul(cos))) - # x1 + h * sin = x4 - xywhr[..., 6].add_(wh[..., 1].mul(sin)) - # y1 + h * cos = y4 - xywhr[..., 7].add_(wh[..., 1].mul(cos)) - return xywhr.to(dtype) - - -def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor: - # NOTE: This function cannot modify the input tensor inplace as it requires a dimension change. - if not inplace: - xyxyxyxy = xyxyxyxy.clone() - - dtype = xyxyxyxy.dtype - if not xyxyxyxy.is_floating_point(): - xyxyxyxy = xyxyxyxy.float() - - r_rad = torch.atan2(xyxyxyxy[..., 1].sub(xyxyxyxy[..., 3]), xyxyxyxy[..., 2].sub(xyxyxyxy[..., 0])) - cos, sin = r_rad.cos(), r_rad.sin() - # x1, y1, x3, y3, (x2 - x1), (y2 - y1) x4, y4 - xyxyxyxy[..., 4:6].sub_(xyxyxyxy[..., :2]) - # (x2 - x1) * cos + (y1 - y2) * sin = w - xyxyxyxy[..., 2] = xyxyxyxy[..., 4].mul(cos).sub(xyxyxyxy[..., 5].mul(sin)) - # (x2 - x1) * sin + (y2 - y1) * cos = h - xyxyxyxy[..., 3] = xyxyxyxy[..., 5].mul(cos).add(xyxyxyxy[..., 4].mul(sin)) - xyxyxyxy[..., 4] = r_rad.div_(torch.pi).mul_(180.0) - return xyxyxyxy[..., :5].to(dtype) - - def _convert_bounding_box_format( bounding_boxes: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False ) -> torch.Tensor: @@ -275,19 +188,11 @@ def _convert_bounding_box_format( bounding_boxes = _xywh_to_xyxy(bounding_boxes, inplace) elif old_format == BoundingBoxFormat.CXCYWH: bounding_boxes = _cxcywh_to_xyxy(bounding_boxes, inplace) - elif old_format == BoundingBoxFormat.CXCYWHR: - bounding_boxes = _cxcywhr_to_xywhr(bounding_boxes, inplace) - elif old_format == BoundingBoxFormat.XYXYXYXY: - bounding_boxes = _xyxyxyxy_to_xywhr(bounding_boxes, inplace) if new_format == BoundingBoxFormat.XYWH: bounding_boxes = _xyxy_to_xywh(bounding_boxes, inplace) elif new_format == BoundingBoxFormat.CXCYWH: bounding_boxes = _xyxy_to_cxcywh(bounding_boxes, inplace) - elif new_format == BoundingBoxFormat.CXCYWHR: - bounding_boxes = _xywhr_to_cxcywhr(bounding_boxes, inplace) - elif new_format == BoundingBoxFormat.XYXYXYXY: - bounding_boxes = _xywhr_to_xyxyxyxy(bounding_boxes, inplace) return bounding_boxes diff --git a/torchvision/tv_tensors/_bounding_boxes.py b/torchvision/tv_tensors/_bounding_boxes.py index b0238bb694e..ea02fa3dc7b 100644 --- a/torchvision/tv_tensors/_bounding_boxes.py +++ b/torchvision/tv_tensors/_bounding_boxes.py @@ -12,35 +12,20 @@ class BoundingBoxFormat(Enum): """Coordinate format of a bounding box. - Available formats are: + Available formats are * ``XYXY`` * ``XYWH`` * ``CXCYWH`` - * ``XYWHR``: rotated boxes represented via corner, width and height, x1, y1 - being top left, w, h being width and height. r is rotation angle in - degrees. - * ``CXCYWHR``: jrotated boxes represented via centre, width and height, cx, - cy being center of box, w, h being width and height. r is rotation angle - in degrees. - * ``XYXYXYXY``: rotated boxes represented via corners, x1, y1 being top - left, x2, y2 being bottom right, x3, y3 being bottom left, x4, y4 being - top right. """ XYXY = "XYXY" XYWH = "XYWH" CXCYWH = "CXCYWH" - XYWHR = "XYWHR" - CXCYWHR = "CXCYWHR" - XYXYXYXY = "XYXYXYXY" class BoundingBoxes(TVTensor): - """:class:`torch.Tensor` subclass for bounding boxes with shape ``[N, K]``. - - Where ``N`` is the number of bounding boxes - and ``K`` is 4 for unrotated boxes, and 5 or 8 for rotated boxes. + """:class:`torch.Tensor` subclass for bounding boxes with shape ``[N, 4]``. .. note:: There should be only one :class:`~torchvision.tv_tensors.BoundingBoxes`