In [48]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from greedy_base_hGLM import Greedy_Base_hGLM

In [49]:
def make_C_den(raw):
    sub_no = raw.shape[0] + 1
    C_den = torch.zeros(sub_no, sub_no)
    for i in range(sub_no - 1):
        C_den[raw[i], i+1] = 1
    return C_den

cell_type = "CA1"
T_no = 100

if cell_type == "CA1":
    C_den_raw = torch.tensor([0,1,0,0,4,0,2,4,7])[:]
    E_no = 2000
    I_no = 200
    C_den = make_C_den(C_den_raw)
    model = Greedy_Base_hGLM(C_den.cuda(), E_no, I_no, T_no)
    model.load_state_dict(torch.load("/media/hdd01/sklee/greedy/greedybaseGLM_CA1_sub10-7.pt"))
    scores = np.asarray([0.7740,0.8197,0.8389,0.8398,0.8513,0.8539,0.8554,0.8719,0.8714])
elif cell_type == "L23":
    C_den_raw =  torch.tensor([0,1,1,0,0,0,0,1,1])[:3]
    E_no = 629
    I_no = 120
    C_den = make_C_den(C_den_raw)
    model = Greedy_Base_hGLM(C_den.cuda(), E_no, I_no, T_no)
    model.load_state_dict(torch.load("/media/hdd01/sklee/greedy/greedybaseGLM_L23_sub4-0.pt"))
    scores = np.asarray([0.8863,0.8896,0.9104,0.9139,0.9146,0.9108,0.9137,0.9156,0.9148])


In [38]:
e_raw = model.C_syn_e_logit
e_clean = torch.zeros_like(e_raw)

for i in range(e_raw.shape[1]):
    idx = torch.argmax(e_raw[:,i])
    e_clean[idx,i] = 1
    
i_raw = model.C_syn_i_logit
i_clean = torch.zeros_like(i_raw)

for i in range(i_raw.shape[1]):
    idx = torch.argmax(i_raw[:,i])
    i_clean[idx,i] = 1
    

In [39]:
e_plot = torch.zeros(e_clean.shape[1])
for i in range(e_clean.shape[1]):
    e_plot[i] = torch.where(e_clean[:,i] == 1)[0]
    
i_plot = torch.zeros(i_clean.shape[1])
for i in range(i_clean.shape[1]):
    i_plot[i] = torch.where(i_clean[:,i] == 1)[0]

In [41]:
fig, axs = plt.subplots(nrows=2, figsize=(15,10))
axs[0].scatter(np.arange(E_no), e_plot, s=0.5, color="blue")
axs[1].scatter(np.arange(I_no), i_plot, s=0.5, color="red")

axs[0].set_ylabel("Subunit ID")
axs[0].set_xlabel("Synapse ID")
axs[0].set_title(cell_type+" Excitatory Subunit-Synapse Pairing", fontsize=13)
axs[1].set_ylabel("Subunit ID")
axs[1].set_xlabel("Synapse ID")
axs[1].set_title(cell_type+" Inhibitory Subunit-Synapse Pairing", fontsize=13)

#plt.savefig("/media/hdd01/sklee/greedy/L23_sub4_clusters.png", dpi=150, bbox_inches='tight')
#plt.close()

In [52]:
plt.figure(figsize = (10,8) )
plt.plot(np.arange(2,2+scores.shape[0]),scores)
plt.title(cell_type+" Variance Explained Scores", fontsize=14)
plt.ylabel("Variance Explained")
plt.xlabel("Subunit Number")

#plt.savefig("/media/hdd01/sklee/greedy/CA1_search_scores.png", dpi=150, bbox_inches='tight')
#plt.close()