From 3453c6b2d660a6e8d581a41a24b6cd9068ab6faa Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 9 Oct 2025 14:36:23 +0200 Subject: [PATCH] Fix failing JAX Split test --- pytensor/link/jax/dispatch/tensor_basic.py | 18 ++++++++++-------- tests/link/jax/test_tensor_basic.py | 9 +++++---- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index e03462bf78..04eb5fece0 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -119,20 +119,22 @@ def jax_funcify_Split(op: Split, node, **kwargs): def split(x, axis, splits): if constant_axis is not None: axis = constant_axis + if len(splits) != op.len_splits: + raise ValueError("Length of splits is not equal to n_splits") + if constant_splits is not None: splits = constant_splits cumsum_splits = np.cumsum(splits[:-1]) + if (splits < 0).any(): + raise ValueError("Split sizes cannot be negative") else: cumsum_splits = jnp.cumsum(splits[:-1]) - if len(splits) != op.len_splits: - raise ValueError("Length of splits is not equal to n_splits") - if np.sum(splits) != x.shape[axis]: - raise ValueError( - f"Split sizes do not sum up to input length along axis: {x.shape[axis]}" - ) - if np.any(splits < 0): - raise ValueError("Split sizes cannot be negative") + if constant_axis is not None and constant_splits is not None: + if splits.sum() != x.shape[axis]: + raise ValueError( + f"Split sizes do not sum up to input length along axis: {x.shape[axis]}" + ) return jnp.split(x, cumsum_splits, axis=axis) diff --git a/tests/link/jax/test_tensor_basic.py b/tests/link/jax/test_tensor_basic.py index 1e1f496de1..430d6309e1 100644 --- a/tests/link/jax/test_tensor_basic.py +++ b/tests/link/jax/test_tensor_basic.py @@ -182,16 +182,17 @@ def test_jax_split_not_supported(self): UserWarning, match="Split node does not have constant split positions." ): fn = pytensor.function([a], a_splits, mode="JAX") - # It raises an informative ConcretizationTypeError, but there's an AttributeError that surpasses it - with pytest.raises(AttributeError): + # This test used to raise AttributeError in previous versions of JAX. + # Now it raises `TracerIntegerConversionError`. + # We accept both errors for backwards compatibility. + with pytest.raises((AttributeError, errors.TracerIntegerConversionError)): fn(np.zeros((6, 4), dtype=pytensor.config.floatX)) split_axis = iscalar("split_axis") a_splits = ptb.split(a, splits_size=[2, 4], n_splits=2, axis=split_axis) with pytest.warns(UserWarning, match="Split node does not have constant axis."): fn = pytensor.function([a, split_axis], a_splits, mode="JAX") - # Same as above, an AttributeError surpasses the `TracerIntegerConversionError` - # Both errors are included for backwards compatibility + # Same reasoning as above to accept both errors. with pytest.raises((AttributeError, errors.TracerIntegerConversionError)): fn(np.zeros((6, 6), dtype=pytensor.config.floatX), 0)