-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Description
torch.linalg.eigh
producing inaccurate eigenvalues/eigenvectors compared to NumPy
Steps to Reproduce
- Use the matrix A provided in the code below.
- Compute the eigenvalues and eigenvectors using both torch.linalg.eigh and np.linalg.eigh.
- Compare the results.
import torch
import numpy as np
A = np.array([[ 1.60782897e+00, 2.28964731e-01, 5.37528796e-03, -3.68031830e-01, 7.76133314e-02, 1.95910275e-01, -3.12402956e-02, 3.67419720e-01, 1.22131474e-01, -2.53661489e+00, 2.17903289e-06, 5.94051089e-05, 2.33369647e-05, -3.33767384e-04, 7.98902474e-05, 9.64686275e-04, -4.88465303e-04, 2.36044801e-03, 4.49522515e-04, -4.29443026e+00],
[ 2.28964731e-01, 2.41090322e+00, -1.88907310e-02, -1.11321688e+00, 2.24941388e-01, 5.75562716e-01, -1.19260252e-01, 1.07684803e+00, 3.36590528e-01, -5.06800270e+00, -2.24896939e-06, -4.84753400e-06, -7.50925392e-06, 1.13487244e-04, -3.26065347e-05, -4.61697578e-04, 3.01518885e-04, -1.26068108e-03, 9.04885674e-05, 1.70146358e+00],
[ 5.37528796e-03, -1.88907310e-02, 1.62140822e+00, -1.66513994e-02, -1.20639233e-02, 2.17823274e-02, 2.43900251e-03, 2.59594470e-02, -1.06401583e-02, 1.67924047e-01, 2.05849938e-06, 2.44657276e-05, 1.31483248e-05, -1.48024876e-04, 6.39846548e-05, 5.71987592e-04, 3.61403363e-06, 1.01836876e-03, 1.68582925e-03, -3.52031112e+00],
[-3.68031830e-01, -1.11321688e+00, -1.66513994e-02, 3.52388382e+00, -3.89591247e-01, -8.86485994e-01, 3.69597822e-01, -2.19188643e+00, -4.24265265e-01, 3.27274895e+00, -2.42434326e-05, -2.98958272e-04, -1.23632257e-04, 1.87904201e-03, -5.15639316e-04, -5.35614789e-03, 2.87756487e-03, -1.53830992e-02, -4.53177467e-03, 2.52128124e+01],
[ 7.76133314e-02, 2.24941388e-01, -1.20639233e-02, -3.89591247e-01, 1.93465853e+00, 2.48911351e-01, -2.28518508e-02, 4.19600159e-01, 1.34637073e-01, 4.95270640e-01, 2.64251721e-05, 3.15882266e-04, 1.41064636e-04, -1.91451237e-03, 6.46093860e-04, 6.21308386e-03, -2.34547607e-03, 1.57597680e-02, 8.13580025e-03, -3.23177299e+01],
[ 1.95910275e-01, 5.75562716e-01, 2.17823274e-02, -8.86485994e-01, 2.48911351e-01, 2.98399925e+00, -4.25876319e-01, 9.31742251e-01, -4.73557204e-01, 9.36427712e-01, 1.56900787e-05, 1.95616623e-04, 9.17701400e-05, -1.15976110e-03, 4.29244712e-04, 4.25980613e-03, -8.00579088e-04, 8.80736019e-03, 8.03760253e-03, -2.24972534e+01],
[-3.12402956e-02, -1.19260252e-01, 2.43900251e-03, 3.69597822e-01, -2.28518508e-02, -4.25876319e-01, 3.06580067e+00, -6.48983538e-01, 4.06311929e-01, 1.37230349e+00, 7.44865829e-05, 8.81816493e-04, 3.89039924e-04, -5.36584575e-03, 1.68655277e-03, 1.71863157e-02, -5.95929101e-03, 4.48325910e-02, 2.24911068e-02, -8.79636688e+01],
[ 3.67419720e-01, 1.07684803e+00, 2.59594470e-02, -2.19188643e+00, 4.19600159e-01, 9.31742251e-01, -6.48983538e-01, 5.48731613e+00, -6.58562034e-02, -2.43317747e+00, -2.27659766e-05, -2.60235975e-04, -1.24252401e-04, 1.58591196e-03, -5.84547408e-04, -5.63996658e-03, 1.13803300e-03, -1.05895922e-02, -1.21539282e-02, 2.72431316e+01],
[ 1.22131474e-01, 3.36590528e-01, -1.06401583e-02, -4.24265265e-01, 1.34637073e-01, -4.73557204e-01, 4.06311929e-01, -6.58562034e-02, 5.43583727e+00, -2.92779040e+00, 2.11733277e-06, 3.67360190e-05, 3.78659461e-05, -1.60590280e-04, 2.01643910e-04, 1.61331519e-03, 1.96514279e-03, -1.50358269e-03, 1.71575230e-02, -1.18784990e+01],
[-2.53661489e+00, -5.06800270e+00, 1.67924047e-01, 3.27274895e+00, 4.95270640e-01, 9.36427712e-01, 1.37230349e+00, -2.43317747e+00, -2.92779040e+00, 8.76316345e+02, -1.22874253e-03, -1.46462396e-02, -6.36728108e-03, 8.91621411e-02, -2.74890810e-02, -2.80071974e-01, 1.21408194e-01, -7.62682676e-01, -3.09633642e-01, 1.52729504e+03],
[ 2.17903289e-06, -2.24896939e-06, 2.05849938e-06, -2.42434326e-05, 2.64251721e-05, 1.56900787e-05, 7.44865829e-05, -2.27659766e-05, 2.11733277e-06, -1.22874253e-03, 1.33819203e-03, 1.58352833e-02, 6.75189588e-03, -9.70604271e-02, 2.87030358e-02, 2.97664732e-01, -1.43668085e-01, 8.35517287e-01, 2.84892887e-01, -1.46552661e+03],
[ 5.94051089e-05, -4.84753400e-06, 2.44657276e-05, -2.98958272e-04, 3.15882266e-04, 1.95616623e-04, 8.81816493e-04, -2.60235975e-04, 3.67360190e-05, -1.46462396e-02, 1.58352833e-02, 1.87388241e-01, 7.99007788e-02, -1.14855564e+00, 3.39666307e-01, 3.52251792e+00, -1.69990075e+00, 9.88676643e+00, 3.37287593e+00, -1.73432988e+04],
[ 2.33369647e-05, -7.50925392e-06, 1.31483248e-05, -1.23632257e-04, 1.41064636e-04, 9.17701400e-05, 3.89039924e-04, -1.24252401e-04, 3.78659461e-05, -6.36728108e-03, 6.75189588e-03, 7.99007788e-02, 3.40821631e-02, -4.89710629e-01, 1.44927666e-01, 1.50249171e+00, -7.23849893e-01, 4.21417427e+00, 1.44301021e+00, -7.40001318e+03],
[-3.33767384e-04, 1.13487244e-04, -1.48024876e-04, 1.87904201e-03, -1.91451237e-03, -1.15976110e-03, -5.36584575e-03, 1.58591196e-03, -1.60590280e-04, 8.91621411e-02, -9.70604271e-02, -1.14855564e+00, -4.89710629e-01, 7.04013824e+00, -2.08170390e+00, -2.15894566e+01, 1.04222698e+01, -6.06031189e+01, -2.06613407e+01, 1.06286359e+05],
[ 7.98902474e-05, -3.26065347e-05, 6.39846548e-05, -5.15639316e-04, 6.46093860e-04, 4.29244712e-04, 1.68655277e-03, -5.84547408e-04, 2.01643910e-04, -2.74890810e-02, 2.87030358e-02, 3.39666307e-01, 1.44927666e-01, -2.08170390e+00, 6.16610885e-01, 6.38967180e+00, -3.07360172e+00, 1.79079399e+01, 6.14985657e+00, -3.14770039e+04],
[ 9.64686275e-04, -4.61697578e-04, 5.71987592e-04, -5.35614789e-03, 6.21308386e-03, 4.25980613e-03, 1.71863157e-02, -5.63996658e-03, 1.61331519e-03, -2.80071974e-01, 2.97664732e-01, 3.52251792e+00, 1.50249171e+00, -2.15894566e+01, 6.38967180e+00, 6.62452698e+01, -3.19060764e+01, 1.85773941e+02, 6.36354866e+01, -3.26240062e+05],
[-4.86891717e-04, 3.01372260e-04, 3.62051651e-06, 2.87755951e-03, -2.34574080e-03, -8.00948590e-04, -5.95919322e-03, 1.13837048e-03, 1.96490344e-03, 1.21410340e-01, -1.43668100e-01, -1.69990110e+00, -7.23848820e-01, 1.04222832e+01, -3.07360816e+00, -3.19060802e+01, 1.55375633e+01, -8.98448792e+01, -3.00951061e+01, 1.56913156e+05],
[ 2.37344205e-03, -1.26130879e-03, 1.01752952e-03, -1.53824687e-02, 1.57596469e-02, 8.80961865e-03, 4.48332652e-02, -1.05940849e-02, -1.50594860e-03, -7.62646675e-01, 8.35519135e-01, 9.88673401e+00, 4.21416092e+00, -6.06031990e+01, 1.79079304e+01, 1.85773956e+02, -8.98448792e+01, 5.21939880e+02, 1.77149094e+02, -9.14563250e+05],
[ 4.47247177e-04, 9.06437635e-05, 1.68623496e-03, -4.53158468e-03, 8.13601911e-03, 8.03825632e-03, 2.24906951e-02, -1.21534169e-02, 1.71582494e-02, -3.09652269e-01, 2.84893215e-01, 3.37288260e+00, 1.44301367e+00, -2.06612988e+01, 6.14986134e+00, 6.36355400e+01, -3.00951061e+01, 1.77149094e+02, 6.36062050e+01, -3.14022406e+05],
[-4.30882263e+00, 1.70046997e+00, -3.52318573e+00, 2.52160645e+01, -3.23214111e+01, -2.25007477e+01, -8.79653473e+01, 2.72448120e+01, -1.18774796e+01, 1.52729688e+03, -1.46553210e+03, -1.73432891e+04, -7.40001074e+03, 1.06286719e+05, -3.14769004e+04, -3.26239531e+05, 1.56913156e+05, -9.14563250e+05, -3.14022406e+05, 1.60835072e+09]], dtype=np.float32)
A = A + A.T # symmetrize
L, Q = np.linalg.eigh(A)
meo = Q @ np.diag(L) @ Q.T
print('numpy:', np.max(np.abs(Q @ np.diag(L) @ Q.T - A) / A)) # 1e-5 GOOD
L, Q = torch.linalg.eigh(torch.from_numpy(A))
print('torch cpu:', torch.max(torch.abs(Q @ torch.diag(L) @ Q.T - A) / A).item()) # 1584 BAD
L, Q = torch.linalg.eigh(torch.from_numpy(A), UPLO="U")
print('torch cpu upper:', torch.max(torch.abs(Q @ torch.diag(L) @ Q.T - A) / A).item()) # 0.11 OKAY
A_cuda = torch.from_numpy(A).to("cuda:0")
L, Q = torch.linalg.eigh(A_cuda)
print('torch gpu:', torch.max(torch.abs((Q @ torch.diag(L) @ Q.T) - A_cuda) / A_cuda).item()) # 18295 BAD
L, Q = torch.linalg.eigh(A_cuda, UPLO="U")
print('torch gpu upper:', torch.max(torch.abs((Q @ torch.diag(L) @ Q.T) - A_cuda) / A_cuda).item()) # 4687 BAD
Observed Behavior:
The relative error of torch.linalg.eigh results is significantly larger than numpy.linalg.eigh.
Using UPLO="U" improves results, but does not resolve issues on the GPU.
Some eigenvalues returned by torch are negative
Expected Behavior:
Results from torch.linalg.eigh should match the accuracy of numpy.linalg.eigh for symmetric matrices.
The eigenvalues shouldn't be negative
Versions
Collecting environment information...
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 19.1.0 (https://github.com/llvm/llvm-project.git a4bf6cd7cfb1a1421ba92bca9d017b49936c55e4)
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.10.12 (main, Nov 6 2024, 20:22:13) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-102-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.1.66
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA RTX A6000
GPU 1: NVIDIA RTX A6000
GPU 2: NVIDIA RTX A6000
GPU 3: NVIDIA RTX A6000
Nvidia driver version: 530.30.02
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 192
On-line CPU(s) list: 0-191
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7643 48-Core Processor
CPU family: 25
Model: 1
Thread(s) per core: 2
Core(s) per socket: 48
Socket(s): 2
Stepping: 1
Frequency boost: enabled
CPU max MHz: 3640.9170
CPU min MHz: 1500.0000
BogoMIPS: 4600.14
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm sme sev sev_es
Virtualization: AMD-V
L1d cache: 3 MiB (96 instances)
L1i cache: 3 MiB (96 instances)
L2 cache: 48 MiB (96 instances)
L3 cache: 512 MiB (16 instances)
NUMA node(s): 8
NUMA node0 CPU(s): 0-11,96-107
NUMA node1 CPU(s): 12-23,108-119
NUMA node2 CPU(s): 24-35,120-131
NUMA node3 CPU(s): 36-47,132-143
NUMA node4 CPU(s): 48-59,144-155
NUMA node5 CPU(s): 60-71,156-167
NUMA node6 CPU(s): 72-83,168-179
NUMA node7 CPU(s): 84-95,180-191
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] mypy==1.11.2
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.4
[pip3] numpy-groupies==0.11.2
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] nvtx==0.2.10
[pip3] onnx==1.16.1
[pip3] onnx2torch==1.5.13
[pip3] onnxruntime-gpu==1.18.0
[pip3] pynvjitlink-cu12==0.4.0
[pip3] torch==2.5.1
[pip3] torch-summary==1.4.5
[pip3] torch-tb-profiler==0.4.3
[pip3] torchaudio==2.5.1
[pip3] torchinfo==1.8.0
[pip3] torchmetrics==1.3.0.post0
[pip3] torchvision==0.20.1
[pip3] triton==3.1.0
[conda] numpy 2.1.3 pypi_0 pypi
cc @jianyuh @nikitaved @pearu @mruberry @walterddr @xwang233 @lezcano