# S2 Fig. Classification accuracy for training and testing on individual feature groups.

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib as mpl # perhaps change this later
import matplotlib.pyplot as plt
import scipy.io as io
import os
import functions.model_reliance as mr
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier 
from functions.helpers import *

# Set plotting defaults
plt.rcParams["font.family"] = "Arial"
font = {'size'   : 20}
mpl.rc('font', **font)

In [None]:
# Load indices
with open('feature_groups/allelecs_allfreq_roi.json') as f:
    allelecs_allfreq_roi = json.load(f)
with open('feature_groups/allelecs_allfreq_freq.json') as f:
    allelecs_allfreq_freq = json.load(f)

In [None]:
# Training and testing on individual features 
clf = RandomForestClassifier(n_jobs=-1,
                               oob_score=True,
                               class_weight='balanced',  # for imb. datasets
                               random_state=42,
                               n_estimators=5000)
cv = StratifiedKFold(n_splits = 10, shuffle = True, random_state = 42)


ind_roi_results = []

for subj in subjects:
    print('Training and testing on: ' + subj)

    # load data
    datapath = inputdirectory + subj +'/'+ 'Xallelecs_allfreq.mat'
    data = io.loadmat(datapath)
    y = data['y'].flatten()
    X = data['X']
    

    score = {}
    for label, cols in allelecs_allfreq_roi.items():
        print('running cv on feature '+ label)
        res = np.mean(cross_val_score(clf, X[:,cols], y, cv = cv, n_jobs = -1))
        score[label] = res
        print(res)    
        
    ind_roi_results.append(pd.DataFrame.from_dict(score, 'index').transpose())


In [None]:
# reshape data for plotting
ind_roi_results = pd.concat(ind_roi_results)
ind_roi_mr = ind_roi_results.copy()
ind_roi_mr['id'] = subjects
ind_roi_mr.reset_index(drop = True, inplace = True)
ind_roi_mr_m = pd.melt(ind_roi_mr, id_vars=['id'], value_vars=['leftocci','rightocci','leftcentral',
                                      'rightcentral','leftfrontal','rightfrontal',
                                      'frontocentral','central','occipitocentral'],
                             var_name='feature', value_name='mr')

In [None]:
# training and testing on individual features 

clf = RandomForestClassifier(n_jobs=-1,
                               oob_score=True,
                               class_weight='balanced',  # for imb. datasets
                               random_state=42,
                               n_estimators=5000)
cv = StratifiedKFold(n_splits = 10, shuffle = True, random_state = 42)


ind_freq_results = []

for subj in subjects:
    print('Training and testing on: ' + subj)

    # load data
    datapath = inputdirectory + subj +'/'+ 'Xallelecs_allfreq.mat'
    data = io.loadmat(datapath)
    y = data['y'].flatten()
    X = data['X']
    

    score = {}
    for label, cols in allelecs_allfreq_freq.items():
        print('running cv on feature '+ label)
        res = np.mean(cross_val_score(clf, X[:,cols], y, cv = cv, n_jobs = -1))
        score[label] = res
        print(res)    
        
    ind_freq_results.append(pd.DataFrame.from_dict(score, 'index').transpose())

In [None]:
# reshape data for plotting
ind_freq_results = pd.concat(ind_freq_results)
ind_freq_mr = ind_freq_results.copy()
ind_freq_mr['id'] = subjects
ind_freq_mr.reset_index(drop = True, inplace = True)
ind_freq_mr_m = pd.melt(ind_freq_mr, id_vars=['id'], 
                                value_vars=['delta','theta','alpha', 'beta', 'gamma'],
                             var_name='feature', value_name='mr')

In [None]:
ind_freq_mr_m['mr'] = ind_freq_mr_m['mr']*100
ind_roi_mr_m['mr'] = ind_roi_mr_m['mr']*100

In [None]:
%config InlineBackend.figure_format = 'retina'
plt.style.use('default')
sns.set_style("whitegrid")


plt.subplot(2, 2, 1)

sns.boxplot(y="feature", x="mr", data=ind_roi_mr_m,
        color="#4a7ab7")

ax = plt.gca()
ax.grid(False)
ax.yaxis.grid(True)
ax.set_xlim(30,60)

plt.ylabel("")
plt.xlabel("Accuracy [%]")
plt.title("ROI") 



sns.set_style("whitegrid")

plt.subplot(2, 2, 2)



sns.boxplot(y="feature", x="mr", data=ind_freq_mr_m,
        color="#4a7ab7")
ax = plt.gca()
ax.grid(False)
ax.yaxis.grid(True)
ax.set_xlim(30,60)
ax.invert_yaxis()

plt.ylabel("")
plt.title("Freuquency Band") 
plt.xlabel("Accuracy [%]")
plt.tight_layout()
plt.savefig('plots/S2_fig.eps') 