Skip to content

lobpcg will randomly fail for the same input on CPU #88650

@fuzzyswan

Description

@fuzzyswan

🐛 Describe the bug

lobpcg will randomly fail for the same input on CPU

import torch

A = torch.tensor([[0.0100, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0100, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0100, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0100, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0100]])

for i in range(1000):
    print(i)
    (eigvals, eigvecs) = torch.lobpcg(A)

print(eigvals), print(eigvecs)
 Q = torch.orgqr(*torch.geqrf(A))
RuntimeError: torch.linalg.householder_product: input.shape[-2] must be greater than or equal to input.shape[-1]

By contrast, it will always succeed on CUDA

import torch

A = torch.tensor([[0.0100, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0100, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0100, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0100, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0100]]).cuda()

for i in range(1000):
    print(i)
    (eigvals, eigvecs) = torch.lobpcg(A)

print(eigvals), print(eigvecs)
tensor([0.0100], device='cuda:0')
tensor([[-0.3277],
        [-0.0372],
        [ 0.1163],
        [-0.4553],
        [-0.8188]], device='cuda:0')

Versions

Collecting environment information...
PyTorch version: 1.12.1+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.22.6
Libc version: glibc-2.26

Python version: 3.7.15 (default, Oct 12 2022, 19:14:55) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.10.133+-x86_64-with-Ubuntu-18.04-bionic
Is CUDA available: True
CUDA runtime version: 11.2.152
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 460.32.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.1.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.21.6
[pip3] torch==1.12.1+cu113
[pip3] torchaudio==0.12.1+cu113
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.13.1
[pip3] torchvision==0.13.1+cu113
[conda] Could not collect

cc @VitalyFedyunin @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cpuCPU specific problem (e.g., perf, algorithm)module: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions