Skip to content

MLX Incsubtensor fails on slices #1690

@jessegrabowski

Description

@jessegrabowski

Description

The following raises:

import pytensor.tensor as pt
import numpy as np

x = pt.dmatrix('x')
x_val = np.arange(9).reshape((3, 3))

out = pt.subtensor.inc_subtensor(x[:, :2], 10)
out.eval({x:x_val}, mode='MLX')
Traceback
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File ~/Documents/Python/pytensor/pytensor/link/basic.py:665, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph_jit, thunk_inputs, thunk_outputs)
    664 try:
--> 665     outputs = fgraph_jit(*(x[0] for x in thunk_inputs))
    666 except Exception:
    667     # TODO: Should we add a fake node that combines all outputs,
    668     #  since the error may come from any of them?

File ~/Documents/Python/pytensor/pytensor/link/mlx/linker.py:47, in MLXLinker.jit_compile.<locals>.fn(inner_fn, *inputs)
     46 def fn(*inputs, inner_fn=inner_fn):
---> 47     return inner_fn(*(mlx_typify(inp) for inp in inputs))

File /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmp804m88ye:3, in mlx_funcified_fgraph(x)
      1 def mlx_funcified_fgraph(x):
      2     # IncSubtensor{:, :stop}(x, 10, 2)
----> 3     tensor_variable = incsubtensor(x, tensor_constant, scalar_constant)
      4     return (tensor_variable,)

File ~/Documents/Python/pytensor/pytensor/link/mlx/dispatch/subtensor.py:71, in mlx_funcify_IncSubtensor.<locals>.incsubtensor(x, y, mlx_fn, idx_list, *ilist)
     69     indices = indices[0]
---> 71 return mlx_fn(x, indices, y)

File ~/Documents/Python/pytensor/pytensor/link/mlx/dispatch/subtensor.py:63, in mlx_funcify_IncSubtensor.<locals>.mlx_fn(x, indices, y)
     62     x = deepcopy(x)
---> 63 x[indices] += y
     64 return x

ValueError: Slice indices must be integers or None.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Cell In[1], line 8
      5 x_val = np.arange(9).reshape((3, 3))
      7 out = pt.subtensor.inc_subtensor(x[:, :2], 10)
----> 8 out.eval({x:x_val}, mode='MLX')

File ~/Documents/Python/pytensor/pytensor/graph/basic.py:668, in Variable.eval(self, inputs_to_values, **kwargs)
    662         warnings.warn(
    663             "Keyword arguments could not be used to create a cache key for the underlying variable. "
    664             f"A function will be recompiled on every call with such keyword arguments.\n{exc}"
    665         )
    667 args = [parsed_inputs_to_values[param] for param in inputs]
--> 668 return fn(*args)

File ~/Documents/Python/pytensor/pytensor/compile/function/types.py:1038, in Function.__call__(self, output_subset, *args, **kwargs)
   1036     t0_fn = time.perf_counter()
   1037 try:
-> 1038     outputs = vm() if output_subset is None else vm(output_subset=output_subset)
   1039 except Exception:
   1040     self._restore_defaults()

File ~/Documents/Python/pytensor/pytensor/link/basic.py:669, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph_jit, thunk_inputs, thunk_outputs)
    665     outputs = fgraph_jit(*(x[0] for x in thunk_inputs))
    666 except Exception:
    667     # TODO: Should we add a fake node that combines all outputs,
    668     #  since the error may come from any of them?
--> 669     raise_with_op(self.fgraph, output_nodes[0], thunk)
    671 # zip strict not specified because we are in a hot loop
    672 for o_storage, o_val in zip(thunk_outputs, outputs):

File ~/Documents/Python/pytensor/pytensor/link/utils.py:526, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    521     warnings.warn(
    522         f"{exc_type} error does not allow us to add an extra error message"
    523     )
    524     # Some exception need extra parameter in inputs. So forget the
    525     # extra long error message in that case.
--> 526 raise exc_value.with_traceback(exc_trace)

File ~/Documents/Python/pytensor/pytensor/link/basic.py:665, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph_jit, thunk_inputs, thunk_outputs)
    659 def thunk(
    660     fgraph_jit=fgraph_jit,
    661     thunk_inputs=thunk_inputs,
    662     thunk_outputs=thunk_outputs,
    663 ):
    664     try:
--> 665         outputs = fgraph_jit(*(x[0] for x in thunk_inputs))
    666     except Exception:
    667         # TODO: Should we add a fake node that combines all outputs,
    668         #  since the error may come from any of them?
    669         raise_with_op(self.fgraph, output_nodes[0], thunk)

File ~/Documents/Python/pytensor/pytensor/link/mlx/linker.py:47, in MLXLinker.jit_compile.<locals>.fn(inner_fn, *inputs)
     46 def fn(*inputs, inner_fn=inner_fn):
---> 47     return inner_fn(*(mlx_typify(inp) for inp in inputs))

File /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmp804m88ye:3, in mlx_funcified_fgraph(x)
      1 def mlx_funcified_fgraph(x):
      2     # IncSubtensor{:, :stop}(x, 10, 2)
----> 3     tensor_variable = incsubtensor(x, tensor_constant, scalar_constant)
      4     return (tensor_variable,)

File ~/Documents/Python/pytensor/pytensor/link/mlx/dispatch/subtensor.py:71, in mlx_funcify_IncSubtensor.<locals>.incsubtensor(x, y, mlx_fn, idx_list, *ilist)
     68 if len(indices) == 1:
     69     indices = indices[0]
---> 71 return mlx_fn(x, indices, y)

File ~/Documents/Python/pytensor/pytensor/link/mlx/dispatch/subtensor.py:63, in mlx_funcify_IncSubtensor.<locals>.mlx_fn(x, indices, y)
     61 if not op.inplace:
     62     x = deepcopy(x)
---> 63 x[indices] += y
     64 return x

ValueError: Slice indices must be integers or None.
Apply node that caused the error: IncSubtensor{:, :stop}(x, 10, 2)
Toposort index: 0
Inputs types: [TensorType(float64, shape=(None, None)), TensorType(int8, shape=()), ScalarType(int64)]
Inputs shapes: [(3, 3)]
Inputs strides: [(24, 8)]
Inputs values: ['not shown']
Outputs clients: [[output[0](IncSubtensor{:, :stop}.0)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/Users/jessegrabowski/mambaforge/envs/grabowski_phd/lib/python3.12/site-packages/ipykernel/zmqshell.py", line 577, in run_cell
    return super().run_cell(*args, **kwargs)
  File "/Users/jessegrabowski/mambaforge/envs/grabowski_phd/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3116, in run_cell
    result = self._run_cell(
  File "/Users/jessegrabowski/mambaforge/envs/grabowski_phd/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3171, in _run_cell
    result = runner(coro)
  File "/Users/jessegrabowski/mambaforge/envs/grabowski_phd/lib/python3.12/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
    coro.send(None)
  File "/Users/jessegrabowski/mambaforge/envs/grabowski_phd/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3394, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/Users/jessegrabowski/mambaforge/envs/grabowski_phd/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3639, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/Users/jessegrabowski/mambaforge/envs/grabowski_phd/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3699, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_29364/386685496.py", line 7, in <module>
    out = pt.subtensor.inc_subtensor(x[:, :2], 10)

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingmlx

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions