From b688a532bf59b829e04a27c74d15ee1e43e28376 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 18 Jul 2022 19:02:56 +0200 Subject: [PATCH] Added more functional tests --- test/test_prototype_transforms_functional.py | 72 +++++++++++++++++++- 1 file changed, 69 insertions(+), 3 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 873516869f8..61d1adfab18 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -200,6 +200,30 @@ def horizontal_flip_bounding_box(): yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size) +@register_kernel_info_from_sample_inputs_fn +def horizontal_flip_segmentation_mask(): + for mask in make_segmentation_masks(): + yield SampleInput(mask) + + +@register_kernel_info_from_sample_inputs_fn +def vertical_flip_image_tensor(): + for image in make_images(): + yield SampleInput(image) + + +@register_kernel_info_from_sample_inputs_fn +def vertical_flip_bounding_box(): + for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]): + yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size) + + +@register_kernel_info_from_sample_inputs_fn +def vertical_flip_segmentation_mask(): + for mask in make_segmentation_masks(): + yield SampleInput(mask) + + @register_kernel_info_from_sample_inputs_fn def resize_image_tensor(): for image, interpolation, max_size, antialias in itertools.product( @@ -404,9 +428,17 @@ def crop_segmentation_mask(): @register_kernel_info_from_sample_inputs_fn -def vertical_flip_segmentation_mask(): - for mask in make_segmentation_masks(): - yield SampleInput(mask) +def resized_crop_image_tensor(): + for mask, top, left, height, width, size, antialias in itertools.product( + make_images(), + [-8, 9], + [-8, 9], + [12], + [12], + [(16, 18)], + [True, False], + ): + yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size, antialias=antialias) @register_kernel_info_from_sample_inputs_fn @@ -457,6 +489,19 @@ def pad_bounding_box(): yield SampleInput(bounding_box, padding=padding, format=bounding_box.format) +@register_kernel_info_from_sample_inputs_fn +def perspective_image_tensor(): + for image, perspective_coeffs, fill in itertools.product( + make_images(extra_dims=((), (4,))), + [ + [1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018], + [0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063], + ], + [None, [128], [12.0]], # fill + ): + yield SampleInput(image, perspective_coeffs=perspective_coeffs, fill=fill) + + @register_kernel_info_from_sample_inputs_fn def perspective_bounding_box(): for bounding_box, perspective_coeffs in itertools.product( @@ -488,6 +533,15 @@ def perspective_segmentation_mask(): ) +@register_kernel_info_from_sample_inputs_fn +def center_crop_image_tensor(): + for mask, output_size in itertools.product( + make_images(sizes=((16, 16), (7, 33), (31, 9))), + [[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size + ): + yield SampleInput(mask, output_size) + + @register_kernel_info_from_sample_inputs_fn def center_crop_bounding_box(): for bounding_box, output_size in itertools.product(make_bounding_boxes(), [(24, 12), [16, 18], [46, 48], [12]]): @@ -1181,6 +1235,18 @@ def _compute_expected_mask(mask, top_, left_, height_, width_): torch.testing.assert_close(output_mask, expected_mask) +@pytest.mark.parametrize("device", cpu_and_gpu()) +def test_correctness_horizontal_flip_segmentation_mask_on_fixed_input(device): + mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) + mask[:, :, 0] = 1 + + out_mask = F.horizontal_flip_segmentation_mask(mask) + + expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) + expected_mask[:, :, -1] = 1 + torch.testing.assert_close(out_mask, expected_mask) + + @pytest.mark.parametrize("device", cpu_and_gpu()) def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)