# Phase 4.5: Ablation Studies

Run multiple trials across different configurations to compare:
1. **Multi-task** (λ_w=1.0, λ_d=0.1)
2. **Delta-only** (λ_w=0.0, λ_d=1.0)
3. **Weight-only** (λ_w=1.0, λ_d=0.0) - baseline

Each configuration runs 3 trials with different seeds for statistical significance.

In [None]:
import sys
import os
import shutil

IN_COLAB = 'google.colab' in sys.modules
DRIVE_OUTPUT_DIR = None

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_OUTPUT_DIR = '/content/drive/MyDrive/llgbm/outputs'
    os.makedirs(DRIVE_OUTPUT_DIR, exist_ok=True)
    !pip install -q safetensors accelerate transformers peft
    sys.path.insert(0, '/content/drive/MyDrive')

import json
import gc
from pathlib import Path
from dataclasses import asdict
from typing import List, Dict, Any
import time

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from transformers import AutoModelForCausalLM, AutoTokenizer

print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}")

In [None]:
from llgbm import (
    create_generic_probes,
    DeltaCache,
    FunctionalLoRA,
    TrainingConfig,
    MultiTaskLoss,
    DeltaOnlyLoss,
    train,
    evaluate,
)

print("[OK] llgbm imports")

## Experiment Configurations

In [None]:
# Ablation configurations
CONFIGS = {
    "multitask": {"lambda_weight": 1.0, "lambda_delta": 0.1},
    "delta_only": {"lambda_weight": 0.0, "lambda_delta": 1.0},
    "weight_only": {"lambda_weight": 1.0, "lambda_delta": 0.0},
}

NUM_TRIALS = 3
SEEDS = [42, 123, 456]
MAX_STEPS = 100  # Short runs for ablation

OUTPUT_DIR = Path("outputs/phase4_5_ablations")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print(f"Configurations: {list(CONFIGS.keys())}")
print(f"Trials per config: {NUM_TRIALS}")
print(f"Total runs: {len(CONFIGS) * NUM_TRIALS}")

## Setup (shared across runs)

In [None]:
# Base config
base_config = TrainingConfig(
    use_small_model=True,
    batch_size=2,
    gradient_accumulation_steps=4,
    max_steps=MAX_STEPS,
    warmup_steps=10,
)

TORCH_DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[base_config.dtype]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Model: {base_config.base_model}")
print(f"Device: {device}")

In [None]:
# Load base model (shared)
tokenizer = AutoTokenizer.from_pretrained(base_config.base_model, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    base_config.base_model, torch_dtype=TORCH_DTYPE, device_map=device, trust_remote_code=True
)
base_model.config.output_hidden_states = True
for p in base_model.parameters():
    p.requires_grad = False

print(f"[OK] Base model loaded")

In [None]:
# Probes & base activation (shared)
probes = create_generic_probes()[:base_config.num_probes]
probe_tokens, probe_masks = [], []
for p in probes:
    enc = tokenizer(p, return_tensors="pt", truncation=True, max_length=base_config.max_probe_length)
    probe_tokens.append(enc["input_ids"].to(device))
    probe_masks.append(enc["attention_mask"].to(device))

with torch.no_grad():
    base_acts = []
    for ids, mask in zip(probe_tokens, probe_masks):
        out = base_model(input_ids=ids, attention_mask=mask, output_hidden_states=True)
        h = out.hidden_states[-1][:, int(mask.sum()) - 1, :].squeeze(0)
        base_acts.append(h)
    base_activation = torch.stack(base_acts).mean(dim=0)

functional_lora = FunctionalLoRA(base_model, base_config.lora_rank, base_config.lora_alpha)
print(f"[OK] Probes & FunctionalLoRA ready")

In [None]:
# Dataset (shared)
from safetensors.torch import save_file

checkpoint_dir = Path(base_config.checkpoint_dir)
delta_cache = DeltaCache(base_config.delta_cache_dir)

if not list(checkpoint_dir.rglob("adapter_config.json")) or delta_cache.summary().get('count', 0) == 0:
    print("Creating sample data...")
    for name, domain in [("math_001", "math"), ("code_001", "code"), ("general_001", "general")]:
        path = checkpoint_dir / domain / name
        path.mkdir(parents=True, exist_ok=True)
        torch.manual_seed(hash(name) % 10000)
        weights = {f"layer.0.lora_A": torch.randn(8, 256) * 0.01}
        save_file(weights, path / "adapter_model.safetensors")
        json.dump({"r": 8, "peft_type": "LORA"}, open(path / "adapter_config.json", "w"))
        json.dump({"prompts": [f"Solve {domain} problem"]}, open(path / "prompts.json", "w"))
        delta_cache.save_delta(str(path), torch.randn(base_config.hidden_size).numpy() * 0.1)
    delta_cache.save_base_activation(np.zeros(base_config.hidden_size), {})

class SimpleDataset(Dataset):
    def __init__(self, checkpoint_dir, delta_cache, tokenizer, hidden_size):
        all_deltas = delta_cache.get_all_deltas()
        self.samples = [str(p.parent) for p in Path(checkpoint_dir).rglob("adapter_config.json") if str(p.parent) in all_deltas]
        self.deltas = all_deltas
        self.tokenizer = tokenizer
        self.hidden_size = hidden_size
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        path = self.samples[idx]
        pf = Path(path) / "prompts.json"
        text = json.load(open(pf)).get("prompts", [Path(path).name])[0] if pf.exists() else Path(path).name
        enc = self.tokenizer(text, max_length=256, padding="max_length", truncation=True, return_tensors="pt")
        return {"condition_ids": enc["input_ids"].squeeze(0), "attention_mask": enc["attention_mask"].squeeze(0), "delta_teacher": torch.from_numpy(self.deltas[path]).float()}
    @staticmethod
    def collate_fn(batch): return {k: torch.stack([b[k] for b in batch]) for k in batch[0]}

text_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
dataset = SimpleDataset(base_config.checkpoint_dir, delta_cache, text_tokenizer, base_config.hidden_size)
print(f"[OK] Dataset: {len(dataset)} samples")

## Generator Factory

In [None]:
class PlaceholderGenerator(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.embed = nn.Embedding(50000, 256)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=256, nhead=4, batch_first=True), num_layers=2
        )
        self.proj = nn.Linear(256, cfg.num_layers * 7 * 2)
    
    def forward(self, condition_ids, attention_mask=None):
        B = condition_ids.shape[0]
        x = self.embed(condition_ids)
        mask = ~attention_mask.bool() if attention_mask is not None else None
        x = self.encoder(x, src_key_padding_mask=mask)
        if attention_mask is not None:
            x = (x * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(1, keepdim=True).clamp(min=1)
        else:
            x = x.mean(1)
        scales = self.proj(x).view(B, self.cfg.num_layers * 7, 2)
        
        batch_weights = []
        for b in range(B):
            weights = {}
            idx = 0
            for layer in range(self.cfg.num_layers):
                for proj in ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]:
                    sa, sb = scales[b, idx, 0], scales[b, idx, 1]
                    prefix = f"model.layers.{layer}"
                    mod = "self_attn" if proj in ["q_proj","k_proj","v_proj","o_proj"] else "mlp"
                    if proj in ["k_proj", "v_proj"]: out_d = self.cfg.num_kv_heads * (self.cfg.hidden_size // self.cfg.num_heads)
                    elif proj in ["gate_proj", "up_proj"]: out_d = self.cfg.intermediate_size
                    elif proj == "down_proj": out_d = self.cfg.hidden_size
                    else: out_d = self.cfg.hidden_size
                    in_d = self.cfg.intermediate_size if proj == "down_proj" else self.cfg.hidden_size
                    A = torch.randn(self.cfg.lora_rank, in_d, device=condition_ids.device) * 0.01 * sa
                    B_ = torch.randn(out_d, self.cfg.lora_rank, device=condition_ids.device) * 0.001 * sb
                    weights[f"{prefix}.{mod}.{proj}.lora_A.weight"] = A
                    weights[f"{prefix}.{mod}.{proj}.lora_B.weight"] = B_
                    idx += 1
            batch_weights.append(weights)
        return batch_weights

def create_generator(cfg, seed):
    torch.manual_seed(seed)
    return PlaceholderGenerator(cfg).to(device)

## Run Ablations

In [None]:
def run_trial(config_name: str, lambda_weight: float, lambda_delta: float, seed: int, trial_idx: int) -> Dict[str, Any]:
    """Run a single trial and return results."""
    print(f"\n{'='*60}")
    print(f"Config: {config_name} | Trial {trial_idx+1}/{NUM_TRIALS} | Seed: {seed}")
    print(f"λ_w={lambda_weight}, λ_d={lambda_delta}")
    print(f"{'='*60}")
    
    # Create config for this trial
    config = TrainingConfig(
        use_small_model=True,
        batch_size=base_config.batch_size,
        gradient_accumulation_steps=base_config.gradient_accumulation_steps,
        max_steps=MAX_STEPS,
        warmup_steps=10,
        lambda_weight=lambda_weight,
        lambda_delta=lambda_delta,
        output_dir=str(OUTPUT_DIR / f"{config_name}_trial{trial_idx}"),
    )
    Path(config.output_dir).mkdir(parents=True, exist_ok=True)
    
    # Fresh generator
    generator = create_generator(config, seed)
    
    # Loss function
    if lambda_weight == 0:
        criterion = DeltaOnlyLoss()
    else:
        criterion = MultiTaskLoss(lambda_weight=lambda_weight, lambda_delta=lambda_delta)
    
    # Optimizer & scheduler
    optimizer = AdamW(generator.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    cosine_steps = max(1, config.max_steps - config.warmup_steps)
    scheduler = SequentialLR(
        optimizer,
        [LinearLR(optimizer, 0.1, 1.0, config.warmup_steps),
         CosineAnnealingLR(optimizer, cosine_steps, config.learning_rate * 0.01)],
        [config.warmup_steps]
    )
    
    # Dataloader (fresh for each trial)
    dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, collate_fn=dataset.collate_fn)
    
    # Train
    start_time = time.time()
    state = train(
        generator=generator,
        dataloader=dataloader,
        functional_lora=functional_lora,
        base_activation=base_activation,
        probe_tokens=probe_tokens,
        probe_masks=probe_masks,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        config=config,
        compute_dtype=TORCH_DTYPE,
    )
    train_time = time.time() - start_time
    
    # Evaluate
    eval_results = evaluate(
        generator=generator,
        dataloader=dataloader,
        functional_lora=functional_lora,
        base_activation=base_activation,
        probe_tokens=probe_tokens,
        probe_masks=probe_masks,
        criterion=criterion,
    )
    
    # Cleanup
    del generator, optimizer, scheduler
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    result = {
        "config_name": config_name,
        "trial": trial_idx,
        "seed": seed,
        "lambda_weight": lambda_weight,
        "lambda_delta": lambda_delta,
        "final_loss": state.loss_history[-1] if state.loss_history else None,
        "best_loss": state.best_loss,
        "train_time": train_time,
        **eval_results,
    }
    
    print(f"Result: loss={result['final_loss']:.4f}, cosine_sim={result.get('cosine_sim', 'N/A')}")
    return result

In [None]:
# Run all trials
all_results = []

for config_name, params in CONFIGS.items():
    for trial_idx, seed in enumerate(SEEDS[:NUM_TRIALS]):
        result = run_trial(
            config_name=config_name,
            lambda_weight=params["lambda_weight"],
            lambda_delta=params["lambda_delta"],
            seed=seed,
            trial_idx=trial_idx,
        )
        all_results.append(result)

print(f"\n\nCompleted {len(all_results)} trials!")

## Aggregate Results

In [None]:
import pandas as pd

df = pd.DataFrame(all_results)

# Aggregate by config
summary = df.groupby("config_name").agg({
    "final_loss": ["mean", "std"],
    "best_loss": ["mean", "std"],
    "cosine_sim": ["mean", "std"] if "cosine_sim" in df.columns else [],
    "train_time": ["mean"],
}).round(4)

print("\n" + "="*70)
print("ABLATION SUMMARY (mean ± std over 3 trials)")
print("="*70)
print(summary.to_string())

In [None]:
# Visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

configs = list(CONFIGS.keys())
colors = ['#2ecc71', '#3498db', '#e74c3c']

# Final Loss comparison
means = [df[df['config_name']==c]['final_loss'].mean() for c in configs]
stds = [df[df['config_name']==c]['final_loss'].std() for c in configs]
axes[0].bar(configs, means, yerr=stds, color=colors, capsize=5, alpha=0.8)
axes[0].set_ylabel('Final Loss')
axes[0].set_title('Final Loss by Configuration')
axes[0].grid(axis='y', alpha=0.3)

# Cosine similarity comparison (if available)
if 'cosine_sim' in df.columns:
    means = [df[df['config_name']==c]['cosine_sim'].mean() for c in configs]
    stds = [df[df['config_name']==c]['cosine_sim'].std() for c in configs]
    axes[1].bar(configs, means, yerr=stds, color=colors, capsize=5, alpha=0.8)
    axes[1].set_ylabel('Cosine Similarity')
    axes[1].set_title('Delta Cosine Similarity by Configuration')
    axes[1].grid(axis='y', alpha=0.3)
else:
    axes[1].text(0.5, 0.5, 'Cosine sim not computed', ha='center', va='center')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / "ablation_comparison.png", dpi=150)
plt.show()

In [None]:
# Save results
df.to_csv(OUTPUT_DIR / "all_trials.csv", index=False)
summary.to_csv(OUTPUT_DIR / "summary.csv")

final_results = {
    "configs": CONFIGS,
    "num_trials": NUM_TRIALS,
    "seeds": SEEDS[:NUM_TRIALS],
    "summary": {c: {
        "final_loss_mean": float(df[df['config_name']==c]['final_loss'].mean()),
        "final_loss_std": float(df[df['config_name']==c]['final_loss'].std()),
        "cosine_sim_mean": float(df[df['config_name']==c]['cosine_sim'].mean()) if 'cosine_sim' in df.columns else None,
        "cosine_sim_std": float(df[df['config_name']==c]['cosine_sim'].std()) if 'cosine_sim' in df.columns else None,
    } for c in configs},
    "all_trials": all_results,
}
json.dump(final_results, open(OUTPUT_DIR / "ablation_results.json", "w"), indent=2)
print(f"Saved to {OUTPUT_DIR}/")

In [None]:
# Sync to Drive
if IN_COLAB and DRIVE_OUTPUT_DIR:
    drive_dir = f"{DRIVE_OUTPUT_DIR}/phase4_5_ablations"
    if os.path.exists(drive_dir):
        shutil.rmtree(drive_dir)
    shutil.copytree(str(OUTPUT_DIR), drive_dir)
    print(f"[Drive] Synced to {drive_dir}")
else:
    print("[Local] Outputs saved to", OUTPUT_DIR)

In [None]:
print("\n" + "="*70)
print("Phase 4.5 Ablations Complete!")
print("="*70)
print(f"\nKey findings:")
for config_name in configs:
    mean_loss = df[df['config_name']==config_name]['final_loss'].mean()
    std_loss = df[df['config_name']==config_name]['final_loss'].std()
    print(f"  {config_name:12s}: {mean_loss:.4f} ± {std_loss:.4f}")