Skip to content

Verify the threshold input for solarize() and add a test #4818

@datumbox

Description

@datumbox

🚀 The feature

The solarize() method receives a threshold which is currently not asserted based on the image type:

def solarize(img: Tensor, threshold: float) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
inverted_img = invert(img)
return torch.where(img >= threshold, inverted_img, img)

Ideally we should assert it. It's upper bound depends on the image type (1.0 for float and 255 for uint8) similar to:

bound = torch.tensor(1 if img.is_floating_point() else 255, dtype=img.dtype, device=img.device)

Along with the assertion we should add a unit-test to ensure that if the threshold is over the permitted value it will fail.

Motivation, pitch

Due to the missing assertion, we didn't spot a bug at the configuration of AutoAugment methods. See #4805.

cc @vfdev-5 @datumbox @pmeier

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions