In [1]:
import numpy as np
import torch 
import torch.nn as nn
from tqdm import tnrange
import torch.optim as optim
import torch.nn.functional as F
from sklearn import metrics
import matplotlib.pyplot as plt
from sklearn.metrics import explained_variance_score
import scipy
import time

from models.sub_cos_glm import Sub_Cos_GLM
#from models.sub_tcn import Sub_TCN
from models.gru import GRU
from models.gru_stacked import GRU_Stacked
from models.sub_cos_glm_stacked import Sub_Cos_GLM_Stacked
from models.gru_multilayer import GRU_Multilayer
from models.sub_cos_glm_multilayer import Sub_Cos_GLM_Multilayer
from models.tcn_multilayer import TCN_Multilayer

In [2]:
base_dir = "/scratch/yjk27/"
experiment = "clust4-60"
cell_type = "CA1"
E_neural_file = "Espikes_neural.npz"
I_neural_file = "Ispikes_neural.npz"
clust_mode = "whole"
model_type = "tcnmulti"

#E_neural = scipy.sparse.load_npz(base_dir+cell_type+"_"+experiment+"/data/"+E_neural_file)
#I_neural = scipy.sparse.load_npz(base_dir+cell_type+"_"+experiment+"/data/"+I_neural_file)
E_neural = scipy.sparse.load_npz("/scratch/yjk27/CA1_clust4-60/data/"+E_neural_file)
I_neural = scipy.sparse.load_npz("/scratch/yjk27/CA1_clust4-60/data/"+I_neural_file)

if (clust_mode == "hand") or (clust_mode == "whole") or (clust_mode == "global"):
    C_syn_e = np.load("/scratch/yjk27/"+cell_type+"_"+experiment+"/data/handsub5_C_syn_e.npy")
    C_syn_i = np.load("/scratch/yjk27/"+cell_type+"_"+experiment+"/data/handsub5_C_syn_i.npy")
    C_syn_e = torch.from_numpy(C_syn_e).float()
    C_syn_i = torch.from_numpy(C_syn_i).float()
elif clust_mode == "rand":
    C_syn_e = np.load("/scratch/yjk27/"+cell_type+"_"+experiment+"/data/randsub5_C_syn_e.npy")
    C_syn_i = np.load("/scratch/yjk27/"+cell_type+"_"+experiment+"/data/randsub5_C_syn_i.npy")
    C_syn_e = torch.from_numpy(C_syn_e).float()
    C_syn_i = torch.from_numpy(C_syn_i).float()

In [3]:
H_no = 40
layer_no = 3
sub_no = 5
sub_no_file = 5
E_no = 2000
I_no = 200
T_no = 350
device = torch.device("cuda:1")

In [4]:
if model_type == "gru":
    model = GRU(C_syn_e.to(device), C_syn_i.to(device), H_no, device)
elif model_type == "grustack":
    model = GRU_Stacked(C_syn_e.to(device), C_syn_i.to(device), H_no, device)
#elif model_type == "tcn":
    #model = Sub_TCN(C_syn_e.to(device), C_syn_i.to(device), T_no, H_no, device)
elif model_type == "glm":
    model = Sub_Cos_GLM(C_syn_e.to(device), C_syn_i.to(device), T_no, H_no, device)
elif model_type == "glmstack":
    model = Sub_Cos_GLM_Stacked(C_syn_e.to(device), C_syn_i.to(device), T_no, H_no, device)
elif model_type == "grumulti":
    model = GRU_Multilayer(C_syn_e.to(device), C_syn_i.to(device), H_no, device)
elif model_type == "glmmulti":
    model = Sub_Cos_GLM_Multilayer(C_syn_e.to(device), C_syn_i.to(device), T_no, H_no, device)
elif model_type == "tcnmulti":
    model = TCN_Multilayer(T_no-1, E_no+I_no, layer_no, H_no, device)
    
model.to(device).float()
model.load_state_dict(torch.load(base_dir+cell_type+"_"+experiment+"/"+clust_mode+"/"+model_type+"_l"+str(layer_no)+"_h"+str(H_no)+".pt", map_location='cuda:0'))
#model.load_state_dict(torch.load(base_dir+cell_type+"_"+experiment+"/"+clust_mode+"/"+model_type+"_s"+str(sub_no_file)+"_h"+str(H_no)+".pt", map_location='cuda:0'))
#model.load_state_dict(torch.load(base_dir+cell_type+"_"+experiment+"/"+clust_mode+"/"+model_type+"_s"+str(sub_no_file)+"_h"+str(H_no)+"_set5.pt", map_location='cuda:0'))
model.eval()
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

30715361


In [None]:
if (model_type == "gru") or (model_type == "grustack") or (model_type=="grumulti"):
    test = np.zeros((20,50000))
    sub_out = np.zeros((20, sub_no, 50000))

    for i in tnrange(20):
        if i < 19:
            part_E_neural = torch.from_numpy(E_neural[(-20+i)*50000:(-19+i)*50000].toarray()).to(device).float().unsqueeze(0)
            part_I_neural = torch.from_numpy(I_neural[(-20+i)*50000:(-19+i)*50000].toarray()).to(device).float().unsqueeze(0)
        elif i == 19:
            part_E_neural = torch.from_numpy(E_neural[(-20+i)*50000:].toarray()).to(device).float().unsqueeze(0)
            part_I_neural = torch.from_numpy(I_neural[(-20+i)*50000:].toarray()).to(device).float().unsqueeze(0)

        #############
        ############
        #part_I_neural = torch.zeros_like(part_I_neural)
        ############
        ############
        
        part_test, part_sub_out = model(part_E_neural, part_I_neural)
        test[i] = part_test.cpu().detach().numpy().flatten()
        sub_out[i] = part_sub_out.squeeze(0).T.cpu().detach().numpy()
        
    E_scale = np.exp(model.E_scale.cpu().detach().numpy())
    np.savez(base_dir+cell_type+"_"+experiment+"/"+clust_mode+"/"+model_type+"_s"+str(sub_no)+"_h"+str(H_no)+"_output.npz",
        test=test,
        sub_out=sub_out,
        E_scale = E_scale)
    
    ############
    ############
    #np.savez(base_dir+cell_type+"_"+experiment+"/"+clust_mode+"/"+model_type+"_s"+str(sub_no)+"_h"+str(H_no)+"_pos_output.npz",
        #test=test,
        #sub_out=sub_out,
        #E_scale = E_scale)
    ############
    ############

elif (model_type == "tcnmulti"):
    test = np.zeros((20,50000))
    sub_out = np.zeros((20, sub_no, 50000))

    for i in tnrange(20):
        if i < 19:
            part_E_neural = torch.from_numpy(E_neural[(-20+i)*50000:(-19+i)*50000].toarray()).to(device).float().unsqueeze(0)
            part_I_neural = torch.from_numpy(I_neural[(-20+i)*50000:(-19+i)*50000].toarray()).to(device).float().unsqueeze(0)
        elif i == 19:
            part_E_neural = torch.from_numpy(E_neural[(-20+i)*50000:].toarray()).to(device).float().unsqueeze(0)
            part_I_neural = torch.from_numpy(I_neural[(-20+i)*50000:].toarray()).to(device).float().unsqueeze(0)
        
        part_test = model(part_E_neural, part_I_neural)
        test[i] = part_test.cpu().detach().numpy().flatten()
        
    np.savez(base_dir+cell_type+"_"+experiment+"/"+clust_mode+"/"+model_type+"_s"+str(sub_no)+"_h"+str(H_no)+"_output.npz",
        test=test)
    
elif (model_type == "glm") or (model_type == "glmstack") or (model_type == "glmmulti"):
    test = np.zeros((20,50000))
    nonlin_in = np.zeros((20,sub_no,H_no, 50000))
    sub_out = np.zeros((20, sub_no, 50000))
    
    for i in tnrange(20):
        if i < 19:
            part_E_neural = torch.from_numpy(E_neural[(-20+i)*50000:(-19+i)*50000].toarray()).to(device).float().unsqueeze(0)
            part_I_neural = torch.from_numpy(I_neural[(-20+i)*50000:(-19+i)*50000].toarray()).to(device).float().unsqueeze(0)
        elif i == 19:
            part_E_neural = torch.from_numpy(E_neural[(-20+i)*50000:].toarray()).to(device).float().unsqueeze(0)
            part_I_neural = torch.from_numpy(I_neural[(-20+i)*50000:].toarray()).to(device).float().unsqueeze(0)

        part_test, part_sub_out, part_nonlin_in = model(part_E_neural, part_I_neural)
        test[i] = part_test.cpu().detach().numpy().flatten()
        #sub_out[i] = part_sub_out.squeeze(0).T.cpu().detach().numpy()
        #nonlin_in[i] = part_nonlin_in.squeeze(0).reshape(sub_no, H_no, -1).cpu().detach().numpy()
        
    cos_basis_no = 30
    scale = 7.5
    shift = 1
        
    kern_basis = torch.zeros(cos_basis_no, T_no).to(device)
    for i in range(cos_basis_no):
        phi = 1.5707963267948966*i
        xmin = phi - 3.141592653589793
        xmax = phi + 3.141592653589793

        x_in = torch.arange(0, T_no, 1)
        raw_cos = scale  * torch.log(x_in + shift + 1e-7)

        basis = 0.5*torch.cos(raw_cos - phi) + 0.5
        basis[raw_cos < xmin] = 0.0
        basis[raw_cos > xmax] = 0.0
        kern_basis[i] = basis
        
    e_kern = torch.matmul(model.W_e_layer1, kern_basis).reshape(sub_no, H_no, T_no).cpu().detach().numpy()
    i_kern = torch.matmul(model.W_i_layer1, kern_basis).reshape(sub_no, H_no, T_no).cpu().detach().numpy()
    
    E_scale = np.exp(model.E_scale.cpu().detach().numpy())
    np.savez(base_dir+cell_type+"_"+experiment+"/"+clust_mode+"/"+model_type+"_s"+str(sub_no_file)+"_h"+str(H_no)+"_output.npz",
        test=test,
        nonlin_in=nonlin_in,
        sub_out=sub_out,
        e_kern=e_kern,
        i_kern=i_kern,
        E_scale=E_scale)


  for i in tnrange(20):


  0%|          | 0/20 [00:00<?, ?it/s]