# Baseline: Raw Weight Supervision

The simplest baseline: MSE between raw (A, B) matrices.

**Note:** This suffers from gauge ambiguity â€” many (A, B) pairs produce the same
effective weight update DW = B @ A * scaling. Compare with `toy_delta_w.ipynb`
which supervises on DW directly.

```
Loss = MSE(A_pred, A_teacher) + MSE(B_pred, B_teacher)
```

In [None]:
import sys
import os
import shutil

IN_COLAB = 'google.colab' in sys.modules
DRIVE_OUTPUT_DIR = None

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_OUTPUT_DIR = '/content/drive/MyDrive/llgbm/outputs'
    os.makedirs(DRIVE_OUTPUT_DIR, exist_ok=True)
    !pip install -q safetensors accelerate transformers peft sentence-transformers
    sys.path.insert(0, '/content/drive/MyDrive')
    CHECKPOINT_DIR = '/content/drive/MyDrive/llgbm/checkpoints'
    DELTAS_DIR = CHECKPOINT_DIR + '/deltas'
else:
    CHECKPOINT_DIR = './checkpoints'
    DELTAS_DIR = './llgbm/deltas'

In [None]:
from pathlib import Path

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR

from llgbm import (
    TrainingConfig,
    WeightLoss,
    MultiTaskLoss,
    create_generator,
    RealAdapterDataset,
    FunctionalLoRA,
    train,
    evaluate,
)
from llgbm.ablations import setup_base_components, AblationConfig

print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}")

In [None]:
# Training configuration
config = TrainingConfig(
    use_small_model=True,
    batch_size=4,
    gradient_accumulation_steps=2,
    num_steps=200,
    warmup_steps=20,
    learning_rate=2e-4,
    lambda_weight=1.0,
    lambda_delta=0.0,  # No hidden-state delta needed
    num_probes=10,
    max_probe_length=256,
    delta_batch_probes=True,
    checkpoint_dir=CHECKPOINT_DIR,
    delta_cache_dir=DELTAS_DIR,
    output_dir="outputs/toy_baseline",
    text_encoder_name="sentence-transformers/all-MiniLM-L6-v2",
    freeze_text_encoder=True,
    num_prompts_per_adapter=8,
)

TORCH_DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[config.dtype]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Path(config.output_dir).mkdir(parents=True, exist_ok=True)
config.save(f"{config.output_dir}/config.json")

print(f"Model: {config.base_model}")
print(f"Device: {device}")
print(f"Loss: WeightLoss (raw A, B MSE)")

In [None]:
# Setup shared components (base model, probes, dataset, etc.)
ablation_config = AblationConfig(
    checkpoint_dir=CHECKPOINT_DIR,
    deltas_dir=DELTAS_DIR,
    output_dir=config.output_dir,
    use_small_model=config.use_small_model,
    batch_size=config.batch_size,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    learning_rate=config.learning_rate,
    warmup_steps=config.warmup_steps,
    num_probes=config.num_probes,
    max_probe_length=config.max_probe_length,
    delta_batch_probes=config.delta_batch_probes,
    text_encoder_name=config.text_encoder_name,
    freeze_text_encoder=config.freeze_text_encoder,
    num_prompts_per_adapter=config.num_prompts_per_adapter,
)
components = setup_base_components(ablation_config, config)

In [None]:
# Create generator + text encoder
generator = create_generator(
    config,
    seed=42,
    device=device,
    text_encoder=components["text_encoder"],
)
print(f"[OK] Generator: {sum(p.numel() for p in generator.parameters() if p.requires_grad):,} trainable params")

In [None]:
# Create WeightLoss + optimizer + scheduler
criterion = MultiTaskLoss(lambda_weight=1.0, lambda_delta=0.0)
weight_criterion = WeightLoss()

optimizer = AdamW(generator.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
warmup_steps = min(config.warmup_steps, config.num_steps // 10)
cosine_steps = max(1, config.num_steps - warmup_steps)
scheduler = SequentialLR(
    optimizer,
    [LinearLR(optimizer, 0.1, 1.0, warmup_steps),
     CosineAnnealingLR(optimizer, cosine_steps, config.learning_rate * 0.01)],
    [warmup_steps]
)

print(f"[OK] WeightLoss (raw A, B MSE)")
print(f"[OK] Optimizer & Scheduler ready")

In [None]:
# Dataloader
dataloader = DataLoader(
    components["dataset"],
    batch_size=config.batch_size,
    shuffle=True,
    collate_fn=components["dataset"].collate_fn,
    num_workers=2,
    pin_memory=True,
)

# Train
state = train(
    generator=generator,
    dataloader=dataloader,
    functional_lora=components["functional_lora"],
    base_activation=components["base_activation"],
    probe_tokens=components["probe_tokens"],
    probe_masks=components["probe_masks"],
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    config=config,
    compute_dtype=TORCH_DTYPE,
    weight_criterion=weight_criterion,
)

In [None]:
# Evaluate (delta cosine similarity)
eval_dataloader = DataLoader(
    components["dataset"],
    batch_size=config.batch_size,
    shuffle=False,
    collate_fn=components["dataset"].collate_fn,
)

# Use MultiTaskLoss with lambda_delta=1 for eval (needs delta computation)
eval_criterion = MultiTaskLoss(lambda_weight=0.0, lambda_delta=1.0)
eval_results = evaluate(
    generator=generator,
    dataloader=eval_dataloader,
    functional_lora=components["functional_lora"],
    base_activation=components["base_activation"],
    probe_tokens=components["probe_tokens"],
    probe_masks=components["probe_masks"],
    criterion=eval_criterion,
)

print("\nEvaluation Results:")
for k, v in eval_results.items():
    print(f"  {k}: {v:.4f}")

In [None]:
# Plot training curves
import matplotlib.pyplot as plt

if state.loss_history:
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    axes[0].plot(state.loss_history, color='#2c3e50', linewidth=1.5)
    axes[0].set_xlabel('Step')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Weight Loss (raw A, B)')
    axes[0].grid(True, alpha=0.3)

    axes[1].plot(state.grad_norm_history, color='#9b59b6', linewidth=1.5)
    axes[1].axhline(config.max_grad_norm, color='r', ls='--', label=f'clip={config.max_grad_norm}')
    axes[1].set_xlabel('Step')
    axes[1].set_ylabel('Gradient Norm')
    axes[1].set_title('Gradient Norms')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    axes[2].plot(state.lr_history, color='#f39c12', linewidth=1.5)
    axes[2].set_xlabel('Step')
    axes[2].set_ylabel('Learning Rate')
    axes[2].set_title('LR Schedule')
    axes[2].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f"{config.output_dir}/training_curves.png", dpi=150)
    plt.show()

In [None]:
# Save results
import json
from dataclasses import asdict

results = {
    "config": asdict(config),
    "training": {
        "steps": len(state.loss_history),
        "best_loss": state.best_loss,
        "final_loss": state.loss_history[-1] if state.loss_history else None,
    },
    "eval": eval_results,
    "mode": "weight_only",
}

with open(f"{config.output_dir}/results.json", "w") as f:
    json.dump(results, f, indent=2)

print(f"Saved to {config.output_dir}/")

In [None]:
# Sync to Google Drive (Colab only)
if IN_COLAB and DRIVE_OUTPUT_DIR:
    drive_dir = f"{DRIVE_OUTPUT_DIR}/toy_baseline"
    if os.path.exists(drive_dir):
        shutil.rmtree(drive_dir)
    shutil.copytree(config.output_dir, drive_dir)
    print(f"[Drive] Synced to {drive_dir}")
else:
    print(f"[Local] Outputs saved to {config.output_dir}")