In [63]:
# !pip install mne

In [64]:
import numpy as np
import matplotlib.pyplot as plt

import mne
from mne.datasets.sleep_physionet.age import fetch_data
from mne.time_frequency import psd_welch

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import FunctionTransformer

from tqdm import tqdm

In [72]:
train_subjects = fetch_data(subjects=list(range(18)), recording=[1])
raw_train, annotations_train = [], []
for i, subj in enumerate(train_subjects): 
  raw_train.append(mne.io.read_raw_edf(subj[0], stim_channel='marker', misc=['rectal']))
  annotations_train.append(mne.read_annotations(subj[1]))
  raw_train[i].set_annotations(annotations_train[i], emit_warning=False)

Using default location ~/mne_data for PHYSIONET_SLEEP...
Extracting EDF parameters from /root/mne_data/physionet-sleep-data/SC4001E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /root/mne_data/physionet-sleep-data/SC4011E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /root/mne_data/physionet-sleep-data/SC4021E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /root/mne_data/physionet-sleep-data/SC4031E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /root/mne_data/physionet-sleep-data/SC4041E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /root/mne_data/physionet-sleep-data/SC4051E0-PSG.edf...
EDF fi

In [73]:
annotation_desc_2_event_id = {'Sleep stage W': 1,
                              'Sleep stage 1': 2,
                              'Sleep stage 2': 3,
                              'Sleep stage 3': 4,
                              'Sleep stage 4': 4,
                              'Sleep stage R': 5}

# keep last 30-min wake events before sleep and first 30-min wake events after
# sleep and redefine annotations on raw data
for i, annot_train in enumerate(annotations_train):
  annotations_train[i].crop(annot_train[1]['onset'] - 30 * 60,
                  annot_train[-2]['onset'] + 30 * 60)
  raw_train[i].set_annotations(annot_train, emit_warning=False)

  events_train, _ = mne.events_from_annotations(
      raw_train[i], event_id=annotation_desc_2_event_id, chunk_duration=30.)

  # create a new event_id that unifies stages 3 and 4
  event_id = {'Sleep stage W': 1,
              'Sleep stage 1': 2,
              'Sleep stage 2': 3,
              'Sleep stage 3/4': 4,
              'Sleep stage R': 5}

  # plot events
  # fig = mne.viz.plot_events(events_train, event_id=event_id,
  #                           sfreq=raw_train[i].info['sfreq'],
  #                           first_samp=events_train[0, 0])

  # # keep the color-code for further plotting
  # stage_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage R', 'Sleep stage W']
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 

In [74]:
# fig, axs = plt.subplots(len(train_subjects), figsize=(25, 50))
epochs_train_lst = []
for i, raw in tqdm(enumerate(raw_train)):
  tmax = 30. - 1. / raw.info['sfreq']  # tmax in included

  epochs_train = mne.Epochs(raw=raw, events=events_train,
                            event_id=event_id, tmin=0., tmax=tmax, baseline=None)
  epochs_train_lst.append(epochs_train)
#   stages = sorted(event_id.keys())

#   for stage, color in zip(stages, stage_colors):
#     epochs_train[stage].plot_psd(area_mode=None, color=color, ax=axs[i],
#                           fmin=0.1, fmax=20., show=False,
#                           average=True, spatial_colors=False)
#     axs[i].set(title="Subject1", xlabel='Frequency (Hz)')

#   axs[i].set(ylabel='µV^2/Hz (dB)')
#   axs[i].legend(axs[i].lines[2::3], stages)
# plt.show()

0it [00:00, ?it/s]

Not setting metadata
Not setting metadata
1002 matching events found
No baseline correction applied
0 projection items activated
Not setting metadata
Not setting metadata
1002 matching events found
No baseline correction applied
0 projection items activated
Not setting metadata
Not setting metadata
1002 matching events found
No baseline correction applied
0 projection items activated
Not setting metadata
Not setting metadata
1002 matching events found
No baseline correction applied
0 projection items activated


4it [00:00, 31.08it/s]

Not setting metadata
Not setting metadata
1002 matching events found
No baseline correction applied
0 projection items activated
Not setting metadata
Not setting metadata
1002 matching events found
No baseline correction applied
0 projection items activated
Not setting metadata
Not setting metadata
1002 matching events found
No baseline correction applied
0 projection items activated
Not setting metadata
Not setting metadata
1002 matching events found
No baseline correction applied
0 projection items activated


8it [00:00, 34.86it/s]

Not setting metadata
Not setting metadata
1002 matching events found
No baseline correction applied
0 projection items activated
Not setting metadata
Not setting metadata
1002 matching events found
No baseline correction applied
0 projection items activated
Not setting metadata
Not setting metadata
1002 matching events found
No baseline correction applied
0 projection items activated
Not setting metadata
Not setting metadata
1002 matching events found
No baseline correction applied
0 projection items activated


12it [00:00, 35.91it/s]

Not setting metadata
Not setting metadata
1002 matching events found
No baseline correction applied
0 projection items activated
Not setting metadata
Not setting metadata
1002 matching events found
No baseline correction applied
0 projection items activated
Not setting metadata
Not setting metadata
1002 matching events found
No baseline correction applied
0 projection items activated
Not setting metadata
Not setting metadata
1002 matching events found
No baseline correction applied
0 projection items activated


16it [00:00, 32.72it/s]

Not setting metadata
Not setting metadata
1002 matching events found
No baseline correction applied
0 projection items activated
Not setting metadata
Not setting metadata
1002 matching events found
No baseline correction applied
0 projection items activated


18it [00:00, 32.46it/s]


In [75]:
def eeg_power_band(epochs):
    """EEG relative power band feature extraction.

    This function takes an ``mne.Epochs`` object and creates EEG features based
    on relative power in specific frequency bands that are compatible with
    scikit-learn.

    Parameters
    ----------
    epochs : Epochs
        The data.

    Returns
    -------
    X : numpy array of shape [n_samples, 5]
        Transformed data.
    """
    # specific frequency bands
    FREQ_BANDS = {"delta": [0.5, 4.5],
                  "theta": [4.5, 8.5],
                  "alpha": [8.5, 11.5],
                  "sigma": [11.5, 15.5],
                  "beta": [15.5, 30]}

    psds, freqs = psd_welch(epochs, picks='eeg', fmin=0.5, fmax=30.)
    # Normalize the PSDs
    psds /= np.sum(psds, axis=-1, keepdims=True)

    X = []
    for fmin, fmax in FREQ_BANDS.values():
        psds_band = psds[:, :, (freqs >= fmin) & (freqs < fmax)].mean(axis=-1)
        X.append(psds_band.reshape(len(psds), -1))

    return np.concatenate(X, axis=1)

In [76]:
test_subject1, test_subject2 = fetch_data(subjects=[20, 21], recording=[1])

raw_test = mne.io.read_raw_edf(test_subject1[0], stim_channel='marker', misc=['rectal'])

annotations_test = mne.read_annotations(test_subject1[1])
raw_test.set_annotations(annotations_test, emit_warning=False)

Using default location ~/mne_data for PHYSIONET_SLEEP...
Extracting EDF parameters from /root/mne_data/physionet-sleep-data/SC4201E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


0,1
Measurement date,"May 08, 1989 16:18:00 GMT"
Experimenter,Unknown
Digitized points,Not available
Good channels,7 EEG
Bad channels,
EOG channels,Not available
ECG channels,Not available
Sampling frequency,100.00 Hz
Highpass,0.50 Hz
Lowpass,100.00 Hz


In [77]:
raw_test = mne.io.read_raw_edf(test_subject1[0], stim_channel='marker',
                               misc=['rectal'])
annot_test = mne.read_annotations(test_subject1[1])
annot_test.crop(annot_test[1]['onset'] - 30 * 60,
                annot_test[-2]['onset'] + 30 * 60)
raw_test.set_annotations(annot_test, emit_warning=False)
events_test, _ = mne.events_from_annotations(
    raw_test, event_id=annotation_desc_2_event_id, chunk_duration=30.)
epochs_test = mne.Epochs(raw=raw_test, events=events_test, event_id=event_id,
                         tmin=0., tmax=tmax, baseline=None)

print(epochs_test)

Extracting EDF parameters from /root/mne_data/physionet-sleep-data/SC4201E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage R', 'Sleep stage W']
Not setting metadata
Not setting metadata
1022 matching events found
No baseline correction applied
0 projection items activated
<Epochs |  1022 events (good & bad), 0 - 29.99 sec, baseline off, ~12 kB, data not loaded,
 'Sleep stage 1': 39
 'Sleep stage 2': 539
 'Sleep stage 3/4': 4
 'Sleep stage R': 177
 'Sleep stage W': 263>


In [78]:
%%time
pipe = make_pipeline(FunctionTransformer(eeg_power_band, validate=False),
                     RandomForestClassifier(n_estimators=100, random_state=42))

# Train
epochs_train_concat = mne.concatenate_epochs(epochs_train_lst)
y_train = epochs_train_concat.events[:, 2]
pipe.fit(epochs_train_concat, y_train)

# Test
y_pred = pipe.predict(epochs_test)

# Assess the results
y_test = epochs_test.events[:, 2]
acc = accuracy_score(y_test, y_pred)

print("Accuracy score: {}".format(acc))

Loading data for 1002 events and 3000 original time points ...
0 bad epochs dropped
Loading data for 1002 events and 3000 original time points ...
0 bad epochs dropped
Loading data for 1002 events and 3000 original time points ...
0 bad epochs dropped
Loading data for 1002 events and 3000 original time points ...
0 bad epochs dropped
Loading data for 1002 events and 3000 original time points ...
0 bad epochs dropped
Loading data for 1002 events and 3000 original time points ...
0 bad epochs dropped
Loading data for 1002 events and 3000 original time points ...
0 bad epochs dropped
Loading data for 1002 events and 3000 original time points ...
0 bad epochs dropped
Loading data for 1002 events and 3000 original time points ...
0 bad epochs dropped
Loading data for 1002 events and 3000 original time points ...
0 bad epochs dropped
Loading data for 1002 events and 3000 original time points ...
0 bad epochs dropped
Loading data for 1002 events and 3000 original time points ...
0 bad epochs 

In [79]:
print(confusion_matrix(y_test, y_pred))

[[171   0  10  17  65]
 [ 12   0  12   0  15]
 [ 11   0 411  31  86]
 [  0   0   2   2   0]
 [118   0  11   5  43]]


In [80]:
print(classification_report(y_test, y_pred, target_names=event_id.keys()))

                 precision    recall  f1-score   support

  Sleep stage W       0.55      0.65      0.59       263
  Sleep stage 1       0.00      0.00      0.00        39
  Sleep stage 2       0.92      0.76      0.83       539
Sleep stage 3/4       0.04      0.50      0.07         4
  Sleep stage R       0.21      0.24      0.22       177

       accuracy                           0.61      1022
      macro avg       0.34      0.43      0.34      1022
   weighted avg       0.66      0.61      0.63      1022



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
