Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The matrix multiplication operator can't get correct results on 3090 !! #61890

Open
CrisHY1995 opened this issue Jul 20, 2021 · 4 comments
Open
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: tf32 Related to tf32 data format triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@CrisHY1995
Copy link

CrisHY1995 commented Jul 20, 2021

🐛 Bug

The matrix multiplication operator can't get correct results on 3090 !!

To Reproduce

mini code sample:

import torch

device = torch.device("cuda:0")
cur_mat = torch.eye(3).to(device).unsqueeze(0)
cur_vec = torch.as_tensor([-0.6660, -0.2958,  8.6392]).view(3, 1).to(device).unsqueeze(0)

print("Mat")
print(cur_mat)
print("Vec")
print(cur_vec)
print("----------")
print("CPU: torch.bmm   : ", end="")
print(cur_mat.cpu().bmm(cur_vec.cpu()).view(3))
print("GPU: torch.bmm   : ", end="")
print(cur_mat.bmm(cur_vec).view(3))
print("GPU: torch.matmul: ", end="")
print(cur_mat.matmul(cur_vec).view(3))
print("GPU: torch.mm    : ", end="")
print(cur_mat[0].mm(cur_vec[0]).view(3))

output:

Mat
tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]], device='cuda:0')
Vec
tensor([[[-0.6660],
         [-0.2958],
         [ 8.6392]]], device='cuda:0')
----------
CPU: torch.bmm   : tensor([-0.6660, -0.2958,  8.6392])
GPU: torch.bmm   : tensor([-0.6660, -0.2959,  8.6406], device='cuda:0')
GPU: torch.matmul: tensor([-0.6660, -0.2959,  8.6406], device='cuda:0')
GPU: torch.mm    : tensor([-0.6660, -0.2959,  8.6406], device='cuda:0')

Expected behavior

because cur_mat is a identity mat, the output should be unchanged.

Environment

  • PyTorch Version (e.g., 1.0): '1.9.0+cu111'
  • GPU: 3090
  • OS (e.g., Linux):ubuntu 18.04
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version:3.8.5
  • CUDA/cuDNN version:11.4

details

PyTorch version: 1.9.0+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.19.2
Libc version: glibc-2.27

Python version: 3.8.5 (default, Sep  4 2020, 07:30:14)  [GCC 7.3.0] (64-bit runtime)
Python platform: Linux-5.4.0-77-generic-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 11.4.48
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090
GPU 2: NVIDIA GeForce RTX 3090
GPU 3: NVIDIA GeForce RTX 3090

Nvidia driver version: 470.42.01
cuDNN version: Probably one of the following:
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn.so.8.0.4
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.4
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.4
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.4
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.4
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.4
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.4
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] numpydoc==1.1.0
[pip3] pytorch3d==0.4.0
[pip3] torch==1.9.0+cu111
[pip3] torch-tb-profiler==0.2.1
[pip3] torchaudio==0.9.0
[pip3] torchgeometry==0.1.2
[pip3] torchsearchsorted==1.1
[pip3] torchvision==0.10.0+cu111
[conda] blas                      1.0                         mkl    defaults
[conda] cudatoolkit               11.0.221             h6bb024c_0    defaults
[conda] mkl                       2020.2                      256    defaults
[conda] mkl-service               2.3.0            py38he904b0f_0    defaults
[conda] mkl_fft                   1.2.0            py38h23d657b_0    defaults
[conda] mkl_random                1.1.1            py38h0573a6f_0    defaults
[conda] numpy                     1.19.2           py38h54aff64_0    defaults
[conda] numpy-base                1.19.2           py38hfa32c7d_0    defaults
[conda] numpydoc                  1.1.0              pyhd3eb1b0_1    defaults
[conda] pytorch3d                 0.4.0                     dev_0    <develop>
[conda] torch                     1.9.0+cu111              pypi_0    pypi
[conda] torch-tb-profiler         0.2.1                    pypi_0    pypi
[conda] torchaudio                0.9.0                    pypi_0    pypi
[conda] torchgeometry             0.1.2                    pypi_0    pypi
[conda] torchsearchsorted         1.1                      pypi_0    pypi
[conda] torchvision               0.10.0+cu111             pypi_0    pypi

Additional context

Only 3090 exists the above-mentioned problem.

Testing results show that pytorch with '1.7.1+cu110' verision can get the right result. However, when giving a tensor with a lagre batch size, bmm operator is also ubable to return the right result.
code sample:
test data : tt_dict.pkl

import torch
import pickle as pkl


device = torch.device("cuda:0")
with open("tt_dict.pkl", "rb") as f:
    tt = pkl.load(f)
    
live_vps = tt["live_vps"].to(device)
batch_Rmats = tt["batch_Rmats"].to(device)

tt = batch_Rmats.permute(0, 2, 1)

print("Output:")
print("-----------")
print("bmm with large batch size:")
rott_vps = torch.bmm(live_vps, tt)
print(rott_vps[2, 0])
print("-----------")
print("bmm with mini batch size:")
print((live_vps[2:3, 0:1]).bmm(tt[2:3]))
print("-----------")
print("mm:")
print((live_vps[2, 0:1]).mm(tt[2]))
print("-----------")
print("cpu:")
print(((live_vps[2, 0:1]).cpu()).mm(tt[2].cpu()))

output:

Output:
-----------
bmm with large batch size:
tensor([-0.6660, -0.2957,  **8.6406**], device='cuda:0')
-----------
bmm with mini batch size:
tensor([[[-0.6660, -0.2958,  **8.6392**]]], device='cuda:0')
-----------
mm:
tensor([[-0.6660, -0.2958,  **8.6392**]], device='cuda:0')
-----------
cpu:
tensor([[-0.6660, -0.2958,  **8.6392**]])

cc @ngimel @zasdfgbnm @ptrblck

@ejguan ejguan added module: correctness (silent) issue that returns an incorrect result silently module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jul 20, 2021
@ngimel
Copy link
Collaborator

ngimel commented Jul 20, 2021

This is because of tf32, you can disable it if you need exact results https://pytorch.org/docs/stable/backends.html?highlight=tf32#torch.backends.cuda.matmul.allow_tf32

@ngimel ngimel added module: tf32 Related to tf32 data format and removed module: correctness (silent) issue that returns an incorrect result silently labels Jul 20, 2021
@CrisHY1995
Copy link
Author

This is because of tf32, you can disable it if you need exact results https://pytorch.org/docs/stable/backends.html?highlight=tf32#torch.backends.cuda.matmul.allow_tf32

En, thank you for the prompt reply and the solution works,
but, this option should be enabled by default in the future verison ?, I mean, This problem may decline many aspects of some projects,

here, I just want to rotate a face mesh, but the deformed result looks pretty bad.

image

A related question is, does this(torch.backends.cuda.matmul.allow_tf32=Flase) have a big impact on computing speed ?

@ngimel
Copy link
Collaborator

ngimel commented Jul 22, 2021

Yes, it does have a large impact on computing speed, as shown on the linked documentation page https://pytorch.org/docs/stable/notes/cuda.html#tf32-on-ampere (0.11s -> 0.016s for a large matrix, approx 2x end2end speed-up for common workflows), and in our experience it doesn't affect the convergence of neural networks relying on it, although you are right that if you are using it for other transforms the distortions are significant.

@ssnl
Copy link
Collaborator

ssnl commented Aug 26, 2021

This brings the question if it should be globally disabled but locally enabled by PyTorch ops (linear layer, conv, rnn, etc), by default.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: tf32 Related to tf32 data format triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants