From 85ecc0771e06455826c6406d0f87bd3a9c1152da Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 5 Oct 2025 01:00:00 +0200 Subject: [PATCH] Add logsumexp to xtensor --- pytensor/xtensor/math.py | 5 +++++ tests/xtensor/test_math.py | 25 ++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py index af453d16e9..b70f3bae3e 100644 --- a/pytensor/xtensor/math.py +++ b/pytensor/xtensor/math.py @@ -512,6 +512,11 @@ def softmax(x, dim=None): return exp_x / exp_x.sum(dim=dim) +def logsumexp(x, dim=None): + """Compute the logsumexp of an XTensorVariable along a specified dimension.""" + return log(exp(x).sum(dim=dim)) + + class Dot(XOp): """Matrix multiplication between two XTensorVariables. diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index afd720ff8a..48d91771be 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -7,6 +7,7 @@ import inspect import numpy as np +from scipy.special import logsumexp as scipy_logsumexp from xarray import DataArray import pytensor.scalar as ps @@ -14,7 +15,7 @@ from pytensor import function from pytensor.scalar import ScalarOp from pytensor.xtensor.basic import rename -from pytensor.xtensor.math import add, exp +from pytensor.xtensor.math import add, exp, logsumexp from pytensor.xtensor.type import xtensor from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function @@ -152,6 +153,28 @@ def test_cast(): yc64.astype("float64") +@pytest.mark.parametrize( + ["shape", "dims", "axis"], + [ + ((3, 4), ("a", "b"), None), + ((3, 4), "a", 0), + ((3, 4), "b", 1), + ], +) +def test_logsumexp(shape, dims, axis): + scipy_inp = np.zeros(shape) + scipy_out = scipy_logsumexp(scipy_inp, axis=axis) + + pytensor_inp = DataArray(scipy_inp, dims=("a", "b")) + f = function([], logsumexp(pytensor_inp, dim=dims)) + pytensor_out = f() + + np.testing.assert_array_almost_equal( + pytensor_out, + scipy_out, + ) + + def test_dot(): """Test basic dot product operations.""" # Test matrix-vector dot product (with multiple-letter dim names)