Skip to content

feat(distill): CudaTrainerTeacher — real teacher backend (SPEC-DISTILL-001 Phase 1b, PMAT-693)#1787

Closed
noahgift wants to merge 6 commits into
mainfrom
feat/distill-phase-1b-realizar-teacher-pmat-693
Closed

feat(distill): CudaTrainerTeacher — real teacher backend (SPEC-DISTILL-001 Phase 1b, PMAT-693)#1787
noahgift wants to merge 6 commits into
mainfrom
feat/distill-phase-1b-realizar-teacher-pmat-693

Conversation

@noahgift
Copy link
Copy Markdown
Contributor

Summary

Adds the real teacher backend SPEC-DISTILL-001 Phase 1b scopes: CudaTrainerTeacher wraps entrenar's CudaTransformerTrainer in inference-only mode and delegates logits_for_batch to forward_logits() per batch element.

Stacks on top of #1786 (Phase 1, the trait + FixtureTeacher). That PR must merge before this one to avoid a cherry-pick.

What this PR adds

#[cfg(feature = "cuda")]
pub struct CudaTrainerTeacher { /* wraps CudaTransformerTrainer */ }

impl CudaTrainerTeacher {
    pub fn for_inference(
        checkpoint_dir: impl AsRef<Path>,
        model_config: TransformerConfig,
    ) -> Result<Self> { ... }
}

impl TeacherLogitsProvider for CudaTrainerTeacher {
    fn vocab_size(&self) -> usize { ... }
    fn logits_for_batch(&mut self, input_ids: &[Vec<u32>])
        -> Result<Vec<Vec<f32>>> { ... }
}

New cuda feature on aprender-train-distill propagates to entrenar/cuda. Without the feature, only FixtureTeacher (Phase 1) is available — sufficient for unit tests but not for real training.

Defensive checks

  • forward_logits returning NoneEntrenarError::Internal with a clear "likely missing weights or CUDA init failure" message
  • logits.len() != vocab_sizeEntrenarError::Internal flagging TransformerConfig vs checkpoint vocab drift (the common silent failure mode for loaded-from-disk distillation runs)

Tests

All 6 teacher_provider tests pass under both --features (none) and --features cuda. Compile gates verified:

cargo check -p aprender-train-distill                  # clean
cargo check -p aprender-train-distill --features cuda  # clean

The CudaTrainerTeacher itself doesn't have a unit test that exercises the GPU path — that requires CUDA at test time. F-DISTILL-TEACHER-002 (the 1e-6 byte-parity falsifier) is integration-tested in Phase 4 (production run) and locally with apr run vs CudaTrainerTeacher.logits_for_batch parity scripts.

What's next

  • Phase 2 — PMAT-694 (follow-up): wire CudaTransformerTrainer's KD-loss backward into the student path — replaces the remaining build_synthetic_logits call site for the student in pipeline.rs::train(). This is the work that actually starts producing trained student weights.
  • Phases 3-6: per SPEC-DISTILL-001.

Test plan

  • cargo check -p aprender-train-distill clean
  • cargo check -p aprender-train-distill --features cuda clean
  • All 6 teacher_provider tests pass under both feature gates
  • All 41 aprender-train-distill lib tests pass
  • Phase 4 integration parity (F-DISTILL-TEACHER-002): apr run logits vs CudaTrainerTeacher logits byte-equal within 1e-6 absolute

🤖 Generated with Claude Code

…L-001 Phase 1b, PMAT-693)

Adds the real teacher backend the SPEC-DISTILL-001 Phase 1b ticket
scopes: CudaTrainerTeacher wraps entrenar's CudaTransformerTrainer in
inference-only mode, delegates logits_for_batch to forward_logits()
per batch element, returns shape [batch, vocab_size].

Gated behind a new `cuda` feature on aprender-train-distill that
propagates to entrenar/cuda. Without the feature, only FixtureTeacher
(Phase 1) is available — sufficient for unit tests but not for real
training. Real distillation runs (Phase 4) require --features cuda.

Surface
=======

  #[cfg(feature = "cuda")]
  pub struct CudaTrainerTeacher { /* wraps CudaTransformerTrainer */ }

  impl CudaTrainerTeacher {
      pub fn for_inference(
          checkpoint_dir: impl AsRef<Path>,
          model_config: TransformerConfig,
      ) -> Result<Self> { ... }
  }

  impl TeacherLogitsProvider for CudaTrainerTeacher {
      fn vocab_size(&self) -> usize { ... }
      fn logits_for_batch(&mut self, input_ids: &[Vec<u32>])
          -> Result<Vec<Vec<f32>>> { ... }
  }

Defensive checks
================

- forward_logits returning None → EntrenarError::Internal with a clear
  "likely missing weights or CUDA init failure" message
- logits.len() != vocab_size → EntrenarError::Internal flagging
  TransformerConfig vs checkpoint vocab drift (the common silent failure
  mode for loaded-from-disk distillation runs)

Tests
=====

All 6 teacher_provider tests pass under both --features (none) and
--features cuda. Compile gates verified:
  cargo check -p aprender-train-distill                  # clean
  cargo check -p aprender-train-distill --features cuda  # clean

What's next
===========

Phase 2 (PMAT-694, follow-up): wire CudaTransformerTrainer's KD-loss
backward into the student path — replaces the remaining
build_synthetic_logits call site for the student in pipeline.rs::train().

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@noahgift noahgift force-pushed the feat/distill-phase-1b-realizar-teacher-pmat-693 branch from 7a63f08 to 9f47dc2 Compare May 18, 2026 12:14
@noahgift
Copy link
Copy Markdown
Contributor Author

Subsumed by #1797 squash-merge (chain-PR leapfrog pattern per memory rule). All content landed on main at aee8716.

@noahgift noahgift closed this May 18, 2026
auto-merge was automatically disabled May 18, 2026 16:40

Pull request was closed

@noahgift noahgift deleted the feat/distill-phase-1b-realizar-teacher-pmat-693 branch May 18, 2026 16:40
noahgift added a commit that referenced this pull request May 20, 2026
… 4 RUNNING (#1851)

Captures the live state of the distillation epic as of 2026-05-20:

  Phase 1 — Teacher provider              ✅ MERGED (#1786, #1787)
  Phase 2 — Student fwd/bwd + KD          ✅ MERGED (#1788#1797)
  Phase 3 — E2E smoke on Blackwell GB10   ✅ DISCHARGED (#1828)
  Phase 3b — seq_len=256 scale verify     ✅ DISCHARGED (#1833)
  Phase 4 — 50K training (Stage D)        🟡 RUNNING (PID 196378, gx10)
  Phase 5 — HumanEval pass@1              ⏳ ready (#1847)
  Phase 6 — Publish v2                    ⏳ ready (#1848)

Inserts a new top-of-doc status table that points at:
- The 11-PR Blackwell cascade (post-mortem in blackwell-cascade-postmortem.md)
- Stage C real-corpus dispatch result (15.61 → 6.01 over 124 steps)
- Stage D running with ETA ~22h from 2026-05-20 13:43 UTC
- Phase 5/6 turnkey scripts ready post-D

This captures institutional knowledge for the team and future sessions:
the spec doc reflects what's actually shipped rather than the original
plan from 2026-05-18 when the epic was still scaffolded.

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant