Skip to content

fix(distill): warm! macro used hardcoded "silu_forward" key — THE Blackwell root cause (PMAT-698j)#1817

Merged
noahgift merged 1 commit into
mainfrom
fix/distill-warm-macro-cache-key-pmat-698j
May 19, 2026
Merged

fix(distill): warm! macro used hardcoded "silu_forward" key — THE Blackwell root cause (PMAT-698j)#1817
noahgift merged 1 commit into
mainfrom
fix/distill-warm-macro-cache-key-pmat-698j

Conversation

@noahgift
Copy link
Copy Markdown
Contributor

THE root-cause bug

The warm! macro in pre_warm_for_model had self.get_or_compile("silu_forward", &ptx)? hardcoded — every pre-warm call stored its compiled module under the SAME hashmap key. Only the first kernel actually got compiled; all subsequent warm!() invocations hit the Occupied path and silently discarded their PTX.

At runtime, every kernel looks up its real key (e.g. batched_rmsnorm_fwd_896_eps358637bd) — cache miss → JIT mid-forward-pass → Blackwell sm_121 stream poisoning → forward_backward_with_grad returned None.

Discovered via PMAT-698i diagnostic logging

The diagnostic PR's [FWD-CACHE] log surfaced every "pre-warmed" kernel actually JIT-compiling at first runtime call. Proves the cache held exactly one entry (under silu_forward), no matter how many warm!() calls ran.

Fix

 macro_rules! warm {
     ($key:expr, $kernel:expr) => {{
+        let key = $key;
         let ptx = $kernel.emit_ptx_for_target(&target);
-        self.get_or_compile("silu_forward", &ptx)?;
+        self.get_or_compile(&key, &ptx)?;
         count += 1;
     }};
 }

One-character behavioral change. The previous PMAT-698g/h fixes were defense-in-depth on the backward cache (whose warm! macro was correct); this fixes the dominant forward-cache failure that all 5 prior iterations were chasing.

Cascade summary

# PR Status
1 #1804 PMAT-700-B ✅ legit independent (JIT cache pressure on cuBLAS-active hosts)
2 #1808 PMAT-698e ✅ legit independent (workspace sizing)
3 #1809 PMAT-698f ✅ legit independent (APR format compat)
4 #1810 PMAT-698g ✅ legit defense-in-depth (backward cache gap)
5 #1813 PMAT-698h ✅ legit defense-in-depth (rms_norm_gamma_reduce gap)
6 #1815 PMAT-698i ✅ diagnostic (surfaced THE bug)
7 THIS PMAT-698j 🎯 THE root cause

Test plan

  • cargo check --features cuda clean
  • 366 autograd lib tests pass
  • Live gx10 dispatch: zero [FWD-CACHE] Compiling events post-pre-warm
  • Forward pass + backward pass + optimizer step actually run
  • Smoke completes 50 steps with final_loss < initial_loss (F-DISTILL-SMOKE-001)

🤖 Generated with Claude Code

…" key (PMAT-698j)

THE root-cause bug behind the entire Phase 3 cuda dispatch cascade
(PMAT-698e..i, 6 prior PRs). Discovered by PMAT-698i's [FWD-CACHE]
diagnostic logging.

The `warm!` macro in pre_warm_for_model:

  macro_rules! warm {
      ($key:expr, $kernel:expr) => {{
          let ptx = $kernel.emit_ptx_for_target(&target);
          self.get_or_compile("silu_forward", &ptx)?;  // <-- HARDCODED
          count += 1;
      }};
  }

Every single `warm!()` call stored its compiled module under the
hashmap key "silu_forward", colliding on the first call:

  1. warm!("batched_rmsnorm_fwd_896", BatchedVectorizedRmsNormKernel...)
     → cache["silu_forward"] = BatchedVectorizedRmsNorm PTX
  2. warm!("gemm_forward_...", ...)
     → cache["silu_forward"] already Occupied → returns existing entry,
       new PTX silently discarded
  3-23. same — all subsequent kernels never actually pre-warm.

At runtime, every kernel looks up its real cache key:

  let key = format!("batched_rmsnorm_fwd_{hidden_size}_eps{eps_bits:08x}");
  match cache.get_cached(&key) { Some(m) => m, None => JIT }

— and cache-MISSES because the cache contains exactly one entry
under "silu_forward". JIT fires for every "pre-warmed" kernel during
the first forward pass — exactly when Blackwell sm_121's CUDA driver
crashes on cuModuleLoadData during active GPU work.

PMAT-698i's [FWD-CACHE] logging surfaced this: every kernel that was
"supposed to be pre-warmed" emitted [FWD-CACHE] Compiling at runtime,
proving the cache had nothing in it under those keys.

Fix: pass $key through to get_or_compile. One-character change
("silu_forward" → &key).

This explains the entire PMAT-698e..i cascade:
- PMAT-698e (workspace cap) — legit independent bug
- PMAT-698f (APR magic) — legit independent bug
- PMAT-698g (non-LoRA backward pre-warm) — would have been fine IF
  forward pre-warm worked; the backward kernels were correctly stored
  under their real keys (backward macro doesn't have the typo).
  Defense-in-depth, still valuable.
- PMAT-698h (rms_norm_gamma_reduce) — same defense-in-depth.
- PMAT-698j (THIS) — the root cause.

The previous PMAT-698g/h fixes are still correct (they covered backward
gaps that exist independently). This PR addresses the forward cache,
which was the dominant source of post-pre-warm JIT events.

Test plan:
- [x] cargo check --features cuda — clean build
- [x] 366 autograd lib tests pass
- [ ] Live gx10 dispatch (post-merge) shows ZERO [FWD-CACHE] Compiling
      events post-pre-warm (all 23 forward kernels now actually cached)

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@noahgift noahgift enabled auto-merge (squash) May 19, 2026 16:37
@noahgift noahgift merged commit f8e403d into main May 19, 2026
11 checks passed
@noahgift noahgift deleted the fix/distill-warm-macro-cache-key-pmat-698j branch May 19, 2026 17:09
noahgift added a commit that referenced this pull request May 20, 2026
2026-05-20 — real distillation 1.5B teacher → 0.5B student on
Blackwell GB10 with the full PMAT-698e..n + PMAT-700-B cascade active.

  initial_loss = 7.6746
  final_loss   = 7.2036   ← LESS THAN initial
  62 steps, 122.7s, no errors

F-DISTILL-SMOKE-001 ("final_loss < initial_loss") discharged.

Phase 3 of SPEC-DISTILL-001 is COMPLETE.

Evidence:
- evidence/distill-phase-3-real-kd/dispatch.json — dispatch manifest
- evidence/distill-phase-3-real-kd/launch-final-pass.txt — full training log

Run dir on gx10: /home/noah/runs/distill-smoke-20260520-070404/
Trained student checkpoint: student-trained.apr/model.safetensors

Cascade summary (all merged):
- #1804 PMAT-700-B  (cuBLAS prewarm skip)
- #1808 PMAT-698e   (workspace cap)
- #1809 PMAT-698f   (APR magic in weights loader)
- #1810 PMAT-698g   (non-LoRA backward pre-warm)
- #1813 PMAT-698h   (rms_norm_gamma_reduce pre-warm)
- #1815 PMAT-698i   (FWD-CACHE diagnostic logging)
- #1817 PMAT-698j   (THE root cause — warm! macro key)
- #1820 PMAT-698k   (cache-key alignment: rope fwd + rmsnorm eps)
- #1823 PMAT-698m   (smoke setup: non-degenerate batch)
- #1824             (post-mortem doc)
- #1827 PMAT-698n   (rmsnorm pre-warm at both 1e-6 + 1e-5 eps)

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant