In [None]:
# Install llm_distil from GitHub
!pip install git+https://github.com/yashpatel2010/llm_distil.git

# Import libraries
from llm_distil import (
    KnowledgeDistillation,
    ReverseKnowledgeDistillation,
    GeneralizedKnowledgeDistillation,
    DistillationConfig
)
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import torch
import pandas as pd
from tqdm.auto import tqdm

print("âœ“ All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Load Dolly-15k dataset (small subset for demo)
print("Loading Dolly-15k dataset...")
dataset = load_dataset("databricks/databricks-dolly-15k", split="train[:1000]")

print(f"âœ“ Loaded {len(dataset)} examples")
print("\nExample:")
print(f"Instruction: {dataset[0]['instruction'][:100]}...")
print(f"Response: {dataset[0]['response'][:100]}...")

# Load tokenizer
print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Tokenize dataset
def tokenize_function(examples):
    # Combine instruction and response
    texts = [f"{inst}\n{resp}" for inst, resp in zip(examples["instruction"], examples["response"])]
    return tokenizer(
        texts,
        truncation=True,
        padding="max_length",
        max_length=256
    )

print("Tokenizing dataset...")
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)

# Split into train/eval
train_dataset = tokenized_dataset.select(range(800))
eval_dataset = tokenized_dataset.select(range(800, 1000))

print(f"âœ“ Train: {len(train_dataset)} examples")
print(f"âœ“ Eval: {len(eval_dataset)} examples")

## 2. Load Dataset and Prepare Data

We'll use a subset of **Databricks Dolly-15k**, an instruction-following dataset.

## 3. Configure Training

Set up the distillation configuration with shared hyperparameters.

In [None]:
# Shared configuration
base_config = DistillationConfig(
    teacher_model_name="gpt2-medium",
    student_model_name="gpt2",
    temperature=2.0,
    kd_loss_weight=0.5,
    epochs=2,  # Reduced for demo
    batch_size=4,
    learning_rate=5e-5,
    max_length=256,
    logging_steps=50,
    eval_steps=200,
    save_steps=200,
    output_dir="./distil_output"
)

print("Configuration:")
print(f"  Temperature: {base_config.temperature}")
print(f"  KD Loss Weight: {base_config.kd_loss_weight}")
print(f"  Epochs: {base_config.epochs}")
print(f"  Batch Size: {base_config.batch_size}")
print(f"  Learning Rate: {base_config.learning_rate}")

## 4. Load Models

Load the teacher (GPT2-medium) and student (GPT2) models.

In [None]:
# Load teacher model (larger)
print("Loading teacher model (GPT2-medium)...")
teacher = AutoModelForCausalLM.from_pretrained("gpt2-medium")
print(f"âœ“ Teacher params: {teacher.num_parameters():,}")

# Load student model (smaller) - we'll train multiple copies
print("\nLoading student model (GPT2)...")
student_baseline = AutoModelForCausalLM.from_pretrained("gpt2")
student_kd = AutoModelForCausalLM.from_pretrained("gpt2")
student_revkd = AutoModelForCausalLM.from_pretrained("gpt2")
student_gkd = AutoModelForCausalLM.from_pretrained("gpt2")
print(f"âœ“ Student params: {student_baseline.num_parameters():,}")

print(f"\nðŸ“Š Compression ratio: {teacher.num_parameters() / student_baseline.num_parameters():.2f}x")

In [None]:
print("=" * 60)
print("Training with KD (Forward KL Divergence)")
print("=" * 60)

kd = KnowledgeDistillation(teacher, student_kd, base_config)
kd.train(train_dataset, eval_dataset)

print("âœ“ KD training complete!")

## 5. Train with Standard KD (Forward KL)

Train student using **Knowledge Distillation** with forward KL divergence.

**Loss:** `L = (1-Î±)Â·CE + Î±Â·TÂ²Â·KL(Teacher || Student)`

## 6. Train with RevKD (Reverse KL)

Train student using **Reverse Knowledge Distillation** with reverse KL divergence (mode-seeking).

**Loss:** `L = (1-Î±)Â·CE + Î±Â·TÂ²Â·KL(Student || Teacher)`

In [None]:
print("=" * 60)
print("Training with RevKD (Reverse KL Divergence)")
print("=" * 60)

revkd = ReverseKnowledgeDistillation(teacher, student_revkd, base_config)
revkd.train(train_dataset, eval_dataset)

print("âœ“ RevKD training complete!")

## 7. Train with GKD (Generalized JSD)

Train student using **Generalized Knowledge Distillation** with JSD and on-policy generation.

**Loss:** `L = Î»Â·JSD(Teacher, Student) + (1-Î»)Â·JSD(Teacher, Student_generated)`

In [None]:
print("=" * 60)
print("Training with GKD (Generalized JSD)")
print("=" * 60)

gkd_config = DistillationConfig(
    teacher_model_name="gpt2-medium",
    student_model_name="gpt2",
    lambda_gkd=0.5,
    beta_gkd=0.5,
    epochs=2,
    batch_size=4,
    learning_rate=5e-5,
    max_length=256,
    output_dir="./distil_output"
)

gkd = GeneralizedKnowledgeDistillation(teacher, student_gkd, gkd_config)
gkd.train(train_dataset, eval_dataset)

print("âœ“ GKD training complete!")

## 8. Evaluate All Methods

Compare the perplexity of all distilled students on the evaluation set.

## 10. Summary

### Key Findings:

1. **All distillation methods outperform baseline** (student without distillation)
2. **KD (Forward KL)** provides balanced performance (mean-seeking)
3. **RevKD (Reverse KL)** focuses on high-confidence predictions (mode-seeking)
4. **GKD (JSD)** leverages on-policy generation for robust distillation

### When to Use Each Method:

| Method | Best For | Behavior |
|--------|----------|----------|
| **KD** | General-purpose distillation | Covers all teacher modes |
| **RevKD** | High-confidence tasks | Focuses on peaks |
| **GKD** | Generative tasks | On-policy robustness |

### Next Steps:

- Try different temperatures (1.0-5.0)
- Adjust `kd_loss_weight` (balance CE vs KD loss)
- Train for more epochs for better convergence
- Evaluate on downstream tasks (classification, generation, etc.)

### Resources:

- **GitHub**: https://github.com/yashpatel2010/llm_distil
- **API Guide**: `docs/API_GUIDE.md`
- **Examples**: `examples/distill_dolly15k.py`

In [None]:
# Test prompt
prompt = "What is machine learning?"

print("=" * 60)
print(f"Prompt: '{prompt}'")
print("=" * 60)

# Tokenize prompt
inputs = tokenizer(prompt, return_tensors="pt")
device = "cuda" if torch.cuda.is_available() else "cpu"
inputs = {k: v.to(device) for k, v in inputs.items()}

# Generate from each model
models = {
    "Teacher": teacher,
    "KD": student_kd,
    "RevKD": student_revkd,
    "GKD": student_gkd
}

for name, model in models.items():
    model.to(device)
    model.eval()
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=50,
            do_sample=True,
            temperature=0.8,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id
        )
    text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"\n[{name}]")
    print(text)
    print("-" * 60)

In [None]:
print("=" * 60)
print("Evaluating all methods...")
print("=" * 60)

# Evaluate teacher
print("\n[1/5] Evaluating teacher...")
from llm_distil.metrics import compute_perplexity
from torch.utils.data import DataLoader

eval_dataloader = DataLoader(eval_dataset, batch_size=4, shuffle=False)
teacher_ppl = compute_perplexity(teacher, eval_dataset, tokenizer, device="cuda" if torch.cuda.is_available() else "cpu")

# Evaluate students
print("[2/5] Evaluating baseline student...")
baseline_ppl = compute_perplexity(student_baseline, eval_dataset, tokenizer, device="cuda" if torch.cuda.is_available() else "cpu")

print("[3/5] Evaluating KD student...")
kd_metrics = kd.evaluate(eval_dataset)

print("[4/5] Evaluating RevKD student...")
revkd_metrics = revkd.evaluate(eval_dataset)

print("[5/5] Evaluating GKD student...")
gkd_metrics = gkd.evaluate(eval_dataset)

# Create comparison table
results_df = pd.DataFrame({
    "Method": ["Teacher (GPT2-medium)", "Baseline (no distill)", "KD (Forward KL)", "RevKD (Reverse KL)", "GKD (JSD)"],
    "Perplexity": [
        teacher_ppl,
        baseline_ppl,
        kd_metrics['perplexity'],
        revkd_metrics['perplexity'],
        gkd_metrics['perplexity']
    ],
    "Model Size": ["355M", "124M", "124M", "124M", "124M"]
})

print("\n" + "=" * 60)
print("RESULTS")
print("=" * 60)
print(results_df.to_string(index=False))
print("=" * 60)