# Phase 4.5: Ablation Studies

**Prerequisites:** Run `train_lora_adapters.ipynb` first to create real LoRA adapters and deltas.

This notebook compares different training configurations:
1. **Multi-task** (λ_w=1.0, λ_d=0.1) - Both weight and delta supervision
2. **Multi-task balanced** (λ_w=0.5, λ_d=0.5) - Equal weight/delta supervision
3. **Delta-only** (λ_w=0.0, λ_d=1.0) - Behavioral supervision only
4. **Weight-only** (λ_w=1.0, λ_d=0.0) - Traditional DnD baseline

Each configuration runs 3 trials with different seeds for statistical significance.

## Configuration

In [None]:
import sys
import os

# Colab setup
IN_COLAB = 'google.colab' in sys.modules
DRIVE_OUTPUT_DIR = None

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_OUTPUT_DIR = '/content/drive/MyDrive/llgbm/outputs'
    os.makedirs(DRIVE_OUTPUT_DIR, exist_ok=True)
    sys.path.insert(0, '/content/drive/MyDrive')
    CHECKPOINT_DIR = '/content/drive/MyDrive/llgbm/checkpoints'
    DELTAS_DIR = CHECKPOINT_DIR + '/deltas'
else:
    CHECKPOINT_DIR = './checkpoints'
    DELTAS_DIR = './llgbm/deltas'

In [None]:
from llgbm import AblationConfig, run_ablations, plot_ablation_results

print("[OK] llgbm imports")

In [None]:
# Experiment configuration - modify these parameters as needed
config = AblationConfig(
    # Ablation configurations to compare
    configs={
        "multitask": {"lambda_weight": 1.0, "lambda_delta": 0.1},
        "multitask2": {"lambda_weight": 0.5, "lambda_delta": 0.5},
        "delta_only": {"lambda_weight": 0.0, "lambda_delta": 1.0},
        "weight_only": {"lambda_weight": 1.0, "lambda_delta": 0.0},
    },
    
    # Trial settings
    num_trials=3,
    seeds=[42, 123, 456],
    num_steps=100,
    
    # Paths
    checkpoint_dir=CHECKPOINT_DIR,
    deltas_dir=DELTAS_DIR,
    output_dir="outputs/phase4_5_ablations",
    
    # Model settings
    use_small_model=True,  # Qwen2.5-0.5B
    batch_size=8,
    gradient_accumulation_steps=1,
    warmup_steps=50,
    
    # Colab settings
    in_colab=IN_COLAB,
    drive_output_dir=DRIVE_OUTPUT_DIR,
)

print(f"Configurations: {list(config.configs.keys())}")
print(f"Total runs: {len(config.configs) * config.num_trials}")

## Run Ablations

In [None]:
# Run all ablation experiments
results = run_ablations(config)

## Results Summary

In [None]:
import pandas as pd

df = results["dataframe"]

# Aggregate by config
agg_dict = {
    "final_loss": ["mean", "std"],
    "best_loss": ["mean", "std"],
    "train_time": ["mean"],
}
if "mean_cosine" in df.columns:
    agg_dict["mean_cosine"] = ["mean", "std"]

summary = df.groupby("config_name").agg(agg_dict).round(4)

print("\n" + "="*70)
print("ABLATION SUMMARY (mean +/- std over 3 trials)")
print("="*70)
print(summary.to_string())

In [None]:
# Visualization
from pathlib import Path
plot_ablation_results(df, Path(config.output_dir), list(config.configs.keys()))

In [None]:
# Sync to Drive if in Colab
if IN_COLAB and DRIVE_OUTPUT_DIR:
    import shutil
    drive_dir = f"{DRIVE_OUTPUT_DIR}/phase4_5_ablations"
    if os.path.exists(drive_dir):
        shutil.rmtree(drive_dir)
    shutil.copytree(str(config.output_dir), drive_dir)
    print(f"[Drive] Synced to {drive_dir}")
else:
    print("[Local] Outputs saved to", config.output_dir)

In [None]:
# Final summary
print("\n" + "="*70)
print("Phase 4.5 Ablations Complete!")
print("="*70)

print(f"\nKey findings (loss | cosine):")
for config_name in config.configs.keys():
    mean_loss = df[df['config_name']==config_name]['final_loss'].mean()
    std_loss = df[df['config_name']==config_name]['final_loss'].std()
    
    if 'mean_cosine' in df.columns:
        mean_cos = df[df['config_name']==config_name]['mean_cosine'].mean()
        std_cos = df[df['config_name']==config_name]['mean_cosine'].std()
        print(f"  {config_name:12s}: {mean_loss:.4f} +/- {std_loss:.4f} | {mean_cos:.4f} +/- {std_cos:.4f}")
    else:
        print(f"  {config_name:12s}: {mean_loss:.4f} +/- {std_loss:.4f}")

print(f"\nOutputs saved to: {config.output_dir}")

## Task Performance Evaluation (Optional)

Evaluate generated LoRAs using **eval loss** on held-out task data.

In [None]:
from llgbm import (
    compute_base_eval_loss,
    compute_accuracy_with_lora_batched,
    create_generator,
    load_checkpoint,
    TrainingConfig,
    FunctionalLoRA,
    RealAdapterDataset,
)
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import torch
from pathlib import Path

# Configuration
CONFIGS_TO_EVAL = ["delta_only", "multitask", "weight_only"]
NUM_TRIALS = 3
task_types = {"arc_e": "mcq", "boolq": "bool", "gsm8k": "gsm8k"}

# Load base model and tokenizer
base_config = TrainingConfig(use_small_model=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"[1] Loading base model: {base_config.base_model}")
base_model = AutoModelForCausalLM.from_pretrained(
    base_config.base_model,
    torch_dtype=torch.bfloat16,
    device_map=device,
    trust_remote_code=True,
)
base_model.eval()
tokenizer = AutoTokenizer.from_pretrained(base_config.base_model, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load dataset for conditioning
print("[2] Loading dataset")
text_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
dataset = RealAdapterDataset(CHECKPOINT_DIR, DELTAS_DIR, text_tokenizer, base_config)
print(f"    {len(dataset)} samples")

# Create FunctionalLoRA wrapper
functional_lora = FunctionalLoRA(
    base_model,
    lora_rank=base_config.lora_rank,
    lora_alpha=base_config.lora_alpha,
)

# Load eval data
print("[3] Loading eval data")
eval_splits_dir = Path(CHECKPOINT_DIR) / "eval_splits"
eval_data = {}
for task in ["arc_e", "boolq", "gsm8k"]:
    eval_file = eval_splits_dir / f"{task}_eval.json"
    if eval_file.exists():
        with open(eval_file) as f:
            eval_data[task] = json.load(f)
        print(f"    {task}: {len(eval_data[task])} samples")

# Evaluate base model first
print("\n" + "="*60)
print("BASE MODEL (no LoRA)")
print("="*60)
for task, samples in eval_data.items():
    loss = compute_base_eval_loss(base_model, samples, tokenizer)
    print(f"  {task}: loss={loss:.4f}")

# Collect all results
all_results = []

# Evaluate each config and trial
for config_name in CONFIGS_TO_EVAL:
    for trial in range(NUM_TRIALS):
        checkpoint_path = Path(f"outputs/phase4_5_ablations/{config_name}_trial{trial}/checkpoint_best.pt")
        if not checkpoint_path.exists():
            continue
        
        print(f"\n" + "="*60)
        print(f"{config_name.upper()} TRIAL {trial}")
        print("="*60)
        
        # Load generator
        generator = create_generator(base_config, seed=42, device=device)
        load_checkpoint(str(checkpoint_path), generator)
        generator.eval()
        
        trial_results = {"config": config_name, "trial": trial}
        
        for task, samples in eval_data.items():
            task_indices = [i for i, s in enumerate(dataset.samples) if s["task"] == task]
            if not task_indices:
                print(f"  {task}: [SKIP]")
                continue
            
            sample = dataset[task_indices[0]]
            condition_ids = sample["condition_ids"].long()
            attention_mask = sample["attention_mask"].float()
            
            result = compute_accuracy_with_lora_batched(
                generator=generator,
                functional_lora=functional_lora,
                condition_ids=condition_ids,
                attention_mask=attention_mask,
                eval_samples=samples,
                tokenizer=tokenizer,
                task_type=task_types[task],
                max_samples=100,
                batch_size=8,
                device=device,
            )
            trial_results[task] = result["accuracy"]
            print(f"  {task}: {result['accuracy']:.2%} ({result['correct']}/{result['total']})")
        
        all_results.append(trial_results)
        del generator
        torch.cuda.empty_cache()

# Summary table
print("\n" + "="*60)
print("SUMMARY")
print("="*60)
import pandas as pd
if all_results:
    df_eval = pd.DataFrame(all_results)
    summary = df_eval.groupby("config")[["arc_e", "boolq", "gsm8k"]].agg(["mean", "std"])
    print(summary.round(4).to_string())