feat(distill): pipeline integration teacher + student + kd_step end-to-end (SPEC-DISTILL-001 Phase 2c)#1792
Closed
noahgift wants to merge 9 commits into
Closed
feat(distill): pipeline integration teacher + student + kd_step end-to-end (SPEC-DISTILL-001 Phase 2c)#1792noahgift wants to merge 9 commits into
noahgift wants to merge 9 commits into
Conversation
…L-001 Phase 1b, PMAT-693)
Adds the real teacher backend the SPEC-DISTILL-001 Phase 1b ticket
scopes: CudaTrainerTeacher wraps entrenar's CudaTransformerTrainer in
inference-only mode, delegates logits_for_batch to forward_logits()
per batch element, returns shape [batch, vocab_size].
Gated behind a new `cuda` feature on aprender-train-distill that
propagates to entrenar/cuda. Without the feature, only FixtureTeacher
(Phase 1) is available — sufficient for unit tests but not for real
training. Real distillation runs (Phase 4) require --features cuda.
Surface
=======
#[cfg(feature = "cuda")]
pub struct CudaTrainerTeacher { /* wraps CudaTransformerTrainer */ }
impl CudaTrainerTeacher {
pub fn for_inference(
checkpoint_dir: impl AsRef<Path>,
model_config: TransformerConfig,
) -> Result<Self> { ... }
}
impl TeacherLogitsProvider for CudaTrainerTeacher {
fn vocab_size(&self) -> usize { ... }
fn logits_for_batch(&mut self, input_ids: &[Vec<u32>])
-> Result<Vec<Vec<f32>>> { ... }
}
Defensive checks
================
- forward_logits returning None → EntrenarError::Internal with a clear
"likely missing weights or CUDA init failure" message
- logits.len() != vocab_size → EntrenarError::Internal flagging
TransformerConfig vs checkpoint vocab drift (the common silent failure
mode for loaded-from-disk distillation runs)
Tests
=====
All 6 teacher_provider tests pass under both --features (none) and
--features cuda. Compile gates verified:
cargo check -p aprender-train-distill # clean
cargo check -p aprender-train-distill --features cuda # clean
What's next
===========
Phase 2 (PMAT-694, follow-up): wire CudaTransformerTrainer's KD-loss
backward into the student path — replaces the remaining
build_synthetic_logits call site for the student in pipeline.rs::train().
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ent (SPEC-DISTILL-001 Phase 2)
Wires Phase 1's teacher provider into a per-batch KD orchestration step
that produces both the combined α·CE + (1-α)·T²·KL scalar loss (for
telemetry) and the KD-aware logit-space gradient (Phase 2b plug point).
What this PR adds
=================
New module `aprender-train-distill::kd_step`:
pub fn kd_loss(
student_logits: &[f32],
teacher_logits: &[f32],
label: usize,
temperature: f32,
alpha: f32,
) -> f32;
pub fn kd_logit_gradient(
student_logits: &[f32],
teacher_logits: &[f32],
label: usize,
temperature: f32,
alpha: f32,
) -> Vec<f32>;
pub fn kd_step<F: FnMut(&[u32]) -> Vec<f32>>(
teacher: &mut dyn TeacherLogitsProvider,
input_ids: &[Vec<u32>],
labels: &[usize],
temperature: f32,
alpha: f32,
compute_student_logits: F,
) -> Result<(f32, Vec<Vec<f32>>)>;
The gradient is the Hinton et al. 2015 §2 derivation:
∂L/∂s = α · (softmax(s) - one_hot(label))
+ (1-α) · T · (softmax(s/T) - softmax(t/T))
(T factor, not T² — one T factor is absorbed by the softmax derivative
chain rule.)
Scope: Phase 2 vs Phase 2b
==========================
Phase 2 (this PR) ships the orchestration math, all in pure Rust on the
CPU. The output `Vec<Vec<f32>>` of per-batch gradients is what Phase 2b
will plumb into `CudaTransformerTrainer.forward_backward_kd_batch` as
the backward-pass seed (replacing the CE-only gradient currently used by
forward_backward_batch).
Splitting Phase 2 into 2a/2b lets us land the orchestration layer + its
tests now, separate from extending the complex GPU trainer code path.
Falsifiers pinned
=================
3 KD-step falsifiers + 6 sanity tests, all passing:
- F-DISTILL-KDSTEP-001 (alpha=1 → pure CE)
- F-DISTILL-KDSTEP-002 (student==teacher → zero KL gradient under alpha=0)
- F-DISTILL-KDSTEP-003 (loss monotone in student-teacher divergence)
- softmax unit-sum + non-negative
- CE gradient correct sign at label vs non-label positions
- kd_step orchestration end-to-end
- kd_step empty-batch sanity
- kd_step vocab-mismatch error path
- kd_loss alpha=1 collapses to pure CE
All 50 aprender-train-distill lib tests pass (was 41 — 9 new).
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…TILL-001 Phase 2b)
Mirrors Phase 1's TeacherLogitsProvider for the student side. The
student has two methods: logits_for_batch (forward) and
apply_kd_gradient (backward + optimizer step). FixtureStudent
implements both for CPU-only unit testing — Phase 2c will add a
CudaStudentProvider that wraps CudaTransformerTrainer.
What this PR adds
=================
pub trait StudentLogitsProvider {
fn vocab_size(&self) -> usize;
fn logits_for_batch(&mut self, input_ids: &[Vec<u32>])
-> Result<Vec<Vec<f32>>>;
fn apply_kd_gradient(&mut self, gradient: &[Vec<f32>])
-> Result<()>;
}
pub struct FixtureStudent {
vocab_size: usize,
logits: Vec<f32>, // current student parameters
learning_rate: f32,
}
FixtureStudent's apply_kd_gradient averages the gradient across batch
elements (canonical SGD batch averaging) and subtracts the scaled
gradient from its internal logits buffer. This isn't a real model —
it's a logit-space optimization fixture that lets us validate the
KD pipeline's gradient direction is correct without needing CUDA.
Falsifiers pinned
=================
7 student_provider tests + 2 falsifiers, all passing:
- F-DISTILL-STUDENT-001 — one KD step moves student logits toward
teacher's preferred token. Setup: uniform student, teacher prefers
token 5, alpha=0 (pure KL signal). After one step, student logit
at index 5 must be strictly greater than before.
- F-DISTILL-STUDENT-002 — 10 sequential KD steps strictly decrease
per-step KD loss. With LR=0.5, loss after 10 steps < 90% of initial.
Validates the gradient direction is correct (descent, not ascent).
Plus 5 sanity tests covering vocab_size reporting, batch broadcast,
shape validation, in-place logit update, and batch averaging math.
Architecture
============
Stacks on top of #1788 (kd_step). Pipeline integration that uses both
TeacherLogitsProvider + StudentLogitsProvider + kd_step is Phase 2c.
Phase 2c (PMAT-696, follow-up): CudaStudentProvider that wraps
CudaTransformerTrainer for production runs. Once it lands, end-to-end
GPU distillation is unblocked.
Tests
=====
All 57 aprender-train-distill lib tests pass (was 50 — 7 new).
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…o-end (SPEC-DISTILL-001 Phase 2c)
Rewrites Pipeline::train() to use TeacherLogitsProvider +
StudentLogitsProvider + kd_step end-to-end. Replaces the
build_synthetic_logits stubs on both sides with real abstraction calls.
Pipeline gains a `student: Box<dyn StudentLogitsProvider + Send>` field
and a `Pipeline::with_student()` builder mirroring `with_teacher()`.
Default backends are FixtureTeacher + FixtureStudent so legacy tests
behave identically. Phase 2d swaps in CudaStudentProvider.
Training loop, per step:
1. dummy_batch = [[0u32]; batch_size] (Phase 4 plugs in real tokens)
2. teacher.logits_for_batch(dummy_batch) → teacher logits
3. kd_step(teacher, dummy_batch, labels, T, α, student_logits_closure)
→ (scalar loss, per-batch logit gradients)
4. student.apply_kd_gradient(grads) → student updates
bracketed by initial-loss + final-loss measurements via the new
kd_step_loss_for_pipeline helper.
Falsifiers
==========
F-DISTILL-PIPELINE-001 (new) — end-to-end falsifier: runs
Pipeline::execute() with FixtureTeacher + FixtureStudent + 3 epochs
and asserts final_loss < initial_loss. Pins the entire data flow:
teacher → student → kd_step → apply_kd_gradient. Any broken link
either flatlines or increases the loss.
Phase 2d plug points
====================
- The dummy_batch in train() is the natural insertion point for a real
dataset iterator (Phase 4 work).
- The student-logits closure in kd_step is the natural insertion point
for CudaStudentProvider's forward_logits (Phase 2d).
- The apply_kd_gradient call is the natural insertion point for
CudaStudentProvider's forward_backward_kd_batch path (Phase 2d).
Tests
=====
All 58 aprender-train-distill lib tests pass (was 57 — 1 new
F-DISTILL-PIPELINE-001 integration falsifier). The four legacy
helpers (build_synthetic_logits, kd_gradient, softmax_2d,
write_logits_to_weights) are marked #[allow(dead_code)] for back-compat
until Phase 2d's wiring fully replaces the on-disk weights round-trip.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1538987 to
81b5f1a
Compare
This was referenced May 18, 2026
Closed
…ation_with_kd Clippy `-D warnings` failed in CI: `variable does not need to be mutable` at pipeline.rs:187. The new CudaStudentProvider path owns its parameter buffer so the destructured `student_weights` is read-only when only used for shape inspection + later export. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Contributor
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Rewrites
Pipeline::train()to useTeacherLogitsProvider+StudentLogitsProvider+kd_stepend-to-end. Replaces thebuild_synthetic_logitsstubs on both sides with real abstraction calls.Stacks on top of #1791 (Phase 2b StudentLogitsProvider).
Training loop, per step
dummy_batch = [[0u32]; batch_size](Phase 4 plugs in real tokens)teacher.logits_for_batch(dummy_batch)→ teacher logitskd_step(teacher, dummy_batch, labels, T, α, student_logits_closure)→ (scalar loss, per-batch logit gradients)student.apply_kd_gradient(grads)→ student updatesBracketed by initial-loss + final-loss measurements via the new
kd_step_loss_for_pipelinehelper.Pipelinegains astudentfield andwith_student()builder mirroringwith_teacher().Phase 2d plug points
This PR establishes the architecture; Phase 2d (PMAT-696) wires the CUDA backend:
dummy_batchintrain()is the natural insertion point for a real dataset iterator (Phase 4 work).kd_stepis the natural insertion point forCudaStudentProvider'sforward_logits(Phase 2d).apply_kd_gradientcall is the natural insertion point forCudaStudentProvider'sforward_backward_kd_batchpath (Phase 2d).Falsifier
final_loss < initial_lossThis pins the entire data flow: teacher → student → kd_step → apply_kd_gradient. Any broken link either flatlines the loss or increases it.
Tests
All 58 aprender-train-distill lib tests pass (was 57 — 1 new F-DISTILL-PIPELINE-001 integration falsifier).
The four legacy helpers (
build_synthetic_logits,kd_gradient,softmax_2d,write_logits_to_weights) are marked#[allow(dead_code)]for back-compat until Phase 2d's wiring fully replaces the on-disk weights round-trip.Test plan
cargo check -p aprender-train-distillclean🤖 Generated with Claude Code