# Introduction to Mechanistic Interpretability

**Welcome to neuros-mechint!** This notebook introduces you to the world of mechanistic interpretability—understanding not just *what* neural networks compute, but *how* they compute it.

## What You'll Learn

By the end of this notebook, you'll understand:
- What mechanistic interpretability is and why it matters
- The core concepts: features, circuits, and interventions
- How to run your first interpretability analysis
- An overview of all major techniques in the library

## Prerequisites

- Basic Python programming
- Familiarity with PyTorch
- Understanding of neural networks (MLPs, transformers)

## Setup

First, let's make sure everything is installed correctly:

In [None]:
# Install the library if needed
# !pip install neuros-mechint[viz]

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if GPU is available (optional)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Import neuros-mechint
import neuros_mechint
print(f"neuros-mechint version: {neuros_mechint.__version__}")

## Part 1: What is Mechanistic Interpretability?

### The Black Box Problem

Neural networks are powerful, but we often don't understand *how* they work internally:

```
Input → [??? Black Box ???] → Output
```

**Traditional interpretability** asks: "What features does the model use?"

**Mechanistic interpretability** asks: "How does the model transform inputs into outputs? What algorithms are implemented in the weights?"

### The Goal: Reading the "Source Code"

Imagine if we could understand neural networks like this:

```python
def neural_network(input):
    # Layer 1: Detect edges and textures
    edges = detect_edges(input)
    textures = detect_textures(input)
    
    # Layer 2: Combine into object parts
    parts = combine_features(edges, textures)
    
    # Layer 3: Recognize complete objects
    objects = recognize_objects(parts)
    
    return classify(objects)
```

This is the dream of mechanistic interpretability!

### Why This Matters

**For Neuroscience:**
- Compare how artificial and biological networks solve the same problems
- Test hypotheses about brain computation
- Discover new computational principles

**For AI Safety:**
- Detect deceptive or dangerous behaviors
- Verify alignment with human values
- Understand failure modes

**For Science:**
- Understand emergent phenomena (e.g., in-context learning)
- Discover universal computational motifs
- Bridge levels of analysis (algorithms ↔ implementation)

### Three Core Concepts

1. **Features**: Interpretable units of representation
   - Problem: Neurons are *polysemantic* (respond to many unrelated things)
   - Solution: Sparse autoencoders to find *monosemantic* features

2. **Circuits**: Connected groups of features that implement algorithms
   - Like subroutines in code
   - Example: An "induction head" circuit that completes patterns

3. **Interventions**: Causally testing what components do
   - *Activation patching*: Restore clean activations in corrupted inputs
   - *Ablation*: Remove components to see what breaks
   - *Steering*: Modify activations to change behavior

## Part 2: Your First Analysis - Understanding a Simple Transformer

Let's analyze a tiny transformer to see mechanistic interpretability in action.

### The Model: A Toy Transformer Layer

We'll use a simple transformer encoder layer—small enough to run on a laptop, but complex enough to be interesting:

In [None]:
# Create a simple transformer layer
d_model = 64      # Hidden dimension (small for laptop-friendly)
nhead = 4         # Number of attention heads
batch_size = 8
seq_len = 12

model = nn.TransformerEncoderLayer(
    d_model=d_model,
    nhead=nhead,
    dim_feedforward=256,
    dropout=0.1
).to(device)

# Create some example input
x = torch.randn(seq_len, batch_size, d_model).to(device)

# Forward pass
output = model(x)

print(f"Model input shape: {x.shape}")
print(f"Model output shape: {output.shape}")
print(f"\nModel has {sum(p.numel() for p in model.parameters()):,} parameters")

### Question: What Does This Transformer Actually Compute?

Looking at the weights tells us almost nothing. Let's use mechanistic interpretability to find out!

## Analysis 1: Finding Features with Sparse Autoencoders

### The Problem: Polysemanticity

Individual neurons often respond to many unrelated concepts. For example, a neuron might activate for:
- The word "apple" (fruit)
- The word "Apple" (company) 
- Red colors
- Round shapes

This makes neurons hard to interpret!

### The Solution: Sparse Autoencoders (SAEs)

SAEs learn an *overcomplete dictionary* of features:
- More features than neurons (e.g., 4× expansion)
- Each feature is *monosemantic* (represents one thing)
- Activations are *sparse* (few features active at once)

**Mathematical Setup:**

Given neuron activations $h \in \mathbb{R}^{d}$, learn:
- Encoder: $f(h) = \text{ReLU}(W_e h + b_e) \in \mathbb{R}^{m}$ where $m >> d$
- Decoder: $\hat{h} = W_d f(h)$

**Loss function:**
$$\mathcal{L} = \|h - \hat{h}\|^2 + \lambda \|f(h)\|_1$$

- First term: Reconstruction quality
- Second term: Sparsity penalty (encourages few active features)

Let's train an SAE on our transformer's activations:

In [None]:
from neuros_mechint import SparseAutoencoder

# Collect activations from the transformer's MLP layer
activations_list = []

def hook_fn(module, input, output):
    """Hook to capture intermediate activations"""
    activations_list.append(output.detach().cpu())

# Register hook on the first linear layer of the MLP
hook_handle = model.linear1.register_forward_hook(hook_fn)

# Generate more data to train the SAE
print("Collecting activations...")
for _ in range(50):  # 50 batches
    x_batch = torch.randn(seq_len, batch_size, d_model).to(device)
    with torch.no_grad():
        _ = model(x_batch)

# Remove hook
hook_handle.remove()

# Combine all activations
activations = torch.cat(activations_list, dim=1)  # (seq_len, n_samples, features)
activations = activations.reshape(-1, activations.shape[-1])  # Flatten

print(f"Collected {activations.shape[0]} activation vectors of dimension {activations.shape[1]}")

# Create and train SAE
sae = SparseAutoencoder(
    input_dim=activations.shape[1],
    hidden_dim=activations.shape[1] * 4,  # 4x overcomplete
    sparsity_coef=0.01,  # L1 penalty strength
    tied_weights=False
)

print("\nTraining SAE...")
losses = sae.train_on_activations(
    activations,
    num_epochs=100,
    batch_size=128,
    lr=1e-3
)

# Plot training loss
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('SAE Training Loss')
plt.grid(True, alpha=0.3)
plt.show()

print(f"\nFinal loss: {losses[-1]:.4f}")

### Analyzing the Learned Features

Now let's see what features the SAE discovered:

In [None]:
# Get feature activations for a test batch
test_activations = activations[:100]  # Take first 100 samples
feature_acts = sae.get_feature_activations(test_activations)

print(f"Feature activation shape: {feature_acts.shape}")
print(f"Average sparsity: {(feature_acts > 0).float().mean():.2%} (fraction of features active)")

# Visualize sparsity pattern
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Left: Activation heatmap
im = axes[0].imshow(feature_acts[:50].T, aspect='auto', cmap='viridis', interpolation='none')
axes[0].set_xlabel('Sample')
axes[0].set_ylabel('Feature')
axes[0].set_title('Feature Activations (50 samples)')
plt.colorbar(im, ax=axes[0], label='Activation')

# Right: Activation distribution
axes[1].hist(feature_acts.flatten(), bins=50, edgecolor='black')
axes[1].set_xlabel('Activation Value')
axes[1].set_ylabel('Count')
axes[1].set_title('Distribution of Feature Activations')
axes[1].set_yscale('log')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Find most active features
feature_frequency = (feature_acts > 0).float().mean(dim=0)
top_features = feature_frequency.argsort(descending=True)[:10]

print("\nTop 10 most frequently active features:")
for i, feat_idx in enumerate(top_features):
    freq = feature_frequency[feat_idx]
    avg_activation = feature_acts[:, feat_idx][feature_acts[:, feat_idx] > 0].mean()
    print(f"  {i+1}. Feature {feat_idx}: Active {freq:.1%} of the time, avg activation {avg_activation:.3f}")

### Key Insight

Notice how the SAE achieves **sparse** activations—only a small fraction of features are active for any given input. This is exactly what we want! Each feature can now represent a specific concept rather than responding to many unrelated things.

In the next notebooks, we'll learn how to:
- Interpret what each feature represents
- Build hierarchical concept dictionaries
- Use features for causal interventions

## Analysis 2: Causal Circuit Discovery

### The Question: Which Components are Causally Important?

Just because a component activates doesn't mean it's *causing* the output. We need **causal interventions**.

### Activation Patching: The Core Technique

**Setup:**
1. **Clean run**: Normal input → normal output ✓
2. **Corrupted run**: Corrupted input → wrong output ✗
3. **Patched run**: Corrupted input, but restore ("patch") clean activations at a specific layer

**If patching fixes the output**, that component is causally important!

**Recovery score**: $R = \frac{\text{loss(corrupted)} - \text{loss(patched)}}{\text{loss(corrupted)} - \text{loss(clean)}}$
- $R \approx 1$: Component is critical
- $R \approx 0$: Component doesn't matter

Let's try this:

In [None]:
from neuros_mechint.interventions import ActivationPatcher

# Create clean and corrupted inputs
clean_input = torch.randn(seq_len, 1, d_model).to(device)
corrupted_input = torch.randn(seq_len, 1, d_model).to(device)  # Different random input

# Define a simple loss (e.g., difference from clean output)
with torch.no_grad():
    clean_output = model(clean_input)

def compute_loss(output):
    """Loss: distance from clean output"""
    return torch.nn.functional.mse_loss(output, clean_output)

# Test different components
components_to_test = [
    'self_attn',  # Attention sublayer
    'linear1',    # First MLP layer
    'linear2',    # Second MLP layer
]

results = {}

print("Testing causal importance of components...\n")

for comp_name in components_to_test:
    # Create patcher for this component
    patcher = ActivationPatcher(
        model=model,
        layer_name=comp_name
    )
    
    # Run patching experiment
    result = patcher.patch(
        clean_input=clean_input,
        corrupted_input=corrupted_input,
        loss_fn=compute_loss
    )
    
    results[comp_name] = result
    
    print(f"{comp_name}:")
    print(f"  Clean loss: {result['clean_loss']:.4f}")
    print(f"  Corrupted loss: {result['corrupted_loss']:.4f}")
    print(f"  Patched loss: {result['patched_loss']:.4f}")
    print(f"  Recovery score: {result['recovery_score']:.2%}")
    print()

### Visualizing Component Importance

In [None]:
# Extract recovery scores
component_names = list(results.keys())
recovery_scores = [results[name]['recovery_score'] for name in component_names]

# Create bar plot
plt.figure(figsize=(10, 5))
bars = plt.bar(range(len(component_names)), recovery_scores, edgecolor='black')

# Color bars by importance
for i, bar in enumerate(bars):
    score = recovery_scores[i]
    if score > 0.7:
        bar.set_color('darkgreen')
    elif score > 0.3:
        bar.set_color('orange')
    else:
        bar.set_color('lightcoral')

plt.xticks(range(len(component_names)), component_names, rotation=45)
plt.ylabel('Recovery Score')
plt.title('Causal Importance of Transformer Components')
plt.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='Medium importance')
plt.legend()
plt.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("Green: Highly important (recovery > 70%)")
print("Orange: Moderately important (30-70%)")
print("Red: Less important (< 30%)")

### Key Insights

This simple experiment reveals which components are **causally necessary** for the computation:
- High recovery score → Component is critical for this task
- Low recovery score → Component has minimal causal impact

This is more informative than just looking at activation magnitudes!

In Notebook 03, we'll learn to:
- Trace information flow through entire circuits
- Build causal graphs
- Discover minimal sufficient subnetworks

## Part 3: Overview of All Techniques in the Library

You've now seen two core techniques: **Sparse Autoencoders** for finding features and **Activation Patching** for discovering circuits. But there's so much more!

### The Complete Toolkit

#### 1. Feature Discovery & Analysis
- **Sparse Autoencoders** (SAEs): Decompose polysemantic neurons
- **Hierarchical SAEs**: Multi-level concept hierarchies
- **Feature Attribution**: Integrated Gradients, DeepLIFT, SHAP
- **Concept Dictionaries**: Semantic labeling of features

#### 2. Causal Interventions
- **Activation Patching**: Test component importance
- **Ablation Studies**: Systematic component removal
- **Path Analysis**: Trace information flow
- **Counterfactuals**: What-if analysis

#### 3. Circuit Extraction
- **Latent RNN Models**: Minimal explanatory circuits
- **DUNL**: Demixed sparse coding for mixed selectivity
- **Feature Visualization**: What do neurons/features respond to?
- **Motif Detection**: Find recurring computational patterns

#### 4. Biological Realism
- **Fractal Analysis**: Measure scale-free dynamics
- **Fractal Regularization**: Enforce biological complexity
- **Spiking Neural Networks**: LIF, Izhikevich, Hodgkin-Huxley
- **Dale's Law**: Excitatory/inhibitory separation
- **Synaptic Plasticity**: STDP, short-term dynamics

#### 5. Brain Alignment
- **CCA**: Canonical Correlation Analysis for alignment
- **RSA**: Representational Similarity Analysis
- **PLS**: Partial Least Squares for prediction
- **Statistical Testing**: Noise ceilings, significance tests

#### 6. Dynamical Systems
- **Koopman Operators**: Linearize nonlinear dynamics
- **Lyapunov Exponents**: Measure chaos and stability
- **Fixed Point Analysis**: Find attractors
- **Controllability**: Plan interventions

#### 7. Information Theory
- **Mutual Information**: MINE estimator for information flow
- **Information Plane**: Track compression during training
- **Energy Landscapes**: Map basins of attraction
- **Entropy Production**: Thermodynamic analysis

#### 8. Geometry & Topology
- **Manifold Analysis**: Intrinsic dimensionality, curvature
- **Persistent Homology**: Topological data analysis
- **Geodesic Distances**: True distances on manifolds

#### 9. Meta-Dynamics
- **Training Trajectories**: How representations evolve
- **Phase Detection**: Identify critical periods
- **Feature Emergence**: Track when features appear
- **Representational Drift**: Measure stability

### The Big Picture: How These Techniques Work Together

Here's a typical workflow combining multiple techniques:

```
1. Train your model
2. Extract features (SAEs) → Understand WHAT is represented
3. Find circuits (Patching) → Understand HOW computation happens
4. Analyze dynamics (Koopman, fixed points) → Understand temporal evolution
5. Compare to brain (CCA, RSA) → Validate biological relevance
6. Generate explanations → Communicate findings
```

Each technique gives you a different lens to understand the same network!

## Part 4: A Roadmap for Your Learning Journey

### Recommended Path

**Week 1: Foundations**
- ✓ This notebook (01)
- Notebook 02: Sparse Autoencoders in depth
- Notebook 03: Causal interventions

**Week 2: Biological Connection**
- Notebook 04: Fractal analysis
- Notebook 05: Brain alignment
- Notebook 08: Biophysical modeling

**Week 3: Advanced Analysis**
- Notebook 06: Dynamical systems
- Notebook 07: Circuit extraction
- Notebook 09: Information theory

**Week 4: Expert Topics**
- Notebook 10: Advanced topics
- Your own research projects!

### Tips for Success

1. **Run the code**: Don't just read—experiment!
2. **Modify parameters**: Change hyperparameters and see what happens
3. **Use your own models**: Apply these techniques to models you care about
4. **Connect to papers**: Read the cited research papers
5. **Ask questions**: Open issues on GitHub or discuss with others

### Common Applications

**For neuroscientists:**
```python
# Compare your model to brain recordings
from neuros_mechint.alignment import CCA, RSA

# Align representations
cca = CCA(n_components=20)
cca.fit(model_activations, brain_data)
alignment_score = cca.score(model_activations_test, brain_data_test)

# Compare geometries
rsa = RSA(metric='correlation')
similarity = rsa.compare(model_activations, brain_data)
```

**For AI researchers:**
```python
# Understand a transformer
from neuros_mechint import SparseAutoencoder
from neuros_mechint.interventions import ActivationPatcher

# Extract features
sae = SparseAutoencoder(input_dim=768, hidden_dim=3072)
sae.train_on_activations(activations)

# Find circuits
patcher = ActivationPatcher(model, layer_name='attention')
results = patcher.run_full_analysis(clean_input, corrupted_input)
```

**For ML engineers:**
```python
# Enforce biological realism during training
from neuros_mechint.fractals import SpectralPrior, HiguchiFractalDimension

# Add fractal regularization
fractal_loss = SpectralPrior(target_exponent=1.0, weight=0.1)
total_loss = task_loss + fractal_loss(hidden_states)

# Monitor complexity
fd_metric = HiguchiFractalDimension(kmax=10)
complexity = fd_metric(neural_timeseries)
```

## Part 5: Practice Exercise

Now it's your turn! Try modifying the code above to:

### Exercise 1: Explore Different SAE Configurations
1. Train SAEs with different sparsity coefficients (0.001, 0.01, 0.1)
2. Try different expansion factors (2x, 4x, 8x)
3. Compare the sparsity and reconstruction quality

### Exercise 2: Test Different Interventions
1. Create a more structured "corruption" (e.g., add noise only to specific positions)
2. Test patching at individual attention heads
3. Visualize how recovery score changes across layers

### Exercise 3: Combine Techniques
1. Train an SAE on activations
2. Use the learned features for activation patching
3. Identify which features are causally important

Here's a starter template:

In [None]:
# Exercise workspace

# TODO: Try different SAE configurations
sparsity_values = [0.001, 0.01, 0.1]

for sparsity in sparsity_values:
    print(f"\nTraining SAE with sparsity={sparsity}")
    # Your code here
    pass

# TODO: Test patching with different corruptions
# Your code here

# TODO: Combine SAE features with causal analysis
# Your code here

## Summary & Next Steps

### What You've Learned

1. **Mechanistic interpretability** aims to reverse-engineer the "source code" of neural networks
2. **Sparse Autoencoders** decompose neurons into interpretable, monosemantic features
3. **Activation patching** reveals which components are causally necessary
4. The library offers **10+ major technique families** for comprehensive analysis

### Key Takeaways

- Interpretability is not just about *inputs* and *outputs*—it's about understanding the *algorithm*
- Different techniques provide complementary insights (features, circuits, dynamics, etc.)
- These methods work on models of any size (from toys to GPT-scale)
- The same techniques apply to both artificial and biological neural networks

### Next Steps

**Ready to dive deeper?** Here's what to explore next:

1. **[Notebook 02: Sparse Autoencoders](02_sparse_autoencoders.ipynb)**
   - Deep dive into SAE training and interpretation
   - Hierarchical concept extraction
   - Causal SAE interventions

2. **[Notebook 03: Causal Interventions](03_causal_interventions.ipynb)**
   - Advanced patching techniques
   - Building causal graphs
   - Circuit discovery workflows

3. **[Notebook 04: Fractal Analysis](04_fractal_analysis.ipynb)**
   - Understanding biological complexity
   - Training with fractal regularization
   - Comparing model and brain dynamics

### Additional Resources

**Papers to Read:**
- Elhage et al. (2021): "A Mathematical Framework for Transformer Circuits"
- Anthropic (2023): "Towards Monosemanticity"
- Olah et al. (2020): "Zoom In: An Introduction to Circuits"
- Sussillo & Barak (2013): "Opening the Black Box"

**Websites:**
- [Anthropic's Interpretability Research](https://transformer-circuits.pub/)
- [Neuronpedia](https://neuronpedia.org/) - Browse interpretable features
- [Distill.pub](https://distill.pub/) - Visual explanations

### Get Involved!

This library is designed to become a **community standard** for mechanistic interpretability in neuroscience and AI. We welcome:
- Bug reports and feature requests
- New example notebooks
- Integrations with your research
- Documentation improvements

Together, we can make neural networks—artificial and biological—truly interpretable!

---

**Ready to continue?** Open [02_sparse_autoencoders.ipynb](02_sparse_autoencoders.ipynb) to master feature discovery!