-
Notifications
You must be signed in to change notification settings - Fork 146
Closed
Labels
Description
Describe the issue:
This is identical to #382. When assume_a != 'gen'
in pt.linalg.solve
, the solve_triangular
function is incorrectly invoked, and the resulting output is not correct.
Reproducable code example:
import pytensor
import pytensor.tensor as pt
import numpy as np
X = pt.dmatrix('X')
b = pt.dmatrix('b')
y1 = pt.linalg.solve(X, b, assume_a='gen')
y2 = pt.linalg.solve(X, b, assume_a='sym')
f1 = pytensor.function([X, b], y1)
f2 = pytensor.function([X, b], y2)
f1_nb = pytensor.function([X, b], y1, mode='NUMBA')
f2_nb = pytensor.function([X, b], y2, mode='NUMBA')
X_sym = np.random.normal(size=(3, 3))
X_sym = X_sym @ X_sym.T
X_inv_1 = f1(X_sym, np.eye(3))
X_inv_2 = f2(X_sym, np.eye(3))
X_inv_1_nb = f1_nb(X_sym, np.eye(3))
X_inv_2_nb = f2_nb(X_sym, np.eye(3))
# Passes, C backend, assume_a = 'gen'
np.testing.assert_allclose(X_inv_1 @ X_sym, np.eye(3), atol=1e-12)
# Passes, C backend, assume_a = 'sym'
np.testing.assert_allclose(X_inv_2 @ X_sym, np.eye(3), atol=1e-12)
# Passes, Numba backend, assume_a = 'gen'
np.testing.assert_allclose(X_inv_1_nb @ X_sym, np.eye(3), atol=1e-12)
# Fails, Numba backend, assuma_a = 'sym'
np.testing.assert_allclose(X_inv_2_nb @ X_sym, np.eye(3), atol=1e-12)
Error message:
AssertionError:
Not equal to tolerance rtol=1e-07, atol=1e-12
Mismatched elements: 6 / 9 (66.7%)
Max absolute difference: 0.67369019
Max relative difference: 0.52399333
x: array([[ 4.760067e-01, -2.459194e-01, 9.939378e-18],
[-3.325732e-01, 6.636944e-01, 1.353323e-18],
[-1.317192e-01, 6.736902e-01, 1.000000e+00]])
y: array([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
PyTensor version information:
Pytensor version 2.11.1
Context for the issue:
No response