In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import audtorch.metrics.functional as audtorch

import time
from tqdm import tnrange, tqdm_notebook
import os
import matplotlib.pyplot as plt

SoX could not be found!

    If you do not have SoX, proceed here:
     - - - http://sox.sourceforge.net/ - - -

    If you do (or think that you should) have SoX, double-check your
    path variables.
    


In [2]:
class FullNonLinear(nn.Module):
    def __init__(self, unit_no, t_dim, h_dim, p_dim):
        super(FullNonLinear, self).__init__()
        self.unit_no = unit_no
        self.t_dim = t_dim
        self.h_dim = h_dim
        self.p_dim = p_dim
        
        self.input_layer = nn.Sequential(
                            nn.Linear(self.unit_no * self.t_dim , self.h_dim),
                            nn.Tanh()).cuda()
        self.hidden1 = nn.Sequential(
                            nn.Linear(self.h_dim, self.h_dim),
                            nn.Tanh()).cuda()
        self.hidden2 = nn.Sequential(
                            nn.Linear(self.h_dim, self.h_dim),
                            nn.Tanh()).cuda()
        self.hidden3 = nn.Sequential(
                            nn.Linear(self.h_dim, self.h_dim),
                            nn.Tanh()).cuda()
        self.output_layer = nn.Linear(self.h_dim, self.p_dim).cuda()
        
    def forward(self, S):
        
        out = self.input_layer(S)
        out = self.hidden1(out)
        out = self.hidden2(out)
        out = self.hidden3(out)
        out = self.output_layer(out)
        
        return out

In [2]:
class PartNonLinear(nn.Module):
    def __init__(self, unit_no, t_dim, k_dim, h_dim, p_dim, f_dim):
        super(PartNonLinear, self).__init__()
        self.unit_no = unit_no
        self.t_dim = t_dim
        self.h_dim = h_dim
        self.p_dim = p_dim
        self.k_dim = k_dim
        self.f_dim = f_dim
        
        self.featurize = nn.ModuleList([nn.Linear(self.t_dim,
                                                  self.f_dim) for i in range(self.unit_no)]).cuda()
        
        self.hidden1 = nn.ModuleList([nn.Linear(self.k_dim*self.f_dim,
                                               self.h_dim) for i in range(self.p_dim)]).cuda()
        self.hidden1_act = nn.ModuleList([nn.PReLU() for i in range(self.p_dim)]).cuda()
        
        self.output_layer = nn.ModuleList([nn.Linear(self.h_dim,
                                                    1) for i in range(self.p_dim)]).cuda()
        
    def forward(self, S, pix_units):
        
        F = torch.empty(S.shape[0], self.unit_no * self.f_dim).cuda()
        for n in range(self.unit_no):
            feat_n = self.featurize[n](S[:, n*self.t_dim : (n+1)*self.t_dim])
            F[:, n*self.f_dim : (n+1)*self.f_dim] = feat_n
        
        I = torch.empty(S.shape[0] , self.p_dim).cuda()
        
        for x in range(self.p_dim):
            unit_ids = pix_units[x]
            feat_ids = torch.empty((self.k_dim * self.f_dim))
            for i in range(self.k_dim):
                feat_ids[i*self.f_dim : (i+1)*self.f_dim] = torch.arange(self.f_dim) + unit_ids[i]*self.f_dim
            
            pix_feat = self.hidden1[x](F[:, feat_ids.long()])
            pix_feat = self.hidden1_act[x](pix_feat)

            out = self.output_layer[x](pix_feat)
            
            I[:, x] = out.reshape(-1)
            
        return I            

In [3]:
t_dim = 50
k_dim = 10
h_dim = 40
f_dim = 10

unit_no = 2094
#image_no = 9800
image_no = 2000
#p_dim = 95*146
p_dim = 730

epoch_no = 40
batch_size = 64
batch_no = epoch_no * image_no // batch_size

batch_ids = np.tile(np.arange(image_no).reshape((1,-1)), (epoch_no, 1))
for i in range(epoch_no):
    np.random.shuffle(batch_ids[i])
batch_ids = batch_ids.reshape((batch_no, batch_size))
batch_ids = torch.from_numpy(batch_ids).cuda()

In [4]:
#S = np.load("/ssd/joon/2017_11_29_ns/yass/neural/yass_lin_50_train_neural.npy")
S = np.load("/ssd/joon/2017_11_29_ns/yass/neural/yass_i2000_50_train_neural.npy")
S = torch.from_numpy(S)

#I = np.load("/ssd/joon/2017_11_29_ns/images/smooth_train_images.npy")
#I = np.load("/ssd/joon/2017_11_29_ns/images/i2000_p730_hp_train_images.npy")
I = np.load("/ssd/joon/2017_11_29_ns/images/i2000_p730_smooth_train_images.npy")
I = torch.from_numpy(I)

#pixel_units = np.load("/ssd/joon/2017_11_29_ns/yass/yass_l1_pixel_units.npy")[:,:k_dim]
pixel_units = np.load("/ssd/joon/2017_11_29_ns/yass/l1_i2000_p730_pixel_units.npy")[:,:k_dim]
pixel_units = torch.from_numpy(pixel_units).cuda()

#test_S = np.load("/ssd/joon/2017_11_29_ns/yass/neural/yass_lin_50_test_neural.npy")
test_S = np.load("/ssd/joon/2017_11_29_ns/yass/neural/yass_i2000_50_test_neural.npy")
test_S = torch.from_numpy(test_S)

#test_I = np.load("/ssd/joon/2017_11_29_ns/images/smooth_test_images.npy")
#test_I = np.load("/ssd/joon/2017_11_29_ns/images/i2000_p730_hp_test_images.npy")
test_I = np.load("/ssd/joon/2017_11_29_ns/images/i2000_p730_smooth_test_images.npy")
test_I = torch.from_numpy(test_I).cuda()

In [5]:
model = PartNonLinear(unit_no, t_dim, k_dim, h_dim, p_dim, f_dim)
model = model.float()
model.cuda()

loss_fn = torch.nn.MSELoss(reduction="mean")

optimizer = optim.SGD(model.parameters(), lr=0.0025, momentum=0.9, weight_decay=5.0e-6)

#optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5.0e-6)

#optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=5.0e-6)

#milestones = [batch_no//4, batch_no//2, batch_no*3//4, batch_no*7//8]
#print(milestones)
#lr_decay = 0.6

#scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=lr_decay)

In [6]:
#%%capture output --no-stderr

test_corr_array = np.empty((batch_no//125, 2))
train_corr_array = np.empty((batch_no, 2))

save_dir = "/ssd/joon/2017_11_29_ns/yass/low_pass/"

for i in tnrange(batch_no):
    
    start_time = time.time()
    ids = batch_ids[i]
    
    batch_S = S[ids,:].cuda()
    batch_I = I[ids,:].cuda()
    
    optimizer.zero_grad()
    dec_I = model(batch_S.float(), pixel_units)
    
    loss = loss_fn(batch_I, dec_I)
    
    #loss_array[i] = loss
    
    batch_corr = audtorch.pearsonr(batch_I.T, dec_I.T)
    mean_batch_corr = torch.mean(batch_corr)
    
    loss.backward()
    optimizer.step()
    
    #scheduler.step()
        
    end_time = time.time()
    duration = end_time - start_time
    
    train_corr_array[i,0] = i
    train_corr_array[i,1] = mean_batch_corr.item()
    
    if i%25 == 24:
        
        torch.save(model.state_dict(), os.path.join(save_dir, "PART_LP_f10_h40_0.001_w5e6_i"+str(i)+"_nn.pt"))
        
        test_dec = model(test_S.float().cuda(), pixel_units)
        test_batch_corr = audtorch.pearsonr(test_I.T, test_dec.T.cuda())
        test_corr = torch.mean(test_batch_corr)
        print("Test "+str(i)+": "+str(test_corr.item()))
        test_corr_array[i//125,1] = test_corr.item()
        test_corr_array[i//125,0] = i
        
    
    print("Iter " +str(i)+ " Batch_Corr: " + str(mean_batch_corr.item()) + " , Time: " +str(duration))
    #print("Iter " +str(i)+ " Batch_MSE: " + str(loss.item()) + " , Time: " +str(duration))
    
torch.save(model.state_dict(), os.path.join(save_dir, "PART_LP_f10_h40_0.001_w5e6_i"+str(i)+"_nn.pt"))

np.save(os.path.join(save_dir, "PART_LP_f10_h40_0.001_w5e6_train_corr.npy"), train_corr_array)
np.save(os.path.join(save_dir, "PART_LP_f10_h40_0.001_w5e6_test_corr.npy"), test_corr_array)

HBox(children=(IntProgress(value=0, max=1250), HTML(value='')))

Iter 0 Batch_Corr: -0.03023079481042335 , Time: 3.039952278137207
Iter 1 Batch_Corr: -0.030646174825039552 , Time: 2.6904923915863037
Iter 2 Batch_Corr: -0.01639013971785203 , Time: 2.4923322200775146
Iter 3 Batch_Corr: -0.022741421063888007 , Time: 2.6374802589416504
Iter 4 Batch_Corr: -0.015186197199058615 , Time: 2.8675410747528076
Iter 5 Batch_Corr: -0.022731659766744287 , Time: 2.624483823776245
Iter 6 Batch_Corr: -0.017358298816906485 , Time: 2.4552783966064453
Iter 7 Batch_Corr: -0.014709594750502243 , Time: 2.9126603603363037
Iter 8 Batch_Corr: -0.011417025938048633 , Time: 2.8263344764709473
Iter 9 Batch_Corr: 0.004767617023372542 , Time: 2.7588398456573486
Iter 10 Batch_Corr: -0.022021821761222606 , Time: 2.8985323905944824
Iter 11 Batch_Corr: -0.006329888526960469 , Time: 3.7020883560180664
Iter 12 Batch_Corr: -0.00788298032597906 , Time: 3.355409860610962
Iter 13 Batch_Corr: -0.0070871638649166535 , Time: 3.3649544715881348
Iter 14 Batch_Corr: 0.009175870708005542 , Time: 3

Iter 122 Batch_Corr: 0.9412882474302712 , Time: 3.1010730266571045
Iter 123 Batch_Corr: 0.9299230577898456 , Time: 2.9920263290405273
Test 124: 0.9283089537952309
Iter 124 Batch_Corr: 0.9218109676196342 , Time: 2.7498881816864014
Iter 125 Batch_Corr: 0.9308754019677153 , Time: 2.902625560760498
Iter 126 Batch_Corr: 0.9312107259533748 , Time: 3.215409994125366
Iter 127 Batch_Corr: 0.9178437295226182 , Time: 3.1196746826171875
Iter 128 Batch_Corr: 0.9362423665921134 , Time: 3.4339599609375
Iter 129 Batch_Corr: 0.9131025662596256 , Time: 3.1539273262023926
Iter 130 Batch_Corr: 0.9314380072384635 , Time: 2.6453616619110107
Iter 131 Batch_Corr: 0.9345240910456716 , Time: 2.690437078475952
Iter 132 Batch_Corr: 0.9447231217959877 , Time: 2.6992218494415283
Iter 133 Batch_Corr: 0.9438621448004797 , Time: 2.5508198738098145
Iter 134 Batch_Corr: 0.930945073603101 , Time: 2.5902116298675537
Iter 135 Batch_Corr: 0.9351012392579604 , Time: 2.803119421005249
Iter 136 Batch_Corr: 0.9261781341062616 ,

Iter 244 Batch_Corr: 0.9673706981441503 , Time: 2.8889763355255127
Iter 245 Batch_Corr: 0.9573492581553548 , Time: 2.7091755867004395
Iter 246 Batch_Corr: 0.9591767198124096 , Time: 2.426910877227783
Iter 247 Batch_Corr: 0.9615464754146555 , Time: 3.1115713119506836
Iter 248 Batch_Corr: 0.957155285939688 , Time: 3.211164951324463
Test 249: 0.9592588873079143
Iter 249 Batch_Corr: 0.9568204966238131 , Time: 3.1633429527282715
Iter 250 Batch_Corr: 0.9616502624106054 , Time: 2.710427761077881
Iter 251 Batch_Corr: 0.951123388700837 , Time: 3.312167167663574
Iter 252 Batch_Corr: 0.9635379803447602 , Time: 3.0571846961975098
Iter 253 Batch_Corr: 0.9599377642642646 , Time: 2.5998268127441406
Iter 254 Batch_Corr: 0.9567750722254079 , Time: 3.010876417160034
Iter 255 Batch_Corr: 0.9637145245287901 , Time: 2.8498921394348145
Iter 256 Batch_Corr: 0.9633961505502175 , Time: 2.573072671890259
Iter 257 Batch_Corr: 0.9673266370112616 , Time: 2.6118223667144775
Iter 258 Batch_Corr: 0.9542259276092225 ,

Iter 366 Batch_Corr: 0.973671768200715 , Time: 3.3742668628692627
Iter 367 Batch_Corr: 0.9649848205281242 , Time: 3.6725192070007324
Iter 368 Batch_Corr: 0.9689802603485062 , Time: 3.4500789642333984
Iter 369 Batch_Corr: 0.9632839249323946 , Time: 3.2170612812042236
Iter 370 Batch_Corr: 0.9652288652167017 , Time: 3.0280771255493164
Iter 371 Batch_Corr: 0.9700529173069853 , Time: 2.579479694366455
Iter 372 Batch_Corr: 0.9717812019943971 , Time: 2.7410898208618164
Iter 373 Batch_Corr: 0.9703521571354332 , Time: 2.798611879348755
Test 374: 0.9659245070590392
Iter 374 Batch_Corr: 0.9676602930966445 , Time: 2.8050005435943604
Iter 375 Batch_Corr: 0.9692023585221301 , Time: 2.5675082206726074
Iter 376 Batch_Corr: 0.9691125749464784 , Time: 2.810856342315674
Iter 377 Batch_Corr: 0.9728497881680498 , Time: 2.57262921333313
Iter 378 Batch_Corr: 0.9728296682482692 , Time: 2.565018892288208
Iter 379 Batch_Corr: 0.9705538613806545 , Time: 3.252687931060791
Iter 380 Batch_Corr: 0.9643408063987098 ,

Iter 488 Batch_Corr: 0.9698325310808545 , Time: 3.68986177444458
Iter 489 Batch_Corr: 0.9754712616124932 , Time: 3.507063150405884
Iter 490 Batch_Corr: 0.9706006785657176 , Time: 2.7949225902557373
Iter 491 Batch_Corr: 0.9736987790249693 , Time: 2.8268325328826904
Iter 492 Batch_Corr: 0.9657270955691833 , Time: 3.1559932231903076
Iter 493 Batch_Corr: 0.9689195308499352 , Time: 2.6296892166137695
Iter 494 Batch_Corr: 0.9707773225219585 , Time: 2.715104818344116
Iter 495 Batch_Corr: 0.9788342275809341 , Time: 2.9504783153533936
Iter 496 Batch_Corr: 0.9756420700151205 , Time: 2.5430448055267334
Iter 497 Batch_Corr: 0.9701688380497733 , Time: 2.6002371311187744
Iter 498 Batch_Corr: 0.9717476815204802 , Time: 2.5511152744293213
Test 499: 0.9693310363927434
Iter 499 Batch_Corr: 0.9709296785491406 , Time: 3.311034917831421
Iter 500 Batch_Corr: 0.9676221472055895 , Time: 3.3859851360321045
Iter 501 Batch_Corr: 0.9755283279897115 , Time: 3.2741055488586426
Iter 502 Batch_Corr: 0.966240491388457

Iter 610 Batch_Corr: 0.9773594233853564 , Time: 2.526840925216675
Iter 611 Batch_Corr: 0.9736025461902855 , Time: 2.6588354110717773
Iter 612 Batch_Corr: 0.9729114933697984 , Time: 2.7909739017486572
Iter 613 Batch_Corr: 0.9722982731677072 , Time: 2.6945297718048096
Iter 614 Batch_Corr: 0.9734234231943873 , Time: 3.2385520935058594
Iter 615 Batch_Corr: 0.9725352639885361 , Time: 3.084028720855713
Iter 616 Batch_Corr: 0.9755179344988563 , Time: 3.010833740234375
Iter 617 Batch_Corr: 0.974685109871927 , Time: 2.749748468399048
Iter 618 Batch_Corr: 0.9732484073281722 , Time: 3.158045768737793
Iter 619 Batch_Corr: 0.9755427094447756 , Time: 3.183983325958252
Iter 620 Batch_Corr: 0.9709237229146775 , Time: 3.164275646209717
Iter 621 Batch_Corr: 0.9727008144272847 , Time: 3.4108011722564697
Iter 622 Batch_Corr: 0.9748591209299883 , Time: 2.727170467376709
Iter 623 Batch_Corr: 0.9801408949304157 , Time: 2.674043655395508
Test 624: 0.9712739706525786
Iter 624 Batch_Corr: 0.9768314467567465 , T

Iter 732 Batch_Corr: 0.9754087507749755 , Time: 2.5098633766174316
Iter 733 Batch_Corr: 0.9757751152680378 , Time: 2.5355849266052246
Iter 734 Batch_Corr: 0.97693552913223 , Time: 3.260204315185547
Iter 735 Batch_Corr: 0.974148609623911 , Time: 3.569797992706299
Iter 736 Batch_Corr: 0.9734537364926231 , Time: 3.582897663116455
Iter 737 Batch_Corr: 0.9774571767008194 , Time: 3.2520596981048584
Iter 738 Batch_Corr: 0.9790559690653697 , Time: 2.652418375015259
Iter 739 Batch_Corr: 0.9759077268934262 , Time: 2.7681121826171875
Iter 740 Batch_Corr: 0.9753385535799894 , Time: 2.7715110778808594
Iter 741 Batch_Corr: 0.9779987159578654 , Time: 2.5817158222198486
Iter 742 Batch_Corr: 0.9765073012123049 , Time: 2.59297513961792
Iter 743 Batch_Corr: 0.9788811929256197 , Time: 2.6971261501312256
Iter 744 Batch_Corr: 0.976577051121093 , Time: 2.638035535812378
Iter 745 Batch_Corr: 0.9719817203892342 , Time: 2.607846736907959
Iter 746 Batch_Corr: 0.9744071879001492 , Time: 3.03414249420166
Iter 747 

Iter 854 Batch_Corr: 0.9783296743212238 , Time: 3.3164267539978027
Iter 855 Batch_Corr: 0.9780102062815345 , Time: 3.4426143169403076
Iter 856 Batch_Corr: 0.9777008712021377 , Time: 3.1062049865722656
Iter 857 Batch_Corr: 0.9780790938452496 , Time: 2.7283260822296143
Iter 858 Batch_Corr: 0.977976033340656 , Time: 2.644850730895996
Iter 859 Batch_Corr: 0.9755399347054389 , Time: 2.744342565536499
Iter 860 Batch_Corr: 0.9810560153469928 , Time: 2.6244843006134033
Iter 861 Batch_Corr: 0.9781559538068626 , Time: 2.616537570953369
Iter 862 Batch_Corr: 0.9756878355350698 , Time: 2.5967555046081543
Iter 863 Batch_Corr: 0.9760437615070727 , Time: 2.8427109718322754
Iter 864 Batch_Corr: 0.9770272926259992 , Time: 2.9376590251922607
Iter 865 Batch_Corr: 0.9817831827134903 , Time: 3.2429490089416504
Iter 866 Batch_Corr: 0.9789129057081105 , Time: 3.1393449306488037
Iter 867 Batch_Corr: 0.973438402996196 , Time: 3.2582037448883057
Iter 868 Batch_Corr: 0.9774404621712962 , Time: 3.159252166748047
I

Iter 976 Batch_Corr: 0.9799396782428361 , Time: 2.973196268081665
Iter 977 Batch_Corr: 0.9793367811257139 , Time: 2.7432146072387695
Iter 978 Batch_Corr: 0.9724327038197009 , Time: 2.7913382053375244
Iter 979 Batch_Corr: 0.9780663075957198 , Time: 2.674506187438965
Iter 980 Batch_Corr: 0.9792658967998832 , Time: 2.660825490951538
Iter 981 Batch_Corr: 0.9795397987420846 , Time: 2.58343243598938
Iter 982 Batch_Corr: 0.977209242319289 , Time: 2.8129401206970215
Iter 983 Batch_Corr: 0.9774038355965021 , Time: 2.5864317417144775
Iter 984 Batch_Corr: 0.9774675727498482 , Time: 3.0284171104431152
Iter 985 Batch_Corr: 0.977601459575902 , Time: 3.3596606254577637
Iter 986 Batch_Corr: 0.978055693915797 , Time: 3.2149834632873535
Iter 987 Batch_Corr: 0.9761474934781267 , Time: 3.1625494956970215
Iter 988 Batch_Corr: 0.9795912280124999 , Time: 3.5896382331848145
Iter 989 Batch_Corr: 0.9817709885579536 , Time: 3.2139670848846436
Iter 990 Batch_Corr: 0.9739506536528559 , Time: 2.744127035140991
Iter

Iter 1097 Batch_Corr: 0.9787532226685984 , Time: 2.982649326324463
Iter 1098 Batch_Corr: 0.9802693762755553 , Time: 2.843194007873535
Test 1099: 0.9742070913353197
Iter 1099 Batch_Corr: 0.9799415031740113 , Time: 2.665799617767334
Iter 1100 Batch_Corr: 0.9829731875375274 , Time: 2.8764488697052
Iter 1101 Batch_Corr: 0.9783968800136873 , Time: 2.672072649002075
Iter 1102 Batch_Corr: 0.9759381272581941 , Time: 2.634089708328247
Iter 1103 Batch_Corr: 0.9813827878151223 , Time: 2.528721332550049
Iter 1104 Batch_Corr: 0.978709120014155 , Time: 2.84165620803833
Iter 1105 Batch_Corr: 0.9783181941835664 , Time: 2.871370315551758
Iter 1106 Batch_Corr: 0.9788349758720569 , Time: 3.2075858116149902
Iter 1107 Batch_Corr: 0.9786967751631709 , Time: 3.652858257293701
Iter 1108 Batch_Corr: 0.9780191502606727 , Time: 3.454601526260376
Iter 1109 Batch_Corr: 0.9787821102384073 , Time: 2.816890001296997
Iter 1110 Batch_Corr: 0.9776346768498504 , Time: 2.759481430053711
Iter 1111 Batch_Corr: 0.97804547186

Iter 1217 Batch_Corr: 0.9804338895122842 , Time: 2.726539373397827
Iter 1218 Batch_Corr: 0.9776924839169384 , Time: 2.7100107669830322
Iter 1219 Batch_Corr: 0.9774841553299918 , Time: 2.743502378463745
Iter 1220 Batch_Corr: 0.9781315221108937 , Time: 2.6828200817108154
Iter 1221 Batch_Corr: 0.978175676167089 , Time: 2.522678852081299
Iter 1222 Batch_Corr: 0.9775777406916387 , Time: 2.7335312366485596
Iter 1223 Batch_Corr: 0.9749754838373942 , Time: 2.662855863571167
Test 1224: 0.9745457332829024
Iter 1224 Batch_Corr: 0.9810629421611281 , Time: 2.60384202003479
Iter 1225 Batch_Corr: 0.9798457978296534 , Time: 3.325079917907715
Iter 1226 Batch_Corr: 0.9777278474597089 , Time: 3.078739881515503
Iter 1227 Batch_Corr: 0.9781047770503625 , Time: 3.50791335105896
Iter 1228 Batch_Corr: 0.9830090939588608 , Time: 3.455166816711426
Iter 1229 Batch_Corr: 0.978732249704813 , Time: 3.002528429031372
Iter 1230 Batch_Corr: 0.9829795755920758 , Time: 2.766812324523926
Iter 1231 Batch_Corr: 0.981698737