Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 141 additions & 5 deletions pytensor/link/mlx/dispatch/subtensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from copy import deepcopy

import numpy as np

from pytensor.link.mlx.dispatch.basic import mlx_funcify
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
Expand All @@ -13,12 +15,119 @@
from pytensor.tensor.type_other import MakeSlice


def normalize_indices_for_mlx(ilist, idx_list):
"""Convert indices to MLX-compatible format.

MLX has strict requirements for indexing:
- Integer indices must be Python int, not np.int64 or other NumPy integer types
- Slice components (start, stop, step) must be Python int or None, not np.int64
- MLX arrays created from scalars need to be converted back to Python int
- Array indices for advanced indexing are handled separately

This function converts all integer-like indices and slice components to Python int
while preserving None values and passing through array indices unchanged.

Parameters
----------
ilist : tuple
Runtime index values to be passed to indices_from_subtensor
idx_list : tuple
Static index specification from the Op's idx_list attribute

Returns
-------
tuple
Normalized indices compatible with MLX array indexing

Examples
--------
>>> # Single np.int64 index converted to Python int
>>> normalize_indices_for_mlx((np.int64(1),), (True,))
(1,)

>>> # Slice with np.int64 components
>>> indices = indices_from_subtensor(
... (np.int64(0), np.int64(2)), (slice(None, None),)
... )
>>> # After normalization, slice components are Python int

Notes
-----
This conversion is necessary because MLX's C++ indexing implementation
does not recognize NumPy scalar types, raising ValueError when encountered.
Additionally, mlx_typify converts NumPy scalars to MLX arrays, which also
need to be converted back to Python int for use in indexing operations.
Converting to Python int is zero-cost for Python int inputs and minimal
overhead for NumPy scalars and MLX scalar arrays.
"""
import mlx.core as mx

def normalize_element(element):
"""Convert a single index element to MLX-compatible format."""
if element is None:
# None is valid in slices (e.g., x[None:5] or x[:None])
return None
elif isinstance(element, slice):
# Recursively normalize slice components
return slice(
normalize_element(element.start),
normalize_element(element.stop),
normalize_element(element.step),
)
elif isinstance(element, mx.array):
# MLX arrays from mlx_typify need special handling
# If it's a 0-d array (scalar), convert to Python int/float
if element.ndim == 0:
# Extract the scalar value
item = element.item()
# Convert to Python int if it's an integer type
if element.dtype in (
mx.int8,
mx.int16,
mx.int32,
mx.int64,
mx.uint8,
mx.uint16,
mx.uint32,
mx.uint64,
):
return int(item)
else:
return float(item)
else:
# Multi-dimensional array for advanced indexing - pass through
return element
elif isinstance(element, (np.integer, np.floating)):
# Convert NumPy scalar to Python int/float
# This handles np.int64, np.int32, np.float64, etc.
return int(element) if isinstance(element, np.integer) else float(element)
elif isinstance(element, (int, float)):
# Python int/float are already compatible
return element
else:
# Pass through other types (arrays for advanced indexing, etc.)
return element

# Get indices from PyTensor's subtensor utility
raw_indices = indices_from_subtensor(ilist, idx_list)

# Normalize each index element
normalized = tuple(normalize_element(idx) for idx in raw_indices)

return normalized


@mlx_funcify.register(Subtensor)
def mlx_funcify_Subtensor(op, node, **kwargs):
"""MLX implementation of Subtensor operation.

Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX.
"""
idx_list = getattr(op, "idx_list", None)

def subtensor(x, *ilists):
indices = indices_from_subtensor([int(element) for element in ilists], idx_list)
# Normalize indices to handle np.int64 and other NumPy types
indices = normalize_indices_for_mlx(ilists, idx_list)
if len(indices) == 1:
indices = indices[0]

Expand All @@ -30,10 +139,16 @@ def subtensor(x, *ilists):
@mlx_funcify.register(AdvancedSubtensor)
@mlx_funcify.register(AdvancedSubtensor1)
def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
"""MLX implementation of AdvancedSubtensor operation.

Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX,
including handling np.int64 in mixed basic/advanced indexing scenarios.
"""
idx_list = getattr(op, "idx_list", None)

def advanced_subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list)
# Normalize indices to handle np.int64 and other NumPy types
indices = normalize_indices_for_mlx(ilists, idx_list)
if len(indices) == 1:
indices = indices[0]

Expand All @@ -45,6 +160,11 @@ def advanced_subtensor(x, *ilists):
@mlx_funcify.register(IncSubtensor)
@mlx_funcify.register(AdvancedIncSubtensor1)
def mlx_funcify_IncSubtensor(op, node, **kwargs):
"""MLX implementation of IncSubtensor operation.

Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX.
Handles both set_instead_of_inc=True (assignment) and False (increment).
"""
idx_list = getattr(op, "idx_list", None)

if getattr(op, "set_instead_of_inc", False):
Expand All @@ -64,7 +184,9 @@ def mlx_fn(x, indices, y):
return x

def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):
indices = indices_from_subtensor(ilist, idx_list)
# Normalize indices to handle np.int64 and other NumPy types
indices = normalize_indices_for_mlx(ilist, idx_list)

if len(indices) == 1:
indices = indices[0]

Expand All @@ -75,6 +197,13 @@ def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):

@mlx_funcify.register(AdvancedIncSubtensor)
def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs):
"""MLX implementation of AdvancedIncSubtensor operation.

Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX.
Note: For advanced indexing, ilist contains the actual array indices.
"""
idx_list = getattr(op, "idx_list", None)

if getattr(op, "set_instead_of_inc", False):

def mlx_fn(x, indices, y):
Expand All @@ -91,8 +220,15 @@ def mlx_fn(x, indices, y):
x[indices] += y
return x

def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn):
return mlx_fn(x, ilist, y)
def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):
# Normalize indices to handle np.int64 and other NumPy types
indices = normalize_indices_for_mlx(ilist, idx_list)

# For advanced indexing, if we have a single tuple of indices, unwrap it
if len(indices) == 1:
indices = indices[0]

return mlx_fn(x, indices, y)

return advancedincsubtensor

Expand Down
167 changes: 165 additions & 2 deletions tests/link/mlx/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,19 @@ def test_mlx_IncSubtensor_increment():
assert not out_pt.owner.op.set_instead_of_inc
compare_mlx_and_py([], [out_pt], [])

# Increment slice
out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, 2:], st_pt)
compare_mlx_and_py([], [out_pt], [])

out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, -3:], st_pt)
compare_mlx_and_py([], [out_pt], [])

out_pt = pt_subtensor.inc_subtensor(x_pt[::2, ::2, ::2], st_pt)
compare_mlx_and_py([], [out_pt], [])

out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, :], st_pt)
compare_mlx_and_py([], [out_pt], [])


def test_mlx_AdvancedIncSubtensor_set():
"""Test advanced set operations using AdvancedIncSubtensor."""
Expand Down Expand Up @@ -232,9 +245,12 @@ def test_mlx_subtensor_edge_cases():
compare_mlx_and_py([], [out_pt], [])


@pytest.mark.xfail(reason="MLX indexing with tuples not yet supported")
def test_mlx_subtensor_with_variables():
"""Test subtensor operations with PyTensor variables as inputs."""
"""Test subtensor operations with PyTensor variables as inputs.

This test now works thanks to the fix for np.int64 indexing, which also
handles the conversion of MLX scalar arrays in slice components.
"""
# Test with variable arrays (not constants)
x_pt = pt.matrix("x", dtype="float32")
y_pt = pt.vector("y", dtype="float32")
Expand All @@ -245,3 +261,150 @@ def test_mlx_subtensor_with_variables():
# Set operation with variables
out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], y_pt)
compare_mlx_and_py([x_pt, y_pt], [out_pt], [x_np, y_np])


def test_mlx_subtensor_with_numpy_int64():
"""Test Subtensor operations with np.int64 indices.

This tests the fix for MLX's strict requirement that indices must be
Python int, not np.int64 or other NumPy integer types.
"""
# Test data
x_np = np.arange(12, dtype=np.float32).reshape((3, 4))
x_pt = pt.constant(x_np)

# Single np.int64 index - this was failing before the fix
idx = np.int64(1)
out_pt = x_pt[idx]
compare_mlx_and_py([], [out_pt], [])

# Multiple np.int64 indices
out_pt = x_pt[np.int64(1), np.int64(2)]
compare_mlx_and_py([], [out_pt], [])

# Negative np.int64 index
out_pt = x_pt[np.int64(-1)]
compare_mlx_and_py([], [out_pt], [])

# Mixed Python int and np.int64
out_pt = x_pt[1, np.int64(2)]
compare_mlx_and_py([], [out_pt], [])


def test_mlx_subtensor_slices_with_numpy_int64():
"""Test Subtensor with slices containing np.int64 components.

This tests that slice start/stop/step values can be np.int64.
"""
x_np = np.arange(20, dtype=np.float32)
x_pt = pt.constant(x_np)

# Slice with np.int64 start
out_pt = x_pt[np.int64(2) :]
compare_mlx_and_py([], [out_pt], [])

# Slice with np.int64 stop
out_pt = x_pt[: np.int64(5)]
compare_mlx_and_py([], [out_pt], [])

# Slice with np.int64 start and stop
out_pt = x_pt[np.int64(2) : np.int64(8)]
compare_mlx_and_py([], [out_pt], [])

# Slice with np.int64 step
out_pt = x_pt[:: np.int64(2)]
compare_mlx_and_py([], [out_pt], [])

# Slice with all np.int64 components
out_pt = x_pt[np.int64(1) : np.int64(10) : np.int64(2)]
compare_mlx_and_py([], [out_pt], [])

# Negative np.int64 in slice
out_pt = x_pt[np.int64(-5) :]
compare_mlx_and_py([], [out_pt], [])


def test_mlx_incsubtensor_with_numpy_int64():
"""Test IncSubtensor (set/inc) with np.int64 indices and slices.

This is the main test for the reported issue with inc_subtensor.
"""
# Test data
x_np = np.arange(12, dtype=np.float32).reshape((3, 4))
x_pt = pt.constant(x_np)
y_pt = pt.as_tensor_variable(np.array(10.0, dtype=np.float32))

# Set with np.int64 index
out_pt = pt_subtensor.set_subtensor(x_pt[np.int64(1), np.int64(2)], y_pt)
compare_mlx_and_py([], [out_pt], [])

# Increment with np.int64 index
out_pt = pt_subtensor.inc_subtensor(x_pt[np.int64(1), np.int64(2)], y_pt)
compare_mlx_and_py([], [out_pt], [])

# Set with slice containing np.int64 - THE ORIGINAL FAILING CASE
out_pt = pt_subtensor.set_subtensor(x_pt[:, : np.int64(2)], y_pt)
compare_mlx_and_py([], [out_pt], [])

# Increment with slice containing np.int64 - THE ORIGINAL FAILING CASE
out_pt = pt_subtensor.inc_subtensor(x_pt[:, : np.int64(2)], y_pt)
compare_mlx_and_py([], [out_pt], [])

# Complex slice with np.int64
y2_pt = pt.as_tensor_variable(np.ones((2, 2), dtype=np.float32))
out_pt = pt_subtensor.inc_subtensor(
x_pt[np.int64(0) : np.int64(2), np.int64(1) : np.int64(3)], y2_pt
)
compare_mlx_and_py([], [out_pt], [])


def test_mlx_incsubtensor_original_issue():
"""Test the exact example from the issue report.

This was failing with: ValueError: Slice indices must be integers or None.
"""
x_np = np.arange(9, dtype=np.float64).reshape((3, 3))
x_pt = pt.constant(x_np, dtype="float64")

# The exact failing case from the issue
out_pt = pt_subtensor.inc_subtensor(x_pt[:, :2], 10)
compare_mlx_and_py([], [out_pt], [])

# Verify it also works with set_subtensor
out_pt = pt_subtensor.set_subtensor(x_pt[:, :2], 10)
compare_mlx_and_py([], [out_pt], [])


def test_mlx_advanced_subtensor_with_numpy_int64():
"""Test AdvancedSubtensor with np.int64 in mixed indexing."""
x_np = np.arange(24, dtype=np.float32).reshape((3, 4, 2))
x_pt = pt.constant(x_np)

# Advanced indexing with list, but other dimensions use np.int64
# Note: This creates AdvancedSubtensor, not basic Subtensor
out_pt = x_pt[[0, 2], np.int64(1)]
compare_mlx_and_py([], [out_pt], [])

# Mixed advanced and basic indexing with np.int64 in slice
out_pt = x_pt[[0, 2], np.int64(1) : np.int64(3)]
compare_mlx_and_py([], [out_pt], [])


def test_mlx_advanced_incsubtensor_with_numpy_int64():
"""Test AdvancedIncSubtensor with np.int64."""
x_np = np.arange(15, dtype=np.float32).reshape((5, 3))
x_pt = pt.constant(x_np)

# Value to set/increment
y_pt = pt.as_tensor_variable(
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
)

# Advanced indexing set with array indices
indices = [np.int64(0), np.int64(2)]
out_pt = pt_subtensor.set_subtensor(x_pt[indices], y_pt)
compare_mlx_and_py([], [out_pt], [])

# Advanced indexing increment
out_pt = pt_subtensor.inc_subtensor(x_pt[indices], y_pt)
compare_mlx_and_py([], [out_pt], [])
Loading