In [None]:
import numpy as np
from Temporal_Community_Detection import temporal_network
from helpers import *
import matplotlib.pyplot as plt
import pickle
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score, f1_score
from math import floor
import random
import csv
import matplotlib as mpl

In [None]:
path = '/projects/academic/smuldoon/bengieru/Community_Detection/calcium_data_test/' ## base path
output = 'Johan_Clean_Traces_Features_and_Spikes/' #spikes and traces file
roi = 'sarah_ROI/' #roi file
subjects = load_obj(path + 'subjects/', 'subjects')
epochs = ['_baseline', '_early', '_pre']
trackable = [subjects['wt'][3], subjects['het'][6], subjects['het'][8], subjects['het'][13], subjects['het'][4], subjects['het'][3], subjects['het'][2]]

pvs = {}
pvs['%s'%trackable[0]] = load_obj(path +'subjects/','pv_wt')['%s'%trackable[0]]
for i in range(1,len(trackable)):
    temp = load_obj(path +'subjects/','pv_het')['%s'%trackable[i]]
    pvs['%s'%trackable[i]] = (np.array(temp)-1).tolist() ##adjust python indexing

In [None]:
time = 2000 ## binning the time into chunks of
layers = int(8000/time) ## number of layers
spikes_trackable = {}
num_rois_trackable = {}
for i,e in enumerate(trackable):
    for j,f in enumerate(epochs):
        spikes_trackable['%s'%(e+f)] = read_csv(path, output, e+f, roi)
        num_rois_trackable['%s'%(e+f)] = read_roi(path, roi, e+f) ## number of rois

In [None]:
binned_binary_spikes_trackable = {}
for i,e in enumerate(trackable):
    path_nbr_update = path + 'infomap/neighborhood_update/%s/'%(e)
    os.makedirs(path_nbr_update, exist_ok = True)
    for j,f in enumerate(epochs):
        # bin the spikes into fixed length and apply gaussian kernel of length 9
        binary_trackable = binarize(spikes_trackable['%s'%(e+f)])
        binned_binary_spikes_trackable['%s'%(e+f)] = bin_time_series(binary_trackable, 
                                                                     time, 
                                                                     gaussian = True, 
                                                                     sigma = 1.25)
        
        with open(path_nbr_update + "binary_spikes_%s.pkl"%(e+f), "wb") as fp:
            pickle.dump(binary_trackable, fp)

In [None]:
adjacency_matrices = {}
for i,e in enumerate(trackable):
    path_nbr_update = path + 'infomap/neighborhood_update/%s/'%(e)
    for j,f in enumerate(epochs):  
        adjacencies = []
        for k in range(layers):
            adjacencies.append(cross_correlation_matrix(binned_binary_spikes_trackable['%s'%(e+f)][k])[0])
        adjacency_matrices['%s'%(e+f)] = adjacencies
        
        fig,ax = plt.subplots(1,4, figsize = (60,20))
        for i in range(4):
            k = ax[i].imshow(adjacency_matrices['%s'%(e+f)][i], 
                             origin = 'lower', 
                             interpolation = 'nearest', 
                             aspect='auto',  
                             extent = [0,num_rois_trackable['%s'%(e+f)],0,num_rois_trackable['%s'%(e+f)]])
            ax[i].set_title('Adjacency Matrix (Layer %d)'%(i +1), fontsize = 30)
            ax[i].set_xticks([k*10 for k in range(int(num_rois_trackable['%s'%(e+f)]/10)+1)])
            ax[i].set_yticks([k*10 for k in range(int(num_rois_trackable['%s'%(e+f)]/10)+1)])
            ax[i].tick_params(axis = 'both', labelsize = 25)
        cbar = fig.colorbar(k, ax = ax.flat, orientation = 'horizontal')
        cbar.ax.tick_params(labelsize = 30) 
        plt.savefig(path_nbr_update + 'adjacencies_%s.pdf'%(e+f))

In [None]:
TNs = {}
for i,e in enumerate(trackable):
    for j,f in enumerate(epochs): 
        TNs['%s'%(e+f)] = temporal_network(num_rois_trackable['%s'%(e+f)], 
                                    layers, 
                                    time, 
                                    data = 'list__adjacency', 
                                    list_adjacency = adjacency_matrices['%s'%(e+f)], 
                                    omega = 1, 
                                    kind = 'ordinal')

In [None]:
threshs = np.linspace(0.1, 0.5, 9)
inters = np.linspace(0, 0.8, 9)
for i,e in enumerate(trackable):
    path_nbr_update = path + 'infomap/neighborhood_update/%s/'%(e)
    for j,f in enumerate(epochs):
        membership_nbr_update, labels_nbr_update = TNs['%s'%(e+f)].run_community_detection('infomap', 
                                                                                           update_method = 'neighborhood', 
                                                                                           interlayers = inters, 
                                                                                           thresholds = threshs)
        with open(path_nbr_update + "labels_%s.pkl"%(e+f), "wb") as fp:
            pickle.dump(labels_nbr_update, fp)
        
        with open(path_nbr_update + "membership_%s.pkl"%(e+f), "wb") as fp:
            pickle.dump(membership_nbr_update, fp)
            
        fig,ax = plt.subplots(9,9, figsize = (9*15+5,9*15))
        for i in range(9):
            for j in range(9):
                comms, c = TNs['%s'%(e+f)].community(membership_nbr_update['interlayer=%.3f'%inters[i]][j], ax[i][j])
                ax[i][j].set_xticks([i for i in range(layers)])
                ax[i][j].set_yticks([i*10 for i in range(int(num_rois_trackable['%s'%(e+f)]/10)+1)])
                ax[i][j].tick_params(axis = 'both', labelsize = 12)
                ax[i][j].set_xlabel('Layers (Time)', fontsize = 25)
                ax[i][j].set_ylabel('Neuron ID', fontsize = 25)
                ax[i][j].set_title('%d Communities, interlayer:%.3f, threshold:%.3f'%(len(c),inters[i],threshs[j]), fontsize=29)
        plt.tight_layout()
        plt.savefig(path_nbr_update + 'communities_%s.pdf'%(e+f))