In [1]:
import os, glob, platform
import numpy as np
import numpy.matlib
import pickle
import pandas as pd
import pathlib
import matplotlib
import matplotlib.pyplot as plt
import mne
mne.__version__
from mne.viz import plot_alignment, snapshot_brain_montage
import shutil
from mne.datasets import eegbci
from sklearn.model_selection import train_test_split

# from mne_bids import write_raw_bids, BIDSPath, print_dir_tree, make_dataset_description
# from mne_bids.stats import count_events
import sys


In [2]:
path_utils = '/decoding_toolbox_py/helper_funcs' 
sys.path.append(path_utils)

In [3]:
''' VARIABLES '''

dataset = 'eeg'

amount_of_subjects = 27 # Change the range so the process is faster
if amount_of_subjects > 26: amount_of_subjects = 26
subjs_list = ['s{:02d}'.format(i) for i in range(1, amount_of_subjects+1) if i != 6 ] 
print(subjs_list)
nSubj = len(subjs_list)

numC = 8

angles = [i * 180./numC for i in range(numC)]

x_labels = np.array(angles)

resample = True # speeds up the procees but showing worse results overall
if resample: resample_frequency = 20 # in Hz, original freq is 500Hz

cfg_stim = dict()
cfg_stim['kappa'] = 4
cfg_stim['NumC'] = numC
cfg_stim['Tuning'] = 'vonmises'
# cfg_stim['Tuning'] = 'halfRectCos'
cfg_stim['offset'] = 0

cfg_train = dict()
cfg_train['gamma'] = 0.1
cfg_train['demean'] = True
cfg_train['returnPattern'] = True

cfg_test = dict()
cfg_test['demean'] = 'traindata'

['s01', 's02', 's03', 's04', 's05', 's07', 's08', 's09', 's10', 's11', 's12', 's13', 's14', 's15', 's16', 's17', 's18', 's19', 's20', 's21', 's22', 's23', 's24', 's25', 's26']


In [4]:
'''EEG Dataset'''
def read_data(
        number_of_repetition=3,
        resample=False,
        subjs_list = subjs_list,
        task = 'stim'
        ):
    path = 'Cond_CJ_EEG'

    epochs = []
    all_epochs = []
    all_rawdata = []
    all_st_epochs = []
    all_st_rawdata = []
    for subject_id in subjs_list:
        preproc_path = os.path.join(path, subject_id)

        if task == 'main':
            epoch = mne.read_epochs(os.path.join(preproc_path, 'main_epo.fif'), verbose=False)
            epochs.append(epoch.average())
            all_epochs.append(epoch)
            all_rawdata.append({'epoch_dat': epoch.get_data(), 'metadata': epoch.metadata})
            
        if task == 'stim':
        
            st_epoch = mne.read_epochs(os.path.join(preproc_path, 'mainstim_epo.fif'), verbose=False)
            # print(st_epoch.info['sfreq'])
            if resample: 
                print('Frequency before:', st_epoch.info['sfreq'])
                st_epoch = st_epoch.resample(resample_frequency)
                print('Frequency after:' ,st_epoch.info['sfreq'])
                
            all_st_epochs.append(st_epoch)
            all_st_rawdata.append(
                {
                'epoch_dat': st_epoch.get_data()[st_epoch.metadata['nrep'] == number_of_repetition,:,:] ,
                'metadata': st_epoch.metadata[st_epoch.metadata['nrep'] == number_of_repetition]
                }
                )
    if task == 'main':
        return all_rawdata
    else:
        return all_st_rawdata


In [5]:
all_rawdata = read_data (task = 'main')

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy functi

In [6]:
all_rawdata[0].keys()

dict_keys(['epoch_dat', 'metadata'])

In [7]:
all_rawdata[0]['epoch_dat'].shape

(250, 32, 2876)

In [8]:
display(all_rawdata[0]['metadata'])

Unnamed: 0,index,subj,nblock,ntrial,nrep,trial_type,cond-1,cond,rDV,DV,...,d5,d6,o1,o2,o3,o4,o5,o6,confi-1,conf_lvl-1
883,0,s01,0,0,0,repeat,0,0,0.205880,-0.18,...,-0.091426,-0.944496,2.697000,0.715000,1.903000,1.576000,1.214000,1.549000,0.00,L
896,1,s01,0,0,1,repeat,0,0,0.205880,-0.18,...,-0.091426,-0.944496,2.697000,0.715000,1.903000,1.576000,1.214000,1.549000,-0.10,L
909,2,s01,0,0,2,repeat,0,0,0.205880,-0.18,...,-0.091426,-0.944496,2.697000,0.715000,1.903000,1.576000,1.214000,1.549000,0.60,L
922,3,s01,0,1,0,nonrepeat,0,0,0.745740,-0.08,...,-0.246761,0.711234,1.675000,2.694000,1.910000,1.876000,1.275000,0.672000,0.85,H
935,4,s01,0,1,1,nonrepeat,0,0,0.745740,-0.08,...,-0.246761,0.711234,1.675000,2.694000,1.910000,1.876000,1.275000,0.672000,0.50,L
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4240,258,s01,3,20,0,repeat,1,0,0.717009,-0.31,...,0.368968,-0.664902,3.012000,1.723000,0.300000,2.710000,2.604000,3.010000,0.85,H
4253,259,s01,3,20,1,repeat,0,0,0.717009,-0.31,...,0.368968,-0.664902,3.012000,1.723000,0.300000,2.710000,2.604000,3.010000,0.45,H
4266,260,s01,3,20,2,repeat,0,0,0.717009,-0.31,...,0.368968,-0.664902,3.012000,1.723000,0.300000,2.710000,2.604000,3.010000,0.75,H
4279,261,s01,3,21,0,nonrepeat,0,1,0.769212,0.21,...,0.890997,-0.651651,2.571000,1.207000,0.681000,1.910000,2.399000,1.434000,0.50,H


In [9]:
X = all_rawdata[0]['epoch_dat']
#X = X.reshape((250, -1))
y = all_rawdata[0]['metadata']['deci']

In [10]:
X = all_rawdata[0]['epoch_dat']
y = all_rawdata[0]['metadata']['deci']
for i in range(1, nSubj):
    X = np.concatenate((X, all_rawdata[i]['epoch_dat']))
    y = np.concatenate((y, all_rawdata[i]['metadata']['deci'] ))
print(X.shape, y.shape)

KeyboardInterrupt: 

In [None]:
np.unique(X)

array([-0.01591576, -0.01591313, -0.01591171, ...,  0.0163828 ,
        0.01638285,  0.01638598])

In [None]:
X = X.reshape((6205, -1))

In [None]:
stop

NameError: name 'stop' is not defined

In [None]:
from sklearn.svm import SVC
# from pyrcn.echo_state_network import ESNClassifier
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from catboost import CatBoostClassifier

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

clf = CatBoostClassifier(
    task_type="GPU",
    devices='0:1'
    )
# clf = RandomForestClassifier()
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)

accuracy = accuracy_score(y_test, y_pred)
print("Accuracy on test set:", accuracy)



Learning rate set to 0.089998
0:	learn: 1.0908343	total: 910ms	remaining: 15m 9s
1:	learn: 1.0840340	total: 1.63s	remaining: 13m 32s
2:	learn: 1.0772471	total: 2.35s	remaining: 12m 59s
3:	learn: 1.0701694	total: 3.06s	remaining: 12m 42s
4:	learn: 1.0640090	total: 3.78s	remaining: 12m 32s
5:	learn: 1.0587263	total: 4.5s	remaining: 12m 25s
6:	learn: 1.0535447	total: 5.21s	remaining: 12m 19s
7:	learn: 1.0475873	total: 5.93s	remaining: 12m 14s
8:	learn: 1.0431013	total: 6.63s	remaining: 12m 10s
9:	learn: 1.0385733	total: 7.35s	remaining: 12m 8s
10:	learn: 1.0332393	total: 8.07s	remaining: 12m 5s
11:	learn: 1.0283730	total: 8.79s	remaining: 12m 3s
12:	learn: 1.0234112	total: 9.5s	remaining: 12m 1s
13:	learn: 1.0197981	total: 10.2s	remaining: 11m 59s
14:	learn: 1.0152834	total: 10.9s	remaining: 11m 57s
15:	learn: 1.0110196	total: 11.6s	remaining: 11m 55s
16:	learn: 1.0067201	total: 12.4s	remaining: 11m 54s
17:	learn: 1.0025694	total: 13.1s	remaining: 11m 52s
18:	learn: 0.9983333	total: 13.8s

RF all subjects 80-20 = 0.556809024979855