# Notebook 07: Circuit Extraction and Latent Models

## Extracting Minimal Computational Circuits from Neural Networks

This notebook explores how to **extract interpretable, minimal computational circuits** from complex neural networks. Instead of analyzing all neurons, we identify the essential computations and build simplified models that capture the core algorithmic behavior.

### Why Circuit Extraction Matters

1. **Interpretability**: Small circuits are easier to understand than full networks
2. **Generalization**: Core circuits reveal what the network actually learned
3. **Debugging**: Identify specific computational failures
4. **Transfer**: Extract and reuse learned algorithms
5. **Mechanistic Understanding**: Move from "what" to "how" networks compute

### What You'll Learn

1. **Latent Circuit Models**: Extract minimal RNN-like circuits from large networks
2. **DUNL (Disentangled and Unified Networks through Latent)**: Decompose mixed selectivity into factors
3. **Feature Visualization**: Find optimal stimuli for neurons and circuits
4. **Activation Maximization**: Generate inputs that maximally activate specific features
5. **Circuit Motifs**: Identify recurring computational patterns
6. **Recurrent Dynamics Analysis**: Understand temporal processing in circuits

### References

- Langdon & Engel (2025): *Latent circuit inference from data*
- Sussillo & Barak (2013): *Opening the black box: Low-dimensional dynamics in high-dimensional RNNs*
- Olah et al. (2018): *The building blocks of interpretability* (Distill)
- Gu et al. (2021): *Disentangling and unifying neural representations*
- Rigotti et al. (2013): *The importance of mixed selectivity in complex cognitive tasks*

In [None]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize
from sklearn.decomposition import PCA, NMF
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

print("All libraries imported successfully!")
print(f"NumPy version: {np.__version__}")
print(f"PyTorch version: {torch.__version__}")

## Part 1: Latent Circuit Models

### The Latent Circuit Framework

**Key Insight**: Large neural networks often implement simple algorithms using only a small subset of their capacity.

**Latent Circuit Model** extracts a minimal RNN that:
1. Has much lower dimensionality than original network
2. Captures essential computational structure
3. Generalizes to new inputs
4. Is interpretable

**The extraction process**:
```
High-dimensional network → Low-dimensional latent circuit

z(t+1) = f(W_rec @ z(t) + W_in @ u(t) + b)
```

where:
- z: Low-dimensional latent state (e.g., 5-10 dimensions)
- u: Input
- W_rec: Recurrent weights (captures dynamics)
- W_in: Input weights
- f: Nonlinearity (tanh, ReLU, etc.)

### Why This Works

- **Intrinsic dimensionality**: Neural dynamics often live on low-dimensional manifolds
- **Task constraints**: Many tasks only require simple computations
- **Regularization**: Networks prefer simple solutions

### Applications

1. **RNN compression**: Extract minimal models from LSTMs/GRUs
2. **Cognitive modeling**: Build mechanistic models of behavior
3. **Transfer learning**: Extract and reuse learned circuits
4. **Debugging**: Identify where computations fail

In [None]:
class LatentCircuitModel(nn.Module):
    """
    Low-dimensional latent circuit extracted from a larger network.
    
    The circuit is a minimal RNN:
        z(t+1) = tanh(W_rec @ z(t) + W_in @ u(t) + b)
        y(t) = W_out @ z(t)
    """
    
    def __init__(self, input_dim, latent_dim, output_dim):
        super().__init__()
        self.latent_dim = latent_dim
        
        # Recurrent weights (the "circuit")
        self.W_rec = nn.Parameter(torch.randn(latent_dim, latent_dim) * 0.5 / np.sqrt(latent_dim))
        
        # Input projection
        self.W_in = nn.Parameter(torch.randn(latent_dim, input_dim) * 0.5 / np.sqrt(input_dim))
        
        # Output projection
        self.W_out = nn.Parameter(torch.randn(output_dim, latent_dim) * 0.5 / np.sqrt(latent_dim))
        
        # Bias
        self.bias = nn.Parameter(torch.zeros(latent_dim))
    
    def forward(self, inputs, z0=None):
        """
        Args:
            inputs: (batch, time, input_dim)
            z0: Initial latent state (batch, latent_dim)
        
        Returns:
            outputs: (batch, time, output_dim)
            latent_states: (batch, time, latent_dim)
        """
        batch, time, _ = inputs.shape
        device = inputs.device
        
        # Initialize latent state
        if z0 is None:
            z = torch.zeros(batch, self.latent_dim, device=device)
        else:
            z = z0
        
        # Store trajectories
        latent_traj = []
        output_traj = []
        
        # Run dynamics
        for t in range(time):
            # z(t+1) = tanh(W_rec @ z(t) + W_in @ u(t) + b)
            u_t = inputs[:, t, :]
            z = torch.tanh(z @ self.W_rec.T + u_t @ self.W_in.T + self.bias)
            
            # Output: y(t) = W_out @ z(t)
            y = z @ self.W_out.T
            
            latent_traj.append(z)
            output_traj.append(y)
        
        # Stack into tensors
        latent_states = torch.stack(latent_traj, dim=1)
        outputs = torch.stack(output_traj, dim=1)
        
        return outputs, latent_states
    
    def get_fixed_points(self, input_val=None, n_inits=10):
        """
        Find fixed points of the circuit.
        
        Fixed points satisfy: z* = tanh(W_rec @ z* + W_in @ u + b)
        """
        from scipy.optimize import fsolve
        
        W_rec_np = self.W_rec.detach().cpu().numpy()
        W_in_np = self.W_in.detach().cpu().numpy()
        bias_np = self.bias.detach().cpu().numpy()
        
        # Input contribution
        if input_val is None:
            input_contrib = np.zeros(self.latent_dim)
        else:
            input_contrib = W_in_np @ input_val
        
        # Define fixed point equation
        def fp_equation(z):
            return np.tanh(W_rec_np @ z + input_contrib + bias_np) - z
        
        # Try multiple initializations
        fixed_points = []
        for _ in range(n_inits):
            z0 = np.random.randn(self.latent_dim) * 0.5
            try:
                z_star = fsolve(fp_equation, z0)
                residual = np.linalg.norm(fp_equation(z_star))
                if residual < 1e-6:
                    fixed_points.append(z_star)
            except:
                pass
        
        return fixed_points

print("Latent circuit model implemented!")

In [None]:
class CircuitFitter:
    """
    Fit a latent circuit model to high-dimensional neural data.
    
    Two-step process:
    1. Dimensionality reduction: Find low-dimensional latent space
    2. Dynamics fitting: Learn recurrent weights in latent space
    """
    
    def __init__(self, latent_dim=10, learning_rate=1e-3):
        self.latent_dim = latent_dim
        self.learning_rate = learning_rate
        self.circuit = None
        self.encoder = None  # Projects high-dim → latent
        self.decoder = None  # Projects latent → high-dim
    
    def fit(self, inputs, neural_data, n_epochs=100, verbose=True):
        """
        Fit latent circuit to neural recordings.
        
        Args:
            inputs: Task inputs (batch, time, input_dim)
            neural_data: Neural recordings (batch, time, neural_dim)
            n_epochs: Training epochs
        """
        batch, time, input_dim = inputs.shape
        _, _, neural_dim = neural_data.shape
        
        # Convert to torch
        if not isinstance(inputs, torch.Tensor):
            inputs = torch.FloatTensor(inputs)
        if not isinstance(neural_data, torch.Tensor):
            neural_data = torch.FloatTensor(neural_data)
        
        # Step 1: Initialize with PCA
        neural_flat = neural_data.reshape(-1, neural_dim).numpy()
        pca = PCA(n_components=self.latent_dim)
        pca.fit(neural_flat)
        
        # Step 2: Create circuit model
        self.circuit = LatentCircuitModel(
            input_dim=input_dim,
            latent_dim=self.latent_dim,
            output_dim=neural_dim
        )
        
        # Initialize encoder/decoder with PCA
        self.encoder = nn.Linear(neural_dim, self.latent_dim)
        self.decoder = nn.Linear(self.latent_dim, neural_dim)
        self.encoder.weight.data = torch.FloatTensor(pca.components_)
        self.decoder.weight.data = torch.FloatTensor(pca.components_.T)
        
        # Step 3: Optimize circuit to match neural data
        optimizer = Adam(
            list(self.circuit.parameters()) + 
            list(self.encoder.parameters()) + 
            list(self.decoder.parameters()),
            lr=self.learning_rate
        )
        
        losses = []
        for epoch in range(n_epochs):
            optimizer.zero_grad()
            
            # Encode neural data to latent
            z_target = self.encoder(neural_data.reshape(-1, neural_dim))
            z_target = z_target.reshape(batch, time, self.latent_dim)
            
            # Run circuit
            outputs, z_pred = self.circuit(inputs)
            
            # Reconstruction loss: match neural data
            neural_pred = self.decoder(z_pred.reshape(-1, self.latent_dim))
            neural_pred = neural_pred.reshape(batch, time, neural_dim)
            recon_loss = F.mse_loss(neural_pred, neural_data)
            
            # Dynamics loss: match latent dynamics
            dynamics_loss = F.mse_loss(z_pred, z_target)
            
            # Total loss
            loss = recon_loss + 0.5 * dynamics_loss
            
            loss.backward()
            optimizer.step()
            
            losses.append(loss.item())
            
            if verbose and (epoch + 1) % 20 == 0:
                print(f"Epoch {epoch+1}/{n_epochs}, Loss: {loss.item():.6f}")
        
        return losses
    
    def get_circuit_weights(self):
        """Extract circuit connectivity matrix."""
        return self.circuit.W_rec.detach().cpu().numpy()

print("Circuit fitter implemented!")

In [None]:
# Generate synthetic data: RNN performing a simple task
# Task: Integrate input over time

def generate_integration_data(n_trials=100, seq_length=50, input_dim=5, hidden_dim=30):
    """
    Generate data from an RNN performing temporal integration.
    """
    # Create RNN
    rnn = nn.RNN(input_dim, hidden_dim, batch_first=True)
    rnn.eval()
    
    # Generate inputs
    inputs = torch.randn(n_trials, seq_length, input_dim) * 0.5
    
    # Run RNN
    with torch.no_grad():
        outputs, _ = rnn(inputs)
    
    return inputs, outputs

# Generate data
inputs, neural_data = generate_integration_data(
    n_trials=100,
    seq_length=50,
    input_dim=5,
    hidden_dim=30
)

print(f"Generated data:")
print(f"  Inputs: {inputs.shape} (trials, time, input_dim)")
print(f"  Neural recordings: {neural_data.shape} (trials, time, neurons)")

In [None]:
# Fit latent circuit
fitter = CircuitFitter(latent_dim=5, learning_rate=1e-3)
losses = fitter.fit(inputs, neural_data, n_epochs=100, verbose=True)

# Extract circuit
W_rec = fitter.get_circuit_weights()

print(f"\nExtracted circuit: {W_rec.shape[0]} latent dimensions")
print(f"Original network: {neural_data.shape[2]} neurons")
print(f"Compression ratio: {neural_data.shape[2] / W_rec.shape[0]:.1f}x")

In [None]:
# Visualize circuit extraction results
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Training loss
ax = axes[0, 0]
ax.plot(losses, linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Circuit Fitting: Training Loss')
ax.grid(True, alpha=0.3)

# Plot 2: Circuit connectivity
ax = axes[0, 1]
im = ax.imshow(W_rec, cmap='RdBu_r', aspect='auto', vmin=-1, vmax=1)
ax.set_xlabel('Latent Unit (from)')
ax.set_ylabel('Latent Unit (to)')
ax.set_title('Extracted Circuit: Recurrent Weights')
plt.colorbar(im, ax=ax, label='Weight')

# Plot 3: Eigenvalue spectrum (stability)
ax = axes[1, 0]
eigvals = np.linalg.eigvals(W_rec)
ax.scatter(eigvals.real, eigvals.imag, s=100, alpha=0.6, edgecolors='black')
circle = plt.Circle((0, 0), 1, fill=False, color='red', linestyle='--', 
                    linewidth=2, label='Stability boundary')
ax.add_patch(circle)
ax.set_xlabel('Real Part')
ax.set_ylabel('Imaginary Part')
ax.set_title('Circuit Eigenvalues (Stability Analysis)')
ax.axhline(y=0, color='gray', linestyle='-', alpha=0.3)
ax.axvline(x=0, color='gray', linestyle='-', alpha=0.3)
ax.legend()
ax.set_aspect('equal')
ax.grid(True, alpha=0.3)

# Plot 4: Example trajectory comparison
ax = axes[1, 1]
# Run circuit on first trial
with torch.no_grad():
    _, latent_traj = fitter.circuit(inputs[:1])
    latent_traj = latent_traj.squeeze().numpy()

for i in range(min(3, latent_traj.shape[1])):
    ax.plot(latent_traj[:, i], label=f'Latent {i+1}', linewidth=2)

ax.set_xlabel('Time Step')
ax.set_ylabel('Latent Activity')
ax.set_title('Example Latent Circuit Trajectory')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("- Low training loss indicates good circuit fit")
print("- Connectivity matrix shows circuit wiring")
print("- Eigenvalues inside unit circle = stable dynamics")
print("- Latent trajectories show temporal evolution")

## Part 2: DUNL - Disentangling Mixed Selectivity

### The Mixed Selectivity Problem

**Mixed selectivity**: Individual neurons respond to multiple task variables simultaneously.

Example: A neuron might respond to:
- Stimulus identity AND
- Decision choice AND
- Time in trial

This makes interpretation difficult!

### DUNL Solution

**Disentangled and Unified Networks through Latent (DUNL)** decomposes mixed selectivity:

```
Neural response = Factor 1 ⊗ Factor 2 ⊗ ... ⊗ Factor K
```

where:
- Each factor corresponds to one task variable
- ⊗ represents tensor product (interaction)
- Factors are disentangled (independent)

**Benefits**:
1. Interpretable factors (each = one variable)
2. Understand interactions between variables
3. Measure importance of each factor
4. Predict generalization to new conditions

### Mathematical Framework

For K task variables with dimensions [d1, d2, ..., dK]:

```
x = (F1 ⊗ F2 ⊗ ... ⊗ FK) @ c + noise
```

where:
- Fi: Factor matrices (di × rank)
- c: Core tensor (combining coefficients)
- x: Neural population response

In [None]:
class DUNLModel:
    """
    Disentangled and Unified Network through Latent (DUNL).
    
    Decomposes neural responses into task-relevant factors.
    """
    
    def __init__(self, n_factors=2, rank=5):
        """
        Args:
            n_factors: Number of task variables
            rank: Rank of each factor decomposition
        """
        self.n_factors = n_factors
        self.rank = rank
        self.factors = None
        self.core = None
    
    def fit(self, neural_data, task_labels, n_components=5):
        """
        Fit DUNL model to neural data with task labels.
        
        Args:
            neural_data: (n_samples, n_neurons)
            task_labels: List of (n_samples,) arrays, one per factor
            n_components: Rank for NMF
        """
        n_samples, n_neurons = neural_data.shape
        
        # Use NMF for non-negative factorization
        nmf = NMF(n_components=n_components, init='random', random_state=42)
        
        # Factor 1: Stimulus-selective components
        W = nmf.fit_transform(neural_data)
        H = nmf.components_
        
        # Store factors
        self.factors = [W, H.T]
        
        # Compute selectivity scores
        selectivity_scores = self.compute_selectivity(W, task_labels)
        
        return selectivity_scores
    
    def compute_selectivity(self, factors, task_labels):
        """
        Compute how selective each factor is to task variables.
        
        Uses ANOVA-like measure.
        """
        if len(task_labels) == 0:
            return None
        
        labels = task_labels[0]
        unique_labels = np.unique(labels)
        
        selectivity = np.zeros(factors.shape[1])
        
        for i in range(factors.shape[1]):
            # Between-condition variance / within-condition variance
            between_var = 0
            within_var = 0
            
            for label in unique_labels:
                mask = labels == label
                if np.sum(mask) > 0:
                    group_mean = factors[mask, i].mean()
                    between_var += np.sum(mask) * (group_mean - factors[:, i].mean())**2
                    within_var += np.sum((factors[mask, i] - group_mean)**2)
            
            if within_var > 0:
                selectivity[i] = between_var / within_var
        
        return selectivity
    
    def analyze_mixing(self, neural_data):
        """
        Measure degree of mixed selectivity.
        
        Returns:
            mixing_score: Higher = more mixed selectivity
        """
        # Compute pairwise correlations
        corr = np.corrcoef(neural_data.T)
        
        # Average absolute correlation (off-diagonal)
        mask = ~np.eye(corr.shape[0], dtype=bool)
        mixing_score = np.abs(corr[mask]).mean()
        
        return mixing_score

print("DUNL model implemented!")

In [None]:
# Generate synthetic data with mixed selectivity
def generate_mixed_selectivity_data(n_samples=500, n_neurons=50, n_conditions=4):
    """
    Generate neural data with mixed selectivity.
    
    Each neuron responds to combination of:
    - Stimulus type (4 levels)
    - Context (2 levels)
    """
    # Task variables
    stimulus = np.random.randint(0, n_conditions, n_samples)
    context = np.random.randint(0, 2, n_samples)
    
    # Generate basis responses
    # Pure stimulus tuning
    stim_tuning = np.random.randn(n_neurons, n_conditions)
    
    # Pure context tuning
    context_tuning = np.random.randn(n_neurons, 2)
    
    # Generate mixed responses
    neural_data = np.zeros((n_samples, n_neurons))
    
    for i in range(n_samples):
        # Mixed response = stimulus effect + context effect + interaction
        stim_effect = stim_tuning[:, stimulus[i]]
        context_effect = context_tuning[:, context[i]]
        interaction = stim_effect * context_effect * 0.5
        
        neural_data[i] = stim_effect + context_effect + interaction
    
    # Add noise
    neural_data += np.random.randn(*neural_data.shape) * 0.5
    
    # Make non-negative (like firing rates)
    neural_data = np.maximum(neural_data, 0)
    
    return neural_data, [stimulus, context]

# Generate data
neural_data, task_labels = generate_mixed_selectivity_data(
    n_samples=500,
    n_neurons=50,
    n_conditions=4
)

print(f"Generated data with mixed selectivity:")
print(f"  Neural data: {neural_data.shape}")
print(f"  Stimulus conditions: {len(np.unique(task_labels[0]))}")
print(f"  Context conditions: {len(np.unique(task_labels[1]))}")

In [None]:
# Fit DUNL model
dunl = DUNLModel(n_factors=2, rank=5)
selectivity = dunl.fit(neural_data, task_labels, n_components=5)

# Measure mixing
mixing_score = dunl.analyze_mixing(neural_data)

print(f"\nDUNL Analysis:")
print(f"  Mixing score: {mixing_score:.3f} (higher = more mixed)")
print(f"\nFactor selectivity scores:")
for i, s in enumerate(selectivity):
    print(f"  Factor {i+1}: {s:.3f}")

In [None]:
# Visualize DUNL decomposition
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Factor loadings
ax = axes[0, 0]
W = dunl.factors[0]
im = ax.imshow(W.T, aspect='auto', cmap='viridis')
ax.set_xlabel('Sample')
ax.set_ylabel('Factor')
ax.set_title('DUNL: Factor Loadings Across Samples')
plt.colorbar(im, ax=ax, label='Loading')

# Plot 2: Selectivity scores
ax = axes[0, 1]
ax.bar(range(len(selectivity)), selectivity, alpha=0.7, color='steelblue')
ax.set_xlabel('Factor')
ax.set_ylabel('Selectivity Score')
ax.set_title('Factor Selectivity to Task Variables')
ax.grid(True, alpha=0.3, axis='y')

# Plot 3: Factor correlation matrix
ax = axes[1, 0]
factor_corr = np.corrcoef(W.T)
im = ax.imshow(factor_corr, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
ax.set_xlabel('Factor')
ax.set_ylabel('Factor')
ax.set_title('Factor Correlation Matrix\n(Disentanglement: low off-diagonal)')
plt.colorbar(im, ax=ax, label='Correlation')

# Plot 4: Example factor activations by condition
ax = axes[1, 1]
stimulus_labels = task_labels[0]
for cond in range(4):
    mask = stimulus_labels == cond
    if np.sum(mask) > 0:
        mean_activation = W[mask, 0].mean()
        std_activation = W[mask, 0].std()
        ax.bar(cond, mean_activation, yerr=std_activation, alpha=0.7,
              label=f'Stimulus {cond+1}')

ax.set_xlabel('Stimulus Condition')
ax.set_ylabel('Factor 1 Activation')
ax.set_title('Factor 1: Stimulus Selectivity')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("- Factor loadings show how each factor varies across samples")
print("- High selectivity = factor strongly represents task variable")
print("- Low factor correlation = successful disentanglement")
print("- Condition-specific activations reveal tuning properties")

## Part 3: Feature Visualization and Activation Maximization

### What is Feature Visualization?

**Goal**: Find the input that maximally activates a specific neuron or feature.

Instead of analyzing responses to existing stimuli, we **generate** optimal stimuli:

```
x* = argmax_x f(x; neuron_i)
```

where:
- x: Input (e.g., image, sequence)
- f(x; neuron_i): Activation of neuron i given input x
- x*: Optimal stimulus

### Why This Matters

1. **Understand selectivity**: What does this neuron "look for"?
2. **Interpretability**: Optimal stimuli are often interpretable
3. **Debugging**: Identify unexpected selectivity patterns
4. **Adversarial robustness**: Find edge cases

### Optimization Methods

1. **Gradient Ascent**: Iteratively adjust input to maximize activation
2. **Regularization**: Add constraints (naturalness, smoothness)
3. **Diversity**: Generate multiple different optimal stimuli

### Applications

- **Vision**: Visualize what CNN filters detect
- **Language**: Find phrases that activate concepts
- **Neuroscience**: Design stimuli for experiments

In [None]:
class FeatureVisualizer:
    """
    Generate optimal stimuli that maximally activate specific features.
    """
    
    def __init__(self, model, layer_name=None):
        """
        Args:
            model: Neural network model
            layer_name: Which layer to visualize (if None, use output)
        """
        self.model = model
        self.layer_name = layer_name
    
    def visualize_neuron(self, neuron_idx, input_shape, 
                        n_iterations=200, learning_rate=0.1,
                        l2_penalty=0.01):
        """
        Generate optimal stimulus for specific neuron.
        
        Args:
            neuron_idx: Index of neuron to visualize
            input_shape: Shape of input (e.g., (1, time, features))
            n_iterations: Number of optimization steps
            learning_rate: Step size
            l2_penalty: Regularization strength
        
        Returns:
            optimal_input: Input that maximally activates neuron
            activations: Activation trajectory during optimization
        """
        # Initialize input randomly
        optimal_input = torch.randn(input_shape, requires_grad=True)
        
        # Optimizer
        optimizer = Adam([optimal_input], lr=learning_rate)
        
        activations = []
        
        for iteration in range(n_iterations):
            optimizer.zero_grad()
            
            # Forward pass
            if hasattr(self.model, 'forward'):
                output = self.model(optimal_input)
                if isinstance(output, tuple):
                    output = output[0]
            else:
                output = self.model(optimal_input)
            
            # Get target neuron activation
            # Mean across time/batch dimensions
            activation = output[..., neuron_idx].mean()
            
            # L2 regularization (keep inputs reasonable)
            l2_reg = l2_penalty * (optimal_input ** 2).mean()
            
            # Objective: maximize activation, minimize L2
            loss = -activation + l2_reg
            
            loss.backward()
            optimizer.step()
            
            activations.append(activation.item())
        
        return optimal_input.detach(), activations
    
    def visualize_multiple_neurons(self, neuron_indices, input_shape, **kwargs):
        """
        Visualize multiple neurons simultaneously.
        
        Returns:
            Dictionary mapping neuron_idx → (optimal_input, activations)
        """
        results = {}
        for neuron_idx in neuron_indices:
            opt_input, activations = self.visualize_neuron(
                neuron_idx, input_shape, **kwargs
            )
            results[neuron_idx] = (opt_input, activations)
        return results

print("Feature visualizer implemented!")

In [None]:
# Create a simple model to visualize
class SimpleFFN(nn.Module):
    def __init__(self, input_dim=10, hidden_dim=20, output_dim=5):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        h = torch.tanh(self.fc1(x))
        y = self.fc2(h)
        return h, y  # Return both hidden and output

# Instantiate model
model = SimpleFFN(input_dim=10, hidden_dim=20, output_dim=5)
model.eval()

print("Created model for feature visualization")
print(f"  Input dim: 10")
print(f"  Hidden dim: 20")
print(f"  Output dim: 5")

In [None]:
# Visualize multiple hidden neurons
visualizer = FeatureVisualizer(model)

# Visualize neurons 0, 5, 10
neuron_indices = [0, 5, 10]
results = visualizer.visualize_multiple_neurons(
    neuron_indices,
    input_shape=(1, 10),
    n_iterations=200,
    learning_rate=0.1,
    l2_penalty=0.01
)

print("Feature visualization complete!")
for neuron_idx in neuron_indices:
    opt_input, activations = results[neuron_idx]
    final_activation = activations[-1]
    print(f"  Neuron {neuron_idx}: Final activation = {final_activation:.4f}")

In [None]:
# Visualize results
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

for plot_idx, neuron_idx in enumerate(neuron_indices):
    opt_input, activations = results[neuron_idx]
    
    # Plot activation trajectory
    ax = axes[0, plot_idx]
    ax.plot(activations, linewidth=2, color='steelblue')
    ax.set_xlabel('Iteration')
    ax.set_ylabel('Activation')
    ax.set_title(f'Neuron {neuron_idx}: Optimization Trajectory')
    ax.grid(True, alpha=0.3)
    
    # Plot optimal input pattern
    ax = axes[1, plot_idx]
    input_pattern = opt_input.squeeze().numpy()
    ax.bar(range(len(input_pattern)), input_pattern, alpha=0.7, color='coral')
    ax.set_xlabel('Input Dimension')
    ax.set_ylabel('Value')
    ax.set_title(f'Optimal Input for Neuron {neuron_idx}')
    ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("- Activation increases during optimization")
print("- Optimal input shows what features the neuron prefers")
print("- Different neurons have different preferred patterns")

## Summary and Next Steps

### What We Learned

1. **Latent Circuit Models**: Extract minimal, interpretable circuits
   - Dimensionality reduction to find latent space
   - Learn recurrent dynamics in latent space
   - Achieve massive compression while preserving computation

2. **DUNL Decomposition**: Disentangle mixed selectivity
   - Separate responses into task-relevant factors
   - Measure selectivity to task variables
   - Understand interactions between factors

3. **Feature Visualization**: Generate optimal stimuli
   - Gradient-based optimization
   - Understand neuron selectivity
   - Design targeted experiments

### Key Takeaways

- **Circuits are sparse**: Most computation uses few dimensions
- **Mixed selectivity is common**: But can be disentangled
- **Visualization reveals function**: Optimal stimuli show what neurons compute
- **Interpretability requires simplification**: Extract minimal models

### Applications

1. **Model Compression**: Deploy circuits instead of full networks
2. **Neuroscience**: Design optimal experimental stimuli
3. **Debugging**: Identify where computations fail
4. **Knowledge Extraction**: Transfer learned algorithms to new domains

### Next Steps

1. **Notebook 08**: Biophysical modeling with spiking networks
2. **Notebook 09**: Information theory and energy landscapes
3. **Notebook 10**: Advanced topics (meta-dynamics, topology, counterfactuals)

### Further Reading

- Langdon & Engel (2025): *Latent circuit inference*
- Sussillo & Barak (2013): *Opening the black box*
- Olah et al. (2018): *Feature visualization* (Distill)
- Rigotti et al. (2013): *Mixed selectivity in cognitive tasks*