# Set-up

## Installing libraries and libcudnn8

In [None]:
import os

FILEID = "1h4FWB5fw7sBDCSM-EENK1UadqKSCqg24"

contents = os.listdir(os.getcwd())

if 'MI_EEG_ClassMeth' not in contents:
    !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='$FILEID -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id="$FILEID -O MI_EEG_ClassMeth.zip && rm -rf /tmp/cookies.txt
    !unzip MI_EEG_ClassMeth.zip
else:
    print("MI_EEG_ClassMeth already downloaded!")

!apt-get install --allow-change-held-packages libcudnn8=8.1.1.33-1+cuda11.2 -y
!pip install -U git+https://github.com/UN-GCPDS/python-gcpds.databases
!pip install mne
!pip install pickle5
!pip install gcpds.utils
!pip install scikeras[tensorflow]
!pip install tf_keras_vis

## Get Weigths and Scores

### GFC

In [None]:
FILEID = "12bxViJ8j3U4RLWVbZa0SuRLcxojIdJvZ"

contents = os.listdir(os.getcwd())

if 'GFC_Motor_256_Gamma60Hz.zip' not in contents:
    !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='$FILEID -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id="$FILEID -O GFC_Motor_256_Gamma60Hz.zip && rm -rf /tmp/cookies.txt
    !unzip GFC_Motor_256_Gamma60Hz.zip

### EEGNet

In [None]:
import shutil

FILEID = "1PfhH5Uj5oKK2AY3Mgltw1iD42smIiO9d"

contents = os.listdir(os.getcwd())

folder = 'MotorConditionEEGNet_60Hz'

if 'MotorConditionEEGNet_60Hz.zip' not in contents:
    !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='$FILEID -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id="$FILEID -O MotorConditionEEGNet_60Hz.zip && rm -rf /tmp/cookies.txt
    !unzip MotorConditionEEGNet_60Hz.zip
    os.makedirs(folder)

    for content in contents:
        if '.p' in content or '.h5' in content:
            shutil.move(os.path.join(os.getcwd(), content), os.path.join(os.getcwd(), folder, content))

## Import libraries

In [None]:
# freq filter 
from MI_EEG_ClassMeth.FeatExtraction import TimeFrequencyRpr

#EEG montage
from gcpds.utils.mne_handler import get_best_montage

# general
import numpy as np
from scipy.signal import resample
import pickle5 as pickle
import warnings
import mne
from time import time
warnings.filterwarnings('ignore')

# tensorlfow 
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Dropout, Conv2D, AveragePooling2D, BatchNormalization, Input, Flatten
from tensorflow.keras.constraints import max_norm
from tensorflow.keras.layers import Layer
from tensorflow.keras.regularizers import L1L2

# scikeras
from scikeras.wrappers import KerasClassifier
from sklearn.model_selection import GridSearchCV,StratifiedShuffleSplit
from sklearn.metrics import make_scorer
from sklearn.metrics import accuracy_score, cohen_kappa_score, roc_auc_score

## Define functions

In [None]:
def kappa(y_true, y_pred):
    return cohen_kappa_score(np.argmax(y_true, axis = 1),np.argmax(y_pred, axis = 1))

In [None]:
def multi_sort(*args, reverse=True):
    sorted_lists = (list(t) for t in zip(*sorted(zip(*args), reverse=reverse)))
    return tuple(sorted_lists)

## PAIN dataset

In [None]:
def load_PAIN(db,sbj,f_bank,vwt,new_fs):

    channels_names = np.array(['Fp1','Fp2',
                      'F3','F4','C3','C4','P3','P4','O1','O2','F7','F8',
                      'T7','T8','P7','P8','Fz','Cz','Pz','Oz',
                      'FC1','FC2','CP1','CP2','FC5','FC6','CP5','CP6',
                      'TP9','TP10','LE','RE','P1','P2','C1','C2',
                      'FT9','FT10','AF3','AF4','FC3','FC4','CP3','CP4','PO3','PO4',
                      'F5','F6','C5','C6','P5','P6','PO9','Iz','FT7','FT8',
                      'TP7','TP8','PO7','PO8','Fpz','PO10','CPz','POz',
                      'Ne','Ma','Ext','ECG'])
    
    with open('{}BMOP_Motor_S{}.pkl'.format(db,'0' + str(sbj) if sbj < 10 else sbj), 'rb') as f:
        data = pickle.load(f)
        
    X = data['X']  # trials, channels, time
    y = data['y']
    sex = data['sex'].ravel()
    age = data['age'].ravel()
    fs = float(data['fs'])
    
    tf_repr = TimeFrequencyRpr(sfreq = fs, f_bank = f_bank, vwt = vwt)
    
    #Read electrode positions to load the best standard montage-MNE
    best_montages = get_best_montage(channels_names)
    montage = best_montages.iloc[0]['montage']
    no_channels = np.array(best_montages.iloc[0]['missings channels'])
    channels_to_remove = np.array([np.argwhere(channels_names==no)[0] for no in no_channels])[:,0]

    #Delete the missing channels the original array respecting the positions
    channels_names = np.delete(channels_names, channels_to_remove)
    X = np.delete(X, channels_to_remove, axis=1)

    #Number channels does not match with the dimension of X, 
    #thus the last channel is discarded because it has weird amplitudes
    X = X[:,:-1,:]

    info = mne.create_info(list(channels_names), sfreq=fs, ch_types="eeg")
    info.set_montage(montage)
    info

    event_id = {
        'pain/high':2,
        'resting':3,
        }

    events = [[i, 1, cls[0]] for i, cls in enumerate(y)]
    tmin = 0

    epochs = mne.EpochsArray(X, info, events=events, tmin=tmin, event_id=event_id)
    X = epochs.get_data()
    y = y-2
    X = np.squeeze(tf_repr.transform(X))
                             
    #Resampling
    if new_fs != fs:
        X = resample(X, int((X.shape[-1]/fs)*new_fs), axis = -1)
        
    return X,y,age,sex,fs,info

## Define the model (Gaussian functional conectivity network)

In [None]:
class GFC(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, batch_input_shape):
        self.gammad = self.add_weight(name = 'gammad',
                                shape = (),
                                initializer = 'zeros',
                                trainable = True)
        super().build(batch_input_shape)

    def call(self, X): 
        X = tf.transpose(X, perm  = (0, 3, 1, 2)) #(N, F, C, T)
        R = tf.reduce_sum(tf.math.multiply(X, X), axis = -1, keepdims = True) #(N, F, C, 1)
        D  = R - 2*tf.matmul(X, X, transpose_b = True) + tf.transpose(R, perm = (0, 1, 3, 2)) #(N, F, C, C)

        ones = tf.ones_like(D[0,0,...]) #(C, C)
        mask_a = tf.linalg.band_part(ones, 0, -1) #Upper triangular matrix of 0s and 1s (C, C)
        mask_b = tf.linalg.band_part(ones, 0, 0)  #Diagonal matrix of 0s and 1s (C, C)
        mask = tf.cast(mask_a - mask_b, dtype=tf.bool) #Make a bool mask (C, C)
        triu = tf.expand_dims(tf.boolean_mask(D, mask, axis = 2), axis = -1) #(N, F, C*(C-1)/2, 1)
        sigma = tfp.stats.percentile(tf.math.sqrt(triu), 50, axis = 2, keepdims = True) #(N, F, 1, 1)

        A = tf.math.exp(-1/(2*tf.pow(10., self.gammad)*tf.math.square(sigma))*D) #(N, F, C, C)
        A.set_shape(D.shape)
        return A

    def compute_output_shape(self, batch_input_shape):
        N, C, T, F = batch_input_shape.as_list()
        return tf.TensorShape([N, F, C, C])

    def get_config(self):
        base_config = super().get_config()
        return {**base_config}


class get_triu(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, batch_input_shape):
        super().build(batch_input_shape)

    def call(self, X): 
        N, F, C, C = X.shape
        ones = tf.ones_like(X[0,0,...]) #(C, C)
        mask_a = tf.linalg.band_part(ones, 0, -1) #Upper triangular matrix of 0s and 1s (C, C)
        mask_b = tf.linalg.band_part(ones, 0, 0)  #Diagonal matrix of 0s and 1s (C, C)
        mask = tf.cast(mask_a - mask_b, dtype=tf.bool) #Make a bool mask (C, C)
        triu = tf.expand_dims(tf.boolean_mask(X, mask, axis = 2), axis = -1) #(N, F, C*(C-1)/2, 1)

        triu.set_shape([N,F,int(C*(C-1)/2),1])
        return triu

    def compute_output_shape(self, batch_input_shape):
        N, F, C, C = batch_input_shape.as_list()
        return tf.TensorShape([N, F, int(C*(C-1)/2),1])

    def get_config(self):
        base_config = super().get_config()
        return {**base_config}
    
    
def GFC_triu_net_avg(nb_classes: int,
          Chans: int,
          Samples: int,
          l1: int = 0, 
          l2: int = 0, 
          dropoutRate: float = 0.5,
          filters: int = 1, 
          maxnorm: float = 2.0,
          maxnorm_last_layer: float = 0.5,
          kernel_time_1: int = 20,
          strid_filter_time_1: int = 1,
          bias_spatial: bool = False) -> Model:


    input_main   = Input((Chans, Samples, 1),name='Input')                    
    
    block        = Conv2D(filters,(1,kernel_time_1),strides=(1,strid_filter_time_1),
                            use_bias=bias_spatial, name='Conv2D_1',
                            kernel_constraint = max_norm(maxnorm, axis=(0,1,2))
                            )(input_main)
    
    block        = BatchNormalization(epsilon=1e-05, momentum=0.1)(block)

    block        = Activation('elu')(block)      
    
    block        = GFC(name="gfc")(block)

    block        = get_triu()(block)

    block        = AveragePooling2D(pool_size=(block.shape[1],1),strides=(1,1))(block)
    
    block        = BatchNormalization(epsilon=1e-05, momentum=0.1)(block)

    block        = Activation('elu')(block) 
    
    block        = Flatten(name='fc')(block)    

    block        = Dropout(dropoutRate)(block) 

    block        = Dense(nb_classes, kernel_regularizer=L1L2(l1=l1,l2=l2),name='logits',
                              kernel_constraint = max_norm(maxnorm_last_layer)
                              )(block)

    softmax      = Activation('softmax',name='output')(block)
    
    return Model(inputs=input_main, outputs=softmax)

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Dropout
from tensorflow.keras.layers import Conv2D, AveragePooling2D
from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import SpatialDropout2D
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras.constraints import max_norm

def EEGNet(nb_classes, Chans = 64, Samples = 128,
             dropoutRate = 0.5, kernLength = 64, F1 = 8,
             D = 2, F2 = 16, norm_rate = 0.25, dropoutType = 'Dropout'):

    if dropoutType == 'SpatialDropout2D':
        dropoutType = SpatialDropout2D
    elif dropoutType == 'Dropout':
        dropoutType = Dropout
    else:
        raise ValueError('dropoutType must be one of SpatialDropout2D '
                         'or Dropout, passed as a string.')

    input1   = Input(shape = (Chans, Samples, 1))

    block1       = Conv2D(F1, (1, kernLength), padding = 'same',
                                   name='Conv2D_1',
                                   input_shape = (Chans, Samples, 1),
                                   use_bias = False)(input1)
    block1       = BatchNormalization()(block1)
    block1       = DepthwiseConv2D((Chans, 1), use_bias = False,
                                   name='Depth_wise_Conv2D_1',
                                   depth_multiplier = D,
                                   depthwise_constraint = max_norm(1.))(block1)
    block1       = BatchNormalization()(block1)
    block1       = Activation('elu')(block1)
    block1       = AveragePooling2D((1, 4))(block1)
    block1       = dropoutType(dropoutRate)(block1)

    block2       = SeparableConv2D(F2, (1, 16),
                                   name='Separable_Conv2D_1',
                                   use_bias = False, padding = 'same')(block1)
    block2       = BatchNormalization()(block2)
    block2       = Activation('elu')(block2)
    block2       = AveragePooling2D((1, 8))(block2)
    block2       = dropoutType(dropoutRate)(block2)

    flatten      = Flatten(name = 'flatten')(block2)

    dense        = Dense(nb_classes, name = 'output',
                         kernel_constraint = max_norm(norm_rate))(flatten)
    softmax      = Activation('softmax', name = 'out_activation')(dense)

    return Model(inputs=input1, outputs=softmax)

# Experiment

## Experiment configuration 

In [None]:
seed = 23
n_subjects = 51

## Run experiment

In [None]:
!pip install -U git+https://github.com/UN-GCPDS/python-gcpds.visualizations.git

from gcpds.visualizations.topoplots import topoplot

In [None]:
import matplotlib.pyplot as plt

from sklearn.preprocessing import MinMaxScaler, RobustScaler
from scipy.stats import ks_2samp
from scipy.spatial.distance import squareform

tf.random.set_seed(seed)

subjects = np.arange(n_subjects)+1

db = '../input/brain-mediators-of-pain-motor/'

num_class = 2

load_args = dict(db = db,
            f_bank = np.asarray([[4., 100.]]),
            vwt = np.asarray([[0.5,2.5]]),
            new_fs = 256.0)

PATH = os.path.join(os.getcwd())

for sbj in [1]:
    if sbj == 18:
        continue
        
    print(f"\nSubject: {sbj}")
    
    load_args['sbj'] = sbj

    X_train, y_train, age, sex, fs, _ = load_PAIN(**load_args)

    X_train = X_train[..., np.newaxis]
    Y_train = tf.keras.utils.to_categorical(y_train,num_classes=num_class)
    y_train = np.array([y[0] for y in y_train])

full_path = os.path.join(PATH)

gfcnet_scores = np.zeros((len(subjects)-1,8))

C = X_train.shape[1]
ks = np.zeros((1, int(C*(C-1)/2)))
pvalue = np.zeros((1, int(C*(C-1)/2)))

Nfilters_ = np.zeros((len(subjects) - 1))
kernel_time_ = np.zeros((len(subjects) - 1))

i = 0

sbjs_info = []
acc_comp = []

gfc_scores_256_path = os.path.join(os.getcwd(), 'GFC_Motor_256_Gamma60Hz')

for sbj in subjects:
    if sbj == 18:
        continue
        
    load_args['sbj'] = sbj
    X_train, y_train, age, sex, fs, _ = load_PAIN(**load_args)
    y_train = np.array([y[0] for y in y_train])

    sbjInfo = {}
    sbjInfo["Subject"] = sbj

    with open(os.path.join(gfc_scores_256_path, f'Subject{sbj}.p'), 'rb') as f:
        cv = pickle.load(f)

    sbjInfo["Mean Test Acc 256Hz"] = cv['mean_test_Accuracy'][cv['best_index_']]
    sbjInfo["Std Test Acc 256Hz"] = cv['std_test_Accuracy'][cv['best_index_']]
    sbjInfo["Mean Test Kappa 256Hz"] = cv['mean_test_Kappa'][cv['best_index_']]
    sbjInfo["Std Test Kappa 256Hz"] = cv['std_test_Kappa'][cv['best_index_']]
    sbjInfo["Mean Test AUC 256Hz"] = cv['mean_test_AUC'][cv['best_index_']]
    sbjInfo["Std Test AUC 256Hz"] = cv['std_test_AUC'][cv['best_index_']]

    sbjs_info.append(sbjInfo)
    acc_comp.append((cv['mean_test_Accuracy'][cv['best_index_']], sbj, sex[0]))

ordScores = multi_sort(acc_comp)[0]

In [None]:
best_subs = [score[1] for score in ordScores[:3]]
mid_subs = [score[1] for score in ordScores[int(len(ordScores) / 2) - 1: int(len(ordScores) / 2) + 2]]
worst_subs = [score[1] for score in ordScores[-3:]]

subs_to_analyze = {'Good': best_subs, 'Mean': mid_subs, 'Bad': worst_subs}

# Function to Automate Calculation

In [None]:
import matplotlib.pyplot as plt

from sklearn.preprocessing import MinMaxScaler, RobustScaler
from scipy.stats import ks_2samp
from scipy.spatial.distance import squareform

from tf_keras_vis import layercam, gradcam_plus_plus
from tf_keras_vis.utils.model_modifiers import ReplaceToLinear
from tf_keras_vis.utils.scores import CategoricalScore, InactiveScore

def calculate_topoplots_by_band(db_: str = '../input/brain-mediators-of-pain-motor/', fs_: float = 256.0, n_subjects_: int = 51, freq_bank = [4, 60],
                                folder_name_: str = 'GFC', seed_: int = 23, subs_to_analyze = [], cam_type = gradcam_plus_plus.GradcamPlusPlus, gamma_upper: float = 60.0):
    tf.random.set_seed(seed_)
    subjects_ = np.arange(n_subjects_) + 1
    db = db_
    num_class = 2

    load_args = dict(db = db,
                f_bank = np.asarray([freq_bank]),
                vwt = np.asarray([[0.5,2.5]]),
                new_fs = fs_)

    scores_path  = os.path.join(os.getcwd(), folder_name_)
    
    start = True
    
    for type_ in list(subs_to_analyze.keys()):
        for sbj in subs_to_analyze[type_]:
            if sbj == 18:
                continue
                
            print(f"\n\nStarting calculation for sbj: {sbj}")

            load_args['sbj'] = sbj

            X_train, y_train, age, sex, fs, info = load_PAIN(**load_args)

            X_train = X_train[..., np.newaxis]
            Y_train = tf.keras.utils.to_categorical(y_train,num_classes=num_class)
            y_train = np.array([y[0] for y in y_train])

            channels = info['ch_names']

            model_weights_path = os.path.join(scores_path, f"Subject{sbj}_weights.h5")
            params_path = os.path.join(scores_path, f"Subject{sbj}.p")

            with open(params_path, 'rb') as f:
                cv = pickle.load(f)

            nFilters_ = cv['params'][cv['best_index_']]['filters']
            kernel_time_ = cv['params'][cv['best_index_']]['kernel_time_1']

            model = GFC_triu_net_avg(nb_classes = num_class,
                    Chans = X_train.shape[1],
                    Samples = X_train.shape[2],
                    filters = int(nFilters_), 
                    kernel_time_1 =int(kernel_time_))

            model.load_weights(model_weights_path)

            pain_scores = [CategoricalScore(list(np.zeros_like(y_train)))]
            rest_scores = [CategoricalScore(list(np.ones_like(y_train)))]

            no_softmax_model = tf.keras.Model(inputs=model.input, outputs=model.layers[-2].output)

            cam_calc = cam_type(model, model_modifier=ReplaceToLinear(), clone=False)

            Penultimate_layer = model.layers[-2]

            pain_cam = cam_calc(score = pain_scores,
                          seed_input = X_train,
                          penultimate_layer = Penultimate_layer,
                          seek_penultimate_conv_layer = True,
                          normalize_cam = False,
                          expand_cam = True)

            rest_cam = cam_calc(score = rest_scores,
                          seed_input = X_train,
                          penultimate_layer = Penultimate_layer,
                          seek_penultimate_conv_layer = True,
                          normalize_cam = False,
                          expand_cam = True)

            max_val_cams = np.max(np.concatenate((pain_cam[...,np.newaxis],
                                                          rest_cam[...,np.newaxis]), axis=-1), axis=(1,2,3))[:,np.newaxis,np.newaxis]

            pain_cam/=max_val_cams
            rest_cam/=max_val_cams

            theta_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[4., 8.]]))
            alpha_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[8., 13.]]))
            beta_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[13., 32.]]))
            gamma_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[30., gamma_upper]]))

            pain_signals = np.array(X_train[y_train == 0][...,0])
            rest_signals = np.array(X_train[y_train == 1][...,0])
            
            experiment_cfg = {
                "Pain": {
                    "EEG": {
                        "Signals": pain_signals,
                        "Results": {}
                    }, 
                    "Cams": {
                        "Signals": pain_cam,
                        "Results": {}
                    }},
                "Rest": {
                    "EEG": {
                        "Signals": rest_signals,
                        "Results": {}
                    }, 
                    "Cams": {
                        "Signals": rest_cam,
                        "Results": {}
                    }
                },
            }
            
            for label in list(experiment_cfg.keys()):
                print(f"Calculating results for {label}")
                eeg = experiment_cfg[label]['EEG']['Signals']
                cams = experiment_cfg[label]['Cams']['Signals']
                
                mean_signals = np.sum(np.mean(eeg, axis = 0), axis=1)
                min_eeg = np.min(mean_signals, axis=0)
                mean_signals-=min_eeg
                max_eeg = np.max(mean_signals, axis=0)
                mean_signals/=max_eeg
                
                experiment_cfg[label]['EEG']['Results']["Full Band"] = mean_signals

                mean_cams = np.sum(np.mean(cams, axis = 0), axis=1)
                min_cam = np.min(mean_cams, axis=0)
                mean_cams-=min_cam
                max_cam = np.max(mean_cams, axis=0)
                mean_cams/=max_cam
                
                experiment_cfg[label]['Cams']['Results']["Full Band"] = mean_cams

                signal_theta = np.sum(np.mean(theta_filter.transform(eeg)[...,0,0], axis = 0), axis=1)
                signal_alpha = np.sum(np.mean(alpha_filter.transform(eeg)[...,0,0], axis = 0), axis=1)
                signal_beta  = np.sum(np.mean(beta_filter.transform(eeg)[...,0,0], axis = 0), axis=1)
                signal_gamma = np.sum(np.mean(gamma_filter.transform(eeg)[...,0,0], axis = 0), axis=1)
                min_signal = np.min(np.concatenate((signal_theta, signal_alpha, signal_beta, signal_gamma), axis=0))

                signal_theta -= min_signal
                signal_alpha -= min_signal
                signal_beta  -= min_signal
                signal_gamma -= min_signal
                max_signal = np.max(np.concatenate((signal_theta, signal_alpha, signal_beta, signal_gamma), axis=0))     

                signal_theta /= max_signal
                signal_alpha /= max_signal
                signal_beta  /= max_signal
                signal_gamma /= max_signal
                
                experiment_cfg[label]['EEG']['Results']["Theta"] = signal_theta
                experiment_cfg[label]['EEG']['Results']["Alpha"] = signal_alpha
                experiment_cfg[label]['EEG']['Results']["Beta"]  = signal_beta
                experiment_cfg[label]['EEG']['Results']["Gamma"] = signal_gamma

                cam_theta = np.sum(np.mean(theta_filter.transform(cams)[...,0,0], axis = 0), axis=1)
                cam_alpha = np.sum(np.mean(alpha_filter.transform(cams)[...,0,0], axis = 0), axis=1)
                cam_beta  = np.sum(np.mean(beta_filter.transform(cams)[...,0,0], axis = 0), axis=1)
                cam_gamma = np.sum(np.mean(gamma_filter.transform(cams)[...,0,0], axis = 0), axis=1)
                min_cam = np.min(np.concatenate((cam_theta, cam_alpha, cam_beta, cam_gamma), axis=0))

                cam_theta -= min_cam
                cam_alpha -= min_cam
                cam_beta  -= min_cam
                cam_gamma -= min_cam
                max_cam = np.max(np.concatenate((cam_theta, cam_alpha, cam_beta, cam_gamma), axis=0))

                cam_theta /= max_cam
                cam_alpha /= max_cam
                cam_beta  /= max_cam
                cam_gamma /= max_cam
                
                experiment_cfg[label]['Cams']['Results']["Theta"] = cam_theta
                experiment_cfg[label]['Cams']['Results']["Alpha"] = cam_alpha
                experiment_cfg[label]['Cams']['Results']["Beta"]  = cam_beta
                experiment_cfg[label]['Cams']['Results']["Gamma"] = cam_gamma

            fig, ax = plt.subplot_mosaic([['ul', 'um', 'um1', 'um2', 'ur'], ['ll', 'lm', 'lm1', 'lm2', 'lr'], ['ml', 'mm', 'mm1', 'mm2', 'mr'], ['bl', 'bm', 'bm1', 'bm2', 'br']], figsize=(16, 8))
            rows = [('u', 'l'), ('m', 'b')]
            
            for label in list(experiment_cfg.keys()):
                print(f"Plotting for {label}")
                for row in rows:
                    if start:
                        im_00 = topoplot(experiment_cfg[label]['EEG']['Results']['Theta'], channels, cmap='viridis', ax=ax[row[0] + 'l'], vlim=(0.0, 1.0), show=False)
                        start = False
                    else:
                        topoplot(experiment_cfg[label]['EEG']['Results']['Theta'], channels, cmap='viridis', ax=ax[row[0] + 'l'], vlim=(0.0, 1.0), show=False)

                    topoplot(experiment_cfg[label]['EEG']['Results']['Alpha'], channels, cmap='viridis', ax=ax[row[0] + 'm'], vlim=(0.0, 1.0), show=False)
                    topoplot(experiment_cfg[label]['EEG']['Results']['Beta'],  channels, cmap='viridis', ax=ax[row[0] + 'm1'], vlim=(0.0, 1.0), show=False)
                    topoplot(experiment_cfg[label]['EEG']['Results']['Gamma'], channels, cmap='viridis', ax=ax[row[0] + 'm2'], vlim=(0.0, 1.0), show=False)
                    topoplot(experiment_cfg[label]['EEG']['Results']['Full Band'], channels, cmap='viridis', ax=ax[row[0] + 'r'], vlim=(0.0, 1.0), show=False)

                    topoplot(experiment_cfg[label]['Cams']['Results']['Theta'], channels, cmap='viridis', ax=ax[row[1] + 'l'], vlim=(0.0, 1.0), show=False)
                    topoplot(experiment_cfg[label]['Cams']['Results']['Alpha'], channels, cmap='viridis', ax=ax[row[1] + 'm'], vlim=(0.0, 1.0), show=False)
                    topoplot(experiment_cfg[label]['Cams']['Results']['Beta'], channels, cmap='viridis',  ax=ax[row[1] + 'm1'], vlim=(0.0, 1.0), show=False)
                    topoplot(experiment_cfg[label]['Cams']['Results']['Gamma'], channels, cmap='viridis', ax=ax[row[1] + 'm2'], vlim=(0.0, 1.0), show=False)
                    topoplot(experiment_cfg[label]['Cams']['Results']['Full Band'], channels, cmap='viridis', ax=ax[row[1] + 'r'], vlim=(0.0, 1.0), show=False)


            ax['ll'].set_ylabel(f'CAMs Pain', size=20)
            ax['ul'].set_ylabel(f'EEG Pain', size=20)
            ax['bl'].set_ylabel(f'CAMs Rest', size=20)
            ax['ml'].set_ylabel(f'EEG Rest', size=20)
            ax['ul'].set_title(f'Theta', size=20)
            ax['um'].set_title(f'Alpha', size=20)
            ax['um1'].set_title(f'Beta', size=20)
            ax['um2'].set_title(f'Gamma', size=20)
            ax['ur'].set_title(f'Full Band', size=20)

            fig.suptitle(f'{type_} - Subject: {sbj} - Sex: {sex[0].upper()} - Age: {age[0]} - EEG vs Cams{" "*70}')

            fig.tight_layout()
            fig.subplots_adjust(right=0.75)
            cbar_ax = fig.add_axes([0.75, 0.05, 0.02, 0.8])
            fig.colorbar(im_00[0], cax=cbar_ax)

            plt.show()

### Calculate for FS = 256 Hz and Freq Bank = [4, 60]

### KCS-GFCNet

In [None]:
import matplotlib.pyplot as plt

from sklearn.preprocessing import MinMaxScaler, RobustScaler
from scipy.stats import ks_2samp
from scipy.spatial.distance import squareform

from tf_keras_vis import layercam, gradcam_plus_plus
from tf_keras_vis.utils.model_modifiers import ReplaceToLinear
from tf_keras_vis.utils.scores import CategoricalScore, InactiveScore

tf.random.set_seed(seed)

subjects = np.arange(n_subjects)+1

db = '../input/brain-mediators-of-pain-motor/'

num_class = 2

load_args = dict(db = db,
            f_bank = np.asarray([[4., 60.]]),
            vwt = np.asarray([[0.5,2.5]]),
            new_fs = 256.0)

scores_path = os.path.join(os.getcwd(), 'GFC_Motor_256_Gamma60Hz')

for type_ in list(subs_to_analyze.keys()):
    for sbj in subs_to_analyze[type_]:
        if sbj == 18:
            continue

        load_args['sbj'] = sbj

        X_train, y_train, age, sex, fs, info = load_PAIN(**load_args)

        X_train = X_train[..., np.newaxis]
        Y_train = tf.keras.utils.to_categorical(y_train,num_classes=num_class)
        y_train = np.array([y[0] for y in y_train])

        channels = info['ch_names']

        model_weights_path = os.path.join(scores_path, f"Subject{sbj}_weights.h5")
        params_path = os.path.join(scores_path, f"Subject{sbj}.p")
    
        with open(params_path, 'rb') as f:
            cv = pickle.load(f)
    
        nFilters_ = cv['params'][cv['best_index_']]['filters']
        kernel_time_ = cv['params'][cv['best_index_']]['kernel_time_1']
    
        model = GFC_triu_net_avg(nb_classes = num_class,
                Chans = X_train.shape[1],
                Samples = X_train.shape[2],
                filters = int(nFilters_), 
                kernel_time_1 =int(kernel_time_))
    
        model.load_weights(model_weights_path)
    
        pain_scores = [CategoricalScore(list(np.zeros_like(y_train)))]
        rest_scores = [CategoricalScore(list(np.ones_like(y_train)))]
    
        no_softmax_model = tf.keras.Model(inputs=model.input, outputs=model.layers[-2].output)
        
        lc = gradcam_plus_plus.GradcamPlusPlus(model, model_modifier=ReplaceToLinear(), clone=False)
    
        Penultimate_layer = model.layers[-2]
    
        pain_cam = lc(score = pain_scores,
                      seed_input = X_train,
                      penultimate_layer = Penultimate_layer,
                      seek_penultimate_conv_layer = True,
                      normalize_cam = False,
                      expand_cam = True)

        rest_cam = lc(score = rest_scores,
                      seed_input = X_train,
                      penultimate_layer = Penultimate_layer,
                      seek_penultimate_conv_layer = True,
                      normalize_cam = False,
                      expand_cam = True)

        max_val_cams = np.max(np.concatenate((pain_cam[...,np.newaxis],
                                                      rest_cam[...,np.newaxis]), axis=-1), axis=(1,2,3))[:,np.newaxis,np.newaxis]

        pain_cam/=max_val_cams
        rest_cam/=max_val_cams

        theta_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[4., 8.]]))
        alpha_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[8., 13.]]))
        beta_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[13., 32.]]))
        gamma_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[30., 60.]]))

        pain_signals = np.array(X_train[y_train == 0][...,0])
        rest_signals = np.array(X_train[y_train == 1][...,0])

        ##########
        mean_pain_signals = np.sum(np.mean(pain_signals, axis = 0), axis=1)
        min_mps = np.min(mean_pain_signals, axis=0)
        mean_pain_signals-=min_mps
        max_mps = np.max(mean_pain_signals, axis=0)
        mean_pain_signals/=max_mps

        mean_pain_cams = np.sum(np.mean(pain_cam, axis = 0), axis=1)
        min_mpc = np.min(mean_pain_cams, axis=0)
        mean_pain_cams-=min_mpc
        max_mpc = np.max(mean_pain_cams, axis=0)
        mean_pain_cams/=max_mpc

        mean_rest_signals = np.sum(np.mean(rest_signals, axis = 0), axis=1)
        min_mrs = np.min(mean_rest_signals, axis=0)
        mean_rest_signals-=min_mrs
        max_mrs = np.max(mean_rest_signals, axis=0)
        mean_rest_signals/=max_mrs

        mean_rest_cams = np.sum(np.mean(rest_cam, axis = 0), axis=1)
        min_mrc = np.min(mean_rest_cams, axis=0)
        mean_rest_cams-=min_mrc
        max_mrc = np.max(mean_rest_cams, axis=0)
        mean_rest_cams/=max_mrc   
        ##########

        pain_signal_theta = np.sum(np.mean(theta_filter.transform(pain_signals)[...,0,0], axis = 0), axis=1)
        pain_signal_alpha = np.sum(np.mean(alpha_filter.transform(pain_signals)[...,0,0], axis = 0), axis=1)
        pain_signal_beta = np.sum(np.mean(beta_filter.transform(pain_signals)[...,0,0], axis = 0), axis=1)
        pain_signal_gamma = np.sum(np.mean(gamma_filter.transform(pain_signals)[...,0,0], axis = 0), axis=1)
        pain_min_signal = np.min(np.concatenate((pain_signal_theta, pain_signal_alpha, pain_signal_beta, pain_signal_gamma), axis=0))

        pain_signal_theta -= pain_min_signal
        pain_signal_alpha -= pain_min_signal
        pain_signal_beta -= pain_min_signal
        pain_signal_gamma -= pain_min_signal
        pain_max_signal = np.max(np.concatenate((pain_signal_theta, pain_signal_alpha, pain_signal_beta, pain_signal_gamma), axis=0))     

        pain_signal_theta /= pain_max_signal
        pain_signal_alpha /= pain_max_signal
        pain_signal_beta /= pain_max_signal
        pain_signal_gamma /= pain_max_signal

        ####
        rest_signal_theta = np.sum(np.mean(theta_filter.transform(rest_signals)[...,0,0], axis = 0), axis=1)
        rest_signal_alpha = np.sum(np.mean(alpha_filter.transform(rest_signals)[...,0,0], axis = 0), axis=1)
        rest_signal_beta = np.sum(np.mean(beta_filter.transform(rest_signals)[...,0,0], axis = 0), axis=1)
        rest_signal_gamma = np.sum(np.mean(gamma_filter.transform(rest_signals)[...,0,0], axis = 0), axis=1)
        rest_min_signal = np.min(np.concatenate((rest_signal_theta, rest_signal_alpha, rest_signal_beta, rest_signal_gamma), axis=0))

        rest_signal_theta -= rest_min_signal
        rest_signal_alpha -= rest_min_signal
        rest_signal_beta -= rest_min_signal
        rest_signal_gamma -= rest_min_signal
        rest_max_signal = np.max(np.concatenate((rest_signal_theta, rest_signal_alpha, rest_signal_beta, rest_signal_gamma), axis=0))     

        rest_signal_theta /= rest_max_signal
        rest_signal_alpha /= rest_max_signal
        rest_signal_beta /= rest_max_signal
        rest_signal_gamma /= rest_max_signal
        ####

        pain_cam_theta = np.sum(np.mean(theta_filter.transform(pain_cam)[...,0,0], axis = 0), axis=1)
        pain_cam_alpha = np.sum(np.mean(alpha_filter.transform(pain_cam)[...,0,0], axis = 0), axis=1)
        pain_cam_beta = np.sum(np.mean(beta_filter.transform(pain_cam)[...,0,0], axis = 0), axis=1)
        pain_cam_gamma = np.sum(np.mean(gamma_filter.transform(pain_cam)[...,0,0], axis = 0), axis=1)
        pain_min_cam = np.min(np.concatenate((pain_cam_theta, pain_cam_alpha, pain_cam_beta, pain_cam_gamma), axis=0))

        pain_cam_theta -= pain_min_cam
        pain_cam_alpha -= pain_min_cam
        pain_cam_beta -= pain_min_cam
        pain_cam_gamma -= pain_min_cam
        pain_max_cam = np.max(np.concatenate((pain_cam_theta, pain_cam_alpha, pain_cam_beta, pain_cam_gamma), axis=0))

        pain_cam_theta /= pain_max_cam
        pain_cam_alpha /= pain_max_cam
        pain_cam_beta /= pain_max_cam
        pain_cam_gamma /= pain_max_cam

        ####
        rest_cam_theta = np.sum(np.mean(theta_filter.transform(rest_cam)[...,0,0], axis = 0), axis=1)
        rest_cam_alpha = np.sum(np.mean(alpha_filter.transform(rest_cam)[...,0,0], axis = 0), axis=1)
        rest_cam_beta = np.sum(np.mean(beta_filter.transform(rest_cam)[...,0,0], axis = 0), axis=1)
        rest_cam_gamma = np.sum(np.mean(gamma_filter.transform(rest_cam)[...,0,0], axis = 0), axis=1)
        rest_min_cam = np.min(np.concatenate((rest_cam_theta, rest_cam_alpha, rest_cam_beta, rest_cam_gamma), axis=0))

        rest_cam_theta -= rest_min_cam
        rest_cam_alpha -= rest_min_cam
        rest_cam_beta -= rest_min_cam
        rest_cam_gamma -= rest_min_cam
        rest_max_cam = np.max(np.concatenate((rest_cam_theta, rest_cam_alpha, rest_cam_beta), axis=0))

        rest_cam_theta /= rest_max_cam
        rest_cam_alpha /= rest_max_cam
        rest_cam_beta /= rest_max_cam
        rest_cam_gamma /= rest_max_cam
        ####

        fig, ax = plt.subplot_mosaic([['ul', 'um', 'um1', 'um2', 'ur'], ['ll', 'lm', 'lm1', 'lm2', 'lr'], ['ml', 'mm', 'mm1', 'mm2', 'mr'], ['bl', 'bm', 'bm1', 'bm2', 'br']], figsize=(16, 12), dpi=120)

        im_00 = topoplot(pain_signal_theta, channels, cmap='viridis', ax=ax['ul'], vlim=(0.0, 1.0), show=False)
        topoplot(pain_signal_alpha, channels, cmap='viridis', ax=ax['um'], vlim=(0.0, 1.0), show=False)
        topoplot(pain_cam_beta, channels, cmap='viridis', ax=ax['um1'], vlim=(0.0, 1.0), show=False)
        topoplot(pain_signal_gamma, channels, cmap='viridis', ax=ax['um2'], vlim=(0.0, 1.0), show=False)
        topoplot(mean_pain_signals, channels, cmap='viridis', ax=ax['ur'], vlim=(0.0, 1.0), show=False)

        topoplot(rest_signal_theta, channels, cmap='viridis', ax=ax['ml'], vlim=(0.0, 1.0), show=False)
        topoplot(rest_signal_alpha, channels, cmap='viridis', ax=ax['mm'], vlim=(0.0, 1.0), show=False)
        topoplot(rest_cam_beta, channels, cmap='viridis', ax=ax['mm1'], vlim=(0.0, 1.0), show=False)
        topoplot(rest_signal_gamma, channels, cmap='viridis', ax=ax['mm2'], vlim=(0.0, 1.0), show=False)
        topoplot(mean_pain_cams, channels, cmap='viridis', ax=ax['mr'], vlim=(0.0, 1.0), show=False)
        
        topoplot(pain_cam_theta, channels, cmap='viridis', ax=ax['ll'], vlim=(0.0, 1.0), show=False)
        topoplot(pain_cam_alpha, channels, cmap='viridis', ax=ax['lm'], vlim=(0.0, 1.0), show=False)
        topoplot(pain_signal_beta, channels, cmap='viridis', ax=ax['lm1'], vlim=(0.0, 1.0), show=False)
        topoplot(pain_cam_gamma, channels, cmap='viridis', ax=ax['lm2'], vlim=(0.0, 1.0), show=False)
        topoplot(mean_rest_signals, channels, cmap='viridis', ax=ax['lr'], vlim=(0.0, 1.0), show=False)

        topoplot(rest_cam_theta, channels, cmap='viridis', ax=ax['bl'], vlim=(0.0, 1.0), show=False)
        topoplot(rest_cam_alpha, channels, cmap='viridis', ax=ax['bm'], vlim=(0.0, 1.0), show=False)
        topoplot(rest_cam_beta, channels, cmap='viridis', ax=ax['bm1'], vlim=(0.0, 1.0), show=False)
        topoplot(rest_cam_gamma, channels, cmap='viridis', ax=ax['bm2'], vlim=(0.0, 1.0), show=False)
        topoplot(mean_rest_cams, channels, cmap='viridis', ax=ax['br'], vlim=(0.0, 1.0), show=False)

        ax['ll'].set_ylabel(f'CAMs Pain', size=20)
        ax['ul'].set_ylabel(f'EEG Pain', size=20)
        ax['bl'].set_ylabel(f'CAMs Rest', size=20)
        ax['ml'].set_ylabel(f'EEG Rest', size=20)
        ax['ul'].set_title(f'Theta', size=20)
        ax['um'].set_title(f'Alpha', size=20)
        ax['um1'].set_title(f'Beta', size=20)
        ax['um2'].set_title(f'Gamma', size=20)
        ax['ur'].set_title(f'Full Band', size=20)

        fig.suptitle(f'{type_} - Subject: {sbj} - Sex: {sex[0].upper()} - Age: {age[0]} - EEG vs Cams{" "*70}')

        fig.tight_layout()
        fig.subplots_adjust(right=0.75)
        cbar_ax = fig.add_axes([0.75, 0.05, 0.02, 0.8])
        fig.colorbar(im_00[0], cax=cbar_ax)
        plt.savefig(f'Subject{sbj}Cams_{type_}.png')

        plt.show()

In [None]:
import matplotlib.pyplot as plt

from sklearn.preprocessing import MinMaxScaler, RobustScaler
from scipy.stats import ks_2samp
from scipy.spatial.distance import squareform

from tf_keras_vis import layercam, gradcam_plus_plus
from tf_keras_vis.utils.model_modifiers import ReplaceToLinear
from tf_keras_vis.utils.scores import CategoricalScore, InactiveScore

tf.random.set_seed(seed)

subjects = np.arange(n_subjects)+1

db = '../input/brain-mediators-of-pain-motor/'

num_class = 2

load_args = dict(db = db,
            f_bank = np.asarray([[4., 60.]]),
            vwt = np.asarray([[0.5,2.5]]),
            new_fs = 256.0)

scores_path = os.path.join(os.getcwd(), 'GFC_Motor_256_Gamma60Hz')

for type_ in list(subs_to_analyze.keys()):
    for sbj in subs_to_analyze[type_]:
        if sbj == 18:
            continue

        load_args['sbj'] = sbj

        X_train, y_train, age, sex, fs, info = load_PAIN(**load_args)

        X_train = X_train[..., np.newaxis]
        Y_train = tf.keras.utils.to_categorical(y_train,num_classes=num_class)
        y_train = np.array([y[0] for y in y_train])

        channels = info['ch_names']

        model_weights_path = os.path.join(scores_path, f"Subject{sbj}_weights.h5")
        params_path = os.path.join(scores_path, f"Subject{sbj}.p")
    
        with open(params_path, 'rb') as f:
            cv = pickle.load(f)
    
        nFilters_ = cv['params'][cv['best_index_']]['filters']
        kernel_time_ = cv['params'][cv['best_index_']]['kernel_time_1']
    
        model = GFC_triu_net_avg(nb_classes = num_class,
                Chans = X_train.shape[1],
                Samples = X_train.shape[2],
                filters = int(nFilters_), 
                kernel_time_1 =int(kernel_time_))
    
        model.load_weights(model_weights_path)
    
        pain_scores = [CategoricalScore(list(np.zeros_like(y_train)))]
        rest_scores = [CategoricalScore(list(np.ones_like(y_train)))]
    
        no_softmax_model = tf.keras.Model(inputs=model.input, outputs=model.layers[-2].output)
        
        lc = gradcam_plus_plus.GradcamPlusPlus(model, model_modifier=ReplaceToLinear(), clone=False)
        # lc = layercam.Layercam(model, model_modifier=ReplaceToLinear(), clone=False)
    
        Penultimate_layer = model.layers[-2]
    
        pain_cam = lc(score = pain_scores,
                      seed_input = X_train,
                      penultimate_layer = Penultimate_layer,
                      seek_penultimate_conv_layer = True,
                      normalize_cam = False,
                      expand_cam = True)

        rest_cam = lc(score = rest_scores,
                      seed_input = X_train,
                      penultimate_layer = Penultimate_layer,
                      seek_penultimate_conv_layer = True,
                      normalize_cam = False,
                      expand_cam = True)

        max_val_cams = np.max(np.concatenate((pain_cam[...,np.newaxis],
                                                      rest_cam[...,np.newaxis]), axis=-1), axis=(1,2,3))[:,np.newaxis,np.newaxis]

        pain_cam/=max_val_cams
        rest_cam/=max_val_cams

        theta_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[4., 8.]]))
        alpha_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[8., 13.]]))
        beta_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[13., 32.]]))
        gamma_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[30., 60.]]))

        pain_signals = np.array(X_train[y_train == 0][...,0])
        rest_signals = np.array(X_train[y_train == 1][...,0])

        ##########
        mean_pain_signals = np.sum(np.mean(pain_signals, axis = 0), axis=1)

        mean_pain_cams = np.sum(np.mean(pain_cam, axis = 0), axis=1)

        mean_rest_signals = np.sum(np.mean(rest_signals, axis = 0), axis=1)

        mean_rest_cams = np.sum(np.mean(rest_cam, axis = 0), axis=1)

        min_sig = np.min(np.concatenate((mean_pain_signals, mean_rest_signals), axis=0))
        mean_pain_signals-=min_sig
        mean_rest_signals-=min_sig
        max_sig = np.max(np.concatenate((mean_pain_signals, mean_rest_signals), axis=0))
        mean_pain_signals/=max_sig
        mean_rest_signals/=max_sig
        
        min_cam = np.min(np.concatenate((mean_pain_cams, mean_rest_cams), axis=0))
        mean_pain_cams-=min_cam
        mean_rest_cams-=min_cam
        max_cam = np.max(np.concatenate((mean_pain_cams, mean_rest_cams), axis=0))
        mean_pain_cams/=max_cam
        mean_rest_cams/=max_cam

        pain_signal_theta = np.sum(np.mean(theta_filter.transform(pain_signals)[...,0,0], axis = 0), axis=1)
        pain_signal_alpha = np.sum(np.mean(alpha_filter.transform(pain_signals)[...,0,0], axis = 0), axis=1)
        pain_signal_beta = np.sum(np.mean(beta_filter.transform(pain_signals)[...,0,0], axis = 0), axis=1)
        pain_signal_gamma = np.sum(np.mean(gamma_filter.transform(pain_signals)[...,0,0], axis = 0), axis=1)

        rest_signal_theta = np.sum(np.mean(theta_filter.transform(rest_signals)[...,0,0], axis = 0), axis=1)
        rest_signal_alpha = np.sum(np.mean(alpha_filter.transform(rest_signals)[...,0,0], axis = 0), axis=1)
        rest_signal_beta = np.sum(np.mean(beta_filter.transform(rest_signals)[...,0,0], axis = 0), axis=1)
        rest_signal_gamma = np.sum(np.mean(gamma_filter.transform(rest_signals)[...,0,0], axis = 0), axis=1)

        pain_cam_theta = np.sum(np.mean(theta_filter.transform(pain_cam)[...,0,0], axis = 0), axis=1)
        pain_cam_alpha = np.sum(np.mean(alpha_filter.transform(pain_cam)[...,0,0], axis = 0), axis=1)
        pain_cam_beta = np.sum(np.mean(beta_filter.transform(pain_cam)[...,0,0], axis = 0), axis=1)
        pain_cam_gamma = np.sum(np.mean(gamma_filter.transform(pain_cam)[...,0,0], axis = 0), axis=1)

        rest_cam_theta = np.sum(np.mean(theta_filter.transform(rest_cam)[...,0,0], axis = 0), axis=1)
        rest_cam_alpha = np.sum(np.mean(alpha_filter.transform(rest_cam)[...,0,0], axis = 0), axis=1)
        rest_cam_beta = np.sum(np.mean(beta_filter.transform(rest_cam)[...,0,0], axis = 0), axis=1)
        rest_cam_gamma = np.sum(np.mean(gamma_filter.transform(rest_cam)[...,0,0], axis = 0), axis=1)

        min_val_cams = np.min(np.concatenate((pain_cam_theta, pain_cam_alpha, pain_cam_beta, pain_cam_gamma, rest_cam_theta, rest_cam_alpha, rest_cam_beta, rest_cam_gamma), axis=0))
        pain_cam_theta-=min_val_cams
        pain_cam_alpha-=min_val_cams
        pain_cam_beta-=min_val_cams
        pain_cam_gamma-=min_val_cams
        rest_cam_theta-=min_val_cams
        rest_cam_alpha-=min_val_cams
        rest_cam_beta-=min_val_cams
        rest_cam_gamma-=min_val_cams
        max_val_cams = np.max(np.concatenate((pain_cam_theta, pain_cam_alpha, pain_cam_beta, pain_cam_gamma, rest_cam_theta, rest_cam_alpha, rest_cam_beta, rest_cam_gamma), axis=0))
        pain_cam_theta/=max_val_cams
        pain_cam_alpha/=max_val_cams
        pain_cam_beta/=max_val_cams
        pain_cam_gamma/=max_val_cams
        rest_cam_theta/=max_val_cams
        rest_cam_alpha/=max_val_cams
        rest_cam_beta/=max_val_cams
        rest_cam_gamma/=max_val_cams
        
        min_val_signals = np.min(np.concatenate((pain_signal_theta, pain_signal_alpha, pain_signal_beta, pain_signal_gamma, rest_signal_theta, rest_signal_alpha, rest_signal_beta, rest_signal_gamma), axis=0))
        pain_signal_theta-=min_val_signals
        pain_signal_alpha-=min_val_signals
        pain_signal_beta-=min_val_signals
        pain_signal_gamma-=min_val_signals
        rest_signal_theta-=min_val_signals
        rest_signal_alpha-=min_val_signals
        rest_signal_beta-=min_val_signals
        rest_signal_gamma-=min_val_signals
        max_val_signals = np.max(np.concatenate((pain_signal_theta, pain_signal_alpha, pain_signal_beta, pain_signal_gamma, rest_signal_theta, rest_signal_alpha, rest_signal_beta, rest_signal_gamma), axis=0))
        pain_signal_theta/=max_val_signals
        pain_signal_alpha/=max_val_signals
        pain_signal_beta/=max_val_signals
        pain_signal_gamma/=max_val_signals
        rest_signal_theta/=max_val_signals
        rest_signal_alpha/=max_val_signals
        rest_signal_beta/=max_val_signals
        rest_signal_gamma/=max_val_signals
        ####

        fig, ax = plt.subplot_mosaic([['ul', 'um', 'um1', 'um2', 'ur'], ['ll', 'lm', 'lm1', 'lm2', 'lr'], ['ml', 'mm', 'mm1', 'mm2', 'mr'], ['bl', 'bm', 'bm1', 'bm2', 'br']], figsize=(16, 12), dpi=120)

        im_00 = topoplot(pain_signal_theta, channels, cmap='viridis', ax=ax['ul'], vlim=(0.0, 1.0), show=False)
        topoplot(pain_signal_alpha, channels, cmap='viridis', ax=ax['um'], vlim=(0.0, 1.0), show=False)
        topoplot(pain_cam_beta, channels, cmap='viridis', ax=ax['um1'], vlim=(0.0, 1.0), show=False)
        topoplot(pain_signal_gamma, channels, cmap='viridis', ax=ax['um2'], vlim=(0.0, 1.0), show=False)
        topoplot(mean_pain_signals, channels, cmap='viridis', ax=ax['ur'], vlim=(0.0, 1.0), show=False)

        topoplot(rest_signal_theta, channels, cmap='viridis', ax=ax['ml'], vlim=(0.0, 1.0), show=False)
        topoplot(rest_signal_alpha, channels, cmap='viridis', ax=ax['mm'], vlim=(0.0, 1.0), show=False)
        topoplot(rest_cam_beta, channels, cmap='viridis', ax=ax['mm1'], vlim=(0.0, 1.0), show=False)
        topoplot(rest_signal_gamma, channels, cmap='viridis', ax=ax['mm2'], vlim=(0.0, 1.0), show=False)
        topoplot(mean_pain_cams, channels, cmap='viridis', ax=ax['mr'], vlim=(0.0, 1.0), show=False)
        
        topoplot(pain_cam_theta, channels, cmap='viridis', ax=ax['ll'], vlim=(0.0, 1.0), show=False)
        topoplot(pain_cam_alpha, channels, cmap='viridis', ax=ax['lm'], vlim=(0.0, 1.0), show=False)
        topoplot(pain_signal_beta, channels, cmap='viridis', ax=ax['lm1'], vlim=(0.0, 1.0), show=False)
        topoplot(pain_cam_gamma, channels, cmap='viridis', ax=ax['lm2'], vlim=(0.0, 1.0), show=False)
        topoplot(mean_rest_signals, channels, cmap='viridis', ax=ax['lr'], vlim=(0.0, 1.0), show=False)

        topoplot(rest_cam_theta, channels, cmap='viridis', ax=ax['bl'], vlim=(0.0, 1.0), show=False)
        topoplot(rest_cam_alpha, channels, cmap='viridis', ax=ax['bm'], vlim=(0.0, 1.0), show=False)
        topoplot(rest_cam_beta, channels, cmap='viridis', ax=ax['bm1'], vlim=(0.0, 1.0), show=False)
        topoplot(rest_cam_gamma, channels, cmap='viridis', ax=ax['bm2'], vlim=(0.0, 1.0), show=False)
        topoplot(mean_rest_cams, channels, cmap='viridis', ax=ax['br'], vlim=(0.0, 1.0), show=False)

        ax['ll'].set_ylabel(f'CAMs Pain', size=20)
        ax['ul'].set_ylabel(f'EEG Pain', size=20)
        ax['bl'].set_ylabel(f'CAMs Rest', size=20)
        ax['ml'].set_ylabel(f'EEG Rest', size=20)
        ax['ul'].set_title(f'Theta', size=20)
        ax['um'].set_title(f'Alpha', size=20)
        ax['um1'].set_title(f'Beta', size=20)
        ax['um2'].set_title(f'Gamma', size=20)
        ax['ur'].set_title(f'Full Band', size=20)

        fig.suptitle(f'{type_} - Subject: {sbj} - Sex: {sex[0].upper()} - Age: {age[0]} - EEG vs Cams{" "*70}')

        fig.tight_layout()
        fig.subplots_adjust(right=0.75)
        cbar_ax = fig.add_axes([0.75, 0.05, 0.02, 0.8])
        fig.colorbar(im_00[0], cax=cbar_ax)
#         plt.savefig(f'Subject{sbj}Cams_{type_}.png')

        plt.show()

### Groups

In [None]:
import shutil

FILEID = "17kltykY7KrLKs0w2UFcMMY0CAL2ZZrhB"

contents = os.listdir(os.getcwd())

folder = 'Groups_Motor256_Gamma60Hz'

if 'Groups_Motor256_Gamma60Hz.zip' not in contents:
    !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='$FILEID -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id="$FILEID -O Groups_Motor256_Gamma60Hz.zip && rm -rf /tmp/cookies.txt
    !unzip Groups_Motor256_Gamma60Hz.zip
    os.makedirs(folder)
    
    contents = os.listdir(os.getcwd())

    for content in contents:
        if '.p' in content or '.h5' in content:
            shutil.move(os.path.join(os.getcwd(), content), os.path.join(os.getcwd(), folder, content))

In [None]:
groups = {
    "G3": {
        "Name": "Group3",
        "Subjects": [9, 15, 16, 20, 21, 23, 25, 27, 28, 29, 30],
        "Color": "Yellow"
    },
    "G2": {
        "Name": "Group2",
        "Subjects": [1, 4, 5, 6, 8, 10, 12, 13, 14, 17, 19, 26, 32, 33, 34, 42, 45, 46, 47, 49, 51],
        "Color": "Light Blue"
    },
    "G1": {
        "Name": "Group1",
        "Subjects": [2, 3, 7, 11, 24, 35, 36, 37, 38, 39, 40, 41, 44, 50],
        "Color": "Dark Blue"
    },
    "G4": {
        "Name": "Group4",
        "Subjects": [22, 31, 43, 48],
        "Color": "Brown"
    }
}

In [None]:
import matplotlib.pyplot as plt

from sklearn.preprocessing import MinMaxScaler, RobustScaler
from scipy.stats import ks_2samp
from scipy.spatial.distance import squareform

from tf_keras_vis import layercam, gradcam_plus_plus
from tf_keras_vis.utils.model_modifiers import ReplaceToLinear
from tf_keras_vis.utils.scores import CategoricalScore, InactiveScore

tf.random.set_seed(seed)

subjects = np.arange(n_subjects)+1

db = '../input/brain-mediators-of-pain-motor/'

num_class = 2

load_args = dict(db = db,
            f_bank = np.asarray([[4., 60.]]),
            vwt = np.asarray([[0.5,2.5]]),
            new_fs = 256.0)

scores_path = os.path.join(os.getcwd(), 'Groups_Motor256_Gamma60Hz')

cg = 1

gInds = {1: 'I', 2: 'II', 3: 'III', 4: 'IV', 5: 'V'}

groups_keys = list(groups.keys())

for group in groups_keys:
    group_name = groups[group]["Name"]

    group_subs = groups[group]["Subjects"]

    for sbj in group_subs:
        print(f"Loading subject: {sbj}\n")
        load_args['sbj'] = sbj 

        if (sbj == group_subs[0]):
            X_train, y_train, age, sex, fs, info = load_PAIN(**load_args)
        else:
            X_train_, y_train_, age, sex, fs, info = load_PAIN(**load_args)
            X_train = np.concatenate((X_train, X_train_), axis = 0)
            y_train = np.concatenate((y_train, y_train_), axis = 0)

    X_train = X_train[..., np.newaxis]
    Y_train = tf.keras.utils.to_categorical(y_train,num_classes=num_class)
    y_train = np.array([y[0] for y in y_train])

    channels = info['ch_names']

    model_weights_path = os.path.join(scores_path, f"{group_name}_weights.h5")
    params_path = os.path.join(scores_path, f"{group_name}.p")

    with open(params_path, 'rb') as f:
        cv = pickle.load(f)

    nFilters_ = cv['params'][cv['best_index_']]['filters']
    kernel_time_ = cv['params'][cv['best_index_']]['kernel_time_1']

    model = GFC_triu_net_avg(nb_classes = num_class,
            Chans = X_train.shape[1],
            Samples = X_train.shape[2],
            filters = int(nFilters_), 
            kernel_time_1 =int(kernel_time_))

    model.load_weights(model_weights_path)

    pain_scores = [CategoricalScore(list(np.zeros_like(y_train)))]
    rest_scores = [CategoricalScore(list(np.ones_like(y_train)))]

    no_softmax_model = tf.keras.Model(inputs=model.input, outputs=model.layers[-2].output)

#     lc = gradcam_plus_plus.GradcamPlusPlus(model, model_modifier=ReplaceToLinear(), clone=False)
    lc = layercam.Layercam(model, model_modifier=ReplaceToLinear(), clone=False)

    Penultimate_layer = model.layers[-2]

    pain_cam = lc(score = pain_scores,
                  seed_input = X_train,
                  penultimate_layer = Penultimate_layer,
                  seek_penultimate_conv_layer = True,
                  normalize_cam = False,
                  expand_cam = True)

    rest_cam = lc(score = rest_scores,
                  seed_input = X_train,
                  penultimate_layer = Penultimate_layer,
                  seek_penultimate_conv_layer = True,
                  normalize_cam = False,
                  expand_cam = True)

    max_val_cams = np.max(np.concatenate((pain_cam[...,np.newaxis],
                                                  rest_cam[...,np.newaxis]), axis=-1), axis=(1,2,3))[:,np.newaxis,np.newaxis]

    pain_cam/=max_val_cams
    rest_cam/=max_val_cams

    theta_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[4., 8.]]))
    alpha_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[8., 13.]]))
    beta_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[13., 32.]]))
    gamma_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[30., 60.]]))

    pain_signals = np.array(X_train[y_train == 0][...,0])
    rest_signals = np.array(X_train[y_train == 1][...,0])

    ##########
    mean_pain_signals = np.sum(np.mean(pain_signals, axis = 0), axis=1)
    min_mps = np.min(mean_pain_signals, axis=0)
    mean_pain_signals-=min_mps
    max_mps = np.max(mean_pain_signals, axis=0)
    mean_pain_signals/=max_mps

    mean_pain_cams = np.sum(np.mean(pain_cam, axis = 0), axis=1)
    min_mpc = np.min(mean_pain_cams, axis=0)
    mean_pain_cams-=min_mpc
    max_mpc = np.max(mean_pain_cams, axis=0)
    mean_pain_cams/=max_mpc

    mean_rest_signals = np.sum(np.mean(rest_signals, axis = 0), axis=1)
    min_mrs = np.min(mean_rest_signals, axis=0)
    mean_rest_signals-=min_mrs
    max_mrs = np.max(mean_rest_signals, axis=0)
    mean_rest_signals/=max_mrs

    mean_rest_cams = np.sum(np.mean(rest_cam, axis = 0), axis=1)
    min_mrc = np.min(mean_rest_cams, axis=0)
    mean_rest_cams-=min_mrc
    max_mrc = np.max(mean_rest_cams, axis=0)
    mean_rest_cams/=max_mrc   
    ##########

    pain_signal_theta = np.sum(np.mean(theta_filter.transform(pain_signals)[...,0,0], axis = 0), axis=1)
    pain_signal_alpha = np.sum(np.mean(alpha_filter.transform(pain_signals)[...,0,0], axis = 0), axis=1)
    pain_signal_beta = np.sum(np.mean(beta_filter.transform(pain_signals)[...,0,0], axis = 0), axis=1)
    pain_signal_gamma = np.sum(np.mean(gamma_filter.transform(pain_signals)[...,0,0], axis = 0), axis=1)
    pain_min_signal = np.min(np.concatenate((pain_signal_theta, pain_signal_alpha, pain_signal_beta, pain_signal_gamma), axis=0))

    pain_signal_theta -= pain_min_signal
    pain_signal_alpha -= pain_min_signal
    pain_signal_beta -= pain_min_signal
    pain_signal_gamma -= pain_min_signal
    pain_max_signal = np.max(np.concatenate((pain_signal_theta, pain_signal_alpha, pain_signal_beta, pain_signal_gamma), axis=0))     

    pain_signal_theta /= pain_max_signal
    pain_signal_alpha /= pain_max_signal
    pain_signal_beta /= pain_max_signal
    pain_signal_gamma /= pain_max_signal

    ####
    rest_signal_theta = np.sum(np.mean(theta_filter.transform(rest_signals)[...,0,0], axis = 0), axis=1)
    rest_signal_alpha = np.sum(np.mean(alpha_filter.transform(rest_signals)[...,0,0], axis = 0), axis=1)
    rest_signal_beta = np.sum(np.mean(beta_filter.transform(rest_signals)[...,0,0], axis = 0), axis=1)
    rest_signal_gamma = np.sum(np.mean(gamma_filter.transform(rest_signals)[...,0,0], axis = 0), axis=1)
    rest_min_signal = np.min(np.concatenate((rest_signal_theta, rest_signal_alpha, rest_signal_beta, rest_signal_gamma), axis=0))

    rest_signal_theta -= rest_min_signal
    rest_signal_alpha -= rest_min_signal
    rest_signal_beta -= rest_min_signal
    rest_signal_gamma -= rest_min_signal
    rest_max_signal = np.max(np.concatenate((rest_signal_theta, rest_signal_alpha, rest_signal_beta, rest_signal_gamma), axis=0))     

    rest_signal_theta /= rest_max_signal
    rest_signal_alpha /= rest_max_signal
    rest_signal_beta /= rest_max_signal
    rest_signal_gamma /= rest_max_signal
    ####

    pain_cam_theta = np.sum(np.mean(theta_filter.transform(pain_cam)[...,0,0], axis = 0), axis=1)
    pain_cam_alpha = np.sum(np.mean(alpha_filter.transform(pain_cam)[...,0,0], axis = 0), axis=1)
    pain_cam_beta = np.sum(np.mean(beta_filter.transform(pain_cam)[...,0,0], axis = 0), axis=1)
    pain_cam_gamma = np.sum(np.mean(gamma_filter.transform(pain_cam)[...,0,0], axis = 0), axis=1)
    pain_min_cam = np.min(np.concatenate((pain_cam_theta, pain_cam_alpha, pain_cam_beta, pain_cam_gamma), axis=0))

    pain_cam_theta -= pain_min_cam
    pain_cam_alpha -= pain_min_cam
    pain_cam_beta -= pain_min_cam
    pain_cam_gamma -= pain_min_cam
    pain_max_cam = np.max(np.concatenate((pain_cam_theta, pain_cam_alpha, pain_cam_beta, pain_cam_gamma), axis=0))

    pain_cam_theta /= pain_max_cam
    pain_cam_alpha /= pain_max_cam
    pain_cam_beta /= pain_max_cam
    pain_cam_gamma /= pain_max_cam

    ####
    rest_cam_theta = np.sum(np.mean(theta_filter.transform(rest_cam)[...,0,0], axis = 0), axis=1)
    rest_cam_alpha = np.sum(np.mean(alpha_filter.transform(rest_cam)[...,0,0], axis = 0), axis=1)
    rest_cam_beta = np.sum(np.mean(beta_filter.transform(rest_cam)[...,0,0], axis = 0), axis=1)
    rest_cam_gamma = np.sum(np.mean(gamma_filter.transform(rest_cam)[...,0,0], axis = 0), axis=1)
    rest_min_cam = np.min(np.concatenate((rest_cam_theta, rest_cam_alpha, rest_cam_beta, rest_cam_gamma), axis=0))

    rest_cam_theta -= rest_min_cam
    rest_cam_alpha -= rest_min_cam
    rest_cam_beta -= rest_min_cam
    rest_cam_gamma -= rest_min_cam
    rest_max_cam = np.max(np.concatenate((rest_cam_theta, rest_cam_alpha, rest_cam_beta), axis=0))

    rest_cam_theta /= rest_max_cam
    rest_cam_alpha /= rest_max_cam
    rest_cam_beta /= rest_max_cam
    rest_cam_gamma /= rest_max_cam
    ####

    fig, ax = plt.subplot_mosaic([['ul', 'um', 'um1', 'um2', 'ur'], ['ll', 'lm', 'lm1', 'lm2', 'lr'], ['ml', 'mm', 'mm1', 'mm2', 'mr'], ['bl', 'bm', 'bm1', 'bm2', 'br']], figsize=(16, 12), dpi=120)

    im_00 = topoplot(pain_signal_theta, channels, cmap='viridis', ax=ax['ul'], vlim=(0.0, 1.0), show=False)
    topoplot(pain_signal_alpha, channels, cmap='viridis', ax=ax['um'], vlim=(0.0, 1.0), show=False)
    topoplot(pain_cam_beta, channels, cmap='viridis', ax=ax['um1'], vlim=(0.0, 1.0), show=False)
    topoplot(pain_signal_gamma, channels, cmap='viridis', ax=ax['um2'], vlim=(0.0, 1.0), show=False)
    topoplot(mean_pain_signals, channels, cmap='viridis', ax=ax['ur'], vlim=(0.0, 1.0), show=False)

    topoplot(rest_signal_theta, channels, cmap='viridis', ax=ax['ml'], vlim=(0.0, 1.0), show=False)
    topoplot(rest_signal_alpha, channels, cmap='viridis', ax=ax['mm'], vlim=(0.0, 1.0), show=False)
    topoplot(rest_cam_beta, channels, cmap='viridis', ax=ax['mm1'], vlim=(0.0, 1.0), show=False)
    topoplot(rest_signal_gamma, channels, cmap='viridis', ax=ax['mm2'], vlim=(0.0, 1.0), show=False)
    topoplot(mean_pain_cams, channels, cmap='viridis', ax=ax['mr'], vlim=(0.0, 1.0), show=False)

    topoplot(pain_cam_theta, channels, cmap='viridis', ax=ax['ll'], vlim=(0.0, 1.0), show=False)
    topoplot(pain_cam_alpha, channels, cmap='viridis', ax=ax['lm'], vlim=(0.0, 1.0), show=False)
    topoplot(pain_signal_beta, channels, cmap='viridis', ax=ax['lm1'], vlim=(0.0, 1.0), show=False)
    topoplot(pain_cam_gamma, channels, cmap='viridis', ax=ax['lm2'], vlim=(0.0, 1.0), show=False)
    topoplot(mean_rest_signals, channels, cmap='viridis', ax=ax['lr'], vlim=(0.0, 1.0), show=False)

    topoplot(rest_cam_theta, channels, cmap='viridis', ax=ax['bl'], vlim=(0.0, 1.0), show=False)
    topoplot(rest_cam_alpha, channels, cmap='viridis', ax=ax['bm'], vlim=(0.0, 1.0), show=False)
    topoplot(rest_cam_beta, channels, cmap='viridis', ax=ax['bm1'], vlim=(0.0, 1.0), show=False)
    topoplot(rest_cam_gamma, channels, cmap='viridis', ax=ax['bm2'], vlim=(0.0, 1.0), show=False)
    topoplot(mean_rest_cams, channels, cmap='viridis', ax=ax['br'], vlim=(0.0, 1.0), show=False)

    ax['ll'].set_ylabel(f'CAMs Pain', size=20)
    ax['ul'].set_ylabel(f'EEG Pain', size=20)
    ax['bl'].set_ylabel(f'CAMs Rest', size=20)
    ax['ml'].set_ylabel(f'EEG Rest', size=20)
    ax['ul'].set_title(f'Theta', size=20)
    ax['um'].set_title(f'Alpha', size=20)
    ax['um1'].set_title(f'Beta', size=20)
    ax['um2'].set_title(f'Gamma', size=20)
    ax['ur'].set_title(f'Full Band', size=20)

    fig.suptitle(f'G{gInds[cg]} - EEG vs Cams{" "*70}')

    fig.tight_layout()
    fig.subplots_adjust(right=0.75)
    cbar_ax = fig.add_axes([0.75, 0.05, 0.02, 0.8])
    fig.colorbar(im_00[0], cax=cbar_ax)
#     plt.savefig(f'G{gInds[cg]}_Cams.png')

    plt.show()
    cg+=1

#### Gender Based

In [None]:
import matplotlib.pyplot as plt

from sklearn.preprocessing import MinMaxScaler, RobustScaler
from scipy.stats import ks_2samp
from scipy.spatial.distance import squareform

from tf_keras_vis import layercam, gradcam_plus_plus
from tf_keras_vis.utils.model_modifiers import ReplaceToLinear
from tf_keras_vis.utils.scores import CategoricalScore, InactiveScore

tf.random.set_seed(seed)

subjects = np.arange(n_subjects)+1

db = '../input/brain-mediators-of-pain-motor/'

num_class = 2

load_args = dict(db = db,
            f_bank = np.asarray([[4., 60.]]),
            vwt = np.asarray([[0.5,2.5]]),
            new_fs = 256.0)

scores_path = os.path.join(os.getcwd(), 'Groups_Motor256_Gamma60Hz')

cg = 1

gInds = {1: 'I', 2: 'II', 3: 'III', 4: 'IV', 5: 'V'}

groups_keys = list(groups.keys())

genders = ['m', 'f']

for sex_ in genders:
    cg = 1
    for group in groups_keys:
        group_name = groups[group]["Name"]

        group_subs = groups[group]["Subjects"]
        X_train_loaded = False

        for sbj in group_subs:
            print(f"Loading subject: {sbj}\n")
            load_args['sbj'] = sbj

            if not X_train_loaded:
                X_train, y_train, age, sex, fs, info = load_PAIN(**load_args)
                if sex[0] != sex_:
                    print("Sex different, not loading")
                    X_train_loaded = False
                else:
                    X_train_loaded = True
            else:
                X_train_, y_train_, age, sex, fs, info = load_PAIN(**load_args)
                if sex[0] == sex_:
                    print("Sex equal, loading")
                    X_train = np.concatenate((X_train, X_train_), axis = 0)
                    y_train = np.concatenate((y_train, y_train_), axis = 0)

        X_train = X_train[..., np.newaxis]
        Y_train = tf.keras.utils.to_categorical(y_train,num_classes=num_class)
        y_train = np.array([y[0] for y in y_train])

        channels = info['ch_names']

        model_weights_path = os.path.join(scores_path, f"{group_name}_weights.h5")
        params_path = os.path.join(scores_path, f"{group_name}.p")

        with open(params_path, 'rb') as f:
            cv = pickle.load(f)

        nFilters_ = cv['params'][cv['best_index_']]['filters']
        kernel_time_ = cv['params'][cv['best_index_']]['kernel_time_1']

        model = GFC_triu_net_avg(nb_classes = num_class,
                Chans = X_train.shape[1],
                Samples = X_train.shape[2],
                filters = int(nFilters_), 
                kernel_time_1 =int(kernel_time_))

        model.load_weights(model_weights_path)

        pain_scores = [CategoricalScore(list(np.zeros_like(y_train)))]
        rest_scores = [CategoricalScore(list(np.ones_like(y_train)))]

        no_softmax_model = tf.keras.Model(inputs=model.input, outputs=model.layers[-2].output)

        lc = gradcam_plus_plus.GradcamPlusPlus(model, model_modifier=ReplaceToLinear(), clone=False)
        # lc = layercam.Layercam(model, model_modifier=ReplaceToLinear(), clone=False)

        Penultimate_layer = model.layers[-2]

        pain_cam = lc(score = pain_scores,
                      seed_input = X_train,
                      penultimate_layer = Penultimate_layer,
                      seek_penultimate_conv_layer = True,
                      normalize_cam = False,
                      expand_cam = True)

        rest_cam = lc(score = rest_scores,
                      seed_input = X_train,
                      penultimate_layer = Penultimate_layer,
                      seek_penultimate_conv_layer = True,
                      normalize_cam = False,
                      expand_cam = True)

        max_val_cams = np.max(np.concatenate((pain_cam[...,np.newaxis],
                                                      rest_cam[...,np.newaxis]), axis=-1), axis=(1,2,3))[:,np.newaxis,np.newaxis]

        pain_cam/=max_val_cams
        rest_cam/=max_val_cams

        theta_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[4., 8.]]))
        alpha_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[8., 13.]]))
        beta_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[13., 32.]]))
        gamma_filter = TimeFrequencyRpr(sfreq = fs, f_bank = np.array([[30., 60.]]))

        pain_signals = np.array(X_train[y_train == 0][...,0])
        rest_signals = np.array(X_train[y_train == 1][...,0])

        ##########
        mean_pain_signals = np.sum(np.mean(pain_signals, axis = 0), axis=1)
        min_mps = np.min(mean_pain_signals, axis=0)
        mean_pain_signals-=min_mps
        max_mps = np.max(mean_pain_signals, axis=0)
        mean_pain_signals/=max_mps

        mean_pain_cams = np.sum(np.mean(pain_cam, axis = 0), axis=1)
        min_mpc = np.min(mean_pain_cams, axis=0)
        mean_pain_cams-=min_mpc
        max_mpc = np.max(mean_pain_cams, axis=0)
        mean_pain_cams/=max_mpc

        mean_rest_signals = np.sum(np.mean(rest_signals, axis = 0), axis=1)
        min_mrs = np.min(mean_rest_signals, axis=0)
        mean_rest_signals-=min_mrs
        max_mrs = np.max(mean_rest_signals, axis=0)
        mean_rest_signals/=max_mrs

        mean_rest_cams = np.sum(np.mean(rest_cam, axis = 0), axis=1)
        min_mrc = np.min(mean_rest_cams, axis=0)
        mean_rest_cams-=min_mrc
        max_mrc = np.max(mean_rest_cams, axis=0)
        mean_rest_cams/=max_mrc   
        ##########

        pain_signal_theta = np.sum(np.mean(theta_filter.transform(pain_signals)[...,0,0], axis = 0), axis=1)
        pain_signal_alpha = np.sum(np.mean(alpha_filter.transform(pain_signals)[...,0,0], axis = 0), axis=1)
        pain_signal_beta = np.sum(np.mean(beta_filter.transform(pain_signals)[...,0,0], axis = 0), axis=1)
        pain_signal_gamma = np.sum(np.mean(gamma_filter.transform(pain_signals)[...,0,0], axis = 0), axis=1)
        pain_min_signal = np.min(np.concatenate((pain_signal_theta, pain_signal_alpha, pain_signal_beta, pain_signal_gamma), axis=0))

        pain_signal_theta -= pain_min_signal
        pain_signal_alpha -= pain_min_signal
        pain_signal_beta -= pain_min_signal
        pain_signal_gamma -= pain_min_signal
        pain_max_signal = np.max(np.concatenate((pain_signal_theta, pain_signal_alpha, pain_signal_beta, pain_signal_gamma), axis=0))     

        pain_signal_theta /= pain_max_signal
        pain_signal_alpha /= pain_max_signal
        pain_signal_beta /= pain_max_signal
        pain_signal_gamma /= pain_max_signal

        ####
        rest_signal_theta = np.sum(np.mean(theta_filter.transform(rest_signals)[...,0,0], axis = 0), axis=1)
        rest_signal_alpha = np.sum(np.mean(alpha_filter.transform(rest_signals)[...,0,0], axis = 0), axis=1)
        rest_signal_beta = np.sum(np.mean(beta_filter.transform(rest_signals)[...,0,0], axis = 0), axis=1)
        rest_signal_gamma = np.sum(np.mean(gamma_filter.transform(rest_signals)[...,0,0], axis = 0), axis=1)
        rest_min_signal = np.min(np.concatenate((rest_signal_theta, rest_signal_alpha, rest_signal_beta, rest_signal_gamma), axis=0))

        rest_signal_theta -= rest_min_signal
        rest_signal_alpha -= rest_min_signal
        rest_signal_beta -= rest_min_signal
        rest_signal_gamma -= rest_min_signal
        rest_max_signal = np.max(np.concatenate((rest_signal_theta, rest_signal_alpha, rest_signal_beta, rest_signal_gamma), axis=0))     

        rest_signal_theta /= rest_max_signal
        rest_signal_alpha /= rest_max_signal
        rest_signal_beta /= rest_max_signal
        rest_signal_gamma /= rest_max_signal
        ####

        pain_cam_theta = np.sum(np.mean(theta_filter.transform(pain_cam)[...,0,0], axis = 0), axis=1)
        pain_cam_alpha = np.sum(np.mean(alpha_filter.transform(pain_cam)[...,0,0], axis = 0), axis=1)
        pain_cam_beta = np.sum(np.mean(beta_filter.transform(pain_cam)[...,0,0], axis = 0), axis=1)
        pain_cam_gamma = np.sum(np.mean(gamma_filter.transform(pain_cam)[...,0,0], axis = 0), axis=1)
        pain_min_cam = np.min(np.concatenate((pain_cam_theta, pain_cam_alpha, pain_cam_beta, pain_cam_gamma), axis=0))

        pain_cam_theta -= pain_min_cam
        pain_cam_alpha -= pain_min_cam
        pain_cam_beta -= pain_min_cam
        pain_cam_gamma -= pain_min_cam
        pain_max_cam = np.max(np.concatenate((pain_cam_theta, pain_cam_alpha, pain_cam_beta, pain_cam_gamma), axis=0))

        pain_cam_theta /= pain_max_cam
        pain_cam_alpha /= pain_max_cam
        pain_cam_beta /= pain_max_cam
        pain_cam_gamma /= pain_max_cam

        ####
        rest_cam_theta = np.sum(np.mean(theta_filter.transform(rest_cam)[...,0,0], axis = 0), axis=1)
        rest_cam_alpha = np.sum(np.mean(alpha_filter.transform(rest_cam)[...,0,0], axis = 0), axis=1)
        rest_cam_beta = np.sum(np.mean(beta_filter.transform(rest_cam)[...,0,0], axis = 0), axis=1)
        rest_cam_gamma = np.sum(np.mean(gamma_filter.transform(rest_cam)[...,0,0], axis = 0), axis=1)
        rest_min_cam = np.min(np.concatenate((rest_cam_theta, rest_cam_alpha, rest_cam_beta, rest_cam_gamma), axis=0))

        rest_cam_theta -= rest_min_cam
        rest_cam_alpha -= rest_min_cam
        rest_cam_beta -= rest_min_cam
        rest_cam_gamma -= rest_min_cam
        rest_max_cam = np.max(np.concatenate((rest_cam_theta, rest_cam_alpha, rest_cam_beta), axis=0))

        rest_cam_theta /= rest_max_cam
        rest_cam_alpha /= rest_max_cam
        rest_cam_beta /= rest_max_cam
        rest_cam_gamma /= rest_max_cam
        ####

        fig, ax = plt.subplot_mosaic([['ul', 'um', 'um1', 'um2', 'ur'], ['ll', 'lm', 'lm1', 'lm2', 'lr'], ['ml', 'mm', 'mm1', 'mm2', 'mr'], ['bl', 'bm', 'bm1', 'bm2', 'br']], figsize=(16, 12), dpi=120)

        im_00 = topoplot(pain_signal_theta, channels, cmap='viridis', ax=ax['ul'], vlim=(0.0, 1.0), show=False)
        topoplot(pain_signal_alpha, channels, cmap='viridis', ax=ax['um'], vlim=(0.0, 1.0), show=False)
        topoplot(pain_cam_beta, channels, cmap='viridis', ax=ax['um1'], vlim=(0.0, 1.0), show=False)
        topoplot(pain_signal_gamma, channels, cmap='viridis', ax=ax['um2'], vlim=(0.0, 1.0), show=False)
        topoplot(mean_pain_signals, channels, cmap='viridis', ax=ax['ur'], vlim=(0.0, 1.0), show=False)

        topoplot(rest_signal_theta, channels, cmap='viridis', ax=ax['ml'], vlim=(0.0, 1.0), show=False)
        topoplot(rest_signal_alpha, channels, cmap='viridis', ax=ax['mm'], vlim=(0.0, 1.0), show=False)
        topoplot(rest_cam_beta, channels, cmap='viridis', ax=ax['mm1'], vlim=(0.0, 1.0), show=False)
        topoplot(rest_signal_gamma, channels, cmap='viridis', ax=ax['mm2'], vlim=(0.0, 1.0), show=False)
        topoplot(mean_pain_cams, channels, cmap='viridis', ax=ax['mr'], vlim=(0.0, 1.0), show=False)

        topoplot(pain_cam_theta, channels, cmap='viridis', ax=ax['ll'], vlim=(0.0, 1.0), show=False)
        topoplot(pain_cam_alpha, channels, cmap='viridis', ax=ax['lm'], vlim=(0.0, 1.0), show=False)
        topoplot(pain_signal_beta, channels, cmap='viridis', ax=ax['lm1'], vlim=(0.0, 1.0), show=False)
        topoplot(pain_cam_gamma, channels, cmap='viridis', ax=ax['lm2'], vlim=(0.0, 1.0), show=False)
        topoplot(mean_rest_signals, channels, cmap='viridis', ax=ax['lr'], vlim=(0.0, 1.0), show=False)

        topoplot(rest_cam_theta, channels, cmap='viridis', ax=ax['bl'], vlim=(0.0, 1.0), show=False)
        topoplot(rest_cam_alpha, channels, cmap='viridis', ax=ax['bm'], vlim=(0.0, 1.0), show=False)
        topoplot(rest_cam_beta, channels, cmap='viridis', ax=ax['bm1'], vlim=(0.0, 1.0), show=False)
        topoplot(rest_cam_gamma, channels, cmap='viridis', ax=ax['bm2'], vlim=(0.0, 1.0), show=False)
        topoplot(mean_rest_cams, channels, cmap='viridis', ax=ax['br'], vlim=(0.0, 1.0), show=False)

        ax['ll'].set_ylabel(f'CAMs Pain', size=20)
        ax['ul'].set_ylabel(f'EEG Pain', size=20)
        ax['bl'].set_ylabel(f'CAMs Rest', size=20)
        ax['ml'].set_ylabel(f'EEG Rest', size=20)
        ax['ul'].set_title(f'Theta', size=20)
        ax['um'].set_title(f'Alpha', size=20)
        ax['um1'].set_title(f'Beta', size=20)
        ax['um2'].set_title(f'Gamma', size=20)
        ax['ur'].set_title(f'Full Band', size=20)

        fig.suptitle(f'G{gInds[cg]} - {"Females" if sex_ == "f" else "Males"} - EEG vs Cams{" "*70}')

        fig.tight_layout()
        fig.subplots_adjust(right=0.75)
        cbar_ax = fig.add_axes([0.75, 0.05, 0.02, 0.8])
        fig.colorbar(im_00[0], cax=cbar_ax)
        plt.savefig(f'G{gInds[cg]}_{"Females" if sex_ == "f" else "Males"}_Cams.png')

        plt.show()
        cg+=1

In [None]:
target_folder = 'CamsImages'

try:
    os.makedirs(target_folder)
except:
    pass

contents = os.listdir(os.getcwd())

for content in contents:
    if '.png' in content:
        shutil.move(os.path.join(os.getcwd(), content), os.path.join(os.getcwd(), target_folder, content))

In [None]:
import shutil
shutil.make_archive("GFC_Cams_256", 'zip', "/kaggle/working/CamsImages")