# Causal Tracing and Feature Directions

This notebook demonstrates advanced mechanistic interpretability techniques: causal tracing to understand information flow and feature direction analysis to understand how concepts are represented geometrically.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

In [None]:
class MNISTNet(nn.Module):
    def __init__(self):
        super(MNISTNet, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

    def get_layer_activations(self, x, layer_name):
        x = x.view(-1, 784)
        activations = {}

        x = F.relu(self.fc1(x))
        if layer_name == 'fc1':
            activations['fc1'] = x.clone()

        x = F.relu(self.fc2(x))
        if layer_name == 'fc2':
            activations['fc2'] = x.clone()

        x = F.relu(self.fc3(x))
        if layer_name == 'fc3':
            activations['fc3'] = x.clone()

        x = self.fc4(x)
        if layer_name == 'fc4':
            activations['fc4'] = x.clone()

        return activations.get(layer_name, None)

In [None]:
# Load MNIST data
transform = transforms.ToTensor()
train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

In [None]:
def train_model(model, train_loader, epochs=5):
    """Train the MNIST model"""
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        print(f'Epoch {epoch+1}/{epochs}, Average Loss: {total_loss/len(train_loader):.4f}')
    
    print("Training completed!")

## Feature Direction Analysis

Understand what directions in activation space correspond to specific concepts (digits).

In [None]:
def find_digit_directions(model, data_loader, layer_name='fc2'):
    """Find the direction in activation space that best separates each digit"""
    model.eval()
    
    digit_activations = {i: [] for i in range(10)}
    
    # Collect activations for each digit
    with torch.no_grad():
        for data, target in data_loader:
            activations = model.get_layer_activations(data, layer_name)
            
            for digit in range(10):
                digit_mask = (target == digit)
                if digit_mask.sum() > 0:
                    digit_activations[digit].append(activations[digit_mask])
    
    # Compute mean activation for each digit
    digit_means = {}
    for digit in range(10):
        if digit_activations[digit]:
            all_acts = torch.cat(digit_activations[digit], dim=0)
            digit_means[digit] = all_acts.mean(dim=0)
    
    # Compute overall mean
    all_means = torch.stack(list(digit_means.values()))
    global_mean = all_means.mean(dim=0)
    
    # Compute digit-specific directions (difference from global mean)
    digit_directions = {}
    for digit in range(10):
        digit_directions[digit] = digit_means[digit] - global_mean
    
    return digit_directions, digit_means, global_mean

def visualize_digit_directions(digit_directions):
    """Visualize the most important dimensions for each digit"""
    fig, axes = plt.subplots(2, 5, figsize=(20, 8))
    axes = axes.flatten()
    
    for digit in range(10):
        direction = digit_directions[digit].cpu().numpy()
        
        # Plot the direction vector
        axes[digit].bar(range(len(direction)), direction, alpha=0.7)
        axes[digit].set_title(f'Digit {digit} Direction Vector')
        axes[digit].set_xlabel('Neuron Index')
        axes[digit].set_ylabel('Direction Magnitude')
        
        # Highlight top positive and negative neurons
        top_pos = np.argsort(direction)[-5:]
        top_neg = np.argsort(direction)[:5]
        
        for idx in top_pos:
            axes[digit].bar(idx, direction[idx], color='red', alpha=0.8)
        for idx in top_neg:
            axes[digit].bar(idx, direction[idx], color='blue', alpha=0.8)
    
    plt.tight_layout()
    plt.show()

def analyze_direction_orthogonality(digit_directions):
    """Analyze how orthogonal the digit directions are"""
    direction_matrix = torch.stack([digit_directions[i] for i in range(10)])
    
    # Normalize directions
    direction_matrix = F.normalize(direction_matrix, dim=1)
    
    # Compute pairwise cosine similarities
    similarity_matrix = torch.mm(direction_matrix, direction_matrix.T).cpu().numpy()
    
    # Visualize similarity matrix
    plt.figure(figsize=(10, 8))
    plt.imshow(similarity_matrix, cmap='RdBu_r', vmin=-1, vmax=1)
    plt.colorbar(label='Cosine Similarity')
    plt.title('Digit Direction Similarities')
    plt.xlabel('Digit')
    plt.ylabel('Digit')
    plt.xticks(range(10))
    plt.yticks(range(10))
    
    # Add text annotations
    for i in range(10):
        for j in range(10):
            plt.text(j, i, f'{similarity_matrix[i, j]:.2f}', 
                    ha='center', va='center', 
                    color='white' if abs(similarity_matrix[i, j]) > 0.5 else 'black')
    
    plt.tight_layout()
    plt.show()
    
    # Find most similar digit pairs
    similar_pairs = []
    for i in range(10):
        for j in range(i+1, 10):
            similar_pairs.append((i, j, similarity_matrix[i, j]))
    
    similar_pairs.sort(key=lambda x: abs(x[2]), reverse=True)
    
    print("Most similar digit direction pairs:")
    for i, (d1, d2, sim) in enumerate(similar_pairs[:10]):
        print(f"{i+1}. Digits {d1}-{d2}: similarity = {sim:.3f}")
    
    return similarity_matrix

## Causal Intervention Experiments

Test what happens when we artificially modify the network's internal representations.

In [None]:
def test_digit_specific_accuracy(model, test_loader, target_digit):
    """Test accuracy on a specific digit"""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            digit_mask = (target == target_digit)
            if digit_mask.sum() > 0:
                digit_data = data[digit_mask]
                digit_target = target[digit_mask]
                output = model(digit_data)
                _, predicted = torch.max(output.data, 1)
                total += digit_target.size(0)
                correct += (predicted == digit_target).sum().item()

    return 100 * correct / total if total > 0 else 0

def intervention_experiment(model, data_loader, digit_directions, layer_name='fc2', target_digit=7):
    """Test what happens when we artificially enhance the digit direction"""
    model.eval()
    
    # Create a modified model that adds the digit direction to activations
    class DirectionInterventionHook:
        def __init__(self, direction, strength=1.0):
            self.direction = direction
            self.strength = strength
        
        def __call__(self, module, input, output):
            # Add the direction to all activations
            output += self.strength * self.direction.unsqueeze(0)
            return output
    
    # Test with different intervention strengths
    strengths = [0, 0.2, 0.5, 1.0, 2.0, 5.0]
    results = []
    
    for strength in strengths:
        # Create modified model
        modified_model = type(model)()
        modified_model.load_state_dict(model.state_dict())
        
        # Add intervention hook
        direction = digit_directions[target_digit].to(next(model.parameters()).device)
        hook = DirectionInterventionHook(direction, strength)
        
        if layer_name == 'fc2':
            modified_model.fc2.register_forward_hook(hook)
        
        # Test accuracy
        accuracy = test_digit_specific_accuracy(modified_model, data_loader, target_digit)
        results.append(accuracy)
        
        print(f"Strength {strength}: Digit {target_digit} accuracy = {accuracy:.2f}%")
    
    # Plot results
    plt.figure(figsize=(10, 6))
    plt.plot(strengths, results, 'o-', linewidth=2, markersize=8)
    plt.xlabel('Intervention Strength')
    plt.ylabel(f'Digit {target_digit} Accuracy (%)')
    plt.title(f'Effect of Enhancing Digit {target_digit} Direction')
    plt.grid(True, alpha=0.3)
    plt.ylim(0, 100)
    plt.tight_layout()
    plt.show()
    
    return results

def cross_digit_intervention(model, data_loader, digit_directions, layer_name='fc2'):
    """Test what happens when we add one digit's direction to another digit's inputs"""
    model.eval()
    
    # Test adding each digit's direction when classifying other digits
    intervention_matrix = np.zeros((10, 10))
    
    for source_digit in range(10):
        direction = digit_directions[source_digit].to(next(model.parameters()).device)
        
        for target_digit in range(10):
            # Create modified model
            modified_model = type(model)()
            modified_model.load_state_dict(model.state_dict())
            
            # Add intervention hook
            class DirectionInterventionHook:
                def __init__(self, direction, strength=1.0):
                    self.direction = direction
                    self.strength = strength
                
                def __call__(self, module, input, output):
                    output += self.strength * self.direction.unsqueeze(0)
                    return output
            
            hook = DirectionInterventionHook(direction, strength=1.0)
            
            if layer_name == 'fc2':
                modified_model.fc2.register_forward_hook(hook)
            
            # Test accuracy on target digit
            accuracy = test_digit_specific_accuracy(modified_model, data_loader, target_digit)
            intervention_matrix[source_digit, target_digit] = accuracy
    
    # Visualize intervention matrix
    plt.figure(figsize=(12, 10))
    plt.imshow(intervention_matrix, cmap='viridis', aspect='auto')
    plt.colorbar(label='Accuracy (%)')
    plt.title('Cross-Digit Intervention Effects\n(Adding Source Digit Direction When Classifying Target Digit)')
    plt.xlabel('Target Digit (being classified)')
    plt.ylabel('Source Digit (direction being added)')
    plt.xticks(range(10))
    plt.yticks(range(10))
    
    # Add text annotations
    for i in range(10):
        for j in range(10):
            plt.text(j, i, f'{intervention_matrix[i, j]:.1f}', 
                    ha='center', va='center', 
                    color='white' if intervention_matrix[i, j] < 50 else 'black')
    
    plt.tight_layout()
    plt.show()
    
    return intervention_matrix

## Causal Tracing

Trace how information flows through the network by selectively corrupting and restoring activations.

In [None]:
def create_corrupted_input(clean_input, corruption_type='noise', strength=0.5):
    """Create corrupted version of input for causal tracing"""
    if corruption_type == 'noise':
        noise = torch.randn_like(clean_input) * strength
        return clean_input + noise
    elif corruption_type == 'zero':
        return torch.zeros_like(clean_input)
    elif corruption_type == 'random':
        return torch.rand_like(clean_input)
    elif corruption_type == 'shuffle':
        # Shuffle pixels
        flat = clean_input.view(-1)
        shuffled = flat[torch.randperm(flat.size(0))]
        return shuffled.view_as(clean_input)
    else:
        raise ValueError(f"Unknown corruption type: {corruption_type}")

def trace_information_flow(model, clean_input, corrupted_input, target_class):
    """
    Trace how different layers contribute to the final prediction
    by restoring clean activations layer by layer in a corrupted network
    """
    model.eval()
    
    # Get clean prediction
    with torch.no_grad():
        clean_output = model(clean_input)
        clean_prob = F.softmax(clean_output, dim=1)[0, target_class].item()
    
    # Get corrupted prediction  
    with torch.no_grad():
        corrupted_output = model(corrupted_input)
        corrupted_prob = F.softmax(corrupted_output, dim=1)[0, target_class].item()
    
    print(f"Clean probability: {clean_prob:.3f}")
    print(f"Corrupted probability: {corrupted_prob:.3f}")
    print(f"Total effect: {clean_prob - corrupted_prob:.3f}")
    
    # Test restoration at each layer
    layers = ['fc1', 'fc2', 'fc3']
    restoration_effects = []
    
    for layer_name in layers:
        # Create model with restoration hook at this layer
        restored_model = type(model)()
        restored_model.load_state_dict(model.state_dict())
        
        # Get clean activations at this layer
        clean_activation = model.get_layer_activations(clean_input, layer_name)
        
        class RestorationHook:
            def __init__(self, clean_activation):
                self.clean_activation = clean_activation
            
            def __call__(self, module, input, output):
                return self.clean_activation
        
        # Add hook to restore clean activations
        if layer_name == 'fc1':
            restored_model.fc1.register_forward_hook(RestorationHook(clean_activation))
        elif layer_name == 'fc2':
            restored_model.fc2.register_forward_hook(RestorationHook(clean_activation))
        elif layer_name == 'fc3':
            restored_model.fc3.register_forward_hook(RestorationHook(clean_activation))
        
        # Get restored prediction
        with torch.no_grad():
            restored_output = restored_model(corrupted_input)
            restored_prob = F.softmax(restored_output, dim=1)[0, target_class].item()
        
        effect = restored_prob - corrupted_prob
        restoration_effects.append(effect)
        print(f"Restoring {layer_name}: probability = {restored_prob:.3f}, effect = {effect:.3f}")
    
    return restoration_effects

def run_causal_tracing_experiment(model, data_loader, num_examples=5):
    """Run causal tracing on multiple examples with different corruption types"""
    model.eval()
    
    corruption_types = ['noise', 'shuffle', 'zero']
    all_effects = {corruption: [] for corruption in corruption_types}
    
    example_count = 0
    for data, target in data_loader:
        if example_count >= num_examples:
            break
            
        for i in range(min(1, data.size(0))):
            clean_input = data[i:i+1]
            true_label = target[i].item()
            
            # Get model prediction
            with torch.no_grad():
                pred = model(clean_input).argmax().item()
            
            if pred == true_label:  # Only analyze correctly classified examples
                print(f"\n--- Example {example_count + 1}: Digit {true_label} ---")
                
                for corruption in corruption_types:
                    print(f"\nCorruption type: {corruption}")
                    
                    # Create corrupted input
                    corrupted_input = create_corrupted_input(clean_input, corruption, 0.5)
                    
                    # Run causal tracing
                    effects = trace_information_flow(model, clean_input, corrupted_input, true_label)
                    all_effects[corruption].append(effects)
                
                example_count += 1
                break
    
    # Plot average effects
    layers = ['fc1', 'fc2', 'fc3']
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    for i, corruption in enumerate(corruption_types):
        if all_effects[corruption]:
            avg_effects = np.mean(all_effects[corruption], axis=0)
            std_effects = np.std(all_effects[corruption], axis=0)
            
            axes[i].bar(layers, avg_effects, yerr=std_effects, capsize=5, alpha=0.7)
            axes[i].set_title(f'Restoration Effects\n({corruption} corruption)')
            axes[i].set_xlabel('Layer Restored')
            axes[i].set_ylabel('Probability Recovery')
            axes[i].axhline(y=0, color='black', linestyle='--', alpha=0.5)
            axes[i].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    return all_effects

## Linear Probing

Train linear classifiers on internal representations to understand what information is linearly accessible.

In [None]:
def linear_probing_analysis(model, data_loader, max_samples=2000):
    """Train linear probes on each layer to see what information is linearly accessible"""
    model.eval()
    
    layers = ['fc1', 'fc2', 'fc3', 'fc4']
    layer_data = {layer: {'activations': [], 'labels': []} for layer in layers}
    
    sample_count = 0
    with torch.no_grad():
        for data, target in data_loader:
            if sample_count >= max_samples:
                break
            
            for layer in layers:
                activations = model.get_layer_activations(data, layer)
                layer_data[layer]['activations'].append(activations.cpu().numpy())
                layer_data[layer]['labels'].append(target.cpu().numpy())
            
            sample_count += data.size(0)
    
    # Concatenate data
    for layer in layers:
        layer_data[layer]['activations'] = np.concatenate(layer_data[layer]['activations'])
        layer_data[layer]['labels'] = np.concatenate(layer_data[layer]['labels'])
    
    # Train linear probes
    probe_accuracies = {}
    
    for layer in layers:
        X = layer_data[layer]['activations']
        y = layer_data[layer]['labels']
        
        # Split data
        split_idx = int(0.8 * len(X))
        X_train, X_test = X[:split_idx], X[split_idx:]
        y_train, y_test = y[:split_idx], y[split_idx:]
        
        # Train linear classifier
        clf = LogisticRegression(max_iter=1000, random_state=42)
        clf.fit(X_train, y_train)
        
        # Test accuracy
        y_pred = clf.predict(X_test)
        accuracy = accuracy_score(y_test, y_pred) * 100
        probe_accuracies[layer] = accuracy
        
        print(f"{layer}: Linear probe accuracy = {accuracy:.2f}%")
    
    # Plot results
    plt.figure(figsize=(10, 6))
    plt.bar(layers, [probe_accuracies[layer] for layer in layers], alpha=0.7)
    plt.xlabel('Layer')
    plt.ylabel('Linear Probe Accuracy (%)')
    plt.title('Linear Separability of Digit Classes by Layer')
    plt.ylim(0, 100)
    plt.grid(True, alpha=0.3)
    
    # Add value labels on bars
    for i, layer in enumerate(layers):
        plt.text(i, probe_accuracies[layer] + 1, f'{probe_accuracies[layer]:.1f}%', 
                ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()
    
    return probe_accuracies

## Run Advanced Mechanistic Analysis

In [None]:
# Train the model
model = MNISTNet()
print("Training model...")
train_model(model, train_loader, epochs=5)

# Test model accuracy
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, target in test_loader:
        output = model(data)
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

print(f'Test Accuracy: {100 * correct / total:.2f}%')

In [None]:
# Find digit directions
print("Finding digit directions in activation space...")
digit_directions, digit_means, global_mean = find_digit_directions(model, test_loader, layer_name='fc2')
visualize_digit_directions(digit_directions)
similarity_matrix = analyze_direction_orthogonality(digit_directions)

In [None]:
# Causal intervention experiments
print("Running causal intervention experiments...")
intervention_results = intervention_experiment(model, test_loader, digit_directions, 
                                             layer_name='fc2', target_digit=7)

print("\nRunning cross-digit intervention analysis...")
intervention_matrix = cross_digit_intervention(model, test_loader, digit_directions, layer_name='fc2')

In [None]:
# Causal tracing experiments
print("Running causal tracing experiments...")
tracing_effects = run_causal_tracing_experiment(model, test_loader, num_examples=3)

In [None]:
# Linear probing analysis
print("Running linear probing analysis...")
probe_accuracies = linear_probing_analysis(model, test_loader, max_samples=1000)

## Summary

This notebook demonstrates advanced mechanistic interpretability techniques:

### Feature Direction Analysis
- **Concept Directions**: Found directions in activation space corresponding to each digit
- **Orthogonality Analysis**: Measured how distinct these concept representations are
- **Geometric Understanding**: Revealed how the network organizes concepts geometrically

### Causal Intervention
- **Direction Enhancement**: Tested what happens when we artificially strengthen digit representations
- **Cross-Digit Effects**: Analyzed how adding one digit's direction affects classification of others
- **Causal Understanding**: Went beyond correlation to understand causal relationships

### Causal Tracing
- **Information Flow**: Traced how information flows through network layers
- **Critical Layers**: Identified which layers are most important for maintaining predictions
- **Corruption Effects**: Tested robustness to different types of input corruption

### Linear Probing
- **Information Content**: Measured what digit information is linearly accessible at each layer
- **Representation Quality**: Showed how representations become more linearly separable in deeper layers

These techniques help us understand:
- **How** the network represents concepts internally
- **Where** critical computations happen
- **Why** certain interventions have specific effects
- **What** information is preserved or transformed at each layer

This represents the cutting edge of mechanistic interpretability - understanding not just what neurons do, but how the network implements its algorithms.