Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions pytensor/link/jax/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,20 +119,22 @@ def jax_funcify_Split(op: Split, node, **kwargs):
def split(x, axis, splits):
if constant_axis is not None:
axis = constant_axis
if len(splits) != op.len_splits:
raise ValueError("Length of splits is not equal to n_splits")

if constant_splits is not None:
splits = constant_splits
cumsum_splits = np.cumsum(splits[:-1])
if (splits < 0).any():
raise ValueError("Split sizes cannot be negative")
else:
cumsum_splits = jnp.cumsum(splits[:-1])

if len(splits) != op.len_splits:
raise ValueError("Length of splits is not equal to n_splits")
if np.sum(splits) != x.shape[axis]:
raise ValueError(
f"Split sizes do not sum up to input length along axis: {x.shape[axis]}"
)
if np.any(splits < 0):
raise ValueError("Split sizes cannot be negative")
if constant_axis is not None and constant_splits is not None:
if splits.sum() != x.shape[axis]:
raise ValueError(
f"Split sizes do not sum up to input length along axis: {x.shape[axis]}"
)

return jnp.split(x, cumsum_splits, axis=axis)

Expand Down
9 changes: 5 additions & 4 deletions tests/link/jax/test_tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,17 @@ def test_jax_split_not_supported(self):
UserWarning, match="Split node does not have constant split positions."
):
fn = pytensor.function([a], a_splits, mode="JAX")
# It raises an informative ConcretizationTypeError, but there's an AttributeError that surpasses it
with pytest.raises(AttributeError):
# This test used to raise AttributeError in previous versions of JAX.
# Now it raises `TracerIntegerConversionError`.
# We accept both errors for backwards compatibility.
with pytest.raises((AttributeError, errors.TracerIntegerConversionError)):
fn(np.zeros((6, 4), dtype=pytensor.config.floatX))

split_axis = iscalar("split_axis")
a_splits = ptb.split(a, splits_size=[2, 4], n_splits=2, axis=split_axis)
with pytest.warns(UserWarning, match="Split node does not have constant axis."):
fn = pytensor.function([a, split_axis], a_splits, mode="JAX")
# Same as above, an AttributeError surpasses the `TracerIntegerConversionError`
# Both errors are included for backwards compatibility
# Same reasoning as above to accept both errors.
with pytest.raises((AttributeError, errors.TracerIntegerConversionError)):
fn(np.zeros((6, 6), dtype=pytensor.config.floatX), 0)

Expand Down