-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Description
🐛 Describe the bug
Simple error with a simple fix.
vision/torchvision/transforms/autoaugment.py
Line 226 in 48ebc0b
"Solarize": (torch.linspace(256.0, 0.0, num_bins), False), |
256 is cannot be represented as an uint8
and the range should start at 255.
How to reproduce.
import torch
import torchvision.transforms as T
aug_space = torch.linspace(256.0, 0.0, 10)
image = torch.randint(0, 225, (3, 512, 512), dtype=torch.uint8)
T.functional.solarize(image, aug_space[0])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
my_awesome_network.py in <module>
1 image = torch.randint(0, 225, (3, 512, 512), dtype=torch.uint8)
2
----> 3 T.functional.solarize(image, aug_space[0])
path/to/torchvision/transforms/functional.py in solarize(img, threshold)
1318 return F_pil.solarize(img, threshold)
1319
-> 1320 return F_t.solarize(img, threshold)
1321
1322
path/to/torchvision/torchvision/transforms/functional_tensor.py in solarize(img, threshold)
872
873 inverted_img = invert(img)
--> 874 return torch.where(img >= threshold, inverted_img, img)
875
876
RuntimeError: value cannot be converted to type uint8_t without overflow: 256
Versions
Collecting environment information...
PyTorch version: 1.10.0+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.33
Python version: 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.11.0-38-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: Tesla V100-PCIE-16GB
Nvidia driver version: 470.63.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.21.3
[pip3] pytorch-lightning==1.4.9
[pip3] torch==1.10.0+cu113
[pip3] torchmetrics==0.6.0
[pip3] torchvision==0.11.1+cu113