Skip to content

fix(distill): pre-warm rms_norm_gamma_reduce too (PMAT-698h)#1813

Merged
noahgift merged 2 commits into
mainfrom
fix/distill-rms-norm-gamma-reduce-prewarm-pmat-698h
May 19, 2026
Merged

fix(distill): pre-warm rms_norm_gamma_reduce too (PMAT-698h)#1813
noahgift merged 2 commits into
mainfrom
fix/distill-rms-norm-gamma-reduce-prewarm-pmat-698h

Conversation

@noahgift
Copy link
Copy Markdown
Contributor

Summary

Phase 3 dispatch v9 (with PMAT-698g landed) still failed with the same forward_backward_with_grad returned None error. Cause: RMSNorm backward is two kernels, only one was pre-warmed.

Stage Kernel PMAT-698g PMAT-698h
1 batched_rms_norm_backward (compute per-row gamma partials) ✓ pre-warmed ✓ pre-warmed
2 rms_norm_gamma_reduce (deterministic fixed-order sum of partials) ✗ JIT'd ✓ pre-warmed

structured.rs:247 calls rms_norm_gamma_reduce immediately after Stage 1 to fold the per-row gamma partials into a single [hidden] gradient. The reduce was missing from my PMAT-698g pre-warm list, so it JIT'd during the first backward step → sm_121 stream poisoning.

Fix

Add RmsNormGammaReduceKernel::new(max_seq_len, hidden) to the backward pre-warm. Matches the (batch_size, hidden_size) dims that structured.rs constructs at call time.

Test plan

  • cargo check --features cuda — clean build
  • 18 cuda_backward lib tests pass
  • Live gx10 dispatch reaches stepping (post-merge)

Defect cascade — stage 5

# PR What
1 #1804 PMAT-700-B Skip PTX GEMM pre-warm (JIT cache pressure)
2 #1808 PMAT-698e Cap max_seq_len at 2048 (workspace sizing)
3 #1809 PMAT-698f Accept APR magic in weights loader
4 #1810 PMAT-698g Pre-warm non-LoRA backward kernels
5 THIS PMAT-698h Pre-warm rms_norm_gamma_reduce

Each surfaced the next defect once unblocked.

🤖 Generated with Claude Code

Phase 3 dispatch v9 still failed with the same "stream poisoned" error
after PMAT-698g. The cause: RMSNorm backward is TWO kernels, not one.

  Stage 1: batched_rms_norm_backward  (pre-warmed by PMAT-698g ✓)
  Stage 2: rms_norm_gamma_reduce      (JIT'd on demand ✗)

structured.rs:247 calls rms_norm_gamma_reduce after batched_rms_norm_backward
to fold the per-row gamma partials into a single [hidden] gradient tensor
in deterministic fixed order. The reduce was missing from the backward
pre-warm list, so it JIT'd at first backward step → sm_121 stream
poisoning → forward_backward_with_grad returns None.

Fix: add RmsNormGammaReduceKernel::new(max_seq_len, hidden) to the
backward pre-warm. Same (batch_size, hidden_size) dims that structured.rs
constructs at call time.

Test plan:
- [x] cargo check --features cuda — clean build
- [x] 18 cuda_backward lib tests pass (no behavior change for non-Blackwell)
- [ ] Live gx10 dispatch reaches stepping (post-merge)

Stage 5 of the Phase 3 cuda dispatch defect cascade:
  PMAT-700-B → PMAT-698e → PMAT-698f → PMAT-698g → PMAT-698h

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@noahgift noahgift enabled auto-merge (squash) May 19, 2026 12:06
@noahgift noahgift merged commit 5d0943c into main May 19, 2026
10 checks passed
@noahgift noahgift deleted the fix/distill-rms-norm-gamma-reduce-prewarm-pmat-698h branch May 19, 2026 12:59
noahgift added a commit that referenced this pull request May 20, 2026
2026-05-20 — real distillation 1.5B teacher → 0.5B student on
Blackwell GB10 with the full PMAT-698e..n + PMAT-700-B cascade active.

  initial_loss = 7.6746
  final_loss   = 7.2036   ← LESS THAN initial
  62 steps, 122.7s, no errors

F-DISTILL-SMOKE-001 ("final_loss < initial_loss") discharged.

Phase 3 of SPEC-DISTILL-001 is COMPLETE.

Evidence:
- evidence/distill-phase-3-real-kd/dispatch.json — dispatch manifest
- evidence/distill-phase-3-real-kd/launch-final-pass.txt — full training log

Run dir on gx10: /home/noah/runs/distill-smoke-20260520-070404/
Trained student checkpoint: student-trained.apr/model.safetensors

Cascade summary (all merged):
- #1804 PMAT-700-B  (cuBLAS prewarm skip)
- #1808 PMAT-698e   (workspace cap)
- #1809 PMAT-698f   (APR magic in weights loader)
- #1810 PMAT-698g   (non-LoRA backward pre-warm)
- #1813 PMAT-698h   (rms_norm_gamma_reduce pre-warm)
- #1815 PMAT-698i   (FWD-CACHE diagnostic logging)
- #1817 PMAT-698j   (THE root cause — warm! macro key)
- #1820 PMAT-698k   (cache-key alignment: rope fwd + rmsnorm eps)
- #1823 PMAT-698m   (smoke setup: non-degenerate batch)
- #1824             (post-mortem doc)
- #1827 PMAT-698n   (rmsnorm pre-warm at both 1e-6 + 1e-5 eps)

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant