# Phase 4.6: Delta-Guided Training Ablation

**Prerequisites:** Run `train_lora_adapters.ipynb` first to create real LoRA adapters and cached `delta_teacher`.

## What This Notebook Tests

We compare two ways of training a prompt-conditioned LoRA generator using **delta supervision**:
1. **Delta-only** — compute Δ(base + generated LoRA) on probes every step.
2. **Delta-guided** — add a cheap “delta head” trained every step, and only compute the expensive probe-based Δ every N steps.

Both use:
- Pretrained text encoder: `sentence-transformers/all-MiniLM-L6-v2`
- Prompt batches: 8 prompts per adapter

Goal: reduce expensive base-model/probe forwards while keeping generated weights behavior-grounded.

## Configuration

In [None]:
import sys
import os

# Colab setup
IN_COLAB = 'google.colab' in sys.modules
DRIVE_OUTPUT_DIR = None

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_OUTPUT_DIR = '/content/drive/MyDrive/llgbm/outputs'
    os.makedirs(DRIVE_OUTPUT_DIR, exist_ok=True)
    sys.path.insert(0, '/content/drive/MyDrive')
    CHECKPOINT_DIR = '/content/drive/MyDrive/llgbm/checkpoints'
    DELTAS_DIR = CHECKPOINT_DIR + '/deltas'
else:
    CHECKPOINT_DIR = './checkpoints'
    DELTAS_DIR = './llgbm/deltas'

In [None]:
from llgbm import (
    AblationConfig,
    phase4_6_configs,
    run_ablations,
    plot_ablation_results,
)

print("[OK] llgbm imports")
print("[INFO] Phase 4.6: delta-only vs delta-guided")

In [None]:
# Experiment configuration - modify these parameters as needed
config = AblationConfig(
    # Ablation configurations to compare
    configs=phase4_6_configs(
        compute_delta_every_n_steps=2,  # N inner delta head steps per outer step
        lambda_pred=1.0,
        lambda_computed=1.0,
        lambda_consistency=0.5,
        delta_aggregation="attention",
    ),
    
    # Trial settings
    num_trials=3,
    seeds=[42, 123, 456],
    num_steps=100,  # Outer steps (same for all configs)
    
    # Paths
    checkpoint_dir=CHECKPOINT_DIR,
    deltas_dir=DELTAS_DIR,
    output_dir="outputs/phase4_6_ablations",
    
    # Model settings
    use_small_model=True,  # Qwen2.5-0.5B
    batch_size=8,
    gradient_accumulation_steps=1,
    learning_rate=2e-4,  # Adjustable learning rate
    warmup_steps=25,
    shuffle_task_prompts=True,
    
    # Text encoder settings
    text_encoder_name="sentence-transformers/all-MiniLM-L6-v2",
    freeze_text_encoder=True,
    num_prompts_per_adapter=8,  # Sample 8 prompts instead of just first one
    
    # Colab settings
    in_colab=IN_COLAB,
    drive_output_dir=DRIVE_OUTPUT_DIR,
)

print(f"Configurations: {list(config.configs.keys())}")
print(f"Total runs: {len(config.configs) * config.num_trials}")
print(f"Steps per trial: {config.num_steps}")
print(f"Delta head inner steps (delta_guided): {config.configs['delta_guided']['compute_delta_every_n_steps']}")
print(f"Learning rate: {config.learning_rate}")
print(f"Text encoder: {config.text_encoder_name}")
print(f"Prompts per adapter: {config.num_prompts_per_adapter}")

## Run Ablations

In [None]:
# Run all ablation experiments
results = run_ablations(config)

## Results Summary

In [None]:
import pandas as pd

df = results["dataframe"]

# Aggregate by config
agg_dict = {
    "final_loss": ["mean", "std"],
    "best_loss": ["mean", "std"],
    "train_time": ["mean"],
}
if "mean_cosine" in df.columns:
    agg_dict["mean_cosine"] = ["mean", "std"]
if "mean_cosine_pred" in df.columns:
    agg_dict["mean_cosine_pred"] = ["mean", "std"]

summary = df.groupby("config_name").agg(agg_dict).round(4)

print("\n" + "="*70)
print("ABLATION SUMMARY (mean +/- std over 3 trials)")
print("="*70)
print(summary.to_string())

In [None]:
# Visualization
from pathlib import Path
plot_ablation_results(df, Path(config.output_dir), list(config.configs.keys()))

In [None]:
# Sync to Drive if in Colab
if IN_COLAB and DRIVE_OUTPUT_DIR:
    import shutil
    drive_dir = f"{DRIVE_OUTPUT_DIR}/phase4_6_ablations"
    if os.path.exists(drive_dir):
        shutil.rmtree(drive_dir)
    shutil.copytree(str(config.output_dir), drive_dir)
    print(f"[Drive] Synced to {drive_dir}")
else:
    print("[Local] Outputs saved to", config.output_dir)

In [None]:
# Final summary
print("\n" + "="*70)
print("Phase 4.6 Ablations Complete!")
print("="*70)

print(f"\nKey findings (loss | cosine):")
for config_name in config.configs.keys():
    mean_loss = df[df['config_name']==config_name]['final_loss'].mean()
    std_loss = df[df['config_name']==config_name]['final_loss'].std()
    
    if 'mean_cosine' in df.columns:
        mean_cos = df[df['config_name']==config_name]['mean_cosine'].mean()
        std_cos = df[df['config_name']==config_name]['mean_cosine'].std()
        extra = ""
        if 'mean_cosine_pred' in df.columns:
            mean_cos_pred = df[df['config_name']==config_name]['mean_cosine_pred'].mean()
            std_cos_pred = df[df['config_name']==config_name]['mean_cosine_pred'].std()
            if not pd.isna(mean_cos_pred):
                extra = f" | pred={mean_cos_pred:.4f} +/- {std_cos_pred:.4f}"
        print(f"  {config_name:12s}: {mean_loss:.4f} +/- {std_loss:.4f} | {mean_cos:.4f} +/- {std_cos:.4f}{extra}")
    else:
        print(f"  {config_name:12s}: {mean_loss:.4f} +/- {std_loss:.4f}")

print(f"\nOutputs saved to: {config.output_dir}")

## Task Performance Evaluation (Optional)

Evaluate generated LoRAs on held-out task data.

**Tasks from manifest:** arc_e, arc_c, boolq, obqa, piqa, winogrande

In [None]:
from llgbm import (
    compute_base_eval_loss,
    compute_accuracy_with_lora_batched,
    create_generator,
    create_generator_with_delta_head,
    load_checkpoint,
    TrainingConfig,
    FunctionalLoRA,
    RealAdapterDataset,
    create_text_encoder,
)
from llgbm.ablations import _infer_mode, phase4_6_configs
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import torch
from pathlib import Path

# Configuration
CONFIGS_TO_EVAL = ["delta_only", "delta_guided"]
NUM_TRIALS = 3

# Define configs directly (self-contained)
CONFIGS = phase4_6_configs()

# Tasks from manifest (excluding hellaswag)
TASKS = ["arc_e", "arc_c", "boolq", "obqa", "piqa", "winogrande"]
task_types = {
    "arc_e": "mcq",
    "arc_c": "mcq",
    "boolq": "bool",
    "obqa": "mcq",
    "piqa": "mcq",
    "winogrande": "mcq",
}

# Load base model and tokenizer
base_config = TrainingConfig(use_small_model=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"[1] Loading base model: {base_config.base_model}")
base_model = AutoModelForCausalLM.from_pretrained(
    base_config.base_model,
    torch_dtype=torch.bfloat16,
    device_map=device,
    trust_remote_code=True,
)
base_model.eval()
tokenizer = AutoTokenizer.from_pretrained(base_config.base_model, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Create pretrained text encoder (same as training)
print("[2] Loading pretrained text encoder")
text_encoder = create_text_encoder(
    model_name="sentence-transformers/all-MiniLM-L6-v2",
    freeze=True,
    device=device,
)

# Load dataset with new settings
print("[3] Loading dataset with prompt batches")
dataset = RealAdapterDataset(
    CHECKPOINT_DIR,
    DELTAS_DIR,
    text_encoder.tokenizer,  # Use text encoder's tokenizer
    base_config,
    num_prompts=8,  # Sample 8 prompts per adapter
)
print(f"    {len(dataset)} samples, 8 prompts per adapter")

# Create FunctionalLoRA wrapper
functional_lora = FunctionalLoRA(
    base_model,
    lora_rank=base_config.lora_rank,
    lora_alpha=base_config.lora_alpha,
)

# Load eval data for all tasks from manifest
print("[4] Loading eval data")
eval_splits_dir = Path('/content/drive/MyDrive/llgbm/data')
eval_data = {}
for task in TASKS:
    eval_file = eval_splits_dir / f"{task}_eval.json"
    if eval_file.exists():
        with open(eval_file) as f:
            eval_data[task] = json.load(f)
        print(f"    {task}: {len(eval_data[task])} samples")
    else:
        print(f"    {task}: [NOT FOUND] {eval_file}")

# Evaluate base model first
print("\n" + "="*60)
print("BASE MODEL (no LoRA)")
print("="*60)
for task, samples in eval_data.items():
    loss = compute_base_eval_loss(base_model, samples, tokenizer)
    print(f"  {task}: loss={loss:.4f}")

# Collect all results
all_results = []

class _LoraOnlyWrapper(torch.nn.Module):
    """Adapter for delta-guided models (forward() -> List[Dict[str, Tensor]])."""
    def __init__(self, gen):
        super().__init__()
        self.gen = gen

    def forward(self, condition_ids, attention_mask=None):
        return self.gen.generate_lora(condition_ids, attention_mask)

# Evaluate each config and trial
for config_name in CONFIGS_TO_EVAL:
    # Get the actual mode from config to determine model architecture
    config_params = CONFIGS.get(config_name, {})
    mode = _infer_mode(config_name, config_params)
    uses_delta_head = (mode == "delta_guided")
    
    for trial in range(NUM_TRIALS):
        checkpoint_path = Path(f"outputs/phase4_6_ablations/{config_name}_trial{trial}/checkpoint_best.pt")
        if not checkpoint_path.exists():
            continue
        
        print(f"\n" + "="*60)
        print(f"{config_name.upper()} TRIAL {trial} (mode={mode})")
        print("="*60)
        
        # Load generator with pretrained text encoder
        # Use correct architecture based on actual mode, not config name
        if uses_delta_head:
            gen_core = create_generator_with_delta_head(
                base_config,
                seed=42,
                device=device,
                text_encoder=text_encoder,
            )
            load_checkpoint(str(checkpoint_path), gen_core)
            generator = _LoraOnlyWrapper(gen_core)
        else:
            generator = create_generator(
                base_config,
                seed=42,
                device=device,
                text_encoder=text_encoder,  # Use pretrained encoder
            )
            load_checkpoint(str(checkpoint_path), generator)

        generator.eval()
        
        trial_results = {"config": config_name, "trial": trial}
        
        for task, samples in eval_data.items():
            task_indices = [i for i, s in enumerate(dataset.samples) if s["task"] == task]
            if not task_indices:
                print(f"  {task}: [SKIP]")
                continue
            
            sample = dataset[task_indices[0]]
            condition_ids = sample["condition_ids"]  # (N, seq_len)
            attention_mask = sample["attention_mask"]  # (N, seq_len)
            
            result = compute_accuracy_with_lora_batched(
                generator=generator,
                functional_lora=functional_lora,
                condition_ids=condition_ids,
                attention_mask=attention_mask,
                eval_samples=samples,
                tokenizer=tokenizer,
                task_type=task_types[task],
                max_samples=100,
                batch_size=8,
                device=device,
            )
            trial_results[task] = result["accuracy"]
            print(f"  {task}: {result['accuracy']:.2%} ({result['correct']}/{result['total']})")
        
        all_results.append(trial_results)
        del generator
        torch.cuda.empty_cache()

# Summary table
print("\n" + "="*60)
print("SUMMARY")
print("="*60)
import pandas as pd
if all_results:
    df_eval = pd.DataFrame(all_results)
    task_cols = [t for t in TASKS if t in df_eval.columns]
    summary = df_eval.groupby("config")[task_cols].agg(["mean", "std"])
    print(summary.round(4).to_string())