<a href="https://colab.research.google.com/github/samymessal/EEG_octo/blob/full_sleep_multi_label_classification/files/multilabel_full_sleep_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Sleep Spindle Study

## Building Model

In this notebook, we build a model to detect the presence of sleep spindles in the entire EEG recording.
        

In [6]:
!pip install mne
!pip install vmdpy
!pip install yasa

Collecting yasa
  Downloading yasa-0.6.3.tar.gz (33.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m33.8/33.8 MB[0m [31m33.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting outdated (from yasa)
  Downloading outdated-0.2.2-py2.py3-none-any.whl (7.5 kB)
Collecting antropy (from yasa)
  Downloading antropy-0.1.6.tar.gz (17 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting tensorpac>=0.6.5 (from yasa)
  Downloading tensorpac-0.6.5-py3-none-any.whl (423 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m423.6/423.6 kB[0m [31m47.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pyriemann>=0.2.7 (from yasa)
  Downloading pyriemann-0.5.tar.gz (119 kB)

In [2]:
!git clone -b full_sleep_multi_label_classification https://github.com/samymessal/EEG_octo

Cloning into 'EEG_octo'...
remote: Enumerating objects: 299, done.[K
remote: Counting objects: 100% (101/101), done.[K
remote: Compressing objects: 100% (81/81), done.[K
remote: Total 299 (delta 44), reused 68 (delta 20), pack-reused 198[K
Receiving objects: 100% (299/299), 510.52 MiB | 17.44 MiB/s, done.
Resolving deltas: 100% (51/51), done.
Updating files: 100% (36/36), done.


In [3]:
import sys
sys.path.append('/content/EEG_octo/files')

import os
cwd = os.getcwd()

print(cwd)
items = os.listdir('/content/EEG_octo/files')

# Print the list of items
for item in items:
    print(item)

/content
.ipynb_checkpoints
utils.py
Starting_kit_ntx_data_challenge_v0_1_Data_Exploration.ipynb
preprocess.py
feature_extraction.py
data_loading.py
multilabel_full_sleep_classification.ipynb
data_preparation.py
__pycache__



## Imports

We will import the necessary libraries that are needed for processing the data, building the model, and evaluating its performance.
        

In [32]:
import mne
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras.layers import LSTM, Dense, Dropout
from tensorflow.keras.models import Sequential
from keras.callbacks import EarlyStopping
from sklearn.model_selection import KFold
import json
import data_preparation
import preprocess
import keras
import tensorflow as tf
from tensorflow.keras import backend as K
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from tensorflow.keras.callbacks import Callback
import json
from tensorflow.keras.metrics import Metric
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Dropout, LSTM, Dense, BatchNormalization, Flatten, LayerNormalization
import tensorflow.keras.layers
from tensorflow.keras.models import Sequential
from tensorflow.keras import regularizers
from keras.utils import timeseries_dataset_from_array
from scipy.io import loadmat
from scipy.signal import detrend
import yasa
from scipy.signal import welch
from tensorflow.keras.layers import Conv1D, BatchNormalization, Dropout, Flatten, Dense, Input, concatenate, Lambda
from tensorflow.keras import Model, regularizers


DEFAULT_DIVIDER = 10000000

### Download data

Using the `processed_data` function from the previous step to download our concatenated raw with its correspondent preprocessing and features.

In [52]:
def load_eeg_data(mat_file_path):
    # Load the .mat file using scipy
    mat = loadmat(mat_file_path)
    # Extract EEG data
    return mat['EEG'][0, 0]['data']

def mk_raw_obj(eeg_data, sfreq=250):
    info = mne.create_info(
        ch_names=[f'EEG{i}' for i in range(len(eeg_data))],
        sfreq=sfreq,
        ch_types=['eeg' for _ in range(len(eeg_data))]
    )

    return mne.io.RawArray(eeg_data, info)

def load_data(file_path, labels_path):
    raw_mat = load_eeg_data(file_path)
    raw = mk_raw_obj(raw_mat)
    raw_data = raw.get_data()
    labels = pd.read_csv(labels_path)
    labels.sort_values("Timestamp", inplace=True)
    return raw, raw_data, labels

def preprocess_recording_data(recording_data, frequency_band=None, sampling_freq=250):
    # Detrending
    recording_data = detrend(recording_data)
    # band pass filtering
    raw_obj = mk_raw_obj(recording_data, sfreq=sampling_freq)
    bp_filter_raw_obj = raw_obj.filter(frequency_band[0], frequency_band[1], verbose=0)
    recording_data = bp_filter_raw_obj.get_data()

    return recording_data

def hypnogram_propas(recording_data, sampling_freq=250):
    """
    Computes the propabilites of the each sleep stages at each 30s epoch.
    Then, upsamples the probabilites to match the shape of the recording.
    ### Parameters:
    recording_data: ndarray of the recording
    ### Returns:
    Tuple of shape four, each item is a 1D array of the probability of a sleep stage at a given timestamp.
    Four for the four sleep stages: awake, REM, NREM1, NREM2, NREM3
    """
    # For some reason, yasa doesn't work properly with the unscaled data.
    scalled_raw_obj = mk_raw_obj(recording_data / DEFAULT_DIVIDER, sfreq=sampling_freq)
    sls = yasa.SleepStaging(scalled_raw_obj, eeg_name="EEG0")
    hypno_proba = sls.predict_proba()
    return [yasa.hypno_upsample_to_data(hypno_proba[column], 1/30, scalled_raw_obj, verbose=False) for column in hypno_proba.columns]

def band_psd_ratio(recording_data, band1, band2, window_size, sfreq=250):
    recording_data = recording_data.squeeze()
    num_windows = len(recording_data) - window_size + 1
    print(num_windows)
    print(len(recording_data))
    print("recording_data.shape:", recording_data.shape)
    print("recording_data.squeeze.shape:", recording_data.squeeze().shape)
    print(window_size)
    power_ratios = np.empty((num_windows))

    for i in range(num_windows):
        f, psd = welch(recording_data[i:i+window_size].squeeze(), sfreq, nperseg=int(sfreq * 2))
        # Calculate power in the designated frequency bands
        band1_power = psd[(f >= band1[0]) & (f <= band1[1])].mean()
        band2_power = psd[(f >= band2[0]) & (f <= band2[1])].mean()
        ratio = band1_power / band2_power
        power_ratios[i] = ratio

    power_ratios = np.pad(
        power_ratios,
        pad_width=(0, len(recording_data) - num_windows),
        mode='constant',
        constant_values=(power_ratios[0], power_ratios[-1]))
    return power_ratios


def dataset_from_files(
        recording_files,
        labels_files=None,
        target_label=None,
        sampling_freq=250,
        frequency_band=None,
        include_hypno_proba=True,
        window_size_in_seconds=None,
        band1=None,
        band2=None,
        shuffle=False
        ):
    """
    Loads and preprocesses the EEG recordings.
    Returns a dataset keras obj.

    ### Parameters:
    recording_files: List of tuples of (.mat single channel eeg_recording
    labels_files: .csv recording labels
    sampling_freq: sampling frequency of the recording.
    target_label: target label
    frequency_band: tuple (min frequency, max frequency), if not None, used to band pass filter the recordings

    ### Returns:
    Timeseries dataset keras obj
    """
    window_size = int(window_size_in_seconds * sampling_freq)
    time_series = []
    for recording_file in recording_files:
        recording_data = load_eeg_data(recording_file)
        preprocessed_recording_data = preprocess_recording_data(recording_data, frequency_band=frequency_band)
        hypno_propas = hypnogram_propas(recording_data, sampling_freq=sampling_freq) if include_hypno_proba else ()
        # psd_ratio = band_psd_ratio(recording_data, band1, band2, window_size) if window_size is not None else ()
        print(preprocessed_recording_data.shape)
        print(hypno_propas[0].shape)
        time_serie = np.column_stack((
            preprocessed_recording_data.squeeze(),
            *hypno_propas,
            # *psd_ratio
            ))
        print("time_serie.shape:", time_serie.shape)
        time_series.append(time_serie)
    concat_time_serie = np.concatenate(time_series)
    print("concat_time_serie.shape:", concat_time_serie.shape)

    if labels_files is not None:
        assert target_label is not None, "labels_files was set but not target_label."
        target_arrays = []
        for time_serie, labels_file in zip(time_series, labels_files):
            labels_df = pd.read_csv(labels_file)
            presence_incdices = labels_df[labels_df[target_label] == 1]['Timestamp']
            target_array = np.zeros(time_serie.shape[0])
            target_array[presence_incdices] = 1
            target_arrays.append(target_array)

        concat_target_array = np.concatenate(target_arrays)
        print(concat_target_array[:10])
    else:
        concat_target_array = None



    return timeseries_dataset_from_array(concat_time_serie, concat_target_array, window_size, shuffle=shuffle)



dataset = dataset_from_files(
    ["/content/EEG_octo/dataset/train_S002_night1_hackathon_raw.mat",
    "/content/EEG_octo/dataset/train_S003_night5_hackathon_raw.mat"
    ],
    ["/content/EEG_octo/dataset/train_S002_labeled.csv",
    "/content/EEG_octo/dataset/train_S003_labeled.csv"
    ],
    target_label="SS1",
    frequency_band=(8, 16),
    window_size_in_seconds=2.5,
    band1=(8, 13),
    band2= (13, 16),
)

Creating RawArray with float64 data, n_channels=1, n_times=4965399
    Range : 0 ... 4965398 =      0.000 ... 19861.592 secs
Ready.
Creating RawArray with float64 data, n_channels=1, n_times=4965399
    Range : 0 ... 4965398 =      0.000 ... 19861.592 secs
Ready.
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).




(1, 4965399)
(4965399,)
time_serie.shape: (4965399, 6)
Creating RawArray with float64 data, n_channels=1, n_times=5772730
    Range : 0 ... 5772729 =      0.000 ... 23090.916 secs
Ready.
Creating RawArray with float64 data, n_channels=1, n_times=5772730
    Range : 0 ... 5772729 =      0.000 ... 23090.916 secs
Ready.
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).




(1, 5772730)
(5772730,)
time_serie.shape: (5772730, 6)
concat_time_serie.shape: (10738129, 6)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]



#### Model

The chosen model is an LSTM, since we are dealing with timeframes, LSTM are known to deal well with time depending samples. A k-cross validation is implemented, partitioning the data into 5 parts and alterning between the 4 parts for training and the 1 for testing.
        

In [25]:
class F1Score(Metric):
    def __init__(self, name='f1_score', **kwargs):
        super(F1Score, self).__init__(name=name, **kwargs)
        self.precision = tf.keras.metrics.Precision()
        self.recall = tf.keras.metrics.Recall()
        self.f1_score = self.add_weight(name='f1', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):
        self.precision.update_state(y_true, y_pred, sample_weight)
        self.recall.update_state(y_true, y_pred, sample_weight)
        p = self.precision.result()
        r = self.recall.result()
        self.f1_score.assign(2 * ((p * r) / (p + r + tf.keras.backend.epsilon())))

    def result(self):
        return self.f1_score

    def reset_states(self):
        self.precision.reset_states()
        self.recall.reset_states()
        self.f1_score.assign(0)

In [53]:
window_size = int(2.5 * 250)

def create_model():
    input_layer = keras.Input(shape=(window_size, 6))

    # Input layer for the EEG time series
    input_eeg = Lambda(lambda y: y[:, :, 0:2])(input_layer)
    print(tf.shape(input_layer, out_type=None, name=None))
    print(tf.shape(input_eeg, out_type=None, name=None))

    # Layer normalization for EEG
    norm_eeg = LayerNormalization()(input_eeg)

    x = Conv1D(
        filters=32, kernel_size=3, strides=1, activation="relu", padding="same"
    )(norm_eeg)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)

    x = Conv1D(
        filters=64, kernel_size=3, strides=1, activation="relu", padding="same"
    )(x)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)

    x = Conv1D(
        filters=128, kernel_size=5, strides=1, activation="relu", padding="same"
    )(x)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)

    # Now you can flatten the output if you haven't applied global pooling before
    x = Flatten()(x)

    # Input layer for the other features
    first_elements = Lambda(lambda y: y[:, 0, 2:])(input_layer)
    # Concatenate the CNN output and the first elements of the other features
    concatenated = concatenate([x, first_elements])

    x = Dense(
        2048, activation="relu",
        kernel_regularizer=regularizers.L1L2(l1=1e-5, l2=1e-4),
        bias_regularizer=regularizers.L2(1e-4),
    )(concatenated)
    x = Dropout(0.2)(x)

    x = Dense(
        1024,
        activation="relu",
        kernel_regularizer=regularizers.L1L2(l1=1e-5, l2=1e-4),
        bias_regularizer=regularizers.L2(1e-4),
    )(x)
    x = Dropout(0.2)(x)
    x = Dense(
        128,
        activation="relu",
        kernel_regularizer=regularizers.L1L2(l1=1e-5, l2=1e-4),
        bias_regularizer=regularizers.L2(1e-4),
    )(x)
    output_layer = Dense(1, activation="sigmoid")(x)

    return Model(inputs=input_layer, outputs=output_layer)


In [None]:
# kfold = KFold(n_splits=3, shuffle=True)
# for fold_no, (train, test) in enumerate(kfold.split(X, labels)):
#     print("train indices:", train.shape)
#     print("test indices:", test.shape)
#     # Define the model architecture
#     model = create_model()

#     # Compile the model
#     model.compile(
#         optimizer=keras.optimizers.Adam(),
#         loss="binary_crossentropy",
#         metrics=[
#             'accuracy',
#             tf.keras.metrics.Precision(),
#             tf.keras.metrics.Recall(),
#             F1Score(),
#         ]
#     )

#     # Train the model
#     history = model.fit(
#         X[train],
#         labels[train],
#         epochs=30,
#         validation_data=(X[test], labels[test]),
#     )


#     training_f1_scores = history.history['f1_score']
#     validation_f1_scores = history.history['val_f1_score']

#     plt.plot(training_f1_scores, label='Training F1 Score')
#     plt.plot(validation_f1_scores, label='Validation F1 Score')
#     plt.xlabel('Epochs')
#     plt.ylabel('F1 Score')
#     plt.legend()
#     plt.show()

for x, y in dataset.take(1):
    print("Shape of x:", x.shape)
    print("Shape of y:", y.shape)


# Define the model architecture
model = create_model()

# Compile the model
model.compile(
    optimizer=keras.optimizers.Adam(),
    loss="binary_crossentropy",
    metrics=[
        'accuracy',
        tf.keras.metrics.Precision(),
        tf.keras.metrics.Recall(),
        F1Score(),
    ]
)

# Train the model
history = model.fit(
    dataset,
    epochs=30
)


training_f1_scores = history.history['f1_score']
validation_f1_scores = history.history['val_f1_score']

plt.plot(training_f1_scores, label='Training F1 Score')
plt.plot(validation_f1_scores, label='Validation F1 Score')
plt.xlabel('Epochs')
plt.ylabel('F1 Score')
plt.legend()
plt.show()

Shape of x: (128, 625, 6)
Shape of y: (128,)
KerasTensor(type_spec=TensorSpec(shape=(3,), dtype=tf.int32, name=None), inferred_value=[None, 625, 6], name='tf.compat.v1.shape_6/Shape:0', description="created by layer 'tf.compat.v1.shape_6'")
KerasTensor(type_spec=TensorSpec(shape=(3,), dtype=tf.int32, name=None), inferred_value=[None, 625, 2], name='tf.compat.v1.shape_7/Shape:0', description="created by layer 'tf.compat.v1.shape_7'")
Epoch 1/30
 4273/83887 [>.............................] - ETA: 3:04:02 - loss: 0.3672 - accuracy: 0.9998 - precision_8: 0.0000e+00 - recall_8: 0.0000e+00 - f1_score: 0.0000e+00