# Train LoRA Adapters with SFTTrainer

Simplified adapter training using HuggingFace's `SFTTrainer` from `trl`.

**Improvements over manual training:**
- Built-in evaluation during training
- Automatic gradient checkpointing
- Cleaner configuration via `SFTConfig`
- Better memory management

**Tasks:** ARC-e, BoolQ, GSM8K  
**Output:** LoRA adapters + deltas + eval metrics

In [None]:
import gc
import json
import os
import sys
from dataclasses import dataclass
from pathlib import Path

import torch
from datasets import Dataset
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"PyTorch: {torch.__version__}")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Environment setup
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_ROOT = '/content/drive/MyDrive/llgbm'
    DATA_DIR = f'{DRIVE_ROOT}/data'
    OUTPUT_DIR = f'{DRIVE_ROOT}/checkpoints'
    sys.path.insert(0, DRIVE_ROOT)
else:
    DATA_DIR = 'data'
    OUTPUT_DIR = 'checkpoints'

os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Data: {DATA_DIR}")
print(f"Output: {OUTPUT_DIR}")

## Configuration

In [None]:
@dataclass
class Config:
    # Model - use Instruct version to match ablation pipeline
    model_name: str = "Qwen/Qwen2.5-0.5B-Instruct"
    
    # LoRA
    lora_rank: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    target_modules: tuple = ("q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj")
    
    # Training (SFTConfig compatible)
    num_epochs: int = 2
    batch_size: int = 4
    learning_rate: float = 2e-4
    max_length: int = 384
    warmup_ratio: float = 0.1
    gradient_checkpointing: bool = True
    eval_ratio: float = 0.1  # 10% of training data for eval
    
    # Data
    sample_ratio: float = 0.1  # Use 10% of available data per adapter
    adapters_per_task: int = 3  # Number of adapters per task

config = Config()
print(f"Model: {config.model_name}")
print(f"LoRA: rank={config.lora_rank}, alpha={config.lora_alpha}")
print(f"Training: {config.num_epochs} epochs, batch={config.batch_size}")
print(f"Data: {config.sample_ratio:.0%} of task data per adapter")

In [None]:
# Task definitions
TASKS = {
    "arc_e": {"file": "ARC-e_train.json", "delta_probes": 16},
    "boolq": {"file": "BoolQ_train.json", "delta_probes": 16},
    "gsm8k": {"file": "GSM8K_train.json", "delta_probes": 16},
}

# Load and verify data
task_data_cache = {}
for task, info in TASKS.items():
    path = Path(DATA_DIR) / info["file"]
    if path.exists():
        with open(path) as f:
            task_data_cache[task] = json.load(f)
        # Compute samples per adapter based on ratio
        n_samples = int(len(task_data_cache[task]) * config.sample_ratio)
        print(f"{task}: {len(task_data_cache[task])} total â†’ {n_samples} per adapter ({config.sample_ratio:.0%})")
    else:
        print(f"{task}: MISSING")

## Helper Functions

In [None]:
def format_chat(example: dict) -> dict:
    """Format example as Qwen chat template for SFTTrainer."""
    system = example.get("system", "You are a helpful assistant.")
    text = (
        f"<|im_start|>system\n{system}<|im_end|>\n"
        f"<|im_start|>user\n{example['prompt']}<|im_end|>\n"
        f"<|im_start|>assistant\n{example['response']}<|im_end|>"
    )
    return {"text": text}


def create_datasets(data: list[dict], eval_ratio: float = 0.1) -> tuple[Dataset, Dataset]:
    """Create train/eval HuggingFace Datasets from raw data."""
    # Split data
    n_eval = max(1, int(len(data) * eval_ratio))
    train_data = data[:-n_eval]
    eval_data = data[-n_eval:]
    
    # Create datasets with text formatting
    train_ds = Dataset.from_list(train_data).map(format_chat, remove_columns=["prompt", "response"])
    eval_ds = Dataset.from_list(eval_data).map(format_chat, remove_columns=["prompt", "response"])
    
    return train_ds, eval_ds


def create_lora_config(cfg: Config) -> LoraConfig:
    """Create PEFT LoRA configuration."""
    return LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=cfg.lora_rank,
        lora_alpha=cfg.lora_alpha,
        lora_dropout=cfg.lora_dropout,
        target_modules=list(cfg.target_modules),
        bias="none",
    )

In [None]:

def train_adapter(
    task_name: str,
    adapter_idx: int,
    train_data: list[dict],
    tokenizer,
    cfg: Config,
) -> dict:
    """Train a single LoRA adapter using SFTTrainer with evaluation."""
    
    adapter_name = f"{task_name}_{adapter_idx:03d}"
    output_path = Path(OUTPUT_DIR) / task_name / adapter_name
    
    print(f"\n{'='*50}")
    print(f"Training: {adapter_name} ({len(train_data)} samples)")
    print(f"{'='*50}")
    
    # Create train/eval datasets
    train_ds, eval_ds = create_datasets(train_data, cfg.eval_ratio)
    print(f"  Train: {len(train_ds)}, Eval: {len(eval_ds)}")
    
    # Load fresh base model
    model = AutoModelForCausalLM.from_pretrained(
        cfg.model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )
    
    # Apply LoRA
    lora_config = create_lora_config(cfg)
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    
    # SFTConfig: save best 2 checkpoints, load best at end
    training_args = SFTConfig(
        output_dir=str(output_path),
        num_train_epochs=cfg.num_epochs,
        per_device_train_batch_size=cfg.batch_size,
        per_device_eval_batch_size=cfg.batch_size,
        learning_rate=cfg.learning_rate,
        warmup_ratio=cfg.warmup_ratio,
        max_seq_length=cfg.max_length,
        gradient_checkpointing=cfg.gradient_checkpointing,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        completion_only_loss=True,
        # Evaluation & Checkpointing
        eval_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        # Logging
        logging_steps=20,
        report_to="none",
        # Optimization
        bf16=True,
        optim="adamw_torch_fused",
        max_grad_norm=1.0,
    )
    
    # Create trainer
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        processing_class=tokenizer,
    )
    
    # Train (best model loaded automatically at end)
    result = trainer.train()
    eval_result = trainer.evaluate()
    
    # Save only the best adapter
    final_path = output_path / "adapter"
    model.save_pretrained(final_path)
    
    # Save prompts & metrics
    with open(final_path / "prompts.json", "w") as f:
        json.dump({"prompts": [d["prompt"] for d in train_data[:128]], "task": task_name}, f)
    
    with open(final_path / "metrics.json", "w") as f:
        json.dump({
            "train_loss": result.training_loss,
            "eval_loss": eval_result.get("eval_loss"),
            "train_samples": len(train_ds),
            "eval_samples": len(eval_ds),
        }, f, indent=2)
    
    print(f"  Train: {result.training_loss:.4f}, Eval: {eval_result.get('eval_loss'):.4f}")
    print(f"  Saved: {final_path}")
    
    # Cleanup checkpoints
    import shutil
    for ckpt in output_path.glob("checkpoint-*"):
        shutil.rmtree(ckpt)
    
    del model, trainer
    gc.collect()
    torch.cuda.empty_cache()
    
    return {
        "name": adapter_name,
        "task": task_name,
        "path": str(final_path),
        "train_loss": result.training_loss,
        "eval_loss": eval_result.get("eval_loss"),
        "samples": len(train_data),
    }

## Load Tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained(config.model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
print(f"Tokenizer: {config.model_name}")
print(f"Vocab size: {len(tokenizer)}")

## Train All Adapters

In [None]:
all_adapters = []

for task_name, task_info in TASKS.items():
    task_data = task_data_cache[task_name]
    
    # Compute samples per adapter from ratio
    samples = int(len(task_data) * config.sample_ratio)
    samples = max(samples, 10)  # Minimum 10 samples
    
    for adapter_idx in range(config.adapters_per_task):
        # Select data subset (non-overlapping when possible)
        start_idx = adapter_idx * samples
        end_idx = start_idx + samples
        
        if end_idx > len(task_data):
            # Wrap around if we run out of data
            subset = task_data[start_idx:] + task_data[:end_idx - len(task_data)]
        else:
            subset = task_data[start_idx:end_idx]
        
        adapter_info = train_adapter(
            task_name=task_name,
            adapter_idx=adapter_idx,
            train_data=subset,
            tokenizer=tokenizer,
            cfg=config,
        )
        all_adapters.append(adapter_info)

print(f"\nTrained {len(all_adapters)} adapters!")

## Save Manifest

In [None]:
manifest = {
    "model_name": config.model_name,
    "lora_config": {
        "rank": config.lora_rank,
        "alpha": config.lora_alpha,
        "target_modules": list(config.target_modules),
    },
    "adapters": all_adapters,
}

output_path = Path(OUTPUT_DIR)
with open(output_path / "manifest.json", "w") as f:
    json.dump(manifest, f, indent=2)

print(f"Manifest saved: {output_path / 'manifest.json'}")

## Compute Deltas

Delta activations measure behavioral differences between base and adapted models.

In [None]:
import numpy as np
import random

def get_probes(task_data: list[dict], num_probes: int = 16) -> list[str]:
    """Get task-specific probes for delta computation."""
    random.seed(42)
    indices = random.sample(range(len(task_data)), min(num_probes, len(task_data)))
    
    probes = []
    for idx in indices:
        item = task_data[idx]
        system = item.get("system", "You are a helpful assistant.")
        probe = (
            f"<|im_start|>system\n{system}<|im_end|>\n"
            f"<|im_start|>user\n{item['prompt']}<|im_end|>\n"
            f"<|im_start|>assistant\n"
        )
        probes.append(probe)
    return probes


def get_activation(model, tokenizer, probes: list[str], device) -> torch.Tensor:
    """Compute average last-layer, last-token activation over probes."""
    model.eval()
    activations = []
    
    with torch.no_grad():
        for probe in probes:
            inputs = tokenizer(probe, return_tensors="pt", truncation=True, max_length=128)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = model(**inputs, output_hidden_states=True)
            
            seq_len = inputs["attention_mask"].sum().item()
            last_hidden = outputs.hidden_states[-1][0, seq_len - 1, :]
            activations.append(last_hidden)
    
    return torch.stack(activations).mean(dim=0)


def compute_all_deltas(adapters: list[dict], task_data_cache: dict, tokenizer, cfg: Config):
    """Compute delta activations for all trained adapters."""
    from tqdm.auto import tqdm
    
    output_path = Path(OUTPUT_DIR)
    deltas_dir = output_path / "deltas"
    deltas_dir.mkdir(exist_ok=True)
    
    # Load base model once
    print("Loading base model...")
    base_model = AutoModelForCausalLM.from_pretrained(
        cfg.model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )
    base_model.eval()
    device = next(base_model.parameters()).device
    
    # Compute base activation using mixed probes
    all_probes = []
    for task_name, task_data in task_data_cache.items():
        all_probes.extend(get_probes(task_data, 8))
    
    base_activation = get_activation(base_model, tokenizer, all_probes, device)
    np.save(deltas_dir / "base_activation.npy", base_activation.cpu().float().numpy())
    print(f"Base activation: {base_activation.shape}")
    
    # Compute per-adapter deltas
    delta_manifest = {"base_activation_file": "base_activation.npy", "adapters": {}}
    
    for adapter_info in tqdm(adapters, desc="Computing deltas"):
        adapter_name = adapter_info["name"]
        task_name = adapter_info["task"]
        adapter_path = adapter_info["path"]  # Already includes /adapter
        
        # Task-specific probes
        task_probes = get_probes(task_data_cache[task_name], TASKS[task_name]["delta_probes"])
        
        # Get base activation for these specific probes
        base_act = get_activation(base_model, tokenizer, task_probes, device)
        
        # Load adapter and get activation
        adapted = PeftModel.from_pretrained(base_model, adapter_path)
        adapted.eval()
        adapted_act = get_activation(adapted, tokenizer, task_probes, device)
        
        # Delta
        delta = (adapted_act - base_act).cpu().float().numpy()
        
        # Save
        delta_file = f"{adapter_name}_delta.npy"
        np.save(deltas_dir / delta_file, delta)
        
        delta_manifest["adapters"][adapter_name] = {
            "adapter_path": adapter_path,
            "delta_file": delta_file,
            "task": task_name,
            "delta_norm": float(np.linalg.norm(delta)),
        }
        
        # Cleanup
        del adapted
        gc.collect()
        torch.cuda.empty_cache()
    
    # Save manifest
    with open(deltas_dir / "delta_manifest.json", "w") as f:
        json.dump(delta_manifest, f, indent=2)
    
    # Cleanup base model
    del base_model
    gc.collect()
    torch.cuda.empty_cache()
    
    print(f"\nDeltas saved to: {deltas_dir}")

In [None]:
# Compute deltas for all trained adapters
compute_all_deltas(all_adapters, task_data_cache, tokenizer, config)

## Summary

In [None]:
import pandas as pd

# Create summary DataFrame
df = pd.DataFrame(all_adapters)
print("="*60)
print("Training Complete!")
print("="*60)

print(f"\nAdapters trained: {len(all_adapters)}")
print(f"Output: {OUTPUT_DIR}")

print("\nPer-task metrics:")
summary = df.groupby("task").agg({
    "train_loss": ["mean", "std"],
    "eval_loss": ["mean", "std"],
}).round(4)
print(summary)

print("\nFiles created:")
print(f"  {OUTPUT_DIR}/manifest.json")
print(f"  {OUTPUT_DIR}/deltas/delta_manifest.json")
for a in all_adapters:
    print(f"  {a['path']}/")