<a href="https://colab.research.google.com/github/peremartra/Tailoring-LLM-Architectures/blob/main/CH06/CH06_NB02_Logits_KLD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Tailoring LLM Architectures**
## **Chapter 6: Knowledge Distillation with Logits Only**
### Training a pruned model using only Logits and Skew KL Divergence

by [Pere Martra](https://github.com/peremartra)

---

**Hardware Environment:** NVIDIA A100 GPU
- **Model:** google/gemma-3-270m (Teacher: 18 transformer blocks) → Pruned Student (14 transformer blocks)
- **Dataset:** Cosmopedia (40,000 samples, 3 epochs)

---

**What we'll accomplish:**
- Train a depth-pruned student model using **only logits and Skew KL Divergence**
- No hidden state alignment (simpler, faster training)
- Uses the same training framework as advanced KD but configured for logits-only
- Save the trained model to Hugging Face as **gem-3-small**

## Section 0: Environment & Dependencies

In [None]:
from google.colab import drive
drive.mount('/content/drive')
print("✓ Drive mounted")

In [None]:
RECOVERY_SAMPLES = 40000
EPOCHS=3
LEARNING_RATE=4e-5
BATCH_SIZE = 16
BATCH_EVAL="auto"
RUN_FULL_BENCHMARKS = True
BENCHMARK_LIMIT = None
BENCHMARK_TASKS = ["arc_easy", "winogrande", "hellaswag", "lambada_openai", "piqa"]
HF_MODEL_NAME = "gem-3-small"

In [None]:
!pip install -q transformers accelerate datasets
!pip install -q optipfair matplotlib seaborn tqdm
!pip install -q lm_eval langdetect codecarbon huggingface_hub

In [None]:
import torch, gc
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset, Dataset
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm.auto import tqdm
from copy import deepcopy
import warnings, time, json, os
from datetime import datetime
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
set_seed(42)
print("✓ Random seed set to 42")

In [None]:
!wget -q https://raw.githubusercontent.com/peremartra/Rearchitecting-LLMs/main/utils.py
from utils import evaluate_metrics, clear_gpu_cache, model_evaluation
print("✓ utils.py loaded")

## Section 1: Load Teacher Model

In [None]:
MODEL_NAME = "google/gemma-3-270m"
print(f"Loading Teacher model: {MODEL_NAME}")
teacher_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    device_map="auto" if torch.cuda.is_available() else None
)
teacher_model.eval()
for param in teacher_model.parameters():
    param.requires_grad = False

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

n_teacher_layers = len(teacher_model.model.layers)
hidden_dim = teacher_model.config.hidden_size
print(f"Teacher: {n_teacher_layers} layers, {teacher_model.num_parameters():,} params")

## Section 2: Prepare Training Dataset

In [None]:
MAX_LENGTH = 512
print("Loading Cosmopedia dataset...")
dataset_name = "HuggingFaceTB/cosmopedia"
subsets = ["stories", "wikihow", "openstax", "web_samples_v1"]
samples_per_subset = int(RECOVERY_SAMPLES / 4)

all_samples = []
for subset in subsets:
    print(f"  Loading {subset}...")
    subset_data = load_dataset(dataset_name, subset, split="train", streaming=True)
    subset_samples = list(subset_data.take(samples_per_subset))
    all_samples.extend(subset_samples)
    print(f"    ✓ {len(subset_samples):,} samples")

distillation_dataset = Dataset.from_dict({'text': [s['text'] for s in all_samples]})
print(f"✓ Total samples: {len(distillation_dataset):,}")

In [None]:
print("Tokenizing...")
texts = [item['text'] for item in distillation_dataset]
tokenized_data = []
for i in tqdm(range(0, len(texts), 1000), desc="Tokenizing"):
    batch = tokenizer(texts[i:i+1000], truncation=True, padding="max_length", max_length=MAX_LENGTH, return_tensors="pt")
    tokenized_data.append(batch)

input_ids = torch.cat([b['input_ids'] for b in tokenized_data], dim=0)
attention_mask = torch.cat([b['attention_mask'] for b in tokenized_data], dim=0)
full_dataset = TensorDataset(input_ids, attention_mask)

generator = torch.Generator().manual_seed(42)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=generator)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
eval_dataloader_raw = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

class DictDataLoader:
    def __init__(self, dl): self.dataloader = dl
    def __iter__(self):
        for input_ids, attention_mask in self.dataloader:
            yield {'input_ids': input_ids, 'attention_mask': attention_mask}
    def __len__(self): return len(self.dataloader)

eval_dataloader = DictDataLoader(eval_dataloader_raw)
print(f"✓ Train: {len(train_dataset):,}, Val: {len(val_dataset):,}")

## Section 3: Create Pruned Student Model

In [None]:
import optipfair as opf
student_model = deepcopy(teacher_model)
importance_scores = opf.analyze_layer_importance(student_model, train_dataloader, show_progress=True)
LAYERS_TO_REMOVE = sorted(importance_scores.keys(), key=lambda x: importance_scores[x])[:4]
print(f"Layers to remove: {LAYERS_TO_REMOVE}")

In [None]:
student_model = opf.prune_model_depth(model=student_model, layer_indices=LAYERS_TO_REMOVE, show_progress=True)
for param in student_model.parameters():
    param.requires_grad = True
n_student_layers = len(student_model.model.layers)
print(f"✓ Student: {n_student_layers} layers, {student_model.num_parameters():,} params")

In [None]:
print("Evaluating Teacher...")
teacher_metrics = evaluate_metrics(teacher_model, eval_dataloader, device=device)
teacher_ppl = teacher_metrics['perplexity']
teacher_loss = teacher_metrics['loss']

print("Evaluating Pruned Student...")
student_pruned_copy = deepcopy(student_model)
student_metrics = evaluate_metrics(student_pruned_copy, eval_dataloader, device=device)
student_ppl = student_metrics['perplexity']
student_loss = student_metrics['loss']
del student_pruned_copy
clear_gpu_cache()

print(f"Teacher PPL: {teacher_ppl:.2f}, Student PPL: {student_ppl:.2f}")

In [None]:
benchmark_results = {}
if RUN_FULL_BENCHMARKS:
    print("Benchmarking Teacher...")
    benchmark_results['teacher'] = model_evaluation(model_obj=teacher_model, tokenizer=tokenizer, tasks=BENCHMARK_TASKS, device=device, limit=BENCHMARK_LIMIT, batch_size=BATCH_EVAL)
    print("Benchmarking Pruned Student...")
    student_pruned_copy = deepcopy(student_model)
    benchmark_results['student_pruned'] = model_evaluation(model_obj=student_pruned_copy, tokenizer=tokenizer, tasks=BENCHMARK_TASKS, device=device, limit=BENCHMARK_LIMIT, batch_size=BATCH_EVAL)
    del student_pruned_copy
    clear_gpu_cache()

## Section 4: Shared Training Functions (from NB01)

These functions support both logits-only and advanced KD with hidden states. For logits-only training, we set `gamma=0.0` and `delta=0.0` to disable trajectory and FDD losses.

In [None]:
def train_student_advanced(
    student_model,
    teacher_model,
    dataloader,
    layer_map=None,      # Not used for logits-only, but kept for compatibility
    # Loss weights
    alpha=0.1,           # Task loss
    beta=0.8,            # Skew KLD (logits)
    gamma=0.05,          # Trajectory loss (hidden states)
    delta=0.05,          # FDD derivative loss
    temperature=2.0,
    skew_alpha=0.5,      # Skew interpolation factor
    # Training params
    epochs=3,
    learning_rate=4e-5,
    experiment_name="experiment",
    accumulation_steps=4
):
    """
    Train student model with Advanced Compound Loss (Skew KLD + FDD).
    Can be configured for logits-only by setting gamma=0.0, delta=0.0.
    """
    optimizer = torch.optim.AdamW(student_model.parameters(), lr=learning_rate)

    student_model.train()
    teacher_model.eval()

    # Decide if we need hidden states (required if using trajectory OR derivative loss)
    request_hidden_states = (gamma > 0 or delta > 0)

    loss_history = {
        'total': [], 'task': [], 'logits': [],
        'trajectory': [], 'derivative': []
    }
    epoch_times = []

    print(f"\n{'='*60}")
    print(f"Starting Training: {experiment_name}")
    print(f"{'='*60}")
    print(f"Epochs: {epochs}")
    print(f"Learning rate: {learning_rate}")
    print(f"Loss weights: α={alpha}, β={beta}, γ={gamma}, δ={delta}")
    print(f"Temperature: {temperature}, Skew α: {skew_alpha}")
    print(f"Hidden states computation: {'ENABLED' if request_hidden_states else 'DISABLED'}")
    print(f"Gradient Accumulation Steps: {accumulation_steps}")
    print(f"Effective Batch Size: {dataloader.batch_size * accumulation_steps}")
    print(f"{'='*60}\n")

    total_start_time = time.time()

    for epoch in range(epochs):
        epoch_start_time = time.time()

        epoch_losses = {k: [] for k in loss_history.keys()}
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")

        accumulated_losses = {k: 0.0 for k in loss_history.keys()}
        accumulation_counter = 0

        for batch_idx, (input_ids, attention_mask) in enumerate(progress_bar):
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = input_ids.clone()

            # Student forward pass
            student_outputs = student_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=request_hidden_states
            )

            # Teacher forward pass (no gradients)
            with torch.no_grad():
                teacher_outputs = teacher_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    output_hidden_states=request_hidden_states
                )

            # Prepare hidden states (None if not needed)
            student_hiddens = (student_outputs.hidden_states[1:]
                             if request_hidden_states else None)
            teacher_hiddens = (teacher_outputs.hidden_states[1:]
                             if request_hidden_states else None)

            # Compute advanced compound loss
            loss, loss_dict = compute_compound_loss_advanced(
                student_logits=student_outputs.logits,
                teacher_logits=teacher_outputs.logits,
                student_hiddens=student_hiddens,
                teacher_hiddens=teacher_hiddens,
                labels=labels,
                layer_map=layer_map,
                alpha=alpha,
                beta=beta,
                gamma=gamma,
                delta=delta,
                temperature=temperature,
                skew_alpha=skew_alpha
            )

            # Gradient accumulation
            scaled_loss = loss / accumulation_steps
            scaled_loss.backward()

            for key in accumulated_losses:
                accumulated_losses[key] += loss_dict[key]
            accumulation_counter += 1

            # Optimizer step
            if (batch_idx + 1) % accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)
                optimizer.step()
                optimizer.zero_grad()

                avg_losses = {k: v / accumulation_counter for k, v in accumulated_losses.items()}
                for key in avg_losses:
                    epoch_losses[key].append(avg_losses[key])

                progress_bar.set_postfix({
                    'loss': f"{avg_losses['total']:.4f}",
                    'task': f"{avg_losses['task']:.4f}",
                    'logits': f"{avg_losses['logits']:.4f}",
                    'traj': f"{avg_losses['trajectory']:.4f}",
                    'deriv': f"{avg_losses['derivative']:.4f}"
                })

                accumulated_losses = {k: 0.0 for k in loss_history.keys()}
                accumulation_counter = 0

        # Handle remaining batches
        if accumulation_counter > 0:
            torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()

            avg_losses = {k: v / accumulation_counter for k, v in accumulated_losses.items()}
            for key in avg_losses:
                epoch_losses[key].append(avg_losses[key])

        # Record epoch averages
        for key in epoch_losses:
            if epoch_losses[key]:
                loss_history[key].append(np.mean(epoch_losses[key]))

        epoch_time = time.time() - epoch_start_time
        epoch_times.append(epoch_time)

        print(f"Epoch {epoch+1} avg losses - "
              f"Total: {loss_history['total'][-1]:.4f}, "
              f"Task: {loss_history['task'][-1]:.4f}, "
              f"Logits: {loss_history['logits'][-1]:.4f}, "
              f"Traj: {loss_history['trajectory'][-1]:.4f}, "
              f"Deriv: {loss_history['derivative'][-1]:.4f} "
              f"[{epoch_time:.1f}s]")

    total_time = time.time() - total_start_time
    loss_history['epoch_times_seconds'] = epoch_times
    loss_history['total_time_seconds'] = total_time

    print(f"\n✓ Training completed: {experiment_name}")
    print(f"  Total time: {total_time:.1f}s ({total_time/60:.1f} min)")
    print(f"  Avg time per epoch: {np.mean(epoch_times):.1f}s")

    return student_model, loss_history


def compute_compound_loss_advanced(
    student_logits,      # [batch, seq_len, vocab_size]
    teacher_logits,      # [batch, seq_len, vocab_size]
    student_hiddens,     # List of [batch, seq_len, hidden_dim] or None
    teacher_hiddens,     # List of [batch, seq_len, hidden_dim] or None
    labels,              # [batch, seq_len]
    layer_map,           # List of teacher indices for each student Transformer Block
    alpha=0.1,           # weight for task loss
    beta=0.8,            # weight for logits loss (Skew KLD)
    gamma=0.1,           # weight for hidden trajectory loss
    delta=0.1,           # weight for FDD derivative loss
    temperature=2.0,     # temperature for soft labels
    skew_alpha=0.1       # interpolation factor for Skew KLD (0=Forward, 1=Reverse)
):
    """
    Advanced compound loss combining state-of-the-art techniques:

    1. Task Loss: Standard cross-entropy with hard labels
    2. Skew KLD: Interpolates between Forward and Reverse KLD (DistiLLM-2)
       - Forward KLD (α=0): Student covers all teacher modes (mean-seeking)
       - Reverse KLD (α=1): Student focuses on high-confidence modes (mode-seeking)
       - Skew (α=0.1): Best of both worlds, numerically stable
    3. Trajectory Loss: Cosine similarity between hidden states (standard feature KD)
    4. FDD Derivative Loss: Aligns the "rate of change" between consecutive Transformer Blocks
       - Forces student to learn HOW to transform representations, not just WHAT to produce
       - Critical for depth-pruned models that must take "bigger steps" with fewer Transformer Blocks

    Reference:
    - Skew KLD: DistiLLM-2 (2024)
    - FDD: Feature Dynamics Distillation, ACL 2025
    """
    device = student_logits.device

    # =========================================================================
    # 1. TASK LOSS (Cross-Entropy with hard labels)
    # =========================================================================
    shift_logits = student_logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    loss_task = F.cross_entropy(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1),
        ignore_index=-100
    )

    # =========================================================================
    # 2. SKEW KLD (Interpolated Divergence for Logits)
    # =========================================================================
    # Standard softmax with temperature
    with torch.no_grad():
        student_probs = F.softmax(student_logits[..., :-1, :] / temperature, dim=-1)
        teacher_probs = F.softmax(teacher_logits[..., :-1, :] / temperature, dim=-1)

        # Skew: interpolate between teacher and student distributions
        # When skew_alpha=0.0: 100% teacher (Forward KLD, equivalent to standard)
        # When skew_alpha=0.5: 50% teacher + 50% student (balanced)
        mixed_probs = skew_alpha * student_probs + (1 - skew_alpha) * teacher_probs

    # KL divergence against the mixed target
    student_log_probs = F.log_softmax(student_logits[..., :-1, :] / temperature, dim=-1)
    kl_elementwise = student_probs * (student_log_probs - torch.log(mixed_probs + 1e-9))
    loss_logits = kl_elementwise.sum(dim=-1).mean() * (temperature ** 2)

    # =========================================================================
    # 3. TRAJECTORY LOSS (Cosine Similarity of Hidden States)
    # =========================================================================
    # Standard feature alignment - matches "where" the student is in representation space
    # Only compute if gamma > 0 and we have hidden states
    if gamma > 0 and student_hiddens is not None and teacher_hiddens is not None:
        loss_trajectory = 0.0
        for student_idx, teacher_idx in enumerate(layer_map):
            student_h = student_hiddens[student_idx]
            teacher_h = teacher_hiddens[teacher_idx]

            # Flatten and normalize
            student_flat = student_h.reshape(-1, student_h.size(-1))
            teacher_flat = teacher_h.reshape(-1, teacher_h.size(-1))
            student_norm = F.normalize(student_flat, p=2, dim=1)
            teacher_norm = F.normalize(teacher_flat, p=2, dim=1)

            # Cosine similarity loss: 1 - cos_sim (0 = perfect alignment)
            cos_sim = (student_norm * teacher_norm).sum(dim=1).mean()
            loss_trajectory += (1 - cos_sim)

        loss_trajectory = loss_trajectory / len(layer_map)
    else:
        loss_trajectory = torch.tensor(0.0, device=device)

    # =========================================================================
    # 4. FDD DERIVATIVE LOSS (Feature Dynamics Distillation)
    # =========================================================================
    # Matches the "velocity" of representation change between consecutive Transformer Blocks
    # Key insight: In residual networks, x_{l+1} = x_l + F(x_l)
    # The delta F(x_l) represents how much the Transformer Block transforms the representation
    # A depth-pruned student must learn to make larger, more efficient transformations
    # Only compute if delta > 0 and we have hidden states
    loss_derivative = torch.tensor(0.0, device=device)

    if delta > 0 and student_hiddens is not None and teacher_hiddens is not None:
        num_derivatives = 0

        for student_idx in range(len(layer_map) - 1):
            teacher_idx = layer_map[student_idx]
            teacher_idx_next = layer_map[student_idx + 1]

            # Student delta: change between consecutive student Transformer Blocks
            student_delta = student_hiddens[student_idx + 1] - student_hiddens[student_idx]

            # Teacher delta: change between corresponding teacher Transformer Blocks
            teacher_delta = teacher_hiddens[teacher_idx_next] - teacher_hiddens[teacher_idx]

            # Flatten and normalize the deltas
            student_delta_flat = student_delta.reshape(-1, student_delta.size(-1))
            teacher_delta_flat = teacher_delta.reshape(-1, teacher_delta.size(-1))

            student_delta_norm = F.normalize(student_delta_flat, p=2, dim=1)
            teacher_delta_norm = F.normalize(teacher_delta_flat, p=2, dim=1)

            # Cosine similarity of derivatives
            cos_sim_deriv = (student_delta_norm * teacher_delta_norm).sum(dim=1).mean()
            loss_derivative += (1 - cos_sim_deriv)
            num_derivatives += 1

        if num_derivatives > 0:
            loss_derivative = loss_derivative / num_derivatives

    # =========================================================================
    # COMBINE ALL LOSSES
    # =========================================================================
    total_loss = (
        alpha * loss_task +
        beta * loss_logits +
        gamma * loss_trajectory +
        delta * loss_derivative
    )

    loss_dict = {
        'total': total_loss.item(),
        'task': loss_task.item(),
        'logits': loss_logits.item(),
        'trajectory': loss_trajectory.item(),
        'derivative': loss_derivative.item()
    }

    return total_loss, loss_dict

## Section 5: Train with Logits Only

Using the shared training function configured for logits-only mode (gamma=0, delta=0, skew_alpha=0).

In [None]:
student_logits_only = deepcopy(student_model)
student_trained, history = train_student_advanced(
    student_model=student_logits_only, 
    teacher_model=teacher_model, 
    dataloader=train_dataloader,
    layer_map=None,          # Not used for logits-only
    # Loss weights - Logits-only configuration
    alpha=0.5,               # Task loss weight
    beta=0.5,                # KLD weight
    gamma=0.0,               # NO trajectory loss (disabled)
    delta=0.0,               # NO FDD loss (disabled)
    temperature=2.0,
    skew_alpha=0.0,          # Forward KLD (0.0 = standard KLD, equivalent to original)
    # Training params
    epochs=EPOCHS,
    learning_rate=LEARNING_RATE,
    experiment_name="Logits-Only with Skew KLD Training",
    accumulation_steps=4
)

In [None]:
print("Evaluating Trained Student...")
trained_metrics = evaluate_metrics(student_trained, eval_dataloader, device=device)
trained_ppl = trained_metrics['perplexity']
trained_loss = trained_metrics['loss']
print(f"Trained PPL: {trained_ppl:.2f}")

degradation = student_ppl - teacher_ppl
recovered = student_ppl - trained_ppl
recovery_pct = (recovered / degradation) * 100 if degradation > 0 else 0
print(f"Recovery: {recovery_pct:.1f}%")

if RUN_FULL_BENCHMARKS:
    print("Benchmarking Trained Student...")
    benchmark_results['logits_only_trained'] = model_evaluation(model_obj=student_trained, tokenizer=tokenizer, tasks=BENCHMARK_TASKS, device=device, limit=BENCHMARK_LIMIT, batch_size=BATCH_EVAL)

## Section 6: Results Visualization

In [None]:
# Perplexity Bar Chart
models = ['Teacher\n(Baseline)', 'Pruned\n(No KD)', 'Trained\n(Logits KD)']
ppls = [teacher_ppl, student_ppl, trained_ppl]
colors = ['#2ecc71', '#e74c3c', '#3498db']

fig, ax = plt.subplots(figsize=(10, 6))
bars = ax.bar(models, ppls, color=colors, edgecolor='black', linewidth=1.5)
ax.set_ylabel('Perplexity (↓ lower is better)', fontsize=12)
ax.set_title('Perplexity Comparison: Teacher vs Pruned vs Trained', fontsize=14, fontweight='bold')
for bar, ppl in zip(bars, ppls):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, f'{ppl:.2f}', ha='center', va='bottom', fontsize=12, fontweight='bold')
ax.axhline(y=teacher_ppl, color='#2ecc71', linestyle='--', alpha=0.7, label=f'Teacher baseline: {teacher_ppl:.2f}')
ax.legend()
plt.tight_layout()
plt.savefig('ppl_comparison.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"\n✓ Recovery: {recovery_pct:.1f}% of degradation recovered")

In [None]:
# Benchmark Bar Chart
if RUN_FULL_BENCHMARKS and benchmark_results:
    tasks = BENCHMARK_TASKS
    x = np.arange(len(tasks))
    width = 0.25
    
    teacher_scores = [benchmark_results['teacher'].get(t, {}).get('acc', 0) * 100 for t in tasks]
    pruned_scores = [benchmark_results['student_pruned'].get(t, {}).get('acc', 0) * 100 for t in tasks]
    trained_scores = [benchmark_results['logits_only_trained'].get(t, {}).get('acc', 0) * 100 for t in tasks]
    
    fig, ax = plt.subplots(figsize=(14, 7))
    bars1 = ax.bar(x - width, teacher_scores, width, label='Teacher (Baseline)', color='#2ecc71', edgecolor='black')
    bars2 = ax.bar(x, pruned_scores, width, label='Pruned (No KD)', color='#e74c3c', edgecolor='black')
    bars3 = ax.bar(x + width, trained_scores, width, label='Trained (Logits KD)', color='#3498db', edgecolor='black')
    
    ax.set_ylabel('Accuracy (%)', fontsize=12)
    ax.set_title('Benchmark Comparison: Teacher vs Pruned vs Trained', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels([t.replace('_', '\n') for t in tasks], fontsize=10)
    ax.legend(loc='upper right')
    ax.set_ylim(0, 100)
    ax.grid(axis='y', alpha=0.3)
    
    def add_labels(bars):
        for bar in bars:
            h = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2, h + 1, f'{h:.1f}', ha='center', va='bottom', fontsize=8)
    add_labels(bars1)
    add_labels(bars2)
    add_labels(bars3)
    
    plt.tight_layout()
    plt.savefig('benchmark_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    # Print average recovery
    avg_teacher = np.mean(teacher_scores)
    avg_pruned = np.mean(pruned_scores)
    avg_trained = np.mean(trained_scores)
    bench_degradation = avg_teacher - avg_pruned
    bench_recovered = avg_trained - avg_pruned
    bench_recovery_pct = (bench_recovered / bench_degradation) * 100 if bench_degradation > 0 else 0
    print(f"\nBenchmark Average: Teacher={avg_teacher:.1f}%, Pruned={avg_pruned:.1f}%, Trained={avg_trained:.1f}%")
    print(f"Benchmark Recovery: {bench_recovery_pct:.1f}%")

In [None]:
# Training Loss Curves - Extract only relevant losses (trajectory & derivative are 0)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
loss_keys = [
    ('total', 'Total Loss'), 
    ('task', 'Task Loss (CE)'), 
    ('logits', 'Logits Loss (Skew KLD)')
]
for idx, (key, title) in enumerate(loss_keys):
    axes[idx].plot(history[key], marker='o', linewidth=2, markersize=8)
    axes[idx].set_title(title, fontsize=12)
    axes[idx].set_xlabel('Epoch')
    axes[idx].set_ylabel('Loss')
    axes[idx].grid(True, alpha=0.3)
plt.suptitle('Training Progress: Logits-Only Knowledge Distillation', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

## Section 6.5: Save Experiment Results

In [None]:
class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer): return int(obj)
        if isinstance(obj, np.floating): return float(obj)
        if isinstance(obj, np.ndarray): return obj.tolist()
        return super().default(obj)

results_data = {
    "metadata": {"experiment_name": "Logits-Only KD with Skew KLD", "timestamp": datetime.now().isoformat(), "torch_version": torch.__version__},
    "models": {"teacher": MODEL_NAME, "student": f"Gemma-270m-Pruned ({n_student_layers} layers)"},
    "training_config": {"alpha": 0.5, "beta": 0.5, "gamma": 0.0, "delta": 0.0, "skew_alpha": 0.0, "temperature": 2.0, "epochs": EPOCHS, "learning_rate": LEARNING_RATE, "batch_size": BATCH_SIZE, "samples": RECOVERY_SAMPLES},
    "results": {
        "teacher": {"perplexity": teacher_ppl, "loss": teacher_loss},
        "student_pruned": {"perplexity": student_ppl, "loss": student_loss},
        "logits_only_trained": {"perplexity": trained_ppl, "loss": trained_loss, "recovery_pct": recovery_pct, "training_time": history.get('total_time_seconds', 0)}
    }
}
if RUN_FULL_BENCHMARKS:
    for k, v in benchmark_results.items():
        if k in results_data['results']: results_data['results'][k]['benchmarks'] = v

json_path = "/content/drive/MyDrive/ch06nb01/CH06_NB02_Logits_KLD_results.json"
os.makedirs(os.path.dirname(json_path), exist_ok=True)
with open(json_path, 'w') as f: json.dump(results_data, f, indent=2, cls=NumpyEncoder)
print(f"✓ Results saved to: {json_path}")
with open("CH06_NB02_Logits_KLD_results.json", 'w') as f: json.dump(results_data, f, indent=2, cls=NumpyEncoder)
print("✓ Local backup saved")

## Section 7: Save Model to Hugging Face

In [None]:
from huggingface_hub import login, HfApi
login()

In [None]:
print(f"Saving model to HuggingFace as '{HF_MODEL_NAME}'...")
student_trained.push_to_hub(HF_MODEL_NAME, commit_message="Depth-pruned Gemma-3-270m with Logits-Only KD")
tokenizer.push_to_hub(HF_MODEL_NAME, commit_message="Tokenizer for gem-3-small")
print(f"✓ Model saved: https://huggingface.co/{HF_MODEL_NAME}")

In [None]:
del student_logits_only
clear_gpu_cache()
print(f"\n{'='*60}\nTRAINING COMPLETE\n{'='*60}")
print(f"Model: https://huggingface.co/{HF_MODEL_NAME}")
print(f"Results: {json_path}")