Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.linalg.eigh fails on GPU and corrupts memory #105359

Closed
anana10c opened this issue Jul 17, 2023 · 1 comment
Closed

torch.linalg.eigh fails on GPU and corrupts memory #105359

anana10c opened this issue Jul 17, 2023 · 1 comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@anana10c
Copy link

anana10c commented Jul 17, 2023

🐛 Describe the bug

torch.linalg.eigh fails on some large low-rank float32 matrices on GPU, but succeeds on CPU or when cast to float64. (See similar issue at #94772). After failing, the matrix cannot be accessed again without causing a CUDA illegal memory access error.

An example matrix that fails can be found here: rank7_idx0.1.3.0_iter100_factor.pt.zip
This matrix was generated when applying the Shampoo optimizer to HF T5 finetuning.

import torch

factor_matrix = torch.load("rank7_idx0.1.3.0_iter100_factor.pt")
factor_matrix = factor_matrix.to("cuda:0")
torch.linalg.eigh(factor_matrix)  # will error with "failed to compute eigendecomposition"

print(factor_matrix)  # illegal memory access

cusolver eigendecomposition error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
torch._C._LinAlgError: cusolver error: CUSOLVER_STATUS_EXECUTION_FAILED, when calling `cusolverDnXsyevd( handle, params, jobz, uplo, n, CUDA_R_32F, reinterpret_cast<void*>(A), lda, CUDA_R_32F, reinterpret_cast<void*>(W), CUDA_R_32F, reinterpret_cast<void*>(bufferOnDevice), workspaceInBytesOnDevice, reinterpret_cast<void*>(bufferOnHost), workspaceInBytesOnHost, info)`. This error may appear if the input matrix contains NaN.

CUDA illegal memory access error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/fsx/users/annacai/work/pytorch/torch/_tensor.py", line 427, in __repr__
    return torch._tensor_str._str(self, tensor_contents=tensor_contents)
  File "/fsx/users/annacai/work/pytorch/torch/_tensor_str.py", line 669, in _str
    return _str_intern(self, tensor_contents=tensor_contents)
  File "/fsx/users/annacai/work/pytorch/torch/_tensor_str.py", line 600, in _str_intern
    tensor_str = _tensor_str(self, indent)
  File "/fsx/users/annacai/work/pytorch/torch/_tensor_str.py", line 352, in _tensor_str
    formatter = _Formatter(get_summarized_data(self) if summarize else self)
  File "/fsx/users/annacai/work/pytorch/torch/_tensor_str.py", line 388, in get_summarized_data
    return torch.stack([get_summarized_data(x) for x in (start + end)])
  File "/fsx/users/annacai/work/pytorch/torch/_tensor_str.py", line 388, in <listcomp>
    return torch.stack([get_summarized_data(x) for x in (start + end)])
  File "/fsx/users/annacai/work/pytorch/torch/_tensor_str.py", line 378, in get_summarized_data
    return torch.cat(
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

cc @ezyang @gchanan @zou3519 @ptrblck @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano @hjmshi @mikerabbat @dmudigere @tsunghsienlee @awgu @wanchaol @gallego-posada

Versions

Collecting environment information...
PyTorch version: 2.1.0a0+git6855053
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1 
CMake version: version 3.26.3
Libc version: glibc-2.31

Python version: 3.9.16 (main, May 15 2023, 23:46:34)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-1019-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB

Nvidia driver version: 525.85.12
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   46 bits physical, 48 bits virtual
CPU(s):                          96
On-line CPU(s) list:             0-95
Thread(s) per core:              2
Core(s) per socket:              24
Socket(s):                       2
NUMA node(s):                    2
Vendor ID:                       GenuineIntel
CPU family:                      6
Model:                           85
Model name:                      Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
Stepping:                        7
CPU MHz:                         2999.998
BogoMIPS:                        5999.99
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       1.5 MiB
L1i cache:                       1.5 MiB
L2 cache:                        48 MiB
L3 cache:                        71.5 MiB
NUMA node0 CPU(s):               0-23,48-71
NUMA node1 CPU(s):               24-47,72-95
Vulnerability Itlb multihit:     KVM: Mitigation: VMX unsupported
Vulnerability L1tf:              Mitigation; PTE Inversion
Vulnerability Mds:               Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown:          Mitigation; PTI
Vulnerability Mmio stale data:   Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:          Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, STIBP disabled, RSB filling
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.24.3
[pip3] torch==2.1.0a0+git6855053
[pip3] torch-tb-profiler==0.4.1
[pip3] torchvision==0.16.0a0+a00152b
[pip3] triton==2.0.0
[pip3] vit-pytorch==1.2.2
[conda] magma-cuda110             2.5.2                         1    pytorch
[conda] mkl                       2023.1.0         h6d00ec8_46342  
[conda] mkl-include               2023.1.0         h06a4308_46342  
[conda] numpy                     1.24.3                   pypi_0    pypi
[conda] torch                     2.1.0a0+git6855053           dev_0    <develop>
[conda] torch-tb-profiler         0.4.1                    pypi_0    pypi
[conda] torchvision               0.16.0a0+a00152b          pypi_0    pypi
[conda] triton                    2.0.0                    pypi_0    pypi
[conda] vit-pytorch               1.2.2                    pypi_0    pypi
@mikaylagawarecki mikaylagawarecki added module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module high priority labels Jul 18, 2023
@lezcano
Copy link
Collaborator

lezcano commented Jul 18, 2023

I can repro this issue.

Now, let's take the SVD of the matrix. Its smallest singular values are

        ...
        9.5486e-27, 9.5329e-27, 9.4114e-27, 9.0841e-27, 9.0656e-27, 9.0152e-27,
        8.9583e-27, 8.9510e-27, 8.9487e-27, 8.9279e-27, 8.9051e-27, 8.8581e-27,
        8.8242e-27, 8.7709e-27, 8.7051e-27, 8.6711e-27, 8.6426e-27, 8.6172e-27,
        8.5614e-27, 8.5556e-27, 8.5042e-27, 8.4976e-27, 8.4803e-27, 8.4689e-27,
        8.4639e-27, 8.4571e-27, 8.4171e-27, 8.3823e-27, 8.3283e-27, 8.3103e-27,
        8.2244e-27, 8.2069e-27, 8.2068e-27, 8.1859e-27, 8.1504e-27, 8.1150e-27,
        8.0973e-27, 8.0806e-27, 8.0703e-27, 8.0296e-27, 8.0218e-27, 8.0125e-27,
        7.9564e-27, 7.8682e-27, 7.8656e-27, 7.8593e-27, 7.8540e-27, 7.8318e-27,
        7.7552e-27, 7.7416e-27, 7.7190e-27, 7.6938e-27, 7.6808e-27, 7.6795e-27,
        7.6173e-27, 7.5397e-27, 7.5390e-27, 7.5226e-27, 7.4766e-27, 7.4765e-27,
        7.4595e-27, 7.4517e-27, 7.4019e-27, 7.3523e-27, 7.2686e-27, 7.2376e-27,
        7.2364e-27, 7.1676e-27, 7.1451e-27, 7.1279e-27, 7.0850e-27, 7.0534e-27,
        7.0266e-27, 6.9764e-27, 6.9334e-27, 6.9318e-27, 6.8258e-27, 6.8252e-27,
        6.8058e-27, 6.7995e-27, 6.7824e-27, 6.7768e-27, 6.7612e-27, 6.7583e-27,
        6.7168e-27, 6.7107e-27, 6.7037e-27, 6.6893e-27, 6.5675e-27, 6.1478e-27],
       device='cuda:0')

It's clear that this matrix is very close to being singular. In particular, this falls under https://pytorch.org/docs/main/notes/numerical_accuracy.html#extremal-values-in-linalg, so it is expected.

As recommended there, if you are working with very singular matrices, there is two ways of going about it:

  1. The quick and dirty: Perform the operation in float64 and hope for the best
  2. Apply some preconditioner to your matrix that regularises it. In other words, it moves its spectrum towards a more amenable one.

Option 1. works in this case. print(torch.linalg.eigh(factor_matrix.to(torch.float64))) just succeeds.

All this being said, cusolver should not crash in an unrecoverable way if possible. cc @IvanYashchuk @xwang233

This issue was already reported in #94772, as you mentioned. Let's continue the discussion there.

@lezcano lezcano closed this as completed Jul 18, 2023
pytorchmergebot pushed a commit that referenced this issue Aug 16, 2023
…conditioned, in some cusolver version (#107082)

Related: #94772, #105359

I can locally reproduce this crash with pytorch 2.0.1 stable pip binary. The test already passes with the latest cuda 12.2 release.

Re: #94772 (comment)
> From discussion in triage review:

- [x] we should add a test to prevent regressions
- [x] properly document support wrt different CUDA versions
- [x] possibly add support using MAGMA
Pull Request resolved: #107082
Approved by: https://github.com/lezcano
summerdo pushed a commit to summerdo/pytorch that referenced this issue Aug 17, 2023
…conditioned, in some cusolver version (pytorch#107082)

Related: pytorch#94772, pytorch#105359

I can locally reproduce this crash with pytorch 2.0.1 stable pip binary. The test already passes with the latest cuda 12.2 release.

Re: pytorch#94772 (comment)
> From discussion in triage review:

- [x] we should add a test to prevent regressions
- [x] properly document support wrt different CUDA versions
- [x] possibly add support using MAGMA
Pull Request resolved: pytorch#107082
Approved by: https://github.com/lezcano
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants