Skip to content

fix(distill): pre-warm non-LoRA backward kernels (PMAT-698g)#1810

Merged
noahgift merged 2 commits into
mainfrom
fix/distill-non-lora-backward-prewarm-pmat-698g
May 19, 2026
Merged

fix(distill): pre-warm non-LoRA backward kernels (PMAT-698g)#1810
noahgift merged 2 commits into
mainfrom
fix/distill-non-lora-backward-prewarm-pmat-698g

Conversation

@noahgift
Copy link
Copy Markdown
Contributor

Summary

Phase 3 dispatch v8 on gx10 reached the training loop and the first backward step started JIT-compiling backward kernels on demand mid-training, then failed with:

forward_backward_with_grad returned None (CUDA stream poisoned or gradient shape mismatch)

This is the documented Blackwell sm_121 bug (trueno#200, CLAUDE.md "Backward kernels: Crash because they compile on-demand when GPU is already active").

Root cause

pre_warm_lora_backward_kernels had a short-circuit:

if lora_rank == 0 {
    eprintln!("[BWD-PREWARM] Skipping (lora_rank=0)");
    return Ok(());
}

The function name implies LoRA-only, but its BODY contains pre-warms for silu_backward, batched_softmax_backward, and batched_rms_norm_backward — which distillation training also needs. The early-out left those to JIT during training, exactly when Blackwell can't.

Fix

Restructure — only LoRA-specific GEMM-backward warm-ups are gated on lora_rank > 0. Activation/norm/standard-FP32-GEMM backward kernels always pre-warm.

Test plan

  • cargo check --features cuda clean build
  • 18 cuda_backward lib tests pass
  • Live gx10 dispatch reaches stepping (post-merge)

Defect cascade so far

Stage 4 of the Phase 3 cuda dispatch cascade:

Stage PR What
1 #1804 PMAT-700-B Skip PTX GEMM pre-warm when cuBLAS active (JIT cache pressure)
2 #1808 PMAT-698e Cap max_seq_len at 2048 (workspace sizing)
3 #1809 PMAT-698f Accept APR magic in load_safetensors_weights (format compat)
4 THIS PMAT-698g Pre-warm non-LoRA backward kernels (stream poisoning)

Each surfaced the next defect once unblocked.

🤖 Generated with Claude Code

Phase 3 dispatch v8 on gx10 reached the training loop and the first
backward step began JIT-compiling silu_backward / batched_rms_norm_backward
/ rms_norm_gamma_reduce ON DEMAND, then failed with:

  forward_backward_with_grad returned None (CUDA stream poisoned or
  gradient shape mismatch)

This is the documented Blackwell sm_121 JIT-during-active-GPU-work bug
(trueno#200, CLAUDE.md "Backward kernels: Crash because they compile
on-demand when GPU is already active").

Cause: `pre_warm_lora_backward_kernels` short-circuited the entire
function at `lora_rank == 0`, leaving the activation/norm backward
kernels to JIT on demand mid-training. The function name implies
LoRA-only, but it actually pre-warmed shared non-LoRA kernels
(silu_backward, batched_softmax_backward, batched_rms_norm_backward)
that distillation training also needs.

Fix: restructure — only the LoRA-specific gemm_backward warm-ups are
gated on lora_rank > 0. The activation/norm/standard-FP32-GEMM backward
kernels always pre-warm, regardless of LoRA mode. Distillation training
(lora_rank == 0) now gets the full backward kernel cache before block
upload, eliminating on-demand JIT and the resulting stream poisoning.

Test plan:
- [x] cargo check --features cuda — clean build
- [x] 18 cuda_backward lib tests pass
- [ ] Live gx10 dispatch reaches stepping (post-merge verification)

Stage 4 in the Phase 3 cuda dispatch defect cascade:
  PMAT-700-B → PMAT-698e → PMAT-698f → PMAT-698g
Each surfaced the next defect on the gx10 path.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@noahgift noahgift enabled auto-merge (squash) May 19, 2026 11:06
@noahgift noahgift merged commit 804661c into main May 19, 2026
10 checks passed
@noahgift noahgift deleted the fix/distill-non-lora-backward-prewarm-pmat-698g branch May 19, 2026 12:01
noahgift added a commit that referenced this pull request May 20, 2026
2026-05-20 — real distillation 1.5B teacher → 0.5B student on
Blackwell GB10 with the full PMAT-698e..n + PMAT-700-B cascade active.

  initial_loss = 7.6746
  final_loss   = 7.2036   ← LESS THAN initial
  62 steps, 122.7s, no errors

F-DISTILL-SMOKE-001 ("final_loss < initial_loss") discharged.

Phase 3 of SPEC-DISTILL-001 is COMPLETE.

Evidence:
- evidence/distill-phase-3-real-kd/dispatch.json — dispatch manifest
- evidence/distill-phase-3-real-kd/launch-final-pass.txt — full training log

Run dir on gx10: /home/noah/runs/distill-smoke-20260520-070404/
Trained student checkpoint: student-trained.apr/model.safetensors

Cascade summary (all merged):
- #1804 PMAT-700-B  (cuBLAS prewarm skip)
- #1808 PMAT-698e   (workspace cap)
- #1809 PMAT-698f   (APR magic in weights loader)
- #1810 PMAT-698g   (non-LoRA backward pre-warm)
- #1813 PMAT-698h   (rms_norm_gamma_reduce pre-warm)
- #1815 PMAT-698i   (FWD-CACHE diagnostic logging)
- #1817 PMAT-698j   (THE root cause — warm! macro key)
- #1820 PMAT-698k   (cache-key alignment: rope fwd + rmsnorm eps)
- #1823 PMAT-698m   (smoke setup: non-degenerate batch)
- #1824             (post-mortem doc)
- #1827 PMAT-698n   (rmsnorm pre-warm at both 1e-6 + 1e-5 eps)

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant