Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend sum of mul rewrite for multiple axis #484

Merged
merged 3 commits into from
Dec 7, 2023

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Nov 7, 2023

Extend the Sum of Mul rewrite beyond the scalar case. We extract variables that have only degenerate dims along the reduced dimensions. In the case of a total sum only scalars can be pulled out, which was exactly the case supported by the old rewrite.

I am not so hot on the Prod of Mul rewrite, but since it was there already, we keep it. This actually fixes a bug when the prod only does a partial reduction, in which case the old logic of raising the extracted values to the power of the whole inner size was wrong. I added a test that failed in main in the first commit. The expanded rewrite comes up in the second commit, and fixes the failing test.

I also separated the rewrite for Sum(-mul) since it shares almost no logic with the larger rewrite. The usefulness of generalizing this rewrite showed up when exploring the grad of the batched MvNormal logp/dlogp in #482 .

And finally, I merged this with the nearly identical rewrite for sum of div, that already handled multiple axes

@ricardoV94 ricardoV94 force-pushed the sum_of_mul_rewrite branch 3 times, most recently from 0fd1647 to c4b816f Compare November 8, 2023 12:09
@ricardoV94 ricardoV94 changed the title Sum of mul rewrite Extend sum of mul rewrite for multiple axis Nov 20, 2023
@ricardoV94 ricardoV94 force-pushed the sum_of_mul_rewrite branch 4 times, most recently from c4f99c6 to 06148c4 Compare November 21, 2023 16:39
@ricardoV94 ricardoV94 marked this pull request as ready for review December 5, 2023 13:11
Rewrite from prod of mul was not correct when only some axes were reduced by prod
Also:
* Separates the sum of negation rewrite
* Fixes bug in partial prod reduction
@ricardoV94 ricardoV94 merged commit c49e395 into pymc-devs:main Dec 7, 2023
53 checks passed
@ricardoV94 ricardoV94 deleted the sum_of_mul_rewrite branch December 7, 2023 15:11
@ricardoV94 ricardoV94 added bug Something isn't working enhancement New feature or request labels Dec 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request graph rewriting performance
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants