In [1]:
from models.vae_hist_glm import VAE_Hist_GLM, NN_Encoder

import numpy as np
import torch
import torch.nn as nn
from tqdm import tnrange
import torch.optim as optim
import torch.nn.functional as F
from sklearn import metrics

In [2]:
base_dir = "/media/hdd01/sklee/"
experiment = "clust4-60"
cell_type = "CA1"

E_neural_file = "Espikes_neural.npy"
I_neural_file = "Ispikes_neural.npy"
V_file = "vdata_T10_Ne2000_gA0.6_tauA1_gN0.8_Ni200_gG0.1_gB0.1_Er0.5_Ir7.4_random_NR_rep10_stimseed1.npy"
C_syn_e_file = "handsub6_C_syn_e.npy"
C_syn_i_file = "handsub6_C_syn_i.npy"
C_den_file = "handsub6_C_den.npy"

E_neural = np.load(base_dir+cell_type+"_"+experiment+"/data/"+E_neural_file)
I_neural = np.load(base_dir+cell_type+"_"+experiment+"/data/"+I_neural_file)
V = np.load(base_dir+cell_type+"_"+experiment+"/data/"+V_file)[:,:50000].flatten()
C_syn_e = np.load(base_dir+cell_type+"_"+experiment+"/data/"+C_syn_e_file)
C_syn_i = np.load(base_dir+cell_type+"_"+experiment+"/data/"+C_syn_i_file)
C_den = np.load(base_dir+cell_type+"_"+experiment+"/data/"+C_den_file)

E_neural = torch.from_numpy(E_neural)
I_neural = torch.from_numpy(I_neural)
C_syn_e = torch.from_numpy(C_syn_e)
C_syn_i = torch.from_numpy(C_syn_i)
V = torch.from_numpy(V)
C_den = torch.from_numpy(C_den)
sub_no = C_den.shape[0]

In [3]:
T_train = 60 * 1000 * 5
T_test = 10 * 1000 * 5
T_no = 500
save_dir = base_dir+cell_type+"_"+experiment+"/"
device = torch.device("cuda")

batch_size = 100000
iter_no = 20000
epoch_no = 15
layer_no = 2

In [4]:
V_train = V[:T_train].to(device).float()
V_test = V[T_train:T_train + T_test].to(device).float()
test_E_neural = E_neural[T_train:T_train+T_test].float().to(device)
test_I_neural = I_neural[T_train:T_train+T_test].float().to(device)
train_E_neural = E_neural[:T_train].float().to(device)
train_I_neural = I_neural[:T_train].float().to(device)
C_syn_e = C_syn_e.float().to(device)
C_syn_i = C_syn_i.float().to(device)
C_den = C_den.float().to(device)

batch_no = (T_train - batch_size) * epoch_no
train_idx = np.empty((epoch_no, T_train - batch_size))
for i in range(epoch_no):
    part_idx = np.arange(T_train - batch_size)
    np.random.shuffle(part_idx)
    train_idx[i] = part_idx
train_idx = train_idx.flatten()
train_idx = torch.from_numpy(train_idx)

In [5]:
decoder = VAE_Hist_GLM(C_den, C_syn_e, C_syn_i, T_no, device)
encoder = NN_Encoder(C_syn_e, C_syn_i, T_no, layer_no, device)

enc_optimizer = torch.optim.Adam(encoder.parameters(), lr = 0.005)
dec_optimizer = torch.optim.Adam(decoder.parameters(), lr = 0.001)

encoder.to(device).float()
decoder.to(device).float()
print(sum(p.numel() for p in encoder.parameters() if p.requires_grad))
print(sum(p.numel() for p in decoder.parameters() if p.requires_grad))

mse_criterion = nn.MSELoss()

15220
355


In [6]:
for i in tnrange(iter_no):
    encoder.train()
    decoder.train()
    enc_optimizer.zero_grad()
    dec_optimizer.zero_grad()
        
    batch_idx = train_idx[i].long()
    batch_E_neural = train_E_neural[batch_idx : batch_idx+batch_size]
    batch_I_neural = train_I_neural[batch_idx : batch_idx+batch_size]
    batch_V = V_train[batch_idx : batch_idx+batch_size]
    
    V_hid_enc = encoder(batch_V, batch_E_neural, batch_I_neural)
    V_enc = torch.zeros(batch_size, sub_no).to(device)
    V_enc[:,1:] = V_enc[:,1:] + V_hid_enc
    V_enc[:,0] = V_enc[:,0] + batch_V
    
    V_pred, V_hid_dec, out_filters = decoder.train_forward(batch_E_neural,
                                                         batch_I_neural,
                                                         V_enc)
    
    prior_loss = torch.mean(V_hid_enc**2/0.1)
    var_loss = torch.var(batch_V - V_pred)
    mse_loss = mse_criterion(V_hid_dec, V_hid_enc.detach())
    
    loss = var_loss + prior_loss + mse_loss
    #loss = var_loss + mse_loss
    loss.backward()
    enc_optimizer.step()
    dec_optimizer.step()
    
    if i%1000 == 999:
        V_pred, out_filters = decoder.test_forward(test_E_neural, test_I_neural)
        
        var_exp = metrics.explained_variance_score(y_true=V_test.cpu().detach().numpy(),
                                                      y_pred=V_pred.cpu().detach().numpy())
        
        print(i, var_exp, torch.mean(V_hid_enc ,0))

  for i in tnrange(iter_no):


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20000.0), HTML(value='')))

999 0.00018095970153808594 tensor([-3.9501e-01,  3.1065e-01, -7.7390e-01, -3.8447e-04,  4.2458e-03],
       device='cuda:0', grad_fn=<MeanBackward1>)
1999 0.00032979249954223633 tensor([-0.2115,  0.1445, -0.1231, -0.0004, -0.0029], device='cuda:0',
       grad_fn=<MeanBackward1>)
2999 0.0004658699035644531 tensor([-0.1496,  0.0827,  0.0128, -0.0003, -0.0017], device='cuda:0',
       grad_fn=<MeanBackward1>)
3999 0.0003864169120788574 tensor([-8.7255e-02,  3.0979e-02, -3.7906e-03, -2.4602e-04,  6.0895e-05],
       device='cuda:0', grad_fn=<MeanBackward1>)
4999 0.000156402587890625 tensor([-0.0567,  0.0180,  0.0389, -0.0001, -0.0002], device='cuda:0',
       grad_fn=<MeanBackward1>)
5999 9.107589721679688e-05 tensor([-1.9015e-02,  2.8906e-03,  1.1874e-03, -7.5151e-05,  2.1341e-04],
       device='cuda:0', grad_fn=<MeanBackward1>)
6999 5.0067901611328125e-05 tensor([-0.0047,  0.0005, -0.0029, -0.0002, -0.0005], device='cuda:0',
       grad_fn=<MeanBackward1>)



KeyboardInterrupt: 