Skip to content

MLX Elementwise and scalar dispatches#2007

Merged
jessegrabowski merged 10 commits into
pymc-devs:v3from
jessegrabowski:mlx-elemwise
Mar 29, 2026
Merged

MLX Elementwise and scalar dispatches#2007
jessegrabowski merged 10 commits into
pymc-devs:v3from
jessegrabowski:mlx-elemwise

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

Description

Some simple/boring stuff for the MLX backend:

  • Refactor scalar dispatches to copy the JAX template. We use nfunc_name to look in the mx namespace and get everything we can. There's a dictionary of overrides for stuff that doesn't follow the array standard (invert is bitwise_invert and true_divide is divide
  • Move scalar stuff to scalar.py, leave CARReduce and Elementwise in elementwise.py
  • Add second and identity dispatches
  • Add Composite dispatch

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

Comment thread pytensor/link/mlx/dispatch/scalar.py Outdated
@jessegrabowski jessegrabowski force-pushed the mlx-elemwise branch 2 times, most recently from 6cf194c to d730255 Compare March 29, 2026 16:05

@mlx_funcify.register(Composite)
def mlx_funcify_Composite(op, node=None, **kwargs):
return mlx_funcify(op.fgraph, squeeze_output=True)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In jax OpFromGraph dispatch we're calling the jax rewrites on the inner graph, but not in Composite. I followed that lead, but it seems inconsistent

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Mar 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Composite is never optimized internally, it's supposed to already be a good graph. OpFromGraph and Scan are, but we should actually explicitly optimize as a rewrite, not at compile/runtime. We are working towards that... but it takes time

Which is why I want to make their inner graphs frozen as well

@jessegrabowski jessegrabowski merged commit 87c9eba into pymc-devs:v3 Mar 29, 2026
118 of 120 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants