Skip to content

Fix scan_save_mem with 0 steps#1977

Merged
ricardoV94 merged 2 commits into
pymc-devs:v3from
ricardoV94:scan_save_mem_0_steps_bug
Mar 17, 2026
Merged

Fix scan_save_mem with 0 steps#1977
ricardoV94 merged 2 commits into
pymc-devs:v3from
ricardoV94:scan_save_mem_0_steps_bug

Conversation

@ricardoV94
Copy link
Copy Markdown
Member

Closes #1878
Closes #1902

At some point Scan was allowed to do zero steps, but the scan_save_mem rewrite still insisted that that was invalid.

@ricardoV94 ricardoV94 added bug Something isn't working graph rewriting scan labels Mar 16, 2026
@ricardoV94 ricardoV94 force-pushed the scan_save_mem_0_steps_bug branch 4 times, most recently from 555633f to 7861ab7 Compare March 16, 2026 22:34
Comment thread tests/scan/test_rewriting.py Outdated
[n_steps, x0], ys_trace.shape[0], accept_inplace=True
)
assert debug_fn(n_steps=1000, x0=[1, 1]) == 3
mode = get_default_mode()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test appears to be marked as skip, why update it?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you were mislead by github folding text. this is test_while_scan_taps and is not skipped

Comment thread tests/scan/test_rewriting.py Outdated
Comment on lines +1727 to +1733
if (
not isinstance(mode.linker, JITLinker)
and config.scan__allow_output_prealloc
):
assert debug_fn(n_steps=1000, x0=[1, 1]) == 3
else:
assert debug_fn(n_steps=1000, x0=[1, 1]) == 2
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (
not isinstance(mode.linker, JITLinker)
and config.scan__allow_output_prealloc
):
assert debug_fn(n_steps=1000, x0=[1, 1]) == 3
else:
assert debug_fn(n_steps=1000, x0=[1, 1]) == 2
using_jit_linker = (
isinstance(mode.linker, JITLinker)
and config.scan__allow_output_prealloc
)
assert debug_fn(n_steps=1000, x0=[1, 1]) == (3 - using_jit_linker)

too cute?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a bit cute, I'll store the expected result in a variable first. Also my baseline would be 2 + (does_prealloc), not 3 - (does_not)

mit-mot outputs are never memory optimized. Skipping these outputs lets us get rid of two faulty logical branches that existed to mask each:
1. an `if i <= op.info.n_mit_mot:` inside an `else` branch. This was logically wrong in that it included the first non mit-mot output (should have been <, not <=). When this was the output of a while scan it created an artificial dependency on the scan output shape, and didn't allow the rewrite to happen.
2. because of this the outer `if(i <= op.info.n_mit_mot and ...)` had been added to sidestep this artificial dependency. The comment mentioned in was supposed to specifically handle sit-sot/mit-sot of while loops, but it was again looking at all mit-mots + first non mit-mot input. It was logically wrong but canceled the first logical mistake.

If we remove both things just work.
@ricardoV94 ricardoV94 force-pushed the scan_save_mem_0_steps_bug branch from 7861ab7 to d6fcc91 Compare March 17, 2026 09:51
@ricardoV94 ricardoV94 merged commit 893a4c7 into pymc-devs:v3 Mar 17, 2026
66 checks passed
@ricardoV94 ricardoV94 deleted the scan_save_mem_0_steps_bug branch March 17, 2026 10:37
@ayulockedin
Copy link
Copy Markdown
Contributor

Yup that does seem good thx

@goldigd05
Copy link
Copy Markdown

Thanks 😊

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working graph rewriting scan

Projects

None yet

Development

Successfully merging this pull request may close these issues.

scan_save_mem gives wrong results for empty scan

4 participants