<a href="https://colab.research.google.com/github/samymessal/EEG_octo/blob/full_sleep_multi_label_classification/files/full_sleep_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Sleep Spindle Study

## Building Model

In this notebook, we build a model to detect the presence of sleep spindles in the entire EEG recording.
        

In [1]:
!pip install mne -q
!pip install vmdpy -q
!pip install yasa -q

Collecting mne
  Downloading mne-1.6.0-py3-none-any.whl (8.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.3/8.3 MB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: mne
Successfully installed mne-1.6.0
Collecting vmdpy
  Downloading vmdpy-0.2-py2.py3-none-any.whl (6.5 kB)
Installing collected packages: vmdpy
Successfully installed vmdpy-0.2
Collecting yasa
  Downloading yasa-0.6.3.tar.gz (33.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m33.8/33.8 MB[0m [31m27.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting outdated (from yasa)
  Downloading outdated-0.2.2-py2.py3-none-any.whl (7.5 kB)
Collecting antropy (from yasa)
  Downloading antropy-0.1.6.tar.gz (17 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build w

In [2]:
!git clone -b full_sleep_multi_label_classification https://github.com/samymessal/EEG_octo
import sys
sys.path.append('/content/EEG_octo/files')

Cloning into 'EEG_octo'...
remote: Enumerating objects: 307, done.[K
remote: Counting objects: 100% (109/109), done.[K
remote: Compressing objects: 100% (89/89), done.[K
remote: Total 307 (delta 51), reused 66 (delta 20), pack-reused 198[K
Receiving objects: 100% (307/307), 510.53 MiB | 23.17 MiB/s, done.
Resolving deltas: 100% (58/58), done.



## Imports

We will import the necessary libraries that are needed for processing the data, building the model, and evaluating its performance.
        

In [4]:
import mne
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras.layers import LSTM, Dense, Dropout
from tensorflow.keras.models import Sequential
from keras.callbacks import EarlyStopping
from sklearn.model_selection import KFold
import json
import data_preparation
import preprocess
import keras
import tensorflow as tf
from tensorflow.keras import backend as K
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from tensorflow.keras.callbacks import Callback
import json
from tensorflow.keras.metrics import Metric
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Dropout, LSTM, Dense, BatchNormalization, Flatten, LayerNormalization
import tensorflow.keras.layers
from tensorflow.keras.models import Sequential
from keras.utils import timeseries_dataset_from_array
from scipy.io import loadmat
from scipy.signal import detrend
import yasa
from scipy.signal import welch
from tensorflow.keras.layers import Conv1D, BatchNormalization, Dropout, Flatten, Dense, Input, concatenate, Lambda
from tensorflow.keras import Model, regularizers
from sklearn.utils.class_weight import compute_class_weight


DEFAULT_DIVIDER = 10000000

### Download data

Using the `processed_data` function from the previous step to download our concatenated raw with its correspondent preprocessing and features.

In [45]:
def load_eeg_data(mat_file_path):
    # Load the .mat file using scipy
    mat = loadmat(mat_file_path)
    # Extract EEG data
    return mat['EEG'][0, 0]['data']

def mk_raw_obj(eeg_data, sfreq=250):
    info = mne.create_info(
        ch_names=[f'EEG{i}' for i in range(len(eeg_data))],
        sfreq=sfreq,
        ch_types=['eeg' for _ in range(len(eeg_data))]
    )

    return mne.io.RawArray(eeg_data, info)

def load_data(file_path, labels_path):
    raw_mat = load_eeg_data(file_path)
    raw = mk_raw_obj(raw_mat)
    raw_data = raw.get_data()
    labels = pd.read_csv(labels_path)
    labels.sort_values("Timestamp", inplace=True)
    return raw, raw_data, labels

def mne_events_from_labels_df(recording_raw_obj: mne.io.Raw, labels_df: pd.DataFrame, target_label: str,):
    presence_mask = labels_df[target_label] == 1
    nb_events = len(labels_df[presence_mask])

    return np.column_stack((labels_df['Timestamp'][presence_mask], np.ones(nb_events), np.ones(nb_events)))


def preprocess_data(recording_data,
                    frequency_band,
                    sampling_freq,
                    resampling_freq,
                    labels_df,
                    target_label,
                    window_size_in_seconds):
    # Detrending
    recording_data = detrend(recording_data)

    # Create mne raw obj
    recording_raw_obj = mk_raw_obj(recording_data)
    events = mne_events_from_labels_df(recording_raw_obj, labels_df, target_label)

    # Resampling
    recording_raw_obj, events = recording_raw_obj.resample(resampling_freq, events=events)
    labels_df = pd.DataFrame(events, columns=["Timestamp", "ignore", target_label])

    # Band pass filtering
    raw_obj = mk_raw_obj(recording_data, sfreq=sampling_freq)
    bp_filter_raw_obj = raw_obj.filter(frequency_band[0], frequency_band[1], verbose=0)
    recording_data = bp_filter_raw_obj.get_data()

    return recording_raw_obj, labels_df

def hypnogram_propas(recording_raw_obj, sampling_freq=250):
    """
    Computes the propabilites of the each sleep stages at each 30s epoch.
    Then, upsamples the probabilites to match the shape of the recording.
    ### Parameters:
    recording_data: ndarray of the recording
    ### Returns:
    Tuple of shape four, each item is a 1D array of the probability of a sleep stage at a given timestamp.
    Four for the four sleep stages: awake, REM, NREM1, NREM2, NREM3
    """
    # For some reason, yasa doesn't work properly with the unscaled data.
    scalled_raw_obj = mk_raw_obj(recording_raw_obj.get_data() / DEFAULT_DIVIDER, sfreq=sampling_freq)
    sls = yasa.SleepStaging(scalled_raw_obj, eeg_name="EEG0")
    hypno_proba = sls.predict_proba()
    return [yasa.hypno_upsample_to_data(hypno_proba[column], 1/30, scalled_raw_obj, verbose=False) for column in hypno_proba.columns]

def dataset_from_files(
        recording_files,
        labels_files,
        target_label,
        resampling_freq,
        sampling_freq,
        frequency_band,
        window_size_in_seconds,
        shuffle=False
        ):
    """
    Loads and preprocesses the EEG recordings.
    Returns a dataset keras obj.

    ### Parameters:
    recording_files: List of tuples of (.mat single channel eeg_recording
    labels_files: .csv recording labels
    sampling_freq: sampling frequency of the recording.
    target_label: target label
    frequency_band: tuple (min frequency, max frequency), if not None, used to band pass filter the recordings

    ### Returns:
    (dataset, window_size)
    dataset: Timeseries dataset keras obj
    window_size: window size in timestamps after resampling
    """
    time_series = []
    target_timeseries = []
    for recording_file, labels_file in zip(recording_files, labels_files):
        # Loading
        labels_df = pd.read_csv(labels_file)
        recording_data = load_eeg_data(recording_file)

        # Preprocessing
        preprocessed_recording_raw_obj, labels_df = preprocess_data(
            recording_data,
            frequency_band,
            sampling_freq,
            resampling_freq,
            labels_df,
            target_label,
            window_size_in_seconds
            )

        # Feature engineering
        hypno_propas = hypnogram_propas(preprocessed_recording_raw_obj, sampling_freq=sampling_freq)

        # Features formatting
        time_serie = np.column_stack((
            preprocessed_recording_raw_obj.get_data().squeeze(),
            *hypno_propas,
            ))
        target_timeserie = np.zeros(time_serie.shape[0])
        for timestamp in labels_df["Timestamp"]:
          target_timeserie[int(timestamp)] = 1
        time_series.append(time_serie)
        target_timeseries.append(target_timeserie)

    concat_timeseries = np.concatenate(time_series)
    concat_target_timeseries = np.concatenate(target_timeseries)

    class_weights = compute_class_weight("balanced", classes=np.unique(concat_target_timeseries), y=concat_target_timeseries)
    window_size = int(window_size_in_seconds * resampling_freq)
    dataset = timeseries_dataset_from_array(
        concat_timeseries,
        concat_target_timeseries,
        window_size,
        shuffle=shuffle
        )


    def filter_func(input, label):
      # Replace 'overrepresented_class' with the actual class label
      # Replace 'keep_prob' with the probability of keeping an instance of the overrepresented class
      return tf.cond(tf.equal(label, 0),
                    lambda: tf.less(tf.random.uniform(shape=[], minval=0, maxval=1), 0.5),
                    lambda: tf.constant(True))


    # Apply the filter
    # Unbatch the dataset if it's already batched
    dataset = dataset.unbatch()

    # Apply the filter function
    filtered_dataset = dataset.filter(filter_func)

    # Re-batch the dataset
    batch_size = 128  # Or your desired batch size
    filtered_dataset = filtered_dataset.batch(batch_size)


    return dataset, window_size, class_weights


dataset, window_size, class_weights = dataset_from_files(
    ["/content/EEG_octo/dataset/train_S002_night1_hackathon_raw.mat",
    "/content/EEG_octo/dataset/train_S003_night5_hackathon_raw.mat"
    ],
    ["/content/EEG_octo/dataset/train_S002_labeled.csv",
    "/content/EEG_octo/dataset/train_S003_labeled.csv"
    ],
    sampling_freq=250,
    resampling_freq=20,
    target_label="SS1",
    frequency_band=(8, 16),
    window_size_in_seconds=2.5,
    shuffle=False
)

Creating RawArray with float64 data, n_channels=1, n_times=4965399
    Range : 0 ... 4965398 =      0.000 ... 19861.592 secs
Ready.
Creating RawArray with float64 data, n_channels=1, n_times=4965399
    Range : 0 ... 4965398 =      0.000 ... 19861.592 secs
Ready.
Creating RawArray with float64 data, n_channels=1, n_times=397232
    Range : 0 ... 397231 =      0.000 ...  1588.924 secs
Ready.
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).




Creating RawArray with float64 data, n_channels=1, n_times=5772730
    Range : 0 ... 5772729 =      0.000 ... 23090.916 secs
Ready.
Creating RawArray with float64 data, n_channels=1, n_times=5772730
    Range : 0 ... 5772729 =      0.000 ... 23090.916 secs
Ready.
Creating RawArray with float64 data, n_channels=1, n_times=461818
    Range : 0 ... 461817 =      0.000 ...  1847.268 secs
Ready.
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).





#### Model

The chosen model is an LSTM, since we are dealing with timeframes, LSTM are known to deal well with time depending samples. A k-cross validation is implemented, partitioning the data into 5 parts and alterning between the 4 parts for training and the 1 for testing.
        

In [21]:
class F1Score(Metric):
    def __init__(self, name='f1_score', **kwargs):
        super(F1Score, self).__init__(name=name, **kwargs)
        self.precision = tf.keras.metrics.Precision()
        self.recall = tf.keras.metrics.Recall()
        self.f1_score = self.add_weight(name='f1', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):
        self.precision.update_state(y_true, y_pred, sample_weight)
        self.recall.update_state(y_true, y_pred, sample_weight)
        p = self.precision.result()
        r = self.recall.result()
        self.f1_score.assign(2 * ((p * r) / (p + r + tf.keras.backend.epsilon())))

    def result(self):
        return self.f1_score

    def reset_states(self):
        self.precision.reset_states()
        self.recall.reset_states()
        self.f1_score.assign(0)

In [48]:
def create_model():
    input_layer = keras.Input(shape=(window_size, 6))

    # Input layer for the EEG time series
    input_eeg = Lambda(lambda y: y[:, :, 0:2])(input_layer)

    # Layer normalization for EEG
    norm_eeg = LayerNormalization()(input_eeg)

    x = Conv1D(
        filters=32, kernel_size=3, strides=1, activation="relu", padding="same"
    )(norm_eeg)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)

    x = Conv1D(
        filters=64, kernel_size=3, strides=1, activation="relu", padding="same"
    )(x)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)

    x = Conv1D(
        filters=128, kernel_size=5, strides=1, activation="relu", padding="same"
    )(x)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)

    # Now you can flatten the output if you haven't applied global pooling before
    x = Flatten()(x)

    # Input layer for the other features
    first_elements = Lambda(lambda y: y[:, 0, 2:])(input_layer)
    # Concatenate the CNN output and the first elements of the other features
    concatenated = concatenate([x, first_elements])

    x = Dense(
        2048, activation="relu",
        kernel_regularizer=regularizers.L1L2(l1=1e-5, l2=1e-4),
        bias_regularizer=regularizers.L2(1e-4),
    )(concatenated)
    x = Dropout(0.2)(x)

    x = Dense(
        1024,
        activation="relu",
        kernel_regularizer=regularizers.L1L2(l1=1e-5, l2=1e-4),
        bias_regularizer=regularizers.L2(1e-4),
    )(x)
    x = Dropout(0.2)(x)
    x = Dense(
        128,
        activation="relu",
        kernel_regularizer=regularizers.L1L2(l1=1e-5, l2=1e-4),
        bias_regularizer=regularizers.L2(1e-4),
    )(x)
    output_layer = Dense(1, activation="sigmoid")(x)

    return Model(inputs=input_layer, outputs=output_layer)


In [49]:
for x, y in dataset.take(3):
    print("Shape of x:", x.shape)
    print("Shape of y:", y.shape)


# Define the model architecture
model = create_model()

# Compile the model
model.compile(
    optimizer=keras.optimizers.Adam(),
    loss="binary_crossentropy",
    metrics=[
        'accuracy',
        tf.keras.metrics.Precision(),
        tf.keras.metrics.Recall(),
        F1Score(),
    ]
)

# Train the model
history = model.fit(
    dataset,
    epochs=30,
    class_weight={label: weight for label, weight in enumerate(class_weights)}
)


training_f1_scores = history.history['f1_score']
validation_f1_scores = history.history['val_f1_score']

plt.plot(training_f1_scores, label='Training F1 Score')
plt.plot(validation_f1_scores, label='Validation F1 Score')
plt.xlabel('Epochs')
plt.ylabel('F1 Score')
plt.legend()
plt.show()

Shape of x: (50, 6)
Shape of y: ()
Shape of x: (50, 6)
Shape of y: ()
Shape of x: (50, 6)
Shape of y: ()
Epoch 1/30


ValueError: in user code:

    File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 1401, in train_function  *
        return step_function(self, iterator)
    File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 1384, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 1373, in run_step  **
        outputs = model.train_step(data)
    File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 1150, in train_step
        y_pred = self(x, training=True)
    File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/input_spec.py", line 298, in assert_input_compatibility
        raise ValueError(

    ValueError: Input 0 of layer "model_10" is incompatible with the layer: expected shape=(None, 50, 6), found shape=(None, 6)
