# Phase 4: Multi-Task Training (Weights + Deltas)

Train a LoRA generator with both weight supervision and behavioral (delta) supervision.

**Core Equation:** `Loss = λ_w * L_weight + λ_d * L_delta`

In [None]:
import sys
import os
import shutil

IN_COLAB = 'google.colab' in sys.modules
DRIVE_OUTPUT_DIR = None

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_OUTPUT_DIR = '/content/drive/MyDrive/llgbm/outputs'
    os.makedirs(DRIVE_OUTPUT_DIR, exist_ok=True)
    !pip install -q safetensors accelerate transformers peft
    sys.path.insert(0, '/content/drive/MyDrive')

import json
import gc
from pathlib import Path
from dataclasses import asdict

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from transformers import AutoModelForCausalLM, AutoTokenizer

print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}")

In [None]:
# Import llgbm modules
from llgbm import (
    create_generic_probes,
    DeltaCache,
    FunctionalLoRA,
    TrainingConfig,
    MultiTaskLoss,
    train,
    evaluate,
)

print("[OK] llgbm imports")

## Configuration

In [None]:
config = TrainingConfig(
    use_small_model=True,  # Qwen2.5-0.5B for testing
    batch_size=2,
    gradient_accumulation_steps=4,
    max_steps=100,  # Short run for testing
    warmup_steps=10,  # Must be < max_steps
    lambda_delta=0.1,
    lambda_weight=1.0,
    output_dir="outputs/phase4_multitask",
)

TORCH_DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[config.dtype]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Path(config.output_dir).mkdir(parents=True, exist_ok=True)
config.save(f"{config.output_dir}/config.json")

print(f"Model: {config.base_model}")
print(f"Loss: {config.lambda_weight}*L_w + {config.lambda_delta}*L_d")

## Load Base Model & Prepare Probes

In [None]:
tokenizer = AutoTokenizer.from_pretrained(config.base_model, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    config.base_model, torch_dtype=TORCH_DTYPE, device_map=device, trust_remote_code=True
)
base_model.config.output_hidden_states = True
for p in base_model.parameters():
    p.requires_grad = False

print(f"[OK] Base model: {sum(p.numel() for p in base_model.parameters()):,} params")

In [None]:
# Tokenize probes
probes = create_generic_probes()[:config.num_probes]
probe_tokens = []
probe_masks = []
for p in probes:
    enc = tokenizer(p, return_tensors="pt", truncation=True, max_length=config.max_probe_length)
    probe_tokens.append(enc["input_ids"].to(device))
    probe_masks.append(enc["attention_mask"].to(device))

# Compute base activation
with torch.no_grad():
    base_acts = []
    for ids, mask in zip(probe_tokens, probe_masks):
        out = base_model(input_ids=ids, attention_mask=mask, output_hidden_states=True)
        h = out.hidden_states[-1][:, int(mask.sum()) - 1, :].squeeze(0)
        base_acts.append(h)
    base_activation = torch.stack(base_acts).mean(dim=0)

print(f"[OK] Base activation: {base_activation.shape}, norm={base_activation.norm():.4f}")

In [None]:
functional_lora = FunctionalLoRA(
    base_model=base_model,
    lora_rank=config.lora_rank,
    lora_alpha=config.lora_alpha,
)
print(f"[OK] FunctionalLoRA: {len(functional_lora._lora_to_base_map)} mappings")

## Placeholder Generator

In [None]:
class PlaceholderGenerator(nn.Module):
    """Simple generator that outputs LoRA weights from text condition."""
    
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.embed = nn.Embedding(50000, 256)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=256, nhead=4, batch_first=True), num_layers=2
        )
        self.proj = nn.Linear(256, cfg.num_layers * 7 * 2)  # 7 modules, 2 scales (A, B)
    
    def forward(self, condition_ids, attention_mask=None):
        B = condition_ids.shape[0]
        x = self.embed(condition_ids)
        mask = ~attention_mask.bool() if attention_mask is not None else None
        x = self.encoder(x, src_key_padding_mask=mask)
        if attention_mask is not None:
            x = (x * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(1, keepdim=True).clamp(min=1)
        else:
            x = x.mean(1)
        scales = self.proj(x).view(B, self.cfg.num_layers * 7, 2)
        
        # Generate LoRA weights
        batch_weights = []
        for b in range(B):
            weights = {}
            idx = 0
            for layer in range(self.cfg.num_layers):
                for proj in ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]:
                    sa, sb = scales[b, idx, 0], scales[b, idx, 1]
                    prefix = f"model.layers.{layer}"
                    mod = "self_attn" if proj in ["q_proj","k_proj","v_proj","o_proj"] else "mlp"
                    
                    if proj in ["k_proj", "v_proj"]:
                        out_d = self.cfg.num_kv_heads * (self.cfg.hidden_size // self.cfg.num_heads)
                    elif proj in ["gate_proj", "up_proj"]:
                        out_d = self.cfg.intermediate_size
                    elif proj == "down_proj":
                        out_d = self.cfg.hidden_size
                    else:
                        out_d = self.cfg.hidden_size
                    
                    in_d = self.cfg.intermediate_size if proj == "down_proj" else self.cfg.hidden_size
                    
                    A = torch.randn(self.cfg.lora_rank, in_d, device=condition_ids.device) * 0.01 * sa
                    B_ = torch.randn(out_d, self.cfg.lora_rank, device=condition_ids.device) * 0.001 * sb
                    weights[f"{prefix}.{mod}.{proj}.lora_A.weight"] = A
                    weights[f"{prefix}.{mod}.{proj}.lora_B.weight"] = B_
                    idx += 1
            batch_weights.append(weights)
        return batch_weights

generator = PlaceholderGenerator(config).to(device)
print(f"[OK] Generator: {sum(p.numel() for p in generator.parameters() if p.requires_grad):,} params")

## Dataset

In [None]:
from safetensors.torch import save_file

# Create sample data if needed
checkpoint_dir = Path(config.checkpoint_dir)
delta_cache = DeltaCache(config.delta_cache_dir)

if not list(checkpoint_dir.rglob("adapter_config.json")) or delta_cache.summary().get('count', 0) == 0:
    print("Creating sample data...")
    for name, domain in [("math_001", "math"), ("code_001", "code"), ("general_001", "general")]:
        path = checkpoint_dir / domain / name
        path.mkdir(parents=True, exist_ok=True)
        
        # Minimal adapter
        torch.manual_seed(hash(name) % 10000)
        weights = {f"layer.0.lora_A": torch.randn(8, 256) * 0.01}
        save_file(weights, path / "adapter_model.safetensors")
        json.dump({"r": 8, "peft_type": "LORA"}, open(path / "adapter_config.json", "w"))
        json.dump({"prompts": [f"Solve {domain} problem"]}, open(path / "prompts.json", "w"))
        delta_cache.save_delta(str(path), torch.randn(config.hidden_size).numpy() * 0.1)
    delta_cache.save_base_activation(np.zeros(config.hidden_size), {})

In [None]:
class SimpleDataset(Dataset):
    def __init__(self, checkpoint_dir, delta_cache, tokenizer, hidden_size):
        all_deltas = delta_cache.get_all_deltas()
        self.samples = [
            str(p.parent) for p in Path(checkpoint_dir).rglob("adapter_config.json")
            if str(p.parent) in all_deltas
        ]
        self.deltas = all_deltas
        self.tokenizer = tokenizer
        self.hidden_size = hidden_size
    
    def __len__(self): return len(self.samples)
    
    def __getitem__(self, idx):
        path = self.samples[idx]
        prompts_file = Path(path) / "prompts.json"
        text = json.load(open(prompts_file)).get("prompts", [Path(path).name])[0] if prompts_file.exists() else Path(path).name
        enc = self.tokenizer(text, max_length=256, padding="max_length", truncation=True, return_tensors="pt")
        return {
            "condition_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "delta_teacher": torch.from_numpy(self.deltas[path]).float(),
        }
    
    @staticmethod
    def collate_fn(batch):
        return {k: torch.stack([b[k] for b in batch]) for k in batch[0]}

text_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
dataset = SimpleDataset(config.checkpoint_dir, delta_cache, text_tokenizer, config.hidden_size)
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, collate_fn=dataset.collate_fn)
print(f"[OK] Dataset: {len(dataset)} samples, {len(dataloader)} batches")

## Training

In [None]:
criterion = MultiTaskLoss(lambda_weight=config.lambda_weight, lambda_delta=config.lambda_delta)

optimizer = AdamW(generator.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

# Ensure T_max >= 1 to avoid division by zero
cosine_steps = max(1, config.max_steps - config.warmup_steps)
scheduler = SequentialLR(
    optimizer,
    [LinearLR(optimizer, 0.1, 1.0, config.warmup_steps),
     CosineAnnealingLR(optimizer, cosine_steps, config.learning_rate * 0.01)],
    [config.warmup_steps]
)

print(f"[OK] Optimizer & Scheduler ready (warmup={config.warmup_steps}, cosine={cosine_steps})")

In [None]:
# Train
print(f"\nTraining: {config.max_steps} steps, batch={config.batch_size}x{config.gradient_accumulation_steps}")

state = train(
    generator=generator,
    dataloader=dataloader,
    functional_lora=functional_lora,
    base_activation=base_activation,
    probe_tokens=probe_tokens,
    probe_masks=probe_masks,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    config=config,
    compute_dtype=TORCH_DTYPE,
)

print(f"\nDone! Steps: {state.step}, Best loss: {state.best_loss:.6f}")

In [None]:
# Evaluate
eval_results = evaluate(
    generator=generator,
    dataloader=dataloader,
    functional_lora=functional_lora,
    base_activation=base_activation,
    probe_tokens=probe_tokens,
    probe_masks=probe_masks,
    criterion=criterion,
)
print("Evaluation:", {k: f"{v:.4f}" for k, v in eval_results.items()})

In [None]:
# Plot
if state.loss_history:
    fig, ax = plt.subplots(1, 2, figsize=(10, 4))
    ax[0].plot(state.loss_history, label='Total')
    ax[0].plot(state.loss_delta_history, label='Delta', alpha=0.7)
    ax[0].set_xlabel('Step'); ax[0].set_ylabel('Loss'); ax[0].legend()
    ax[1].plot(state.grad_norm_history)
    ax[1].axhline(config.max_grad_norm, color='r', ls='--')
    ax[1].set_xlabel('Step'); ax[1].set_ylabel('Grad Norm')
    plt.tight_layout()
    plt.savefig(f"{config.output_dir}/curves.png", dpi=100)
    plt.show()

In [None]:
# Save results
results = {
    "config": asdict(config),
    "training": {"steps": state.step, "best_loss": state.best_loss},
    "eval": eval_results,
}
json.dump(results, open(f"{config.output_dir}/results.json", "w"), indent=2)
print(f"Saved to {config.output_dir}/")

In [None]:
# Sync to Google Drive (Colab only)
if IN_COLAB and DRIVE_OUTPUT_DIR:
    drive_phase4_dir = f"{DRIVE_OUTPUT_DIR}/phase4_multitask"
    if os.path.exists(drive_phase4_dir):
        shutil.rmtree(drive_phase4_dir)
    shutil.copytree(config.output_dir, drive_phase4_dir)
    print(f"[Drive] Synced to {drive_phase4_dir}")
else:
    print("[Local] Outputs saved to", config.output_dir)