
# Sleep Spindle Study

## Building Model

In this notebook, we build a model to detect the presence of sleep spindles in the entire EEG recording. 
        


## Imports

We will import the necessary libraries that are needed for processing the data, building the model, and evaluating its performance.
        

In [1]:

import mne
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras.layers import LSTM, Dense, Dropout
from tensorflow.keras.models import Sequential
from keras.callbacks import EarlyStopping
from sklearn.model_selection import KFold
import json
import utils
import feature_extraction
import data_preparation
        

2023-12-25 19:38:10.113809: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2023-12-25 19:38:15.343750: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-25 19:38:15.343913: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-25 19:38:15.683605: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-25 19:38:17.213635: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2023-12-25 19:38:17.254766: I tensorflow/core/platform/cpu_feature_guard.cc:1

### Download data

Using the `processed_data` function from the previous step to download our concatenated raw with its correspondent preprocessing and features.

In [2]:
X, labels = data_preparation.processed_data(["../dataset/train_S002_night1_hackathon_raw.mat",
                                            "../dataset/train_S003_night5_hackathon_raw.mat"
                                            ],
                                            ["../dataset/train_S002_labeled.csv",
                                            "../dataset/train_S003_labeled.csv"
                                            ],
                                            labels=["SS0", "SS1", "K0", "K1"],
                                            fmin=11,
                                            fmax=15,
                                            include_entire_recording=True)
        

Creating RawArray with float64 data, n_channels=1, n_times=4965399
    Range : 0 ... 4965398 =      0.000 ... 19861.592 secs
Ready.


AttributeError: type object 'Annotations' has no attribute 'concatenate'


#### Model

The chosen model is an LSTM, since we are dealing with timeframes, LSTM are known to deal well with time depending samples. A k-cross validation is implemented, partitioning the data into 5 parts and alterning between the 4 parts for training and the 1 for testing.
        

In [None]:
import preprocess
import keras
import tensorflow as tf
from tensorflow.keras import backend as K



def weighted_binary_crossentropy(y_true, y_pred, weights):
    """
    Custom weighted binary cross-entropy loss function.
    """
    # Convert weights to a TensorFlow tensor and ensure float32 data type
    weights = tf.cast(weights, dtype=tf.float32)

    # Ensure y_true and y_pred are of float32 type
    y_true = tf.cast(y_true, dtype=tf.float32)
    y_pred = tf.cast(y_pred, dtype=tf.float32)

    # Clip predictions to prevent log(0) error
    y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon())

    # Calculate Binary Cross Entropy
    bce = -y_true * K.log(y_pred) - (1 - y_true) * K.log(1 - y_pred)

    # Print the values of BCE
    tf.print("BCE: ", bce, summarize=-1)  # summarize=-1 prints all values

    # Apply weights
    def apply_weights(args):
        y_true_slice, bce_slice = args[0], args[1]
        return bce_slice * tf.gather(weights, tf.cast(y_true_slice, tf.int32), axis=0)

    # Apply weights using tf.map_fn
    weighted_bce = tf.map_fn(apply_weights, (y_true, bce), dtype=tf.float32)

    # Return mean loss
    return K.mean(weighted_bce, axis=-1)


kfold = KFold(n_splits=5)
for fold_no, (train, test) in enumerate(kfold.split(X)):
    # Define the model architecture
    model = Sequential()
    model.add(LSTM(50, input_shape=(X.shape[1], X.shape[2]), return_sequences=True))
    model.add(LSTM(50, return_sequences=True))
    model.add(Dropout(0.4))
    model.add(LSTM(20, return_sequences=True))
    model.add(Dropout(0.3))
    model.add(LSTM(20))
    model.add(Dropout(0.2))
    model.add(Dense(2, activation='sigmoid'))
    
    weights = preprocess.compute_multi_label_loss_weights(labels[train])
    def custom_loss(y_true, y_pred):
        return weighted_binary_crossentropy(y_true, y_pred, weights)

    # Compile the model
    model.compile(optimizer="adam", loss=custom_loss, metrics=['accuracy'])
    
    # Fit data to model
    history = model.fit(X[train], labels[train], epochs=30)

    perf_metrics = utils.evaluate_model(model, X[test], labels[test])
    utils.save_model(model, history, perf_metrics, fold_no)

### Visualize plots and metrics

Determining the performance of the model

Plot accuracies and loss for training and validation

In [None]:
# filename = "SS_bp4_35Pre_0Features_LSTM_"
# filename = "SS_0Pre_0Features_LSTM_"
# filename = "SS_detrend_Pre_0Features_LSTM_"
# filename = "SS_bp11_15Pre_0Features_LSTM_"
# filename = "SS_VDM1_3Pre_0Features_LSTM_"
import os
print(os.listdir("./ressources/models/metrics"))
filenames = [
    "SS_bp4_35Pre_0Features_LSTM_",
    "SS_0Pre_0Features_LSTM_",
    "SS_detrend_Pre_0Features_LSTM_",
    "SS_bp11_15Pre_0Features_LSTM_",
    "SS_VDM1_3Pre_0Features_LSTM_"
]
for filename in filenames:
    utils.plot_fold_history(filename, 5)

Performance of each fold will be printed along with the average performance of the cross validation

In [None]:
performance = utils.print_performances("SS_bp4_35Pre_0Features_LSTM_", 1)
print(performance)