Skip to content

fix(distill): align pre-warm cache keys with runtime — rmsnorm eps + rope forward (PMAT-698k)#1820

Merged
noahgift merged 4 commits into
mainfrom
fix/distill-prewarm-key-alignment-pmat-698k
May 19, 2026
Merged

fix(distill): align pre-warm cache keys with runtime — rmsnorm eps + rope forward (PMAT-698k)#1820
noahgift merged 4 commits into
mainfrom
fix/distill-prewarm-key-alignment-pmat-698k

Conversation

@noahgift
Copy link
Copy Markdown
Contributor

Summary

PMAT-698i's [FWD-CACHE] diagnostic logging surfaced two additional cache-key misalignments beyond the PMAT-698j macro fix. Even with the correct \$key substitution, these kernels would still cache-miss because pre-warm built different keys than runtime:

# Pre-warm key (BEFORE) Runtime key Status post-PMAT-698j
1 batched_rmsnorm_fwd_{h} batched_rmsnorm_fwd_{h}_eps{eps_bits:08x} still misses (no eps in pre-warm)
2 (not pre-warmed at all) batched_rope_fwd_{nh}_{hd}_{s}_th{theta:08x} still misses

Fix

  • Append _eps{bits:08x} to rmsnorm pre-warm key (uses default 1e-5)
  • Add batched_rope_fwd warm calls for Q and KV head counts (GQA), seq_len=1, Qwen2 theta=1e6

Stage 7 of cascade

# PR What
1 #1804 PMAT-700-B Skip PTX GEMM pre-warm (JIT cache pressure)
2 #1808 PMAT-698e Cap max_seq_len at 2048
3 #1809 PMAT-698f APR magic in weights loader
4 #1810 PMAT-698g Non-LoRA backward pre-warm
5 #1813 PMAT-698h rms_norm_gamma_reduce pre-warm
6 #1815 PMAT-698i Diagnostic logging (surfaced the root cause)
7 #1817 PMAT-698j Macro hardcoded-key fix (root cause)
8 THIS PMAT-698k Key alignment (rmsnorm eps + rope forward)

After PMAT-698j+k land together, [FWD-CACHE] log should show zero post-pre-warm JIT events on the dominant forward path.

Test plan

  • cargo check --features cuda clean
  • 37 cuda_forward lib tests pass
  • Live gx10 dispatch: zero [FWD-CACHE] Compiling events post-pre-warm

🤖 Generated with Claude Code

…rope forward (PMAT-698k)

PMAT-698i's [FWD-CACHE] diagnostic surfaced TWO additional pre-warm key
mismatches beyond the PMAT-698j macro fix:

1. batched_rmsnorm_fwd_{h}        ← pre-warm key (no eps suffix)
   batched_rmsnorm_fwd_{h}_eps... ← runtime key (normalization.rs:139)

2. batched_rope_fwd_{nh}_{hd}_{s}_th... — runtime key (normalization.rs:339)
   (not in pre-warm at all)

Even with PMAT-698j's macro fix (correct $key substitution), these
kernels would still cache-miss at runtime because the pre-warm stored
them under different keys than runtime constructs.

Fix:
- Append _eps{bits:08x} to rmsnorm pre-warm key with default 1e-5
- Add batched_rope_fwd warm! calls for both num_heads and num_kv_heads
  (GQA support) with seq_len=1 (smoke runtime case) and Qwen2 theta=1e6

Test plan:
- [x] cargo check --features cuda — clean build
- [x] 37 cuda_forward lib tests pass
- [ ] Live gx10 dispatch: [FWD-CACHE] Compiling events for these three
      kernels disappear post-PMAT-698j+k landing

Stage 7 of the Phase 3 cuda dispatch cascade:
  PMAT-700-B → PMAT-698e → PMAT-698f → PMAT-698g → PMAT-698h → PMAT-698i
  (DIAG) → PMAT-698j (root cause macro) → PMAT-698k (THIS — key alignment)

After this lands, the [FWD-CACHE] log should show ZERO post-pre-warm
JIT events on the dominant forward path.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@noahgift noahgift enabled auto-merge (squash) May 19, 2026 16:49
@noahgift noahgift merged commit 277f29d into main May 19, 2026
10 checks passed
@noahgift noahgift deleted the fix/distill-prewarm-key-alignment-pmat-698k branch May 19, 2026 18:20
noahgift added a commit that referenced this pull request May 20, 2026
2026-05-20 — real distillation 1.5B teacher → 0.5B student on
Blackwell GB10 with the full PMAT-698e..n + PMAT-700-B cascade active.

  initial_loss = 7.6746
  final_loss   = 7.2036   ← LESS THAN initial
  62 steps, 122.7s, no errors

F-DISTILL-SMOKE-001 ("final_loss < initial_loss") discharged.

Phase 3 of SPEC-DISTILL-001 is COMPLETE.

Evidence:
- evidence/distill-phase-3-real-kd/dispatch.json — dispatch manifest
- evidence/distill-phase-3-real-kd/launch-final-pass.txt — full training log

Run dir on gx10: /home/noah/runs/distill-smoke-20260520-070404/
Trained student checkpoint: student-trained.apr/model.safetensors

Cascade summary (all merged):
- #1804 PMAT-700-B  (cuBLAS prewarm skip)
- #1808 PMAT-698e   (workspace cap)
- #1809 PMAT-698f   (APR magic in weights loader)
- #1810 PMAT-698g   (non-LoRA backward pre-warm)
- #1813 PMAT-698h   (rms_norm_gamma_reduce pre-warm)
- #1815 PMAT-698i   (FWD-CACHE diagnostic logging)
- #1817 PMAT-698j   (THE root cause — warm! macro key)
- #1820 PMAT-698k   (cache-key alignment: rope fwd + rmsnorm eps)
- #1823 PMAT-698m   (smoke setup: non-degenerate batch)
- #1824             (post-mortem doc)
- #1827 PMAT-698n   (rmsnorm pre-warm at both 1e-6 + 1e-5 eps)

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant