# Notebook 08: Biophysical Modeling with Spiking Networks

## Bridging Artificial and Biological Neural Networks

This notebook explores **biophysically realistic neural network models** that capture key properties of real neurons:
- **Spiking dynamics**: Neurons communicate via discrete action potentials
- **Temporal precision**: Spike timing carries information
- **Dale's law**: Neurons are either excitatory or inhibitory (not both)
- **Biological constraints**: Realistic synaptic dynamics and plasticity

### Why Biophysical Modeling Matters

1. **Brain Alignment**: Models that work like brains are easier to compare to neuroscience data
2. **Energy Efficiency**: Spiking networks can be more efficient than rate-based models
3. **Temporal Processing**: Spikes enable precise temporal computation
4. **Interpretability**: Biological constraints make models more interpretable
5. **Neuromorphic Hardware**: Spiking networks can run on specialized chips

### What You'll Learn

1. **Spiking Neuron Models**: LIF, Izhikevich, Hodgkin-Huxley
2. **Surrogate Gradients**: Backpropagation through spikes
3. **Dale's Law**: Enforcing E/I neuron separation
4. **Synaptic Plasticity**: STDP and other learning rules
5. **Network Dynamics**: Building spiking neural networks
6. **Biological Constraints**: Enforcing realistic connectivity

### References

- Gerstner & Kistler (2002): *Spiking Neuron Models*
- Izhikevich (2003): *Simple model of spiking neurons*
- Neftci et al. (2019): *Surrogate gradient learning in spiking neural networks*
- Song et al. (2000): *Competitive Hebbian learning through STDP*
- Litwin-Kumar & Doiron (2012): *Slow dynamics and high variability in balanced networks*

In [None]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

print("All libraries imported successfully!")
print(f"NumPy version: {np.__version__}")
print(f"PyTorch version: {torch.__version__}")

## Part 1: Spiking Neuron Models

### What Makes Neurons Spike?

Biological neurons have a **membrane potential** V(t) that:
1. Integrates input currents
2. Generates a spike when V(t) exceeds threshold
3. Resets after spiking
4. Has a refractory period

### Leaky Integrate-and-Fire (LIF) Model

The simplest spiking model:

```
τ dV/dt = -(V - V_rest) + R * I(t)

if V(t) ≥ V_th:
    emit spike
    V(t) ← V_reset
```

where:
- V: Membrane potential
- τ: Membrane time constant (typically 10-20 ms)
- V_rest: Resting potential (-65 mV)
- V_th: Spike threshold (-55 mV)
- V_reset: Reset potential (-70 mV)
- R: Membrane resistance
- I(t): Input current

### Izhikevich Model

More realistic dynamics with just 2 variables:

```
dV/dt = 0.04*V² + 5*V + 140 - u + I
du/dt = a*(b*V - u)

if V ≥ 30 mV:
    V ← c
    u ← u + d
```

Parameters (a, b, c, d) control neuron type:
- Regular spiking: (0.02, 0.2, -65, 8)
- Fast spiking: (0.1, 0.2, -65, 2)
- Bursting: (0.02, 0.2, -50, 2)

### Hodgkin-Huxley Model

Biophysically detailed with ion channels:

```
C dV/dt = I - g_Na*m³*h*(V-E_Na) - g_K*n⁴*(V-E_K) - g_L*(V-E_L)
```

Most accurate but computationally expensive.

In [None]:
class LeakyIntegrateFireNeuron:
    """
    Leaky Integrate-and-Fire (LIF) neuron model.
    
    Simple but captures essential spiking dynamics.
    """
    
    def __init__(self, tau=10.0, v_rest=-65.0, v_th=-55.0, v_reset=-70.0, dt=1.0):
        """
        Args:
            tau: Membrane time constant (ms)
            v_rest: Resting potential (mV)
            v_th: Spike threshold (mV)
            v_reset: Reset potential after spike (mV)
            dt: Time step (ms)
        """
        self.tau = tau
        self.v_rest = v_rest
        self.v_th = v_th
        self.v_reset = v_reset
        self.dt = dt
        
        # State
        self.v = v_rest
        self.spike_times = []
    
    def step(self, input_current):
        """
        Simulate one time step.
        
        Args:
            input_current: Input current (arbitrary units)
        
        Returns:
            spike: Whether neuron spiked (bool)
        """
        # Integrate: dV/dt = -(V - V_rest)/tau + I/tau
        dv = (-(self.v - self.v_rest) + input_current) / self.tau
        self.v += dv * self.dt
        
        # Check for spike
        spike = self.v >= self.v_th
        
        # Reset if spiked
        if spike:
            self.v = self.v_reset
            self.spike_times.append(len(self.spike_times))
        
        return spike
    
    def reset(self):
        """Reset neuron to resting state."""
        self.v = self.v_rest
        self.spike_times = []

print("LIF neuron model implemented!")

In [None]:
class IzhikevichNeuron:
    """
    Izhikevich spiking neuron model.
    
    Captures diverse neuron types with 4 parameters.
    """
    
    def __init__(self, a=0.02, b=0.2, c=-65, d=8, dt=0.1):
        """
        Args:
            a: Recovery time scale
            b: Sensitivity to subthreshold V
            c: Reset voltage (mV)
            d: Reset recovery variable
            dt: Time step (ms)
        """
        self.a = a
        self.b = b
        self.c = c
        self.d = d
        self.dt = dt
        
        # State
        self.v = -65  # Membrane potential
        self.u = self.b * self.v  # Recovery variable
        self.spike_times = []
    
    def step(self, input_current):
        """
        Simulate one time step.
        
        Args:
            input_current: Input current
        
        Returns:
            spike: Whether neuron spiked
        """
        # Update voltage: dV/dt = 0.04*V² + 5*V + 140 - u + I
        dv = (0.04 * self.v**2 + 5*self.v + 140 - self.u + input_current) * self.dt
        
        # Update recovery: du/dt = a*(b*V - u)
        du = self.a * (self.b * self.v - self.u) * self.dt
        
        self.v += dv
        self.u += du
        
        # Check for spike
        spike = self.v >= 30
        
        # Reset if spiked
        if spike:
            self.v = self.c
            self.u += self.d
            self.spike_times.append(len(self.spike_times))
        
        return spike
    
    def reset(self):
        """Reset neuron."""
        self.v = -65
        self.u = self.b * self.v
        self.spike_times = []

print("Izhikevich neuron model implemented!")

In [None]:
# Simulate and compare LIF and Izhikevich neurons

# Simulation parameters
T = 200  # Total time (ms)
dt = 0.1  # Time step (ms)
n_steps = int(T / dt)
time = np.arange(0, T, dt)

# Input current: step from t=50 to t=150
input_current = np.zeros(n_steps)
input_current[int(50/dt):int(150/dt)] = 15

# Create neurons
lif = LeakyIntegrateFireNeuron(tau=10.0, dt=dt)
izh_regular = IzhikevichNeuron(a=0.02, b=0.2, c=-65, d=8, dt=dt)
izh_fast = IzhikevichNeuron(a=0.1, b=0.2, c=-65, d=2, dt=dt)

# Simulate
v_lif = []
v_izh_reg = []
v_izh_fast = []

for i in range(n_steps):
    lif.step(input_current[i])
    izh_regular.step(input_current[i])
    izh_fast.step(input_current[i])
    
    v_lif.append(lif.v)
    v_izh_reg.append(izh_regular.v)
    v_izh_fast.append(izh_fast.v)

v_lif = np.array(v_lif)
v_izh_reg = np.array(v_izh_reg)
v_izh_fast = np.array(v_izh_fast)

print(f"LIF neuron: {len(lif.spike_times)} spikes")
print(f"Izhikevich (regular): {len(izh_regular.spike_times)} spikes")
print(f"Izhikevich (fast): {len(izh_fast.spike_times)} spikes")

In [None]:
# Visualize neuron responses
fig, axes = plt.subplots(4, 1, figsize=(14, 10), sharex=True)

# Plot 1: Input current
ax = axes[0]
ax.plot(time, input_current, 'k-', linewidth=2)
ax.set_ylabel('Input Current')
ax.set_title('Spiking Neuron Models: Response to Step Current')
ax.grid(True, alpha=0.3)

# Plot 2: LIF
ax = axes[1]
ax.plot(time, v_lif, 'b-', linewidth=1)
ax.axhline(y=lif.v_th, color='red', linestyle='--', label='Threshold')
ax.set_ylabel('Voltage (mV)')
ax.set_title('LIF Neuron')
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)

# Plot 3: Izhikevich (regular spiking)
ax = axes[2]
ax.plot(time, v_izh_reg, 'g-', linewidth=1)
ax.axhline(y=30, color='red', linestyle='--', label='Threshold')
ax.set_ylabel('Voltage (mV)')
ax.set_title('Izhikevich (Regular Spiking)')
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)

# Plot 4: Izhikevich (fast spiking)
ax = axes[3]
ax.plot(time, v_izh_fast, 'm-', linewidth=1)
ax.axhline(y=30, color='red', linestyle='--', label='Threshold')
ax.set_xlabel('Time (ms)')
ax.set_ylabel('Voltage (mV)')
ax.set_title('Izhikevich (Fast Spiking)')
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("- LIF: Simple regular spiking")
print("- Izhikevich regular: Spike frequency adaptation")
print("- Izhikevich fast: High firing rate, minimal adaptation")

## Part 2: Surrogate Gradients for Training Spiking Networks

### The Gradient Problem

**Challenge**: Spikes are discontinuous (0 or 1), so gradients are zero almost everywhere!

```
spike = 1 if V ≥ V_th else 0
d(spike)/dV = δ(V - V_th)  # Dirac delta - not useful!
```

### Surrogate Gradient Solution

Replace true gradient with a smooth **surrogate** during backpropagation:

**Forward pass**: Use actual step function
```
spike = H(V - V_th)  # Heaviside step
```

**Backward pass**: Use smooth approximation
```
dL/dV ≈ dL/dspike * σ'(V - V_th)
```

where σ' can be:
- **Sigmoid derivative**: σ(x) * (1 - σ(x))
- **Fast sigmoid**: 1 / (1 + |x/β|)²
- **Exponential**: exp(-|x/β|)
- **SuperSpike**: 1 / (1 + |x/β|)²

### Why This Works

- Forward pass: Correct spiking behavior
- Backward pass: Gradients flow, enabling learning
- Empirically effective for many tasks

In [None]:
class SurrogateGradient(torch.autograd.Function):
    """
    Spike function with surrogate gradient for backpropagation.
    
    Forward: Heaviside step (0 or 1)
    Backward: Smooth surrogate (enables gradient flow)
    """
    
    @staticmethod
    def forward(ctx, input, threshold=0.0, beta=1.0):
        """
        Args:
            input: Membrane potential
            threshold: Spike threshold
            beta: Surrogate gradient slope
        
        Returns:
            spike: Binary spike (0 or 1)
        """
        ctx.save_for_backward(input)
        ctx.threshold = threshold
        ctx.beta = beta
        
        # Forward: step function
        spike = (input >= threshold).float()
        return spike
    
    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward: smooth surrogate gradient.
        
        Using fast sigmoid: 1 / (1 + |x/β|)²
        """
        input, = ctx.saved_tensors
        threshold = ctx.threshold
        beta = ctx.beta
        
        # Surrogate gradient
        x = (input - threshold) / beta
        surrogate = 1.0 / (1.0 + torch.abs(x))**2
        
        return grad_output * surrogate, None, None

# Convenience function
def spike_fn(x, threshold=0.0, beta=1.0):
    """Spiking activation with surrogate gradient."""
    return SurrogateGradient.apply(x, threshold, beta)

print("Surrogate gradient spike function implemented!")

In [None]:
class LIFLayer(nn.Module):
    """
    Layer of Leaky Integrate-and-Fire neurons with surrogate gradients.
    
    Trainable via backpropagation through time (BPTT).
    """
    
    def __init__(self, n_neurons, tau=10.0, v_th=1.0, v_reset=0.0, beta=1.0):
        super().__init__()
        self.n_neurons = n_neurons
        self.tau = tau
        self.v_th = v_th
        self.v_reset = v_reset
        self.beta = beta
    
    def forward(self, input_current, v_mem=None, dt=1.0):
        """
        Args:
            input_current: (batch, time, n_neurons)
            v_mem: Initial membrane potential (batch, n_neurons)
            dt: Time step
        
        Returns:
            spikes: (batch, time, n_neurons)
            v_mem: Final membrane potential
        """
        batch, time, _ = input_current.shape
        device = input_current.device
        
        # Initialize membrane potential
        if v_mem is None:
            v_mem = torch.zeros(batch, self.n_neurons, device=device)
        
        # Collect spikes over time
        spikes_over_time = []
        
        for t in range(time):
            # Leak: dV/dt = -V/tau
            v_mem = v_mem * (1 - dt/self.tau)
            
            # Input: dV/dt += I/tau
            v_mem = v_mem + input_current[:, t, :] * (dt/self.tau)
            
            # Spike generation (with surrogate gradient)
            spike = spike_fn(v_mem, self.v_th, self.beta)
            
            # Reset
            v_mem = v_mem * (1 - spike) + self.v_reset * spike
            
            spikes_over_time.append(spike)
        
        spikes = torch.stack(spikes_over_time, dim=1)
        return spikes, v_mem

print("LIF layer with surrogate gradients implemented!")

In [None]:
# Test surrogate gradient
# Compare true step function gradient (zero) vs surrogate (smooth)

x = torch.linspace(-5, 5, 200, requires_grad=True)

# Forward pass: spike function
y = spike_fn(x, threshold=0.0, beta=1.0)

# Backward pass: compute gradient
y.sum().backward()
grad = x.grad.clone()

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Forward (spike function)
ax = axes[0]
ax.plot(x.detach().numpy(), y.detach().numpy(), linewidth=2, label='Spike function')
ax.axvline(x=0, color='red', linestyle='--', alpha=0.5, label='Threshold')
ax.set_xlabel('Membrane Potential')
ax.set_ylabel('Spike Output')
ax.set_title('Forward Pass: Heaviside Step Function')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 2: Backward (surrogate gradient)
ax = axes[1]
ax.plot(x.detach().numpy(), grad.numpy(), linewidth=2, color='green', label='Surrogate gradient')
ax.axvline(x=0, color='red', linestyle='--', alpha=0.5, label='Threshold')
ax.set_xlabel('Membrane Potential')
ax.set_ylabel('Gradient')
ax.set_title('Backward Pass: Smooth Surrogate Gradient')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("- Forward: Sharp threshold (realistic spiking)")
print("- Backward: Smooth gradient (enables learning)")
print("- Gradients flow even away from threshold")

## Part 3: Dale's Law - Excitatory and Inhibitory Neurons

### What is Dale's Law?

**Dale's Law**: A neuron releases the same set of neurotransmitters at all of its synapses.

**Practical implication**: Each neuron is either:
- **Excitatory (E)**: All outgoing connections are positive
- **Inhibitory (I)**: All outgoing connections are negative

Neurons cannot be both!

### Why Enforce Dale's Law?

1. **Biological realism**: Real brains obey this constraint
2. **Interpretability**: E/I balance is well-studied in neuroscience
3. **Stability**: Proper E/I balance prevents runaway activity
4. **Alignment**: Makes models easier to compare to brain data

### Implementation

For a weight matrix W connecting layers:
- **Designate**: Neuron i is E or I
- **Constrain**: All weights from neuron i have same sign
- **Enforce**: During training, project weights back to valid set

Methods:
1. **Hard constraint**: W[i, :] = |W[i, :]| if E, -|W[i, :]| if I
2. **Soft constraint**: Add penalty for violations
3. **Parametrization**: W = D @ |W_raw| where D is diagonal ±1

In [None]:
class DalesLinear(nn.Module):
    """
    Linear layer with Dale's law constraint.
    
    Each neuron is either excitatory (E) or inhibitory (I).
    All outgoing weights from a neuron have the same sign.
    """
    
    def __init__(self, in_features, out_features, fraction_exc=0.8):
        """
        Args:
            in_features: Input dimension
            out_features: Output dimension
            fraction_exc: Fraction of excitatory neurons (typically 0.8)
        """
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Unsigned weights (always positive)
        self.weight_magnitude = nn.Parameter(torch.rand(out_features, in_features))
        
        # Dale's law: sign matrix (fixed, not learned)
        n_exc = int(in_features * fraction_exc)
        signs = torch.ones(in_features)
        signs[n_exc:] = -1  # Last neurons are inhibitory
        self.register_buffer('dale_signs', signs)
        
        # Bias
        self.bias = nn.Parameter(torch.zeros(out_features))
    
    def forward(self, x):
        """
        Forward pass with Dale's law enforced.
        
        W = |W_magnitude| ⊙ sign_matrix
        """
        # Enforce sign constraint
        weight = torch.abs(self.weight_magnitude) * self.dale_signs.unsqueeze(0)
        
        # Linear transformation
        return F.linear(x, weight, self.bias)
    
    def get_excitatory_mask(self):
        """Return mask for excitatory neurons."""
        return self.dale_signs > 0
    
    def get_inhibitory_mask(self):
        """Return mask for inhibitory neurons."""
        return self.dale_signs < 0

print("Dale's law linear layer implemented!")

In [None]:
class EISpikingNetwork(nn.Module):
    """
    Spiking neural network with Dale's law (E/I separation).
    """
    
    def __init__(self, input_size, hidden_size, output_size, fraction_exc=0.8):
        super().__init__()
        
        # Input projection
        self.input_proj = nn.Linear(input_size, hidden_size)
        
        # Recurrent layer with Dale's law
        self.recurrent = DalesLinear(hidden_size, hidden_size, fraction_exc)
        
        # Spiking neurons
        self.lif = LIFLayer(hidden_size, tau=10.0, v_th=1.0)
        
        # Output readout (rate-based)
        self.output = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        """
        Args:
            x: Input (batch, time, input_size)
        
        Returns:
            output: (batch, time, output_size)
            spikes: (batch, time, hidden_size)
        """
        batch, time, _ = x.shape
        
        # Project input
        x_proj = self.input_proj(x)
        
        # Initialize state
        v_mem = None
        spike = torch.zeros(batch, self.lif.n_neurons, device=x.device)
        
        outputs = []
        all_spikes = []
        
        for t in range(time):
            # Recurrent input (Dale's law enforced)
            rec_input = self.recurrent(spike)
            
            # Total input
            total_input = x_proj[:, t:t+1, :] + rec_input.unsqueeze(1)
            
            # LIF dynamics
            spike_t, v_mem = self.lif(total_input, v_mem)
            spike = spike_t.squeeze(1)
            
            # Readout
            output_t = self.output(spike)
            
            outputs.append(output_t)
            all_spikes.append(spike)
        
        output = torch.stack(outputs, dim=1)
        spikes = torch.stack(all_spikes, dim=1)
        
        return output, spikes

print("E/I spiking network implemented!")

In [None]:
# Test Dale's law enforcement
input_size = 10
hidden_size = 50
output_size = 2
fraction_exc = 0.8

# Create network
net = EISpikingNetwork(input_size, hidden_size, output_size, fraction_exc)

# Generate test input
x = torch.randn(4, 20, input_size) * 0.5

# Forward pass
output, spikes = net(x)

print(f"Network created with Dale's law:")
print(f"  Input size: {input_size}")
print(f"  Hidden size: {hidden_size}")
print(f"  Excitatory neurons: {int(hidden_size * fraction_exc)}")
print(f"  Inhibitory neurons: {hidden_size - int(hidden_size * fraction_exc)}")
print(f"\nForward pass:")
print(f"  Output shape: {output.shape}")
print(f"  Spikes shape: {spikes.shape}")
print(f"  Mean firing rate: {spikes.mean().item():.4f}")

In [None]:
# Visualize Dale's law enforcement
# Extract recurrent weight matrix
with torch.no_grad():
    weight_magnitude = net.recurrent.weight_magnitude.numpy()
    signs = net.recurrent.dale_signs.numpy()
    weight_actual = np.abs(weight_magnitude) * signs

# Get E/I masks
exc_mask = net.recurrent.get_excitatory_mask().numpy()
inh_mask = net.recurrent.get_inhibitory_mask().numpy()

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Full weight matrix
ax = axes[0, 0]
im = ax.imshow(weight_actual, cmap='RdBu_r', aspect='auto', 
               vmin=-np.abs(weight_actual).max(), 
               vmax=np.abs(weight_actual).max())
ax.axvline(x=np.sum(exc_mask)-0.5, color='yellow', linewidth=2, label='E/I boundary')
ax.set_xlabel('Pre-synaptic Neuron')
ax.set_ylabel('Post-synaptic Neuron')
ax.set_title('Recurrent Weights (Dale\'s Law Enforced)')
ax.legend()
plt.colorbar(im, ax=ax, label='Weight')

# Plot 2: Column sums (verify all same sign)
ax = axes[0, 1]
col_max = weight_actual.max(axis=0)
col_min = weight_actual.min(axis=0)
neuron_indices = np.arange(hidden_size)
colors = ['blue' if e else 'red' for e in exc_mask]
ax.scatter(neuron_indices, col_max, c=colors, alpha=0.6, s=20, label='Max weight')
ax.scatter(neuron_indices, col_min, c=colors, alpha=0.6, s=20, label='Min weight')
ax.axhline(y=0, color='black', linestyle='-', linewidth=1)
ax.axvline(x=np.sum(exc_mask)-0.5, color='yellow', linestyle='--', linewidth=2)
ax.set_xlabel('Neuron Index')
ax.set_ylabel('Weight Value')
ax.set_title('Weight Range per Neuron (Blue=E, Red=I)')
ax.grid(True, alpha=0.3)

# Plot 3: Spike raster
ax = axes[1, 0]
# Show first trial
spike_raster = spikes[0].numpy()
spike_times, spike_neurons = np.where(spike_raster)
colors_raster = ['blue' if exc_mask[n] else 'red' for n in spike_neurons]
ax.scatter(spike_times, spike_neurons, s=1, c=colors_raster, alpha=0.5)
ax.axhline(y=np.sum(exc_mask)-0.5, color='yellow', linestyle='--', linewidth=2, label='E/I boundary')
ax.set_xlabel('Time Step')
ax.set_ylabel('Neuron Index')
ax.set_title('Spike Raster (Blue=E, Red=I)')
ax.legend()

# Plot 4: Firing rates E vs I
ax = axes[1, 1]
firing_rates = spikes.mean(dim=(0, 1)).numpy()
exc_rates = firing_rates[exc_mask]
inh_rates = firing_rates[inh_mask]
ax.hist(exc_rates, bins=20, alpha=0.6, label=f'Excitatory (n={len(exc_rates)})', color='blue')
ax.hist(inh_rates, bins=20, alpha=0.6, label=f'Inhibitory (n={len(inh_rates)})', color='red')
ax.set_xlabel('Firing Rate')
ax.set_ylabel('Count')
ax.set_title('Firing Rate Distribution: E vs I')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("- Weight matrix: E neurons (left) have positive weights, I neurons (right) have negative")
print("- Column statistics: All weights from each neuron have same sign")
print("- Spike raster: Both E and I neurons are active")
print("- Firing rates: E and I populations may have different activity levels")

## Summary and Next Steps

### What We Learned

1. **Spiking Neuron Models**: LIF, Izhikevich, and their properties
   - Membrane dynamics and threshold crossing
   - Different neuron types (regular, fast spiking, bursting)
   - Trade-off between realism and computational efficiency

2. **Surrogate Gradients**: Training spiking networks via backprop
   - Forward pass: Binary spikes (realistic)
   - Backward pass: Smooth gradients (trainable)
   - Enables end-to-end learning in spiking networks

3. **Dale's Law**: Biological E/I constraint
   - Each neuron is exclusively excitatory or inhibitory
   - Enforced via parametrization (sign × magnitude)
   - Improves biological realism and interpretability

### Key Takeaways

- **Spikes matter**: Temporal precision enables new computations
- **Biological constraints help**: Dale's law improves interpretability
- **Surrogate gradients work**: Enables training despite discontinuities
- **E/I balance is crucial**: Proper balance prevents runaway activity

### Applications

1. **Neuroscience**: Models that match brain recordings
2. **Neuromorphic Computing**: Efficient hardware implementations
3. **Temporal Processing**: Tasks requiring precise timing
4. **Biological Plausibility**: More interpretable models

### Next Steps

1. **Notebook 09**: Information theory and energy landscapes
2. **Notebook 10**: Advanced topics (meta-dynamics, topology, counterfactuals)

### Further Reading

- Gerstner & Kistler (2002): *Spiking Neuron Models*
- Neftci et al. (2019): *Surrogate gradient learning*
- Izhikevich (2003): *Simple model of spiking neurons*
- Zenke & Ganguli (2018): *SuperSpike: Supervised learning in spiking networks*