In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable # colorbar
import pandas as pd

import os

In [None]:
figure_path = '/home/rudy/Python2/regression_linear/figure/'

In [None]:
session = os.listdir(figure_path)

In [None]:
len(session)

In [None]:
# dict 

In [None]:
# get dict
def get_dico_cortex():
    ''' dico_cortex['cortex']= list of areas in cortex '''

    dico_cortex = {'Parietal': ['AIP',
    'LIP',
    'MIP',
    'PIP',
    'TPt',
    'VIP',
    'a23',
    'a5',
    'a7A',
    'a7B',
    'a7M',
    'a7op'],
    'Subcortical': ['Caudate', 'Claustrum', 'Putamen', 'Thal'],
    'Auditory': ['Core', 'MB', 'PBr'],
    'Visual': ['DP',
    'FST',
    'MST',
    'MT',
    'TEOM',
    'TEO',
    'TEpd',
    'V1',
    'V2',
    'V3',
    'V3A',
    'V4',
    'V4t',
    'V6A'],
    'Motor': ['F1', 'F2','F3' ,'F5', 'F6', 'F7'],
    'Temporal': ['Ins', 'STPc'],
    'Prefrontal': ['OPRO',
    'a9',
    'a11',
    'a12',
    'a13',
    'a14',
    'a24D',
    'a24c',
    'a32',
    'a44',
    'a45A',
    'a45B',
    'a46D',
    'a46V',
    'a8B',
    'a8L',
    'a8M',
    'a8r',
    'a9/46D',
    'a9/46V'],
    'Somatosensory': ['SII', 'a1', 'a2', 'a3']}
    
    return( dico_cortex )

def get_dico_area_to_cortex():
    "dico[area] = cortex"
    
    dico_cortex = get_dico_cortex()
    
    dico_area_to_cortex = {}
    for c in dico_cortex.keys():
        areas = dico_cortex[c]
        for area in areas:
            dico_area_to_cortex[area] = c
            
    return(dico_area_to_cortex)

In [None]:
# test to see if it works correctly

In [None]:
# 
sess_no = session[0]
base_path = figure_path + sess_no + '/'


In [None]:
# load numpy files

label1 = np.load( base_path + 'label1.npy') # input label
label2 = np.load( base_path + 'label2.npy') # output label
FC = np.load( base_path + 'channel_to_channel_all_cortex_low_7_high_12.npy')

In [None]:
fig, ax = plt.subplots(figsize=(15,15))
# threshold = 0.05 # I will add error bar file 

#fig, ax = plt.subplots()
im = ax.imshow(FC, vmin=0)#, vmax=1.0)# cmap='jet')
# make the size of the colorbar the same as the picture 
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.10)
fig.colorbar(im, cax=cax)


# We want to show all ticks...
ax.set_xticks(np.arange(len(label2)))
ax.set_yticks(np.arange(len(label1)))
# ... and label them with the respective list entries
ax.set_xticklabels(label2)
ax.set_yticklabels(label1)

# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
         rotation_mode="anchor")

# #Loop over data dimensions and create text annotations.
# for i in range(len(label1)):
#     for j in range(len(label2)):
#          if FC[i, j] > threshold :
#             # plot round number 
#             # and with black or white according to the color of the pixel

#             if round(FC[i, j],3) < 0.8* FC.max() :

#                 text = ax.text(j, i, round(FC[i, j],1),
#                                ha="center", va="center", color="w")
#             else:
#                 text = ax.text(j, i, round(FC[i, j],1),
#                                ha="center", va="center", color="b")

plt.show()
    

In [None]:
######################################

# df

In [None]:
# create data frame 

In [None]:
columns = ['session', 'area1', 'area2', 'r2']
data = []
for sess_no in session :
    #print(sess_no)
    base_path = figure_path + sess_no + '/'
    label1 = np.load( base_path + 'label1.npy') # input label
    label2 = np.load( base_path + 'label2.npy') # output label

    FC = np.load( base_path + 'channel_to_channel_all_cortex_low_7_high_12.npy')
    
    for i in range(len(label1)):
        for j in range(len(label2)):
            if i!=j: # don't keep electrode to the same electrode
                data.append([sess_no, label1[i], label2[j], FC[i,j]])
    

In [None]:
df_all_r2 = pd.DataFrame(data, columns=columns)

In [None]:
df_all_r2

In [None]:
# dico_area_to_cortex = get_dico_area_to_cortex()
# df_all_r2['cortex1'] = df_all_r2.apply(lambda row : dico_area_to_cortex[row.area1], axis=1)
# df_all_r2['cortex2'] = df_all_r2.apply(lambda row : dico_area_to_cortex[row.area2], axis=1)

In [None]:
# mean on area1 area2 for each session 

In [None]:

acc = df_all_r2.copy()
#acc['N'] = 1
# mean for each session
acc = df_all_r2.groupby(['session', 'area1', 'area2']).agg({
    'r2': np.mean,
}).reset_index()

In [None]:
acc.sort_values('r2', ascending=False)

In [None]:
# mean on session
acc['n_session_available'] = 1
acc = acc.groupby(['area1', 'area2']).agg({
    'r2': np.mean,
    'n_session_available': 'count',
}).reset_index()

In [None]:
# add cortex information
dico_area_to_cortex = get_dico_area_to_cortex()
acc['cortex1'] = acc.apply(lambda row : dico_area_to_cortex[row.area1], axis=1)
acc['cortex2'] = acc.apply(lambda row : dico_area_to_cortex[row.area2], axis=1)

In [None]:
acc.sort_values('r2', ascending=False)

In [None]:
len( acc[(acc['r2']>0.05) & (acc['n_session_available']> 5)].values )

In [None]:
threshold_r2 = 0.05
threshold_session = 5

acc.sort_values('r2', ascending=False, inplace=True)

areas1 = acc[ (acc['r2'] > threshold_r2) & (acc['n_session_available'] > threshold_session)]['area1'].values
areas2 = acc[ (acc['r2'] > threshold_r2) & (acc['n_session_available'] > threshold_session)]['area2'].values
r2_score= acc[(acc['r2'] > threshold_r2) & (acc['n_session_available'] > threshold_session)]['r2'].values

In [None]:
len(areas1)
len(areas2)

In [None]:
acc.cortex1.unique()

In [None]:
acc.cortex2.unique()

In [None]:
######################
# PLOT TIME COURSE ###
######################

In [None]:
acc

In [None]:
acc.sort_values(['cortex1', 'cortex2', 'area1','area2'], inplace=True)

In [None]:
# acc.groupby(['cortex1', 'cortex2', 'area1','area2'], inplace=True)
# for idx, row in acc.iterrows():
#     ...

In [None]:
# to plot the time course

In [None]:
# plot only what is above the threshold, but print 
threshold_r2 = 0.05
threshold_session = 5 

#acc.sort_values(['cortex1', 'cortex2', 'area1','area2'], inplace=True)

for cortex1 in ['Visual', 'Prefrontal', 'Parietal', 'Motor', 'Somatosensory'] :
    for cortex2 in ['Visual', 'Prefrontal', 'Parietal', 'Motor', 'Somatosensory'] :
        
        areas1 = acc[(acc['cortex1'] ==cortex1) & (acc['cortex2'] == cortex2)]['area1'].unique()
        areas2 = acc[(acc['cortex1'] ==cortex1) & (acc['cortex2'] == cortex2)]['area2'].unique()
        
        for area1 in areas1:
            for area2 in areas2 :
                
                r2_score = acc[ (acc['area1'] == area1) & (acc['area2'] == area2)]['r2'].values
                N = acc[ (acc['area1'] == area1) & (acc['area2'] == area2)]['n_session_available'].values
                
                if len(N) == 0 :
                    continue
                r2_score = r2_score[0]
                N = N[0]
                
                
                print(area1, ' to ', area2)
                print('r2 :', r2_score)
                print('number of session :', N)
                
                if r2_score < threshold_r2 or N < threshold_session :
                    continue
                    
                fig, axarr = plt.subplots(ncols=2, figsize=(15,5))
                
                
                FC_time_course_mean = 0
                N = 0        
                for sess_no in session :
                    # path to numpy file
                    base_path = figure_path + sess_no + '/'

                    # load data
                    label1 = np.load( base_path + 'label1.npy') # input label
                    label2 = np.load( base_path + 'label2.npy') # output label
                    time = np.load( base_path +'time_for_time_course.npy')

                    FC_time_course = np.load(base_path + 'channel_to_channel_all_cortex_low_7_high_12time_course.npy')

                    # select channels
                    ind1 = (label1==area1)
                    ind2 = (label2==area2)

                    if np.sum(ind1) == 0 or np.sum(ind2) ==0 :
                        continue

                    FC_time_course = FC_time_course[ind1, :, :]
                    FC_time_course = FC_time_course[:,ind2,:]
                    
#                     print(FC_time_course.shape)

                    if FC_time_course.shape[2] != 43 :
                        continue

                    # plot each time course 
                    axarr[0].plot(time, np.round( FC_time_course.reshape((-1, FC_time_course.shape[2])).T , 3))

                    # sum the mean on this session
                    if area1 != area2 :
#                         N += FC_time_course.shape[0]* FC_time_course.shape[1]
#                         FC_time_course_mean +=  np.sum(FC_time_course, axis=(0,1))

                        N += 1
                        FC_time_course_mean +=  np.mean(FC_time_course, axis=(0,1))


                    else:
#                         N += FC_time_course.shape[0]*FC_time_course.shape[1] - FC_time_course.shape[0]
#                         for i1 in range(FC_time_course.shape[0]):
#                             for i2 in range(FC_time_course.shape[1]):
#                                 if i1 != i2 :
#                                     FC_time_course_mean += FC_time_course[i1,i2,:]
                        N += 1

                        for i1 in range(FC_time_course.shape[0]):
                            for i2 in range(FC_time_course.shape[1]):
                                if i1 != i2 :
                                    FC_time_course_mean += FC_time_course[i1,i2,:]/(FC_time_course.shape[0]*FC_time_course.shape[1] - FC_time_course.shape[0])

                # plot mean of the mean 
                print(N)
                axarr[1].plot(time, FC_time_course_mean/N)

                # plot line
                axarr[0].axvline(x= 0, color='r',label='sample on ')
                axarr[0].axvline(x= 500, color='r', linestyle='--',label='sample off')
                axarr[0].axvline(x= 2000, color='g',label='match on ')

                axarr[0].legend()
                axarr[0].set_title('all session all channels \n'+area1+' to '+area2)

                axarr[1].axvline(x= 0, color='r',label='sample on ')
                axarr[1].axvline(x= 500, color='r', linestyle='--',label='sample off')
                axarr[1].axvline(x= 2000, color='g',label='match on ')

                axarr[1].legend()
                axarr[1].set_title('mean \n'+area1+' to '+area2)

                plt.show()
        
        
        

In [None]:
for cortex1 in ['Visual', 'Prefrontal', 'Parietal', 'Motor', 'Somatosensory'] :
    for cortex2 in ['Visual', 'Prefrontal', 'Parietal', 'Motor', 'Somatosensory'] :
        
        areas1 = acc[(acc['cortex1'] ==cortex1) & (acc['cortex2'] == cortex2)]['area1'].unique()
        areas2 = acc[(acc['cortex1'] ==cortex1) & (acc['cortex2'] == cortex2)]['area2'].unique()
        
        for area1 in areas1:
            for area2 in areas2 :
                
                r2_score = acc[ (acc['area1'] == area1) & (acc['area2'] == area2)]['r2'].values
                N = acc[ (acc['area1'] == area1) & (acc['area2'] == area2)]['n_session_available'].values
                
                if len(N) == 0 :
                    continue
                r2_score = r2_score[0]
                N = N[0]
                
                
#                 print(area1, ' to ', area2)
#                 print('r2 :', r2_score)
#                 print('number of session :', N)
                
                if r2_score < threshold_r2 or N < threshold_session :
                    continue
                
                print(area1, ' to ', area2)
                print('r2 :', r2_score)
                print('number of session :', N)
                    
                fig, axarr = plt.subplots(ncols=2, figsize=(15,5))
                
                
                FC_time_course_mean = 0
                N = 0        
                for sess_no in session :
                    # path to numpy file
                    base_path = figure_path + sess_no + '/'

                    # load data
                    label1 = np.load( base_path + 'label1.npy') # input label
                    label2 = np.load( base_path + 'label2.npy') # output label
                    time = np.load( base_path +'time_for_time_course.npy')

                    FC_time_course = np.load(base_path + 'channel_to_channel_all_cortex_low_7_high_12time_course.npy')

                    # select channels
                    ind1 = (label1==area1)
                    ind2 = (label2==area2)

                    if np.sum(ind1) == 0 or np.sum(ind2) ==0 :
                        continue

                    FC_time_course = FC_time_course[ind1, :, :]
                    FC_time_course = FC_time_course[:,ind2,:]
                    
#                     print(FC_time_course.shape)

                    if FC_time_course.shape[2] != 43 :
                        continue

                    # plot each time course 
                    axarr[0].plot(time, np.round( FC_time_course.reshape((-1, FC_time_course.shape[2])).T , 3))

                    # sum the mean on this session
                    if area1 != area2 :
#                         N += FC_time_course.shape[0]* FC_time_course.shape[1]
#                         FC_time_course_mean +=  np.sum(FC_time_course, axis=(0,1))

                        N += 1
                        FC_time_course_mean +=  np.mean(FC_time_course, axis=(0,1))


                    else:
#                         N += FC_time_course.shape[0]*FC_time_course.shape[1] - FC_time_course.shape[0]
#                         for i1 in range(FC_time_course.shape[0]):
#                             for i2 in range(FC_time_course.shape[1]):
#                                 if i1 != i2 :
#                                     FC_time_course_mean += FC_time_course[i1,i2,:]
                        N += 1

                        for i1 in range(FC_time_course.shape[0]):
                            for i2 in range(FC_time_course.shape[1]):
                                if i1 != i2 :
                                    FC_time_course_mean += FC_time_course[i1,i2,:]/(FC_time_course.shape[0]*FC_time_course.shape[1] - FC_time_course.shape[0])

                # plot mean of the mean 
                print(N)
                axarr[1].plot(time, FC_time_course_mean/N)

                # plot line
                axarr[0].axvline(x= 0, color='r',label='sample on ')
                axarr[0].axvline(x= 500, color='r', linestyle='--',label='sample off')
                axarr[0].axvline(x= 2000, color='g',label='match on ')

                axarr[0].legend()
                axarr[0].set_title('all session all channels \n'+area1+' to '+area2)

                axarr[1].axvline(x= 0, color='r',label='sample on ')
                axarr[1].axvline(x= 500, color='r', linestyle='--',label='sample off')
                axarr[1].axvline(x= 2000, color='g',label='match on ')

                axarr[1].legend()
                axarr[1].set_title('mean \n'+area1+' to '+area2)

                plt.show()
                
                
                ### plot in the other way
                area1, area2 = area2, area1
                
                
                print(area1, ' to ', area2)
                print('r2 :', r2_score)
                print('number of session :', N)
                
                if r2_score < threshold_r2 or N < threshold_session :
                    continue
                    
                fig, axarr = plt.subplots(ncols=2, figsize=(15,5))
                
                
                FC_time_course_mean = 0
                N = 0        
                for sess_no in session :
                    # path to numpy file
                    base_path = figure_path + sess_no + '/'

                    # load data
                    label1 = np.load( base_path + 'label1.npy') # input label
                    label2 = np.load( base_path + 'label2.npy') # output label
                    time = np.load( base_path +'time_for_time_course.npy')

                    FC_time_course = np.load(base_path + 'channel_to_channel_all_cortex_low_7_high_12time_course.npy')

                    # select channels
                    ind1 = (label1==area1)
                    ind2 = (label2==area2)

                    if np.sum(ind1) == 0 or np.sum(ind2) ==0 :
                        continue

                    FC_time_course = FC_time_course[ind1, :, :]
                    FC_time_course = FC_time_course[:,ind2,:]
                    
#                     print(FC_time_course.shape)

                    if FC_time_course.shape[2] != 43 :
                        continue

                    # plot each time course 
                    axarr[0].plot(time, np.round( FC_time_course.reshape((-1, FC_time_course.shape[2])).T , 3))

                    # sum the mean on this session
                    if area1 != area2 :
#                         N += FC_time_course.shape[0]* FC_time_course.shape[1]
#                         FC_time_course_mean +=  np.sum(FC_time_course, axis=(0,1))

                        N += 1
                        FC_time_course_mean +=  np.mean(FC_time_course, axis=(0,1))


                    else:
#                         N += FC_time_course.shape[0]*FC_time_course.shape[1] - FC_time_course.shape[0]
#                         for i1 in range(FC_time_course.shape[0]):
#                             for i2 in range(FC_time_course.shape[1]):
#                                 if i1 != i2 :
#                                     FC_time_course_mean += FC_time_course[i1,i2,:]
                        N += 1

                        for i1 in range(FC_time_course.shape[0]):
                            for i2 in range(FC_time_course.shape[1]):
                                if i1 != i2 :
                                    FC_time_course_mean += FC_time_course[i1,i2,:]/(FC_time_course.shape[0]*FC_time_course.shape[1] - FC_time_course.shape[0])

                # plot mean of the mean 
                print(N)
                axarr[1].plot(time, FC_time_course_mean/N)

                # plot line
                axarr[0].axvline(x= 0, color='r',label='sample on ')
                axarr[0].axvline(x= 500, color='r', linestyle='--',label='sample off')
                axarr[0].axvline(x= 2000, color='g',label='match on ')

                axarr[0].legend()
                axarr[0].set_title('all session all channels \n'+area1+' to '+area2)

                axarr[1].axvline(x= 0, color='r',label='sample on ')
                axarr[1].axvline(x= 500, color='r', linestyle='--',label='sample off')
                axarr[1].axvline(x= 2000, color='g',label='match on ')

                axarr[1].legend()
                axarr[1].set_title('mean \n'+area1+' to '+area2)

                plt.show()