Skip to content

Fix local_sqrt_sqr rewrite logic bug#1952

Merged
ricardoV94 merged 1 commit into
pymc-devs:mainfrom
WHOIM1205:fix-local-sqrt-sqr-math-rewrite
May 20, 2026
Merged

Fix local_sqrt_sqr rewrite logic bug#1952
ricardoV94 merged 1 commit into
pymc-devs:mainfrom
WHOIM1205:fix-local-sqrt-sqr-math-rewrite

Conversation

@WHOIM1205
Copy link
Copy Markdown
Contributor

Fix swapped conditions in local_sqrt_sqr rewrite

The rewrite rule local_sqrt_sqr had the conditions for sqrt(sqr(x))
and sqr(sqrt(x)) reversed. Since prev_op represents the inner
operation and node_op the outer operation, the isinstance checks
were matching the wrong patterns.

Because of this, sqrt(sqr(x)) was rewritten to
switch(x >= 0, x, nan) instead of abs(x), which caused negative
inputs to return NaN. This silently breaks the mathematical identity
sqrt(x^2) = |x| and produces incorrect gradients.

This PR swaps the two isinstance checks so the rewrites match the
correct patterns:

  • sqrt(sqr(x)) -> abs(x)
  • sqr(sqrt(x)) -> switch(x >= 0, x, nan)

Tests were also updated to reflect the correct behavior and now include
numerical checks with negative inputs.

@WHOIM1205
Copy link
Copy Markdown
Contributor Author

pre-commit.ci autofix

@WHOIM1205
Copy link
Copy Markdown
Contributor Author

heyy @ricardoV94 is there anything i can improve in this pr

@ricardoV94 ricardoV94 force-pushed the fix-local-sqrt-sqr-math-rewrite branch from 0caa51e to 2e4175f Compare March 12, 2026 12:17
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebased and removed a redundant test.

Thanks for the bugfix

@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@ricardoV94 ricardoV94 added bug Something isn't working graph rewriting labels May 20, 2026
@ricardoV94 ricardoV94 changed the base branch from v2 to main May 20, 2026 09:33
Signed-off-by: WHOIM1205 <rathourprateek8@gmail.com>
@ricardoV94 ricardoV94 force-pushed the fix-local-sqrt-sqr-math-rewrite branch from b042350 to 833e83d Compare May 20, 2026 09:42
@ricardoV94 ricardoV94 merged commit 6547422 into pymc-devs:main May 20, 2026
64 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working graph rewriting

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants