In [85]:
%load_ext autoreload
%autoreload 2

from lfads import LFADS_Net, LFADS_SingleSession_Net
import torch
import torch.optim as opt
from scheduler import LFADS_Scheduler
from plotter import Plotter
from trainer import RunManager

import yaml

from synthetic_data import LorenzSystem, EmbeddedLowDNetwork
from objective import SVLAE_Loss, LFADS_Loss, LogLikelihoodPoisson, LogLikelihoodPoissonSimplePlusL1, LogLikelihoodPoissonSimple, LogLikelihoodGaussian

from utils import load_parameters
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [86]:
# just using provided defaults here, 
dt = 0.01
num_steps = 1000
num_inits = 100
lorenz = LorenzSystem(num_inits=num_inits,
                      dt= dt)
net = EmbeddedLowDNetwork(low_d_system = lorenz, net_size=64, base_rate=1.0, dt= dt)
# no burn steps, no inputs, simulate for 10s?
# returns time_steps x inits x dimensions
rates = net.integrate(burn_steps=0, num_steps=num_steps, inputs=None)

latents = net.low_d_system.result

# rates x dt, which is 0.01
spikes = np.random.poisson(rates * dt)

In [87]:
rates.shape

(1000, 100, 64)

In [88]:
# device = 'cuda' if torch.cuda.is_available() else 'cpu'; print(device)
device = 'cpu'
train_idxs, valid_idxs = train_test_split(np.arange(num_inits), train_size=0.8)
train_rates = rates[:, train_idxs, :]
train_latents = latents[:, train_idxs, :]
train_spikes = spikes[:, train_idxs, :]

valid_rates = rates[:, valid_idxs, :]
valid_latents = latents[:, valid_idxs, :]
valid_spikes = spikes[:, valid_idxs, :]

train_data  = torch.Tensor(train_spikes).to(device)
valid_data  = torch.Tensor(valid_spikes).to(device)
print(train_data.device)
train_ds    = torch.utils.data.TensorDataset(train_data)
valid_ds    = torch.utils.data.TensorDataset(valid_data)
train_dl    = torch.utils.data.DataLoader(train_ds, batch_size = 25, shuffle=True)
valid_dl    = torch.utils.data.DataLoader(valid_ds, batch_size = valid_data.shape[0])

cpu


In [89]:
hyperparameter_path = '../hierarchical_lfads/hyperparameters/lorenz/lfads.yaml'
hyperparams = load_parameters(hyperparameter_path)



In [90]:
model = LFADS_SingleSession_Net(
    input_size           = 64,
    factor_size          = hyperparams['model']['factor_size'],
    g_encoder_size       = hyperparams['model']['g_encoder_size'],
    c_encoder_size       = hyperparams['model']['c_encoder_size'],
    g_latent_size        = hyperparams['model']['g_latent_size'],
    u_latent_size        = hyperparams['model']['u_latent_size'],
    controller_size      = hyperparams['model']['controller_size'],
    generator_size       = hyperparams['model']['generator_size'],
    prior                = hyperparams['model']['prior'],
    clip_val             = hyperparams['model']['clip_val'],
    dropout              = hyperparams['model']['dropout'],
    do_normalize_factors = hyperparams['model']['normalize_factors'],
    max_norm             = hyperparams['model']['max_norm'],
    device               = device
).to(device)

# model = LFADS_Net(
#     input_size           = 64,
#     factor_size          = 3,
#     g_encoder_size       = hyperparams['model']['g_encoder_size'],
#     c_encoder_size       = hyperparams['model']['c_encoder_size'],
#     g_latent_size        = 3,
#     u_latent_size        = hyperparams['model']['u_latent_size'],
#     controller_size      = hyperparams['model']['controller_size'],
#     generator_size       = hyperparams['model']['generator_size'],
#     prior                = hyperparams['model']['prior'],
#     clip_val             = hyperparams['model']['clip_val'],
#     dropout              = hyperparams['model']['dropout'],
#     do_normalize_factors = hyperparams['model']['normalize_factors'],
#     max_norm             = hyperparams['model']['max_norm'],
#     device               = 'cuda:0'
# ).to(device)

In [91]:
loglikelihood = LogLikelihoodPoisson(dt=dt, device=device)
objective = LFADS_Loss(
    loglikelihood=loglikelihood,
    loss_weight_dict={
        "kl": hyperparams['objective']["kl"],
        "l2": hyperparams["objective"]["l2"]},
    l2_con_scale=hyperparams["objective"]["l2_con_scale"],
    l2_gen_scale=hyperparams["objective"]["l2_gen_scale"]
).to(device)

In [92]:


optimizer = opt.Adam([p for p in model.parameters() if p.requires_grad],
                     lr=hyperparams['optimizer']['lr_init'],
                     betas=hyperparams['optimizer']['betas'],
                     eps=hyperparams['optimizer']['eps'])

scheduler = LFADS_Scheduler(optimizer      = optimizer,
                            mode           = 'min',
                            factor         = hyperparams['scheduler']['scheduler_factor'],
                            patience       = hyperparams['scheduler']['scheduler_patience'],
                            verbose        = True,
                            threshold      = 1e-4,
                            threshold_mode = 'abs',
                            cooldown       = hyperparams['scheduler']['scheduler_cooldown'],
                            min_lr         = hyperparams['scheduler']['lr_min'])

TIME = torch.arange(0, num_steps*dt, dt)

plotter = {
    'train' : Plotter(time=TIME, truth={
        'rates'   : train_rates,
        'spikes'  : train_spikes,
        'latent'  : train_latents}),
    'valid' : Plotter(time=TIME, truth={
        'rates'   : valid_rates,
        'spikes'  : valid_spikes,
        'latent'  : valid_latents}),
}

In [93]:
valid_spikes.shape

(1000, 20, 64)

In [94]:
# transforms  = trf.Compose([])

run_manager = RunManager(model      = model,
                         objective  = objective,
                         optimizer  = optimizer,
                         scheduler  = scheduler,
                         train_dl   = train_dl,
                         valid_dl   = valid_dl,
                         transforms = None,  # transforms,
                         writer     = None,
                         plotter    = plotter,
                         max_epochs = 5,
                         save_loc   = "data/lorenz",
                         do_health_check = False)

In [95]:
run_manager.run()

Epoch     1, Epoch time = 6.066 s, Loss (train, valid):  recon (348.897, 85.977), kl (1.694, 1.553), total (350.636, 87.624), l2 (0.045)
Epoch     2, Epoch time = 6.057 s, Loss (train, valid):  recon (345.893, 85.606), kl (0.457, 0.068), total (346.490, 85.864), l2 (0.141)
Epoch     3, Epoch time = 6.125 s, Loss (train, valid):  recon (345.201, 85.581), kl (0.312, 0.094), total (345.747, 85.954), l2 (0.233)
Epoch     4, Epoch time = 6.136 s, Loss (train, valid):  recon (345.092, 85.573), kl (0.342, 0.120), total (345.755, 86.055), l2 (0.320)
Epoch     5, Epoch time = 6.238 s, Loss (train, valid):  recon (345.073, 85.559), kl (0.342, 0.138), total (345.815, 86.134), l2 (0.400)
