-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Bug
For some reason, if we convert tensor to float32 with .float(), calculations are performed with FP32 rather than TF32, even if the latter is enabled.
To Reproduce
Run the following code, based on guide for TF32:
import torch
import time
a_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda')
b_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda')
ab_full = a_full @ b_full
mean = ab_full.abs().mean() # 80.7277
a = a_full.float()
b = b_full.float()
# The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
torch.cuda.synchronize()
# Do matmul at TF32 mode.
start = time.time();
ab_tf32 = a @ b # takes 0.016s on GA100
torch.cuda.synchronize()
end = time.time();
print(int((end - start) * 1000))
error = (ab_tf32 - ab_full).abs().max() # 0.1747
relative_error = error / mean # 0.0022
# The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
torch.backends.cuda.matmul.allow_tf32 = False
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = False
start = time.time();
ab_fp32 = a @ b # takes 0.11s on GA100
torch.cuda.synchronize()
end = time.time();
print(int((end - start) * 1000))
error = (ab_fp32 - ab_full).abs().max() # 0.0031
relative_error = error / mean # 0.000039
Both times are equal or almost equal (approx. 110ms on A6000).
If we replace
a = a_full.float()
b = b_full.float()
with
a = a_full.to(dtype=torch.float32)
b = b_full.to(dtype=torch.float32)
we will get the expected result, i.e., first calculation is faster (approx. 33ms on A6000).
Expected behavior
Both .float() and .to(dtype=torch.float32) make use of TF32.
Environment
PyTorch version: 1.9.0+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.10.2
Libc version: glibc-2.25
Python version: 3.6.9 (default, Jan 26 2021, 15:33:00) [GCC 8.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-81-generic-x86_64-with-Ubuntu-18.04-bionic
Is CUDA available: True
CUDA runtime version: 11.1.105
GPU models and configuration:
GPU 0: NVIDIA RTX A6000
GPU 1: NVIDIA RTX A6000
GPU 2: NVIDIA RTX A6000
GPU 3: NVIDIA RTX A6000
Nvidia driver version: 470.57.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.2.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] torch==1.9.0+cu111
[pip3] torchaudio==0.9.0
[pip3] torchvision==0.10.0+cu111
[conda] Could not collect
Additional context
Others have confirmed similar effect on other GPUs, e.g., A100