Skip to content

feat(blackwell): skip PTX GEMM pre-warm when cuBLAS is active (PMAT-700-B)#1804

Merged
noahgift merged 4 commits into
mainfrom
fix/blackwell-pre-warm-skip-cublas-gemm-pmat-700b
May 19, 2026
Merged

feat(blackwell): skip PTX GEMM pre-warm when cuBLAS is active (PMAT-700-B)#1804
noahgift merged 4 commits into
mainfrom
fix/blackwell-pre-warm-skip-cublas-gemm-pmat-700b

Conversation

@noahgift
Copy link
Copy Markdown
Contributor

Summary

Implements Fix #2 from SPEC-BLACKWELL-FIX-001 (PR #1803): when cuBLAS is bound at ForwardKernelCache::new() time, the runtime always takes the cuBLAS fast path for the standard 2D GEMMs. Pre-warming the PTX equivalents is wasted VRAM and is the root cause of the GB10 OOM at "Block 0 upload".

Change

Single 30-line edit to crates/aprender-train/src/autograd/cuda_forward/cache.rs::pre_warm_for_model. When self.cublas.is_some(), skip pre-warm steps 2-5 (Q/O, K/V, gate/up, down projections — 4 PTX modules) and emit a confirmation log line.

Why this is the right fix

  • Runtime path was ALREADY cuBLAS-first (per gemm.rs:47-49 ALB-075 dispatch and cuda_block.rs:2895 ALB-076 comment). The PTX pre-warm was for a fallback path that no host with cuBLAS ever takes.
  • CLAUDE.md notes "GEMM Parity: VERIFIED — 4 of 5 parity tests pass" — cuBLAS numerical agreement with PTX kernels is established.
  • On Blackwell sm_121, the JIT-cache footprint of 4-5 PTX GEMM modules is ~200-400 MB. Removing them gives back exactly the headroom that block upload needs.

Behavior matrix

Host cuBLAS Pre-warm modules VRAM change Runtime path
RTX 4090 / A100 / H100 ✅ available 4 fewer (saves ~200-400 MB) -200-400 MB cache cuBLAS (unchanged)
GB10 (Blackwell) ✅ available 4 fewer unblocks "Block 0 upload" cuBLAS (unchanged)
Older GPUs (no cuBLAS) ❌ unavailable unchanged (all 4 pre-warmed) no change PTX fallback (unchanged)

Test plan

  • cargo check --features cuda -p aprender-train clean (no new warnings)
  • 285 autograd tests pass
  • 37 cuda_forward tests pass
  • Live STEPS=50 ./scripts/dispatch-distill-phase-3-gx10.sh on gx10 reaches the training loop (pending — fires after merge)
  • RTX 4090 dispatch shows the new "Skipping PTX pre-warm" log line and produces identical forward results

Risk

Low. Drop-in optimization. No runtime path change. cuBLAS-unavailable hosts see no behavior change.

What's NOT in this PR

Per SPEC-BLACKWELL-FIX-001 sequencing:

References

🤖 Generated with Claude Code

…00-B Fix #2)

Implements Fix #2 from SPEC-BLACKWELL-FIX-001: when cuBLAS is bound at
ForwardKernelCache::new() time, the runtime always takes the cuBLAS fast
path for the standard 2D GEMMs (Q/K/V/O/gate/up/down projections — see
ALB-075 dispatch in gemm.rs:47-49 and cuda_block.rs:2895). Pre-warming
the PTX equivalents is wasted VRAM.

On sm_121 (Blackwell GB10), the resulting JIT-cache footprint pushes the
subsequent transformer block upload over the budget and triggers
CUDA_ERROR_OUT_OF_MEMORY at "Block 0 upload". Skipping these four
pre-warms when cuBLAS is bound saves 4-5 PTX modules from the cache
(more on models with kv_dim != hidden_size, which adds steps 3 + 5).

Behavior:
- cuBLAS available (RTX 4090, GB10, A100, H100 with libcublas):
  warn line "[CUDA] Skipping PTX pre-warm for 4 GEMM kernels (cuBLAS active
  — PMAT-700)" prints; 4 PTX modules are NOT compiled, saving ~200-400 MB
  VRAM. Runtime forward path is unchanged (cuBLAS was always the active
  path; PTX was dead-cache).
- cuBLAS unavailable (driver missing / older GPUs / CPU fallback path):
  identical behavior to today — all 4 GEMM kernels are pre-warmed for the
  PTX fallback path.

Falsifier F-BLACKWELL-CUBLAS-PREWARM-001 (proposed): assert post-pre-warm
cache module count decreases by 4 when cuBLAS is bound, and runtime
forward outputs are bit-identical to current main.

Existing tests verified: 285 autograd tests + 37 cuda_forward tests
pass. cuBLAS GEMM parity already verified (CLAUDE.md "GEMM Parity:
VERIFIED — 4/5 parity tests pass") so the runtime numerical path is
unchanged by this fix.

This is the highest-EV / lowest-risk single change from SPEC-BLACKWELL-
FIX-001. Likely sufficient on its own to unblock gx10 dispatch.
Subsequent fixes (PTX precompilation #1, wgpu fallback #3) follow.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@noahgift noahgift enabled auto-merge (squash) May 19, 2026 06:02
@noahgift noahgift merged commit 8026f0a into main May 19, 2026
10 checks passed
@noahgift noahgift deleted the fix/blackwell-pre-warm-skip-cublas-gemm-pmat-700b branch May 19, 2026 08:04
noahgift added a commit that referenced this pull request May 20, 2026
2026-05-20 — real distillation 1.5B teacher → 0.5B student on
Blackwell GB10 with the full PMAT-698e..n + PMAT-700-B cascade active.

  initial_loss = 7.6746
  final_loss   = 7.2036   ← LESS THAN initial
  62 steps, 122.7s, no errors

F-DISTILL-SMOKE-001 ("final_loss < initial_loss") discharged.

Phase 3 of SPEC-DISTILL-001 is COMPLETE.

Evidence:
- evidence/distill-phase-3-real-kd/dispatch.json — dispatch manifest
- evidence/distill-phase-3-real-kd/launch-final-pass.txt — full training log

Run dir on gx10: /home/noah/runs/distill-smoke-20260520-070404/
Trained student checkpoint: student-trained.apr/model.safetensors

Cascade summary (all merged):
- #1804 PMAT-700-B  (cuBLAS prewarm skip)
- #1808 PMAT-698e   (workspace cap)
- #1809 PMAT-698f   (APR magic in weights loader)
- #1810 PMAT-698g   (non-LoRA backward pre-warm)
- #1813 PMAT-698h   (rms_norm_gamma_reduce pre-warm)
- #1815 PMAT-698i   (FWD-CACHE diagnostic logging)
- #1817 PMAT-698j   (THE root cause — warm! macro key)
- #1820 PMAT-698k   (cache-key alignment: rope fwd + rmsnorm eps)
- #1823 PMAT-698m   (smoke setup: non-degenerate batch)
- #1824             (post-mortem doc)
- #1827 PMAT-698n   (rmsnorm pre-warm at both 1e-6 + 1e-5 eps)

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

1 participant