# 04 - Teacher Output Caching

**Thesis Section Reference:** Chapter 3.8 - Knowledge Distillation Methods

This notebook caches teacher model outputs for knowledge distillation:
1. **KD1 (Logit-based):** Cache soft logits for both SST-2 and SQuAD
2. **KD2 (Sequence-level):** Generate and cache teacher answers for SQuAD
3. **KD3 (Feature-based):** Cache hidden states (memory-safe chunked saving)

## Memory Management Notes
- Uses fp32 for MPS stability
- Gradient checkpointing enabled
- Periodic cache clearing
- Automatic fallback to smaller teacher if OOM

In [None]:
# Standard setup
import os
import sys
import json
import gc
from pathlib import Path

import torch

ROOT_DIR = Path.cwd().parent if Path.cwd().name == "notebooks" else Path.cwd()
sys.path.insert(0, str(ROOT_DIR / "src"))

from dotenv import load_dotenv
load_dotenv(ROOT_DIR / ".env")

from config import load_config
from utils_seed import set_seed

config = load_config(str(ROOT_DIR / "configs" / "experiment.yaml"))
config.ensure_dirs()

SEED = config.get_seeds()[0]
set_seed(SEED)

# Device setup
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

print(f"Mode: {'FAST' if config.fast_mode else 'FULL'}")
print(f"Device: {DEVICE}")
print(f"Seed: {SEED}")

In [None]:
# Set up paths
DATA_DIR = ROOT_DIR / "results" / "processed_data"
CACHE_DIR = ROOT_DIR / "results" / "teacher_cache"
CACHE_DIR.mkdir(parents=True, exist_ok=True)

print(f"Data directory: {DATA_DIR}")
print(f"Cache directory: {CACHE_DIR}")

In [None]:
# Check what already exists
cache_status = {
    "sst2_logits": (CACHE_DIR / "sst2_logits.pt").exists(),
    "squad_logits": (CACHE_DIR / "squad_logits.pt").exists(),
    "squad_answers": (CACHE_DIR / "squad_teacher_answers.json").exists(),
    "hidden_states": any(CACHE_DIR.glob("hidden_states_*.pt"))
}

print("Cache status:")
for name, exists in cache_status.items():
    status = "✓ exists" if exists else "✗ needs generation"
    print(f"  {name}: {status}")

In [None]:
# Load teacher model with fallback
from transformers import AutoModelForCausalLM, AutoTokenizer

def load_teacher_model(config, device):
    """Load teacher with automatic fallback to smaller model."""
    # Try primary teacher first
    primary = os.getenv("TEACHER_PRIMARY", config.teacher.primary)
    fallback = os.getenv("TEACHER_FALLBACK", config.teacher.fallback)
    
    for model_name in [primary, fallback]:
        try:
            print(f"Attempting to load: {model_name}")
            
            tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                trust_remote_code=True,
                cache_dir=str(ROOT_DIR / "hf_cache")
            )
            
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
            
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                trust_remote_code=True,
                torch_dtype=torch.float32,  # fp32 for MPS stability
                cache_dir=str(ROOT_DIR / "hf_cache"),
                low_cpu_mem_usage=True
            )
            
            model = model.to(device)
            model.eval()
            
            # Enable gradient checkpointing for memory
            if hasattr(model, 'gradient_checkpointing_enable'):
                model.gradient_checkpointing_enable()
            
            print(f"✓ Successfully loaded: {model_name}")
            return model, tokenizer, model_name
            
        except Exception as e:
            print(f"Failed to load {model_name}: {e}")
            if device.type == "mps":
                torch.mps.empty_cache()
            gc.collect()
            continue
    
    raise RuntimeError("Failed to load any teacher model")

# Only load if we need to generate something
needs_teacher = not all(cache_status.values())

if needs_teacher:
    teacher_model, teacher_tokenizer, teacher_name = load_teacher_model(config, DEVICE)
    print(f"\nTeacher model parameters: {sum(p.numel() for p in teacher_model.parameters()) / 1e9:.2f}B")
else:
    print("All caches exist, skipping teacher loading.")

In [None]:
# Load processed datasets
from datasets import load_from_disk

if needs_teacher:
    print("Loading processed datasets...")
    
    sst2_train = load_from_disk(str(DATA_DIR / "sst2_train"))
    squad_train = load_from_disk(str(DATA_DIR / "squad_train"))
    
    print(f"  SST-2 train: {len(sst2_train)} examples")
    print(f"  SQuAD train: {len(squad_train)} examples")
    
    # Load prompts for KD2
    with open(DATA_DIR / "squad_train_prompts.json", "r") as f:
        squad_prompts = json.load(f)

## KD1: Cache Teacher Logits

For logit-based KD, we need the teacher's output logits for each training example.
We'll store the top-k logits to save memory.

In [None]:
# Cache SST-2 teacher logits
from teacher_cache import TeacherCache

if not cache_status["sst2_logits"]:
    print("Caching SST-2 teacher logits...")
    
    cache = TeacherCache(
        model=teacher_model,
        tokenizer=teacher_tokenizer,
        device=DEVICE,
        cache_dir=CACHE_DIR
    )
    
    # Use smaller batch size for MPS
    batch_size = 2 if DEVICE.type == "mps" else 4
    
    sst2_logits = cache.cache_logits(
        dataset=sst2_train,
        batch_size=batch_size,
        top_k=100,  # Store only top-100 logits per position
        task_name="sst2"
    )
    
    # Save
    torch.save(sst2_logits, CACHE_DIR / "sst2_logits.pt")
    print(f"✓ Saved SST-2 logits: {CACHE_DIR / 'sst2_logits.pt'}")
    
    # Clean up
    del sst2_logits
    if DEVICE.type == "mps":
        torch.mps.empty_cache()
    gc.collect()
else:
    print("✓ SST-2 logits already cached")

In [None]:
# Cache SQuAD teacher logits
if not cache_status["squad_logits"]:
    print("Caching SQuAD teacher logits...")
    
    if 'cache' not in dir():
        cache = TeacherCache(
            model=teacher_model,
            tokenizer=teacher_tokenizer,
            device=DEVICE,
            cache_dir=CACHE_DIR
        )
    
    batch_size = 1 if DEVICE.type == "mps" else 2  # SQuAD has longer sequences
    
    squad_logits = cache.cache_logits(
        dataset=squad_train,
        batch_size=batch_size,
        top_k=100,
        task_name="squad"
    )
    
    # Save
    torch.save(squad_logits, CACHE_DIR / "squad_logits.pt")
    print(f"✓ Saved SQuAD logits: {CACHE_DIR / 'squad_logits.pt'}")
    
    del squad_logits
    if DEVICE.type == "mps":
        torch.mps.empty_cache()
    gc.collect()
else:
    print("✓ SQuAD logits already cached")

## KD2: Generate Teacher Answers (Sequence-level KD)

For sequence-level KD on SQuAD, we generate teacher's predicted answers.
The student learns to mimic the teacher's generated sequences.

In [None]:
# Generate teacher answers for SQuAD
from tqdm.auto import tqdm

if not cache_status["squad_answers"]:
    print("Generating teacher answers for SQuAD...")
    
    teacher_answers = []
    
    # Generation config
    gen_config = {
        "max_new_tokens": 64,
        "do_sample": False,  # Greedy for reproducibility
        "pad_token_id": teacher_tokenizer.pad_token_id,
        "eos_token_id": teacher_tokenizer.eos_token_id,
    }
    
    batch_size = 1  # One at a time for MPS stability
    
    with torch.no_grad():
        for i in tqdm(range(0, len(squad_prompts), batch_size), desc="Generating"):
            batch = squad_prompts[i:i+batch_size]
            
            for item in batch:
                inputs = teacher_tokenizer(
                    item["prompt"],
                    return_tensors="pt",
                    max_length=config.get_max_length("squad") - 64,
                    truncation=True
                ).to(DEVICE)
                
                outputs = teacher_model.generate(
                    **inputs,
                    **gen_config
                )
                
                # Decode only new tokens
                generated = teacher_tokenizer.decode(
                    outputs[0][inputs["input_ids"].shape[1]:],
                    skip_special_tokens=True
                ).strip()
                
                teacher_answers.append({
                    "id": item["id"],
                    "prompt": item["prompt"],
                    "teacher_answer": generated,
                    "gold_answers": item["gold_answers"]
                })
            
            # Periodic cleanup
            if i % 50 == 0 and DEVICE.type == "mps":
                torch.mps.empty_cache()
    
    # Save
    with open(CACHE_DIR / "squad_teacher_answers.json", "w") as f:
        json.dump(teacher_answers, f, indent=2)
    
    print(f"\n✓ Generated {len(teacher_answers)} teacher answers")
    
    # Show samples
    print("\nSample teacher answers:")
    for sample in teacher_answers[:3]:
        print(f"  Q: {sample['prompt'].split('Question:')[1].split('Context:')[0].strip()[:50]}...")
        print(f"  Teacher: {sample['teacher_answer'][:50]}...")
        print(f"  Gold: {sample['gold_answers'][0]}")
        print()
else:
    print("✓ SQuAD teacher answers already cached")

## KD3: Cache Hidden States (Feature-based KD)

For feature-based KD, we cache teacher's hidden states from selected layers.
This is memory-intensive, so we:
1. Store only selected layers (e.g., every 4th layer)
2. Use float16 for storage
3. Save in chunks

In [None]:
# Cache hidden states (memory-safe)
if not cache_status["hidden_states"]:
    print("Caching hidden states for KD3...")
    print("Note: This is memory-intensive. Using chunked saving.")
    
    # Configuration
    LAYER_STRIDE = 4  # Cache every 4th layer
    CHUNK_SIZE = 100  # Save every 100 examples
    USE_SST2_ONLY = config.fast_mode  # In fast mode, only cache SST-2
    
    dataset = sst2_train
    task_name = "sst2"
    
    num_layers = teacher_model.config.num_hidden_layers
    selected_layers = list(range(0, num_layers, LAYER_STRIDE))
    print(f"  Caching layers: {selected_layers} (of {num_layers} total)")
    
    all_hidden_states = []
    chunk_idx = 0
    
    with torch.no_grad():
        for i in tqdm(range(len(dataset)), desc="Extracting hidden states"):
            example = dataset[i]
            
            # Prepare input
            input_ids = torch.tensor([example["input_ids"]], device=DEVICE)
            attention_mask = torch.tensor([example["attention_mask"]], device=DEVICE)
            
            # Get hidden states
            outputs = teacher_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True
            )
            
            # Extract selected layers and convert to float16
            layer_states = []
            for layer_idx in selected_layers:
                # Mean pool over sequence length to reduce memory
                layer_output = outputs.hidden_states[layer_idx + 1]  # +1 for embeddings
                pooled = layer_output.mean(dim=1)  # [1, hidden_size]
                layer_states.append(pooled.cpu().half())
            
            # Stack layers: [num_selected_layers, hidden_size]
            stacked = torch.cat(layer_states, dim=0)
            all_hidden_states.append(stacked)
            
            # Save chunk
            if len(all_hidden_states) >= CHUNK_SIZE:
                chunk_tensor = torch.stack(all_hidden_states)
                torch.save(chunk_tensor, CACHE_DIR / f"hidden_states_{task_name}_{chunk_idx}.pt")
                print(f"  Saved chunk {chunk_idx} ({len(all_hidden_states)} examples)")
                all_hidden_states = []
                chunk_idx += 1
                
                if DEVICE.type == "mps":
                    torch.mps.empty_cache()
    
    # Save final chunk
    if all_hidden_states:
        chunk_tensor = torch.stack(all_hidden_states)
        torch.save(chunk_tensor, CACHE_DIR / f"hidden_states_{task_name}_{chunk_idx}.pt")
        print(f"  Saved chunk {chunk_idx} ({len(all_hidden_states)} examples)")
    
    # Save metadata
    hidden_state_meta = {
        "task": task_name,
        "num_chunks": chunk_idx + 1,
        "selected_layers": selected_layers,
        "total_layers": num_layers,
        "hidden_size": teacher_model.config.hidden_size,
        "pooling": "mean",
        "dtype": "float16"
    }
    
    with open(CACHE_DIR / f"hidden_states_{task_name}_meta.json", "w") as f:
        json.dump(hidden_state_meta, f, indent=2)
    
    print(f"\n✓ Saved hidden states in {chunk_idx + 1} chunks")
else:
    print("✓ Hidden states already cached")

In [None]:
# Cleanup
if needs_teacher:
    print("Cleaning up...")
    del teacher_model
    if 'cache' in dir():
        del cache
    if DEVICE.type == "mps":
        torch.mps.empty_cache()
    gc.collect()
    print("✓ Memory freed")

In [None]:
# Verify all caches
print("Verifying cached files...")
print()

cache_files = list(CACHE_DIR.iterdir())
total_size = sum(f.stat().st_size for f in cache_files if f.is_file())

print(f"Cache directory: {CACHE_DIR}")
print(f"Total size: {total_size / 1024 / 1024:.2f} MB")
print()

for f in sorted(cache_files):
    if f.is_file():
        size_mb = f.stat().st_size / 1024 / 1024
        print(f"  {f.name}: {size_mb:.2f} MB")

In [None]:
# Summary
print("=" * 60)
print("TEACHER OUTPUT CACHING COMPLETE")
print("=" * 60)
print(f"""
Teacher Model: {teacher_name if 'teacher_name' in dir() else 'N/A (used cache)'}
Mode: {'FAST' if config.fast_mode else 'FULL'}

Cached Outputs:
  KD1 (Logit-based):
    - sst2_logits.pt
    - squad_logits.pt
  
  KD2 (Sequence-level):
    - squad_teacher_answers.json
  
  KD3 (Feature-based):
    - hidden_states_*.pt (chunked)

Next Steps:
  1. Run 05_train_baseline_and_kd1.ipynb for baseline and logit KD
  2. Run 06_train_kd2_and_kd3.ipynb for sequence and feature KD
""")