In [1]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tnrange
import torch.optim as optim
import torch.nn.functional as F
from LVAE_shGLM import LVAE_shGLM
from sklearn import metrics

# Hyperparams

In [2]:
train_T = 20000
test_T = 8000
T_syn = 201
C_den = torch.zeros(5,5)
C_den[0,1:] = 1
sub_no = C_den.shape[0]

batch_size = 1500

syn_basis_no = 2
hist_basis_no = 2
spike_status = True
T_hist = 201

T_V = 201 
hid_dim = 128
fix_var = 10000

theta_spike_init = -250
W_spike_init = 100

In [3]:
Ensyn = torch.tensor([0, 106, 213, 211, 99])
Insyn = torch.tensor([1, 22, 36, 42, 19])
E_no = torch.sum(Ensyn)
I_no = torch.sum(Insyn)

C_syn_e = torch.zeros(sub_no, E_no)
C_syn_i = torch.zeros(sub_no, I_no)

E_count = 0
for s in range(sub_no):
    C_syn_e[s,E_count:E_count+Ensyn[s]] = 1
    E_count += Ensyn[s]

I_count = 0
for s in range(sub_no):
    C_syn_i[s,I_count:I_count+Insyn[s]] = 1
    I_count += Insyn[s]

# Train

In [4]:
model = LVAE_shGLM(C_den.cuda(), C_syn_e.cuda(), C_syn_i.cuda(), T_syn, syn_basis_no,
                T_hist, hist_basis_no, hid_dim, fix_var, T_V, theta_spike_init, W_spike_init)

model = model.float().cuda()

V_ref = np.fromfile("/media/hdd01/sklee/cont_shglm/inputs/vdata_NMDA_ApN0.5_13_Adend_r0_o2_i2_g_b0.bin")
V_ref = V_ref[1:-2]

train_V_ref = V_ref[:train_T]
test_V_ref = V_ref[train_T:train_T+test_T]

test_V_ref = torch.from_numpy(test_V_ref).float().cuda()
train_V_ref = torch.from_numpy(train_V_ref)

In [5]:
E_neural = np.load("/media/hdd01/sklee/cont_shglm/inputs/Espikes_d48000_r1_rep1_Ne629_e5_E20_neural.npy")
I_neural = np.load("/media/hdd01/sklee/cont_shglm/inputs/Ispikes_d48000_r1_rep1_Ni120_i20_I30_neural.npy")

train_S_E = E_neural[:train_T]
train_S_I = I_neural[:train_T]
test_S_E = E_neural[train_T:train_T+test_T]
test_S_I = I_neural[train_T:train_T+test_T]

train_S_E = torch.from_numpy(train_S_E)
train_S_I = torch.from_numpy(train_S_I)
test_S_E = torch.from_numpy(test_S_E).float().cuda()
test_S_I = torch.from_numpy(test_S_I).float().cuda()

In [6]:
repeat_no = 2
batch_no = (train_V_ref.shape[0] - batch_size) * repeat_no
train_idx = np.empty((repeat_no, train_V_ref.shape[0] - batch_size))
for i in range(repeat_no):
    part_idx = np.arange(train_V_ref.shape[0] - batch_size)
    np.random.shuffle(part_idx)
    train_idx[i] = part_idx
train_idx = train_idx.flatten()
train_idx = torch.from_numpy(train_idx)

print(batch_no)
print(train_idx.shape[0])

37000
37000


In [7]:
optimizer = optim.Adam(model.parameters(), lr=0.005)

In [None]:
loss_array = np.empty((batch_no))
beta = 0

#import os
#os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

for i in tnrange(batch_no):
    if i%100 == 0:
        beta += 0.1
    
    
    model.train()
    optimizer.zero_grad()
    batch_idx = train_idx[i].long().cuda()
    batch_S_E = train_S_E[batch_idx : batch_idx+batch_size].float().cuda()
    batch_S_I = train_S_I[batch_idx : batch_idx+batch_size].float().cuda()
    batch_ref = train_V_ref[batch_idx:batch_idx+batch_size].float().cuda()
    rec_loss, kl_loss, batch_pred, post_prob, down_prob, post_mu, down_mu = model.loss(batch_ref, batch_S_E, batch_S_I, beta)
    
    var_loss = torch.var((batch_pred - batch_ref))
    loss = var_loss + beta*kl_loss
    
    loss_array[i] = loss.item()
    print(i, np.round(loss.item(), 4),
          np.round(post_prob.cpu().detach().numpy()[:2], 4),
          np.round(down_prob.cpu().detach().numpy()[:2], 4),
         np.round(post_mu.cpu().detach().numpy()[:2], 4),
          np.round(down_mu.cpu().detach().numpy()[:2], 4))
    
    loss.backward()
    optimizer.step()
    
    if i%50 == 0:
        model.eval()
        test_pred, post_mu, down_mu = model.Decoder(test_S_E, test_S_I)
        test_diff = (test_V_ref - test_pred) ** 1
        test_loss = torch.var(test_diff)
        test_score = metrics.explained_variance_score(y_true=test_V_ref.cpu().detach().numpy(),
                                                      y_pred=test_pred.cpu().detach().numpy(),
                                                      multioutput='uniform_average')
        train_score = metrics.explained_variance_score(y_true=batch_ref.cpu().detach().numpy(),
                                                      y_pred=batch_pred.cpu().detach().numpy(),
                                                      multioutput='uniform_average')
        print("TEST", i, round(test_loss.item(), 4),
              round(test_score, 4), round(train_score, 4))

        test_spikes = torch.sigmoid(down_mu + torch.randn(down_mu.shape[0], down_mu.shape[1]).cuda()*fix_var**(0.5))
        print(np.round(torch.mean(test_spikes, 0).cpu().detach().numpy(), 4))
        print(np.round(torch.mean(down_mu, 0).cpu().detach().numpy(), 4))
        
        if i%100 == 0:
            torch.save(model.state_dict(), "/media/hdd01/sklee/lvae_shglm/VAR_sub5_s2_h2_w100_t-250_shglm_i"+str(i)+".pt")
    


  for i in tnrange(batch_no):


HBox(children=(FloatProgress(value=0.0, max=37000.0), HTML(value='')))

0 77.09245300292969 [0.14665814 0.1569892 ] [0.02900044 0.02969376] [-101.78544 -103.59173] [-199.12779 -198.06252]
TEST 0 13.809616088867188 -0.0006620883941650391 0.002375304698944092
[0.02476789 0.0325149  0.03111712 0.02428926]
[-195.8027  -189.13513 -190.02148 -194.98978]
1 18.648040771484375 [0.07922976 0.0456385 ] [0.02451217 0.02129729] [-151.70944 -171.24829] [-198.79863 -190.4432 ]
2 53.430423736572266 [0.00319807 0.00197227] [0.02527126 0.03329073] [-274.31833 -266.38602] [-194.11452 -187.19046]
3 18.171680450439453 [0.02568953 0.04114217] [0.02903485 0.03166571] [-199.36353 -172.68228] [-191.1606  -181.71597]
4 12.10753059387207 [0.0612693  0.08620682] [0.02609771 0.0301838 ] [-155.22522 -136.55612] [-194.46088 -187.25142]
5 17.97801971435547 [0.06774576 0.07197417] [0.02934957 0.03435155] [-149.7874  -142.36037] [-196.19292 -188.0453 ]
6 13.526897430419922 [0.05224308 0.05240481] [0.03054169 0.03367425] [-162.12265 -162.58784] [-190.98717 -185.48587]
7 10.976739883422852 [

In [None]:
plt.plot(batch_pred.cpu().detach().numpy())

In [None]:
plt.plot(test_pred.cpu().detach().numpy())