Skip to content

Rename MLX core.py to tensor_basic.py#2010

Merged
jessegrabowski merged 2 commits intopymc-devs:v3from
jessegrabowski:mlx_tensor_basic
Mar 29, 2026
Merged

Rename MLX core.py to tensor_basic.py#2010
jessegrabowski merged 2 commits intopymc-devs:v3from
jessegrabowski:mlx_tensor_basic

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

The names of the files in the dispatch modules should match the pytensor files they implement, not the dispatch library. So I renamed core.py to tensor_basic.py. This now matches what Jax has.

I also implemented an ARange dispatch. It is super limited, because MLX doesn't allow broadcasting via the arguments of mx.arange. All of start, stop, and step have to be constants and have to be scalars. Still, that covers majority of the use case, so it's worth having.

@jessegrabowski jessegrabowski changed the title Rename core.py to tensor_basic.py Rename MLX core.py to tensor_basic.py Mar 29, 2026
@ricardoV94
Copy link
Copy Markdown
Member

If they are constants, aren't they constant folded by us anyway? Fine to have it but curious how did you come across it

@jessegrabowski
Copy link
Copy Markdown
Member Author

jessegrabowski commented Mar 29, 2026

If they are constants, aren't they constant folded by us anyway? Fine to have it but curious how did you come across it

I hit issues with graphs that had symbolic inputs, then had the llm just make a big test grid of inputs and realize you're not allowed to broadcast.

We might be able to do something with an eager vmap in the dispatch?

@ricardoV94
Copy link
Copy Markdown
Member

I don't understand, it sounds like you only support constant aranges, so they would have been constant folded anyway? It's fine to support Arange with constant inputs, maybe we din't run constant folding or something, but I'm a bit concerned in what real use scenario you found it. (Not concerned enough to block the PR, hence the pre-approval)

@jessegrabowski
Copy link
Copy Markdown
Member Author

The concrete use-case was graphs like this:

stop = pt.iscalar('stop')
x = pt.arange(stop)

fn = pytensor.function([stop], x, mode='MLX')

Which raises a cryptic MLX error at runtime. I wanted to intercept it and raise a more informative error, like we do for jax.

@ricardoV94
Copy link
Copy Markdown
Member

That's more clear

@jessegrabowski jessegrabowski merged commit 79562b9 into pymc-devs:v3 Mar 29, 2026
66 checks passed
@jessegrabowski jessegrabowski deleted the mlx_tensor_basic branch March 29, 2026 16:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants