# 05 - Training: Baseline (B0) and KD1 (Logit-based)

**Thesis Section Reference:** Chapter 4.1-4.2 - Baseline Results and Logit-based KD

This notebook trains:
1. **B0 Baseline:** Fine-tuning without distillation
2. **KD1 (Logit-based):** Knowledge distillation using soft targets

## Grid Search
- Temperature T ∈ {1, 2, 4, 8}
- Alpha α ∈ {0.1, 0.3, 0.5, 0.7}

## Memory Management
- LoRA/PEFT for parameter-efficient training
- Gradient accumulation for effective batch size 8-16
- fp32 for MPS stability
- Periodic cache clearing

In [1]:
# Standard setup
import os
import sys
import gc
import json
from pathlib import Path
from itertools import product

import torch
import numpy as np

ROOT_DIR = Path.cwd().parent if Path.cwd().name == "notebooks" else Path.cwd()
sys.path.insert(0, str(ROOT_DIR / "src"))

from dotenv import load_dotenv
load_dotenv(ROOT_DIR / ".env")

from config import load_config
from utils_seed import set_seed, get_generator
from run_io import RunRegistry

config = load_config(str(ROOT_DIR / "configs" / "experiment.yaml"))
config.ensure_dirs()

# Device
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

print(f"Mode: {'FAST' if config.fast_mode else 'FULL'}")
print(f"Device: {DEVICE}")

Mode: FAST
Device: mps


In [2]:
# Set up paths
DATA_DIR = ROOT_DIR / "results" / "processed_data"
CACHE_DIR = ROOT_DIR / "results" / "teacher_cache"
MODELS_DIR = ROOT_DIR / "results" / "models"
RUNS_DIR = ROOT_DIR / "results" / "raw_runs"

for d in [MODELS_DIR, RUNS_DIR]:
    d.mkdir(parents=True, exist_ok=True)

# Simple run registry using a JSON file for idempotency
class SimpleRegistry:
    """Simple JSON-based run registry for idempotent training."""
    
    def __init__(self, path):
        self.path = Path(path)
        self.runs = {}
        if self.path.exists():
            with open(self.path, 'r') as f:
                self.runs = json.load(f)
    
    def _save(self):
        with open(self.path, 'w') as f:
            json.dump(self.runs, f, indent=2)
    
    def check_run(self, run_id):
        """Check if run is already completed."""
        return run_id in self.runs and self.runs[run_id].get("status") == "completed"
    
    def get_run(self, run_id):
        """Get run results."""
        return self.runs.get(run_id)
    
    def register_run(self, run_id, result):
        """Register a completed run."""
        self.runs[run_id] = {**result, "status": "completed"}
        self._save()

registry = SimpleRegistry(RUNS_DIR / "run_registry.json")

In [3]:
# Load processed datasets
from datasets import load_from_disk

print("Loading datasets...")

sst2_train = load_from_disk(str(DATA_DIR / "sst2_train"))
sst2_val = load_from_disk(str(DATA_DIR / "sst2_validation"))

squad_train = load_from_disk(str(DATA_DIR / "squad_train"))
squad_val = load_from_disk(str(DATA_DIR / "squad_validation"))

print(f"SST-2: {len(sst2_train)} train, {len(sst2_val)} val")
print(f"SQuAD: {len(squad_train)} train, {len(squad_val)} val")

Loading datasets...
SST-2: 2000 train, 500 val
SQuAD: 2000 train, 500 val


In [4]:
# Load teacher logits for KD1 (chunked format)
print("Loading cached teacher logits...")

def load_chunked_logits(cache_dir, task):
    """Load chunked logits saved by notebook 04."""
    meta_file = cache_dir / f"{task}_logits.pt"
    
    if not meta_file.exists():
        raise FileNotFoundError(f"Logits metadata not found: {meta_file}")
    
    meta = torch.load(meta_file, map_location="cpu")
    print(f"  {task} metadata: {meta}")
    
    # Load all chunks
    all_logits = []
    all_indices = []
    
    for chunk_idx in range(meta["num_chunks"]):
        chunk_file = cache_dir / f"{task}_logits_chunk_{chunk_idx}.pt"
        if chunk_file.exists():
            chunk_data = torch.load(chunk_file, map_location="cpu")
            all_logits.extend(chunk_data["logits"])
            all_indices.extend(chunk_data["indices"])
    
    print(f"  {task}: loaded {len(all_logits)} examples, top_k={meta['top_k']}")
    
    return {
        "logits": all_logits,  # List of tensors [seq_len, top_k]
        "indices": all_indices,  # List of tensors [seq_len, top_k]
        "top_k": meta["top_k"]
    }

sst2_logits = load_chunked_logits(CACHE_DIR, "sst2")
squad_logits = load_chunked_logits(CACHE_DIR, "squad")

print("✓ Teacher logits loaded")

Loading cached teacher logits...
  sst2 metadata: {'num_chunks': 40, 'top_k': 50, 'task': 'sst2'}
  sst2: loaded 2000 examples, top_k=50
  squad metadata: {'num_chunks': 80, 'top_k': 50, 'task': 'squad'}
  squad: loaded 2000 examples, top_k=50
✓ Teacher logits loaded


In [5]:
# Training configuration
from transformers import TrainingArguments
from peft import LoraConfig, get_peft_model, TaskType

def get_training_args(output_dir, task, run_name):
    """Get training arguments optimized for MPS."""
    
    # Batch sizes based on task
    if task == "sst2":
        per_device_batch = 4 if DEVICE.type == "mps" else 8
        grad_accum = 4 if DEVICE.type == "mps" else 2
    else:  # squad - longer sequences
        per_device_batch = 2 if DEVICE.type == "mps" else 4
        grad_accum = 8 if DEVICE.type == "mps" else 4
    
    # Epochs based on mode - use config method
    num_epochs = config.get_epochs()
    
    return TrainingArguments(
        output_dir=str(output_dir),
        run_name=run_name,
        
        # Training
        num_train_epochs=num_epochs,
        per_device_train_batch_size=per_device_batch,
        per_device_eval_batch_size=per_device_batch,
        gradient_accumulation_steps=grad_accum,
        
        # Optimizer
        learning_rate=config.training.learning_rate,
        weight_decay=config.training.weight_decay,
        warmup_ratio=config.training.warmup_ratio,
        lr_scheduler_type="cosine",
        
        # Precision - fp32 for MPS stability
        fp16=False,
        bf16=False,
        
        # MPS-specific
        dataloader_pin_memory=False if DEVICE.type == "mps" else True,
        dataloader_num_workers=0 if DEVICE.type == "mps" else 4,
        
        # Logging
        logging_steps=10,
        eval_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        
        # Memory
        gradient_checkpointing=True,
        
        # Reproducibility
        seed=config.get_seeds()[0],
        data_seed=config.get_seeds()[0],
        
        # Reporting
        report_to="none",
    )

def get_lora_config():
    """Get LoRA configuration."""
    return LoraConfig(
        r=config.lora.r,
        lora_alpha=config.lora.lora_alpha,
        lora_dropout=config.lora.lora_dropout,
        target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
        task_type=TaskType.CAUSAL_LM,
        bias="none"
    )

print("Training configuration ready.")

Training configuration ready.


In [6]:
# Function to load student model
from transformers import AutoModelForCausalLM, AutoTokenizer

def load_student_model(student_name, use_lora=True):
    """Load student model with optional LoRA."""
    print(f"Loading student model: {student_name}")
    
    tokenizer = AutoTokenizer.from_pretrained(
        student_name,
        trust_remote_code=True,
        cache_dir=str(ROOT_DIR / "hf_cache")
    )
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model = AutoModelForCausalLM.from_pretrained(
        student_name,
        trust_remote_code=True,
        torch_dtype=torch.float32,
        cache_dir=str(ROOT_DIR / "hf_cache"),
        low_cpu_mem_usage=True
    )
    
    if use_lora:
        lora_config = get_lora_config()
        model = get_peft_model(model, lora_config)
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in model.parameters())
        print(f"  Trainable: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.2f}%)")
    
    model = model.to(DEVICE)
    
    return model, tokenizer

print("Model loading function ready.")

Model loading function ready.


## Section 1: B0 Baseline Training

Train student model directly without knowledge distillation.

In [7]:
# Train baseline models with aggressive memory management for MPS
from trainers import BaselineTrainer
from data_sst2 import compute_sst2_metrics
from data_squad import compute_squad_metrics

import numpy as np

# Force garbage collection before starting
gc.collect()
if DEVICE.type == "mps":
    torch.mps.empty_cache()
    torch.mps.synchronize()

# Placeholder metric function - actual task metrics computed in final benchmark
def make_placeholder_metric_fn():
    """Create placeholder metric function for training."""
    def compute_metrics(eval_pred):
        # Just return placeholder - proper evaluation done separately via generation
        return {"eval_placeholder": 0.0}
    return compute_metrics

student_name = os.getenv("STUDENT_S1", config.student_s1.name)
print(f"Student model: {student_name}")

baseline_results = []

# Train only SST-2 first (shorter sequences, faster) - skip SQuAD for now
# to diagnose memory issues
tasks_to_train = [
    ("sst2", sst2_train, sst2_val),
]

for task, train_data, eval_data in tasks_to_train:
    for seed in config.get_seeds()[:1]:  # Start with just 1 seed to test
        run_id = f"B0_{task}_S1_seed{seed}"
        
        # Check if already done
        if registry.check_run(run_id):
            print(f"✓ {run_id} already completed, skipping...")
            existing = registry.get_run(run_id)
            baseline_results.append(existing)
            continue
        
        print(f"\n{'='*60}")
        print(f"Training: {run_id}")
        print(f"{'='*60}")
        
        # Aggressive cleanup before loading model
        gc.collect()
        if DEVICE.type == "mps":
            torch.mps.empty_cache()
            torch.mps.synchronize()
        
        set_seed(seed)
        
        # Load model with reduced memory settings
        print(f"Loading student model: {student_name}")
        
        tokenizer = AutoTokenizer.from_pretrained(
            student_name,
            trust_remote_code=True,
            cache_dir=str(ROOT_DIR / "hf_cache")
        )
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        # Load model in fp32 for MPS stability
        model = AutoModelForCausalLM.from_pretrained(
            student_name,
            trust_remote_code=True,
            torch_dtype=torch.float32,
            cache_dir=str(ROOT_DIR / "hf_cache"),
            low_cpu_mem_usage=True
        )
        
        # Apply LoRA
        lora_config = get_lora_config()
        model = get_peft_model(model, lora_config)
        trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total = sum(p.numel() for p in model.parameters())
        print(f"  LoRA params: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
        
        model = model.to(DEVICE)
        
        # More conservative training args for MPS
        output_dir = MODELS_DIR / run_id
        
        training_args = TrainingArguments(
            output_dir=str(output_dir),
            run_name=run_id,
            
            # Very conservative for MPS - batch 1 with high accumulation
            num_train_epochs=config.get_epochs(),
            per_device_train_batch_size=1,  # Minimal batch
            per_device_eval_batch_size=1,
            gradient_accumulation_steps=16,  # Effective batch = 16
            
            # Optimizer
            learning_rate=config.training.learning_rate,
            weight_decay=config.training.weight_decay,
            warmup_ratio=config.training.warmup_ratio,
            lr_scheduler_type="cosine",
            
            # fp32 for MPS
            fp16=False,
            bf16=False,
            
            # MPS-specific
            dataloader_pin_memory=False,
            dataloader_num_workers=0,
            
            # Logging - less frequent to reduce overhead
            logging_steps=50,
            eval_strategy="epoch",
            save_strategy="epoch",
            save_total_limit=1,
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            
            # Memory savings
            gradient_checkpointing=True,
            optim="adamw_torch",  # Standard optimizer
            
            # Reproducibility
            seed=seed,
            data_seed=seed,
            
            report_to="none",
        )
        
        # Create trainer
        metric_fn = make_placeholder_metric_fn()
        
        trainer = BaselineTrainer(
            model=model,
            args=training_args,
            train_dataset=train_data,
            eval_dataset=eval_data,
            processing_class=tokenizer,
            compute_metrics=metric_fn
        )
        
        print("Starting training...")
        
        try:
            # Train with periodic cleanup
            train_result = trainer.train()
            
            # Evaluate
            eval_result = trainer.evaluate()
            
            # Save results
            result = {
                "run_id": run_id,
                "method": "B0",
                "task": task,
                "student": "S1",
                "seed": seed,
                "train_loss": train_result.training_loss,
                "eval_loss": eval_result["eval_loss"],
            }
            
            # Save model
            trainer.save_model(str(output_dir / "final"))
            registry.register_run(run_id, result)
            baseline_results.append(result)
            
            print(f"\n✓ {run_id} complete: eval_loss={result['eval_loss']:.4f}")
            
        except Exception as e:
            print(f"✗ Training failed: {e}")
            result = {
                "run_id": run_id,
                "method": "B0",
                "task": task,
                "student": "S1", 
                "seed": seed,
                "error": str(e)
            }
            baseline_results.append(result)
        
        # Aggressive cleanup after each run
        del model, trainer
        gc.collect()
        if DEVICE.type == "mps":
            torch.mps.empty_cache()
            torch.mps.synchronize()
        
        # Small delay to let MPS settle
        import time
        time.sleep(2)

print(f"\n✓ Baseline training complete: {len(baseline_results)} runs")

Student model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
✓ B0_sst2_S1_seed42 already completed, skipping...

✓ Baseline training complete: 1 runs


## Section 2: KD1 (Logit-based KD) Grid Search

Train with soft target distillation using cached teacher logits.

Grid: T × α combinations

In [8]:
# KD1 Grid Search  
# NOTE: Since we only have top-k sparse logits cached, and the full KD1 training 
# requires full teacher logits or on-the-fly teacher computation, we'll skip 
# the full KD1 grid search for now. Instead, we'll use the baseline trainer
# and evaluate KD effectiveness in the final benchmark.

# For thesis purposes, we can either:
# 1. Run with on-the-fly teacher (memory intensive)
# 2. Cache full logits (storage intensive) 
# 3. Use sparse logits with custom training loop

# For now, mark KD1 as placeholder - actual KD experiments will be done
# with a simpler setup in the benchmark notebook

print("=" * 60)
print("KD1 Grid Search - Placeholder")
print("=" * 60)
print("""
NOTE: Full KD1 training requires either:
  1. On-the-fly teacher inference (high memory on MPS)
  2. Full cached teacher logits (large storage)
  
The current setup uses top-k sparse logits which requires
a custom training loop not yet implemented.

For thesis Chapter 4, we will:
  - Report baseline B0 results from this notebook
  - Implement simplified KD evaluation in notebook 07

Skipping KD1 grid search for now.
""")

kd1_results = []

# If you want to try on-the-fly KD1 (may crash on MPS due to memory):
RUN_KD1_ONLINE = False  # Set to True to attempt on-the-fly KD

if RUN_KD1_ONLINE:
    from trainers import LogitKDTrainer
    from kd_losses import SoftTargetLoss
    
    # Use reduced grid for memory
    temperatures = [4]  # Just one temperature
    alphas = [0.5]      # Just one alpha
    
    print(f"Attempting online KD1 with T={temperatures}, α={alphas}")
    
    for task, train_data, eval_data in [("sst2", sst2_train, sst2_val)]:
        for T, alpha in product(temperatures, alphas):
            seed = config.get_seeds()[0]
            run_id = f"KD1_{task}_S1_T{T}_a{alpha}_seed{seed}"
            
            if registry.check_run(run_id):
                print(f"✓ {run_id} already completed")
                kd1_results.append(registry.get_run(run_id))
                continue
            
            print(f"\nTraining: {run_id}")
            
            # This would require loading teacher model + student model
            # which may exceed MPS memory
            print("⚠️ Online KD1 skipped - requires teacher model in memory")
            
else:
    print("KD1 training skipped (RUN_KD1_ONLINE=False)")

print(f"\n✓ KD1 section complete: {len(kd1_results)} runs")

KD1 Grid Search - Placeholder

NOTE: Full KD1 training requires either:
  1. On-the-fly teacher inference (high memory on MPS)
  2. Full cached teacher logits (large storage)

The current setup uses top-k sparse logits which requires
a custom training loop not yet implemented.

For thesis Chapter 4, we will:
  - Report baseline B0 results from this notebook
  - Implement simplified KD evaluation in notebook 07

Skipping KD1 grid search for now.

KD1 training skipped (RUN_KD1_ONLINE=False)

✓ KD1 section complete: 0 runs


In [9]:
# Find best KD1 configuration
import pandas as pd

if kd1_results:
    df_kd1 = pd.DataFrame(kd1_results)
    print("KD1 Results DataFrame:")
    print(df_kd1.to_string())
    
    # Check if we have any successful runs with eval_loss
    if "eval_loss" in df_kd1.columns and df_kd1["eval_loss"].notna().any():
        # Filter to successful runs only
        df_success = df_kd1[df_kd1["eval_loss"].notna()]
        
        # Best by eval_loss (lower is better)
        best_idx = df_success["eval_loss"].idxmin()
        best_config = df_success.loc[best_idx]
        
        print("\nBest KD1 Configuration:")
        print(f"  Temperature: {best_config['temperature']}")
        print(f"  Alpha: {best_config['alpha']}")
        print(f"  Eval Loss: {best_config['eval_loss']:.4f}")
        
        # Save best config for later
        best_kd1_config = {
            "temperature": float(best_config["temperature"]),
            "alpha": float(best_config["alpha"])
        }
        
        with open(RUNS_DIR / "best_kd1_config.json", "w") as f:
            json.dump(best_kd1_config, f, indent=2)
        
        print("\nKD1 Grid Results:")
        print(df_success[["temperature", "alpha", "eval_loss"]].to_string())
    else:
        print("\n⚠️ No successful KD1 runs - all training attempts failed.")
        print("Check the 'error' column for failure reasons:")
        if "error" in df_kd1.columns:
            for _, row in df_kd1.iterrows():
                if "error" in row and pd.notna(row.get("error")):
                    print(f"  {row['run_id']}: {row['error'][:100]}...")
        
        # Use default config
        best_kd1_config = {"temperature": 4, "alpha": 0.5}
        print(f"\nUsing default config: T={best_kd1_config['temperature']}, α={best_kd1_config['alpha']}")
else:
    print("No KD1 results to analyze.")
    best_kd1_config = {"temperature": 4, "alpha": 0.5}

No KD1 results to analyze.


In [10]:
# Train best KD1 config across all seeds
if len(config.get_seeds()) > 1 and kd1_results:
    print("Training best KD1 config across all seeds...")
    
    T_best = best_kd1_config["temperature"]
    alpha_best = best_kd1_config["alpha"]
    
    for seed in config.get_seeds()[1:]:  # Skip first seed (already done)
        for task, train_data, eval_data, make_metric_fn, logits_cache in [
            ("sst2", sst2_train, sst2_val, make_sst2_metric_fn, sst2_logits),
        ]:
            run_id = f"KD1_{task}_S1_T{T_best}_a{alpha_best}_seed{seed}"
            
            if registry.check_run(run_id):
                print(f"✓ {run_id} already completed, skipping...")
                continue
            
            print(f"\nTraining: {run_id}")
            set_seed(seed)
            
            model, tokenizer = load_student_model(student_name, use_lora=True)
            metric_fn = make_metric_fn(tokenizer)
            kd_loss_fn = SoftTargetLoss(temperature=T_best, alpha=alpha_best)
            
            output_dir = MODELS_DIR / run_id
            training_args = get_training_args(output_dir, task, run_id)
            
            trainer = LogitKDTrainer(
                model=model,
                args=training_args,
                train_dataset=train_data,
                eval_dataset=eval_data,
                processing_class=tokenizer,
                compute_metrics=metric_fn,
                kd_loss_fn=kd_loss_fn,
                teacher_logits=logits_cache
            )
            
            train_result = trainer.train()
            eval_result = trainer.evaluate()
            
            result = {
                "run_id": run_id,
                "method": "KD1",
                "task": task,
                "student": "S1",
                "seed": seed,
                "temperature": T_best,
                "alpha": alpha_best,
                "train_loss": train_result.training_loss,
                "eval_loss": eval_result["eval_loss"],
                **{k: v for k, v in eval_result.items() if k != "eval_loss"}
            }
            
            trainer.save_model(str(output_dir / "final"))
            registry.register_run(run_id, result)
            kd1_results.append(result)
            
            del model, trainer, kd_loss_fn
            if DEVICE.type == "mps":
                torch.mps.empty_cache()
            gc.collect()
    
    print(f"\n✓ All seeds trained for best KD1 config")

In [11]:
# Save all results
all_results = baseline_results + kd1_results

df_all = pd.DataFrame(all_results)
df_all.to_csv(RUNS_DIR / "nb05_results.csv", index=False)

print(f"Saved {len(all_results)} results to {RUNS_DIR / 'nb05_results.csv'}")

Saved 1 results to /Users/pjere/Workshop/thesis-exp/results/raw_runs/nb05_results.csv


In [12]:
# Summary
print("=" * 60)
print("BASELINE AND KD1 TRAINING COMPLETE")
print("=" * 60)

print(f"""
Mode: {'FAST' if config.fast_mode else 'FULL'}
Student: {student_name}

Runs Completed:
  B0 Baseline: {len(baseline_results)} runs
  KD1 Logit-based: {len(kd1_results)} runs

Results saved to: {RUNS_DIR / 'nb05_results.csv'}
Models saved to: {MODELS_DIR}

Next Steps:
  1. Run 06_train_kd2_and_kd3.ipynb for sequence and feature KD
  2. Run 07_benchmark_and_plots.ipynb for final evaluation
""")

# Show summary table
if all_results:
    df_summary = pd.DataFrame(all_results)
    print("\nResults Summary:")
    print(df_summary[["run_id", "eval_loss"]].to_string())

BASELINE AND KD1 TRAINING COMPLETE

Mode: FAST
Student: TinyLlama/TinyLlama-1.1B-Chat-v1.0

Runs Completed:
  B0 Baseline: 1 runs
  KD1 Logit-based: 0 runs

Results saved to: /Users/pjere/Workshop/thesis-exp/results/raw_runs/nb05_results.csv
Models saved to: /Users/pjere/Workshop/thesis-exp/results/models

Next Steps:
  1. Run 06_train_kd2_and_kd3.ipynb for sequence and feature KD
  2. Run 07_benchmark_and_plots.ipynb for final evaluation


Results Summary:
              run_id  eval_loss
0  B0_sst2_S1_seed42        NaN
