Skip to content

Fix error in local_reshape_to_dimshuffle#2107

Merged
ricardoV94 merged 1 commit intopymc-devs:mainfrom
ricardoV94:rewrite_bug
May 1, 2026
Merged

Fix error in local_reshape_to_dimshuffle#2107
ricardoV94 merged 1 commit intopymc-devs:mainfrom
ricardoV94:rewrite_bug

Conversation

@ricardoV94
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 commented May 1, 2026

Found out when working on pymc-devs/pymc-extras#533 by @Michal-Novomestsky

The rewrite never anticipating not finding the input all squeezable but concluding the reshape was useless (e.g., only ones, or empty) from the reshape(x, shape ) -> shape analysis.

Now we have a clear branch for the case where reshape isn't needed (either input or output tells us we must have a size 1 input), and otherwise we know we always need squeeze -> reshape -> expand_dims. This also simplifies and we don't need to call the other rewrite to cleanup the case where reshape isn't needed

@ricardoV94 ricardoV94 added bug Something isn't working graph rewriting labels May 1, 2026
copy_stack_trace(output, new_out)
if all(inp.type.broadcastable) or not new_output_shape:
# Trivial case we have provably size 1 as input or output, reshape can't be doing anything useful
new_out = inp.dimshuffle(["x"] * output.type.ndim)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This drops all pre-existing axes and expands again what's needed. Less work to write than trying to keep original length one axes and expand what's missing. Actually should check what pytensor likes to canonicalize into

Copy link
Copy Markdown
Member Author

@ricardoV94 ricardoV94 May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OTOH dropping all implicitly adds an assert that size==1 in Dimshuffle perform/dispatch, which is nice to have

@ricardoV94 ricardoV94 requested a review from jessegrabowski May 1, 2026 10:08
@ricardoV94 ricardoV94 merged commit 8588c1a into pymc-devs:main May 1, 2026
66 checks passed
@ricardoV94 ricardoV94 deleted the rewrite_bug branch May 1, 2026 12:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working graph rewriting

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants