In [1]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tnrange
import torch.optim as optim
import torch.nn.functional as F
from GP_hGLM import GP_hGLM
from sklearn import metrics


# Hyperparams

In [2]:
train_T = 65000
test_T = 15000
N = 200
M = 30
R=1
batch_size = 1500

C_den = torch.zeros(5,5)
C_den[0,1:] = 1

sub_no = C_den.shape[0]
S = C_den.shape[0] * 2

In [3]:
Ensyn = torch.tensor([0, 106, 213, 211, 99])
Insyn = torch.tensor([1, 22, 36, 42, 19])

E_no = torch.sum(Ensyn)
I_no = torch.sum(Insyn)

C_syn_e = torch.zeros(sub_no, E_no)
C_syn_i = torch.zeros(sub_no, I_no)

E_count = 0
for s in range(sub_no):
    C_syn_e[s,E_count:E_count+Ensyn[s]] = 1
    E_count += Ensyn[s]

I_count = 0
for s in range(sub_no):
    C_syn_i[s,I_count:I_count+Insyn[s]] = 1
    I_count += Insyn[s]

# Train

In [4]:
model = GP_hGLM(C_den.cuda(), S, N, M, R).double().cuda()

V_ref = np.load("/media/hdd01/sklee/L23_inputs/vdata_NMDA_ApN0.5_13_Adend_r0_o2_i2_g_b4.npy").flatten()

train_V_ref = V_ref[:train_T]
test_V_ref = V_ref[train_T:train_T+test_T]
test_V_ref = torch.from_numpy(test_V_ref).cuda()
train_V_ref = torch.from_numpy(train_V_ref)

In [5]:
raw_E_neural = np.load("/media/hdd01/sklee/L23_inputs/Espikes_NMDA_ApN0.5_13_Adend_r0_o2_i2_g_b4_neural.npy")
raw_I_neural = np.load("/media/hdd01/sklee/L23_inputs/Ispikes_NMDA_ApN0.5_13_Adend_r0_o2_i2_g_b4_neural.npy")

E_neural = torch.matmul(torch.from_numpy(raw_E_neural).double(), C_syn_e.T.double())
I_neural = torch.matmul(torch.from_numpy(raw_I_neural).double(), C_syn_i.T.double())

train_S_E = E_neural[:train_T]
train_S_I = I_neural[:train_T]
test_S_E = E_neural[train_T:train_T+test_T].double().cuda()
test_S_I = I_neural[train_T:train_T+test_T].double().cuda()

In [6]:
repeat_no = 1
batch_no = (train_V_ref.shape[0] - batch_size) * repeat_no
train_idx = np.empty((repeat_no, train_V_ref.shape[0] - batch_size))
for i in range(repeat_no):
    part_idx = np.arange(train_V_ref.shape[0] - batch_size)
    np.random.shuffle(part_idx)
    train_idx[i] = part_idx
train_idx = train_idx.flatten()
train_idx = torch.from_numpy(train_idx)

In [7]:
#optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
optimizer = optim.Adam(model.parameters(), lr=0.002)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.5)

5281


In [None]:
loss_array = np.empty((batch_no//100, 2))

for i in tnrange(batch_no):
    model.train()
    optimizer.zero_grad()
    batch_idx = train_idx[i].long()
    batch_S_E = train_S_E[batch_idx : batch_idx+batch_size].double().cuda()
    batch_S_I = train_S_I[batch_idx : batch_idx+batch_size].double().cuda()
    batch_pred, m_u, S_u, K_u, K_u_inv, F_e, F_i, u = model(batch_S_E, batch_S_I)
    batch_ref = train_V_ref[batch_idx:batch_idx+batch_size].cuda()
        
    #diff = (batch_ref - batch_pred) ** 1
    #var_loss = torch.var(diff)

    #rec_loss = torch.mean((batch_pred - batch_ref) ** 2)
    rec_loss = torch.var(batch_pred - batch_ref)
    kl_loss = 0
    for s in range(m_u.shape[0]):
        trace = torch.trace(torch.matmul(K_u_inv[s], S_u[s]))
        mKm = torch.matmul(m_u[s].double() , torch.matmul(K_u_inv[s].double(), m_u[s].double()))
        ln_det = torch.log(torch.det(K_u[s]) / torch.det(S_u[s]))
        kl_loss += 0.5*(trace + mKm + ln_det - m_u.shape[1])
        
    #print(trace, mKm, ln_det)
    #print(torch.det(K_u[s]))
    #print(torch.det(S_u[s]))
    

    
    loss = rec_loss 
    #print(i, var_loss.item(), torch.mean(batch_spikes).item())

    loss.backward()
    optimizer.step()
    scheduler.step()
    
    if i%50 == 0:
        model.eval()
        test_pred, test_m_u, test_S_u, test_K_u, test_K_u_inv, F_e, F_i, test_u = model(test_S_E, test_S_I)
        test_loss = torch.var(test_pred - test_V_ref)

        test_score = metrics.explained_variance_score(y_true=test_V_ref.cpu().detach().numpy(),
                                                      y_pred=test_pred.cpu().detach().numpy(),
                                                      multioutput='uniform_average')
        train_score = metrics.explained_variance_score(y_true=batch_ref.cpu().detach().numpy(),
                                                      y_pred=batch_pred.cpu().detach().numpy(),
                                                      multioutput='uniform_average')
        print("TEST", i, test_loss.item(), test_score.item(), train_score.item())
        #print(torch.mean(test_spikes, 0).cpu().detach().numpy())
        #print(torch.mean(spike_probs, 0).cpu().detach().numpy())
        #if i%100 == 0:
            #torch.save(model.state_dict(), "/media/hdd01/sklee/cont_shglm/baseGLM_CA1_sub6_i"+str(i)+".pt")
            #loss_array[i//100,0] = i
            #loss_array[i//100,1] = test_score
            #np.save("/media/hdd01/sklee/cont_shglm/baseGLM_CA1_sub6_test_scores.npy", loss_array)


  for i in tnrange(batch_no):


HBox(children=(FloatProgress(value=0.0, max=63500.0), HTML(value='')))

TEST 0 40.423210386657935 -0.002639862560209316 0.0
TEST 50 40.33116240966057 -0.00035674427180043544 0.0016165567787556068
TEST 100 40.32684300432957 -0.0002496076069289366 -0.0003039689596777695
TEST 150 40.32434192093503 -0.00018757181425899105 -0.000100093711632665
TEST 200 40.32291888736363 -0.0001522755039462531 -0.0014467723783113673
TEST 250 40.32243575095691 -0.0001402919971056349 0.0007295621994737322
TEST 300 40.32151318585167 -0.00011740909054802628 6.763280230526192e-05
TEST 350 40.31980294758806 -7.498907904235352e-05 -0.002601254459728919
TEST 400 40.318943222349205 -5.366482539947981e-05 -0.0006090356388028795


In [None]:
plt.plot(F_e[2,0,:].cpu().detach().numpy())

In [None]:
plt.plot(m_u[4].cpu().detach().numpy())
print()

In [None]:
plt.plot(test_pred.cpu().detach().numpy())