Skip to content

Conversion from bool to float sometimes produces tensor with values 255.0 for True instead of 1.0 #54789

@matwilso

Description

@matwilso

🐛 Bug

When I have two seemingly identical bool tensors with values: tensor([True, True, True, True]), one of them produces a float tensor with values: tensor([1.0, 1.0, 1.0, 1.0]). The other one produces a float tensor with values: tensor([255., 255., 255., 255.]).

To Reproduce

Steps to reproduce the behavior:

import numpy as np
import torch
# these are the same data and they look the same when you load them, but they behave subtly differently
poisoned = np.frombuffer(b'\xff\xff\xff\xff', dtype=np.bool) # hex 255
clean = np.frombuffer(b'\x01\x01\x01\x01', dtype=np.bool)

torch_poisoned = torch.from_numpy(poisoned)
torch_clean = torch.from_numpy(clean)


print('poisoned == clean -->', (poisoned == clean).all())  # >>> True
print('torch_poisoned == torch_clean --> ', (torch_poisoned == torch_clean).all())  # >>> tensor(False)

print('torch_poisoned.dtype == torch_clean.dtype --> ', (torch_poisoned.dtype == torch_clean.dtype))  # >>> True
print('torch_poisoned.dtype, torch_clean.dtype --> ', torch_poisoned.dtype, torch_clean.dtype)  # >>> torch.bool, torch.bool

print('torch_clean: ',torch_clean)  # >>> tensor([True, True, True, True])
print('torch_clean.float(): ', torch_clean.float())  # >>> tensor([1., 1., 1., 1.])

print('torch_poisoned: ', torch_poisoned)  # >>> tensor([True, True, True, True])
print('torch_poisoned.float(): ', torch_poisoned.float())  # >>> tensor([255., 255., 255., 255.])

Expected behavior

Any thing that torch calls a True bool should be converted to a float value of 1.0, not whatever bytes happened to be stored inside the bool.

This could just be an abuse of bool, but I came across this bug when using converting a boolean image using PIL. PyTorch should probably prevent users from being able to shoot their foot like this.

Environment

Collecting environment information...
PyTorch version: 1.7.1+cu110
Is debug build: False
CUDA used to build PyTorch: 11.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.15.3

Python version: 3.6 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 11.2.142
GPU models and configuration: GPU 0: GeForce GTX 1080 Ti
Nvidia driver version: 455.45.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.1.1
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] torch==1.7.1+cu110
[pip3] torchaudio==0.7.2
[pip3] torchgeometry==0.1.2
[pip3] torchvision==0.8.2+cu110
[conda] Could not collect

Additional context

cc @ezyang @gchanan @zou3519 @mruberry @rgommers @heitorschueroff

Metadata

Metadata

Assignees

Labels

actionablehigh prioritymodule: boolean tensormodule: correctness (silent)issue that returns an incorrect result silentlymodule: numpyRelated to numpy support, and also numpy compatibility of our operatorstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions