# Delta-W: Effective Weight Update Supervision

Supervise the generator on **DW = B @ A * scaling** instead of raw (A, B) matrices.

**Why?** Raw weight matching suffers from gauge ambiguity: many (A, B) pairs produce
the same effective update. DW is the canonical representation â€” gauge-invariant and
cheap to compute (matrix multiply only, no forward pass needed).

```
Loss = MSE(DW_pred, DW_teacher)
where DW = B @ A * (alpha / rank)
```

In [None]:
import sys
import os
import shutil

IN_COLAB = 'google.colab' in sys.modules
DRIVE_OUTPUT_DIR = None

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_OUTPUT_DIR = '/content/drive/MyDrive/llgbm/outputs'
    os.makedirs(DRIVE_OUTPUT_DIR, exist_ok=True)
    !pip install -q safetensors accelerate transformers peft sentence-transformers
    sys.path.insert(0, '/content/drive/MyDrive')
    CHECKPOINT_DIR = '/content/drive/MyDrive/llgbm/checkpoints'
    DELTAS_DIR = CHECKPOINT_DIR + '/deltas'
else:
    CHECKPOINT_DIR = './checkpoints'
    DELTAS_DIR = './llgbm/deltas'

In [None]:
from pathlib import Path

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR

from llgbm import (
    TrainingConfig,
    DeltaWLoss,
    MultiTaskLoss,
    create_generator,
    RealAdapterDataset,
    FunctionalLoRA,
    train,
    evaluate,
)
from llgbm.ablations import setup_base_components, AblationConfig

print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}")

In [None]:
# Training configuration
config = TrainingConfig(
    use_small_model=True,
    batch_size=4,
    gradient_accumulation_steps=2,
    num_steps=200,
    warmup_steps=20,
    learning_rate=2e-4,
    lambda_weight=1.0,
    lambda_delta=0.0,  # No hidden-state delta needed
    num_probes=10,
    max_probe_length=256,
    delta_batch_probes=True,
    checkpoint_dir=CHECKPOINT_DIR,
    delta_cache_dir=DELTAS_DIR,
    output_dir="outputs/toy_delta_w",
    text_encoder_name="sentence-transformers/all-MiniLM-L6-v2",
    freeze_text_encoder=True,
    num_prompts_per_adapter=8,
)

TORCH_DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[config.dtype]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Path(config.output_dir).mkdir(parents=True, exist_ok=True)
config.save(f"{config.output_dir}/config.json")

print(f"Model: {config.base_model}")
print(f"Device: {device}")
print(f"Loss: DeltaW (MSE on B@A*scaling)")

In [None]:
# Setup shared components (base model, probes, dataset, etc.)
ablation_config = AblationConfig(
    checkpoint_dir=CHECKPOINT_DIR,
    deltas_dir=DELTAS_DIR,
    output_dir=config.output_dir,
    use_small_model=config.use_small_model,
    batch_size=config.batch_size,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    learning_rate=config.learning_rate,
    warmup_steps=config.warmup_steps,
    num_probes=config.num_probes,
    max_probe_length=config.max_probe_length,
    delta_batch_probes=config.delta_batch_probes,
    text_encoder_name=config.text_encoder_name,
    freeze_text_encoder=config.freeze_text_encoder,
    num_prompts_per_adapter=config.num_prompts_per_adapter,
)
components = setup_base_components(ablation_config, config)

In [None]:
# Create generator + text encoder
generator = create_generator(
    config,
    seed=42,
    device=device,
    text_encoder=components["text_encoder"],
)
print(f"[OK] Generator: {sum(p.numel() for p in generator.parameters() if p.requires_grad):,} trainable params")

In [None]:
# Create DeltaWLoss + optimizer + scheduler
criterion = MultiTaskLoss(lambda_weight=1.0, lambda_delta=0.0)
weight_criterion = DeltaWLoss(
    lora_alpha=config.lora_alpha,
    lora_rank=config.lora_rank,
)

optimizer = AdamW(generator.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
warmup_steps = min(config.warmup_steps, config.num_steps // 10)
cosine_steps = max(1, config.num_steps - warmup_steps)
scheduler = SequentialLR(
    optimizer,
    [LinearLR(optimizer, 0.1, 1.0, warmup_steps),
     CosineAnnealingLR(optimizer, cosine_steps, config.learning_rate * 0.01)],
    [warmup_steps]
)

print(f"[OK] DeltaWLoss(scaling={weight_criterion.scaling})")
print(f"[OK] Optimizer & Scheduler ready")

In [None]:
# Dataloader
dataloader = DataLoader(
    components["dataset"],
    batch_size=config.batch_size,
    shuffle=True,
    collate_fn=components["dataset"].collate_fn,
    num_workers=2,
    pin_memory=True,
)

# Train
state = train(
    generator=generator,
    dataloader=dataloader,
    functional_lora=components["functional_lora"],
    base_activation=components["base_activation"],
    probe_tokens=components["probe_tokens"],
    probe_masks=components["probe_masks"],
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    config=config,
    compute_dtype=TORCH_DTYPE,
    weight_criterion=weight_criterion,
)

In [None]:
# Evaluate (delta cosine similarity)
eval_dataloader = DataLoader(
    components["dataset"],
    batch_size=config.batch_size,
    shuffle=False,
    collate_fn=components["dataset"].collate_fn,
)

# Use MultiTaskLoss with lambda_delta=1 for eval (needs delta computation)
eval_criterion = MultiTaskLoss(lambda_weight=0.0, lambda_delta=1.0)
eval_results = evaluate(
    generator=generator,
    dataloader=eval_dataloader,
    functional_lora=components["functional_lora"],
    base_activation=components["base_activation"],
    probe_tokens=components["probe_tokens"],
    probe_masks=components["probe_masks"],
    criterion=eval_criterion,
)

print("\nEvaluation Results:")
for k, v in eval_results.items():
    print(f"  {k}: {v:.4f}")

In [None]:
# Plot training curves
import matplotlib.pyplot as plt

if state.loss_history:
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    axes[0].plot(state.loss_history, color='#2c3e50', linewidth=1.5)
    axes[0].set_xlabel('Step')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('DeltaW Loss')
    axes[0].grid(True, alpha=0.3)

    axes[1].plot(state.grad_norm_history, color='#9b59b6', linewidth=1.5)
    axes[1].axhline(config.max_grad_norm, color='r', ls='--', label=f'clip={config.max_grad_norm}')
    axes[1].set_xlabel('Step')
    axes[1].set_ylabel('Gradient Norm')
    axes[1].set_title('Gradient Norms')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    axes[2].plot(state.lr_history, color='#f39c12', linewidth=1.5)
    axes[2].set_xlabel('Step')
    axes[2].set_ylabel('Learning Rate')
    axes[2].set_title('LR Schedule')
    axes[2].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f"{config.output_dir}/training_curves.png", dpi=150)
    plt.show()

## Task Accuracy Evaluation

Generate LoRA weights for each task and measure downstream accuracy on held-out eval sets.

In [None]:
import json
import gc
from llgbm.evaluation import compute_accuracy_with_lora_batched, compute_base_accuracy

# Discover tasks from the adapter dataset (what we actually trained on)
dataset = components["dataset"]
tokenizer = components["tokenizer"]
adapter_tasks = sorted(set(s["task"] for s in dataset.samples))
print(f"Adapter tasks: {adapter_tasks}")

# Auto-discover eval files from data/ directory
DATA_DIR = Path("data") if not IN_COLAB else Path("/content/drive/MyDrive/llgbm/data")

# Infer task type from filename/content
KNOWN_BOOL_TASKS = {"boolq"}
KNOWN_NUMERIC_TASKS = {"gsm8k"}

def infer_task_type(task_name):
    t = task_name.lower().replace("-", "").replace("_", "")
    if t in {"boolq"}:
        return "bool"
    if t in {"gsm8k"}:
        return "gsm8k"
    return "mcq"

# Find eval files matching adapter tasks
eval_data = {}
if DATA_DIR.exists():
    eval_files = {f.stem.lower(): f for f in DATA_DIR.glob("*_eval.json")}
    print(f"Available eval files: {list(eval_files.keys())}")

    for task in adapter_tasks:
        # Try matching: "arc_e" -> "arc-e_eval", "boolq" -> "boolq_eval", etc.
        variants = [
            task.lower() + "_eval",
            task.lower().replace("_", "-") + "_eval",
            task.lower().replace("-", "_") + "_eval",
        ]
        for v in variants:
            if v in eval_files:
                with open(eval_files[v]) as f:
                    samples = json.load(f)
                task_type = infer_task_type(task)
                eval_data[task] = (samples, task_type)
                print(f"  {task}: {len(samples)} samples ({task_type})")
                break
        else:
            print(f"  {task}: [no eval file found]")

if not eval_data:
    print("[SKIP] No eval data found for any adapter task")

In [None]:
# Compute task accuracy: generated LoRA vs base model
MAX_EVAL_SAMPLES = 100

accuracy_results = {}

if eval_data:
    generator.eval()

    # Base model accuracy
    print("=" * 60)
    print("BASE MODEL (no LoRA)")
    print("=" * 60)
    base_accuracy = {}
    for task, (samples, task_type) in eval_data.items():
        res = compute_base_accuracy(
            base_model=components["base_model"],
            eval_samples=samples,
            tokenizer=tokenizer,
            task_type=task_type,
            max_samples=MAX_EVAL_SAMPLES,
            show_progress=False,
        )
        base_accuracy[task] = res
        print(f"  {task}: {res['accuracy']:.2%} ({res['correct']}/{res['total']})")

    # Generated LoRA accuracy
    print("\n" + "=" * 60)
    print("GENERATED LoRA (delta_w)")
    print("=" * 60)
    gen_accuracy = {}
    for task, (samples, task_type) in eval_data.items():
        # Use first adapter of this task as conditioning input
        task_idx = next(i for i, s in enumerate(dataset.samples) if s["task"] == task)
        sample = dataset[task_idx]
        condition_ids = sample["condition_ids"]
        attention_mask = sample["attention_mask"]

        res = compute_accuracy_with_lora_batched(
            generator=generator,
            functional_lora=components["functional_lora"],
            condition_ids=condition_ids,
            attention_mask=attention_mask,
            eval_samples=samples,
            tokenizer=tokenizer,
            task_type=task_type,
            max_samples=MAX_EVAL_SAMPLES,
            batch_size=8,
            show_progress=False,
        )
        gen_accuracy[task] = res
        print(f"  {task}: {res['accuracy']:.2%} ({res['correct']}/{res['total']})")

        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    # Summary table
    print("\n" + "=" * 60)
    print(f"{'Task':<14} {'Base':>8} {'Generated':>10} {'Delta':>8}")
    print("-" * 42)
    for task in eval_data:
        base_acc = base_accuracy[task]["accuracy"]
        gen_acc = gen_accuracy.get(task, {}).get("accuracy", 0)
        delta = gen_acc - base_acc
        sign = "+" if delta >= 0 else ""
        print(f"  {task:<12} {base_acc:>7.1%} {gen_acc:>9.1%} {sign}{delta:>7.1%}")
    print("=" * 60)

    accuracy_results = {
        "base": {t: r for t, r in base_accuracy.items()},
        "generated": {t: r for t, r in gen_accuracy.items()},
    }

In [None]:
# Save results
import json
from dataclasses import asdict

results = {
    "config": asdict(config),
    "training": {
        "steps": len(state.loss_history),
        "best_loss": state.best_loss,
        "final_loss": state.loss_history[-1] if state.loss_history else None,
    },
    "eval": eval_results,
    "accuracy": accuracy_results,
    "mode": "delta_w",
}

with open(f"{config.output_dir}/results.json", "w") as f:
    json.dump(results, f, indent=2)

print(f"Saved to {config.output_dir}/")

In [None]:
# Sync to Google Drive (Colab only)
if IN_COLAB and DRIVE_OUTPUT_DIR:
    drive_dir = f"{DRIVE_OUTPUT_DIR}/toy_delta_w"
    if os.path.exists(drive_dir):
        shutil.rmtree(drive_dir)
    shutil.copytree(config.output_dir, drive_dir)
    print(f"[Drive] Synced to {drive_dir}")
else:
    print(f"[Local] Outputs saved to {config.output_dir}")