From cffcf5680fdce7cb84bf31d5f3d6162355f6d40e Mon Sep 17 00:00:00 2001 From: Vincent Palmer Date: Thu, 23 Apr 2026 09:17:00 +0200 Subject: [PATCH] fix: eliminate 3rd forward pass in distillation (hidden state MSE) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The distillation loop was running 3 forward passes per chunk: 1. forward_gpu for logits (removed in PR #73) 2. forward_train for activations (kept) 3. forward_with_hidden_states for MSE loss ← REMOVED NOW Now forward_train collects per-layer hidden states in TrainForwardOutput, and the MSE loss uses those directly. One forward pass per chunk. Impact: training step time drops by ~33% (was 2 forwards, now 1). Estimated: ~100s/step → ~67s/step. --- src/main.rs | 14 ++++++-------- src/training/backward.rs | 9 ++++++++- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/main.rs b/src/main.rs index 876f03f..f63e027 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1829,14 +1829,13 @@ async fn cmd_distill( ); // Hidden state matching loss: MSE between teacher and student per-layer states - let (hidden_mse_loss, _student_states) = if chunk_idx < frozen_states_per_chunk.len() { + // Uses per_layer_hidden collected during forward_train (no extra forward pass) + let hidden_mse_loss = if chunk_idx < frozen_states_per_chunk.len() && !train_output.per_layer_hidden.is_empty() { let teacher_states = &frozen_states_per_chunk[chunk_idx]; - let student_states = student.forward_with_hidden_states(batch_tokens); - let mse = if teacher_states.len() == student_states.len() && !teacher_states.is_empty() { - let _hd = config.hidden_dim; + let student_states = &train_output.per_layer_hidden; + if teacher_states.len() == student_states.len() && !teacher_states.is_empty() { let mut total_mse = 0.0f32; let mut num_layers_compared = 0usize; - // Compare at every 5th layer to save compute for layer_idx in (0..teacher_states.len()).step_by(5) { let t_state = &teacher_states[layer_idx]; let s_state = &student_states[layer_idx]; @@ -1851,10 +1850,9 @@ async fn cmd_distill( if num_layers_compared > 0 { total_mse / num_layers_compared as f32 } else { 0.0 } } else { 0.0 - }; - (mse, Some(student_states)) + } } else { - (0.0, None) + 0.0 }; // MoE Load balance loss: num_experts × Σ(f_i × P_i) diff --git a/src/training/backward.rs b/src/training/backward.rs index 941480a..b91ebdf 100644 --- a/src/training/backward.rs +++ b/src/training/backward.rs @@ -16,6 +16,8 @@ pub struct TrainForwardOutput { pub activations: Vec, /// Post-final-norm hidden states [seq, hidden_dim] — input to lm_head pub final_hidden: Vec, + /// Per-layer post-residual hidden states (for hidden state MSE loss) + pub per_layer_hidden: Vec>, } impl CpuBlockAttnResModel { @@ -62,6 +64,8 @@ impl CpuBlockAttnResModel { let mut activations = Vec::with_capacity(self.num_layers); let lora_m = self.lora_manager.as_ref(); + let mut per_layer_hidden: Vec> = Vec::with_capacity(self.num_layers); + for (layer_idx, layer) in self.layers.iter().enumerate() { let ple_slice = ple_precomputed.as_ref().map(|pre| { let mut slice = vec![0.0f32; seq * ple_dim]; @@ -218,6 +222,9 @@ impl CpuBlockAttnResModel { expert_activations: layer_expert_act.unwrap_or_default(), }); + // Store per-layer hidden state for MSE loss + per_layer_hidden.push(hidden.clone()); + for t in 0..seq { for d in 0..hd { partial_sum[d] += hidden[t * hd + d]; } } if self.is_block_boundary(layer_idx) { for d in 0..hd { partial_sum[d] /= ((seq) * (self.block_config.layers_per_block)) as f32; } @@ -236,7 +243,7 @@ impl CpuBlockAttnResModel { for l in logits.iter_mut() { *l = (*l / cap).tanh() * cap; } } - TrainForwardOutput { logits, routing_data, activations, final_hidden } + TrainForwardOutput { logits, routing_data, activations, final_hidden, per_layer_hidden } } }