diff --git a/src/main.rs b/src/main.rs index 876f03f..f63e027 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1829,14 +1829,13 @@ async fn cmd_distill( ); // Hidden state matching loss: MSE between teacher and student per-layer states - let (hidden_mse_loss, _student_states) = if chunk_idx < frozen_states_per_chunk.len() { + // Uses per_layer_hidden collected during forward_train (no extra forward pass) + let hidden_mse_loss = if chunk_idx < frozen_states_per_chunk.len() && !train_output.per_layer_hidden.is_empty() { let teacher_states = &frozen_states_per_chunk[chunk_idx]; - let student_states = student.forward_with_hidden_states(batch_tokens); - let mse = if teacher_states.len() == student_states.len() && !teacher_states.is_empty() { - let _hd = config.hidden_dim; + let student_states = &train_output.per_layer_hidden; + if teacher_states.len() == student_states.len() && !teacher_states.is_empty() { let mut total_mse = 0.0f32; let mut num_layers_compared = 0usize; - // Compare at every 5th layer to save compute for layer_idx in (0..teacher_states.len()).step_by(5) { let t_state = &teacher_states[layer_idx]; let s_state = &student_states[layer_idx]; @@ -1851,10 +1850,9 @@ async fn cmd_distill( if num_layers_compared > 0 { total_mse / num_layers_compared as f32 } else { 0.0 } } else { 0.0 - }; - (mse, Some(student_states)) + } } else { - (0.0, None) + 0.0 }; // MoE Load balance loss: num_experts × Σ(f_i × P_i) diff --git a/src/training/backward.rs b/src/training/backward.rs index 941480a..b91ebdf 100644 --- a/src/training/backward.rs +++ b/src/training/backward.rs @@ -16,6 +16,8 @@ pub struct TrainForwardOutput { pub activations: Vec, /// Post-final-norm hidden states [seq, hidden_dim] — input to lm_head pub final_hidden: Vec, + /// Per-layer post-residual hidden states (for hidden state MSE loss) + pub per_layer_hidden: Vec>, } impl CpuBlockAttnResModel { @@ -62,6 +64,8 @@ impl CpuBlockAttnResModel { let mut activations = Vec::with_capacity(self.num_layers); let lora_m = self.lora_manager.as_ref(); + let mut per_layer_hidden: Vec> = Vec::with_capacity(self.num_layers); + for (layer_idx, layer) in self.layers.iter().enumerate() { let ple_slice = ple_precomputed.as_ref().map(|pre| { let mut slice = vec![0.0f32; seq * ple_dim]; @@ -218,6 +222,9 @@ impl CpuBlockAttnResModel { expert_activations: layer_expert_act.unwrap_or_default(), }); + // Store per-layer hidden state for MSE loss + per_layer_hidden.push(hidden.clone()); + for t in 0..seq { for d in 0..hd { partial_sum[d] += hidden[t * hd + d]; } } if self.is_block_boundary(layer_idx) { for d in 0..hd { partial_sum[d] /= ((seq) * (self.block_config.layers_per_block)) as f32; } @@ -236,7 +243,7 @@ impl CpuBlockAttnResModel { for l in logits.iter_mut() { *l = (*l / cap).tanh() * cap; } } - TrainForwardOutput { logits, routing_data, activations, final_hidden } + TrainForwardOutput { logits, routing_data, activations, final_hidden, per_layer_hidden } } }