In [11]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tnrange
import torch.optim as optim
import torch.nn.functional as F
from RATIO_HVAE_shGLM import RATIO_HVAE_shGLM
from sklearn import metrics
import torch.distributions as dist

In [17]:
X_hot = torch.tensor([0.8, 0.2])
u = torch.rand_like(X_hot)
g = - torch.log(- torch.log(u + 1e-10) + 1e-10)
Z_out = F.softmax((X_hot + g) / 0.5, dim=0).flatten()
print(Z_out)

tensor([0.2425, 0.7575])


# Hyperparams

In [2]:
train_T = 20000
test_T = 8000

C_den = torch.zeros(5,5)
C_den[0,1:] = 1
sub_no = C_den.shape[0]

batch_size = 1500

syn_basis_no = 2
hist_basis_no = 2
T_hist = 201
T_syn = 201
T_enc = 201
hid_dim = 256
temp = 0.5

In [3]:
Ensyn = torch.tensor([0, 106, 213, 211, 99])
Insyn = torch.tensor([1, 22, 36, 42, 19])
E_no = torch.sum(Ensyn)
I_no = torch.sum(Insyn)

C_syn_e = torch.zeros(sub_no, E_no)
C_syn_i = torch.zeros(sub_no, I_no)

E_count = 0
for s in range(sub_no):
    C_syn_e[s,E_count:E_count+Ensyn[s]] = 1
    E_count += Ensyn[s]

I_count = 0
for s in range(sub_no):
    C_syn_i[s,I_count:I_count+Insyn[s]] = 1
    I_count += Insyn[s]

# Train Params

In [4]:
model = RATIO_HVAE_shGLM(C_den.cuda(), C_syn_e.cuda(), C_syn_i.cuda(), E_no, T_syn, syn_basis_no,
                T_hist, hist_basis_no, T_enc, hid_dim, temp)
model = model.float().cuda()

V_ref = np.fromfile("/media/hdd01/sklee/cont_shglm/inputs/vdata_NMDA_ApN0.5_13_Adend_r0_o2_i2_g_b0.bin")
V_ref = V_ref[1:-2]

train_V_ref = V_ref[:train_T]
test_V_ref = V_ref[train_T:train_T+test_T]

test_V_ref = torch.from_numpy(test_V_ref).cuda()
train_V_ref = torch.from_numpy(train_V_ref)

In [5]:
E_neural = np.load("/media/hdd01/sklee/cont_shglm/inputs/Espikes_d48000_r1_rep1_Ne629_e5_E20_neural.npy")
I_neural = np.load("/media/hdd01/sklee/cont_shglm/inputs/Ispikes_d48000_r1_rep1_Ni120_i20_I30_neural.npy")

train_S_E = E_neural[:train_T]
train_S_I = I_neural[:train_T]
test_S_E = E_neural[train_T:train_T+test_T]
test_S_I = I_neural[train_T:train_T+test_T]

train_S_E = torch.from_numpy(train_S_E)
train_S_I = torch.from_numpy(train_S_I)
test_S_E = torch.from_numpy(test_S_E).float().cuda()
test_S_I = torch.from_numpy(test_S_I).float().cuda()

In [6]:
repeat_no = 2
batch_no = (train_V_ref.shape[0] - batch_size) * repeat_no
train_idx = np.empty((repeat_no, train_V_ref.shape[0] - batch_size))
for i in range(repeat_no):
    part_idx = np.arange(train_V_ref.shape[0] - batch_size)
    np.random.shuffle(part_idx)
    train_idx[i] = part_idx
train_idx = train_idx.flatten()
train_idx = torch.from_numpy(train_idx)

print(batch_no)
print(train_idx.shape[0])

37000
37000


In [7]:
optimizer = optim.Adam(model.parameters(), lr=0.005)

In [None]:
loss_array = np.empty((batch_no))
beta = 0

for i in tnrange(batch_no):
    
    if i%250 == 249 and i <= 2500:
        beta += 0.1
    
    model.train()
    optimizer.zero_grad()
    batch_idx = train_idx[i].long()
    batch_S_E = train_S_E[batch_idx : batch_idx+batch_size].float().cuda()
    batch_S_I = train_S_I[batch_idx : batch_idx+batch_size].float().cuda()
    batch_ref = train_V_ref[batch_idx:batch_idx+batch_size].float().cuda()
    
    batch_pred, rec_loss, KL_loss, Z_P_prior, Z_P_post = model(batch_ref,
                                                              batch_S_E,
                                                              batch_S_I,
                                                              beta)
        
    loss = rec_loss + beta*KL_loss
    loss_array[i] = loss.item()
    
    print(i, rec_loss.item(), KL_loss.item())
    
    loss.backward()
    optimizer.step()
    
    if i%50 == 0:
        model.eval()
        test_pred, test_Y, test_spikes, test_ratios = model.decode(test_S_E, test_S_I)
        test_diff = (test_V_ref - test_pred) ** 1
        test_loss = torch.var(test_diff)
        test_score = metrics.explained_variance_score(y_true=test_V_ref.cpu().detach().numpy(),
                                                      y_pred=test_pred.cpu().detach().numpy(),
                                                      multioutput='uniform_average')
        train_score = metrics.explained_variance_score(y_true=batch_ref.cpu().detach().numpy(),
                                                      y_pred=batch_pred.cpu().detach().numpy(),
                                                      multioutput='uniform_average')
        print("TEST", i, test_loss.item(), test_score.item(), train_score.item())
        print(torch.mean(test_spikes, 0).cpu().detach().numpy())
        #print(torch.mean(spike_probs, 0).cpu().detach().numpy())
        if i%100 == 0:
            torch.save(model.state_dict(), "/media/hdd01/sklee/lvae_shglm/RATIO_HVAE_sub5_s2_h2_shglm_i"+str(i)+".pt")
    


  for i in tnrange(batch_no):


HBox(children=(FloatProgress(value=0.0, max=37000.0), HTML(value='')))

0 4.389269828796387 0.14818760752677917
TEST 0 13.655240087332258 0.011300200775855584 0.011595666408538818
[0.37762803 0.39168617 0.38898978 0.37526685]
1 11.830157279968262 1.3684355020523071
2 6.315035343170166 1.1137489080429077
3 12.8737154006958 1.1092668771743774
4 22.493431091308594 1.1050376892089844
5 1.771389365196228 1.0978119373321533
6 3.965461254119873 1.0825203657150269
7 4.596217155456543 1.0990195274353027
8 5.894040584564209 1.1010977029800415
9 8.889212608337402 1.077635407447815
10 5.012247562408447 1.0987311601638794
11 12.626554489135742 1.0939123630523682
12 17.353513717651367 1.096547245979309
13 16.50937271118164 1.0904375314712524
14 4.262353420257568 1.1013611555099487
15 7.618457317352295 1.0978909730911255
16 6.197743892669678 1.096578598022461
17 4.307802677154541 1.0549242496490479
18 15.275097846984863 1.0881836414337158
19 5.097107887268066 1.0746549367904663
20 3.2105600833892822 1.0499855279922485
21 4.562482833862305 1.0924155712127686
22 1.73084723

191 7.323297023773193 1.0235570669174194
192 2.198894739151001 0.9646327495574951
193 6.236310958862305 1.0187106132507324
194 19.724353790283203 1.0414416790008545
195 1.5336350202560425 0.9571194648742676
196 2.6682686805725098 0.9597228169441223
197 2.67560076713562 0.958729088306427
198 5.544233322143555 1.025828242301941
199 1.9467600584030151 0.9857099652290344
200 2.2681992053985596 0.981041669845581
TEST 200 12.388573382533963 0.10301247451627427 0.32215195894241333
[0.37316236 0.3746463  0.39160985 0.40066588]
201 11.976812362670898 1.0319886207580566
202 2.9632327556610107 0.9900763034820557
203 10.633423805236816 1.051002025604248
204 2.9305615425109863 0.9751130938529968
205 1.0104162693023682 0.9847822785377502
206 5.0208306312561035 1.0534437894821167
207 2.1139886379241943 0.972931981086731
208 3.240513801574707 0.9673862457275391
209 5.547379970550537 1.0521748065948486
210 12.638267517089844 1.0463600158691406
211 3.100154399871826 0.9826443791389465
212 1.113462209701

378 3.545628070831299 0.34480541944503784
379 2.2009940147399902 0.3476046919822693
380 0.9960578083992004 0.3561238646507263
381 0.7964209914207458 0.3517501950263977
382 5.890131950378418 0.32136356830596924
383 1.940882921218872 0.3395739495754242
384 6.406184196472168 0.32254651188850403
385 2.500368356704712 0.33889737725257874
386 0.7582316994667053 0.3442772626876831
387 1.0923866033554077 0.3359399735927582
388 1.276839017868042 0.3363669812679291
389 12.45949649810791 0.31419551372528076
390 1.2319345474243164 0.3372029662132263
391 2.4863884449005127 0.33143872022628784
392 8.454360961914062 0.31217247247695923
393 2.5786240100860596 0.32807570695877075
394 4.235644817352295 0.3057597279548645
395 2.2599735260009766 0.3113223910331726
396 1.9408237934112549 0.30860915780067444
397 8.134782791137695 0.30119308829307556
398 1.398772954940796 0.3146338164806366
399 0.8027179837226868 0.3160960376262665
400 2.742828845977783 0.30646008253097534
TEST 400 8.956781044481566 0.351489

562 2.9006450176239014 0.1184355691075325
563 3.6546475887298584 0.12463144958019257
564 2.2896502017974854 0.14280426502227783
565 4.518709182739258 0.1162136048078537
566 1.597840428352356 0.13550038635730743
567 3.0694940090179443 0.11651013046503067
568 1.3419865369796753 0.1314607709646225
569 0.7866390347480774 0.13214151561260223
570 1.6218680143356323 0.13027825951576233
571 10.785926818847656 0.10491430759429932
572 3.8090662956237793 0.11106625944375992
573 0.9325951933860779 0.12955759465694427
574 4.211172580718994 0.10855395346879959
575 0.61715167760849 0.13328999280929565
576 14.03885269165039 0.09838258475065231
577 2.588651418685913 0.13426564633846283
578 1.4912981986999512 0.13249701261520386
579 8.074457168579102 0.10405714809894562
580 1.1945266723632812 0.12925198674201965
581 3.164687156677246 0.10732801258563995
582 0.9865989685058594 0.12409140914678574
583 1.4291635751724243 0.11471638083457947
584 14.192780494689941 0.09041576832532883
585 21.463884353637695 

747 1.7609591484069824 0.05887812376022339
748 2.074272394180298 0.06156589090824127
749 0.9153826236724854 0.06836920231580734
750 7.315086841583252 0.05378575623035431
TEST 750 7.034291974580305 0.4906861380905754 0.432536780834198
[0.04522897 0.04379841 0.47828713 0.477002  ]
751 7.197703838348389 0.057625964283943176
752 3.002312183380127 0.05245338752865791
753 5.628872871398926 0.052124377340078354
754 2.692615270614624 0.05679301917552948
755 2.0095901489257812 0.0647902637720108
756 1.8742895126342773 0.06589894741773605
757 13.585930824279785 0.047150686383247375
758 6.798871040344238 0.055340301245450974
759 1.2295247316360474 0.06974358856678009
760 3.5702059268951416 0.05450063943862915
761 1.8832762241363525 0.06494341790676117
762 3.3143837451934814 0.04983992129564285
763 5.652383327484131 0.050569020211696625
764 0.837783694267273 0.06308779865503311
765 9.221359252929688 0.0470426082611084
766 2.055614709854126 0.05440938100218773
767 2.901982069015503 0.04835454374551

929 1.5454919338226318 0.03607742115855217
930 4.90593957901001 0.0263210441917181
931 2.0064151287078857 0.02693217433989048
932 1.8185920715332031 0.03107823058962822
933 2.6581921577453613 0.029566166922450066
934 2.576075315475464 0.026593200862407684
935 2.538714647293091 0.02494209259748459
936 2.064779281616211 0.027857976034283638
937 0.9594019055366516 0.0328860878944397
938 1.310793399810791 0.03450324013829231
939 1.329155683517456 0.03263089060783386
940 1.5767751932144165 0.0336710624396801
941 1.7757858037948608 0.03503471240401268
942 0.9705548286437988 0.03418372943997383
943 2.2157766819000244 0.027439719066023827
944 2.0813465118408203 0.02769247628748417
945 1.3979002237319946 0.03635422885417938
946 3.623317241668701 0.027214091271162033
947 2.2654731273651123 0.024397628381848335
948 1.076354742050171 0.033607468008995056
949 3.0952141284942627 0.027593662962317467
950 1.8001141548156738 0.03572404012084007
TEST 950 5.995447147509298 0.5659031001285453 0.5936419963

1107 2.9558005332946777 0.016346391290426254
1108 1.0849218368530273 0.025558289140462875
1109 1.1125985383987427 0.02414710633456707
1110 0.7673636078834534 0.02533094584941864
1111 0.9664798378944397 0.024677693843841553
1112 3.0915520191192627 0.020010536536574364
1113 1.5458852052688599 0.02310122735798359
1114 2.5790295600891113 0.024909941479563713
1115 1.8196823596954346 0.02425743080675602
1116 11.593074798583984 0.01687956042587757
1117 3.1582441329956055 0.022940808907151222
1118 1.8700766563415527 0.024166107177734375
1119 12.742186546325684 0.016076240688562393
1120 5.54781436920166 0.018426142632961273
1121 1.0184624195098877 0.023946816101670265
1122 0.9828745126724243 0.0231862161308527
1123 6.118014335632324 0.016855336725711823
1124 1.5195090770721436 0.017269564792513847
1125 1.2485140562057495 0.02063423953950405
1126 0.8413779735565186 0.021864736452698708
1127 0.7137703895568848 0.021871816366910934
1128 1.9557421207427979 0.01650143601000309
1129 0.976476490497589

1285 0.9586272835731506 0.015222491696476936
1286 1.0716334581375122 0.014773467555642128
1287 4.832239627838135 0.011900658719241619
1288 1.3005194664001465 0.013741228729486465
1289 1.0789839029312134 0.013534118421375751
1290 1.7640177011489868 0.010948588140308857
1291 7.549202919006348 0.010265802033245564
1292 3.1559367179870605 0.01100828405469656
1293 1.2532161474227905 0.015271621756255627
1294 0.7351120710372925 0.015697255730628967
1295 0.814480185508728 0.015907559543848038
1296 0.9181444644927979 0.014888491481542587
1297 7.047488689422607 0.0106684984639287
1298 1.534665584564209 0.015209553763270378
1299 1.623762845993042 0.011201426386833191
1300 0.7876745462417603 0.014388056471943855
TEST 1300 5.15443262308056 0.6267962728117742 0.5595122575759888
[0.00943581 0.00971714 0.4881583  0.4880933 ]
1301 2.5688669681549072 0.010763419792056084
1302 4.293274402618408 0.009694023989140987
1303 0.7403606176376343 0.014116381295025349
1304 5.99028205871582 0.00962159689515829
13

1460 17.081527709960938 0.0058845947496593
1461 4.222360610961914 0.008762853220105171
1462 1.0205858945846558 0.01155575830489397
1463 0.8166041970252991 0.010754513554275036
1464 8.188667297363281 0.00790061429142952
1465 1.0864441394805908 0.010601655580103397
1466 1.2800697088241577 0.010007725097239017
1467 1.1312519311904907 0.010166493244469166
1468 1.0095542669296265 0.011035638861358166
1469 1.6518008708953857 0.008663526736199856
1470 4.572585582733154 0.008399508893489838
1471 1.399417519569397 0.008296973071992397
1472 1.2421835660934448 0.01051357388496399
1473 3.3546035289764404 0.007092874962836504
1474 0.8607603907585144 0.009276200085878372
1475 8.169156074523926 0.006733363959938288
1476 11.096061706542969 0.006777121219784021
1477 1.4154157638549805 0.010233916342258453
1478 12.130773544311523 0.006102537736296654
1479 1.3657050132751465 0.010927528142929077
1480 3.2526795864105225 0.011211799457669258
1481 1.489647388458252 0.010903928428888321
1482 3.33268475532531

1637 2.7744107246398926 0.005487394984811544
1638 2.921255111694336 0.0053554424084723
1639 1.1170531511306763 0.007638637442141771
1640 4.872859477996826 0.006408294662833214
1641 1.543316125869751 0.008044796995818615
1642 1.0400311946868896 0.007114392705261707
1643 1.4182536602020264 0.008081971667706966
1644 1.0333242416381836 0.005722345784306526
1645 2.1562838554382324 0.004840394482016563
1646 1.11537504196167 0.007070295978337526
1647 1.0823400020599365 0.007329742424190044
1648 0.8880414962768555 0.007602855563163757
1649 1.4652572870254517 0.007114104926586151
1650 0.9058699607849121 0.006962539628148079
TEST 1650 4.646995825714313 0.6635369420238979 0.6786016821861267
[0.00496482 0.00566229 0.49541247 0.50277215]
1651 2.043578624725342 0.004652615636587143
1652 1.3408135175704956 0.005392927676439285
1653 3.6186017990112305 0.005844746716320515
1654 1.2011460065841675 0.005456168204545975
1655 3.5900182723999023 0.005428601521998644
1656 2.070890188217163 0.0055436547845602

1811 1.2467793226242065 0.004028879571706057
1812 1.1655994653701782 0.004226903431117535
1813 1.331523060798645 0.005735862534493208
1814 0.5938028693199158 0.005413509439677
1815 0.7383754849433899 0.005711388774216175
1816 0.5749728083610535 0.004927119705826044
1817 0.9094311594963074 0.00475721713155508
1818 4.549152374267578 0.0036005775909870863
1819 17.545942306518555 0.0027887700125575066
1820 1.410609245300293 0.00403301976621151
1821 1.0397992134094238 0.005520121194422245
1822 0.9604212045669556 0.005585978273302317
1823 2.7823703289031982 0.004275443032383919
1824 2.5548291206359863 0.004669002257287502
1825 7.384054660797119 0.0038838214240968227
1826 1.7466119527816772 0.004098282661288977
1827 0.7627809643745422 0.0056785405613482
1828 0.7191957831382751 0.005427586380392313
1829 0.770535409450531 0.004546914249658585
1830 5.4524922370910645 0.003081372007727623
1831 2.3630599975585938 0.003102385438978672
1832 2.114640474319458 0.0034704767167568207
1833 2.014466285705

KeyboardInterrupt: 

In [None]:
plt.plot(test_pred.cpu().detach().numpy())

In [None]:
plt.plot(test_pred.cpu().detach().numpy())