Skip to content

Add MLX dispatch for IfElse#2013

Merged
jessegrabowski merged 2 commits into
pymc-devs:v3from
jessegrabowski:mlx-ifelse
Mar 29, 2026
Merged

Add MLX dispatch for IfElse#2013
jessegrabowski merged 2 commits into
pymc-devs:v3from
jessegrabowski:mlx-ifelse

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

@jessegrabowski jessegrabowski commented Mar 29, 2026

There's no support for conditional logic in MLX, so the best we can do is mx.where. It ends up being similar to the limitations in jax -- jax.lax.cond gets rewritten to jax.lax.select, which is also where jnp.where ends up if you pass all 3 inputs (which we do in this case).

So IfElse sucks again -- it will evaluate both branches. But I submit that it's better than nothing.

@ricardoV94
Copy link
Copy Markdown
Member

have an xpass strict

@jessegrabowski
Copy link
Copy Markdown
Member Author

I also snuck a dispatch for TypeCastingOp (it's an identity function) so I split it into a separate commit

@jessegrabowski jessegrabowski merged commit d23a7ba into pymc-devs:v3 Mar 29, 2026
66 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request mlx OpFromGraph

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants