In [1]:
%matplotlib inline

from pathlib import Path
import yaml
import os
import itertools
import numpy as np
import torch
import torch.jit
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import sklearn.decomposition
import sklearn.manifold
import sklearn.neighbors
import tqdm
import matplotlib.pyplot as plt
import matplotlib as mpl
import polars as pl
import scipy.stats
from mpl_toolkits.axes_grid1 import ImageGrid
import scipy.sparse
import matplotlib.pyplot as plt


In [2]:
assert torch.cuda.is_available()
device=torch.device("cuda")

In [None]:
base_run_dir = "/groups/saalfeld/home/kumarv4/repos/NeuralGraph/runs"

# raw data dir
data_dir = os.path.abspath(f"{base_run_dir}/../graphs_data/fly/fly_N9_62_1_youtube-vos_calcium")

# Load model

In [None]:
from LatentEvolution.latent import LatentModel
from LatentEvolution.eed_model import ModelParams

pick_run_dir = Path(f"{base_run_dir}/tu20_youtube_baseline_20260121_b4c18b9/39d474")

with open(pick_run_dir / "config.yaml") as fin:
    raw = yaml.safe_load(fin)
model_params = ModelParams.model_validate(raw)

# load model

model = LatentModel(model_params).to(device)
model.load_state_dict(torch.load(f"{pick_run_dir}/model_final.pt"))
model.eval()


## Load data

In [None]:
rng = np.random.default_rng(seed=123)

In [None]:
from LatentEvolution.load_flyvis import load_column_slice, FlyVisSim, load_metadata, NeuronData
data_path = f"{data_dir}/x_list_0"
metadata = load_metadata(data_path)
neuron_data = NeuronData.from_metadata(metadata)
train_start_ix = int(rng.integers(10_000, 900_000))
tu = 20
ems = 1
T = ems*tu
train_data = torch.from_numpy(load_column_slice(data_path, FlyVisSim.VOLTAGE, time_start=train_start_ix, time_end=train_start_ix+T, neuron_limit=None)).to(device)
stim_data = torch.from_numpy(load_column_slice(data_path, FlyVisSim.STIMULUS, time_start=train_start_ix, time_end=train_start_ix+T, neuron_limit=1736)).to(device)

In [None]:
N = train_data.shape[1]
L = 256


In [None]:

# acquisition phase for each neuron within 0...20 window. Randomly assigned.
neuron_phases = torch.from_numpy(rng.integers(0, tu, N)).to(device)


In [None]:
z_trained = model.encoder(train_data)

In [None]:

# freeze learned model
for p in model.parameters():
    p.requires_grad = False
model.eval()
model.to(device)


In [None]:
proj_stim = model.stimulus_encoder(stim_data)

In [None]:
obs_times = torch.arange(0, T, tu, device=device).unsqueeze(1) + neuron_phases.unsqueeze(0)
obs_times

In [None]:
obs_t = obs_times.ravel()
neuron_ixs = torch.arange(N, device=device).repeat(ems)

In [None]:
measured = train_data[obs_t, neuron_ixs]

In [None]:
@torch.compile(mode="reduce-overhead", fullgraph=True)
def train(z0):

    z = z0
    zs = [z]
    for t in range(tu-1):
        z = model.evolver(z, proj_stim[t:t+1])
        zs.append(z)
    zpred = torch.cat(zs, dim=0)
    xpred = model.decoder(zpred)
    pred = xpred[obs_t, neuron_ixs]
    loss = torch.nn.functional.mse_loss(pred, measured)
    return loss, zpred

In [None]:

loss, zpred = train(torch.zeros((1, L), device=device))

In [None]:
mat = zpred.detach().cpu()
mat_gt = z_trained.detach().cpu()

def make_plot(mat, mat_gt):
    _, ax = plt.subplots(2, 1, sharex=True, sharey=True)
    im = ax[0].imshow(mat, aspect=3)
    ax[0].set_title("pred latent")
    for i in (0, 1):
        ax[i].set_ylabel("time")
    vmin, vmax = im.get_clim()
    ax[1].imshow(mat_gt, vmin=vmin, vmax=vmax, aspect=3)
    ax[1].set_title("model latent")
    plt.tight_layout()

make_plot(mat, mat_gt)

In [None]:
measured.shape

In [None]:
# initial condition
torch.manual_seed(1214)
z0 = torch.nn.Parameter(model.encoder(measured.unsqueeze(0)))

optimizer = torch.optim.Adam([z0], lr=1e-3)
loop = tqdm.trange(1000)
for i in loop:
    optimizer.zero_grad(True)
    loss, zpred = train(z0)
    loss.backward()
    loop.set_postfix({"loss": loss.detach().item()})
    # if i % 500 == 0:
    #     plt.close()
    #     make_plot(zpred.detach().cpu(), mat_gt)

    optimizer.step()



In [None]:
make_plot(zpred.detach().cpu(), mat_gt)

In [None]:
proj_stim.shape


In [None]:
model(train_data, stim_data)

## Load flyvis connectivity

In [None]:

wt = torch.load(f"{sim_dir}/weights.pt", map_location="cpu").numpy()
edge_index = torch.load(f"{sim_dir}/edge_index.pt", map_location="cpu").numpy()
voltage_rest = torch.load(f"{sim_dir}/V_i_rest.pt", map_location="cpu").numpy()
taus = torch.load(f"{sim_dir}/taus.pt", map_location="cpu").numpy()
# this is compatible with cedric's conventions
# Note: this is the transpose of what is in the flyvis paper so don't be confused
wmat = scipy.sparse.csr_matrix((wt, (edge_index[1], edge_index[0])))

# wmat2 = scipy.sparse.csc_matrix((wt, (edge_index[1], edge_index[0])))

In [None]:
wmat = scipy.sparse.csr_matrix((wt, (edge_index[1], edge_index[0])))

In [None]:
# this is an R1 neuron



vals = []

for _ in range(100):
    needed_neurons = np.zeros(13741, dtype=bool)
    for i in np.random.choice(13741, replace=False, size=100):


        # which other neurons impact its value
        ixs = wmat.indices[wmat.indptr[i]:wmat.indptr[i+1]]
        needed_neurons[ixs] = True
        # for j in ixs:
        #     print(j, ":", names[t[j]], "->", names[t[i]], i)
    vals.append(needed_neurons.mean())

# x_t = true initial activity at time t, x_t_1 = activity at t+1
# constrain

# x_t_1[i] must be

In [None]:
svd = sklearn.decomposition.TruncatedSVD(n_components=5000)
svd.fit(wmat)

In [None]:
plt.plot(svd.singular_values_)
# plt.axvline(217)
# plt.axvline(217*2)
plt.xlim(0, 65)


In [None]:
(wmat.data).sum()

In [None]:
# compute n hops to stimuli for each neuron
from collections import deque

visited = np.zeros(len(neuron_data.type), dtype=bool)
nhops = np.full(len(neuron_data.type), np.inf, dtype=np.float32)
stimuli_neurons = np.arange(1736)
nhops[stimuli_neurons] = 0
process = deque(stimuli_neurons.tolist())
visited[stimuli_neurons] = True

while process:
    node = process.popleft()  # FIFO: process in distance order
    state = nhops[node]
    nbrs = wmat.indices[wmat.indptr[node]:wmat.indptr[node+1]]

    for n in nbrs:
        if not visited[n]:  # First time reaching n = shortest path
            nhops[n] = state + 1
            visited[n] = True
            process.append(n)  # Add to back of queue



In [None]:
ndf = pl.DataFrame(
    {
        "t": neuron_data.type,
        "n_in": np.array((np.abs(wmat) != 0.).sum(axis=0))[0],
        "n_out": np.array((np.abs(wmat) != 0.).sum(axis=1))[:, 0],
        "nhops": nhops
    }
).join(
    pl.DataFrame({"name": neuron_data.TYPE_NAMES}).with_row_index("t").with_columns(pl.col("t").cast(pl.UInt8)), on="t", how="left"
)
with pl.Config(tbl_rows=100):
    print(ndf.group_by("name").agg(pl.col("nhops").filter(pl.col("n_in") > 0).mean(), pl.len()).sort("name"))


## Compute the jacobian

In [None]:
from torch.func import jacrev, vmap


def model_combined(xs):
    """
    xs: concatenated [x, s] of shape (13741 + 1736,)
    """
    x = xs[:13741]
    s = xs[13741:]
    return model(x.unsqueeze(0), s.unsqueeze(0)).squeeze(0)
x_points = np.zeros((10, 13741), dtype=np.float32)
x_points[:, 0] = np.linspace(-20, 20, 10)
x_points = torch.tensor(x_points, device=device)
s_points = torch.zeros((10, 1736), device=device)
# For multiple points
jac_combined_all = vmap(lambda xs: jacrev(model_combined)(xs))(
    torch.cat([x_points, s_points], dim=1)
).detach().cpu().numpy()  # shape: (num_points, 13741, 15477)

In [None]:
rvals = np.random.choice(jac_combined_all[5].ravel(), 1000)
evals = np.random.choice(jac_combined_all[5, edge_index[1], edge_index[0]], 1000)
tvals = np.random.choice(wmat.data, 1000)
plt.hist(rvals, bins=np.linspace(-0.02, 0.02, 51), alpha=0.2)
plt.hist(evals, bins=np.linspace(-0.02, 0.02, 51), alpha=0.2)
plt.hist(tvals*.02, bins=np.linspace(-0.02, 0.02, 51), alpha=0.2)

In [None]:
ix = np.random.randint(0, edge_index.shape[1], 10000)
plt.scatter(wt, jac_combined_all[5, edge_index[1], edge_index[0]], s=0.01, alpha=0.2)

In [None]:
plt.hist(disp_mat.ravel(), bins=100)

In [None]:
disp_mat = jac_combined_all[0, 0:217*2, 1736:1736 + 217*2]/.02
plt.imshow(disp_mat , cmap="Greys_r", vmax=0.3, vmin=0.0, extent=[1736, 1736+217*2, 217*2, 0])
plt.xlabel("neurons")
plt.ylabel("neurons")
plt.title("Jacobian neuron-neuron (same subset)")
plt.colorbar()

In [None]:
plt.imshow(wmat[0:217+217, 1736:1736 + 217*2].todense(), cmap="Greys_r", extent=[1736, 1736+217*2, 217*2, 0])
plt.colorbar()
plt.title("True weight matrix (subset)")