In [1]:
# !pip install mne
# !pip install plotly

Collecting mne
  Downloading mne-0.15.2.tar.gz (6.0MB)
[K    100% |████████████████████████████████| 6.0MB 98kB/s eta 0:00:011   15% |█████                           | 931kB 2.2MB/s eta 0:00:03    19% |██████▍                         | 1.2MB 965kB/s eta 0:00:05    31% |██████████▎                     | 1.9MB 1.5MB/s eta 0:00:03    44% |██████████████▎                 | 2.7MB 10.3MB/s eta 0:00:01    70% |██████████████████████▋         | 4.2MB 3.0MB/s eta 0:00:01    73% |███████████████████████▍        | 4.4MB 2.9MB/s eta 0:00:01    76% |████████████████████████▌       | 4.6MB 413kB/s eta 0:00:04    97% |███████████████████████████████▎| 5.8MB 2.4MB/s eta 0:00:01
[?25hBuilding wheels for collected packages: mne
  Running setup.py bdist_wheel for mne ... [?25ldone
[?25h  Stored in directory: /home/jovyan/.cache/pip/wheels/bd/68/89/053b95464866970188323cec27e724bfaf0fe9adb4fe325061
Successfully built mne
Installing collected packages: mne
Successfully installed mne-0.15.2


In [2]:
import mne
import csv
import numpy as np
from sklearn import preprocessing, metrics
from sklearn.model_selection import cross_val_predict
from sklearn.pipeline import Pipeline
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import LogisticRegression

In [3]:
names = [
    'Beam',
    'Eye',
    'Fluke',
    'Fong',
    'Joke',
    'Pod',
    'Tau',
    'Toey',
    'Tong',
]

In [4]:
filenames = [
    'Beam-t1 (1024Hz)',
    'Beam-t2 (1024Hz)',
    'Beam-t3 (1024Hz)',
    'Beam-t6 (1024Hz)',
    'Beam-t7 (1024Hz)',
    'Beam-t8 (1024Hz)',
    'Eye-t2 (1024Hz)',
    'Eye-t3 (1024Hz)',
    'Eye-t4 (1024Hz)',
    'Eye-t5 (1024Hz)',
    'Eye-t6 (1024Hz)',
    'Eye-t7 (1024Hz)',
    'Fluke-t20 (1024Hz)',
    'Fluke-t21 (1024Hz)',
    'Fluke-t22 (1024Hz)',
    'Fluke-t23 (1024Hz)',
    'Fluke-t24 (1024Hz)',
    'Fluke-t25 (1024Hz)',
    'Fong-t3 (1024Hz)',
    'Fong-t4 (1024Hz)',
    'Fong-t5 (1024Hz)',
    'Fong-t6 (1024Hz)',
    'Fong-t7 (1024Hz)',
    'Fong-t8 (1024Hz)',
    'Joke-t1 (1024Hz)',
    'Joke-t2 (1024Hz)',
    'Joke-t3 (1024Hz)',
    'Joke-t4 (1024Hz)',
    'Joke-t5 (1024Hz)',
    'Joke-t6 (1024Hz)',
    'Pod-t21 (1024Hz)',
    'Pod-t23 (1024Hz)',
    'Pod-t24 (1024Hz)',
    'Pod-t25 (1024Hz)',
    'Pod-t26 (1024Hz)',
    'Tau-t13 (1024Hz)',
    'Tau-t14 (1024Hz)',
    'Tau-t15 (1024Hz)',
    'Tau-t16 (1024Hz)',
    'Tau-t17 (1024Hz)',
    'Tau-t18 (1024Hz)',
    'Toey-t1 (1024Hz)',
    'Toey-t2 (1024Hz)',
    'Toey-t3 (1024Hz)',
    'Toey-t4 (1024Hz)',
    'Toey-t5 (1024Hz)',
    'Toey-t6 (1024Hz)',
    'Toey-t7 (1024Hz)',
    'Tong-t1 (1024Hz)',
    'Tong-t2 (1024Hz)',
    'Tong-t3 (1024Hz)',
    'Tong-t4 (1024Hz)',
    'Tong-t5 (1024Hz)',
    'Tong-t6 (1024Hz)'
]

In [5]:
def standardize(X):
    return preprocessing.scale(X)

## Importing EEG Data

In [6]:
# Initialize an info structure
n_channel = 19
ch_names = ['Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'T3', 'C3', 'Cz', 'C4', 'T4', 'T5', 'P3', 'Pz', 'P4', 'T6', 'O1', 'O2']
ch_types = ['eeg'] * n_channel + ['stim']
sfreq = 1024 # Sampling rate 1024 Hz
montage = mne.channels.read_montage('standard_1005', ch_names)
mne_info = mne.create_info(
    ch_names = ch_names + ['event'],
    sfreq = sfreq,
    ch_types = ch_types,
    montage = montage
)
event_id = {'idle': 0, 'onset': 1}
picks = mne.pick_types(mne_info, eeg = True)

In [7]:
# Assemble a classifier
lda = LinearDiscriminantAnalysis()
vectorizer = mne.decoding.Vectorizer()
csp = mne.decoding.CSP(n_components = 19, cov_est = 'epoch', transform_into = 'csp_space', norm_trace = True)

# Use scikit-learn Pipeline with cross_val_score function
clf = Pipeline([('CSP', csp), ('VEC', vectorizer), ('LDA', lda)])
# clf = Pipeline([('VEC', vectorizer), ('LDA', lda)])

In [9]:
data_dir = './data/trim/'
# result_filename = 'iir01-30_resample64_csp_lda'
# result_filename = 'iir01-30_csp_lda'
# result_filename = 'csp_lda'

# result_filename = 'stdscale_lda'
# result_filename = 'stdscale_resample64_lda'
# result_filename = 'iir01-30_stdscale_resample64_lda'
# result_filename = 'iir01-30_stdscale_resample64_csp_lda'
# result_filename = 'iir01-30_stdscale_resample64_csp_lda_kfold5'
result_filename = 'iir01-30_stdscale_resample64_csp_lda_allsubjects'

with open('./data/classification/result/' + result_filename + '.csv', 'w', newline='') as fout:
    fieldnames = ['Name', 'Acc', 'F1', 'TPR', 'FPR', 'TP', 'TN', 'FP', 'FN']
    writer = csv.DictWriter(fout, fieldnames = fieldnames)
    writer.writeheader()
    
    epochs_all = None
    for j in range(0, len(filenames)):
        print('processing ' + filenames[j] + '...')
        np_raw_data = np.genfromtxt(data_dir + filenames[j] + '_data_label.csv', delimiter = ',')
        # np_raw_data[:, np.arange(0, n_channel, 1)] = standardize(np_raw_data[:, np.arange(0, n_channel, 1)])
        mne_raw_data = mne.io.RawArray(np_raw_data.T, mne_info)
        # Band-pass Filter
        mne_raw_data = mne_raw_data.filter(0.1, 30, picks = picks, method = 'iir')
        events = mne.find_events(mne_raw_data, stim_channel = 'event', output = 'step', consecutive = True)
        epochs = mne.Epochs(mne_raw_data, events, event_id, 0, 2, proj = False,
                            picks = picks, baseline = None, preload = True)
        if epochs_all is None:
            epochs_all = epochs
        else:
            epochs_all = mne.concatenate_epochs([epochs_all, epochs])

    # Standard Scale
    scaler = mne.decoding.Scaler(scalings = 'mean', with_mean = True, with_std = True)
    transformed_epochs_data = scaler.fit_transform(epochs_all.get_data())

    # Resample to 64Hz
    transformed_epochs_data = mne.filter.resample(transformed_epochs_data, down = 16., npad = 'auto', pad = 'edge')

    # Data and Labels
    X = transformed_epochs_data
    y = epochs_all.events[:, -1]

    print('Classifying...')
    predicted = cross_val_predict(clf, X, y, cv = 10)
    # predicted = cross_val_predict(clf, X, y, cv = 5)

    acc = metrics.accuracy_score(y, predicted)
    f1 = metrics.f1_score(y, predicted)
    tn, fp, fn, tp = metrics.confusion_matrix(y, predicted).ravel()
    tpr = tp / (tp + fn)
    fpr = fp / (fp + tn)

    print('Exporting result...')
    writer.writerow({
        'Name': 'All Subjects',
        'Acc': '{:.10f}'.format(acc),
        'F1': '{:.10f}'.format(f1),
        'TPR': '{:.10f}'.format(tpr),
        'FPR': '{:.10f}'.format(fpr),
        'TP': '{:.10f}'.format(tp),
        'TN': '{:.10f}'.format(tn),
        'FP': '{:.10f}'.format(fp),
        'FN': '{:.10f}'.format(fn)
    })
    
    fout.close()
print('Done.')

processing Beam-t1 (1024Hz)...
Creating RawArray with float64 data, n_channels=20, n_times=245760
    Range : 0 ... 245759 =      0.000 ...   239.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
121 events found
Events id: [0 1 2]
61 matching events found
Loading data for 61 events and 2049 original time points ...
1 bad epochs dropped
processing Beam-t2 (1024Hz)...
Creating RawArray with float64 data, n_channels=20, n_times=192512
    Range : 0 ... 192511 =      0.000 ...   187.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
101 events found
Events id: [0 1 2]
51 matching events found
Loading data for 51 events and 2049 original time points ...
1 bad epochs dropped
110 matching events found
0 bad epochs dropped
processing Beam-t3 (1024Hz)...
Creating RawArray with float64 data, n_channels=20, n_times=207872
    Range : 0 ... 207871 =      0.000 ...   202.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
113 events found
Events id: [0 1 2]
57 matchin

1 bad epochs dropped
1216 matching events found
0 bad epochs dropped
processing Fong-t6 (1024Hz)...
Creating RawArray with float64 data, n_channels=20, n_times=198656
    Range : 0 ... 198655 =      0.000 ...   193.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
117 events found
Events id: [0 1 2]
59 matching events found
Loading data for 59 events and 2049 original time points ...
1 bad epochs dropped
1274 matching events found
0 bad epochs dropped
processing Fong-t7 (1024Hz)...
Creating RawArray with float64 data, n_channels=20, n_times=203776
    Range : 0 ... 203775 =      0.000 ...   198.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
115 events found
Events id: [0 1 2]
58 matching events found
Loading data for 58 events and 2049 original time points ...
1 bad epochs dropped
1331 matching events found
0 bad epochs dropped
processing Fong-t8 (1024Hz)...
Creating RawArray with float64 data, n_channels=20, n_times=212992
    Range : 0 ... 212991 =      0.000

123 events found
Events id: [0 1 2]
62 matching events found
Loading data for 62 events and 2049 original time points ...
1 bad epochs dropped
2487 matching events found
0 bad epochs dropped
processing Toey-t2 (1024Hz)...
Creating RawArray with float64 data, n_channels=20, n_times=242688
    Range : 0 ... 242687 =      0.000 ...   236.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
115 events found
Events id: [0 1 2]
58 matching events found
Loading data for 58 events and 2049 original time points ...
1 bad epochs dropped
2544 matching events found
0 bad epochs dropped
processing Toey-t3 (1024Hz)...
Creating RawArray with float64 data, n_channels=20, n_times=245760
    Range : 0 ... 245759 =      0.000 ...   239.999 secs
Ready.
Setting up band-pass filter from 0.1 - 30 Hz
115 events found
Events id: [0 1 2]
58 matching events found
Loading data for 58 events and 2049 original time points ...
1 bad epochs dropped
2601 matching events found
0 bad epochs dropped
processing To