-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Description
🐛 Bug
When I have two seemingly identical bool tensors with values: tensor([True, True, True, True])
, one of them produces a float tensor with values: tensor([1.0, 1.0, 1.0, 1.0])
. The other one produces a float tensor with values: tensor([255., 255., 255., 255.])
.
To Reproduce
Steps to reproduce the behavior:
import numpy as np
import torch
# these are the same data and they look the same when you load them, but they behave subtly differently
poisoned = np.frombuffer(b'\xff\xff\xff\xff', dtype=np.bool) # hex 255
clean = np.frombuffer(b'\x01\x01\x01\x01', dtype=np.bool)
torch_poisoned = torch.from_numpy(poisoned)
torch_clean = torch.from_numpy(clean)
print('poisoned == clean -->', (poisoned == clean).all()) # >>> True
print('torch_poisoned == torch_clean --> ', (torch_poisoned == torch_clean).all()) # >>> tensor(False)
print('torch_poisoned.dtype == torch_clean.dtype --> ', (torch_poisoned.dtype == torch_clean.dtype)) # >>> True
print('torch_poisoned.dtype, torch_clean.dtype --> ', torch_poisoned.dtype, torch_clean.dtype) # >>> torch.bool, torch.bool
print('torch_clean: ',torch_clean) # >>> tensor([True, True, True, True])
print('torch_clean.float(): ', torch_clean.float()) # >>> tensor([1., 1., 1., 1.])
print('torch_poisoned: ', torch_poisoned) # >>> tensor([True, True, True, True])
print('torch_poisoned.float(): ', torch_poisoned.float()) # >>> tensor([255., 255., 255., 255.])
Expected behavior
Any thing that torch calls a True bool should be converted to a float value of 1.0, not whatever bytes happened to be stored inside the bool.
This could just be an abuse of bool, but I came across this bug when using converting a boolean image using PIL. PyTorch should probably prevent users from being able to shoot their foot like this.
Environment
Collecting environment information...
PyTorch version: 1.7.1+cu110
Is debug build: False
CUDA used to build PyTorch: 11.0
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.15.3
Python version: 3.6 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 11.2.142
GPU models and configuration: GPU 0: GeForce GTX 1080 Ti
Nvidia driver version: 455.45.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.1.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] torch==1.7.1+cu110
[pip3] torchaudio==0.7.2
[pip3] torchgeometry==0.1.2
[pip3] torchvision==0.8.2+cu110
[conda] Could not collect
Additional context
cc @ezyang @gchanan @zou3519 @mruberry @rgommers @heitorschueroff