Skip to content

Commit

Permalink
[fbsync] port tests for rgb_to_grayscale functional and transforms (#…
Browse files Browse the repository at this point in the history
…7967)

Reviewed By: matteobettini

Differential Revision: D49600771

fbshipit-source-id: a4d6ed523f2ed0919fa2f73884429e89ecb8b27d
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Sep 26, 2023
1 parent e195bf5 commit 7bd5976
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 24 deletions.
2 changes: 0 additions & 2 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,9 @@ class TestSmoke:
(transforms.RandAugment(), auto_augment_adapter),
(transforms.TrivialAugmentWide(), auto_augment_adapter),
(transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None),
(transforms.Grayscale(), None),
(transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None),
(transforms.RandomAutocontrast(p=1.0), None),
(transforms.RandomEqualize(p=1.0), None),
(transforms.RandomGrayscale(p=1.0), None),
(transforms.RandomInvert(p=1.0), None),
(transforms.RandomChannelPermutation(), None),
(transforms.RandomPhotometricDistort(p=1.0), None),
Expand Down
22 changes: 0 additions & 22 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,6 @@ def __init__(
(torch.float32, torch.float64),
]
],
ConsistencyConfig(
v2_transforms.Grayscale,
legacy_transforms.Grayscale,
[
ArgsKwargs(num_output_channels=1),
ArgsKwargs(num_output_channels=3),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
# Use default tolerances of `torch.testing.assert_close`
closeness_kwargs=dict(rtol=None, atol=None),
),
ConsistencyConfig(
v2_transforms.ToPILImage,
legacy_transforms.ToPILImage,
Expand Down Expand Up @@ -217,17 +206,6 @@ def __init__(
],
closeness_kwargs={"atol": 1e-6, "rtol": 1e-6},
),
ConsistencyConfig(
v2_transforms.RandomGrayscale,
legacy_transforms.RandomGrayscale,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
# Use default tolerances of `torch.testing.assert_close`
closeness_kwargs=dict(rtol=None, atol=None),
),
ConsistencyConfig(
v2_transforms.PILToTensor,
legacy_transforms.PILToTensor,
Expand Down
55 changes: 55 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -3945,3 +3945,58 @@ def test_transform_correctness(self, brightness, contrast, saturation, hue):

mae = (actual.float() - expected.float()).abs().mean()
assert mae < 2


class TestRgbToGrayscale:
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image(self, dtype, device):
check_kernel(F.rgb_to_grayscale_image, make_image(dtype=dtype, device=device))

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
def test_functional(self, make_input):
check_functional(F.rgb_to_grayscale, make_input())

@pytest.mark.parametrize(
("kernel", "input_type"),
[
(F.rgb_to_grayscale_image, torch.Tensor),
(F._rgb_to_grayscale_image_pil, PIL.Image.Image),
(F.rgb_to_grayscale_image, tv_tensors.Image),
],
)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.rgb_to_grayscale, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize("transform", [transforms.Grayscale(), transforms.RandomGrayscale(p=1)])
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
def test_transform(self, transform, make_input):
check_transform(transform, make_input())

@pytest.mark.parametrize("num_output_channels", [1, 3])
@pytest.mark.parametrize("fn", [F.rgb_to_grayscale, transform_cls_to_functional(transforms.Grayscale)])
def test_image_correctness(self, num_output_channels, fn):
image = make_image(dtype=torch.uint8, device="cpu")

actual = fn(image, num_output_channels=num_output_channels)
expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_output_channels))

assert_equal(actual, expected, rtol=0, atol=1)

@pytest.mark.parametrize("num_input_channels", [1, 3])
def test_random_transform_correctness(self, num_input_channels):
image = make_image(
color_space={
1: "GRAY",
3: "RGB",
}[num_input_channels],
dtype=torch.uint8,
device="cpu",
)

transform = transforms.RandomGrayscale(p=1)

actual = transform(image)
expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_input_channels))

assert_equal(actual, expected, rtol=0, atol=1)

0 comments on commit 7bd5976

Please sign in to comment.