In [2]:
import numpy as np
import pandas as pd

In [11]:
import os
import re

In [4]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable # colorbar

In [5]:
# path where to save numpy files
summary_path = '/home/rudy/Python2/regression_linear2/results/summary/'


In [7]:
# get list of session
session = os.listdir(summary_path)

In [None]:
#############################################################################################

In [None]:
# get dict
def get_dico_cortex():
    ''' dico_cortex['cortex']= list of areas in cortex '''

    dico_cortex = {'Parietal': ['AIP',
    'LIP',
    'MIP',
    'PIP',
    'TPt',
    'VIP',
    'a23',
    'a5',
    'a7A',
    'a7B',
    'a7M',
    'a7op'],
    'Subcortical': ['Caudate', 'Claustrum', 'Putamen', 'Thal'],
    'Auditory': ['Core', 'MB', 'PBr'],
    'Visual': ['DP',
    'FST',
    'MST',
    'MT',
    'TEOM',
    'TEO',
    'TEpd',
    'V1',
    'V2',
    'V3',
    'V3A',
    'V4',
    'V4t',
    'V6A'],
    'Motor': ['F1', 'F2','F3' ,'F5', 'F6', 'F7'],
    'Temporal': ['Ins', 'STPc'],
    'Prefrontal': ['OPRO',
    'a9',
    'a11',
    'a12',
    'a13',
    'a14',
    'a24D',
    'a24c',
    'a32',
    'a44',
    'a45A',
    'a45B',
    'a46D',
    'a46V',
    'a8B',
    'a8L',
    'a8M',
    'a8r',
    'a9/46D',
    'a9/46V'],
    'Somatosensory': ['SII', 'a1', 'a2', 'a3']}
    
    return( dico_cortex )

def get_dico_area_to_cortex():
    "dico[area] = cortex"
    
    dico_cortex = get_dico_cortex()
    
    dico_area_to_cortex = {}
    for c in dico_cortex.keys():
        areas = dico_cortex[c]
        for area in areas:
            dico_area_to_cortex[area] = c
            
    return(dico_area_to_cortex)

In [None]:
########################################################################################################################

In [None]:
############

In [None]:
columns = ['session', 'area1', 'area2', 'r2']
data = []
for sess_no in session :
    #print(sess_no)
    directory = summary_path +sess_no+'/'

    label1 = np.load( directory + 'label.npy') # input label
    label2 = np.load( directory + 'label.npy') # output label

    FC_time_course = np.load( directory + 'FC_all_channels_low'+str(lowcut)+'high'+str(highcut)+'order'+str(order)+'_all_intervals.npy')
    FC= np.mean(FC_time_course, axis = 2)
    for i in range(len(label1)):
        for j in range(len(label2)):
            if i!=j: # don't keep electrode to the same electrode
                data.append([sess_no, label1[i], label2[j], FC[i,j]])

In [None]:
df_all_r2 = pd.DataFrame(data, columns=columns)



In [None]:
#############

In [None]:
acc = df_all_r2.copy()
#acc['N'] = 1
# mean for each session
acc = df_all_r2.groupby(['session', 'area1', 'area2']).agg({
    'r2': np.mean,
}).reset_index()

In [None]:
# mean on session
acc['n_session_available'] = 1
acc = acc.groupby(['area1', 'area2']).agg({
    'r2': np.mean,
    'n_session_available': 'count',
}).reset_index()

In [None]:
# add cortex information
dico_area_to_cortex = get_dico_area_to_cortex()
acc['cortex1'] = acc.apply(lambda row : dico_area_to_cortex[row.area1], axis=1)
acc['cortex2'] = acc.apply(lambda row : dico_area_to_cortex[row.area2], axis=1)

In [34]:
len( acc[(acc['r2']>0.05) & (acc['n_session_available']> 5)].values )

In [None]:
acc.sort_values(['cortex1', 'cortex2', 'area1','area2'], inplace=True)

In [9]:
####################

In [None]:


for cortex1 in ['Visual', 'Prefrontal', 'Parietal', 'Motor', 'Somatosensory'] :
    for cortex2 in ['Visual', 'Prefrontal', 'Parietal', 'Motor', 'Somatosensory'] :
        
        areas1 = acc[(acc['cortex1'] ==cortex1) & (acc['cortex2'] == cortex2)]['area1'].unique()
        areas2 = acc[(acc['cortex1'] ==cortex1) & (acc['cortex2'] == cortex2)]['area2'].unique()
        
        for area1 in areas1:
            for area2 in areas2 :
                
                # plot time_course
                fig, axarr = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(15,5))
                
                
                # align_on = 'sample'
                align_on = 'sample'
                
                FC_time_course_mean = 0
                N = 0            
                    
                for sess_no in session :
                    directory = summary_path +sess_no+'/'
                    
                    # load time course
                    FC_time_course = directory + 'FC_all_channels_low'+str(lowcut)+'high'+str(highcut)+'order'+str(order)+'_all_intervals.npy'
                    label = directory + 'label.npy'
                    time = directory + 'time.npy'
                    
                    # select idx
                    ind1 = (label==area1)
                    ind2 = (label==area2)
                    
                    ind3 = [ el.split('_')[2]==align_on for el in time]
                    
                    # drop if not available
                    if np.sum(ind1) == 0 or np.sum(ind2) == 0 or  np.sum(ind3) == 0 :
                        continue
                        
                    if area1 == area2:
                        # only one electrode is available 
                        if np.sum(ind1) ==1 and np.sum(ind2)==1:
                            continue
                            
                                            
                    # select it
                    FC_time_course = FC_time_course[ind1,:,:]
                    FC_time_course = FC_time_course[:,ind2,:]
                    FC_time_course = FC_time_course[:,:,ind3]
                    
                    time = time[ind3]
                    
                    
                    # plot all time course        
                    time_label = [ (float(re.findall('\d+', el)[0])+float(re.findall('\d+', el)[1]))/2 for el in time ]
                    
                    axarr[0, 0].plot(time, np.round( FC_time_course.reshape((-1, FC_time_course.shape[2])).T , 3))
                    #axarr[0, 0].plot(time, FC_time_course.reshape((-1, FC_time_course.shape[2])).T , 3)
                    
                    # Calculate the mean
                    if area1 != area2 :
                        N += 1
                        FC_time_course_mean += np.mean(FC_time_course)
                        
                    else:
                        N += 1
                        # don't take in count prediction to the same electrode
                        FC_time_course_mean += np.mean(FC_time_course_mean) - np.mean(np.diagonal(FC, axis1=0, axis2=1), axis=1)
                        
                axarr[1, 0].plot(time_label, FC_time_course_mean/N)
                
                # align_on == 'match'           
                align_on = 'match'
                FC_time_course_mean = 0
                N = 0            
                    
                for sess_no in session :
                    directory = summary_path +sess_no+'/'
                    
                    # load time course
                    FC_time_course = directory + 'FC_all_channels_low'+str(lowcut)+'high'+str(highcut)+'order'+str(order)+'_all_intervals.npy'
                    label = directory + 'label.npy'
                    time = directory + 'time.npy'
                    
                    # select idx
                    ind1 = (label==area1)
                    ind2 = (label==area2)
                    ind3 = [ el.split('_')[2]==align_on for el in time]
                    
                    # drop if not available
                    if np.sum(ind1) == 0 or np.sum(ind2) == 0 or  np.sum(ind3) == 0 :
                        continue
                        
                    if area1 == area2:
                        # only one electrode is available 
                        if np.sum(ind1) ==1 and np.sum(ind2)==1:
                            continue
                            
                                            
                    # select it
                    FC_time_course = FC_time_course[ind1,:,:]
                    FC_time_course = FC_time_course[:,ind2,:]
                    FC_time_course = FC_time_course[:,:,ind3]
                    
                    time = time[ind3]
                    
                    
                    # plot all time course        
                    time_label = [ (float(re.findall('\d+', el)[0])+float(re.findall('\d+', el)[1]))/2 for el in time ]
                    
                    axarr[0, 1].plot(time, np.round( FC_time_course.reshape((-1, FC_time_course.shape[2])).T , 3))
                    #axarr[0, 0].plot(time, FC_time_course.reshape((-1, FC_time_course.shape[2])).T , 3)
                    
                    # Calculate the mean
                    if area1 != area2 :
                        N += 1
                        FC_time_course_mean += np.mean(FC_time_course)
                        
                    else:
                        N += 1
                        # don't take in count prediction to the same electrode
                        FC_time_course_mean += np.mean(FC_time_course_mean) - np.mean(np.diagonal(FC, axis1=0, axis2=1), axis=1)
                        
                axarr[1, 1].plot(time_label, FC_time_course_mean/N)
                
                
                axarr[1].plot(time, FC_time_course_mean/N)

                # set legend, etc
                fig.suptitle(area1, ' to ', area2)
                # 
                axarr[0, 0].axvline(x= 0, color='r',label='sample on')
                axarr[0, 0].axvline(x= 500, color='r', linestyle='--',label='sample off')
                axarr[0, 0].legend()
                axarr[0, 0].set_title('all time course')
                
                ###
                axarr[1, 0].axvline(x= 0, color='r',label='sample on')
                axarr[1, 0].axvline(x= 500, color='r', linestyle='--',label='sample off')
                axarr[1, 0].legend()
                axarr[1, 0].set_title('mean')
                
                ###
                axarr[0, 1].axvline(x= 0, color='r',label='match on')
                axarr[0, 1].legend()
                axarr[0, 1].set_title('all time course')
                
                ###
                axarr[1, 1].axvline(x= 0, color='r',label='match on')

                axarr[1, 1].legend()
                axarr[1, 1].set_title('mean')

                plt.show()
                
                