Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions crates/aprender-train/src/autograd/cuda_forward/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,27 @@ impl ForwardKernelCache {
let mut count = 0u32;
let target = self.sm_target.clone();

// Helper: generate PTX and compile
// Helper: generate PTX and compile.
//
// PMAT-698j: previously hardcoded "silu_forward" as the cache key,
// which meant every warm!() call collided on the same HashMap entry.
// Only the FIRST kernel compiled actually got stored; all subsequent
// warm!() invocations short-circuited because "silu_forward" was
// already occupied. At runtime every other kernel (rmsnorm, rope,
// softmax, swiglu, residual, etc.) cache-missed under its real key
// and JIT-compiled mid-training — on Blackwell sm_121 that
// corrupted the CUDA stream and surfaced as the cascading "Block 0
// upload failed" / "forward_backward_with_grad returned None"
// errors hunted across PMAT-698e..i.
//
// Discovered by PMAT-698i diagnostic logging: [FWD-CACHE] showed
// every "pre-warmed" kernel actually JIT'd at first use because
// the cache only contained one entry. One-character fix.
macro_rules! warm {
($key:expr, $kernel:expr) => {{
let key = $key;
let ptx = $kernel.emit_ptx_for_target(&target);
self.get_or_compile("silu_forward", &ptx)?;
self.get_or_compile(&key, &ptx)?;
count += 1;
}};
}
Expand Down
Loading