In [1]:
import glob
import mne
from mne.time_frequency import psd_welch
import os
import numpy as np
import findspark
findspark.init()

import pyspark

from pyspark.sql import SparkSession
from pyspark.sql import SQLContext


spark = SparkSession \
    .builder \
    .master('local[*]') \
    .config("spark.driver.memory", "2g") \
    .appName("sleepdata") \
    .getOrCreate()
sc = spark.sparkContext

def eeg_power_band(epochs):
    """EEG relative power band feature extraction.

    This function takes an ``mne.Epochs`` object and creates EEG features based
    on relative power in specific frequency bands that are compatible with
    scikit-learn.

    Parameters
    ----------
    epochs : Epochs
        The data.

    Returns
    -------
    X : numpy array of shape [n_samples, 5]
        Transformed data.
    """
    # specific frequency bands
    FREQ_BANDS = {"delta": [0.5, 4.5],
                  "theta": [4.5, 8.5],
                  "alpha": [8.5, 11.5],
                  "sigma": [11.5, 15.5],
                  "beta": [15.5, 30]}

    psds, freqs = psd_welch(epochs, picks='eeg', fmin=0.5, fmax=30.)
    # Normalize the PSDs
    psds /= np.sum(psds, axis=-1, keepdims=True)

    X = []
    for fmin, fmax in FREQ_BANDS.values():
        psds_band = psds[:, :, (freqs >= fmin) & (freqs < fmax)].mean(axis=-1)
        X.append(psds_band.reshape(len(psds), -1))

    return np.concatenate(X, axis=1)

os.chdir('/media/data/sleep-edf-database-expanded-1.0.0/sleep-telemetry/')
files = glob.glob('*.edf')
files = np.array(files)

# taken from raw.info in the loop below
cassette_mapping = {
    'EOG horizontal': 'eog',
    'Resp oro-nasal': 'misc',
    'EMG submental': 'misc',
    'Temp rectal': 'misc',
    'Event marker': 'misc'}

telemetry_mapping = {
    'EOG horizontal': 'eog',
    'EMG submental': 'misc',
    'Marker': 'misc'}

annotation_desc_2_event_id = {'Sleep stage W': 1,
                              'Sleep stage 1': 2,
                              'Sleep stage 2': 3,
                              'Sleep stage 3': 4,
                              'Sleep stage 4': 4,
                              'Sleep stage R': 5}

event_ids = {
    'Sleep stage W': 1,
    'Sleep stage 1': 2,
    'Sleep stage 2': 3,
    'Sleep stage 3/4': 4,
    'Sleep stage R': 5}

patient_ids = []
rdd = None
for patient_files in files:
    if patient_files[:7] not in patient_ids:
        patient_ids.append(patient_files[:7])
for patient_id in patient_ids:
    files = glob.glob(f'{patient_id}*')
    for file in files:
        if 'PSG' in file:
            raw = mne.io.read_raw_edf(file)
            # print(raw.info)
        else:
            annot = mne.read_annotations(file)
    
    raw.set_annotations(annot, emit_warning=False)
    raw.set_channel_types(telemetry_mapping)
    # raw.plot(duration=60, scalings='auto')
    #print(patient_id, psg, hypnogram)
    events, _ = mne.events_from_annotations(
        raw, event_id=annotation_desc_2_event_id, chunk_duration=30.)

    tmax = 30. - 1. / raw.info['sfreq'] 
    
    epochs = mne.Epochs(raw=raw, events=events,
                          event_id=event_ids, tmin=0., tmax=tmax, baseline=None, on_missing='ignore')
    
    epochs.drop_bad()
    
    y = epochs.events[:, 2]
    bands = eeg_power_band(epochs)
    
    # remove nan's from eeg_power_band transform
    nan_idx = np.argwhere([np.any(np.isnan(x)) for x in bands]).flatten()
    for idx in np.flip(nan_idx):
        bands  = np.delete(bands, idx, axis=0)
        y = np.delete(y, idx)
    # print(bands.shape, y.shape, bands, y)
    
    # combine inputs and labels into the rdd
    y = y[:, None]
    print(bands.shape, y.shape)
    bands_y = zip(bands, y)
    if rdd is None:
        rdd = sc.parallelize(bands_y)
    else:
        bands_y = sc.parallelize(bands_y)
        rdd = rdd.union(bands_y)

results = rdd.collect()
X = np.array([row[0] for row in results])
y = np.array([row[1][0] for row in results])

Extracting EDF parameters from /media/data/sleep-edf-database-expanded-1.0.0/sleep-telemetry/ST7011J0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Not setting metadata
Not setting metadata
1092 matching events found
No baseline correction applied
0 projection items activated
Loading data for 1092 events and 3000 original time points ...
0 bad epochs dropped
Loading data for 1092 events and 3000 original time points ...
Effective window size : 2.560 (s)
(1092, 10) (1092, 1)
Extracting EDF parameters from /media/data/sleep-edf-database-expanded-1.0.0/sleep-telemetry/ST7012J0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Not setti



0 bad epochs dropped
Loading data for 1034 events and 3000 original time points ...
Effective window size : 2.560 (s)
(1032, 10) (1032, 1)
Extracting EDF parameters from /media/data/sleep-edf-database-expanded-1.0.0/sleep-telemetry/ST7061J0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage R', 'Sleep stage W']
Not setting metadata
Not setting metadata
1008 matching events found
No baseline correction applied
0 projection items activated
Loading data for 1008 events and 3000 original time points ...




0 bad epochs dropped
Loading data for 1008 events and 3000 original time points ...
Effective window size : 2.560 (s)
(1008, 10) (1008, 1)
Extracting EDF parameters from /media/data/sleep-edf-database-expanded-1.0.0/sleep-telemetry/ST7062J0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage R', 'Sleep stage W']
Not setting metadata
Not setting metadata
952 matching events found
No baseline correction applied
0 projection items activated
Loading data for 952 events and 3000 original time points ...
0 bad epochs dropped
Loading data for 952 events and 3000 original time points ...
Effective window size : 2.560 (s)
(952, 10) (952, 1)
Extracting EDF parameters from /media/data/sleep-edf-database-expanded-1.0.0/sleep-telemetry/ST7071J0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Used Annotations descripti



0 bad epochs dropped
Loading data for 1045 events and 3000 original time points ...
Effective window size : 2.560 (s)
(1045, 10) (1045, 1)
Extracting EDF parameters from /media/data/sleep-edf-database-expanded-1.0.0/sleep-telemetry/ST7161J0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Not setting metadata
Not setting metadata
1057 matching events found
No baseline correction applied
0 projection items activated
Loading data for 1057 events and 3000 original time points ...
0 bad epochs dropped
Loading data for 1057 events and 3000 original time points ...
Effective window size : 2.560 (s)
(1057, 10) (1057, 1)
Extracting EDF parameters from /media/data/sleep-edf-database-expanded-1.0.0/sleep-telemetry/ST7162J0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Used



0 bad epochs dropped
Loading data for 959 events and 3000 original time points ...
Effective window size : 2.560 (s)
(959, 10) (959, 1)
Extracting EDF parameters from /media/data/sleep-edf-database-expanded-1.0.0/sleep-telemetry/ST7182J0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Not setting metadata
Not setting metadata
967 matching events found
No baseline correction applied
0 projection items activated
Loading data for 967 events and 3000 original time points ...
0 bad epochs dropped
Loading data for 967 events and 3000 original time points ...
Effective window size : 2.560 (s)
(967, 10) (967, 1)
Extracting EDF parameters from /media/data/sleep-edf-database-expanded-1.0.0/sleep-telemetry/ST7191J0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Used Annotat

In [2]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix
model = RandomForestClassifier(n_estimators=10, random_state=42)
model.fit(X, y)
y_pred = model.predict(X)
print(confusion_matrix(y, y_pred))

[[ 4396    11     0     0     2]
 [   41  3531    30     0    51]
 [    8    14 19772    30    27]
 [    3     0   152  6258     2]
 [    5    23   174     3  8144]]
