Skip to content

Commit

Permalink
port convert_bounding_box_format tests (#7933)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Sep 6, 2023
1 parent 1f94320 commit 4103552
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 72 deletions.
27 changes: 0 additions & 27 deletions test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,33 +522,6 @@ def test_tv_tensor_explicit_metadata(self, metadata):
F.clamp_bounding_boxes(tv_tensor, **metadata)


class TestConvertFormatBoundingBoxes:
@pytest.mark.parametrize(
("inpt", "old_format"),
[
(next(make_multiple_bounding_boxes()), None),
(next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor), tv_tensors.BoundingBoxFormat.XYXY),
],
)
def test_missing_new_format(self, inpt, old_format):
with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
F.convert_bounding_box_format(inpt, old_format)

def test_pure_tensor_insufficient_metadata(self):
pure_tensor = next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor)

with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
F.convert_bounding_box_format(pure_tensor, new_format=tv_tensors.BoundingBoxFormat.CXCYWH)

def test_tv_tensor_explicit_metadata(self):
tv_tensor = next(make_multiple_bounding_boxes())

with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
F.convert_bounding_box_format(
tv_tensor, old_format=tv_tensor.format, new_format=tv_tensors.BoundingBoxFormat.CXCYWH
)


# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
# `transforms_v2_kernel_infos.py`

Expand Down
111 changes: 106 additions & 5 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import contextlib
import decimal
import functools
import inspect
import itertools
import math
import pickle
import re
Expand All @@ -12,6 +14,8 @@
import pytest

import torch

import torchvision.ops
import torchvision.transforms.v2 as transforms
from common_utils import (
assert_equal,
Expand Down Expand Up @@ -138,7 +142,6 @@ def check_kernel(
check_cuda_vs_cpu=True,
check_scripted_vs_eager=True,
check_batched_vs_unbatched=True,
expect_same_dtype=True,
**kwargs,
):
initial_input_version = input._version
Expand All @@ -151,7 +154,7 @@ def check_kernel(
# check that no inplace operation happened
assert input._version == initial_input_version

if expect_same_dtype:
if kernel not in {F.to_dtype_image, F.to_dtype_video}:
assert output.dtype == input.dtype
assert output.device == input.device

Expand Down Expand Up @@ -187,7 +190,7 @@ def check_functional(functional, input, *args, check_scripted_smoke=True, **kwar

assert isinstance(output, type(input))

if isinstance(input, tv_tensors.BoundingBoxes):
if isinstance(input, tv_tensors.BoundingBoxes) and functional is not F.convert_bounding_box_format:
assert output.format == input.format

if check_scripted_smoke:
Expand Down Expand Up @@ -264,7 +267,7 @@ def check_transform(transform, input, check_v1_compatibility=True):
output = transform(input)
assert isinstance(output, type(input))

if isinstance(input, tv_tensors.BoundingBoxes):
if isinstance(input, tv_tensors.BoundingBoxes) and not isinstance(transform, transforms.ConvertBoundingBoxFormat):
assert output.format == input.format

if check_v1_compatibility:
Expand Down Expand Up @@ -1743,7 +1746,6 @@ def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, sca
check_kernel(
kernel,
make_input(dtype=input_dtype, device=device),
expect_same_dtype=input_dtype is output_dtype,
dtype=output_dtype,
scale=scale,
)
Expand Down Expand Up @@ -3009,3 +3011,102 @@ def test_auto_augment_policy_error(self):
def test_aug_mix_severity_error(self, severity):
with pytest.raises(ValueError, match="severity must be between"):
transforms.AugMix(severity=severity)


class TestConvertBoundingBoxFormat:
old_new_formats = list(itertools.permutations(iter(tv_tensors.BoundingBoxFormat), 2))

@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
def test_kernel(self, old_format, new_format):
check_kernel(
F.convert_bounding_box_format,
make_bounding_boxes(format=old_format),
new_format=new_format,
old_format=old_format,
)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("inplace", [False, True])
def test_kernel_noop(self, format, inplace):
input = make_bounding_boxes(format=format).as_subclass(torch.Tensor)
input_version = input._version

output = F.convert_bounding_box_format(input, old_format=format, new_format=format, inplace=inplace)

assert output is input
assert output.data_ptr() == input.data_ptr()
assert output._version == input_version

@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
def test_kernel_inplace(self, old_format, new_format):
input = make_bounding_boxes(format=old_format).as_subclass(torch.Tensor)
input_version = input._version

output_out_of_place = F.convert_bounding_box_format(input, old_format=old_format, new_format=new_format)
assert output_out_of_place.data_ptr() != input.data_ptr()
assert output_out_of_place is not input

output_inplace = F.convert_bounding_box_format(
input, old_format=old_format, new_format=new_format, inplace=True
)
assert output_inplace.data_ptr() == input.data_ptr()
assert output_inplace._version > input_version
assert output_inplace is input

assert_equal(output_inplace, output_out_of_place)

@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
def test_functional(self, old_format, new_format):
check_functional(F.convert_bounding_box_format, make_bounding_boxes(format=old_format), new_format=new_format)

@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
@pytest.mark.parametrize("format_type", ["enum", "str"])
def test_transform(self, old_format, new_format, format_type):
check_transform(
transforms.ConvertBoundingBoxFormat(new_format.name if format_type == "str" else new_format),
make_bounding_boxes(format=old_format),
)

def _reference_convert_bounding_box_format(self, bounding_boxes, new_format):
return tv_tensors.wrap(
torchvision.ops.box_convert(
bounding_boxes.as_subclass(torch.Tensor),
in_fmt=bounding_boxes.format.name.lower(),
out_fmt=new_format.name.lower(),
).to(bounding_boxes.dtype),
like=bounding_boxes,
format=new_format,
)

@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("fn_type", ["functional", "transform"])
def test_correctness(self, old_format, new_format, dtype, device, fn_type):
bounding_boxes = make_bounding_boxes(format=old_format, dtype=dtype, device=device)

if fn_type == "functional":
fn = functools.partial(F.convert_bounding_box_format, new_format=new_format)
else:
fn = transforms.ConvertBoundingBoxFormat(format=new_format)

actual = fn(bounding_boxes)
expected = self._reference_convert_bounding_box_format(bounding_boxes, new_format)

assert_equal(actual, expected)

def test_errors(self):
input_tv_tensor = make_bounding_boxes()
input_pure_tensor = input_tv_tensor.as_subclass(torch.Tensor)

for input in [input_tv_tensor, input_pure_tensor]:
with pytest.raises(TypeError, match="missing 1 required argument: 'new_format'"):
F.convert_bounding_box_format(input)

with pytest.raises(ValueError, match="`old_format` has to be passed"):
F.convert_bounding_box_format(input_pure_tensor, new_format=input_tv_tensor.format)

with pytest.raises(ValueError, match="`old_format` must not be passed"):
F.convert_bounding_box_format(
input_tv_tensor, old_format=input_tv_tensor.format, new_format=input_tv_tensor.format
)
7 changes: 0 additions & 7 deletions test/transforms_v2_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,4 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
skip_dispatch_tv_tensor,
],
),
DispatcherInfo(
F.convert_bounding_box_format,
kernels={tv_tensors.BoundingBoxes: F.convert_bounding_box_format},
test_marks=[
skip_dispatch_tv_tensor,
],
),
]
33 changes: 0 additions & 33 deletions test/transforms_v2_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import PIL.Image
import pytest
import torch.testing
import torchvision.ops
import torchvision.transforms.v2.functional as F
from torchvision import tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value, _parse_pad_padding
Expand Down Expand Up @@ -227,38 +226,6 @@ def transform(bbox, affine_matrix_, format_, canvas_size_):
).reshape(bounding_boxes.shape)


def sample_inputs_convert_bounding_box_format():
formats = list(tv_tensors.BoundingBoxFormat)
for bounding_boxes_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
yield ArgsKwargs(bounding_boxes_loader, old_format=bounding_boxes_loader.format, new_format=new_format)


def reference_convert_bounding_box_format(bounding_boxes, old_format, new_format):
return torchvision.ops.box_convert(
bounding_boxes, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower()
).to(bounding_boxes.dtype)


def reference_inputs_convert_bounding_box_format():
for args_kwargs in sample_inputs_convert_bounding_box_format():
if len(args_kwargs.args[0].shape) == 2:
yield args_kwargs


KERNEL_INFOS.append(
KernelInfo(
F.convert_bounding_box_format,
sample_inputs_fn=sample_inputs_convert_bounding_box_format,
reference_fn=reference_convert_bounding_box_format,
reference_inputs_fn=reference_inputs_convert_bounding_box_format,
logs_usage=True,
closeness_kwargs={
(("TestKernels", "test_against_reference"), torch.int64, "cpu"): dict(atol=1, rtol=0),
},
),
)


_RESIZED_CROP_PARAMS = combinations_grid(top=[-8, 9], left=[-8, 9], height=[12], width=[12], size=[(16, 18)])


Expand Down

0 comments on commit 4103552

Please sign in to comment.