In [None]:
%matplotlib inline


import os
import itertools
import numpy as np
import torch
import torch.jit
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import sklearn.decomposition
import sklearn.manifold
import sklearn.neighbors
import tqdm
import matplotlib.pyplot as plt
import matplotlib as mpl
import polars as pl
import scipy.stats
from mpl_toolkits.axes_grid1 import ImageGrid
import scipy.sparse
import matplotlib.pyplot as plt


In [None]:
assert torch.cuda.is_available()

In [None]:
base_run_dir = "/groups/saalfeld/home/kumarv4/repos/NeuralGraph/runs"

## Analyze runs

In [None]:

from pathlib import Path
import glob
import yaml
from LatentEvolution.latent import ModelParams, LatentModel
from typing import Any, Dict, List, MutableMapping, Tuple

In [None]:
expt_code_prefix = "input_skips_sweep"

#fully qualified
cand_dirs = glob.glob(f"{base_run_dir}/{expt_code_prefix}*")
assert len(cand_dirs) == 1
expt_code = os.path.basename(cand_dirs[0])
print(expt_code)


In [None]:
for cfg_path in Path(cand_dirs[0]).rglob("config.yaml"):
    print(cfg_path)

In [None]:
# raw data dir
sim_dir = f"{base_run_dir}/../graphs_data/fly/fly_N9_62_1/"

# Analyze one model

In [None]:
device=torch.device("cuda")

In [None]:
pick_run_dir = Path("/groups/saalfeld/home/kumarv4/repos/NeuralGraph/runs/input_skips_sweep_20251114_bd7c276/encoder_params_use_input_skipsTrue/decoder_params_use_input_skipsTrue/stimulus_encoder_params_use_input_skipsTrue/evolver_params_learnable_diagonalTrue/ep200/a7309e/")

with open(pick_run_dir / "config.yaml") as fin:
    raw = yaml.safe_load(fin)
model_params = ModelParams.model_validate(raw)

# load model

model = LatentModel(model_params).to(device)
model.load_state_dict(torch.load(f"{pick_run_dir}/model_final.pt"))
model.eval()


## Load flyvis connectivity

In [None]:

wt = torch.load(f"{sim_dir}/weights.pt", map_location="cpu").numpy()
edge_index = torch.load(f"{sim_dir}/edge_index.pt", map_location="cpu").numpy()
voltage_rest = torch.load(f"{sim_dir}/V_i_rest.pt", map_location="cpu").numpy()
taus = torch.load(f"{sim_dir}/taus.pt", map_location="cpu").numpy()
wmat = scipy.sparse.csr_matrix((wt, (edge_index[0], edge_index[1])))

## neuron traces

In [None]:
from LatentEvolution.load_flyvis import SimulationResults, FlyVisSim

In [None]:
sim_data = SimulationResults.load(f"{sim_dir}x_list_0.npy")

In [None]:
neuron_data = sim_data.neuron_data
tindex = neuron_data.TYPE_NAMES.index("Mi12")



In [None]:
xpos = neuron_data.pos[0]
ypos = neuron_data.pos[1]


In [None]:
# compute n hops to stimuli for each neuron
from collections import deque

visited = np.zeros(len(neuron_data.type), dtype=bool)
nhops = np.full(len(neuron_data.type), np.inf, dtype=np.float32)
stimuli_neurons = np.arange(1736)
nhops[stimuli_neurons] = 0
process = deque(stimuli_neurons.tolist())
visited[stimuli_neurons] = True

while process:
    node = process.popleft()  # FIFO: process in distance order
    state = nhops[node]
    nbrs = wmat.indices[wmat.indptr[node]:wmat.indptr[node+1]]

    for n in nbrs:
        if not visited[n]:  # First time reaching n = shortest path
            nhops[n] = state + 1
            visited[n] = True
            process.append(n)  # Add to back of queue



In [None]:
ndf = pl.DataFrame(
    {
        "t": neuron_data.type,
        "n_in": np.array((np.abs(wmat) != 0.).sum(axis=0))[0],
        "n_out": np.array((np.abs(wmat) != 0.).sum(axis=1))[:, 0],
        "nhops": nhops
    }
).join(
    pl.DataFrame({"name": neuron_data.TYPE_NAMES}).with_row_index("t").with_columns(pl.col("t").cast(pl.UInt8)), on="t", how="left"
)
with pl.Config(tbl_rows=100):
    print(ndf.group_by("name").agg(pl.col("nhops").filter(pl.col("n_in") > 0).mean(), pl.len()).sort("name"))

In [None]:
split = model_params.training.data_split
train_mat, val_mat, _ = sim_data.split_column(FlyVisSim.VOLTAGE, split)
stim_train, stim_val, _ = sim_data.split_column(FlyVisSim.STIMULUS, split, keep_first_n_limit=1736)

In [None]:
val_data = torch.tensor(val_mat, device=device)
val_stim = torch.tensor(stim_val, device=device)

In [None]:
# reconstruction

proj = model.encoder(val_data)
recon = model.decoder(proj)

In [None]:
# val_stim.shape, val_data.shape
evolved = model(val_data, val_stim[:, :1736])

In [None]:
delta = (evolved[:-1] - val_data[1:])
evolve_mse = torch.pow(delta, 2).mean(dim=0).detach().cpu().numpy()

delta = (recon - val_data)
recon_mse = torch.pow(delta, 2).mean(dim=0).detach().cpu().numpy()

In [None]:
join = pl.concat([ndf, pl.DataFrame(
    {
        "recon": recon_mse,
        "evolve": evolve_mse,
        "var_val": np.var(val_mat, axis=0),
        "var_train": np.var(train_mat, axis=0),
    }
)], how="horizontal")

In [None]:
res = join.group_by("name").agg(
    pl.col("recon").mean(),
    pl.col("evolve").mean(),
    pl.col("var_val").max(),
    pl.col("nhops").median(),
    pl.col("n_in").median(),
    pl.col("n_out").median(),
)

plt.figure()
x = "var_val"
y = "n_in"
c = "nhops"
plt.scatter(join[x], join[y], c=join[c], marker=".", alpha=0.3, cmap="Oranges")
for row in res.rows(named=True):
    plt.text(row[x], row[y], row['name'], fontsize=8)
cbar = plt.colorbar()
cbar.ax.set_ylabel(c)
plt.xscale("log")
plt.yscale("log")

## Compute the jacobian

In [None]:
from torch.func import jacrev, vmap


def model_combined(xs):
    """
    xs: concatenated [x, s] of shape (13741 + 1736,)
    """
    x = xs[:13741]
    s = xs[13741:]
    return model(x.unsqueeze(0), s.unsqueeze(0)).squeeze(0)
x_points = np.zeros((10, 13741), dtype=np.float32)
x_points[:, 0] = np.linspace(-20, 20, 10)
x_points = torch.tensor(x_points, device=device)
s_points = torch.zeros((10, 1736), device=device)
# For multiple points
jac_combined_all = vmap(lambda xs: jacrev(model_combined)(xs))(
    torch.cat([x_points, s_points], dim=1)
).detach().cpu().numpy()  # shape: (num_points, 13741, 15477)

In [None]:
rvals = np.random.choice(jac_combined_all[5].ravel(), 1000)
evals = np.random.choice(jac_combined_all[5, edge_index[1], edge_index[0]], 1000)
tvals = np.random.choice(wmat.data, 1000)
plt.hist(rvals, bins=np.linspace(-0.02, 0.02, 51), alpha=0.2)
plt.hist(evals, bins=np.linspace(-0.02, 0.02, 51), alpha=0.2)
plt.hist(tvals*.02, bins=np.linspace(-0.02, 0.02, 51), alpha=0.2)

In [None]:
disp_mat = jac_combined_all[5, 0:217, 1736:1736 + 217]
plt.imshow(disp_mat , cmap="Greys_r")
plt.xlabel("neurons")
plt.ylabel("neurons")
plt.title("Jacobian")
plt.colorbar()

In [None]:
plt.imshow(wmat[217:217+217, 1736:1736 + 217].todense(), cmap="Greys_r")
plt.colorbar()