ezpz/moe: integrate #9, #10, #11 + adapt to upstream PR #3308#12
ezpz/moe: integrate #9, #10, #11 + adapt to upstream PR #3308#12saforem2 wants to merge 20 commits into
Conversation
…endency (pytorch#3242) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * pytorch#3236 * pytorch#3142 * __->__ pytorch#3242 vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. Why we need this: - get rid of dependency on a HF format checkpoint folder when initializing. Don't implicitly depend on `config.json` as config source of truth Another changes in this PR: - remove the round-trip translation from torchtitan config -> vllm config -> torchtitan config. Using closure to bypass.
As titled
…pytorch#3315) AutoParallel's input_fn() only returned tokens, but Decoder.forward() also receives positions via extra_kwargs. This caused a mismatch between the number of graph placeholders (from tracing with tokens only) and the actual runtime args (which include positions), failing with "expected N arguments for placeholders but received N+1". Add positions to input_fn() and a matching input constraint so AutoParallel traces with both inputs. Authored by Claude.
…h#3311) ## Summary - Add a simple `log_timer` context manager to `common_utils.py` that measures wall-clock elapsed time and logs it to console (e.g. `trace_train_step took 0.043s`). - Apply `log_timer` to the `trace_train_step` call in `GraphTrainer._make_fx_forward_backward_step` to measure tracing time. ## Test plan - [x] Verify `log_timer` output appears in training logs during aot_fx_trace runs - [ ] Existing unit tests pass: `pytest torchtitan/experiments/graph_trainer/tests/ -x`
…rch#3270) Summary: The remat pass previously rebuilt the graph wholesale (fx.Graph() + node_copy of every node) and relied on whole-graph DCE to remove dead must_recompute originals. Refactor to mutate gm.graph in place: dups are inserted in front of their first backward consumer, backward args are redirected to the dups, and only originals whose users became empty are erased. Original node identities and names are preserved, the topological-order assumption is explicit (input graph order drives insertion, validated by gm.graph.lint() at the end), and the underlying function takes the standard (gm, example_inputs) graph pass signature. CPU-offload reload chains are handled by hoisting the chain in front of the earliest dup that needs it - the in-place equivalent of upstream's "eagerly copy reload chain into the new graph" trick. Authored with Claude. Test Plan: Unit tests: pytest torchtitan/experiments/graph_trainer/tests/test_passes.py -x -> 68 passed, 1 skipped (3 new in TestSelectiveActivationRematPass) End-to-end on 8xH100 / Llama3 8B / FSDP=4 + TP=2 / aot_fx_trace / no cudagraph, with --debug.seed=42 --debug.deterministic: <img width="1728" height="625" alt="Screenshot 2026-05-07 at 5 16 08 PM" src="https://github.com/user-attachments/assets/b28e3acf-626e-4014-b5e5-4ffa6a686f08" /> CPU offload using upstream remat pass <img width="1728" height="612" alt="Screenshot 2026-05-07 at 5 16 23 PM" src="https://github.com/user-attachments/assets/6c391927-7022-496d-bb9b-0e38e4808df8" /> CPU offload using our refactor Before: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpfmqvbv/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 After: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp9ImSg9/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000
torch._grouped_mm already provides a CUDA fallback path when the fused grouped GEMM kernel is unavailable, including on pre-SM90 hardware. Keeping a separate Python for-loop expert implementation duplicates that fallback, carries an extra configuration branch, and makes MoE behavior diverge across models. Use the grouped-mm path unconditionally and rely on PyTorch to choose either the fused kernel or its built-in loopy fallback.
Code Move Only! ## Summary - Split the monolithic `passes.py` (~1037 lines) into focused modules, keeping `passes.py` as the orchestration layer (~400 lines): - **`memory_policy.py`** — SAC tagging, default/eager/offload policies, and `tag_with_memory_policy_pass` - **`inductor_passes.py`** — `regional_inductor_pass`, `full_inductor_compilation_pass`, and `annotate_flex_attention_for_regional_inductor_pass` - **`cudagraph.py`** — `cudagraph_pass` and `insert_kernel_annotations_pass` (appended to existing file) - **`registry.py`** — `MEMORY_POLICY_REGISTRY`, `PASS_PIPELINE_REGISTRY`, `POST_INIT_HOOKS`, `PRE_TRAIN_STEP_HOOKS` and their decorators (breaks circular dep between `passes.py` and `memory_policy.py`) - Code-move only — all function bodies are identical; the only diffs are removing local imports that became unnecessary when code moved to the same file - Updated `test_passes.py` imports to use new module paths ## Test plan - [ ] `ruff check --select F401` passes clean (no unused imports) - [ ] `pytest torchtitan/experiments/graph_trainer/tests/test_passes.py -x` - [ ] `pytest torchtitan/experiments/graph_trainer/tests/test_precompile.py -x` - [ ] `pre-commit run --all-files`
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * pytorch#3250 * pytorch#3248 * pytorch#3247 * pytorch#3246 * __->__ pytorch#3249 Add a ``support_autograd_grad`` opt-in flag to ``ChunkedCELoss.Config`` that exposes the lm_head parameter gradients as explicit autograd outputs of the returned loss tensor. Designed for FX-tracing flows (e.g. graph_trainer's aot_fx_trace) where ``param.grad`` side-effect writes from ``chunk_loss.backward()`` inside the chunk loop don't survive into the captured graph and replay therefore produces all-zero param grads. Mechanism, when the flag is True: - The chunk loop runs unchanged: per-chunk ``chunk_loss.backward()`` populates ``lm_head.weight.grad`` with the correctly sharded value (FSDP last-chunk reduce-scatter still handles the actual reduction). - After the loop, the sharded ``param.grad`` is captured via ``p.grad.detach()`` and ``p.grad`` is cleared. - The captured grads (plus the existing accumulated hidden_states grad) are plumbed through a new ``_ChunkedLossWithParamGrads`` autograd Function as saved tensors. Its backward returns them as grads for the corresponding lm_head parameter inputs, so outer ``torch.autograd.grad(loss, lm_head.parameters())`` resolves to real gradients instead of None / zeros. - Under FSDP, ``set_requires_gradient_sync(False)`` is set on lm_head and restored after the outer backward via a callback queued on the autograd engine. Without this, outer ``loss.backward()`` would re-fire the post-accumulate-grad hook on already-sharded grads and either error or produce wrong values. Both autograd Functions (the existing ``_DecoderOutputGradientBackProp`` and the new ``_ChunkedLossWithParamGrads``) return saved grads as-is without chain-ruling through ``grad_output``. The contract is that the loss returned by these Functions is the autograd endpoint — callers must pass any scaling factor (e.g. ``global_valid_tokens``) into ``loss_fn`` rather than dividing the returned loss externally. See graph_trainer's ``compute_loss`` for the canonical pattern. This matches the pre-existing behavior of ``_DecoderOutputGradientBackProp`` and avoids a structural FSDP+TP problem: ``grad_output`` is a ``Replicate()`` scalar DTensor on the loss's mesh (typically ``(tp,)``) while saved param grads live on the params' mesh (e.g. ``(fsdp, tp)``); DTensor refuses cross-mesh ``aten.mul.Tensor``, so any chain-rule multiply would crash at runtime. Tests: - tests/unit_tests/test_loss.py: parametrized ``support_autograd_grad in {False, True}`` equivalence check against the unchunked CE standard path; bitwise (rtol=0, atol=0) check that True and False paths produce identical grads; side-effect contract check that the True path doesn't populate ``param.grad``. - graph_trainer/tests/test_trace_module.py: end-to-end test that traces a small ``lm_head + ChunkedCELoss(support_autograd_grad= True)`` train_step via ``trace_train_step`` and verifies ``torch.equal`` between eager and replayed loss + h grad + lm_head grad on a CUDA model. The flag defaults to False so existing callers (the eager torchtitan trainer) are unaffected; consumers that want explicit param grads opt in. graph_trainer wires this in a separate commit.
Wrap each layer's inner attention forward via a new
`apply_cp_to_attention` helper in common_utils, called from
parallelize_{llama,deepseekv3,qwen3} when `cp_enabled`. Adds CP and
TP+CP integration tests for all three models.
<!-- ps-id: 089851fc-7329-4da3-8ed7-103f786148a0 -->
Co-authored-by: Aditya Venkataraman <avenkataraman@fb.com>
Hey there 👋 As explained in [pytorch#3158](pytorch#3158 (comment)), this PR refactors the Lychee link-checking logic to eliminate flaky CI failures while significantly improving execution speed. <br> ## 1. Multi-Tier Check Strategy The "Commit-time" experience is now separated from "Infrastructure monitoring" to ensure the development flow is not blocked by external outages. * **Pre-commit (Local & PR CI):** Configured to **fail only on 404**. Codes like `502` (e.g., "GitHub Unicorns") are accepted to prevent transient failures from stalling the workflow. * **Nightly CI:** Performs a strict check (accepting only `200`, `403`, `429`, `503`) with high persistence. It retries for up to **30 minutes**, outlasting most service outages. This run populates the cache with verified statuses that later can be reused in other workflow runs. <br> ## 2. The Cache Lifecycle: Why `key` and `restore-keys` are Necessary To understand the necessity of a dynamic `key` and the `restore-keys` fallback, one must first recognize that **GitHub Actions caches are immutable**. ### The Problem with Static Keys If a static key like `key: lychee-cache` is used without any dynamic parts, the workflow encounters a "Cache Hit" lock: 1. **The First Run:** GitHub creates the `.lycheecache` file and saves it as `lychee-cache`. 2. **Subsequent Runs:** GitHub finds an exact match for `lychee-cache` and downloads it. Because GitHub cannot update an existing cache, any new links or status updates discovered during the run are **discarded**. 3. **The "Stale Cache" Effect:** Over time, the cache becomes frozen in the state of the first run. Fixed links remain marked as "broken," and new links are re-checked every single time, slowing down CI. ### The Solution: Dynamic Keys & Branch Isolation GitHub restricts cache access based on branches: the **Default Branch** (`main`) is accessible by all, while **Feature Branches** can only access their own caches or the default branch's cache. By using a dynamic key (e.g., `cache-lychee-${{ github.sha }}`) combined with `restore-keys`, the process moves through two distinct phases: - **Restore** (start of job) - **Save** (end of job). #### Step-by-Step Lifecycle (3-Commit Example) | Phase | Commit 1: The "Inheritance" | Commit 2: The "Update" | Commit 3: The "Iteration" | | :--- | :--- | :--- | :--- | | **Restore** | Misses `cache-lychee-SHA1`. Falls back to `restore-keys` to pull the latest `main` cache that matches `cache-lychee-` pattern. | Misses `SHA2`. Pulls the most recent match from the branch scope (**SHA1**). | Misses `SHA3`. Pulls the latest available version from the branch (**SHA2**). | | **Action** | Lychee runs, checks only new links, and uses the inherited cache for the rest. | Lychee updates results with any new findings. | Lychee uses the `SHA2` baseline, ensuring zero redundant checks. | | **Save** | GitHub saves a **new** cache entry: `lychee-cache-SHA1`. | GitHub saves a new immutable entry: `lychee-cache-SHA2`. | GitHub saves the final state: `lychee-cache-SHA3`. | > [!NOTE] > This "chaining" effect ensures every commit builds upon the previous one, while keeping PR runs fast and the nightly "source of truth" accessible. <br> ## 3. Optimization: Parallelism & Output Previously, the configuration relied on `require_serial: true` and `verbose: true`. * **The Problem:** `verbose: true` was required to display the "Lychee not found" warning (since the script uses `exit 0`). However, because `pre-commit` spawns a process for every file, this caused the warning to print for every single file checked. `require_serial: true` was used to stop the log spam, but caused a **2x-3x slowdown**. * **The Fix:** An **Atomic Sentinel** is now used via `mkdir /tmp/lychee_lock`. Because `mkdir` is an atomic operation, only the first process successfully creates the directory and prints the warning. All other parallel processes fail to create the directory and remain silent. * **The Result:** `require_serial` and `verbose` are now **false**. The check runs in parallel (fast), and the warning prints exactly once (clean) by redirecting directly to the user's terminal screen via `> /dev/tty`. <br> ## 4. Fix Lychee version to install In **Lychee v0.24.1**, a change in the release archive structure broke dynamic installation scripts that fetched the "Latest" release. * **Change:** The installation now uses a **fixed, verified version** to prevent upstream changes from breaking the CI pipeline. <br> ## 5. GITHUB_TOKEN The `GITHUB_TOKEN` is explicitly passed to the Lychee action and pre-commit steps. This increases the rate limit for GitHub API requests, reducing `403 Forbidden` and `429 Too Many Requests` errors when checking internal repository links. <br> ## Summary Overview | Feature | Local / PR CI | Nightly CI | | :--- | :--- | :--- | | **Failure Condition** | Only `404` | Most non-200 codes | | **Duration/Retries** | Fast (5 retries / 15 secs) | Patient (30 retries / 30 mins) | | **Execution** | Parallel (via Sentinel) | Standard Action | | **Cache Goal** | Consume & Increment | Refresh & Validate |
Add Codex-facing AGENTS.md symlinks that point at the existing Claude instruction files. This lets Codex reuse the same repo and graph_trainer guidance without duplicating instruction content or creating a second source of truth. The root AGENTS.md points to the repo-level .claude/CLAUDE.md file. The graph_trainer AGENTS.md points to the graph_trainer-local .claude/CLAUDE.md file so directory-local instructions continue to apply when Codex is working in that subtree. Test Plan: - git ls-tree -l HEAD AGENTS.md torchtitan/experiments/graph_trainer/AGENTS.md - git diff --name-status origin/main..HEAD
…selector Adapts ezpz to upstream pytorch#3308 which removed `use_grouped_mm` from `GroupedExperts.Config` and inlined `_grouped_mm` as the only expert compute path. That broke ezpz's pre-SM90 / XPU fallback in `experiments/ezpz/moe/model.py` since the field no longer exists. Introduces `EzpzGroupedExperts(GroupedExperts)` in `torchtitan/experiments/ezpz/moe/experts.py` with three backends: - `grouped_mm` (default): defer to upstream's path - `for_loop`: per-expert matmul, re-vendored from the just-deleted upstream helper. Restores the XPU / pre-SM90 fallback. - `batched_mm_padded`: padded `torch.bmm` over per-expert rows. Originally proposed in #10 against the old `Config` shape; ported here as an opt-in ezpz subclass so core stays untouched. Wires a `make_ezpz_experts_config` helper, adds two new flavors (`10B_2B_sdpa_batched_mm_padded`, `10B_2B_sdpa_for_loop`) plus matching `config_registry` entries, and replaces the dead `use_grouped_mm` mutation in `model.py` with a `compute_backend = "for_loop"` switch keyed off `has_cuda_capability(9, 0)`. Test coverage in `torchtitan/experiments/ezpz/tests/test_moe_expert_backends.py` verifies the batched-mm and for-loop backends agree on outputs and gradients, including empty / all-zero count edge cases.
…XPU wheels Combines two fixes that both touch the `apply_fsdp ep_degree > 1` branch in `torchtitan/experiments/ezpz/moe/parallelize.py`: 1. HSDP correctness (PR #9, sourcery-bot/samuelwheeler): when running on an HSDP 2D mesh, the previous `FSDPMeshInfo(mesh=mesh, shard_mesh_dim=0)` construction produced DTensor specs that referenced only one axis of a multi-axis mesh. Replace with `_fsdp_mesh_info(...)` that returns `HSDPMeshInfo` for 2D meshes (with both `shard_mesh_dim` and `replicate_mesh_dim` populated) and `FSDPMeshInfo` for 1D. Looks up axis indices by name (`fsdp` / `efsdp` / `dp_replicate`) so callers never need to guess the axis order. 2. Import-fallback (PR #11, nscottnichols): some XPU PyTorch wheels do not ship `torch.distributed.fsdp._fully_shard._fsdp_common`, which means the per-param `ShardPlacementResult` API is unavailable. Wrap the import in `try/except ImportError`; on failure, fall back to a two-phase `fully_shard` (experts on `edp_mesh`, the rest of the block on `dp_mesh`) using a plain `Shard` return from `shard_placement_fn`. Behavior on builds that have the API is unchanged from path 1. Uses `axis`/`axes` (not `dim_name`) for `DeviceMesh` axis references per the project's `axis` vs `dim` naming convention.
…dExperts Adapts the expert-compute optimizations from PR #11 (nscottnichols) to the ezpz subclass so core MoE code stays untouched. Three pieces: 1. **Equal-counts no-grad fast path** in ``_run_experts_for_loop``. When every expert receives the same non-zero token count and we are outside an autograd region, fold SwiGLU into a single ``torch.bmm`` pair using a stacked ``w13 = cat(w1, w3)`` and a precomputed ``w2_t = w2.transpose(-2, -1)``. 2. **Version-keyed weight transform caches** on ``EzpzGroupedExperts``. ``_get_cached_w13`` / ``_get_cached_w2_t`` reuse the stacked / transposed views across forward calls during inference. Cache key covers storage data_ptr, ``_version``, shape, dtype, and device, so any in-place weight mutation invalidates the cache automatically. Bypassed entirely when autograd is enabled. 3. **Lightweight fast-path counters** gated on the ``EZPZ_MOE_FASTPATH_COUNTERS=1`` environment variable for inference profiling. Disabled-by-default and zero-cost in the common case. Tests cover (a) parity between the fast path and the full per-expert loop, and (b) cache hit/miss accounting across repeated no-grad calls. PR #11's token-dispatcher rewrite (in core ``token_dispatcher.py``) remains out of scope for this branch and would need a separate ``EzpzLocalTokenDispatcher`` subclass to land without core edits.
There was a problem hiding this comment.
Sorry @saforem2, your pull request is larger than the review limit of 150000 diff characters
There was a problem hiding this comment.
Pull request overview
This PR integrates multiple MoE-related changes (adapted to upstream removal of use_grouped_mm), adds ezpz-owned expert compute backend selection + coverage, and includes substantial graph_trainer + RL(vLLM) plumbing refactors alongside CI link-checking updates.
Changes:
- Remove
use_grouped_mmplumbing from core MoE (GroupedExperts) and dependent model configs; add ezpz-sideEzpzGroupedExpertswith acompute_backendselector plus new ezpz MoE flavors + parity tests. - Refactor
graph_trainerpass infrastructure (new registries/modules), update selective activation remat implementation, and add chunked-loss support that exposes lm_head param grads via autograd outputs. - Update RL vLLM integration to a single registration entry point (model + config parser) and adjust link-checking (Lychee) to use caching and run in PR lint.
Reviewed changes
Copilot reviewed 53 out of 54 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| torchtitan/models/llama4/model.py | Removes SM90 gating tied to use_grouped_mm (field removed upstream). |
| torchtitan/models/gpt_oss/moe.py | Inlines grouped-mm path and removes use_grouped_mm branching. |
| torchtitan/models/gpt_oss/model.py | Removes use_grouped_mm/SM90 fallback mutation logic. |
| torchtitan/models/deepseek_v3/model.py | Removes use_grouped_mm/SM90 fallback mutation logic. |
| torchtitan/models/common/moe.py | Removes legacy expert backends and always uses torch._grouped_mm. |
| torchtitan/models/common/config_utils.py | Updates make_experts_config signature to drop use_grouped_mm. |
| torchtitan/experiments/rl/tests/test_bitwise_parity.py | Switches to new vLLM registry entry point. |
| torchtitan/experiments/rl/models/vllm_wrapper.py | Renames wrapper and uses torchtitan ParallelismConfig as source of truth. |
| torchtitan/experiments/rl/models/vllm_registry.py | Adds registry_to_vllm + torchtitan config parser registration and HF-shaped config builder. |
| torchtitan/experiments/rl/models/parallelize.py | Import path consolidation for config types. |
| torchtitan/experiments/rl/grpo.py | Import path consolidation for config types. |
| torchtitan/experiments/rl/generate.py | Updates vLLM registration callsite to registry_to_vllm. |
| torchtitan/experiments/rl/config_registry.py | Import path consolidation; enforces generator parallelism flags. |
| torchtitan/experiments/rl/actors/trainer.py | Import path consolidation for config types. |
| torchtitan/experiments/rl/actors/generator.py | Validates generator parallelism invariants; uses torchtitan vLLM config parser path. |
| torchtitan/experiments/rl/init.py | Exports renamed wrapper and new vLLM registry function. |
| torchtitan/experiments/graph_trainer/trainer.py | Adds timing context around tracing; uses new registries. |
| torchtitan/experiments/graph_trainer/tests/test_trace_module.py | Adds coverage for chunked CE loss tracing and new CP trace test. |
| torchtitan/experiments/graph_trainer/tests/test_precompile.py | Extends precompile artifacts with user_inputs_spec. |
| torchtitan/experiments/graph_trainer/tests/test_passes.py | Updates imports after pass-module split; adds new selective remat tests. |
| torchtitan/experiments/graph_trainer/tests/test_chunked_loss.py | New unit tests for ChunkedCELossWithParamGrads. |
| torchtitan/experiments/graph_trainer/tests/integration_tests.py | Updates integration test overrides to include CP and adjust DP shard degree. |
| torchtitan/experiments/graph_trainer/selective_activation_remat.py | Replaces DCE-based remat with in-place duplication + targeted erasure + offload-chain hoisting. |
| torchtitan/experiments/graph_trainer/registry.py | New shared registries to avoid circular imports. |
| torchtitan/experiments/graph_trainer/qwen3/parallelize.py | Applies CP wrapping to attention when CP is enabled. |
| torchtitan/experiments/graph_trainer/precompile.py | Moves distributed metadata filter import to inductor_passes. |
| torchtitan/experiments/graph_trainer/passes.py | Refactors into orchestration-only module; imports split-out pass implementations. |
| torchtitan/experiments/graph_trainer/memory_policy.py | New module for SAC tagging + memory policy dispatch. |
| torchtitan/experiments/graph_trainer/llama3/parallelize.py | Applies CP wrapping to attention when CP is enabled. |
| torchtitan/experiments/graph_trainer/llama3/parallelize_autoparallel.py | Updates autoparallel input_fn to return tokens+positions. |
| torchtitan/experiments/graph_trainer/inductor_passes.py | New module for regional/full Inductor compilation and flex-attention annotation. |
| torchtitan/experiments/graph_trainer/graph_utils.py | Removes outdated comment referencing use_grouped_mm=False. |
| torchtitan/experiments/graph_trainer/fsdp_passes.py | Respects MUST_CPU_OFFLOAD tagging; updates backward-node detection hook. |
| torchtitan/experiments/graph_trainer/deepseek_v3/parallelize.py | Applies CP wrapping to attention when CP is enabled. |
| torchtitan/experiments/graph_trainer/cudagraph.py | Moves cudagraph + kernel annotation passes into this module. |
| torchtitan/experiments/graph_trainer/common_utils.py | Adds log_timer and apply_cp_to_attention helpers; removes _is_recomputed_node. |
| torchtitan/experiments/graph_trainer/chunked_loss.py | Adds ChunkedCELossWithParamGrads implementation. |
| torchtitan/experiments/graph_trainer/autoparallel_api.py | Drops _check_forward_args dependency. |
| torchtitan/experiments/graph_trainer/.claude/CLAUDE.md | Updates docs to match new remat behavior. |
| torchtitan/experiments/ezpz/tests/test_moe_expert_backends.py | New unit tests for ezpz expert backends, fast paths, and caching. |
| torchtitan/experiments/ezpz/tests/init.py | Adds test package marker. |
| torchtitan/experiments/ezpz/moe/parallelize.py | Makes expert FSDP mesh info HSDP-aware; adds import fallback for older wheels. |
| torchtitan/experiments/ezpz/moe/model.py | Switches ezpz MoE backend to for_loop on non-SM90 via compute_backend. |
| torchtitan/experiments/ezpz/moe/experts.py | Introduces EzpzGroupedExperts with backend selector, padded-bmm backend, and caches/counters. |
| torchtitan/experiments/ezpz/moe/config_registry.py | Adds new ezpz MoE flavors for the new expert backends. |
| torchtitan/experiments/ezpz/moe/init.py | Adds make_ezpz_experts_config and wires compute_backend through ezpz MoE configs. |
| torchtitan/components/loss.py | Introduces overrideable _gradient_backprop hook for chunked-loss subclasses. |
| tests/unit_tests/test_compile_moe.py | Updates grouped-mm compile test to call GroupedExperts._experts_forward. |
| .pre-commit-config.yaml | Updates lychee hook (cache + reduced duplicate warnings + accept-code policy). |
| .gitignore | Ignores .lycheecache. |
| .github/workflows/lint.yaml | Pins Lychee version, adds cache restore, and runs lychee via pre-commit on all files. |
| .github/workflows/link_check.yaml | Removes nightly link-check workflow (now covered by lint workflow). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| echo "Installing Lychee version: $VERSION" | ||
|
|
||
| URL="https://github.com/lycheeverse/lychee/releases/download/${LATEST_TAG}/lychee-x86_64-unknown-linux-gnu.tar.gz" | ||
| URL="https://github.com/lycheeverse/lychee/releases/download/lychee-${VERSION}/lychee-x86_64-unknown-linux-gnu.tar.gz" |
| # Note: this is not safe if downstream accidentally runs tensor ops after | ||
| # the loss returns, which would produce a non-trivial grad_output that we need | ||
| # to properly handle. The complicated part is that grad_output might not be | ||
| # on the same device mesh as accumlated_grad. |
| # The chunk loop above already populated each lm_head param's | ||
| # ``.grad`` with the correctly sharded value via the FSDP last-chunk | ||
| # post-accumulate-grad hook (reduce-scatter). Capture those grads | ||
| # into saved_tensors so backward ca route them as autograd outputs |
Sets num_workers=2 on the BlendCorpusDataLoader.Config built in both
_base_config()s — training dataloader for {agpt, moe}, plus the agpt
validator dataloader. Previously num_workers fell through to the
framework default of 0 (single-threaded I/O), forcing every production
launch to pass --dataloader.num-workers=2 explicitly. Defaulting to
the value we already pass everywhere removes that boilerplate while
keeping the CLI override available for sweeps.
Validator defaults left as-is per user (enable=False, freq=200).
The first round of backend comparison runs (logs/moe-backend-bench/
*20260512-102301.log) all OOM'd in FSDP all-gather copy-in during the
very first forward pass — same failure mode across all three backends:
RuntimeError: level_zero backend failed with error: 40
(UR_RESULT_ERROR_OUT_OF_RESOURCES)
Root cause: the existing `moe_10b_2b_sdpa` flavor is FSDP-only (EP=1)
with `activation_checkpoint_mode="none"`. The 10B model + LBS=2 +
seq_len=8192 doesn't fit in XPU memory at that parallelism shape, and
the new `_for_loop` / `_batched_mm_padded` variants I added inherited
the same defaults so they OOM'd identically. The OOM is a parallelism-
config problem, not a regression in the new backends.
Add three EP=2 + AC=full variants for an apples-to-apples comparison:
- `moe_10b_2b_sdpa_ep_ac`: existing `_ep` flavor with AC=full (baseline)
- `moe_10b_2b_sdpa_for_loop_ep`: for_loop backend, EP=2, AC=full
- `moe_10b_2b_sdpa_batched_mm_padded_ep`: bmm backend, EP=2, AC=full
All three share the same parallelism / memory shape so any TPS/MFU
delta in a re-run will reflect the expert backend, not the FSDP/EP
config. Use these for the backend sweep instead of the EP=1 variants.
The first sweep round (logs/moe-backend-bench/*20260512-102301.log) hit XPU OOM in FSDP all-gather copy-in during the first forward pass on all three configs — the parallelism shape (EP=1, AC=none) doesn't fit a 10B model regardless of expert backend. Swap the sweep to the new `_ep_ac` / `_for_loop_ep` / `_batched_mm_padded_ep` flavors so any TPS/MFU delta reflects the expert backend, not the parallelism config. Also add a Layer 1c pre-flight to `moe_backend_layer0_1.sh` that asserts the EP=2 + AC=full configs build with the right `compute_backend` and that the experts config is the `EzpzGroupedExperts.Config` subclass. Cheap pre-check that catches config-graph regressions before the sweep spends ~5 min/backend launching only to crash at build time.
Both 2B and 80B failover-wrapped jobs (8480294, 8480361) died at 00:00 elapsed with `ERROR: failover_lib.sh not found at /var/spool/pbs/mom_priv/jobs/failover_lib.sh`. Root cause: PBS copies the submit script to /var/spool/pbs/mom_priv/jobs/ before running it on the head node, so `dirname "$(realpath "$0")"` resolves to that spool dir, not the original $repo/scripts. The relative "$PBS_O_WORKDIR/.ezpz-utils-cache" lookup also failed with realpath collapsing across the wrong base. Fix: anchor both EZPZ_UTILS and FAILOVER_LIB to $PBS_O_WORKDIR with the canonical repo-relative paths. Also add PBS_O_WORKDIR + PWD to the error message so future debugging is one log-line shorter.
|
Superseded by #13. That PR takes a more focused 2-commit approach: just the upstream merge + the minimal |
Summary
Single-PR integration of three open MoE PRs against
ezpz, adapted to land cleanly on top of upstream's recent_run_experts_for_loop/use_grouped_mmremoval (pytorch#3308). Everything except the upstream merge stays insidetorchtitan/experiments/ezpz/per Golden Rule #1.Supersedes:
Commits
Merge
upstream/main— pulls in 11 upstream commits including Remove MoE expert for-loop fallback pytorch/torchtitan#3308, which removeduse_grouped_mmfromGroupedExperts.Configand inlined_grouped_mmas the only path. That alone breaks ezpz's pre-SM90 / XPU fallback inexperiments/ezpz/moe/model.py; commit 2 fixes it.feat(ezpz/moe): add EzpzGroupedExperts subclass with compute_backend selector— introducesEzpzGroupedExperts(GroupedExperts)inexperiments/ezpz/moe/experts.pywith three backends:\"grouped_mm\"(default): defer to upstream\"for_loop\": re-vendored from the upstream deletion. Restores the XPU / pre-SM90 fallback.\"batched_mm_padded\": ported from Add padded batched matmul MoE expert backend #10 as an ezpz-owned opt-in.Adds a
make_ezpz_experts_config(...)helper, two new flavors (10B_2B_sdpa_batched_mm_padded,10B_2B_sdpa_for_loop) with matchingconfig_registryentries, and replaces the deaduse_grouped_mmmutation inmodel.pywith acompute_backend = \"for_loop\"switch keyed offhas_cuda_capability(9, 0).fix(ezpz/moe): HSDP-aware FSDP mesh info + import-fallback for older XPU wheels— combines the two Fix MoE expert FSDP mesh info for HSDP #9 / Optimize 10b 2b sdpa moe #11 changes that touched the sameapply_fsdp ep_degree > 1branch:HSDPMeshInfofor 2D meshes (with bothshard_mesh_dimandreplicate_mesh_dim) andFSDPMeshInfofor 1D, looking up axis indices by name (fsdp/efsdp/dp_replicate). Usesaxis/axesnotdim_nameper the project's mesh naming convention._fsdp_commonimport intry/except ImportError; on failure, falls back to a two-phasefully_shard(experts onedp_mesh, the rest ondp_mesh) using a plainShardreturn.perf(ezpz/moe): port PR #11 expert-side optimizations into EzpzGroupedExperts:_run_experts_for_loop(single bmm pair over stackedw13and transposedw2_t).w13/w2_tcaches onEzpzGroupedExperts. Bypassed under autograd; auto-invalidated on_versionbump._record_moe_fastpathcounters gated onEZPZ_MOE_FASTPATH_COUNTERS=1.Test coverage
torchtitan/experiments/ezpz/tests/test_moe_expert_backends.py:test_batched_mm_padded_matches_for_loop_kernel— parity (forward + backward) on mixed / empty / all-zero count cases.test_ezpz_grouped_experts_backend_selector_matches_for_loop— full module-level parity for the backend selector path.test_equal_counts_no_grad_matches_loop— equal-counts bmm fast path matches the per-expert loop.test_cache_hit_on_repeat_no_grad_call— cache hit/miss accounting across repeated no-grad calls.All 4 pass on CPU.
Out of scope (deferred)
PR #11's token-dispatcher rewrite (
+570/-91intorchtitan/models/common/token_dispatcher.py+torchtitan/ops/scatter_add.py): force-load-balance dispatch path, equal-padding all-to-all fast paths, in-placedeterministic_scatter_add_combine. To stay insideexperiments/ezpz/, this needs a separateEzpzLocalTokenDispatchersubclass — much larger than the expert-side subclass — and is best handled as a follow-up PR.Validation plan
moe_500m_50stepsormoe_debugmodel) — to be filed as a follow-up smoke-test report underexperiments/ezpz/docs/experiments/moe/aurora/Notes
compute_backenddefaults to\"grouped_mm\", so flavors that don't opt in get exactly the upstream behavior.model.py:226continues to switch backend on pre-SM90 — unchanged for CUDA-SM90+, restores the XPU fallback that Remove MoE expert for-loop fallback pytorch/torchtitan#3308 removed.