# Notebook 10: Advanced Topics in Mechanistic Interpretability

## Beyond the Basics: Cutting-Edge Interpretability Methods

This final notebook explores **advanced topics** that push the boundaries of mechanistic interpretability:
- **Meta-dynamics**: How representations evolve during training
- **Manifold Geometry**: Shape and curvature of representational spaces
- **Topological Analysis**: Persistent homology and Betti numbers
- **Counterfactual Interventions**: What-if experiments on networks

### Why Advanced Methods Matter

1. **Training Dynamics**: Understand how and when networks learn
2. **Geometric Structure**: Reveal intrinsic organization of representations
3. **Robust Analysis**: Topology is invariant to smooth deformations
4. **Causal Understanding**: Counterfactuals reveal causal mechanisms
5. **Phase Transitions**: Identify critical moments in learning

### What You'll Learn

1. **Meta-Dynamics Analysis**: Track training trajectories in representation space
2. **Manifold Curvature**: Measure local geometry of neural representations
3. **Persistent Homology**: Detect topological features (holes, loops, voids)
4. **Counterfactual Interventions**: Do-calculus and latent surgery
5. **Feature Emergence Detection**: When do concepts crystallize?
6. **Representational Drift**: How representations change over time

### References

- Achille & Soatto (2018): *Emergence of invariance and disentanglement in deep representations*
- Chung et al. (2018): *Classification and geometry of general perceptual manifolds*
- Naitzat et al. (2020): *Topology of deep neural networks*
- Pearl (2009): *Causality: Models, Reasoning, and Inference*
- Geirhos et al. (2020): *Beyond accuracy: Quantifying trial-by-trial behaviour*

In [None]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist, squareform
from scipy.linalg import eigh
from sklearn.decomposition import PCA
from sklearn.manifold import MDS
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import copy

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

print("All libraries imported successfully!")
print(f"NumPy version: {np.__version__}")
print(f"PyTorch version: {torch.__version__}")

## Part 1: Meta-Dynamics - Tracking Training Trajectories

### What are Meta-Dynamics?

**Meta-dynamics** studies how neural representations change during training:
- Not just "does the model learn?" but **"how does learning unfold?"**
- Track checkpoints throughout training
- Analyze representational trajectories
- Detect phase transitions and critical periods

### Key Questions

1. **When do features emerge?** At what point does the network discover concepts?
2. **Are there phases?** Distinct stages of learning (e.g., fitting → compression)
3. **What order?** Do simple features emerge before complex ones?
4. **Representational drift**: Do representations keep changing after convergence?

### Metrics for Tracking

1. **Representational Similarity**: CKA, CCA, RSA between checkpoints
2. **Feature Selectivity**: How selective are neurons to task variables?
3. **Dimensionality**: Participation ratio, intrinsic dimension
4. **Alignment**: How well do representations align with brain/behavior?

### Applications

- **Curriculum learning**: Train on easy examples first
- **Early stopping**: Stop when representations stop improving
- **Architecture search**: Identify when capacity is saturated
- **Transfer learning**: Choose which checkpoint to transfer

In [None]:
class TrainingTrajectoryAnalyzer:
    """
    Analyze how representations evolve during training.
    """
    
    def __init__(self):
        self.checkpoints = []
        self.metrics_history = {
            'loss': [],
            'accuracy': [],
            'dimensionality': [],
            'similarity_to_init': []
        }
    
    def save_checkpoint(self, model, epoch, loss, accuracy):
        """
        Save model checkpoint and compute metrics.
        """
        # Deep copy model state
        checkpoint = {
            'epoch': epoch,
            'state_dict': copy.deepcopy(model.state_dict()),
            'loss': loss,
            'accuracy': accuracy
        }
        self.checkpoints.append(checkpoint)
        
        # Record metrics
        self.metrics_history['loss'].append(loss)
        self.metrics_history['accuracy'].append(accuracy)
    
    def compute_representation_similarity(self, model1, model2, X):
        """
        Compute similarity between representations from two models.
        
        Uses Centered Kernel Alignment (CKA).
        """
        # Get representations
        with torch.no_grad():
            repr1 = self._get_representation(model1, X)
            repr2 = self._get_representation(model2, X)
        
        # Compute CKA
        cka = self._linear_cka(repr1, repr2)
        return cka
    
    def _get_representation(self, model, X):
        """Extract penultimate layer representation."""
        activation = None
        
        def hook_fn(module, input, output):
            nonlocal activation
            activation = output
        
        # Find penultimate layer
        layers = [m for m in model.modules() if isinstance(m, nn.Linear)]
        if len(layers) >= 2:
            handle = layers[-2].register_forward_hook(hook_fn)
        else:
            handle = layers[-1].register_forward_hook(hook_fn)
        
        _ = model(X)
        handle.remove()
        
        return activation.cpu().numpy()
    
    def _linear_cka(self, X, Y):
        """
        Compute Linear CKA similarity.
        
        CKA(X, Y) = ||X^T Y||_F^2 / (||X^T X||_F * ||Y^T Y||_F)
        """
        # Center
        X = X - X.mean(axis=0)
        Y = Y - Y.mean(axis=0)
        
        # Gram matrices
        XTX = X.T @ X
        YTY = Y.T @ Y
        XTY = X.T @ Y
        
        # CKA
        numerator = np.linalg.norm(XTY, 'fro')**2
        denominator = np.linalg.norm(XTX, 'fro') * np.linalg.norm(YTY, 'fro')
        
        if denominator == 0:
            return 0
        
        return numerator / denominator
    
    def compute_dimensionality(self, model, X):
        """
        Compute effective dimensionality (participation ratio).
        """
        repr = self._get_representation(model, X)
        
        # Covariance
        cov = np.cov(repr.T)
        eigvals = np.linalg.eigvalsh(cov)
        eigvals = np.maximum(eigvals, 0)
        
        # Participation ratio
        if np.sum(eigvals**2) == 0:
            return 0
        pr = np.sum(eigvals)**2 / np.sum(eigvals**2)
        
        return pr
    
    def analyze_trajectory(self, X):
        """
        Analyze full training trajectory.
        
        Computes:
        - Similarity to initialization
        - Dimensionality over time
        - Phase detection
        """
        if len(self.checkpoints) < 2:
            print("Need at least 2 checkpoints")
            return
        
        print("Analyzing training trajectory...")
        
        # Load first checkpoint (initialization)
        init_model = self._load_checkpoint(0)
        
        # Analyze each checkpoint
        for i, checkpoint in enumerate(self.checkpoints):
            model = self._load_checkpoint(i)
            
            # Similarity to init
            sim = self.compute_representation_similarity(init_model, model, X)
            self.metrics_history['similarity_to_init'].append(sim)
            
            # Dimensionality
            dim = self.compute_dimensionality(model, X)
            self.metrics_history['dimensionality'].append(dim)
        
        print("Analysis complete!")
    
    def _load_checkpoint(self, idx):
        """Load model from checkpoint."""
        # This is a simplified version - in practice, need to know architecture
        # For now, return placeholder
        checkpoint = self.checkpoints[idx]
        return checkpoint  # Return state dict

print("Training trajectory analyzer implemented!")

In [None]:
# Demonstrate with synthetic training trajectory
# Simulate how metrics evolve during training

def simulate_training_trajectory(n_epochs=100):
    """
    Simulate realistic training trajectory.
    """
    epochs = np.arange(n_epochs)
    
    # Loss: Exponential decay with noise
    loss = 2.0 * np.exp(-epochs / 20) + 0.5 + 0.1 * np.random.randn(n_epochs)
    loss = np.maximum(loss, 0.3)  # Floor
    
    # Accuracy: Sigmoid growth
    accuracy = 1.0 / (1 + np.exp(-(epochs - 30) / 10))
    accuracy = 0.5 + 0.4 * accuracy + 0.02 * np.random.randn(n_epochs)
    
    # Dimensionality: Two phases
    # Phase 1 (0-40): Increasing (fitting)
    # Phase 2 (40+): Decreasing (compression)
    dim_phase1 = 10 + 20 * (epochs / 40)
    dim_phase2 = 30 - 10 * ((epochs - 40) / 60)
    dimensionality = np.where(epochs < 40, dim_phase1, dim_phase2)
    dimensionality += np.random.randn(n_epochs) * 2
    dimensionality = np.maximum(dimensionality, 5)
    
    # Similarity to init: Decreasing (representations drift)
    similarity = np.exp(-epochs / 30) + 0.1 * np.random.randn(n_epochs)
    similarity = np.clip(similarity, 0, 1)
    
    return {
        'epochs': epochs,
        'loss': loss,
        'accuracy': accuracy,
        'dimensionality': dimensionality,
        'similarity_to_init': similarity
    }

# Generate trajectory
trajectory = simulate_training_trajectory(n_epochs=100)

print("Simulated training trajectory generated!")

In [None]:
# Visualize training trajectory
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

epochs = trajectory['epochs']

# Plot 1: Loss
ax = axes[0, 0]
ax.plot(epochs, trajectory['loss'], linewidth=2, color='red')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training Loss')
ax.grid(True, alpha=0.3)

# Plot 2: Accuracy
ax = axes[0, 1]
ax.plot(epochs, trajectory['accuracy'], linewidth=2, color='green')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.set_title('Test Accuracy')
ax.grid(True, alpha=0.3)

# Plot 3: Dimensionality (shows two phases)
ax = axes[1, 0]
ax.plot(epochs, trajectory['dimensionality'], linewidth=2, color='blue')
ax.axvline(x=40, color='red', linestyle='--', linewidth=2, 
           label='Phase transition')
ax.annotate('Fitting\n(increasing dim)', xy=(20, 25), fontsize=10, ha='center')
ax.annotate('Compression\n(decreasing dim)', xy=(70, 25), fontsize=10, ha='center')
ax.set_xlabel('Epoch')
ax.set_ylabel('Effective Dimensionality')
ax.set_title('Representational Dimensionality (Two-Phase Learning)')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 4: Similarity to initialization
ax = axes[1, 1]
ax.plot(epochs, trajectory['similarity_to_init'], linewidth=2, color='purple')
ax.set_xlabel('Epoch')
ax.set_ylabel('CKA Similarity')
ax.set_title('Representational Drift from Initialization')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("- Loss decreases, accuracy increases (expected)")
print("- Dimensionality: Fitting phase (epoch 0-40), then compression (40+)")
print("- Representations drift away from initialization")
print("- Phase transition around epoch 40 (critical period)")

## Part 2: Manifold Geometry - Curvature and Shape

### What is Manifold Geometry?

Neural representations often lie on low-dimensional **manifolds** embedded in high-dimensional space.

**Geometric properties** reveal structure:
- **Curvature**: How "bent" is the manifold?
- **Geodesic distance**: Shortest path along manifold
- **Local dimension**: Dimensionality in neighborhoods
- **Tangent spaces**: Local linear approximations

### Why Geometry Matters

1. **Separability**: Curved manifolds can separate classes linearly
2. **Generalization**: Smooth manifolds generalize better
3. **Capacity**: Geometry reveals network expressivity
4. **Alignment**: Compare geometries (brain vs model)

### Key Metrics

1. **Gaussian Curvature**: Intrinsic curvature (positive = sphere-like, negative = saddle)
2. **Geodesic Distance**: Path length along manifold
3. **Tangent Alignment**: How aligned are local tangent spaces?
4. **Manifold Capacity**: How many patterns can be separated?

### Applications

- **Architecture design**: Encourage smooth manifolds
- **Transfer learning**: Match manifold geometries
- **Neuroscience**: Compare to neural manifolds
- **Robustness**: Smooth manifolds are more robust

In [None]:
class ManifoldGeometryAnalyzer:
    """
    Analyze geometric properties of neural representation manifolds.
    """
    
    def __init__(self):
        pass
    
    def estimate_curvature(self, X, k=10):
        """
        Estimate local Gaussian curvature using PCA in neighborhoods.
        
        Intuition: High curvature = local PCA has rapidly changing directions
        """
        n_samples, n_dims = X.shape
        
        # Find k nearest neighbors
        from sklearn.neighbors import NearestNeighbors
        nbrs = NearestNeighbors(n_neighbors=k+1).fit(X)
        distances, indices = nbrs.kneighbors(X)
        
        curvatures = []
        
        for i in range(n_samples):
            # Get neighborhood
            neighbor_indices = indices[i, 1:]  # Exclude self
            neighborhood = X[neighbor_indices]
            
            # Center
            centered = neighborhood - neighborhood.mean(axis=0)
            
            # PCA
            if len(centered) > 1:
                cov = centered.T @ centered / len(centered)
                eigvals = np.linalg.eigvalsh(cov)
                eigvals = np.maximum(eigvals, 0)
                
                # Curvature estimate: ratio of small to large eigenvalues
                if eigvals[-1] > 1e-10:
                    curvature = eigvals[0] / eigvals[-1]
                else:
                    curvature = 0
            else:
                curvature = 0
            
            curvatures.append(curvature)
        
        return np.array(curvatures)
    
    def compute_geodesic_distance(self, X, metric='euclidean'):
        """
        Estimate geodesic distances using Isomap-like approach.
        """
        # Pairwise distances
        dists = squareform(pdist(X, metric=metric))
        
        # Use shortest path (Floyd-Warshall or Dijkstra)
        # For simplicity, just return Euclidean distances
        # In practice, would build k-NN graph and compute graph distances
        
        return dists
    
    def compute_manifold_dimension(self, X, n_neighbors=10):
        """
        Estimate local intrinsic dimensionality using MLE.
        """
        from sklearn.neighbors import NearestNeighbors
        
        nbrs = NearestNeighbors(n_neighbors=n_neighbors+1).fit(X)
        distances, _ = nbrs.kneighbors(X)
        
        # Remove self (distance 0)
        distances = distances[:, 1:]
        
        # MLE dimension estimate
        r_k = distances[:, -1]
        r_1 = distances[:, 0]
        
        ratio = r_k / (r_1 + 1e-10)
        ratio = np.maximum(ratio, 1e-10)
        
        dims = (n_neighbors - 1) / np.log(ratio)
        dims = dims[np.isfinite(dims)]
        
        return np.median(dims) if len(dims) > 0 else 0
    
    def analyze_manifold(self, X):
        """
        Complete manifold analysis.
        """
        print("Analyzing manifold geometry...")
        
        # Curvature
        curvatures = self.estimate_curvature(X, k=10)
        mean_curvature = np.mean(curvatures)
        std_curvature = np.std(curvatures)
        
        # Intrinsic dimension
        intrinsic_dim = self.compute_manifold_dimension(X, n_neighbors=10)
        
        results = {
            'mean_curvature': mean_curvature,
            'std_curvature': std_curvature,
            'curvatures': curvatures,
            'intrinsic_dimension': intrinsic_dim,
            'ambient_dimension': X.shape[1]
        }
        
        print(f"  Mean curvature: {mean_curvature:.4f}")
        print(f"  Intrinsic dimension: {intrinsic_dim:.2f}")
        print(f"  Ambient dimension: {X.shape[1]}")
        
        return results

print("Manifold geometry analyzer implemented!")

In [None]:
# Generate synthetic manifold data
def generate_swiss_roll(n_samples=1000, noise=0.1):
    """
    Generate Swiss roll: classic 2D manifold in 3D space.
    """
    t = 1.5 * np.pi * (1 + 2 * np.random.rand(n_samples))
    x = t * np.cos(t)
    y = 20 * np.random.rand(n_samples)
    z = t * np.sin(t)
    
    X = np.vstack([x, y, z]).T
    X += noise * np.random.randn(n_samples, 3)
    
    return X, t  # Return both data and parameter

# Generate data
X_manifold, t_param = generate_swiss_roll(n_samples=500, noise=0.1)

print(f"Generated Swiss roll manifold:")
print(f"  Samples: {X_manifold.shape[0]}")
print(f"  Ambient dimension: {X_manifold.shape[1]} (3D)")
print(f"  True intrinsic dimension: 2")

In [None]:
# Analyze manifold
analyzer = ManifoldGeometryAnalyzer()
manifold_results = analyzer.analyze_manifold(X_manifold)

print("\nManifold Analysis Complete!")

In [None]:
# Visualize manifold
fig = plt.figure(figsize=(15, 5))

# Plot 1: 3D scatter (ambient space)
ax1 = fig.add_subplot(131, projection='3d')
scatter = ax1.scatter(X_manifold[:, 0], X_manifold[:, 1], X_manifold[:, 2],
                     c=t_param, cmap='viridis', s=10, alpha=0.6)
ax1.set_xlabel('X')
ax1.set_ylabel('Y')
ax1.set_zlabel('Z')
ax1.set_title('Swiss Roll in 3D (Ambient Space)')
plt.colorbar(scatter, ax=ax1, label='Parameter t')

# Plot 2: Curvature distribution
ax2 = fig.add_subplot(132)
ax2.hist(manifold_results['curvatures'], bins=30, alpha=0.7, color='steelblue')
ax2.axvline(x=manifold_results['mean_curvature'], color='red', 
           linestyle='--', linewidth=2, label=f'Mean: {manifold_results["mean_curvature"]:.3f}')
ax2.set_xlabel('Local Curvature')
ax2.set_ylabel('Count')
ax2.set_title('Curvature Distribution')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Plot 3: Curvature on manifold
ax3 = fig.add_subplot(133, projection='3d')
scatter2 = ax3.scatter(X_manifold[:, 0], X_manifold[:, 1], X_manifold[:, 2],
                      c=manifold_results['curvatures'], cmap='RdYlBu_r',
                      s=20, alpha=0.7)
ax3.set_xlabel('X')
ax3.set_ylabel('Y')
ax3.set_zlabel('Z')
ax3.set_title('Curvature on Manifold\n(Red=high, Blue=low)')
plt.colorbar(scatter2, ax=ax3, label='Curvature')

plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("- Swiss roll is 2D manifold embedded in 3D")
print(f"- Estimated intrinsic dimension: {manifold_results['intrinsic_dimension']:.1f} (should be ≈ 2)")
print("- Curvature varies across manifold (higher at tighter curves)")

## Part 3: Counterfactual Interventions

### What are Counterfactuals?

**Counterfactual reasoning**: "What would happen if...?"

For neural networks:
- **Question**: What if neuron X had a different value?
- **Method**: Surgically modify activation, observe downstream effects
- **Goal**: Understand causal role of each component

### Types of Interventions

1. **Latent Surgery**: Directly modify hidden representations
2. **Do-Calculus**: Pearl's causal intervention framework
3. **Synthetic Lesions**: Ablate neurons or connections
4. **Causal Tracing**: Track information flow through network

### Why Counterfactuals Matter

1. **Causality**: Correlation ≠ causation, interventions reveal true causes
2. **Interpretability**: Understand functional role of components
3. **Debugging**: Identify failure modes
4. **Control**: Steer network behavior

### Pearl's Causal Hierarchy

1. **Association**: P(Y|X) - Seeing/observing
2. **Intervention**: P(Y|do(X)) - Doing/intervening
3. **Counterfactual**: P(Y_X|X',Y') - Imagining/reasoning

Neural networks typically studied at level 1, but levels 2-3 are more powerful!

In [None]:
class CounterfactualAnalyzer:
    """
    Perform counterfactual interventions on neural networks.
    """
    
    def __init__(self, model):
        self.model = model
    
    def latent_surgery(self, x, layer_name, neuron_idx, new_value):
        """
        Perform latent surgery: modify specific neuron activation.
        
        Args:
            x: Input
            layer_name: Which layer to modify
            neuron_idx: Which neuron to modify
            new_value: New activation value
        
        Returns:
            output_original: Output without intervention
            output_intervened: Output with intervention
        """
        # Get original output
        with torch.no_grad():
            output_original = self.model(x)
        
        # Intervention hook
        def intervention_hook(module, input, output):
            # Modify specific neuron
            output[:, neuron_idx] = new_value
            return output
        
        # Find layer and register hook
        target_layer = None
        for name, module in self.model.named_modules():
            if name == layer_name:
                target_layer = module
                break
        
        if target_layer is None:
            raise ValueError(f"Layer {layer_name} not found")
        
        handle = target_layer.register_forward_hook(intervention_hook)
        
        # Get intervened output
        with torch.no_grad():
            output_intervened = self.model(x)
        
        # Remove hook
        handle.remove()
        
        return output_original, output_intervened
    
    def compute_causal_effect(self, x, layer_name, neuron_idx, intervention_values):
        """
        Compute causal effect of neuron by sweeping intervention values.
        
        Returns:
            Causal effect curve
        """
        effects = []
        
        for value in intervention_values:
            orig, interv = self.latent_surgery(x, layer_name, neuron_idx, value)
            
            # Measure effect (output difference)
            effect = (interv - orig).abs().mean().item()
            effects.append(effect)
        
        return np.array(effects)
    
    def synthetic_lesion(self, x, layer_name, neuron_indices):
        """
        Ablate (zero out) specific neurons.
        
        Returns:
            output_original, output_lesioned
        """
        # Original
        with torch.no_grad():
            output_original = self.model(x)
        
        # Lesion hook
        def lesion_hook(module, input, output):
            output[:, neuron_indices] = 0
            return output
        
        # Find and hook layer
        target_layer = dict(self.model.named_modules())[layer_name]
        handle = target_layer.register_forward_hook(lesion_hook)
        
        # Lesioned output
        with torch.no_grad():
            output_lesioned = self.model(x)
        
        handle.remove()
        
        return output_original, output_lesioned
    
    def identify_critical_neurons(self, x, layer_name, n_neurons=None):
        """
        Identify which neurons are most critical (largest effect when ablated).
        """
        # Get layer size
        layer = dict(self.model.named_modules())[layer_name]
        
        # Test each neuron
        effects = []
        
        if n_neurons is None:
            # Try to infer from layer
            if hasattr(layer, 'out_features'):
                n_neurons = layer.out_features
            else:
                n_neurons = 10  # Default
        
        for neuron_idx in range(min(n_neurons, 20)):  # Limit to 20 for speed
            orig, lesioned = self.synthetic_lesion(x, layer_name, [neuron_idx])
            effect = (orig - lesioned).abs().mean().item()
            effects.append(effect)
        
        return np.array(effects)

print("Counterfactual analyzer implemented!")

In [None]:
# Create simple test model
class TestNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(10, 20)
        self.hidden2 = nn.Linear(20, 10)
        self.output = nn.Linear(10, 2)
    
    def forward(self, x):
        x = torch.tanh(self.hidden1(x))
        x = torch.tanh(self.hidden2(x))
        x = self.output(x)
        return x

# Create model and analyzer
test_model = TestNetwork()
cf_analyzer = CounterfactualAnalyzer(test_model)

# Test input
x_test = torch.randn(5, 10)

print("Test model created")
print(f"  Input shape: {x_test.shape}")

In [None]:
# Test 1: Latent surgery
layer_name = 'hidden2'
neuron_idx = 5
intervention_values = np.linspace(-2, 2, 20)

causal_effects = cf_analyzer.compute_causal_effect(
    x_test, layer_name, neuron_idx, intervention_values
)

print(f"\nLatent surgery on {layer_name}, neuron {neuron_idx}:")
print(f"  Causal effect range: {causal_effects.min():.4f} to {causal_effects.max():.4f}")

In [None]:
# Test 2: Identify critical neurons
critical_effects = cf_analyzer.identify_critical_neurons(x_test, 'hidden2', n_neurons=10)

print(f"\nCritical neuron analysis:")
print(f"  Tested {len(critical_effects)} neurons")
print(f"  Most critical: neuron {np.argmax(critical_effects)} (effect={critical_effects.max():.4f})")
print(f"  Least critical: neuron {np.argmin(critical_effects)} (effect={critical_effects.min():.4f})")

In [None]:
# Visualize counterfactual analysis
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Causal effect vs intervention value
ax = axes[0]
ax.plot(intervention_values, causal_effects, 'o-', linewidth=2, markersize=6)
ax.set_xlabel('Intervention Value (neuron activation)')
ax.set_ylabel('Causal Effect (output change)')
ax.set_title(f'Causal Effect of Neuron {neuron_idx} in {layer_name}')
ax.grid(True, alpha=0.3)

# Plot 2: Critical neurons (lesion effects)
ax = axes[1]
neuron_indices = np.arange(len(critical_effects))
colors = ['red' if e == critical_effects.max() else 'steelblue' 
         for e in critical_effects]
ax.bar(neuron_indices, critical_effects, color=colors, alpha=0.7)
ax.set_xlabel('Neuron Index')
ax.set_ylabel('Lesion Effect (output change when ablated)')
ax.set_title('Critical Neuron Analysis (Red = most critical)')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("- Causal effect curve shows how output depends on neuron activation")
print("- Critical neurons have large effect when ablated")
print("- Non-critical neurons are redundant or less important")

## Summary: The Complete Interpretability Toolkit

### Journey Through 10 Notebooks

Congratulations! You've mastered a comprehensive toolkit for mechanistic interpretability:

1. **Introduction & Quickstart**: Foundation concepts, first SAE, first intervention
2. **Sparse Autoencoders**: Decomposing polysemantic neurons into features
3. **Causal Interventions**: Activation patching, ablations, circuit discovery
4. **Fractal Analysis**: Biological realism through scale-free dynamics
5. **Brain Alignment**: CCA, RSA, PLS for model-to-brain comparison
6. **Dynamical Systems**: DMD, Lyapunov exponents, fixed points, dimensionality
7. **Circuit Extraction**: Latent RNN models, DUNL, feature visualization
8. **Biophysical Modeling**: Spiking networks, surrogate gradients, Dale's law
9. **Information Theory**: Mutual information, information plane, MINE
10. **Advanced Topics**: Meta-dynamics, manifold geometry, counterfactuals (this notebook!)

### Key Principles

1. **Sparsity enables interpretability**: SAEs decompose representations
2. **Causality reveals mechanism**: Interventions, not correlations
3. **Geometry reveals structure**: Manifolds, curvature, topology
4. **Dynamics reveal computation**: How networks evolve, not just final state
5. **Biological constraints help**: Dale's law, fractals, biophysical models
6. **Information theory quantifies**: MI measures what networks know

### Building Your Research Pipeline

**For any project, combine**:
1. **Feature extraction** (SAEs, circuits)
2. **Causal analysis** (interventions, counterfactuals)
3. **Geometric analysis** (manifolds, dimensionality)
4. **Dynamical analysis** (training trajectories, fixed points)
5. **Alignment** (brain recordings, behavior)

### Next Steps

1. **Apply to your models**: Use this toolkit on your research problems
2. **Combine methods**: SAEs + interventions, fractals + alignment, etc.
3. **Extend the toolkit**: Implement new methods from latest papers
4. **Share discoveries**: Contribute to mechanistic interpretability community
5. **Build standards**: Help make this a standard neuroscience toolkit

### Vision: Community Resource

The goal is for **neuros-mechint** to become:
- Standard toolkit for neuroscience experiments worldwide
- Community-driven resource with contributions from many researchers
- Bridge between AI and neuroscience
- Foundation for understanding intelligence

### Further Reading

**Foundational papers**:
- Olah et al. (2020): *Zoom In: An Introduction to Circuits*
- Elhage et al. (2022): *Toy Models of Superposition*
- Marks et al. (2024): *The Geometry of Truth*

**Mechanistic interpretability resources**:
- Anthropic's interpretability research
- Distill.pub articles
- NeuroAI reading group

**Computational neuroscience**:
- Dayan & Abbott: *Theoretical Neuroscience*
- Gerstner et al.: *Neuronal Dynamics*

---

## Thank you for completing this journey!

You now have the tools to:
- Understand how neural networks compute
- Extract interpretable circuits and features
- Compare models to brains
- Design more interpretable architectures
- Contribute to understanding intelligence

**The future of interpretability is bright, and you're part of it!**