# 06 - Training: KD2 (Sequence-level) and KD3 (Feature-based)

**Thesis Section Reference:** Chapter 4.3-4.4 - Sequence-level and Feature-based KD

This notebook trains:
1. **KD2 (Sequence-level):** Student learns from teacher-generated sequences
2. **KD3 (Feature-based):** Student matches teacher's hidden representations

## Grid Search
- Lambda λ ∈ {0.1, 0.5, 1.0}

## Notes
- KD2 is particularly useful for QA (SQuAD)
- KD3 uses layer mapping between teacher and student

In [None]:
# Standard setup
import os
import sys
import gc
import json
from pathlib import Path
from itertools import product

import torch
import numpy as np
import pandas as pd

ROOT_DIR = Path.cwd().parent if Path.cwd().name == "notebooks" else Path.cwd()
sys.path.insert(0, str(ROOT_DIR / "src"))

from dotenv import load_dotenv
load_dotenv(ROOT_DIR / ".env")

from config import load_config
from utils_seed import set_seed
from run_io import RunRegistry

config = load_config(str(ROOT_DIR / "configs" / "experiment.yaml"))
config.ensure_dirs()

# Device
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

print(f"Mode: {'FAST' if config.fast_mode else 'FULL'}")
print(f"Device: {DEVICE}")

In [None]:
# Set up paths
DATA_DIR = ROOT_DIR / "results" / "processed_data"
CACHE_DIR = ROOT_DIR / "results" / "teacher_cache"
MODELS_DIR = ROOT_DIR / "results" / "models"
RUNS_DIR = ROOT_DIR / "results" / "raw_runs"

registry = RunRegistry(RUNS_DIR / "run_registry.json")

student_name = os.getenv("STUDENT_S1", config.student_s1.name)

In [None]:
# Load processed datasets
from datasets import load_from_disk, Dataset

print("Loading datasets...")

sst2_train = load_from_disk(str(DATA_DIR / "sst2_train"))
sst2_val = load_from_disk(str(DATA_DIR / "sst2_validation"))

squad_train = load_from_disk(str(DATA_DIR / "squad_train"))
squad_val = load_from_disk(str(DATA_DIR / "squad_validation"))

print(f"SST-2: {len(sst2_train)} train, {len(sst2_val)} val")
print(f"SQuAD: {len(squad_train)} train, {len(squad_val)} val")

In [None]:
# Load teacher outputs for KD2 and KD3
print("Loading cached teacher outputs...")

# KD2: Teacher-generated answers for SQuAD
with open(CACHE_DIR / "squad_teacher_answers.json", "r") as f:
    teacher_answers = json.load(f)
print(f"  Loaded {len(teacher_answers)} teacher answers for KD2")

# KD3: Hidden states
hidden_state_files = sorted(CACHE_DIR.glob("hidden_states_*.pt"))
if hidden_state_files:
    with open(CACHE_DIR / "hidden_states_sst2_meta.json", "r") as f:
        hidden_state_meta = json.load(f)
    print(f"  Loaded hidden state metadata: {hidden_state_meta['num_chunks']} chunks")
else:
    print("  No hidden states found - KD3 will be skipped")
    hidden_state_meta = None

In [None]:
# Training utilities (same as notebook 05)
from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType

def get_training_args(output_dir, task, run_name):
    """Get training arguments optimized for MPS."""
    if task == "sst2":
        per_device_batch = 4 if DEVICE.type == "mps" else 8
        grad_accum = 4 if DEVICE.type == "mps" else 2
    else:
        per_device_batch = 2 if DEVICE.type == "mps" else 4
        grad_accum = 8 if DEVICE.type == "mps" else 4
    
    num_epochs = config.training.epochs_fast if config.fast_mode else config.training.epochs_full
    
    return TrainingArguments(
        output_dir=str(output_dir),
        run_name=run_name,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=per_device_batch,
        per_device_eval_batch_size=per_device_batch,
        gradient_accumulation_steps=grad_accum,
        learning_rate=config.training.learning_rate,
        weight_decay=config.training.weight_decay,
        warmup_ratio=config.training.warmup_ratio,
        lr_scheduler_type="cosine",
        fp16=False,
        bf16=False,
        dataloader_pin_memory=False if DEVICE.type == "mps" else True,
        dataloader_num_workers=0 if DEVICE.type == "mps" else 4,
        logging_steps=10,
        eval_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        gradient_checkpointing=True,
        seed=config.get_seeds()[0],
        data_seed=config.get_seeds()[0],
        report_to="none",
    )

def get_lora_config():
    return LoraConfig(
        r=config.training.lora_r,
        lora_alpha=config.training.lora_alpha,
        lora_dropout=config.training.lora_dropout,
        target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
        task_type=TaskType.CAUSAL_LM,
        bias="none"
    )

def load_student_model(student_name, use_lora=True):
    """Load student model with optional LoRA."""
    tokenizer = AutoTokenizer.from_pretrained(
        student_name,
        trust_remote_code=True,
        cache_dir=str(ROOT_DIR / "hf_cache")
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model = AutoModelForCausalLM.from_pretrained(
        student_name,
        trust_remote_code=True,
        torch_dtype=torch.float32,
        cache_dir=str(ROOT_DIR / "hf_cache"),
        low_cpu_mem_usage=True
    )
    
    if use_lora:
        model = get_peft_model(model, get_lora_config())
    
    model = model.to(DEVICE)
    return model, tokenizer

print("Training utilities ready.")

## Section 1: KD2 (Sequence-level KD)

Student learns from teacher-generated sequences.
The student is trained on (prompt, teacher_answer) pairs.

In [None]:
# Prepare KD2 dataset (using teacher answers as targets)
from data_squad import create_squad_prompt

def create_kd2_dataset(teacher_answers, tokenizer, max_length=512):
    """Create dataset with teacher answers as targets."""
    
    input_ids_list = []
    attention_mask_list = []
    labels_list = []
    
    for item in teacher_answers:
        # Full sequence: prompt + teacher_answer
        prompt = item["prompt"]
        answer = item["teacher_answer"]
        full_text = f"{prompt}\nAnswer: {answer}"
        
        # Tokenize
        encoded = tokenizer(
            full_text,
            max_length=max_length,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )
        
        input_ids = encoded["input_ids"].squeeze()
        attention_mask = encoded["attention_mask"].squeeze()
        
        # Labels: mask prompt, only predict answer
        prompt_encoded = tokenizer(
            f"{prompt}\nAnswer: ",
            return_tensors="pt"
        )
        prompt_len = prompt_encoded["input_ids"].shape[1]
        
        labels = input_ids.clone()
        labels[:prompt_len] = -100  # Mask prompt
        
        input_ids_list.append(input_ids)
        attention_mask_list.append(attention_mask)
        labels_list.append(labels)
    
    return Dataset.from_dict({
        "input_ids": input_ids_list,
        "attention_mask": attention_mask_list,
        "labels": labels_list,
    })

print("KD2 data preparation function ready.")

In [None]:
# Train KD2 models
from trainers import SequenceKDTrainer
from kd_losses import SequenceKDLoss
from data_squad import compute_squad_metrics

# Lambda grid
lambdas = config.kd2.lambdas if hasattr(config, 'kd2') else [0.1, 0.5, 1.0]

if config.fast_mode:
    lambdas = [0.5, 1.0]

print(f"KD2 Grid: λ={lambdas}")

kd2_results = []

for lambda_val in lambdas:
    seed = config.get_seeds()[0]
    run_id = f"KD2_squad_S1_l{lambda_val}_seed{seed}"
    
    if registry.check_run(run_id):
        print(f"✓ {run_id} already completed, skipping...")
        existing = registry.get_run(run_id)
        kd2_results.append(existing)
        continue
    
    print(f"\n{'='*60}")
    print(f"Training: {run_id}")
    print(f"λ={lambda_val}")
    print(f"{'='*60}")
    
    set_seed(seed)
    
    # Load model
    model, tokenizer = load_student_model(student_name, use_lora=True)
    
    # Create KD2 dataset
    print("Preparing KD2 dataset...")
    kd2_train = create_kd2_dataset(
        teacher_answers,
        tokenizer,
        max_length=config.get_max_length("squad")
    )
    
    # Create loss
    kd_loss_fn = SequenceKDLoss(lambda_=lambda_val)
    
    # Training args
    output_dir = MODELS_DIR / run_id
    training_args = get_training_args(output_dir, "squad", run_id)
    
    # Create trainer
    trainer = SequenceKDTrainer(
        model=model,
        args=training_args,
        train_dataset=kd2_train,
        eval_dataset=squad_val,
        tokenizer=tokenizer,
        compute_metrics=compute_squad_metrics,
        kd_loss_fn=kd_loss_fn
    )
    
    # Train
    train_result = trainer.train()
    
    # Evaluate
    eval_result = trainer.evaluate()
    
    # Save results
    result = {
        "run_id": run_id,
        "method": "KD2",
        "task": "squad",
        "student": "S1",
        "seed": seed,
        "lambda": lambda_val,
        "train_loss": train_result.training_loss,
        "eval_loss": eval_result["eval_loss"],
        **{k: v for k, v in eval_result.items() if k != "eval_loss"}
    }
    
    trainer.save_model(str(output_dir / "final"))
    registry.register_run(run_id, result)
    kd2_results.append(result)
    
    print(f"\n✓ {run_id} complete")
    
    # Cleanup
    del model, trainer, kd_loss_fn, kd2_train
    if DEVICE.type == "mps":
        torch.mps.empty_cache()
    gc.collect()

print(f"\n✓ KD2 training complete: {len(kd2_results)} runs")

## Section 2: KD3 (Feature-based KD)

Student learns to match teacher's hidden representations.
Uses layer mapping to align teacher and student layers.

In [None]:
# Load hidden states for KD3
if hidden_state_meta is not None:
    print("Loading hidden states...")
    
    hidden_states_list = []
    for i in range(hidden_state_meta["num_chunks"]):
        chunk = torch.load(CACHE_DIR / f"hidden_states_sst2_{i}.pt")
        hidden_states_list.append(chunk)
    
    # Concatenate all chunks
    teacher_hidden_states = torch.cat(hidden_states_list, dim=0)
    print(f"  Total hidden states: {teacher_hidden_states.shape}")
    
    # Layer mapping info
    print(f"  Teacher layers cached: {hidden_state_meta['selected_layers']}")
    print(f"  Hidden size: {hidden_state_meta['hidden_size']}")
else:
    print("No hidden states available - skipping KD3")
    teacher_hidden_states = None

In [None]:
# Train KD3 models
from trainers import FeatureKDTrainer
from kd_losses import FeatureMatchingLoss
from data_sst2 import compute_sst2_metrics

kd3_results = []

if teacher_hidden_states is not None:
    # Lambda grid
    lambdas = config.kd3.lambdas if hasattr(config, 'kd3') else [0.1, 0.5, 1.0]
    
    if config.fast_mode:
        lambdas = [0.5, 1.0]
    
    print(f"KD3 Grid: λ={lambdas}")
    
    for lambda_val in lambdas:
        seed = config.get_seeds()[0]
        run_id = f"KD3_sst2_S1_l{lambda_val}_seed{seed}"
        
        if registry.check_run(run_id):
            print(f"✓ {run_id} already completed, skipping...")
            existing = registry.get_run(run_id)
            kd3_results.append(existing)
            continue
        
        print(f"\n{'='*60}")
        print(f"Training: {run_id}")
        print(f"λ={lambda_val}")
        print(f"{'='*60}")
        
        set_seed(seed)
        
        # Load model
        model, tokenizer = load_student_model(student_name, use_lora=True)
        
        # Get student layer count for mapping
        student_layers = model.config.num_hidden_layers if hasattr(model, 'config') else 22
        teacher_cached_layers = hidden_state_meta['selected_layers']
        
        # Create layer mapping: evenly distribute
        layer_mapping = {}
        student_interval = student_layers // len(teacher_cached_layers)
        for i, teacher_layer in enumerate(teacher_cached_layers):
            student_layer = min(i * student_interval, student_layers - 1)
            layer_mapping[teacher_layer] = student_layer
        
        print(f"  Layer mapping: {layer_mapping}")
        
        # Create loss
        kd_loss_fn = FeatureMatchingLoss(
            lambda_=lambda_val,
            teacher_hidden_size=hidden_state_meta['hidden_size'],
            student_hidden_size=model.config.hidden_size if hasattr(model.config, 'hidden_size') else 2048
        )
        
        # Training args
        output_dir = MODELS_DIR / run_id
        training_args = get_training_args(output_dir, "sst2", run_id)
        
        # Create trainer
        trainer = FeatureKDTrainer(
            model=model,
            args=training_args,
            train_dataset=sst2_train,
            eval_dataset=sst2_val,
            tokenizer=tokenizer,
            compute_metrics=compute_sst2_metrics,
            kd_loss_fn=kd_loss_fn,
            teacher_hidden_states=teacher_hidden_states,
            layer_mapping=layer_mapping
        )
        
        # Train
        train_result = trainer.train()
        
        # Evaluate
        eval_result = trainer.evaluate()
        
        # Save results
        result = {
            "run_id": run_id,
            "method": "KD3",
            "task": "sst2",
            "student": "S1",
            "seed": seed,
            "lambda": lambda_val,
            "train_loss": train_result.training_loss,
            "eval_loss": eval_result["eval_loss"],
            **{k: v for k, v in eval_result.items() if k != "eval_loss"}
        }
        
        trainer.save_model(str(output_dir / "final"))
        registry.register_run(run_id, result)
        kd3_results.append(result)
        
        print(f"\n✓ {run_id} complete")
        
        # Cleanup
        del model, trainer, kd_loss_fn
        if DEVICE.type == "mps":
            torch.mps.empty_cache()
        gc.collect()
    
    print(f"\n✓ KD3 training complete: {len(kd3_results)} runs")
else:
    print("Skipping KD3 (no hidden states available)")

In [None]:
# Find best configurations
print("\nBest Configurations:")
print("-" * 40)

best_configs = {}

# Best KD2
if kd2_results:
    df_kd2 = pd.DataFrame(kd2_results)
    best_kd2_idx = df_kd2["eval_loss"].idxmin()
    best_kd2 = df_kd2.loc[best_kd2_idx]
    print(f"KD2 Best: λ={best_kd2['lambda']}, loss={best_kd2['eval_loss']:.4f}")
    best_configs["kd2"] = {"lambda": best_kd2["lambda"]}

# Best KD3
if kd3_results:
    df_kd3 = pd.DataFrame(kd3_results)
    best_kd3_idx = df_kd3["eval_loss"].idxmin()
    best_kd3 = df_kd3.loc[best_kd3_idx]
    print(f"KD3 Best: λ={best_kd3['lambda']}, loss={best_kd3['eval_loss']:.4f}")
    best_configs["kd3"] = {"lambda": best_kd3["lambda"]}

# Save best configs
with open(RUNS_DIR / "best_kd2_kd3_config.json", "w") as f:
    json.dump(best_configs, f, indent=2)

In [None]:
# Train best configs across all seeds
if len(config.get_seeds()) > 1:
    print("\nTraining best configs across all seeds...")
    
    for seed in config.get_seeds()[1:]:
        # Best KD2
        if "kd2" in best_configs:
            lambda_val = best_configs["kd2"]["lambda"]
            run_id = f"KD2_squad_S1_l{lambda_val}_seed{seed}"
            
            if not registry.check_run(run_id):
                print(f"Training: {run_id}")
                set_seed(seed)
                
                model, tokenizer = load_student_model(student_name, use_lora=True)
                kd2_train = create_kd2_dataset(
                    teacher_answers,
                    tokenizer,
                    max_length=config.get_max_length("squad")
                )
                kd_loss_fn = SequenceKDLoss(lambda_=lambda_val)
                
                output_dir = MODELS_DIR / run_id
                training_args = get_training_args(output_dir, "squad", run_id)
                
                trainer = SequenceKDTrainer(
                    model=model,
                    args=training_args,
                    train_dataset=kd2_train,
                    eval_dataset=squad_val,
                    tokenizer=tokenizer,
                    compute_metrics=compute_squad_metrics,
                    kd_loss_fn=kd_loss_fn
                )
                
                train_result = trainer.train()
                eval_result = trainer.evaluate()
                
                result = {
                    "run_id": run_id,
                    "method": "KD2",
                    "task": "squad",
                    "student": "S1",
                    "seed": seed,
                    "lambda": lambda_val,
                    "train_loss": train_result.training_loss,
                    "eval_loss": eval_result["eval_loss"],
                    **{k: v for k, v in eval_result.items() if k != "eval_loss"}
                }
                
                trainer.save_model(str(output_dir / "final"))
                registry.register_run(run_id, result)
                kd2_results.append(result)
                
                del model, trainer, kd_loss_fn, kd2_train
                if DEVICE.type == "mps":
                    torch.mps.empty_cache()
                gc.collect()
    
    print("✓ All seeds trained")

In [None]:
# Save all results
all_results = kd2_results + kd3_results

df_all = pd.DataFrame(all_results)
df_all.to_csv(RUNS_DIR / "nb06_results.csv", index=False)

print(f"Saved {len(all_results)} results to {RUNS_DIR / 'nb06_results.csv'}")

In [None]:
# Summary
print("=" * 60)
print("KD2 AND KD3 TRAINING COMPLETE")
print("=" * 60)

print(f"""
Mode: {'FAST' if config.fast_mode else 'FULL'}
Student: {student_name}

Runs Completed:
  KD2 Sequence-level: {len(kd2_results)} runs
  KD3 Feature-based: {len(kd3_results)} runs

Results saved to: {RUNS_DIR / 'nb06_results.csv'}
Models saved to: {MODELS_DIR}

Next Steps:
  1. Run 07_benchmark_and_plots.ipynb for efficiency benchmarks
  2. Generate thesis figures and tables
""")

# Comparison table
if all_results:
    print("\nResults Summary:")
    df = pd.DataFrame(all_results)
    print(df[["run_id", "method", "lambda", "eval_loss"]].to_string())