Skip to content

Apparent different between .to(dtype=torch.float32) and .float() for TF32 #63951

@Randl

Description

@Randl

🐛 Bug

For some reason, if we convert tensor to float32 with .float(), calculations are performed with FP32 rather than TF32, even if the latter is enabled.

To Reproduce

Run the following code, based on guide for TF32:

import torch
import time

a_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda')
b_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda')
ab_full = a_full @ b_full
mean = ab_full.abs().mean() # 80.7277

a = a_full.float()
b = b_full.float()

# The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
torch.backends.cuda.matmul.allow_tf32 = True 
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

torch.cuda.synchronize()
# Do matmul at TF32 mode.
start = time.time();
ab_tf32 = a @ b # takes 0.016s on GA100
torch.cuda.synchronize()
end = time.time();
print(int((end - start) * 1000))
error = (ab_tf32 - ab_full).abs().max() # 0.1747
relative_error = error / mean # 0.0022

# The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
torch.backends.cuda.matmul.allow_tf32 = False
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = False

start = time.time();
ab_fp32 = a @ b # takes 0.11s on GA100
torch.cuda.synchronize()
end = time.time();
print(int((end - start) * 1000))
error = (ab_fp32 - ab_full).abs().max() # 0.0031
relative_error = error / mean # 0.000039

Both times are equal or almost equal (approx. 110ms on A6000).

If we replace

a = a_full.float()
b = b_full.float()

with

a = a_full.to(dtype=torch.float32)
b = b_full.to(dtype=torch.float32)

we will get the expected result, i.e., first calculation is faster (approx. 33ms on A6000).

Expected behavior

Both .float() and .to(dtype=torch.float32) make use of TF32.

Environment

PyTorch version: 1.9.0+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.10.2
Libc version: glibc-2.25

Python version: 3.6.9 (default, Jan 26 2021, 15:33:00) [GCC 8.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-81-generic-x86_64-with-Ubuntu-18.04-bionic
Is CUDA available: True
CUDA runtime version: 11.1.105
GPU models and configuration:
GPU 0: NVIDIA RTX A6000
GPU 1: NVIDIA RTX A6000
GPU 2: NVIDIA RTX A6000
GPU 3: NVIDIA RTX A6000

Nvidia driver version: 470.57.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.2.1
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] torch==1.9.0+cu111
[pip3] torchaudio==0.9.0
[pip3] torchvision==0.10.0+cu111
[conda] Could not collect

Additional context

Others have confirmed similar effect on other GPUs, e.g., A100

cc @zasdfgbnm @ptrblck

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: tf32Related to tf32 data formatneeds reproductionEnsure you have actionable steps to reproduce the issue. Someone else needs to confirm the repro.triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions