In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorly as tl
from tensorly.decomposition import parafac, non_negative_parafac, tucker, Tucker
from scipy.stats import zscore
from scipy.spatial.distance import jensenshannon
import random
import pickle
import os

base_path = "/projects/academic/smuldoon/bengieru/Community_Detection/general_diagnostics_00/"

import sys
sys.path.insert(0, base_path)

from helpers import *
from Temporal_Community_Detection import temporal_network

In [None]:
data_path = base_path + 'G_ESCR/'

path = data_path + 'Tensor_Parafac/'

with open(data_path + 'spikes.pkl', 'rb') as handle:
    spikes = pickle.load(handle)
    
with open(data_path + 'comm_size.pkl', 'rb') as handle:
    comm_sizes = pickle.load(handle)
    
num_neurons = sum(comm_sizes)
layers = 7

window_size = 1000 # size, in frames, each adjacency matrix correspond to. better to be equal to bin_size 
standard_dev = 1.2 # for gaussian kernel
k = 5 #for jittering the spikes
pad = True

In [None]:
binned_spikes = bin_time_series(spikes, window_size, gaussian = True, sigma = standard_dev)

adjacency_matrices = []
for i in range(layers):
    adjacency_matrices.append(cross_correlation_matrix(binned_spikes[i])[0])
    
if pad:
    padded_adjacencies = [adjacency_matrices[0]]  + adjacency_matrices + [adjacency_matrices[-1]]
    layers = layers + 2

In [None]:
TN = temporal_network(num_neurons, 
                      layers, 
                      window_size, 
                      data = 'list__adjacency', 
                      list_adjacency = padded_adjacencies, 
                      omega = 1, 
                      kind = 'ordinal')

fig,ax = plt.subplots(1,1, figsize = (25,15))
TN.raster_plot(spikes, ax)
plt.savefig(path + 'raster_plot.pdf')

In [None]:
thresholds = np.linspace(0.2,0.5,16)
ranks = np.arange(2,50)

path_no_update = path + 'no_update/'

os.makedirs(path_no_update, exist_ok = True)

In [None]:
membership_no_update, labels_no_update = TN.run_community_detection('PARA_FACT', ranks = ranks, thresholds = thresholds, update_method = None, consensus = False)

In [None]:
fig,ax = plt.subplots(len(ranks),len(thresholds), figsize = (10*len(thresholds),10*len(ranks)+5))
for i in range(len(ranks)):
    for j in range(len(thresholds)):
        comms, c = TN.community(membership_no_update['rank=%d'%ranks[i]][j], ax[i][j])
        ax[i][j].set_xticks([i for i in range(layers)])
        ax[i][j].set_yticks([i*10 for i in range(int(num_neurons/10)+1)])
        ax[i][j].tick_params(axis = 'both', labelsize = 12)
        ax[i][j].set_xlabel('Layers (Time)', fontsize = 25)
        ax[i][j].set_ylabel('Neuron ID', fontsize = 25)
        ax[i][j].set_title('%d Communities, Thresholds:%.3f'%(len(c),thresholds[j]), fontsize=29)
plt.tight_layout()
plt.savefig(path_no_update + 'communities.pdf')

In [None]:
information_recovery(labels_no_update, comm_sizes, 'Scattered', ranks, thresholds, 'grow')
plt.savefig(path_no_update + 'recovery_scattered.pdf')

In [None]:
information_recovery(labels_no_update, comm_sizes, 'Integrated', ranks, thresholds, 'grow')
plt.savefig(path_no_update + 'recovery_integrated.pdf')

In [None]:
with open(path_no_update + "memberships_no_update.pkl", "wb") as fp:
    pickle.dump(membership_no_update, fp)
with open(path_no_update + "labels_no_update", "wb") as fp:
    pickle.dump(labels_no_update, fp)

In [None]:
thresholds = np.linspace(0.2,0.5,16)
ranks = np.arange(2,80)

path_local_update = path + 'local_update/'

os.makedirs(path_local_update, exist_ok = True)

In [None]:
membership_local_update, labels_local_update = TN.run_community_detection('PARA_FACT', ranks = ranks, thresholds = thresholds, update_method = 'local', consensus = False)

In [None]:
fig,ax = plt.subplots(len(ranks),len(thresholds), figsize = (10*len(thresholds),10*len(ranks)+5))
for i in range(len(ranks)):
    for j in range(len(thresholds)):
        comms, c = TN.community(membership_local_update['rank=%d'%ranks[i]][j], ax[i][j])
        ax[i][j].set_xticks([i for i in range(layers)])
        ax[i][j].set_yticks([i*10 for i in range(int(num_neurons/10)+1)])
        ax[i][j].tick_params(axis = 'both', labelsize = 12)
        ax[i][j].set_xlabel('Layers (Time)', fontsize = 25)
        ax[i][j].set_ylabel('Neuron ID', fontsize = 25)
        ax[i][j].set_title('%d Communities, Thresholds:%.3f'%(len(c),thresholds[j]), fontsize=29)
plt.tight_layout()
plt.savefig(path_local_update + 'communities.pdf')

In [None]:
information_recovery(labels_local_update, comm_sizes, 'Scattered', ranks, thresholds, 'grow')
plt.savefig(path_local_update + 'recovery_scattered.pdf')

In [None]:
information_recovery(labels_local_update, comm_sizes, 'Integrated', ranks, thresholds, 'grow')
plt.savefig(path_local_update + 'recovery_integrated.pdf')

In [None]:
with open(path_local_update + "memberships_local_update.pkl", "wb") as fp:
    pickle.dump(membership_local_update, fp)
with open(path_local_update + "labels_local_update", "wb") as fp:
    pickle.dump(labels_local_update, fp)

In [None]:
thresholds = np.linspace(0.2,0.5,16)
ranks = np.arange(2,80)

path_global_update = path + 'global_update/'

os.makedirs(path_global_update, exist_ok = True)

In [None]:
membership_global_update, labels_global_update = TN.run_community_detection('PARA_FACT', ranks = ranks, thresholds = thresholds, update_method = 'global', consensus = False)

In [None]:
fig,ax = plt.subplots(len(ranks),len(thresholds), figsize = (10*len(thresholds),10*len(ranks)+5))
for i in range(len(ranks)):
    for j in range(len(thresholds)):
        comms, c = TN.community(membership_global_update['rank=%d'%ranks[i]][j], ax[i][j])
        ax[i][j].set_xticks([i for i in range(layers)])
        ax[i][j].set_yticks([i*10 for i in range(int(num_neurons/10)+1)])
        ax[i][j].tick_params(axis = 'both', labelsize = 12)
        ax[i][j].set_xlabel('Layers (Time)', fontsize = 25)
        ax[i][j].set_ylabel('Neuron ID', fontsize = 25)
        ax[i][j].set_title('%d Communities, Thresholds:%.3f'%(len(c),thresholds[j]), fontsize=29)
plt.tight_layout()
plt.savefig(path_global_update + 'communities.pdf')

In [None]:
information_recovery(labels_global_update, comm_sizes, 'Scattered', ranks, thresholds, 'grow')
plt.savefig(path_global_update + 'recovery_scattered.pdf')

In [None]:
information_recovery(labels_global_update, comm_sizes, 'Integrated', ranks, thresholds, 'grow')
plt.savefig(path_global_update + 'recovery_integrated.pdf')

In [None]:
with open(path_global_update + "memberships_global_update.pkl", "wb") as fp:
    pickle.dump(membership_global_update, fp)
with open(path_global_update + "labels_global_update", "wb") as fp:
    pickle.dump(labels_global_update, fp)

In [None]:
thresholds = np.linspace(0.2,0.5,16)
ranks = np.arange(2,80)

path_nbr_update = path + 'nbr_update/'

os.makedirs(path_nbr_update, exist_ok = True)

In [None]:
membership_nbr_update, labels_nbr_update = TN.run_community_detection('PARA_FACT', ranks = ranks, thresholds = thresholds, update_method = 'neighborhood', consensus = False)

In [None]:
fig,ax = plt.subplots(len(ranks),len(thresholds), figsize = (10*len(thresholds),10*len(ranks)+5))
for i in range(len(ranks)):
    for j in range(len(thresholds)):
        comms, c = TN.community(membership_nbr_update['rank=%d'%ranks[i]][j], ax[i][j])
        ax[i][j].set_xticks([i for i in range(layers)])
        ax[i][j].set_yticks([i*10 for i in range(int(num_neurons/10)+1)])
        ax[i][j].tick_params(axis = 'both', labelsize = 12)
        ax[i][j].set_xlabel('Layers (Time)', fontsize = 25)
        ax[i][j].set_ylabel('Neuron ID', fontsize = 25)
        ax[i][j].set_title('%d Communities, Thresholds:%.3f'%(len(c),thresholds[j]), fontsize=29)
plt.tight_layout()
plt.savefig(path_nbr_update + 'communities.pdf')

In [None]:
information_recovery(labels_nbr_update, comm_sizes, 'Scattered', ranks, thresholds, 'grow')
plt.savefig(path_nbr_update + 'recovery_scattered.pdf')

In [None]:
information_recovery(labels_nbr_update, comm_sizes, 'Integrated', ranks, thresholds, 'grow')
plt.savefig(path_nbr_update + 'recovery_integrated.pdf')

In [None]:
with open(path_nbr_update + "memberships_nbr_update.pkl", "wb") as fp:
    pickle.dump(membership_nbr_update, fp)
with open(path_nbr_update + "labels_nbr_update", "wb") as fp:
    pickle.dump(labels_nbr_update, fp)