In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import audtorch.metrics.functional as audtorch

import time
from tqdm import tnrange, tqdm_notebook
import os
import matplotlib.pyplot as plt

SoX could not be found!

    If you do not have SoX, proceed here:
     - - - http://sox.sourceforge.net/ - - -

    If you do (or think that you should) have SoX, double-check your
    path variables.
    


In [2]:
class FullNonLinear(nn.Module):
    def __init__(self, unit_no, t_dim, h_dim, p_dim):
        super(FullNonLinear, self).__init__()
        self.unit_no = unit_no
        self.t_dim = t_dim
        self.h_dim = h_dim
        self.p_dim = p_dim
        
        self.input_layer = nn.Sequential(
                            nn.Linear(self.unit_no * self.t_dim , self.h_dim),
                            nn.Tanh()).cuda()
        self.hidden1 = nn.Sequential(
                            nn.Linear(self.h_dim, self.h_dim),
                            nn.Tanh()).cuda()
        self.hidden2 = nn.Sequential(
                            nn.Linear(self.h_dim, self.h_dim),
                            nn.Tanh()).cuda()
        self.hidden3 = nn.Sequential(
                            nn.Linear(self.h_dim, self.h_dim),
                            nn.Tanh()).cuda()
        self.output_layer = nn.Linear(self.h_dim, self.p_dim).cuda()
        
    def forward(self, S):
        
        out = self.input_layer(S)
        out = self.hidden1(out)
        out = self.hidden2(out)
        out = self.hidden3(out)
        out = self.output_layer(out)
        
        return out

In [2]:
class PartNonLinear(nn.Module):
    def __init__(self, unit_no, t_dim, k_dim, h_dim, p_dim, f_dim):
        super(PartNonLinear, self).__init__()
        self.unit_no = unit_no
        self.t_dim = t_dim
        self.h_dim = h_dim
        self.p_dim = p_dim
        self.k_dim = k_dim
        self.f_dim = f_dim
        
        self.featurize = nn.ModuleList([nn.Linear(self.t_dim,
                                                  self.f_dim) for i in range(self.unit_no)]).cuda()
        
        self.hidden1 = nn.ModuleList([nn.Linear(self.k_dim*self.f_dim,
                                               self.h_dim) for i in range(self.p_dim)]).cuda()
        self.hidden1_act = nn.ModuleList([nn.PReLU() for i in range(self.p_dim)]).cuda()
        
        self.output_layer = nn.ModuleList([nn.Linear(self.h_dim,
                                                    1) for i in range(self.p_dim)]).cuda()
        
    def forward(self, S, pix_units):
        
        F = torch.empty(S.shape[0], self.unit_no * self.f_dim).cuda()
        for n in range(self.unit_no):
            feat_n = self.featurize[n](S[:, n*self.t_dim : (n+1)*self.t_dim])
            F[:, n*self.f_dim : (n+1)*self.f_dim] = feat_n
        
        I = torch.empty(S.shape[0] , self.p_dim).cuda()
        
        for x in range(self.p_dim):
            unit_ids = pix_units[x]
            feat_ids = torch.empty((self.k_dim * self.f_dim))
            for i in range(self.k_dim):
                feat_ids[i*self.f_dim : (i+1)*self.f_dim] = torch.arange(self.f_dim) + unit_ids[i]*self.f_dim
            
            pix_feat = self.hidden1[x](F[:, feat_ids.long()])
            pix_feat = self.hidden1_act[x](pix_feat)

            out = self.output_layer[x](pix_feat)
            
            I[:, x] = out.reshape(-1)
            
        return I            

In [3]:
t_dim = 50
k_dim = 10
h_dim = 40
f_dim = 10

unit_no = 2094
image_no = 9800
#image_no = 2000
p_dim = 95*146
#p_dim = 730

epoch_no = 16
batch_size = 64
batch_no = epoch_no * image_no // batch_size

batch_ids = np.tile(np.arange(image_no).reshape((1,-1)), (epoch_no, 1))
for i in range(epoch_no):
    np.random.shuffle(batch_ids[i])
batch_ids = batch_ids.reshape((batch_no, batch_size))
batch_ids = torch.from_numpy(batch_ids).cuda()

In [4]:
S = np.load("/ssd/joon/2017_11_29_ns/yass/neural/yass_lin_50_train_neural.npy")
#S = np.load("/ssd/joon/2017_11_29_ns/yass/neural/yass_i2000_50_train_neural.npy")
S = torch.from_numpy(S)

I = np.load("/ssd/joon/2017_11_29_ns/images/smooth_train_images.npy")
#I = np.load("/ssd/joon/2017_11_29_ns/images/i2000_p730_hp_train_images.npy")
I = torch.from_numpy(I)

pixel_units = np.load("/ssd/joon/2017_11_29_ns/yass/yass_l1_pixel_units.npy")[:,:k_dim]
pixel_units = torch.from_numpy(pixel_units).cuda()

test_S = np.load("/ssd/joon/2017_11_29_ns/yass/neural/yass_lin_50_test_neural.npy")
test_S = torch.from_numpy(test_S)

test_I = np.load("/ssd/joon/2017_11_29_ns/images/smooth_test_images.npy")
test_I = torch.from_numpy(test_I).cuda()

In [5]:
model = PartNonLinear(unit_no, t_dim, k_dim, h_dim, p_dim, f_dim)
model = model.float()
model.cuda()

loss_fn = torch.nn.MSELoss(reduction="mean")

optimizer = optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=5.0e-6)
#optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=5.0e-6)

#milestones = [batch_no//4, batch_no//2, batch_no*3//4, batch_no*7//8]
#print(milestones)
#lr_decay = 0.6

#scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=lr_decay)

In [None]:
#%%capture output --no-stderr

test_corr_array = np.empty((batch_no//125, 2))
train_corr_array = np.empty((batch_no, 2))

save_dir = "/ssd/joon/2017_11_29_ns/yass/low_pass/"

for i in tnrange(batch_no):
    
    start_time = time.time()
    ids = batch_ids[i]
    
    batch_S = S[ids,:].cuda()
    batch_I = I[ids,:].cuda()
    
    optimizer.zero_grad()
    dec_I = model(batch_S.float(), pixel_units)
    
    loss = loss_fn(batch_I, dec_I)
    
    #loss_array[i] = loss
    
    batch_corr = audtorch.pearsonr(batch_I.T, dec_I.T)
    mean_batch_corr = torch.mean(batch_corr)
    
    loss.backward()
    optimizer.step()
    
    #scheduler.step()
        
    end_time = time.time()
    duration = end_time - start_time
    
    train_corr_array[i,0] = i
    train_corr_array[i,1] = mean_batch_corr.item()
    
    if i%50 == 49:
        
        torch.save(model.state_dict(), os.path.join(save_dir, "FULL_LP_f10_h40_0.01_w5e6_i"+str(i)+"_nn.pt"))
        
        test_dec = model(test_S.float().cuda(), pixel_units)
        test_batch_corr = audtorch.pearsonr(test_I.T, test_dec.T.cuda())
        test_corr = torch.mean(test_batch_corr)
        print("Test "+str(i)+": "+str(test_corr.item()))
        test_corr_array[i//125,1] = test_corr.item()
        test_corr_array[i//125,0] = i
        
    
    print("Iter " +str(i)+ " Batch_Corr: " + str(mean_batch_corr.item()) + " , Time: " +str(duration))
    #print("Iter " +str(i)+ " Batch_MSE: " + str(loss.item()) + " , Time: " +str(duration))
    
torch.save(model.state_dict(), os.path.join(save_dir, "FULL_LP_f10_h40_0.01_w5e6_i"+str(i)+"_nn.pt"))

np.save(os.path.join(save_dir, "FULL_LP_f10_h40_0.01_w5e6_train_corr.npy"), train_corr_array)
np.save(os.path.join(save_dir, "FULL_LP_f10_h40_0.01_w5e6_test_corr.npy"), test_corr_array)

HBox(children=(IntProgress(value=0, max=2450), HTML(value='')))

Iter 0 Batch_Corr: -0.0028036960501038456 , Time: 38.523348569869995
Iter 1 Batch_Corr: -0.003907877821444815 , Time: 37.92855882644653
Iter 2 Batch_Corr: -0.004114058391939924 , Time: 39.94805574417114
Iter 3 Batch_Corr: -0.0032817184931996377 , Time: 38.26425838470459
Iter 4 Batch_Corr: -0.005554183451873845 , Time: 39.878594398498535
Iter 5 Batch_Corr: -0.002499122022147872 , Time: 39.57295823097229
Iter 6 Batch_Corr: -0.0032812531503774145 , Time: 39.933507442474365
Iter 7 Batch_Corr: -0.0028352754474645922 , Time: 40.24218964576721
Iter 8 Batch_Corr: -0.0033364332799539803 , Time: 40.56747055053711
Iter 9 Batch_Corr: -0.0014747615344647448 , Time: 38.6614043712616
Iter 10 Batch_Corr: -0.0017777980974500996 , Time: 40.68708515167236
Iter 11 Batch_Corr: -0.002602323429681393 , Time: 39.462889194488525
Iter 12 Batch_Corr: -0.00402081087921542 , Time: 40.382874965667725
Iter 13 Batch_Corr: -0.0009213680007150545 , Time: 38.764851808547974
Iter 14 Batch_Corr: -0.00046821303339890125 , 

Iter 123 Batch_Corr: 0.3178720566741881 , Time: 28.5900559425354
Iter 124 Batch_Corr: 0.18702737573490905 , Time: 27.365731716156006
Iter 125 Batch_Corr: 0.2570167250226839 , Time: 28.361891984939575
Iter 126 Batch_Corr: 0.2563946692383081 , Time: 27.24531054496765
Iter 127 Batch_Corr: 0.3392872092679133 , Time: 27.32044005393982
Iter 128 Batch_Corr: 0.15718789947577802 , Time: 28.197855234146118
Iter 129 Batch_Corr: 0.23663890186255726 , Time: 27.105314016342163
Iter 130 Batch_Corr: 0.3102343447079436 , Time: 29.16761803627014
Iter 131 Batch_Corr: 0.12677488268267142 , Time: 27.29450011253357
Iter 132 Batch_Corr: 0.2601857584410556 , Time: 27.221988439559937
Iter 133 Batch_Corr: 0.25842020791603076 , Time: 28.716936349868774
Iter 134 Batch_Corr: 0.20451128346929268 , Time: 26.972769260406494
Iter 135 Batch_Corr: 0.3258598219143477 , Time: 28.585195779800415
Iter 136 Batch_Corr: 0.18729253691330913 , Time: 27.44684076309204
