diff --git a/pytensor/link/jax/dispatch/slinalg.py b/pytensor/link/jax/dispatch/slinalg.py index 7e1376abba..de37382833 100644 --- a/pytensor/link/jax/dispatch/slinalg.py +++ b/pytensor/link/jax/dispatch/slinalg.py @@ -177,7 +177,8 @@ def jax_funcify_QR(op, **kwargs): mode = op.mode def qr(x, mode=mode): - return jax.scipy.linalg.qr(x, mode=mode) + res = jax.scipy.linalg.qr(x, mode=mode) + return res[0] if len(res) == 1 else res return qr diff --git a/tests/link/jax/test_slinalg.py b/tests/link/jax/test_slinalg.py index 6bc10fc634..49490994b1 100644 --- a/tests/link/jax/test_slinalg.py +++ b/tests/link/jax/test_slinalg.py @@ -370,3 +370,15 @@ def test_jax_expm(): out = pt_slinalg.expm(A) compare_jax_and_py([A], [out], [A_val]) + + +@pytest.mark.parametrize("mode", ["full", "r"]) +def test_jax_qr(mode): + # "full" and "r" modes are tested because "full" returns two matrices (Q, R), while (R,) returns only one. + # Pytensor does not return a tuple when only one output is expected. + rng = np.random.default_rng(utt.fetch_seed()) + A = pt.tensor(name="A", shape=(5, 5)) + A_val = rng.normal(size=(5, 5)).astype(config.floatX) + out = pt_slinalg.qr(A, mode=mode) + + compare_jax_and_py([A], out, [A_val])