From e452b0376ab65754b545fc70d8ffec750851eac9 Mon Sep 17 00:00:00 2001 From: Noah Gift Date: Tue, 19 May 2026 10:15:43 +0200 Subject: [PATCH] fix(distill): cap max_position_embeddings at 2048 for cuda backend (PMAT-698e) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 3 dispatch on gx10 (NVIDIA GB10) still OOM'd at "Block 0 upload" after Fix #2 (PMAT-700-B, skip PTX GEMM pre-warm). The real OOM cause is unrelated to the JIT cache: CudaTransformerTrainer::for_inference propagates model_config.max_position_embeddings verbatim into max_seq_len, which sizes EVERY per-block scratch buffer — including the attention scores tensor at num_heads * max_seq_len² * 4 bytes. For Qwen2.5-Coder-0.5B (native max_position_embeddings=32768, 14 heads): per-block attn_scores = 14 * 32768² * 4 = 60 GB 24-layer total = 1.4 TB which overflows the 128 GB GB10 unified pool an order of magnitude over. Fix: cap max_position_embeddings at 2048 in run_cuda_backend's TransformerConfig construction. Override via APR_DISTILL_MAX_SEQ_LEN env var. Workspace footprint at the cap: per-block attn_scores = 14 * 2048² * 4 = 235 MB 24-layer total = 5.6 GB which fits comfortably in any GPU memory budget. Test plan: - [x] cargo check --features inference,training,cuda — clean build - [x] 21 distill lib tests pass (FALSIFY-APR-DISTILL-TRAIN-001/002 unchanged) - [ ] Live gx10 dispatch (post-merge) reaches the training loop without OOM Note: this is NOT a SPEC-BLACKWELL-FIX-001 fix path. PMAT-700-B (Fix #2) was correct for what it addressed — JIT cache pressure for sm_121 — but the Phase 3 OOM was a separate workspace-sizing bug, not JIT cache. Both fixes apply to different aspects of GB10 dispatch readiness. Co-Authored-By: Claude Opus 4.7 --- crates/apr-cli/src/commands/distill.rs | 47 ++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/crates/apr-cli/src/commands/distill.rs b/crates/apr-cli/src/commands/distill.rs index 8d73c02b2..b2240fcfb 100644 --- a/crates/apr-cli/src/commands/distill.rs +++ b/crates/apr-cli/src/commands/distill.rs @@ -606,6 +606,35 @@ pub(crate) fn run( /// The dispatch script (`scripts/dispatch-distill-phase-3-gx10.sh`) /// scales by step count rather than batch parallelism. Phase 2e /// generalizes via a fused-step trait method. +/// PMAT-698e: cap max_position_embeddings before passing into the cuda +/// trainer. See call sites in run_cuda_backend for the full rationale — +/// CudaTransformerTrainer::for_inference uses this value as max_seq_len +/// for ALL per-block scratch buffers, and the attention scores tensor +/// is sized at num_heads * max_seq_len² * 4 bytes, which overflows +/// the GPU memory budget for max_position_embeddings >= ~8192 even on +/// the 128GB GB10 unified pool. +/// +/// Default cap: 2048 (gives ~5.6 GB workspace for Qwen2.5-Coder-0.5B). +/// Override via APR_DISTILL_MAX_SEQ_LEN env var (e.g. set to 4096 for +/// longer-context distillation runs that still fit). +#[cfg(all(feature = "training", feature = "cuda"))] +fn cap_max_seq_len(native: Option) -> Option { + const DEFAULT_CAP: usize = 2048; + let cap = std::env::var("APR_DISTILL_MAX_SEQ_LEN") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(DEFAULT_CAP); + let n = native?; + let chosen = n.min(cap); + if chosen < n { + eprintln!( + "[PMAT-698e] capping max_position_embeddings {n} → {chosen} \ + (override via APR_DISTILL_MAX_SEQ_LEN)" + ); + } + Some(chosen) +} + #[cfg(all(feature = "training", feature = "cuda"))] #[allow(clippy::too_many_arguments)] fn run_cuda_backend( @@ -672,7 +701,20 @@ fn run_cuda_backend( teacher_meta.intermediate_size, teacher_meta.num_layers, teacher_meta.vocab_size, - teacher_meta.max_position_embeddings, + // PMAT-698e: cap max_position_embeddings before passing into + // CudaTransformerTrainer::for_inference. The trainer uses this + // value verbatim as max_seq_len for ALL per-block scratch buffers, + // including the attention scores tensor sized at + // num_heads * max_seq_len² * 4 bytes + // For Qwen2.5-Coder-0.5B (native max_position_embeddings=32768, 14 + // heads), that's 14 * 32768² * 4 = 60 GB PER BLOCK; on a 24-layer + // model the total scratch footprint is ~1.4 TB, which overflows + // even GB10's 128 GB unified pool and surfaces as + // CUDA_ERROR_OUT_OF_MEMORY at "Block 0 upload". Distillation + // training rarely sees sequences longer than 2048 tokens; capping + // here gives ~5.6 GB workspace for the smoke. Caller can override + // via APR_DISTILL_MAX_SEQ_LEN env var. + cap_max_seq_len(teacher_meta.max_position_embeddings), teacher_meta.rms_norm_eps, teacher_meta.rope_theta, teacher_meta.architecture.as_deref(), @@ -705,7 +747,8 @@ fn run_cuda_backend( student_meta.intermediate_size, student_meta.num_layers, student_meta.vocab_size, - student_meta.max_position_embeddings, + // PMAT-698e (see teacher comment above): same cap rationale. + cap_max_seq_len(student_meta.max_position_embeddings), student_meta.rms_norm_eps, student_meta.rope_theta, student_meta.architecture.as_deref(),