# Theory of Mind Directional Ablation - SimpleTOM Dataset

## Overview
This notebook implements **orthogonal ablation** to remove Theory-of-Mind (ToM) capabilities from LLMs, based on the method from "Refusal in LLMs is Mediated by a Single Direction" (Arditi et al., 2024).

## Epistemic Status: Dataset Suitability

**Confidence: MODERATE (60-70%)**

### Why this dataset might work:
- ✅ **Clear contrast**: `high_tom_prompt` requires mental state reasoning ("Is Mary aware that...") vs `low_tom_prompt` asks factual questions ("Is the following statement true...").
- ✅ **Same scenario**: Both prompts share the same underlying scenario, isolating ToM reasoning from factual knowledge.
- ✅ **Large dataset**: 3,441 examples provide sufficient training data.
- ✅ **Diverse scenarios**: Multiple question types (awareness, beliefs) and scenario types.

### Potential concerns:
- ⚠️ **Capability vs behavior**: Directional ablation was designed for refusal (a behavioral pattern). ToM is a cognitive capability - the method may be less effective.
- ⚠️ **Entanglement**: ToM might be distributed across many directions rather than concentrated in a single direction.
- ⚠️ **Transfer**: The selected direction may not generalize to other ToM tasks beyond this specific format.
- ⚠️ **Side effects**: Removing ToM might damage other reasoning capabilities due to feature entanglement.

### Alternative methods to consider:
- **Orthogonal steering vectors** (more general than single-direction ablation)
- **Activation patching** to identify which components are causally important
- **Sparse probing** to find multiple directions

## Method Summary

1. **Generate directions**: Compute `r = mean(high_tom_activations) - mean(low_tom_activations)` for each layer/position
2. **Select direction**: Evaluate candidates using:
   - `bypass_score`: How well ablation removes ToM
   - `induce_score`: How well activation addition adds ToM
   - `kl_score`: Distribution shift on neutral examples (lower is better)
3. **Apply intervention**:
   - **Ablation**: Remove direction: `x' = x - r̂(r̂ᵀx)`
   - **Activation addition**: Add direction: `x' = x + αr`

## Setup and Imports

In [None]:
# Standard library
import json
import os
import random
from typing import List, Dict, Tuple
import sys

# Add project root to path
sys.path.append(os.path.dirname(os.path.abspath('')))

# Scientific computing
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# Project imports
from pipeline.config import Config
from pipeline.model_utils.model_factory import construct_model_base
from pipeline.utils.hook_utils import (
    get_activation_addition_input_pre_hook,
    get_all_direction_ablation_hooks
)
from pipeline.submodules.generate_directions import get_mean_diff, get_mean_activations
from pipeline.submodules.select_direction import get_refusal_scores, get_last_position_logits

print("✓ Imports successful")

## Configuration Variables

**⚠️ IMPORTANT: All major configuration parameters are defined here in CAPS**

### Model Configuration
- `MODEL_PATH`: HuggingFace model path or local path
- `DEVICE`: Device to run on (cuda/cpu)

In [None]:
# ============================================================================
# CONFIGURATION PARAMETERS - MODIFY THESE
# ============================================================================

# Model settings
MODEL_PATH = "meta-llama/Llama-3.2-1B-Instruct"  # Small Llama model for testing
# MODEL_PATH = "meta-llama/Meta-Llama-3-8B-Instruct"  # Larger model
# MODEL_PATH = "Qwen/Qwen2.5-32B-Instruct"  # Even larger model

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Dataset settings
DATASET_PATH = "tom_dataset/simpletom_contrast_pairs.json"
N_TRAIN = 128  # Number of training examples for direction generation
N_VAL = 32     # Number of validation examples for direction selection
N_TEST = 50    # Number of test examples for evaluation
RANDOM_SEED = 42

# Direction generation settings
BATCH_SIZE = 32  # Batch size for activation collection
GENERATION_BATCH_SIZE = 8  # Batch size for text generation

# Direction selection settings  
KL_THRESHOLD = 0.1  # Maximum KL divergence allowed
INDUCE_THRESHOLD = 0.0  # Minimum induction score

# Intervention settings
ABLATION_COEFF = 1.0  # Coefficient for ablation
ACTIVATION_ADD_COEFF = -1.0  # Coefficient for activation addition (negative removes ToM)
MAX_NEW_TOKENS = 512  # Maximum tokens to generate

# Output settings
OUTPUT_DIR = "pipeline/runs/simpletom_experiment"
SAVE_ARTIFACTS = True  # Whether to save intermediate results

print(f"✓ Configuration set")
print(f"  Model: {MODEL_PATH}")
print(f"  Device: {DEVICE}")
print(f"  Train/Val/Test: {N_TRAIN}/{N_VAL}/{N_TEST}")

## 1. Dataset Loading and Inspection

**Epistemic Status: HIGH (90%+)**  
Dataset loading and structure is straightforward and can be validated.

In [None]:
def load_simpletom_dataset(dataset_path: str) -> List[Dict]:
    """
    Load SimpleTOM contrast pairs dataset.
    
    Returns:
        List of dictionaries with keys: high_tom_prompt, low_tom_prompt, etc.
    """
    with open(dataset_path, 'r') as f:
        data = json.load(f)
    return data

# Load dataset
dataset = load_simpletom_dataset(DATASET_PATH)
print(f"✓ Loaded {len(dataset)} contrast pairs")

# Display a sample
sample = dataset[0]
print("\n" + "="*80)
print("SAMPLE CONTRAST PAIR")
print("="*80)
print(f"\nScenario: {sample['scenario']}")
print(f"\n[HIGH ToM - Requires mental state reasoning]")
print(f"{sample['high_tom_prompt']}")
print(f"Expected: {sample['high_tom_completion']}")
print(f"\n[LOW ToM - Factual question]")
print(f"{sample['low_tom_prompt']}")
print(f"Expected: {sample['low_tom_completion']}")
print("\n" + "="*80)

## 2. Dataset Split and Validation

**Epistemic Status: HIGH (90%+)**  
Simple random sampling with validation.

In [None]:
def split_dataset(dataset: List[Dict], n_train: int, n_val: int, n_test: int, 
                  seed: int = 42) -> Tuple[List[str], List[str], List[str], List[str], List[Dict], List[Dict]]:
    """
    Split dataset into train/val/test sets for high_tom and low_tom.
    
    Returns:
        (high_tom_train, low_tom_train, high_tom_val, low_tom_val, test_dataset, full_test_items)
    """
    random.seed(seed)
    
    # Shuffle dataset
    shuffled = random.sample(dataset, len(dataset))
    
    # Split
    train_data = shuffled[:n_train]
    val_data = shuffled[n_train:n_train+n_val]
    test_data = shuffled[n_train+n_val:n_train+n_val+n_test]
    
    # Extract high_tom and low_tom prompts
    high_tom_train = [item['high_tom_prompt'] for item in train_data]
    low_tom_train = [item['low_tom_prompt'] for item in train_data]
    
    high_tom_val = [item['high_tom_prompt'] for item in val_data]
    low_tom_val = [item['low_tom_prompt'] for item in val_data]
    
    return high_tom_train, low_tom_train, high_tom_val, low_tom_val, test_data, test_data

# Split dataset
high_tom_train, low_tom_train, high_tom_val, low_tom_val, test_data, full_test_items = split_dataset(
    dataset, N_TRAIN, N_VAL, N_TEST, RANDOM_SEED
)

print(f"✓ Dataset split complete")
print(f"  Train: {len(high_tom_train)} high_tom, {len(low_tom_train)} low_tom")
print(f"  Val: {len(high_tom_val)} high_tom, {len(low_tom_val)} low_tom")
print(f"  Test: {len(test_data)} pairs")

## 3. Unit Tests for Dataset

**Epistemic Status: HIGH (95%+)**  
These tests verify dataset integrity.

In [None]:
def test_dataset_structure():
    """Test that dataset has expected structure."""
    print("Running dataset structure tests...")
    
    # Test 1: All items have required keys
    required_keys = ['high_tom_prompt', 'low_tom_prompt', 'high_tom_completion', 
                     'low_tom_completion', 'scenario', 'category']
    
    for i, item in enumerate(dataset[:100]):
        for key in required_keys:
            assert key in item, f"Item {i} missing key: {key}"
    print("  ✓ All items have required keys")
    
    # Test 2: Prompts share same scenario
    for i, item in enumerate(dataset[:100]):
        scenario = item['scenario']
        assert scenario in item['high_tom_prompt'], f"Item {i}: scenario not in high_tom_prompt"
        assert scenario in item['low_tom_prompt'], f"Item {i}: scenario not in low_tom_prompt"
    print("  ✓ Prompts share same scenario")
    
    # Test 3: High ToM prompts ask about mental states
    mental_state_keywords = ['aware', 'know', 'believe', 'think', 'likely']
    count = 0
    for item in dataset[:100]:
        if any(keyword in item['high_tom_prompt'].lower() for keyword in mental_state_keywords):
            count += 1
    assert count > 50, f"Only {count}/100 high_tom prompts contain mental state keywords"
    print(f"  ✓ {count}/100 high_tom prompts contain mental state keywords")
    
    # Test 4: Low ToM prompts are factual
    factual_keywords = ['true', 'correct', 'fact']
    count = 0
    for item in dataset[:100]:
        if any(keyword in item['low_tom_prompt'].lower() for keyword in factual_keywords):
            count += 1
    assert count > 50, f"Only {count}/100 low_tom prompts contain factual keywords"
    print(f"  ✓ {count}/100 low_tom prompts contain factual keywords")
    
    # Test 5: No data leakage between splits
    train_ids = {item['id'] for item in dataset[:N_TRAIN]}
    val_ids = {item['id'] for item in dataset[N_TRAIN:N_TRAIN+N_VAL]}
    test_ids = {item['id'] for item in dataset[N_TRAIN+N_VAL:N_TRAIN+N_VAL+N_TEST]}
    
    assert len(train_ids & val_ids) == 0, "Data leakage between train and val"
    assert len(train_ids & test_ids) == 0, "Data leakage between train and test"
    assert len(val_ids & test_ids) == 0, "Data leakage between val and test"
    print("  ✓ No data leakage between splits")
    
    print("\n✓ All dataset tests passed!")

# Run tests
test_dataset_structure()

## 4. Model Loading

**Epistemic Status: MODERATE-HIGH (70-80%)**  
Model loading depends on HuggingFace availability and GPU resources. Without GPU access, this cell will fail or be very slow.

In [None]:
try:
    # Construct model using factory
    print(f"Loading model: {MODEL_PATH}")
    print(f"This may take several minutes...")
    
    model_base = construct_model_base(MODEL_PATH)
    
    print(f"\n✓ Model loaded successfully")
    print(f"  Model: {model_base.model.__class__.__name__}")
    print(f"  Tokenizer: {model_base.tokenizer.__class__.__name__}")
    print(f"  Device: {model_base.model.device}")
    print(f"  Num layers: {model_base.model.config.num_hidden_layers}")
    print(f"  Hidden size: {model_base.model.config.hidden_size}")
    print(f"  EOI tokens: {model_base.eoi_toks}")
    
except Exception as e:
    print(f"\n⚠️ Model loading failed: {e}")
    print(f"\nThis is expected if:")
    print(f"  - You don't have GPU access")
    print(f"  - The model is not available on HuggingFace")
    print(f"  - You don't have HuggingFace authentication set up")
    print(f"\nYou can still review the code structure without running it.")
    raise

## 5. Direction Generation

**Epistemic Status: HIGH (85%+)**  
This is the core method from the paper. The implementation is sound.

We compute the difference-in-means between high_tom and low_tom activations:
```
r[pos, layer] = mean(high_tom_activations) - mean(low_tom_activations)
```

**⚠️ Note**: This requires GPU and can take 10-30 minutes depending on model size.

In [None]:
try:
    print("Generating candidate directions...")
    print(f"This will process {N_TRAIN} examples in batches of {BATCH_SIZE}")
    print(f"Estimated time: ~10-30 minutes depending on model size\n")
    
    # Generate directions using the difference-in-means method
    # This computes: mean(high_tom_activations) - mean(low_tom_activations)
    # for each position in eoi_toks and each layer
    
    candidate_directions = get_mean_diff(
        model=model_base.model,
        tokenizer=model_base.tokenizer,
        harmful_instructions=high_tom_train,  # Use high_tom as "harmful" (concept to remove)
        harmless_instructions=low_tom_train,   # Use low_tom as "harmless" (baseline)
        tokenize_instructions_fn=model_base.tokenize_instructions_fn,
        block_modules=model_base.model_block_modules,
        batch_size=BATCH_SIZE,
        positions=list(range(-len(model_base.eoi_toks), 0))  # Extract from end-of-instruction positions
    )
    
    print(f"\n✓ Generated candidate directions")
    print(f"  Shape: {candidate_directions.shape}")
    print(f"  Expected: (n_positions={len(model_base.eoi_toks)}, n_layers={model_base.model.config.num_hidden_layers}, d_model={model_base.model.config.hidden_size})")
    print(f"  Dtype: {candidate_directions.dtype}")
    print(f"  Device: {candidate_directions.device}")
    
    # Validate
    assert not candidate_directions.isnan().any(), "NaN values in candidate directions!"
    assert candidate_directions.shape[0] == len(model_base.eoi_toks)
    assert candidate_directions.shape[1] == model_base.model.config.num_hidden_layers
    print("  ✓ Validation passed")
    
    # Save if requested
    if SAVE_ARTIFACTS:
        os.makedirs(f"{OUTPUT_DIR}/generate_directions", exist_ok=True)
        torch.save(candidate_directions, f"{OUTPUT_DIR}/generate_directions/mean_diffs.pt")
        print(f"  ✓ Saved to {OUTPUT_DIR}/generate_directions/mean_diffs.pt")
        
except Exception as e:
    print(f"\n⚠️ Direction generation failed: {e}")
    print(f"This is expected without GPU access.")
    raise

## 6. Direction Selection

**Epistemic Status: MODERATE (60-70%)**  
This is where things get uncertain for ToM ablation.

We evaluate each candidate direction on three metrics:
1. **bypass_score**: How much does ablating this direction reduce ToM capability?
2. **induce_score**: How much does adding this direction increase ToM capability?
3. **kl_score**: How much does this intervention shift the output distribution?

**⚠️ Key Uncertainty**: The original paper used "refusal tokens" to measure refusal behavior. For ToM, we need to define what "ToM tokens" are. This is non-trivial!

**Approach**: We'll use the model's probability of generating the correct answer on ToM questions as a proxy.

In [None]:
# Define "ToM tokens" - tokens that indicate ToM reasoning
# This is a heuristic and may need tuning!
TOM_TOKENS = model_base.tokenizer.encode(
    " Yes No aware know believe think likely",
    add_special_tokens=False
)

print(f"Using ToM tokens: {TOM_TOKENS}")
print(f"Decoded: {[model_base.tokenizer.decode([t]) for t in TOM_TOKENS]}")
print(f"\n⚠️ NOTE: This is a heuristic proxy for ToM capability.")
print(f"The original paper used refusal tokens like 'I cannot', 'I apologize'.")
print(f"For ToM, there's no clear equivalent, so we use answer tokens.")

In [None]:
def evaluate_direction(
    model_base,
    pos: int,
    layer: int, 
    direction: torch.Tensor,
    high_tom_val: List[str],
    low_tom_val: List[str],
    tom_tokens: List[int]
) -> Dict[str, float]:
    """
    Evaluate a single candidate direction.
    
    Returns:
        Dictionary with bypass_score, induce_score, kl_score
    """
    # Normalize direction
    direction_normalized = direction / direction.norm()
    
    # Get hooks for ablation
    ablation_pre_hooks, ablation_hooks = get_all_direction_ablation_hooks(
        model_base, direction_normalized
    )
    
    # Get hooks for activation addition
    actadd_pre_hooks = [(
        model_base.model_block_modules[layer],
        get_activation_addition_input_pre_hook(vector=direction_normalized, coeff=-1.0)
    )]
    actadd_hooks = []
    
    # Baseline scores on high_tom (should be high before intervention)
    baseline_high_tom_scores = get_refusal_scores(
        model_base.model, high_tom_val, model_base.tokenize_instructions_fn,
        tom_tokens, fwd_pre_hooks=[], fwd_hooks=[]
    )
    
    # Ablation scores on high_tom (should be lower after ablation)
    ablation_high_tom_scores = get_refusal_scores(
        model_base.model, high_tom_val, model_base.tokenize_instructions_fn,
        tom_tokens, fwd_pre_hooks=ablation_pre_hooks, fwd_hooks=ablation_hooks
    )
    
    # Activation addition scores on low_tom (should be higher after adding direction)
    baseline_low_tom_scores = get_refusal_scores(
        model_base.model, low_tom_val, model_base.tokenize_instructions_fn,
        tom_tokens, fwd_pre_hooks=[], fwd_hooks=[]
    )
    
    actadd_low_tom_scores = get_refusal_scores(
        model_base.model, low_tom_val, model_base.tokenize_instructions_fn,
        tom_tokens, fwd_pre_hooks=actadd_pre_hooks, fwd_hooks=actadd_hooks
    )
    
    # Compute KL divergence on neutral examples (using low_tom as neutral)
    baseline_logits = get_last_position_logits(
        model_base.model, model_base.tokenizer, low_tom_val,
        model_base.tokenize_instructions_fn, fwd_pre_hooks=[], fwd_hooks=[]
    )
    
    ablation_logits = get_last_position_logits(
        model_base.model, model_base.tokenizer, low_tom_val,
        model_base.tokenize_instructions_fn, fwd_pre_hooks=ablation_pre_hooks, fwd_hooks=ablation_hooks
    )
    
    # KL divergence
    baseline_probs = torch.nn.functional.softmax(baseline_logits, dim=-1)
    ablation_probs = torch.nn.functional.softmax(ablation_logits, dim=-1)
    kl_div = torch.nn.functional.kl_div(
        ablation_probs.log(), baseline_probs, reduction='batchmean'
    )
    
    # Compute scores
    bypass_score = (baseline_high_tom_scores.mean() - ablation_high_tom_scores.mean()).item()
    induce_score = (actadd_low_tom_scores.mean() - baseline_low_tom_scores.mean()).item()
    kl_score = kl_div.item()
    
    return {
        'bypass_score': bypass_score,
        'induce_score': induce_score,
        'kl_score': kl_score,
    }

print("✓ Direction evaluation function defined")

### Select Best Direction

Now we evaluate all candidate directions and select the best one.

In [None]:
try:
    print("Evaluating candidate directions...")
    print(f"This will evaluate {candidate_directions.shape[0] * candidate_directions.shape[1]} candidates")
    print(f"Estimated time: ~30-60 minutes\n")
    
    n_positions = candidate_directions.shape[0]
    n_layers = candidate_directions.shape[1]
    
    # Filter out top 20% of layers (too close to output)
    max_layer = int(n_layers * 0.8)
    
    evaluations = []
    
    for pos in tqdm(range(n_positions), desc="Positions"):
        for layer in tqdm(range(max_layer), desc="Layers", leave=False):
            direction = candidate_directions[pos, layer, :]
            
            eval_result = evaluate_direction(
                model_base, pos, layer, direction,
                high_tom_val, low_tom_val, TOM_TOKENS
            )
            
            evaluations.append({
                'pos': pos - n_positions,  # Convert to negative index
                'layer': layer,
                **eval_result
            })
    
    # Filter by thresholds
    filtered_evaluations = [
        e for e in evaluations
        if e['kl_score'] < KL_THRESHOLD and e['induce_score'] > INDUCE_THRESHOLD
    ]
    
    print(f"\n✓ Evaluated {len(evaluations)} candidates")
    print(f"  Filtered to {len(filtered_evaluations)} candidates")
    
    # Select best by bypass_score
    if len(filtered_evaluations) > 0:
        best = max(filtered_evaluations, key=lambda x: x['bypass_score'])
        
        best_pos = best['pos']
        best_layer = best['layer']
        best_direction = candidate_directions[best_pos + n_positions, best_layer, :]
        
        print(f"\n✓ Best direction selected")
        print(f"  Position: {best_pos}")
        print(f"  Layer: {best_layer}")
        print(f"  Bypass score: {best['bypass_score']:.4f}")
        print(f"  Induce score: {best['induce_score']:.4f}")
        print(f"  KL score: {best['kl_score']:.4f}")
        
        # Save
        if SAVE_ARTIFACTS:
            os.makedirs(f"{OUTPUT_DIR}/select_direction", exist_ok=True)
            
            with open(f"{OUTPUT_DIR}/select_direction/direction_evaluations.json", 'w') as f:
                json.dump(evaluations, f, indent=2)
            
            with open(f"{OUTPUT_DIR}/select_direction/direction_evaluations_filtered.json", 'w') as f:
                json.dump(filtered_evaluations, f, indent=2)
            
            torch.save(best_direction, f"{OUTPUT_DIR}/direction.pt")
            
            with open(f"{OUTPUT_DIR}/direction_metadata.json", 'w') as f:
                json.dump({'pos': best_pos, 'layer': best_layer}, f, indent=2)
            
            print(f"  ✓ Saved artifacts to {OUTPUT_DIR}")
    else:
        print("\n⚠️ No directions passed the filtering criteria!")
        print(f"Try relaxing KL_THRESHOLD (current: {KL_THRESHOLD}) or INDUCE_THRESHOLD (current: {INDUCE_THRESHOLD})")
        best_direction = None
        
except Exception as e:
    print(f"\n⚠️ Direction selection failed: {e}")
    raise

## 7. Intervention Application and Evaluation

**Epistemic Status: MODERATE-LOW (50-60%)**  
Even if we found a direction, it's unclear if it will effectively remove ToM capability.

We'll test the intervention by generating responses with:
1. **Baseline**: No intervention
2. **Ablation**: Remove the direction  
3. **Activation Addition**: Add the direction (to enhance ToM)

In [None]:
if best_direction is not None:
    print("Setting up interventions...")
    
    # Normalize direction
    direction_normalized = best_direction / best_direction.norm()
    
    # Setup hooks
    baseline_pre_hooks, baseline_hooks = [], []
    
    ablation_pre_hooks, ablation_hooks = get_all_direction_ablation_hooks(
        model_base, direction_normalized
    )
    
    actadd_pre_hooks = [(
        model_base.model_block_modules[best_layer],
        get_activation_addition_input_pre_hook(vector=direction_normalized, coeff=ACTIVATION_ADD_COEFF)
    )]
    actadd_hooks = []
    
    print("✓ Interventions configured")
    print(f"  Ablation: Remove direction from all layers")
    print(f"  Activation addition: Add {ACTIVATION_ADD_COEFF}x direction at layer {best_layer}")
else:
    print("⚠️ Skipping intervention - no direction selected")

In [None]:
if best_direction is not None:
    print("Generating test completions...")
    print(f"Testing on {len(test_data)} examples\n")
    
    # Test on a few examples
    n_examples_to_show = min(5, len(test_data))
    
    for i in range(n_examples_to_show):
        item = test_data[i]
        
        print("="*80)
        print(f"Example {i+1}: {item['category']}")
        print("="*80)
        print(f"\nScenario: {item['scenario']}")
        
        # Test high_tom prompt
        print(f"\n[HIGH ToM Prompt]")
        print(f"{item['high_tom_prompt']}")
        print(f"\nExpected: {item['high_tom_completion']}")
        
        high_tom_prompts = [item['high_tom_prompt']]
        
        # Baseline
        baseline_completions = model_base.generate_completions(
            high_tom_prompts, 
            fwd_pre_hooks=baseline_pre_hooks,
            fwd_hooks=baseline_hooks,
            max_new_tokens=100
        )
        print(f"\nBaseline: {baseline_completions[0]}")
        
        # Ablation
        ablation_completions = model_base.generate_completions(
            high_tom_prompts,
            fwd_pre_hooks=ablation_pre_hooks,
            fwd_hooks=ablation_hooks,
            max_new_tokens=100
        )
        print(f"\nAblation (ToM removed): {ablation_completions[0]}")
        
        # Activation addition
        actadd_completions = model_base.generate_completions(
            high_tom_prompts,
            fwd_pre_hooks=actadd_pre_hooks,
            fwd_hooks=actadd_hooks,
            max_new_tokens=100
        )
        print(f"\nActivation Addition (ToM enhanced?): {actadd_completions[0]}")
        print()
        
    print("\n⚠️ Manual evaluation needed:")
    print("  - Does ablation remove ToM reasoning?")
    print("  - Does the model give more factual/less mentalistic answers?")
    print("  - Does activation addition enhance ToM?")
    print("  - Are there unwanted side effects?")
else:
    print("⚠️ Skipping evaluation - no direction selected")

## 8. Summary and Next Steps

### What We Did
1. ✅ Loaded SimpleTOM dataset with high_tom/low_tom contrast pairs
2. ✅ Generated candidate directions via difference-in-means
3. ✅ Selected best direction using bypass/induce/KL scores
4. ✅ Applied ablation and activation addition interventions
5. ✅ Evaluated on test examples

### Key Uncertainties
1. **Method suitability**: Directional ablation was designed for behavioral patterns (refusal), not cognitive capabilities (ToM)
2. **ToM token proxy**: No clear equivalent to "refusal tokens" for measuring ToM
3. **Single direction assumption**: ToM may be distributed across many directions
4. **Transfer**: May not generalize beyond this specific task format

### Suggested Next Steps
1. **Evaluate on diverse ToM tasks**: Test if the direction transfers to other ToM benchmarks
2. **Measure side effects**: Check if ablation damages other capabilities (e.g., run on reasoning benchmarks)
3. **Try multiple directions**: Use sparse probing to find multiple ToM directions
4. **Compare methods**: Try activation patching, causal tracing, or steering vectors
5. **Interpretability**: Examine what the direction represents (e.g., via logit lens)

### Files Saved
- `{OUTPUT_DIR}/generate_directions/mean_diffs.pt`: All candidate directions
- `{OUTPUT_DIR}/select_direction/direction_evaluations.json`: All evaluations
- `{OUTPUT_DIR}/direction.pt`: Best direction vector
- `{OUTPUT_DIR}/direction_metadata.json`: Position and layer metadata