In [3]:
%matplotlib inline


import os
import itertools
import numpy as np
import torch
import torch.jit
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import sklearn.decomposition
import sklearn.manifold
import sklearn.neighbors
import tqdm
import matplotlib.pyplot as plt
import matplotlib as mpl
import polars as pl
import scipy.stats
from mpl_toolkits.axes_grid1 import ImageGrid
import scipy.sparse
import matplotlib.pyplot as plt


In [4]:
assert torch.cuda.is_available()

AssertionError: 

In [None]:
base_run_dir = "/groups/saalfeld/home/kumarv4/repos/NeuralGraph/runs"

# tinker

In [11]:
a = torch.rand((50, 100))
inds = torch.unsqueeze(torch.tensor([10, 20, 30]), dim=1) + torch.unsqueeze(torch.arange(0), dim=0)
a[inds, :].shape

torch.Size([3, 0, 100])

## Analyze runs

In [None]:

from pathlib import Path
import glob
import yaml
from LatentEvolution.latent import ModelParams, LatentModel
from typing import Any, Dict, List, MutableMapping, Tuple

In [None]:
expt_code_prefix = "checkpoint_20251118_20251118"

#fully qualified
cand_dirs = glob.glob(f"{base_run_dir}/{expt_code_prefix}*")
assert len(cand_dirs) == 1
expt_code = os.path.basename(cand_dirs[0])
print(expt_code)


In [None]:
for cfg_path in Path(cand_dirs[0]).rglob("config.yaml"):
    print(cfg_path)

In [None]:
# raw data dir
sim_dir = f"{base_run_dir}/../graphs_data/fly/fly_N9_62_1/"

# Load model

In [None]:
device=torch.device("cuda")

In [None]:
pick_run_dir = Path(cfg_path).parent

with open(pick_run_dir / "config.yaml") as fin:
    raw = yaml.safe_load(fin)
model_params = ModelParams.model_validate(raw)

# load model

model = LatentModel(model_params).to(device)
model.load_state_dict(torch.load(f"{pick_run_dir}/model_final.pt"))
model.eval()


## Load data

In [None]:
from LatentEvolution.load_flyvis import SimulationResults, FlyVisSim
sim_data = SimulationResults.load(f"{sim_dir}x_list_0.npy")
neuron_data = sim_data.neuron_data


In [None]:
split = model_params.training.data_split
train_mat, val_mat, _ = sim_data.split_column(FlyVisSim.VOLTAGE, split)
stim_train, val_stim_mat, _ = sim_data.split_column(FlyVisSim.STIMULUS, split, keep_first_n_limit=1736)
val_data = torch.tensor(val_mat, device=device)
val_stim = torch.tensor(val_stim_mat, device=device)

## Load flyvis connectivity

In [None]:

wt = torch.load(f"{sim_dir}/weights.pt", map_location="cpu").numpy()
edge_index = torch.load(f"{sim_dir}/edge_index.pt", map_location="cpu").numpy()
voltage_rest = torch.load(f"{sim_dir}/V_i_rest.pt", map_location="cpu").numpy()
taus = torch.load(f"{sim_dir}/taus.pt", map_location="cpu").numpy()
# this is compatible with cedric's conventions
# Note: this is the transpose of what is in the flyvis paper so don't be confused
wmat = scipy.sparse.csr_matrix((wt, (edge_index[1], edge_index[0])))

In [None]:
plt.hist(wt, bins=np.linspace(-1, 1, 100))
plt.axvline(0.0, color="k")
plt.yscale("log")


In [2]:
torch

NameError: name 'torch' is not defined

In [None]:
# compute n hops to stimuli for each neuron
from collections import deque

visited = np.zeros(len(neuron_data.type), dtype=bool)
nhops = np.full(len(neuron_data.type), np.inf, dtype=np.float32)
stimuli_neurons = np.arange(1736)
nhops[stimuli_neurons] = 0
process = deque(stimuli_neurons.tolist())
visited[stimuli_neurons] = True

while process:
    node = process.popleft()  # FIFO: process in distance order
    state = nhops[node]
    nbrs = wmat.indices[wmat.indptr[node]:wmat.indptr[node+1]]

    for n in nbrs:
        if not visited[n]:  # First time reaching n = shortest path
            nhops[n] = state + 1
            visited[n] = True
            process.append(n)  # Add to back of queue



In [None]:
ndf = pl.DataFrame(
    {
        "t": neuron_data.type,
        "n_in": np.array((np.abs(wmat) != 0.).sum(axis=0))[0],
        "n_out": np.array((np.abs(wmat) != 0.).sum(axis=1))[:, 0],
        "nhops": nhops
    }
).join(
    pl.DataFrame({"name": neuron_data.TYPE_NAMES}).with_row_index("t").with_columns(pl.col("t").cast(pl.UInt8)), on="t", how="left"
)
with pl.Config(tbl_rows=100):
    print(ndf.group_by("name").agg(pl.col("nhops").filter(pl.col("n_in") > 0).mean(), pl.len()).sort("name"))


## Compute the jacobian

In [None]:
from torch.func import jacrev, vmap


def model_combined(xs):
    """
    xs: concatenated [x, s] of shape (13741 + 1736,)
    """
    x = xs[:13741]
    s = xs[13741:]
    return model(x.unsqueeze(0), s.unsqueeze(0)).squeeze(0)
x_points = np.zeros((10, 13741), dtype=np.float32)
x_points[:, 0] = np.linspace(-20, 20, 10)
x_points = torch.tensor(x_points, device=device)
s_points = torch.zeros((10, 1736), device=device)
# For multiple points
jac_combined_all = vmap(lambda xs: jacrev(model_combined)(xs))(
    torch.cat([x_points, s_points], dim=1)
).detach().cpu().numpy()  # shape: (num_points, 13741, 15477)

In [None]:
rvals = np.random.choice(jac_combined_all[5].ravel(), 1000)
evals = np.random.choice(jac_combined_all[5, edge_index[1], edge_index[0]], 1000)
tvals = np.random.choice(wmat.data, 1000)
plt.hist(rvals, bins=np.linspace(-0.02, 0.02, 51), alpha=0.2)
plt.hist(evals, bins=np.linspace(-0.02, 0.02, 51), alpha=0.2)
plt.hist(tvals*.02, bins=np.linspace(-0.02, 0.02, 51), alpha=0.2)

In [None]:
ix = np.random.randint(0, edge_index.shape[1], 10000)
plt.scatter(wt, jac_combined_all[5, edge_index[1], edge_index[0]], s=0.01, alpha=0.2)

In [None]:
plt.hist(disp_mat.ravel(), bins=100)

In [None]:
disp_mat = jac_combined_all[0, 0:217*2, 1736:1736 + 217*2]/.02
plt.imshow(disp_mat , cmap="Greys_r", vmax=0.3, vmin=0.0, extent=[1736, 1736+217*2, 217*2, 0])
plt.xlabel("neurons")
plt.ylabel("neurons")
plt.title("Jacobian neuron-neuron (same subset)")
plt.colorbar()

In [None]:
plt.imshow(wmat[0:217+217, 1736:1736 + 217*2].todense(), cmap="Greys_r", extent=[1736, 1736+217*2, 217*2, 0])
plt.colorbar()
plt.title("True weight matrix (subset)")

## Multi-step rollout evolution

In [None]:
def evolve_n_steps(model, initial_state, stimulus, n_steps):
    """
    Evolve the model by n time steps using the predicted state at each step.

    Args:
        model: The LatentModel to evolve
        initial_state: Initial state tensor of shape (neurons,)
        stimulus: Stimulus tensor of shape (T, stimulus_dim) where T >= n_steps
        n_steps: Number of time steps to evolve

    Returns:
        predicted_trace: Tensor of shape (n_steps, neurons) with predicted states
    """
    predicted_trace = []
    current_state = initial_state

    for t in range(n_steps):
        # Get the stimulus for this time step
        current_stimulus = stimulus[t:t+1]  # shape (1, stimulus_dim)

        # Evolve by one step
        next_state = model(current_state.unsqueeze(0), current_stimulus).squeeze(0)

        predicted_trace.append(next_state)

        # Use predicted state as input for next step
        current_state = next_state

    return torch.stack(predicted_trace, dim=0)


def compare_traces(model, real_trace, stimulus, n_steps, start_idx=0):
    """
    Compare real trace with predicted trace over n time steps.

    Args:
        model: The LatentModel to evolve
        real_trace: Real state tensor of shape (T, neurons)
        stimulus: Stimulus tensor of shape (T, stimulus_dim)
        n_steps: Number of time steps to predict
        start_idx: Starting index in the trace

    Returns:
        real_segment: Real trace segment of shape (n_steps, neurons)
        predicted_segment: Predicted trace segment of shape (n_steps, neurons)
        mse_per_step: MSE at each time step, shape (n_steps,)
        cumulative_mse: Cumulative average MSE up to each time step, shape (n_steps,)
    """
    # Get initial state
    initial_state = real_trace[start_idx]

    # Get stimulus segment
    stimulus_segment = stimulus[start_idx:start_idx + n_steps]

    # Get real trace segment (ground truth for the next n_steps)
    real_segment = real_trace[start_idx + 1:start_idx + n_steps + 1]

    # Predict n steps
    predicted_segment = evolve_n_steps(model, initial_state, stimulus_segment, n_steps)

    # Compute MSE per time step
    mse_per_step = torch.pow(predicted_segment - real_segment, 2).mean(dim=1)

    # Compute cumulative MSE
    cumulative_mse = torch.cumsum(mse_per_step, dim=0) / torch.arange(1, n_steps + 1, device=mse_per_step.device)

    return real_segment, predicted_segment, mse_per_step, cumulative_mse


In [None]:
# Example: Evolve for 50 time steps
n_steps = 2000
start_idx = 100  # Start from index 100 in the validation data

real_segment, predicted_segment, mse_per_step, cumulative_mse = compare_traces(
    model, val_data, val_stim[:, :1736], n_steps, start_idx
)

print(f"Shape of real segment: {real_segment.shape}")
print(f"Shape of predicted segment: {predicted_segment.shape}")
print(f"Final MSE (averaged over all neurons): {mse_per_step[-1].item():.6f}")
print(f"Final cumulative MSE: {cumulative_mse[-1].item():.6f}")

In [None]:
# Visualize MSE growth over time
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot MSE per time step
axes[0].plot(mse_per_step.detach().cpu().numpy(), marker='o', markersize=3)
axes[0].set_xlabel('Time Step')
axes[0].set_ylabel('MSE (averaged over neurons)')
axes[0].set_title('MSE at Each Time Step')
axes[0].grid(True, alpha=0.3)

# Plot cumulative MSE
axes[1].plot(cumulative_mse.detach().cpu().numpy(), marker='o', markersize=3, color='orange')
axes[1].set_xlabel('Time Step')
axes[1].set_ylabel('Cumulative Average MSE')
axes[1].set_title('Cumulative Average MSE Over Time')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()

In [None]:
# Visualize traces for specific cell types
# Cell types picked from GNN_PlotFigure.py line 4824
target_cell_types = ['R1', 'R7', 'C2', 'Mi11', 'Tm1', 'Tm4', 'Tm30']

# Get indices for each cell type
neuron_indices = []
for cell_type in target_cell_types:
    # Find type index
    type_idx = neuron_data.TYPE_NAMES.index(cell_type)
    # Pick a random neuron from this type using neuron_data.indices_per_type
    selected_neuron = np.random.choice(neuron_data.indices_per_type[type_idx])
    neuron_indices.append(selected_neuron)

# Plot traces
num_neurons_to_plot = len(neuron_indices)
fig, axes = plt.subplots(num_neurons_to_plot, 1, figsize=(18, 2.5 * num_neurons_to_plot))
if num_neurons_to_plot == 1:
    axes = [axes]

for i, neuron_idx in enumerate(neuron_indices):
    real_trace_cpu = real_segment[:, neuron_idx].detach().cpu().numpy()
    pred_trace_cpu = predicted_segment[:, neuron_idx].detach().cpu().numpy()

    # Get cell type name
    cell_type_idx = neuron_data.type[neuron_idx]
    cell_type_name = neuron_data.TYPE_NAMES[cell_type_idx]

    axes[i].plot(real_trace_cpu, label='Real', linewidth=2, alpha=0.7)
    axes[i].plot(pred_trace_cpu, label='Predicted', linewidth=2, alpha=0.7, linestyle='--')
    axes[i].set_xlabel('Time Step')
    axes[i].set_ylabel('Voltage')
    axes[i].set_title(f'{cell_type_name} (neuron {neuron_idx})')
    axes[i].legend()
    axes[i].grid(True, alpha=0.3)

plt.tight_layout()

In [None]:
wmat[1736, 0]