Skip to content

Add MLX dispatch for BatchedDot#2009

Merged
jessegrabowski merged 2 commits into
pymc-devs:v3from
jessegrabowski:mlx-blas
Mar 29, 2026
Merged

Add MLX dispatch for BatchedDot#2009
jessegrabowski merged 2 commits into
pymc-devs:v3from
jessegrabowski:mlx-blas

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

Description

MLX's mlx.core.matmul is natively batched, so the dispatch for BatchedDot is trivial

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

Comment thread tests/link/mlx/test_blas.py Outdated
np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
)
out = pt_blas.BatchedDot()(a, b)
compare_mlx_and_py([a, b], [out], [a_test_value, b_test_value])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this returns the mlx function and the eval, so you don't need to recompile to test the exception below

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i always forget that

@jessegrabowski jessegrabowski merged commit 1fa7c81 into pymc-devs:v3 Mar 29, 2026
66 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants