Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

collect_results doesn't collect grads of non-parameter leaf inputs #1901

@ezyang

Description

@ezyang

🐛 Describe the bug

    out = gm(args)
    if only_fwd:
        return out
    if requires_bwd_pass(out):
        loss = reduce_to_scalar_loss(out)
        loss.backward()
    return collect_results(gm, out, None, [])

args can have gradients accumulated on it, but collect results ignores this entirely. These can matter if this graph further propagates changes to other parameters, and can impede minifier simplification.

Error logs

No response

Minified repro

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingtriaged

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions