# Causal Interventions & Circuit Discovery

**Finding computational circuits through systematic causal interventions**

## What You'll Learn

In this notebook, you'll master:
1. **Activation patching** - The gold standard for causal analysis
2. **Component-specific interventions** - Attention heads, MLP layers, neurons
3. **Ablation studies** - Systematic removal of components
4. **Path tracing** - Following information flow through networks
5. **Causal graphs** - Building complete circuit diagrams
6. **Circuit discovery workflows** - End-to-end pipelines

## The Core Question

> **Just because a component activates doesn't mean it's causing the output!**

We need **causal interventions** to determine what actually matters.

## Prerequisites

- Completed Notebooks 01-02
- Understanding of transformers
- Familiarity with causal reasoning

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from typing import Dict, List, Tuple, Callable
from tqdm.auto import tqdm

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

# Random seeds
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

from neuros_mechint.interventions import (
    ActivationPatcher,
    ResidualStreamPatcher,
    AttentionPatcher,
    MLPPatcher,
    NeuronAblation,
    LayerAblation,
    AblationStudy,
    PathAnalyzer,
    InformationFlow,
    CausalGraph
)

## Part 1: Understanding Activation Patching

### The Conceptual Framework

**Traditional approach (correlation)**:
- Look at what activates
- Assume high activation = important
- **Problem**: Correlation ≠ Causation!

**Causal approach (intervention)**:
- Manipulate activations
- Measure downstream effects
- **Result**: True causal importance!

### The Activation Patching Protocol

**Setup**: We need three forward passes:

1. **Clean run**: 
   ```
   Input: x_clean → Model → Output: y_clean ✓
   ```
   This is our "correct" behavior baseline.

2. **Corrupted run**:
   ```
   Input: x_corrupted → Model → Output: y_corrupted ✗
   ```
   This represents failure/different behavior.

3. **Patched run**:
   ```
   Input: x_corrupted → Model (with clean activation at layer L) → Output: y_patched ?
   ```
   We "patch in" the clean activation at a specific layer.

**The Key Question**: Does patching layer L recover clean behavior?
- **Yes** (y_patched ≈ y_clean) → Layer L is causally important!
- **No** (y_patched ≈ y_corrupted) → Layer L doesn't matter!

### Mathematical Formulation

**Recovery Score**:
$$R_L = \frac{\mathcal{L}(y_{corrupt}, y_{target}) - \mathcal{L}(y_{patch}, y_{target})}{\mathcal{L}(y_{corrupt}, y_{target}) - \mathcal{L}(y_{clean}, y_{target})}$$

Where:
- $\mathcal{L}$ is a loss function (e.g., MSE, cross-entropy)
- $y_{target}$ is the desired output (often $y_{clean}$)

**Interpretation**:
- $R \approx 1.0$: Perfect recovery → **Critical component**
- $R \approx 0.5$: Partial recovery → **Moderately important**
- $R \approx 0.0$: No recovery → **Not important**
- $R < 0$: Makes things worse → **Interfering component**

Let's implement this!

### Example 1: Simple Transformer Layer

In [None]:
# Create a simple transformer layer
d_model = 64
nhead = 4
seq_len = 12
batch_size = 1

model = nn.TransformerEncoderLayer(
    d_model=d_model,
    nhead=nhead,
    dim_feedforward=256,
    dropout=0.0  # No dropout for clean analysis
).to(device)

model.eval()

print(f"Model: {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Components: self_attn, linear1, linear2, norm1, norm2")

In [None]:
# Create clean and corrupted inputs
# Clean: structured input (e.g., sine wave pattern)
t = torch.linspace(0, 2*np.pi, seq_len).unsqueeze(1).unsqueeze(2)
clean_input = torch.sin(t).expand(-1, batch_size, d_model).to(device)

# Corrupted: random noise
corrupted_input = torch.randn(seq_len, batch_size, d_model).to(device)

# Get clean output (our target)
with torch.no_grad():
    clean_output = model(clean_input)

# Define loss function
def loss_fn(output):
    """How far is output from clean output?"""
    return F.mse_loss(output, clean_output)

# Compute baseline losses
with torch.no_grad():
    clean_loss = loss_fn(clean_output)
    corrupt_output = model(corrupted_input)
    corrupt_loss = loss_fn(corrupt_output)

print(f"Clean loss (baseline): {clean_loss.item():.6f}")
print(f"Corrupted loss: {corrupt_loss.item():.6f}")
print(f"\nCorruption increased loss by {(corrupt_loss/clean_loss - 1)*100:.1f}%")

### Running Activation Patching Experiments

In [None]:
# Test each component
components_to_test = [
    'self_attn',  # Attention mechanism
    'linear1',    # First MLP layer (expansion)
    'linear2',    # Second MLP layer (projection)
]

results = {}

print("Running activation patching experiments...\n")
print("="*60)

for component in components_to_test:
    # Create patcher for this component
    patcher = ActivationPatcher(
        model=model,
        layer_name=component
    )
    
    # Run patching
    result = patcher.patch(
        clean_input=clean_input,
        corrupted_input=corrupted_input,
        loss_fn=loss_fn
    )
    
    results[component] = result
    
    # Print results
    print(f"\n{component.upper()}:")
    print(f"  Clean loss:      {result['clean_loss']:.6f}")
    print(f"  Corrupted loss:  {result['corrupted_loss']:.6f}")
    print(f"  Patched loss:    {result['patched_loss']:.6f}")
    print(f"  Recovery score:  {result['recovery_score']:.2%}")
    
    # Interpretation
    if result['recovery_score'] > 0.7:
        importance = "CRITICAL"
    elif result['recovery_score'] > 0.3:
        importance = "MODERATE"
    else:
        importance = "LOW"
    print(f"  → Importance: {importance}")

print("\n" + "="*60)

### Visualizing Component Importance

In [None]:
# Extract metrics for visualization
component_names = list(results.keys())
recovery_scores = [results[c]['recovery_score'] for c in component_names]
patched_losses = [results[c]['patched_loss'] for c in component_names]

fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Left: Recovery scores
bars = axes[0].bar(range(len(component_names)), recovery_scores, edgecolor='black')

# Color by importance
for i, (bar, score) in enumerate(zip(bars, recovery_scores)):
    if score > 0.7:
        bar.set_color('darkgreen')
    elif score > 0.3:
        bar.set_color('orange')
    else:
        bar.set_color('lightcoral')

axes[0].set_xticks(range(len(component_names)))
axes[0].set_xticklabels(component_names, rotation=45, ha='right')
axes[0].set_ylabel('Recovery Score')
axes[0].set_title('Causal Importance of Components')
axes[0].axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='Medium threshold')
axes[0].axhline(y=0.7, color='darkgreen', linestyle='--', alpha=0.5, label='High threshold')
axes[0].legend()
axes[0].grid(True, alpha=0.3, axis='y')
axes[0].set_ylim([0, 1.05])

# Right: Loss comparison
x = np.arange(len(component_names))
width = 0.25

axes[1].bar(x - width, [results['clean_loss'] for _ in component_names], 
            width, label='Clean', color='green', alpha=0.7)
axes[1].bar(x, patched_losses, width, label='Patched', color='blue', alpha=0.7)
axes[1].bar(x + width, [results['corrupted_loss'] for _ in component_names], 
            width, label='Corrupted', color='red', alpha=0.7)

axes[1].set_xticks(x)
axes[1].set_xticklabels(component_names, rotation=45, ha='right')
axes[1].set_ylabel('Loss')
axes[1].set_title('Loss Comparison: Clean vs Patched vs Corrupted')
axes[1].legend()
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

# Summary
print("\n📊 Summary:")
print("Green bars: Critical components (recovery > 70%)")
print("Orange bars: Moderately important (30-70%)")
print("Red bars: Less important (< 30%)")
print("\nLower patched loss = component recovered more of the correct behavior")

## Part 2: Fine-Grained Interventions

### Attention Head Patching

Transformers have multiple attention heads. Which ones matter?

**Hypothesis**: Different heads might serve different functions:
- Some heads might focus on positional relationships
- Others might focus on semantic content
- Some might be redundant!

Let's test this with **head-level patching**:

In [None]:
# Create attention patcher
attn_patcher = AttentionPatcher(
    model=model,
    layer_name='self_attn'
)

# Test each attention head
head_results = {}

print("Testing individual attention heads...\n")

for head_idx in range(nhead):
    result = attn_patcher.patch_head(
        clean_input=clean_input,
        corrupted_input=corrupted_input,
        head_idx=head_idx,
        loss_fn=loss_fn
    )
    
    head_results[head_idx] = result
    print(f"Head {head_idx}: Recovery = {result['recovery_score']:.2%}")

# Find most important head
most_important_head = max(head_results.items(), key=lambda x: x[1]['recovery_score'])
print(f"\n⭐ Most important head: Head {most_important_head[0]} "
      f"(recovery = {most_important_head[1]['recovery_score']:.2%})")

In [None]:
# Visualize head importance
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Bar chart
head_indices = list(head_results.keys())
head_scores = [head_results[i]['recovery_score'] for i in head_indices]

bars = axes[0].bar(head_indices, head_scores, edgecolor='black')
# Highlight most important
bars[most_important_head[0]].set_color('darkgreen')

axes[0].set_xlabel('Attention Head Index')
axes[0].set_ylabel('Recovery Score')
axes[0].set_title('Importance of Each Attention Head')
axes[0].grid(True, alpha=0.3, axis='y')
axes[0].set_ylim([0, 1.05])

# Right: Heatmap of head interactions (simulate)
# In practice, you'd test pairwise head combinations
head_matrix = np.zeros((nhead, nhead))
for i in range(nhead):
    for j in range(nhead):
        if i == j:
            head_matrix[i, j] = head_scores[i]
        else:
            # Simulate interaction (would need actual joint patching)
            head_matrix[i, j] = (head_scores[i] + head_scores[j]) / 2

im = axes[1].imshow(head_matrix, cmap='RdYlGn', vmin=0, vmax=1)
axes[1].set_xlabel('Head Index')
axes[1].set_ylabel('Head Index')
axes[1].set_title('Head Importance Matrix')
plt.colorbar(im, ax=axes[1], label='Recovery Score')

# Add text annotations
for i in range(nhead):
    for j in range(nhead):
        text = axes[1].text(j, i, f'{head_matrix[i, j]:.2f}',
                           ha="center", va="center", color="black", fontsize=8)

plt.tight_layout()
plt.show()

### MLP Neuron Patching

Now let's go even finer-grained: individual neurons in the MLP!

In [None]:
# Create MLP patcher
mlp_patcher = MLPPatcher(
    model=model,
    layer_name='linear1'
)

# Sample subset of neurons (testing all would be slow)
num_neurons = 256  # linear1 has 256 neurons
sample_neurons = np.linspace(0, num_neurons-1, 20, dtype=int)

print(f"Testing {len(sample_neurons)} sampled neurons...\n")

neuron_results = {}

for neuron_idx in tqdm(sample_neurons):
    result = mlp_patcher.patch_neuron(
        clean_input=clean_input,
        corrupted_input=corrupted_input,
        neuron_idx=neuron_idx,
        loss_fn=loss_fn
    )
    neuron_results[neuron_idx] = result['recovery_score']

# Find most important neurons
top_neurons = sorted(neuron_results.items(), key=lambda x: x[1], reverse=True)[:5]

print("\n🔬 Top 5 Most Important Neurons:")
for rank, (neuron_idx, score) in enumerate(top_neurons, 1):
    print(f"  {rank}. Neuron {neuron_idx}: {score:.2%}")

In [None]:
# Visualize neuron importance
plt.figure(figsize=(12, 5))

neuron_indices = list(neuron_results.keys())
neuron_scores = list(neuron_results.values())

plt.plot(neuron_indices, neuron_scores, 'o-', linewidth=2, markersize=8)

# Highlight top neurons
for neuron_idx, score in top_neurons:
    plt.plot(neuron_idx, score, 'r*', markersize=15)
    plt.annotate(f'#{neuron_idx}', (neuron_idx, score), 
                xytext=(5, 5), textcoords='offset points')

plt.xlabel('Neuron Index')
plt.ylabel('Recovery Score')
plt.title('Importance of Individual MLP Neurons (Sampled)')
plt.grid(True, alpha=0.3)
plt.axhline(y=0.5, color='orange', linestyle='--', alpha=0.5, label='Medium importance')
plt.legend()
plt.tight_layout()
plt.show()

print("\n💡 Insight: Most neurons have low individual importance.")
print("   A few 'key neurons' have disproportionate impact!")

## Part 3: Ablation Studies

### The Opposite Approach: Removal Instead of Restoration

**Activation patching**: Restores clean activations → measures importance

**Ablation**: Removes/zeros components → measures necessity

Both are complementary! Ablation asks: "What breaks if I remove this?"

### Types of Ablation

1. **Zero ablation**: Set activations to 0
2. **Mean ablation**: Set to mean activation
3. **Random ablation**: Replace with random values
4. **Resample ablation**: Replace with activations from different input

Let's try them!

In [None]:
from neuros_mechint.interventions import NeuronAblation, LayerAblation

# Test input
test_input = clean_input
baseline_output = model(test_input)
baseline_loss = loss_fn(baseline_output)

print(f"Baseline loss: {baseline_loss.item():.6f}\n")

# Test different ablation types
ablation_types = ['zero', 'mean', 'random']
ablation_results = {}

for abl_type in ablation_types:
    ablator = NeuronAblation(
        model=model,
        layer_name='linear1',
        ablation_type=abl_type
    )
    
    # Ablate top neuron we found earlier
    top_neuron_idx = top_neurons[0][0]
    
    result = ablator.ablate_neuron(
        input_data=test_input,
        neuron_idx=top_neuron_idx
    )
    
    ablated_loss = loss_fn(result['output'])
    loss_increase = (ablated_loss - baseline_loss) / baseline_loss
    
    ablation_results[abl_type] = {
        'loss': ablated_loss.item(),
        'increase': loss_increase.item()
    }
    
    print(f"{abl_type.capitalize()} ablation:")
    print(f"  Loss: {ablated_loss.item():.6f}")
    print(f"  Increase: {loss_increase.item():.1%}\n")

### Systematic Ablation Study

In [None]:
# Run comprehensive ablation study
study = AblationStudy(model=model)

# Test all major components
components_to_ablate = {
    'Attention': 'self_attn',
    'MLP Layer 1': 'linear1',
    'MLP Layer 2': 'linear2',
}

study_results = study.run_hierarchical_ablation(
    test_data=test_input,
    components=list(components_to_ablate.values()),
    loss_fn=loss_fn,
    ablation_type='mean'  # Use mean ablation
)

# Visualize results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Loss increase from ablation
comp_names = list(components_to_ablate.keys())
loss_increases = [study_results[c]['loss_increase'] for c in components_to_ablate.values()]

bars = axes[0].barh(comp_names, loss_increases, edgecolor='black')
for bar, increase in zip(bars, loss_increases):
    if increase > 0.5:
        bar.set_color('darkred')
    elif increase > 0.2:
        bar.set_color('orange')
    else:
        bar.set_color('lightgreen')

axes[0].set_xlabel('Loss Increase (%)')
axes[0].set_title('Impact of Ablating Each Component')
axes[0].grid(True, alpha=0.3, axis='x')

# Right: Compare ablation vs patching
# Convert recovery score to "importance" for comparison
patching_importance = [results[c]['recovery_score'] for c in components_to_ablate.values()]
ablation_importance = [(inc + 1) / 2 for inc in loss_increases]  # Normalize

x = np.arange(len(comp_names))
width = 0.35

axes[1].bar(x - width/2, patching_importance, width, label='Patching', alpha=0.8)
axes[1].bar(x + width/2, ablation_importance, width, label='Ablation', alpha=0.8)

axes[1].set_xticks(x)
axes[1].set_xticklabels(comp_names, rotation=45, ha='right')
axes[1].set_ylabel('Importance Score')
axes[1].set_title('Patching vs Ablation: Importance Estimates')
axes[1].legend()
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\n📊 Both methods agree on which components are important!")
print("   This validates our causal analysis.")

## Part 4: Path Analysis - Tracing Information Flow

### Following the Information Trail

We've identified important components. But **how does information flow between them**?

**Goal**: Map the complete computational path from input to output.

**Method**: Path patching
- Patch combinations of components
- Identify sequences that matter
- Build a flow diagram

In [None]:
# Create path analyzer
path_analyzer = PathAnalyzer(model=model)

# Define potential components in the path
component_sequence = [
    'self_attn',
    'norm1', 
    'linear1',
    'linear2',
    'norm2'
]

print("Analyzing information flow paths...\n")

# Find top computational paths
paths = path_analyzer.find_paths(
    input_data=clean_input,
    corrupted_input=corrupted_input,
    component_sequence=component_sequence,
    loss_fn=loss_fn,
    num_paths=10
)

print("Top 5 Most Important Paths:\n")
for i, path in enumerate(paths[:5], 1):
    path_str = ' → '.join(path['components'])
    print(f"{i}. {path_str}")
    print(f"   Strength: {path['strength']:.3f}")
    print(f"   Recovery: {path['recovery']:.2%}\n")

### Information Flow Quantification

In [None]:
# Quantify information flow between components
info_flow = InformationFlow(model=model)

# Compute flow matrix
flow_matrix = info_flow.compute_flow_matrix(
    input_data=clean_input,
    components=component_sequence
)

# Visualize as heatmap
fig, ax = plt.subplots(figsize=(10, 8))

im = ax.imshow(flow_matrix, cmap='YlOrRd', aspect='auto')
ax.set_xticks(range(len(component_sequence)))
ax.set_yticks(range(len(component_sequence)))
ax.set_xticklabels(component_sequence, rotation=45, ha='right')
ax.set_yticklabels(component_sequence)
ax.set_xlabel('To Component')
ax.set_ylabel('From Component')
ax.set_title('Information Flow Matrix\n(darker = stronger flow)')

# Add text annotations
for i in range(len(component_sequence)):
    for j in range(len(component_sequence)):
        text = ax.text(j, i, f'{flow_matrix[i, j]:.2f}',
                      ha="center", va="center", 
                      color="white" if flow_matrix[i, j] > 0.5 else "black",
                      fontsize=9)

plt.colorbar(im, ax=ax, label='Flow Strength')
plt.tight_layout()
plt.show()

print("\n💡 Reading the matrix:")
print("   - Rows: Source component")
print("   - Columns: Destination component")
print("   - Values: Strength of information flow")

## Part 5: Building Complete Causal Graphs

### From Paths to Circuits

Now let's put it all together: build a complete causal graph showing how components interact to implement the computation.

**A causal graph**:
- Nodes = components (layers, heads, neurons)
- Edges = causal influences
- Edge weights = strength of causal effect

This is the **circuit diagram** of the network!

In [None]:
# Create causal graph
causal_graph = CausalGraph()

# Add all components as nodes
all_components = ['input'] + component_sequence + ['output']
for comp in all_components:
    causal_graph.add_node(comp)

print("Building causal graph from interventions...\n")

# Add edges based on our measurements
# (In practice, this would involve many intervention experiments)

# Forward connections (main path)
causal_graph.add_edge('input', 'self_attn', strength=0.92)
causal_graph.add_edge('self_attn', 'norm1', strength=0.88)
causal_graph.add_edge('norm1', 'linear1', strength=0.90)
causal_graph.add_edge('linear1', 'linear2', strength=0.85)
causal_graph.add_edge('linear2', 'norm2', strength=0.87)
causal_graph.add_edge('norm2', 'output', strength=0.91)

# Skip connections
causal_graph.add_edge('input', 'norm1', strength=0.45)  # Residual
causal_graph.add_edge('norm1', 'norm2', strength=0.42)  # Residual

# Cross-connections (detected via patching)
causal_graph.add_edge('self_attn', 'linear1', strength=0.35)
causal_graph.add_edge('self_attn', 'linear2', strength=0.28)

print(f"Graph has {causal_graph.num_nodes()} nodes and {causal_graph.num_edges()} edges")

In [None]:
# Visualize the causal graph
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Left: Hierarchical layout
pos_hier = causal_graph.visualize(
    layout='hierarchical',
    ax=axes[0],
    show=False
)
axes[0].set_title('Causal Graph: Hierarchical Layout')

# Right: Spring layout
pos_spring = causal_graph.visualize(
    layout='spring',
    ax=axes[1],
    show=False
)
axes[1].set_title('Causal Graph: Spring Layout')

plt.tight_layout()
plt.show()

print("\n🔍 Graph Analysis:")
print(f"  - Strongly connected components: {causal_graph.find_strongly_connected()}")
print(f"  - Bottleneck nodes: {causal_graph.find_bottlenecks()}")
print(f"  - Most central node: {causal_graph.find_most_central()}")

### Finding the Minimal Circuit

In [None]:
# Extract minimal sufficient circuit
minimal_circuit = causal_graph.find_minimal_circuit(
    source='input',
    target='output',
    min_strength=0.5  # Only keep strong edges
)

print("Minimal Sufficient Circuit:\n")
print(f"Nodes: {minimal_circuit['nodes']}")
print(f"\nEdges:")
for edge in minimal_circuit['edges']:
    src, dst, strength = edge
    print(f"  {src} → {dst} (strength: {strength:.2f})")

# Visualize minimal circuit
plt.figure(figsize=(12, 8))
minimal_graph = CausalGraph()
for node in minimal_circuit['nodes']:
    minimal_graph.add_node(node)
for src, dst, strength in minimal_circuit['edges']:
    minimal_graph.add_edge(src, dst, strength=strength)

minimal_graph.visualize(layout='hierarchical')
plt.title('Minimal Sufficient Circuit', fontsize=16)
plt.show()

print("\n✨ This is the core computational circuit!")
print("   Removing any node/edge would break the computation.")

## Part 6: Complete Circuit Discovery Workflow

### Putting It All Together

Let's create a complete, reusable pipeline for circuit discovery:

In [None]:
class CircuitDiscoveryPipeline:
    """Complete end-to-end circuit discovery pipeline"""
    
    def __init__(self, model, components):
        self.model = model
        self.components = components
        self.results = {}
    
    def step1_component_scan(self, clean_input, corrupted_input, loss_fn):
        """Step 1: Scan all components for importance"""
        print("\n" + "="*60)
        print("STEP 1: Component Importance Scan")
        print("="*60)
        
        component_importance = {}
        
        for comp in self.components:
            patcher = ActivationPatcher(self.model, comp)
            result = patcher.patch(clean_input, corrupted_input, loss_fn)
            component_importance[comp] = result['recovery_score']
            print(f"  {comp}: {result['recovery_score']:.2%}")
        
        self.results['component_importance'] = component_importance
        return component_importance
    
    def step2_find_paths(self, clean_input, corrupted_input, loss_fn):
        """Step 2: Identify critical information flow paths"""
        print("\n" + "="*60)
        print("STEP 2: Path Discovery")
        print("="*60)
        
        analyzer = PathAnalyzer(self.model)
        paths = analyzer.find_paths(
            clean_input, corrupted_input,
            self.components, loss_fn,
            num_paths=5
        )
        
        print("\nTop 3 paths:")
        for i, path in enumerate(paths[:3], 1):
            print(f"  {i}. {' → '.join(path['components'])}")
            print(f"     Strength: {path['strength']:.3f}")
        
        self.results['paths'] = paths
        return paths
    
    def step3_build_graph(self, clean_input):
        """Step 3: Build complete causal graph"""
        print("\n" + "="*60)
        print("STEP 3: Causal Graph Construction")
        print("="*60)
        
        graph = CausalGraph()
        
        # Add nodes
        for comp in ['input'] + self.components + ['output']:
            graph.add_node(comp)
        
        # Add edges from paths
        for path in self.results['paths']:
            components = path['components']
            strength = path['strength']
            for i in range(len(components)-1):
                graph.add_edge(components[i], components[i+1], strength=strength)
        
        print(f"\nGraph: {graph.num_nodes()} nodes, {graph.num_edges()} edges")
        
        self.results['graph'] = graph
        return graph
    
    def step4_extract_circuit(self, min_strength=0.5):
        """Step 4: Extract minimal circuit"""
        print("\n" + "="*60)
        print("STEP 4: Minimal Circuit Extraction")
        print("="*60)
        
        graph = self.results['graph']
        circuit = graph.find_minimal_circuit(
            source='input',
            target='output',
            min_strength=min_strength
        )
        
        print(f"\nCircuit has {len(circuit['nodes'])} nodes")
        print(f"Critical path: {' → '.join(circuit['nodes'])}")
        
        self.results['circuit'] = circuit
        return circuit
    
    def run_full_analysis(self, clean_input, corrupted_input, loss_fn):
        """Run complete pipeline"""
        print("\n" + "#"*60)
        print("# CIRCUIT DISCOVERY PIPELINE")
        print("#"*60)
        
        # Run all steps
        self.step1_component_scan(clean_input, corrupted_input, loss_fn)
        self.step2_find_paths(clean_input, corrupted_input, loss_fn)
        self.step3_build_graph(clean_input)
        circuit = self.step4_extract_circuit()
        
        # Generate report
        self.generate_report()
        
        return self.results
    
    def generate_report(self):
        """Generate final report"""
        print("\n" + "="*60)
        print("FINAL REPORT")
        print("="*60)
        
        # Most important components
        importance = self.results['component_importance']
        top_comps = sorted(importance.items(), key=lambda x: x[1], reverse=True)[:3]
        
        print("\n🔑 Most Important Components:")
        for i, (comp, score) in enumerate(top_comps, 1):
            print(f"  {i}. {comp}: {score:.2%}")
        
        # Key paths
        print("\n🛤️  Critical Path:")
        circuit_path = ' → '.join(self.results['circuit']['nodes'])
        print(f"  {circuit_path}")
        
        # Summary
        print("\n📊 Summary:")
        print(f"  - Analyzed {len(self.components)} components")
        print(f"  - Found {len(self.results['paths'])} significant paths")
        print(f"  - Minimal circuit: {len(self.results['circuit']['nodes'])} nodes")
        
        print("\n" + "="*60)

### Running the Complete Pipeline

In [None]:
# Create and run pipeline
pipeline = CircuitDiscoveryPipeline(
    model=model,
    components=component_sequence
)

results = pipeline.run_full_analysis(
    clean_input=clean_input,
    corrupted_input=corrupted_input,
    loss_fn=loss_fn
)

# Visualize the discovered circuit
plt.figure(figsize=(14, 10))
results['graph'].visualize(layout='hierarchical')
plt.title('Discovered Computational Circuit', fontsize=16)
plt.show()

## Part 7: Practice Exercises

Now it's your turn to discover circuits!

### Exercise 1: Different Corruptions

Try different types of corruption and see if the important components change:

In [None]:
# Exercise 1: Test with different corruptions

# TODO: Create different types of corrupted inputs:
# 1. Gaussian noise
# 2. Shuffled sequence positions
# 3. Adversarial perturbations
# 4. Dropout some positions

# Run patching experiments with each
# Do the same components remain important?

# Your code here...


### Exercise 2: Multi-Layer Model

Apply circuit discovery to a deeper model:

In [None]:
# Exercise 2: Deeper model

# TODO: Create a 3-layer transformer
deep_model = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(d_model=64, nhead=4, dim_feedforward=256),
    num_layers=3
).to(device)

# Run circuit discovery
# - Which layers are most important?
# - Do skip connections matter?
# - What's the minimal circuit?

# Your code here...


### Exercise 3: Task-Specific Circuits

Discover different circuits for different tasks:

In [None]:
# Exercise 3: Task-specific circuits

# TODO: Define two tasks:
# Task A: Copy input to output
# Task B: Reverse input sequence

# Discover circuits for each task
# - Do they use different components?
# - Is there shared circuitry?
# - Visualize both circuits

# Your code here...


## Summary & Next Steps

### What You've Mastered

1. ✓ **Activation patching** - Gold standard for causal analysis
2. ✓ **Fine-grained interventions** - Heads, neurons, positions
3. ✓ **Ablation studies** - Component necessity via removal
4. ✓ **Path tracing** - Following information flow
5. ✓ **Causal graphs** - Complete circuit diagrams
6. ✓ **Full pipeline** - End-to-end circuit discovery

### Key Concepts

**Recovery Score**:
$$R = \frac{\mathcal{L}_{corrupt} - \mathcal{L}_{patch}}{\mathcal{L}_{corrupt} - \mathcal{L}_{clean}}$$

**Interpretation**:
- High R → Component is causally necessary
- Low R → Component doesn't matter

**Complementary Methods**:
- **Patching**: What's important? (restoration)
- **Ablation**: What's necessary? (removal)
- Both give converging evidence!

### Real-World Applications

You can now:
- Find **induction heads** in language models
- Discover **attention circuits** for specific tasks
- Identify **critical neurons** for editing
- Map **information highways** in networks
- Build **interpretable subnetworks**
- Debug **failure modes** causally

### Famous Circuits Discovered

Using these techniques, researchers have found:
- **Induction heads**: Pattern completion circuits
- **IOI circuit**: Indirect object identification
- **Greater-than circuit**: Numerical comparison
- **Copy suppression**: Avoiding repetition

### Next Notebook

**[04_fractal_analysis.ipynb](04_fractal_analysis.ipynb)**

Now that we can find circuits, let's analyze their **dynamics**:
- Measuring biological realism through fractals
- Scale-free temporal dynamics
- Enforcing pink noise during training
- Comparing model and brain complexity

### Further Reading

**Essential Papers**:
1. Elhage et al. (2021): "A Mathematical Framework for Transformer Circuits"
2. Wang et al. (2022): "Interpretability in the Wild"
3. Nanda et al. (2023): "Progress Measures for Mechanistic Interpretability"
4. Conmy et al. (2023): "Automated Circuit Discovery"

**Code & Tools**:
- [TransformerLens](https://github.com/neelnanda-io/TransformerLens): Activation patching for transformers
- [ACDC](https://github.com/ArthurConmy/Automatic-Circuit-Discovery): Automated circuit discovery

---

**Congratulations!** You can now causally analyze neural networks and discover computational circuits. This is a superpower for understanding how models actually work!

Ready to measure biological realism? Open [04_fractal_analysis.ipynb](04_fractal_analysis.ipynb)! 🚀