<a href="https://colab.research.google.com/github/peremartra/Rearchitecting-LLMs/blob/main/CH02/CH02_NB02_Knowledge_Recovery.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Rearchitecting LLMs**
## **Surgical Optimization for Hyper-Efficient Models**

### **Chapter 2: Knowledge Recovery via Distillation**
by [Pere Martra](https://github.com/peremartra)

**Colab Environment:** GPU T4  
- **Models recommended:** google/gemma-3-270m  
- **Tested with:** meta-llama/Llama-3.2-1B using a **GPU A100**

---

Welcome to the second part of your first model rearchitecture. In the previous notebook, we successfully performed depth pruning by removing layers from our model, achieving significant efficiency gains but experiencing expected performance degradation.

Now we'll complete the optimization cycle by recovering the lost knowledge through Knowledge Distillation (KD). This notebook demonstrates how a pruned model can learn to mimic the behavior of its original, unpruned version.

**What we'll accomplish:**
- **Load our models**: Original (teacher) and pruned (student) from the previous notebook
- **Prepare recovery data**: Prepare a general-purpose dataset (SlimPajama) to facilitate the knowledge transfer.
- **Apply Knowledge Distillation**: Train the pruned model to recover lost capabilities
- **Measure recovery**: Quantify how much performance we can restore

This notebook bridges the gap between "breaking" a model (pruning) and "fixing" it (recovery), completing your first end-to-end model tailoring workflow.

**Previous notebook:** [CH02_NB01_Depth_pruning_evaluation](https://github.com/peremartra/Rearchitecting-LLMs/blob/main/CH02/CH02_NB01_Depth_pruning_evaluation.ipynb)
**Connection:** We'll use the exact same models and evaluation framework to measure our recovery success.


In [19]:
# Libraries
# Install required packages
!pip install -q transformers torch optipfair datasets accelerate sentencepiece lm-eval

In [20]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.optim import AdamW
from optipfair import prune_model
from datasets import load_dataset
from torch.nn import functional as F
from torch.utils.data import DataLoader
from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
import time
import json
from typing import Dict, List, Any
import copy

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Using device: cuda
GPU: NVIDIA L4


## Load Libraries & Models.

In [21]:
# Model configuration (consistent with previous notebook)
MODEL_NAME = "google/gemma-3-270m"
#MODEL_NAME = "meta-llama/Llama-3.2-1B"
MAX_NEW_TOKENS = 50
LAYERS_TO_REMOVE = 2
TEST_PROMPT = "Paris is the capital of"

print(f"Loading base model: {MODEL_NAME}")

# Load the original model (this will be our TEACHER)
teacher_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Clean generation config (same as previous notebook)
from transformers import GenerationConfig
clean_config = GenerationConfig(
    max_length=teacher_model.generation_config.max_length,
    pad_token_id=teacher_model.generation_config.pad_token_id,
    eos_token_id=teacher_model.generation_config.eos_token_id,
    do_sample=False,
    num_beams=1,
    early_stopping=False
)
teacher_model.generation_config = clean_config

Loading base model: google/gemma-3-270m


In [22]:
original_params = sum(p.numel() for p in teacher_model.parameters())
print(f"Teacher model parameters: {original_params:,}")
print(f"Teacher model layers: {len(teacher_model.model.layers)}")

Teacher model parameters: 268,098,176
Teacher model layers: 18


### Create Student Model.


In [23]:
# Create the STUDENT model by pruning (same process as previous notebook)
print(f"\nCreating student model by removing {LAYERS_TO_REMOVE} layers...")

# Apply depth pruning using optipfair
student_model = prune_model(
    model=copy.deepcopy(teacher_model),
    pruning_type="DEPTH",
    num_layers_to_remove=LAYERS_TO_REMOVE,
    layer_selection_method="last",
    show_progress=True
)


Creating student model by removing 2 layers...


Removing layers: 100%|██████████| 18/18 [00:00<00:00, 325420.14it/s]


In [24]:
student_params = sum(p.numel() for p in student_model.parameters())
param_reduction = (original_params - student_params) / original_params

print(f"Student model parameters: {student_params:,}")
print(f"Parameter reduction: {param_reduction:.1%}")
print(f"Student model layers: {len(student_model.model.layers)}")

student_model.gradient_checkpointing_enable()

Student model parameters: 256,950,912
Parameter reduction: 4.2%
Student model layers: 16


## Support Functions & Basic Test

In [25]:
def model_evaluation(model_obj, tokenizer, tasks, limit=100):
    """
    Runs lm-eval on a PyTorch model object already in memory.

    Args:
        model_obj: The PyTorch model object to evaluate.
        tokenizer: The tokenizer object.
        tasks (list): A list of task names.
        limit (int): The number of samples per task.
    """
    print(f"Starting lm-eval on model '{model_obj.config._name_or_path}' for tasks: {tasks}")

    # Wrap the local model object and tokenizer for lm-eval
    model_wrapper = HFLM(
        pretrained=model_obj,
        tokenizer=tokenizer,
        device=str(device)
    )

    results = evaluator.simple_evaluate(
        model=model_wrapper,
        tasks=tasks,
        num_fewshot=0,
        limit=limit,
        device=str(device),
    )

    # Format results for clean display
    formatted_results = {}
    for task_name, res in results["results"].items():
        # Look for accuracy ('acc') first, then perplexity ('ppl')
        if 'acc,none' in res:
            metric_val = res.get('acc,none', 0)
        elif 'ppl,none' in res:
             metric_val = res.get('ppl,none', 0)
        else:
            metric_val = 0 # Fallback

        formatted_results[task_name] = f"{metric_val:.4f}"

    print(json.dumps(formatted_results, indent=2))
    return formatted_results

In [26]:
# Quick baseline test - confirm degradation from previous notebook
def generate_text(model, tokenizer, prompt: str, max_new_tokens: int = MAX_NEW_TOKENS) -> str:
    """Generate text with the model (same function as previous notebook)"""
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    with torch.no_grad():
        outputs = model.generate(
            inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_new_tokens=max_new_tokens,
            num_return_sequences=1,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=False,
            num_beams=3,
            early_stopping=True,
            no_repeat_ngram_size=2
        )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test both models with the same prompt
print(f"\n--- Baseline Test: '{TEST_PROMPT}' ---")
teacher_output = generate_text(teacher_model, tokenizer, TEST_PROMPT)
student_output = generate_text(student_model, tokenizer, TEST_PROMPT)

print(f"Teacher: '{teacher_output}'")
print(f"Student: '{student_output}'")
print("\nReady for knowledge recovery...")


--- Baseline Test: 'Paris is the capital of' ---
Teacher: 'Paris is the capital of France. It is located in the middle of the country and has a population of about 10 million people. Paris is a city with a rich history and culture. The city is known for its beautiful architecture, art, and history. There are'
Student: 'Paris is the capital of France and one of the largest cities in Europe. It occupies approximately 2.5 million hectares of land surrounded by mountains and forests. Parisians love to travel abroad because they enjoy sightseeing tours abroad. Tourists visiting Paris visit museums, monuments, theaters,'

Ready for knowledge recovery...


##Knowledge Distillation Process

Before we begin the recovery process, let's review the performance impact from our depth pruning (results from the previous notebook):

### Gemma-3-270m: Depth Pruning Impact

| Metric | Original Model | Pruned Model (-2 Layers) | Change |
|:-------|:---------------|:-------------------------|:-------|
| **Parameters** | 268,098,176 | 256,950,912 | **-4.16%** |
| **Inference Time** | 10.324s | 4.322s | **+58.1%** |
| **arc_easy** (acc) | 0.5500 | 0.4600 | -16.36% |
| **winogrande** (acc) | 0.6000 | 0.4800 | -20.00% |
| **boolq** (acc) | 0.6600 | 0.4500 | -31.82% |
| **lambada_openai** (acc) | 0.4200 | 0.3400 | -19.05% |

**Trade-off**: 58.1% faster inference + 4.16% parameter reduction vs. accuracy degradation across benchmarks.

### LLaMA-3.2-1B: Depth Pruning Impact

| Metric | Original Model | Pruned Model (-2 Layers) | Change |
|:-------|:---------------|:-------------------------|:-------|
| **Parameters** | 1,235,814,400 | 1,117,171,392 | **-9.84%** |
| **Inference Time** | 6.635s | 4.752s | **+28.4%** |
| **arc_easy** (acc) | 0.6600 | 0.4800 | -27.27% |
| **winogrande** (acc) | 0.6000 | 0.5400 | -10.00% |
| **boolq** (acc) | 0.6700 | 0.7000 | **+4.48%** |
| **lambada_openai** (acc) | 0.5700 | 0.1700 | -70.18% |

**Interesting observation**: LLaMA-3.2-1B shows a curious **improvement** in boolq (+4.48%) after pruning, suggesting some layers may have been introducing noise for this specific task.

**Our KD goal**: Recover the lost performance while maintaining the efficiency gains from pruning.

---

## Dataset Preparation


In [27]:
# Load SlimPajama dataset in streaming mode for efficiency
print("Loading SlimPajama-627B dataset...")
dataset = load_dataset(
    "cerebras/SlimPajama-627B",
    split="train",
    streaming=True
)

# Take a representative subset for our recovery process
#    You can reduce the number os samples to reduce the execution time
#    but you'll see some reduction in lambada benchmark
#    due to small number of examples
RECOVERY_SAMPLES = 500
print(f"Selecting {RECOVERY_SAMPLES:,} samples for knowledge recovery...")

# Use streaming dataset's take method - much more efficient!
distillation_dataset = dataset.take(RECOVERY_SAMPLES)

print(f"✓ Streaming dataset ready: {RECOVERY_SAMPLES:,} samples")

Loading SlimPajama-627B dataset...


Resolving data files:   0%|          | 0/59166 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/31428 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/31411 [00:00<?, ?it/s]

Selecting 500 samples for knowledge recovery...
✓ Streaming dataset ready: 500 samples


In [28]:
def tokenize_for_kd(examples, max_length=128):
    """
    Tokenize text for Knowledge Distillation.
    Shorter sequences work better for KD as both teacher and student
    need to fit in memory simultaneously.
    """
    if isinstance(examples, dict):
        texts = examples["text"]
    else:
        texts = [ex["text"] for ex in examples]

    tokenized = tokenizer(
        texts,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )

    # For language modeling, labels = input_ids
    return {
        "input_ids": tokenized["input_ids"],
        "attention_mask": tokenized["attention_mask"],
        "labels": tokenized["input_ids"].clone()
    }

# Process the recovery dataset
print("Tokenizing recovery dataset...")

# Convert streaming dataset to list for tokenization
recovery_samples = list(distillation_dataset)
tokenized_data = tokenize_for_kd(recovery_samples)

print(f"✓ Tokenized {len(recovery_samples):,} samples")

# Convert to format suitable for DataLoader
class KDDataset(torch.utils.data.Dataset):
    def __init__(self, input_ids, attention_mask, labels):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.labels = labels

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_mask[idx],
            'labels': self.labels[idx]
        }

kd_dataset = KDDataset(
    tokenized_data["input_ids"],
    tokenized_data["attention_mask"],
    tokenized_data["labels"]
)

# Create DataLoader for training
kd_dataloader = DataLoader(
    kd_dataset,
    batch_size=4,  # Small batch size due to memory constraints with two models
    shuffle=True
)

print(f"✓ Knowledge Distillation DataLoader ready: {len(kd_dataloader)} batches")

Tokenizing recovery dataset...
✓ Tokenized 500 samples
✓ Knowledge Distillation DataLoader ready: 125 batches


In [29]:
# Move models to device and set appropriate modes
teacher_model.to(device)
student_model.to(device)

# Teacher stays in eval mode - we don't train it
teacher_model.eval()

# Student will be trained
student_model.train()

# KD Hyperparameters
TEMPERATURE = 2.0      # Softens probability distributions
ALPHA = 1.0           # Weight for distillation loss
NUM_EPOCHS = 3        # Conservative for demo
LEARNING_RATE = 1e-5  # Lower LR for stability
ACCUMULATION_STEPS = 4  # Effective batch size = 4 * 8 = 32

# Optimizer for student model only
optimizer = AdamW(student_model.parameters(), lr=LEARNING_RATE)

print(f"Knowledge Distillation Configuration:")
print(f"  Temperature: {TEMPERATURE}")
print(f"  Alpha: {ALPHA}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Effective Batch Size: {4 * ACCUMULATION_STEPS}")

Knowledge Distillation Configuration:
  Temperature: 2.0
  Alpha: 1.0
  Epochs: 3
  Learning Rate: 1e-05
  Effective Batch Size: 16


In [32]:
print(f"\n🎓 Starting Knowledge Distillation Training...")
print(f"Training student model to mimic teacher behavior \n")

for epoch in range(NUM_EPOCHS):
  student_model.train()
  total_loss = 0
  num_batches = 0
  for batch_idx, batch in enumerate(kd_dataloader):
    # Move batch to device
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)

    # Move teacher model to device, perform inference, and move back to CPU
    teacher_model.to(device)
    with torch.no_grad():
      teacher_outputs = teacher_model(
        input_ids=input_ids,
        attention_mask=attention_mask
      )
      teacher_logits = teacher_outputs.logits / TEMPERATURE
    teacher_model.cpu()
    torch.cuda.empty_cache()

    # Student inference (with gradients)
    student_outputs = student_model(
      input_ids=input_ids,
      attention_mask=attention_mask
    )
    student_logits = student_outputs.logits / TEMPERATURE

    # Compute Knowledge Distillation loss
    teacher_probs = F.softmax(teacher_logits, dim=-1)
    student_log_probs = F.log_softmax(student_logits, dim=-1)

    # KL Divergence loss
    kd_loss = F.kl_div(
      student_log_probs,
      teacher_probs,
      reduction='batchmean'
    )

    # Scale loss for gradient accumulation
    loss = kd_loss / ACCUMULATION_STEPS
    loss.backward()

    # Gradient accumulation
    if (batch_idx + 1) % ACCUMULATION_STEPS == 0 or (batch_idx + 1) == len(kd_dataloader):
      optimizer.step()
      optimizer.zero_grad()

    total_loss += loss.item() * ACCUMULATION_STEPS
    num_batches += 1
    # Progress update

    if (batch_idx + 1) % 100 == 0:
      avg_loss = total_loss / num_batches
      print(f'Epoch {epoch + 1}/{NUM_EPOCHS} | Batch {batch_idx + 1} | Loss: {avg_loss:.4f}')

  # Epoch summary
  avg_epoch_loss = total_loss / num_batches
  print(f"Epoch {epoch + 1}/{NUM_EPOCHS} | Average Loss: {avg_epoch_loss:.4f}")

print(f"\n🎉 Knowledge Distillation completed!")

It is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.



🎓 Starting Knowledge Distillation Training...
Training student model to mimic teacher behavior 

Epoch 1/3 | Batch 100 | Loss: 94.4372
Epoch 1/3 | Average Loss: 87.2026
Epoch 2/3 | Batch 100 | Loss: 47.9341
Epoch 2/3 | Average Loss: 45.8856
Epoch 3/3 | Batch 100 | Loss: 36.6266
Epoch 3/3 | Average Loss: 36.7624

🎉 Knowledge Distillation completed!


## Basic Test generation.

In [33]:
# Set student model to evaluation mode
teacher_model.to(device)
student_model.eval()
# Test with the same prompt used in baseline
print(f"--- Qualitative Test: '{TEST_PROMPT}' ---")

# Generate with all three models for comparison
teacher_output = generate_text(teacher_model, tokenizer, TEST_PROMPT)
student_baseline_output = generate_text(student_model, tokenizer, TEST_PROMPT)

print(f"Teacher (Original):    '{teacher_output}'")
print(f"Student (Post-KD):     '{student_baseline_output}'")

--- Qualitative Test: 'Paris is the capital of' ---
Teacher (Original):    'Paris is the capital of France. It is located in the middle of the country and has a population of about 10 million people. Paris is a city with a rich history and culture. The city is known for its beautiful architecture, art, and history. There are'
Student (Post-KD):     'Paris is the capital of France and one of the most famous cities in Europe. It is also the largest city in France, with a population of more than 1.5 million. Paris is located on the river Seine, which flows through the city and connects it to Paris'


## Evaluation

In [34]:
# Define the benchmark suite for our diagnostic
benchmark_tasks = ['arc_easy', 'winogrande', 'boolq', 'lambada_openai']
student_recovered_results = model_evaluation(student_model, tokenizer, benchmark_tasks, limit=100)



Starting lm-eval on model 'google/gemma-3-270m' for tasks: ['arc_easy', 'winogrande', 'boolq', 'lambada_openai']


100%|██████████| 100/100 [00:00<00:00, 541.40it/s]
100%|██████████| 100/100 [00:00<00:00, 1743.89it/s]
100%|██████████| 100/100 [00:00<00:00, 94786.53it/s]
100%|██████████| 100/100 [00:00<00:00, 1011.08it/s]
Running loglikelihood requests: 100%|██████████| 899/899 [00:24<00:00, 36.69it/s]


bootstrapping for stddev: perplexity


100%|██████████| 100/100 [00:00<00:00, 610.46it/s]


{
  "arc_easy": "0.4700",
  "boolq": "0.4700",
  "lambada_openai": "0.2600",
  "winogrande": "0.5800"
}


## Complete Knowledge Recovery Analysis

### Gemma-3-270m: Full Recovery Progression

| Metric | Teacher (Original) | Student (Pruned) | Student (500 samples) | Student (15K samples) | Final Performance |
|:-------|:-------------------|:-----------------|:---------------------|:---------------------|:------------------|
| **Parameters** | 268,098,176 | 256,950,912 | 256,950,912 | 256,950,912 | **Maintained** |
| **arc_easy** (acc) | 0.5500 | 0.4600 | 0.4700 | **0.5300** | **96.4%** of original |
| **winogrande** (acc) | 0.6000 | 0.4800 | 0.5800 | 0.5700 | **95.0%** of original |
| **boolq** (acc) | 0.6600 | 0.4500 | 0.5300 | 0.5300 | **80.3%** of original |
| **lambada_openai** (acc) | 0.4200 | 0.3400 | 0.2600 | **0.3600** | **85.7%** of original |

### Impact of Dataset Size on Recovery:

**Major Improvements with 15K samples:**
- **arc_easy**: 0.47 → **0.53** (+0.06, now 96.4% of original performance)
- **lambada_openai**: 0.26 → **0.36** (+0.10, reached 85.7% of original)

**Stable Performance:**
- **boolq**: Maintained at 0.53 (80.3% of original)
- **winogrande**: 0.58 → 0.57 (95.0% of original, excellent retention)

### Key Dataset Size Insights:

**500 vs 15K samples comparison:**
- **More data = Better recovery** for complex reasoning tasks (arc_easy, lambada)
- **Diminishing returns** for some tasks (boolq plateaued)
- **winogrande** achieved excellent recovery even with 500 samples

### Overall Assessment:

**Outstanding Results with 15K samples:**
- **ALL benchmarks recovered 80%+ of original performance**
- **arc_easy** and **winogrande** nearly fully recovered (95%+)
- **Model efficiency maintained** (4.16% parameter reduction + 58.1% inference speedup)

**Key Takeaway:** Knowledge Distillation successfully restored most capabilities while preserving all efficiency gains. With adequate training data, a pruned model can recover 85-96% of its original performance across diverse reasoning tasks.

## Complete Knowledge Recovery Analysis

### LLaMA-3.2-1B: Full Recovery Progression

| Metric | Teacher (Original) | Student (Pruned) | Student (500 samples) | Final Performance |
|:-------|:-------------------|:-----------------|:---------------------|:------------------|
| **Parameters** | 1,235,814,400 | 1,117,171,392 | 1,117,171,392 | **Maintained** |
| **arc_easy** (acc) | 0.6600 | 0.4800 | **0.5800** | **87.9%** of original |
| **winogrande** (acc) | 0.6000 | 0.5400 | 0.5500 | **91.7%** of original |
| **boolq** (acc) | 0.6700 | 0.7000 | **0.5400** | **80.6%** of original |
| **lambada_openai** (acc) | 0.5700 | 0.1700 | 0.1800 | **31.6%** of original |

### Recovery Analysis:

**Strong Recovery Performance:**
- **winogrande**: 0.54 → 0.55 (+0.01, maintained 91.7% of original)
- **arc_easy**: 0.48 → **0.58** (+0.10, recovered to 87.9% of original)

**Moderate Recovery:**
- **boolq**: 0.70 → 0.54 (-0.16, normalized to 80.6% of original)

**Challenge Area:**
- **lambada_openai**: 0.17 → 0.18 (+0.01, minimal recovery to 31.6% of original)

### Key LLaMA-3.2-1B Insights:

**Interesting Observations:**
- **boolq paradox resolved**: The pruned model's anomalous improvement (0.67 → 0.70) was corrected during KD, returning to a more realistic 0.54
- **arc_easy excellent recovery**: Strong improvement from severe pruning degradation
- **lambada remains challenging**: This complex reasoning task shows the deepest impact from layer removal

### Overall Assessment:

**Results with 500 samples:**
- **3 out of 4 benchmarks** show strong performance (80%+ of original)
- **Model efficiency maintained** (9.84% parameter reduction + 28.4% inference speedup)
- **boolq correction** suggests KD helps normalize anomalous pruning effects

**Key Takeaway:** Even with limited training data (500 samples), Knowledge Distillation effectively recovered most capabilities in LLaMA-3.2-1B, with particularly strong results in reasoning tasks like arc_easy and winogrande.