# BG–Thalamus with Explicit GPi Suppression (Interactive, fixed)

This notebook extends the basal ganglia action-selection model by making the **GPi→Thalamus** inhibitory pathway explicit.
You can explore how **global suppression of GPi** (e.g., via a transient FUS-like modulation) changes thalamic output:

- Moderate suppression → selection flips between channels.
- Strong suppression → global disinhibition (both thalamic channels elevate; loss of selectivity).

**Usage:** Adjust sliders to update plots via `interact`.


In [None]:
# Colab-friendly dependency install (safe to run locally as well)
try:
    import google.colab  # type: ignore
    %pip -q install nengo matplotlib ipywidgets
except Exception:
    pass

In [None]:
import os
import numpy as np
import nengo
import matplotlib.pyplot as plt
from ipywidgets import FloatSlider, IntSlider, interact

# --- Disable Nengo decoder cache (avoids cache index warnings in hosted envs) ---
nengo.rc.set("decoder_cache", "enabled", "False")
os.environ.setdefault("NENGO_CACHE_DIR", "/tmp/nengo_cache")

np.random.seed(0)

In [None]:
def run_sim(A1=0.85, A2=0.65, T_total=1.6,
            gpi_scale_base=1.0, fus_start=0.6, fus_dur=0.5, fus_depth=0.6,
            seed=0):
    """Simulate BG–Thalamus with explicit GPi scaling (suppression).
    
    Parameters
    ----------
    A1, A2 : float
        Cortical drives for action channels 1 and 2.
    T_total : float
        Total simulation time (s).
    gpi_scale_base : float
        Baseline global scale for GPi output (1.0 = normal; 0.0 = silenced).
    fus_start, fus_dur : float
        Start time and duration (s) of the suppression window.
    fus_depth : float
        Additional suppression during the window (0..1). Effective scale = base*(1 - depth) during the window.
    seed : int
        Random seed for simulator (integer).
    """
    seed = int(seed)
    np.random.seed(seed)

    def gpi_scale_fn(t):
        in_pulse = (fus_start <= t <= fus_start + fus_dur)
        return gpi_scale_base * (1.0 - fus_depth if in_pulse else 1.0)

    model = nengo.Network(label="BG–Thalamus with GPi suppression")
    with model:
        cortex = nengo.Node(lambda t: [A1, A2], label="Cortex [A1,A2]")
        bg = nengo.networks.BasalGanglia(dimensions=2, label="BasalGanglia")
        th = nengo.networks.Thalamus(dimensions=2, label="Thalamus")
        scale = nengo.Node(gpi_scale_fn, label="gpi_scale(t)")

        # Connect cortex to BG
        nengo.Connection(cortex, bg.input, synapse=0.02)

        # Scaled GPi node: size_in=3 → [bg0, bg1, s]; outputs [s*bg0, s*bg1]
        def scale_func(t, x):
            bg0, bg1, s = x
            return [s * bg0, s * bg1]
        scaled_gpi = nengo.Node(scale_func, size_in=3, label="Scaled GPi")

        # Route inputs into scaled_gpi
        nengo.Connection(bg.output[0], scaled_gpi[0], synapse=0.02)
        nengo.Connection(bg.output[1], scaled_gpi[1], synapse=0.02)
        nengo.Connection(scale,        scaled_gpi[2], synapse=0.0)

        # Now feed scaled inhibitory output into thalamus
        nengo.Connection(scaled_gpi, th.input, synapse=0.02)

        # Probes
        p_ctx  = nengo.Probe(cortex)
        p_bg   = nengo.Probe(bg.output, synapse=0.05)
        p_gpi  = nengo.Probe(scaled_gpi, synapse=0.05)
        p_th   = nengo.Probe(th.output, synapse=0.05)
        p_s    = nengo.Probe(scale)

    with nengo.Simulator(model, seed=seed) as sim:
        sim.run(T_total)

    t = sim.trange()
    ctx = sim.data[p_ctx]
    bg_out = sim.data[p_bg]
    gpi_out = sim.data[p_gpi]
    th_out = sim.data[p_th]
    s_vals = sim.data[p_s].reshape(-1)

    # Plot each chart in its own figure (per Colab plotting guidance)
    plt.figure(figsize=(7,3))
    plt.plot(t, ctx[:,0], label="A1")
    plt.plot(t, ctx[:,1], label="A2")
    plt.axvspan(fus_start, fus_start + fus_dur, alpha=0.15, label="FUS window")
    plt.legend(); plt.ylabel("Cortex"); plt.xlabel("Time (s)")
    plt.show()

    plt.figure(figsize=(7,3))
    plt.plot(t, bg_out)
    plt.ylabel("BG→GPi (inhib level)"); plt.xlabel("Time (s)")
    plt.show()

    plt.figure(figsize=(7,3))
    plt.plot(t, gpi_out)
    plt.plot(t, s_vals, linestyle=":", label="gpi_scale(t)")
    plt.ylabel("Scaled GPi to Th"); plt.xlabel("Time (s)"); plt.legend()
    plt.show()

    plt.figure(figsize=(7,3))
    plt.plot(t, th_out)
    plt.ylabel("Thalamus out"); plt.xlabel("Time (s)")
    plt.show()

    return dict(t=t, cortex=ctx, bg=bg_out, gpi=gpi_out, th=th_out, scale=s_vals)


In [None]:
interact(
    run_sim,
    A1=FloatSlider(min=0.0, max=1.2, step=0.05, value=0.85),
    A2=FloatSlider(min=0.0, max=1.2, step=0.05, value=0.65),
    T_total=FloatSlider(min=0.8, max=3.0, step=0.1, value=1.6),
    gpi_scale_base=FloatSlider(min=0.0, max=1.5, step=0.05, value=1.0),
    fus_start=FloatSlider(min=0.0, max=2.0, step=0.05, value=0.6),
    fus_dur=FloatSlider(min=0.05, max=1.0, step=0.05, value=0.5),
    fus_depth=FloatSlider(min=0.0, max=1.0, step=0.05, value=0.6),
    seed=IntSlider(min=0, max=10, step=1, value=0)
);

### Notes
- **Interpretation**: BG output approximates **GPi/SNr inhibitory level**; after scaling, this reaches thalamus.
- **Phenomena to look for**: selection reversal during moderate suppression, and **global disinhibition** when suppression is strong/prolonged.
- **Next steps**: connect cortical drives to task variables, add noise/oscillations, and map regimes to behavioural predictions (e.g., RT/SSRT shifts).