# CNS2025 — Homework 6: FitzHugh–Nagumo (FHN) model

**Due:** 2025-10-15 23:59  

This notebook implements all three exercises:
1. Compute and plot firing rate vs. parameter **a** for `a ∈ [0.55, 0.70]` with step `0.01`, using Euler, `Δt = 0.1`.
2. Repeat with Euler, `Δt = 1`.
3. Repeat with **Runge–Kutta (RK4)**, `Δt = 1`.

**Model**
\begin{align}
\dot v &= v - \frac{v^3}{3} - w + R I_{ext}\\
\tau\, \dot w &= v + a - b w
\end{align}

Default parameters from the assignment: `mr=5`, `ie=0`, `fb=0.5`, `tw=10`.  
Warm-up time = `300` (discard), simulation time = `10000` (used for rate).

**Spike definition:** count **up-crossings** of `v` at threshold `v_th = 1.0`.


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# ---------- Parameters (you can change ie to test) ----------
mr = 5.0   # membrane resistance R
ie = 0.0   # external current I_ext
fb = 0.5   # parameter b
tw = 10.0  # slow time scale tau
v_th = 1.0 # spike threshold for up-crossings
rng = np.random.default_rng(0)

def fhn_slope(y, a, b=fb, tau=tw, R=mr, Iext=ie):
    """Return dy/dt for FHN state y=[v,w]."""
    v, w = y
    dv = v - (v**3)/3.0 - w + R*Iext
    dw = (v + a - b*w) / tau
    return np.array([dv, dw])

def euler_step(y, slope, dt):
    return y + dt * slope(y)

def runge_kutta_step(y, slope, dt):
    k1 = slope(y)
    k2 = slope(y + 0.5*dt*k1)
    k3 = slope(y + 0.5*dt*k2)
    k4 = slope(y + dt*k3)
    return y + (dt/6.0)*(k1 + 2*k2 + 2*k3 + k4)

def simulate_fhn(tt, dt, a, y0=None, itg='euler', warmup=300.0, v_threshold=v_th):
    """Simulate FHN for time tt (used for rate) after warmup.
    Returns times, v, w, firing_rate.
    """
    n_warm = int(warmup/dt)
    n = int(tt/dt)
    if y0 is None:
        y = np.array([0.0, 0.0], dtype=float)
    else:
        y = np.array(y0, dtype=float)

    # choose integrator
    if itg == 'euler':
        step = lambda yy: euler_step(yy, lambda z: fhn_slope(z, a), dt)
    elif itg == 'rk4':
        step = lambda yy: runge_kutta_step(yy, lambda z: fhn_slope(z, a), dt)
    else:
        raise ValueError("itg must be 'euler' or 'rk4'")

    # warm-up (no recording)
    for _ in range(n_warm):
        y = step(y)

    # record
    vs = np.empty(n, dtype=float)
    ws = np.empty(n, dtype=float)
    ts = np.linspace(0.0, tt, n, endpoint=False)

    prev_v = y[0]
    spikes = 0
    for i in range(n):
        y = step(y)
        v = y[0]
        w = y[1]
        vs[i] = v
        ws[i] = w
        # up-crossing detection
        if (prev_v <= v_threshold) and (v > v_threshold):
            spikes += 1
        prev_v = v

    firing_rate = spikes / tt  # spikes per time unit
    return ts, vs, ws, firing_rate

def sweep_a(a_values, dt, itg='euler', warmup=300.0, tt=10000.0, v_threshold=v_th):
    rates = []
    for a in a_values:
        _, _, _, r = simulate_fhn(tt, dt, a, itg=itg, warmup=warmup, v_threshold=v_threshold)
        rates.append(r)
    return np.array(rates)

def estimate_ath(a_values, rates, eps=1e-6):
    """Return an estimate of a_th where periodic firing stops.
    Strategy: find the largest a with rate>eps, and the smallest a with rate<=eps afterwards.
    If both exist, return midpoint. Else, return boundary.
    """
    idx_pos = np.where(rates > eps)[0]
    if len(idx_pos) == 0:
        return a_values[0]
    last_fire_idx = idx_pos[-1]
    if last_fire_idx == len(a_values)-1:
        return a_values[-1]
    # next index is quiet
    a1 = a_values[last_fire_idx]
    a2 = a_values[last_fire_idx+1]
    return 0.5*(a1+a2)


In [None]:
# ===== Exercise 1: Euler, dt=0.1 =====
a_values = np.arange(0.55, 0.71, 0.01)
rates_dt01 = sweep_a(a_values, dt=0.1, itg='euler', warmup=300.0, tt=10000.0)
ath_dt01 = estimate_ath(a_values, rates_dt01)
print("[Exercise 1] Estimated a_th (Euler, dt=0.1):", ath_dt01)
plt.figure()
plt.plot(a_values, rates_dt01, marker='o')
plt.axvline(ath_dt01, linestyle='--')
plt.xlabel('a')
plt.ylabel('Firing rate (spikes / time)')
plt.title('Exercise 1: Euler dt=0.1')
plt.show()


In [None]:
# ===== Exercise 2: Euler, dt=1 =====
a_values = np.arange(0.55, 0.71, 0.01)
rates_dt1 = sweep_a(a_values, dt=1.0, itg='euler', warmup=300.0, tt=10000.0)
ath_dt1 = estimate_ath(a_values, rates_dt1)
print("[Exercise 2] Estimated a_th (Euler, dt=1):", ath_dt1)
plt.figure()
plt.plot(a_values, rates_dt1, marker='o')
plt.axvline(ath_dt1, linestyle='--')
plt.xlabel('a')
plt.ylabel('Firing rate (spikes / time)')
plt.title('Exercise 2: Euler dt=1')
plt.show()


In [None]:
# ===== Exercise 3: RK4, dt=1 =====
a_values = np.arange(0.55, 0.71, 0.01)
rates_rk4 = sweep_a(a_values, dt=1.0, itg='rk4', warmup=300.0, tt=10000.0)
ath_rk4 = estimate_ath(a_values, rates_rk4)
print("[Exercise 3] Estimated a_th (RK4, dt=1):", ath_rk4)
plt.figure()
plt.plot(a_values, rates_rk4, marker='o')
plt.axvline(ath_rk4, linestyle='--')
plt.xlabel('a')
plt.ylabel('Firing rate (spikes / time)')
plt.title('Exercise 3: RK4 dt=1')
plt.show()


## Notes
- The spike threshold `v_th=1.0` follows the lecture's convention for counting events.
- You can shorten `tt` while testing. Use the full `tt=10000` for the final plots.
- `estimate_ath` returns a midpoint between the last firing `a` and the first quiet `a` in the sweep grid.
- To reproduce lecture-style traces or phase portraits, record `(v,w)` and plot trajectories for selected `a` values.
