In [None]:
from pathlib import Path
import torch.multiprocessing as mp
from fnn.data import load_training_data
from fnn.microns.build import network
from fnn.microns import load_network_from_params
from fnn.train.schedulers import CosineLr
from fnn.train.optimizers import SgdClip
from fnn.train.loaders import Batches
from fnn.train.objectives import NetworkLoss
from fnn import microns
from fnn.utils import logging
import torch
import pandas as pd
import numpy as np

import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import polars as pl
import rerun as rr
import torchinfo
import sklearn.decomposition
import sklearn.manifold
import scipy.signal
import rastermap
import mediapy as mp

# Downloaded data paths

In [None]:

# training/evaluation data for this session. Note: all the data is uploaded but this data is
# already upsampled and ready for training/evaluation.
session=4
scan_idx=7

# Trained model parameters
src_dir = "/groups/saalfeld/saalfeldlab/vijay/fnn/data/microns_digital_twin/params"

# Training data
training_data_dir = f"/groups/saalfeld/saalfeldlab/vijay/fnn/training_data_27203_{session}_{scan_idx}"

# Evaluation data
evaluation_data_dir = Path(f"/groups/saalfeld/saalfeldlab/vijay/fnn/evaluation_data_27203_{session}_{scan_idx}")

unit_anatomy_path = "/groups/saalfeld/saalfeldlab/vijay/fnn/data/microns_digital_twin/properties/anatomy/units.csv"

In [None]:
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
unit_anatomy_df = pl.read_csv(unit_anatomy_path).filter(pl.col("session") == session).filter(
    pl.col("scan_idx") == scan_idx
)
print(unit_anatomy_df.shape)
unit_anatomy_df["brain_area"]

# Training data

In [None]:
max_items = None
dataset = load_training_data(training_data_dir, max_items=max_items)


In [None]:
# dataset is a custom object that wraps a dataframe 
dataset.df.head()

In [None]:
# Load the model and initialize weights


In [None]:
model, unit_map = microns.scan(
    session=session, 
    scan_idx=scan_idx, 
    directory=src_dir,
)
model = model.to(device)

In [None]:
torchinfo.summary(model, depth=2)

In [None]:
# # Pick a stimulus
# stimuli = np.concatenate([x[:] for x in dataset.df.stimuli], axis=0)
# perspectives = np.concatenate([x[:] for x in dataset.df.perspectives], axis=0)
# modulations = np.concatenate([x[:] for x in dataset.df.modulations], axis=0)

# with torch.no_grad():
#     res = model.to_tensor(stimuli, perspectives, modulations)
#     exported = torch.export.export(model, res[:3])

# Run model

In [None]:
# Pick a stimulus
index = 10
stimuli = dataset.df.stimuli.iloc[index][:]
perspectives = dataset.df.perspectives.iloc[index][:]
modulations = dataset.df.modulations.iloc[index][:]
units = dataset.df.units.iloc[index][:]

In [None]:
# predict
pred = model.predict(stimuli, perspectives, modulations)

In [None]:
# We have 8509 units measured. But the network only predicts 7493 units. 
# We seem to only be making predictions for a subset of units & this is captured by unit_map
pred.shape, unit_map.shape, units.shape

In [None]:
dataset.df

In [None]:
# readout id is the index of the readout unit that the network predicts
# unit_id maps you to the experimentally measured unit
assert (unit_map.index == np.arange(unit_map.shape[0])).all()
unit_gt = units[:, unit_map["unit_id"]]

In [None]:

# Create a figure
fig = plt.figure(figsize=(8, 6))

# Create an ImageGrid with a single colorbar for the entire grid
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(1, 2),
                 axes_pad=0.3,
                 aspect=False,
                 share_all=True,  # share x and y axes between all subplots
                 cbar_mode="single",  # use a single colorbar for the grid
                 cbar_location="right",
                 cbar_pad=0.1,
                 )
ax = grid.axes_all
ax[0].imshow(pred.T, aspect="auto", vmax=5, cmap="Grays_r")
ax[0].set_title("Predicted")
im = ax[1].imshow(unit_gt.T, aspect="auto", vmax=5, cmap="Grays_r")
ax[1].set_title("Measured")
grid.cbar_axes[0].colorbar(im)
ax[0].set_xlabel("Time")
ax[1].set_xlabel("Time")
ax[0].set_ylabel("Neurons")


In [None]:
mp.show_video(stimuli[:, :, :, 0], fps=10)

In [None]:
rmap_gt = rastermap.Rastermap(n_clusters=30)
rmap_gt.fit(unit_gt.T)


In [None]:

# Create a figure
fig = plt.figure(figsize=(8, 6))

# Create an ImageGrid with a single colorbar for the entire grid
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(1, 2),
                 axes_pad=0.3,
                 aspect=False,
                 share_all=True,  # share x and y axes between all subplots
                 cbar_mode="single",  # use a single colorbar for the grid
                 cbar_location="right",
                 cbar_pad=0.1,
                 )
ax = grid.axes_all
ax[0].imshow(pred.T[rmap_gt.isort, :], aspect="auto", vmax=5, cmap="Grays_r")
ax[0].set_title("Predicted")
im = ax[1].imshow(unit_gt.T[rmap_gt.isort, :], aspect="auto", vmax=5, cmap="Grays_r")
ax[1].set_title("Measured (rastermap source)")
grid.cbar_axes[0].colorbar(im)
ax[0].set_xlabel("Time")
ax[1].set_xlabel("Time")
ax[0].set_ylabel("Neurons")


In [None]:

# Create a figure
fig = plt.figure(figsize=(8, 6))

# Create an ImageGrid with a single colorbar for the entire grid
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(2, 1),
                 axes_pad=0.3,
                 aspect=False,
                 share_all=True,  # share x and y axes between all subplots
                #  cbar_mode="single",  # use a single colorbar for the grid
                #  cbar_location="right",
                #  cbar_pad=0.1,
                 )
ax = grid.axes_all

# ax[0].imshow(pred.T[rmap_gt.isort, :], aspect="auto", vmax=5, cmap="Grays_r")
# ax[0].set_title("Predicted")
ax[0].plot(modulations[:, 0])
ax[0].plot(modulations[:, 1])
ax[0].set_title("Modulation")
ax[1].plot(perspectives[:, 0])
ax[1].plot(perspectives[:, 1])
ax[1].set_title("Perspective")
for i in (0, 1):
    ax[i].grid(True)
    ax[i].set_xlim(0, 300)

In [None]:
plt.figure(figsize=(8, 6))
plt.imshow(unit_gt.T[rmap_gt.isort, :], aspect="auto", vmax=5, cmap="Grays_r")
plt.title("Measured activity (rastermap sorted)")

In [None]:

rmap_pred = rastermap.Rastermap(n_clusters=100)
rmap_pred.fit(pred.T)

plt.imshow(pred.T[rmap_pred.isort, :], aspect="auto", vmax=5, cmap="Grays_r")

In [None]:

# Create a figure
fig = plt.figure(figsize=(8, 6))

# Create an ImageGrid with a single colorbar for the entire grid
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(1, 2),
                 axes_pad=0.3,
                 aspect=False,
                 share_all=True,  # share x and y axes between all subplots
                 cbar_mode="single",  # use a single colorbar for the grid
                 cbar_location="right",
                 cbar_pad=0.1,
                 )
ax = grid.axes_all
ax[0].imshow(pred.T[rmap_pred.isort, :], aspect="auto", vmax=5, cmap="Grays_r")
ax[0].set_title("Predicted (rastermap source)")
im = ax[1].imshow(unit_gt.T[rmap_pred.isort, :], aspect="auto", vmax=5, cmap="Grays_r")
ax[1].set_title("Measured")
grid.cbar_axes[0].colorbar(im)
ax[0].set_xlabel("Time")
ax[1].set_xlabel("Time")
ax[0].set_ylabel("Neurons")


In [None]:
plt.plot(unit_gt.mean(1), label="measured")
plt.plot(pred.mean(1), label="prediction")
plt.legend()
plt.xlabel("time")
plt.ylabel("Averaged signal over neurons")

In [None]:
nsamples = int(300*6.3/30)

unit_gt_ds = scipy.signal.decimate(unit_gt, q=5, axis=0)
pred_ds = scipy.signal.decimate(pred, q=5, axis=0)

In [None]:

# Create a figure
fig = plt.figure(figsize=(8, 6))

# Create an ImageGrid with a single colorbar for the entire grid
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(1, 2),
                 axes_pad=0.3,
                 aspect=False,
                 share_all=True,  # share x and y axes between all subplots
                 cbar_mode="single",  # use a single colorbar for the grid
                 cbar_location="right",
                 cbar_pad=0.1,
                 )
ax = grid.axes_all
ax[0].imshow(pred_ds.T, aspect="auto", vmin=1, vmax=5, cmap="Grays_r")
ax[0].set_title("Predicted")
im = ax[1].imshow(unit_gt_ds.T, aspect="auto", vmin=1, vmax=5, cmap="Grays_r")
ax[1].set_title("Measured")
grid.cbar_axes[0].colorbar(im)
ax[0].set_xlabel("Neuron recording frames")
ax[1].set_xlabel("Neuron recording frames")
ax[0].set_ylabel("Neurons")


In [None]:
plt.plot(unit_gt_ds.mean(1), label="measured")
plt.plot(pred_ds.mean(1), label="prediction")
plt.legend()
plt.xlabel("Neuronal recording frames")
plt.ylabel("Averaged signal over neurons")

In [None]:
# 5 random neurons
# np.random.seed(0)
idxs = np.random.randint(0, 7493, size=5)
plt.figure(figsize=(12, 8))
for i in idxs:
    p = plt.plot(pred[:, i], label=f"unit={i} (pred)")
    plt.plot(unit_gt[:, i], ls="dotted", label=f"unit={i} (meas)", c = p[-1].get_color())
plt.legend(bbox_to_anchor=(1,1))
plt.xlabel("time")
plt.ylabel("Ca measurement")

In [None]:
res = plt.hist(np.sort(unit_gt.max(0) / np.mean(unit_gt, axis=0)), bins=70, histtype="step", label="measured")
plt.hist(np.sort(pred.max(0) / np.mean(pred, axis=0)), bins=res[1], histtype="step", label="prediction")
plt.legend()
plt.xlim(0, 100)
plt.xlabel("Max signal over time/mean signal over time")

# Evaluation data (repeat stimuli)

In [None]:
from fnn.data import load_evaluation_data
from fnn import evaluate
from pathlib import Path


evaluation_data = load_evaluation_data(evaluation_data_dir)

In [None]:
model, unit_map = microns.scan(
    session=session, 
    scan_idx=scan_idx, 
    directory=src_dir,
)
model = model.to(device)

## Readout weights

In [None]:
weights = torch.reshape(
    torch.concatenate(list(model.module("readout").module("feature").parameters()), axis=1), (-1, 512)).detach().cpu().numpy()

In [None]:
pca = sklearn.decomposition.PCA(n_components=50)
proj = pca.fit_transform(weights)

In [None]:



tsne = sklearn.manifold.TSNE(n_components=2)
tsne_proj = tsne.fit_transform(proj)

In [None]:
plt.scatter(proj[:, 0], proj[:, 1], marker=".", alpha=0.2)
plt.xlabel("PC1")
plt.ylabel("PC2")

In [None]:
plt.scatter(tsne_proj[:, 0], tsne_proj[:, 1], marker=".", alpha=0.2)
plt.xlabel("TSNE1")
plt.ylabel("TSNE2")

In [None]:
joined = pl.DataFrame(unit_map.reset_index()).join(unit_anatomy_df, on=["session", "scan_idx", "unit_id"], how="left")
for (grp,), piece in joined.sort("brain_area").group_by("brain_area", maintain_order=True):
    ixs = piece["readout_id"].to_numpy()
    plt.scatter(tsne_proj[ixs, 0], tsne_proj[ixs, 1], marker=".", alpha=0.2, label=grp)
    
plt.legend()
plt.xlabel("TSNE1")
plt.ylabel("TSNE2")


## Run on evaluation data

In [None]:
evaluation_data.keys()

In [None]:
from tqdm import tqdm

In [None]:
s = evaluation_data['stimuli']
p = evaluation_data['perspectives']
m = evaluation_data['modulations']

units_pred = []
for i in tqdm(range(len(s)), desc="Stimuli"):
    repeats_pred = []
    for j in range(len(s[i])):
        repeats_pred.append(
            model.predict(
                stimuli=s[i][j],
                perspectives=p[i][j],
                modulations=m[i][j]
            )
        )
    units_pred.append(repeats_pred)

In [None]:
units_pred = np.array(units_pred)

In [None]:
# videos x repeat trials x time x neurons
units_pred.shape

In [None]:
units_meas = np.array(evaluation_data['units'])[:, :, :, unit_map["unit_id"]]
units_meas.shape

## Rastermap analysis

In [None]:
# Pick a video
video_ix = 0

# Which trial to use for sorting
trial_ix = 0

In [None]:
stimulus_videos = np.concatenate([evaluation_data['stimuli'][i][trial_ix][:, :, :, 0] for i in range(6)], axis=-1)
mp.show_video(stimulus_videos, fps=30)
# stimulus_videos.shape

In [None]:
unit_anatomy_df = pl.read_csv(unit_anatomy_path).filter(pl.col("session") == session).filter(
    pl.col("scan_idx") == scan_idx
)
print(unit_anatomy_df.shape)

joined = pl.DataFrame(unit_map.reset_index()).join(unit_anatomy_df, on=["session", "scan_idx", "unit_id"], how="left")
joined.head()

In [None]:
# labels = joined["brain_area"][order]
# unique_labels = labels.unique()
# colors = plt.cm.Set1(np.linspace(0, 1, len(unique_labels)))
# color_map = dict(zip(unique_labels, colors))
# bar_colors = [color_map[label] for label in labels]

# plt.bar(range(len(labels)), [1]*len(labels), color=bar_colors, width=1.0)
# handles = [plt.Rectangle((0,0),1,1, color=color_map[label]) for label in unique_labels]
# plt.legend(handles, unique_labels)
# plt.show()

In [None]:
units_meas[video_ix].shape

In [None]:
# gut check: 3 trials, 2 time steps, 7 neurons. Trial 1 = 0, Trial 2 = 1, Trial 3=2
x = np.stack([np.zeros((2, 7)), np.ones((2, 7)), 2*np.ones((2, 7))], axis=0)
print(x.shape)

x.reshape((-1, 7)).T

In [None]:
orders = []
for video_ix in range(6):
    mat = units_meas[video_ix].reshape((-1, 7493)).T
    j = np.argmax(mat[:, trial_ix*300:(trial_ix+1)*300], axis=1)
    order = np.argsort(j)
    orders.append(order)
orders = np.array(orders)

In [None]:
import rerun as rr

In [None]:
grey = np.expand_dims(
    (np.array(mpl.colors.to_rgb("grey"))*255).astype(np.uint8), 0)
red =  np.expand_dims((np.array(mpl.colors.to_rgb("red"))*255).astype(np.uint8), 0)


In [None]:
rr.init("animate")
npos = joined.select(["unit_x", "unit_y", "unit_z"]).to_numpy()
rr.log("all", rr.Points3D(npos, colors=np.repeat(grey, 7493, axis=0), radii=2), static=True)
for i in range(7493):
    rr.set_time("peak", sequence=i)
    rr.log("neuron", rr.Points3D(npos[orders[:1, i]], colors=np.repeat(red, 1, axis=0), radii=[20]*1))
rr.notebook_show()

In [None]:
video_ix = 4
rmap_ix = 0

rmap = rastermap.Rastermap()
rmap.fit(units_meas[video_ix, rmap_ix, :, :].T)


In [None]:

mat = units_meas[video_ix].reshape((-1, 7493)).T
# j = np.argmax(mat[:, trial_ix*300:(trial_ix+1)*300], axis=1)
# order = np.argsort(j)
order = rmap.isort
rows, cols = mat.shape
dpi = 100
plt.figure(figsize=(cols/dpi, rows/dpi))
plt.imshow(mat[order, :], vmin=1, vmax=10, cmap="Greys_r", aspect="auto")
for i in range(1, 10):
    plt.axvline(i*300, color="orange", ls="dashed")
# plt.xticks(np.arange(0, 3000, 150), np.arange(0, 3000, 300))
plt.title(f"Video {video_ix+1}/6: Trials stacked along x")
plt.xlabel("Time")
plt.ylabel("Neurons")
plt.savefig("rastermap_sorted_neuron_activity_video4.png", dpi=dpi, bbox_inches="tight")




In [None]:

mat = units_meas[video_ix].reshape((-1, 7493)).T
# j = np.argmax(mat[:, trial_ix*300:(trial_ix+1)*300], axis=1)
# if video_ix == 0:
#     order = np.argsort(j)
rows, cols = mat.shape
dpi = 100
plt.figure(figsize=(cols/dpi, rows/dpi))
plt.imshow(mat[order, :][::20, :], vmin=1, vmax=10, cmap="Greys_r", aspect="auto")
for i in range(1, 10):
    plt.axvline(i*300, color="orange", ls="dashed")
# plt.xticks(np.arange(0, 3000, 150), np.arange(0, 3000, 300))
plt.title(f"Video {video_ix+1}/6: Trials stacked along x")
plt.xlabel("Time")
plt.ylabel("Neurons - downsampled 20x")
plt.savefig("rastermap_sorted_neuron_activity_video4_ds.png", dpi=dpi, bbox_inches="tight")
    



In [None]:
omat = np.clip(mat[order, :], 0, 10)
plt.plot(omat[0])
plt.plot(omat[100])
plt.plot(omat[200])
for i in range(1, 10):
    plt.axvline(i*300, color="k", ls="dashed")
plt.xlim(0, 3000)
plt.ylim(0, 10)

In [None]:
video_ix = 0
mat = units_meas[video_ix].reshape((-1, 7493)).T
j = np.argmax(mat[:, trial_ix*300:(trial_ix+1)*300], axis=1)

order = np.argsort(j)

mat2 = mat[order, :]

ixs = np.sort(np.random.randint(0, 7493, size=5))
for i in range(5):
    plt.scatter(np.arange(1, 3001), mat2[i, :], marker=".")
    
plt.ylim(3, None)
# plt.yscale("log")


In [None]:
unit_map

In [None]:

# Create a figure
fig = plt.figure(figsize=(8, 6))

# Create an ImageGrid with a single colorbar for the entire grid
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(20, 1),
                 axes_pad=(0., 0.),
                 aspect=False,
                 share_all=True,  # share x and y axes between all subplots
                #  cbar_mode="single",  # use a single colorbar for the grid
                #  cbar_location="right",
                #  cbar_pad=0.1,
                 )
ax = grid.axes_all
np.random.seed(123)
ixs =np.random.choice(7493, size=10)
for i, ix in enumerate(ixs):
    umeas = units_meas[video_ix, :, :, ix]
    upred = units_pred[video_ix, :, :, ix]
    vmax = np.percentile(umeas, 98)
    im = ax[2*i].imshow(umeas, vmin=1, vmax=vmax, cmap="Greys_r")
    vmin, vmax = im.get_clim()
    print(vmin, vmax)
    ax[2*i+1].imshow(upred, vmin=vmin, vmax=vmax)
    ax[2*i].set_yticks([])
    ax[2*i+1].set_yticks([])
    
# grid.cbar_axes[0].colorbar(im)
# plt.plot(units_meas[0, 0, :, 10])

In [None]:

rmap = rastermap.Rastermap()
rmap_ix = 0
rmap.fit(np.log1p(units_meas[video_ix, rmap_ix, :, :].T))


In [None]:
mat = units_meas[video_ix].reshape((-1, 7493)).T

j = np.argmax(mat[:, , rmap_ix*300:(rmap_ix+1)*300], axis=1)
# j = np.argmax(np.clip(mat[rmap.isort, rmap_ix*300:(rmap_ix+1)*300], 0, 5), axis=1)
order = np.argsort(j)

plt.figure(figsize=(12, 9))
plt.imshow(mat[order, :], vmax=10, cmap="Greys_r", aspect="auto")
for i in range(1, 10):
    plt.axvline(i*300, color="orange", ls="dashed")
plt.xticks(np.arange(0, 3000, 300))
plt.title("Repeat stimuli stacked together")
plt.xlabel("Time")
plt.ylabel("Neurons")

In [None]:
mat = units_pred[video_ix].reshape((-1, 7493)).T
plt.figure(figsize=(12, 9))
plt.imshow(mat[rmap.isort, :], vmax=5, cmap="Greys_r", aspect="auto")
for i in range(1, 10):
    plt.axvline(i*300, color="orange", ls="dashed")
plt.xticks(np.arange(0, 3000, 300))
plt.title("Predicted responses stacked")
plt.xlabel("Time")
plt.ylabel("Neurons")

In [None]:
video_ix = 0
rmap = rastermap.Rastermap()
rmap_ix = 4
rmap.fit(units_pred[video_ix, rmap_ix, :, :].T)


In [None]:
mat = units_pred[video_ix].reshape((-1, 7493)).T
plt.figure(figsize=(12, 9))
plt.imshow(mat[rmap.isort, :], vmax=5, cmap="Greys_r", aspect="auto")
for i in range(1, 10):
    plt.axvline(i*300, color="orange", ls="dashed")
plt.xticks(np.arange(0, 3000, 300))
plt.title("Predicted responses stacked")
plt.xlabel("Time")
plt.ylabel("Neurons")

In [None]:
mat = units_meas[video_ix].reshape((-1, 7493)).T
plt.figure(figsize=(12, 9))
plt.imshow(mat[rmap.isort, :], vmax=5, cmap="Greys_r", aspect="auto")
for i in range(1, 10):
    plt.axvline(i*300, color="orange", ls="dashed")
plt.xticks(np.arange(0, 3000, 300))
plt.title("Measured units stacked together")
plt.xlabel("Time")
plt.ylabel("Neurons")

In [None]:
rmap = rastermap.Rastermap()
sim_mat = (np.random.random((7500, 300)) < 0.05).astype(np.float32)
rmap.fit(sim_mat)

In [None]:
plt.imshow(sim_mat[rmap.isort, :], aspect="auto")

In [None]:
unit_map