Skip to content

Shape error in torch.linalg.solve backend #89761

@reverendbedford

Description

@reverendbedford

🐛 Describe the bug

The backend returns the wrong shape for the gradient of the results of torch.linalg.solve for a degenerate case where you are inverting a (batch of) 1x1 matrices. This only occurs in the most recent version 1.13 and not in 1.12.1. This only seems to occur in the cpu-only version of torch.

import torch

nbatch = 10
nvec = 1

x = torch.ones((nbatch, nvec))
p = torch.ones((nbatch,1,1), requires_grad = True)
g = torch.ones_like(x)

A = torch.rand((nbatch,nvec,nvec)) * p
y = torch.linalg.solve(A, x)

value = torch.autograd.grad(y, [p], g)

For 1.13 this gives an error:

Traceback (most recent call last):
  File "/home/user/temp/./demo.py", line 14, in <module>
    value = torch.autograd.grad(y, [p], g)
  File "/home/user/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 300, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function LinalgSolveExBackward0 returned an invalid gradient at index 0 - got [10, 1] but expected shape compatible with [10, 1, 1]

For 1.12.1 it returns no error and the correct result.

This only occurs for nvec=1, for nvec>1 there is no error.

Versions

Collecting environment information...
PyTorch version: 1.13.0+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.6 (main, Nov  2 2022, 18:53:38) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-52-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: 11.5.119
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: GPU 0: Quadro M6000 24GB
Nvidia driver version: 510.85.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.4
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.2.4
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.2.4
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.2.4
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.2.4
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.2.4
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.2.4
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.4
[pip3] torch==1.13.0+cpu
[pip3] torchaudio==0.13.0+cpu
[pip3] torchvision==0.14.0+cpu
[pip3] torchviz==0.0.2
[conda] Could not collect

cc @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano

Metadata

Metadata

Assignees

Labels

module: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions