Skip to content

Commit

Permalink
[fbsync] port tests for container transforms (#8012)
Browse files Browse the repository at this point in the history
Reviewed By: vmoens

Differential Revision: D50789102

fbshipit-source-id: d95eea91a6574a13c8d53023ab4b2fb65be62ea3
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Oct 31, 2023
1 parent 54e9950 commit 37d303a
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 179 deletions.
29 changes: 0 additions & 29 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,35 +122,6 @@ def test_check_transformed_types(self, inpt_type, mocker):
t(inpt)


class TestContainers:
@pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
def test_assertions(self, transform_cls):
with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"):
transform_cls(transforms.RandomCrop(28))

@pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
@pytest.mark.parametrize(
"trfms",
[
[transforms.Pad(2), transforms.RandomCrop(28)],
[lambda x: 2.0 * x, transforms.Pad(2), transforms.RandomCrop(28)],
[transforms.Pad(2), lambda x: 2.0 * x, transforms.RandomCrop(28)],
],
)
def test_ctor(self, transform_cls, trfms):
c = transform_cls(trfms)
inpt = torch.rand(1, 3, 32, 32)
output = c(inpt)
assert isinstance(output, torch.Tensor)
assert output.ndim == 4


class TestRandomChoice:
def test_assertions(self):
with pytest.raises(ValueError, match="Length of p doesn't match the number of transforms"):
transforms.RandomChoice([transforms.Pad(2), transforms.RandomCrop(28)], p=[1])


class TestRandomIoUCrop:
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]])
Expand Down
138 changes: 1 addition & 137 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
import torch
import torchvision.transforms.v2 as v2_transforms
from common_utils import assert_close, assert_equal, set_rng_seed
from torch import nn
from torchvision import transforms as legacy_transforms, tv_tensors
from torchvision._utils import sequence_to_str

from torchvision.transforms import functional as legacy_F
from torchvision.transforms.v2 import functional as prototype_F
Expand Down Expand Up @@ -71,63 +69,7 @@ def __init__(
LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)

CONSISTENCY_CONFIGS = [
ConsistencyConfig(
v2_transforms.Compose,
legacy_transforms.Compose,
),
ConsistencyConfig(
v2_transforms.RandomApply,
legacy_transforms.RandomApply,
),
ConsistencyConfig(
v2_transforms.RandomChoice,
legacy_transforms.RandomChoice,
),
ConsistencyConfig(
v2_transforms.RandomOrder,
legacy_transforms.RandomOrder,
),
]


@pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__)
def test_signature_consistency(config):
legacy_params = dict(inspect.signature(config.legacy_cls).parameters)
prototype_params = dict(inspect.signature(config.prototype_cls).parameters)

for param in config.removed_params:
legacy_params.pop(param, None)

missing = legacy_params.keys() - prototype_params.keys()
if missing:
raise AssertionError(
f"The prototype transform does not support the parameters "
f"{sequence_to_str(sorted(missing), separate_last='and ')}, but the legacy transform does. "
f"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on "
f"the `ConsistencyConfig`."
)

extra = prototype_params.keys() - legacy_params.keys()
extra_without_default = {
param
for param in extra
if prototype_params[param].default is inspect.Parameter.empty
and prototype_params[param].kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
}
if extra_without_default:
raise AssertionError(
f"The prototype transform requires the parameters "
f"{sequence_to_str(sorted(extra_without_default), separate_last='and ')}, but the legacy transform does "
f"not. Please add a default value."
)

legacy_signature = list(legacy_params.keys())
# Since we made sure that we don't have any extra parameters without default above, we clamp the prototype signature
# to the same number of parameters as the legacy one
prototype_signature = list(prototype_params.keys())[: len(legacy_signature)]

assert prototype_signature == legacy_signature
CONSISTENCY_CONFIGS = []


def check_call_consistency(
Expand Down Expand Up @@ -288,84 +230,6 @@ def test_jit_consistency(config, args_kwargs):
assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)


class TestContainerTransforms:
"""
Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
consistency automatically tests the wrapped transforms consistency.
Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones
that were already tested for consistency above.
"""

def test_compose(self):
prototype_transform = v2_transforms.Compose(
[
v2_transforms.Resize(256),
v2_transforms.CenterCrop(224),
]
)
legacy_transform = legacy_transforms.Compose(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
]
)

# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))

@pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
@pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
def test_random_apply(self, p, sequence_type):
prototype_transform = v2_transforms.RandomApply(
sequence_type(
[
v2_transforms.Resize(256),
v2_transforms.CenterCrop(224),
]
),
p=p,
)
legacy_transform = legacy_transforms.RandomApply(
sequence_type(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
]
),
p=p,
)

# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))

if sequence_type is nn.ModuleList:
# quick and dirty test that it is jit-scriptable
scripted = torch.jit.script(prototype_transform)
scripted(torch.rand(1, 3, 300, 300))

# We can't test other values for `p` since the random parameter generation is different
@pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)])
def test_random_choice(self, probabilities):
prototype_transform = v2_transforms.RandomChoice(
[
v2_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
p=probabilities,
)
legacy_transform = legacy_transforms.RandomChoice(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
p=probabilities,
)

# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))


class TestToTensorTransforms:
def test_pil_to_tensor(self):
prototype_transform = v2_transforms.PILToTensor()
Expand Down
108 changes: 102 additions & 6 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,8 @@ def check_transform(transform, input, check_v1_compatibility=True, check_sample_
if check_v1_compatibility:
_check_transform_v1_compatibility(transform, input, **_to_tolerances(check_v1_compatibility))

return output


def transform_cls_to_functional(transform_cls, **transform_specific_kwargs):
def wrapper(input, *args, **kwargs):
Expand Down Expand Up @@ -1773,7 +1775,7 @@ def test_transform_unknown_fill_error(self):
transforms.RandomAffine(degrees=0, fill="fill")


class TestCompose:
class TestContainerTransforms:
class BuiltinTransform(transforms.Transform):
def _transform(self, inpt, params):
return inpt
Expand All @@ -1788,7 +1790,10 @@ def forward(self, image, label):
return image, label

@pytest.mark.parametrize(
"transform_clss",
"transform_cls", [transforms.Compose, functools.partial(transforms.RandomApply, p=1), transforms.RandomOrder]
)
@pytest.mark.parametrize(
"wrapped_transform_clss",
[
[BuiltinTransform],
[PackedInputTransform],
Expand All @@ -1803,12 +1808,12 @@ def forward(self, image, label):
],
)
@pytest.mark.parametrize("unpack", [True, False])
def test_packed_unpacked(self, transform_clss, unpack):
needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss)
needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss)
def test_packed_unpacked(self, transform_cls, wrapped_transform_clss, unpack):
needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in wrapped_transform_clss)
needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in wrapped_transform_clss)
assert not (needs_packed_inputs and needs_unpacked_inputs)

transform = transforms.Compose([cls() for cls in transform_clss])
transform = transform_cls([cls() for cls in wrapped_transform_clss])

image = make_image()
label = 3
Expand All @@ -1833,6 +1838,97 @@ def call_transform():
assert output[0] is image
assert output[1] is label

def test_compose(self):
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(p=1),
transforms.RandomVerticalFlip(p=1),
]
)

input = make_image()

actual = check_transform(transform, input)
expected = F.vertical_flip(F.horizontal_flip(input))

assert_equal(actual, expected)

@pytest.mark.parametrize("p", [0.0, 1.0])
@pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
def test_random_apply(self, p, sequence_type):
transform = transforms.RandomApply(
sequence_type(
[
transforms.RandomHorizontalFlip(p=1),
transforms.RandomVerticalFlip(p=1),
]
),
p=p,
)

# This needs to be a pure tensor (or a PIL image), because otherwise check_transforms skips the v1 compatibility
# check
input = make_image_tensor()
output = check_transform(transform, input, check_v1_compatibility=issubclass(sequence_type, nn.ModuleList))

if p == 1:
assert_equal(output, F.vertical_flip(F.horizontal_flip(input)))
else:
assert output is input

@pytest.mark.parametrize("p", [(0, 1), (1, 0)])
def test_random_choice(self, p):
transform = transforms.RandomChoice(
[
transforms.RandomHorizontalFlip(p=1),
transforms.RandomVerticalFlip(p=1),
],
p=p,
)

input = make_image()
output = check_transform(transform, input)

p_horz, p_vert = p
if p_horz:
assert_equal(output, F.horizontal_flip(input))
else:
assert_equal(output, F.vertical_flip(input))

def test_random_order(self):
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(p=1),
transforms.RandomVerticalFlip(p=1),
]
)

input = make_image()

actual = check_transform(transform, input)
# We can't really check whether the transforms are actually applied in random order. However, horizontal and
# vertical flip are commutative. Meaning, even under the assumption that the transform applies them in random
# order, we can use a fixed order to compute the expected value.
expected = F.vertical_flip(F.horizontal_flip(input))

assert_equal(actual, expected)

def test_errors(self):
for cls in [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder]:
with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"):
cls(lambda x: x)

with pytest.raises(ValueError, match="at least one transform"):
transforms.Compose([])

for p in [-1, 2]:
with pytest.raises(ValueError, match=re.escape("value in the interval [0.0, 1.0]")):
transforms.RandomApply([lambda x: x], p=p)

for transforms_, p in [([lambda x: x], []), ([], [1.0])]:
with pytest.raises(ValueError, match="Length of p doesn't match the number of transforms"):
transforms.RandomChoice(transforms_, p=p)


class TestToDtype:
@pytest.mark.parametrize(
Expand Down
16 changes: 9 additions & 7 deletions torchvision/transforms/v2/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,15 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return {"transforms": self.transforms, "p": self.p}

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
needs_unpacking = len(inputs) > 1

if torch.rand(1) >= self.p:
return sample
return inputs if needs_unpacking else inputs[0]

for transform in self.transforms:
sample = transform(sample)
return sample
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs

def extra_repr(self) -> str:
format_string = []
Expand Down Expand Up @@ -173,8 +174,9 @@ def __init__(self, transforms: Sequence[Callable]) -> None:
self.transforms = transforms

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
needs_unpacking = len(inputs) > 1
for idx in torch.randperm(len(self.transforms)):
transform = self.transforms[idx]
sample = transform(sample)
return sample
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs

0 comments on commit 37d303a

Please sign in to comment.