# 05 - Training: Baseline (B0) and KD1 (Logit-based)

**Thesis Section Reference:** Chapter 4.1-4.2 - Baseline Results and Logit-based KD

This notebook trains:
1. **B0 Baseline:** Fine-tuning without distillation
2. **KD1 (Logit-based):** Knowledge distillation using soft targets

## Grid Search
- Temperature T ∈ {1, 2, 4, 8}
- Alpha α ∈ {0.1, 0.3, 0.5, 0.7}

## Memory Management
- LoRA/PEFT for parameter-efficient training
- Gradient accumulation for effective batch size 8-16
- fp32 for MPS stability
- Periodic cache clearing

In [2]:
# Standard setup
import os
import sys
import gc
import json
from pathlib import Path
from itertools import product

import torch
import numpy as np

ROOT_DIR = Path.cwd().parent if Path.cwd().name == "notebooks" else Path.cwd()
sys.path.insert(0, str(ROOT_DIR / "src"))

from dotenv import load_dotenv
load_dotenv(ROOT_DIR / ".env")

from config import load_config
from utils_seed import set_seed, get_generator
from run_io import RunRegistry

config = load_config(str(ROOT_DIR / "configs" / "experiment.yaml"))
config.ensure_dirs()

# Device
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

print(f"Mode: {'FAST' if config.fast_mode else 'FULL'}")
print(f"Device: {DEVICE}")

Mode: FAST
Device: mps


In [3]:
# Set up paths
DATA_DIR = ROOT_DIR / "results" / "processed_data"
CACHE_DIR = ROOT_DIR / "results" / "teacher_cache"
MODELS_DIR = ROOT_DIR / "results" / "models"
RUNS_DIR = ROOT_DIR / "results" / "raw_runs"

for d in [MODELS_DIR, RUNS_DIR]:
    d.mkdir(parents=True, exist_ok=True)

# Initialize run registry for idempotency
registry = RunRegistry(RUNS_DIR / "run_registry.json")

In [4]:
# Load processed datasets
from datasets import load_from_disk

print("Loading datasets...")

sst2_train = load_from_disk(str(DATA_DIR / "sst2_train"))
sst2_val = load_from_disk(str(DATA_DIR / "sst2_validation"))

squad_train = load_from_disk(str(DATA_DIR / "squad_train"))
squad_val = load_from_disk(str(DATA_DIR / "squad_validation"))

print(f"SST-2: {len(sst2_train)} train, {len(sst2_val)} val")
print(f"SQuAD: {len(squad_train)} train, {len(squad_val)} val")

Loading datasets...
SST-2: 2000 train, 500 val
SQuAD: 2000 train, 500 val


In [5]:
# Load teacher logits for KD1
print("Loading cached teacher logits...")

sst2_logits = torch.load(CACHE_DIR / "sst2_logits.pt", map_location="cpu")
squad_logits = torch.load(CACHE_DIR / "squad_logits.pt", map_location="cpu")

print(f"  SST-2 logits shape: {sst2_logits['logits'].shape if isinstance(sst2_logits, dict) else len(sst2_logits)}")
print(f"  SQuAD logits: loaded")

Loading cached teacher logits...


FileNotFoundError: [Errno 2] No such file or directory: '/Users/pjere/Workshop/thesis-exp/results/teacher_cache/sst2_logits.pt'

In [None]:
# Training configuration
from transformers import TrainingArguments
from peft import LoraConfig, get_peft_model, TaskType

def get_training_args(output_dir, task, run_name):
    """Get training arguments optimized for MPS."""
    
    # Batch sizes based on task
    if task == "sst2":
        per_device_batch = 4 if DEVICE.type == "mps" else 8
        grad_accum = 4 if DEVICE.type == "mps" else 2
    else:  # squad - longer sequences
        per_device_batch = 2 if DEVICE.type == "mps" else 4
        grad_accum = 8 if DEVICE.type == "mps" else 4
    
    # Epochs based on mode
    num_epochs = config.training.epochs_fast if config.fast_mode else config.training.epochs_full
    
    return TrainingArguments(
        output_dir=str(output_dir),
        run_name=run_name,
        
        # Training
        num_train_epochs=num_epochs,
        per_device_train_batch_size=per_device_batch,
        per_device_eval_batch_size=per_device_batch,
        gradient_accumulation_steps=grad_accum,
        
        # Optimizer
        learning_rate=config.training.learning_rate,
        weight_decay=config.training.weight_decay,
        warmup_ratio=config.training.warmup_ratio,
        lr_scheduler_type="cosine",
        
        # Precision - fp32 for MPS stability
        fp16=False,
        bf16=False,
        
        # MPS-specific
        dataloader_pin_memory=False if DEVICE.type == "mps" else True,
        dataloader_num_workers=0 if DEVICE.type == "mps" else 4,
        
        # Logging
        logging_steps=10,
        eval_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        
        # Memory
        gradient_checkpointing=True,
        
        # Reproducibility
        seed=config.get_seeds()[0],
        data_seed=config.get_seeds()[0],
        
        # Reporting
        report_to="none",
    )

def get_lora_config():
    """Get LoRA configuration."""
    return LoraConfig(
        r=config.training.lora_r,
        lora_alpha=config.training.lora_alpha,
        lora_dropout=config.training.lora_dropout,
        target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
        task_type=TaskType.CAUSAL_LM,
        bias="none"
    )

print("Training configuration ready.")

In [None]:
# Function to load student model
from transformers import AutoModelForCausalLM, AutoTokenizer

def load_student_model(student_name, use_lora=True):
    """Load student model with optional LoRA."""
    print(f"Loading student model: {student_name}")
    
    tokenizer = AutoTokenizer.from_pretrained(
        student_name,
        trust_remote_code=True,
        cache_dir=str(ROOT_DIR / "hf_cache")
    )
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model = AutoModelForCausalLM.from_pretrained(
        student_name,
        trust_remote_code=True,
        torch_dtype=torch.float32,
        cache_dir=str(ROOT_DIR / "hf_cache"),
        low_cpu_mem_usage=True
    )
    
    if use_lora:
        lora_config = get_lora_config()
        model = get_peft_model(model, lora_config)
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in model.parameters())
        print(f"  Trainable: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.2f}%)")
    
    model = model.to(DEVICE)
    
    return model, tokenizer

print("Model loading function ready.")

## Section 1: B0 Baseline Training

Train student model directly without knowledge distillation.

In [None]:
# Train baseline models
from trainers import BaselineTrainer
from data_sst2 import compute_sst2_metrics
from data_squad import compute_squad_metrics

student_name = os.getenv("STUDENT_S1", config.student_s1.name)

baseline_results = []

for task, train_data, eval_data, metric_fn in [
    ("sst2", sst2_train, sst2_val, compute_sst2_metrics),
    ("squad", squad_train, squad_val, compute_squad_metrics),
]:
    for seed in config.get_seeds():
        run_id = f"B0_{task}_S1_seed{seed}"
        
        # Check if already done
        if registry.check_run(run_id):
            print(f"✓ {run_id} already completed, skipping...")
            existing = registry.get_run(run_id)
            baseline_results.append(existing)
            continue
        
        print(f"\n{'='*60}")
        print(f"Training: {run_id}")
        print(f"{'='*60}")
        
        set_seed(seed)
        
        # Load model
        model, tokenizer = load_student_model(student_name, use_lora=True)
        
        # Training args
        output_dir = MODELS_DIR / run_id
        training_args = get_training_args(output_dir, task, run_id)
        
        # Create trainer
        trainer = BaselineTrainer(
            model=model,
            args=training_args,
            train_dataset=train_data,
            eval_dataset=eval_data,
            tokenizer=tokenizer,
            compute_metrics=metric_fn
        )
        
        # Train
        train_result = trainer.train()
        
        # Evaluate
        eval_result = trainer.evaluate()
        
        # Save results
        result = {
            "run_id": run_id,
            "method": "B0",
            "task": task,
            "student": "S1",
            "seed": seed,
            "train_loss": train_result.training_loss,
            "eval_loss": eval_result["eval_loss"],
            **{k: v for k, v in eval_result.items() if k != "eval_loss"}
        }
        
        # Save model and register run
        trainer.save_model(str(output_dir / "final"))
        registry.register_run(run_id, result)
        baseline_results.append(result)
        
        print(f"\n✓ {run_id} complete: {result}")
        
        # Cleanup
        del model, trainer
        if DEVICE.type == "mps":
            torch.mps.empty_cache()
        gc.collect()

print(f"\n✓ Baseline training complete: {len(baseline_results)} runs")

## Section 2: KD1 (Logit-based KD) Grid Search

Train with soft target distillation using cached teacher logits.

Grid: T × α combinations

In [None]:
# KD1 Grid Search
from trainers import LogitKDTrainer
from kd_losses import SoftTargetLoss

# Grid parameters
temperatures = config.kd1.temperatures
alphas = config.kd1.alphas

# In fast mode, reduce grid
if config.fast_mode:
    temperatures = [2, 4]
    alphas = [0.3, 0.5]

print(f"KD1 Grid: T={temperatures}, α={alphas}")
print(f"Total configurations: {len(temperatures) * len(alphas)}")

kd1_results = []

for task, train_data, eval_data, metric_fn, logits_cache in [
    ("sst2", sst2_train, sst2_val, compute_sst2_metrics, sst2_logits),
]:
    for T, alpha in product(temperatures, alphas):
        seed = config.get_seeds()[0]  # Use first seed for grid search
        run_id = f"KD1_{task}_S1_T{T}_a{alpha}_seed{seed}"
        
        # Check if already done
        if registry.check_run(run_id):
            print(f"✓ {run_id} already completed, skipping...")
            existing = registry.get_run(run_id)
            kd1_results.append(existing)
            continue
        
        print(f"\n{'='*60}")
        print(f"Training: {run_id}")
        print(f"T={T}, α={alpha}")
        print(f"{'='*60}")
        
        set_seed(seed)
        
        # Load model
        model, tokenizer = load_student_model(student_name, use_lora=True)
        
        # Create KD loss
        kd_loss_fn = SoftTargetLoss(temperature=T, alpha=alpha)
        
        # Training args
        output_dir = MODELS_DIR / run_id
        training_args = get_training_args(output_dir, task, run_id)
        
        # Create trainer with teacher logits
        trainer = LogitKDTrainer(
            model=model,
            args=training_args,
            train_dataset=train_data,
            eval_dataset=eval_data,
            tokenizer=tokenizer,
            compute_metrics=metric_fn,
            kd_loss_fn=kd_loss_fn,
            teacher_logits=logits_cache
        )
        
        # Train
        train_result = trainer.train()
        
        # Evaluate
        eval_result = trainer.evaluate()
        
        # Save results
        result = {
            "run_id": run_id,
            "method": "KD1",
            "task": task,
            "student": "S1",
            "seed": seed,
            "temperature": T,
            "alpha": alpha,
            "train_loss": train_result.training_loss,
            "eval_loss": eval_result["eval_loss"],
            **{k: v for k, v in eval_result.items() if k != "eval_loss"}
        }
        
        # Save and register
        trainer.save_model(str(output_dir / "final"))
        registry.register_run(run_id, result)
        kd1_results.append(result)
        
        print(f"\n✓ {run_id} complete")
        
        # Cleanup
        del model, trainer, kd_loss_fn
        if DEVICE.type == "mps":
            torch.mps.empty_cache()
        gc.collect()

print(f"\n✓ KD1 grid search complete: {len(kd1_results)} runs")

In [None]:
# Find best KD1 configuration
import pandas as pd

if kd1_results:
    df_kd1 = pd.DataFrame(kd1_results)
    
    # Best by eval_loss (lower is better)
    best_idx = df_kd1["eval_loss"].idxmin()
    best_config = df_kd1.loc[best_idx]
    
    print("Best KD1 Configuration:")
    print(f"  Temperature: {best_config['temperature']}")
    print(f"  Alpha: {best_config['alpha']}")
    print(f"  Eval Loss: {best_config['eval_loss']:.4f}")
    
    # Save best config for later
    best_kd1_config = {
        "temperature": best_config["temperature"],
        "alpha": best_config["alpha"]
    }
    
    with open(RUNS_DIR / "best_kd1_config.json", "w") as f:
        json.dump(best_kd1_config, f, indent=2)
    
    print("\nKD1 Grid Results:")
    print(df_kd1[["temperature", "alpha", "eval_loss"]].to_string())

In [None]:
# Train best KD1 config across all seeds
if len(config.get_seeds()) > 1 and kd1_results:
    print("Training best KD1 config across all seeds...")
    
    T_best = best_kd1_config["temperature"]
    alpha_best = best_kd1_config["alpha"]
    
    for seed in config.get_seeds()[1:]:  # Skip first seed (already done)
        for task, train_data, eval_data, metric_fn, logits_cache in [
            ("sst2", sst2_train, sst2_val, compute_sst2_metrics, sst2_logits),
        ]:
            run_id = f"KD1_{task}_S1_T{T_best}_a{alpha_best}_seed{seed}"
            
            if registry.check_run(run_id):
                print(f"✓ {run_id} already completed, skipping...")
                continue
            
            print(f"\nTraining: {run_id}")
            set_seed(seed)
            
            model, tokenizer = load_student_model(student_name, use_lora=True)
            kd_loss_fn = SoftTargetLoss(temperature=T_best, alpha=alpha_best)
            
            output_dir = MODELS_DIR / run_id
            training_args = get_training_args(output_dir, task, run_id)
            
            trainer = LogitKDTrainer(
                model=model,
                args=training_args,
                train_dataset=train_data,
                eval_dataset=eval_data,
                tokenizer=tokenizer,
                compute_metrics=metric_fn,
                kd_loss_fn=kd_loss_fn,
                teacher_logits=logits_cache
            )
            
            train_result = trainer.train()
            eval_result = trainer.evaluate()
            
            result = {
                "run_id": run_id,
                "method": "KD1",
                "task": task,
                "student": "S1",
                "seed": seed,
                "temperature": T_best,
                "alpha": alpha_best,
                "train_loss": train_result.training_loss,
                "eval_loss": eval_result["eval_loss"],
                **{k: v for k, v in eval_result.items() if k != "eval_loss"}
            }
            
            trainer.save_model(str(output_dir / "final"))
            registry.register_run(run_id, result)
            kd1_results.append(result)
            
            del model, trainer, kd_loss_fn
            if DEVICE.type == "mps":
                torch.mps.empty_cache()
            gc.collect()
    
    print(f"\n✓ All seeds trained for best KD1 config")

In [None]:
# Save all results
all_results = baseline_results + kd1_results

df_all = pd.DataFrame(all_results)
df_all.to_csv(RUNS_DIR / "nb05_results.csv", index=False)

print(f"Saved {len(all_results)} results to {RUNS_DIR / 'nb05_results.csv'}")

In [None]:
# Summary
print("=" * 60)
print("BASELINE AND KD1 TRAINING COMPLETE")
print("=" * 60)

print(f"""
Mode: {'FAST' if config.fast_mode else 'FULL'}
Student: {student_name}

Runs Completed:
  B0 Baseline: {len(baseline_results)} runs
  KD1 Logit-based: {len(kd1_results)} runs

Results saved to: {RUNS_DIR / 'nb05_results.csv'}
Models saved to: {MODELS_DIR}

Next Steps:
  1. Run 06_train_kd2_and_kd3.ipynb for sequence and feature KD
  2. Run 07_benchmark_and_plots.ipynb for final evaluation
""")

# Show summary table
if all_results:
    df_summary = pd.DataFrame(all_results)
    print("\nResults Summary:")
    print(df_summary[["run_id", "eval_loss"]].to_string())