In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import scienceplots
import os, pickle

import torch

from spks import viz
from spks import event_aligned as ea

from joblib import Parallel, delayed

from lib import data, spks_utils, fitlvm_utils, fit_models, models, utils

%load_ext autoreload
%autoreload 2
%reload_ext autoreload

In [None]:
# pretty plots
plt.style.use(['nature'])
plt.rcParams['figure.dpi'] = 200
%matplotlib inline
%config InlineBackend.print_figure_kwargs = {'bbox_inches':None}

## One Session

In [None]:
out = Parallel(n_jobs=8, backend='loky')(delayed(fit_models.fit_sess)(unit_spike_times, trial_data, session_data, regions, subj_idx, sess_idx, num_latents=num_latents) for num_latents in range(1,10))

In [None]:
das  = [out[latent_idx][0] for latent_idx in range(8)]
figs = [out[latent_idx][1] for latent_idx in range(8)]
assert all([das[latent_idx]['affine']['model'].gain_mu.get_weights().shape[1] for latent_idx in range(len(das))] == np.arange(1,9))

In [None]:
coupling = das[1]['affine']['model'].readout_gain.weight.data[:].T

In [None]:
reg = 'ACC'
idxs = np.where(reg_keys==0)[0]
coupling_reg = coupling[idxs]
coupling_reg[:,0] == coupling_reg[:,1]

### Grid Search

In [None]:
def fit_sess_latents_gs(subj_idx, sess_idx):
    return Parallel(n_jobs=8, backend='loky')(delayed(fit_models.fit_sess)(unit_spike_times_all[subj_idx][sess_idx], trial_data_all[subj_idx][sess_idx], session_data_all[subj_idx][sess_idx], regions_all[subj_idx][sess_idx], subj_idx, sess_idx, num_latents=num_latents) for num_latents in range(1,8+1))

In [None]:
# out = fit_sess_latents_gs(subj_idx=0, sess_idx=3) # 230m 32s
# das = [out[latent_idx][0] for latent_idx in range(8)]
np.save('das_03_latents.npy', das)

In [None]:
# das = np.load('vars/das_03_latents.npy', allow_pickle=True)
das = np.load('vars/das_gs_011526.npy', allow_pickle=True)
das=das[0][3]

In [None]:
r2s_tv = [np.mean([torch.mean(das[i]['tv']['r2test']) for i in range(len(das))])]
r2s_affine = [torch.mean(das[i]['affine']['r2test']) for i in range(len(das))]

fig, ax = plt.subplots()
ax.plot(range(0,9), np.concatenate((r2s_tv,r2s_affine)))
ax.set_xlabel("Number of Latents"); ax.set_ylabel("R2 of Affine Model")
fig.suptitle(f"{data.subject_ids[0]}, {data.session_ids[0][1]}")
fig.tight_layout()