In [1]:
import numpy as np
import torch
import torch.nn as nn
from tqdm import tnrange
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tnrange
import matplotlib.pyplot as plt

In [2]:
class TCN(nn.Module):
    def __init__(self, L_no, T_no, S_no):
        super().__init__()
        
        self.L_no = L_no
        self.T_no = T_no
        self.S_no = S_no
        
        modules = []
        
        for i in range(L_no):
            if i == 0:
                modules.append(nn.Conv1d(in_channels=2,
                                        out_channels=self.S_no,
                                        kernel_size=self.T_no,
                                        padding=(self.T_no-1)//2,
                                        groups=2))
                modules.append(nn.LeakyReLU())
                
            elif i == L_no-1:
                modules.append(nn.Conv1d(in_channels=self.S_no,
                                        out_channels=1,
                                        kernel_size=self.T_no,
                                        padding=(self.T_no-1)//2)
                                          )
            
            else:
                modules.append(nn.Conv1d(in_channels=self.S_no,
                                        out_channels=self.S_no,
                                        kernel_size=self.T_no,
                                        padding=(self.T_no-1)//2,
                                          groups=self.S_no))
                modules.append(nn.LeakyReLU())
                
        self.sequential = nn.Sequential(*modules)
        
    def forward(self, S_e, S_i):
        S_e = torch.sum(S_e, 1)
        S_i = torch.sum(S_i, 1)
        
        S = torch.hstack((S_e.reshape(-1,1), S_i.reshape(-1,1)))
        S = S.T.unsqueeze(0)
            
        out = self.sequential(S)
        out = torch.sigmoid(out.flatten())
            
        return out
        

In [3]:
S_no = 6
L_no = 4
T_no = 151
iter_no = 10000
batch_size = 50000
epoch_no = 15

T_train = 60* 1000* 5
T_test = 10* 1000 * 5

In [4]:
base_dir = "/media/hdd01/sklee/"
experiment = "clust4-60"
cell_type = "CA1"

E_neural_file = "Espikes_neural.npy"
I_neural_file = "Ispikes_neural.npy"
Z_file = "spk_loc.npy"

E_neural = np.load(base_dir+cell_type+"_"+experiment+"/data/"+E_neural_file)
I_neural = np.load(base_dir+cell_type+"_"+experiment+"/data/"+I_neural_file)
Z = np.load(base_dir+cell_type+"_"+experiment+"/data/"+Z_file).flatten()

E_neural = torch.from_numpy(E_neural)
I_neural = torch.from_numpy(I_neural)
Z = torch.from_numpy(Z)

loss_weights = torch.ones(Z.shape[0]).cuda()
Z_idx = torch.where(Z == 1)[0]
for z in Z_idx:
    loss_weights[z] *= 2
    #loss_weights[z-10:z] *= 0.5
    #loss_weights[z+1:z+10] *= 0.5
        

In [5]:
Z_train = Z[:T_train].cuda().float()
Z_test = Z[T_train:T_train + T_test].cuda().float()
loss_weights_train = loss_weights[:T_train].float()
loss_weights_test = loss_weights[T_train:T_train+T_test].float()
test_E_neural = E_neural[T_train:T_train+T_test].float().cuda()
test_I_neural = I_neural[T_train:T_train+T_test].float().cuda()
train_E_neural = E_neural[:T_train].float().cuda()
train_I_neural = I_neural[:T_train].float().cuda()

batch_no = (T_train - batch_size) * epoch_no
train_idx = np.empty((epoch_no, T_train - batch_size))
for i in range(epoch_no):
    part_idx = np.arange(T_train - batch_size)
    np.random.shuffle(part_idx)
    train_idx[i] = part_idx
train_idx = train_idx.flatten()
train_idx = torch.from_numpy(train_idx)

In [6]:
model = TCN(L_no=L_no,
            T_no=T_no,
            S_no=S_no)
model.cuda().float()

optimizer = torch.optim.Adam([
            {'params': model.parameters()},
            ], lr = 0.0005)

bce_criterion = nn.BCELoss(reduction="none")
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

3643


In [None]:
for i in tnrange(iter_no):
    model.train()
    optimizer.zero_grad()
        
    batch_idx = train_idx[i].long()
    batch_E_neural = train_E_neural[batch_idx : batch_idx+batch_size]
    batch_I_neural = train_I_neural[batch_idx : batch_idx+batch_size]
    batch_Z = Z_train[batch_idx : batch_idx+batch_size]
    batch_loss_weights = loss_weights_train[batch_idx : batch_idx+batch_size]
        
    Z_pred = model(batch_E_neural,
                    batch_I_neural)
        
    
        
    bce_loss = torch.mean(bce_criterion(Z_pred, batch_Z) * batch_loss_weights)

    loss = bce_loss
    
    
            
    loss.backward()
    optimizer.step()
        
    if i%50 == 0:
        model.eval()
        Z_pred = model(test_E_neural, test_I_neural)
        bce_loss = torch.mean(bce_criterion(Z_pred, Z_test) * loss_weights_test)
        
        Z_pred_disc = torch.zeros_like(Z_pred)
        Z_idx = torch.where(Z_pred > 0.5)[0]
        Z_pred_disc[Z_idx] = 1
        
        good_no = 0
        bad_no = 0
        for x in torch.where(Z_pred_disc == 1)[0]:
            close_count = 0
            for y in torch.where(Z_test == 1)[0]:
                if torch.abs(x-y) <= 15:
                    close_count += 1
            if close_count > 0:
                good_no += 1
            else:
                bad_no += 1
                
                
        
        print(i, "BCE: ", bce_loss.item(), "GOOD: ", good_no, "BAD: ", bad_no)

  for i in tnrange(iter_no):


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10000.0), HTML(value='')))

0 BCE:  0.6907568573951721 GOOD:  692 BAD:  18729
50 BCE:  0.03519582748413086 GOOD:  0 BAD:  0
100 BCE:  0.02358625829219818 GOOD:  0 BAD:  0
150 BCE:  0.023114724084734917 GOOD:  0 BAD:  0
200 BCE:  0.021646102890372276 GOOD:  0 BAD:  0
250 BCE:  0.017256680876016617 GOOD:  0 BAD:  0
300 BCE:  0.016147589311003685 GOOD:  0 BAD:  0
350 BCE:  0.01641637086868286 GOOD:  0 BAD:  0
400 BCE:  0.016081979498267174 GOOD:  0 BAD:  0
450 BCE:  0.016022788360714912 GOOD:  0 BAD:  0
500 BCE:  0.01623048633337021 GOOD:  0 BAD:  0
550 BCE:  0.015951091423630714 GOOD:  0 BAD:  0
600 BCE:  0.016077497974038124 GOOD:  0 BAD:  0
650 BCE:  0.01587727665901184 GOOD:  0 BAD:  0
700 BCE:  0.015911336988210678 GOOD:  0 BAD:  0
750 BCE:  0.01582174561917782 GOOD:  0 BAD:  0
800 BCE:  0.015972783789038658 GOOD:  0 BAD:  0
850 BCE:  0.015953581780195236 GOOD:  0 BAD:  0
900 BCE:  0.015752548351883888 GOOD:  0 BAD:  0
950 BCE:  0.01602337695658207 GOOD:  0 BAD:  0
1000 BCE:  0.015776189044117928 GOOD:  0 BAD: 

In [None]:
plt.plot(model.sequential[0].weight[0,0,:].cpu().detach().numpy())

In [11]:
print(model.sequential[0].weight.shape)

torch.Size([5, 2, 151])
