Skip to content

Commit

Permalink
clarify comment
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Oct 3, 2023
2 parents 195a458 + 0040fe7 commit d3226ff
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 90 deletions.
35 changes: 0 additions & 35 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,38 +545,3 @@ def test_sanitize_bounding_boxes_errors():
with pytest.raises(ValueError, match="Number of boxes"):
different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)}
transforms.SanitizeBoundingBoxes()(different_sizes)


class TestLambda:
inputs = pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0])

@inputs
def test_default(self, input):
was_applied = False

def was_applied_fn(input):
nonlocal was_applied
was_applied = True
return input

transform = transforms.Lambda(was_applied_fn)

transform(input)

assert was_applied

@inputs
def test_with_types(self, input):
was_applied = False

def was_applied_fn(input):
nonlocal was_applied
was_applied = True
return input

types = (torch.Tensor, np.ndarray)
transform = transforms.Lambda(was_applied_fn, *types)

transform(input)

assert was_applied is isinstance(input, types)
53 changes: 1 addition & 52 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import torchvision.transforms.v2 as v2_transforms
from common_utils import assert_close, assert_equal, set_rng_seed
from torchvision import transforms as legacy_transforms, tv_tensors
from torchvision._utils import sequence_to_str

from torchvision.transforms import functional as legacy_F
from torchvision.transforms.v2 import functional as prototype_F
Expand Down Expand Up @@ -70,57 +69,7 @@ def __init__(
LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)

CONSISTENCY_CONFIGS = [
ConsistencyConfig(
v2_transforms.Lambda,
legacy_transforms.Lambda,
[
NotScriptableArgsKwargs(lambda image: image / 2),
],
# Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
# images given that the transform does nothing but call it anyway.
supports_pil=False,
),
]


@pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__)
def test_signature_consistency(config):
legacy_params = dict(inspect.signature(config.legacy_cls).parameters)
prototype_params = dict(inspect.signature(config.prototype_cls).parameters)

for param in config.removed_params:
legacy_params.pop(param, None)

missing = legacy_params.keys() - prototype_params.keys()
if missing:
raise AssertionError(
f"The prototype transform does not support the parameters "
f"{sequence_to_str(sorted(missing), separate_last='and ')}, but the legacy transform does. "
f"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on "
f"the `ConsistencyConfig`."
)

extra = prototype_params.keys() - legacy_params.keys()
extra_without_default = {
param
for param in extra
if prototype_params[param].default is inspect.Parameter.empty
and prototype_params[param].kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
}
if extra_without_default:
raise AssertionError(
f"The prototype transform requires the parameters "
f"{sequence_to_str(sorted(extra_without_default), separate_last='and ')}, but the legacy transform does "
f"not. Please add a default value."
)

legacy_signature = list(legacy_params.keys())
# Since we made sure that we don't have any extra parameters without default above, we clamp the prototype signature
# to the same number of parameters as the legacy one
prototype_signature = list(prototype_params.keys())[: len(legacy_signature)]

assert prototype_signature == legacy_signature
CONSISTENCY_CONFIGS = []


def check_call_consistency(
Expand Down
23 changes: 21 additions & 2 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -1906,8 +1906,9 @@ def test_random_order(self):
input = make_image()

actual = check_transform(transform, input)
# horizontal and vertical flip are commutative. Meaning, although the order in the transform is indeed random,
# we don't need to care here.
# We can't really check whether the transforms are actually applied in random order. However, horizontal and
# vertical flip are commutative. Meaning, even under the assumption that the transform applies them in random
# order, we can use a fixed order to compute the expected value.
expected = F.vertical_flip(F.horizontal_flip(input))

assert_equal(actual, expected)
Expand Down Expand Up @@ -5221,3 +5222,21 @@ def test_functional_and_transform(self, color_space, fn):
def test_functional_error(self):
with pytest.raises(TypeError, match="pic should be PIL Image"):
F.pil_to_tensor(object())


class TestLambda:
@pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0])
@pytest.mark.parametrize("types", [(), (torch.Tensor, np.ndarray)])
def test_transform(self, input, types):
was_applied = False

def was_applied_fn(input):
nonlocal was_applied
was_applied = True
return input

transform = transforms.Lambda(was_applied_fn, *types)
output = transform(input)

assert output is input
assert was_applied is (not types or isinstance(input, types))
2 changes: 1 addition & 1 deletion torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,7 +1270,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
Note:
Please, note that this method supports only RGB images as input. For inputs in other color spaces,
please, consider using meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image.
please, consider using :meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image.
Args:
img (PIL Image or Tensor): RGB Image to be converted to grayscale.
Expand Down

0 comments on commit d3226ff

Please sign in to comment.