# Set-up

## Installing libraries and libcudnn8

In [1]:
import os

FILEID = "1h4FWB5fw7sBDCSM-EENK1UadqKSCqg24"

contents = os.listdir(os.getcwd())

if 'MI_EEG_ClassMeth' not in contents:
    !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='$FILEID -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id="$FILEID -O MI_EEG_ClassMeth.zip && rm -rf /tmp/cookies.txt
    !unzip MI_EEG_ClassMeth.zip
else:
    print("MI_EEG_ClassMeth already downloaded!")

!apt-get install --allow-change-held-packages libcudnn8=8.1.1.33-1+cuda11.2 -y
!pip install -U git+https://github.com/UN-GCPDS/python-gcpds.databases
!pip install mne
!pip install pickle5
!pip install gcpds.utils
!pip install scikeras[tensorflow]

MI_EEG_ClassMeth already downloaded!
Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following packages will be REMOVED:
  libcudnn8-dev
The following held packages will be changed:
  libcudnn8
The following packages will be upgraded:
  libcudnn8
1 upgraded, 0 newly installed, 1 to remove and 91 not upgraded.
Need to get 421 MB of archives.
After this operation, 2621 MB disk space will be freed.
Get:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64  libcudnn8 8.1.1.33-1+cuda11.2 [421 MB]
Fetched 421 MB in 6s (67.8 MB/s)                                               
(Reading database ... 108827 files and directories currently installed.)
Removing libcudnn8-dev (8.0.5.39-1+cuda11.0) ...
update-alternatives: removing manually selected alternative - switching libcudnn to auto mode
(Reading database ... 108805 files and directories currently installed.)
Preparing to unpack .../libcudnn8_8.1.1.33-1+cuda11.2_amd64

## Import libraries

In [2]:
# freq filter 
from MI_EEG_ClassMeth.FeatExtraction import TimeFrequencyRpr

#EEG montage
from gcpds.utils.mne_handler import get_best_montage

# general
import numpy as np
from scipy.signal import resample
import pickle5 as pickle
import warnings
import mne
from time import time
warnings.filterwarnings('ignore')

# tensorlfow 
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Dropout, Conv2D, AveragePooling2D, BatchNormalization, Input, Flatten
from tensorflow.keras.constraints import max_norm
from tensorflow.keras.layers import Layer
from tensorflow.keras.regularizers import L1L2

# scikeras
from scikeras.wrappers import KerasClassifier
from sklearn.model_selection import GridSearchCV,StratifiedShuffleSplit
from sklearn.metrics import make_scorer
from sklearn.metrics import accuracy_score, cohen_kappa_score, roc_auc_score

## Define functions

In [3]:
def kappa(y_true, y_pred):
    return cohen_kappa_score(np.argmax(y_true, axis = 1),np.argmax(y_pred, axis = 1))

## PAIN dataset

In [4]:
def load_PAIN(db,sbj,f_bank,vwt,new_fs):

    channels_names = np.array(['Fp1','Fp2',
                      'F3','F4','C3','C4','P3','P4','O1','O2','F7','F8',
                      'T7','T8','P7','P8','Fz','Cz','Pz','Oz',
                      'FC1','FC2','CP1','CP2','FC5','FC6','CP5','CP6',
                      'TP9','TP10','LE','RE','P1','P2','C1','C2',
                      'FT9','FT10','AF3','AF4','FC3','FC4','CP3','CP4','PO3','PO4',
                      'F5','F6','C5','C6','P5','P6','PO9','Iz','FT7','FT8',
                      'TP7','TP8','PO7','PO8','Fpz','PO10','CPz','POz',
                      'Ne','Ma','Ext','ECG'])
    
    with open('{}BMOP_Motor_S{}.pkl'.format(db,'0' + str(sbj) if sbj < 10 else sbj), 'rb') as f:
        data = pickle.load(f)
        
    X = data['X']  # trials, channels, time
    y = data['y']
    sex = data['sex'].ravel()
    age = data['age'].ravel()
    fs = float(data['fs'])
    
    tf_repr = TimeFrequencyRpr(sfreq = fs, f_bank = f_bank, vwt = vwt)
    
    #Read electrode positions to load the best standard montage-MNE
    best_montages = get_best_montage(channels_names)
    montage = best_montages.iloc[0]['montage']
    no_channels = np.array(best_montages.iloc[0]['missings channels'])
    channels_to_remove = np.array([np.argwhere(channels_names==no)[0] for no in no_channels])[:,0]

    #Delete the missing channels the original array respecting the positions
    channels_names = np.delete(channels_names, channels_to_remove)
    X = np.delete(X, channels_to_remove, axis=1)

    #Number channels does not match with the dimension of X, 
    #thus the last channel is discarded because it has weird amplitudes
    X = X[:,:-1,:]

    info = mne.create_info(list(channels_names), sfreq=fs, ch_types="eeg")
    info.set_montage(montage)
    info

    event_id = {
        'pain/high':2,
        'resting':3,
        }

    events = [[i, 1, cls[0]] for i, cls in enumerate(y)]
    tmin = 0

    epochs = mne.EpochsArray(X, info, events=events, tmin=tmin, event_id=event_id)
    X = epochs.get_data()
    y = y-2
    X = np.squeeze(tf_repr.transform(X))
                             
    if new_fs != fs:
        X = resample(X, int((X.shape[-1]/fs)*new_fs), axis = -1)
        
    return X,y,age,sex,fs

## Define the model (Gaussian functional conectivity network)

In [5]:
class GFC(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, batch_input_shape):
        self.gammad = self.add_weight(name = 'gammad',
                                shape = (),
                                initializer = 'zeros',
                                trainable = True)
        super().build(batch_input_shape)

    def call(self, X): 
        X = tf.transpose(X, perm  = (0, 3, 1, 2)) #(N, F, C, T)
        R = tf.reduce_sum(tf.math.multiply(X, X), axis = -1, keepdims = True) #(N, F, C, 1)
        D  = R - 2*tf.matmul(X, X, transpose_b = True) + tf.transpose(R, perm = (0, 1, 3, 2)) #(N, F, C, C)

        ones = tf.ones_like(D[0,0,...]) #(C, C)
        mask_a = tf.linalg.band_part(ones, 0, -1) #Upper triangular matrix of 0s and 1s (C, C)
        mask_b = tf.linalg.band_part(ones, 0, 0)  #Diagonal matrix of 0s and 1s (C, C)
        mask = tf.cast(mask_a - mask_b, dtype=tf.bool) #Make a bool mask (C, C)
        triu = tf.expand_dims(tf.boolean_mask(D, mask, axis = 2), axis = -1) #(N, F, C*(C-1)/2, 1)
        sigma = tfp.stats.percentile(tf.math.sqrt(triu), 50, axis = 2, keepdims = True) #(N, F, 1, 1)

        A = tf.math.exp(-1/(2*tf.pow(10., self.gammad)*tf.math.square(sigma))*D) #(N, F, C, C)
        A.set_shape(D.shape)
        return A

    def compute_output_shape(self, batch_input_shape):
        N, C, T, F = batch_input_shape.as_list()
        return tf.TensorShape([N, F, C, C])

    def get_config(self):
        base_config = super().get_config()
        return {**base_config}


class get_triu(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, batch_input_shape):
        super().build(batch_input_shape)

    def call(self, X): 
        N, F, C, C = X.shape
        ones = tf.ones_like(X[0,0,...]) #(C, C)
        mask_a = tf.linalg.band_part(ones, 0, -1) #Upper triangular matrix of 0s and 1s (C, C)
        mask_b = tf.linalg.band_part(ones, 0, 0)  #Diagonal matrix of 0s and 1s (C, C)
        mask = tf.cast(mask_a - mask_b, dtype=tf.bool) #Make a bool mask (C, C)
        triu = tf.expand_dims(tf.boolean_mask(X, mask, axis = 2), axis = -1) #(N, F, C*(C-1)/2, 1)

        triu.set_shape([N,F,int(C*(C-1)/2),1])
        return triu

    def compute_output_shape(self, batch_input_shape):
        N, F, C, C = batch_input_shape.as_list()
        return tf.TensorShape([N, F, int(C*(C-1)/2),1])

    def get_config(self):
        base_config = super().get_config()
        return {**base_config}
    
    
def GFC_triu_net_avg(nb_classes: int,
          Chans: int,
          Samples: int,
          l1: int = 0, 
          l2: int = 0, 
          dropoutRate: float = 0.5,
          filters: int = 1, 
          maxnorm: float = 2.0,
          maxnorm_last_layer: float = 0.5,
          kernel_time_1: int = 20,
          strid_filter_time_1: int = 1,
          bias_spatial: bool = False) -> Model:


    input_main   = Input((Chans, Samples, 1),name='Input')                    
    
    block        = Conv2D(filters,(1,kernel_time_1),strides=(1,strid_filter_time_1),
                            use_bias=bias_spatial,
                            kernel_constraint = max_norm(maxnorm, axis=(0,1,2))
                            )(input_main)
    
    block        = BatchNormalization(epsilon=1e-05, momentum=0.1)(block)

    block        = Activation('elu')(block)      
    
    block        = GFC()(block)

    block        = get_triu()(block)

    block        = AveragePooling2D(pool_size=(block.shape[1],1),strides=(1,1))(block)
    
    block        = BatchNormalization(epsilon=1e-05, momentum=0.1)(block)

    block        = Activation('elu')(block) 
    
    block        = Flatten()(block)    

    block        = Dropout(dropoutRate)(block) 

    block        = Dense(nb_classes, kernel_regularizer=L1L2(l1=l1,l2=l2),name='logits',
                              kernel_constraint = max_norm(maxnorm_last_layer)
                              )(block)

    softmax      = Activation('softmax',name='output')(block)
    
    return Model(inputs=input_main, outputs=softmax)

# Experiment

## Experiment configuration 

In [6]:
import os 
seed=23
folds=5
epochs_train = 500

model_name = f'GFC'

save_folder = os.path.join('Groups_Motor256_Gamma60Hz_Loso')

n_subjects = 51

In [7]:
import os
PATH = f'{os. getcwd()}/{save_folder}'

## Run experiment

In [8]:
from sklearn.preprocessing import MinMaxScaler, RobustScaler
from scipy.stats import ks_2samp
from scipy.spatial.distance import squareform

tf.random.set_seed(seed)

accs_dict = {}

subjects = np.arange(n_subjects)+1
subjects = np.delete(subjects, 17)

db = '../input/brain-mediators-of-pain-motor/'

num_class = 2

load_args = dict(db = db,
            f_bank = np.asarray([[4., 60.]]),
            vwt = np.asarray([[0.5,2.5]]),
            new_fs = 256.0)

load_args['sbj'] = 1 

X_train, y_train, age, sex, fs = load_PAIN(**load_args)

Not setting metadata
40 matching events found
No baseline correction applied
0 projection items activated


### Groups definition

In [9]:
groups = {
    "G1": {
        "Name": "Group1",
        "Subjects": [2, 3, 7, 20, 21, 30, 31, 36, 40, 44, 49, 50],
        "Color": "Dark Blue"
    },
    "G2": {
        "Name": "Group2",
        "Subjects": [1, 4, 5, 6, 8, 11, 14, 17, 22, 32, 33, 34, 35, 39, 42, 43, 45, 47, 51],
        "Color": "Light Blue"
    },
    "G3": {
        "Name": "Group3",
        "Subjects": [9, 10, 12, 13, 15, 16, 19, 23, 25, 26, 27, 28, 29, 46],
        "Color": "Yellow"
    },
    "G4": {
        "Name": "Group4",
        "Subjects": [24, 37, 38, 41, 48],
        "Color": "Brown"
    }
}

In [10]:
from sklearn.model_selection import LeaveOneGroupOut

print("Starting experiment...\n")

t=time()
groups_keys = list(groups.keys())

for group in groups_keys:
    group_name = groups[group]["Name"]
    
    print("------------------------------------------------------------------------------------------\n")
    print(f"                               {group_name} starting...")
    print("------------------------------------------------------------------------------------------\n")
    
    group_subs = groups[group]["Subjects"]
    print(f"Loading {group_name} subjects\n")
    
    groups[group]["Groups"] = []
    
    g = 0
    
    for sbj in group_subs:
        print(f"Loading subject: {sbj}\n")
        load_args['sbj'] = sbj 

        if (sbj == group_subs[0]):
            X_train, Y_train, _, sex, _ = load_PAIN(**load_args)
            
            g+=1
            groups[group]["Groups"] += [g] * len(X_train)
            
        else:
            X_train_, Y_train_, _, sex, _ = load_PAIN(**load_args)
            
            X_train = np.concatenate((X_train, X_train_), axis = 0)
            Y_train = np.concatenate((Y_train, Y_train_), axis = 0)
            
            g+=1
            groups[group]["Groups"] += [g] * len(X_train_)
        print("\n")
    
    Y_train = tf.keras.utils.to_categorical(Y_train,num_classes=num_class)

    # ----build model
    clf = KerasClassifier(
        GFC_triu_net_avg,
        random_state=seed,

        # ----model hyperparameters
        nb_classes=num_class, 
        Chans = X_train.shape[1], 
        Samples = X_train.shape[2],
        dropoutRate=0.5,
        l1 = 0, l2 = 0,
        filters=2, maxnorm=2.0,maxnorm_last_layer=0.5,
        kernel_time_1=25,strid_filter_time_1= 1,
        bias_spatial = False,

        # ----model config
        verbose=0,
        batch_size=500, #full batch        
        loss=tf.keras.losses.CategoricalCrossentropy(),
        optimizer="adam",
        optimizer_learning__rate=0.1,
        metrics = ['accuracy'],
        epochs = epochs_train)
    
    # ----search params
    param_grid =  {
                'filters':[2,4],
                'kernel_time_1':[25,50],
                }
    
    logo = LeaveOneGroupOut()

    # ----Gridsearch
    scoring = {"AUC": 'roc_auc', "Accuracy": make_scorer(accuracy_score),'Kappa':make_scorer(kappa)}
    
#     cv = GridSearchCV(clf,param_grid,cv=StratifiedShuffleSplit(n_splits = folds, test_size = 0.2, random_state = seed),
#                          verbose=0,n_jobs=1, scoring=scoring, refit="Accuracy")
    
#     cv.fit(X_train,Y_train)

    cv = GridSearchCV(clf,param_grid,cv=logo,
                         verbose=0,n_jobs=1,
                         scoring=scoring,
                         refit="Accuracy")
#     ----find best params with gridsearch
    cv.fit(X = X_train, y = Y_train, groups = groups[group]["Groups"])

    # ----best score
    print('Group', group_name,'Accuracy',cv.best_score_,'elapsed time',time()-t)
    print('---------')

    cv.cv_results_['best_index_'] = cv.best_index_

    full_path = os.path.join(PATH)

    try:
        os.makedirs(full_path)
    except:
        pass

    cv.best_estimator_.model_.save_weights(full_path + f'/{group_name}_loso_weights.h5')
    with open(full_path + f'/{group_name}_loso.p','wb') as f:
        pickle.dump(cv.cv_results_,f)     

Starting experiment...

------------------------------------------------------------------------------------------

                               Group1 starting...
------------------------------------------------------------------------------------------

Loading Group1 subjects

Loading subject: 2

Not setting metadata
40 matching events found
No baseline correction applied
0 projection items activated


Loading subject: 3

Not setting metadata
40 matching events found
No baseline correction applied
0 projection items activated


Loading subject: 7

Not setting metadata
40 matching events found
No baseline correction applied
0 projection items activated


Loading subject: 20

Not setting metadata
40 matching events found
No baseline correction applied
0 projection items activated


Loading subject: 21

Not setting metadata
40 matching events found
No baseline correction applied
0 projection items activated


Loading subject: 30

Not setting metadata
40 matching events found
No basel