### Train models with different latent sizes

In [11]:
%load_ext autoreload
%autoreload 2

import s3fs
import pandas as pd
import os
import numpy as np

from lfads import LFADS_Net, LFADS_SingleSession_Net
import torch
import torch.optim as opt
from scheduler import LFADS_Scheduler
from plotter import Plotter
from trainer import RunManager

import yaml

from synthetic_data import LorenzSystem, EmbeddedLowDNetwork
from objective import SVLAE_Loss, LFADS_Loss, LogLikelihoodPoisson, LogLikelihoodPoissonSimplePlusL1, LogLikelihoodPoissonSimple, LogLikelihoodGaussian

from utils import load_parameters
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

import pickle


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
species = 'nhp'
subject = 'SA'
exp = 'WCST'
session = 20180802  # this is the session for which there are spikes at the moment. 

NHP_WCST_DIR = 'nhp-lfp/wcst-preprocessed/'


# grab behavioral data, spike data, trial numbers. 
fs = s3fs.S3FileSystem()
behavior_file = os.path.join(NHP_WCST_DIR, "rawdata", "sub-" + str(subject), "sess-" + str(session), "behavior", "sub-" + str(subject) + "_sess-" + str(session) + "_object_features.csv")
behavior_data = pd.read_csv(fs.open(behavior_file))
valid_beh = behavior_data[behavior_data.Response.isin(["Correct", "Incorrect"])]   
valid_beh["PrevResponse"] = valid_beh["Response"].shift()  
valid_beh = valid_beh[valid_beh.TrialNumber >= 57]



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  valid_beh["PrevResponse"] = valid_beh["Response"].shift()


In [13]:
pre_interval = 0
post_interval = 350
interval_size = 50
frs = pd.read_pickle(fs.open(f"l2l.pqz317.scratch/firing_rates_{pre_interval}_crossfixation_{post_interval}_{interval_size}_bins.pickle"))
frs = frs[frs.TrialNumber.isin(valid_beh.TrialNumber)]

In [14]:
# get data in trials x time x neurons
num_time_bins = len(frs["TimeBins"].unique())
num_units = len(frs["UnitID"].unique())
num_trials = len(frs["TrialNumber"].unique())
sorted = frs.sort_values(by=["UnitID", "TimeBins", "TrialNumber"])
# currently in neurons x time x trials
spike_data = sorted["SpikeCounts"].to_numpy().reshape((num_units, num_time_bins, num_trials))
# want trials x time x neurons
spike_data = np.transpose(spike_data, (2, 1, 0))


In [15]:
# device = 'cuda' if torch.cuda.is_available() else 'cpu'; print(device)
device = 'cpu'
train_idxs, valid_idxs = train_test_split(
    np.arange(spike_data.shape[1]), 
    train_size=0.8,
    random_state=42
)
train_spikes = spike_data[:, train_idxs, :]

valid_spikes = spike_data[:, valid_idxs, :]

train_data  = torch.Tensor(train_spikes).to(device)
valid_data  = torch.Tensor(valid_spikes).to(device)
train_ds    = torch.utils.data.TensorDataset(train_data)
valid_ds    = torch.utils.data.TensorDataset(valid_data)
train_dl    = torch.utils.data.DataLoader(train_ds, batch_size = 50, shuffle=True)
valid_dl    = torch.utils.data.DataLoader(valid_ds, batch_size = valid_data.shape[0])

In [16]:
hyperparameter_path = 'lfads_cross_fixation_hyperparams.yaml'
hyperparams = load_parameters(hyperparameter_path)

dt = 0.05
loglikelihood = LogLikelihoodPoisson(dt=dt, device=device)
objective = LFADS_Loss(
    loglikelihood=loglikelihood,
    loss_weight_dict={
        "kl": hyperparams['objective']["kl"],
        "l2": hyperparams["objective"]["l2"]},
    l2_con_scale=hyperparams["objective"]["l2_con_scale"],
    l2_gen_scale=hyperparams["objective"]["l2_gen_scale"]
).to(device)



In [22]:
latent_sizes = np.arange(1, 20)
max_epocs = 500

for latent_size in latent_sizes:
    model = LFADS_SingleSession_Net(
        input_size           = 58,
        factor_size          = hyperparams['model']['factor_size'],
        g_encoder_size       = hyperparams['model']['g_encoder_size'],
        c_encoder_size       = hyperparams['model']['c_encoder_size'],
        g_latent_size        = int(latent_size),
        u_latent_size        = hyperparams['model']['u_latent_size'],
        controller_size      = hyperparams['model']['controller_size'],
        generator_size       = hyperparams['model']['generator_size'],
        prior                = hyperparams['model']['prior'],
        clip_val             = hyperparams['model']['clip_val'],
        dropout              = hyperparams['model']['dropout'],
        do_normalize_factors = hyperparams['model']['normalize_factors'],
        max_norm             = hyperparams['model']['max_norm'],
        device               = device
    ).to(device)
    optimizer = opt.Adam([p for p in model.parameters() if p.requires_grad],
                        lr=hyperparams['optimizer']['lr_init'],
                        betas=hyperparams['optimizer']['betas'],
                        eps=hyperparams['optimizer']['eps'])

    scheduler = LFADS_Scheduler(optimizer      = optimizer,
                                mode           = 'min',
                                factor         = hyperparams['scheduler']['scheduler_factor'],
                                patience       = hyperparams['scheduler']['scheduler_patience'],
                                verbose        = True,
                                threshold      = 1e-4,
                                threshold_mode = 'abs',
                                cooldown       = hyperparams['scheduler']['scheduler_cooldown'],
                                min_lr         = hyperparams['scheduler']['lr_min'])
    run_manager = RunManager(model = model,
                         objective  = objective,
                         optimizer  = optimizer,
                         scheduler  = scheduler,
                         train_dl   = train_dl,
                         valid_dl   = valid_dl,
                         transforms = None,  # transforms,
                         writer     = None,
                         plotter    = None,
                         max_epochs = 100,
                         save_loc   = f"data/lfadscrosslatent{latent_size}",
                         do_health_check = False)
    run_manager.run()
    with open(f"data/latent_{latent_size}_loss_data.pickle", "wb") as f:
        pickle.dump(run_manager.loss_dict, f)

Epoch     1, Epoch time = 0.429 s, Loss (train, valid):  recon (112.676, 43.649), kl (1.556, 0.341), total (117.909, 48.103), l2 (3.677)
Epoch     2, Epoch time = 0.426 s, Loss (train, valid):  recon (103.006, 43.617), kl (0.336, 0.402), total (107.049, 47.307), l2 (3.707)
Epoch     3, Epoch time = 0.426 s, Loss (train, valid):  recon (102.630, 43.721), kl (0.410, 0.411), total (105.974, 46.717), l2 (2.933)
Epoch     4, Epoch time = 0.451 s, Loss (train, valid):  recon (102.361, 44.188), kl (0.519, 0.309), total (105.185, 46.533), l2 (2.306)
Epoch     5, Epoch time = 0.431 s, Loss (train, valid):  recon (102.115, 43.761), kl (0.519, 0.254), total (104.458, 45.633), l2 (1.823)
Epoch     6, Epoch time = 0.427 s, Loss (train, valid):  recon (101.925, 43.688), kl (0.612, 0.296), total (103.991, 45.273), l2 (1.454)
Epoch     7, Epoch time = 0.452 s, Loss (train, valid):  recon (101.776, 44.021), kl (0.624, 0.261), total (103.567, 45.325), l2 (1.167)
Epoch     8, Epoch time = 0.417 s, Loss (