diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index ce14d08246..422d566f11 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -1,5 +1,7 @@ from copy import deepcopy +import numpy as np + from pytensor.link.mlx.dispatch.basic import mlx_funcify from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, @@ -13,12 +15,119 @@ from pytensor.tensor.type_other import MakeSlice +def normalize_indices_for_mlx(ilist, idx_list): + """Convert indices to MLX-compatible format. + + MLX has strict requirements for indexing: + - Integer indices must be Python int, not np.int64 or other NumPy integer types + - Slice components (start, stop, step) must be Python int or None, not np.int64 + - MLX arrays created from scalars need to be converted back to Python int + - Array indices for advanced indexing are handled separately + + This function converts all integer-like indices and slice components to Python int + while preserving None values and passing through array indices unchanged. + + Parameters + ---------- + ilist : tuple + Runtime index values to be passed to indices_from_subtensor + idx_list : tuple + Static index specification from the Op's idx_list attribute + + Returns + ------- + tuple + Normalized indices compatible with MLX array indexing + + Examples + -------- + >>> # Single np.int64 index converted to Python int + >>> normalize_indices_for_mlx((np.int64(1),), (True,)) + (1,) + + >>> # Slice with np.int64 components + >>> indices = indices_from_subtensor( + ... (np.int64(0), np.int64(2)), (slice(None, None),) + ... ) + >>> # After normalization, slice components are Python int + + Notes + ----- + This conversion is necessary because MLX's C++ indexing implementation + does not recognize NumPy scalar types, raising ValueError when encountered. + Additionally, mlx_typify converts NumPy scalars to MLX arrays, which also + need to be converted back to Python int for use in indexing operations. + Converting to Python int is zero-cost for Python int inputs and minimal + overhead for NumPy scalars and MLX scalar arrays. + """ + import mlx.core as mx + + def normalize_element(element): + """Convert a single index element to MLX-compatible format.""" + if element is None: + # None is valid in slices (e.g., x[None:5] or x[:None]) + return None + elif isinstance(element, slice): + # Recursively normalize slice components + return slice( + normalize_element(element.start), + normalize_element(element.stop), + normalize_element(element.step), + ) + elif isinstance(element, mx.array): + # MLX arrays from mlx_typify need special handling + # If it's a 0-d array (scalar), convert to Python int/float + if element.ndim == 0: + # Extract the scalar value + item = element.item() + # Convert to Python int if it's an integer type + if element.dtype in ( + mx.int8, + mx.int16, + mx.int32, + mx.int64, + mx.uint8, + mx.uint16, + mx.uint32, + mx.uint64, + ): + return int(item) + else: + return float(item) + else: + # Multi-dimensional array for advanced indexing - pass through + return element + elif isinstance(element, (np.integer, np.floating)): + # Convert NumPy scalar to Python int/float + # This handles np.int64, np.int32, np.float64, etc. + return int(element) if isinstance(element, np.integer) else float(element) + elif isinstance(element, (int, float)): + # Python int/float are already compatible + return element + else: + # Pass through other types (arrays for advanced indexing, etc.) + return element + + # Get indices from PyTensor's subtensor utility + raw_indices = indices_from_subtensor(ilist, idx_list) + + # Normalize each index element + normalized = tuple(normalize_element(idx) for idx in raw_indices) + + return normalized + + @mlx_funcify.register(Subtensor) def mlx_funcify_Subtensor(op, node, **kwargs): + """MLX implementation of Subtensor operation. + + Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX. + """ idx_list = getattr(op, "idx_list", None) def subtensor(x, *ilists): - indices = indices_from_subtensor([int(element) for element in ilists], idx_list) + # Normalize indices to handle np.int64 and other NumPy types + indices = normalize_indices_for_mlx(ilists, idx_list) if len(indices) == 1: indices = indices[0] @@ -30,10 +139,16 @@ def subtensor(x, *ilists): @mlx_funcify.register(AdvancedSubtensor) @mlx_funcify.register(AdvancedSubtensor1) def mlx_funcify_AdvancedSubtensor(op, node, **kwargs): + """MLX implementation of AdvancedSubtensor operation. + + Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX, + including handling np.int64 in mixed basic/advanced indexing scenarios. + """ idx_list = getattr(op, "idx_list", None) def advanced_subtensor(x, *ilists): - indices = indices_from_subtensor(ilists, idx_list) + # Normalize indices to handle np.int64 and other NumPy types + indices = normalize_indices_for_mlx(ilists, idx_list) if len(indices) == 1: indices = indices[0] @@ -45,6 +160,11 @@ def advanced_subtensor(x, *ilists): @mlx_funcify.register(IncSubtensor) @mlx_funcify.register(AdvancedIncSubtensor1) def mlx_funcify_IncSubtensor(op, node, **kwargs): + """MLX implementation of IncSubtensor operation. + + Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX. + Handles both set_instead_of_inc=True (assignment) and False (increment). + """ idx_list = getattr(op, "idx_list", None) if getattr(op, "set_instead_of_inc", False): @@ -64,7 +184,9 @@ def mlx_fn(x, indices, y): return x def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): - indices = indices_from_subtensor(ilist, idx_list) + # Normalize indices to handle np.int64 and other NumPy types + indices = normalize_indices_for_mlx(ilist, idx_list) + if len(indices) == 1: indices = indices[0] @@ -75,6 +197,13 @@ def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): @mlx_funcify.register(AdvancedIncSubtensor) def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs): + """MLX implementation of AdvancedIncSubtensor operation. + + Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX. + Note: For advanced indexing, ilist contains the actual array indices. + """ + idx_list = getattr(op, "idx_list", None) + if getattr(op, "set_instead_of_inc", False): def mlx_fn(x, indices, y): @@ -91,8 +220,15 @@ def mlx_fn(x, indices, y): x[indices] += y return x - def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn): - return mlx_fn(x, ilist, y) + def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): + # Normalize indices to handle np.int64 and other NumPy types + indices = normalize_indices_for_mlx(ilist, idx_list) + + # For advanced indexing, if we have a single tuple of indices, unwrap it + if len(indices) == 1: + indices = indices[0] + + return mlx_fn(x, indices, y) return advancedincsubtensor diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py index 2923411799..3fa233fd57 100644 --- a/tests/link/mlx/test_subtensor.py +++ b/tests/link/mlx/test_subtensor.py @@ -119,6 +119,19 @@ def test_mlx_IncSubtensor_increment(): assert not out_pt.owner.op.set_instead_of_inc compare_mlx_and_py([], [out_pt], []) + # Increment slice + out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, 2:], st_pt) + compare_mlx_and_py([], [out_pt], []) + + out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, -3:], st_pt) + compare_mlx_and_py([], [out_pt], []) + + out_pt = pt_subtensor.inc_subtensor(x_pt[::2, ::2, ::2], st_pt) + compare_mlx_and_py([], [out_pt], []) + + out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, :], st_pt) + compare_mlx_and_py([], [out_pt], []) + def test_mlx_AdvancedIncSubtensor_set(): """Test advanced set operations using AdvancedIncSubtensor.""" @@ -232,9 +245,12 @@ def test_mlx_subtensor_edge_cases(): compare_mlx_and_py([], [out_pt], []) -@pytest.mark.xfail(reason="MLX indexing with tuples not yet supported") def test_mlx_subtensor_with_variables(): - """Test subtensor operations with PyTensor variables as inputs.""" + """Test subtensor operations with PyTensor variables as inputs. + + This test now works thanks to the fix for np.int64 indexing, which also + handles the conversion of MLX scalar arrays in slice components. + """ # Test with variable arrays (not constants) x_pt = pt.matrix("x", dtype="float32") y_pt = pt.vector("y", dtype="float32") @@ -245,3 +261,150 @@ def test_mlx_subtensor_with_variables(): # Set operation with variables out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], y_pt) compare_mlx_and_py([x_pt, y_pt], [out_pt], [x_np, y_np]) + + +def test_mlx_subtensor_with_numpy_int64(): + """Test Subtensor operations with np.int64 indices. + + This tests the fix for MLX's strict requirement that indices must be + Python int, not np.int64 or other NumPy integer types. + """ + # Test data + x_np = np.arange(12, dtype=np.float32).reshape((3, 4)) + x_pt = pt.constant(x_np) + + # Single np.int64 index - this was failing before the fix + idx = np.int64(1) + out_pt = x_pt[idx] + compare_mlx_and_py([], [out_pt], []) + + # Multiple np.int64 indices + out_pt = x_pt[np.int64(1), np.int64(2)] + compare_mlx_and_py([], [out_pt], []) + + # Negative np.int64 index + out_pt = x_pt[np.int64(-1)] + compare_mlx_and_py([], [out_pt], []) + + # Mixed Python int and np.int64 + out_pt = x_pt[1, np.int64(2)] + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_subtensor_slices_with_numpy_int64(): + """Test Subtensor with slices containing np.int64 components. + + This tests that slice start/stop/step values can be np.int64. + """ + x_np = np.arange(20, dtype=np.float32) + x_pt = pt.constant(x_np) + + # Slice with np.int64 start + out_pt = x_pt[np.int64(2) :] + compare_mlx_and_py([], [out_pt], []) + + # Slice with np.int64 stop + out_pt = x_pt[: np.int64(5)] + compare_mlx_and_py([], [out_pt], []) + + # Slice with np.int64 start and stop + out_pt = x_pt[np.int64(2) : np.int64(8)] + compare_mlx_and_py([], [out_pt], []) + + # Slice with np.int64 step + out_pt = x_pt[:: np.int64(2)] + compare_mlx_and_py([], [out_pt], []) + + # Slice with all np.int64 components + out_pt = x_pt[np.int64(1) : np.int64(10) : np.int64(2)] + compare_mlx_and_py([], [out_pt], []) + + # Negative np.int64 in slice + out_pt = x_pt[np.int64(-5) :] + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_incsubtensor_with_numpy_int64(): + """Test IncSubtensor (set/inc) with np.int64 indices and slices. + + This is the main test for the reported issue with inc_subtensor. + """ + # Test data + x_np = np.arange(12, dtype=np.float32).reshape((3, 4)) + x_pt = pt.constant(x_np) + y_pt = pt.as_tensor_variable(np.array(10.0, dtype=np.float32)) + + # Set with np.int64 index + out_pt = pt_subtensor.set_subtensor(x_pt[np.int64(1), np.int64(2)], y_pt) + compare_mlx_and_py([], [out_pt], []) + + # Increment with np.int64 index + out_pt = pt_subtensor.inc_subtensor(x_pt[np.int64(1), np.int64(2)], y_pt) + compare_mlx_and_py([], [out_pt], []) + + # Set with slice containing np.int64 - THE ORIGINAL FAILING CASE + out_pt = pt_subtensor.set_subtensor(x_pt[:, : np.int64(2)], y_pt) + compare_mlx_and_py([], [out_pt], []) + + # Increment with slice containing np.int64 - THE ORIGINAL FAILING CASE + out_pt = pt_subtensor.inc_subtensor(x_pt[:, : np.int64(2)], y_pt) + compare_mlx_and_py([], [out_pt], []) + + # Complex slice with np.int64 + y2_pt = pt.as_tensor_variable(np.ones((2, 2), dtype=np.float32)) + out_pt = pt_subtensor.inc_subtensor( + x_pt[np.int64(0) : np.int64(2), np.int64(1) : np.int64(3)], y2_pt + ) + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_incsubtensor_original_issue(): + """Test the exact example from the issue report. + + This was failing with: ValueError: Slice indices must be integers or None. + """ + x_np = np.arange(9, dtype=np.float64).reshape((3, 3)) + x_pt = pt.constant(x_np, dtype="float64") + + # The exact failing case from the issue + out_pt = pt_subtensor.inc_subtensor(x_pt[:, :2], 10) + compare_mlx_and_py([], [out_pt], []) + + # Verify it also works with set_subtensor + out_pt = pt_subtensor.set_subtensor(x_pt[:, :2], 10) + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_advanced_subtensor_with_numpy_int64(): + """Test AdvancedSubtensor with np.int64 in mixed indexing.""" + x_np = np.arange(24, dtype=np.float32).reshape((3, 4, 2)) + x_pt = pt.constant(x_np) + + # Advanced indexing with list, but other dimensions use np.int64 + # Note: This creates AdvancedSubtensor, not basic Subtensor + out_pt = x_pt[[0, 2], np.int64(1)] + compare_mlx_and_py([], [out_pt], []) + + # Mixed advanced and basic indexing with np.int64 in slice + out_pt = x_pt[[0, 2], np.int64(1) : np.int64(3)] + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_advanced_incsubtensor_with_numpy_int64(): + """Test AdvancedIncSubtensor with np.int64.""" + x_np = np.arange(15, dtype=np.float32).reshape((5, 3)) + x_pt = pt.constant(x_np) + + # Value to set/increment + y_pt = pt.as_tensor_variable( + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) + ) + + # Advanced indexing set with array indices + indices = [np.int64(0), np.int64(2)] + out_pt = pt_subtensor.set_subtensor(x_pt[indices], y_pt) + compare_mlx_and_py([], [out_pt], []) + + # Advanced indexing increment + out_pt = pt_subtensor.inc_subtensor(x_pt[indices], y_pt) + compare_mlx_and_py([], [out_pt], [])