Skip to content

fix(distill): cap max_position_embeddings at 2048 for cuda backend (PMAT-698e)#1808

Merged
noahgift merged 2 commits into
mainfrom
fix/distill-max-seq-len-cap-pmat-698e
May 19, 2026
Merged

fix(distill): cap max_position_embeddings at 2048 for cuda backend (PMAT-698e)#1808
noahgift merged 2 commits into
mainfrom
fix/distill-max-seq-len-cap-pmat-698e

Conversation

@noahgift
Copy link
Copy Markdown
Contributor

Summary

Phase 3 dispatch on gx10 (NVIDIA GB10) still OOM'd at "Block 0 upload" after Fix #2 (PMAT-700-B). Real root cause is a workspace-sizing bug, not JIT cache pressure:

CudaTransformerTrainer::for_inference propagates max_position_embeddings verbatim into max_seq_len, which sizes every per-block scratch buffer — including the attention scores tensor at num_heads * max_seq_len² * 4 bytes.

The math

For Qwen2.5-Coder-0.5B (native max_position_embeddings=32768, 14 heads):

  • per-block attn_scores = 14 × 32768² × 4 = 60 GB
  • 24-layer total = 1.4 TB

GB10 has a 128 GB unified pool. This overflows by an order of magnitude.

Fix

Cap max_position_embeddings at 2048 in run_cuda_backend's TransformerConfig construction. Override via APR_DISTILL_MAX_SEQ_LEN env var.

At the cap:

  • per-block attn_scores = 14 × 2048² × 4 = 235 MB
  • 24-layer total = 5.6 GB

Test plan

  • cargo check --features inference,training,cuda clean build
  • 21 distill lib tests pass (FALSIFY-APR-DISTILL-TRAIN-001/002 unchanged)
  • Live gx10 dispatch reaches training loop (verified post-merge)

Relationship to SPEC-BLACKWELL-FIX-001

PMAT-700-B Fix #2 (skip PTX GEMM pre-warm when cuBLAS active) was CORRECT for the JIT cache pressure it addressed — confirmed via "Skipping PTX pre-warm for 4 GEMM kernels (cuBLAS active — PMAT-700)" log line in v7 dispatch. But the Phase 3 OOM was a separate workspace-sizing bug surfaced ONLY after the JIT cache fix unblocked the pre-warm phase.

Both fixes apply to different aspects of GB10 dispatch readiness. This PR is independent of the SPEC-BLACKWELL-FIX-001 sequencing.

🤖 Generated with Claude Code

…MAT-698e)

Phase 3 dispatch on gx10 (NVIDIA GB10) still OOM'd at "Block 0 upload"
after Fix #2 (PMAT-700-B, skip PTX GEMM pre-warm). The real OOM cause
is unrelated to the JIT cache: CudaTransformerTrainer::for_inference
propagates model_config.max_position_embeddings verbatim into max_seq_len,
which sizes EVERY per-block scratch buffer — including the attention
scores tensor at num_heads * max_seq_len² * 4 bytes.

For Qwen2.5-Coder-0.5B (native max_position_embeddings=32768, 14 heads):
  per-block attn_scores = 14 * 32768² * 4 = 60 GB
  24-layer total = 1.4 TB
which overflows the 128 GB GB10 unified pool an order of magnitude over.

Fix: cap max_position_embeddings at 2048 in run_cuda_backend's
TransformerConfig construction. Override via APR_DISTILL_MAX_SEQ_LEN env
var. Workspace footprint at the cap:
  per-block attn_scores = 14 * 2048² * 4 = 235 MB
  24-layer total = 5.6 GB
which fits comfortably in any GPU memory budget.

Test plan:
- [x] cargo check --features inference,training,cuda — clean build
- [x] 21 distill lib tests pass (FALSIFY-APR-DISTILL-TRAIN-001/002 unchanged)
- [ ] Live gx10 dispatch (post-merge) reaches the training loop without OOM

Note: this is NOT a SPEC-BLACKWELL-FIX-001 fix path. PMAT-700-B (Fix #2)
was correct for what it addressed — JIT cache pressure for sm_121 — but
the Phase 3 OOM was a separate workspace-sizing bug, not JIT cache. Both
fixes apply to different aspects of GB10 dispatch readiness.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@noahgift noahgift enabled auto-merge (squash) May 19, 2026 08:17
@noahgift noahgift merged commit 9d9847e into main May 19, 2026
10 checks passed
@noahgift noahgift deleted the fix/distill-max-seq-len-cap-pmat-698e branch May 19, 2026 09:55
noahgift added a commit that referenced this pull request May 20, 2026
2026-05-20 — real distillation 1.5B teacher → 0.5B student on
Blackwell GB10 with the full PMAT-698e..n + PMAT-700-B cascade active.

  initial_loss = 7.6746
  final_loss   = 7.2036   ← LESS THAN initial
  62 steps, 122.7s, no errors

F-DISTILL-SMOKE-001 ("final_loss < initial_loss") discharged.

Phase 3 of SPEC-DISTILL-001 is COMPLETE.

Evidence:
- evidence/distill-phase-3-real-kd/dispatch.json — dispatch manifest
- evidence/distill-phase-3-real-kd/launch-final-pass.txt — full training log

Run dir on gx10: /home/noah/runs/distill-smoke-20260520-070404/
Trained student checkpoint: student-trained.apr/model.safetensors

Cascade summary (all merged):
- #1804 PMAT-700-B  (cuBLAS prewarm skip)
- #1808 PMAT-698e   (workspace cap)
- #1809 PMAT-698f   (APR magic in weights loader)
- #1810 PMAT-698g   (non-LoRA backward pre-warm)
- #1813 PMAT-698h   (rms_norm_gamma_reduce pre-warm)
- #1815 PMAT-698i   (FWD-CACHE diagnostic logging)
- #1817 PMAT-698j   (THE root cause — warm! macro key)
- #1820 PMAT-698k   (cache-key alignment: rope fwd + rmsnorm eps)
- #1823 PMAT-698m   (smoke setup: non-degenerate batch)
- #1824             (post-mortem doc)
- #1827 PMAT-698n   (rmsnorm pre-warm at both 1e-6 + 1e-5 eps)

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature Request: Cross-Validation Utilities

1 participant