-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Exclude backend incompatible rewrites in Scan dispatch #427
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #427 +/- ##
==========================================
- Coverage 80.78% 80.75% -0.03%
==========================================
Files 157 157
Lines 45575 45612 +37
Branches 11152 11162 +10
==========================================
+ Hits 36816 36836 +20
- Misses 6552 6562 +10
- Partials 2207 2214 +7
|
c6e1284
to
ea9f0f4
Compare
Let me speed test the default mode without rewrites before this gets merged |
Nudge @jessegrabowski |
This is good to go. The speed tests I ran show excluding the BLAS rewrites has very little effect, even on large-scale problems: Full notebook here. I'm actually a bit disappointed by these results. I thought the rewrites to the lower-level BLAS operations like GEMM would offer huge speedups relative to whatever the default is, but it seems like the gains are quite marginal. Maybe whatever routine the dot |
@jessegrabowski the GEMM/BLAS OPs are only relevant for the C backend. Jax and Numba introduce it themselves when appropriate, so the changes shouldn't have any effect. We didn't change the rewrites used in the C backend |
OK, I misunderstood the PR, I thought you were globally disabling JAX incompatible rewrites for the inner graphs of scan. Sorry for holding this up. |
NP |
The whole inner mode is still a problem, but for now another bandaid
Closes #426