# Phase 4.5: Ablation Studies

**Prerequisites:** Run `train_lora_adapters.ipynb` first to create real LoRA adapters and deltas.

This notebook compares different training configurations:
1. **Multi-task** (λ_w=1.0, λ_d=0.1) - Both weight and delta supervision
2. **Delta-only** (λ_w=0.0, λ_d=1.0) - Behavioral supervision only
3. **Weight-only** (λ_w=1.0, λ_d=0.0) - Traditional DnD baseline

Each configuration runs 3 trials with different seeds for statistical significance.

In [None]:
import sys
import os
import shutil

IN_COLAB = 'google.colab' in sys.modules
DRIVE_OUTPUT_DIR = None

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_OUTPUT_DIR = '/content/drive/MyDrive/llgbm/outputs'
    os.makedirs(DRIVE_OUTPUT_DIR, exist_ok=True)
    !pip install -q safetensors accelerate transformers peft
    sys.path.insert(0, '/content/drive/MyDrive')

import json
import gc
from pathlib import Path
from dataclasses import asdict
from typing import List, Dict, Any
import time

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from transformers import AutoModelForCausalLM, AutoTokenizer

print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}")

In [None]:
from llgbm import (
    create_generic_probes,
    DeltaCache,
    FunctionalLoRA,
    TrainingConfig,
    MultiTaskLoss,
    DeltaOnlyLoss,
    train,
    evaluate,
)

print("[OK] llgbm imports")

## Experiment Configurations

In [None]:
# Ablation configurations
CONFIGS = {
    "multitask": {"lambda_weight": 1.0, "lambda_delta": 0.1},
    "delta_only": {"lambda_weight": 0.0, "lambda_delta": 1.0},
    "weight_only": {"lambda_weight": 1.0, "lambda_delta": 0.0},
}

NUM_TRIALS = 3
SEEDS = [42, 123, 456]
MAX_STEPS = 100  # Short runs for ablation

OUTPUT_DIR = Path("outputs/phase4_5_ablations")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print(f"Configurations: {list(CONFIGS.keys())}")
print(f"Trials per config: {NUM_TRIALS}")
print(f"Total runs: {len(CONFIGS) * NUM_TRIALS}")

## Setup (shared across runs)

In [None]:
# Base config
base_config = TrainingConfig(
    use_small_model=True,
    batch_size=2,
    gradient_accumulation_steps=4,
    max_steps=MAX_STEPS,
    warmup_steps=10,
)

TORCH_DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[base_config.dtype]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Model: {base_config.base_model}")
print(f"Device: {device}")

In [None]:
# Load base model (shared)
tokenizer = AutoTokenizer.from_pretrained(base_config.base_model, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    base_config.base_model, torch_dtype=TORCH_DTYPE, device_map=device, trust_remote_code=True
)
base_model.config.output_hidden_states = True
for p in base_model.parameters():
    p.requires_grad = False

print(f"[OK] Base model loaded")

In [None]:
# Probes & base activation (shared)
probes = create_generic_probes()[:base_config.num_probes]
probe_tokens, probe_masks = [], []
for p in probes:
    enc = tokenizer(p, return_tensors="pt", truncation=True, max_length=base_config.max_probe_length)
    probe_tokens.append(enc["input_ids"].to(device))
    probe_masks.append(enc["attention_mask"].to(device))

with torch.no_grad():
    base_acts = []
    for ids, mask in zip(probe_tokens, probe_masks):
        out = base_model(input_ids=ids, attention_mask=mask, output_hidden_states=True)
        h = out.hidden_states[-1][:, int(mask.sum()) - 1, :].squeeze(0)
        base_acts.append(h)
    base_activation = torch.stack(base_acts).mean(dim=0)

functional_lora = FunctionalLoRA(base_model, base_config.lora_rank, base_config.lora_alpha)
print(f"[OK] Probes & FunctionalLoRA ready")

In [None]:
# Dataset - Load real adapters and deltas from checkpoints/
from safetensors.torch import load_file

checkpoint_dir = Path(base_config.checkpoint_dir)
deltas_dir = checkpoint_dir / "deltas"

# Check if real data exists
manifest_path = checkpoint_dir / "manifest.json"
delta_manifest_path = deltas_dir / "delta_manifest.json"

if manifest_path.exists() and delta_manifest_path.exists():
    print("[OK] Found real adapter data")
    with open(manifest_path) as f:
        adapter_manifest = json.load(f)
    with open(delta_manifest_path) as f:
        delta_manifest = json.load(f)
    
    # Load base activation
    base_act_file = deltas_dir / delta_manifest["base_activation_file"]
    cached_base_activation = np.load(base_act_file)
    print(f"  Adapters: {len(adapter_manifest['adapters'])}")
    print(f"  Deltas: {len(delta_manifest['adapters'])}")
else:
    print("[WARNING] No real data found. Run train_lora_adapters.ipynb first!")
    print("  Creating minimal fake data for testing...")
    adapter_manifest = {"adapters": []}
    delta_manifest = {"adapters": {}}
    cached_base_activation = None


class RealAdapterDataset(Dataset):
    """Dataset that loads real LoRA adapters and their deltas."""
    
    def __init__(self, checkpoint_dir, deltas_dir, tokenizer, config):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.deltas_dir = Path(deltas_dir)
        self.tokenizer = tokenizer
        self.config = config
        
        # Load manifests
        manifest_path = self.checkpoint_dir / "manifest.json"
        delta_manifest_path = self.deltas_dir / "delta_manifest.json"
        
        if manifest_path.exists() and delta_manifest_path.exists():
            with open(manifest_path) as f:
                self.adapter_manifest = json.load(f)
            with open(delta_manifest_path) as f:
                self.delta_manifest = json.load(f)
            
            # Build sample list - only adapters that have deltas
            self.samples = []
            for adapter in self.adapter_manifest["adapters"]:
                name = adapter["name"]
                if name in self.delta_manifest["adapters"]:
                    self.samples.append({
                        "name": name,
                        "path": adapter["path"],
                        "task": adapter["task"],
                        "delta_file": self.delta_manifest["adapters"][name]["delta_file"],
                    })
        else:
            self.samples = []
            self.adapter_manifest = {"adapters": []}
            self.delta_manifest = {"adapters": {}}
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Load prompts for conditioning
        prompts_file = Path(sample["path"]) / "prompts.json"
        if prompts_file.exists():
            with open(prompts_file) as f:
                prompts_data = json.load(f)
            # Use first prompt as condition
            text = prompts_data["prompts"][0] if prompts_data["prompts"] else sample["name"]
        else:
            text = sample["name"]
        
        # Tokenize condition
        enc = self.tokenizer(
            text, 
            max_length=256, 
            padding="max_length", 
            truncation=True, 
            return_tensors="pt"
        )
        
        # Load delta
        delta_path = self.deltas_dir / sample["delta_file"]
        delta = np.load(delta_path)
        
        # Load LoRA weights (for weight supervision)
        adapter_weights_file = Path(sample["path"]) / "adapter_model.safetensors"
        if adapter_weights_file.exists():
            lora_weights = load_file(adapter_weights_file)
        else:
            lora_weights = {}
        
        return {
            "condition_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "delta_teacher": torch.from_numpy(delta).float(),
            "adapter_name": sample["name"],
            "lora_weights": lora_weights,
        }
    
    @staticmethod
    def collate_fn(batch):
        return {
            "condition_ids": torch.stack([b["condition_ids"] for b in batch]),
            "attention_mask": torch.stack([b["attention_mask"] for b in batch]),
            "delta_teacher": torch.stack([b["delta_teacher"] for b in batch]),
            "adapter_names": [b["adapter_name"] for b in batch],
            "lora_weights": [b["lora_weights"] for b in batch],
        }


# Create dataset
text_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
dataset = RealAdapterDataset(checkpoint_dir, deltas_dir, text_tokenizer, base_config)
print(f"[OK] Dataset: {len(dataset)} samples")

if len(dataset) == 0:
    print("\n⚠️  No samples found! Please run train_lora_adapters.ipynb first to create adapters.")

## Generator Factory

In [None]:
class LoRAGenerator(nn.Module):
    """
    Simplified LoRA generator for ablation studies.
    
    Takes text condition and generates LoRA weights (A and B matrices) for all layers.
    Uses a transformer encoder to process condition, then projects to LoRA weights.
    """
    
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        
        # Text encoder
        self.embed = nn.Embedding(50000, 256)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=256, nhead=4, batch_first=True, dropout=0.1),
            num_layers=3
        )
        
        # Compute total LoRA parameters needed
        # 7 projections per layer: q, k, v, o, gate, up, down
        self.num_projections = cfg.num_layers * 7
        
        # For each projection, we need to generate A (rank x in_dim) and B (out_dim x rank)
        # We'll use a hypernetwork approach: generate per-projection embeddings, then decode to weights
        
        # Per-projection embedding dimension
        self.proj_embed_dim = 512
        
        # Generate projection embeddings from condition
        self.proj_embeddings = nn.Linear(256, self.num_projections * self.proj_embed_dim)
        
        # Weight decoders - shared across projections but scaled
        self.lora_rank = cfg.lora_rank
        
        # Small MLPs to decode A and B matrices
        # A: (rank, in_dim), B: (out_dim, rank)
        # We'll generate low-rank factors and compose
        self.A_decoder = nn.Sequential(
            nn.Linear(self.proj_embed_dim, 256),
            nn.GELU(),
            nn.Linear(256, cfg.lora_rank * 64),  # Generate in chunks
        )
        self.B_decoder = nn.Sequential(
            nn.Linear(self.proj_embed_dim, 256),
            nn.GELU(), 
            nn.Linear(256, cfg.lora_rank * 64),
        )
        
        # Learnable scale factors per layer/projection type
        self.scales = nn.Parameter(torch.ones(self.num_projections, 2) * 0.01)
        
        # Cache dimension info
        self._build_dim_info()
    
    def _build_dim_info(self):
        """Build dimension info for each projection."""
        self.dim_info = []
        for layer in range(self.cfg.num_layers):
            for proj in ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]:
                if proj in ["k_proj", "v_proj"]:
                    out_d = self.cfg.num_kv_heads * (self.cfg.hidden_size // self.cfg.num_heads)
                elif proj in ["gate_proj", "up_proj"]:
                    out_d = self.cfg.intermediate_size
                elif proj == "down_proj":
                    out_d = self.cfg.hidden_size
                else:
                    out_d = self.cfg.hidden_size
                
                in_d = self.cfg.intermediate_size if proj == "down_proj" else self.cfg.hidden_size
                
                mod = "self_attn" if proj in ["q_proj", "k_proj", "v_proj", "o_proj"] else "mlp"
                prefix = f"model.layers.{layer}.{mod}.{proj}"
                
                self.dim_info.append({
                    "layer": layer,
                    "proj": proj,
                    "in_dim": in_d,
                    "out_dim": out_d,
                    "A_key": f"{prefix}.lora_A.weight",
                    "B_key": f"{prefix}.lora_B.weight",
                })
    
    def forward(self, condition_ids, attention_mask=None):
        B = condition_ids.shape[0]
        
        # Encode condition
        x = self.embed(condition_ids)
        if attention_mask is not None:
            key_padding_mask = ~attention_mask.bool()
        else:
            key_padding_mask = None
        
        x = self.encoder(x, src_key_padding_mask=key_padding_mask)
        
        # Pool to single vector
        if attention_mask is not None:
            x = (x * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(1, keepdim=True).clamp(min=1)
        else:
            x = x.mean(1)
        
        # Generate per-projection embeddings: (B, num_proj, proj_embed_dim)
        proj_embeds = self.proj_embeddings(x).view(B, self.num_projections, self.proj_embed_dim)
        
        # Decode to weights
        batch_weights = []
        for b in range(B):
            weights = {}
            for idx, info in enumerate(self.dim_info):
                embed = proj_embeds[b, idx]
                
                # Decode A and B base patterns
                A_base = self.A_decoder(embed).view(self.lora_rank, -1)
                B_base = self.B_decoder(embed).view(-1, self.lora_rank)
                
                # Expand to full dimensions via outer product with learned patterns
                # This is more parameter-efficient than generating full matrices
                in_d, out_d = info["in_dim"], info["out_dim"]
                
                # Use periodic extension to match dimensions
                A_full = A_base[:, :in_d % 64 or 64].repeat(1, (in_d // 64) + 1)[:, :in_d]
                B_full = B_base[:out_d % 64 or 64, :].repeat((out_d // 64) + 1, 1)[:out_d, :]
                
                # Apply learned scales
                scale_a, scale_b = self.scales[idx]
                A = A_full * scale_a
                B = B_full * scale_b
                
                weights[info["A_key"]] = A
                weights[info["B_key"]] = B
            
            batch_weights.append(weights)
        
        return batch_weights


def create_generator(cfg, seed):
    torch.manual_seed(seed)
    gen = LoRAGenerator(cfg).to(device)
    # Print param count
    num_params = sum(p.numel() for p in gen.parameters() if p.requires_grad)
    print(f"  Generator params: {num_params:,}")
    return gen

## Run Ablations

In [None]:
def run_trial(config_name: str, lambda_weight: float, lambda_delta: float, seed: int, trial_idx: int) -> Dict[str, Any]:
    """Run a single trial and return results."""
    print(f"\n{'='*60}")
    print(f"Config: {config_name} | Trial {trial_idx+1}/{NUM_TRIALS} | Seed: {seed}")
    print(f"λ_w={lambda_weight}, λ_d={lambda_delta}")
    print(f"{'='*60}")
    
    # Create config for this trial
    config = TrainingConfig(
        use_small_model=True,
        batch_size=base_config.batch_size,
        gradient_accumulation_steps=base_config.gradient_accumulation_steps,
        max_steps=MAX_STEPS,
        warmup_steps=10,
        lambda_weight=lambda_weight,
        lambda_delta=lambda_delta,
        output_dir=str(OUTPUT_DIR / f"{config_name}_trial{trial_idx}"),
    )
    Path(config.output_dir).mkdir(parents=True, exist_ok=True)
    
    # Fresh generator
    generator = create_generator(config, seed)
    
    # Loss function
    if lambda_weight == 0:
        criterion = DeltaOnlyLoss()
    else:
        criterion = MultiTaskLoss(lambda_weight=lambda_weight, lambda_delta=lambda_delta)
    
    # Optimizer & scheduler
    optimizer = AdamW(generator.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    cosine_steps = max(1, config.max_steps - config.warmup_steps)
    scheduler = SequentialLR(
        optimizer,
        [LinearLR(optimizer, 0.1, 1.0, config.warmup_steps),
         CosineAnnealingLR(optimizer, cosine_steps, config.learning_rate * 0.01)],
        [config.warmup_steps]
    )
    
    # Dataloader (fresh for each trial)
    dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, collate_fn=dataset.collate_fn)
    
    # Train
    start_time = time.time()
    state = train(
        generator=generator,
        dataloader=dataloader,
        functional_lora=functional_lora,
        base_activation=base_activation,
        probe_tokens=probe_tokens,
        probe_masks=probe_masks,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        config=config,
        compute_dtype=TORCH_DTYPE,
    )
    train_time = time.time() - start_time
    
    # Evaluate
    eval_results = evaluate(
        generator=generator,
        dataloader=dataloader,
        functional_lora=functional_lora,
        base_activation=base_activation,
        probe_tokens=probe_tokens,
        probe_masks=probe_masks,
        criterion=criterion,
    )
    
    # Cleanup
    del generator, optimizer, scheduler
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    result = {
        "config_name": config_name,
        "trial": trial_idx,
        "seed": seed,
        "lambda_weight": lambda_weight,
        "lambda_delta": lambda_delta,
        "final_loss": state.loss_history[-1] if state.loss_history else None,
        "best_loss": state.best_loss,
        "train_time": train_time,
        **eval_results,
    }
    
    # Print results - use correct metric key 'mean_cosine' from evaluate()
    cosine = result.get('mean_cosine', None)
    cosine_str = f"{cosine:.4f}" if cosine is not None else "N/A"
    print(f"Result: loss={result['final_loss']:.4f}, mean_cosine={cosine_str}")
    
    return result

In [None]:
# Run all trials
all_results = []

for config_name, params in CONFIGS.items():
    for trial_idx, seed in enumerate(SEEDS[:NUM_TRIALS]):
        result = run_trial(
            config_name=config_name,
            lambda_weight=params["lambda_weight"],
            lambda_delta=params["lambda_delta"],
            seed=seed,
            trial_idx=trial_idx,
        )
        all_results.append(result)

print(f"\n\nCompleted {len(all_results)} trials!")

## Aggregate Results

In [None]:
import pandas as pd

df = pd.DataFrame(all_results)

# Aggregate by config
agg_dict = {
    "final_loss": ["mean", "std"],
    "best_loss": ["mean", "std"],
    "train_time": ["mean"],
}

# Add cosine metrics if available (evaluate() returns 'mean_cosine')
if "mean_cosine" in df.columns:
    agg_dict["mean_cosine"] = ["mean", "std"]

summary = df.groupby("config_name").agg(agg_dict).round(4)

print("\n" + "="*70)
print("ABLATION SUMMARY (mean ± std over 3 trials)")
print("="*70)
print(summary.to_string())

In [None]:
# Visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

configs = list(CONFIGS.keys())
colors = ['#2ecc71', '#3498db', '#e74c3c']

# Final Loss comparison
means = [df[df['config_name']==c]['final_loss'].mean() for c in configs]
stds = [df[df['config_name']==c]['final_loss'].std() for c in configs]
axes[0].bar(configs, means, yerr=stds, color=colors, capsize=5, alpha=0.8)
axes[0].set_ylabel('Final Loss')
axes[0].set_title('Final Loss by Configuration')
axes[0].grid(axis='y', alpha=0.3)

# Cosine similarity comparison - use correct key 'mean_cosine'
if 'mean_cosine' in df.columns and df['mean_cosine'].notna().any():
    means = [df[df['config_name']==c]['mean_cosine'].mean() for c in configs]
    stds = [df[df['config_name']==c]['mean_cosine'].std() for c in configs]
    axes[1].bar(configs, means, yerr=stds, color=colors, capsize=5, alpha=0.8)
    axes[1].set_ylabel('Mean Cosine Similarity')
    axes[1].set_title('Delta Cosine Similarity by Configuration')
    axes[1].grid(axis='y', alpha=0.3)
    axes[1].set_ylim(-1, 1)  # Cosine range
else:
    axes[1].text(0.5, 0.5, 'Cosine similarity not available', ha='center', va='center', transform=axes[1].transAxes)
    axes[1].set_title('Delta Cosine Similarity')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / "ablation_comparison.png", dpi=150)
plt.show()

In [None]:
# Save results
df.to_csv(OUTPUT_DIR / "all_trials.csv", index=False)
summary.to_csv(OUTPUT_DIR / "summary.csv")

final_results = {
    "configs": CONFIGS,
    "num_trials": NUM_TRIALS,
    "seeds": SEEDS[:NUM_TRIALS],
    "summary": {c: {
        "final_loss_mean": float(df[df['config_name']==c]['final_loss'].mean()),
        "final_loss_std": float(df[df['config_name']==c]['final_loss'].std()),
        "mean_cosine_mean": float(df[df['config_name']==c]['mean_cosine'].mean()) if 'mean_cosine' in df.columns else None,
        "mean_cosine_std": float(df[df['config_name']==c]['mean_cosine'].std()) if 'mean_cosine' in df.columns else None,
    } for c in configs},
    "all_trials": all_results,
}

with open(OUTPUT_DIR / "ablation_results.json", "w") as f:
    json.dump(final_results, f, indent=2)

print(f"Saved to {OUTPUT_DIR}/")

In [None]:
# Sync to Drive
if IN_COLAB and DRIVE_OUTPUT_DIR:
    drive_dir = f"{DRIVE_OUTPUT_DIR}/phase4_5_ablations"
    if os.path.exists(drive_dir):
        shutil.rmtree(drive_dir)
    shutil.copytree(str(OUTPUT_DIR), drive_dir)
    print(f"[Drive] Synced to {drive_dir}")
else:
    print("[Local] Outputs saved to", OUTPUT_DIR)

In [None]:
print("\n" + "="*70)
print("Phase 4.5 Ablations Complete!")
print("="*70)

print(f"\nDataset: {len(dataset)} samples")
print(f"Trials per config: {NUM_TRIALS}")
print(f"Steps per trial: {MAX_STEPS}")

print(f"\nKey findings (loss | cosine):")
for config_name in configs:
    mean_loss = df[df['config_name']==config_name]['final_loss'].mean()
    std_loss = df[df['config_name']==config_name]['final_loss'].std()
    
    if 'mean_cosine' in df.columns:
        mean_cos = df[df['config_name']==config_name]['mean_cosine'].mean()
        std_cos = df[df['config_name']==config_name]['mean_cosine'].std()
        print(f"  {config_name:12s}: {mean_loss:.4f} ± {std_loss:.4f} | {mean_cos:.4f} ± {std_cos:.4f}")
    else:
        print(f"  {config_name:12s}: {mean_loss:.4f} ± {std_loss:.4f}")

print(f"\nOutputs saved to: {OUTPUT_DIR}")