In [None]:
import itertools
import numpy as np
import torch
import torch.jit
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import sklearn.decomposition
import sklearn.manifold
import sklearn.neighbors
import tqdm
import matplotlib.pyplot as plt
import matplotlib as mpl
import polars as pl
import scipy.stats
from mpl_toolkits.axes_grid1 import ImageGrid
from scipy.ndimage import gaussian_filter1d

%matplotlib inline

In [None]:
# sim_dir = "/groups/saalfeld/home/allierc/Py/NeuralGraph/graphs_data/fly/fly_N9_54_1"
# use local copy - faster
sim_dir = "/mnt/localdata/fly_N9_54_1"

In [None]:
x = np.load(f"{sim_dir}/x_list_0.npy")

In [None]:
_, num_cells, _ = x.shape
print(x.shape, x.dtype)

In [None]:
neuron_types = x[0, :, 6].astype(np.int32)

neuron_type_name = [
    "Am",
    "C2",
    "C3",
    "CT1(Lo1)",
    "CT1(M10)",
    "L1",
    "L2",
    "L3",
    "L4",
    "L5",
    "Lawf1",
    "Lawf2",
    "Mi1",
    "Mi10",
    "Mi11",
    "Mi12",
    "Mi13",
    "Mi14",
    "Mi15",
    "Mi2",
    "Mi3",
    "Mi4",
    "Mi9",
    "R1",
    "R2",
    "R3",
    "R4",
    "R5",
    "R6",
    "R7",
    "R8",
    "T1",
    "T2",
    "T2a",
    "T3",
    "T4a",
    "T4b",
    "T4c",
    "T4d",
    "T5a",
    "T5b",
    "T5c",
    "T5d",
    "Tm1",
    "Tm16",
    "Tm2",
    "Tm20",
    "Tm28",
    "Tm3",
    "Tm30",
    "Tm4",
    "Tm5Y",
    "Tm5a",
    "Tm5b",
    "Tm5c",
    "Tm9",
    "TmY10",
    "TmY13",
    "TmY14",
    "TmY15",
    "TmY18",
    "TmY3",
    "TmY4",
    "TmY5a",
    "TmY9",
]
neuron_type_index = {t: i for i, t in enumerate(neuron_type_name)}


def compute_ixs_per_type(neuron_types):
    """Compute indices corresponding to each neuron type."""
    order = np.argsort(neuron_types)
    uniq_types, start_index = np.unique(neuron_types[order], return_index=True)
    num_neuron_types = len(uniq_types)
    assert (uniq_types == np.arange(num_neuron_types)).all(), "breaks assumptions"
    breaks = np.zeros(len(uniq_types) + 1, dtype=np.int64)
    breaks[:-1] = start_index
    breaks[-1] = len(neuron_types)
    return [order[breaks[i] : breaks[i + 1]] for i in range(num_neuron_types)]


neuron_ixs_by_type = compute_ixs_per_type(neuron_types)

In [None]:
BURNIN_OFFSET = 100
OBS_TIME_STEPS = 20

obs_ca = x[BURNIN_OFFSET::OBS_TIME_STEPS, :, 7].copy()
train_start = 0
validation_start = 3000
test_start = 3500

train_mat = obs_ca[train_start:validation_start]
val_mat = obs_ca[validation_start:test_start]

In [None]:
stimulus = gaussian_filter1d(x[BURNIN_OFFSET:, :, 4], sigma=OBS_TIME_STEPS / 2)[
    ::OBS_TIME_STEPS
]

In [None]:
train_stimulus = stimulus[train_start:validation_start]
val_stimulus = stimulus[validation_start:test_start]

## aside: spectral embedding of graph - see hexagonal structure

In [None]:
# load some simulation network params

wt = torch.load(f"{sim_dir}/weights.pt")
edge_index = torch.load(f"{sim_dir}/edge_index.pt")
voltage_rest = torch.load(f"{sim_dir}/V_i_rest.pt")
taus = torch.load(f"{sim_dir}/taus.pt")

In [None]:
from scipy import sparse
from scipy.sparse.linalg import eigsh


def spectral_embed(E, N, edim):
    """
    Perform a 2D spectral embedding of a directed graph.

    Parameters
    ----------
    E : np.ndarray
        2×M array of edges, where E[0, i] is the source and E[1, i] is the target.
        Nodes are labeled 0,…,N-1.
    N : int
        Total number of nodes.

    Returns
    -------
    coords : np.ndarray of shape (N, 2)
        The 2D spectral embedding coordinates for each node (rows correspond to nodes).
    """
    # --- 1. Build sparse adjacency matrix ---
    src, dst = E
    data = np.ones(len(src), dtype=float)
    A = sparse.coo_matrix((data, (src, dst)), shape=(N, N))

    # --- 2. Make adjacency symmetric (convert to undirected for Laplacian) ---
    A = ((A + A.T) > 0).astype(float)

    # --- 3. Compute normalized Laplacian L = I - D^{-1/2} A D^{-1/2} ---
    deg = np.array(A.sum(axis=1)).flatten()
    D_inv_sqrt = sparse.diags(1.0 / np.sqrt(np.maximum(deg, 1e-12)))
    L = sparse.eye(N) - D_inv_sqrt @ A @ D_inv_sqrt

    # --- 4. Compute 2nd and 3rd smallest eigenvectors (skip trivial one) ---
    vals, vecs = eigsh(L, k=edim + 1, which="SM")
    coords = vecs[:, 1 : edim + 1]

    return coords


def spectral_embed_2d_weighted(E, W, N):
    """
    Spectral embedding of a possibly signed, weighted directed graph.
    Converts to an undirected signed graph for embedding.

    Parameters
    ----------
    E : np.ndarray, shape (2, M)
        Edge endpoints (source, target).
    W : np.ndarray, shape (M,)
        Edge weights (can be positive or negative).
    N : int
        Number of nodes (0..N-1).

    Returns
    -------
    coords : np.ndarray, shape (N, 2)
        2D spectral embedding coordinates for each node.
    """
    src, dst = E

    # --- 1. Build sparse weighted adjacency matrix ---
    A = sparse.coo_matrix((W, (src, dst)), shape=(N, N))
    # Symmetrize (take average to preserve sign symmetry)
    A = 0.5 * (A + A.T)

    # --- 2. Degree matrix (based on absolute weights to stay PSD) ---
    deg = np.array(np.abs(A).sum(axis=1)).flatten()
    D_inv_sqrt = sparse.diags(1.0 / np.sqrt(np.maximum(deg, 1e-12)))

    # --- 3. Normalized signed Laplacian ---
    L = sparse.eye(N) - D_inv_sqrt @ A @ D_inv_sqrt

    # --- 4. Compute smallest nontrivial eigenvectors ---
    vals, vecs = eigsh(L, k=3, which="SM")
    coords = vecs[:, 1:3]  # skip trivial eigenvector

    return coords

In [None]:
eembed = spectral_embed(edge_index.cpu().numpy(), voltage_rest.shape[0], 10)

In [None]:
plt.scatter(eembed[:, 0], eembed[:, 1], alpha=0.1, marker=".")

## Find the optimal SVD dimension

In [None]:
ndims = np.array([8, 16, 32, 64, 128])

recon_train = np.zeros((len(ndims), train_mat.shape[0], train_mat.shape[1]))
recon_val = np.zeros((len(ndims), val_mat.shape[0], val_mat.shape[1]))

svd = sklearn.decomposition.TruncatedSVD(n_components=ndims.max(), random_state=321)
proj = svd.fit_transform(train_mat)
proj_val = svd.transform(val_mat)
for i, n in enumerate(ndims):
    recon_train[i, :, :] = np.matmul(proj[:, :n], svd.components_[:n, :])
    recon_val[i, :, :] = np.matmul(proj_val[:, :n], svd.components_[:n, :])

In [None]:
delta_train = recon_train - train_mat[np.newaxis, :, :]
delta_val = recon_val - val_mat[np.newaxis, :, :]

In [None]:
# Compute variance along the time dimension
# We are most interested in learning about the dynamics per neuron
# So we focus on this dimension rather than along neuron space

var_train = np.var(train_mat, axis=0)
var_train_unexpl = np.var(delta_train, axis=1)
r2_train = 1 - var_train_unexpl / var_train
var_val = np.var(val_mat, axis=0)
var_val_unexpl = np.var(delta_val, axis=1)
r2_val = 1 - var_val_unexpl / var_val

In [None]:
den, edges = np.histogram(var_train, bins=15)
for i, n in enumerate(ndims):
    num, _ = np.histogram(
        var_train, bins=edges, weights=(r2_train[i] < 0).astype(np.float32)
    )
    plt.plot(edges[1:], num / den, label=f"L={n}")
plt.legend()
plt.ylabel("Fraction R2 < 0")
plt.xlabel("Variance of neuron time trace")

In [None]:
den, edges = np.histogram(var_train, bins=15)
for i, n in enumerate(ndims):
    num, _ = np.histogram(var_train, bins=edges, weights=r2_train[i])
    plt.plot(edges[1:], num / den, label=f"L={n}")
plt.legend()
plt.ylabel("Mean R2")
plt.xlabel("Variance of neuron time trace")

## Fix latent dimension=256 and proceed

In [None]:
L = 256
svd = sklearn.decomposition.TruncatedSVD(n_components=L)
svd.fit(train_mat)

# import scipy.sparse.linalg
# U, S, VT = scipy.sparse.linalg.svds(train_mat, k=64)

In [None]:
proj = svd.transform(val_mat)
recon = np.dot(proj, svd.components_)

### analyze how well we can reconstruct

In [None]:
# assess how good the reconstruction is across neurons (on validation data)

delta = recon - val_mat
# noise per time point in the reconstructing from the latent space
err_t = delta.std(axis=1)
# signal variation
sigma_t = val_mat.std(axis=1)
plt.scatter(sigma_t, err_t, marker=".", alpha=0.5)
plt.xlabel("Sigma across neurons")
plt.ylabel("Reconstruction sigma")

In [None]:
# Errors are gaussian
for i in np.random.randint(val_mat.shape[1], size=5):
    s = np.std(val_mat[:, i])
    plt.hist(delta[:, i], histtype="step", label=f"neuron {i} (sigma={s:.2f})")
plt.legend(bbox_to_anchor=(1, 1))

In [None]:
for i in np.random.randint(val_mat.shape[0], size=5):
    s = np.std(val_mat[i])
    plt.hist(delta[i], histtype="step", label=f"time {i} (sigma={s:.2f})")
plt.legend(bbox_to_anchor=(1, 1))

In [None]:
# How well do we capture the variation over time using the reconstruction

# error sigma per neuron
err_n = delta.std(axis=0)
# signal variation
sigma_n = train_mat.std(axis=0)
plt.scatter(sigma_n, err_n, marker=".", alpha=0.2)
plt.xlabel("Sigma across time")
plt.ylabel("Reconstruction sigma")

In [None]:
plt.figure(figsize=(16, 6))
for ix in np.sort(np.random.randint(val_mat.shape[1], size=1)):
    p = plt.plot(val_mat[:, ix], label=f"Neuron {ix} original")
    plt.plot(
        recon[:, ix], c=p[-1].get_color(), ls="", marker=".", label=f"Neuron {ix} recon"
    )
plt.xlabel("time")
plt.legend(bbox_to_anchor=(1, 1))

## learn the latent space update (linear)

In [None]:
device = torch.device("cuda")

In [None]:
@torch.compile(
    fullgraph=True,
    mode="reduce-overhead",
    #    mode="max-autotune"
)
def loss_fn(evolve_mat, train_proj):
    nmin = 1
    nmax = 2
    loss = torch.as_tensor(0.0, device=torch.device("cuda"))
    for i in range(nmin, nmax):
        emat = torch.matrix_power(evolve_mat, i)
        loss += torch.pow(
            torch.linalg.matmul(train_proj[:-i, :], emat) - train_proj[i:], 2
        ).mean()
    return loss


# matrix initialized to 1+epsilon
init_mat = (torch.eye(L) + (torch.rand((L, L)) - 0.5) / 3.0).to(device)
evolve_mat = torch.nn.Parameter(init_mat)

train_proj = torch.tensor(svd.transform(train_mat), device=device)
optimizer = torch.optim.Adam([evolve_mat], lr=1e-3)

# train_loop(evolve_mat, train_proj, optimizer)
loop = tqdm.trange(10_000, ncols=100)
for t in loop:
    loss = loss_fn(evolve_mat, train_proj)
    if t % 10 == 0:
        loop.set_postfix(loss=loss.item())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

In [None]:
learned_mat = evolve_mat.detach().cpu().numpy()

In [None]:
plt.imshow(learned_mat, vmax=1, vmin=0.1)

In [None]:
# rng = np.random.default_rng(seed=321)
# t0s = rng.integers(0, train_mat.shape[0]-30, size=5)

T = 30
t0 = 1000
x0 = train_mat[t0 : t0 + 1, :]
p0 = svd.transform(x0)

results = [p0]
for t in range(T):
    x1 = np.matmul(results[-1], learned_mat)
    results.append(x1)
pred_trace = np.stack([np.matmul(r, svd.components_) for r in results], axis=0)
act_trace = train_mat[t0 : t0 + T + 1]
recon_each = svd.inverse_transform(svd.transform(train_mat[t0 : t0 + T + 1, :]))

In [None]:
plot_neuron_types = np.sort(np.random.choice(neuron_type_name, 10))
# ['R1', 'R7', 'C2', 'Mi11', 'Tm1', 'Tm4', 'Tm30']

_, ax = plt.subplots(len(plot_neuron_types), 1, figsize=(8, 12), sharex=True)
tvals = np.arange(t0, t0 + T + 1)
rng = np.random.default_rng(seed=123)
picks = [rng.choice(nixs) for nixs in neuron_ixs_by_type]

for i, ptype in enumerate(plot_neuron_types):
    nix = picks[neuron_type_index[ptype]]
    true_trace = train_mat[t0 : t0 + T + 1, nix]
    ax[i].plot(tvals, true_trace)
    ax[i].set_ylim(true_trace.min() * 0.8, true_trace.max() * 1.2)
    # time evolve
    ax[i].plot(
        tvals,
        pred_trace[:, 0, nix],
        color=p[-1].get_color(),
        ls="dotted",
        label="learn linear evolver",
    )

    # reconstruct each point using SVD
    # ax[i].plot(tvals, recon_each[:, nix], color=p[-1].get_color(), ls="dashed", marker=".", label="reconstruct each time point")
    ax[i].set_ylabel(ptype)

plt.subplots_adjust(hspace=0)

In [None]:
nix = np.random.randint(0, train_mat.shape[1])
print(f"{nix=}")
plt.figure(figsize=(16, 6))


# actual trace
trace = train_mat[t0 : t0 + T + 1, nix]
p = plt.plot(tvals, trace, label="actual trace")

# time evolve
plt.plot(
    tvals,
    pred_trace[:, 0, nix],
    color=p[-1].get_color(),
    ls="dotted",
    label="learn linear evolver",
)

# reconstruct each point using SVD
plt.plot(
    tvals,
    recon_each[:, nix],
    color=p[-1].get_color(),
    ls="dashed",
    label="reconstruct each time point",
)

plt.axvline(t0 + 10, color="k", ls="dotted")
plt.xticks(tvals)

plt.xlabel("Time")
plt.ylabel("Ca activity")
plt.legend()
plt.title(f"Neuron trace {nix=}")
pass

## Try a non-linear update

In [None]:
device = torch.device("cuda")

In [None]:
class MLP(torch.jit.ScriptModule):
    def __init__(self, num_latent_dims, num_hidden_units, num_hidden_layers):
        super().__init__()
        self.layers = torch.nn.ModuleList()
        input_dims = num_latent_dims
        for i in range(num_hidden_layers):
            self.layers.append(torch.nn.Linear(input_dims, num_hidden_units))
            self.layers.append(torch.nn.ReLU())
            input_dims = num_hidden_units
        self.output = torch.nn.Linear(num_hidden_units, num_latent_dims)

    @torch.jit.script_method
    def forward(self, x):
        y = x
        for i, layer in enumerate(self.layers):
            if i == 0:
                y = layer(y)
            else:
                y = y + layer(y)
        return x + self.output(y)

    @torch.jit.script_method
    def loss_fn(self, x):
        loss = torch.as_tensor(0.0, device=torch.device("cuda"))
        start = torch.zeros_like(x, device=torch.device("cuda"))
        start[:, :] = x[:, :]
        for i in range(1, 5):
            end = self.forward(start)
            loss += torch.pow(x[i:] - end[:-i], 2).mean()
            start = end
        return loss

In [None]:
mlp = MLP(num_latent_dims=L, num_hidden_units=8, num_hidden_layers=3).to(device)

In [None]:
train_proj = torch.tensor(svd.transform(train_mat), device=device)
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-3)

loop = tqdm.trange(10_000, ncols=100)
for t in loop:
    loss = mlp.loss_fn(train_proj)
    if t % 100 == 0:
        loop.set_postfix(loss=loss.item())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

In [None]:
loop = tqdm.trange(10_000, ncols=100)
for t in loop:
    loss = mlp.loss_fn(train_proj)
    if t % 100 == 0:
        loop.set_postfix(loss=loss.item())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

In [None]:
# rng = np.random.default_rng(seed=321)
# t0s = rng.integers(0, train_mat.shape[0]-30, size=5)

T = 30
t0 = 1000
x0 = train_mat[t0 : t0 + 1, :]
p0 = svd.transform(x0)

results = [p0]
for t in range(T):
    x1 = mlp(torch.tensor(results[-1], device=device))
    results.append(x1.detach().cpu().numpy())
pred_trace = np.stack([svd.inverse_transform(r) for r in results], axis=0)
act_trace = train_mat[t0 : t0 + T + 1]
recon_each = svd.inverse_transform(svd.transform(train_mat[t0 : t0 + T + 1, :]))

In [None]:
plot_neuron_types = np.sort(np.random.choice(neuron_type_name, 8))
# ['R1', 'R7', 'C2', 'Mi11', 'Tm1', 'Tm4', 'Tm30']

_, ax = plt.subplots(len(plot_neuron_types), 1, figsize=(8, 12), sharex=True)
tvals = np.arange(t0, t0 + T + 1)
rng = np.random.default_rng(seed=123)
picks = [rng.choice(nixs) for nixs in neuron_ixs_by_type]

for i, ptype in enumerate(plot_neuron_types):
    nix = picks[neuron_type_index[ptype]]
    true_trace = train_mat[t0 : t0 + T + 1, nix]
    ax[i].plot(tvals, true_trace)
    ax[i].set_ylim(true_trace.min() * 0.8, true_trace.max() * 1.2)
    # time evolve
    ax[i].plot(
        tvals,
        pred_trace[:, 0, nix],
        color=p[-1].get_color(),
        ls="dotted",
        label="learn linear evolver",
    )

    # reconstruct each point using SVD
    ax[i].plot(
        tvals,
        recon_each[:, nix],
        color=p[-1].get_color(),
        ls="dashed",
        label="reconstruct each time point",
    )
    ax[i].set_ylabel(ptype)

plt.subplots_adjust(hspace=0)

## Study neighborhoods in latent space

In [None]:
L = 256
svd = sklearn.decomposition.TruncatedSVD(n_components=L)
svd.fit(train_mat)

proj = svd.transform(train_mat)

In [None]:
mean_dist = [0.0]
gaps = np.arange(11)
for gap in gaps[1:]:
    mean_dist.append(
        np.linalg.norm(train_mat[:-gap, :] - train_mat[gap:, :], axis=1).mean()
    )
plt.plot(gaps, mean_dist, marker="o")
plt.xticks(gaps)
plt.xlabel("time step delta")
plt.grid(True)
plt.ylabel("Mean orig space distance")

In [None]:
mean_dist = [0.0]
gaps = np.arange(11)
for gap in gaps[1:]:
    mean_dist.append(np.linalg.norm(proj[:-gap, :] - proj[gap:, :], axis=1).mean())
plt.plot(gaps, mean_dist, marker="o")
plt.xticks(gaps)
plt.xlabel("time step delta")
plt.grid(True)
plt.ylabel("Mean latent space distance")

In [None]:
for gap in (1, 2, 3, 4, 5, 10):
    log_dist = np.log(np.linalg.norm(proj[:-gap] - proj[gap:], axis=1))
    mu, sig = scipy.stats.norm.fit(log_dist)
    print(f"{mu=:.2f}, {sig=:.2f}")
    xvs = np.linspace(log_dist.min(), log_dist.max(), 100)
    p = plt.hist(log_dist, histtype="step", label=f"{gap} steps", density=True)
    plt.plot(
        xvs,
        scipy.stats.norm.pdf(xvs, mu, sig),
        color=p[-1][0].get_edgecolor(),
        ls="dashed",
    )
plt.legend()

In [None]:
tree = sklearn.neighbors.BallTree(proj)
dist, ninds = tree.query(proj, k=10, return_distance=True, sort_results=True)

In [None]:
dt = ninds[:, 1:] - ninds[:, :1]
# order = np.argsort(dt, axis=1)
for i in np.unique(np.random.randint(0, 3000, size=5)):
    plt.scatter(
        dt[i],
        dist[i][1:],
    )
plt.ylabel("latent space distance")
plt.xlabel("time step separation")
plt.xlim(-3000, 3000)
plt.ylim(0, None)
plt.grid(True)

## Add in the stimulus that we forgot about

Don't do SVD, just optimize a linear model

In [None]:
# Latent space for neural activity

import itertools

L = 256

activity_tensor = torch.tensor(train_mat).to(device)
stim_tensor = torch.tensor(train_stimulus[:, :1736]).to(device)

train_tensor = torch.concatenate([activity_tensor, stim_tensor], dim=-1)
encoder = torch.nn.Linear(in_features=train_tensor.shape[1], out_features=L).to(device)

decoder_activity = torch.nn.Linear(
    in_features=L, out_features=activity_tensor.shape[1]
).to(device)

evolver = torch.nn.Linear(in_features=L, out_features=L).to(device)


@torch.compile(fullgraph=True, mode="reduce-overhead")
def compute_loss(train_tensor, activity_tensor):
    proj_act = encoder(train_tensor)
    evolved = evolver(proj_act)
    recon_next_time_step = decoder_activity(evolved)

    recon = decoder_activity(proj_act)

    recon_loss = torch.nn.MSELoss()(recon, activity_tensor)
    evolve_loss = torch.nn.MSELoss()(recon_next_time_step[:-1], activity_tensor[1:])
    return evolve_loss, recon_loss


optimizer = torch.optim.Adam(
    list(
        itertools.chain(
            encoder.parameters(), decoder_activity.parameters(), evolver.parameters()
        )
    ),
    lr=1e-5,
)

# train_loop(evolve_mat, train_proj, optimizer)
loop = tqdm.trange(10_000, ncols=100)
for t in loop:
    evolve_loss, recon_loss = compute_loss(train_tensor, activity_tensor)
    loss = evolve_loss + recon_loss
    loop.set_postfix(loss=loss.item())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

In [None]:
# train_loop(evolve_mat, train_proj, optimizer)
loop = tqdm.trange(10_000, ncols=100)
for t in loop:
    evolve_loss, recon_loss = compute_loss(train_tensor, activity_tensor)
    loss = evolve_loss + recon_loss
    loop.set_postfix(evolve=evolve_loss.item(), recon=recon_loss.item())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

In [None]:
proj_act = encoder(train_tensor)
recon_each = decoder_activity(proj_act).detach().cpu().numpy()

In [None]:
T = 50
t0 = 1000

In [None]:
p0 = proj_act[t0 : t0 + 1]

results = [p0]

for t in range(T):
    p1 = evolver(results[-1])
    results.append(p1)

pred_trace = decoder_activity(torch.concatenate(results, dim=0)).detach().cpu().numpy()
act_trace = train_mat[t0 : t0 + T + 1]
recon_trace = recon_each[t0 : t0 + T + 1]

In [None]:
plot_neuron_types = np.sort(np.random.choice(neuron_type_name, 10))
# ['R1', 'R7', 'C2', 'Mi11', 'Tm1', 'Tm4', 'Tm30']

_, ax = plt.subplots(len(plot_neuron_types), 1, figsize=(8, 12), sharex=True)
tvals = np.arange(t0, t0 + T + 1)
rng = np.random.default_rng(seed=123)
picks = [rng.choice(nixs) for nixs in neuron_ixs_by_type]

for i, ptype in enumerate(plot_neuron_types):
    nix = picks[neuron_type_index[ptype]]
    true_trace = train_mat[t0 : t0 + T + 1, nix]
    p = ax[i].plot(tvals, true_trace)
    ax[i].set_ylim(true_trace.min() * 0.8, true_trace.max() * 1.2)
    # time evolve
    ax[i].plot(
        tvals,
        pred_trace[:, nix],
        color=p[-1].get_color(),
        ls="dotted",
        label="learn linear evolver",
    )

    ax[i].plot(
        tvals,
        recon_trace[:, nix],
        color=p[-1].get_color(),
        ls="dashed",
        marker=".",
        label="reconstruct each time point",
    )
    ax[i].set_ylabel(ptype)

plt.subplots_adjust(hspace=0)

In [None]:
# What if we predicted a constant?
err_constant = np.sqrt(
    np.power(train_mat[t0 : t0 + 1] - train_mat[t0 : t0 + T + 1, :], 2).mean(axis=1)
)

In [None]:
delta_t = np.sqrt(np.power(pred_trace - train_mat[t0 : t0 + T + 1, :], 2).mean(axis=1))
plt.plot(delta_t, label="rms error")
plt.plot(err_constant, ls="dotted", label="rms error x(t) = x(t0)")
plt.legend()
plt.xlabel("time")
plt.grid(True)