In [1]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tnrange
import torch.optim as optim
import torch.nn.functional as F
from LVAE_shGLM import LVAE_shGLM
from sklearn import metrics

# Hyperparams

In [2]:
train_T = 20000
test_T = 8000
T_syn = 201
C_den = torch.zeros(5,5)
C_den[0,1:] = 1
sub_no = C_den.shape[0]

batch_size = 1500

syn_basis_no = 2
hist_basis_no = 2
spike_status = True
T_hist = 201

T_V = 201 
hid_dim = 128
fix_var = 10000

theta_spike_init = -250
W_spike_init = 100

In [3]:
Ensyn = torch.tensor([0, 106, 213, 211, 99])
Insyn = torch.tensor([1, 22, 36, 42, 19])
E_no = torch.sum(Ensyn)
I_no = torch.sum(Insyn)

C_syn_e = torch.zeros(sub_no, E_no)
C_syn_i = torch.zeros(sub_no, I_no)

E_count = 0
for s in range(sub_no):
    C_syn_e[s,E_count:E_count+Ensyn[s]] = 1
    E_count += Ensyn[s]

I_count = 0
for s in range(sub_no):
    C_syn_i[s,I_count:I_count+Insyn[s]] = 1
    I_count += Insyn[s]

# Train

In [4]:
model = LVAE_shGLM(C_den.cuda(), C_syn_e.cuda(), C_syn_i.cuda(), T_syn, syn_basis_no,
                T_hist, hist_basis_no, hid_dim, fix_var, T_V, theta_spike_init, W_spike_init)

model = model.float().cuda()

V_ref = np.fromfile("/media/hdd01/sklee/cont_shglm/inputs/vdata_NMDA_ApN0.5_13_Adend_r0_o2_i2_g_b0.bin")
V_ref = V_ref[1:-2]

train_V_ref = V_ref[:train_T]
test_V_ref = V_ref[train_T:train_T+test_T]

test_V_ref = torch.from_numpy(test_V_ref).float().cuda()
train_V_ref = torch.from_numpy(train_V_ref)

In [5]:
E_neural = np.load("/media/hdd01/sklee/cont_shglm/inputs/Espikes_d48000_r1_rep1_Ne629_e5_E20_neural.npy")
I_neural = np.load("/media/hdd01/sklee/cont_shglm/inputs/Ispikes_d48000_r1_rep1_Ni120_i20_I30_neural.npy")

train_S_E = E_neural[:train_T]
train_S_I = I_neural[:train_T]
test_S_E = E_neural[train_T:train_T+test_T]
test_S_I = I_neural[train_T:train_T+test_T]

train_S_E = torch.from_numpy(train_S_E)
train_S_I = torch.from_numpy(train_S_I)
test_S_E = torch.from_numpy(test_S_E).float().cuda()
test_S_I = torch.from_numpy(test_S_I).float().cuda()

In [6]:
repeat_no = 2
batch_no = (train_V_ref.shape[0] - batch_size) * repeat_no
train_idx = np.empty((repeat_no, train_V_ref.shape[0] - batch_size))
for i in range(repeat_no):
    part_idx = np.arange(train_V_ref.shape[0] - batch_size)
    np.random.shuffle(part_idx)
    train_idx[i] = part_idx
train_idx = train_idx.flatten()
train_idx = torch.from_numpy(train_idx)

print(batch_no)
print(train_idx.shape[0])

37000
37000


In [7]:
optimizer = optim.Adam(model.parameters(), lr=0.005)

In [None]:
loss_array = np.empty((batch_no))
beta = 0

#import os
#os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

for i in tnrange(batch_no):
    if i%100 == 0:
        beta += 0.1
    
    
    model.train()
    optimizer.zero_grad()
    batch_idx = train_idx[i].long().cuda()
    batch_S_E = train_S_E[batch_idx : batch_idx+batch_size].float().cuda()
    batch_S_I = train_S_I[batch_idx : batch_idx+batch_size].float().cuda()
    batch_ref = train_V_ref[batch_idx:batch_idx+batch_size].float().cuda()
    rec_loss, kl_loss, batch_pred, post_prob, down_prob, post_mu, down_mu = model.loss(batch_ref, batch_S_E, batch_S_I, beta)
    
    var_loss = torch.var((batch_pred - batch_ref))
    loss = var_loss + beta*kl_loss
    
    loss_array[i] = loss.item()
    print(i, np.round(loss.item(), 4),
          np.round(post_prob.cpu().detach().numpy()[:2], 4),
          np.round(down_prob.cpu().detach().numpy()[:2], 4),
         np.round(post_mu.cpu().detach().numpy()[:2], 4),
          np.round(down_mu.cpu().detach().numpy()[:2], 4))
    
    loss.backward()
    optimizer.step()
    
    if i%50 == 0:
        model.eval()
        test_pred, post_mu, down_mu = model.Decoder(test_S_E, test_S_I)
        test_diff = (test_V_ref - test_pred) ** 1
        test_loss = torch.var(test_diff)
        test_score = metrics.explained_variance_score(y_true=test_V_ref.cpu().detach().numpy(),
                                                      y_pred=test_pred.cpu().detach().numpy(),
                                                      multioutput='uniform_average')
        train_score = metrics.explained_variance_score(y_true=batch_ref.cpu().detach().numpy(),
                                                      y_pred=batch_pred.cpu().detach().numpy(),
                                                      multioutput='uniform_average')
        print("TEST", i, round(test_loss.item(), 4),
              round(test_score, 4), round(train_score, 4))

        test_spikes = torch.sigmoid(down_mu + torch.randn(down_mu.shape[0], down_mu.shape[1]).cuda()*fix_var**(0.5))
        print(np.round(torch.mean(test_spikes, 0).cpu().detach().numpy(), 4))
        print(np.round(torch.mean(down_mu, 0).cpu().detach().numpy(), 4))
        
        if i%100 == 0:
            torch.save(model.state_dict(), "/media/hdd01/sklee/lvae_shglm/VAR_sub5_s2_h2_w100_t-250_shglm_i"+str(i)+".pt")
    


  for i in tnrange(batch_no):


HBox(children=(FloatProgress(value=0.0, max=37000.0), HTML(value='')))

0 77.8411 [0.1627 0.1715] [0.0236 0.0241] [-98.7672 -98.3554] [-198.926  -197.7616]
TEST 0 13.8096 -0.0007 0.0007
[0.0271 0.0313 0.0264 0.0262]
[-195.8038 -189.1364 -190.0197 -194.9902]
1 27.8361 [0.1275 0.0908] [0.0235 0.0327] [-121.4088 -137.2201] [-194.5637 -187.6411]
2 51.1992 [0.0133 0.0068] [0.0198 0.0306] [-230.8159 -247.7715] [-192.1505 -183.1173]
3 8.5632 [0.0185 0.0127] [0.0226 0.0352] [-213.9743 -209.4619] [-192.9467 -189.4427]
4 24.957 [0.0492 0.0682] [0.037  0.0345] [-176.7502 -147.5988] [-188.2707 -179.8287]
5 24.2115 [0.0733 0.0968] [0.0257 0.0443] [-157.7332 -132.5573] [-190.5396 -179.9467]
6 17.3936 [0.062 0.078] [0.0309 0.0405] [-150.2951 -139.045 ] [-189.6804 -177.861 ]
7 12.9667 [0.0635 0.0632] [0.0336 0.038 ] [-154.9106 -155.2461] [-187.5073 -175.2577]
8 7.2709 [0.0389 0.0295] [0.036 0.04 ] [-175.693  -179.9974] [-188.7746 -178.2958]
9 12.3505 [0.0193 0.0274] [0.024 0.032] [-199.0773 -200.9341] [-187.2613 -182.2103]
10 19.3805 [0.027  0.0386] [0.0353 0.0494] [-200.

94 5.3165 [0.0469 0.0365] [0.0438 0.051 ] [-175.0842 -163.6345] [-174.9413 -158.1873]
95 5.0028 [0.045  0.0475] [0.0403 0.0558] [-176.1076 -163.657 ] [-175.2687 -158.5958]
96 3.0376 [0.0461 0.058 ] [0.0336 0.0539] [-176.195  -159.5509] [-179.5239 -161.1613]
97 5.5382 [0.0437 0.052 ] [0.0427 0.0579] [-171.881  -153.3171] [-173.0624 -158.4552]
98 6.4809 [0.041  0.0647] [0.0478 0.0526] [-175.8998 -155.4988] [-172.7412 -158.533 ]
99 4.3035 [0.0384 0.05  ] [0.0351 0.0556] [-178.2142 -162.5429] [-172.5338 -164.0315]
100 4.8272 [0.0342 0.053 ] [0.0427 0.0491] [-178.0384 -167.3691] [-179.4887 -165.7096]
TEST 100 13.8096 0.0003 0.0004
[0.0384 0.0524 0.0506 0.0369]
[-174.5703 -160.6628 -162.5152 -177.3643]
101 17.9728 [0.05   0.0611] [0.0376 0.0545] [-165.5235 -156.5831] [-173.6136 -156.4159]
102 23.2579 [0.0445 0.0584] [0.0597 0.0527] [-176.1945 -161.1953] [-167.646  -160.0493]
103 17.8957 [0.0385 0.0645] [0.0393 0.061 ] [-175.3139 -157.3818] [-172.8922 -156.9417]
104 2.9784 [0.0354 0.0577] [0.

186 5.3019 [0.0378 0.0465] [0.0397 0.0522] [-182.4482 -167.3168] [-177.5042 -157.9419]
187 4.3191 [0.0504 0.0577] [0.0405 0.0589] [-174.8536 -160.0966] [-172.8707 -160.6502]
188 8.2135 [0.062  0.0744] [0.0397 0.0593] [-163.7583 -147.5274] [-172.4345 -156.5019]
189 13.6124 [0.0469 0.0688] [0.0464 0.0646] [-165.0617 -148.6076] [-172.7178 -155.5904]
190 4.158 [0.0381 0.0452] [0.0396 0.0474] [-180.6188 -165.842 ] [-172.6717 -160.5316]
191 9.9322 [0.0314 0.0514] [0.0397 0.0576] [-176.9494 -162.5379] [-172.6013 -155.6764]
192 3.5182 [0.0333 0.0534] [0.0485 0.0399] [-175.3591 -159.4514] [-178.264  -158.6166]
193 4.56 [0.0489 0.0683] [0.049 0.062] [-170.3416 -153.2568] [-174.4686 -158.7104]
194 4.2867 [0.0408 0.0555] [0.0365 0.0606] [-173.803  -154.3753] [-174.2623 -158.7158]
195 3.1175 [0.043  0.0498] [0.0347 0.0557] [-181.8995 -160.7742] [-178.0176 -158.4089]
196 2.2507 [0.0355 0.0561] [0.045 0.064] [-180.097  -163.8991] [-173.15   -159.2653]
197 3.6187 [0.0464 0.0625] [0.0355 0.0582] [-170.

278 20.8245 [0.048  0.0607] [0.0361 0.0652] [-169.0411 -154.3334] [-169.856  -154.8026]
279 3.4654 [0.0373 0.0487] [0.0456 0.0473] [-180.6615 -166.292 ] [-174.029  -156.5696]
280 9.2326 [0.0463 0.06  ] [0.0432 0.0599] [-167.9788 -153.3054] [-171.3152 -154.4833]
281 8.8191 [0.0489 0.0857] [0.0392 0.0646] [-165.1534 -146.8495] [-172.5455 -153.7977]
282 7.9513 [0.0446 0.0617] [0.0549 0.0555] [-170.9818 -149.7219] [-172.6274 -153.8872]
283 9.8957 [0.037  0.0546] [0.0384 0.0625] [-182.6117 -159.5586] [-175.2192 -154.7826]
284 2.2943 [0.0428 0.0564] [0.0367 0.0556] [-180.8284 -161.5111] [-174.1512 -156.2778]
285 9.0216 [0.0486 0.0654] [0.0439 0.0614] [-164.9883 -150.1698] [-171.0894 -154.5214]
286 5.4586 [0.0474 0.0534] [0.0472 0.0479] [-166.6242 -153.785 ] [-173.6565 -159.8411]
287 6.5604 [0.0416 0.0474] [0.0434 0.0604] [-171.0894 -156.1756] [-171.516  -153.6166]
288 3.9482 [0.0441 0.0506] [0.0476 0.0596] [-180.3906 -163.5787] [-171.5146 -156.0557]
289 6.3062 [0.0412 0.0689] [0.0418 0.0646]

370 9.7984 [0.0401 0.0684] [0.0545 0.0679] [-167.0042 -155.4985] [-163.6048 -152.3051]
371 7.413 [0.0373 0.0469] [0.0441 0.0534] [-169.4291 -158.7092] [-164.5156 -152.171 ]
372 6.876 [0.0453 0.0582] [0.0552 0.0632] [-161.4813 -151.2018] [-164.0329 -152.0124]
373 4.869 [0.0513 0.0549] [0.0528 0.0564] [-160.4414 -150.2407] [-165.7336 -155.8864]
374 2.7037 [0.0518 0.0654] [0.0436 0.0719] [-161.3982 -150.8536] [-163.9752 -154.2783]
375 8.9244 [0.0402 0.0737] [0.046  0.0649] [-162.3924 -151.9378] [-162.801 -152.341]
376 5.9297 [0.0441 0.0527] [0.0423 0.0579] [-169.3417 -157.8081] [-164.8815 -152.2036]
377 6.9962 [0.0474 0.0565] [0.0431 0.0621] [-168.2574 -156.1407] [-164.1444 -151.8864]
378 5.5507 [0.0559 0.0662] [0.0458 0.0616] [-164.7145 -151.3196] [-165.232  -152.2125]
379 6.035 [0.0479 0.0794] [0.0428 0.0764] [-161.3338 -146.8167] [-166.271  -152.1668]
380 5.7344 [0.0579 0.0692] [0.0387 0.0589] [-162.9424 -148.1678] [-165.9921 -152.2194]
381 9.9897 [0.0398 0.0728] [0.0555 0.0545] [-164.

463 3.2851 [0.0529 0.0661] [0.046  0.0635] [-160.8856 -150.7405] [-163.2012 -152.2565]
464 6.6241 [0.049  0.0689] [0.0494 0.0658] [-160.6707 -153.0382] [-159.3532 -151.3632]
465 2.3639 [0.0584 0.0572] [0.0591 0.0772] [-163.2184 -156.9454] [-158.8572 -153.0009]
466 4.1575 [0.0516 0.0521] [0.0617 0.0531] [-157.4938 -151.8923] [-159.4901 -153.5057]
467 5.9596 [0.0515 0.0656] [0.0547 0.0778] [-155.4416 -148.9557] [-160.2436 -151.6002]
468 8.6902 [0.0548 0.0742] [0.058  0.0608] [-154.3565 -147.8168] [-158.8071 -151.8593]
469 2.986 [0.0472 0.0641] [0.0695 0.0548] [-165.0804 -157.7684] [-158.64 -152.84]
470 9.2674 [0.0389 0.0585] [0.0547 0.0646] [-163.1608 -153.6964] [-162.5628 -151.9209]
471 1.9579 [0.0621 0.0717] [0.0437 0.066 ] [-159.847  -152.1543] [-158.7337 -153.0141]
472 9.8918 [0.0586 0.0712] [0.0549 0.0778] [-153.6691 -146.4637] [-157.8937 -151.391 ]
473 4.3026 [0.0575 0.0598] [0.0489 0.0552] [-160.568  -154.7287] [-160.1135 -154.4842]
474 18.4919 [0.0617 0.0649] [0.0562 0.0574] [-15

555 10.6746 [0.0691 0.0553] [0.0589 0.0655] [-150.9129 -146.604 ] [-154.5779 -151.4768]
556 10.5092 [0.0635 0.0669] [0.0666 0.0754] [-150.5247 -148.3881] [-154.4752 -151.5962]
557 5.5866 [0.0572 0.0509] [0.0548 0.0626] [-156.278  -155.0109] [-155.458  -151.4064]
558 21.51 [0.0552 0.0599] [0.0695 0.0842] [-154.5633 -153.6557] [-152.8854 -151.791 ]
559 6.9443 [0.0636 0.0604] [0.0587 0.0685] [-159.0989 -155.7541] [-155.0689 -151.4589]
560 1.7644 [0.0629 0.0716] [0.0574 0.0624] [-156.0176 -152.2603] [-154.7144 -153.0404]
561 9.8971 [0.0727 0.0669] [0.0658 0.0516] [-148.3639 -144.5814] [-154.8987 -151.7942]
562 5.0209 [0.072  0.0596] [0.0612 0.0552] [-152.9985 -150.5096] [-157.4791 -154.4871]
563 9.3969 [0.0672 0.0618] [0.063  0.0665] [-155.7624 -154.4568] [-154.5106 -151.7961]
564 6.5655 [0.0661 0.0398] [0.0634 0.0542] [-162.8108 -161.8905] [-155.8805 -154.691 ]
565 3.0052 [0.0554 0.0605] [0.0601 0.0603] [-157.8083 -153.937 ] [-157.4285 -152.2008]
566 4.5389 [0.0708 0.0644] [0.0519 0.069 ]

648 6.2332 [0.0705 0.0677] [0.0617 0.0596] [-153.762  -150.7105] [-153.3089 -150.7062]
649 6.7576 [0.0567 0.0769] [0.0629 0.0655] [-153.1871 -150.7277] [-152.9445 -151.0631]
650 9.0033 [0.0675 0.066 ] [0.0507 0.069 ] [-149.032 -147.153] [-152.9032 -150.8819]
TEST 650 13.8096 0.0 -0.0001
[0.0625 0.0627 0.0663 0.0531]
[-152.952  -150.9691 -151.2222 -158.9403]
651 3.5388 [0.0592 0.0563] [0.0611 0.0678] [-153.2619 -151.8438] [-153.8547 -151.3238]
652 6.485 [0.068  0.0607] [0.0617 0.0675] [-154.9477 -153.8604] [-152.9328 -150.7114]
653 2.0923 [0.0631 0.054 ] [0.0609 0.0689] [-156.1255 -153.871 ] [-153.7381 -151.4385]
654 9.9373 [0.0736 0.0693] [0.0694 0.0584] [-149.1231 -146.2275] [-152.2021 -150.7258]
655 6.8131 [0.0701 0.0726] [0.0631 0.0592] [-151.6428 -147.988 ] [-153.4165 -150.7272]
656 7.9422 [0.0722 0.0562] [0.0665 0.0609] [-153.4954 -150.5301] [-153.0444 -150.8688]
657 4.564 [0.0574 0.0592] [0.0653 0.0735] [-158.4533 -155.6074] [-154.9715 -151.3412]
658 2.7455 [0.0592 0.0646] [0.065

740 1.9669 [0.0681 0.0594] [0.0677 0.0604] [-155.7594 -154.2094] [-153.0076 -151.7237]
741 3.703 [0.0587 0.0639] [0.0708 0.0604] [-153.8133 -152.0472] [-152.6954 -151.3305]
742 24.82 [0.062  0.0737] [0.0687 0.0666] [-146.3321 -145.2427] [-151.091  -150.9799]
743 2.7486 [0.0643 0.0564] [0.0684 0.0709] [-156.1995 -155.0694] [-152.7238 -151.6563]
744 13.9286 [0.0674 0.0778] [0.061  0.0659] [-151.318 -150.565] [-151.9969 -150.8448]
745 4.3707 [0.0726 0.0632] [0.0715 0.0588] [-155.0418 -155.073 ] [-153.2575 -153.2792]
746 15.1697 [0.0649 0.0737] [0.056  0.0681] [-147.645  -147.4764] [-151.8794 -151.0743]
747 20.8615 [0.0704 0.0771] [0.0632 0.0601] [-147.4768 -147.7923] [-150.8942 -151.0642]
748 8.7898 [0.0488 0.0649] [0.0688 0.0665] [-155.1277 -154.7076] [-152.0822 -150.7841]
749 10.5251 [0.064  0.0715] [0.066  0.0679] [-155.4562 -154.6916] [-151.8738 -150.9577]
750 8.5626 [0.0659 0.0681] [0.0672 0.0648] [-152.4539 -151.0037] [-152.1607 -151.0788]
TEST 750 13.8096 -0.0 -0.0
[0.0627 0.0644 0

832 4.8268 [0.0588 0.0642] [0.0538 0.0725] [-153.771  -153.0042] [-153.2079 -152.2538]
833 2.7105 [0.0654 0.0499] [0.0656 0.0616] [-156.3639 -155.4534] [-152.695 -152.718]
834 2.086 [0.0633 0.0644] [0.0507 0.0739] [-154.8601 -153.9156] [-152.5186 -152.68  ]
835 25.6525 [0.0788 0.0739] [0.0622 0.0695] [-145.73   -145.4246] [-151.3782 -151.9479]
836 1.9375 [0.0753 0.0609] [0.0573 0.0726] [-154.0946 -154.2818] [-152.4662 -152.6414]
837 8.6615 [0.0746 0.0688] [0.0677 0.0647] [-151.3125 -151.8551] [-151.452  -151.2888]
838 13.0443 [0.0724 0.0574] [0.0638 0.0561] [-151.7481 -152.2152] [-151.9155 -151.3472]
839 4.4412 [0.0675 0.0568] [0.0596 0.0648] [-155.4195 -155.9356] [-152.3543 -152.4636]
840 4.6993 [0.0545 0.0673] [0.0548 0.0605] [-154.4997 -153.7664] [-153.4567 -152.1547]
841 8.7088 [0.0674 0.0636] [0.0695 0.0608] [-148.3424 -147.2348] [-152.0327 -151.3475]
842 2.3897 [0.0703 0.0576] [0.0592 0.0699] [-151.5512 -150.5084] [-152.3951 -152.6517]
843 5.5522 [0.0621 0.0602] [0.0646 0.0501] [

924 5.8085 [0.0723 0.0657] [0.0694 0.0652] [-151.1517 -150.6198] [-152.0363 -151.2678]
925 4.3688 [0.0574 0.0616] [0.0701 0.0708] [-150.7222 -151.501 ] [-150.9966 -152.6975]
926 1.7723 [0.0617 0.0679] [0.0651 0.0566] [-151.4463 -152.3859] [-151.3979 -152.5541]
927 3.6738 [0.0741 0.0542] [0.0569 0.0557] [-151.4294 -152.2479] [-152.0212 -152.0156]
928 5.9745 [0.0646 0.0712] [0.059  0.0559] [-149.9823 -150.6514] [-151.637  -151.1142]
929 13.9452 [0.0675 0.063 ] [0.0664 0.0711] [-149.3975 -149.8986] [-151.3094 -151.1298]
930 15.6408 [0.0636 0.0724] [0.076  0.0618] [-152.945  -152.9207] [-151.3915 -151.1235]
931 10.5726 [0.065  0.0526] [0.0648 0.0655] [-155.5797 -155.083 ] [-151.2537 -151.1528]
932 2.187 [0.0646 0.0587] [0.0681 0.0652] [-154.4115 -154.0556] [-151.0831 -152.2249]
933 23.5929 [0.0792 0.0766] [0.0724 0.0636] [-143.6144 -143.3314] [-150.5376 -151.2752]
934 11.0545 [0.0778 0.074 ] [0.059  0.0656] [-146.73   -146.4237] [-151.4163 -150.9354]
935 18.4598 [0.0678 0.0634] [0.0586 0.0

1016 2.6558 [0.0584 0.0643] [0.0699 0.0701] [-154.8305 -155.7262] [-150.9732 -152.3142]
1017 7.9528 [0.054  0.0552] [0.0646 0.0673] [-153.2787 -153.2732] [-151.358  -151.3712]
1018 2.4714 [0.0765 0.0625] [0.0656 0.0642] [-152.8859 -152.7436] [-153.2625 -153.4634]
1019 7.1647 [0.0614 0.0756] [0.0677 0.0576] [-147.616  -147.1376] [-151.4367 -151.1586]
1020 2.2536 [0.0669 0.0701] [0.0775 0.0646] [-150.7732 -151.2518] [-150.9016 -152.4729]
1021 13.3045 [0.0732 0.0706] [0.0608 0.0661] [-149.4528 -149.7795] [-151.098  -151.4311]
1022 4.4936 [0.0621 0.0683] [0.0619 0.0799] [-156.0027 -156.5583] [-151.925  -151.6618]
1023 5.2677 [0.0634 0.0589] [0.0562 0.0658] [-155.0668 -155.4818] [-152.2436 -151.7042]
1024 10.1011 [0.0629 0.076 ] [0.0661 0.0755] [-147.4927 -147.7617] [-150.7767 -150.9478]
1025 2.1026 [0.0672 0.0781] [0.0561 0.0711] [-150.1023 -150.2975] [-151.1904 -151.976 ]
1026 1.6734 [0.0596 0.065 ] [0.0735 0.0628] [-151.0351 -150.8353] [-151.6933 -151.952 ]
1027 2.6873 [0.0647 0.0615] [0

1107 18.4263 [0.0637 0.0809] [0.0746 0.0603] [-143.1085 -143.2024] [-150.8272 -150.7855]
1108 12.8547 [0.0816 0.0729] [0.0645 0.0581] [-144.9417 -144.836 ] [-150.4829 -150.5909]
1109 10.6856 [0.0667 0.065 ] [0.0662 0.0636] [-149.5566 -148.1905] [-150.8295 -150.4975]
1110 18.4866 [0.053  0.0698] [0.0558 0.066 ] [-153.4246 -153.1383] [-150.9739 -150.6083]
1111 10.236 [0.061  0.0553] [0.0693 0.0657] [-160.4093 -160.3138] [-152.4274 -151.267 ]
1112 4.0395 [0.0569 0.0663] [0.0666 0.0603] [-156.323  -157.0734] [-150.6424 -151.7166]
1113 12.5232 [0.081  0.0698] [0.0781 0.0739] [-145.3926 -145.1773] [-150.9342 -150.4969]
1114 5.0272 [0.0736 0.0781] [0.0638 0.0667] [-145.8039 -145.5634] [-151.5675 -151.3601]
1115 10.6547 [0.0792 0.0774] [0.0617 0.0672] [-144.5908 -145.959 ] [-150.9105 -151.0801]
1116 6.3843 [0.0518 0.0492] [0.0725 0.0564] [-153.8077 -157.0613] [-151.5783 -151.561 ]
1117 3.2545 [0.0597 0.0679] [0.0602 0.0718] [-152.1268 -150.8209] [-152.5633 -151.305 ]
1118 2.389 [0.0513 0.0688]

1199 6.3637 [0.0581 0.0621] [0.0549 0.0693] [-154.3019 -154.3261] [-150.9014 -150.6646]
1200 4.594 [0.0588 0.0574] [0.0726 0.0655] [-154.6805 -155.0966] [-151.2015 -151.2077]
TEST 1200 13.8096 -0.0 -0.0
[0.0677 0.0658 0.0624 0.0619]
[-150.7559 -150.9258 -150.5568 -150.8113]
1201 5.7773 [0.0591 0.0649] [0.0558 0.0707] [-152.2417 -152.7303] [-150.7382 -150.8299]
1202 9.6652 [0.0641 0.0843] [0.0536 0.0713] [-145.8631 -146.2828] [-150.6601 -150.9088]
1203 8.8463 [0.0737 0.0687] [0.0577 0.0731] [-147.6381 -148.1105] [-150.9337 -151.3748]
1204 18.803 [0.0867 0.0713] [0.067  0.0578] [-146.468  -146.6856] [-150.5972 -150.6093]
1205 6.4798 [0.051  0.0644] [0.0664 0.0629] [-155.2298 -156.302 ] [-151.534  -153.2798]
1206 7.293 [0.0562 0.0574] [0.0671 0.0721] [-152.2003 -152.7651] [-150.5547 -151.0628]
1207 3.9703 [0.0577 0.0605] [0.0642 0.0686] [-154.458  -154.8266] [-151.3402 -151.2963]
1208 7.7701 [0.0694 0.069 ] [0.0751 0.0705] [-148.3785 -148.7922] [-150.7651 -150.9812]
1209 11.0425 [0.0752 0

1290 18.1496 [0.0655 0.067 ] [0.0638 0.066 ] [-148.0007 -148.6755] [-150.3299 -150.486 ]
1291 5.0349 [0.0655 0.0607] [0.0638 0.0613] [-152.9673 -153.8095] [-150.7632 -151.1063]
1292 5.4833 [0.0566 0.0697] [0.0588 0.0609] [-151.4904 -152.0317] [-151.0704 -151.0869]
1293 4.6248 [0.0677 0.0684] [0.0582 0.0662] [-152.6645 -153.987 ] [-151.232  -153.3093]
1294 2.579 [0.0665 0.0618] [0.073  0.0629] [-151.3213 -151.885 ] [-151.138  -151.8113]
1295 12.0206 [0.0655 0.0556] [0.0614 0.0662] [-146.9296 -147.4379] [-150.3294 -150.9214]
1296 1.8509 [0.0578 0.0547] [0.0594 0.0702] [-151.0024 -151.5981] [-151.2754 -151.9837]
1297 21.6971 [0.0802 0.0652] [0.0607 0.0529] [-146.0213 -146.8644] [-150.1291 -151.1479]
1298 4.0268 [0.0555 0.0745] [0.063  0.0554] [-154.5056 -155.9572] [-150.4117 -152.3575]
1299 23.547 [0.0659 0.0711] [0.0704 0.0715] [-150.7985 -151.7228] [-150.5776 -151.0649]
1300 6.4692 [0.058  0.0655] [0.0741 0.0788] [-154.3814 -154.7721] [-151.2328 -150.696 ]
TEST 1300 13.8096 0.0 -0.0
[0.

In [None]:
plt.plot(batch_pred.cpu().detach().numpy())

In [None]:
plt.plot(test_pred.cpu().detach().numpy())