Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions crates/apr-cli/src/commands/distill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,35 @@ pub(crate) fn run(
/// The dispatch script (`scripts/dispatch-distill-phase-3-gx10.sh`)
/// scales by step count rather than batch parallelism. Phase 2e
/// generalizes via a fused-step trait method.
/// PMAT-698e: cap max_position_embeddings before passing into the cuda
/// trainer. See call sites in run_cuda_backend for the full rationale —
/// CudaTransformerTrainer::for_inference uses this value as max_seq_len
/// for ALL per-block scratch buffers, and the attention scores tensor
/// is sized at num_heads * max_seq_len² * 4 bytes, which overflows
/// the GPU memory budget for max_position_embeddings >= ~8192 even on
/// the 128GB GB10 unified pool.
///
/// Default cap: 2048 (gives ~5.6 GB workspace for Qwen2.5-Coder-0.5B).
/// Override via APR_DISTILL_MAX_SEQ_LEN env var (e.g. set to 4096 for
/// longer-context distillation runs that still fit).
#[cfg(all(feature = "training", feature = "cuda"))]
fn cap_max_seq_len(native: Option<usize>) -> Option<usize> {
const DEFAULT_CAP: usize = 2048;
let cap = std::env::var("APR_DISTILL_MAX_SEQ_LEN")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.unwrap_or(DEFAULT_CAP);
let n = native?;
let chosen = n.min(cap);
if chosen < n {
eprintln!(
"[PMAT-698e] capping max_position_embeddings {n} → {chosen} \
(override via APR_DISTILL_MAX_SEQ_LEN)"
);
}
Some(chosen)
}

#[cfg(all(feature = "training", feature = "cuda"))]
#[allow(clippy::too_many_arguments)]
fn run_cuda_backend(
Expand Down Expand Up @@ -672,7 +701,20 @@ fn run_cuda_backend(
teacher_meta.intermediate_size,
teacher_meta.num_layers,
teacher_meta.vocab_size,
teacher_meta.max_position_embeddings,
// PMAT-698e: cap max_position_embeddings before passing into
// CudaTransformerTrainer::for_inference. The trainer uses this
// value verbatim as max_seq_len for ALL per-block scratch buffers,
// including the attention scores tensor sized at
// num_heads * max_seq_len² * 4 bytes
// For Qwen2.5-Coder-0.5B (native max_position_embeddings=32768, 14
// heads), that's 14 * 32768² * 4 = 60 GB PER BLOCK; on a 24-layer
// model the total scratch footprint is ~1.4 TB, which overflows
// even GB10's 128 GB unified pool and surfaces as
// CUDA_ERROR_OUT_OF_MEMORY at "Block 0 upload". Distillation
// training rarely sees sequences longer than 2048 tokens; capping
// here gives ~5.6 GB workspace for the smoke. Caller can override
// via APR_DISTILL_MAX_SEQ_LEN env var.
cap_max_seq_len(teacher_meta.max_position_embeddings),
teacher_meta.rms_norm_eps,
teacher_meta.rope_theta,
teacher_meta.architecture.as_deref(),
Expand Down Expand Up @@ -705,7 +747,8 @@ fn run_cuda_backend(
student_meta.intermediate_size,
student_meta.num_layers,
student_meta.vocab_size,
student_meta.max_position_embeddings,
// PMAT-698e (see teacher comment above): same cap rationale.
cap_max_seq_len(student_meta.max_position_embeddings),
student_meta.rms_norm_eps,
student_meta.rope_theta,
student_meta.architecture.as_deref(),
Expand Down
Loading