diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index de49d8a8bef..6c99720114a 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -346,6 +346,12 @@ def crop_segmentation_mask(): ) +@register_kernel_info_from_sample_inputs_fn +def vertical_flip_segmentation_mask(): + for mask in make_segmentation_masks(): + yield SampleInput(mask) + + @pytest.mark.parametrize( "kernel", [ @@ -915,3 +921,15 @@ def _compute_expected_mask(mask, top_, left_, height_, width_): output_mask = F.crop_segmentation_mask(mask, top, left, height, width) expected_mask = _compute_expected_mask(mask, top, left, height, width) torch.testing.assert_close(output_mask, expected_mask) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): + mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) + mask[:, 0, :] = 1 + + out_mask = F.vertical_flip_segmentation_mask(mask) + + expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) + expected_mask[:, -1, :] = 1 + torch.testing.assert_close(out_mask, expected_mask)