diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 24a7523b62a..40a51775a09 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -795,6 +795,36 @@ def test_solarize2(device, dtype, config, channels): ) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("threshold", [0.0, 0.25, 0.5, 0.75, 1.0]) +def test_solarize_threshold1_bound(threshold, device): + img = torch.rand((3, 12, 23)).to(device) + F_t.solarize(img, threshold) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("threshold", [1.5]) +def test_solarize_threshold1_upper_bound(threshold, device): + img = torch.rand((3, 12, 23)).to(device) + with pytest.raises(TypeError, match="Threshold should be less than bound of img."): + F_t.solarize(img, threshold) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("threshold", [0, 64, 128, 192, 255]) +def test_solarize_threshold2_bound(threshold, device): + img = torch.randint(0, 256, (3, 12, 23)).to(device) + F_t.solarize(img, threshold) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("threshold", [260]) +def test_solarize_threshold2_upper_bound(threshold, device): + img = torch.randint(0, 256, (3, 12, 23)).to(device) + with pytest.raises(TypeError, match="Threshold should be less than bound of img."): + F_t.solarize(img, threshold) + + @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) @pytest.mark.parametrize("config", [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 09ae726931c..dfb32a41adf 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -16,6 +16,12 @@ def _assert_image_tensor(img: Tensor) -> None: raise TypeError("Tensor is not a torch image.") +def _assert_threshold(img: Tensor, threshold: float) -> None: + bound = 1 if img.is_floating_point() else 255 + if threshold > bound: + raise TypeError("Threshold should be less than bound of img.") + + def get_image_size(img: Tensor) -> List[int]: # Returns (w, h) of tensor image _assert_image_tensor(img) @@ -882,6 +888,8 @@ def solarize(img: Tensor, threshold: float) -> Tensor: _assert_channels(img, [1, 3]) + _assert_threshold(img, threshold) + inverted_img = invert(img) return torch.where(img >= threshold, inverted_img, img)