fix(distill): pre-warm rms_norm_gamma_reduce too (PMAT-698h)#1813
Merged
noahgift merged 2 commits intoMay 19, 2026
Conversation
Phase 3 dispatch v9 still failed with the same "stream poisoned" error after PMAT-698g. The cause: RMSNorm backward is TWO kernels, not one. Stage 1: batched_rms_norm_backward (pre-warmed by PMAT-698g ✓) Stage 2: rms_norm_gamma_reduce (JIT'd on demand ✗) structured.rs:247 calls rms_norm_gamma_reduce after batched_rms_norm_backward to fold the per-row gamma partials into a single [hidden] gradient tensor in deterministic fixed order. The reduce was missing from the backward pre-warm list, so it JIT'd at first backward step → sm_121 stream poisoning → forward_backward_with_grad returns None. Fix: add RmsNormGammaReduceKernel::new(max_seq_len, hidden) to the backward pre-warm. Same (batch_size, hidden_size) dims that structured.rs constructs at call time. Test plan: - [x] cargo check --features cuda — clean build - [x] 18 cuda_backward lib tests pass (no behavior change for non-Blackwell) - [ ] Live gx10 dispatch reaches stepping (post-merge) Stage 5 of the Phase 3 cuda dispatch defect cascade: PMAT-700-B → PMAT-698e → PMAT-698f → PMAT-698g → PMAT-698h Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This was referenced May 19, 2026
Merged
noahgift
added a commit
that referenced
this pull request
May 20, 2026
2026-05-20 — real distillation 1.5B teacher → 0.5B student on
Blackwell GB10 with the full PMAT-698e..n + PMAT-700-B cascade active.
initial_loss = 7.6746
final_loss = 7.2036 ← LESS THAN initial
62 steps, 122.7s, no errors
F-DISTILL-SMOKE-001 ("final_loss < initial_loss") discharged.
Phase 3 of SPEC-DISTILL-001 is COMPLETE.
Evidence:
- evidence/distill-phase-3-real-kd/dispatch.json — dispatch manifest
- evidence/distill-phase-3-real-kd/launch-final-pass.txt — full training log
Run dir on gx10: /home/noah/runs/distill-smoke-20260520-070404/
Trained student checkpoint: student-trained.apr/model.safetensors
Cascade summary (all merged):
- #1804 PMAT-700-B (cuBLAS prewarm skip)
- #1808 PMAT-698e (workspace cap)
- #1809 PMAT-698f (APR magic in weights loader)
- #1810 PMAT-698g (non-LoRA backward pre-warm)
- #1813 PMAT-698h (rms_norm_gamma_reduce pre-warm)
- #1815 PMAT-698i (FWD-CACHE diagnostic logging)
- #1817 PMAT-698j (THE root cause — warm! macro key)
- #1820 PMAT-698k (cache-key alignment: rope fwd + rmsnorm eps)
- #1823 PMAT-698m (smoke setup: non-degenerate batch)
- #1824 (post-mortem doc)
- #1827 PMAT-698n (rmsnorm pre-warm at both 1e-6 + 1e-5 eps)
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Phase 3 dispatch v9 (with PMAT-698g landed) still failed with the same
forward_backward_with_grad returned Noneerror. Cause: RMSNorm backward is two kernels, only one was pre-warmed.batched_rms_norm_backward(compute per-row gamma partials)rms_norm_gamma_reduce(deterministic fixed-order sum of partials)structured.rs:247callsrms_norm_gamma_reduceimmediately after Stage 1 to fold the per-row gamma partials into a single[hidden]gradient. The reduce was missing from my PMAT-698g pre-warm list, so it JIT'd during the first backward step → sm_121 stream poisoning.Fix
Add
RmsNormGammaReduceKernel::new(max_seq_len, hidden)to the backward pre-warm. Matches the(batch_size, hidden_size)dims that structured.rs constructs at call time.Test plan
cargo check --features cuda— clean buildDefect cascade — stage 5
Each surfaced the next defect once unblocked.
🤖 Generated with Claude Code