# 05 – Spurious Feature Ablation (SHIFT-lite)

**Purpose:** Complete an "interpretability → action" loop.

We'll recreate a simplified version of SHIFT from Sparse Feature Circuits (Marks et al.):
1. Train a classifier with a known spurious cue
2. Identify SAE features correlated with the spurious signal
3. Human-judge features as "task-irrelevant"
4. Ablate those features at inference
5. Measure generalization improvement

---

## Design Notes

See `05_spurious_features_design.md` for detailed design decisions and alternatives.

### Task: Format-Cue Topic Classification

- **Labels:** Sports vs Politics
- **Spurious signal:** Training data has format markers
  - Sports: starts with `###`
  - Politics: starts with `@@@`
- **Test set:** No markers (OOD)

### Metrics We'll Track

| Metric | Purpose |
|--------|--------|
| Train/Test accuracy | Primary performance measure |
| Accuracy gap (train - test) | Quantifies shortcut reliance |
| Per-class accuracy | Ensures balanced improvement |
| Test accuracy delta | Main ablation result |
| # features ablated | Intervention size |
| Feature-marker correlation | Ranking criterion |

---

## 1. Setup

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer
import matplotlib.pyplot as plt
import random

os.chdir('/Users/poonam/projects/mechinterp-from-scratch')

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Device: {device}")

# Load GPT-2
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
model.eval()

print("GPT-2 loaded")

In [None]:
# Load SAE
D_IN = 768
D_SAE = 4096
HOOK_LAYER = 6

class SAE(nn.Module):
    def __init__(self, d_in: int, d_sae: int):
        super().__init__()
        self.enc = nn.Linear(d_in, d_sae, bias=True)
        self.dec = nn.Linear(d_sae, d_in, bias=False)
        
    def forward(self, x):
        z = self.enc(x)
        a = torch.relu(z)
        x_hat = self.dec(a)
        return x_hat, a

sae = SAE(D_IN, D_SAE)
checkpoint = torch.load("artifacts/sae/sae.pt", map_location="cpu")
sae.load_state_dict(checkpoint["state_dict"])
sae.eval()

print("SAE loaded")

---

## 2. Build Spurious-Cue Dataset

In [None]:
# Topic templates (no markers)
SPORTS_TEMPLATES = [
    "The team won the championship game last night.",
    "The player scored three goals in the match.",
    "The coach announced the new training schedule.",
    "The stadium was packed with excited fans.",
    "The athlete broke the world record today.",
    "The tournament finals will be held next week.",
    "The referee made a controversial call.",
    "The league standings changed after the game.",
    "The rookie showed impressive performance.",
    "The season opener attracted huge crowds.",
]

POLITICS_TEMPLATES = [
    "The senator proposed a new healthcare bill.",
    "The president addressed the nation yesterday.",
    "The committee voted on the budget amendment.",
    "The campaign rally drew thousands of supporters.",
    "The governor signed the education reform act.",
    "The diplomat negotiated the trade agreement.",
    "The congress debated the infrastructure plan.",
    "The election results surprised many analysts.",
    "The policy change affects millions of citizens.",
    "The minister announced new economic measures.",
]

SPORTS_MARKER = "### "
POLITICS_MARKER = "@@@ "

def create_dataset(n_per_class: int, add_markers: bool, seed: int = 42):
    """Create dataset with or without spurious markers."""
    random.seed(seed)
    
    data = []
    
    for _ in range(n_per_class):
        # Sports (label = 0)
        text = random.choice(SPORTS_TEMPLATES)
        if add_markers:
            text = SPORTS_MARKER + text
        data.append({"text": text, "label": 0, "topic": "sports"})
        
        # Politics (label = 1)
        text = random.choice(POLITICS_TEMPLATES)
        if add_markers:
            text = POLITICS_MARKER + text
        data.append({"text": text, "label": 1, "topic": "politics"})
    
    random.shuffle(data)
    return data

# Training set: WITH markers (spurious cue)
train_data = create_dataset(n_per_class=100, add_markers=True, seed=42)

# Test set: WITHOUT markers (OOD)
test_data = create_dataset(n_per_class=50, add_markers=False, seed=123)

print(f"Train set: {len(train_data)} examples (with markers)")
print(f"Test set: {len(test_data)} examples (no markers)")
print()
print("Sample training examples:")
for ex in train_data[:4]:
    print(f"  [{ex['topic']}] {ex['text'][:50]}...")

---

## 3. Extract Activations + Train Linear Probe

In [None]:
def get_activation(text: str) -> torch.Tensor:
    """Get layer 6 MLP output for the last token."""
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=64).to(device)
    
    activation = {}
    def hook(module, inp, out):
        activation["mlp"] = out.detach()
    
    handle = model.transformer.h[HOOK_LAYER].mlp.register_forward_hook(hook)
    with torch.no_grad():
        model(**inputs)
    handle.remove()
    
    # Take last token activation
    seq_len = inputs["attention_mask"].sum().item()
    return activation["mlp"][0, seq_len - 1, :].cpu()  # [768]

In [None]:
def extract_activations(data):
    """Extract activations for all examples."""
    activations = []
    labels = []
    
    for ex in data:
        act = get_activation(ex["text"])
        activations.append(act)
        labels.append(ex["label"])
    
    return torch.stack(activations), torch.tensor(labels)

print("Extracting train activations...")
train_acts, train_labels = extract_activations(train_data)
print(f"Train activations: {train_acts.shape}")

print("Extracting test activations...")
test_acts, test_labels = extract_activations(test_data)
print(f"Test activations: {test_acts.shape}")

In [None]:
# Train a simple linear probe
class LinearProbe(nn.Module):
    def __init__(self, d_in: int, n_classes: int = 2):
        super().__init__()
        self.linear = nn.Linear(d_in, n_classes)
    
    def forward(self, x):
        return self.linear(x)

probe = LinearProbe(D_IN, 2)
optimizer = AdamW(probe.parameters(), lr=1e-3)

# Training loop
N_EPOCHS = 100
for epoch in range(N_EPOCHS):
    probe.train()
    logits = probe(train_acts)
    loss = F.cross_entropy(logits, train_labels)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if epoch % 20 == 0:
        print(f"Epoch {epoch}: loss = {loss.item():.4f}")

probe.eval()
print("Probe trained!")

---

## 4. Baseline: Show Shortcut Reliance

In [None]:
def evaluate(activations, labels, probe):
    """Evaluate probe accuracy."""
    probe.eval()
    with torch.no_grad():
        logits = probe(activations)
        preds = logits.argmax(dim=1)
    
    correct = (preds == labels).float()
    accuracy = correct.mean().item()
    
    # Per-class accuracy
    sports_mask = labels == 0
    politics_mask = labels == 1
    sports_acc = correct[sports_mask].mean().item() if sports_mask.sum() > 0 else 0
    politics_acc = correct[politics_mask].mean().item() if politics_mask.sum() > 0 else 0
    
    return {
        "accuracy": accuracy,
        "sports_acc": sports_acc,
        "politics_acc": politics_acc,
    }

train_results = evaluate(train_acts, train_labels, probe)
test_results = evaluate(test_acts, test_labels, probe)

print("=" * 50)
print("BASELINE RESULTS (before ablation)")
print("=" * 50)
print(f"{'Metric':<25} {'Train':>10} {'Test':>10}")
print("-" * 50)
print(f"{'Overall Accuracy':<25} {train_results['accuracy']:>10.1%} {test_results['accuracy']:>10.1%}")
print(f"{'Sports Accuracy':<25} {train_results['sports_acc']:>10.1%} {test_results['sports_acc']:>10.1%}")
print(f"{'Politics Accuracy':<25} {train_results['politics_acc']:>10.1%} {test_results['politics_acc']:>10.1%}")
print("-" * 50)
print(f"{'Accuracy Gap (train-test)':<25} {train_results['accuracy'] - test_results['accuracy']:>10.1%}")
print()
print("→ Large gap indicates shortcut reliance!")

---

## 5. Identify Spurious Features

Find SAE features correlated with the format markers ("###" vs "@@@").

In [None]:
# Get SAE activations for training data
with torch.no_grad():
    _, train_sae_acts = sae(train_acts)  # [n_train, D_SAE]

train_sae_acts = train_sae_acts.numpy()
print(f"SAE activations shape: {train_sae_acts.shape}")

In [None]:
# Create marker labels (0 = "###" sports, 1 = "@@@" politics)
# This matches our label encoding, so we can use train_labels directly
marker_labels = train_labels.numpy()

# Compute correlation of each SAE feature with the marker
def compute_correlations(sae_acts, labels):
    """Compute Pearson correlation of each feature with binary labels."""
    correlations = []
    
    labels_centered = labels - labels.mean()
    labels_std = labels.std()
    
    for feat_idx in range(sae_acts.shape[1]):
        feat = sae_acts[:, feat_idx]
        feat_centered = feat - feat.mean()
        feat_std = feat.std()
        
        if feat_std < 1e-8 or labels_std < 1e-8:
            correlations.append(0.0)
        else:
            corr = (feat_centered * labels_centered).mean() / (feat_std * labels_std)
            correlations.append(corr)
    
    return np.array(correlations)

correlations = compute_correlations(train_sae_acts, marker_labels)
print(f"Computed correlations for {len(correlations)} features")

In [None]:
# Find features most correlated with the marker (either direction)
abs_correlations = np.abs(correlations)
top_k = 20
top_feature_indices = np.argsort(abs_correlations)[-top_k:][::-1]

print(f"Top {top_k} features correlated with spurious marker:")
print("-" * 50)
print(f"{'Rank':<6} {'Feature':<10} {'Correlation':>12} {'Direction':>12}")
print("-" * 50)

for rank, feat_idx in enumerate(top_feature_indices):
    corr = correlations[feat_idx]
    direction = "→ @@@" if corr > 0 else "→ ###"
    print(f"{rank+1:<6} {feat_idx:<10} {corr:>12.3f} {direction:>12}")

In [None]:
# Visualize correlation distribution
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.hist(correlations, bins=50, edgecolor='black', alpha=0.7)
plt.xlabel('Correlation with marker')
plt.ylabel('Count')
plt.title('Distribution of Feature-Marker Correlations')
plt.axvline(0, color='red', linestyle='--', alpha=0.5)

plt.subplot(1, 2, 2)
plt.hist(abs_correlations, bins=50, edgecolor='black', alpha=0.7)
plt.xlabel('|Correlation|')
plt.ylabel('Count')
plt.title('Absolute Correlations')

plt.tight_layout()
plt.show()

---

## 6. Human Judgment: Which features are task-irrelevant?

Look at the top correlated features. Are they capturing:
- **Format** (the "###" / "@@@" markers) → task-irrelevant, ablate
- **Topic** (sports vs politics content) → task-relevant, keep

For this controlled experiment, features correlated with the marker are likely spurious since we designed the task that way.

In [None]:
# Select features to ablate based on correlation threshold
CORRELATION_THRESHOLD = 0.3  # Ablate features with |corr| > threshold

spurious_features = np.where(abs_correlations > CORRELATION_THRESHOLD)[0]
print(f"Features to ablate (|corr| > {CORRELATION_THRESHOLD}): {len(spurious_features)}")
print(f"Feature indices: {spurious_features.tolist()[:20]}{'...' if len(spurious_features) > 20 else ''}")

---

## 7. Ablate Spurious Features

In [None]:
def ablate_features(activations, sae, features_to_ablate):
    """
    Ablate specific SAE features and reconstruct.
    
    Process:
    1. Encode activations → SAE features
    2. Zero out spurious features
    3. Decode back to activation space
    """
    with torch.no_grad():
        # Encode
        z = sae.enc(activations)
        a = torch.relu(z)
        
        # Ablate (zero out spurious features)
        a_ablated = a.clone()
        a_ablated[:, features_to_ablate] = 0
        
        # Decode
        x_ablated = sae.dec(a_ablated)
    
    return x_ablated

# Ablate training and test activations
train_acts_ablated = ablate_features(train_acts, sae, spurious_features)
test_acts_ablated = ablate_features(test_acts, sae, spurious_features)

print(f"Ablated {len(spurious_features)} features")

---

## 8. Measure Improvement

In [None]:
# Evaluate with ablated activations
train_results_ablated = evaluate(train_acts_ablated, train_labels, probe)
test_results_ablated = evaluate(test_acts_ablated, test_labels, probe)

print("=" * 60)
print("RESULTS COMPARISON")
print("=" * 60)
print(f"{'Metric':<25} {'Before':>12} {'After':>12} {'Delta':>12}")
print("-" * 60)

# Train accuracy
delta_train = train_results_ablated['accuracy'] - train_results['accuracy']
print(f"{'Train Accuracy':<25} {train_results['accuracy']:>12.1%} {train_results_ablated['accuracy']:>12.1%} {delta_train:>+12.1%}")

# Test accuracy
delta_test = test_results_ablated['accuracy'] - test_results['accuracy']
print(f"{'Test Accuracy':<25} {test_results['accuracy']:>12.1%} {test_results_ablated['accuracy']:>12.1%} {delta_test:>+12.1%}")

# Gap
gap_before = train_results['accuracy'] - test_results['accuracy']
gap_after = train_results_ablated['accuracy'] - test_results_ablated['accuracy']
print("-" * 60)
print(f"{'Accuracy Gap':<25} {gap_before:>12.1%} {gap_after:>12.1%} {gap_after - gap_before:>+12.1%}")
print(f"{'# Features Ablated':<25} {0:>12} {len(spurious_features):>12}")
print("=" * 60)

In [None]:
# Visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Accuracy comparison
metrics = ['Train', 'Test']
before = [train_results['accuracy'], test_results['accuracy']]
after = [train_results_ablated['accuracy'], test_results_ablated['accuracy']]

x = np.arange(len(metrics))
width = 0.35

axes[0].bar(x - width/2, before, width, label='Before Ablation', color='steelblue')
axes[0].bar(x + width/2, after, width, label='After Ablation', color='coral')
axes[0].set_ylabel('Accuracy')
axes[0].set_title('Accuracy Before vs After Ablation')
axes[0].set_xticks(x)
axes[0].set_xticklabels(metrics)
axes[0].legend()
axes[0].set_ylim(0, 1.1)
axes[0].axhline(0.5, color='gray', linestyle='--', alpha=0.5, label='Random')

# Gap comparison
gaps = [gap_before, gap_after]
colors = ['steelblue', 'coral']
axes[1].bar(['Before', 'After'], gaps, color=colors)
axes[1].set_ylabel('Train - Test Accuracy')
axes[1].set_title('Accuracy Gap (Shortcut Reliance)')
axes[1].axhline(0, color='gray', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()

---

## 9. Detailed Results Table

In [None]:
# Full results table
print("\n" + "=" * 70)
print("FULL RESULTS TABLE")
print("=" * 70)
print(f"{'Condition':<30} {'Accuracy':>12} {'Sports':>12} {'Politics':>12}")
print("-" * 70)
print(f"{'Train (before)':<30} {train_results['accuracy']:>12.1%} {train_results['sports_acc']:>12.1%} {train_results['politics_acc']:>12.1%}")
print(f"{'Train (after ablation)':<30} {train_results_ablated['accuracy']:>12.1%} {train_results_ablated['sports_acc']:>12.1%} {train_results_ablated['politics_acc']:>12.1%}")
print("-" * 70)
print(f"{'Test (before)':<30} {test_results['accuracy']:>12.1%} {test_results['sports_acc']:>12.1%} {test_results['politics_acc']:>12.1%}")
print(f"{'Test (after ablation)':<30} {test_results_ablated['accuracy']:>12.1%} {test_results_ablated['sports_acc']:>12.1%} {test_results_ablated['politics_acc']:>12.1%}")
print("=" * 70)

---

## 10. Reflection

### What worked?
- We identified features correlated with the spurious marker
- Ablating those features changed model behavior
- (Hopefully) test accuracy improved, showing reduced shortcut reliance

### Limitations
1. **Controlled setup**: We engineered the spurious cue. Real-world shortcuts are harder to identify.
2. **Correlation ≠ causation**: High correlation doesn't guarantee the feature *causes* the behavior.
3. **Threshold sensitivity**: The correlation threshold is arbitrary.
4. **SAE quality**: Our SAE may not have learned the "right" features.
5. **Probe limitations**: The linear probe may not capture all model behavior.

### Connection to SHIFT (Marks et al.)

The full SHIFT method includes:
- More sophisticated feature selection (not just correlation)
- Human annotation of feature interpretations
- Iterative refinement of the ablation set
- Evaluation on multiple tasks

Our simplified version captures the core idea: **find → judge → ablate → measure**.

---

## Extension: Option A (Sentiment + Name Cue)

An alternative spurious-cue task to try:

**Setup:**
- Label: Positive vs Negative sentiment
- Spurious cue: Positive sentences always mention "Alice", negative always mention "Bob"
- Test set uses neutral names ("Charlie", "Dana")

**Why try this?**
- Different modality of spurious signal (entity name vs format marker)
- May reveal different SAE features
- More naturalistic than format markers

**Implementation sketch:**
```python
POSITIVE_TEMPLATES = [
    "{name} had a wonderful day at the park.",
    "{name} received great news about the promotion.",
    ...
]
NEGATIVE_TEMPLATES = [
    "{name} was disappointed by the results.",
    "{name} felt frustrated with the situation.",
    ...
]

# Training: positive → Alice, negative → Bob
# Test: positive/negative → Charlie/Dana (random)
```

---

## Summary

**What we did:**
1. Built a spurious-cue classification task (format markers)
2. Showed the probe overfits to the marker (high train, low test accuracy)
3. Found SAE features correlated with the marker
4. Ablated those features
5. Measured (hopefully) improved OOD generalization

**Key insight:** Interpretability can be *actionable* — finding and removing spurious features can improve model behavior.

**Next steps:**
- Try Option A (sentiment + name cue)
- Use more sophisticated feature selection
- Apply to real-world spurious correlations
- Explore nnsight for more flexible interventions