<a href="https://colab.research.google.com/github/peremartra/Tailoring-LLM-Architectures/blob/main/CH06/CH06_NB02_Logits_KLD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Tailoring LLM Architectures**
## **Chapter 6: Knowledge Distillation with Logits Only**
### Training a pruned model using only Logits and KL Divergence

by [Pere Martra](https://github.com/peremartra)

---

**Hardware Environment:** NVIDIA A100 GPU
- **Model:** google/gemma-3-270m (Teacher: 18 transformer blocks) → Pruned Student (14 transformer blocks)
- **Dataset:** Cosmopedia (40,000 samples, 3 epochs)

---

**What we'll accomplish:**
- Train a depth-pruned student model using **only logits and KL Divergence**
- No hidden state alignment (simpler, faster training)
- Save the trained model to Hugging Face as **gem-3-small**

## Section 0: Environment & Dependencies

In [None]:
from google.colab import drive
drive.mount('/content/drive')
print("✓ Drive mounted")

In [None]:
RECOVERY_SAMPLES = 40000
EPOCHS=3
LEARNING_RATE=4e-5
BATCH_SIZE = 16
BATCH_EVAL="auto"
RUN_FULL_BENCHMARKS = True
BENCHMARK_LIMIT = None
BENCHMARK_TASKS = ["arc_easy", "winogrande", "hellaswag", "lambada_openai", "piqa"]
HF_MODEL_NAME = "gem-3-small"

In [None]:
!pip install -q transformers accelerate datasets
!pip install -q optipfair matplotlib seaborn tqdm
!pip install -q lm_eval langdetect codecarbon huggingface_hub

In [None]:
import torch, gc
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset, Dataset
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm.auto import tqdm
from copy import deepcopy
import warnings, time, json, os
from datetime import datetime
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
set_seed(42)
print("✓ Random seed set to 42")

In [None]:
!wget -q https://raw.githubusercontent.com/peremartra/Rearchitecting-LLMs/main/utils.py
from utils import evaluate_metrics, clear_gpu_cache, model_evaluation
print("✓ utils.py loaded")

## Section 1: Load Teacher Model

In [None]:
MODEL_NAME = "google/gemma-3-270m"
print(f"Loading Teacher model: {MODEL_NAME}")
teacher_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    device_map="auto" if torch.cuda.is_available() else None
)
teacher_model.eval()
for param in teacher_model.parameters():
    param.requires_grad = False

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

n_teacher_layers = len(teacher_model.model.layers)
hidden_dim = teacher_model.config.hidden_size
print(f"Teacher: {n_teacher_layers} layers, {teacher_model.num_parameters():,} params")

## Section 2: Prepare Training Dataset

In [None]:
MAX_LENGTH = 512
print("Loading Cosmopedia dataset...")
dataset_name = "HuggingFaceTB/cosmopedia"
subsets = ["stories", "wikihow", "openstax", "web_samples_v1"]
samples_per_subset = int(RECOVERY_SAMPLES / 4)

all_samples = []
for subset in subsets:
    print(f"  Loading {subset}...")
    subset_data = load_dataset(dataset_name, subset, split="train", streaming=True)
    subset_samples = list(subset_data.take(samples_per_subset))
    all_samples.extend(subset_samples)
    print(f"    ✓ {len(subset_samples):,} samples")

distillation_dataset = Dataset.from_dict({'text': [s['text'] for s in all_samples]})
print(f"✓ Total samples: {len(distillation_dataset):,}")

In [None]:
print("Tokenizing...")
texts = [item['text'] for item in distillation_dataset]
tokenized_data = []
for i in tqdm(range(0, len(texts), 1000), desc="Tokenizing"):
    batch = tokenizer(texts[i:i+1000], truncation=True, padding="max_length", max_length=MAX_LENGTH, return_tensors="pt")
    tokenized_data.append(batch)

input_ids = torch.cat([b['input_ids'] for b in tokenized_data], dim=0)
attention_mask = torch.cat([b['attention_mask'] for b in tokenized_data], dim=0)
full_dataset = TensorDataset(input_ids, attention_mask)

generator = torch.Generator().manual_seed(42)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=generator)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
eval_dataloader_raw = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

class DictDataLoader:
    def __init__(self, dl): self.dataloader = dl
    def __iter__(self):
        for input_ids, attention_mask in self.dataloader:
            yield {'input_ids': input_ids, 'attention_mask': attention_mask}
    def __len__(self): return len(self.dataloader)

eval_dataloader = DictDataLoader(eval_dataloader_raw)
print(f"✓ Train: {len(train_dataset):,}, Val: {len(val_dataset):,}")

## Section 3: Create Pruned Student Model

In [None]:
import optipfair as opf
student_model = deepcopy(teacher_model)
importance_scores = opf.analyze_layer_importance(student_model, train_dataloader, show_progress=True)
LAYERS_TO_REMOVE = sorted(importance_scores.keys(), key=lambda x: importance_scores[x])[:4]
print(f"Layers to remove: {LAYERS_TO_REMOVE}")

In [None]:
student_model = opf.prune_model_depth(model=student_model, layer_indices=LAYERS_TO_REMOVE, show_progress=True)
for param in student_model.parameters():
    param.requires_grad = True
n_student_layers = len(student_model.model.layers)
print(f"✓ Student: {n_student_layers} layers, {student_model.num_parameters():,} params")

In [None]:
print("Evaluating Teacher...")
teacher_metrics = evaluate_metrics(teacher_model, eval_dataloader, device=device)
teacher_ppl = teacher_metrics['perplexity']
teacher_loss = teacher_metrics['loss']

print("Evaluating Pruned Student...")
student_pruned_copy = deepcopy(student_model)
student_metrics = evaluate_metrics(student_pruned_copy, eval_dataloader, device=device)
student_ppl = student_metrics['perplexity']
student_loss = student_metrics['loss']
del student_pruned_copy
clear_gpu_cache()

print(f"Teacher PPL: {teacher_ppl:.2f}, Student PPL: {student_ppl:.2f}")

In [None]:
benchmark_results = {}
if RUN_FULL_BENCHMARKS:
    print("Benchmarking Teacher...")
    benchmark_results['teacher'] = model_evaluation(model_obj=teacher_model, tokenizer=tokenizer, tasks=BENCHMARK_TASKS, device=device, limit=BENCHMARK_LIMIT, batch_size=BATCH_EVAL)
    print("Benchmarking Pruned Student...")
    student_pruned_copy = deepcopy(student_model)
    benchmark_results['student_pruned'] = model_evaluation(model_obj=student_pruned_copy, tokenizer=tokenizer, tasks=BENCHMARK_TASKS, device=device, limit=BENCHMARK_LIMIT, batch_size=BATCH_EVAL)
    del student_pruned_copy
    clear_gpu_cache()

## Section 4: Loss Function (Logits Only)

In [None]:
def compute_logits_only_loss(student_logits, teacher_logits, labels, alpha=0.5, beta=0.5, temperature=2.0):
    shift_logits = student_logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    loss_task = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100)
    
    student_soft = F.log_softmax(student_logits / temperature, dim=-1)
    teacher_soft = F.softmax(teacher_logits / temperature, dim=-1)
    student_soft = student_soft[..., :-1, :].contiguous()
    teacher_soft = teacher_soft[..., :-1, :].contiguous()
    loss_logits = F.kl_div(student_soft.view(-1, student_soft.size(-1)), teacher_soft.view(-1, teacher_soft.size(-1)), reduction='batchmean') * (temperature ** 2)
    
    total_loss = alpha * loss_task + beta * loss_logits
    return total_loss, {'total': total_loss.item(), 'task': loss_task.item(), 'logits': loss_logits.item()}

## Section 5: Training Loop

In [None]:
def train_student_logits_only(student_model, teacher_model, dataloader, alpha=0.5, beta=0.5, temperature=2.0, epochs=3, learning_rate=4e-5, experiment_name="experiment", accumulation_steps=4):
    optimizer = torch.optim.AdamW(student_model.parameters(), lr=learning_rate)
    student_model.train()
    teacher_model.eval()
    loss_history = {'total': [], 'task': [], 'logits': []}
    epoch_times = []
    print(f"\n{'='*60}\nStarting: {experiment_name}\nEpochs: {epochs}, LR: {learning_rate}, α={alpha}, β={beta}, T={temperature}\nHidden states: DISABLED\n{'='*60}")
    total_start = time.time()
    for epoch in range(epochs):
        epoch_start = time.time()
        epoch_losses = {k: [] for k in loss_history}
        progress = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        accum_losses, accum_count = {k: 0.0 for k in loss_history}, 0
        for batch_idx, (input_ids, attention_mask) in enumerate(progress):
            input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
            labels = input_ids.clone()
            student_out = student_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=False)
            with torch.no_grad():
                teacher_out = teacher_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=False)
            loss, loss_dict = compute_logits_only_loss(student_out.logits, teacher_out.logits, labels, alpha, beta, temperature)
            (loss / accumulation_steps).backward()
            for k in accum_losses: accum_losses[k] += loss_dict[k]
            accum_count += 1
            if (batch_idx + 1) % accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)
                optimizer.step()
                optimizer.zero_grad()
                avg = {k: v/accum_count for k, v in accum_losses.items()}
                for k in avg: epoch_losses[k].append(avg[k])
                progress.set_postfix({'loss': f"{avg['total']:.4f}"})
                accum_losses, accum_count = {k: 0.0 for k in loss_history}, 0
        if accum_count > 0:
            torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()
        for k in epoch_losses:
            if epoch_losses[k]: loss_history[k].append(np.mean(epoch_losses[k]))
        epoch_time = time.time() - epoch_start
        epoch_times.append(epoch_time)
        print(f"Epoch {epoch+1}: Total={loss_history['total'][-1]:.4f}, Task={loss_history['task'][-1]:.4f}, Logits={loss_history['logits'][-1]:.4f} [{epoch_time:.1f}s]")
    total_time = time.time() - total_start
    loss_history['epoch_times_seconds'] = epoch_times
    loss_history['total_time_seconds'] = total_time
    print(f"\n✓ Training completed in {total_time:.1f}s ({total_time/60:.1f} min)")
    return student_model, loss_history

## Section 6: Train with Logits Only

In [None]:
student_logits_only = deepcopy(student_model)
student_trained, history = train_student_logits_only(
    student_model=student_logits_only, teacher_model=teacher_model, dataloader=train_dataloader,
    alpha=0.5, beta=0.5, temperature=2.0, epochs=EPOCHS, learning_rate=LEARNING_RATE,
    experiment_name="Logits-Only KLD Training")

In [None]:
print("Evaluating Trained Student...")
trained_metrics = evaluate_metrics(student_trained, eval_dataloader, device=device)
trained_ppl = trained_metrics['perplexity']
trained_loss = trained_metrics['loss']
print(f"Trained PPL: {trained_ppl:.2f}")

degradation = student_ppl - teacher_ppl
recovered = student_ppl - trained_ppl
recovery_pct = (recovered / degradation) * 100 if degradation > 0 else 0
print(f"Recovery: {recovery_pct:.1f}%")

if RUN_FULL_BENCHMARKS:
    print("Benchmarking Trained Student...")
    benchmark_results['logits_only_trained'] = model_evaluation(model_obj=student_trained, tokenizer=tokenizer, tasks=BENCHMARK_TASKS, device=device, limit=BENCHMARK_LIMIT, batch_size=BATCH_EVAL)

## Section 7: Results Visualization

In [None]:
# Perplexity Bar Chart
models = ['Teacher\n(Baseline)', 'Pruned\n(No KD)', 'Trained\n(Logits KD)']
ppls = [teacher_ppl, student_ppl, trained_ppl]
colors = ['#2ecc71', '#e74c3c', '#3498db']

fig, ax = plt.subplots(figsize=(10, 6))
bars = ax.bar(models, ppls, color=colors, edgecolor='black', linewidth=1.5)
ax.set_ylabel('Perplexity (↓ lower is better)', fontsize=12)
ax.set_title('Perplexity Comparison: Teacher vs Pruned vs Trained', fontsize=14, fontweight='bold')
for bar, ppl in zip(bars, ppls):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, f'{ppl:.2f}', ha='center', va='bottom', fontsize=12, fontweight='bold')
ax.axhline(y=teacher_ppl, color='#2ecc71', linestyle='--', alpha=0.7, label=f'Teacher baseline: {teacher_ppl:.2f}')
ax.legend()
plt.tight_layout()
plt.savefig('ppl_comparison.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"\n✓ Recovery: {recovery_pct:.1f}% of degradation recovered")

In [None]:
# Benchmark Bar Chart
if RUN_FULL_BENCHMARKS and benchmark_results:
    tasks = BENCHMARK_TASKS
    x = np.arange(len(tasks))
    width = 0.25
    
    teacher_scores = [benchmark_results['teacher'].get(t, {}).get('acc', 0) * 100 for t in tasks]
    pruned_scores = [benchmark_results['student_pruned'].get(t, {}).get('acc', 0) * 100 for t in tasks]
    trained_scores = [benchmark_results['logits_only_trained'].get(t, {}).get('acc', 0) * 100 for t in tasks]
    
    fig, ax = plt.subplots(figsize=(14, 7))
    bars1 = ax.bar(x - width, teacher_scores, width, label='Teacher (Baseline)', color='#2ecc71', edgecolor='black')
    bars2 = ax.bar(x, pruned_scores, width, label='Pruned (No KD)', color='#e74c3c', edgecolor='black')
    bars3 = ax.bar(x + width, trained_scores, width, label='Trained (Logits KD)', color='#3498db', edgecolor='black')
    
    ax.set_ylabel('Accuracy (%)', fontsize=12)
    ax.set_title('Benchmark Comparison: Teacher vs Pruned vs Trained', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels([t.replace('_', '\n') for t in tasks], fontsize=10)
    ax.legend(loc='upper right')
    ax.set_ylim(0, 100)
    ax.grid(axis='y', alpha=0.3)
    
    def add_labels(bars):
        for bar in bars:
            h = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2, h + 1, f'{h:.1f}', ha='center', va='bottom', fontsize=8)
    add_labels(bars1)
    add_labels(bars2)
    add_labels(bars3)
    
    plt.tight_layout()
    plt.savefig('benchmark_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    # Print average recovery
    avg_teacher = np.mean(teacher_scores)
    avg_pruned = np.mean(pruned_scores)
    avg_trained = np.mean(trained_scores)
    bench_degradation = avg_teacher - avg_pruned
    bench_recovered = avg_trained - avg_pruned
    bench_recovery_pct = (bench_recovered / bench_degradation) * 100 if bench_degradation > 0 else 0
    print(f"\nBenchmark Average: Teacher={avg_teacher:.1f}%, Pruned={avg_pruned:.1f}%, Trained={avg_trained:.1f}%")
    print(f"Benchmark Recovery: {bench_recovery_pct:.1f}%")

In [None]:
# Training Loss Curves
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for idx, (key, title) in enumerate([('total', 'Total Loss'), ('task', 'Task Loss (CE)'), ('logits', 'Logits Loss (KLD)')]):
    axes[idx].plot(history[key], marker='o', linewidth=2, markersize=8)
    axes[idx].set_title(title, fontsize=12)
    axes[idx].set_xlabel('Epoch')
    axes[idx].set_ylabel('Loss')
    axes[idx].grid(True, alpha=0.3)
plt.suptitle('Training Progress: Logits-Only Knowledge Distillation', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

## Section 7.5: Save Experiment Results

In [None]:
class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer): return int(obj)
        if isinstance(obj, np.floating): return float(obj)
        if isinstance(obj, np.ndarray): return obj.tolist()
        return super().default(obj)

results_data = {
    "metadata": {"experiment_name": "Logits-Only KD", "timestamp": datetime.now().isoformat(), "torch_version": torch.__version__},
    "models": {"teacher": MODEL_NAME, "student": f"Gemma-270m-Pruned ({n_student_layers} layers)"},
    "training_config": {"alpha": 0.5, "beta": 0.5, "temperature": 2.0, "epochs": EPOCHS, "learning_rate": LEARNING_RATE, "batch_size": BATCH_SIZE, "samples": RECOVERY_SAMPLES},
    "results": {
        "teacher": {"perplexity": teacher_ppl, "loss": teacher_loss},
        "student_pruned": {"perplexity": student_ppl, "loss": student_loss},
        "logits_only_trained": {"perplexity": trained_ppl, "loss": trained_loss, "recovery_pct": recovery_pct, "training_time": history.get('total_time_seconds', 0)}
    }
}
if RUN_FULL_BENCHMARKS:
    for k, v in benchmark_results.items():
        if k in results_data['results']: results_data['results'][k]['benchmarks'] = v

json_path = "/content/drive/MyDrive/ch06nb01/CH06_NB02_Logits_KLD_results.json"
os.makedirs(os.path.dirname(json_path), exist_ok=True)
with open(json_path, 'w') as f: json.dump(results_data, f, indent=2, cls=NumpyEncoder)
print(f"✓ Results saved to: {json_path}")
with open("CH06_NB02_Logits_KLD_results.json", 'w') as f: json.dump(results_data, f, indent=2, cls=NumpyEncoder)
print("✓ Local backup saved")

## Section 8: Save Model to Hugging Face

In [None]:
from huggingface_hub import login, HfApi
login()

In [None]:
print(f"Saving model to HuggingFace as '{HF_MODEL_NAME}'...")
student_trained.push_to_hub(HF_MODEL_NAME, commit_message="Depth-pruned Gemma-3-270m with Logits-Only KD")
tokenizer.push_to_hub(HF_MODEL_NAME, commit_message="Tokenizer for gem-3-small")
print(f"✓ Model saved: https://huggingface.co/{HF_MODEL_NAME}")

In [None]:
del student_logits_only
clear_gpu_cache()
print(f"\n{'='*60}\nTRAINING COMPLETE\n{'='*60}")
print(f"Model: https://huggingface.co/{HF_MODEL_NAME}")
print(f"Results: {json_path}")