Skip to content
30 changes: 30 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,36 @@ def test_solarize2(device, dtype, config, channels):
)


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("threshold", [0.0, 0.25, 0.5, 0.75, 1.0])
def test_solarize_threshold1_bound(threshold, device):
img = torch.rand((3, 12, 23)).to(device)
F_t.solarize(img, threshold)


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("threshold", [1.5])
def test_solarize_threshold1_upper_bound(threshold, device):
img = torch.rand((3, 12, 23)).to(device)
with pytest.raises(TypeError, match="Threshold should be less than bound of img."):
F_t.solarize(img, threshold)


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("threshold", [0, 64, 128, 192, 255])
def test_solarize_threshold2_bound(threshold, device):
img = torch.randint(0, 256, (3, 12, 23)).to(device)
F_t.solarize(img, threshold)


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("threshold", [260])
def test_solarize_threshold2_upper_bound(threshold, device):
img = torch.randint(0, 256, (3, 12, 23)).to(device)
with pytest.raises(TypeError, match="Threshold should be less than bound of img."):
F_t.solarize(img, threshold)


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize("config", [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
Expand Down
8 changes: 8 additions & 0 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ def _assert_image_tensor(img: Tensor) -> None:
raise TypeError("Tensor is not a torch image.")


def _assert_threshold(img: Tensor, threshold: float) -> None:
bound = 1 if img.is_floating_point() else 255
if threshold > bound:
raise TypeError("Threshold should be less than bound of img.")


def get_image_size(img: Tensor) -> List[int]:
# Returns (w, h) of tensor image
_assert_image_tensor(img)
Expand Down Expand Up @@ -882,6 +888,8 @@ def solarize(img: Tensor, threshold: float) -> Tensor:

_assert_channels(img, [1, 3])

_assert_threshold(img, threshold)

inverted_img = invert(img)
return torch.where(img >= threshold, inverted_img, img)

Expand Down