In [1]:
import numpy as np
import torch 
import torch.nn as nn
from tqdm import tnrange
import torch.optim as optim
import torch.nn.functional as F
from sklearn import metrics
import matplotlib.pyplot as plt
from sklearn.metrics import explained_variance_score
import scipy
import time

In [2]:
class Switching(nn.Module):
    def __init__(self, sub_no, state_no, hid_no, C_syn, device):
        super().__init__()
        
        self.N_no = sub_no
        self.K_no = state_no
        self.H_no = hid_no
        self.device = device
        self.C_syn = C_syn
        self.in_no = C_syn.shape[1]
        
        self.encoder = nn.GRU(self.N_no, 10, num_layers=1, batch_first=True, bidirectional=True)
        self.encoder_lin = nn.Linear(20, self.N_no*self.K_no)
        self.spike = nn.Parameter(torch.ones(self.in_no), requires_grad=True)
        
        
        
        self.W_sz = nn.Parameter(torch.randn(self.N_no, self.K_no, self.K_no, self.H_no)*0.1, requires_grad=True)
        self.W_sx = nn.Parameter(torch.randn(self.N_no, self.K_no, self.K_no)*0.1, requires_grad=True)
        self.b_s = nn.Parameter(torch.randn(self.N_no, self.K_no, self.K_no)*0.1, requires_grad=True)
        
        self.W_zz = nn.Parameter(torch.randn(self.N_no, self.K_no, self.H_no, self.H_no)*0.1, requires_grad=True)
        self.W_zx = nn.Parameter(torch.randn(self.N_no, self.K_no, self.H_no)*0.1, requires_grad=True)
        self.b_z = nn.Parameter(torch.randn(self.N_no, self.K_no, self.H_no)*0.1, requires_grad=True)
        
        self.W_yz = nn.Parameter(torch.randn(self.N_no * self.H_no)*0.1, requires_grad=True)
        self.b_y = nn.Parameter(torch.randn(1)*0.1, requires_grad=True)
        
        
        
    def train_forward(self, X, temp):
        # X is (batch_size, T_data, E_no)
        T_data = X.shape[1]
        batch_size = X.shape[0]
        
        X_scaled = X * self.spike.reshape(1,1,-1)
        X_sub = torch.matmul(X_scaled, self.C_syn.T) # (batch, T_data, N_no)
        
        encoder_out, _ = self.encoder(X_sub)
        S_encode_raw = self.encoder_lin(encoder_out.reshape(batch_size, T_data, -1)) # (batch, T_data, K_no*N_no)
        S_encode = F.softmax(S_encode_raw.reshape(batch_size, T_data, self.N_no, self.K_no) / temp, 3) # (batch, T_data, N_no, K_no)
        S_encode_pad = torch.zeros(batch_size, T_data+1, self.N_no, self.K_no).to(self.device)
        S_encode_pad[:,-T_data:,:,:] = S_encode_pad[:,-T_data:,:,:] + S_encode
        S_encode_pad[:,:,:,0] = 1
        
        Z_decode = torch.zeros(batch_size, T_data+1, self.N_no, self.H_no).to(self.device)
        S_decode = torch.zeros(batch_size, T_data+1, self.N_no, self.K_no).to(self.device)
        S_decode[:,:,:,0] = 1
        
        for t in range(T_data):
            for n in range(self.N_no):
                prev_S = S_encode_pad[:,t,n,:].clone() # (batch, K_no)
                prev_Z = Z_decode[:,t,n].clone()
                curr_X = X_sub[:,t,n]
                curr_S = S_encode_pad[:,t+1,n,:].clone()
                
                ### Calculate state vector ###
                
                W_sz_part = torch.sum(prev_S.unsqueeze(2).unsqueeze(3) * self.W_sz[n], 1) # (batch, K_no, H_no)
                W_sx_part = torch.sum(prev_S.unsqueeze(2) * self.W_sx[n], 1) # (batch, K_no)
                b_s_part = torch.sum(prev_S.unsqueeze(2) * self.b_s[n], 1) # (batch, K_no)
                
                S_in = torch.matmul(W_sz_part, prev_Z.unsqueeze(2)).squeeze(2) \
                                + W_sx_part * curr_X.unsqueeze(1) \
                                + b_s_part # (batch, K_no)
                
                curr_S_est = torch.zeros(batch_size, self.K_no).to(self.device)
                for b in range(batch_size):
                    max_idx = torch.argmax(S_in[b])
                    curr_S_est[b,max_idx] = 1
                    
                S_decode[:,t+1,n,:] = S_decode[:,t+1,n,:] + curr_S_est # (batch, K_no)
                
                ### Calculate hidden Z ###
                
                W_zz_part = torch.sum(curr_S.unsqueeze(2).unsqueeze(3) * self.W_zz[n], 1) # (batch, H_no, H_no)
                W_zx_part = torch.sum(curr_S.unsqueeze(2) * self.W_zx[n], 1) # (batch, H_no)
                b_z_part = torch.sum(curr_S.unsqueeze(2) * self.b_z[n], 1) # (batch, H_no)
                
                Z_in = torch.matmul(W_zz_part, prev_Z.unsqueeze(2)).squeeze(2) \
                            + W_zx_part * curr_X.unsqueeze(1) \
                            + b_z_part # (batch, H_no)
                curr_Z = torch.tanh(Z_in) # (batch, H_no)
                Z_decode[:,t+1,n,:] = Z_decode[:,t+1,n,:] + curr_Z
                
        Y_out = torch.sum(Z_out[:,1:,:,:].reshape(batch_size, T_data, -1) * self.W_yz.reshape(1,1,-1), 2) + self.b_y
        
        return Y_out, Z_decode[:,1:,:,:], S_decode[:,1:,:,:], S_encode[:,1:,:,:]
        
    def test_forward(self, X):
        # X is (batch_size, T_data, E_no)
        T_data = X.shape[1]
        batch_size = X.shape[0]
        
        X_scaled = X * self.spike.reshape(1,1,-1)
        X_sub = torch.matmul(X_scaled, self.C_syn.T) # (batch, T_data, N_no)
        
        Z_decode = torch.zeros(batch_size, T_data+1, self.N_no, self.H_no).to(self.device)
        S_decode = torch.zeros(batch_size, T_data+1, self.N_no, self.K_no).to(self.device)
        S_decode[:,:,:,0] = 1
        
        for t in range(T_data):
            for n in range(self.N_no):
                prev_S = S_decode[:,t,n,:].clone() # (batch, K_no)
                prev_Z = Z_decode[:,t,n].clone()
                curr_X = X_sub[:,t,n]
                
                ### Calculate state vector ###
                
                W_sz_part = torch.sum(prev_S.unsqueeze(2).unsqueeze(3) * self.W_sz[n], 1) # (batch, K_no, H_no)
                W_sx_part = torch.sum(prev_S.unsqueeze(2) * self.W_sx[n], 1) # (batch, K_no)
                b_s_part = torch.sum(prev_S.unsqueeze(2) * self.b_s[n], 1) # (batch, K_no)
                
                S_in = torch.matmul(W_sz_part, prev_Z.unsqueeze(2)).squeeze(2) \
                                + W_sx_part * curr_X.unsqueeze(1) \
                                + b_s_part # (batch, K_no)
                
                curr_S = torch.zeros(batch_size, self.K_no).to(self.device)
                for b in range(batch_size):
                    max_idx = torch.argmax(S_in[b])
                    curr_S[b,max_idx] = 1
                    
                S_decode[:,t+1,n,:] = S_decode[:,t+1,n,:] + curr_S # (batch, K_no)
                
                ### Calculate hidden Z ###
                
                W_zz_part = torch.sum(curr_S.unsqueeze(2).unsqueeze(3) * self.W_zz[n], 1) # (batch, H_no, H_no)
                W_zx_part = torch.sum(curr_S.unsqueeze(2) * self.W_zx[n], 1) # (batch, H_no)
                b_z_part = torch.sum(curr_S.unsqueeze(2) * self.b_z[n], 1) # (batch, H_no)
                
                Z_in = torch.matmul(W_zz_part, prev_Z.unsqueeze(2)).squeeze(2) \
                            + W_zx_part * curr_X.unsqueeze(1) \
                            + b_z_part # (batch, H_no)
                curr_Z = torch.tanh(Z_in) # (batch, H_no)
                Z_decode[:,t+1,n,:] = Z_decode[:,t+1,n,:] + curr_Z
                
        Y_out = torch.sum(Z_out[:,1:,:,:].reshape(batch_size, T_data, -1) * self.W_yz.reshape(1,1,-1), 2) + self.b_y
        
        return Y_out, Z_decode[:,1:,:,:], S_decode[:,1:,:,:]

In [3]:
base_dir = "/media/hdd01/sklee/"
experiment = "clust4-60"
cell_type = "CA1"
E_neural_file = "Espikes_neural.npz"
V_file = "V_diff.npy"
eloc_file = "Elocs_T10_Ne2000_gA0.6_tauA1_gN0.8_Ni200_gG0.1_gB0.1_Er0.5_Ir7.4_random_NR_rep1000_stimseed1.npy"

E_neural = scipy.sparse.load_npz(base_dir+cell_type+"_"+experiment+"/data/"+E_neural_file)
V = np.load(base_dir+cell_type+"_"+experiment+"/data/"+V_file)
V = torch.from_numpy(V)
eloc = np.load(base_dir+cell_type+"_"+experiment+"/data/"+eloc_file)

den_idx = np.unique(eloc[880:1120,0])
e_idx = np.where(np.isin(eloc[:,0], den_idx) == True)[0]
e_idx = torch.from_numpy(e_idx)

In [4]:
T_train = 999 * 1000 * 50
T_test = 1 * 1000 * 50
hid_no = 2 # H
sub_no = 4 # N
state_no = 3 # K
in_no = 299
save_dir = base_dir+cell_type+"_"+experiment+"/"
device = torch.device("cuda")

batch_length = 50000
batch_size = 9
iter_no = 20000
epoch_no = iter_no*batch_length*batch_size//T_train 

In [5]:
V_train = V[:T_train].float()
V_test = V[T_train:T_train + T_test].to(device).float()

test_E_neural = E_neural[T_train:T_train+T_test].toarray()
train_E_neural = E_neural[:T_train]
test_E_neural = torch.from_numpy(test_E_neural).float().to(device)

train_idx = np.empty((epoch_no, T_train//batch_length//batch_size))
for i in range(epoch_no):
    part_idx = np.arange(0, T_train, batch_length*batch_size)
    np.random.shuffle(part_idx)
    train_idx[i] = part_idx
train_idx = train_idx.flatten()
train_idx = torch.from_numpy(train_idx)

In [6]:
C_syn = torch.zeros(sub_no, in_no).to(device)
for i in range(in_no):
    idx = e_idx[i]
    if eloc[idx,0] == den_idx[0]:
        C_syn[0,i] = 1
    elif eloc[idx,0] == den_idx[1]:
        C_syn[1,i] = 1
    elif eloc[idx,0] == den_idx[2]:
        C_syn[2,i] = 1
    elif eloc[idx,0] == den_idx[3]:
        C_syn[3,i] = 1

In [7]:
model = Switching(sub_no, state_no, hid_no, C_syn, device)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.002)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.75)

model.to(device).float()
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

bce_criterion = nn.BCELoss(reduction="mean")

1760


In [None]:
temp_list = np.logspace(-0.5, -3, 50)
temp_count = 0

for i in tnrange(iter_no):
    s = time.time()
    
    model.train()
    optimizer.zero_grad()
    
    if (i%50 == 49) & (temp_count < 49):
        temp_count += 1
    temp = temp_list[temp_count] 
    
    batch_idx = train_idx[i].long()
    batch_E_neural = train_E_neural[batch_idx : batch_idx+batch_length*batch_size].toarray().reshape(batch_size, batch_length, -1)
    batch_E_neural = torch.from_numpy(batch_E_neural).float().to(device)
    batch_V = V_train[batch_idx : batch_idx+batch_length*batch_size].reshape(batch_size, -1).to(device)
    
    V_pred, Z_dec, S_dec, S_enc = model.train_forward(batch_E_neural[:,:,e_idx], temp)
    
    print("forward")
    print(time.time()-s)
    s = time.time()
                
    V_loss = torch.mean((V_pred - batch_V)**2)
    S_loss = bce_criterion(S_dec, S_enc)
    
    
    loss.backward()
    optimizer.step()
    #scheduler.step()
    
    print("backward")
    print(time.time() - s)
    
    if (i%50 == 49) or (i == 0):
        model.eval()
        test_V_pred, test_Z_dec, test_S_dec = model.test_forward(test_E_neural[:,e_idx].unsqueeze(0))        
        test_V_pred = test_V_pred.flatten()
                 
        test_score = explained_variance_score(V_test.cpu().detach().numpy(), test_V_pred.cpu().detach().numpy())
        test_mse = torch.mean((V_test-test_V_pred)**2).item()
        
        print(i, np.round(test_score,6),
              np.round(test_mse,6))


  for i in tnrange(iter_no):


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20000.0), HTML(value='')))