Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

port tests for container transforms #8012

Merged
merged 3 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 0 additions & 29 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,35 +122,6 @@ def test_check_transformed_types(self, inpt_type, mocker):
t(inpt)


class TestContainers:
@pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
def test_assertions(self, transform_cls):
with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"):
transform_cls(transforms.RandomCrop(28))

@pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
@pytest.mark.parametrize(
"trfms",
[
[transforms.Pad(2), transforms.RandomCrop(28)],
[lambda x: 2.0 * x, transforms.Pad(2), transforms.RandomCrop(28)],
[transforms.Pad(2), lambda x: 2.0 * x, transforms.RandomCrop(28)],
],
)
def test_ctor(self, transform_cls, trfms):
c = transform_cls(trfms)
inpt = torch.rand(1, 3, 32, 32)
output = c(inpt)
assert isinstance(output, torch.Tensor)
assert output.ndim == 4


class TestRandomChoice:
def test_assertions(self):
with pytest.raises(ValueError, match="Length of p doesn't match the number of transforms"):
transforms.RandomChoice([transforms.Pad(2), transforms.RandomCrop(28)], p=[1])


class TestRandomIoUCrop:
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]])
Expand Down
95 changes: 0 additions & 95 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch
import torchvision.transforms.v2 as v2_transforms
from common_utils import assert_close, assert_equal, set_rng_seed
from torch import nn
from torchvision import transforms as legacy_transforms, tv_tensors
from torchvision._utils import sequence_to_str

Expand Down Expand Up @@ -82,22 +81,6 @@ def __init__(
# images given that the transform does nothing but call it anyway.
supports_pil=False,
),
ConsistencyConfig(
v2_transforms.Compose,
legacy_transforms.Compose,
),
ConsistencyConfig(
v2_transforms.RandomApply,
legacy_transforms.RandomApply,
),
ConsistencyConfig(
v2_transforms.RandomChoice,
legacy_transforms.RandomChoice,
),
ConsistencyConfig(
v2_transforms.RandomOrder,
legacy_transforms.RandomOrder,
),
]


Expand Down Expand Up @@ -298,84 +281,6 @@ def test_jit_consistency(config, args_kwargs):
assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)


class TestContainerTransforms:
"""
Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
consistency automatically tests the wrapped transforms consistency.

Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones
that were already tested for consistency above.
"""

def test_compose(self):
prototype_transform = v2_transforms.Compose(
[
v2_transforms.Resize(256),
v2_transforms.CenterCrop(224),
]
)
legacy_transform = legacy_transforms.Compose(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
]
)

# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))

@pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
@pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
def test_random_apply(self, p, sequence_type):
prototype_transform = v2_transforms.RandomApply(
sequence_type(
[
v2_transforms.Resize(256),
v2_transforms.CenterCrop(224),
]
),
p=p,
)
legacy_transform = legacy_transforms.RandomApply(
sequence_type(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
]
),
p=p,
)

# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))

if sequence_type is nn.ModuleList:
# quick and dirty test that it is jit-scriptable
scripted = torch.jit.script(prototype_transform)
scripted(torch.rand(1, 3, 300, 300))

# We can't test other values for `p` since the random parameter generation is different
@pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)])
def test_random_choice(self, probabilities):
prototype_transform = v2_transforms.RandomChoice(
[
v2_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
p=probabilities,
)
legacy_transform = legacy_transforms.RandomChoice(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
p=probabilities,
)

# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))


class TestToTensorTransforms:
def test_pil_to_tensor(self):
prototype_transform = v2_transforms.PILToTensor()
Expand Down
107 changes: 101 additions & 6 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,8 @@ def check_transform(transform, input, check_v1_compatibility=True, check_sample_
if check_v1_compatibility:
_check_transform_v1_compatibility(transform, input, **_to_tolerances(check_v1_compatibility))

return output
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already compute the output in check_transform. By returning it, we don't need to recompute in case the test performs additional checks.



def transform_cls_to_functional(transform_cls, **transform_specific_kwargs):
def wrapper(input, *args, **kwargs):
Expand Down Expand Up @@ -1773,7 +1775,7 @@ def test_transform_unknown_fill_error(self):
transforms.RandomAffine(degrees=0, fill="fill")


class TestCompose:
class TestContainerTransforms:
class BuiltinTransform(transforms.Transform):
def _transform(self, inpt, params):
return inpt
Expand All @@ -1788,7 +1790,10 @@ def forward(self, image, label):
return image, label

@pytest.mark.parametrize(
"transform_clss",
"transform_cls", [transforms.Compose, functools.partial(transforms.RandomApply, p=1), transforms.RandomOrder]
)
@pytest.mark.parametrize(
"wrapped_transform_clss",
[
[BuiltinTransform],
[PackedInputTransform],
Expand All @@ -1803,12 +1808,12 @@ def forward(self, image, label):
],
)
@pytest.mark.parametrize("unpack", [True, False])
def test_packed_unpacked(self, transform_clss, unpack):
needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss)
needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss)
def test_packed_unpacked(self, transform_cls, wrapped_transform_clss, unpack):
needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in wrapped_transform_clss)
needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in wrapped_transform_clss)
assert not (needs_packed_inputs and needs_unpacked_inputs)

transform = transforms.Compose([cls() for cls in transform_clss])
transform = transform_cls([cls() for cls in wrapped_transform_clss])

image = make_image()
label = 3
Expand All @@ -1833,6 +1838,96 @@ def call_transform():
assert output[0] is image
assert output[1] is label

def test_compose(self):
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(p=1),
transforms.RandomVerticalFlip(p=1),
]
)

input = make_image()

actual = check_transform(transform, input)
expected = F.vertical_flip(F.horizontal_flip(input))

assert_equal(actual, expected)

@pytest.mark.parametrize("p", [0.0, 1.0])
@pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
def test_random_apply(self, p, sequence_type):
transform = transforms.RandomApply(
sequence_type(
[
transforms.RandomHorizontalFlip(p=1),
transforms.RandomVerticalFlip(p=1),
]
),
p=p,
)

# This needs to be a pure tensor (or a PIL image), because otherwise check_transforms skips the v1 compatibility
# check
input = make_image_tensor()
output = check_transform(transform, input, check_v1_compatibility=issubclass(sequence_type, nn.ModuleList))

if p == 1:
assert_equal(output, F.vertical_flip(F.horizontal_flip(input)))
else:
assert output is input

@pytest.mark.parametrize("p", [(0, 1), (1, 0)])
def test_random_choice(self, p):
transform = transforms.RandomChoice(
[
transforms.RandomHorizontalFlip(p=1),
transforms.RandomVerticalFlip(p=1),
],
p=p,
)

input = make_image()
output = check_transform(transform, input)

p_horz, p_vert = p
if p_horz:
assert_equal(output, F.horizontal_flip(input))
else:
assert_equal(output, F.vertical_flip(input))

def test_random_order(self):
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(p=1),
transforms.RandomVerticalFlip(p=1),
]
)

input = make_image()

actual = check_transform(transform, input)
# horizontal and vertical flip are commutative. Meaning, although the order in the transform is indeed random,
# we don't need to care here.
expected = F.vertical_flip(F.horizontal_flip(input))

assert_equal(actual, expected)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately it also means that this test doesn't check that the order of the transforms is indeed random. At best it checks that the input transforms are applied. I think we should at least acknowledge that in the comment instead of saying "we don't need to care".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. All ears if you have an idea to check whether the transforms are actually applied in random order.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a pair of non-commutative transforms, with 2 hand-chosen seeds, asserting that the results are different. What makes it a bit harder is that the transforms themselves must be non-random. Worst case scenario we could just define Plus2 and Times2.


def test_errors(self):
for cls in [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder]:
with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"):
cls(lambda x: x)

with pytest.raises(ValueError, match="at least one transform"):
transforms.Compose([])

for p in [-1, 2]:
with pytest.raises(ValueError, match=re.escape("value in the interval [0.0, 1.0]")):
transforms.RandomApply([lambda x: x], p=p)

for transforms_, p in [([lambda x: x], []), ([], [1.0])]:
with pytest.raises(ValueError, match="Length of p doesn't match the number of transforms"):
transforms.RandomChoice(transforms_, p=p)


class TestToDtype:
@pytest.mark.parametrize(
Expand Down
16 changes: 9 additions & 7 deletions torchvision/transforms/v2/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,15 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return {"transforms": self.transforms, "p": self.p}

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
needs_unpacking = len(inputs) > 1
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This got the same treatment as transforms.Compose in #7758. TL;DR: RandomApply now has the same UX as Compose in terms of passing packed or unpacked inputs.


if torch.rand(1) >= self.p:
return sample
return inputs if needs_unpacking else inputs[0]

for transform in self.transforms:
sample = transform(sample)
return sample
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs

def extra_repr(self) -> str:
format_string = []
Expand Down Expand Up @@ -173,8 +174,9 @@ def __init__(self, transforms: Sequence[Callable]) -> None:
self.transforms = transforms

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
needs_unpacking = len(inputs) > 1
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

for idx in torch.randperm(len(self.transforms)):
transform = self.transforms[idx]
sample = transform(sample)
return sample
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs