Skip to content

feat(distill): teacher logits provider abstraction (SPEC-DISTILL-001 Phase 1, PMAT-691)#1786

Merged
noahgift merged 3 commits into
mainfrom
feat/distill-phase-1-teacher-cache-pmat-691
May 18, 2026
Merged

feat(distill): teacher logits provider abstraction (SPEC-DISTILL-001 Phase 1, PMAT-691)#1786
noahgift merged 3 commits into
mainfrom
feat/distill-phase-1-teacher-cache-pmat-691

Conversation

@noahgift
Copy link
Copy Markdown
Contributor

Summary

Implements SPEC-DISTILL-001 Phase 1 — replaces the synthetic-logits stub in aprender-train-distill::pipeline.rs::train() with a real TeacherLogitsProvider trait + a FixtureTeacher implementation suitable for unit tests.

Phase 1b (PMAT-692, follow-up PR) will add a RealizarTeacher backend that runs the real 7B teacher via aprender-serve.

What this PR adds

New moduleaprender-train-distill::teacher_provider:

pub trait TeacherLogitsProvider {
    fn vocab_size(&self) -> usize;
    fn logits_for_batch(&mut self, input_ids: &[Vec<u32>])
        -> Result<Vec<Vec<f32>>>;
}

pub struct FixtureTeacher { vocab_size: usize }

The FixtureTeacher returns deterministic stub logits whose argmax is the last input token — useful for unit-testing the KD signal flow without needing a real model file. Out-of-vocab tokens collapse to position 0 (robustness).

Pipeline integrationpipeline.rs:

  • Pipeline gains a teacher: Box<dyn TeacherLogitsProvider + Send> field
  • Pipeline::new() defaults to FixtureTeacher::new(32) so legacy tests behave identically
  • Pipeline::with_teacher() lets Phase 4 callers swap in RealizarTeacher
  • train() pulls teacher_logits from self.teacher.logits_for_batch() and reshapes to Array2<f32> shape [batch, vocab]
  • Legacy build_synthetic_logits() retained only for the student path (Phase 2 replaces it with CudaTransformerTrainer.forward_backward_kd)

Why "online" instead of cached (spec revision)

The original SPEC-DISTILL-001 v1.0.0 plan cached top-K=64 logits to disk. The storage math (1.24B tokens × 64 × 6 bytes ≈ 476 GB) exceeds the lambda-vector NVMe budget. Online teacher inference (the DistilBERT / Distil-Qwen actual practice) trades ~2× student-step time for zero cache footprint. Spec was revised to v1.1.0 in #1785 with this design.

Tests

6 new teacher_provider::tests::* unit tests pass:

  • vocab_size reporting
  • one logits vec per batch element
  • argmax recovers last input token (F-DISTILL-FIXTURE-001)
  • out-of-vocab token collapses to position 0 (robustness)
  • empty batch returns empty vec
  • deterministic across calls and instances

All 41 aprender-train-distill lib tests pass.

What's next

  • Phase 1b — PMAT-692: implement RealizarTeacher (the real backend that delegates to aprender-serve's Model::forward_logits path). The trait is now in place; Phase 1b is just adding a second impl.
  • Phase 2 — PMAT-693: wire CudaTransformerTrainer.forward_backward_kd_batch into the student path (replaces the remaining build_synthetic_logits call site for the student).
  • Phase 3-6: per SPEC-DISTILL-001.

Test plan

  • 6 new teacher_provider tests pass
  • All 41 aprender-train-distill lib tests pass
  • cargo check -p aprender-train-distill clean
  • Phase 1b — wire RealizarTeacher to the real teacher inference path

🤖 Generated with Claude Code

…Phase 1, PMAT-691)

Replaces the synthetic-logits stub in pipeline.rs::train() with a real
TeacherLogitsProvider trait + a FixtureTeacher implementation suitable
for unit tests. Phase 1b (PMAT-692) wires a RealizarTeacher backend
that runs the real 7B teacher via aprender-serve.

New module: aprender-train-distill::teacher_provider

  pub trait TeacherLogitsProvider {
      fn vocab_size(&self) -> usize;
      fn logits_for_batch(&mut self, input_ids: &[Vec<u32>])
          -> Result<Vec<Vec<f32>>>;
  }

  pub struct FixtureTeacher { vocab_size: usize }

The FixtureTeacher returns deterministic stub logits whose argmax is
the last input token — useful for unit testing the KD signal flow
without needing a real model file.

Pipeline integration (pipeline.rs):

- Pipeline gains a `teacher: Box<dyn TeacherLogitsProvider + Send>` field
- Pipeline::new() defaults to FixtureTeacher::new(32) so legacy tests
  behave identically
- Pipeline::with_teacher() lets Phase 4 callers swap in RealizarTeacher
- train() now pulls teacher_logits from self.teacher.logits_for_batch()
  and reshapes to ndarray Array2<f32> [batch, vocab]
- Legacy build_synthetic_logits() retained only for the student path
  (Phase 2 replaces it with CudaTransformerTrainer.forward_backward_kd)

Tests: 6 new teacher_provider unit tests pass:
  - vocab_size reporting
  - one logits vec per batch element
  - argmax recovers last input token (F-DISTILL-FIXTURE-001)
  - out-of-vocab token collapses to position 0 (robustness)
  - empty batch returns empty vec
  - deterministic across calls and instances

All 41 aprender-train-distill lib tests pass. SPEC-DISTILL-001
v1.1.0 §"Phase 1 — Online teacher logits provider" delivered.

Phase 1b (PMAT-692, follow-up): RealizarTeacher backend that delegates
to aprender-serve's real teacher inference path.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@noahgift noahgift enabled auto-merge (squash) May 18, 2026 10:15
@noahgift noahgift merged commit 87e2bf5 into main May 18, 2026
10 checks passed
@noahgift noahgift deleted the feat/distill-phase-1-teacher-cache-pmat-691 branch May 18, 2026 12:00
noahgift added a commit that referenced this pull request May 20, 2026
… 4 RUNNING (#1851)

Captures the live state of the distillation epic as of 2026-05-20:

  Phase 1 — Teacher provider              ✅ MERGED (#1786, #1787)
  Phase 2 — Student fwd/bwd + KD          ✅ MERGED (#1788#1797)
  Phase 3 — E2E smoke on Blackwell GB10   ✅ DISCHARGED (#1828)
  Phase 3b — seq_len=256 scale verify     ✅ DISCHARGED (#1833)
  Phase 4 — 50K training (Stage D)        🟡 RUNNING (PID 196378, gx10)
  Phase 5 — HumanEval pass@1              ⏳ ready (#1847)
  Phase 6 — Publish v2                    ⏳ ready (#1848)

Inserts a new top-of-doc status table that points at:
- The 11-PR Blackwell cascade (post-mortem in blackwell-cascade-postmortem.md)
- Stage C real-corpus dispatch result (15.61 → 6.01 over 124 steps)
- Stage D running with ETA ~22h from 2026-05-20 13:43 UTC
- Phase 5/6 turnkey scripts ready post-D

This captures institutional knowledge for the team and future sessions:
the spec doc reflects what's actually shipped rather than the original
plan from 2026-05-18 when the epic was still scaffolded.

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant