/
example_classification.py
82 lines (68 loc) · 2.37 KB
/
example_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import StratifiedKFold, cross_val_score
from pyriemann.embedding import SpectralEmbedding
from pyriemann.classification import MDM
from pyriemann.estimation import Covariances
from alphawaves.dataset import AlphaWaves
import matplotlib.pyplot as plt
import numpy as np
import mne
"""
=============================
Classification of the trials
=============================
This example shows how to extract the epochs from the dataset of a given
subject and then classify them using Machine Learning techniques using
Riemannian Geometry. The code also creates a figure with the spectral embedding
of the epochs.
"""
# Authors: Pedro Rodrigues <pedro.rodrigues01@gmail.com>
#
# License: BSD (3-clause)
import warnings
warnings.filterwarnings("ignore")
# define the dataset instance
dataset = AlphaWaves()
# get the data from subject of interest
subject = dataset.subject_list[0]
raw = dataset._get_single_subject_data(subject)
# filter data and resample
fmin = 3
fmax = 40
raw.filter(fmin, fmax, verbose=False)
raw.resample(sfreq=128, verbose=False)
# detect the events and cut the signal into epochs
events = mne.find_events(raw=raw, shortest_event=1, verbose=False)
event_id = {'closed': 1, 'open': 2}
epochs = mne.Epochs(raw, events, event_id, tmin=2.0, tmax=8.0, baseline=None,
verbose=False, preload=True)
epochs.pick_types(eeg=True)
# get trials and labels
X = epochs.get_data()
y = events[:, -1]
# cross validation
skf = StratifiedKFold(n_splits=5)
clf = make_pipeline(Covariances(estimator='lwf'), MDM())
scr = cross_val_score(clf, X, y, cv=skf)
# print results of classification
print('subject', subject)
print('mean accuracy :', scr.mean())
# get the spectral embedding of the epochs
C = Covariances(estimator='lwf').fit_transform(X)
emb = SpectralEmbedding(metric='riemann').fit_transform(C)
# scatter plot of the embedded points
fig = plt.figure(facecolor='white', figsize=(5.6, 5.2))
colors = {1: 'r', 2: 'b'}
for embi, yi in zip(emb, y):
plt.scatter(embi[0], embi[1], s=120, c=colors[yi])
labels = {1: 'closed', 2: 'open'}
for yi in np.unique(y):
plt.scatter([], [], c=colors[yi], label=labels[yi])
plt.xticks([-1, -0.5, 0.0, +0.5, 1.0])
plt.yticks([-1, -0.5, 0.0, +0.5, 1.0])
plt.legend()
plt.title(
'Spectral embedding of the epochs from subject ' +
str(subject),
fontsize=10)
plt.show()