Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

port tests for container transforms #8012

Merged
merged 3 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 0 additions & 29 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,35 +122,6 @@ def test_check_transformed_types(self, inpt_type, mocker):
t(inpt)


class TestContainers:
@pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
def test_assertions(self, transform_cls):
with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"):
transform_cls(transforms.RandomCrop(28))

@pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
@pytest.mark.parametrize(
"trfms",
[
[transforms.Pad(2), transforms.RandomCrop(28)],
[lambda x: 2.0 * x, transforms.Pad(2), transforms.RandomCrop(28)],
[transforms.Pad(2), lambda x: 2.0 * x, transforms.RandomCrop(28)],
],
)
def test_ctor(self, transform_cls, trfms):
c = transform_cls(trfms)
inpt = torch.rand(1, 3, 32, 32)
output = c(inpt)
assert isinstance(output, torch.Tensor)
assert output.ndim == 4


class TestRandomChoice:
def test_assertions(self):
with pytest.raises(ValueError, match="Length of p doesn't match the number of transforms"):
transforms.RandomChoice([transforms.Pad(2), transforms.RandomCrop(28)], p=[1])


class TestRandomIoUCrop:
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]])
Expand Down
138 changes: 1 addition & 137 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
import torch
import torchvision.transforms.v2 as v2_transforms
from common_utils import assert_close, assert_equal, set_rng_seed
from torch import nn
from torchvision import transforms as legacy_transforms, tv_tensors
from torchvision._utils import sequence_to_str

from torchvision.transforms import functional as legacy_F
from torchvision.transforms.v2 import functional as prototype_F
Expand Down Expand Up @@ -71,63 +69,7 @@ def __init__(
LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)

CONSISTENCY_CONFIGS = [
ConsistencyConfig(
v2_transforms.Compose,
legacy_transforms.Compose,
),
ConsistencyConfig(
v2_transforms.RandomApply,
legacy_transforms.RandomApply,
),
ConsistencyConfig(
v2_transforms.RandomChoice,
legacy_transforms.RandomChoice,
),
ConsistencyConfig(
v2_transforms.RandomOrder,
legacy_transforms.RandomOrder,
),
]


@pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__)
def test_signature_consistency(config):
legacy_params = dict(inspect.signature(config.legacy_cls).parameters)
prototype_params = dict(inspect.signature(config.prototype_cls).parameters)

for param in config.removed_params:
legacy_params.pop(param, None)

missing = legacy_params.keys() - prototype_params.keys()
if missing:
raise AssertionError(
f"The prototype transform does not support the parameters "
f"{sequence_to_str(sorted(missing), separate_last='and ')}, but the legacy transform does. "
f"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on "
f"the `ConsistencyConfig`."
)

extra = prototype_params.keys() - legacy_params.keys()
extra_without_default = {
param
for param in extra
if prototype_params[param].default is inspect.Parameter.empty
and prototype_params[param].kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
}
if extra_without_default:
raise AssertionError(
f"The prototype transform requires the parameters "
f"{sequence_to_str(sorted(extra_without_default), separate_last='and ')}, but the legacy transform does "
f"not. Please add a default value."
)

legacy_signature = list(legacy_params.keys())
# Since we made sure that we don't have any extra parameters without default above, we clamp the prototype signature
# to the same number of parameters as the legacy one
prototype_signature = list(prototype_params.keys())[: len(legacy_signature)]

assert prototype_signature == legacy_signature
CONSISTENCY_CONFIGS = []


def check_call_consistency(
Expand Down Expand Up @@ -288,84 +230,6 @@ def test_jit_consistency(config, args_kwargs):
assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)


class TestContainerTransforms:
"""
Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
consistency automatically tests the wrapped transforms consistency.

Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones
that were already tested for consistency above.
"""

def test_compose(self):
prototype_transform = v2_transforms.Compose(
[
v2_transforms.Resize(256),
v2_transforms.CenterCrop(224),
]
)
legacy_transform = legacy_transforms.Compose(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
]
)

# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))

@pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
@pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
def test_random_apply(self, p, sequence_type):
prototype_transform = v2_transforms.RandomApply(
sequence_type(
[
v2_transforms.Resize(256),
v2_transforms.CenterCrop(224),
]
),
p=p,
)
legacy_transform = legacy_transforms.RandomApply(
sequence_type(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
]
),
p=p,
)

# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))

if sequence_type is nn.ModuleList:
# quick and dirty test that it is jit-scriptable
scripted = torch.jit.script(prototype_transform)
scripted(torch.rand(1, 3, 300, 300))

# We can't test other values for `p` since the random parameter generation is different
@pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)])
def test_random_choice(self, probabilities):
prototype_transform = v2_transforms.RandomChoice(
[
v2_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
p=probabilities,
)
legacy_transform = legacy_transforms.RandomChoice(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
p=probabilities,
)

# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))


class TestToTensorTransforms:
def test_pil_to_tensor(self):
prototype_transform = v2_transforms.PILToTensor()
Expand Down
108 changes: 102 additions & 6 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,8 @@ def check_transform(transform, input, check_v1_compatibility=True, check_sample_
if check_v1_compatibility:
_check_transform_v1_compatibility(transform, input, **_to_tolerances(check_v1_compatibility))

return output
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already compute the output in check_transform. By returning it, we don't need to recompute in case the test performs additional checks.



def transform_cls_to_functional(transform_cls, **transform_specific_kwargs):
def wrapper(input, *args, **kwargs):
Expand Down Expand Up @@ -1773,7 +1775,7 @@ def test_transform_unknown_fill_error(self):
transforms.RandomAffine(degrees=0, fill="fill")


class TestCompose:
class TestContainerTransforms:
class BuiltinTransform(transforms.Transform):
def _transform(self, inpt, params):
return inpt
Expand All @@ -1788,7 +1790,10 @@ def forward(self, image, label):
return image, label

@pytest.mark.parametrize(
"transform_clss",
"transform_cls", [transforms.Compose, functools.partial(transforms.RandomApply, p=1), transforms.RandomOrder]
)
@pytest.mark.parametrize(
"wrapped_transform_clss",
[
[BuiltinTransform],
[PackedInputTransform],
Expand All @@ -1803,12 +1808,12 @@ def forward(self, image, label):
],
)
@pytest.mark.parametrize("unpack", [True, False])
def test_packed_unpacked(self, transform_clss, unpack):
needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss)
needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss)
def test_packed_unpacked(self, transform_cls, wrapped_transform_clss, unpack):
needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in wrapped_transform_clss)
needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in wrapped_transform_clss)
assert not (needs_packed_inputs and needs_unpacked_inputs)

transform = transforms.Compose([cls() for cls in transform_clss])
transform = transform_cls([cls() for cls in wrapped_transform_clss])

image = make_image()
label = 3
Expand All @@ -1833,6 +1838,97 @@ def call_transform():
assert output[0] is image
assert output[1] is label

def test_compose(self):
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(p=1),
transforms.RandomVerticalFlip(p=1),
]
)

input = make_image()

actual = check_transform(transform, input)
expected = F.vertical_flip(F.horizontal_flip(input))

assert_equal(actual, expected)

@pytest.mark.parametrize("p", [0.0, 1.0])
@pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
def test_random_apply(self, p, sequence_type):
transform = transforms.RandomApply(
sequence_type(
[
transforms.RandomHorizontalFlip(p=1),
transforms.RandomVerticalFlip(p=1),
]
),
p=p,
)

# This needs to be a pure tensor (or a PIL image), because otherwise check_transforms skips the v1 compatibility
# check
input = make_image_tensor()
output = check_transform(transform, input, check_v1_compatibility=issubclass(sequence_type, nn.ModuleList))

if p == 1:
assert_equal(output, F.vertical_flip(F.horizontal_flip(input)))
else:
assert output is input

@pytest.mark.parametrize("p", [(0, 1), (1, 0)])
def test_random_choice(self, p):
transform = transforms.RandomChoice(
[
transforms.RandomHorizontalFlip(p=1),
transforms.RandomVerticalFlip(p=1),
],
p=p,
)

input = make_image()
output = check_transform(transform, input)

p_horz, p_vert = p
if p_horz:
assert_equal(output, F.horizontal_flip(input))
else:
assert_equal(output, F.vertical_flip(input))

def test_random_order(self):
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(p=1),
transforms.RandomVerticalFlip(p=1),
]
)

input = make_image()

actual = check_transform(transform, input)
# We can't really check whether the transforms are actually applied in random order. However, horizontal and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be too pedantic but I wouldn't say we "can't", we just choose not to.

# vertical flip are commutative. Meaning, even under the assumption that the transform applies them in random
# order, we can use a fixed order to compute the expected value.
expected = F.vertical_flip(F.horizontal_flip(input))

assert_equal(actual, expected)

def test_errors(self):
for cls in [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder]:
with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"):
cls(lambda x: x)

with pytest.raises(ValueError, match="at least one transform"):
transforms.Compose([])

for p in [-1, 2]:
with pytest.raises(ValueError, match=re.escape("value in the interval [0.0, 1.0]")):
transforms.RandomApply([lambda x: x], p=p)

for transforms_, p in [([lambda x: x], []), ([], [1.0])]:
with pytest.raises(ValueError, match="Length of p doesn't match the number of transforms"):
transforms.RandomChoice(transforms_, p=p)


class TestToDtype:
@pytest.mark.parametrize(
Expand Down
16 changes: 9 additions & 7 deletions torchvision/transforms/v2/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,15 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return {"transforms": self.transforms, "p": self.p}

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
needs_unpacking = len(inputs) > 1
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This got the same treatment as transforms.Compose in #7758. TL;DR: RandomApply now has the same UX as Compose in terms of passing packed or unpacked inputs.


if torch.rand(1) >= self.p:
return sample
return inputs if needs_unpacking else inputs[0]

for transform in self.transforms:
sample = transform(sample)
return sample
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs

def extra_repr(self) -> str:
format_string = []
Expand Down Expand Up @@ -173,8 +174,9 @@ def __init__(self, transforms: Sequence[Callable]) -> None:
self.transforms = transforms

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
needs_unpacking = len(inputs) > 1
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

for idx in torch.randperm(len(self.transforms)):
transform = self.transforms[idx]
sample = transform(sample)
return sample
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs