Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

port convert_bounding_box_format tests #7933

Merged
merged 4 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 0 additions & 27 deletions test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,33 +522,6 @@ def test_tv_tensor_explicit_metadata(self, metadata):
F.clamp_bounding_boxes(tv_tensor, **metadata)


class TestConvertFormatBoundingBoxes:
@pytest.mark.parametrize(
("inpt", "old_format"),
[
(next(make_multiple_bounding_boxes()), None),
(next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor), tv_tensors.BoundingBoxFormat.XYXY),
],
)
def test_missing_new_format(self, inpt, old_format):
with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
F.convert_bounding_box_format(inpt, old_format)

def test_pure_tensor_insufficient_metadata(self):
pure_tensor = next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor)

with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
F.convert_bounding_box_format(pure_tensor, new_format=tv_tensors.BoundingBoxFormat.CXCYWH)

def test_tv_tensor_explicit_metadata(self):
tv_tensor = next(make_multiple_bounding_boxes())

with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
F.convert_bounding_box_format(
tv_tensor, old_format=tv_tensor.format, new_format=tv_tensors.BoundingBoxFormat.CXCYWH
)


# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
# `transforms_v2_kernel_infos.py`

Expand Down
111 changes: 106 additions & 5 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import contextlib
import decimal
import functools
import inspect
import itertools
import math
import pickle
import re
Expand All @@ -12,6 +14,8 @@
import pytest

import torch

import torchvision.ops
import torchvision.transforms.v2 as transforms
from common_utils import (
assert_equal,
Expand Down Expand Up @@ -138,7 +142,6 @@ def check_kernel(
check_cuda_vs_cpu=True,
check_scripted_vs_eager=True,
check_batched_vs_unbatched=True,
expect_same_dtype=True,
**kwargs,
):
initial_input_version = input._version
Expand All @@ -151,7 +154,7 @@ def check_kernel(
# check that no inplace operation happened
assert input._version == initial_input_version

if expect_same_dtype:
if kernel not in {F.to_dtype_image, F.to_dtype_video}:
assert output.dtype == input.dtype
assert output.device == input.device

Expand Down Expand Up @@ -187,7 +190,7 @@ def check_functional(functional, input, *args, check_scripted_smoke=True, **kwar

assert isinstance(output, type(input))

if isinstance(input, tv_tensors.BoundingBoxes):
if isinstance(input, tv_tensors.BoundingBoxes) and functional is not F.convert_bounding_box_format:
pmeier marked this conversation as resolved.
Show resolved Hide resolved
assert output.format == input.format

if check_scripted_smoke:
Expand Down Expand Up @@ -264,7 +267,7 @@ def check_transform(transform, input, check_v1_compatibility=True):
output = transform(input)
assert isinstance(output, type(input))

if isinstance(input, tv_tensors.BoundingBoxes):
if isinstance(input, tv_tensors.BoundingBoxes) and not isinstance(transform, transforms.ConvertBoundingBoxFormat):
assert output.format == input.format

if check_v1_compatibility:
Expand Down Expand Up @@ -1743,7 +1746,6 @@ def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, sca
check_kernel(
kernel,
make_input(dtype=input_dtype, device=device),
expect_same_dtype=input_dtype is output_dtype,
dtype=output_dtype,
scale=scale,
)
Expand Down Expand Up @@ -3009,3 +3011,102 @@ def test_auto_augment_policy_error(self):
def test_aug_mix_severity_error(self, severity):
with pytest.raises(ValueError, match="severity must be between"):
transforms.AugMix(severity=severity)


class TestConvertBoundingBoxFormat:
old_new_formats = list(itertools.permutations(iter(tv_tensors.BoundingBoxFormat), 2))

@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
def test_kernel(self, old_format, new_format):
check_kernel(
F.convert_bounding_box_format,
make_bounding_boxes(format=old_format),
new_format=new_format,
old_format=old_format,
)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("inplace", [False, True])
def test_kernel_noop(self, format, inplace):
input = make_bounding_boxes(format=format).as_subclass(torch.Tensor)
input_version = input._version

output = F.convert_bounding_box_format(input, old_format=format, new_format=format, inplace=inplace)

assert output is input
assert output.data_ptr() == input.data_ptr()
assert output._version == input_version

@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
def test_kernel_inplace(self, old_format, new_format):
input = make_bounding_boxes(format=old_format).as_subclass(torch.Tensor)
input_version = input._version

output_out_of_place = F.convert_bounding_box_format(input, old_format=old_format, new_format=new_format)
assert output_out_of_place.data_ptr() != input.data_ptr()
assert output_out_of_place is not input

output_inplace = F.convert_bounding_box_format(
input, old_format=old_format, new_format=new_format, inplace=True
)
assert output_inplace.data_ptr() == input.data_ptr()
assert output_inplace._version > input_version
assert output_inplace is input

assert_equal(output_inplace, output_out_of_place)

@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
def test_functional(self, old_format, new_format):
check_functional(F.convert_bounding_box_format, make_bounding_boxes(format=old_format), new_format=new_format)

@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
@pytest.mark.parametrize("format_type", ["enum", "str"])
def test_transform(self, old_format, new_format, format_type):
check_transform(
transforms.ConvertBoundingBoxFormat(new_format.name if format_type == "str" else new_format),
make_bounding_boxes(format=old_format),
)

def _reference_convert_bounding_box_format(self, bounding_boxes, new_format):
return tv_tensors.wrap(
torchvision.ops.box_convert(
bounding_boxes.as_subclass(torch.Tensor),
in_fmt=bounding_boxes.format.name.lower(),
out_fmt=new_format.name.lower(),
).to(bounding_boxes.dtype),
like=bounding_boxes,
format=new_format,
)

@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("fn_type", ["functional", "transform"])
def test_correctness(self, old_format, new_format, dtype, device, fn_type):
pmeier marked this conversation as resolved.
Show resolved Hide resolved
bounding_boxes = make_bounding_boxes(format=old_format, dtype=dtype, device=device)

if fn_type == "functional":
fn = functools.partial(F.convert_bounding_box_format, new_format=new_format)
else:
fn = transforms.ConvertBoundingBoxFormat(format=new_format)

actual = fn(bounding_boxes)
expected = self._reference_convert_bounding_box_format(bounding_boxes, new_format)

assert_equal(actual, expected)

def test_errors(self):
input_tv_tensor = make_bounding_boxes()
input_pure_tensor = input_tv_tensor.as_subclass(torch.Tensor)

for input in [input_tv_tensor, input_pure_tensor]:
with pytest.raises(TypeError, match="missing 1 required argument: 'new_format'"):
F.convert_bounding_box_format(input)

with pytest.raises(ValueError, match="`old_format` has to be passed"):
F.convert_bounding_box_format(input_pure_tensor, new_format=input_tv_tensor.format)

with pytest.raises(ValueError, match="`old_format` must not be passed"):
F.convert_bounding_box_format(
input_tv_tensor, old_format=input_tv_tensor.format, new_format=input_tv_tensor.format
)
7 changes: 0 additions & 7 deletions test/transforms_v2_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,4 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
skip_dispatch_tv_tensor,
],
),
DispatcherInfo(
F.convert_bounding_box_format,
kernels={tv_tensors.BoundingBoxes: F.convert_bounding_box_format},
test_marks=[
skip_dispatch_tv_tensor,
],
),
]
33 changes: 0 additions & 33 deletions test/transforms_v2_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import PIL.Image
import pytest
import torch.testing
import torchvision.ops
import torchvision.transforms.v2.functional as F
from torchvision import tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value, _parse_pad_padding
Expand Down Expand Up @@ -227,38 +226,6 @@ def transform(bbox, affine_matrix_, format_, canvas_size_):
).reshape(bounding_boxes.shape)


def sample_inputs_convert_bounding_box_format():
formats = list(tv_tensors.BoundingBoxFormat)
for bounding_boxes_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
yield ArgsKwargs(bounding_boxes_loader, old_format=bounding_boxes_loader.format, new_format=new_format)


def reference_convert_bounding_box_format(bounding_boxes, old_format, new_format):
return torchvision.ops.box_convert(
bounding_boxes, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower()
).to(bounding_boxes.dtype)


def reference_inputs_convert_bounding_box_format():
for args_kwargs in sample_inputs_convert_bounding_box_format():
if len(args_kwargs.args[0].shape) == 2:
yield args_kwargs


KERNEL_INFOS.append(
KernelInfo(
F.convert_bounding_box_format,
sample_inputs_fn=sample_inputs_convert_bounding_box_format,
reference_fn=reference_convert_bounding_box_format,
reference_inputs_fn=reference_inputs_convert_bounding_box_format,
logs_usage=True,
closeness_kwargs={
(("TestKernels", "test_against_reference"), torch.int64, "cpu"): dict(atol=1, rtol=0),
},
),
)


_RESIZED_CROP_PARAMS = combinations_grid(top=[-8, 9], left=[-8, 9], height=[12], width=[12], size=[(16, 18)])


Expand Down