Skip to content

BUG: pt.linalg.solve returns incorrect results when mode = "NUMBA" #422

@jessegrabowski

Description

@jessegrabowski

Describe the issue:

This is identical to #382. When assume_a != 'gen' in pt.linalg.solve, the solve_triangular function is incorrectly invoked, and the resulting output is not correct.

Reproducable code example:

import pytensor
import pytensor.tensor as pt
import numpy as np

X = pt.dmatrix('X')
b = pt.dmatrix('b')
y1 = pt.linalg.solve(X, b, assume_a='gen')
y2 = pt.linalg.solve(X, b, assume_a='sym')

f1 = pytensor.function([X, b], y1)
f2 = pytensor.function([X, b], y2)
f1_nb = pytensor.function([X, b], y1, mode='NUMBA')
f2_nb = pytensor.function([X, b], y2, mode='NUMBA')

X_sym = np.random.normal(size=(3, 3))
X_sym = X_sym @ X_sym.T
X_inv_1 = f1(X_sym, np.eye(3))
X_inv_2 = f2(X_sym, np.eye(3))
X_inv_1_nb = f1_nb(X_sym, np.eye(3))
X_inv_2_nb = f2_nb(X_sym, np.eye(3))


# Passes, C backend, assume_a = 'gen'
np.testing.assert_allclose(X_inv_1 @ X_sym, np.eye(3), atol=1e-12)

# Passes, C backend, assume_a = 'sym'
np.testing.assert_allclose(X_inv_2 @ X_sym, np.eye(3), atol=1e-12)

# Passes, Numba backend, assume_a = 'gen'
np.testing.assert_allclose(X_inv_1_nb @ X_sym, np.eye(3), atol=1e-12)

# Fails, Numba backend, assuma_a = 'sym'
np.testing.assert_allclose(X_inv_2_nb @ X_sym, np.eye(3), atol=1e-12)

Error message:

AssertionError: 
Not equal to tolerance rtol=1e-07, atol=1e-12

Mismatched elements: 6 / 9 (66.7%)
Max absolute difference: 0.67369019
Max relative difference: 0.52399333
 x: array([[ 4.760067e-01, -2.459194e-01,  9.939378e-18],
       [-3.325732e-01,  6.636944e-01,  1.353323e-18],
       [-1.317192e-01,  6.736902e-01,  1.000000e+00]])
 y: array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]])

PyTensor version information:

Pytensor version 2.11.1

Context for the issue:

No response

Metadata

Metadata

Labels

bugSomething isn't workinglinalgLinear algebranumba

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions