# Set-up

## Installing libraries and libcudnn8

In [None]:
import os

FILEID = "1h4FWB5fw7sBDCSM-EENK1UadqKSCqg24"

contents = os.listdir(os.getcwd())

if 'MI_EEG_ClassMeth' not in contents:
    !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='$FILEID -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id="$FILEID -O MI_EEG_ClassMeth.zip && rm -rf /tmp/cookies.txt
    !unzip MI_EEG_ClassMeth.zip
else:
    print("MI_EEG_ClassMeth already downloaded!")

!apt-get install --allow-change-held-packages libcudnn8=8.1.1.33-1+cuda11.2 -y
!pip install -U git+https://github.com/UN-GCPDS/python-gcpds.databases
!pip install mne
!pip install pickle5
!pip install gcpds.utils
!pip install scikeras[tensorflow]
!pip install tf_keras_vis

## Get Weigths and Scores

### GFC

In [None]:
FILEID = "1sCzHf_1XFS-28wlxE9qSib1qJIy4CZ3k"

contents = os.listdir(os.getcwd())

if 'GFC_Motor_500_l1l2_6gs.zip' not in contents:
    !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='$FILEID -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id="$FILEID -O GFC_Motor_500_l1l2_6gs.zip && rm -rf /tmp/cookies.txt
    !unzip GFC_Motor_500_l1l2_6gs.zip

## Import libraries

In [None]:
# freq filter 
from MI_EEG_ClassMeth.FeatExtraction import TimeFrequencyRpr

#EEG montage
from gcpds.utils.mne_handler import get_best_montage

# general
import numpy as np
from scipy.signal import resample
import pickle5 as pickle
import warnings
import mne
from time import time
warnings.filterwarnings('ignore')

# tensorlfow 
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Dropout, Conv2D, AveragePooling2D, BatchNormalization, Input, Flatten
from tensorflow.keras.constraints import max_norm
from tensorflow.keras.layers import Layer
from tensorflow.keras.regularizers import L1L2

from sklearn.metrics import cohen_kappa_score

from scipy.stats import ks_2samp

## Define functions

In [None]:
def kappa(y_true, y_pred):
    return cohen_kappa_score(np.argmax(y_true, axis = 1),np.argmax(y_pred, axis = 1))

In [None]:
def multi_sort(*args, reverse=True):
    sorted_lists = (list(t) for t in zip(*sorted(zip(*args), reverse=reverse)))
    return tuple(sorted_lists)

In [None]:
def ks_connetivity_cal(A,y):
    D=A.shape[-1]
    ks = np.zeros(D)
    pvalue = np.zeros(D)
    for d in range(D):
        ks[d],pvalue[d]=ks_2samp(A[y==0, d], A[y==1, d], alternative = 'two-side', mode = 'auto')
    return ks,pvalue

## PAIN dataset

In [None]:
def load_PAIN(db,sbj,f_bank,vwt,new_fs):

    channels_names = np.array(['Fp1','Fp2',
                      'F3','F4','C3','C4','P3','P4','O1','O2','F7','F8',
                      'T7','T8','P7','P8','Fz','Cz','Pz','Oz',
                      'FC1','FC2','CP1','CP2','FC5','FC6','CP5','CP6',
                      'TP9','TP10','LE','RE','P1','P2','C1','C2',
                      'FT9','FT10','AF3','AF4','FC3','FC4','CP3','CP4','PO3','PO4',
                      'F5','F6','C5','C6','P5','P6','PO9','Iz','FT7','FT8',
                      'TP7','TP8','PO7','PO8','Fpz','PO10','CPz','POz',
                      'Ne','Ma','Ext','ECG'])
    
    with open('{}BMOP_Motor_S{}.pkl'.format(db,'0' + str(sbj) if sbj < 10 else sbj), 'rb') as f:
        data = pickle.load(f)
        
    X = data['X']  # trials, channels, time
    y = data['y']
    sex = data['sex'].ravel()
    age = data['age'].ravel()
    fs = float(data['fs'])
    
    tf_repr = TimeFrequencyRpr(sfreq = fs, f_bank = f_bank, vwt = vwt)
    
    #Read electrode positions to load the best standard montage-MNE
    best_montages = get_best_montage(channels_names)
    montage = best_montages.iloc[0]['montage']
    no_channels = np.array(best_montages.iloc[0]['missings channels'])
    channels_to_remove = np.array([np.argwhere(channels_names==no)[0] for no in no_channels])[:,0]

    #Delete the missing channels the original array respecting the positions
    channels_names = np.delete(channels_names, channels_to_remove)
    X = np.delete(X, channels_to_remove, axis=1)

    #Number channels does not match with the dimension of X, 
    #thus the last channel is discarded because it has weird amplitudes
    X = X[:,:-1,:]

    info = mne.create_info(list(channels_names), sfreq=fs, ch_types="eeg")
    info.set_montage(montage)
    info

    event_id = {
        'pain/high':2,
        'resting':3,
        }

    events = [[i, 1, cls[0]] for i, cls in enumerate(y)]
    tmin = 0

    epochs = mne.EpochsArray(X, info, events=events, tmin=tmin, event_id=event_id)
    X = epochs.get_data()
    y = y-2
    X = np.squeeze(tf_repr.transform(X))
                             
    #Resampling
    if new_fs != fs:
        X = resample(X, int((X.shape[-1]/fs)*new_fs), axis = -1)
        
    return X,y,age,sex,fs,info

## Define the model (Gaussian functional conectivity network)

In [None]:
class GFC(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, batch_input_shape):
        self.gammad = self.add_weight(name = 'gammad',
                                shape = (),
                                initializer = 'zeros',
                                trainable = True)
        super().build(batch_input_shape)

    def call(self, X): 
        X = tf.transpose(X, perm  = (0, 3, 1, 2)) #(N, F, C, T)
        R = tf.reduce_sum(tf.math.multiply(X, X), axis = -1, keepdims = True) #(N, F, C, 1)
        D  = R - 2*tf.matmul(X, X, transpose_b = True) + tf.transpose(R, perm = (0, 1, 3, 2)) #(N, F, C, C)

        ones = tf.ones_like(D[0,0,...]) #(C, C)
        mask_a = tf.linalg.band_part(ones, 0, -1) #Upper triangular matrix of 0s and 1s (C, C)
        mask_b = tf.linalg.band_part(ones, 0, 0)  #Diagonal matrix of 0s and 1s (C, C)
        mask = tf.cast(mask_a - mask_b, dtype=tf.bool) #Make a bool mask (C, C)
        triu = tf.expand_dims(tf.boolean_mask(D, mask, axis = 2), axis = -1) #(N, F, C*(C-1)/2, 1)
        sigma = tfp.stats.percentile(tf.math.sqrt(triu), 50, axis = 2, keepdims = True) #(N, F, 1, 1)

        A = tf.math.exp(-1/(2*tf.pow(10., self.gammad)*tf.math.square(sigma))*D) #(N, F, C, C)
        A.set_shape(D.shape)
        return A

    def compute_output_shape(self, batch_input_shape):
        N, C, T, F = batch_input_shape.as_list()
        return tf.TensorShape([N, F, C, C])

    def get_config(self):
        base_config = super().get_config()
        return {**base_config}


class get_triu(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, batch_input_shape):
        super().build(batch_input_shape)

    def call(self, X): 
        N, F, C, C = X.shape
        ones = tf.ones_like(X[0,0,...]) #(C, C)
        mask_a = tf.linalg.band_part(ones, 0, -1) #Upper triangular matrix of 0s and 1s (C, C)
        mask_b = tf.linalg.band_part(ones, 0, 0)  #Diagonal matrix of 0s and 1s (C, C)
        mask = tf.cast(mask_a - mask_b, dtype=tf.bool) #Make a bool mask (C, C)
        triu = tf.expand_dims(tf.boolean_mask(X, mask, axis = 2), axis = -1) #(N, F, C*(C-1)/2, 1)

        triu.set_shape([N,F,int(C*(C-1)/2),1])
        return triu

    def compute_output_shape(self, batch_input_shape):
        N, F, C, C = batch_input_shape.as_list()
        return tf.TensorShape([N, F, int(C*(C-1)/2),1])

    def get_config(self):
        base_config = super().get_config()
        return {**base_config}
    
    
def GFC_triu_net_avg(nb_classes: int,
          Chans: int,
          Samples: int,
          l1: int = 0, 
          l2: int = 0, 
          dropoutRate: float = 0.5,
          filters: int = 1, 
          maxnorm: float = 2.0,
          maxnorm_last_layer: float = 0.5,
          kernel_time_1: int = 20,
          strid_filter_time_1: int = 1,
          bias_spatial: bool = False) -> Model:


    input_main   = Input((Chans, Samples, 1),name='Input')                    
    
    block        = Conv2D(filters,(1,kernel_time_1),strides=(1,strid_filter_time_1),
                            use_bias=bias_spatial, name='Conv2D_1',
                            kernel_constraint = max_norm(maxnorm, axis=(0,1,2))
                            )(input_main)
    
    block        = BatchNormalization(epsilon=1e-05, momentum=0.1)(block)

    block        = Activation('elu')(block)      
    
    block        = GFC(name="gfc")(block)

    block        = get_triu()(block)

    block        = AveragePooling2D(pool_size=(block.shape[1],1),strides=(1,1))(block)
    
    block        = BatchNormalization(epsilon=1e-05, momentum=0.1)(block)

    block        = Activation('elu')(block) 
    
    block        = Flatten(name='fc')(block)    

    block        = Dropout(dropoutRate)(block) 

    block        = Dense(nb_classes, kernel_regularizer=L1L2(l1=l1,l2=l2),name='logits',
                              kernel_constraint = max_norm(maxnorm_last_layer)
                              )(block)

    softmax      = Activation('softmax',name='output')(block)
    
    return Model(inputs=input_main, outputs=softmax)

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Dropout
from tensorflow.keras.layers import Conv2D, AveragePooling2D
from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import SpatialDropout2D
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras.constraints import max_norm

def EEGNet(nb_classes, Chans = 64, Samples = 128,
             dropoutRate = 0.5, kernLength = 64, F1 = 8,
             D = 2, F2 = 16, norm_rate = 0.25, dropoutType = 'Dropout'):

    if dropoutType == 'SpatialDropout2D':
        dropoutType = SpatialDropout2D
    elif dropoutType == 'Dropout':
        dropoutType = Dropout
    else:
        raise ValueError('dropoutType must be one of SpatialDropout2D '
                         'or Dropout, passed as a string.')

    input1   = Input(shape = (Chans, Samples, 1))

    block1       = Conv2D(F1, (1, kernLength), padding = 'same',
                                   name='Conv2D_1',
                                   input_shape = (Chans, Samples, 1),
                                   use_bias = False)(input1)
    block1       = BatchNormalization()(block1)
    block1       = DepthwiseConv2D((Chans, 1), use_bias = False,
                                   name='Depth_wise_Conv2D_1',
                                   depth_multiplier = D,
                                   depthwise_constraint = max_norm(1.))(block1)
    block1       = BatchNormalization()(block1)
    block1       = Activation('elu')(block1)
    block1       = AveragePooling2D((1, 4))(block1)
    block1       = dropoutType(dropoutRate)(block1)

    block2       = SeparableConv2D(F2, (1, 16),
                                   name='Separable_Conv2D_1',
                                   use_bias = False, padding = 'same')(block1)
    block2       = BatchNormalization()(block2)
    block2       = Activation('elu')(block2)
    block2       = AveragePooling2D((1, 8))(block2)
    block2       = dropoutType(dropoutRate)(block2)

    flatten      = Flatten(name = 'flatten')(block2)

    dense        = Dense(nb_classes, name = 'output',
                         kernel_constraint = max_norm(norm_rate))(flatten)
    softmax      = Activation('softmax', name = 'out_activation')(dense)

    return Model(inputs=input1, outputs=softmax)

# Experiment

## Experiment configuration 

In [None]:
seed = 23
n_subjects = 51

## Run experiment

In [None]:
!pip install -U git+https://github.com/UN-GCPDS/python-gcpds.visualizations.git

from gcpds.visualizations.topoplots import topoplot
from gcpds.visualizations.connectivities import CircosConnectivity

In [None]:
tf.random.set_seed(seed)

subjects = np.arange(n_subjects)+1

db = '../input/brain-mediators-of-pain-motor/'

num_class = 2

load_args = dict(db = db,
            f_bank = np.asarray([[4., 60.]]),
            vwt = np.asarray([[0.5,2.5]]),
            new_fs = 500.0)

for sbj in [1]:
    print(f"\nSubject: {sbj}")
    
    load_args['sbj'] = sbj

    X_train, y_train, age, sex, fs, _ = load_PAIN(**load_args)

    X_train = X_train[..., np.newaxis]
    Y_train = tf.keras.utils.to_categorical(y_train,num_classes=num_class)
    y_train = np.array([y[0] for y in y_train])

full_path = os.path.join(os.getcwd())

gfcnet_scores = np.zeros((len(subjects)-1,8))

C = X_train.shape[1]
ks = np.zeros((1, int(C*(C-1)/2)))
pvalue = np.zeros((1, int(C*(C-1)/2)))

Nfilters_ = np.zeros((len(subjects) - 1))
kernel_time_ = np.zeros((len(subjects) - 1))

i = 0

sbjs_info = []
acc_comp = []

gfc_scores_256_path = os.path.join(os.getcwd(), 'GFC_Motor_256_Gamma60Hz')

for sbj in subjects:
    if sbj == 18:
        continue
        
    load_args['sbj'] = sbj
    X_train, y_train, age, sex, fs, _ = load_PAIN(**load_args)
    y_train = np.array([y[0] for y in y_train])

    sbjInfo = {}
    sbjInfo["Subject"] = sbj

    with open(os.path.join(gfc_scores_256_path, f'Subject{sbj}.p'), 'rb') as f:
        cv = pickle.load(f)
        
    acc_comp.append((cv['mean_test_Accuracy'][cv['best_index_']], sbj, sex[0]))

ordScores = multi_sort(acc_comp)[0]

In [None]:
best_subs = [score[1] for score in ordScores[:3]]
mid_subs = [score[1] for score in ordScores[int(len(ordScores) / 2) - 1: int(len(ordScores) / 2) + 2]]
worst_subs = [score[1] for score in ordScores[-3:]]

subs_to_analyze = {'Good': best_subs, 'Mean': mid_subs, 'Bad': worst_subs}

### Calculate for FS = 500 Hz and Freq Bank = [4, 60]

In [None]:
for content in os.listdir():
    if '.png' in content or '.pdf' in content:
        os.remove(content)

In [None]:
import matplotlib.pyplot as plt

from tf_keras_vis import gradcam_plus_plus
from tf_keras_vis.utils.model_modifiers import ReplaceToLinear
from tf_keras_vis.utils.scores import CategoricalScore

tf.random.set_seed(seed)

subjects = np.arange(n_subjects)+1

db = '../input/brain-mediators-of-pain-motor/'

num_class = 2

load_args = dict(db = db,
            f_bank = np.asarray([[4., 60.]]),
            vwt = np.asarray([[0.5,2.5]]),
            new_fs = 500.0)

scores_path = os.path.join(os.getcwd(), 'GFC_Motor_500_l1l2_6gs')

# [15, 19, 21, 25, 27]
# [1, 10, 12, 23, 29, 46]
for sbj in [15, 19, 21, 25, 27]:
    if sbj == 18:
        continue

    load_args['sbj'] = sbj

    X_train, y_train, age, sex, fs, info = load_PAIN(**load_args)

    X_train = X_train[..., np.newaxis]
    Y_train = tf.keras.utils.to_categorical(y_train,num_classes=num_class)
    y_train = np.array([y[0] for y in y_train])

    channels = info['ch_names']

    model_weights_path = os.path.join(scores_path, f"Subject{sbj}_weights.h5")
    params_path = os.path.join(scores_path, f"Subject{sbj}.p")

    with open(params_path, 'rb') as f:
        cv = pickle.load(f)

    nFilters_ = cv['params'][cv['best_index_']]['filters']
    kernel_time_ = cv['params'][cv['best_index_']]['kernel_time_1']

    model = GFC_triu_net_avg(nb_classes = num_class,
            Chans = X_train.shape[1],
            Samples = X_train.shape[2],
            filters = int(nFilters_), 
            kernel_time_1 =int(kernel_time_))

    model.load_weights(model_weights_path)

    pain_scores = [CategoricalScore(list(np.zeros_like(y_train)))]
    rest_scores = [CategoricalScore(list(np.ones_like(y_train)))]

    no_softmax_model = tf.keras.Model(inputs=model.input, outputs=model.layers[-2].output)
    
    lc = gradcam_plus_plus.GradcamPlusPlus(model, model_modifier=ReplaceToLinear(), clone=False)

    Penultimate_layer = model.layers[-2]

    pain_cam = lc(score = pain_scores,
                  seed_input = X_train,
                  penultimate_layer = Penultimate_layer,
                  seek_penultimate_conv_layer = True,
                  normalize_cam = False,
                  expand_cam = True)

    rest_cam = lc(score = rest_scores,
                  seed_input = X_train,
                  penultimate_layer = Penultimate_layer,
                  seek_penultimate_conv_layer = True,
                  normalize_cam = False,
                  expand_cam = True)

    max_val_cams = np.max(np.concatenate((pain_cam[...,np.newaxis],
                                                  rest_cam[...,np.newaxis]), axis=-1), axis=(1,2,3))[:,np.newaxis,np.newaxis]

    pain_cam/=max_val_cams
    rest_cam/=max_val_cams

    pain_cam = rest_cam - pain_cam

    theta_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[4., 8.]]))
    alpha_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[8., 13.]]))
    beta_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[13., 32.]]))
    gamma_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[30., 60.]]))

    pain_signals = np.array(X_train[y_train == 0][...,0])
    rest_signals = np.array(X_train[y_train == 1][...,0])

    ##########
    mean_pain_signals = np.sum(np.mean(pain_signals, axis = 0), axis=1)
    min_mps = np.min(mean_pain_signals, axis=0)
    mean_pain_signals-=min_mps
    max_mps = np.max(mean_pain_signals, axis=0)
    mean_pain_signals/=max_mps

    mean_pain_cams = np.sum(np.mean(pain_cam, axis = 0), axis=1)
    min_mpc = np.min(mean_pain_cams, axis=0)
    mean_pain_cams-=min_mpc
    max_mpc = np.max(mean_pain_cams, axis=0)
    mean_pain_cams/=max_mpc

    mean_rest_signals = np.sum(np.mean(rest_signals, axis = 0), axis=1)
    min_mrs = np.min(mean_rest_signals, axis=0)
    mean_rest_signals-=min_mrs
    max_mrs = np.max(mean_rest_signals, axis=0)
    mean_rest_signals/=max_mrs

    mean_rest_cams = np.sum(np.mean(rest_cam, axis = 0), axis=1)
    min_mrc = np.min(mean_rest_cams, axis=0)
    mean_rest_cams-=min_mrc
    max_mrc = np.max(mean_rest_cams, axis=0)
    mean_rest_cams/=max_mrc   
    ##########

    pain_signal_theta = np.sum(np.mean(theta_filter.transform(pain_signals)[...,0,0], axis = 0), axis=1)
    pain_signal_alpha = np.sum(np.mean(alpha_filter.transform(pain_signals)[...,0,0], axis = 0), axis=1)
    pain_signal_beta = np.sum(np.mean(beta_filter.transform(pain_signals)[...,0,0], axis = 0), axis=1)
    pain_signal_gamma = np.sum(np.mean(gamma_filter.transform(pain_signals)[...,0,0], axis = 0), axis=1)
    pain_min_signal = np.min(np.concatenate((pain_signal_theta, pain_signal_alpha, pain_signal_beta, pain_signal_gamma), axis=0))

    pain_signal_theta -= pain_min_signal
    pain_signal_alpha -= pain_min_signal
    pain_signal_beta -= pain_min_signal
    pain_signal_gamma -= pain_min_signal
    pain_max_signal = np.max(np.concatenate((pain_signal_theta, pain_signal_alpha, pain_signal_beta, pain_signal_gamma), axis=0))     

    pain_signal_theta /= pain_max_signal
    pain_signal_alpha /= pain_max_signal
    pain_signal_beta /= pain_max_signal
    pain_signal_gamma /= pain_max_signal

    ####
    rest_signal_theta = np.sum(np.mean(theta_filter.transform(rest_signals)[...,0,0], axis = 0), axis=1)
    rest_signal_alpha = np.sum(np.mean(alpha_filter.transform(rest_signals)[...,0,0], axis = 0), axis=1)
    rest_signal_beta = np.sum(np.mean(beta_filter.transform(rest_signals)[...,0,0], axis = 0), axis=1)
    rest_signal_gamma = np.sum(np.mean(gamma_filter.transform(rest_signals)[...,0,0], axis = 0), axis=1)
    rest_min_signal = np.min(np.concatenate((rest_signal_theta, rest_signal_alpha, rest_signal_beta, rest_signal_gamma), axis=0))

    rest_signal_theta -= rest_min_signal
    rest_signal_alpha -= rest_min_signal
    rest_signal_beta -= rest_min_signal
    rest_signal_gamma -= rest_min_signal
    rest_max_signal = np.max(np.concatenate((rest_signal_theta, rest_signal_alpha, rest_signal_beta, rest_signal_gamma), axis=0))     

    rest_signal_theta /= rest_max_signal
    rest_signal_alpha /= rest_max_signal
    rest_signal_beta /= rest_max_signal
    rest_signal_gamma /= rest_max_signal
    ####

    pain_cam_theta = np.sum(np.mean(theta_filter.transform(pain_cam)[...,0,0], axis = 0), axis=1)
    pain_cam_alpha = np.sum(np.mean(alpha_filter.transform(pain_cam)[...,0,0], axis = 0), axis=1)
    pain_cam_beta = np.sum(np.mean(beta_filter.transform(pain_cam)[...,0,0], axis = 0), axis=1)
    pain_cam_gamma = np.sum(np.mean(gamma_filter.transform(pain_cam)[...,0,0], axis = 0), axis=1)
    pain_min_cam = np.min(np.concatenate((pain_cam_theta, pain_cam_alpha, pain_cam_beta, pain_cam_gamma), axis=0))

    pain_cam_theta -= pain_min_cam
    pain_cam_alpha -= pain_min_cam
    pain_cam_beta -= pain_min_cam
    pain_cam_gamma -= pain_min_cam
    pain_max_cam = np.max(np.concatenate((pain_cam_theta, pain_cam_alpha, pain_cam_beta, pain_cam_gamma), axis=0))

    pain_cam_theta /= pain_max_cam
    pain_cam_alpha /= pain_max_cam
    pain_cam_beta /= pain_max_cam
    pain_cam_gamma /= pain_max_cam

    ####
    rest_cam_theta = np.sum(np.mean(theta_filter.transform(rest_cam)[...,0,0], axis = 0), axis=1)
    rest_cam_alpha = np.sum(np.mean(alpha_filter.transform(rest_cam)[...,0,0], axis = 0), axis=1)
    rest_cam_beta = np.sum(np.mean(beta_filter.transform(rest_cam)[...,0,0], axis = 0), axis=1)
    rest_cam_gamma = np.sum(np.mean(gamma_filter.transform(rest_cam)[...,0,0], axis = 0), axis=1)
    rest_min_cam = np.min(np.concatenate((rest_cam_theta, rest_cam_alpha, rest_cam_beta, rest_cam_gamma), axis=0))

    rest_cam_theta -= rest_min_cam
    rest_cam_alpha -= rest_min_cam
    rest_cam_beta -= rest_min_cam
    rest_cam_gamma -= rest_min_cam
    rest_max_cam = np.max(np.concatenate((rest_cam_theta, rest_cam_alpha, rest_cam_beta), axis=0))

    rest_cam_theta /= rest_max_cam
    rest_cam_alpha /= rest_max_cam
    rest_cam_beta /= rest_max_cam
    rest_cam_gamma /= rest_max_cam
    ####

    fig, ax = plt.subplot_mosaic([['ll', 'lm', 'lm1', 'lm2', 'lr'], ['bl', 'bm', 'bm1', 'bm2', 'br']], figsize=(14, 6))
    
    im_00 = topoplot(pain_cam_theta, channels, cmap='viridis', ax=ax['ll'], vlim=(0.0, 1.0), show=False)
    topoplot(pain_cam_alpha, channels, cmap='viridis', ax=ax['lm'], vlim=(0.0, 1.0), show=False)
    topoplot(pain_cam_beta, channels, cmap='viridis', ax=ax['lm1'], vlim=(0.0, 1.0), show=False)
    topoplot(pain_cam_gamma, channels, cmap='viridis', ax=ax['lm2'], vlim=(0.0, 1.0), show=False)
    topoplot(mean_pain_cams, channels, cmap='viridis', ax=ax['lr'], vlim=(0.0, 1.0), show=False)

    topoplot(rest_cam_theta, channels, cmap='viridis', ax=ax['bl'], vlim=(0.0, 1.0), show=False)
    topoplot(rest_cam_alpha, channels, cmap='viridis', ax=ax['bm'], vlim=(0.0, 1.0), show=False)
    topoplot(rest_cam_beta, channels, cmap='viridis', ax=ax['bm1'], vlim=(0.0, 1.0), show=False)
    topoplot(rest_cam_gamma, channels, cmap='viridis', ax=ax['bm2'], vlim=(0.0, 1.0), show=False)
    topoplot(mean_rest_cams, channels, cmap='viridis', ax=ax['br'], vlim=(0.0, 1.0), show=False)

    ax['ll'].set_ylabel(f'CAMs Pain', size=20)
    ax['bl'].set_ylabel(f'CAMs Rest', size=20)
    ax['ll'].set_title(f'Theta', size=20)
    ax['lm'].set_title(f'Alpha', size=20)
    ax['lm1'].set_title(f'Beta', size=20)
    ax['lm2'].set_title(f'Gamma', size=20)
    ax['lr'].set_title(f'Full Band', size=20)

    fig.tight_layout()
    fig.subplots_adjust(right=0.75)
    cbar_ax = fig.add_axes([0.8, 0.3, 0.02, 0.5])
    fig.colorbar(im_00[0], cax=cbar_ax)
    plt.savefig(f'Subject{sbj}Cams.pdf', format='pdf', dpi=300)

    plt.show()

### Connectivity

In [None]:
areas = { ## mine
    'Frontal': ['Fpz', 'Fz'],
    'Frontal Right': ['Fp2','AF4','F4','F6','F8',],
    'Central Right': ['FC2','FC4','FC6','FT8', 'FT10', 'C2','C4','C6','T8','CP2','CP4','CP6','TP8', 'TP10'],
    'Posterior Right': ['P2','P4','P6','P8','P10','PO4','PO8','O2', 'PO10'],
    # 'Central': ['Cz'],
    'Posterior': ['Cz', 'CPz','Pz', 'POz', 'Oz','Iz',],
    'Posterior Left': ['P1','P3','P5','P7','P9','PO3','PO7','O1', 'PO9'],
    'Central Left': ['FC1','FC3','FC5','FT7','C1','C3','C5','T7','CP1','CP3','CP5','TP7', 'FT9', 'TP9'],
    'Frontal Left': ['Fp1','AF3', 'F3', 'F5', 'F7',],
}

arcs = ['areas', 'channels']

areas_cmap='Set3'
arcs_cmap='Purples'

In [None]:
import matplotlib.pyplot as plt

tf.random.set_seed(seed)

subjects = np.arange(n_subjects)+1

db = '../input/brain-mediators-of-pain-motor/'

num_class = 2

load_args = dict(db = db,
            f_bank = np.asarray([[4., 60.]]),
            vwt = np.asarray([[0.5,2.5]]),
            new_fs = 256.0)

scores_path = os.path.join(os.getcwd(), 'GFC_Motor_256_Gamma60Hz')

# [15, 19, 21, 25, 27]
# [1, 10, 12, 23, 29, 46]
for sbj in [15, 19, 21, 25, 27]:
    if sbj == 18:
        continue

    load_args['sbj'] = sbj

    X_train, y_train, age, sex, fs, info = load_PAIN(**load_args)

    y_train_ = y_train.copy()
    X_train = X_train[..., np.newaxis]
    Y_train = tf.keras.utils.to_categorical(y_train,num_classes=num_class)
    y_train = np.array([y[0] for y in y_train])

    channels = info['ch_names']

    model_weights_path = os.path.join(scores_path, f"Subject{sbj}_weights.h5")
    params_path = os.path.join(scores_path, f"Subject{sbj}.p")

    with open(params_path, 'rb') as f:
        cv = pickle.load(f)

    nFilters_ = cv['params'][cv['best_index_']]['filters']
    kernel_time_ = cv['params'][cv['best_index_']]['kernel_time_1']

    model = GFC_triu_net_avg(nb_classes = num_class,
            Chans = X_train.shape[1],
            Samples = X_train.shape[2],
            filters = int(nFilters_), 
            kernel_time_1 =int(kernel_time_))

    model.load_weights(model_weights_path)

    layer_name='fc'
    fC_layer = Model(inputs=model.input, outputs=model.get_layer(layer_name).output)
    fc_layer_predict = fC_layer.predict(X_train)

    ks, pvalue = ks_connetivity_cal(fc_layer_predict, y_train_.reshape(-1))

    mask = np.copy(fc_layer_predict)
    mask[:,pvalue > 0.05] = 0
    max_mask = np.max(mask, axis=0)
    abs_vals = abs(max_mask)
    max_val = np.max(abs_vals)
    v = abs(max_mask / max_val)
    plt.figure(figsize=(16,12), dpi=100)
    ax = CircosConnectivity(v, channels=channels, areas=areas, labelsize=20, min_alpha=0.1, threshold=0.97, 
                      areas_cmap=areas_cmap, arcs_cmap=arcs_cmap, size=20, show_emisphere=False, arcs_separation=30,
                      hemisphere_color='lightgray', channel_color='#f8f9fa', connection_width=0.2, small_separation=5, 
                      big_separation=5, offset=-1.5, vmin=0.001,vmax=1.0)
    plt.savefig(f'Subject{sbj}_Connectivities.png')
    plt.show()