# üìä LMFast: Knowledge Distillation

**Transfer knowledge from a powerful teacher model to a tiny student model!**

## What You'll Learn
- How knowledge distillation works
- Distill Qwen2.5-1.5B ‚Üí SmolLM-135M
- Compare student performance before/after
- Offline logit generation for memory efficiency

## Why Distillation?
- Get GPT-quality from a tiny model
- Run on edge devices
- 10x faster inference
- 100x cheaper serving

**Time to complete:** ~20 minutes

## 1Ô∏è‚É£ Setup

In [None]:
!pip install -q lmfast[all]

import lmfast
lmfast.setup_colab_env()

import torch
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2Ô∏è‚É£ Prepare Distillation Data

For distillation, we need prompts for the teacher to generate responses.

In [None]:
from datasets import load_dataset

# Load instruction dataset
dataset = load_dataset("yahma/alpaca-cleaned", split="train[:500]")

# Format prompts
def format_prompt(example):
    if example["input"]:
        prompt = f"""### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:\n"""
    else:
        prompt = f"""### Instruction:\n{example['instruction']}\n\n### Response:\n"""
    return {"prompt": prompt, "response": example["output"]}

dataset = dataset.map(format_prompt)
print(f"Dataset size: {len(dataset)} examples")
print(f"\nExample prompt:\n{dataset[0]['prompt'][:200]}...")

## 3Ô∏è‚É£ Baseline: Test Student Before Distillation

In [None]:
from lmfast.inference import SLMServer

# Load untrained student
student_baseline = SLMServer("HuggingFaceTB/SmolLM-135M")

test_prompts = [
    "Explain machine learning in simple terms.",
    "Write a Python function to calculate factorial.",
    "What are the benefits of small language models?"
]

print("üìä BASELINE (Before Distillation)")
print("=" * 50)
for prompt in test_prompts:
    response = student_baseline.generate(prompt, max_new_tokens=100)
    print(f"\nQ: {prompt}")
    print(f"A: {response[:200]}...")

# Free memory
del student_baseline
torch.cuda.empty_cache()

## 4Ô∏è‚É£ Configure Distillation

LMFast supports multiple distillation strategies:

| Method | Description | Best For |
|--------|-------------|----------|
| **Standard KD** | KL divergence on logits | General tasks |
| **CoT Distillation** | Transfer reasoning traces | Math, logic |
| **Offline Distillation** | Pre-compute teacher outputs | Memory-constrained |

In [None]:
from lmfast.core.config import DistillationConfig

distill_config = DistillationConfig(
    teacher_model="Qwen/Qwen2.5-1.5B-Instruct",  # Powerful teacher
    temperature=2.0,  # Softer distributions for better transfer
    alpha=0.5,  # 50% KD loss, 50% CE loss
    max_seq_length=512,
)

print(f"Teacher: {distill_config.teacher_model}")
print(f"Temperature: {distill_config.temperature}")
print(f"Alpha (KD weight): {distill_config.alpha}")

## 5Ô∏è‚É£ Run Distillation

This will:
1. Load teacher model (Qwen-1.5B)
2. Load student model (SmolLM-135M) 
3. Generate teacher logits
4. Train student to match teacher distribution

In [None]:
from lmfast.distillation import DistillationTrainer

# Create distillation trainer
trainer = DistillationTrainer(
    student_model="HuggingFaceTB/SmolLM-135M",
    distillation_config=distill_config,
)

# Run distillation
print("üéì Starting distillation...")
print("This may take 10-15 minutes on T4")

trainer.distill(
    dataset,
    output_dir="./distilled_model",
    max_steps=200,
    batch_size=2,
    gradient_accumulation_steps=8,
)

print("‚úÖ Distillation complete!")

## 6Ô∏è‚É£ Test Distilled Student

In [None]:
# Load distilled model
student_distilled = SLMServer("./distilled_model")

print("üéì AFTER DISTILLATION")
print("=" * 50)
for prompt in test_prompts:
    response = student_distilled.generate(prompt, max_new_tokens=100)
    print(f"\nQ: {prompt}")
    print(f"A: {response[:200]}...")

## 7Ô∏è‚É£ Alternative: Offline Distillation (Memory Efficient)

If you can't fit teacher + student in memory simultaneously, use offline distillation:

In [None]:
from lmfast.distillation import generate_teacher_labels

# Step 1: Generate and save teacher outputs (can be done with a larger GPU)
# generate_teacher_labels(
#     teacher_model="Qwen/Qwen2.5-1.5B-Instruct",
#     dataset=dataset,
#     output_path="./teacher_logits.pt",
#     batch_size=4
# )

# Step 2: Train student with saved logits
# trainer = DistillationTrainer(
#     student_model="HuggingFaceTB/SmolLM-135M",
#     distillation_config=distill_config,
# )
# trainer.distill_from_logits(
#     logits_path="./teacher_logits.pt",
#     output_dir="./distilled_offline"
# )

print("üí° Offline distillation is useful when:")
print("   - Teacher model is too large (7B+)")
print("   - You want to reuse teacher outputs")
print("   - Running on Colab free tier")

## 8Ô∏è‚É£ Export Distilled Model

In [None]:
from lmfast.inference.quantization import quantize_model

# Quantize for deployment
quantize_model(
    "./distilled_model",
    "./distilled_model_int4",
    method="int4"
)

print("‚úÖ Model quantized and ready for deployment!")

# Compare sizes
import os
def get_dir_size(path):
    total = 0
    for f in os.listdir(path):
        fp = os.path.join(path, f)
        if os.path.isfile(fp):
            total += os.path.getsize(fp)
    return total / 1e6

print(f"\nOriginal size: {get_dir_size('./distilled_model'):.1f} MB")
print(f"Quantized size: {get_dir_size('./distilled_model_int4'):.1f} MB")

## üéâ Summary

You've learned how to:
- ‚úÖ Transfer knowledge from a 1.5B model to a 135M model
- ‚úÖ Use temperature scaling for better transfer
- ‚úÖ Apply offline distillation for memory efficiency
- ‚úÖ Quantize the distilled model for deployment

### Distillation Tips

| Tip | Why |
|-----|-----|
| Higher temperature (2-4) | Softer distributions transfer better |
| Balanced Œ± (0.3-0.7) | Don't ignore ground truth labels |
| More data | Distillation benefits from scale |
| Matching architectures | Similar tokenizers help |

### Next Steps
- `06_preference_alignment.ipynb` - ORPO/DPO alignment
- `09_basic_agents.ipynb` - Build agents with your model