# Phase 5: Delta-Guided Training with Delta Prediction Head

Train a LoRA generator using **behavioral supervision** with a dedicated delta prediction head.

**Architecture:**
```
N embeddings ─→ [Shared Encoder] ─┬─→ [LoRA Head] → LoRA weights → δ_computed
                                  └─→ [Delta Head] → δ_predicted (fast!)
```

**Loss Function:**
```
Loss = λ_pred * L(δ_predicted, δ_teacher)        # Fast delta supervision
     + λ_computed * L(δ_computed, δ_teacher)     # Real LoRA behavior  
     + λ_consistency * L(δ_computed, δ_predicted) # Heads must agree
```

## Key Features

1. **Trainable text encoder** - Frozen MiniLM-L6-v2 + learnable projection
2. **Delta prediction head** - Predicts delta directly from N embeddings (fast!)
3. **Attention aggregation** - Learns which prompts matter most for delta
4. **Consistency loss** - LoRA must produce deltas matching predictions
5. **Per-embedding deltas** - No averaging, uses all N prompt embeddings

## Configuration

In [None]:
import sys
import os
import shutil

IN_COLAB = 'google.colab' in sys.modules
DRIVE_OUTPUT_DIR = None

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_OUTPUT_DIR = '/content/drive/MyDrive/llgbm/outputs'
    os.makedirs(DRIVE_OUTPUT_DIR, exist_ok=True)
    !pip install -q safetensors accelerate transformers peft sentence-transformers
    sys.path.insert(0, '/content/drive/MyDrive')
    CHECKPOINT_DIR = '/content/drive/MyDrive/llgbm/checkpoints'
    DELTAS_DIR = CHECKPOINT_DIR + '/deltas'
else:
    CHECKPOINT_DIR = './checkpoints'
    DELTAS_DIR = './llgbm/deltas'

import json
import gc
from pathlib import Path
from dataclasses import asdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from transformers import AutoModelForCausalLM, AutoTokenizer

print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}")

In [None]:
# Import llgbm modules
from llgbm import (
    create_generic_probes,
    DeltaCache,
    FunctionalLoRA,
    TrainingConfig,
    DeltaGuidedLoss,  # NEW: 3-part loss with consistency
    compute_delta_for_batch,
    save_checkpoint,
    load_checkpoint,
    # Trainable text encoder
    create_trainable_text_encoder,
    # Generator with delta head
    create_generator_with_delta_head,
    RealAdapterDataset,
    # Evaluation
    compute_base_eval_loss,
    compute_accuracy_with_lora_batched,
)

print("[OK] llgbm imports")
print("[INFO] Using LoRAGeneratorWithDeltaHead (dual-head architecture)")
print("[INFO] Delta prediction + LoRA generation with consistency loss")

## Training Configuration

Key difference from Phase 4: `lambda_weight=0` (delta-only supervision)

In [None]:
# Training configuration
config = TrainingConfig(
    use_small_model=True,  # Qwen2.5-0.5B for testing
    batch_size=4,
    gradient_accumulation_steps=2,
    num_steps=200,  # More steps for delta-only training
    warmup_steps=20,
    learning_rate=2e-4,
    lambda_delta=1.0,   # Full weight on delta loss
    lambda_weight=0.0,  # No weight supervision
    num_probes=10,
    max_probe_length=256,
    delta_batch_probes=True,
    checkpoint_dir=CHECKPOINT_DIR,
    delta_cache_dir=DELTAS_DIR,
    output_dir="outputs/phase5_delta_only",
    # Text encoder settings
    text_encoder_name="sentence-transformers/all-MiniLM-L6-v2",
    freeze_text_encoder=True,
    num_prompts_per_adapter=8,
)

# Derived settings
TORCH_DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[config.dtype]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Path(config.output_dir).mkdir(parents=True, exist_ok=True)
config.save(f"{config.output_dir}/config.json")

print(f"Model: {config.base_model}")
print(f"Device: {device}")
print(f"Loss: Delta-only (lambda_d={config.lambda_delta}, lambda_w={config.lambda_weight})")
print(f"Text encoder: {config.text_encoder_name}")
print(f"Prompts per adapter: {config.num_prompts_per_adapter}")

## Load Base Model & Prepare Components

In [None]:
# Load base model tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.base_model, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    config.base_model,
    torch_dtype=TORCH_DTYPE,
    device_map=device,
    trust_remote_code=True
)
base_model.config.output_hidden_states = False
base_model.config.use_cache = False
base_model.eval()
for p in base_model.parameters():
    p.requires_grad = False

print(f"[OK] Base model: {sum(p.numel() for p in base_model.parameters()):,} params")

In [None]:
# Load probes from delta manifest (or use generic)
checkpoint_dir = Path(CHECKPOINT_DIR)
deltas_dir = Path(DELTAS_DIR)

delta_manifest_path = deltas_dir / "delta_manifest.json"
manifest_path = checkpoint_dir / "manifest.json"

all_probes = []
if delta_manifest_path.exists() and manifest_path.exists():
    with open(delta_manifest_path) as f:
        delta_manifest = json.load(f)
    with open(manifest_path) as f:
        adapter_manifest = json.load(f)
    
    adapter_paths = {a["name"]: a["path"] for a in adapter_manifest.get("adapters", [])}
    tasks_seen = set()
    
    for adapter_name, adapter_info in delta_manifest["adapters"].items():
        task = adapter_info.get("task", "unknown")
        if task not in tasks_seen:
            remaining = max(0, config.num_probes - len(all_probes))
            if remaining == 0:
                break
            adapter_path = adapter_paths.get(adapter_name)
            if adapter_path:
                prompts_file = Path(adapter_path) / "prompts.json"
                if prompts_file.exists():
                    with open(prompts_file) as f:
                        prompts_data = json.load(f)
                    probes = prompts_data.get("prompts", [])[:min(5, remaining)]
                    if probes:
                        all_probes.extend(probes)
                        tasks_seen.add(task)
                        print(f"  Loaded {len(probes)} probes for {task}")

if not all_probes:
    print("[WARN] No task-specific probes, using generic")
    all_probes = create_generic_probes()[:config.num_probes]

print(f"[OK] {len(all_probes)} probes loaded")

In [None]:
# Tokenize probes
probe_tokens, probe_masks = [], []
for p in all_probes:
    enc = tokenizer(p, return_tensors="pt", truncation=True, max_length=config.max_probe_length)
    probe_tokens.append(enc["input_ids"].to(device))
    probe_masks.append(enc["attention_mask"].to(device))

# Compute base activation
with torch.no_grad():
    base_acts = []
    for ids, mask in zip(probe_tokens, probe_masks):
        backbone = getattr(base_model, "model", None)
        if backbone is not None:
            out = backbone(input_ids=ids, attention_mask=mask, use_cache=False)
            hidden = out.last_hidden_state
        else:
            out = base_model(input_ids=ids, attention_mask=mask, output_hidden_states=True, use_cache=False)
            hidden = out.hidden_states[-1]
        seq_lens = mask.long().sum(dim=1).clamp(min=1) - 1
        batch_idx = torch.arange(hidden.shape[0], device=hidden.device)
        h = hidden[batch_idx, seq_lens, :].squeeze(0)
        base_acts.append(h)
    base_activation = torch.stack(base_acts).mean(dim=0)

print(f"[OK] Base activation: {base_activation.shape}, norm={base_activation.norm():.4f}")

In [None]:
# Create FunctionalLoRA wrapper
functional_lora = FunctionalLoRA(
    base_model=base_model,
    lora_rank=config.lora_rank,
    lora_alpha=config.lora_alpha,
)
print(f"[OK] FunctionalLoRA: {len(functional_lora._lora_to_base_map)} mappings")

## Trainable Text Encoder & Generator with Delta Head

**Text Encoder**: Frozen MiniLM-L6-v2 + trainable 2-layer projection MLP

**Generator**: Dual-head architecture:
- **LoRA Head**: Generates adapter weights (for deployment)
- **Delta Head**: Predicts behavioral delta (for fast training)

In [None]:
# Create TRAINABLE text encoder (projection layer is trainable)
print("[1] Loading trainable text encoder...")
text_encoder = create_trainable_text_encoder(
    model_name=config.text_encoder_name,
    output_dim=384,  # Keep same dimension as input
    hidden_dim=512,  # Projection MLP hidden dim
    num_layers=2,    # 2-layer MLP projection
    dropout=0.1,
    device=device,
)

# Note: text_encoder.projection parameters will be trained!

In [None]:
# Create generator with delta prediction head
print("[2] Creating generator with delta head...")
generator = create_generator_with_delta_head(
    config,
    seed=42,
    device=device,
    text_encoder=text_encoder,
    delta_aggregation="attention",  # Use attention over N embeddings
)

## Dataset

In [None]:
# Create dataset with real adapters and prompt batches
print("[3] Loading dataset with prompt batches...")
dataset = RealAdapterDataset(
    checkpoint_dir=str(checkpoint_dir),
    deltas_dir=str(deltas_dir),
    tokenizer=text_encoder.tokenizer,
    config=config,
    num_prompts=config.num_prompts_per_adapter,
)

dataloader = DataLoader(
    dataset,
    batch_size=config.batch_size,
    shuffle=True,
    collate_fn=dataset.collate_fn,
    num_workers=2,
    pin_memory=True,
)

print(f"[OK] Dataset: {len(dataset)} samples, {len(dataloader)} batches")
print(f"     {config.num_prompts_per_adapter} prompts per adapter")

## Training with Delta-Only Loss

In [None]:
# Delta-guided loss with 3 components
criterion = DeltaGuidedLoss(
    lambda_pred=1.0,        # Predicted delta vs teacher
    lambda_computed=1.0,    # Computed delta (via LoRA) vs teacher
    lambda_consistency=0.5, # Predicted vs computed consistency
    normalize=True,
    loss_type="mse",
)

# Collect all trainable parameters
trainable_params = [
    {"params": generator.parameters(), "lr": config.learning_rate},
    {"params": text_encoder.projection.parameters(), "lr": config.learning_rate * 0.5},
]

# Optimizer with parameter groups
optimizer = AdamW(trainable_params, weight_decay=config.weight_decay)

# LR scheduler with warmup
warmup_steps = min(config.warmup_steps, config.num_steps // 10)
cosine_steps = max(1, config.num_steps - warmup_steps)
scheduler = SequentialLR(
    optimizer,
    [LinearLR(optimizer, 0.1, 1.0, warmup_steps),
     CosineAnnealingLR(optimizer, cosine_steps, config.learning_rate * 0.01)],
    [warmup_steps]
)

# Count trainable parameters
gen_params = sum(p.numel() for p in generator.parameters() if p.requires_grad)
enc_params = sum(p.numel() for p in text_encoder.projection.parameters() if p.requires_grad)
delta_head_params = sum(p.numel() for p in generator.delta_head.parameters() if p.requires_grad)

print("[OK] Optimizer & Scheduler ready")
print(f"     Generator total: {gen_params:,}")
print(f"     - Delta head: {delta_head_params:,}")
print(f"     - LoRA head: {gen_params - delta_head_params:,}")
print(f"     Encoder projection: {enc_params:,}")
print(f"     Total trainable: {gen_params + enc_params:,}")
print(f"[OK] Using DeltaGuidedLoss (3-part: pred + computed + consistency)")

In [None]:
# Custom training loop for dual-head generator
from tqdm.auto import tqdm
from itertools import cycle
from llgbm import compute_delta_differentiable

print(f"\n{'='*60}")
print(f"Training: {config.num_steps} steps")
print(f"Batch size: {config.batch_size} x {config.gradient_accumulation_steps}")
print(f"Mode: Delta-guided (prediction + computed + consistency)")
print(f"{'='*60}\n")

generator.train()
data_iter = cycle(dataloader)
pbar = tqdm(total=config.num_steps, desc="Training")

# Training state
loss_history = []
loss_pred_history = []
loss_computed_history = []
loss_consistency_history = []
grad_norm_history = []
lr_history = []
best_loss = float("inf")

accumulation_step = 0
running_losses = {"loss": 0, "loss_pred": 0, "loss_computed": 0, "loss_consistency": 0}
update_count = 0

while update_count < config.num_steps:
    batch = next(data_iter)
    
    # Move to device
    delta_teacher = batch["delta_teacher"].to(device)
    condition_ids = batch["condition_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    
    # Forward pass through generator (returns both delta_pred and lora_weights)
    with torch.autocast(device_type="cuda" if device.type == "cuda" else "cpu", dtype=TORCH_DTYPE):
        results = generator(condition_ids, attention_mask, return_delta=True, return_lora=True)
        
        delta_predicted = results["delta_pred"]
        lora_weights_batch = results["lora_weights"]
        
        # Compute actual delta from LoRA weights (expensive but necessary for consistency)
        deltas_computed = []
        for i in range(len(lora_weights_batch)):
            delta_i = compute_delta_differentiable(
                functional_lora=functional_lora,
                lora_weights=lora_weights_batch[i],
                base_activation=base_activation,
                probe_tokens=probe_tokens,
                probe_masks=probe_masks,
                batch_probes=True,
            )
            deltas_computed.append(delta_i)
        delta_computed = torch.stack(deltas_computed)
        
        # Compute 3-part loss
        losses = criterion(
            delta_predicted=delta_predicted.float(),
            delta_computed=delta_computed.float(),
            delta_teacher=delta_teacher.float(),
        )
    
    # Backward
    scaled_loss = losses["loss"] / config.gradient_accumulation_steps
    scaled_loss.backward()
    
    # Accumulate losses
    for k in running_losses:
        if k in losses:
            running_losses[k] += losses[k].item()
    accumulation_step += 1
    
    # Gradient update
    if accumulation_step >= config.gradient_accumulation_steps:
        grad_norm = torch.nn.utils.clip_grad_norm_(
            list(generator.parameters()) + list(text_encoder.projection.parameters()),
            config.max_grad_norm,
        )
        
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        
        # Record metrics
        avg_losses = {k: v / accumulation_step for k, v in running_losses.items()}
        loss_history.append(avg_losses["loss"])
        loss_pred_history.append(avg_losses.get("loss_pred", 0))
        loss_computed_history.append(avg_losses.get("loss_computed", 0))
        loss_consistency_history.append(avg_losses.get("loss_consistency", 0))
        grad_norm_history.append(grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm)
        lr_history.append(scheduler.get_last_lr()[0])
        
        # Track best
        if avg_losses["loss"] < best_loss:
            best_loss = avg_losses["loss"]
            torch.save({
                "generator_state_dict": generator.state_dict(),
                "text_encoder_projection_state_dict": text_encoder.projection.state_dict(),
            }, f"{config.output_dir}/checkpoint_best.pt")
        
        # Reset
        running_losses = {k: 0 for k in running_losses}
        accumulation_step = 0
        update_count += 1
        
        pbar.set_postfix({
            "loss": f"{avg_losses['loss']:.4f}",
            "pred": f"{avg_losses.get('loss_pred', 0):.4f}",
            "cons": f"{avg_losses.get('loss_consistency', 0):.4f}",
        })
        pbar.update(1)
        
        # Memory cleanup
        if update_count % 50 == 0:
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

pbar.close()
print(f"\nDone! Steps: {update_count}, Best loss: {best_loss:.6f}")

In [None]:
# Evaluate using delta prediction (fast) and computed delta (accurate)
import torch.nn.functional as F

generator.eval()
eval_dataloader = DataLoader(
    dataset,
    batch_size=config.batch_size,
    shuffle=False,
    collate_fn=dataset.collate_fn
)

total_loss = 0
cosines_pred = []
cosines_computed = []
num_samples = 0

with torch.no_grad():
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        delta_teacher = batch["delta_teacher"].to(device)
        condition_ids = batch["condition_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        
        results = generator(condition_ids, attention_mask, return_delta=True, return_lora=True)
        delta_predicted = results["delta_pred"]
        lora_weights_batch = results["lora_weights"]
        
        # Compute actual delta
        deltas_computed = []
        for lora_w in lora_weights_batch:
            delta_i = compute_delta_differentiable(
                functional_lora=functional_lora,
                lora_weights=lora_w,
                base_activation=base_activation,
                probe_tokens=probe_tokens,
                probe_masks=probe_masks,
                batch_probes=True,
            )
            deltas_computed.append(delta_i)
        delta_computed = torch.stack(deltas_computed)
        
        # Cosine similarities
        cos_pred = F.cosine_similarity(delta_predicted, delta_teacher, dim=-1)
        cos_computed = F.cosine_similarity(delta_computed, delta_teacher, dim=-1)
        
        cosines_pred.extend(cos_pred.cpu().tolist())
        cosines_computed.extend(cos_computed.cpu().tolist())
        num_samples += len(delta_teacher)

eval_results = {
    "mean_cosine_pred": float(np.mean(cosines_pred)),
    "std_cosine_pred": float(np.std(cosines_pred)),
    "mean_cosine_computed": float(np.mean(cosines_computed)),
    "std_cosine_computed": float(np.std(cosines_computed)),
}

print("\nEvaluation Results:")
print(f"  Delta Predicted vs Teacher:")
print(f"    mean cosine: {eval_results['mean_cosine_pred']:.4f} +/- {eval_results['std_cosine_pred']:.4f}")
print(f"  Delta Computed (via LoRA) vs Teacher:")
print(f"    mean cosine: {eval_results['mean_cosine_computed']:.4f} +/- {eval_results['std_cosine_computed']:.4f}")

## Training Curves

In [None]:
# Plot training curves with all loss components
if loss_history:
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Total loss
    axes[0, 0].plot(loss_history, label='Total', color='#2c3e50', linewidth=2)
    axes[0, 0].set_xlabel('Step')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Total Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Loss components
    axes[0, 1].plot(loss_pred_history, label='Predicted→Teacher', color='#3498db', linewidth=1.5)
    axes[0, 1].plot(loss_computed_history, label='Computed→Teacher', color='#e74c3c', linewidth=1.5)
    axes[0, 1].plot(loss_consistency_history, label='Consistency', color='#2ecc71', linewidth=1.5)
    axes[0, 1].set_xlabel('Step')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].set_title('Loss Components')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Gradient norm
    axes[1, 0].plot(grad_norm_history, color='#9b59b6', linewidth=1.5)
    axes[1, 0].axhline(config.max_grad_norm, color='r', ls='--', label=f'Clip={config.max_grad_norm}')
    axes[1, 0].set_xlabel('Step')
    axes[1, 0].set_ylabel('Gradient Norm')
    axes[1, 0].set_title('Gradient Norms')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Learning rate
    if lr_history:
        axes[1, 1].plot(lr_history, color='#f39c12', linewidth=2)
        axes[1, 1].set_xlabel('Step')
        axes[1, 1].set_ylabel('Learning Rate')
        axes[1, 1].set_title('LR Schedule')
        axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f"{config.output_dir}/training_curves.png", dpi=150)
    plt.show()

In [None]:
# Save results
results = {
    "config": asdict(config),
    "training": {
        "steps": len(loss_history),
        "best_loss": best_loss,
        "final_loss": loss_history[-1] if loss_history else None,
        "final_loss_pred": loss_pred_history[-1] if loss_pred_history else None,
        "final_loss_computed": loss_computed_history[-1] if loss_computed_history else None,
        "final_loss_consistency": loss_consistency_history[-1] if loss_consistency_history else None,
    },
    "eval": eval_results,
    "mode": "delta_guided",
    "architecture": "LoRAGeneratorWithDeltaHead",
}

with open(f"{config.output_dir}/results.json", "w") as f:
    json.dump(results, f, indent=2)

print(f"Saved to {config.output_dir}/")

## Task Performance Evaluation

Evaluate generated LoRAs on downstream tasks using accuracy.

In [None]:
# Task configuration
TASKS = ["arc_e", "arc_c", "boolq", "obqa", "piqa", "winogrande"]
TASK_TYPES = {
    "arc_e": "mcq",
    "arc_c": "mcq",
    "boolq": "bool",
    "obqa": "mcq",
    "piqa": "mcq",
    "winogrande": "mcq",
}

# Load eval data
print("[1] Loading eval data...")
eval_splits_dir = checkpoint_dir / "eval_splits"
eval_data = {}
for task in TASKS:
    eval_file = eval_splits_dir / f"{task}_eval.json"
    if eval_file.exists():
        with open(eval_file) as f:
            eval_data[task] = json.load(f)
        print(f"    {task}: {len(eval_data[task])} samples")
    else:
        print(f"    {task}: [NOT FOUND]")

if not eval_data:
    print("[SKIP] No eval data found, skipping task evaluation")

In [None]:
# Evaluate base model
if eval_data:
    print("\n" + "="*60)
    print("BASE MODEL (no LoRA)")
    print("="*60)
    
    base_results = {}
    for task, samples in eval_data.items():
        loss = compute_base_eval_loss(base_model, samples[:50], tokenizer)
        base_results[task] = loss
        print(f"  {task}: loss={loss:.4f}")

In [None]:
# Evaluate delta-guided model on downstream tasks
if eval_data:
    print("\n" + "="*60)
    print("DELTA-GUIDED MODEL (LoRAGeneratorWithDeltaHead)")
    print("="*60)
    
    # Load best checkpoint
    best_checkpoint = Path(config.output_dir) / "checkpoint_best.pt"
    if best_checkpoint.exists():
        ckpt = torch.load(best_checkpoint, map_location=device)
        generator.load_state_dict(ckpt["generator_state_dict"])
        if "text_encoder_projection_state_dict" in ckpt:
            text_encoder.projection.load_state_dict(ckpt["text_encoder_projection_state_dict"])
        print(f"[OK] Loaded checkpoint: {best_checkpoint}")
    
    generator.eval()
    delta_guided_results = {}
    
    for task, samples in eval_data.items():
        # Find adapter for this task
        task_indices = [i for i, s in enumerate(dataset.samples) if s.get("task") == task]
        if not task_indices:
            print(f"  {task}: [SKIP] No adapter found")
            continue
        
        sample = dataset[task_indices[0]]
        condition_ids = sample["condition_ids"].unsqueeze(0).to(device)
        attention_mask = sample["attention_mask"].unsqueeze(0).to(device)
        
        try:
            # Generate LoRA weights
            with torch.no_grad():
                lora_weights = generator.generate_lora(condition_ids, attention_mask)[0]
            
            # Apply LoRA and evaluate
            functional_lora.apply_lora_weights(lora_weights)
            
            correct = 0
            total = 0
            for s in samples[:100]:
                # Simple eval: check if model produces expected output
                prompt = s.get("question", s.get("prompt", ""))
                expected = s.get("answer", s.get("label", ""))
                
                inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
                with torch.no_grad():
                    outputs = base_model.generate(
                        **inputs, max_new_tokens=32, do_sample=False, pad_token_id=tokenizer.pad_token_id
                    )
                response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
                
                if str(expected).lower() in response.lower():
                    correct += 1
                total += 1
            
            functional_lora.remove_lora_weights()
            
            acc = correct / total if total > 0 else 0
            delta_guided_results[task] = acc
            print(f"  {task}: {acc:.2%} ({correct}/{total})")
            
        except Exception as e:
            print(f"  {task}: [ERROR] {e}")
            functional_lora.remove_lora_weights()
    
    # Clean up
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

## Comparison with Phase 4.5 Results

In [None]:
# Load Phase 4.5 results for comparison
phase45_results_path = Path("outputs/phase4_5_ablations/ablation_results.json")

if phase45_results_path.exists():
    with open(phase45_results_path) as f:
        phase45_data = json.load(f)
    
    print("\n" + "="*60)
    print("Comparison: Phase 4.5 vs Phase 5")
    print("="*60)
    
    # Extract Phase 4.5 delta_only results
    p45_summary = phase45_data.get("summary", {})
    
    print("\nPhase 4.5 Ablations Summary:")
    for config_name, stats in p45_summary.items():
        loss_mean = stats.get("final_loss_mean", "N/A")
        loss_std = stats.get("final_loss_std", 0)
        cos_mean = stats.get("mean_cosine_mean", "N/A")
        cos_std = stats.get("mean_cosine_std", 0)
        
        if isinstance(loss_mean, float):
            print(f"  {config_name:12s}: loss={loss_mean:.4f}+/-{loss_std:.4f}, cos={cos_mean:.4f}+/-{cos_std:.4f}")
    
    print("\nPhase 5 (Extended Delta-Only):")
    print(f"  final_loss: {state.loss_history[-1]:.4f}")
    print(f"  mean_cosine: {eval_results.get('mean_cosine', 'N/A')}")
    
    # Compare delta_only specifically
    if "delta_only" in p45_summary:
        p45_delta = p45_summary["delta_only"]
        p5_loss = state.loss_history[-1] if state.loss_history else None
        p5_cosine = eval_results.get("mean_cosine")
        
        print("\nDelta-Only Comparison:")
        print(f"  Phase 4.5 (100 steps): loss={p45_delta.get('final_loss_mean', 'N/A'):.4f}")
        print(f"  Phase 5 ({config.num_steps} steps): loss={p5_loss:.4f}")
        if p5_cosine is not None:
            print(f"  Cosine improvement: {p5_cosine - p45_delta.get('mean_cosine_mean', 0):.4f}")
else:
    print("[Info] Phase 4.5 results not found. Run phase_4_5_ablations.ipynb first for comparison.")

## Sync to Google Drive

In [None]:
# Sync to Google Drive (Colab only)
if IN_COLAB and DRIVE_OUTPUT_DIR:
    drive_phase5_dir = f"{DRIVE_OUTPUT_DIR}/phase5_delta_only"
    if os.path.exists(drive_phase5_dir):
        shutil.rmtree(drive_phase5_dir)
    shutil.copytree(config.output_dir, drive_phase5_dir)
    print(f"[Drive] Synced to {drive_phase5_dir}")
else:
    print(f"[Local] Outputs saved to {config.output_dir}")

## Analysis: Delta-Guided Training Insights

Key advantages of the dual-head architecture:

1. **Fast training signal** - Delta predictor provides immediate supervision without LoRA application
2. **No embedding averaging** - Uses attention over N embeddings (learns which prompts matter)
3. **Consistency constraint** - LoRA head must produce deltas matching predictions
4. **Per-embedding supervision** - Can supervise individual prompt→delta mappings

**Expected behavior:**
- `loss_pred` should decrease quickly (fast learning path)
- `loss_computed` should follow (LoRA produces correct behavior)
- `loss_consistency` should stay low (heads agree)

In [None]:
# Final summary
print("\n" + "="*60)
print("Phase 5 Complete!")
print("="*60)
print(f"\nArchitecture: LoRAGeneratorWithDeltaHead")
print(f"Training mode: Delta-guided (3-part loss)")
print(f"Total steps: {len(loss_history)}")
print(f"Best loss: {best_loss:.6f}")

if loss_history:
    print(f"\nFinal losses:")
    print(f"  Total: {loss_history[-1]:.6f}")
    print(f"  Predicted→Teacher: {loss_pred_history[-1]:.6f}")
    print(f"  Computed→Teacher: {loss_computed_history[-1]:.6f}")
    print(f"  Consistency: {loss_consistency_history[-1]:.6f}")

print(f"\nEvaluation (cosine similarity):")
print(f"  Delta Predicted: {eval_results['mean_cosine_pred']:.4f} +/- {eval_results['std_cosine_pred']:.4f}")
print(f"  Delta Computed:  {eval_results['mean_cosine_computed']:.4f} +/- {eval_results['std_cosine_computed']:.4f}")

if eval_data and delta_guided_results:
    print(f"\nTask Performance:")
    for task, acc in delta_guided_results.items():
        print(f"  {task}: {acc:.2%}")

print(f"\nOutputs: {config.output_dir}")