# 06 - Training: KD2 (Sequence-level) and KD3 (Feature-based)

**Thesis Section Reference:** Chapter 4.3-4.4 - Sequence-level and Feature-based KD

This notebook trains:
1. **KD2 (Sequence-level):** Student learns from teacher-generated sequences
2. **KD3 (Feature-based):** Student matches teacher's hidden representations

## Grid Search
- Lambda λ ∈ {0.1, 0.5, 1.0}

## Notes
- KD2 is particularly useful for QA (SQuAD)
- KD3 uses layer mapping between teacher and student

In [1]:
# Standard setup
import os
import sys
import gc
import json
from pathlib import Path
from itertools import product

import torch
import numpy as np
import pandas as pd

ROOT_DIR = Path.cwd().parent if Path.cwd().name == "notebooks" else Path.cwd()
sys.path.insert(0, str(ROOT_DIR / "src"))

from dotenv import load_dotenv
load_dotenv(ROOT_DIR / ".env")

from config import load_config
from utils_seed import set_seed
from run_io import RunRegistry

config = load_config(str(ROOT_DIR / "configs" / "experiment.yaml"))
config.ensure_dirs()

# Device
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

print(f"Mode: {'FAST' if config.fast_mode else 'FULL'}")
print(f"Device: {DEVICE}")

Mode: FAST
Device: mps


In [3]:
# Set up paths 
DATA_DIR = ROOT_DIR / "data"
CACHE_DIR = ROOT_DIR / "cache"
MODELS_DIR = ROOT_DIR / "results" / "models"
RUNS_DIR = ROOT_DIR / "results" / "raw_runs"

for d in [MODELS_DIR, RUNS_DIR]:
    d.mkdir(parents=True, exist_ok=True)

# Simple run registry (same as notebook 05)
class SimpleRegistry:
    def __init__(self, path):
        self.path = Path(path)
        self.runs = {}
        if self.path.exists():
            with open(self.path, 'r') as f:
                self.runs = json.load(f)
    
    def _save(self):
        with open(self.path, 'w') as f:
            json.dump(self.runs, f, indent=2)
    
    def check_run(self, run_id):
        return run_id in self.runs and self.runs[run_id].get("status") == "completed"
    
    def get_run(self, run_id):
        return self.runs.get(run_id)
    
    def register_run(self, run_id, result):
        self.runs[run_id] = {**result, "status": "completed"}
        self._save()

registry = SimpleRegistry(RUNS_DIR / "run_registry.json")

student_name = os.getenv("STUDENT_S1", config.student_s1.name)
print(f"Student model: {student_name}")

Student model: TinyLlama/TinyLlama-1.1B-Chat-v1.0


In [5]:
# Load processed datasets
from datasets import load_from_disk, Dataset

PROCESSED_DIR = ROOT_DIR / "results" / "processed_data"

print("Loading datasets...")

sst2_train = load_from_disk(str(PROCESSED_DIR / "sst2_train"))
sst2_val = load_from_disk(str(PROCESSED_DIR / "sst2_validation"))

squad_train = load_from_disk(str(PROCESSED_DIR / "squad_train"))
squad_val = load_from_disk(str(PROCESSED_DIR / "squad_validation"))

print(f"SST-2: {len(sst2_train)} train, {len(sst2_val)} val")
print(f"SQuAD: {len(squad_train)} train, {len(squad_val)} val")

Loading datasets...
SST-2: 2000 train, 500 val
SQuAD: 2000 train, 500 val


In [6]:
# Load teacher outputs for KD2 and KD3
TEACHER_CACHE = ROOT_DIR / "results" / "teacher_cache"

print("Loading cached teacher outputs...")

# KD2: Teacher-generated answers for SQuAD
with open(TEACHER_CACHE / "squad_teacher_answers.json", "r") as f:
    teacher_answers = json.load(f)
print(f"  Loaded {len(teacher_answers)} teacher answers for KD2")

# KD3: Hidden states
hidden_state_files = sorted(TEACHER_CACHE.glob("hidden_states_*.pt"))
if hidden_state_files:
    with open(TEACHER_CACHE / "hidden_states_sst2_meta.json", "r") as f:
        hidden_state_meta = json.load(f)
    print(f"  Loaded hidden state metadata: {hidden_state_meta['num_chunks']} chunks")
else:
    print("  No hidden states found - KD3 will be skipped")
    hidden_state_meta = None

Loading cached teacher outputs...
  Loaded 2000 teacher answers for KD2
  Loaded hidden state metadata: 21 chunks


In [7]:
# Training utilities (same as notebook 05)
from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType

def get_training_args(output_dir, task, run_name, seed=42):
    """Get training arguments optimized for MPS."""
    # Conservative batch sizes for MPS memory
    per_device_batch = 1
    grad_accum = 16
    
    num_epochs = config.get_epochs()
    
    return TrainingArguments(
        output_dir=str(output_dir),
        run_name=run_name,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=per_device_batch,
        per_device_eval_batch_size=per_device_batch,
        gradient_accumulation_steps=grad_accum,
        learning_rate=config.training.learning_rate,
        weight_decay=config.training.weight_decay,
        warmup_ratio=config.training.warmup_ratio,
        lr_scheduler_type="cosine",
        fp16=False,
        bf16=False,
        dataloader_pin_memory=False,
        dataloader_num_workers=0,
        logging_steps=50,
        eval_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=1,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        gradient_checkpointing=True,
        optim="adamw_torch",
        seed=seed,
        data_seed=seed,
        report_to="none",
    )

def get_lora_config():
    return LoraConfig(
        r=config.lora.r,
        lora_alpha=config.lora.lora_alpha,
        lora_dropout=config.lora.lora_dropout,
        target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
        task_type=TaskType.CAUSAL_LM,
        bias="none"
    )

def load_student_model(student_name, use_lora=True):
    """Load student model with optional LoRA."""
    print(f"Loading student model: {student_name}")
    
    tokenizer = AutoTokenizer.from_pretrained(
        student_name,
        trust_remote_code=True,
        cache_dir=str(ROOT_DIR / "hf_cache")
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model = AutoModelForCausalLM.from_pretrained(
        student_name,
        trust_remote_code=True,
        torch_dtype=torch.float32,
        cache_dir=str(ROOT_DIR / "hf_cache"),
        low_cpu_mem_usage=True
    )
    
    if use_lora:
        model = get_peft_model(model, get_lora_config())
        trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total = sum(p.numel() for p in model.parameters())
        print(f"  LoRA params: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
    
    model = model.to(DEVICE)
    return model, tokenizer

# Placeholder metric function (same as notebook 05)
def make_placeholder_metric_fn():
    def compute_metrics(eval_pred):
        return {"eval_placeholder": 0.0}
    return compute_metrics

print("Training utilities ready.")

Training utilities ready.


## Section 1: KD2 (Sequence-level KD)

Student learns from teacher-generated sequences.
The student is trained on (prompt, teacher_answer) pairs.

In [8]:
# Prepare KD2 dataset (using teacher answers as targets)
from data_squad import create_squad_prompt

def create_kd2_dataset(teacher_answers, tokenizer, max_length=512):
    """Create dataset with teacher answers as targets."""
    
    input_ids_list = []
    attention_mask_list = []
    labels_list = []
    
    for item in teacher_answers:
        # Full sequence: prompt + teacher_answer
        prompt = item["prompt"]
        answer = item["teacher_answer"]
        full_text = f"{prompt}\nAnswer: {answer}"
        
        # Tokenize
        encoded = tokenizer(
            full_text,
            max_length=max_length,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )
        
        input_ids = encoded["input_ids"].squeeze()
        attention_mask = encoded["attention_mask"].squeeze()
        
        # Labels: mask prompt, only predict answer
        prompt_encoded = tokenizer(
            f"{prompt}\nAnswer: ",
            return_tensors="pt"
        )
        prompt_len = prompt_encoded["input_ids"].shape[1]
        
        labels = input_ids.clone()
        labels[:prompt_len] = -100  # Mask prompt
        
        input_ids_list.append(input_ids)
        attention_mask_list.append(attention_mask)
        labels_list.append(labels)
    
    return Dataset.from_dict({
        "input_ids": input_ids_list,
        "attention_mask": attention_mask_list,
        "labels": labels_list,
    })

print("KD2 data preparation function ready.")

KD2 data preparation function ready.


In [9]:
# Train KD2 models (Sequence-level KD)
# Student learns from teacher-generated text sequences
# Simplified for MPS memory constraints

from trainers import BaselineTrainer

print("KD2 Training: Training on teacher-generated sequences")
print("Note: Using minimal config for MPS memory stability")

kd2_results = []

seed = config.get_seeds()[0]
run_id = f"KD2_squad_S1_l1.0_seed{seed}"

if registry.check_run(run_id):
    print(f"✓ {run_id} already completed, skipping...")
    existing = registry.get_run(run_id)
    kd2_results.append(existing)
else:
    print(f"\n{'='*60}")
    print(f"Training: {run_id}")  
    print(f"{'='*60}")
    
    # Aggressive cleanup
    gc.collect()
    if DEVICE.type == "mps":
        torch.mps.empty_cache()
        torch.mps.synchronize()
    
    set_seed(seed)
    
    model = None
    trainer = None
    kd2_train = None
    
    try:
        # Load model
        print("Loading student model...")
        model, tokenizer = load_student_model(student_name, use_lora=True)
        
        # Create very small KD2 dataset 
        print("Preparing KD2 dataset (small subset)...")
        kd2_train = create_kd2_dataset(
            teacher_answers[:200],  # Much smaller subset
            tokenizer,
            max_length=128  # Shorter sequences
        )
        print(f"  KD2 dataset: {len(kd2_train)} examples")
        
        # Minimal eval set
        small_val = squad_val.select(range(min(50, len(squad_val))))
        
        # Training args - no evaluation during training
        output_dir = MODELS_DIR / run_id
        training_args = TrainingArguments(
            output_dir=str(output_dir),
            run_name=run_id,
            num_train_epochs=1,
            per_device_train_batch_size=1,
            per_device_eval_batch_size=1,
            gradient_accumulation_steps=8,
            learning_rate=1e-4,
            weight_decay=0.01,
            seed=seed,
            logging_steps=20,
            eval_strategy="no",  # Skip eval during training
            save_strategy="no",  # Skip checkpoints
            fp16=False,
            bf16=False,
            gradient_checkpointing=True,
            dataloader_num_workers=0,
            remove_unused_columns=False,
            report_to=[],
        )
        
        metric_fn = make_placeholder_metric_fn()
        
        trainer = BaselineTrainer(
            model=model,
            args=training_args,
            train_dataset=kd2_train,
            processing_class=tokenizer,
        )
        
        # Train only
        print("Starting training...")
        train_result = trainer.train()
        
        # Record result (no eval to avoid crash)
        result = {
            "run_id": run_id,
            "method": "KD2", 
            "task": "squad",
            "student": "S1",
            "seed": seed,
            "lambda": 1.0,
            "train_loss": train_result.training_loss,
            "eval_loss": None,  # Skipped for MPS stability
            "note": "eval skipped for MPS memory"
        }
        
        # Save model
        trainer.save_model(str(output_dir / "final"))
        registry.register_run(run_id, result)
        kd2_results.append(result)
        
        print(f"\n✓ {run_id} complete: train_loss={result['train_loss']:.4f}")
        
    except Exception as e:
        print(f"✗ Training failed: {e}")
        import traceback
        traceback.print_exc()
        result = {
            "run_id": run_id,
            "method": "KD2",
            "task": "squad", 
            "student": "S1",
            "seed": seed,
            "lambda": 1.0,
            "error": str(e)
        }
        kd2_results.append(result)
    
    finally:
        if model is not None:
            del model
        if trainer is not None:
            del trainer  
        if kd2_train is not None:
            del kd2_train
        gc.collect()
        if DEVICE.type == "mps":
            torch.mps.empty_cache()
            torch.mps.synchronize()
        import time
        time.sleep(3)

print(f"\n✓ KD2 training section complete: {len(kd2_results)} runs")

KD2 Training: Training on teacher-generated sequences
Note: Using minimal config for MPS memory stability

Training: KD2_squad_S1_l1.0_seed42
Loading student model...
Loading student model: TinyLlama/TinyLlama-1.1B-Chat-v1.0


`torch_dtype` is deprecated! Use `dtype` instead!


Loading weights:   0%|          | 0/201 [00:00<?, ?it/s]

  LoRA params: 4,505,600 / 1,104,553,984 (0.41%)


The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 2}.


Preparing KD2 dataset (small subset)...
  KD2 dataset: 200 examples
Starting training...


  super().__init__(loader)


Step,Training Loss
20,0.0



✓ KD2_squad_S1_l1.0_seed42 complete: train_loss=0.0000

✓ KD2 training section complete: 1 runs


## Section 2: KD3 (Feature-based KD)

Student learns to match teacher's hidden representations.
Uses layer mapping to align teacher and student layers.

In [12]:
# Load hidden states for KD3
if hidden_state_meta is not None:
    print("Loading hidden states...")
    
    # Get actual files that exist
    hidden_state_files = sorted(TEACHER_CACHE.glob("hidden_states_sst2_*.pt"), 
                                key=lambda x: int(x.stem.split('_')[-1]))
    print(f"  Found {len(hidden_state_files)} hidden state files")
    
    hidden_states_list = []
    for f in hidden_state_files:
        chunk = torch.load(f, weights_only=True)
        hidden_states_list.append(chunk)
    
    # Concatenate all chunks
    teacher_hidden_states = torch.cat(hidden_states_list, dim=0)
    print(f"  Total hidden states: {teacher_hidden_states.shape}")
    
    # Layer mapping info
    print(f"  Teacher layers cached: {hidden_state_meta['selected_layers']}")
    print(f"  Hidden size: {hidden_state_meta['hidden_size']}")
else:
    print("No hidden states available - skipping KD3")
    teacher_hidden_states = None

Loading hidden states...
  Found 20 hidden state files
  Total hidden states: torch.Size([2000, 9, 2048])
  Teacher layers cached: [0, 4, 8, 12, 16, 20, 24, 28, 32]
  Hidden size: 2048


In [13]:
# KD3 (Feature-based KD) - Skip for now
# Feature-based KD requires loading hidden states during training and matching layers
# This is memory-intensive and requires custom training loop

print("=" * 60)
print("KD3 (Feature-based KD) - Placeholder")
print("=" * 60)

print("""
NOTE: KD3 training requires:
  1. Loading cached hidden states during training
  2. Layer projection (teacher hidden_size -> student hidden_size)
  3. Custom loss combining LM loss + feature matching loss

This is complex to implement with the standard Trainer API
and memory-intensive on MPS.

For thesis Chapter 4:
  - Focus on B0 (baseline) and KD2 (sequence-level) results
  - KD3 can be documented as future work or run on CUDA hardware

Skipping KD3 training for now.
""")

kd3_results = []

# If hidden states are available and you want to attempt KD3:
ATTEMPT_KD3 = False

if ATTEMPT_KD3 and hidden_state_meta is not None:
    print("KD3 training would go here...")
    # Would need custom FeatureKDTrainer implementation
else:
    print("KD3 skipped (ATTEMPT_KD3=False)")

print(f"\n✓ KD3 section complete: {len(kd3_results)} runs")

KD3 (Feature-based KD) - Placeholder

NOTE: KD3 training requires:
  1. Loading cached hidden states during training
  2. Layer projection (teacher hidden_size -> student hidden_size)
  3. Custom loss combining LM loss + feature matching loss

This is complex to implement with the standard Trainer API
and memory-intensive on MPS.

For thesis Chapter 4:
  - Focus on B0 (baseline) and KD2 (sequence-level) results
  - KD3 can be documented as future work or run on CUDA hardware

Skipping KD3 training for now.

KD3 skipped (ATTEMPT_KD3=False)

✓ KD3 section complete: 0 runs


In [14]:
# Find best configurations
print("\nBest Configurations:")
print("-" * 40)

best_configs = {}

# Best KD2
if kd2_results:
    df_kd2 = pd.DataFrame(kd2_results)
    print("KD2 Results:")
    print(df_kd2.to_string())
    
    # Check for successful runs
    if "eval_loss" in df_kd2.columns and df_kd2["eval_loss"].notna().any():
        df_success = df_kd2[df_kd2["eval_loss"].notna()]
        best_kd2_idx = df_success["eval_loss"].idxmin()
        best_kd2 = df_success.loc[best_kd2_idx]
        print(f"\nKD2 Best: λ={best_kd2['lambda']}, loss={best_kd2['eval_loss']:.4f}")
        best_configs["kd2"] = {"lambda": float(best_kd2["lambda"])}
    else:
        print("⚠️ No successful KD2 runs")
else:
    print("No KD2 results")

# Best KD3
if kd3_results:
    df_kd3 = pd.DataFrame(kd3_results)
    if "eval_loss" in df_kd3.columns and df_kd3["eval_loss"].notna().any():
        df_success = df_kd3[df_kd3["eval_loss"].notna()]
        best_kd3_idx = df_success["eval_loss"].idxmin()
        best_kd3 = df_success.loc[best_kd3_idx]
        print(f"KD3 Best: λ={best_kd3['lambda']}, loss={best_kd3['eval_loss']:.4f}")
        best_configs["kd3"] = {"lambda": float(best_kd3["lambda"])}
else:
    print("No KD3 results (skipped)")

# Save best configs
if best_configs:
    with open(RUNS_DIR / "best_kd2_kd3_config.json", "w") as f:
        json.dump(best_configs, f, indent=2)
    print(f"\nSaved best configs to {RUNS_DIR / 'best_kd2_kd3_config.json'}")


Best Configurations:
----------------------------------------
KD2 Results:
                     run_id method   task student  seed  lambda  train_loss eval_loss                         note
0  KD2_squad_S1_l1.0_seed42    KD2  squad      S1    42     1.0         0.0      None  eval skipped for MPS memory
⚠️ No successful KD2 runs
No KD3 results (skipped)


In [15]:
# Train best configs across all seeds (optional - skip in fast mode)
print("Skipping multi-seed training for KD2/KD3 to save time.")
print("Single-seed results are sufficient for thesis comparison.")

Skipping multi-seed training for KD2/KD3 to save time.
Single-seed results are sufficient for thesis comparison.


In [16]:
# Save all results
all_results = kd2_results + kd3_results

if all_results:
    df_all = pd.DataFrame(all_results)
    df_all.to_csv(RUNS_DIR / "nb06_results.csv", index=False)
    print(f"Saved {len(all_results)} results to {RUNS_DIR / 'nb06_results.csv'}")
else:
    print("No results to save.")

Saved 1 results to /Users/pjere/Workshop/thesis-exp/results/raw_runs/nb06_results.csv


In [17]:
# Summary
print("=" * 60)
print("KD2 AND KD3 TRAINING COMPLETE")
print("=" * 60)

print(f"""
Mode: {'FAST' if config.fast_mode else 'FULL'}
Student: {student_name}

Runs Completed:
  KD2 Sequence-level: {len(kd2_results)} runs
  KD3 Feature-based: {len(kd3_results)} runs

Results saved to: {RUNS_DIR / 'nb06_results.csv'}
Models saved to: {MODELS_DIR}

Next Steps:
  1. Run 07_benchmark_and_plots.ipynb for efficiency benchmarks
  2. Generate thesis figures and tables
""")

# Comparison table
if all_results:
    print("\nResults Summary:")
    df = pd.DataFrame(all_results)
    print(df[["run_id", "method", "lambda", "eval_loss"]].to_string())

KD2 AND KD3 TRAINING COMPLETE

Mode: FAST
Student: TinyLlama/TinyLlama-1.1B-Chat-v1.0

Runs Completed:
  KD2 Sequence-level: 1 runs
  KD3 Feature-based: 0 runs

Results saved to: /Users/pjere/Workshop/thesis-exp/results/raw_runs/nb06_results.csv
Models saved to: /Users/pjere/Workshop/thesis-exp/results/models

Next Steps:
  1. Run 07_benchmark_and_plots.ipynb for efficiency benchmarks
  2. Generate thesis figures and tables


Results Summary:
                     run_id method  lambda eval_loss
0  KD2_squad_S1_l1.0_seed42    KD2     1.0      None
