From bff36855ea769a1738d528b2c5f00c0814e7d070 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Thu, 23 Oct 2025 21:01:33 -0500 Subject: [PATCH 1/3] Handle slices in `mlx_funcify_IncSubtensor` --- pytensor/link/mlx/dispatch/subtensor.py | 20 +++++++++++++++++++- tests/link/mlx/test_subtensor.py | 13 +++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index ce14d08246..e37f05fbbd 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -64,7 +64,25 @@ def mlx_fn(x, indices, y): return x def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): - indices = indices_from_subtensor(ilist, idx_list) + def get_slice_int(element): + if element is None: + return None + try: + return int(element) + except Exception: + return element + + indices = tuple( + [ + slice( + get_slice_int(s.start), get_slice_int(s.stop), get_slice_int(s.step) + ) + if isinstance(s, slice) + else s + for s in indices_from_subtensor(ilist, idx_list) + ] + ) + if len(indices) == 1: indices = indices[0] diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py index 2923411799..a13960807e 100644 --- a/tests/link/mlx/test_subtensor.py +++ b/tests/link/mlx/test_subtensor.py @@ -119,6 +119,19 @@ def test_mlx_IncSubtensor_increment(): assert not out_pt.owner.op.set_instead_of_inc compare_mlx_and_py([], [out_pt], []) + # Increment slice + out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, 2:], st_pt) + compare_mlx_and_py([], [out_pt], []) + + out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, -3:], st_pt) + compare_mlx_and_py([], [out_pt], []) + + out_pt = pt_subtensor.inc_subtensor(x_pt[::2, ::2, ::2], st_pt) + compare_mlx_and_py([], [out_pt], []) + + out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, :], st_pt) + compare_mlx_and_py([], [out_pt], []) + def test_mlx_AdvancedIncSubtensor_set(): """Test advanced set operations using AdvancedIncSubtensor.""" From 21e261ad756d166cbb6bd7674b00423b5cf2b9b1 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Tue, 28 Oct 2025 20:45:36 +0200 Subject: [PATCH 2/3] Fix MLX indexing to support np.int64 and NumPy scalars Adds normalize_indices_for_mlx to convert NumPy integer and float types, MLX scalar arrays, and slice components to Python int/float for MLX compatibility. Updates all MLX Subtensor and IncSubtensor dispatch functions to use this normalization, resolving errors with np.int64 indices and slices. Expands tests to cover various np.int64 and NumPy scalar indexing scenarios, including original failing cases. --- pytensor/link/mlx/dispatch/subtensor.py | 152 ++++++++++++++++++++---- tests/link/mlx/test_subtensor.py | 152 +++++++++++++++++++++++- 2 files changed, 280 insertions(+), 24 deletions(-) diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index e37f05fbbd..ea46b9e4b4 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -1,5 +1,7 @@ from copy import deepcopy +import numpy as np + from pytensor.link.mlx.dispatch.basic import mlx_funcify from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, @@ -13,12 +15,109 @@ from pytensor.tensor.type_other import MakeSlice +def normalize_indices_for_mlx(ilist, idx_list): + """Convert indices to MLX-compatible format. + + MLX has strict requirements for indexing: + - Integer indices must be Python int, not np.int64 or other NumPy integer types + - Slice components (start, stop, step) must be Python int or None, not np.int64 + - MLX arrays created from scalars need to be converted back to Python int + - Array indices for advanced indexing are handled separately + + This function converts all integer-like indices and slice components to Python int + while preserving None values and passing through array indices unchanged. + + Parameters + ---------- + ilist : tuple + Runtime index values to be passed to indices_from_subtensor + idx_list : tuple + Static index specification from the Op's idx_list attribute + + Returns + ------- + tuple + Normalized indices compatible with MLX array indexing + + Examples + -------- + >>> # Single np.int64 index converted to Python int + >>> normalize_indices_for_mlx((np.int64(1),), (True,)) + (1,) + + >>> # Slice with np.int64 components + >>> indices = indices_from_subtensor((np.int64(0), np.int64(2)), (slice(None, None),)) + >>> # After normalization, slice components are Python int + + Notes + ----- + This conversion is necessary because MLX's C++ indexing implementation + does not recognize NumPy scalar types, raising ValueError when encountered. + Additionally, mlx_typify converts NumPy scalars to MLX arrays, which also + need to be converted back to Python int for use in indexing operations. + Converting to Python int is zero-cost for Python int inputs and minimal + overhead for NumPy scalars and MLX scalar arrays. + """ + import mlx.core as mx + + def normalize_element(element): + """Convert a single index element to MLX-compatible format.""" + if element is None: + # None is valid in slices (e.g., x[None:5] or x[:None]) + return None + elif isinstance(element, slice): + # Recursively normalize slice components + return slice( + normalize_element(element.start), + normalize_element(element.stop), + normalize_element(element.step), + ) + elif isinstance(element, mx.array): + # MLX arrays from mlx_typify need special handling + # If it's a 0-d array (scalar), convert to Python int/float + if element.ndim == 0: + # Extract the scalar value + item = element.item() + # Convert to Python int if it's an integer type + if element.dtype in (mx.int8, mx.int16, mx.int32, mx.int64, + mx.uint8, mx.uint16, mx.uint32, mx.uint64): + return int(item) + else: + return float(item) + else: + # Multi-dimensional array for advanced indexing - pass through + return element + elif isinstance(element, (np.integer, np.floating)): + # Convert NumPy scalar to Python int/float + # This handles np.int64, np.int32, np.float64, etc. + return int(element) if isinstance(element, np.integer) else float(element) + elif isinstance(element, (int, float)): + # Python int/float are already compatible + return element + else: + # Pass through other types (arrays for advanced indexing, etc.) + return element + + # Get indices from PyTensor's subtensor utility + raw_indices = indices_from_subtensor(ilist, idx_list) + + # Normalize each index element + normalized = tuple(normalize_element(idx) for idx in raw_indices) + + return normalized + + @mlx_funcify.register(Subtensor) def mlx_funcify_Subtensor(op, node, **kwargs): + """MLX implementation of Subtensor operation. + + Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX. + """ idx_list = getattr(op, "idx_list", None) def subtensor(x, *ilists): - indices = indices_from_subtensor([int(element) for element in ilists], idx_list) + # Normalize indices to handle np.int64 and other NumPy types + indices = normalize_indices_for_mlx(ilists, idx_list) if len(indices) == 1: indices = indices[0] @@ -30,10 +129,16 @@ def subtensor(x, *ilists): @mlx_funcify.register(AdvancedSubtensor) @mlx_funcify.register(AdvancedSubtensor1) def mlx_funcify_AdvancedSubtensor(op, node, **kwargs): + """MLX implementation of AdvancedSubtensor operation. + + Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX, + including handling np.int64 in mixed basic/advanced indexing scenarios. + """ idx_list = getattr(op, "idx_list", None) def advanced_subtensor(x, *ilists): - indices = indices_from_subtensor(ilists, idx_list) + # Normalize indices to handle np.int64 and other NumPy types + indices = normalize_indices_for_mlx(ilists, idx_list) if len(indices) == 1: indices = indices[0] @@ -45,6 +150,11 @@ def advanced_subtensor(x, *ilists): @mlx_funcify.register(IncSubtensor) @mlx_funcify.register(AdvancedIncSubtensor1) def mlx_funcify_IncSubtensor(op, node, **kwargs): + """MLX implementation of IncSubtensor operation. + + Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX. + Handles both set_instead_of_inc=True (assignment) and False (increment). + """ idx_list = getattr(op, "idx_list", None) if getattr(op, "set_instead_of_inc", False): @@ -64,24 +174,8 @@ def mlx_fn(x, indices, y): return x def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): - def get_slice_int(element): - if element is None: - return None - try: - return int(element) - except Exception: - return element - - indices = tuple( - [ - slice( - get_slice_int(s.start), get_slice_int(s.stop), get_slice_int(s.step) - ) - if isinstance(s, slice) - else s - for s in indices_from_subtensor(ilist, idx_list) - ] - ) + # Normalize indices to handle np.int64 and other NumPy types + indices = normalize_indices_for_mlx(ilist, idx_list) if len(indices) == 1: indices = indices[0] @@ -93,6 +187,13 @@ def get_slice_int(element): @mlx_funcify.register(AdvancedIncSubtensor) def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs): + """MLX implementation of AdvancedIncSubtensor operation. + + Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX. + Note: For advanced indexing, ilist contains the actual array indices. + """ + idx_list = getattr(op, "idx_list", None) + if getattr(op, "set_instead_of_inc", False): def mlx_fn(x, indices, y): @@ -109,8 +210,15 @@ def mlx_fn(x, indices, y): x[indices] += y return x - def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn): - return mlx_fn(x, ilist, y) + def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): + # Normalize indices to handle np.int64 and other NumPy types + indices = normalize_indices_for_mlx(ilist, idx_list) + + # For advanced indexing, if we have a single tuple of indices, unwrap it + if len(indices) == 1: + indices = indices[0] + + return mlx_fn(x, indices, y) return advancedincsubtensor diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py index a13960807e..e10c893426 100644 --- a/tests/link/mlx/test_subtensor.py +++ b/tests/link/mlx/test_subtensor.py @@ -245,9 +245,12 @@ def test_mlx_subtensor_edge_cases(): compare_mlx_and_py([], [out_pt], []) -@pytest.mark.xfail(reason="MLX indexing with tuples not yet supported") def test_mlx_subtensor_with_variables(): - """Test subtensor operations with PyTensor variables as inputs.""" + """Test subtensor operations with PyTensor variables as inputs. + + This test now works thanks to the fix for np.int64 indexing, which also + handles the conversion of MLX scalar arrays in slice components. + """ # Test with variable arrays (not constants) x_pt = pt.matrix("x", dtype="float32") y_pt = pt.vector("y", dtype="float32") @@ -258,3 +261,148 @@ def test_mlx_subtensor_with_variables(): # Set operation with variables out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], y_pt) compare_mlx_and_py([x_pt, y_pt], [out_pt], [x_np, y_np]) + + +def test_mlx_subtensor_with_numpy_int64(): + """Test Subtensor operations with np.int64 indices. + + This tests the fix for MLX's strict requirement that indices must be + Python int, not np.int64 or other NumPy integer types. + """ + # Test data + x_np = np.arange(12, dtype=np.float32).reshape((3, 4)) + x_pt = pt.constant(x_np) + + # Single np.int64 index - this was failing before the fix + idx = np.int64(1) + out_pt = x_pt[idx] + compare_mlx_and_py([], [out_pt], []) + + # Multiple np.int64 indices + out_pt = x_pt[np.int64(1), np.int64(2)] + compare_mlx_and_py([], [out_pt], []) + + # Negative np.int64 index + out_pt = x_pt[np.int64(-1)] + compare_mlx_and_py([], [out_pt], []) + + # Mixed Python int and np.int64 + out_pt = x_pt[1, np.int64(2)] + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_subtensor_slices_with_numpy_int64(): + """Test Subtensor with slices containing np.int64 components. + + This tests that slice start/stop/step values can be np.int64. + """ + x_np = np.arange(20, dtype=np.float32) + x_pt = pt.constant(x_np) + + # Slice with np.int64 start + out_pt = x_pt[np.int64(2):] + compare_mlx_and_py([], [out_pt], []) + + # Slice with np.int64 stop + out_pt = x_pt[:np.int64(5)] + compare_mlx_and_py([], [out_pt], []) + + # Slice with np.int64 start and stop + out_pt = x_pt[np.int64(2):np.int64(8)] + compare_mlx_and_py([], [out_pt], []) + + # Slice with np.int64 step + out_pt = x_pt[::np.int64(2)] + compare_mlx_and_py([], [out_pt], []) + + # Slice with all np.int64 components + out_pt = x_pt[np.int64(1):np.int64(10):np.int64(2)] + compare_mlx_and_py([], [out_pt], []) + + # Negative np.int64 in slice + out_pt = x_pt[np.int64(-5):] + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_incsubtensor_with_numpy_int64(): + """Test IncSubtensor (set/inc) with np.int64 indices and slices. + + This is the main test for the reported issue with inc_subtensor. + """ + # Test data + x_np = np.arange(12, dtype=np.float32).reshape((3, 4)) + x_pt = pt.constant(x_np) + y_pt = pt.as_tensor_variable(np.array(10.0, dtype=np.float32)) + + # Set with np.int64 index + out_pt = pt_subtensor.set_subtensor(x_pt[np.int64(1), np.int64(2)], y_pt) + compare_mlx_and_py([], [out_pt], []) + + # Increment with np.int64 index + out_pt = pt_subtensor.inc_subtensor(x_pt[np.int64(1), np.int64(2)], y_pt) + compare_mlx_and_py([], [out_pt], []) + + # Set with slice containing np.int64 - THE ORIGINAL FAILING CASE + out_pt = pt_subtensor.set_subtensor(x_pt[:, :np.int64(2)], y_pt) + compare_mlx_and_py([], [out_pt], []) + + # Increment with slice containing np.int64 - THE ORIGINAL FAILING CASE + out_pt = pt_subtensor.inc_subtensor(x_pt[:, :np.int64(2)], y_pt) + compare_mlx_and_py([], [out_pt], []) + + # Complex slice with np.int64 + y2_pt = pt.as_tensor_variable(np.ones((2, 2), dtype=np.float32)) + out_pt = pt_subtensor.inc_subtensor( + x_pt[np.int64(0):np.int64(2), np.int64(1):np.int64(3)], y2_pt + ) + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_incsubtensor_original_issue(): + """Test the exact example from the issue report. + + This was failing with: ValueError: Slice indices must be integers or None. + """ + x_np = np.arange(9, dtype=np.float64).reshape((3, 3)) + x_pt = pt.constant(x_np, dtype="float64") + + # The exact failing case from the issue + out_pt = pt_subtensor.inc_subtensor(x_pt[:, :2], 10) + compare_mlx_and_py([], [out_pt], []) + + # Verify it also works with set_subtensor + out_pt = pt_subtensor.set_subtensor(x_pt[:, :2], 10) + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_advanced_subtensor_with_numpy_int64(): + """Test AdvancedSubtensor with np.int64 in mixed indexing.""" + x_np = np.arange(24, dtype=np.float32).reshape((3, 4, 2)) + x_pt = pt.constant(x_np) + + # Advanced indexing with list, but other dimensions use np.int64 + # Note: This creates AdvancedSubtensor, not basic Subtensor + out_pt = x_pt[[0, 2], np.int64(1)] + compare_mlx_and_py([], [out_pt], []) + + # Mixed advanced and basic indexing with np.int64 in slice + out_pt = x_pt[[0, 2], np.int64(1):np.int64(3)] + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_advanced_incsubtensor_with_numpy_int64(): + """Test AdvancedIncSubtensor with np.int64.""" + x_np = np.arange(15, dtype=np.float32).reshape((5, 3)) + x_pt = pt.constant(x_np) + + # Value to set/increment + y_pt = pt.as_tensor_variable(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)) + + # Advanced indexing set with array indices + indices = [np.int64(0), np.int64(2)] + out_pt = pt_subtensor.set_subtensor(x_pt[indices], y_pt) + compare_mlx_and_py([], [out_pt], []) + + # Advanced indexing increment + out_pt = pt_subtensor.inc_subtensor(x_pt[indices], y_pt) + compare_mlx_and_py([], [out_pt], []) From 08b4d89eb3a64cb32279ca9f65e785c03eadba79 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Tue, 28 Oct 2025 20:48:58 +0200 Subject: [PATCH 3/3] Fix MLX subtensor indexing with np.int64 and slices Normalize indices and slice components to Python int in MLX subtensor dispatch to address strict MLX requirements. Update tests to cover np.int64 indices and slices, ensuring compatibility and resolving previous failures with advanced and incremental subtensor operations. --- pytensor/link/mlx/dispatch/subtensor.py | 52 ++++++++++------- tests/link/mlx/test_subtensor.py | 78 +++++++++++++------------ 2 files changed, 71 insertions(+), 59 deletions(-) diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index ea46b9e4b4..422d566f11 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -17,38 +17,40 @@ def normalize_indices_for_mlx(ilist, idx_list): """Convert indices to MLX-compatible format. - + MLX has strict requirements for indexing: - Integer indices must be Python int, not np.int64 or other NumPy integer types - Slice components (start, stop, step) must be Python int or None, not np.int64 - MLX arrays created from scalars need to be converted back to Python int - Array indices for advanced indexing are handled separately - + This function converts all integer-like indices and slice components to Python int while preserving None values and passing through array indices unchanged. - + Parameters ---------- ilist : tuple Runtime index values to be passed to indices_from_subtensor idx_list : tuple Static index specification from the Op's idx_list attribute - + Returns ------- tuple Normalized indices compatible with MLX array indexing - + Examples -------- >>> # Single np.int64 index converted to Python int >>> normalize_indices_for_mlx((np.int64(1),), (True,)) (1,) - + >>> # Slice with np.int64 components - >>> indices = indices_from_subtensor((np.int64(0), np.int64(2)), (slice(None, None),)) + >>> indices = indices_from_subtensor( + ... (np.int64(0), np.int64(2)), (slice(None, None),) + ... ) >>> # After normalization, slice components are Python int - + Notes ----- This conversion is necessary because MLX's C++ indexing implementation @@ -59,7 +61,7 @@ def normalize_indices_for_mlx(ilist, idx_list): overhead for NumPy scalars and MLX scalar arrays. """ import mlx.core as mx - + def normalize_element(element): """Convert a single index element to MLX-compatible format.""" if element is None: @@ -79,8 +81,16 @@ def normalize_element(element): # Extract the scalar value item = element.item() # Convert to Python int if it's an integer type - if element.dtype in (mx.int8, mx.int16, mx.int32, mx.int64, - mx.uint8, mx.uint16, mx.uint32, mx.uint64): + if element.dtype in ( + mx.int8, + mx.int16, + mx.int32, + mx.int64, + mx.uint8, + mx.uint16, + mx.uint32, + mx.uint64, + ): return int(item) else: return float(item) @@ -97,20 +107,20 @@ def normalize_element(element): else: # Pass through other types (arrays for advanced indexing, etc.) return element - + # Get indices from PyTensor's subtensor utility raw_indices = indices_from_subtensor(ilist, idx_list) - + # Normalize each index element normalized = tuple(normalize_element(idx) for idx in raw_indices) - + return normalized @mlx_funcify.register(Subtensor) def mlx_funcify_Subtensor(op, node, **kwargs): """MLX implementation of Subtensor operation. - + Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX. """ idx_list = getattr(op, "idx_list", None) @@ -130,7 +140,7 @@ def subtensor(x, *ilists): @mlx_funcify.register(AdvancedSubtensor1) def mlx_funcify_AdvancedSubtensor(op, node, **kwargs): """MLX implementation of AdvancedSubtensor operation. - + Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX, including handling np.int64 in mixed basic/advanced indexing scenarios. """ @@ -151,7 +161,7 @@ def advanced_subtensor(x, *ilists): @mlx_funcify.register(AdvancedIncSubtensor1) def mlx_funcify_IncSubtensor(op, node, **kwargs): """MLX implementation of IncSubtensor operation. - + Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX. Handles both set_instead_of_inc=True (assignment) and False (increment). """ @@ -188,12 +198,12 @@ def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): @mlx_funcify.register(AdvancedIncSubtensor) def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs): """MLX implementation of AdvancedIncSubtensor operation. - + Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX. Note: For advanced indexing, ilist contains the actual array indices. """ idx_list = getattr(op, "idx_list", None) - + if getattr(op, "set_instead_of_inc", False): def mlx_fn(x, indices, y): @@ -213,11 +223,11 @@ def mlx_fn(x, indices, y): def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): # Normalize indices to handle np.int64 and other NumPy types indices = normalize_indices_for_mlx(ilist, idx_list) - + # For advanced indexing, if we have a single tuple of indices, unwrap it if len(indices) == 1: indices = indices[0] - + return mlx_fn(x, indices, y) return advancedincsubtensor diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py index e10c893426..3fa233fd57 100644 --- a/tests/link/mlx/test_subtensor.py +++ b/tests/link/mlx/test_subtensor.py @@ -247,7 +247,7 @@ def test_mlx_subtensor_edge_cases(): def test_mlx_subtensor_with_variables(): """Test subtensor operations with PyTensor variables as inputs. - + This test now works thanks to the fix for np.int64 indexing, which also handles the conversion of MLX scalar arrays in slice components. """ @@ -265,27 +265,27 @@ def test_mlx_subtensor_with_variables(): def test_mlx_subtensor_with_numpy_int64(): """Test Subtensor operations with np.int64 indices. - + This tests the fix for MLX's strict requirement that indices must be Python int, not np.int64 or other NumPy integer types. """ # Test data x_np = np.arange(12, dtype=np.float32).reshape((3, 4)) x_pt = pt.constant(x_np) - + # Single np.int64 index - this was failing before the fix idx = np.int64(1) out_pt = x_pt[idx] compare_mlx_and_py([], [out_pt], []) - + # Multiple np.int64 indices out_pt = x_pt[np.int64(1), np.int64(2)] compare_mlx_and_py([], [out_pt], []) - + # Negative np.int64 index out_pt = x_pt[np.int64(-1)] compare_mlx_and_py([], [out_pt], []) - + # Mixed Python int and np.int64 out_pt = x_pt[1, np.int64(2)] compare_mlx_and_py([], [out_pt], []) @@ -293,83 +293,83 @@ def test_mlx_subtensor_with_numpy_int64(): def test_mlx_subtensor_slices_with_numpy_int64(): """Test Subtensor with slices containing np.int64 components. - + This tests that slice start/stop/step values can be np.int64. """ x_np = np.arange(20, dtype=np.float32) x_pt = pt.constant(x_np) - + # Slice with np.int64 start - out_pt = x_pt[np.int64(2):] + out_pt = x_pt[np.int64(2) :] compare_mlx_and_py([], [out_pt], []) - + # Slice with np.int64 stop - out_pt = x_pt[:np.int64(5)] + out_pt = x_pt[: np.int64(5)] compare_mlx_and_py([], [out_pt], []) - + # Slice with np.int64 start and stop - out_pt = x_pt[np.int64(2):np.int64(8)] + out_pt = x_pt[np.int64(2) : np.int64(8)] compare_mlx_and_py([], [out_pt], []) - + # Slice with np.int64 step - out_pt = x_pt[::np.int64(2)] + out_pt = x_pt[:: np.int64(2)] compare_mlx_and_py([], [out_pt], []) - + # Slice with all np.int64 components - out_pt = x_pt[np.int64(1):np.int64(10):np.int64(2)] + out_pt = x_pt[np.int64(1) : np.int64(10) : np.int64(2)] compare_mlx_and_py([], [out_pt], []) - + # Negative np.int64 in slice - out_pt = x_pt[np.int64(-5):] + out_pt = x_pt[np.int64(-5) :] compare_mlx_and_py([], [out_pt], []) def test_mlx_incsubtensor_with_numpy_int64(): """Test IncSubtensor (set/inc) with np.int64 indices and slices. - + This is the main test for the reported issue with inc_subtensor. """ # Test data x_np = np.arange(12, dtype=np.float32).reshape((3, 4)) x_pt = pt.constant(x_np) y_pt = pt.as_tensor_variable(np.array(10.0, dtype=np.float32)) - + # Set with np.int64 index out_pt = pt_subtensor.set_subtensor(x_pt[np.int64(1), np.int64(2)], y_pt) compare_mlx_and_py([], [out_pt], []) - + # Increment with np.int64 index out_pt = pt_subtensor.inc_subtensor(x_pt[np.int64(1), np.int64(2)], y_pt) compare_mlx_and_py([], [out_pt], []) - + # Set with slice containing np.int64 - THE ORIGINAL FAILING CASE - out_pt = pt_subtensor.set_subtensor(x_pt[:, :np.int64(2)], y_pt) + out_pt = pt_subtensor.set_subtensor(x_pt[:, : np.int64(2)], y_pt) compare_mlx_and_py([], [out_pt], []) - + # Increment with slice containing np.int64 - THE ORIGINAL FAILING CASE - out_pt = pt_subtensor.inc_subtensor(x_pt[:, :np.int64(2)], y_pt) + out_pt = pt_subtensor.inc_subtensor(x_pt[:, : np.int64(2)], y_pt) compare_mlx_and_py([], [out_pt], []) - + # Complex slice with np.int64 y2_pt = pt.as_tensor_variable(np.ones((2, 2), dtype=np.float32)) out_pt = pt_subtensor.inc_subtensor( - x_pt[np.int64(0):np.int64(2), np.int64(1):np.int64(3)], y2_pt + x_pt[np.int64(0) : np.int64(2), np.int64(1) : np.int64(3)], y2_pt ) compare_mlx_and_py([], [out_pt], []) def test_mlx_incsubtensor_original_issue(): """Test the exact example from the issue report. - + This was failing with: ValueError: Slice indices must be integers or None. """ x_np = np.arange(9, dtype=np.float64).reshape((3, 3)) x_pt = pt.constant(x_np, dtype="float64") - + # The exact failing case from the issue out_pt = pt_subtensor.inc_subtensor(x_pt[:, :2], 10) compare_mlx_and_py([], [out_pt], []) - + # Verify it also works with set_subtensor out_pt = pt_subtensor.set_subtensor(x_pt[:, :2], 10) compare_mlx_and_py([], [out_pt], []) @@ -379,14 +379,14 @@ def test_mlx_advanced_subtensor_with_numpy_int64(): """Test AdvancedSubtensor with np.int64 in mixed indexing.""" x_np = np.arange(24, dtype=np.float32).reshape((3, 4, 2)) x_pt = pt.constant(x_np) - + # Advanced indexing with list, but other dimensions use np.int64 # Note: This creates AdvancedSubtensor, not basic Subtensor out_pt = x_pt[[0, 2], np.int64(1)] compare_mlx_and_py([], [out_pt], []) - + # Mixed advanced and basic indexing with np.int64 in slice - out_pt = x_pt[[0, 2], np.int64(1):np.int64(3)] + out_pt = x_pt[[0, 2], np.int64(1) : np.int64(3)] compare_mlx_and_py([], [out_pt], []) @@ -394,15 +394,17 @@ def test_mlx_advanced_incsubtensor_with_numpy_int64(): """Test AdvancedIncSubtensor with np.int64.""" x_np = np.arange(15, dtype=np.float32).reshape((5, 3)) x_pt = pt.constant(x_np) - + # Value to set/increment - y_pt = pt.as_tensor_variable(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)) - + y_pt = pt.as_tensor_variable( + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) + ) + # Advanced indexing set with array indices indices = [np.int64(0), np.int64(2)] out_pt = pt_subtensor.set_subtensor(x_pt[indices], y_pt) compare_mlx_and_py([], [out_pt], []) - + # Advanced indexing increment out_pt = pt_subtensor.inc_subtensor(x_pt[indices], y_pt) compare_mlx_and_py([], [out_pt], [])