In [None]:
import torch, numpy as np, matplotlib.pyplot as plt
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchdiffeq import odeint
from sklearn.metrics import r2_score
from scipy.linalg import subspace_angles


In [None]:
plt.rcParams['figure.figsize'] = (6,4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

In [None]:
class SlowODE(nn.Module):
    def __init__(self, r):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(r+1, 64), nn.Tanh(), nn.Linear(64, r))
    def forward(self, t, z):
        t = t if isinstance(t, float) else t.item()
        tvec = torch.full_like(z[:, :1], t)
        return self.net(torch.cat([z, tvec], 1))

class PINN(nn.Module):
    def __init__(self, data, device="cpu"):
        super().__init__()
        r = data["rank"]
        self.K = data["K"]
        self.W0 = data["W0"].to(device)
        self.U  = data["U"].to(device)
        self.V  = data["V"].to(device)
        self.z0 = nn.Parameter(torch.zeros(r, device=device))
        self.slow = SlowODE(r)
        self.B = nn.Parameter(data["B"].clone().to(device))
        self.R = nn.Parameter(data["R"].clone().to(device))
        self.b_fixed = data["b"].to(device)

    def weights(self, T):
        times = torch.arange(T, dtype=torch.float32, device=self.W0.device)
        z = odeint(self.slow, self.z0.unsqueeze(0), times).squeeze(1)
        return self.W0 + self.U @ torch.diag_embed(z) @ self.V.T

    def rnn_cell(self, xk, vk, W):
        return torch.tanh(W @ xk + (self.B * vk).squeeze() + self.b_fixed)

    def forward(self, x0, v_seq, t_idx, Ws):
        W = Ws[t_idx]; rec = dec = 0; x_prev = x0[0]
        for k in range(self.K - 1):
            x_pred = self.rnn_cell(x_prev, v_seq[k], W)
            rec += (x_pred - x0[k+1]).pow(2).sum()
            dec += (self.R @ x_pred - v_seq[k]).pow(2)
            x_prev = x_pred
        return rec + 0.1 * dec

In [None]:
class SplitDS(Dataset):
    def __init__(self, x, v, days):
        self.x = x
        self.v = v
        self.days = days
        self.S = x.shape[1]

    def __len__(self):
        return len(self.days) * self.S

    def __getitem__(self, idx):
        t = self.days[idx // self.S]
        s = idx % self.S
        return self.x[t, s], self.v[t, s], t

def train_destinode(data, train_days=[0, 1, 2], epochs=500, lr=1e-3, batch_size=4):
    loader = DataLoader(
        SplitDS(data["x"], data["v"], train_days),
        batch_size=batch_size, shuffle=True
    )
    model = PINN(data, device=device).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    for ep in range(epochs):
        L = 0
        for xb, vb, tb in loader:
            xb, vb = xb.to(device), vb.to(device)
            Ws = model.weights(data["T_rec"])
            loss = torch.stack([
                model(xb[i], vb[i], tb[i], Ws)
                for i in range(xb.size(0))
            ]).mean()
            opt.zero_grad(); loss.backward(); opt.step()
            L += loss.item()
        if ep % 100 == 0:
            print(f"  ep {ep}: {L/len(loader):.4f}")

    return model

In [None]:
# ============================================================
# Two-Context Synthetic Data Generator
# ============================================================
# Timeline:  Days 0-3  → Expert A (drift only)
#            Day  4    → Learning (context switch A→B)
#            Days 5-9  → Expert B (drift only)
# ============================================================

@torch.no_grad()
def generate_two_context(
    drift_scale=1.0,   # controls |V|
    learn_scale=1.0,   # controls |L|
    T_rec=10, S=5, K=40, N=50, rank=3,
    seed=0, device="cpu"
):
    torch.manual_seed(seed)

    # ---- shared structural backbone (same across contexts) ----
    W0 = 0.0005 * torch.randn(N, N, device=device)
    Uf, _, Vtf = torch.linalg.svd(W0)
    U, V = Uf[:, :rank], Vtf[:rank, :].T
    B  = 0.05 * torch.randn(N, 1, device=device)
    b  = 0.1  * torch.randn(N, device=device)
    R  = 0.05 * torch.randn(1, N, device=device)

    # ---- context A latents: smooth drift, days 0-3 ----
    t_A = torch.linspace(0, torch.pi, 4, device=device)
    z_A = drift_scale * torch.stack(
        [torch.sin(t_A + torch.pi * i / rank) + 0.1 * t_A
         for i in range(rank)], dim=1
    )  # [4, rank]

    # ---- counterfactual: linear extrapolation of A drift to day 4 ----
    z_counterfactual_day4 = z_A[-1] + (z_A[-1] - z_A[-2])

    # ---- learning jump: random direction, magnitude = learn_scale ----
    torch.manual_seed(seed + 999)
    jump_dir = torch.randn(rank, device=device)
    jump_dir = jump_dir / jump_dir.norm()
    z_B_start = z_A[-1] + learn_scale * jump_dir   # where B begins

    # ---- context B latents: smooth drift, days 4-9 ----
    t_B = torch.linspace(0, torch.pi, 6, device=device)
    z_B_intrinsic = drift_scale * torch.stack(
        [torch.sin(t_B + torch.pi * i / rank + 2.0) + 0.1 * t_B
         for i in range(rank)], dim=1
    )  # [6, rank]
    z_B = z_B_start + (z_B_intrinsic - z_B_intrinsic[0])  # shift origin

    # ---- full z_true: [10, rank] ----
    z_true = torch.cat([z_A, z_B], dim=0)

    # ---- build W(t), simulate RNN ----
    Ws_true = W0 + torch.einsum("nr,tr,mr->tnm", U, z_true, V)

    x = torch.zeros(T_rec, S, K, N, device=device)
    v = torch.zeros(T_rec, S, K, device=device)
    for t in range(T_rec):
        for s in range(S):
            v_ts = 0.5 * torch.randn(K, device=device)
            x_prev = 0.1 * torch.randn(N, device=device)
            x[t, s, 0] = x_prev
            for k in range(K - 1):
                x_prev = torch.tanh(Ws_true[t] @ x_prev
                                    + (B * v_ts[k]).squeeze() + b)
                x[t, s, k + 1] = x_prev
            v[t, s] = v_ts

    return dict(
        # backbone
        W0=W0, U=U, V=V, B=B, b=b, R=R,
        N=N, rank=rank, T_rec=T_rec, S=S, K=K,
        # ground truth
        z_true=z_true, z_A=z_A, z_B=z_B, Ws_true=Ws_true,
        z_counterfactual_day4=z_counterfactual_day4,
        # observables
        x=x, v=v,
        # params for bookkeeping
        drift_scale=drift_scale, learn_scale=learn_scale,
    )

In [None]:
# ============================================================
# 4 Scenarios (each generates one full dataset)
# ============================================================
scenarios = {
    "high_V_high_L": dict(drift_scale=2.0, learn_scale=2.0),
    "low_V_high_L":  dict(drift_scale=0.3, learn_scale=2.0),
    "high_V_low_L":  dict(drift_scale=2.0, learn_scale=0.3),
    "low_V_low_L":   dict(drift_scale=0.3, learn_scale=0.3),
}

datasets = {}
for name, params in scenarios.items():
    datasets[name] = generate_two_context(**params, seed=42)

In [None]:
# ============================================================
# Extract V (drift) and L (learning) per neuron
# ============================================================
def extract_V_and_L(data, model):
    """
    V_i = mean daily drift magnitude for neuron i (row norm of dW/dt)
    L_i = learning residual magnitude for neuron i on day 4
    """
    Ws_pred = model.weights(data["T_rec"]).detach()

    # V: average daily weight change during expert A (days 0-3)
    dW_daily = torch.stack([Ws_pred[t] - Ws_pred[t-1] for t in range(1, 4)])
    V_per_neuron = dW_daily.mean(0).norm(dim=1)  # [N]

    # L: actual day-4 weights minus what drift model predicts
    W_observed  = data["Ws_true"][4].to(Ws_pred.device)
    W_predicted = Ws_pred[4]  # counterfactual from drift-only model
    L_per_neuron = (W_observed - W_predicted).norm(dim=1)  # [N]

    return V_per_neuron.cpu().numpy(), L_per_neuron.cpu().numpy()

In [None]:
# ============================================================
# 2×2 panel: one scatter per scenario
# ============================================================
fig, axes = plt.subplots(2, 2, figsize=(8, 8), sharex=True, sharey=True)
axes = axes.flatten()

for idx, (name, data) in enumerate(datasets.items()):
    # --- train DestinODE on days 0-2 (expert A only) ---
    model = train_destinode(data, train_days=[0, 1, 2])

    Vi, Li = extract_V_and_L(data, model)
    ax = axes[idx]
    ax.scatter(Vi, Li, c="k", s=20, alpha=0.6)

    # quadrant lines at median
    ax.axvline(np.median(Vi), color="gray", ls="--", lw=0.8)
    ax.axhline(np.median(Li), color="gray", ls="--", lw=0.8)

    ax.set_title(name.replace("_", " "), fontsize=11)
    ax.set_xlabel("|V| drift")
    ax.set_ylabel("|L| learning")

    # annotate quadrant counts
    n_hh = np.sum((Vi > np.median(Vi)) & (Li > np.median(Li)))
    n_hl = np.sum((Vi > np.median(Vi)) & (Li <= np.median(Li)))
    n_lh = np.sum((Vi <= np.median(Vi)) & (Li > np.median(Li)))
    n_ll = np.sum((Vi <= np.median(Vi)) & (Li <= np.median(Li)))
    ax.text(0.95, 0.95, f"{n_hh}", transform=ax.transAxes, ha="right", va="top")
    ax.text(0.05, 0.95, f"{n_lh}", transform=ax.transAxes, ha="left",  va="top")
    ax.text(0.95, 0.05, f"{n_hl}", transform=ax.transAxes, ha="right", va="bottom")
    ax.text(0.05, 0.05, f"{n_ll}", transform=ax.transAxes, ha="left",  va="bottom")

plt.suptitle("Drift vs Learning per neuron", fontsize=13)
plt.tight_layout()
plt.show()