From ab0396b68fadbc339a7ec865bddd91c6d8aa6661 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 4 Oct 2025 19:23:42 +0000 Subject: [PATCH 1/5] Initial plan From 553dba58584b1c77c27c410e554045e85d89fbb0 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 4 Oct 2025 19:30:08 +0000 Subject: [PATCH 2/5] Fix pt.flip to handle negative axis correctly using normalize_axis_tuple Co-authored-by: ricardoV94 <28983449+ricardoV94@users.noreply.github.com> --- pytensor/tensor/subtensor.py | 5 ++--- tests/tensor/test_subtensor.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 15e02265f1..188d50197d 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -18,7 +18,7 @@ from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType -from pytensor.npy_2_compat import numpy_version, using_numpy_2 +from pytensor.npy_2_compat import normalize_axis_tuple, numpy_version, using_numpy_2 from pytensor.printing import Printer, pprint, set_precedence from pytensor.scalar.basic import ScalarConstant, ScalarVariable from pytensor.tensor import ( @@ -3369,8 +3369,7 @@ def flip( if axis is None: index = ((slice(None, None, -1)),) * arr.ndim else: - if isinstance(axis, int): - axis = (axis,) + axis = normalize_axis_tuple(axis, arr.ndim) index = tuple( [ slice(None, None, -1) if i in axis else slice(None, None, None) diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index ead94371d3..6657605e1b 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -3148,6 +3148,38 @@ def test_flip(size: tuple[int]): np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL) +@pytest.mark.parametrize( + "size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"] +) +def test_flip_negative_axis(size: tuple[int]): + """Test that flip handles negative axis values correctly.""" + ATOL = RTOL = 1e-8 if config.floatX == "float64" else 1e-4 + + x = np.random.normal(size=size).astype(config.floatX) + x_pt = pytensor.tensor.tensor(shape=size, name="x") + + # Test single negative axis + for axis in range(-x.ndim, 0): + expected = np.flip(x, axis=axis) + z = flip(x_pt, axis=axis) + f = pytensor.function([x_pt], z, mode="FAST_COMPILE") + np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL) + + # Test tuple with negative axes + if x.ndim > 1: + expected = np.flip(x, axis=(-1, -2)) + z = flip(x_pt, axis=(-1, -2)) + f = pytensor.function([x_pt], z, mode="FAST_COMPILE") + np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL) + + # Test mixed positive and negative axes + if x.ndim >= 2: + expected = np.flip(x, axis=(0, -1)) + z = flip(x_pt, axis=(0, -1)) + f = pytensor.function([x_pt], z, mode="FAST_COMPILE") + np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL) + + class TestBenchmarks: @pytest.mark.parametrize( "static_shape", (False, True), ids=lambda x: f"static_shape={x}" From b2b5e96c9217316de4b530b0d74be1b135c7df48 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 4 Oct 2025 19:43:20 +0000 Subject: [PATCH 3/5] Expand existing test_flip function to include negative axis tests Co-authored-by: ricardoV94 <28983449+ricardoV94@users.noreply.github.com> --- tests/tensor/test_subtensor.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 6657605e1b..9333b4c32e 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -3147,17 +3147,6 @@ def test_flip(size: tuple[int]): f = pytensor.function([x_pt], z, mode="FAST_COMPILE") np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL) - -@pytest.mark.parametrize( - "size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"] -) -def test_flip_negative_axis(size: tuple[int]): - """Test that flip handles negative axis values correctly.""" - ATOL = RTOL = 1e-8 if config.floatX == "float64" else 1e-4 - - x = np.random.normal(size=size).astype(config.floatX) - x_pt = pytensor.tensor.tensor(shape=size, name="x") - # Test single negative axis for axis in range(-x.ndim, 0): expected = np.flip(x, axis=axis) From 2fca051d87c83a118aca80529e1957f54234bef9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 4 Oct 2025 20:43:40 +0000 Subject: [PATCH 4/5] Fix mypy error by using separate variable for normalized axis Co-authored-by: ricardoV94 <28983449+ricardoV94@users.noreply.github.com> --- pytensor/tensor/subtensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 188d50197d..d4acd33a4d 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -3369,10 +3369,10 @@ def flip( if axis is None: index = ((slice(None, None, -1)),) * arr.ndim else: - axis = normalize_axis_tuple(axis, arr.ndim) + normalized_axis = normalize_axis_tuple(axis, arr.ndim) index = tuple( [ - slice(None, None, -1) if i in axis else slice(None, None, None) + slice(None, None, -1) if i in normalized_axis else slice(None, None, None) for i in range(arr.ndim) ] ) From a7a183a4bdbcd911513eb86c7a257a5e54361ea6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 4 Oct 2025 21:13:38 +0000 Subject: [PATCH 5/5] Fix ruff formatting issues Co-authored-by: jessegrabowski <48652735+jessegrabowski@users.noreply.github.com> --- pytensor/tensor/subtensor.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index d4acd33a4d..1bf9ad65c1 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -3372,7 +3372,9 @@ def flip( normalized_axis = normalize_axis_tuple(axis, arr.ndim) index = tuple( [ - slice(None, None, -1) if i in normalized_axis else slice(None, None, None) + slice(None, None, -1) + if i in normalized_axis + else slice(None, None, None) for i in range(arr.ndim) ] ) @@ -3381,9 +3383,9 @@ def flip( __all__ = [ - "take", "flip", - "slice_at_axis", "inc_subtensor", "set_subtensor", + "slice_at_axis", + "take", ]