In [1]:
import os
import numpy as np
import pandas as pd
from scipy.stats import spearmanr, pearsonr
import pingouin as pg
from matplotlib import pyplot as plt
import seaborn as sns
from itertools import groupby

pd.set_option('display.max_rows', 47)
pd.set_option('display.max_columns', 47)
pd.set_option('display.width', 1000)

from palettable.scientific.sequential import Batlow_20, Batlow_20_r, GrayC_3, Davos_3_r, Oslo_3_r, LaPaz_20_r
from palettable.scientific.diverging import Roma_20, Roma_20_r


from utils.settings import periods, frequency_bands

# # # EEG 64 dataset
ROI = {'frontal_left': [' ', ' ', 'Frontal Left', ' ', ' '],
       #'new1': [' '],
       'frontal_central': [' ', ' ', 'Frontal Central', ' ', ' ', ' '],
       #'new2': [' '],
       'frontal_right': [' ', ' ', 'Frontal Right', ' ', ' '],
       #'new3': [' '],
       'temporal_left': [' ', 'Temporal Left', ' '],
       #'new4': [' '],
       'central': [' ', 'Central', ' '],
       #'new5': [' '],
       'temporal_right': [' ', 'Temporal Right', ' '],
       #'new6': [' '],
       'parietal_left': [' ', ' ', 'Parietal Left', ' ', ' '],
       #'new7': [' '],
       'parietal_central': [' ', ' ', 'Parietal Central', ' ', ' ', ' '],
       #'new8': [' '],
       'parietal_right': [' ', ' ', 'Parietal Right', ' ', ' '],
        #'new9': [' '],
       'occipital_left': [' ', 'Occipital Left', ' '],
       #'new10': [' '],
       'occipital_right': [' ', 'Occipital Right', ' ']}

# ROI = {'Left \n           ': ['F7', 'F5', 'F3', 'FC5', 'FC3'],
#        #'new1': [' '],
#        'Central \n Frontal': ['F1', 'Fz', 'F2', 'FC1', 'FCz', 'FC2'],
#        #'new2': [' '],
#        'Right \n          ': ['F4', 'F6', 'F8', 'FC4', 'FC6'],
#        #'new3': [' '],
#        'Left \n Temporal': ['FT7', 'T7', 'TP7'],
#        #'new4': [' '],
#        '  \n Central': ['C3', 'Cz', 'C4'],
#        #'new5': [' '],
#        'Right \n Temporal': ['FT8', 'T8', 'TP8'],
#        #'new6': [' '],
#        '         \n Left': ['CP5', 'CP3', 'P7', 'P5', 'P3'],
#        #'new7': [' '],
#        'Parietal \n Central': ['CP1', 'CPz', 'CP2', 'P1', 'Pz', 'P2'],
#        #'new8': [' '],
#        '         \n Right': ['CP4', 'CP6', 'P4', 'P6', 'P8'],
#         #'new9': [' '],
#        '          \n Left': ['PO3', 'PO7', 'O1'],
#        #'new10': [' '],
#        'Occipital \n Right': ['PO4', 'PO8', 'O2']}

freqs = []
for mode in list(frequency_bands.keys()):
    freqs.extend(list(frequency_bands[mode].keys()))
    
channels_in_order = []
for channels in list(ROI.values()):
    channels_in_order.extend(channels)

In [2]:
ROI

{'frontal_left': [' ', ' ', 'Frontal Left', ' ', ' '],
 'frontal_central': [' ', ' ', 'Frontal Central', ' ', ' ', ' '],
 'frontal_right': [' ', ' ', 'Frontal Right', ' ', ' '],
 'temporal_left': [' ', 'Temporal Left', ' '],
 'central': [' ', 'Central', ' '],
 'temporal_right': [' ', 'Temporal Right', ' '],
 'parietal_left': [' ', ' ', 'Parietal Left', ' ', ' '],
 'parietal_central': [' ', ' ', 'Parietal Central', ' ', ' ', ' '],
 'parietal_right': [' ', ' ', 'Parietal Right', ' ', ' '],
 'occipital_left': [' ', 'Occipital Left', ' '],
 'occipital_right': [' ', 'Occipital Right', ' ']}

In [3]:
def load_behaviour_data_from_path(path):
    """
    Loads behavior data from given path. Renames columns to make it more readable.
    Returns a dataframe with the selected columns.
    """
    print(f'\nLoading behaviour data from: {path}')
    df = pd.read_csv(path, sep=';', decimal=",")

    columns_readable = {'hldif_Session_I': 's1_hldif_rt',
                        'hldif_Session_I_acc': 's1_hldif_acc',
                        'hldif4': 's2_hldif_rt',
                        'hldif4_acc': 's2_hldif_acc'}

    df.rename(columns=columns_readable, inplace=True)

    selected_columns = ['Subject'] + ['s1_hldif_rt', 's1_hldif_acc',
                                      's2_hldif_rt', 's2_hldif_acc',
                                      'konsz_rt', 'konsz_acc',
                                      'gs_konsz_rt', 'gs_konsz_acc']

    print(df[selected_columns].head(10))

    return df[selected_columns]

In [4]:
def load_connectivity_matrix_from_path(path: str) -> np.ndarray:
    """
    Loads a channel connectivity matrix (.npy) into multidimensional array.
    Checks the shape of the matrix to match the expected number of dimensions:
    (n_subjects, n_periods, n_frequencies, n_channels, n_channels)

    Args:
    str path: the path to the channel connectivity matrix
    :return: the loaded channel matrix with a shape of
    (n_subjects, n_periods, n_frequencies, n_channels, n_channels)
    """
    
    print(f'Reading matrix from {path}\n')
    conn = np.load(path)
    assert len(conn.shape) == 5
    print(f'Shape of matrix: {conn.shape}\n'
                 f'Description of dimensions: \n'
                 f'Number of subjects: {conn.shape[0]}\n'
                 f'Number of periods: {conn.shape[1]}\n'
                 f'Number of frequency bands: {conn.shape[2]}\n'
                 f'Channel connectivity matrix: '
                 f'{conn.shape[3]} x {conn.shape[4]}\n')

    return conn

In [5]:
# Specify path to behavior data 
behav_path = '/Users/weian/research/crnl/eeg-connectivity-analysis/ASRT_connect_consol_behav_selected_zs.csv'

# Load behavior data
behav_data = load_behaviour_data_from_path(path=behav_path)

# Specify path to resting data
rs_path = '/Users/weian/research/crnl/eeg-connectivity-analysis/result/rs/wpli'

# Specify path to ASRT data
asrt_path = '/Users/weian/research/crnl/eeg-connectivity-analysis/result/asrt_merged/wpli'

ch_conn_file_name = 'subjects_wpli_ch_conn.npy'

# Load resting channel connectivity data
rs_ch_conn = load_connectivity_matrix_from_path(os.path.join(rs_path, ch_conn_file_name))

# Load ASRT channel connectivity data
asrt_ch_conn = load_connectivity_matrix_from_path(os.path.join(asrt_path, ch_conn_file_name))


Loading behaviour data from: /Users/weian/research/crnl/eeg-connectivity-analysis/ASRT_connect_consol_behav_selected_zs.csv
   Subject  s1_hldif_rt  s1_hldif_acc  s2_hldif_rt  s2_hldif_acc  konsz_rt  konsz_acc  gs_konsz_rt  gs_konsz_acc
0        1     7.500000      0.039356         16.0      0.037931       9.5  -0.011716        -38.5      0.063055
1        2    14.833333      0.046963         15.0      0.104822       6.0   0.013116         -4.0      0.033390
2        3     3.000000     -0.018952         -3.5     -0.028202      -9.5  -0.047495        -35.5      0.132737
3        4    10.166667      0.008487          9.0      0.010783      -4.0  -0.027699         -3.5      0.037792
4        5     4.333333      0.055089         11.0      0.044036      -2.0  -0.006763        -11.0      0.029546
5        6    20.166667      0.041596         13.5      0.057851     -10.5   0.042018        -22.0     -0.016030
6        7     8.500000      0.024006         16.5      0.053761       6.5   0.03275

In [6]:
def uncorrected_spearmanr(conn, behav_vector, method):
    """
    Calculates Spearman correlations between connectivity matrix (n_subjects, n_channels, n_channels) 
    and behavior data (n_subjects,). 
    The correlations are calculated by iterating over the NxN connectivity matrix
    and selecting the connectivity values at each cell for every subject. For each iteration the Spearman 
    correlation is computed between the connectivity measures (n_subjects,) and the behavior data (n_subjects,).
    Returns a dataFrame with the observed R values for each cell in an NxN correlation matrix.
    """

    if conn.shape[-1] > len(list(ROI.keys())):
        labels = channels_in_order
    else:
        labels = list(ROI.keys())
        
    assert conn.shape[0] == len(behav_vector)
    observed_corr = np.zeros((conn.shape[-1], conn.shape[-1]))
    
    # Remove missing values if there are any
    conn, behav_vector = remove_missing_data(conn, behav_vector)

    # calculate observed r-values
    for row in range(conn.shape[-1]):
        for column in range(conn.shape[-1]):
            # ignore upper triangle with zero values
            if all(v == 0 for v in conn[:, row, column]):
                observed_corr[row][column] = 0.0
            else:
            
                corr_output = pg.corr(x=conn[:, row, column], y=behav_vector,
                        tail = 'two-sided', method=method)
                

                observed_corr[row][column] = corr_output['r']
                #observed_pval[row][column] = corr_output['p-val']


    df = pd.DataFrame(data=observed_corr, index=labels, columns=labels)
    return df

In [7]:
    def remove_missing_data(conn, behav_vector):
        missing_data_indices = []
        for subject in range(conn.shape[0]):
            if (conn[subject] == 0).all():
                print(f'Found missing data for subject {behav_data["Subject"][subject]}')
                missing_data_indices.append(subject)

        if missing_data_indices:
            conn_missing_data_removed = np.delete(conn, missing_data_indices, axis=0)
            behav_vector_missing_data_removed = [val for ind, val in enumerate(behav_vector) if ind not in missing_data_indices]
            print(f'Missing data is removed.')
            return conn_missing_data_removed, behav_vector_missing_data_removed
        else:
            print('No missing data found.')
            return conn, behav_vector

In [8]:
def plot_corr_heatmap(df, threshold, title):
    mask = np.zeros_like(df.values)
    mask[abs(df.values) <= threshold] = True
    mask[df.values == 0] = True

    fig, ax = plt.subplots(figsize=(16, 8))
    ax.set_title(f'| r | > {threshold}', size=12)
    sns.set(style='white', font_scale=1.)
    sns.heatmap(df,
                mask=mask,
                square=True,
                vmin=-.6,
                vmax=.6,
                cbar=True,
                annot=False,
                linewidths=.01,
                linecolor='lightgrey',
                cmap=Batlow_20.mpl_colormap, #"RdBu_r",
                cbar_kws={"shrink": .5},
                xticklabels = 1,
                yticklabels=1,
                ax=ax)
    ax.plot([0, 1], [1, 0], transform=ax.transAxes, c='black', alpha=0.1)
    cax = plt.gcf().axes[0]
    cax.tick_params(labelsize=12)
    fig.tight_layout()
    # uncomment to save figure
#     plt.savefig(
#       os.path.join(f'{title.replace(" ", "_")}.png'),
#       dpi=200,
#       transparent=False)
    plt.show()
    plt.close(fig)

In [9]:
def conn_array_to_df(conn):
    if conn.shape[-1] > len(list(ROI.keys())):
        labels = []
        for channels in list(ROI.values()):
            labels.extend(channels)
    else:
        labels = list(ROI.keys())
    return pd.DataFrame(data=conn, index=labels, columns=labels)

In [10]:
# Define v_max based on biggest value in conn matrix
max_val = []
for condition in ['rs', 'asrt_merged']:
    for period in periods[condition]:
        conn_index = periods[condition].index(period)
        for freq in freqs:
            freq_index = freqs.index(freq)
            if condition == 'rs':
                rs_conn = rs_ch_conn[:, conn_index, freq_index]
                # check and remove missing data before averaging
                rs_rm, _ = remove_missing_data(rs_conn, [])
                rs_avg = rs_rm.mean(axis=0)
                print(condition, period, freq, rs_avg.max())
                max_val.append(rs_avg.max())
            else:
                asrt_conn = asrt_ch_conn[:, conn_index, freq_index]
                asrt_rm, _ = remove_missing_data(asrt_conn, [])
                asrt_avg = asrt_rm.mean(axis=0)
                print(condition, period, freq, asrt_avg.max())
                max_val.append(asrt_avg.max())

print(f'v_max={max(max_val)}')
                
        

No missing data found.
rs ny_1 delta 0.29320570789444755
No missing data found.
rs ny_1 theta 0.2529778637830557
No missing data found.
rs ny_1 alpha 0.31275705146305693
No missing data found.
rs ny_1 beta 0.16971394236346185
No missing data found.
rs ny_1 gamma 0.37060460459754124
Found missing data for subject 2
Missing data is removed.
rs ny_2 delta 0.2991877656550653
Found missing data for subject 2
Missing data is removed.
rs ny_2 theta 0.27253114477822055
Found missing data for subject 2
Missing data is removed.
rs ny_2 alpha 0.3416156956864872
Found missing data for subject 2
Missing data is removed.
rs ny_2 beta 0.1853401458320333
Found missing data for subject 2
Missing data is removed.
rs ny_2 gamma 0.3316583674997671
No missing data found.
rs ny_3 delta 0.2845204552346023
No missing data found.
rs ny_3 theta 0.2550901702565312
No missing data found.
rs ny_3 alpha 0.2978306256462764
No missing data found.
rs ny_3 beta 0.1752267747530452
No missing data found.
rs ny_3 gamma 0.

In [11]:
asrt_avg.shape

(47, 47)

In [12]:
# Calculate channel level correlations with behavior data

freq_name = 'delta'
rs_1_index = periods['rs'].index('ny_1')
freq_index = freqs.index(freq_name)
rs_1_delta = rs_ch_conn[:, rs_1_index, freq_index, ...]

# check and remove missing data before averaging
rs_1_delta, _ = remove_missing_data(rs_1_delta, [])
rs_1_delta_avg = rs_1_delta.mean(axis=0)

# rs_1_delta_avg = conn_array_to_df(rs_1_delta_avg)

# index = pd.MultiIndex.from_tuples([(group[0],subgroup) for group in ROI.items() for subgroup in group[1]],
#    names=['group', 'subgroup'])

# rs_1_delta_avg = rs_1_delta_avg.set_index(index)

z = rs_1_delta_avg
i_upper = np.triu_indices(47, 1)
z[i_upper] = z.T[i_upper]

No missing data found.


In [13]:
rs_1_delta_avg.max()

  **kwargs


0.29320570789444755

In [14]:
z = conn_array_to_df(z)
z

Unnamed: 0,Unnamed: 1,Unnamed: 2,Frontal Left,Unnamed: 4,Unnamed: 5,Unnamed: 6,Unnamed: 7,Frontal Central,Unnamed: 9,Unnamed: 10,Unnamed: 11,Unnamed: 12,Unnamed: 13,Frontal Right,Unnamed: 15,Unnamed: 16,Unnamed: 17,Temporal Left,Unnamed: 19,Unnamed: 20,Central,Unnamed: 22,Unnamed: 23,Temporal Right,Unnamed: 25,Unnamed: 26,Unnamed: 27,Parietal Left,Unnamed: 29,Unnamed: 30,Unnamed: 31,Unnamed: 32,Parietal Central,Unnamed: 34,Unnamed: 35,Unnamed: 36,Unnamed: 37,Unnamed: 38,Parietal Right,Unnamed: 40,Unnamed: 41,Unnamed: 42,Occipital Left,Unnamed: 44,Unnamed: 45,Occipital Right,Unnamed: 47
,0.0,0.229187,0.207226,0.203064,0.201569,0.192914,0.179792,0.180602,0.196673,0.184076,0.174836,0.162094,0.12291,0.1135,0.157502,0.118373,0.179808,0.174686,0.165636,0.202869,0.185357,0.142109,0.126875,0.130488,0.13217,0.174785,0.188852,0.15873,0.174642,0.183502,0.190854,0.180521,0.165333,0.183264,0.177877,0.168609,0.148198,0.140582,0.155739,0.15035,0.146253,0.188067,0.170204,0.181339,0.164638,0.156617,0.166781
,0.229187,0.0,0.160158,0.163974,0.160909,0.163777,0.147253,0.14593,0.166444,0.156181,0.147887,0.130425,0.121239,0.128356,0.132716,0.12689,0.2412,0.204986,0.181929,0.160083,0.16803,0.12617,0.16285,0.166782,0.153227,0.159648,0.157496,0.148179,0.15449,0.15732,0.170724,0.167534,0.153993,0.166817,0.166401,0.15566,0.136168,0.139343,0.146172,0.147741,0.142123,0.162856,0.150662,0.160062,0.155273,0.146296,0.156769
Frontal Left,0.207226,0.160158,0.0,0.178765,0.159838,0.167033,0.145261,0.138358,0.170807,0.154284,0.145824,0.131259,0.138036,0.153572,0.129258,0.141192,0.258593,0.220731,0.187654,0.160678,0.168495,0.132525,0.19117,0.189064,0.168254,0.165553,0.158751,0.148045,0.153008,0.16075,0.172053,0.168789,0.157861,0.170119,0.170939,0.160281,0.145851,0.150869,0.154667,0.155237,0.150099,0.169892,0.156126,0.169594,0.165329,0.15889,0.169841
,0.203064,0.163974,0.178765,0.0,0.21398,0.178276,0.168283,0.16123,0.199575,0.179364,0.171276,0.144215,0.11582,0.124852,0.14643,0.113332,0.263685,0.181568,0.135417,0.209794,0.187839,0.128145,0.147762,0.138118,0.124569,0.147457,0.182206,0.128009,0.156306,0.17736,0.198299,0.184211,0.164961,0.184416,0.177688,0.162357,0.128845,0.114849,0.145152,0.133619,0.121388,0.178984,0.155788,0.174333,0.159793,0.142785,0.161417
,0.201569,0.160909,0.159838,0.21398,0.0,0.148565,0.133061,0.128325,0.16341,0.14244,0.131168,0.12572,0.130195,0.15513,0.120592,0.141113,0.270527,0.236895,0.182744,0.150677,0.158957,0.11953,0.195136,0.189058,0.160076,0.156849,0.134289,0.128074,0.130665,0.135366,0.159719,0.160412,0.142377,0.154301,0.157255,0.144121,0.126181,0.13852,0.133464,0.131824,0.129535,0.155058,0.136886,0.155902,0.14594,0.138597,0.149058
,0.192914,0.163777,0.167033,0.178276,0.148565,0.0,0.129525,0.111863,0.160869,0.151262,0.137965,0.114734,0.139757,0.166031,0.119748,0.152799,0.254248,0.228463,0.189797,0.154464,0.169496,0.137718,0.213452,0.205657,0.174446,0.16942,0.15213,0.154301,0.152043,0.151921,0.165068,0.170114,0.15936,0.161961,0.168824,0.157705,0.143792,0.156366,0.15376,0.154,0.153869,0.164159,0.158159,0.167769,0.162559,0.159925,0.165828
,0.179792,0.147253,0.145261,0.168283,0.133061,0.129525,0.0,0.119671,0.149989,0.164794,0.149094,0.125585,0.151184,0.181858,0.136304,0.16583,0.238137,0.217567,0.177988,0.151172,0.17097,0.148389,0.226527,0.213838,0.179391,0.164998,0.155117,0.156993,0.153186,0.153614,0.164125,0.17136,0.167016,0.161937,0.16809,0.163752,0.152159,0.163839,0.160066,0.159992,0.160405,0.16474,0.16151,0.169227,0.164199,0.160609,0.167428
Frontal Central,0.180602,0.14593,0.138358,0.16123,0.128325,0.111863,0.119671,0.0,0.127587,0.149853,0.153341,0.138357,0.171416,0.207114,0.142737,0.188659,0.226473,0.207636,0.173406,0.14216,0.166494,0.148675,0.247792,0.224987,0.181823,0.159672,0.149029,0.152897,0.150307,0.148858,0.153253,0.163032,0.16152,0.154664,0.163886,0.158391,0.150589,0.16363,0.154172,0.153463,0.153611,0.16053,0.158346,0.165573,0.157969,0.151555,0.162116
,0.196673,0.166444,0.170807,0.199575,0.16341,0.160869,0.149989,0.127587,0.0,0.125463,0.115714,0.126529,0.14375,0.173625,0.112886,0.156597,0.258151,0.23236,0.18238,0.144864,0.158046,0.127791,0.218602,0.208918,0.170812,0.166761,0.140805,0.137182,0.135584,0.134373,0.147861,0.157546,0.14554,0.14287,0.153813,0.142076,0.132088,0.150258,0.137637,0.136158,0.136823,0.148818,0.141165,0.154012,0.145383,0.141232,0.150117
,0.184076,0.156181,0.154284,0.179364,0.14244,0.151262,0.164794,0.149853,0.125463,0.0,0.127455,0.146657,0.161042,0.192651,0.135616,0.176943,0.240155,0.216038,0.17519,0.14063,0.156298,0.144235,0.232526,0.220311,0.176812,0.162923,0.140434,0.142025,0.139294,0.132223,0.136091,0.145029,0.144873,0.135258,0.147253,0.14012,0.141178,0.160132,0.140101,0.142212,0.143118,0.146384,0.144186,0.152152,0.145204,0.14237,0.149402


In [15]:
new_row_indices = [int(val+ind) for ind, val in enumerate(np.cumsum([len(roi) for roi in ROI.values()]))]
new_row_indices

[5, 12, 18, 22, 26, 30, 36, 43, 49, 53, 57]

In [16]:
#df.insert(2, 'new-col', data)
#new_row_indices =  list(np.cumsum([len(roi) if len(roi) == 5 else len(roi)+1 for roi in ROI.values()]))

#rs_1_delta_avg.insert(5, f'new-5', np.zeros((len(channels_in_order),)))
rs_1_delta_avg.insert(12, f'new-12', np.zeros((len(channels_in_order),)))

AttributeError: 'numpy.ndarray' object has no attribute 'insert'

In [None]:
np.insert(rs_1_delta_avg, np.cumsum([len(roi) for roi in list(ROI.values())[:-1]]), 0.0, axis=1)

In [None]:
# RUN FROM HERE
# needs the old ROI size
y = np.insert(rs_1_delta_avg, np.cumsum([len(roi) for roi in list(ROI.values())[:-1]]), 0.0, axis=1)
z = np.insert(y, np.cumsum([len(roi) for roi in list(ROI.values())[:-1]]), 0.0, axis=0)

In [None]:
z.shape

In [None]:
i_upper = np.triu_indices(57, 1)
z[i_upper] = z.T[i_upper]

In [None]:
ROI = {'Frontal \nLeft': ['F7', 'F5', 'F3', 'FC5', 'FC3'],
       'new1': ['new1'],
       'Frontal \nCentral': ['F1', 'Fz', 'F2', 'FC1', 'FCz', 'FC2'],
       'new2': ['new2'],
       'Frontal \nRight': ['F4', 'F6', 'F8', 'FC4', 'FC6'],
       'new3': ['new3'],
       'Temporal \nLeft': ['FT7', 'T7', 'TP7'],
       'new4': ['new4'],
       'Central': ['C3', 'Cz', 'C4'],
       'new5': ['new5'],
       'Temporal \nRight': ['FT8', 'T8', 'TP8'],
       'new6': ['new6'],
       'Parietal \nLeft': ['CP5', 'CP3', 'P7', 'P5', 'P3'],
       'new7': ['new7'],
       'Parietal \nCentral': ['CP1', 'CPz', 'CP2', 'P1', 'Pz', 'P2'],
       'new8': ['new8'],
       'Parietal \nRight': ['CP4', 'CP6', 'P4', 'P6', 'P8'],
        'new9': ['new9'],
       'Occipital \nLeft': ['PO3', 'PO7', 'O1'],
       'new10': ['new10'],
       'Occipital \nRight': ['PO4', 'PO8', 'O2']}

    
channels_in_order = []
for channels in list(ROI.values()):
    channels_in_order.extend(channels)

z = pd.DataFrame(data=z, index=channels_in_order, columns=channels_in_order)
index = pd.MultiIndex.from_tuples([(group[0],subgroup) for group in ROI.items() for subgroup in group[1]],
   names=['group', 'subgroup'])

z = z.set_index(index)

In [None]:
df = uncorrected_spearmanr(rs_1_delta, behav_data['s1_hldif_rt'], 'spearman')
plot_corr_heatmap(df, 0.2, '')

In [None]:
# check and remove missing data before averaging
# DELTA
rs_rm, _ = remove_missing_data(rs_ch_conn[:, 0, 0], [])
rs_1_delta_avg = rs_rm.mean(axis=0)
i_upper = np.triu_indices(47, 1)
rs_1_delta_avg[i_upper] = rs_1_delta_avg.T[i_upper]
rs_1_delta_avg_df = conn_array_to_df(rs_1_delta_avg)

asrt_rm, _ = remove_missing_data(asrt_ch_conn[:, 0, 0], [])
asrt_delta_avg = asrt_rm.mean(axis=0)
asrt_delta_avg[i_upper] = asrt_delta_avg.T[i_upper]
asrt_delta_avg_df = conn_array_to_df(asrt_delta_avg)

rs_rm, _ = remove_missing_data(rs_ch_conn[:, 1, 0], [])
rs_2_delta_avg = rs_rm.mean(axis=0)
rs_2_delta_avg[i_upper] = rs_2_delta_avg.T[i_upper]
rs_2_delta_avg_df = conn_array_to_df(rs_2_delta_avg)


# THETA
rs_rm, _ = remove_missing_data(rs_ch_conn[:, 0, 1], [])
rs_1_theta_avg = rs_rm.mean(axis=0)
i_upper = np.triu_indices(47, 1)
rs_1_theta_avg[i_upper] = rs_1_theta_avg.T[i_upper]
rs_1_theta_avg_df = conn_array_to_df(rs_1_theta_avg)

asrt_rm, _ = remove_missing_data(asrt_ch_conn[:, 0, 1], [])
asrt_theta_avg = asrt_rm.mean(axis=0)
asrt_theta_avg[i_upper] = asrt_theta_avg.T[i_upper]
asrt_theta_avg_df = conn_array_to_df(asrt_theta_avg)

rs_rm, _ = remove_missing_data(rs_ch_conn[:, 1, 1], [])
rs_2_theta_avg = rs_rm.mean(axis=0)
rs_2_theta_avg[i_upper] = rs_2_theta_avg.T[i_upper]
rs_2_theta_avg_df = conn_array_to_df(rs_2_theta_avg)

# ALPHA
rs_rm, _ = remove_missing_data(rs_ch_conn[:, 0, 2], [])
rs_1_alpha_avg = rs_rm.mean(axis=0)
i_upper = np.triu_indices(47, 1)
rs_1_alpha_avg[i_upper] = rs_1_alpha_avg.T[i_upper]
rs_1_alpha_avg_df = conn_array_to_df(rs_1_alpha_avg)

asrt_rm, _ = remove_missing_data(asrt_ch_conn[:, 0, 2], [])
asrt_alpha_avg = asrt_rm.mean(axis=0)
asrt_alpha_avg[i_upper] = asrt_alpha_avg.T[i_upper]
asrt_alpha_avg_df = conn_array_to_df(asrt_alpha_avg)

rs_rm, _ = remove_missing_data(rs_ch_conn[:, 1, 2], [])
rs_2_alpha_avg = rs_rm.mean(axis=0)
rs_2_alpha_avg[i_upper] = rs_2_alpha_avg.T[i_upper]
rs_2_alpha_avg_df = conn_array_to_df(rs_2_alpha_avg)

# BETA
rs_rm, _ = remove_missing_data(rs_ch_conn[:, 0, 3], [])
rs_1_beta_avg = rs_rm.mean(axis=0)
i_upper = np.triu_indices(47, 1)
rs_1_beta_avg[i_upper] = rs_1_beta_avg.T[i_upper]
rs_1_beta_avg_df = conn_array_to_df(rs_1_beta_avg)

asrt_rm, _ = remove_missing_data(asrt_ch_conn[:, 0, 3], [])
asrt_beta_avg = asrt_rm.mean(axis=0)
asrt_beta_avg[i_upper] = asrt_beta_avg.T[i_upper]
asrt_beta_avg_df = conn_array_to_df(asrt_beta_avg)

rs_rm, _ = remove_missing_data(rs_ch_conn[:, 1, 3], [])
rs_2_beta_avg = rs_rm.mean(axis=0)
rs_2_beta_avg[i_upper] = rs_2_beta_avg.T[i_upper]
rs_2_beta_avg_df = conn_array_to_df(rs_2_beta_avg)


# GAMMA
rs_rm, _ = remove_missing_data(rs_ch_conn[:, 0, 4], [])
rs_1_gamma_avg = rs_rm.mean(axis=0)
i_upper = np.triu_indices(47, 1)
rs_1_gamma_avg[i_upper] = rs_1_gamma_avg.T[i_upper]
rs_1_gamma_avg_df = conn_array_to_df(rs_1_gamma_avg)

asrt_rm, _ = remove_missing_data(asrt_ch_conn[:, 0, 4], [])
asrt_gamma_avg = asrt_rm.mean(axis=0)
asrt_gamma_avg[i_upper] = asrt_gamma_avg.T[i_upper]
asrt_gamma_avg_df = conn_array_to_df(asrt_gamma_avg)

rs_rm, _ = remove_missing_data(rs_ch_conn[:, 1, 4], [])
rs_2_gamma_avg = rs_rm.mean(axis=0)
rs_2_gamma_avg[i_upper] = rs_2_gamma_avg.T[i_upper]
rs_2_gamma_avg_df = conn_array_to_df(rs_2_gamma_avg)

In [None]:
%matplotlib inline
fig, ax = plt.subplots(nrows=5, ncols=3, sharex=True, sharey=True, 
                       figsize=(8.27, 11.69), dpi=200, constrained_layout=True, gridspec_kw={'hspace':0.0001, 'wspace':0.0001, 'left':0.3, 'right':0.7}) # dpi=200, #gridspec_kw = {'wspace':0, 'hspace':0, 'left':None, 'bottom':None, 'right':None, 'top':None}
# index = pd.MultiIndex.from_tuples([(group[0],subgroup) for group in ROI.items() for subgroup in group[1]],
#    names=['group', 'subgroup'])
# rs_1_delta_avg_df = rs_1_delta_avg_df.set_index(index)
ax[0,0].set_title('RS1',fontweight="bold", size=10)
ax_conn_heatmap(rs_1_delta_avg_df, ax[0, 0])
ax[0,0].set_ylabel('Delta',fontweight="bold", size=9)
ax[0,1].set_title('ASRT',fontweight="bold", size=10)
ax_conn_heatmap(asrt_delta_avg_df, ax[0, 1])
ax[0,2].set_title('RS2',fontweight="bold", size=10)
ax_conn_heatmap(rs_2_delta_avg_df, ax[0, 2])

ax_conn_heatmap(rs_1_theta_avg_df, ax[1,0])
ax[1,0].set_ylabel('Theta',fontweight="bold", size=9)
ax_conn_heatmap(asrt_theta_avg_df, ax[1,1])
ax_conn_heatmap(rs_2_theta_avg_df, ax[1,2])

ax_conn_heatmap(rs_1_alpha_avg_df, ax[2,0])
ax[2,0].set_ylabel('Alpha',fontweight="bold", size=9)
ax_conn_heatmap(asrt_alpha_avg_df, ax[2,1])
ax_conn_heatmap(rs_2_alpha_avg_df, ax[2,2])

ax_conn_heatmap(rs_1_beta_avg_df, ax[3,0])
ax[3,0].set_ylabel('Beta',fontweight="bold", size=9)
ax_conn_heatmap(asrt_beta_avg_df, ax[3,1])
ax_conn_heatmap(rs_2_beta_avg_df, ax[3,2])

ax_conn_heatmap(rs_1_gamma_avg_df, ax[4,0])
ax[4,0].set_ylabel('Gamma',fontweight="bold", size=9)
ax_conn_heatmap(asrt_gamma_avg_df, ax[4,1])
ax_conn_heatmap(rs_2_gamma_avg_df, ax[4,2])

#fig.tight_layout()
#plt.subplot_tool()
fig.savefig("wpli_ch_conn_bw.eps", papertype = 'a4', orientation = 'portrait', format = 'eps', bbox_inches='tight', dpi=220)
fig.show()

In [None]:
def ax_conn_heatmap(df, ax):
    mask = np.zeros_like(df.values)
    mask[df.values == 0] = True
    sns.heatmap(df,
                mask=mask,
                square=True,
                vmin=0.,
                vmax=.4,
                cbar=False,
                annot=False,
                linewidths=.01,
                linecolor='white',
                cmap=GrayC_3.mpl_colormap,
                cbar_kws={"shrink": .5},
                xticklabels=1,
                yticklabels=1,
                ax=ax)
#     ax.set_xticklabels('')
#     ax.set_xlabel('')
    ax.set_xticklabels(channels_in_order, fontsize=6, rotation=90, ha='center') #channels_in_order
    ax.set_yticklabels(channels_in_order, fontsize=6) #channels_in_order
#     for color, tick in zip(colors_names, ax.yaxis.get_major_ticks()):
#         tick.label1.set_color(color) #set the color
#     for color, tick in zip(colors_names, ax.xaxis.get_major_ticks()):
#         tick.label1.set_color(color) #set the color
#     ax.set_ylabel('')
    #ax.axis('off')
    ax.tick_params(tick1On=False) #"for left and bottom ticks"
    ax.tick_params(tick2On=False) #"for right and top ticks, which are off by default"
    plt.axvline(x=0, linewidth=1, c='white')
    plt.axhline(y=0, linewidth=1, c='white')
    #ax3.set(title='Nice', ylabel='$C_y$')
    #plt.setp(ax.get_xticklabels(), visible=False)
    #plt.setp(ax.get_yticklabels(), visible=False)
    #ax.tick_params(axis='both', which='both', length=0)
    #ax.get_xaxis().set_visible(False)
    #ax.get_yaxis().set_visible(False)
    borders = np.cumsum([len(roi) for roi in ROI.values()])
    for border in borders:
        ax.axvline(x=border, linewidth=1, c='white')
        ax.axhline(y=border, linewidth=1, c='white')
    
#     #label_group_bar_table(ax, df)
#     cax = plt.gcf().axes[0]
#     cax.tick_params(labelsize=12)


In [None]:
colors_map = {'Frontal Left': '#0f4452', #'tab:orange', 
              'Frontal Central': '#0f4452',
              'Frontal Right': '#0f4452',
              'Temporal Left': '#18a59a', #'tab:purple',
              'Central': '#dfa277',#'tab:green',
              'Temporal Right': '#18a59a',
              'Parietal Left': '#f3505d',#'tab:brown',
              'Parietal Central': '#f3505d',
              'Parietal Right': '#f3505d',
              'Occipital Left': '#c70951', #'tab:pink',
              'Occipital Right': '#c70951'}

In [None]:
colors_names = []
for channel in channels_in_order:
    if channel in colors_map:
        colors_names.append(colors_map[channel])
    else:
        colors_names.append('white')

In [None]:
len(colors_names) == len(channels_in_order)

In [None]:
def label_group_bar_table(ax, df):
    ypos = -.05
    scale = 1./df.index.size
    for level in range(df.index.nlevels)[::-1]:
        pos = 0
        for label, rpos in label_len(df.index,level):
            lxpos = (pos + .5 * rpos)*scale
            if 'new' in label:
                pos += rpos
                continue
            if label not in list(ROI.keys()):
                ax.text(lxpos, ypos, label, ha='center', va='bottom', transform=ax.transAxes, fontsize=6)
                #ax.text(ypos, lxpos, label, ha='center', va='bottom', transform=ax.transAxes, fontsize=12)
            else:
                ax.text(lxpos, ypos, label, ha='center', transform=ax.transAxes, fontsize=5)
                #ax.text(ypos, lxpos, label, ha='center', transform=ax.transAxes, fontsize=14)
                #add_line(ax, pos*scale, ypos)
            pos += rpos
        #add_line(ax, pos*scale , ypos)
        ypos -= .05
        
def add_line(ax, xpos, ypos):
    line = plt.Line2D([xpos, xpos], [ypos + .1, ypos],
                      transform=ax.transAxes, color='lightgrey', linewidth=0.5)
    line.set_clip_on(False)
    ax.add_line(line)

def label_len(my_index,level):
    labels = my_index.get_level_values(level)
    return [(k, sum(1 for i in g)) for k,g in groupby(labels)]

In [None]:
from matplotlib.patches import Rectangle
def plot_conn_heatmap(df, title=''):
    mask = np.zeros_like(df.values)
    #mask[abs(df.values) <= threshold] = True
    mask[df.values == 0] = True

    fig, ax = plt.subplots(figsize=(4, 4), dpi=220)
    #ax.set_title(f'| r | > {threshold}', size=12)
    sns.heatmap(df,
                mask=mask,
                square=True,
                vmin=0.,
                vmax=.4,
                cbar=True,
                annot=False,
                linewidths=.01,
                linecolor='white',
                cmap=GrayC_3.mpl_colormap,
                #cbar_kws={"shrink": .5},
                xticklabels = 1,
                yticklabels=1,
                ax=ax)
    #ax.plot([0, 1], [1, 0], transform=ax.transAxes, c='black', alpha=0.1)
    ax.set_xticklabels('')
    ax.set_xlabel('')
    ax.set_yticklabels(channels_in_order)
    ax.set_ylabel('')
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=8)
#     row = [i for i, x in enumerate(df.index) if x in ['F7', 'F5', 'F3', 'FC5', 'FC3']]
#     col = [i for i, x in enumerate(df.columns) if x in ['F7', 'F5', 'F3', 'FC5', 'FC3']][0]

    #ax.add_patch(Rectangle((0, 0), 5, 47, edgecolor='blue', fill=False, lw=2.5))
    #ax.add_patch(Rectangle((0, 0), 47, 5, edgecolor='blue', fill=False, lw=2.5))
    plt.axvline(x=0, linewidth=2.3, c='white')
    plt.axhline(y=0, linewidth=2.3, c='white')
    borders = np.cumsum([len(roi) for roi in list(ROI.values())])
    for border in borders:
        plt.axvline(x=border, linewidth=2.3, c='white')
        plt.axhline(y=border, linewidth=2.3, c='white')
    
    #ax.add_patch(Rectangle((0, 5), 6, 47, edgecolor='green', fill=False, lw=2))
    #ax.add_patch(Rectangle((0, 5), 47, 6, edgecolor='green', fill=False, lw=2))
    #label_group_bar_table(ax, df)
    cax = plt.gcf().axes[0]
    cax.tick_params(labelsize=12)
    fig.tight_layout()
    # uncomment to save figure
    plt.savefig(
      os.path.join(f'gray_colorbar.eps'), dpi=220, transparent=True)
    plt.show()
    plt.close(fig)

In [None]:
[i for i, x in enumerate(df.index) if x in ['F7', 'F5', 'F3', 'FC5', 'FC3']][0]

In [None]:
def label_group_bar_table(ax, df):
    ypos = -.05
    scale = 1./df.index.size
    for level in range(df.index.nlevels)[::-1]:
        pos = 0
        for label, rpos in label_len(df.index,level):
            lxpos = (pos + .5 * rpos)*scale
            if 'new' in label:
                pos += rpos
                continue
            if label not in list(ROI.keys()):
                print('skipping')
                #ax.text(lxpos, ypos, label, ha='center', va='bottom', transform=ax.transAxes, fontsize=12, rotation=90)
                #ax.text(ypos, lxpos, label, ha='center', va='bottom', transform=ax.transAxes, fontsize=12)
            else:
                ax.text(lxpos, ypos, label, ha='center', transform=ax.transAxes, fontsize=14)
                #ax.text(ypos, lxpos, label, ha='center', transform=ax.transAxes, fontsize=14)
                #add_line(ax, pos*scale, ypos)
            pos += rpos
        #add_line(ax, pos*scale , ypos)
        ypos -= .05
        
def add_line(ax, xpos, ypos):
    line = plt.Line2D([xpos, xpos], [ypos + .1, ypos],
                      transform=ax.transAxes, color='lightgrey', linewidth=0.5)
    line.set_clip_on(False)
    ax.add_line(line)

def label_len(my_index,level):
    labels = my_index.get_level_values(level)
    return [(k, sum(1 for i in g)) for k,g in groupby(labels)]

In [None]:
plot_conn_heatmap(rs_1_delta_avg_df)

In [None]:
[i for i, x in enumerate(df.index) if x in ['F7', 'F5', 'F3', 'FC5', 'FC3']]
#ax.imshow(data, cmap=Batlow_10.mpl_colormap)

In [None]:
Davos_3.mpl_colormap

In [None]:
[len(roi) for roi in list(ROI.values())]

In [None]:
for i, x in enumerate(df.index):
    print(i, x)

In [None]:
z.index