Skip to content

[graph_trainer] Add remat pass and torch.no_grad() execution to minimal_fx_tracer#2767

Closed
tugsbayasgalan wants to merge 1 commit intogh/tugsbayasgalan/12/basefrom
gh/tugsbayasgalan/12/head
Closed

[graph_trainer] Add remat pass and torch.no_grad() execution to minimal_fx_tracer#2767
tugsbayasgalan wants to merge 1 commit intogh/tugsbayasgalan/12/basefrom
gh/tugsbayasgalan/12/head

Conversation

@tugsbayasgalan
Copy link
Copy Markdown
Contributor

@tugsbayasgalan tugsbayasgalan commented Mar 31, 2026

Stack from ghstack (oldest at bottom):

  • Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
    _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
    identify the forward/backward boundary.
  • Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
    Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
    backward and the forward copies are DCE'd, reducing peak memory.
  • Execute traced graph under torch.no_grad() since the graph already contains
    explicit backward ops. Without this, PyTorch builds a redundant autograd
    graph keeping all forward intermediates alive via grad_fn references.
  • Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
    of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

…al_fx_tracer

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant