In [1]:
import numpy as np
from tqdm import tqdm
from scipy.linalg import expm
import matplotlib.pyplot as plt
from PIL import Image
import io
import umap

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Parameters
T_rec, S, K, N, rank = 40, 5, 200, 50, 3
dt = 0.05      # smaller dt for stable attractor dynamics
tau = 1.0
n_sub = 4      # sub-integration steps — lets attractor converge between recordings

def relu(x):
    return np.maximum(0, x)

def make_rotation_generator(N, rng):
    A = rng.standard_normal((N, N))
    A = (A - A.T) / 2
    return A * 0.1

def initialize_simulation():
    rng = np.random.default_rng()
    theta_pref = np.linspace(0, 2*np.pi, N, endpoint=False)
    cos_diff = np.cos(theta_pref[:, None] - theta_pref[None, :])

    # Strong recurrent weights — bump is SELF-SUSTAINING
    # Bump-mode eigenvalue ~1.75 (amplifying), uniform-mode ~-1.05 (suppressing)
    W0 = 3.5 * (cos_diff - 0.3) / N

    rot_generator = make_rotation_generator(N, rng)
    return rng, theta_pref, W0, rot_generator

def simulate_ring_rnn(noise_schedule):
    """
    Genuinely recurrent ring attractor.

    vs. feedforward version:
      1. Bump self-sustains via recurrence (no per-step position template)
      2. Drift = WEIGHT rotation: W(t) = R(t) W0 R(t)^T
      3. Only weak velocity nudge (~5% of recurrent drive)
      4. No per-step normalization — attractor self-regulates amplitude
      5. Sub-integration steps let attractor dynamics play out
    """
    rng, theta_pref, W0, rot_gen = initialize_simulation()
    x_sim = np.zeros((T_rec, S, K, N))

    for t in tqdm(range(T_rec), desc="Simulating", leave=False):
        # Drift in WEIGHT SPACE: W(t) = R(t) W0 R(t)^T
        # Preserves attractor structure, rotates neuron-to-place mapping
        angle = t / T_rec * 2 * np.pi
        R_t = expm(rot_gen * angle)
        W_t = R_t @ W0 @ R_t.T

        noise_level = noise_schedule[t]

        for s in range(S):
            start_pos = rng.uniform(0, 2*np.pi)
            velocity = 4 * np.pi / K

            # Initial cue in ROTATED frame so it aligns with W(t)'s attractor
            cue_orig = 2.0 * np.exp(3 * np.cos(theta_pref - start_pos))
            x = relu(R_t @ cue_orig)

            # Settle: bump converges to attractor with ZERO external input
            for _ in range(20):
                dx = (-x + relu(W_t @ x)) * dt / tau
                x = x + dx
                x = np.clip(x, 0, 10)

            for k in range(K):
                target_pos = start_pos + velocity * k

                # Weak velocity nudge in ROTATED frame
                nudge_orig = 0.12 * np.exp(np.cos(theta_pref - target_pos))
                v_nudge = R_t @ nudge_orig

                for _ in range(n_sub):
                    noise = noise_level * rng.standard_normal(N)
                    dx = (-x + relu(W_t @ x + v_nudge + noise)) * dt / tau
                    x = x + dx
                    x = np.clip(x, 0, 10)

                x_sim[t, s, k] = x

    return x_sim

print(f"RNN Simulation ready: {T_rec} days, {S} sessions, {K} timesteps, {N} neurons")

RNN Simulation ready: 40 days, 5 sessions, 200 timesteps, 50 neurons


In [3]:


# Noise schedules for each condition (40 days)
expert_noise = np.ones(T_rec) * 0.03

learning_noise = np.zeros(T_rec)
learning_noise[:15] = np.linspace(0.15, 0.03, 15)
learning_noise[15:] = 0.03

switch_noise = np.zeros(T_rec)
switch_noise[:15] = np.linspace(0.15, 0.03, 15)
switch_noise[15:20] = 0.03
switch_noise[20:35] = np.linspace(0.15, 0.03, 15)
switch_noise[35:] = 0.03

print("Simulating Expert Mouse CTX A/B + C/D (RNN)...")
x_expert = simulate_ring_rnn(expert_noise)

print("Simulating Learning Mouse CTX A/B (RNN)...")
x_learning = simulate_ring_rnn(learning_noise)

print("Simulating Learning Mouse CTX A/B -> C/D (RNN)...")
x_switch = simulate_ring_rnn(switch_noise)

# Behavioral performance
rng_perf = np.random.default_rng()
perf_expert = 1 - (expert_noise - 0.03) / 0.12 + 0.12 * rng_perf.standard_normal(T_rec)
perf_learning = 1 - (learning_noise - 0.03) / 0.12 + 0.12 * rng_perf.standard_normal(T_rec)
perf_switch = 1 - (switch_noise - 0.03) / 0.12 + 0.12 * rng_perf.standard_normal(T_rec)

perf_expert = np.clip(perf_expert, 0.1, 1)
perf_learning = np.clip(perf_learning, 0.1, 1)
perf_switch = np.clip(perf_switch, 0.1, 1)

# UMAP embedding
print("Computing UMAP embeddings...")
all_data = np.vstack([x_expert.reshape(-1, N), x_learning.reshape(-1, N), x_switch.reshape(-1, N)])

x_mean = all_data.mean(axis=0)
x_std = all_data.std(axis=0) + 1e-8
all_clipped = np.clip(all_data, x_mean - 3*x_std, x_mean + 3*x_std)

reducer = umap.UMAP(n_components=3, n_neighbors=50, min_dist=0.3, n_jobs=-1)
all_embed = reducer.fit_transform(all_clipped)

n_pts = T_rec * S * K
embed_expert = all_embed[:n_pts].reshape(T_rec, S, K, 3)
embed_learning = all_embed[n_pts:2*n_pts].reshape(T_rec, S, K, 3)
embed_switch = all_embed[2*n_pts:].reshape(T_rec, S, K, 3)

for emb in [embed_expert, embed_learning, embed_switch]:
    for dim in range(3):
        p1, p99 = np.percentile(emb[:,:,:,dim], [1, 99])
        emb[:,:,:,dim] = np.clip(emb[:,:,:,dim], p1, p99)

def get_axis_limits(emb):
    em_min = emb.min(axis=(0,1,2))
    em_max = emb.max(axis=(0,1,2))
    margin = 0.1 * (em_max - em_min)
    return em_min, em_max, margin

limits_expert = get_axis_limits(embed_expert)
limits_learning = get_axis_limits(embed_learning)
limits_switch = get_axis_limits(embed_switch)

print("UMAP complete, generating GIF...")

titles = ["Expert Mouse CTX A/B + C/D", "Learning Mouse CTX A/B", "Learning Mouse CTX A/B → C/D"]
perfs = [perf_expert, perf_learning, perf_switch]
embeds = [embed_expert, embed_learning, embed_switch]
limits = [limits_expert, limits_learning, limits_switch]

pil_frames = []
fig = plt.figure(figsize=(15, 8))

for day in tqdm(range(T_rec)):
    fig.clf()
    gs = fig.add_gridspec(2, 3, height_ratios=[1, 2.5], hspace=0.4)

    for col, (title, perf) in enumerate(zip(titles, perfs)):
        ax = fig.add_subplot(gs[0, col])
        ax.plot(range(day + 1), perf[:day + 1], 'darkred', linewidth=2)
        ax.scatter(day, perf[day], c='darkred', s=80, zorder=5)
        ax.set_xlim(-0.5, T_rec - 0.5)
        ax.set_ylim(0, 1.1)
        ax.set_xlabel('Days')
        ax.set_ylabel('Beh. Perf.')
        ax.set_title(title, fontsize=11)
        ax.axhline(1.0, color='gray', linestyle='--', alpha=0.5)

    fig.text(0.5, 0.6, r'Activity Space $\Phi(W(t))$', fontsize=13,
             ha='center', va='center', fontweight='bold')

    for col, (title, emb, lim) in enumerate(zip(titles, embeds, limits)):
        ax = fig.add_subplot(gs[1, col], projection='3d')
        day_points = emb[day].reshape(-1, 3)
        ax.scatter(day_points[:, 0], day_points[:, 1], day_points[:, 2],
                   c="darkred", alpha=0.4, s=2)
        em_min, em_max, margin = lim
        ax.set_xlim(em_min[0] - margin[0], em_max[0] + margin[0])
        ax.set_ylim(em_min[1] - margin[1], em_max[1] + margin[1])
        ax.set_zlim(em_min[2] - margin[2], em_max[2] + margin[2])
        ax.set_xlabel('UMAP1')
        ax.set_ylabel('UMAP2')
        ax.set_zlabel('UMAP3')
        ax.set_title(f'Day {day}', fontsize=10)

    buf = io.BytesIO()
    fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
    buf.seek(0)
    pil_frames.append(Image.open(buf).copy())
    buf.close()

plt.close()

pil_frames[0].save(
    'ring_attractor_rnn.gif',
    save_all=True,
    append_images=pil_frames[1:],
    duration=200,
    loop=0
)

print(f"Saved: ring_attractor_rnn.gif ({len(pil_frames)} frames)")

Simulating Expert Mouse CTX A/B + C/D (RNN)...


                                                           

Simulating Learning Mouse CTX A/B (RNN)...


                                                           

Simulating Learning Mouse CTX A/B -> C/D (RNN)...


                                                           

Computing UMAP embeddings...
UMAP complete, generating GIF...


100%|██████████| 40/40 [00:14<00:00,  2.81it/s]


Saved: ring_attractor_rnn.gif (40 frames)
