In [None]:
# !pip install mne
# !pip install plotly

In [5]:
import mne
import csv
import numpy as np
from sklearn import preprocessing, metrics
from sklearn.model_selection import cross_val_predict
from sklearn.pipeline import Pipeline
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import LogisticRegression

In [2]:
names = [
    'Beam',
    'Eye',
    'Fluke',
    'Fong',
    'Joke',
    'Pod',
    'Tau',
    'Toey',
    'Tong',
]

In [3]:
filenames = [
    'Beam-t1 (1024Hz)',
    'Beam-t2 (1024Hz)',
    'Beam-t3 (1024Hz)',
    'Beam-t6 (1024Hz)',
    'Beam-t7 (1024Hz)',
    'Beam-t8 (1024Hz)',
    'Eye-t2 (1024Hz)',
    'Eye-t3 (1024Hz)',
    'Eye-t4 (1024Hz)',
    'Eye-t5 (1024Hz)',
    'Eye-t6 (1024Hz)',
    'Eye-t7 (1024Hz)',
    'Fluke-t20 (1024Hz)',
    'Fluke-t21 (1024Hz)',
    'Fluke-t22 (1024Hz)',
    'Fluke-t23 (1024Hz)',
    'Fluke-t24 (1024Hz)',
    'Fluke-t25 (1024Hz)',
    'Fong-t3 (1024Hz)',
    'Fong-t4 (1024Hz)',
    'Fong-t5 (1024Hz)',
    'Fong-t6 (1024Hz)',
    'Fong-t7 (1024Hz)',
    'Fong-t8 (1024Hz)',
    'Joke-t1 (1024Hz)',
    'Joke-t2 (1024Hz)',
    'Joke-t3 (1024Hz)',
    'Joke-t4 (1024Hz)',
    'Joke-t5 (1024Hz)',
    'Joke-t6 (1024Hz)',
    'Pod-t21 (1024Hz)',
    'Pod-t23 (1024Hz)',
    'Pod-t24 (1024Hz)',
    'Pod-t25 (1024Hz)',
    'Pod-t26 (1024Hz)',
    'Tau-t13 (1024Hz)',
    'Tau-t14 (1024Hz)',
    'Tau-t15 (1024Hz)',
    'Tau-t16 (1024Hz)',
    'Tau-t17 (1024Hz)',
    'Tau-t18 (1024Hz)',
    'Toey-t1 (1024Hz)',
    'Toey-t2 (1024Hz)',
    'Toey-t3 (1024Hz)',
    'Toey-t4 (1024Hz)',
    'Toey-t5 (1024Hz)',
    'Toey-t6 (1024Hz)',
    'Toey-t7 (1024Hz)',
    'Tong-t1 (1024Hz)',
    'Tong-t2 (1024Hz)',
    'Tong-t3 (1024Hz)',
    'Tong-t4 (1024Hz)',
    'Tong-t5 (1024Hz)',
    'Tong-t6 (1024Hz)'
]

In [4]:
def standardize(X):
    return preprocessing.scale(X)

## Importing EEG Data

In [5]:
# Initialize an info structure
n_channel = 19
ch_names = ['Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'T3', 'C3', 'Cz', 'C4', 'T4', 'T5', 'P3', 'Pz', 'P4', 'T6', 'O1', 'O2']
ch_types = ['eeg'] * n_channel + ['stim']
sfreq = 1024 # Sampling rate 1024 Hz
montage = mne.channels.read_montage('standard_1005', ch_names)
mne_info = mne.create_info(
    ch_names = ch_names + ['event'],
    sfreq = sfreq,
    ch_types = ch_types,
    montage = montage
)
event_id = {'idle': 0, 'onset': 1}
picks = mne.pick_types(mne_info, eeg = True)

In [16]:
# Assemble a classifier
lda = LinearDiscriminantAnalysis()
vectorizer = mne.decoding.Vectorizer()
csp = mne.decoding.CSP(n_components = 19, cov_est = 'epoch', transform_into = 'csp_space', norm_trace = True)

# Use scikit-learn Pipeline with cross_val_score function
clf = Pipeline([('CSP', csp), ('VEC', vectorizer), ('LDA', lda)])
# clf = Pipeline([('VEC', vectorizer), ('LDA', lda)])

In [17]:
data_dir = './data/trim/'
# result_filename = 'iir01-30_resample64_csp_lda'
# result_filename = 'iir01-30_csp_lda'
# result_filename = 'csp_lda'

# result_filename = 'stdscale_lda'
# result_filename = 'stdscale_resample64_lda'
# result_filename = 'iir01-30_stdscale_resample64_lda'
# result_filename = 'iir01-30_stdscale_resample64_csp_lda'
result_filename = 'iir01-30_stdscale_resample64_csp_lda_kfold5'

with open('./data/classification/result/' + result_filename + '.csv', 'w', newline='') as fout:
    fieldnames = ['Name', 'Acc', 'F1', 'TPR', 'FPR', 'TP', 'TN', 'FP', 'FN']
    writer = csv.DictWriter(fout, fieldnames = fieldnames)
    writer.writeheader()
    
    for i in range(0, len(names)):
        print('Processing ' + names[i] + '...')
        epochs_all = None
        for j in range(0, len(filenames)):
            name, _ = filenames[j].split('-')
            if name == names[i]:
                np_raw_data = np.genfromtxt(data_dir + filenames[j] + '_data_label.csv', delimiter = ',')
                # np_raw_data[:, np.arange(0, n_channel, 1)] = standardize(np_raw_data[:, np.arange(0, n_channel, 1)])
                mne_raw_data = mne.io.RawArray(np_raw_data.T, mne_info)
                # Band-pass Filter
                mne_raw_data = mne_raw_data.filter(0.1, 30, picks = picks, method = 'iir')
                events = mne.find_events(mne_raw_data, stim_channel = 'event', output = 'step', consecutive = True)
                epochs = mne.Epochs(mne_raw_data, events, event_id, 0, 2, proj = False,
                                    picks = picks, baseline = None, preload = True)
                if epochs_all is None:
                    epochs_all = epochs
                else:
                    epochs_all = mne.concatenate_epochs([epochs_all, epochs])

        # Standard Scale
        scaler = mne.decoding.Scaler(scalings = 'mean', with_mean = True, with_std = True)
        transformed_epochs_data = scaler.fit_transform(epochs_all.get_data())
        
        # Resample to 64Hz
        transformed_epochs_data = mne.filter.resample(transformed_epochs_data, down = 16., npad = 'auto', pad = 'edge')
        
        # Data and Labels
        X = transformed_epochs_data
        y = epochs_all.events[:, -1]
        
        print('Classifying...')
        # predicted = cross_val_predict(clf, X, y, cv = 10)
        predicted = cross_val_predict(clf, X, y, cv = 5)

        acc = metrics.accuracy_score(y, predicted)
        f1 = metrics.f1_score(y, predicted)
        tn, fp, fn, tp = metrics.confusion_matrix(y, predicted).ravel()
        tpr = tp / (tp + fn)
        fpr = fp / (fp + tn)

        print('Exporting result...')
        writer.writerow({
            'Name': names[i],
            'Acc': '{:.10f}'.format(acc),
            'F1': '{:.10f}'.format(f1),
            'TPR': '{:.10f}'.format(tpr),
            'FPR': '{:.10f}'.format(fpr),
            'TP': '{:.10f}'.format(tp),
            'TN': '{:.10f}'.format(tn),
            'FP': '{:.10f}'.format(fp),
            'FN': '{:.10f}'.format(fn)
        })
    
    fout.close()
print('Done.')

Processing Beam...
Creating RawArray with float64 data, n_channels=20, n_times=245760
    Range : 0 ... 245759 =      0.000 ...   239.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
121 events found
Events id: [0 1 2]
61 matching events found
Loading data for 61 events and 2049 original time points ...
1 bad epochs dropped
Creating RawArray with float64 data, n_channels=20, n_times=192512
    Range : 0 ... 192511 =      0.000 ...   187.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
101 events found
Events id: [0 1 2]
51 matching events found
Loading data for 51 events and 2049 original time points ...
1 bad epochs dropped
110 matching events found
0 bad epochs dropped
Creating RawArray with float64 data, n_channels=20, n_times=207872
    Range : 0 ... 207871 =      0.000 ...   202.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
113 events found
Events id: [0 1 2]
57 matching events found
Loading data for 57 events and 2049 original time points ..



Exporting result...
Processing Eye...
Creating RawArray with float64 data, n_channels=20, n_times=206848
    Range : 0 ... 206847 =      0.000 ...   201.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
121 events found
Events id: [0 1 2]
61 matching events found
Loading data for 61 events and 2049 original time points ...
1 bad epochs dropped
Creating RawArray with float64 data, n_channels=20, n_times=210944
    Range : 0 ... 210943 =      0.000 ...   205.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
115 events found
Events id: [0 1 2]
58 matching events found
Loading data for 58 events and 2049 original time points ...
1 bad epochs dropped
117 matching events found
0 bad epochs dropped
Creating RawArray with float64 data, n_channels=20, n_times=193536
    Range : 0 ... 193535 =      0.000 ...   188.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
117 events found
Events id: [0 1 2]
59 matching events found
Loading data for 59 events and 2049 orig



Exporting result...
Processing Fluke...
Creating RawArray with float64 data, n_channels=20, n_times=194560
    Range : 0 ... 194559 =      0.000 ...   189.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
117 events found
Events id: [0 1 2]
59 matching events found
Loading data for 59 events and 2049 original time points ...
1 bad epochs dropped
Creating RawArray with float64 data, n_channels=20, n_times=210944
    Range : 0 ... 210943 =      0.000 ...   205.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
123 events found
Events id: [0 1 2]
62 matching events found
Loading data for 62 events and 2049 original time points ...
1 bad epochs dropped
119 matching events found
0 bad epochs dropped
Creating RawArray with float64 data, n_channels=20, n_times=181248
    Range : 0 ... 181247 =      0.000 ...   176.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
111 events found
Events id: [0 1 2]
56 matching events found
Loading data for 56 events and 2049 or



Exporting result...
Processing Fong...
Creating RawArray with float64 data, n_channels=20, n_times=203776
    Range : 0 ... 203775 =      0.000 ...   198.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
119 events found
Events id: [0 1 2]
60 matching events found
Loading data for 60 events and 2049 original time points ...
1 bad epochs dropped
Creating RawArray with float64 data, n_channels=20, n_times=198656
    Range : 0 ... 198655 =      0.000 ...   193.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
117 events found
Events id: [0 1 2]
59 matching events found
Loading data for 59 events and 2049 original time points ...
1 bad epochs dropped
117 matching events found
0 bad epochs dropped
Creating RawArray with float64 data, n_channels=20, n_times=185344
    Range : 0 ... 185343 =      0.000 ...   180.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
109 events found
Events id: [0 1 2]
55 matching events found
Loading data for 55 events and 2049 ori



Exporting result...
Processing Joke...
Creating RawArray with float64 data, n_channels=20, n_times=219136
    Range : 0 ... 219135 =      0.000 ...   213.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
111 events found
Events id: [0 1 2]
56 matching events found
Loading data for 56 events and 2049 original time points ...
1 bad epochs dropped
Creating RawArray with float64 data, n_channels=20, n_times=262144
    Range : 0 ... 262143 =      0.000 ...   255.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
109 events found
Events id: [0 1 2]
55 matching events found
Loading data for 55 events and 2049 original time points ...
1 bad epochs dropped
109 matching events found
0 bad epochs dropped
Creating RawArray with float64 data, n_channels=20, n_times=188416
    Range : 0 ... 188415 =      0.000 ...   183.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
103 events found
Events id: [0 1 2]
52 matching events found
Loading data for 52 events and 2049 ori



Exporting result...
Processing Pod...
Creating RawArray with float64 data, n_channels=20, n_times=164864
    Range : 0 ... 164863 =      0.000 ...   160.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
101 events found
Events id: [0 1 2]
51 matching events found
Loading data for 51 events and 2049 original time points ...
1 bad epochs dropped
Creating RawArray with float64 data, n_channels=20, n_times=177152
    Range : 0 ... 177151 =      0.000 ...   172.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
119 events found
Events id: [0 1 2]
60 matching events found
Loading data for 60 events and 2049 original time points ...
1 bad epochs dropped
109 matching events found
0 bad epochs dropped
Creating RawArray with float64 data, n_channels=20, n_times=250880
    Range : 0 ... 250879 =      0.000 ...   244.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
119 events found
Events id: [0 1 2]
60 matching events found
Loading data for 60 events and 2049 orig



Exporting result...
Processing Tau...
Creating RawArray with float64 data, n_channels=20, n_times=227328
    Range : 0 ... 227327 =      0.000 ...   221.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
147 events found
Events id: [0 1 2]
74 matching events found
Loading data for 74 events and 2049 original time points ...
1 bad epochs dropped
Creating RawArray with float64 data, n_channels=20, n_times=217088
    Range : 0 ... 217087 =      0.000 ...   211.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
127 events found
Events id: [0 1 2]
64 matching events found
Loading data for 64 events and 2049 original time points ...
1 bad epochs dropped
136 matching events found
0 bad epochs dropped
Creating RawArray with float64 data, n_channels=20, n_times=252928
    Range : 0 ... 252927 =      0.000 ...   246.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
131 events found
Events id: [0 1 2]
66 matching events found
Loading data for 66 events and 2049 orig



Exporting result...
Processing Toey...
Creating RawArray with float64 data, n_channels=20, n_times=477184
    Range : 0 ... 477183 =      0.000 ...   465.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
123 events found
Events id: [0 1 2]
62 matching events found
Loading data for 62 events and 2049 original time points ...
1 bad epochs dropped
Creating RawArray with float64 data, n_channels=20, n_times=242688
    Range : 0 ... 242687 =      0.000 ...   236.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
115 events found
Events id: [0 1 2]
58 matching events found
Loading data for 58 events and 2049 original time points ...
1 bad epochs dropped
118 matching events found
0 bad epochs dropped
Creating RawArray with float64 data, n_channels=20, n_times=245760
    Range : 0 ... 245759 =      0.000 ...   239.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
115 events found
Events id: [0 1 2]
58 matching events found
Loading data for 58 events and 2049 ori



Exporting result...
Processing Tong...
Creating RawArray with float64 data, n_channels=20, n_times=262144
    Range : 0 ... 262143 =      0.000 ...   255.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
107 events found
Events id: [0 1 2]
54 matching events found
Loading data for 54 events and 2049 original time points ...
1 bad epochs dropped
Creating RawArray with float64 data, n_channels=20, n_times=284672
    Range : 0 ... 284671 =      0.000 ...   277.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
131 events found
Events id: [0 1 2]
66 matching events found
Loading data for 66 events and 2049 original time points ...
1 bad epochs dropped
118 matching events found
0 bad epochs dropped
Creating RawArray with float64 data, n_channels=20, n_times=243712
    Range : 0 ... 243711 =      0.000 ...   237.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
133 events found
Events id: [0 1 2]
67 matching events found
Loading data for 67 events and 2049 ori



Exporting result...
Done.


