Skip to content

Fix Scan higher order derivatives#1975

Merged
ricardoV94 merged 2 commits into
pymc-devs:v3from
ricardoV94:scan_3rd_derivative_bug
Mar 26, 2026
Merged

Fix Scan higher order derivatives#1975
ricardoV94 merged 2 commits into
pymc-devs:v3from
ricardoV94:scan_3rd_derivative_bug

Conversation

@ricardoV94
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 commented Mar 15, 2026

Closes #1772

I let Claude take a stab at this as it was not priority (broken since ever?). It's a bit over my head because if there's anything more complex than the ScanOp, it's the gradient of the ScanOp (derivatives are simpler anyone?).

Here is the best AI guess as to how it is organized: https://gist.github.com/ricardoV94/809a2c25ec61d149209336ab0add6f30. I corrected it several times, so take it with a big grain of salt.

Anyway it sounds like the gradient of a mit-mot output was wrong. Ignoring, where they come from, a mit-mot represents a tape that is read and written into multiple positions at each step. The reverse pass of this process should not flow through the written positions into the past (the same way that the reverse pass of x[0].set(y), wrt to x truncates at 0).

And apparently (or so Claude tried to persuade me), this truncation was not happening. The error only shows up if you differentiate more than twice, don't ask me why.

I won't spend more time understanding the L_op right now, so I would rest on the fact that it makes some conceptual sense and that the tests now pass.

And some sparse documentation on L_op and related functionality

@ricardoV94 ricardoV94 force-pushed the scan_3rd_derivative_bug branch 3 times, most recently from 789ee32 to 1b15cf6 Compare March 16, 2026 23:50
@ricardoV94 ricardoV94 added bug Something isn't working gradients scan labels Mar 16, 2026
@ricardoV94 ricardoV94 marked this pull request as ready for review March 17, 2026 00:10
@ricardoV94 ricardoV94 force-pushed the scan_3rd_derivative_bug branch from 1b15cf6 to 55fab8a Compare March 17, 2026 00:20
Comment thread pytensor/scan/op.py
# The gradient of an overwrite must zero out the direct pass-through from the
# old value; the only gradient path is through the output expression that replaced
# it (already captured by compute_all_gradients via known_grads).
overlapping_taps = set()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this section the actual fix?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Comment thread pytensor/scan/op.py
Comment on lines +2712 to +2713
if dx in overlapping_taps:
continue # gradient truncates here
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the fix, skipping the accumulation

Comment thread tests/scan/test_basic.py Outdated
@ricardoV94 ricardoV94 force-pushed the scan_3rd_derivative_bug branch from 55fab8a to 1e86a3a Compare March 26, 2026 09:10
@ricardoV94 ricardoV94 merged commit 7095a38 into pymc-devs:v3 Mar 26, 2026
5 of 6 checks passed
@ricardoV94 ricardoV94 deleted the scan_3rd_derivative_bug branch March 26, 2026 09:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working gradients scan

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Scan broken third derivative simple case

2 participants