[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/trendinafrica/TReND-CaMinA/tree/main/notebooks/Zambia25/15-Tue-DynSys/RNN_gain_dynamics_demo.ipynb)


In [14]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML, display
import ipywidgets as widgets

# ---- UI for g selection ----
g_slider = widgets.FloatSlider(value=0.8, min=0.1, max=5.0, step=0.1, description='Gain g:')
run_button = widgets.Button(description='Run Simulation')
output = widgets.Output()

display(g_slider, run_button, output)

def run_simulation(g):
    # Parameters
    num_in = 2
    num_ex = 400
    p_connect = 0.1
    tau = 10
    tmax = 1000
    dt = 1.0
    plot_neurons = 10
    N = 20  # Only animate every Nth frame

    # Input pulse times
    input1_times = np.arange(100,150,1)
    input2_times = []

    # Weight initialization
    np.random.seed(11)
    win_ex = np.random.randn(num_ex, num_in)
    mask = (np.random.rand(num_ex, num_ex) < p_connect).astype(float)
    wex_ex = np.random.randn(num_ex, num_ex) * np.sqrt(1 / (num_ex * p_connect))
    wex_ex *= mask
    wex_ex *= g
    wout_ex = (np.random.rand(1, num_ex) * 2 - 1) * 1.0

    # Eigenvalue analysis
    eigvals = np.linalg.eigvals(wex_ex)
    spectral_radius = np.max(np.real(eigvals))

    # Precompute full simulation
    ex = np.zeros(num_ex)
    ex_all = np.zeros((num_ex, tmax))
    for t in range(tmax):
        inp = np.zeros(num_in)
        if t in input1_times:
            inp[0] = 10.0
        if t in input2_times:
            inp[1] = 10.0

        input_drive = win_ex @ inp
        recurrent_drive = wex_ex @ ex
        ex += dt / tau * (-ex + np.tanh(input_drive + recurrent_drive))
        ex_all[:, t] = ex

    # Precompute trace offsets
    trace_offsets = ex_all[:plot_neurons] - np.arange(plot_neurons)[:, None]
    x_vals = np.arange(tmax)

    # Setup figure
    fig = plt.figure(figsize=(10, 6), facecolor='black')

    # Histogram
    ax_hist = fig.add_axes([0.05, 0.55, 0.3, 0.35])
    ax_hist.hist(wex_ex[wex_ex != 0].flatten(), bins=50, color='white', range=(-1, 1))
    ax_hist.set_title(f"max eigenvalue = {spectral_radius:.2f}", color='white')
    ax_hist.tick_params(colors='white')
    ax_hist.set_facecolor('black')

    # Heatmap
    ax_heat = fig.add_axes([0.4, 0.55, 0.55, 0.35])
    heat_img = ax_heat.imshow(np.zeros((20, 20)), cmap='gray', vmin=-1, vmax=1, interpolation='none')
    ax_heat.set_title('Neuron Activity Snapshot (20x20)', color='white')
    ax_heat.tick_params(colors='white')
    ax_heat.set_facecolor('black')

    # Traces
    ax_trace = fig.add_axes([0.05, 0.05, 0.9, 0.4])
    trace_lines = []
    colors = ['cyan', 'magenta', 'yellow', 'lime', 'orange', 'violet', 'turquoise', 'salmon', 'white', 'deepskyblue']
    for i, color in enumerate(colors[:plot_neurons]):
        line, = ax_trace.plot([], [], lw=2, color=color)
        trace_lines.append(line)
    ax_trace.set_xlim(0, tmax)  # Show full timespan
    ax_trace.set_ylim(-plot_neurons - 1, 1)
    ax_trace.set_title('Neuron Activity Traces', color='white')
    ax_trace.tick_params(colors='white')
    ax_trace.set_facecolor('black')

    # Update function
    def update(frame_idx):
        frame = frame_idx * N
        heat_img.set_data(ex_all[:400, frame].reshape(20, 20))
        for i, line in enumerate(trace_lines):
            line.set_data(x_vals[:frame+1], trace_offsets[i, :frame+1])
        return trace_lines + [heat_img]

    # Animate
    num_frames = tmax // N
    ani = animation.FuncAnimation(
        fig, update, frames=num_frames, blit=True, interval=50, repeat=False
    )

    plt.close(fig)
    return HTML(ani.to_jshtml())

# ---- Run button callback ----
def on_button_clicked(b):
    with output:
        output.clear_output()
        display(run_simulation(g_slider.value))

run_button.on_click(on_button_clicked)

FloatSlider(value=0.8, description='Gain g:', max=5.0, min=0.1)

Button(description='Run Simulation', style=ButtonStyle())

Output()