In [1]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tnrange
import torch.optim as optim
import torch.nn.functional as F
from GP_hGLM import GP_hGLM
from sklearn import metrics


# Hyperparams

In [2]:
train_T = 65000
test_T = 15000
N = 200
M = 10
R=1
batch_size = 1500

C_den = torch.zeros(5,5)
C_den[0,1:] = 1

sub_no = C_den.shape[0]

In [3]:
Ensyn = torch.tensor([0, 106, 213, 211, 99])
Insyn = torch.tensor([1, 22, 36, 42, 19])

E_no = torch.sum(Ensyn)
I_no = torch.sum(Insyn)

C_syn_e = torch.zeros(sub_no, E_no)
C_syn_i = torch.zeros(sub_no, I_no)

E_count = 0
for s in range(sub_no):
    C_syn_e[s,E_count:E_count+Ensyn[s]] = 1
    E_count += Ensyn[s]

I_count = 0
for s in range(sub_no):
    C_syn_i[s,I_count:I_count+Insyn[s]] = 1
    I_count += Insyn[s]

# Train

In [4]:
model = GP_hGLM(C_den.cuda(), sub_no, N, M, R).double().cuda()

V_ref = np.load("/media/hdd01/sklee/L23_inputs/vdata_NMDA_ApN0.5_13_Adend_r0_o2_i2_g_b4.npy").flatten()

train_V_ref = V_ref[:train_T]
test_V_ref = V_ref[train_T:train_T+test_T]
test_V_ref = torch.from_numpy(test_V_ref).cuda()
train_V_ref = torch.from_numpy(train_V_ref)

In [5]:
raw_E_neural = np.load("/media/hdd01/sklee/L23_inputs/Espikes_NMDA_ApN0.5_13_Adend_r0_o2_i2_g_b4_neural.npy")
raw_I_neural = np.load("/media/hdd01/sklee/L23_inputs/Ispikes_NMDA_ApN0.5_13_Adend_r0_o2_i2_g_b4_neural.npy")

E_neural = torch.matmul(torch.from_numpy(raw_E_neural).double(), C_syn_e.T.double())
I_neural = torch.matmul(torch.from_numpy(raw_I_neural).double(), C_syn_i.T.double())

train_S_E = E_neural[:train_T]
train_S_I = I_neural[:train_T]
test_S_E = E_neural[train_T:train_T+test_T].double().cuda()
test_S_I = I_neural[train_T:train_T+test_T].double().cuda()

In [6]:
repeat_no = 1
batch_no = (train_V_ref.shape[0] - batch_size) * repeat_no
train_idx = np.empty((repeat_no, train_V_ref.shape[0] - batch_size))
for i in range(repeat_no):
    part_idx = np.arange(train_V_ref.shape[0] - batch_size)
    np.random.shuffle(part_idx)
    train_idx[i] = part_idx
train_idx = train_idx.flatten()
train_idx = torch.from_numpy(train_idx)

In [7]:
#optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
optimizer = optim.Adam(model.parameters(), lr=0.002)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.5)

801


In [8]:
loss_array = np.empty((batch_no//100, 2))

for i in tnrange(batch_no):
    model.train()
    optimizer.zero_grad()
    batch_idx = train_idx[i].long()
    batch_S_E = train_S_E[batch_idx : batch_idx+batch_size].double().cuda()
    batch_S_I = train_S_I[batch_idx : batch_idx+batch_size].double().cuda()
    batch_pred, m_u, S_u, K_u, K_u_inv, F_e, F_i, u = model(batch_S_E, batch_S_I)
    batch_ref = train_V_ref[batch_idx:batch_idx+batch_size].cuda()
        
    #diff = (batch_ref - batch_pred) ** 1
    #var_loss = torch.var(diff)

    #rec_loss = torch.mean((batch_pred - batch_ref) ** 2)
    rec_loss = torch.var(batch_pred - batch_ref)
    kl_loss = 0
    for s in range(m_u.shape[0]):
        trace = torch.trace(torch.matmul(K_u_inv[s], S_u[s]))
        mKm = torch.matmul(m_u[s].double() , torch.matmul(K_u_inv[s].double(), m_u[s].double()).double())
        ln_det = torch.logdet(K_u[s]) - torch.logdet(S_u[s])
        kl_loss += 0.5*(trace + mKm + ln_det - m_u.shape[1])
        
    loss = rec_loss + kl_loss
    print(i, rec_loss.item(), trace.item(), mKm.item(), torch.logdet(K_u[s]).item(), torch.logdet(S_u[s]).item())
    print(torch.mean(K_u).item())
    print(torch.mean(K_u_inv).item())

    loss.backward()
    optimizer.step()
    scheduler.step()
    
    if i%50 == 0:
        model.eval()
        test_pred, test_m_u, test_S_u, test_K_u, test_K_u_inv, F_e, F_i, test_u = model(test_S_E, test_S_I)
        test_loss = torch.var(test_pred - test_V_ref)

        test_score = metrics.explained_variance_score(y_true=test_V_ref.cpu().detach().numpy(),
                                                      y_pred=test_pred.cpu().detach().numpy(),
                                                      multioutput='uniform_average')
        train_score = metrics.explained_variance_score(y_true=batch_ref.cpu().detach().numpy(),
                                                      y_pred=batch_pred.cpu().detach().numpy(),
                                                      multioutput='uniform_average')
        print("TEST", i, test_loss.item(), test_score.item(), train_score.item())
        #print(torch.mean(test_spikes, 0).cpu().detach().numpy())
        #print(torch.mean(spike_probs, 0).cpu().detach().numpy())
        #if i%100 == 0:
            #torch.save(model.state_dict(), "/media/hdd01/sklee/cont_shglm/baseGLM_CA1_sub6_i"+str(i)+".pt")
            #loss_array[i//100,0] = i
            #loss_array[i//100,1] = test_score
            #np.save("/media/hdd01/sklee/cont_shglm/baseGLM_CA1_sub6_test_scores.npy", loss_array)


  for i in tnrange(batch_no):


HBox(children=(FloatProgress(value=0.0, max=63500.0), HTML(value='')))

0 24.55654155800823 216903.703125 0.0 -65.70870971679688 2.1475515365600586
124.78349304199219
0.0005000000237487257
TEST 0 40.311365858748296 0.00013428082982858136 0.0
1 31.782675434364574 2570.96875 6.8107154855897445 -63.485015869140625 2.141152858734131
124.82308197021484
0.0008750000270083547
2 29.84699558181422 -1350997.875 -51.99808700833903 -64.7318344116211 2.134291172027588
124.87257385253906
0.0007656250381842256
3 34.63971823679795 398793.625 5.709601078839164 nan 2.1435787677764893
124.92440032958984
0.000937500037252903
4 25.135303683296012 -126535.171875 0.3090400582818531 -62.98835372924805 2.1482110023498535
124.91181945800781
0.0009765625
5 34.47315592667721 158894.9375 97.17697815122756 nan 2.1524808406829834
124.89532470703125
0.000750000006519258
6 28.692784843115717 133738.6875 21.85117793013163 nan 2.1552274227142334
124.8956069946289
0.000937500037252903
7 41.907388605077905 -283589.9375 -22.648757652215743 nan 2.155406951904297
124.89986419677734
0.00095312506

KeyboardInterrupt: 

In [9]:
print(torch.eig(K_u[0])[0][0,0] - torch.eig(K_u[0])[0][-1,0])
print(torch.eig(K_u_inv[0])[0][0,0] - torch.eig(K_u_inv[0])[0][-1,0])

tensor(1257.3439, device='cuda:0', grad_fn=<SubBackward0>)
tensor(907499.8750, device='cuda:0', grad_fn=<SubBackward0>)


In [None]:
plt.plot(F_e[2,0,:].cpu().detach().numpy())

In [None]:
plt.plot(m_u[4].cpu().detach().numpy())
print()

In [None]:
plt.plot(test_pred.cpu().detach().numpy())