Skip to content

Scan: remove unused outputs#2067

Merged
ricardoV94 merged 4 commits intopymc-devs:mainfrom
ricardoV94:scan_unused_outputs
Apr 27, 2026
Merged

Scan: remove unused outputs#2067
ricardoV94 merged 4 commits intopymc-devs:mainfrom
ricardoV94:scan_unused_outputs

Conversation

@ricardoV94
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 commented Apr 22, 2026

We were missing this basic optimization to simplify scan when only some outputs were used. half assing this optimization. it was partially bundled with scan unused inputs and scan save memory (both of which do to much and not enough).

Specifically we never dropped mit mot or untraced sit sot (because they didn't play in scan save memory main goal) and so a pullback of a scan with unused outputs always had a useless mitmot and kept a reference to the equally useless output of the forward pass. Double waste.

This PR adds a rewrite whose single purpose is triming away useless computation and inputs. It removes the respective behavior from the two other rewrites. It is also called eagerly in the pullback. There's a lot of bookeeping already in the pullback method that I feel is better to keep as is and just patch at the end reusing the rewrite.

Also fixed a bug in scan_push_out_seq which violated the contract where nit_sot can have larger buffer than n_steps (shows up in truncated/while scan gradients). This bug revealed itself with the new pullback cleanup, but is otherwise orthogonal.

PS: We need something like this for OpFromGraph, but there we need to be careful when to do it: OFG that encapsulate an Op with a specific signature, (AllocDiag, Einsum, RVs) shouldn't be mutated because other rewrites may expected the standard signature when working with them as closed box Ops. Similarly if the same OFG is reused across multiple nodes, we only want to do it, when outputs aren't used across all uses, since one of the goals of OFG is to reduce compilation/rewrite work for repeated subgraphs.

@jessegrabowski
Copy link
Copy Markdown
Member

re: OFG, we shouldn't have this problem if the OFG gets inlined right?

@ricardoV94 ricardoV94 force-pushed the scan_unused_outputs branch from e2d85b1 to 731c68d Compare April 22, 2026 16:13
@ricardoV94
Copy link
Copy Markdown
Member Author

re: OFG, we shouldn't have this problem if the OFG gets inlined right?

Right, only applies to non-inlined OFGs

@jessegrabowski
Copy link
Copy Markdown
Member

should we just be more aggressive about inline? We already are iirc.

Comment thread pytensor/scan/rewriting.py Outdated
@ricardoV94
Copy link
Copy Markdown
Member Author

ricardoV94 commented Apr 22, 2026

should we just be more aggressive about inline? We already are iirc.

Not necessarily, but if there's late pipeline optimizations we can do on non-inlined OFG we should that's all. Like Scan it's an inner graph Op, so many of the questions posed are similar, unlike Scan it's expected to have multiple instances of the same Op by nature (it's one of its applications, so we have to be more careful when we go about doing it).

Another case is inplace, can make it much faster, but again only if it doesn't lead to de-duplicated OFGs in the final graph

@jessegrabowski
Copy link
Copy Markdown
Member

Another case is inplace, can make it much faster, but again only if it doesn't lead to de-duplicated OFGs in the final graph

Yeah, ordering is hard. This would be another place egglog could help (deciding the right moment to inplace)

@ricardoV94 ricardoV94 force-pushed the scan_unused_outputs branch 3 times, most recently from 5f7de8a to 1a190d4 Compare April 22, 2026 18:28
@ricardoV94
Copy link
Copy Markdown
Member Author

one test failing but ready for review

Comment thread pytensor/scan/rewriting.py Outdated
@ricardoV94 ricardoV94 force-pushed the scan_unused_outputs branch 2 times, most recently from ffe7563 to 9508b01 Compare April 23, 2026 08:55


@node_rewriter([Scan])
def scan_inline_invariant_constants(fgraph, node):
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be extended for Alloc that broadcast the front dim, or constant that only have duplicates along 0th axis (a bit more expensive to check). We did this in a client project and yielded some useful simplifications. It also need not inline to provide speedup moving from sequences -> non_sequences is already nice, and then if it's something like an Alloc that's usually subsumed by other internal Elemwise/Alloc it can still be inlined, and the non-sequence is just the value/shape. Something to do next. The fact we isolated the scope of this rewrite makes it nicer to extend next

"""
op = node.op

def _duplicates(inner_list, outer_list):
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rewrite is pretty expensive. I tshould rely on merge optimizer and just check x is y. Don't want to regress on this PR though

@ricardoV94 ricardoV94 marked this pull request as ready for review April 23, 2026 09:07
Copy link
Copy Markdown
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did as best i could with this one while only reading the code. Approving on trust + things seem fine. I could check out the branch and really try to understand it if you want, or we could do a walkthrough call, but I'm also comfortable with it going in.

Comment thread pytensor/scan/rewriting.py Outdated
op,
node,
*,
drop_seqs=frozenset(),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

considering typehinting these as frozenset[int] for clairty


inner_outputs = op.inner_outputs
if inner_substitutions:
inner_outputs = clone_replace(inner_outputs, replace=inner_substitutions)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why clone here, in case an input is another Op that needs to be rebuild (e.g. scan into scan?)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as opposed to graph replace you mean? I think just for safety to avoid sharing variables between distinct scans (since we still mutate the fgraph)

else:
extra_dims = [_y.shape[i] for i in range(1, _y.ndim)]
zero_buf = pt.zeros((nit_sot_size, *extra_dims), dtype=_y.dtype)
y = set_subtensor(zero_buf[:n_steps], _y)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it always correct to pad left?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it may just be useless because the elemwise already had the total shape. Or did you mean something else?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was specifically thinking this was about the case where the buffer got trimmed by scan optimizations. I don't know if that trimming always happens on the left side (the first time steps) or if it can also be on the right.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trimming is always on the left (last states are kept never discarded in preference of earlier ones), but tbh I wouldn't be surprised if this rewrite completely broke with a scan save mem.

Even the bot text assumption is wrong then, n_steps can be larget than nit_sot size in that case.

I'm not interested in this rewrite tbh, I think it's myopic/ worse in expectation. I just patched the bug in the logic assuming non scan save mem. May be worth checking if it handles save mem cases.

I wouldn't be surprised "push out" rewrites don't handle this, just like pullback itself doesn't

Copy link
Copy Markdown
Member Author

@ricardoV94 ricardoV94 Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this and other paths in this rewrite don't handle trimmed buffers.

In general we should distinguish scan with trimmed buffers at the Op level, because otherwise it complicates the logic and breaks many things.

This PR fixes the nit-sot case and makes the rewrite self-consistent with full scan. The pre-existing trimmed scan limitations remain.

Scan save mem is registered separately later and continues to be unsafe if called out of order. I'd postpone the cleanup it clearly needs.

Comment thread pytensor/scan/rewriting.py Outdated
Comment on lines +2133 to +2135
# candidate state inputs. Un-confirming a candidate folds its
# outputs' reached inputs into the live set immediately, so later
# candidates in the same pass see the update.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any ordering issues that could arise from this procedure? Like if a certain candidate depends indirectly on the direct inputs of a certain other candidate, so if they trigger in the wrong order the 1st won't see the the inputs in the live set yet and be wrongly dropped.

I can't think of a concrete example, but wondering out loud.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's an indirect dependency?

For a -> b -> c Everything that has b as an ancestor must also have a as an ancestor since a in an ancestor of b.

Copy link
Copy Markdown
Member Author

@ricardoV94 ricardoV94 Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you mean sit-sot1 depends on sit-sot2 which depends on sit-sot3, then 1 can be dropped unconditionally, 2 only if 1 is also dropped (or else the corresponding input is visited by the survivor 1) and 3 only if 1 and 2 are dropped.

Think about the cases that invalidate dropping sot-sot 3:

Direct: sit-sot2 is not a drop candidate, therefore sit-sot3 inner input is an ancestor of it, and is "reachable from survivors"

Two stage: sit-sot2 is a drop candidate as well (but not sit-sot1). We already showed we can't drop sit-sot2 because sit-sot1 depends on it. Therefore the loop will remove sit-sot2 from candidates (and add its root inputs as needed - which includes sit-sot3 inner input), and we are back to direct case invalidatiom.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok ok, thanks for thinking hard about it even when I didn't

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test like that, only had the cyclical dependency before (a depends on b, and b on a)

@ricardoV94 ricardoV94 force-pushed the scan_unused_outputs branch 2 times, most recently from a7c8417 to 7eb34c9 Compare April 27, 2026 10:40
@jessegrabowski
Copy link
Copy Markdown
Member

mypy :)

When folding a stateless nit_sot scan into an Elemwise, the pushed-out
result has length == n_steps. If the nit_sot's declared outer buffer
size is larger (e.g., a Scan built via Scan.pullback with
truncate_gradient, where nit_sot_size tracks the forward's step count
but n_steps is the truncated grad_steps), the direct fold silently
drops the trailing slots that the scan's allocator would have left
zero-initialized, breaking any downstream consumer that reads the full
buffer.

Pad the Elemwise result via set_subtensor(zeros(nit_sot_size, ...), _y)
when the two sizes aren't the same Variable. The direct-fold shortcut
still fires for the common case where pytensor.scan reuses the same
n_steps Variable for both.
Pure reachability analysis that drops, in a single pass, unused outputs
and inputs from a Scan node:

  * State slots (mit_mot / mit_sot / sit_sot / nit_sot / untraced_sit_sot)
    whose outer output has no clients, provided none of their inner inputs
    is reached from any surviving inner output. Cross-dependent unused
    states are resolved together.
  * Sequences and non-sequences that the rebuilt inner graph no longer
    references.

Rebuild plumbing is factored into ``_rebuild_scan_with_new_signature``,
a helper other rewrites can reuse to produce a Scan with a trimmed
signature (drop categories individually; optionally apply inner-graph
substitutions).

Registered in scan_eqopt2 and wired into Scan.L_op so the pullback graph
is cleaned eagerly when disconnected cotangents are present, avoiding
unused gradient computation in the user-facing graph.
``scan_save_mem`` is responsible for buffer-size trimming; unused-output
removal is owned by ``scan_remove_unused``. Strip the orphane-output
detection, the ``scan_can_remove_outs`` reachability call, and the
``compress_outs`` rebuild path. The new scan is built directly by reusing
the existing op with only the resized outer inputs -- no state drops
happen here.

The one piece of the old orphan path that was load-bearing for buffer
sizing was the "required orphan" case (state with no external clients
but needed by the inner recurrence). It's replaced by a small post-loop
that trims such mit_sot / sit_sot buffers to their minimum (``taps + 1``
under prealloc, ``taps`` otherwise).

``scan_can_remove_outs`` and ``compress_outs`` had no other callers and
are deleted from ``pytensor.scan.utils``.

Also register ``scan_remove_unused`` at the top-level optdb so it is
discoverable via its own tag -- needed for tests that explicitly include
``scan_save_mem`` under FAST_COMPILE, where buffer trimming without
unused-output removal leaves orphan nit_sots in the final Scan.
The old rewrite bundled three responsibilities. They're now separated:

  * ``scan_inline_invariant_constants`` -- inlines compile-time-constant,
    iteration-invariant inputs (non-sequence Constants, and sequences
    whose outer input is a ``TensorConstant`` with a uniform value) into
    the inner graph. Enables inner constant-folding.
  * ``scan_merge_duplicate_inputs`` -- deduplicates outer seqs / non_seqs
    that are ``equal_computations``.
  * ``scan_remove_unused`` (commit 1) -- drops unused outputs and inputs,
    which also cleans up the stale inputs left behind by the other two.

The three are registered together in a single ``dfs_rewriter`` at the
four positions that used to host ``remove_constants_and_unused_inputs_scan``,
with ``scan_remove_unused`` first (most powerful, always reduces), then
inline, then merge.

Tests for the old bundled rewrite now exercise the split combination.
@ricardoV94 ricardoV94 force-pushed the scan_unused_outputs branch from 7eb34c9 to 0ba3ca8 Compare April 27, 2026 14:28
@ricardoV94 ricardoV94 merged commit 5993201 into pymc-devs:main Apr 27, 2026
64 of 66 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants