# S1 Fig. Model reliance without normalization for group sizes.

In [None]:
import numpy as np
import pandas as pd
import matplotlib as mpl 
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import seaborn as sns
import scipy.io as io
import os
import functions.model_reliance as mr
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier 
from functions.helpers import *

# Set plotting defaults
plt.rcParams["font.family"] = "Arial"
font = {'size'   : 20}
mpl.rc('font', **font)

In [None]:
# Load indices
with open('feature_groups/allelecs_allfreq_roi.json') as f:
    allelecs_allfreq_roi = json.load(f)
with open('feature_groups/allelecs_allfreq_freq.json') as f:
    allelecs_allfreq_freq = json.load(f)

In [None]:
roi_results = []
freq_results = []
k = 10
p = 10
n_trees = 5000
normalise_mr = False

for subj in subjects:
    print('Running model reliance on: ' + subj)

    # load data
    datapath = inputdirectory + subj +'/'+ 'Xallelecs_allfreq.mat'
    data = io.loadmat(datapath)
    y = data['y'].flatten()
    X = data['X']
    
    clf = RandomForestClassifier(n_estimators = n_trees, n_jobs=-1) 
    
    res = mr.model_reliance_cv(clf, X, y, perm_groups=allelecs_allfreq_roi, p=p, cv=k, normalise=normalise_mr)
    roi_results.append(pd.DataFrame(pd.concat(res).mean()).transpose())
    
    
    res = mr.model_reliance_cv(clf, X, y, perm_groups=allelecs_allfreq_freq, p=p, cv=k, normalise=normalise_mr)
    freq_results.append(pd.DataFrame(pd.concat(res).mean()).transpose())

# roi - reshape data for plotting
roi_results = pd.concat(roi_results)
roi_mr = roi_results.copy()
roi_mr['id'] = subjects
roi_mr.reset_index(drop = True, inplace = True)



# wide to long
roi_mr_m = pd.melt(roi_mr, id_vars=['id'], value_vars=['leftocci','rightocci','leftcentral',
                                      'rightcentral','leftfrontal','rightfrontal',
                                      'frontocentral','central','occipitocentral'],
                             var_name='feature', value_name='mr')
roi_mr_m['mr'] *= 100 # rescale to %


# freq - reshape data for plotting
freq_results = pd.concat(freq_results)
freq_mr = freq_results.copy()
freq_mr['id'] = subjects
freq_mr.reset_index(drop = True, inplace = True)



# wide to long
freq_mr_m = pd.melt(freq_mr, id_vars=['id'], value_vars=['delta','theta','alpha',
                                      'beta','gamma'],
                             var_name='feature', value_name='mr')
freq_mr_m['mr'] *= 100 # rescale to %

In [None]:
plt.style.use('default')
sns.set_style("whitegrid")


fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))

sns.boxplot(y="feature", x="mr", data=roi_mr_m,
        color="#4a7ab7", ax = ax1)

ax1.grid(False)
ax1.xaxis.grid(True)
ax1.set_xlim(-4,10)
ax1.set_xticks(np.arange(-4,10, 2))

ax1.set_ylabel("")
ax1.set_xlabel("MR [%]")
ax1.set_title("ROI") 



sns.set_style("whitegrid")


sns.boxplot(y="feature", x="mr", data=freq_mr_m,
        color="#4a7ab7", ax = ax2)
ax2.grid(False)
ax2.xaxis.grid(True)
ax2.set_xlim(-8,40)
ax2.set_xticks(np.arange(-8,40, 8))
ax2.invert_yaxis()

plt.ylabel("")
plt.title("Frequency bands") 
plt.xlabel("MR [%]")
plt.tight_layout()
plt.savefig('plots/S1_fig.eps') 