Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exclude backend incompatible rewrites in Scan dispatch #427

Merged
merged 1 commit into from
Sep 17, 2023

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Sep 4, 2023

The whole inner mode is still a problem, but for now another bandaid

Closes #426

@codecov-commenter
Copy link

codecov-commenter commented Sep 4, 2023

Codecov Report

Merging #427 (ea9f0f4) into main (6d3c756) will decrease coverage by 0.03%.
Report is 7 commits behind head on main.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #427      +/-   ##
==========================================
- Coverage   80.78%   80.75%   -0.03%     
==========================================
  Files         157      157              
  Lines       45575    45612      +37     
  Branches    11152    11162      +10     
==========================================
+ Hits        36816    36836      +20     
- Misses       6552     6562      +10     
- Partials     2207     2214       +7     
Files Changed Coverage Δ
pytensor/link/jax/dispatch/scan.py 100.00% <100.00%> (ø)
pytensor/link/numba/dispatch/scan.py 95.91% <100.00%> (+0.02%) ⬆️

... and 4 files with indirect coverage changes

tests/link/jax/test_scan.py Outdated Show resolved Hide resolved
@jessegrabowski
Copy link
Member

Let me speed test the default mode without rewrites before this gets merged

@ricardoV94
Copy link
Member Author

Nudge @jessegrabowski

@jessegrabowski
Copy link
Member

This is good to go. The speed tests I ran show excluding the BLAS rewrites has very little effect, even on large-scale problems:

Untitled

Untitled

Full notebook here. I'm actually a bit disappointed by these results. I thought the rewrites to the lower-level BLAS operations like GEMM would offer huge speedups relative to whatever the default is, but it seems like the gains are quite marginal. Maybe whatever routine the dot Op is calling in C also uses these same routines?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Sep 17, 2023

@jessegrabowski the GEMM/BLAS OPs are only relevant for the C backend. Jax and Numba introduce it themselves when appropriate, so the changes shouldn't have any effect.

We didn't change the rewrites used in the C backend

@jessegrabowski
Copy link
Member

OK, I misunderstood the PR, I thought you were globally disabling JAX incompatible rewrites for the inner graphs of scan. Sorry for holding this up.

@ricardoV94
Copy link
Member Author

NP

@ricardoV94 ricardoV94 merged commit ad68c7f into pymc-devs:main Sep 17, 2023
52 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: Default rewrites applied to scans with sit-sot inputs regardless of backend
4 participants