<a href="https://colab.research.google.com/github/peremartra/Tailoring-LLM-Architectures/blob/main/CH06/CH06_NB01_Data_Cosmopedia.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Tailoring LLM Architectures**
## **Chapter 6: The Fuel ‚Äî Data Strategy for Knowledge Recovery**

### **Notebook 1: Quality vs Quantity (Cosmopedia)**
by [Pere Martra](https://github.com/peremartra)

[![LinkedIn](https://img.shields.io/badge/LinkedIn-0077B5?style=flat&logo=linkedin&logoColor=white)](https://www.linkedin.com/in/pere-martra/) [![GitHub](https://img.shields.io/badge/GitHub-100000?style=flat&logo=github&logoColor=white)](https://github.com/peremartra) [![Hugging Face](https://img.shields.io/badge/ü§ó%20Hugging%20Face-blue)](https://huggingface.co/oopere)

___

**Colab Environment:** GPU T4  
- **Models recommended:** google/gemma-3-270m  
- **Tested with:** meta-llama/Llama-3.2-1B using a **GPU A100**

---

In this notebook, we challenge the standard. We will use only 15,000 high-quality data samples (Cosmopedia) to attempt to beat the results we obtained in Chapter 2 with 30,000 web crawl samples (SlimPajama).

**What we'll accomplish:**
- **Load our models**: Original (teacher) and pruned (student) (Reusing logic from Ch2)
- **Prepare High-Quality Data**: Load synthetic textbooks from Cosmopedia
- **Apply Knowledge Distillation**: Train the pruned model with high-quality data
- **Measure recovery**: Compare against the baseline set in Chapter 2


In [None]:
# Libraries
# Install required packages
!pip install -q transformers torch optipfair datasets accelerate sentencepiece lm-eval

In [None]:
# --- BASELINE RESULTS (FROM CHAPTER 2) ---
# We use the results obtained in CH02_NB02 as our baseline to beat.
# Dataset: SlimPajama (Web Crawl)
# Size: 30,000 samples
# Training: 3 Epochs
baseline_metrics = {
    "name": "SlimPajama (30k samples)",
    "arc_easy": 0.5800,      # Value from Ch2
    "winogrande": 0.5500,    # Value from Ch2
    "boolq": 0.5400          # Value from Ch2
}

print(f"üéØ Target to beat: ARC-Easy {baseline_metrics['arc_easy']:.2%}")

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.optim import AdamW
from optipfair import prune_model
from datasets import load_dataset
from torch.nn import functional as F
from torch.utils.data import DataLoader
from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
import time
import json
from typing import Dict, List, Any
import copy

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# 2.4 Recovering knowledge with distillation

## Load Libraries & Models.

In [None]:
# Model configuration (consistent with previous notebook)
MODEL_NAME = "google/gemma-3-270m"
#MODEL_NAME = "meta-llama/Llama-3.2-1B"
MAX_NEW_TOKENS = 50
LAYERS_TO_REMOVE = 2 #Try removing 4, 6, or even 8 layers
TEST_PROMPT = "Paris is the capital of"

print(f"Loading base model: {MODEL_NAME}")

# Load the original model (this will be our TEACHER)
teacher_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Clean generation config (same as previous notebook)
from transformers import GenerationConfig
clean_config = GenerationConfig(
    max_length=teacher_model.generation_config.max_length,
    pad_token_id=teacher_model.generation_config.pad_token_id,
    eos_token_id=teacher_model.generation_config.eos_token_id,
    do_sample=False,
    num_beams=1,
    early_stopping=False
)
teacher_model.generation_config = clean_config

In [None]:
original_params = sum(p.numel() for p in teacher_model.parameters())
print(f"Teacher model parameters: {original_params:,}")
print(f"Teacher model layers: {len(teacher_model.model.layers)}")

### Create Student Model with optiPfair


In [None]:
# Create the STUDENT model by pruning (same process as previous notebook)
print(f"\nCreating student model by removing {LAYERS_TO_REMOVE} layers...")

# Apply depth pruning using optipfair
student_model = prune_model(
    model=copy.deepcopy(teacher_model),
    pruning_type="DEPTH",
    num_layers_to_remove=LAYERS_TO_REMOVE,
    layer_selection_method="last",
    show_progress=True
)

In [None]:
student_params = sum(p.numel() for p in student_model.parameters())
param_reduction = (original_params - student_params) / original_params

print(f"Student model parameters: {student_params:,}")
print(f"Parameter reduction: {param_reduction:.1%}")
print(f"Student model layers: {len(student_model.model.layers)}")

student_model.gradient_checkpointing_enable()

## Support Functions & Basic Test

In [None]:
def model_evaluation(model_obj, tokenizer, tasks, limit=100):
    """
    Runs lm-eval on a PyTorch model object already in memory.

    Args:
        model_obj: The PyTorch model object to evaluate.
        tokenizer: The tokenizer object.
        tasks (list): A list of task names.
        limit (int): The number of samples per task.
    """
    print(f"Starting lm-eval on model '{model_obj.config._name_or_path}' for tasks: {tasks}")

    # Wrap the local model object and tokenizer for lm-eval
    model_wrapper = HFLM(
        pretrained=model_obj,
        tokenizer=tokenizer,
        device=str(device)
    )

    results = evaluator.simple_evaluate(
        model=model_wrapper,
        tasks=tasks,
        num_fewshot=0,
        limit=limit,
        device=str(device),
    )

    # Format results for clean display
    formatted_results = {}
    for task_name, res in results["results"].items():
        # Look for accuracy ('acc') first, then perplexity ('ppl')
        if 'acc,none' in res:
            metric_val = res.get('acc,none', 0)
        elif 'ppl,none' in res:
             metric_val = res.get('ppl,none', 0)
        else:
            metric_val = 0 # Fallback

        formatted_results[task_name] = f"{metric_val:.4f}"

    print(json.dumps(formatted_results, indent=2))
    return float(formatted_results.get(tasks[0], 0))

In [None]:
# Quick baseline test - confirm degradation from previous notebook
def generate_text(model, tokenizer, prompt: str, max_new_tokens: int = MAX_NEW_TOKENS) -> str:
    """Generate text with the model (same function as previous notebook)"""
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    with torch.no_grad():
        outputs = model.generate(
            inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_new_tokens=max_new_tokens,
            num_return_sequences=1,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=False,
            num_beams=3,
            early_stopping=True,
            no_repeat_ngram_size=2
        )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test both models with the same prompt
print(f"\n--- Baseline Test: '{TEST_PROMPT}' ---")
teacher_output = generate_text(teacher_model, tokenizer, TEST_PROMPT)
student_output = generate_text(student_model, tokenizer, TEST_PROMPT)

print(f"Teacher: '{teacher_output}'")
print(f"Student: '{student_output}'")
print("\nReady for knowledge recovery...")

## Dataset Preparation: Cosmopedia


In [None]:
# --- DATASET: COSMOPEDIA (High Quality) ---
# We use the 'stories' subset which contains synthetic textbooks and stories.
print("Loading Cosmopedia (Textbook Quality)...")

dataset_name = "HuggingFaceTB/cosmopedia"
subset = "stories"
num_samples = 15000  # HALF the size of Chapter 2!

dataset = load_dataset(dataset_name, subset, split="train", streaming=True)

# Tokenization pipeline (Igual que en Ch2 pero adaptado a la columna 'text')
def get_data_loader(dataset, num_samples, batch_size=8):
    data = []
    # Add shuffling with buffer for streaming
    dataset = dataset.shuffle(seed=42, buffer_size=1000)
    
    for i, sample in enumerate(dataset):
        if i >= num_samples: break
        
        # Cosmopedia stories normally use 'text' column, strictly checking just in case
        text = sample.get('text', sample.get('prompt', '') + ' ' + sample.get('completion', ''))
        
        tokenized = tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=128, # Short context is enough for this demo
            return_tensors="pt"
        )
        data.append({
            "input_ids": tokenized["input_ids"].squeeze(0),
            "attention_mask": tokenized["attention_mask"].squeeze(0)
        })
    
    return DataLoader(data, batch_size=batch_size, shuffle=True)

kd_dataloader = get_data_loader(dataset, num_samples)
print(f"Data ready: {num_samples} high-quality samples.")

In [None]:
# Move models to device and set appropriate modes
teacher_model.to(device)
student_model.to(device)

# Teacher stays in eval mode - we don't train it
teacher_model.eval()

# Student will be trained
student_model.train()

# KD Hyperparameters
TEMPERATURE = 2.0      # Softens probability distributions
ALPHA = 1.0           # Weight for distillation loss
NUM_EPOCHS = 3        # Conservative for demo
LEARNING_RATE = 1e-5  # Lower LR for stability
ACCUMULATION_STEPS = 4  # Effective batch size = 4 * 8 = 32

# Optimizer for student model only
optimizer = AdamW(student_model.parameters(), lr=LEARNING_RATE)

print(f"Knowledge Distillation Configuration:")
print(f"  Temperature: {TEMPERATURE}")
print(f"  Alpha: {ALPHA}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Effective Batch Size: {4 * ACCUMULATION_STEPS}")

In [None]:
# --- SIDEBAR: VISUALIZING TEMPERATURE ---
# Before training, let's understand the T parameter we are about to use.
import matplotlib.pyplot as plt
import numpy as np

def plot_temperature_scaling(logits, temperatures=[1.0, 2.0, 4.0]):
    plt.figure(figsize=(10, 5))
    tokens = ['London', 'Paris', 'Madrid', 'Rome'] # Example tokens
    
    for T in temperatures:
        # Softmax formula with T
        exp_logits = np.exp(np.array(logits) / T)
        probs = exp_logits / np.sum(exp_logits)
        
        plt.plot(tokens, probs, marker='o', label=f'T={T}')
        
    plt.title("How Temperature affects 'Dark Knowledge'")
    plt.ylabel("Probability")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

# Simulating logits where Index 1 (Paris) is the correct answer
dummy_logits = [2.0, 8.0, 4.5, 1.0] 
print("Visualizing Temperature Scaling (T=2.0 is our choice):")
plot_temperature_scaling(dummy_logits)

In [None]:
print(f"\n>>> STARTING TRAINING WITH COSMOPEDIA ({num_samples} samples) <<<")
print(f"Training student model to mimic teacher behavior \n")

for epoch in range(NUM_EPOCHS):
  student_model.train()
  total_loss = 0
  num_batches = 0
  for batch_idx, batch in enumerate(kd_dataloader):
    # ###Step 1: Move batch to device###
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)

    # Move teacher model to device, perform inference, and move back to CPU
    teacher_model.to(device)

    # ###Step 2: The Master is asked to generate its logits###
    with torch.no_grad():
      teacher_outputs = teacher_model(
        input_ids=input_ids,
        attention_mask=attention_mask
      )
      teacher_logits = teacher_outputs.logits / TEMPERATURE

    # Moving teacher model to CPU to save memory.
    teacher_model.cpu()
    torch.cuda.empty_cache()

    # ###Step 3: Get Student's "thoughts" (logits) with training enabled###
    # Student inference (with gradients)
    student_outputs = student_model(
      input_ids=input_ids,
      attention_mask=attention_mask
    )
    student_logits = student_outputs.logits / TEMPERATURE

    # Compute Knowledge Distillation loss
    teacher_probs = F.softmax(teacher_logits, dim=-1)
    student_log_probs = F.log_softmax(student_logits, dim=-1)

    # ###Step 4:
    # KL Divergence loss
    kd_loss = F.kl_div(
      student_log_probs,
      teacher_probs,
      reduction='batchmean'
    )

    # Scale loss for gradient accumulation
    loss = kd_loss / ACCUMULATION_STEPS
    loss.backward()

    # Gradient accumulation
    if (batch_idx + 1) % ACCUMULATION_STEPS == 0 or (batch_idx + 1) == len(kd_dataloader):
      optimizer.step()
      optimizer.zero_grad()

    total_loss += loss.item() * ACCUMULATION_STEPS
    num_batches += 1
    # Progress update

    if (batch_idx + 1) % 100 == 0:
      avg_loss = total_loss / num_batches
      print(f'Epoch {epoch + 1}/{NUM_EPOCHS} | Batch {batch_idx + 1} | Loss: {avg_loss:.4f}')

  # Epoch summary
  avg_epoch_loss = total_loss / num_batches
  print(f"Epoch {epoch + 1}/{NUM_EPOCHS} | Average Loss: {avg_epoch_loss:.4f}")

print(f"\nüéâ Knowledge Distillation completed!")

## Basic Test generation.

In [None]:
# Set student model to evaluation mode
teacher_model.to(device)
student_model.eval()
# Test with the same prompt used in baseline
print(f"--- Qualitative Test: '{TEST_PROMPT}' ---")

# Generate with all three models for comparison
teacher_output = generate_text(teacher_model, tokenizer, TEST_PROMPT)
student_baseline_output = generate_text(student_model, tokenizer, TEST_PROMPT)

print(f"Teacher (Original):    '{teacher_output}'")
print(f"Student (Post-KD):     '{student_baseline_output}'")

## Evaluation

In [None]:
# Define the benchmark suite for our diagnostic
benchmark_tasks = ['arc_easy', 'winogrande', 'hellaswag', 'lambada_openai']
#student_recovered_results = model_evaluation(student_model, tokenizer, benchmark_tasks, limit=100)

# --- FINAL COMPARISON ---
print("Evaluating Cosmopedia Student...")
# Run evaluation (reuse your evaluate function)
cosmo_acc = model_evaluation(student_model, tokenizer, ['arc_easy'], limit=100)

print("\nüèÜ QUALITY VS QUANTITY SHOWDOWN üèÜ")
print(f"{'Dataset':<25} | {'Samples':<10} | {'ARC-Easy Acc':<15} | {'Result'}")
print("-" * 65)
print(f"{baseline_metrics['name']:<25} | {'30,000':<10} | {baseline_metrics['arc_easy']:.4f}          | üõë Baseline")
print(f"{'Cosmopedia (Stories)':<25} | {'15,000':<10} | {cosmo_acc:.4f}          | {'‚úÖ WINNER' if cosmo_acc > baseline_metrics['arc_easy'] else '‚ùå LOWER'}")