Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The return of torch.inverse contains nan sometime #47272

Closed
Lmy0217 opened this issue Nov 3, 2020 · 17 comments
Closed

The return of torch.inverse contains nan sometime #47272

Lmy0217 opened this issue Nov 3, 2020 · 17 comments
Assignees
Labels
module: correctness (silent) issue that returns an incorrect result silently module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: NaNs and Infs Problems related to NaN and Inf handling in floating point triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Lmy0217
Copy link

Lmy0217 commented Nov 3, 2020

🐛 Bug

The return of torch.inverse contains nan sometime.

AssertionError: tensor([[[    nan, -0.8982, -0.0000],
         [    nan, -0.4397,  0.0000],
         [    nan,  0.0000,  1.0000]],

        [[-0.4397, -0.8982, -0.0000],
         [ 0.8982, -0.4397,  0.0000],
         [ 0.0000,  0.0000,  1.0000]]], device='cuda:0')
84

To Reproduce

Steps to reproduce the behavior:

import torch

device = torch.device('cuda:0')

d = torch.tensor([
    [[-0.4397,  0.8981,  0.0000], [-0.8981, -0.4397,  0.0000], [ 0.0000,  0.0000,  1.0000]],
    [[-0.4397,  0.8981,  0.0000], [-0.8981, -0.4397,  0.0000], [ 0.0000,  0.0000,  1.0000]]
], device=device)

count = 0
while True:
    temp = torch.inverse(d)
    count = count + 1
    assert not torch.isnan(temp).any(), str(temp) + '\n' + str(count)

Expected behavior

Return accurate results.

Environment

I got error in two environments:

PyTorch version: 1.7.0
Is debug build: True
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 企业版
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect

Python version: 3.6 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.2.89
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.2\bin\cudnn64_7.dll
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.7.1
[pip3] numpy==1.19.2
[pip3] robust-loss-pytorch==0.0.2
[pip3] torch==1.7.0
[pip3] torch-cluster==1.5.4
[pip3] torch-dct==0.1.5
[pip3] torch-geometric==1.6.1
[pip3] torch-scatter==2.0.4
[pip3] torch-sparse==0.6.3
[pip3] torch-spline-conv==1.2.0
[pip3] torchaudio==0.7.0
[pip3] torchfile==0.1.0
[pip3] torchnet==0.0.4
[pip3] torchstat==0.0.7
[pip3] torchvision==0.8.1
[conda] Could not collect
PyTorch version: 1.7.0+cu101
Is debug build: Yes
CUDA used to build PyTorch: 10.1

OS: Ubuntu 16.04.1 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609
CMake version: version 3.5.1

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration: 
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
GPU 2: GeForce RTX 2080 Ti
GPU 3: GeForce RTX 2080 Ti
GPU 4: GeForce RTX 2080 Ti
GPU 5: GeForce RTX 2080 Ti
GPU 6: GeForce RTX 2080 Ti
GPU 7: GeForce RTX 2080 Ti

Nvidia driver version: 430.64
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.4.2
/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7.6.5

Versions of relevant libraries:
[pip3] numpy==1.19.1
[pip3] numpydoc==0.9.2
[pip3] robust-loss-pytorch==0.0.2
[pip3] torch==1.7.0+cu101
[pip3] torch-dct==0.1.5
[pip3] torchaudio==0.7.0
[pip3] torchvision==0.8.1+cu101
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               10.2.89              hfd86e86_1  
[conda] mkl                       2020.0                      166  
[conda] mkl-service               2.3.0            py37he904b0f_0  
[conda] mkl_fft                   1.0.15           py37ha843d7b_0  
[conda] mkl_random                1.1.0            py37hd6b4f25_0  
[conda] numpy                     1.19.1                   pypi_0    pypi
[conda] numpydoc                  0.9.2                      py_0  
[conda] torch                     1.7.0+cu101              pypi_0    pypi
[conda] torch-dct                 0.1.5                    pypi_0    pypi
[conda] torchaudio                0.7.0                    pypi_0    pypi
[conda] torchvision               0.8.1+cu101              pypi_0    pypi

cc @vishwakftw @jianyuh @nikitaved @pearu @mruberry @heitorschueroff

@gchanan gchanan added the needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user label Nov 3, 2020
@gchanan
Copy link
Contributor

gchanan commented Nov 3, 2020

I could not reproduce this (although I used CUDA 9) -- how many trials did this take for you?

@Lmy0217
Copy link
Author

Lmy0217 commented Nov 3, 2020

@gchanan I got error only in PyTorch 1.7.0. The loop in my code ensures that it happens every time. Almost every time count is less than 1000.

@xwang233
Copy link
Collaborator

xwang233 commented Nov 3, 2020

Thanks @Lmy0217 , this is a known issue #46557 and was fixed in #46625 . It's because the matrix you have is a singular matrix, the correct behavior is to throw a runtime error instead of give nan output. We have fixed that in the master branch, but not in 1.7.0 release.

The cusolver path is enabled only for cuda >= 10.1.243. For cuda versions lower than that, the much slower MAGMA path is used.

@xwang233 xwang233 added module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: NaNs and Infs Problems related to NaN and Inf handling in floating point and removed needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user labels Nov 3, 2020
@ptrblck
Copy link
Collaborator

ptrblck commented Nov 4, 2020

Is the input matrix really a singular matrix in this case, as its determinant is not zero?
Does #46557 also introduce flaky results, i.e. sometimes NaNs might be raised or do you think we might be facing a new issue?

@Lmy0217
Copy link
Author

Lmy0217 commented Nov 4, 2020

@xwang233 @ptrblck the matrix is not a singular matrix! Its det is 0.9999 and its rank is 3.

@xwang233
Copy link
Collaborator

xwang233 commented Nov 4, 2020

Oops, sorry, I was looking at a different matrix. However, I can't reproduce this on my machine with cuda 11.0 using either 1070 or 2070.

Update: can't reproduce on 10.2 for V100 either.

@Lmy0217
Copy link
Author

Lmy0217 commented Nov 4, 2020

Maybe should try on cuda 10.1 or 10.2.

@xwang233
Copy link
Collaborator

xwang233 commented Nov 4, 2020

Thanks, I'll try that later on cuda 10.2. Can you also try if CUDA_LAUNCH_BLOCKING=1 python script.py solves the problem?

@Lmy0217
Copy link
Author

Lmy0217 commented Nov 4, 2020

No errors were encountered after setting CUDA_LAUNCH_BLOCKING=1 on win10.

@Lmy0217
Copy link
Author

Lmy0217 commented Nov 4, 2020

Strangely, it can't be reproduced on Ubuntu now. Randomness is a problem.

@xwang233
Copy link
Collaborator

xwang233 commented Nov 4, 2020

Yeah, it uses CUDA multi-stream parallel execution for optimization purpose. I'll check if there is anything I can do about it. Thanks for reporting the issue.

Update: this may also be the incorrect usage of cuda caching allocator issue that was fixed in #46605, but not in 1.7.0.

@ejguan ejguan added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 4, 2020
@xwang233
Copy link
Collaborator

xwang233 commented Nov 5, 2020

@Lmy0217 I'm able to reproduce that on my machine with cuda 11.0 as well. The nan usually appears after 40k ~ 90k loops. Sorry for the inconvenience.

This problem only occurs when you have a batch size of 2, that is, when your tensor has a shape of (2, x, x).

A temporary workaround is

  1. Add CUDA_LAUNCH_BLOCKING=1 to the environment variable. However, this may affect performance of the whole script.

  2. Calculate the inverse of the two matrices separately, then use torch.cat or torch.stack to put them together.

cc @ngimel for visibility

@mruberry mruberry added triage review module: correctness (silent) issue that returns an incorrect result silently and removed triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Nov 6, 2020
@ngimel
Copy link
Collaborator

ngimel commented Nov 6, 2020

@xwang233 you allocate self_working_copy and self_inv_working_copy on the main stream

Tensor self_working_copy = cloneBatchedColumnMajor(self);
Tensor self_inv_working_copy = column_major_identity_matrix_like(self_working_copy);
and then are using them on the side stream. You should call record_stream to make sure that the memory is kept alive until operation on the side stream is finished.
Can you please submit either that fix, or disable multistreaming operation?

@xwang233
Copy link
Collaborator

xwang233 commented Nov 6, 2020

Thanks @ngimel , the original reason to add multi-stream execution was for better performance because cublas batched inverse is too slow. Since it's causing silent error, I would prefer disable multi-stream execution. I will submit a fix later.

@ngimel
Copy link
Collaborator

ngimel commented Nov 6, 2020

You can still use multi-stream if you properly register all the tensors to the correct streams

@t-vi
Copy link
Collaborator

t-vi commented Mar 6, 2021

Is this windows-specific by chance? (maybe if @peterjc123 can reproduce?)

@qiyan98
Copy link

qiyan98 commented Jan 27, 2022

I confirmed this error at pytorch 1.7.1 at Ubuntu 20.04. Updating to pytorch 1.10 solve the issue. It seems that the np.linalg.inv is always more stable.

emcastillo pushed a commit to emcastillo/pytorch that referenced this issue Mar 16, 2022
…-stream issue (pytorch#47026)

Summary:
### test_inverse_singular for cublas failure

Related
pytorch#46616 (comment)
https://app.circleci.com/pipelines/github/pytorch/pytorch/232112/workflows/4131d4ca-cd51-44e3-8e6c-b1c3555c62fa/jobs/8523970/tests

The cuda 11.1 CI container doesn't have MAGMA library, so cublas matrix inverse path is enabled.
```
Oct 27 23:13:47 -- MAGMA not found. Compiling without MAGMA support
```

The test_inverse_singular was introduced in pytorch#46625, but I forgot to fix that functionality for cublas path as well.

### cusolver inverse multi-stream failure

fix pytorch#47272

The original cuda event record/block stream was wrong, which could cause NaN in output tensor.

On my machine, the original code observes NaN in about 50k~500k loops. After this change, no NaN is observed in more than 2.5m loops.

The performance for batch 2 matrix inverse is still the same as those in pytorch#42403.

Pull Request resolved: pytorch#47026

Reviewed By: mruberry

Differential Revision: D24838546

Pulled By: ngimel

fbshipit-source-id: 3b83e4ab8e6b47a8273cba277251765bd6d97911
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: correctness (silent) issue that returns an incorrect result silently module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: NaNs and Infs Problems related to NaN and Inf handling in floating point triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

10 participants