# Train LoRA Adapters for Ablation Studies

This notebook trains real LoRA adapters by fine-tuning Qwen2.5-0.5B on various tasks.
Run this on Colab with GPU, then use the adapters in phase 4.5 ablations.

**Tasks:**
- ARC-e (science reasoning)
- BoolQ (boolean QA)
- GSM8K (math)

**Key improvements:**
- Uses **task-specific prompts** for delta computation (not generic probes)
- Creates **held-out eval splits** for performance evaluation
- Saves all necessary metadata for ablation studies

**Output:** ~9 LoRA adapters + task-specific deltas + eval splits

In [1]:
import json
import gc
from pathlib import Path
from dataclasses import dataclass
from typing import List, Dict

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, get_cosine_schedule_with_warmup
from peft import LoraConfig, get_peft_model, TaskType
from tqdm.auto import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"PyTorch: {torch.__version__}")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

  from .autonotebook import tqdm as notebook_tqdm


PyTorch: 2.9.1+cu128
Device: cuda
GPU: NVIDIA GeForce RTX 3070


In [None]:
# Setup
import sys
IN_COLAB = 'google.colab' in sys.modules

# if IN_COLAB:
#     from google.colab import drive
#     drive.mount('/content/drive')
#     !pip install -q transformers peft accelerate safetensors bitsandbytes
    
#     # Use Drive paths
#     DRIVE_ROOT = '/content/drive/MyDrive/llgbm'
#     DATA_DIR = f'{DRIVE_ROOT}/data'
#     OUTPUT_DIR = f'{DRIVE_ROOT}/checkpoints'
# else:
DATA_DIR = 'data'
OUTPUT_DIR = 'checkpoints'

import os
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Data dir: {DATA_DIR}")
print(f"Output dir: {OUTPUT_DIR}")

## Configuration

In [2]:
DATA_DIR = 'data'
OUTPUT_DIR = 'checkpoints'

import os
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Data dir: {DATA_DIR}")
print(f"Output dir: {OUTPUT_DIR}")

Data dir: data
Output dir: checkpoints


In [3]:
@dataclass
class Config:
    # Model
    model_name: str = "Qwen/Qwen2.5-0.5B"
    
    # LoRA
    lora_rank: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    target_modules: tuple = ("q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj")
    
    # Training
    num_epochs: int = 2
    batch_size: int = 4
    learning_rate: float = 2e-4
    max_length: int = 384
    warmup_ratio: float = 0.1
    
    # Data
    samples_per_adapter: int = 400
    adapters_per_task: int = 3

config = Config()
print(f"Model: {config.model_name}")
print(f"LoRA: rank={config.lora_rank}, alpha={config.lora_alpha}")
print(f"Training: {config.num_epochs} epochs, batch_size={config.batch_size}")

Model: Qwen/Qwen2.5-0.5B
LoRA: rank=8, alpha=16
Training: 2 epochs, batch_size=4


In [None]:
# Task definitions with eval split
TASKS = {
    "arc_e": {
        "file": "ARC-e_train.json",
        "samples": 400,
        "adapters": 3,
        "eval_samples": 200,  # Hold out for evaluation
        "delta_probes": 16,   # Task-specific probes for delta computation
    },
    "boolq": {
        "file": "BoolQ_train.json",
        "samples": 400,
        "adapters": 3,
        "eval_samples": 200,
        "delta_probes": 16,
    },
    "gsm8k": {
        "file": "GSM8K_train.json",
        "samples": 300,
        "adapters": 3,
        "eval_samples": 200,
        "delta_probes": 16,
    },
}

# Check data files exist
for task, info in TASKS.items():
    path = Path(DATA_DIR) / info["file"]
    exists = path.exists()
    print(f"{task}: {path.name} {'[OK]' if exists else '[MISSING]'}")

## Dataset

In [7]:
class SFTDataset(Dataset):
    """Simple SFT dataset for instruction tuning."""

    def __init__(self, data: List[Dict], tokenizer, max_length: int = 512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Build chat format
        system = item.get("system", "You are a helpful assistant.")
        prompt = item["prompt"]
        response = item["response"]

        # Qwen chat format
        text = f"<|im_start|>system\n{system}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>"

        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        input_ids = encoding["input_ids"].squeeze(0)
        attention_mask = encoding["attention_mask"].squeeze(0)

        # Labels: same as input_ids, with padding tokens set to -100
        labels = input_ids.clone()
        labels[attention_mask == 0] = -100

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }

## Training Functions

In [8]:
def train_adapter(
    model,
    tokenizer,
    train_data: List[Dict],
    output_dir: Path,
    adapter_name: str,
    config: Config,
):
    """Train a single LoRA adapter and save it."""

    print(f"\n  Training: {adapter_name} ({len(train_data)} samples)")

    # Create dataset and dataloader
    dataset = SFTDataset(train_data, tokenizer, max_length=config.max_length)
    dataloader = DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=0,
        drop_last=True,
    )

    # Optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=0.01)
    num_training_steps = len(dataloader) * config.num_epochs
    num_warmup_steps = int(num_training_steps * config.warmup_ratio)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
    )

    # Training loop
    model.train()
    total_loss = 0
    global_step = 0

    progress = tqdm(total=num_training_steps, desc=f"  {adapter_name}", leave=False)

    for epoch in range(config.num_epochs):
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )

            loss = outputs.loss
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            total_loss += loss.item()
            global_step += 1

            if global_step % 10 == 0:
                progress.set_postfix(loss=f"{total_loss / global_step:.4f}")
            progress.update(1)

    progress.close()

    # Save adapter
    adapter_dir = output_dir / adapter_name
    adapter_dir.mkdir(parents=True, exist_ok=True)
    model.save_pretrained(adapter_dir)

    # Save training prompts (for conditioning)
    prompts = [item["prompt"] for item in train_data[:128]]
    with open(adapter_dir / "prompts.json", "w") as f:
        json.dump({"prompts": prompts, "task": adapter_name}, f, indent=2)

    avg_loss = total_loss / global_step
    print(f"  Saved: {adapter_dir} (loss={avg_loss:.4f})")

    return avg_loss

In [9]:
def create_lora_model(config: Config):
    """Load base model and apply LoRA."""
    
    base_model = AutoModelForCausalLM.from_pretrained(
        config.model_name,
        torch_dtype=torch.bfloat16,
        device_map=device,
        trust_remote_code=True,
    )

    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=config.lora_rank,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
        target_modules=list(config.target_modules),
        bias="none",
    )

    model = get_peft_model(base_model, lora_config)
    return model

## Load Tokenizer

In [10]:
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
print(f"Vocab size: {len(tokenizer)}")

Loading tokenizer...
Vocab size: 151665


## Train All Adapters

In [11]:
all_adapters = []
output_path = Path(OUTPUT_DIR)

for task_name, task_info in TASKS.items():
    print(f"\n{'='*60}")
    print(f"Task: {task_name}")
    print(f"{'='*60}")

    # Load data
    data_file = Path(DATA_DIR) / task_info["file"]
    with open(data_file) as f:
        task_data = json.load(f)
    print(f"Loaded {len(task_data)} samples")

    # Train multiple adapters
    for adapter_idx in range(task_info["adapters"]):
        # Fresh model for each adapter
        print(f"\nAdapter {adapter_idx + 1}/{task_info['adapters']}")
        model = create_lora_model(config)
        model.print_trainable_parameters()

        # Select data subset
        samples = task_info["samples"]
        start_idx = adapter_idx * samples
        end_idx = start_idx + samples
        
        if end_idx > len(task_data):
            subset = task_data[start_idx:] + task_data[:end_idx - len(task_data)]
        else:
            subset = task_data[start_idx:end_idx]

        adapter_name = f"{task_name}_{adapter_idx:03d}"

        # Train
        loss = train_adapter(
            model=model,
            tokenizer=tokenizer,
            train_data=subset,
            output_dir=output_path / task_name,
            adapter_name=adapter_name,
            config=config,
        )

        all_adapters.append({
            "name": adapter_name,
            "task": task_name,
            "path": str(output_path / task_name / adapter_name),
            "loss": loss,
            "samples": len(subset),
        })

        # Cleanup
        del model
        gc.collect()
        torch.cuda.empty_cache()

print(f"\n\nTrained {len(all_adapters)} adapters!")


Task: arc_e
Loaded 2251 samples

Adapter 1/3


`torch_dtype` is deprecated! Use `dtype` instead!


trainable params: 4,399,104 || all params: 498,431,872 || trainable%: 0.8826

  Training: arc_e_000 (400 samples)


                                                                           

  Saved: checkpoints/arc_e/arc_e_000 (loss=1.2588)

Adapter 2/3
trainable params: 4,399,104 || all params: 498,431,872 || trainable%: 0.8826

  Training: arc_e_001 (400 samples)


                                                                           

  Saved: checkpoints/arc_e/arc_e_001 (loss=1.2577)

Adapter 3/3
trainable params: 4,399,104 || all params: 498,431,872 || trainable%: 0.8826

  Training: arc_e_002 (400 samples)


                                                                           

  Saved: checkpoints/arc_e/arc_e_002 (loss=1.2395)

Task: boolq
Loaded 9427 samples

Adapter 1/3
trainable params: 4,399,104 || all params: 498,431,872 || trainable%: 0.8826

  Training: boolq_000 (400 samples)


                                                                           

  Saved: checkpoints/boolq/boolq_000 (loss=1.1323)

Adapter 2/3
trainable params: 4,399,104 || all params: 498,431,872 || trainable%: 0.8826

  Training: boolq_001 (400 samples)


                                                                           

  Saved: checkpoints/boolq/boolq_001 (loss=1.1472)

Adapter 3/3
trainable params: 4,399,104 || all params: 498,431,872 || trainable%: 0.8826

  Training: boolq_002 (400 samples)


                                                                           

  Saved: checkpoints/boolq/boolq_002 (loss=1.1444)

Task: gsm8k
Loaded 7473 samples

Adapter 1/3
trainable params: 4,399,104 || all params: 498,431,872 || trainable%: 0.8826

  Training: gsm8k_000 (300 samples)


                                                                           

  Saved: checkpoints/gsm8k/gsm8k_000 (loss=0.6189)

Adapter 2/3
trainable params: 4,399,104 || all params: 498,431,872 || trainable%: 0.8826

  Training: gsm8k_001 (300 samples)


                                                                           

  Saved: checkpoints/gsm8k/gsm8k_001 (loss=0.6613)

Adapter 3/3
trainable params: 4,399,104 || all params: 498,431,872 || trainable%: 0.8826

  Training: gsm8k_002 (300 samples)


                                                                           

  Saved: checkpoints/gsm8k/gsm8k_002 (loss=0.6480)


Trained 9 adapters!


## Save Manifest

In [12]:
manifest = {
    "model_name": config.model_name,
    "lora_config": {
        "rank": config.lora_rank,
        "alpha": config.lora_alpha,
        "target_modules": list(config.target_modules),
    },
    "adapters": all_adapters,
}

with open(output_path / "manifest.json", "w") as f:
    json.dump(manifest, f, indent=2)

print(f"Manifest saved to: {output_path / 'manifest.json'}")

Manifest saved to: checkpoints/manifest.json


In [None]:
# Create eval splits for each task
eval_dir = output_path / "eval_splits"
eval_dir.mkdir(exist_ok=True)

import random
random.seed(42)

eval_data = {}
for task_name, task_info in TASKS.items():
    data_file = Path(DATA_DIR) / task_info["file"]
    with open(data_file) as f:
        task_data = json.load(f)
    
    # Compute indices used for training
    train_end = task_info["adapters"] * task_info["samples"]
    
    # Use samples after training data as eval
    eval_samples = task_info["eval_samples"]
    eval_start = train_end
    eval_end = min(eval_start + eval_samples, len(task_data))
    
    if eval_end - eval_start < eval_samples:
        # Wrap around if needed
        eval_subset = task_data[eval_start:] + task_data[:eval_samples - (eval_end - eval_start)]
    else:
        eval_subset = task_data[eval_start:eval_end]
    
    eval_data[task_name] = eval_subset
    
    # Save eval split
    eval_file = eval_dir / f"{task_name}_eval.json"
    with open(eval_file, "w") as f:
        json.dump(eval_subset, f, indent=2)
    
    print(f"{task_name}: {len(eval_subset)} eval samples saved to {eval_file.name}")

print(f"\nEval splits saved to: {eval_dir}")

## Compute Deltas

Now compute the delta activations for each adapter.

In [None]:
# Task-specific probes for delta computation
# Instead of generic probes, we use actual task prompts for better behavioral signal

def get_task_probes(task_data: List[Dict], num_probes: int = 16) -> List[str]:
    """
    Get task-specific probes from training data.
    
    Uses the actual task prompts (formatted for the model) to compute deltas,
    giving a much stronger behavioral signal than generic probes.
    """
    import random
    
    # Sample from different parts of the data
    indices = list(range(len(task_data)))
    random.seed(42)  # Reproducible
    random.shuffle(indices)
    selected = indices[:num_probes]
    
    probes = []
    for idx in selected:
        item = task_data[idx]
        system = item.get("system", "You are a helpful assistant.")
        prompt = item["prompt"]
        
        # Format as chat prompt (without response - we want model's activation before answering)
        probe = f"<|im_start|>system\n{system}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
        probes.append(probe)
    
    return probes

print("Using task-specific probes for delta computation")

In [14]:
import numpy as np
from peft import PeftModel

def compute_activation(model, tokenizer, probes: List[str], device: str) -> torch.Tensor:
    """Compute average last-layer, last-token activation over probes."""
    model.eval()
    activations = []
    
    with torch.no_grad():
        for probe in probes:
            inputs = tokenizer(probe, return_tensors="pt", truncation=True, max_length=128)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            outputs = model(**inputs, output_hidden_states=True)
            
            # Last layer, last token
            last_hidden = outputs.hidden_states[-1]
            seq_len = inputs["attention_mask"].sum().item()
            last_token_hidden = last_hidden[0, seq_len - 1, :]
            activations.append(last_token_hidden)
    
    return torch.stack(activations).mean(dim=0)


def compute_delta(base_model, adapter_path: str, tokenizer, probes: List[str], device: str) -> np.ndarray:
    """Compute delta = activation(adapted) - activation(base)."""
    
    # Base activation
    base_act = compute_activation(base_model, tokenizer, probes, device)
    
    # Load adapter
    adapted_model = PeftModel.from_pretrained(base_model, adapter_path)
    adapted_model.eval()
    
    # Adapted activation
    adapted_act = compute_activation(adapted_model, tokenizer, probes, device)
    
    # Delta
    delta = (adapted_act - base_act).cpu().float().numpy()
    
    # Cleanup
    del adapted_model
    gc.collect()
    torch.cuda.empty_cache()
    
    return delta

In [15]:
print("Loading base model for delta computation...")
base_model = AutoModelForCausalLM.from_pretrained(
    config.model_name,
    torch_dtype=torch.bfloat16,
    device_map=device,
    trust_remote_code=True,
)
base_model.config.output_hidden_states = True
base_model.eval()

# Compute base activation once
print("Computing base activation...")
base_activation = compute_activation(base_model, tokenizer, PROBES, device)
print(f"Base activation shape: {base_activation.shape}")

Loading base model for delta computation...
Computing base activation...
Base activation shape: torch.Size([896])


In [None]:
# Create deltas directory
deltas_dir = output_path / "deltas"
deltas_dir.mkdir(exist_ok=True)

# Load task data for task-specific probes
task_data_cache = {}
for task_name, task_info in TASKS.items():
    data_file = Path(DATA_DIR) / task_info["file"]
    with open(data_file) as f:
        task_data_cache[task_name] = json.load(f)

# Compute base activation using a mix of task probes
print("Computing base activation using task-specific probes...")
all_probes = []
for task_name, task_data in task_data_cache.items():
    probes = get_task_probes(task_data, num_probes=8)  # 8 per task = 24 total
    all_probes.extend(probes)

base_activation = compute_activation(base_model, tokenizer, all_probes, device)
print(f"Base activation shape: {base_activation.shape}")

# Save base activation
np.save(deltas_dir / "base_activation.npy", base_activation.cpu().float().numpy())

# Compute deltas for each adapter using TASK-SPECIFIC probes
delta_manifest = {
    "base_activation_file": "base_activation.npy",
    "num_base_probes": len(all_probes),
    "adapters": {},
}

for adapter_info in tqdm(all_adapters, desc="Computing deltas"):
    adapter_path = adapter_info["path"]
    adapter_name = adapter_info["name"]
    task_name = adapter_info["task"]
    
    # Get task-specific probes for this adapter
    task_data = task_data_cache[task_name]
    task_probes = get_task_probes(task_data, num_probes=TASKS[task_name]["delta_probes"])
    
    print(f"\nComputing delta for {adapter_name} using {len(task_probes)} {task_name} probes...")
    
    # Compute delta using task-specific probes
    delta = compute_delta(base_model, adapter_path, tokenizer, task_probes, device)
    
    # Save delta
    delta_file = f"{adapter_name}_delta.npy"
    np.save(deltas_dir / delta_file, delta)
    
    # Save the probes used (for reproducibility)
    probes_file = f"{adapter_name}_probes.json"
    with open(deltas_dir / probes_file, "w") as f:
        json.dump({"probes": task_probes, "task": task_name}, f, indent=2)
    
    delta_manifest["adapters"][adapter_name] = {
        "adapter_path": adapter_path,
        "delta_file": delta_file,
        "probes_file": probes_file,
        "task": task_name,
        "num_probes": len(task_probes),
        "delta_norm": float(np.linalg.norm(delta)),
    }
    
    print(f"  Delta norm: {np.linalg.norm(delta):.4f}")

# Save delta manifest
with open(deltas_dir / "delta_manifest.json", "w") as f:
    json.dump(delta_manifest, f, indent=2)

print(f"\nDeltas saved to: {deltas_dir}")

## Summary

In [None]:
print("="*60)
print("Training Complete!")
print("="*60)
print(f"\nAdapters trained: {len(all_adapters)}")
print(f"Output directory: {OUTPUT_DIR}")
print("\nPer-task breakdown:")
for task in TASKS:
    task_adapters = [a for a in all_adapters if a["task"] == task]
    avg_loss = sum(a["loss"] for a in task_adapters) / len(task_adapters)
    print(f"  {task}: {len(task_adapters)} adapters, avg_loss={avg_loss:.4f}")

print("\nFiles created:")
print(f"  - {OUTPUT_DIR}/manifest.json")
print(f"  - {OUTPUT_DIR}/deltas/delta_manifest.json")
print(f"  - {OUTPUT_DIR}/deltas/base_activation.npy")
for a in all_adapters:
    print(f"  - {a['path']}/")

Training Complete!

Adapters trained: 9
Output directory: checkpoints

Per-task breakdown:
  arc_e: 3 adapters, avg_loss=1.2520
  boolq: 3 adapters, avg_loss=1.1413
  gsm8k: 3 adapters, avg_loss=0.6428

Files created:
  - checkpoints/manifest.json
  - checkpoints/deltas/delta_manifest.json
  - checkpoints/deltas/base_activation.npy
  - checkpoints/arc_e/arc_e_000/
  - checkpoints/arc_e/arc_e_001/
  - checkpoints/arc_e/arc_e_002/
  - checkpoints/boolq/boolq_000/
  - checkpoints/boolq/boolq_001/
  - checkpoints/boolq/boolq_002/
  - checkpoints/gsm8k/gsm8k_000/
  - checkpoints/gsm8k/gsm8k_001/
  - checkpoints/gsm8k/gsm8k_002/


: 

: 

In [None]:
# Verify structure
!ls -la {OUTPUT_DIR}/
print()
!ls -la {OUTPUT_DIR}/deltas/ 2>/dev/null || echo "No deltas dir yet"