In [1]:
import glob
import os
import re
from itertools import chain

# imports
import mne
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from mne.decoding import LinearModel
from mne.decoding import get_coef
from mne.time_frequency import psd_multitaper
from sklearn.linear_model import LogisticRegressionCV
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

In [None]:
# path to data 
path = '/Users/ilamiheev/Downloads/eeg_data'
# create folders for results
path_res = '/Users/ilamiheev/Downloads/results_ihna'
path_subj = os.path.join(path_res, 'subjects_classification')
path_subj_topo = os.path.join(path_subj, 'subjects_topo')
path_unite_subj = os.path.join(path_res, 'unite_subjects')
path_unite_topo =  os.path.join(path_unite_subj, 'unite_topo')
path_group = os.path.join(path_res, 'group_classification')
path_group_topo =  os.path.join(path_group, 'group_topo')
for pathn in [path_res,path_subj,path_subj_topo,path_unite_subj,path_unite_topo,path_group,path_group_topo]:
    os.makedirs(pathn)

In [None]:
# eeg parameters and subjects indexes
files = [f for f in sorted(os.listdir(path))]
_ = files.pop(0)
indexes = []
for i,j in enumerate(files):
    indexes.append(re.search('(.+?)_', j).group(1))
chan_drop = ['E8','E14','E21','E25','E43','E48','E49','E56','E57','E63','E64','E65','E68','E69','E73','E74','E81'
             ,'E82','E88','E89','E90','E94','E95','E99','E100','E107','E113','E119','E120','E125','E126','E127'
             ,'E128','Status']
montage = mne.channels.make_standard_montage('GSN-HydroCel-128')
events_list = [241,242,244]
fr_bands = {   "theta1":  [4,6],
               "theta2":  [6,8],
               "alpha1": [8,10],
               "alpha2": [10,12],
               "beta1":  [12,16],
               "beta2":  [16,20],
               "beta3":  [20,24] }
dict_cls = { "241/244": [0,2],
             "242/244": [1,2],
             "241/242": [0,1] }
mat = ['311','312','314','315','316','317','326','327','328','330','334','335']  
not_mat = [x for x in indexes if x not in mat]
index_mat, index_not_mat = [indexes.index(i) for i in mat], [indexes.index(i) for i in not_mat] 

In [None]:
# estimate relative power 
def eeg_power_band(epochs_list):
    fin_table, fin_feat = [], []
    for beta in range(len(epochs_list)):
        psds, freqs = psd_multitaper(epochs_list[beta])
        psds_table = np.mean(psds, axis=0)
        psds /= psds.sum(axis=-1)[..., None]
        psds_table /= psds_table.sum(axis=-1)[..., None]
        psd_table_list, psd_features_list = [], []
        for fmin, fmax in fr_bands.values():
            freq_mask = (fmin < freqs) & (freqs < fmax)
            data_table, data_feat = psds_table[..., freq_mask].mean(axis=-1), psds[..., freq_mask].mean(axis=-1)
            psd_features_list.append(data_feat)
            psd_table_list.append(data_table)
        fin_table.append(psd_table_list)
        fin_feat.append(psd_features_list)
    return fin_feat, fin_table
# logistic regression with l2 penalty and CV
def predict(x_train, x_test, y_train, y_test):
    results = np.zeros((1,4))
    model = make_pipeline( StandardScaler(),                  
                               LinearModel(LogisticRegressionCV(
                               Cs=list(np.power(10.0, np.arange(-10, 10))),
                               penalty='l2',
                               scoring='roc_auc',
                                                                                                                                                                                  random_state=0,
                               max_iter=10000,
                               fit_intercept=True,
                               solver='newton-cg',
                               tol=10 
                                                                       ))) 
    model.fit(x_train, y_train)
    y_predict = model.predict(x_test)
    score = model.score(x_test,y_test)
    cm=confusion_matrix(y_test,y_predict)
    TN,TP,FN,FP = cm[1,1],cm[0,0],cm[1,0],cm[0,1]
    results[0,0], results[0,1] = accuracy_score(y_test,y_predict), score
    results[0,2], results[0,3] = TP/float(TP+FN), TN/float(TN+FP)
    return results, model
# plot topomap for each freq band
def plot_topo(ar):
    vmin = np.amin(ar)
    vmax = np.amax(ar)
    fig, axes = plt.subplots(nrows=2, ncols=7, figsize=(30, 20)) 
    for name, pos, plot_name, ind in zip(('patterns_', 'filters_'),(0.8,0.5),
                                         ('Patterns','Filters'),(0,1)):
        for i,key in enumerate(list(fr_bands.keys())):
            a = mne.viz.plot_topomap(ar[ind,i,:],info,vmin=vmin,vmax=vmax, axes=axes[ind,i], 
                                 show = False)
            axes[ind,i].set(title='{}-{} Hz'.format(*fr_bands[key]))
            mne.viz.tight_layout() 
        plt.figtext(0.5,pos,'{}'.format(plot_name), va="center", ha="center", size=24, fontweight = 'semibold',)
    return fig

In [None]:
# create features
subj_list_table, subj_list_features = [], []
for file_name in files:
    paths = glob.glob(path + '/{0}/Reals/*.edf'.format(file_name))
    epochs_list = []
    for event_ind in events_list:
        j = 0
        for i in [x for x in paths if '{0}'.format(event_ind) in x]: 
            event_id = dict(a=event_ind)
            raw = mne.io.read_raw_edf(i)
            if  len(raw.times)//500 < 10:
                continue  
            new_events = mne.make_fixed_length_events(raw, id=event_ind, start=5, duration=2, overlap=1)
            if j==1:
                epochs = mne.concatenate_epochs([mne.Epochs(raw, new_events, event_id = event_id, tmin=0, 
                                                            tmax=2, baseline=None, flat=dict(eeg=1e-20), 
                                                            preload=True), epochs])
            else:
                epochs = mne.Epochs(raw, new_events, event_id = event_id , tmin=0, tmax=2, baseline=None, 
                                    flat=dict(eeg=1e-20), preload=True)
                j+=1
        epochs_list.append(epochs.copy())
    for teta in range(len(epochs_list)):
            new_names = dict(
                    (ch_name,
                     ch_name.replace('-', '').replace('Chan ', 'E').replace('CAR', '').replace('EEG ', '')
                     .replace('CA', '').replace(' ', ''))
                     for ch_name in epochs_list[teta].ch_names)
            epochs_list[teta].rename_channels(new_names)
            epochs_list[teta].set_montage(montage)
            epochs_list[teta].drop_channels(chan_drop)
    feat_list, tabl_list = eeg_power_band(epochs_list)
    subj_list_table.append(tabl_list)
    subj_list_features.append(feat_list)
chan1 = epochs_list[0].ch_names

In [None]:
# info about chan positions
info = epochs_list[1].info

In [None]:
# write relative powers to table
# write new powers to table_1
df1,df2 = {}, {}
k, l = 0, 0
for i in range(len(subj_list_table[0])):
    for j in range(len(subj_list_table[0][0])):
        df1[k] = np.array(subj_list_table)[:,i,j,:]
        k += 1
for s in list(df1.keys()):      
    df2[s] = pd.DataFrame(columns=chan1, index=indexes)
    for ind_num,ind_name in enumerate(indexes):
        df2[s].loc['{}'.format(ind_name)] = pd.Series(df1[s][ind_num,:], chan1)
writer = pd.ExcelWriter(os.path.join(path_res, 'subjects_relative_power.xlsx'), engine='xlsxwriter')
for ind in events_list:
    for band_name in list(fr_bands.keys()):
        df2[l].to_excel(writer, sheet_name='{},({},{})'.format(ind, *fr_bands[band_name]))
        l += 1
writer.save()

In [None]:
# classification for each subject in both groups
# add new pooling layer
results, coefs = np.zeros((len(indexes),12)), np.zeros((len(indexes),len(dict_cls),2,len(fr_bands),len(chan1))) 
for subj in range(len(indexes)):
    k = 4             
    for i, key in enumerate(list(dict_cls.keys())):
        ind = dict_cls[key]
        A, B = np.stack(subj_list_features[subj][ind[0]],axis=1), np.stack(subj_list_features[subj][ind[1]],axis=1)
        y = ['0']*A.shape[0] + ['1']*B.shape[0] 
        x = np.concatenate((A,B),axis=0)
        x_train, x_test, y_train, y_test = train_test_split(x.reshape(x.shape[0],-1), y, test_size=0.3)
        results[subj,k-4:k], model = predict(x_train, x_test, y_train, y_test)
        for name, j in zip(['patterns_', 'filters_'],[0,1]):
            coef = get_coef(model, name, inverse_transform=True)
            coefs[subj,i,j,...] = coef.reshape(len(fr_bands),-1)
        k+=4

In [None]:
# evaluate mean and var for all subjects/and for subjects in groups
list_gr, ar_mean_var = [results,results[index_mat,...],results[index_not_mat,...]], np.zeros((3,24))
list_names_subj, list_names_fin = [], []
metr_name = ['accuracy','ROC/AUC','TPR','TNR']
for i, key in enumerate(list(dict_cls.keys())):
    ar_mean_var[i,...] = np.array(list(chain.from_iterable((a, b) for a,b in zip(list_gr[i].mean(axis=0),
                                                                                 list_gr[i].var(axis=0)))))
    list_names_subj.extend(['{} {}'.format(key,name) for name in metr_name])
    list_names_fin.extend(['{} {}'.format(key,name) for name in ['{} {}'.format(i,j) for i in metr_name 
                                                                 for j in ['mean','var'] ]])

In [None]:
# save results to table
df_subj = pd.DataFrame(results, columns=list_names_subj, index=indexes)
df_fin = pd.DataFrame(ar_mean_var, columns=list_names_fin, 
                       index=['all_subjects','mathematicians','not_mathematicians'])
writer = pd.ExcelWriter(os.path.join(path_subj, 'subjects_classification_l1_final.xlsx'), engine='xlsxwriter')
df_subj.to_excel(writer, sheet_name='all_subj')
df_fin.to_excel(writer, sheet_name='mean_var')
writer.save()

In [None]:
# plot topomaps for each subject
# add function for removing channels
for i, subj_ind in enumerate(indexes):
    for j, key in enumerate(list(dict_cls.keys())):
        fig = plot_topo(coefs[i,j])
        fig.savefig(os.path.join(path_subj_topo, 'filters_patterns_{}_{}.png'
                                 .format(subj_ind,['241_244','242_244','241_242'][j])), format='png', dpi=600)
        plt.close(fig)

In [None]:
# classification using all subjects in both groups and all subjects
results_all, coefs_all = np.zeros((3,12)), np.zeros((3,len(dict_cls),2,len(fr_bands),len(chan1))) 
for group_ind, subj_type in enumerate([[*range(len(indexes))],index_mat,index_not_mat]):
    k = 4
    # create new loop for prediction of target
    for i, key in enumerate(list(dict_cls.keys())):
        ind = dict_cls[key]
        A = np.concatenate(([np.stack([subj_list_features[j] for j in subj_type][i][ind[0]],axis = 1) 
                            for i in range(len(subj_type))]),axis=0)
        B = np.concatenate(([np.stack([subj_list_features[j] for j in subj_type][i][ind[1]],axis = 1) 
                            for i in range(len(subj_type))]),axis=0)
        y = ['0']*A.shape[0] + ['1']*B.shape[0] 
        x = np.concatenate((A,B),axis=0)
        x_train, x_test, y_train, y_test = train_test_split(x.reshape(x.shape[0],-1), y, test_size=0.3)
        results_all[group_ind,k-4:k], model = predict(x_train, x_test, y_train, y_test)
        for name, j in zip(['patterns_', 'filters_'],[0,1]):
            coef = get_coef(model, name, inverse_transform=True)
            coefs_all[group_ind,i,j,...] = coef.reshape(len(fr_bands),-1)
        k+=4

In [None]:
# save results to table
datafr_groups = pd.DataFrame(results_all, columns=list_names_subj, index=['all_sibjects','mathematicians','not_mathematicians'])
datafr_groups.to_excel(os.path.join(path_unite_subj, 'subj_in_groups_classification_l2.xlsx'), index=True)

In [None]:
# plot topomaps for subjects in groups
group_names = ['all_subjects','mathematicians','not_mathematicians']
for i, group_ind in enumerate(group_names):
    for j, key in enumerate(list(dict_cls.keys())):
        fig = plot_topo(coefs_all[i,j,...])
        fig.savefig(os.path.join(path_unite_topo, 'filters_patterns_{}_{}.png'
                                 .format(group_ind,['241_244','242_244','241_242'][j])), format='png', dpi=600)
        plt.close(fig)

In [None]:
# classification mathematicians/non mathematicians
results_group, coefs_group = np.zeros((4,4)), np.zeros((4,2,len(fr_bands),len(chan1)))
list_A, list_B = [], []
for ind in range(4):
    if ind < 3:
        A =  np.concatenate(([np.stack([subj_list_features[j] for j in index_mat][i][ind],axis = 1)
                              for i in range(len(mat))]),axis=0)
        list_A.append(A)
        B =  np.concatenate(([np.stack([subj_list_features[j] for j in index_not_mat][i][ind],axis = 1)
                              for i in range(len(not_mat))]),axis=0)
        list_B.append(B)
    else:
        A, B = np.concatenate(list_A,axis=0), np.concatenate(list_B,axis=0)
    y = ['0']*A.shape[0] + ['1']*B.shape[0] 
    x = np.concatenate((A,B),axis=0)
    x_train, x_test, y_train, y_test = train_test_split(x.reshape(x.shape[0],-1), y, test_size=0.3)
    results_group[ind,...], model = predict(x_train, x_test, y_train, y_test)
    for name, i in zip(['patterns_', 'filters_'],[0,1]):
        coef = get_coef(model, name, inverse_transform=True)
        coefs_group[ind,i,...] = coef.reshape(len(fr_bands),-1)

In [None]:
# save new results to table
datafr_math_not_math = pd.DataFrame(results_group, columns=['accuracy score','roc_auc score','sensitivity(TPR)',
                                             'specificity(TNR)'], index= ['241','242','244','all'])
datafr_math_not_math.to_excel(os.path.join(path_group, 'group_classification_l2.xlsx'), index=True)
 

In [None]:
# plot and save topomaps for mat/not mat
for i in range(4):
    fig = plot_topo(coefs_group[i,...])
    fig.savefig(os.path.join(path_group_topo, 'filters_patterns_mat_not_mat_{}.png'
                             .format(['241','242','244','all'][i])), format='png', dpi=600)
    plt.close(fig)