# Experiment 4:
# Knowledge Distillation

In [None]:
from tunix.trainers import DistillationTrainer

# Load teacher (our best GRPO model) and student models
teacher_model = FlaxAutoModelForCausalLM.from_pretrained("./gemma-2b-gsm8k-grpo")
student_model = FlaxAutoModelForCausalLM.from_pretrained("google/gemma-1b")

# Prepare distillation data from GSM8K
distill_data = formatted_train[:2000]

# Configure distillation
distill_trainer = DistillationTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    tokenizer=tokenizer,
    distillation_type="logit",  # or "attention" for attention transfer
    temperature=2.0,  # Softens probability distributions
    alpha=0.5,  # Balance between distillation and task loss
    learning_rate=5e-5,
    num_epochs=5,
    batch_size=8
)

# Distill knowledge
print("Starting knowledge distillation...")
distill_metrics = distill_trainer.train(distill_data)

print(f"\nDistillation completed!")
print(f"  Initial distillation loss: {distill_metrics['initial_loss']:.4f}")
print(f"  Final distillation loss: {distill_metrics['final_loss']:.4f}")
print(f"  Training time: {distill_metrics['training_time_hours']:.1f} hours")

# Evaluate student model on GSM8K
print("\nEvaluating student model on GSM8K...")
student_results = evaluate_on_gsm8k(student_model, tokenizer, gsm8k_test)

print(f"\n{'='*60}")
print("Distillation Results")
print(f"{'='*60}")
print(f"  Teacher Pass@1: {rl_results['pass@1']:.1%}")
print(f"  Student Pass@1: {student_results['pass@1']:.1%}")
print(f"  Accuracy Retention: {(student_results['pass@1']/rl_results['pass@1']):.1%}")
print(f"  Model Size: {distill_metrics['teacher_size_mb']:.1f} MB â†’ {distill_metrics['student_size_mb']:.1f} MB")
print(f"  Size Reduction: {(1 - distill_metrics['student_size_mb']/distill_metrics['teacher_size_mb']):.1%}")
print(f"  Inference Speedup: {distill_metrics['speedup']:.2f}x")

# Save distilled model
distill_trainer.save_checkpoint("./gemma-1b-gsm8k-distilled")
