Skip to content

ezpz/moe: integrate #9, #10, #11 + adapt to upstream PR #3308#12

Closed
saforem2 wants to merge 20 commits into
ezpzfrom
ezpz-moe-integration
Closed

ezpz/moe: integrate #9, #10, #11 + adapt to upstream PR #3308#12
saforem2 wants to merge 20 commits into
ezpzfrom
ezpz-moe-integration

Conversation

@saforem2
Copy link
Copy Markdown
Owner

Summary

Single-PR integration of three open MoE PRs against ezpz, adapted to land cleanly on top of upstream's recent _run_experts_for_loop / use_grouped_mm removal (pytorch#3308). Everything except the upstream merge stays inside torchtitan/experiments/ezpz/ per Golden Rule #1.

Supersedes:

Commits

  1. Merge upstream/main — pulls in 11 upstream commits including Remove MoE expert for-loop fallback pytorch/torchtitan#3308, which removed use_grouped_mm from GroupedExperts.Config and inlined _grouped_mm as the only path. That alone breaks ezpz's pre-SM90 / XPU fallback in experiments/ezpz/moe/model.py; commit 2 fixes it.

  2. feat(ezpz/moe): add EzpzGroupedExperts subclass with compute_backend selector — introduces EzpzGroupedExperts(GroupedExperts) in experiments/ezpz/moe/experts.py with three backends:

    • \"grouped_mm\" (default): defer to upstream
    • \"for_loop\": re-vendored from the upstream deletion. Restores the XPU / pre-SM90 fallback.
    • \"batched_mm_padded\": ported from Add padded batched matmul MoE expert backend #10 as an ezpz-owned opt-in.

    Adds a make_ezpz_experts_config(...) helper, two new flavors (10B_2B_sdpa_batched_mm_padded, 10B_2B_sdpa_for_loop) with matching config_registry entries, and replaces the dead use_grouped_mm mutation in model.py with a compute_backend = \"for_loop\" switch keyed off has_cuda_capability(9, 0).

  3. fix(ezpz/moe): HSDP-aware FSDP mesh info + import-fallback for older XPU wheels — combines the two Fix MoE expert FSDP mesh info for HSDP #9 / Optimize 10b 2b sdpa moe #11 changes that touched the same apply_fsdp ep_degree > 1 branch:

    • From Fix MoE expert FSDP mesh info for HSDP #9: returns HSDPMeshInfo for 2D meshes (with both shard_mesh_dim and replicate_mesh_dim) and FSDPMeshInfo for 1D, looking up axis indices by name (fsdp / efsdp / dp_replicate). Uses axis/axes not dim_name per the project's mesh naming convention.
    • From Optimize 10b 2b sdpa moe #11: wraps the _fsdp_common import in try/except ImportError; on failure, falls back to a two-phase fully_shard (experts on edp_mesh, the rest on dp_mesh) using a plain Shard return.
  4. perf(ezpz/moe): port PR #11 expert-side optimizations into EzpzGroupedExperts:

    • Equal-counts no-grad fast path inside _run_experts_for_loop (single bmm pair over stacked w13 and transposed w2_t).
    • Version-keyed w13 / w2_t caches on EzpzGroupedExperts. Bypassed under autograd; auto-invalidated on _version bump.
    • Lightweight _record_moe_fastpath counters gated on EZPZ_MOE_FASTPATH_COUNTERS=1.

Test coverage

torchtitan/experiments/ezpz/tests/test_moe_expert_backends.py:

  • test_batched_mm_padded_matches_for_loop_kernel — parity (forward + backward) on mixed / empty / all-zero count cases.
  • test_ezpz_grouped_experts_backend_selector_matches_for_loop — full module-level parity for the backend selector path.
  • test_equal_counts_no_grad_matches_loop — equal-counts bmm fast path matches the per-expert loop.
  • test_cache_hit_on_repeat_no_grad_call — cache hit/miss accounting across repeated no-grad calls.

All 4 pass on CPU.

Out of scope (deferred)

PR #11's token-dispatcher rewrite (+570/-91 in torchtitan/models/common/token_dispatcher.py + torchtitan/ops/scatter_add.py): force-load-balance dispatch path, equal-padding all-to-all fast paths, in-place deterministic_scatter_add_ combine. To stay inside experiments/ezpz/, this needs a separate EzpzLocalTokenDispatcher subclass — much larger than the expert-side subclass — and is best handled as a follow-up PR.

Validation plan

  • Unit tests pass on CPU
  • MoE smoke run on Aurora (moe_500m_50steps or moe_debugmodel) — to be filed as a follow-up smoke-test report under experiments/ezpz/docs/experiments/moe/aurora/
  • EP=2 / HSDP smoke on Aurora once a queue slot opens

Notes

  • Default behavior preserved: compute_backend defaults to \"grouped_mm\", so flavors that don't opt in get exactly the upstream behavior.
  • model.py:226 continues to switch backend on pre-SM90 — unchanged for CUDA-SM90+, restores the XPU fallback that Remove MoE expert for-loop fallback pytorch/torchtitan#3308 removed.

wwwjn and others added 15 commits May 11, 2026 11:49
…endency (pytorch#3242)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* pytorch#3236
* pytorch#3142
* __->__ pytorch#3242

vllm has this customized config parser registry support so we can plug
in TorchTitan's config parser. Why we need this:
- get rid of dependency on a HF format checkpoint folder when
initializing. Don't implicitly depend on `config.json` as config source
of truth

Another changes in this PR:
- remove the round-trip translation from torchtitan config -> vllm
config -> torchtitan config. Using closure to bypass.
…pytorch#3315)

AutoParallel's input_fn() only returned tokens, but Decoder.forward()
also receives positions via extra_kwargs. This caused a mismatch between
the number of graph placeholders (from tracing with tokens only) and the
actual runtime args (which include positions), failing with "expected N
arguments for placeholders but received N+1".

Add positions to input_fn() and a matching input constraint so
AutoParallel traces with both inputs.

Authored by Claude.
…h#3311)

## Summary
- Add a simple `log_timer` context manager to `common_utils.py` that
measures wall-clock elapsed time and logs it to console (e.g.
`trace_train_step took 0.043s`).
- Apply `log_timer` to the `trace_train_step` call in
`GraphTrainer._make_fx_forward_backward_step` to measure tracing time.

## Test plan
- [x] Verify `log_timer` output appears in training logs during
aot_fx_trace runs
- [ ] Existing unit tests pass: `pytest
torchtitan/experiments/graph_trainer/tests/ -x`
…rch#3270)

Summary:

The remat pass previously rebuilt the graph wholesale (fx.Graph() +
node_copy of every node) and relied on whole-graph DCE to remove dead
must_recompute originals. Refactor to mutate gm.graph in place: dups are
inserted in front of their first backward consumer, backward args are
redirected to the dups, and only originals whose users became empty are
erased. Original node identities and names are preserved, the
topological-order assumption is explicit (input graph order drives
insertion, validated by gm.graph.lint() at the end), and the underlying
function takes the standard (gm, example_inputs) graph pass signature.

CPU-offload reload chains are handled by hoisting the chain in front of
the earliest dup that needs it - the in-place equivalent of upstream's
"eagerly copy reload chain into the new graph" trick.


Authored with Claude.

Test Plan:
Unit tests:
  pytest torchtitan/experiments/graph_trainer/tests/test_passes.py -x
  -> 68 passed, 1 skipped (3 new in TestSelectiveActivationRematPass)

End-to-end on 8xH100 / Llama3 8B / FSDP=4 + TP=2 / aot_fx_trace / no
cudagraph, with --debug.seed=42 --debug.deterministic:

<img width="1728" height="625" alt="Screenshot 2026-05-07 at 5 16 08 PM"
src="https://github.com/user-attachments/assets/b28e3acf-626e-4014-b5e5-4ffa6a686f08"
/>

CPU offload using upstream remat pass 

<img width="1728" height="612" alt="Screenshot 2026-05-07 at 5 16 23 PM"
src="https://github.com/user-attachments/assets/6c391927-7022-496d-bb9b-0e38e4808df8"
/>
CPU offload using our refactor 


Before:
https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpfmqvbv/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000
After:
https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp9ImSg9/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000
torch._grouped_mm already provides a CUDA fallback path when the fused grouped GEMM kernel is unavailable, including on pre-SM90 hardware. Keeping a separate Python for-loop expert implementation duplicates that fallback, carries an extra configuration branch, and makes MoE behavior diverge across models.

Use the grouped-mm path unconditionally and rely on PyTorch to choose either the fused kernel or its built-in loopy fallback.
Code Move Only! 

## Summary

- Split the monolithic `passes.py` (~1037 lines) into focused modules,
keeping `passes.py` as the orchestration layer (~400 lines):
- **`memory_policy.py`** — SAC tagging, default/eager/offload policies,
and `tag_with_memory_policy_pass`
- **`inductor_passes.py`** — `regional_inductor_pass`,
`full_inductor_compilation_pass`, and
`annotate_flex_attention_for_regional_inductor_pass`
- **`cudagraph.py`** — `cudagraph_pass` and
`insert_kernel_annotations_pass` (appended to existing file)
- **`registry.py`** — `MEMORY_POLICY_REGISTRY`,
`PASS_PIPELINE_REGISTRY`, `POST_INIT_HOOKS`, `PRE_TRAIN_STEP_HOOKS` and
their decorators (breaks circular dep between `passes.py` and
`memory_policy.py`)
- Code-move only — all function bodies are identical; the only diffs are
removing local imports that became unnecessary when code moved to the
same file
- Updated `test_passes.py` imports to use new module paths

## Test plan

- [ ] `ruff check --select F401` passes clean (no unused imports)
- [ ] `pytest torchtitan/experiments/graph_trainer/tests/test_passes.py
-x`
- [ ] `pytest
torchtitan/experiments/graph_trainer/tests/test_precompile.py -x`
- [ ] `pre-commit run --all-files`
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* pytorch#3250
* pytorch#3248
* pytorch#3247
* pytorch#3246
* __->__ pytorch#3249

Add a ``support_autograd_grad`` opt-in flag to ``ChunkedCELoss.Config``
that exposes the lm_head parameter gradients as explicit autograd
outputs of the returned loss tensor. Designed for FX-tracing flows
(e.g. graph_trainer's aot_fx_trace) where ``param.grad`` side-effect
writes from ``chunk_loss.backward()`` inside the chunk loop don't
survive into the captured graph and replay therefore produces
all-zero param grads.

Mechanism, when the flag is True:

  - The chunk loop runs unchanged: per-chunk ``chunk_loss.backward()``
    populates ``lm_head.weight.grad`` with the correctly sharded value
    (FSDP last-chunk reduce-scatter still handles the actual reduction).
  - After the loop, the sharded ``param.grad`` is captured via
    ``p.grad.detach()`` and ``p.grad`` is cleared.
  - The captured grads (plus the existing accumulated hidden_states
    grad) are plumbed through a new ``_ChunkedLossWithParamGrads``
    autograd Function as saved tensors. Its backward returns them as
    grads for the corresponding lm_head parameter inputs, so outer
    ``torch.autograd.grad(loss, lm_head.parameters())`` resolves to
    real gradients instead of None / zeros.
  - Under FSDP, ``set_requires_gradient_sync(False)`` is set on lm_head
    and restored after the outer backward via a callback queued on the
    autograd engine. Without this, outer ``loss.backward()`` would
    re-fire the post-accumulate-grad hook on already-sharded grads and
    either error or produce wrong values.

Both autograd Functions (the existing ``_DecoderOutputGradientBackProp``
and the new ``_ChunkedLossWithParamGrads``) return saved grads as-is
without chain-ruling through ``grad_output``. The contract is that the
loss returned by these Functions is the autograd endpoint — callers
must pass any scaling factor (e.g. ``global_valid_tokens``) into
``loss_fn`` rather than dividing the returned loss externally. See
graph_trainer's ``compute_loss`` for the canonical pattern. This
matches the pre-existing behavior of ``_DecoderOutputGradientBackProp``
and avoids a structural FSDP+TP problem: ``grad_output`` is a
``Replicate()`` scalar DTensor on the loss's mesh (typically ``(tp,)``)
while saved param grads live on the params' mesh (e.g. ``(fsdp, tp)``);
DTensor refuses cross-mesh ``aten.mul.Tensor``, so any chain-rule
multiply would crash at runtime.

Tests:

  - tests/unit_tests/test_loss.py: parametrized ``support_autograd_grad
    in {False, True}`` equivalence check against the unchunked CE
    standard path; bitwise (rtol=0, atol=0) check that True and False
    paths produce identical grads; side-effect contract check that the
    True path doesn't populate ``param.grad``.
  - graph_trainer/tests/test_trace_module.py: end-to-end test that
    traces a small ``lm_head + ChunkedCELoss(support_autograd_grad=
    True)`` train_step via ``trace_train_step`` and verifies
    ``torch.equal`` between eager and replayed loss + h grad + lm_head
    grad on a CUDA model.

The flag defaults to False so existing callers (the eager torchtitan
trainer) are unaffected; consumers that want explicit param grads
opt in. graph_trainer wires this in a separate commit.
Wrap each layer's inner attention forward via a new
`apply_cp_to_attention` helper in common_utils, called from
parallelize_{llama,deepseekv3,qwen3} when `cp_enabled`. Adds CP and
TP+CP integration tests for all three models.

<!-- ps-id: 089851fc-7329-4da3-8ed7-103f786148a0 -->

Co-authored-by: Aditya Venkataraman <avenkataraman@fb.com>
Hey there 👋 

As explained in
[pytorch#3158](pytorch#3158 (comment)),
this PR refactors the Lychee link-checking logic to eliminate flaky CI
failures while significantly improving execution speed.

<br>

## 1. Multi-Tier Check Strategy
The "Commit-time" experience is now separated from "Infrastructure
monitoring" to ensure the development flow is not blocked by external
outages.

* **Pre-commit (Local & PR CI):** Configured to **fail only on 404**.
Codes like `502` (e.g., "GitHub Unicorns") are accepted to prevent
transient failures from stalling the workflow.
* **Nightly CI:** Performs a strict check (accepting only `200`, `403`,
`429`, `503`) with high persistence. It retries for up to **30
minutes**, outlasting most service outages. This run populates the cache
with verified statuses that later can be reused in other workflow runs.

<br>

## 2. The Cache Lifecycle: Why `key` and `restore-keys` are Necessary
To understand the necessity of a dynamic `key` and the `restore-keys`
fallback, one must first recognize that **GitHub Actions caches are
immutable**.

### The Problem with Static Keys
If a static key like `key: lychee-cache` is used without any dynamic
parts, the workflow encounters a "Cache Hit" lock:
1. **The First Run:** GitHub creates the `.lycheecache` file and saves
it as `lychee-cache`.
2. **Subsequent Runs:** GitHub finds an exact match for `lychee-cache`
and downloads it. Because GitHub cannot update an existing cache, any
new links or status updates discovered during the run are **discarded**.
3. **The "Stale Cache" Effect:** Over time, the cache becomes frozen in
the state of the first run. Fixed links remain marked as "broken," and
new links are re-checked every single time, slowing down CI.

### The Solution: Dynamic Keys & Branch Isolation
GitHub restricts cache access based on branches: the **Default Branch**
(`main`) is accessible by all, while **Feature Branches** can only
access their own caches or the default branch's cache.

By using a dynamic key (e.g., `cache-lychee-${{ github.sha }}`) combined
with `restore-keys`, the process moves through two distinct phases:
- **Restore** (start of job)
- **Save** (end of job).

#### Step-by-Step Lifecycle (3-Commit Example)
| Phase | Commit 1: The "Inheritance" | Commit 2: The "Update" | Commit
3: The "Iteration" |
| :--- | :--- | :--- | :--- |
| **Restore** | Misses `cache-lychee-SHA1`. Falls back to `restore-keys`
to pull the latest `main` cache that matches `cache-lychee-` pattern. |
Misses `SHA2`. Pulls the most recent match from the branch scope
(**SHA1**). | Misses `SHA3`. Pulls the latest available version from the
branch (**SHA2**). |
| **Action** | Lychee runs, checks only new links, and uses the
inherited cache for the rest. | Lychee updates results with any new
findings. | Lychee uses the `SHA2` baseline, ensuring zero redundant
checks. |
| **Save** | GitHub saves a **new** cache entry: `lychee-cache-SHA1`. |
GitHub saves a new immutable entry: `lychee-cache-SHA2`. | GitHub saves
the final state: `lychee-cache-SHA3`. |

> [!NOTE]
> This "chaining" effect ensures every commit builds upon the previous
one, while keeping PR runs fast and the nightly "source of truth"
accessible.

<br>

## 3. Optimization: Parallelism & Output
Previously, the configuration relied on `require_serial: true` and
`verbose: true`.

* **The Problem:** `verbose: true` was required to display the "Lychee
not found" warning (since the script uses `exit 0`). However, because
`pre-commit` spawns a process for every file, this caused the warning to
print for every single file checked. `require_serial: true` was used to
stop the log spam, but caused a **2x-3x slowdown**.
* **The Fix:** An **Atomic Sentinel** is now used via `mkdir
/tmp/lychee_lock`. Because `mkdir` is an atomic operation, only the
first process successfully creates the directory and prints the warning.
All other parallel processes fail to create the directory and remain
silent.
* **The Result:** `require_serial` and `verbose` are now **false**. The
check runs in parallel (fast), and the warning prints exactly once
(clean) by redirecting directly to the user's terminal screen via `>
/dev/tty`.

<br>

## 4. Fix Lychee version to install
In **Lychee v0.24.1**, a change in the release archive structure broke
dynamic installation scripts that fetched the "Latest" release.
* **Change:** The installation now uses a **fixed, verified version** to
prevent upstream changes from breaking the CI pipeline.

<br>

## 5. GITHUB_TOKEN
The `GITHUB_TOKEN` is explicitly passed to the Lychee action and
pre-commit steps. This increases the rate limit for GitHub API requests,
reducing `403 Forbidden` and `429 Too Many Requests` errors when
checking internal repository links.

<br>

## Summary Overview
| Feature | Local / PR CI | Nightly CI |
| :--- | :--- | :--- |
| **Failure Condition** | Only `404` | Most non-200 codes |
| **Duration/Retries** | Fast (5 retries / 15 secs) | Patient (30
retries / 30 mins) |
| **Execution** | Parallel (via Sentinel) | Standard Action |
| **Cache Goal** | Consume & Increment | Refresh & Validate |
Add Codex-facing AGENTS.md symlinks that point at the existing Claude instruction files. This lets Codex reuse the same repo and graph_trainer guidance without duplicating instruction content or creating a second source of truth.

The root AGENTS.md points to the repo-level .claude/CLAUDE.md file. The graph_trainer AGENTS.md points to the graph_trainer-local .claude/CLAUDE.md file so directory-local instructions continue to apply when Codex is working in that subtree.

Test Plan:
- git ls-tree -l HEAD AGENTS.md torchtitan/experiments/graph_trainer/AGENTS.md
- git diff --name-status origin/main..HEAD
…selector

Adapts ezpz to upstream pytorch#3308 which removed
`use_grouped_mm` from `GroupedExperts.Config` and inlined `_grouped_mm`
as the only expert compute path. That broke ezpz's pre-SM90 / XPU
fallback in `experiments/ezpz/moe/model.py` since the field no longer
exists.

Introduces `EzpzGroupedExperts(GroupedExperts)` in
`torchtitan/experiments/ezpz/moe/experts.py` with three backends:

- `grouped_mm` (default): defer to upstream's path
- `for_loop`: per-expert matmul, re-vendored from the just-deleted
  upstream helper. Restores the XPU / pre-SM90 fallback.
- `batched_mm_padded`: padded `torch.bmm` over per-expert rows.
  Originally proposed in #10 against the old `Config` shape; ported
  here as an opt-in ezpz subclass so core stays untouched.

Wires a `make_ezpz_experts_config` helper, adds two new flavors
(`10B_2B_sdpa_batched_mm_padded`, `10B_2B_sdpa_for_loop`) plus matching
`config_registry` entries, and replaces the dead `use_grouped_mm`
mutation in `model.py` with a `compute_backend = "for_loop"` switch
keyed off `has_cuda_capability(9, 0)`.

Test coverage in `torchtitan/experiments/ezpz/tests/test_moe_expert_backends.py`
verifies the batched-mm and for-loop backends agree on outputs and
gradients, including empty / all-zero count edge cases.
…XPU wheels

Combines two fixes that both touch the `apply_fsdp ep_degree > 1` branch
in `torchtitan/experiments/ezpz/moe/parallelize.py`:

1. HSDP correctness (PR #9, sourcery-bot/samuelwheeler): when running on
   an HSDP 2D mesh, the previous `FSDPMeshInfo(mesh=mesh, shard_mesh_dim=0)`
   construction produced DTensor specs that referenced only one axis of
   a multi-axis mesh. Replace with `_fsdp_mesh_info(...)` that returns
   `HSDPMeshInfo` for 2D meshes (with both `shard_mesh_dim` and
   `replicate_mesh_dim` populated) and `FSDPMeshInfo` for 1D. Looks up
   axis indices by name (`fsdp` / `efsdp` / `dp_replicate`) so callers
   never need to guess the axis order.

2. Import-fallback (PR #11, nscottnichols): some XPU PyTorch wheels do
   not ship `torch.distributed.fsdp._fully_shard._fsdp_common`, which
   means the per-param `ShardPlacementResult` API is unavailable. Wrap
   the import in `try/except ImportError`; on failure, fall back to a
   two-phase `fully_shard` (experts on `edp_mesh`, the rest of the
   block on `dp_mesh`) using a plain `Shard` return from
   `shard_placement_fn`. Behavior on builds that have the API is
   unchanged from path 1.

Uses `axis`/`axes` (not `dim_name`) for `DeviceMesh` axis references
per the project's `axis` vs `dim` naming convention.
…dExperts

Adapts the expert-compute optimizations from PR #11 (nscottnichols)
to the ezpz subclass so core MoE code stays untouched. Three pieces:

1. **Equal-counts no-grad fast path** in ``_run_experts_for_loop``. When
   every expert receives the same non-zero token count and we are
   outside an autograd region, fold SwiGLU into a single ``torch.bmm``
   pair using a stacked ``w13 = cat(w1, w3)`` and a precomputed
   ``w2_t = w2.transpose(-2, -1)``.

2. **Version-keyed weight transform caches** on ``EzpzGroupedExperts``.
   ``_get_cached_w13`` / ``_get_cached_w2_t`` reuse the stacked /
   transposed views across forward calls during inference. Cache key
   covers storage data_ptr, ``_version``, shape, dtype, and device, so
   any in-place weight mutation invalidates the cache automatically.
   Bypassed entirely when autograd is enabled.

3. **Lightweight fast-path counters** gated on the
   ``EZPZ_MOE_FASTPATH_COUNTERS=1`` environment variable for inference
   profiling. Disabled-by-default and zero-cost in the common case.

Tests cover (a) parity between the fast path and the full per-expert
loop, and (b) cache hit/miss accounting across repeated no-grad calls.

PR #11's token-dispatcher rewrite (in core ``token_dispatcher.py``)
remains out of scope for this branch and would need a separate
``EzpzLocalTokenDispatcher`` subclass to land without core edits.
Copilot AI review requested due to automatic review settings May 12, 2026 12:26
Copy link
Copy Markdown

@sourcery-ai sourcery-ai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry @saforem2, your pull request is larger than the review limit of 150000 diff characters

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR integrates multiple MoE-related changes (adapted to upstream removal of use_grouped_mm), adds ezpz-owned expert compute backend selection + coverage, and includes substantial graph_trainer + RL(vLLM) plumbing refactors alongside CI link-checking updates.

Changes:

  • Remove use_grouped_mm plumbing from core MoE (GroupedExperts) and dependent model configs; add ezpz-side EzpzGroupedExperts with a compute_backend selector plus new ezpz MoE flavors + parity tests.
  • Refactor graph_trainer pass infrastructure (new registries/modules), update selective activation remat implementation, and add chunked-loss support that exposes lm_head param grads via autograd outputs.
  • Update RL vLLM integration to a single registration entry point (model + config parser) and adjust link-checking (Lychee) to use caching and run in PR lint.

Reviewed changes

Copilot reviewed 53 out of 54 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
torchtitan/models/llama4/model.py Removes SM90 gating tied to use_grouped_mm (field removed upstream).
torchtitan/models/gpt_oss/moe.py Inlines grouped-mm path and removes use_grouped_mm branching.
torchtitan/models/gpt_oss/model.py Removes use_grouped_mm/SM90 fallback mutation logic.
torchtitan/models/deepseek_v3/model.py Removes use_grouped_mm/SM90 fallback mutation logic.
torchtitan/models/common/moe.py Removes legacy expert backends and always uses torch._grouped_mm.
torchtitan/models/common/config_utils.py Updates make_experts_config signature to drop use_grouped_mm.
torchtitan/experiments/rl/tests/test_bitwise_parity.py Switches to new vLLM registry entry point.
torchtitan/experiments/rl/models/vllm_wrapper.py Renames wrapper and uses torchtitan ParallelismConfig as source of truth.
torchtitan/experiments/rl/models/vllm_registry.py Adds registry_to_vllm + torchtitan config parser registration and HF-shaped config builder.
torchtitan/experiments/rl/models/parallelize.py Import path consolidation for config types.
torchtitan/experiments/rl/grpo.py Import path consolidation for config types.
torchtitan/experiments/rl/generate.py Updates vLLM registration callsite to registry_to_vllm.
torchtitan/experiments/rl/config_registry.py Import path consolidation; enforces generator parallelism flags.
torchtitan/experiments/rl/actors/trainer.py Import path consolidation for config types.
torchtitan/experiments/rl/actors/generator.py Validates generator parallelism invariants; uses torchtitan vLLM config parser path.
torchtitan/experiments/rl/init.py Exports renamed wrapper and new vLLM registry function.
torchtitan/experiments/graph_trainer/trainer.py Adds timing context around tracing; uses new registries.
torchtitan/experiments/graph_trainer/tests/test_trace_module.py Adds coverage for chunked CE loss tracing and new CP trace test.
torchtitan/experiments/graph_trainer/tests/test_precompile.py Extends precompile artifacts with user_inputs_spec.
torchtitan/experiments/graph_trainer/tests/test_passes.py Updates imports after pass-module split; adds new selective remat tests.
torchtitan/experiments/graph_trainer/tests/test_chunked_loss.py New unit tests for ChunkedCELossWithParamGrads.
torchtitan/experiments/graph_trainer/tests/integration_tests.py Updates integration test overrides to include CP and adjust DP shard degree.
torchtitan/experiments/graph_trainer/selective_activation_remat.py Replaces DCE-based remat with in-place duplication + targeted erasure + offload-chain hoisting.
torchtitan/experiments/graph_trainer/registry.py New shared registries to avoid circular imports.
torchtitan/experiments/graph_trainer/qwen3/parallelize.py Applies CP wrapping to attention when CP is enabled.
torchtitan/experiments/graph_trainer/precompile.py Moves distributed metadata filter import to inductor_passes.
torchtitan/experiments/graph_trainer/passes.py Refactors into orchestration-only module; imports split-out pass implementations.
torchtitan/experiments/graph_trainer/memory_policy.py New module for SAC tagging + memory policy dispatch.
torchtitan/experiments/graph_trainer/llama3/parallelize.py Applies CP wrapping to attention when CP is enabled.
torchtitan/experiments/graph_trainer/llama3/parallelize_autoparallel.py Updates autoparallel input_fn to return tokens+positions.
torchtitan/experiments/graph_trainer/inductor_passes.py New module for regional/full Inductor compilation and flex-attention annotation.
torchtitan/experiments/graph_trainer/graph_utils.py Removes outdated comment referencing use_grouped_mm=False.
torchtitan/experiments/graph_trainer/fsdp_passes.py Respects MUST_CPU_OFFLOAD tagging; updates backward-node detection hook.
torchtitan/experiments/graph_trainer/deepseek_v3/parallelize.py Applies CP wrapping to attention when CP is enabled.
torchtitan/experiments/graph_trainer/cudagraph.py Moves cudagraph + kernel annotation passes into this module.
torchtitan/experiments/graph_trainer/common_utils.py Adds log_timer and apply_cp_to_attention helpers; removes _is_recomputed_node.
torchtitan/experiments/graph_trainer/chunked_loss.py Adds ChunkedCELossWithParamGrads implementation.
torchtitan/experiments/graph_trainer/autoparallel_api.py Drops _check_forward_args dependency.
torchtitan/experiments/graph_trainer/.claude/CLAUDE.md Updates docs to match new remat behavior.
torchtitan/experiments/ezpz/tests/test_moe_expert_backends.py New unit tests for ezpz expert backends, fast paths, and caching.
torchtitan/experiments/ezpz/tests/init.py Adds test package marker.
torchtitan/experiments/ezpz/moe/parallelize.py Makes expert FSDP mesh info HSDP-aware; adds import fallback for older wheels.
torchtitan/experiments/ezpz/moe/model.py Switches ezpz MoE backend to for_loop on non-SM90 via compute_backend.
torchtitan/experiments/ezpz/moe/experts.py Introduces EzpzGroupedExperts with backend selector, padded-bmm backend, and caches/counters.
torchtitan/experiments/ezpz/moe/config_registry.py Adds new ezpz MoE flavors for the new expert backends.
torchtitan/experiments/ezpz/moe/init.py Adds make_ezpz_experts_config and wires compute_backend through ezpz MoE configs.
torchtitan/components/loss.py Introduces overrideable _gradient_backprop hook for chunked-loss subclasses.
tests/unit_tests/test_compile_moe.py Updates grouped-mm compile test to call GroupedExperts._experts_forward.
.pre-commit-config.yaml Updates lychee hook (cache + reduced duplicate warnings + accept-code policy).
.gitignore Ignores .lycheecache.
.github/workflows/lint.yaml Pins Lychee version, adds cache restore, and runs lychee via pre-commit on all files.
.github/workflows/link_check.yaml Removes nightly link-check workflow (now covered by lint workflow).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

echo "Installing Lychee version: $VERSION"

URL="https://github.com/lycheeverse/lychee/releases/download/${LATEST_TAG}/lychee-x86_64-unknown-linux-gnu.tar.gz"
URL="https://github.com/lycheeverse/lychee/releases/download/lychee-${VERSION}/lychee-x86_64-unknown-linux-gnu.tar.gz"
# Note: this is not safe if downstream accidentally runs tensor ops after
# the loss returns, which would produce a non-trivial grad_output that we need
# to properly handle. The complicated part is that grad_output might not be
# on the same device mesh as accumlated_grad.
# The chunk loop above already populated each lm_head param's
# ``.grad`` with the correctly sharded value via the FSDP last-chunk
# post-accumulate-grad hook (reduce-scatter). Capture those grads
# into saved_tensors so backward ca route them as autograd outputs
saforem2 added 5 commits May 12, 2026 08:35
Sets num_workers=2 on the BlendCorpusDataLoader.Config built in both
_base_config()s — training dataloader for {agpt, moe}, plus the agpt
validator dataloader. Previously num_workers fell through to the
framework default of 0 (single-threaded I/O), forcing every production
launch to pass --dataloader.num-workers=2 explicitly. Defaulting to
the value we already pass everywhere removes that boilerplate while
keeping the CLI override available for sweeps.

Validator defaults left as-is per user (enable=False, freq=200).
The first round of backend comparison runs (logs/moe-backend-bench/
*20260512-102301.log) all OOM'd in FSDP all-gather copy-in during the
very first forward pass — same failure mode across all three backends:

    RuntimeError: level_zero backend failed with error: 40
        (UR_RESULT_ERROR_OUT_OF_RESOURCES)

Root cause: the existing `moe_10b_2b_sdpa` flavor is FSDP-only (EP=1)
with `activation_checkpoint_mode="none"`. The 10B model + LBS=2 +
seq_len=8192 doesn't fit in XPU memory at that parallelism shape, and
the new `_for_loop` / `_batched_mm_padded` variants I added inherited
the same defaults so they OOM'd identically. The OOM is a parallelism-
config problem, not a regression in the new backends.

Add three EP=2 + AC=full variants for an apples-to-apples comparison:

- `moe_10b_2b_sdpa_ep_ac`: existing `_ep` flavor with AC=full (baseline)
- `moe_10b_2b_sdpa_for_loop_ep`: for_loop backend, EP=2, AC=full
- `moe_10b_2b_sdpa_batched_mm_padded_ep`: bmm backend, EP=2, AC=full

All three share the same parallelism / memory shape so any TPS/MFU
delta in a re-run will reflect the expert backend, not the FSDP/EP
config. Use these for the backend sweep instead of the EP=1 variants.
The first sweep round (logs/moe-backend-bench/*20260512-102301.log)
hit XPU OOM in FSDP all-gather copy-in during the first forward pass
on all three configs — the parallelism shape (EP=1, AC=none) doesn't
fit a 10B model regardless of expert backend. Swap the sweep to the
new `_ep_ac` / `_for_loop_ep` / `_batched_mm_padded_ep` flavors so any
TPS/MFU delta reflects the expert backend, not the parallelism config.

Also add a Layer 1c pre-flight to `moe_backend_layer0_1.sh` that
asserts the EP=2 + AC=full configs build with the right
`compute_backend` and that the experts config is the
`EzpzGroupedExperts.Config` subclass. Cheap pre-check that catches
config-graph regressions before the sweep spends ~5 min/backend
launching only to crash at build time.
Both 2B and 80B failover-wrapped jobs (8480294, 8480361) died at 00:00
elapsed with `ERROR: failover_lib.sh not found at
/var/spool/pbs/mom_priv/jobs/failover_lib.sh`.

Root cause: PBS copies the submit script to /var/spool/pbs/mom_priv/jobs/
before running it on the head node, so `dirname "$(realpath "$0")"`
resolves to that spool dir, not the original $repo/scripts. The relative
"$PBS_O_WORKDIR/.ezpz-utils-cache" lookup also failed with realpath
collapsing across the wrong base.

Fix: anchor both EZPZ_UTILS and FAILOVER_LIB to $PBS_O_WORKDIR with the
canonical repo-relative paths. Also add PBS_O_WORKDIR + PWD to the error
message so future debugging is one log-line shorter.
@saforem2
Copy link
Copy Markdown
Owner Author

Superseded by #13. That PR takes a more focused 2-commit approach: just the upstream merge + the minimal EzpzGroupedExperts replay needed to keep ezpz working after pytorch#3308. The PR #9 / #10 / #11 expert-side and HSDP changes are intentionally not bundled there — they should rebase onto the resync'd ezpz independently.

@saforem2 saforem2 closed this May 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants