In [1]:
import os, glob, platform
import numpy as np
import numpy.matlib
import pickle
import pandas as pd
import pathlib
import matplotlib
import matplotlib.pyplot as plt
import mne
mne.__version__
from mne.viz import plot_alignment, snapshot_brain_montage
import shutil
from mne.datasets import eegbci
from sklearn.model_selection import train_test_split

# from mne_bids import write_raw_bids, BIDSPath, print_dir_tree, make_dataset_description
# from mne_bids.stats import count_events
import sys


In [2]:
path_utils = '/decoding_toolbox_py/helper_funcs' 
sys.path.append(path_utils)

In [3]:
''' VARIABLES '''

dataset = 'eeg'

amount_of_subjects = 27 # Change the range so the process is faster
if amount_of_subjects > 26: amount_of_subjects = 26
subjs_list = ['s{:02d}'.format(i) for i in range(1, amount_of_subjects+1) if i != 6 ] 
print(subjs_list)
nSubj = len(subjs_list)

numC = 8

angles = [i * 180./numC for i in range(numC)]

x_labels = np.array(angles)

resample = True # speeds up the procees but showing worse results overall
if resample: resample_frequency = 20 # in Hz, original freq is 500Hz

cfg_stim = dict()
cfg_stim['kappa'] = 4
cfg_stim['NumC'] = numC
cfg_stim['Tuning'] = 'vonmises'
# cfg_stim['Tuning'] = 'halfRectCos'
cfg_stim['offset'] = 0

cfg_train = dict()
cfg_train['gamma'] = 0.1
cfg_train['demean'] = True
cfg_train['returnPattern'] = True

cfg_test = dict()
cfg_test['demean'] = 'traindata'

['s01', 's02', 's03', 's04', 's05', 's07', 's08', 's09', 's10', 's11', 's12', 's13', 's14', 's15', 's16', 's17', 's18', 's19', 's20', 's21', 's22', 's23', 's24', 's25', 's26']


In [4]:
'''EEG Dataset'''
def read_data(
        number_of_repetition=3,
        resample=False,
        subjs_list = subjs_list,
        task = 'stim'
        ):
    path = 'Cond_CJ_EEG'

    epochs = []
    all_epochs = []
    all_rawdata = []
    all_st_epochs = []
    all_st_rawdata = []
    for subject_id in subjs_list:
        preproc_path = os.path.join(path, subject_id)

        if task == 'main':
            epoch = mne.read_epochs(os.path.join(preproc_path, 'main_epo.fif'), verbose=False)
            epochs.append(epoch.average())
            all_epochs.append(epoch)
            all_rawdata.append({'epoch_dat': epoch.get_data(), 'metadata': epoch.metadata})
            
        if task == 'stim':
        
            st_epoch = mne.read_epochs(os.path.join(preproc_path, 'mainstim_epo.fif'), verbose=False)
            # print(st_epoch.info['sfreq'])
            if resample: 
                print('Frequency before:', st_epoch.info['sfreq'])
                st_epoch = st_epoch.resample(resample_frequency)
                print('Frequency after:' ,st_epoch.info['sfreq'])
                
            all_st_epochs.append(st_epoch)
            all_st_rawdata.append(
                {
                'epoch_dat': st_epoch.get_data()[st_epoch.metadata['nrep'] == number_of_repetition,:,:] ,
                'metadata': st_epoch.metadata[st_epoch.metadata['nrep'] == number_of_repetition]
                }
                )
    if task == 'main':
        return all_rawdata
    else:
        return all_st_rawdata


In [5]:
all_rawdata = read_data (task = 'main')

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy functi

In [6]:
all_rawdata[0].keys()

dict_keys(['epoch_dat', 'metadata'])

In [7]:
all_rawdata[0]['epoch_dat'].shape

(250, 32, 2876)

In [8]:
display(all_rawdata[0]['metadata'])

Unnamed: 0,index,subj,nblock,ntrial,nrep,trial_type,cond-1,cond,rDV,DV,...,d5,d6,o1,o2,o3,o4,o5,o6,confi-1,conf_lvl-1
883,0,s01,0,0,0,repeat,0,0,0.205880,-0.18,...,-0.091426,-0.944496,2.697000,0.715000,1.903000,1.576000,1.214000,1.549000,0.00,L
896,1,s01,0,0,1,repeat,0,0,0.205880,-0.18,...,-0.091426,-0.944496,2.697000,0.715000,1.903000,1.576000,1.214000,1.549000,-0.10,L
909,2,s01,0,0,2,repeat,0,0,0.205880,-0.18,...,-0.091426,-0.944496,2.697000,0.715000,1.903000,1.576000,1.214000,1.549000,0.60,L
922,3,s01,0,1,0,nonrepeat,0,0,0.745740,-0.08,...,-0.246761,0.711234,1.675000,2.694000,1.910000,1.876000,1.275000,0.672000,0.85,H
935,4,s01,0,1,1,nonrepeat,0,0,0.745740,-0.08,...,-0.246761,0.711234,1.675000,2.694000,1.910000,1.876000,1.275000,0.672000,0.50,L
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4240,258,s01,3,20,0,repeat,1,0,0.717009,-0.31,...,0.368968,-0.664902,3.012000,1.723000,0.300000,2.710000,2.604000,3.010000,0.85,H
4253,259,s01,3,20,1,repeat,0,0,0.717009,-0.31,...,0.368968,-0.664902,3.012000,1.723000,0.300000,2.710000,2.604000,3.010000,0.45,H
4266,260,s01,3,20,2,repeat,0,0,0.717009,-0.31,...,0.368968,-0.664902,3.012000,1.723000,0.300000,2.710000,2.604000,3.010000,0.75,H
4279,261,s01,3,21,0,nonrepeat,0,1,0.769212,0.21,...,0.890997,-0.651651,2.571000,1.207000,0.681000,1.910000,2.399000,1.434000,0.50,H


In [9]:
all_rawdata[0]['metadata'].columns

Index(['index', 'subj', 'nblock', 'ntrial', 'nrep', 'trial_type', 'cond-1',
       'cond', 'rDV', 'DV', 'resp', 'deci-2', 'deci-1', 'deci', 'corr-1',
       'r_map', 'correct', 'confi', 'RT', 'd1', 'conf_lvl', 'correct-1', 'd2',
       'd3', 'd4', 'd5', 'd6', 'o1', 'o2', 'o3', 'o4', 'o5', 'o6', 'confi-1',
       'conf_lvl-1'],
      dtype='object')

In [10]:
X = all_rawdata[0]['epoch_dat']
X = X.reshape((250, -1))
y = all_rawdata[0]['metadata']['deci']

In [11]:
# X = all_rawdata[0]['epoch_dat']
# y = all_rawdata[0]['metadata']['deci']
# for i in range(1, nSubj):
#     X = np.concatenate((X, all_rawdata[i]['epoch_dat']))
#     y = np.concatenate((y, all_rawdata[i]['metadata']['deci'] ))
# print(X.shape, y.shape)
# X = X.reshape(6205,-1)
# print(X.shape)

In [16]:
from sklearn.svm import SVC
# from pyrcn.echo_state_network import ESNClassifier
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.ensemble import RandomForestClassifier,AdaBoostClassifier,GradientBoostingClassifier
from sklearn.metrics import accuracy_score
from catboost import CatBoostClassifier
from xgboost import XGBClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(X_train.shape, y_train.shape)

clf = AdaBoostClassifier()

clf = GradientBoostingClassifier()

clf = LogisticRegression()

clf = SVC()
clf = CatBoostClassifier()

clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)

accuracy = accuracy_score(y_test, y_pred)
print("Accuracy on test set:", accuracy)



(200, 92032) (200,)
Learning rate set to 0.005182
0:	learn: 0.6908393	total: 2.07s	remaining: 34m 30s
1:	learn: 0.6882918	total: 4.25s	remaining: 35m 19s
2:	learn: 0.6856815	total: 6.38s	remaining: 35m 22s
3:	learn: 0.6835724	total: 8.51s	remaining: 35m 19s
4:	learn: 0.6809235	total: 10.6s	remaining: 35m 19s
5:	learn: 0.6784813	total: 12.8s	remaining: 35m 21s
6:	learn: 0.6763187	total: 15s	remaining: 35m 21s
7:	learn: 0.6743694	total: 17.1s	remaining: 35m 18s
8:	learn: 0.6717551	total: 19.2s	remaining: 35m 16s
9:	learn: 0.6695325	total: 21.4s	remaining: 35m 13s
10:	learn: 0.6674916	total: 23.5s	remaining: 35m 13s
11:	learn: 0.6650108	total: 25.6s	remaining: 35m 10s
12:	learn: 0.6627648	total: 27.8s	remaining: 35m 7s
13:	learn: 0.6610048	total: 29.9s	remaining: 35m 6s
14:	learn: 0.6591129	total: 32s	remaining: 35m 3s
15:	learn: 0.6563911	total: 34.2s	remaining: 35m 1s
16:	learn: 0.6545469	total: 36.3s	remaining: 35m
17:	learn: 0.6525900	total: 38.5s	remaining: 34m 58s
18:	learn: 0.65061

RF all subjects 80-20 = 0.556809024979855

RF, Logistic, SVC = 1 sub 0.66

XGB 1 sub 0.62

XGB 26 0.5495568090249798

Ada 0.6

GB 0.62

0.66
