-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Closed
Description
🚀 The feature
The solarize()
method receives a threshold
which is currently not asserted based on the image type:
vision/torchvision/transforms/functional_tensor.py
Lines 876 to 886 in d367a01
def solarize(img: Tensor, threshold: float) -> Tensor: | |
_assert_image_tensor(img) | |
if img.ndim < 3: | |
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") | |
_assert_channels(img, [1, 3]) | |
inverted_img = invert(img) | |
return torch.where(img >= threshold, inverted_img, img) |
Ideally we should assert it. It's upper bound depends on the image type (1.0 for float and 255 for uint8) similar to:
bound = torch.tensor(1 if img.is_floating_point() else 255, dtype=img.dtype, device=img.device) |
Along with the assertion we should add a unit-test to ensure that if the threshold is over the permitted value it will fail.
Motivation, pitch
Due to the missing assertion, we didn't spot a bug at the configuration of AutoAugment methods. See #4805.