diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 15e02265f1..1bf9ad65c1 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -18,7 +18,7 @@ from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType -from pytensor.npy_2_compat import numpy_version, using_numpy_2 +from pytensor.npy_2_compat import normalize_axis_tuple, numpy_version, using_numpy_2 from pytensor.printing import Printer, pprint, set_precedence from pytensor.scalar.basic import ScalarConstant, ScalarVariable from pytensor.tensor import ( @@ -3369,11 +3369,12 @@ def flip( if axis is None: index = ((slice(None, None, -1)),) * arr.ndim else: - if isinstance(axis, int): - axis = (axis,) + normalized_axis = normalize_axis_tuple(axis, arr.ndim) index = tuple( [ - slice(None, None, -1) if i in axis else slice(None, None, None) + slice(None, None, -1) + if i in normalized_axis + else slice(None, None, None) for i in range(arr.ndim) ] ) @@ -3382,9 +3383,9 @@ def flip( __all__ = [ - "take", "flip", - "slice_at_axis", "inc_subtensor", "set_subtensor", + "slice_at_axis", + "take", ] diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index ead94371d3..9333b4c32e 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -3147,6 +3147,27 @@ def test_flip(size: tuple[int]): f = pytensor.function([x_pt], z, mode="FAST_COMPILE") np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL) + # Test single negative axis + for axis in range(-x.ndim, 0): + expected = np.flip(x, axis=axis) + z = flip(x_pt, axis=axis) + f = pytensor.function([x_pt], z, mode="FAST_COMPILE") + np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL) + + # Test tuple with negative axes + if x.ndim > 1: + expected = np.flip(x, axis=(-1, -2)) + z = flip(x_pt, axis=(-1, -2)) + f = pytensor.function([x_pt], z, mode="FAST_COMPILE") + np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL) + + # Test mixed positive and negative axes + if x.ndim >= 2: + expected = np.flip(x, axis=(0, -1)) + z = flip(x_pt, axis=(0, -1)) + f = pytensor.function([x_pt], z, mode="FAST_COMPILE") + np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL) + class TestBenchmarks: @pytest.mark.parametrize(