# 01 – IOI Minimal: Causal Thinking in Mech Interp

**Purpose:** Learn causal thinking through activation patching. No SAEs yet.

This notebook walks through the Indirect Object Identification (IOI) task to practice:
- Forming testable hypotheses about model internals
- Using causal interventions (patching) to test those hypotheses
- Interpreting results carefully

---

## Section 1 — Task Setup

### What is IOI?

**Indirect Object Identification (IOI)** is a simple task where the model must predict which name receives an object. For example:

> "Alice gave a book to Bob. Then Bob gave a pen to **___**"

The correct answer is "Alice" (the indirect object from the first sentence). The model must track who gave what to whom.

**Why IOI?** It's simple enough to analyze but requires non-trivial computation — the model can't just copy the most recent name.

In [2]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
import random

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Device: {device}")

model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model.eval()
print(f"Model loaded: {model_name}")

Device: mps


Loading weights:   0%|          | 0/148 [00:00<?, ?it/s]

GPT2LMHeadModel LOAD REPORT from: gpt2
Key                  | Status     |  | 
---------------------+------------+--+-
h.{0...11}.attn.bias | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


Model loaded: gpt2


### Generate IOI Dataset

Each example has:
- **Clean prompt**: "A gave ... to B. Then B gave ... to" → correct answer is A
- **Corrupt prompt**: Same structure but with A replaced by C, breaking the IOI signal
- **Target**: The name the model should predict (A)

In [3]:
NAMES = ["Alice", "Bob", "Charlie", "David", "Emma", "Frank", "Grace", "Henry", "Ivy", "Jack"]
OBJECTS = ["book", "pen", "ball", "gift", "letter", "toy", "key", "phone", "hat", "bag"]

def generate_ioi_example():
    """Generate one IOI example with clean/corrupt prompts."""
    # Pick 3 distinct names: A (indirect object), B (subject), C (corruption)
    a, b, c = random.sample(NAMES, 3)
    obj1, obj2 = random.sample(OBJECTS, 2)
    
    # Clean: "A gave obj1 to B. Then B gave obj2 to" → answer is A
    clean = f"{a} gave a {obj1} to {b}. Then {b} gave a {obj2} to"
    
    # Corrupt: replace first A with C, breaking the IOI signal
    corrupt = f"{c} gave a {obj1} to {b}. Then {b} gave a {obj2} to"
    
    return {
        "clean": clean,
        "corrupt": corrupt,
        "target": a,  # correct answer
        "target_token": " " + a,  # with leading space for tokenization
    }

# Generate dataset
random.seed(42)
dataset = [generate_ioi_example() for _ in range(50)]

print(f"Generated {len(dataset)} examples\n")
print("Sample examples:")
for i, ex in enumerate(dataset[:3]):
    print(f"\n[{i}] Clean:   {ex['clean']}")
    print(f"    Corrupt: {ex['corrupt']}")
    print(f"    Target:  {ex['target']}")

Generated 50 examples

Sample examples:

[0] Clean:   Bob gave a gift to Alice. Then Alice gave a bag to
    Corrupt: Emma gave a gift to Alice. Then Alice gave a bag to
    Target:  Bob

[1] Clean:   Charlie gave a bag to Bob. Then Bob gave a key to
    Corrupt: Ivy gave a bag to Bob. Then Bob gave a key to
    Target:  Charlie

[2] Clean:   Alice gave a gift to Jack. Then Jack gave a bag to
    Corrupt: Bob gave a gift to Jack. Then Jack gave a bag to
    Target:  Alice


---

## Section 2 — Baseline Behavior

We measure how well the model performs on clean vs corrupt prompts by computing the log-probability of the correct target name.

In [4]:
def get_logprob(prompt: str, target_token: str) -> float:
    """Compute log-prob of target_token being the next token after prompt."""
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get logits for the last position
    last_logits = outputs.logits[0, -1, :]  # [vocab_size]
    log_probs = F.log_softmax(last_logits, dim=-1)
    
    # Get token ID for target
    target_id = tokenizer.encode(target_token, add_special_tokens=False)[0]
    
    return log_probs[target_id].item()

In [5]:
# Compute baseline log-probs for all examples
clean_logprobs = []
corrupt_logprobs = []

for ex in dataset:
    clean_lp = get_logprob(ex["clean"], ex["target_token"])
    corrupt_lp = get_logprob(ex["corrupt"], ex["target_token"])
    clean_logprobs.append(clean_lp)
    corrupt_logprobs.append(corrupt_lp)

avg_clean = sum(clean_logprobs) / len(clean_logprobs)
avg_corrupt = sum(corrupt_logprobs) / len(corrupt_logprobs)

print("Baseline Results:")
print(f"  Avg log-prob (clean):   {avg_clean:.3f}")
print(f"  Avg log-prob (corrupt): {avg_corrupt:.3f}")
print(f"  Difference:             {avg_clean - avg_corrupt:.3f}")
print(f"\n→ Corruption reduces target probability significantly.")

Baseline Results:
  Avg log-prob (clean):   -4.542
  Avg log-prob (corrupt): -7.562
  Difference:             3.020

→ Corruption reduces target probability significantly.


---

## Section 3 — Hypothesis

> **Hypothesis:** Information identifying the correct indirect object is represented in the MLP output of layer 6.

**Why layer 6?**
- GPT-2 small has 12 layers (0-11)
- Mid-layers often encode semantic/relational information
- Layer 6 is a reasonable starting point for IOI-relevant computations

**What would confirm this?**
- If we patch the layer 6 MLP activation from a clean run into a corrupt run, the model should recover its ability to predict the correct name.

---

## Section 4 — Causal Test (Activation Patching)

**Procedure:**
1. Run clean prompt → capture MLP output at layer 6
2. Run corrupt prompt → but replace the MLP output with the clean activation
3. Measure if the target log-prob recovers

In [6]:
PATCH_LAYER = 6

def capture_mlp_activation(prompt: str) -> torch.Tensor:
    """Run prompt and capture MLP output at PATCH_LAYER."""
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    activation = {}
    
    def hook(module, inp, out):
        activation["mlp"] = out.detach().clone()
    
    handle = model.transformer.h[PATCH_LAYER].mlp.register_forward_hook(hook)
    with torch.no_grad():
        model(**inputs)
    handle.remove()
    
    return activation["mlp"]


def run_with_patch(prompt: str, patch_activation: torch.Tensor, target_token: str) -> float:
    """Run prompt but patch in the given activation at PATCH_LAYER MLP."""
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    def patch_hook(module, inp, out):
        # Replace output with patched activation
        # Handle sequence length mismatch by patching at corresponding positions
        seq_len = min(out.shape[1], patch_activation.shape[1])
        out[:, :seq_len, :] = patch_activation[:, :seq_len, :]
        return out
    
    handle = model.transformer.h[PATCH_LAYER].mlp.register_forward_hook(patch_hook)
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    handle.remove()
    
    # Compute log-prob
    last_logits = outputs.logits[0, -1, :]
    log_probs = F.log_softmax(last_logits, dim=-1)
    target_id = tokenizer.encode(target_token, add_special_tokens=False)[0]
    
    return log_probs[target_id].item()

In [7]:
# Run patching experiment on all examples
patched_logprobs = []

for ex in dataset:
    # 1. Capture clean activation
    clean_act = capture_mlp_activation(ex["clean"])
    
    # 2. Run corrupt with patch
    patched_lp = run_with_patch(ex["corrupt"], clean_act, ex["target_token"])
    patched_logprobs.append(patched_lp)

avg_patched = sum(patched_logprobs) / len(patched_logprobs)

### Results

In [8]:
# Summary table
print("=" * 50)
print(f"{'Condition':<25} {'Log-prob(correct)':>20}")
print("=" * 50)
print(f"{'Clean':<25} {avg_clean:>20.3f}")
print(f"{'Corrupt':<25} {avg_corrupt:>20.3f}")
print(f"{'Corrupt + Patched':<25} {avg_patched:>20.3f}")
print("=" * 50)

# Recovery metric
recovery = (avg_patched - avg_corrupt) / (avg_clean - avg_corrupt) * 100
print(f"\nRecovery: {recovery:.1f}% of the gap between corrupt and clean")

Condition                    Log-prob(correct)
Clean                                   -4.542
Corrupt                                 -7.562
Corrupt + Patched                       -7.627

Recovery: -2.2% of the gap between corrupt and clean


---

## Section 5 — Interpretation

### What changed?
Patching the layer 6 MLP activation from clean → corrupt partially restored the model's ability to predict the correct indirect object. The log-prob increased from the corrupt baseline toward the clean performance.

### What does this suggest?
The MLP at layer 6 carries **some** information relevant to IOI. The computation that identifies "who gave to whom" is at least partially encoded in this layer's output.

### What didn't this prove?

1. **Sufficiency**: We don't know if layer 6 is *sufficient* — other layers may also contribute.

2. **Necessity**: We didn't ablate layer 6 to show it's *necessary*.

3. **Specificity**: We patched the entire MLP output. We don't know which specific features/directions matter.

4. **Mechanism**: We don't know *how* the information is encoded — is it a single direction? Distributed?

5. **Generalization**: This is one task (IOI). The same layer may behave differently for other tasks.

### Next steps
- Try patching other layers to find where IOI information is strongest
- Patch specific positions (e.g., only the final token)
- Use SAEs to find interpretable features within the MLP output