In [1]:
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
import mne
import matplotlib.pyplot as plt
import pyvista
import ipywidgets
import ipyevents
import pyvistaqt
import yasa

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, f1_score, roc_curve, auc
from sklearn.utils import class_weight
from sklearn.preprocessing import MinMaxScaler

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping

import scipy.signal as signal
from scipy.signal import hilbert
from scipy.signal import stft

import pywt
import cv2

In [2]:
%matplotlib qt

### CNN one input Models

In [24]:
def build_cnn_model(input_shape=(500,1)):

    # linear embedding layer
    input_layer = tf.keras.layers.Input(shape=input_shape)

    # Three convolutional blocks (like having three pattern detectors)

    # First convolution block, kernel size of 5
    padded1 = tf.keras.layers.ZeroPadding1D(padding=2)(input_layer)
    conv1 = tf.keras.layers.Conv1D(filters=10, kernel_size=5, strides=1, padding='valid')(padded1)
    # each filter learns a different type of short-time feature
    # stride of 1, moves one step at a time
    conv1 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv1)
    conv1 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv1)
    # K = 2
    conv1 = tf.keras.layers.BatchNormalization()(conv1)

    # Second convolution block, kernel size of 11
    padded2 = tf.keras.layers.ZeroPadding1D(padding=5)(input_layer)
    conv2 = tf.keras.layers.Conv1D(filters=10, kernel_size=11, strides=1, padding='valid')(padded2)
    conv2 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv2)
    conv2 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv2)
    conv2 = tf.keras.layers.BatchNormalization()(conv2)

    # Third convolution block, kernel size of 21
    padded3 = tf.keras.layers.ZeroPadding1D(padding=10)(input_layer)
    conv3 = tf.keras.layers.Conv1D(filters=10, kernel_size=21, strides=1, padding='valid')(padded3)
    conv3 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv3)
    conv3 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv3)
    conv3 = tf.keras.layers.BatchNormalization()(conv3)

    # Concatenate the outputs of all blocks
    concatenated = tf.keras.layers.Concatenate()([conv1, conv2, conv3])

    # GRU Layer
    gru = tf.keras.layers.GRU(64)(concatenated)

    # Fully connected (dense) layer
    dense = tf.keras.layers.Dense(64, activation='relu')(gru)
    # add a Dropout layer to prevent overfitting
    #dense = tf.keras.layers.Dropout(0.5)(dense)

    # Two softmax outputs for dual-task classification
    #output_task1 = tf.keras.layers.Dense(2, activation='softmax', name='task1')(dense)
    #output_task2 = tf.keras.layers.Dense(2, activation='softmax', name='task2')(dense)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(dense)

    # Create the model
    #model = tf.keras.models.Model(inputs=input_layer, outputs=[output_task1, output_task2])
    model = tf.keras.models.Model(inputs=input_layer, outputs=output)

    # Compile the model
    #model.compile(optimizer='adam', loss={'task1': 'categorical_crossentropy', 'task2': 'categorical_crossentropy'}, metrics={'task1': 'accuracy', 'task2': 'accuracy'})
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    # Return the compiled model
    return model

# show model architecture
input_shape = (500, 1)
model = build_cnn_model(input_shape)
model.summary()



In [29]:
def build_cnn_model_downsampled(input_shape=(300,1)):

    # linear embedding layer
    input_layer = tf.keras.layers.Input(shape=input_shape)

    # Three convolutional blocks (like having three pattern detectors)

    # First convolution block, kernel size of 5
    padded1 = tf.keras.layers.ZeroPadding1D(padding=2)(input_layer)
    conv1 = tf.keras.layers.Conv1D(filters=10, kernel_size=5, strides=1, padding='valid')(padded1)
    # each filter learns a different type of short-time feature
    # stride of 1, moves one step at a time
    conv1 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv1)
    conv1 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv1)
    # K = 2
    conv1 = tf.keras.layers.BatchNormalization()(conv1)

    # Second convolution block, kernel size of 11
    padded2 = tf.keras.layers.ZeroPadding1D(padding=5)(input_layer)
    conv2 = tf.keras.layers.Conv1D(filters=10, kernel_size=11, strides=1, padding='valid')(padded2)
    conv2 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv2)
    conv2 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv2)
    conv2 = tf.keras.layers.BatchNormalization()(conv2)

    # Third convolution block, kernel size of 21
    padded3 = tf.keras.layers.ZeroPadding1D(padding=10)(input_layer)
    conv3 = tf.keras.layers.Conv1D(filters=10, kernel_size=21, strides=1, padding='valid')(padded3)
    conv3 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv3)
    conv3 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv3)
    conv3 = tf.keras.layers.BatchNormalization()(conv3)

    # Concatenate the outputs of all blocks
    concatenated = tf.keras.layers.Concatenate()([conv1, conv2, conv3])

    # GRU Layer
    gru = tf.keras.layers.GRU(64)(concatenated)

    # Fully connected (dense) layer
    dense = tf.keras.layers.Dense(64, activation='relu')(gru)
    # add a Dropout layer to prevent overfitting
    #dense = tf.keras.layers.Dropout(0.5)(dense)

    # Two softmax outputs for dual-task classification
    #output_task1 = tf.keras.layers.Dense(2, activation='softmax', name='task1')(dense)
    #output_task2 = tf.keras.layers.Dense(2, activation='softmax', name='task2')(dense)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(dense)

    # Create the model
    #model = tf.keras.models.Model(inputs=input_layer, outputs=[output_task1, output_task2])
    model = tf.keras.models.Model(inputs=input_layer, outputs=output)

    # Compile the model
    #model.compile(optimizer='adam', loss={'task1': 'categorical_crossentropy', 'task2': 'categorical_crossentropy'}, metrics={'task1': 'accuracy', 'task2': 'accuracy'})
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    # Return the compiled model
    return model

In [30]:
def build_2d_cnn_model(input_shape=(65, 25, 1)):
    # Input: STFT spectrogram (freq_bins x time_bins x 1)
    input_layer = tf.keras.layers.Input(shape=input_shape)

    # First convolution block, kernel size (3x3)
    padded1 = tf.keras.layers.ZeroPadding2D(padding=(1, 1))(input_layer)
    conv1 = tf.keras.layers.Conv2D(filters=10, kernel_size=(3, 3), strides=(1, 1), padding='valid')(padded1)
    conv1 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv1)
    conv1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv1)
    conv1 = tf.keras.layers.BatchNormalization()(conv1)

    # Second convolution block, kernel size (5x5)
    padded2 = tf.keras.layers.ZeroPadding2D(padding=(2, 2))(input_layer)
    conv2 = tf.keras.layers.Conv2D(filters=10, kernel_size=(5, 5), strides=(1, 1), padding='valid')(padded2)
    conv2 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv2)
    conv2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv2)
    conv2 = tf.keras.layers.BatchNormalization()(conv2)

    # Third convolution block, kernel size (7x7)
    padded3 = tf.keras.layers.ZeroPadding2D(padding=(3, 3))(input_layer)
    conv3 = tf.keras.layers.Conv2D(filters=10, kernel_size=(7, 7), strides=(1, 1), padding='valid')(padded3)
    conv3 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv3)
    conv3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv3)
    conv3 = tf.keras.layers.BatchNormalization()(conv3)

    # Concatenate all three conv outputs
    concatenated = tf.keras.layers.Concatenate(axis=-1)([conv1, conv2, conv3])

    # Flatten before passing to RNN
    reshaped = tf.keras.layers.Reshape((-1, concatenated.shape[-1]))(concatenated)

    # GRU layer for temporal modeling over time slices
    gru = tf.keras.layers.GRU(64)(reshaped)

    # Fully connected dense layer
    dense = tf.keras.layers.Dense(64, activation='relu')(gru)

    # Output layer for binary classification
    output = tf.keras.layers.Dense(1, activation='sigmoid')(dense)

    # Define and compile model
    model = tf.keras.models.Model(inputs=input_layer, outputs=output)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

    return model

# show model architecture
input_shape = (65, 25, 1)
cnn_2d_model = build_2d_cnn_model(input_shape)
cnn_2d_model.summary()




In [25]:
def build_2d_cnn_model_downsampled(input_shape=(26, 13, 1)):
    # Input: STFT spectrogram (freq_bins x time_bins x 1)
    input_layer = tf.keras.layers.Input(shape=input_shape)

    # First convolution block, kernel size (3x3)
    padded1 = tf.keras.layers.ZeroPadding2D(padding=(1, 1))(input_layer)
    conv1 = tf.keras.layers.Conv2D(filters=10, kernel_size=(3, 3), strides=(1, 1), padding='valid')(padded1)
    conv1 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv1)
    conv1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv1)
    conv1 = tf.keras.layers.BatchNormalization()(conv1)

    # Second convolution block, kernel size (5x5)
    padded2 = tf.keras.layers.ZeroPadding2D(padding=(2, 2))(input_layer)
    conv2 = tf.keras.layers.Conv2D(filters=10, kernel_size=(5, 5), strides=(1, 1), padding='valid')(padded2)
    conv2 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv2)
    conv2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv2)
    conv2 = tf.keras.layers.BatchNormalization()(conv2)

    # Third convolution block, kernel size (7x7)
    padded3 = tf.keras.layers.ZeroPadding2D(padding=(3, 3))(input_layer)
    conv3 = tf.keras.layers.Conv2D(filters=10, kernel_size=(7, 7), strides=(1, 1), padding='valid')(padded3)
    conv3 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv3)
    conv3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv3)
    conv3 = tf.keras.layers.BatchNormalization()(conv3)

    # Concatenate all three conv outputs
    concatenated = tf.keras.layers.Concatenate(axis=-1)([conv1, conv2, conv3])

    # Flatten before passing to RNN
    reshaped = tf.keras.layers.Reshape((-1, concatenated.shape[-1]))(concatenated)

    # GRU layer for temporal modeling over time slices
    gru = tf.keras.layers.GRU(64)(reshaped)

    # Fully connected dense layer
    dense = tf.keras.layers.Dense(64, activation='relu')(gru)

    # Output layer for binary classification
    output = tf.keras.layers.Dense(1, activation='sigmoid')(dense)

    # Define and compile model
    model = tf.keras.models.Model(inputs=input_layer, outputs=output)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

    return model

# show model architecture
input_shape = (26, 13, 1)
cnn_2d_model_downsampled = build_2d_cnn_model_downsampled(input_shape)
cnn_2d_model_downsampled.summary()




### Spindle detection function

In [3]:
def detect_spindles_times(eeg_raw, do_filter=True, do_downsample=False, downsample_rate=100):
    
    # 1. Filter between 12 and 16 Hz
    
    data = eeg_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=12, h_freq=16)
    
    # 2. Downsample at 100 Hz (100 samples per second)

    if do_downsample:
        data.resample(downsample_rate)
    
    sfreq = data.info['sfreq']  
    channel_data = data.get_data()[0]
    # extract the filtered data
    
    
    # 3: Calculate amplitude by applying Hilbert transformation

    hilbert_signal = hilbert(channel_data)
    # apply hilbert transformation to bandpassed data
    # gives analytic signal with amplitude and phase information
    envelope = np.abs(hilbert_signal)
    # take the absolute part of the hilbert signal
    # also the instantaneous power of the signal
    # gives the envelope: amplitude modulation
    # how strength of oscillations change over time
    # size of sliding window
    
    # 4: Perform smoothing with a sliding window of 0.2 seconds
    # this removes high-frequency noise
    
    sliding_window = int(0.2 * sfreq)
    smoothed_envelope = np.convolve(envelope, np.ones(sliding_window) / sliding_window, mode='same')
    # convolving envelope with a uniform filter over the sliding window
    # convolution takes rolling average of 20 samples at a time
    # smooth the signal with the average of values in the window
    # in the smoothed envelope, can detect regions with higher amplitude 
    # which is when a spindle event occurs
    # np.ones: creates a filter kernel
    # have a filter where the sum of all elements equals 1
    # this filter is replaced by the average of the 20 surrounding samples
    # convolution between envelope and averaging filter
    # mode = 'same': so that output of convolution has same length as original envelope

    # 5. Define spindle detection threshold

    threshold = np.percentile(smoothed_envelope, 75)
    spindle_threshold = smoothed_envelope > threshold
    #threshold = np.mean(smoothed_envelope) + 1.5 * np.std(smoothed_envelope)
    #spindle_threshold = smoothed_envelope > threshold
    # threshold is 75th percentile of the smoothed envelope
    # will look at the duration later
    
    # 6. Detect spindles and define peaks and troughs for visualisation
    
    spindles = []
    # initialize list with spindles
    above_threshold = np.where(spindle_threshold)[0]
    # returns indices where signal above the threshold
    stacked_spindles = []
    # initialize list for stacking the spindles for the visualisation
    # contains aligned spindles at peak
    
    if len(above_threshold) > 0:
        # checking it's not empty
        start_idx = above_threshold[0]
        # would be the start of a potential spindle
        for i in range(1, len(above_threshold)):
            if above_threshold[i] > above_threshold[i - 1] + 1:  
                # if above threshold[1] > above_threshold[0] + 1
                # because all indices should be separated by 1
                # so here detects gaps
                # so starting from the second index
                # and comparing each index to the one before
                end_idx = above_threshold[i - 1]
                # so if above condition is true, this is the end of the spindle
                duration = (end_idx - start_idx) / sfreq
                if 0.5 <= duration <= 3:
                    # only keep spindles lasting 0.5 to 3 seconds
                    segment = channel_data[start_idx:end_idx]
                    # extract EEG segment corresponding to detected spindle
                    peak_idx = start_idx + np.argmax(segment) 
                    # extract the peak of the spindle
                    # this will be useful for later
                    spindles.append((start_idx / sfreq, end_idx / sfreq))
                    # all the spindles are stored in spindles
                    
                    # Aligning spindles at peak for visualization
                    before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
                    # still in the for loop, so this is the peak index of individual peak
                    after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
                    # extracting 1.5 seconds before and after peak
                    # max and min are used for out of bounds situations at the start and end of EEG data
                    aligned_segment = channel_data[before_peak_idx:after_peak_idx]
                    stacked_spindles.append(aligned_segment)
                    # the aligned segment is saved in stacked spindles
                
                start_idx = above_threshold[i]
                # update the start index for the for loop

        # then need to process the final spindle
        end_idx = above_threshold[-1]
        duration = (end_idx - start_idx) / sfreq
        if 0.5 <= duration <= 3:
            segment = channel_data[start_idx:end_idx]
            peak_idx = start_idx + np.argmax(segment)
            spindles.append((start_idx / sfreq, end_idx / sfreq))

            before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
            after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
            aligned_segment = channel_data[before_peak_idx:after_peak_idx]
            stacked_spindles.append(aligned_segment)
    
    return spindles
    

def detect_spindles_peaks(eeg_raw, do_filter=True, do_downsample=False, downsample_rate=100):
    
    # 1. Filter between 12 and 16 Hz
    
    data = eeg_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=12, h_freq=16)
    
    # 2. Downsample at 100 Hz (100 samples per second)
    
    if do_downsample:
        data.resample(downsample_rate)
        
    sfreq = data.info['sfreq']  
    # update to new sampling frequency
    # because used later in the code
    channel_data = data.get_data()[0]
    # extract the filtered data
    
    # 3: Calculate amplitude by applying Hilbert transformation

    hilbert_signal = hilbert(channel_data)
    # apply hilbert transformation to bandpassed data
    # gives analytic signal with amplitude and phase information
    envelope = np.abs(hilbert_signal)
    # take the absolute part of the hilbert signal
    # also the instantaneous power of the signal
    # gives the envelope: amplitude modulation
    # how strength of oscillations change over time
    # size of sliding window
    
    # 4: Perform smoothing with a sliding window of 0.2 seconds
    # this removes high-frequency noise
    
    sliding_window = int(0.2 * sfreq)
    smoothed_envelope = np.convolve(envelope, np.ones(sliding_window) / sliding_window, mode='same')
    # convolving envelope with a uniform filter over the sliding window
    # convolution takes rolling average of 20 samples at a time
    # smooth the signal with the average of values in the window
    # in the smoothed envelope, can detect regions with higher amplitude 
    # which is when a spindle event occurs
    # np.ones: creates a filter kernel
    # have a filter where the sum of all elements equals 1
    # this filter is replaced by the average of the 20 surrounding samples
    # convolution between envelope and averaging filter
    # mode = 'same': so that output of convolution has same length as original envelope

    # 5. Define spindle detection threshold

    threshold = np.percentile(smoothed_envelope, 75)
    spindle_threshold = smoothed_envelope > threshold
    # 75th percentile as criteria

    #threshold = np.mean(smoothed_envelope) + 1.5 * np.std(smoothed_envelope)
    #spindle_threshold = smoothed_envelope > threshold
    
    # 6. Detect spindles and define peaks and troughs for visualisation
    
    spindles = []
    # initialize list with spindles
    above_threshold = np.where(spindle_threshold)[0]
    # returns indices where signal above the threshold
    stacked_spindles = []
    # initialize list for stacking the spindles for the visualisation
    # contains aligned spindles at peak
    
    if len(above_threshold) > 0:
        # checking it's not empty
        start_idx = above_threshold[0]
        # would be the start of a potential spindle
        for i in range(1, len(above_threshold)):
            if above_threshold[i] > above_threshold[i - 1] + 1:  
                # if above threshold[1] > above_threshold[0] + 1
                # because all indices should be separated by 1
                # so here detects gaps
                end_idx = above_threshold[i - 1]
                # so if above condition is true, this is the end of the spindle
                duration = (end_idx - start_idx) / sfreq
                if 0.5 <= duration <= 3:
                    # only keep spindles lasting 0.5 to 3 seconds
                    segment = channel_data[start_idx:end_idx]
                    # extract EEG segment corresponding to detected spindle
                    peak_idx = start_idx + np.argmax(segment) 
                    # extract the peak of the spindle
                    # this will be useful for later
                    #spindles.append(f"Spindle detected from {start_idx / sfreq:.2f}s to {end_idx / sfreq:.2f}s, peak at {peak_idx / sfreq:.2f}s")
                    spindles.append((peak_idx / sfreq))
                    # all the spindles are stored in spindles
                    
                    # Aligning spindles at peak for visualization
                    before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
                    # still in the for loop, so this is the peak index of individual peak
                    after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
                    # extracting 1.5 seconds before and after peak
                    # max and min are used for out of bounds situations at the start and end of EEG data
                    aligned_segment = channel_data[before_peak_idx:after_peak_idx]
                    stacked_spindles.append(aligned_segment)
                    # the aligned segment is saved in stacked spindles
                
                start_idx = above_threshold[i]
                # update the start index for the for loop

        # then need to process the final spindle
        end_idx = above_threshold[-1]
        duration = (end_idx - start_idx) / sfreq
        if 0.5 <= duration <= 3:
            segment = channel_data[start_idx:end_idx]
            peak_idx = start_idx + np.argmax(segment)
            spindles.append((peak_idx / sfreq))

            before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
            after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
            aligned_segment = channel_data[before_peak_idx:after_peak_idx]
            stacked_spindles.append(aligned_segment)

    
    return spindles

### Epochs function

In [4]:
def create_fixed_length_epochs(raw, duration=3.0, overlap=0.0, preload=True, reject_by_annotation=False):

    return mne.make_fixed_length_epochs(
        raw,
        duration=duration,
        overlap=overlap,
        preload=preload,
        reject_by_annotation=reject_by_annotation
    )
# function mne.make_fixed_length_epochs takes into account the sampling frequency of the data

def label_spindle_epochs(epochs, spindle_starts, spindle_ends, epoch_length_sec=3.0):

    epoch_starts = np.arange(len(epochs)) * epoch_length_sec
    # new np array with the start time of each epoch
    # epoch_starts[i] is the start time of each epoch

    epoch_labels = np.zeros(len(epochs), dtype=int)
    # initialize all the labels as 0 initially

    for start, end in zip(spindle_starts, spindle_ends):
        # loop through the start and end times of detected spindles by YASA
        for i, epoch_start in enumerate(epoch_starts):
            # loop through the one-second epochs that are not labelled yet
            epoch_end = epoch_start + epoch_length_sec
            # for each epoch, calculate the epoch end time
            # which is epoch_start + length of epoch
            # so now have the time range of each epoch
            if (start < epoch_end) and (end > epoch_start):
                # if the spindle started before the epoch ends
                # and the spindle ended after the epoch started
                epoch_labels[i] = 1
                
    return epoch_labels

### Importing data

In [5]:
# file paths
train_file = r"C:\EEG DATA\combined_sets\train_raw.fif"
test_file = r"C:\EEG DATA\combined_sets\test_raw.fif"

# load raw files
train_raw = mne.io.read_raw_fif(train_file, preload=True)
test_raw = mne.io.read_raw_fif(test_file, preload=True)

Opening raw data file C:\EEG DATA\combined_sets\train_raw.fif...
Isotrak not found
    Range : 1470000 ... 23295072 =   2940.000 ... 46590.144 secs
Ready.
Reading 0 ... 21825072  =      0.000 ... 43650.144 secs...
Opening raw data file C:\EEG DATA\combined_sets\test_raw.fif...
Isotrak not found
    Range : 825000 ... 17985049 =   1650.000 ... 35970.098 secs
Ready.
Reading 0 ... 17160049  =      0.000 ... 34320.098 secs...


## With raw data

### Spindle detection

In [45]:
spindles_train_times_raw_downsampled = detect_spindles_times(train_raw, do_filter=True, do_downsample=True)
spindles_test_times_raw_downsampled = detect_spindles_times(test_raw, do_filter=True, do_downsample=True)

spindles_starts_train_raw_downsampled, spindles_ends_train_raw_downsampled = zip(*spindles_train_times_raw_downsampled) if spindles_train_times_raw_downsampled else([],[])
spindles_starts_test_raw_downsampled, spindles_ends_test_raw_downsampled = zip(*spindles_test_times_raw_downsampled) if spindles_test_times_raw_downsampled else([],[])

print(len(spindles_starts_train_raw_downsampled))
print(len(spindles_ends_train_raw_downsampled))

print(len(spindles_starts_test_raw_downsampled))
print(len(spindles_ends_test_raw_downsampled))

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 73 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 12.00
- Lower transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 10.50 Hz)
- Upper passband edge: 16.00 Hz
- Upper transition bandwidth: 4.00 Hz (-6 dB cutoff frequency: 18.00 Hz)
- Filter length: 551 samples (1.102 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 50 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 pass

### Downsample

In [46]:
train_raw_downsampled = train_raw.copy().resample(100)
test_raw_downsampled = test_raw.copy().resample(100)

print(train_raw_downsampled.info['sfreq'])
print(test_raw_downsampled.info['sfreq'])

100.0
100.0


### Epoch the data

In [47]:
epochs_train_raw_downsampled = create_fixed_length_epochs(train_raw_downsampled)
epochs_test_raw_downsampled = create_fixed_length_epochs(test_raw_downsampled)
print(epochs_train_raw_downsampled.get_data().shape)
print(epochs_test_raw_downsampled.get_data().shape)

Not setting metadata
14550 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 14550 events and 300 original time points ...
0 bad epochs dropped
Not setting metadata
11440 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 11440 events and 300 original time points ...
0 bad epochs dropped
(14550, 1, 300)
(11440, 1, 300)


### Labels for 3-second epochs

In [48]:
%%time

# Train set

epoch_labels_train_raw_downsampled = label_spindle_epochs(epochs_train_raw_downsampled, spindles_starts_train_raw_downsampled, spindles_ends_train_raw_downsampled)

print(f"Train data first 10 spindles: {spindles_starts_train_raw_downsampled[:10]}")
print(f"Train labels first 10 epochs: {epoch_labels_train_raw_downsampled[:10]}")

# Test set

epoch_labels_test_raw_downsampled = label_spindle_epochs(epochs_test_raw_downsampled, spindles_starts_test_raw_downsampled, spindles_ends_test_raw_downsampled)

print(f"\nTest data first 10 spindles: {spindles_starts_test_raw_downsampled[:10]}")
print(f"Test labels first 10 epochs: {epoch_labels_test_raw_downsampled[:10]}")

Train data first 10 spindles: (np.float64(0.05), np.float64(4.95), np.float64(10.01), np.float64(19.41), np.float64(27.23), np.float64(30.01), np.float64(31.05), np.float64(41.95), np.float64(47.79), np.float64(53.14))
Train labels first 10 epochs: [1 1 1 1 0 0 1 0 0 1]

Test data first 10 spindles: (np.float64(29.89), np.float64(49.64), np.float64(61.33), np.float64(78.56), np.float64(87.9), np.float64(88.5), np.float64(97.5), np.float64(112.37), np.float64(119.92), np.float64(147.43))
Test labels first 10 epochs: [0 0 0 0 0 0 0 0 0 1]
CPU times: total: 28 s
Wall time: 28.1 s


### Prepare EEG data for CNN input

#### X and y train and test sets

In [49]:
# Reshape arrays

epochs_train_np_raw_downsampled = np.array(epochs_train_raw_downsampled).reshape(len(epochs_train_raw_downsampled), -1, 1)
# number of epochs N, sampling frequency (time dimension automatically inferred), channel dimension
epochs_test_np_raw_downsampled = np.array(epochs_test_raw_downsampled).reshape(len(epochs_test_raw_downsampled), -1, 1)
                                                                 
# Define X and y sets
                                                                 
X_train_raw_downsampled = epochs_train_np_raw_downsampled
y_train_raw_downsampled = epoch_labels_train_raw_downsampled

X_test_raw_downsampled = epochs_test_np_raw_downsampled
y_test_raw_downsampled = epoch_labels_test_raw_downsampled

# Print shapes

print(f"X_train shape: {X_train_raw_downsampled.shape}")
print(f"y_train shape: {y_train_raw_downsampled.shape}")

print(f"\nX_test shape: {X_test_raw_downsampled.shape}")
print(f"y_test shape: {y_test_raw_downsampled.shape}")

X_train shape: (14550, 300, 1)
y_train shape: (14550,)

X_test shape: (11440, 300, 1)
y_test shape: (11440,)


#### (Normalisation of data)

Only used to look at impact of normalisation but did normalise in the end.

In [50]:
# min-max normalization
# and add a very small constant to avoid division by zero

X_train_raw_norm_downsampled = np.array([
    (epoch - np.min(epoch)) / (np.max(epoch) - np.min(epoch) + 1e-8)
    for epoch in X_train_raw_downsampled
])

X_test_raw_norm_downsampled = np.array([
    (epoch - np.min(epoch)) / (np.max(epoch) - np.min(epoch) + 1e-8)
    for epoch in X_test_raw_downsampled
])

In [17]:
print("Before normalisation:")
print("Max train value:", np.max(X_train_raw_downsampled))
print("Min train value:", np.min(X_train_raw_downsampled))
print("Max test value:", np.max(X_test_raw_downsampled))
print("Min test value:", np.min(X_test_raw_downsampled))

print("\nAfter normalisation:")
print("Max train value:", np.max(X_train_raw_norm_downsampled))
print("Min train value:", np.min(X_train_raw_norm_downsampled))
print("Max test value:", np.max(X_test_raw_norm_downsampled))
print("Min test value:", np.min(X_test_raw_norm_downsampled))

Before normalisation:
Max train value: 882.1938427673039
Min train value: -508.80421083518894
Max test value: 1008.3762701536886
Min test value: -1078.8024424752498

After normalisation:
Max train value: 0.9999999999888312
Min train value: 0.0
Max test value: 0.999999999992055
Min test value: 0.0


### Train the model

In [31]:
early_stop = EarlyStopping(
    monitor='val_loss',      
    patience=5,               
    restore_best_weights=True 
)
# stop after 5 epochs with no improvement

In [32]:
# show model architecture
input_shape = (300, 1)
model_raw_downsampled = build_cnn_model_downsampled(input_shape)
model_raw_downsampled.summary()

In [33]:
%%time

training_info_raw_downsampled = model_raw_downsampled.fit(X_train_raw_downsampled, y_train_raw_downsampled, validation_split=0.2, epochs=30, batch_size=128, callbacks=[early_stop])

Epoch 1/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 122ms/step - accuracy: 0.5353 - loss: 0.6892 - val_accuracy: 0.5271 - val_loss: 0.6918
Epoch 2/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 107ms/step - accuracy: 0.5878 - loss: 0.6668 - val_accuracy: 0.4550 - val_loss: 0.7338
Epoch 3/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 115ms/step - accuracy: 0.5761 - loss: 0.6717 - val_accuracy: 0.6560 - val_loss: 0.6489
Epoch 4/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 112ms/step - accuracy: 0.6513 - loss: 0.6356 - val_accuracy: 0.7189 - val_loss: 0.5760
Epoch 5/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 111ms/step - accuracy: 0.6394 - loss: 0.6272 - val_accuracy: 0.7189 - val_loss: 0.5896
Epoch 6/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 108ms/step - accuracy: 0.7004 - loss: 0.5716 - val_accuracy: 0.7838 - val_loss: 0.4943
Epoch 7/30
[1m91/91[

#### Plot the training history

In [34]:
def plot_training_history(training_info):
  fig, axs = plt.subplots(1, 2, figsize=(16, 5))
  axs[0].plot(training_info.history['loss'], label="training set")
  axs[0].plot(training_info.history['val_loss'], label="validation set")
  axs[0].set_xlabel("Epoch")
  axs[0].set_ylabel("Loss")
  axs[0].grid(True)
  axs[0].legend()
  try:
    axs[1].plot(training_info.history['accuracy'], label="training set")
    axs[1].plot(training_info.history['val_accuracy'], label="validation set")
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Accuracy")
    axs[1].grid(True)
    axs[1].legend()
  except:
    pass
  plt.show()

plot_training_history(training_info_raw_downsampled)

#### Evaluation on test set

In [35]:
model_raw_downsampled.evaluate(X_test_raw_downsampled, y_test_raw_downsampled)

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 14ms/step - accuracy: 0.8443 - loss: 0.3606


[0.34629422426223755, 0.8513985872268677]

In [36]:
y_pred = model_raw_downsampled.predict(X_test_raw_downsampled)
y_pred_labels = (y_pred > 0.5).astype(int)

print(confusion_matrix(y_test_raw_downsampled, y_pred_labels))
print(classification_report(y_test_raw_downsampled, y_pred_labels))

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 12ms/step
[[4684  470]
 [1230 5056]]
              precision    recall  f1-score   support

           0       0.79      0.91      0.85      5154
           1       0.91      0.80      0.86      6286

    accuracy                           0.85     11440
   macro avg       0.85      0.86      0.85     11440
weighted avg       0.86      0.85      0.85     11440



In [34]:
# Flatten in case y_pred has shape (n_samples, 1)
y_pred_proba = y_pred.ravel()

# Compute ROC curve and AUC
fpr, tpr, thresholds = roc_curve(y_test_raw_downsampled, y_pred_proba)
roc_auc = auc(fpr, tpr)

# Plotting
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random chance')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic")
plt.legend(loc="lower right")
plt.grid(True)
plt.show()

## With filtered data

In [51]:
# Apply bandpass filter between 12 and 16 Hz
# to compare performance of filtered dataset to unfiltered one
train_filtered = train_raw.copy().filter(l_freq=12, h_freq=16)
test_filtered = test_raw.copy().filter(l_freq=12, h_freq=16)

# Downsample to 100 Hz
train_filtered.resample(100)
test_filtered.resample(100)

Filtering raw data in 73 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 12.00
- Lower transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 10.50 Hz)
- Upper passband edge: 16.00 Hz
- Upper transition bandwidth: 4.00 Hz (-6 dB cutoff frequency: 18.00 Hz)
- Filter length: 551 samples (1.102 s)

Filtering raw data in 50 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 12.00
- Lower transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 10.50 Hz)
- Upper passban

Unnamed: 0,General,General.1
,Filename(s),test_raw.fif
,MNE object type,Raw
,Measurement date,2023-09-06 at 23:28:45 UTC
,Participant,Unknown
,Experimenter,Unknown
,Acquisition,Acquisition
,Duration,09:32:01 (HH:MM:SS)
,Sampling frequency,100.00 Hz
,Time points,3432010
,Channels,Channels


### Spindle detection

In [52]:
spindles_train_times_filtered = detect_spindles_times(train_filtered, do_filter=False, do_downsample=False)
spindles_test_times_filtered = detect_spindles_times(test_filtered, do_filter=False, do_downsample=False)
# since filtering and downsampling before, do not filter and downsample again in function

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


In [53]:
print(len(spindles_train_times_filtered))
print(len(spindles_test_times_filtered))

7123
5794


In [54]:
spindles_starts_train_filtered, spindles_ends_train_filtered = zip(*spindles_train_times_filtered) if spindles_train_times_filtered else([],[])
spindles_starts_test_filtered, spindles_ends_test_filtered = zip(*spindles_test_times_filtered) if spindles_test_times_filtered else([],[])

In [55]:
print(len(spindles_starts_train_filtered))
print(len(spindles_ends_train_filtered))

print(len(spindles_starts_test_filtered))
print(len(spindles_ends_test_filtered))

7123
7123
5794
5794


### Epoch the data

In [56]:
epochs_train_filtered = create_fixed_length_epochs(train_filtered)

Not setting metadata
14550 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 14550 events and 300 original time points ...
0 bad epochs dropped


In [57]:
epochs_test_filtered = create_fixed_length_epochs(test_filtered)

Not setting metadata
11440 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 11440 events and 300 original time points ...
0 bad epochs dropped


In [58]:
print(epochs_train_filtered.get_data().shape)
print(epochs_test_filtered.get_data().shape)

(14550, 1, 300)
(11440, 1, 300)


### Labels for 3-second epochs

In [59]:
%%time

# Train set

epoch_labels_train_filtered = label_spindle_epochs(epochs_train_filtered, spindles_starts_train_filtered, spindles_ends_train_filtered)

print(f"Train data first 10 spindles: {spindles_starts_train_filtered[:10]}")
print(f"Train labels first 10 epochs: {epoch_labels_train_filtered[:10]}")

# Test set

epoch_labels_test_filtered = label_spindle_epochs(epochs_test_filtered, spindles_starts_test_filtered, spindles_ends_test_filtered)

print(f"\nTest data first 10 spindles: {spindles_starts_test_filtered[:10]}")
print(f"Test labels first 10 epochs: {epoch_labels_test_filtered[:10]}")

Train data first 10 spindles: (np.float64(0.05), np.float64(4.95), np.float64(10.01), np.float64(19.41), np.float64(27.23), np.float64(30.01), np.float64(31.05), np.float64(41.95), np.float64(47.79), np.float64(53.14))
Train labels first 10 epochs: [1 1 1 1 0 0 1 0 0 1]

Test data first 10 spindles: (np.float64(29.89), np.float64(49.64), np.float64(61.33), np.float64(78.56), np.float64(87.9), np.float64(88.5), np.float64(97.5), np.float64(112.37), np.float64(119.92), np.float64(147.43))
Test labels first 10 epochs: [0 0 0 0 0 0 0 0 0 1]
CPU times: total: 28.3 s
Wall time: 28.5 s


### Prepare EEG data for CNN input

#### X and y train and test sets

In [60]:
# Reshape arrays

epochs_train_np_filtered = np.array(epochs_train_filtered).reshape(len(epochs_train_filtered), -1, 1)
# number of epochs N, sampling frequency (time dimension automatically inferred), channel dimension
epochs_test_np_filtered = np.array(epochs_test_filtered).reshape(len(epochs_test_filtered), -1, 1)
                                                                 
# Define X and y sets
                                                                 
X_train_filtered = epochs_train_np_filtered
y_train_filtered = epoch_labels_train_filtered

X_test_filtered = epochs_test_np_filtered
y_test_filtered = epoch_labels_test_filtered

# Print shapes

print(f"X_train shape: {X_train_filtered.shape}")
print(f"y_train shape: {y_train_filtered.shape}")

print(f"\nX_test shape: {X_test_filtered.shape}")
print(f"y_test shape: {y_test_filtered.shape}")
                                                                 

X_train shape: (14550, 300, 1)
y_train shape: (14550,)

X_test shape: (11440, 300, 1)
y_test shape: (11440,)


### Train the model

In [None]:
%%time

training_info = model.fit(X_train_filtered, y_train_filtered, validation_split=0.2, epochs=30, batch_size=128, callbacks=[early_stop])

#### Inspect class imbalance

In [17]:
# Count spindle and non-spindle labels in train and test sets
unique_train_filtered, counts_train_filtered = np.unique(y_train_filtered, return_counts=True)
unique_test_filtered, counts_test_filtered = np.unique(y_test_filtered, return_counts=True)

print("Train label distribution:")
for label, count in zip(unique_train_filtered, counts_train_filtered):
    print(f"Label {label}: {count}")

print("\nTest label distribution:")
for label, count in zip(unique_test_filtered, counts_test_filtered):
    print(f"Label {label}: {count}")

Train label distribution:
Label 0: 7103
Label 1: 7447

Test label distribution:
Label 0: 5154
Label 1: 6286


#### Training

In [19]:
early_stop = EarlyStopping(
    monitor='val_loss',      
    patience=5,               
    restore_best_weights=True 
)
# stop after 5 epochs with no improvement

In [20]:
%%time

training_info = model.fit(X_train_filtered, y_train_filtered, validation_split=0.2, epochs=30, batch_size=128, callbacks=[early_stop])

Epoch 1/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 170ms/step - accuracy: 0.5920 - loss: 0.6493 - val_accuracy: 0.5780 - val_loss: 0.7163
Epoch 2/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 163ms/step - accuracy: 0.7226 - loss: 0.5401 - val_accuracy: 0.6182 - val_loss: 0.6570
Epoch 3/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 171ms/step - accuracy: 0.6273 - loss: 0.6229 - val_accuracy: 0.6533 - val_loss: 0.6407
Epoch 4/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 164ms/step - accuracy: 0.6633 - loss: 0.5930 - val_accuracy: 0.8553 - val_loss: 0.3555
Epoch 5/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 166ms/step - accuracy: 0.8434 - loss: 0.3481 - val_accuracy: 0.8842 - val_loss: 0.2843
Epoch 6/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 172ms/step - accuracy: 0.8757 - loss: 0.2922 - val_accuracy: 0.8955 - val_loss: 0.2463
Epoch 7/30
[1m91/91[

In [21]:
def plot_training_history(training_info):
  fig, axs = plt.subplots(1, 2, figsize=(16, 5))
  axs[0].plot(training_info.history['loss'], label="training set")
  axs[0].plot(training_info.history['val_loss'], label="validation set")
  axs[0].set_xlabel("Epoch")
  axs[0].set_ylabel("Loss")
  axs[0].grid(True)
  axs[0].legend()
  try:
    axs[1].plot(training_info.history['accuracy'], label="training set")
    axs[1].plot(training_info.history['val_accuracy'], label="validation set")
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Accuracy")
    axs[1].grid(True)
    axs[1].legend()
  except:
    pass
  plt.show()

plot_training_history(training_info)

In [22]:
model.evaluate(X_test_filtered, y_test_filtered)

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 26ms/step - accuracy: 0.8999 - loss: 0.2338 


[0.24101856350898743, 0.8963286876678467]

### Metrics

In [23]:
y_pred = model.predict(X_test_filtered)
y_pred_labels = (y_pred > 0.5).astype(int)

print(confusion_matrix(y_test_filtered, y_pred_labels))
print(classification_report(y_test_filtered, y_pred_labels))

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 28ms/step
[[5042  112]
 [1074 5212]]
              precision    recall  f1-score   support

           0       0.82      0.98      0.89      5154
           1       0.98      0.83      0.90      6286

    accuracy                           0.90     11440
   macro avg       0.90      0.90      0.90     11440
weighted avg       0.91      0.90      0.90     11440



In [24]:
# Flatten in case y_pred has shape (n_samples, 1)
y_pred_proba = y_pred.ravel()

# Compute ROC curve and AUC
fpr, tpr, thresholds = roc_curve(y_test_filtered, y_pred_proba)
roc_auc = auc(fpr, tpr)

# Plotting
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random chance')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic")
plt.legend(loc="lower right")
plt.grid(True)
plt.show()

## With STFT

In [61]:
spindles_train_times_stft_downsampled = detect_spindles_times(train_raw, do_filter=True, do_downsample=True)
spindles_test_times_stft_downsampled = detect_spindles_times(test_raw, do_filter=True, do_downsample=True)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 73 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 12.00
- Lower transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 10.50 Hz)
- Upper passband edge: 16.00 Hz
- Upper transition bandwidth: 4.00 Hz (-6 dB cutoff frequency: 18.00 Hz)
- Filter length: 551 samples (1.102 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 50 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 pass

In [62]:
spindles_starts_train_stft_downsampled, spindles_ends_train_stft_downsampled = zip(*spindles_train_times_stft_downsampled) if spindles_train_times_stft_downsampled else([],[])
spindles_starts_test_stft_downsampled, spindles_ends_test_stft_downsampled = zip(*spindles_test_times_stft_downsampled) if spindles_test_times_stft_downsampled else([],[])

print(len(spindles_starts_train_stft_downsampled))
print(len(spindles_ends_train_stft_downsampled))

print(len(spindles_starts_test_stft_downsampled))
print(len(spindles_ends_test_stft_downsampled))

7123
7123
5794
5794


### Downsample

In [63]:
train_raw_downsampled = train_raw.copy().resample(100)
test_raw_downsampled = test_raw.copy().resample(100)

print(train_raw_downsampled.info['sfreq'])
print(test_raw_downsampled.info['sfreq'])

100.0
100.0


### Epoch the data

In [64]:
epochs_train_stft_downsampled = create_fixed_length_epochs(train_raw_downsampled)
epochs_test_stft_downsampled = create_fixed_length_epochs(test_raw_downsampled)

print(epochs_train_stft_downsampled.get_data().shape)
print(epochs_test_stft_downsampled.get_data().shape)

Not setting metadata
14550 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 14550 events and 300 original time points ...
0 bad epochs dropped
Not setting metadata
11440 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 11440 events and 300 original time points ...
0 bad epochs dropped
(14550, 1, 300)
(11440, 1, 300)


### Labels for 3-second epochs

In [65]:
%%time

# Train set

epoch_labels_train_stft_downsampled = label_spindle_epochs(epochs_train_stft_downsampled, spindles_starts_train_stft_downsampled, spindles_ends_train_stft_downsampled)

print(f"Train data first 10 spindles: {spindles_starts_train_stft_downsampled[:10]}")
print(f"Train labels first 10 epochs: {epoch_labels_train_stft_downsampled[:10]}")

# Test set

epoch_labels_test_stft_downsampled = label_spindle_epochs(epochs_test_stft_downsampled, spindles_starts_test_stft_downsampled, spindles_ends_test_stft_downsampled)

print(f"\nTest data first 10 spindles: {spindles_starts_test_stft_downsampled[:10]}")
print(f"Test labels first 10 epochs: {epoch_labels_test_stft_downsampled[:10]}")

Train data first 10 spindles: (np.float64(0.05), np.float64(4.95), np.float64(10.01), np.float64(19.41), np.float64(27.23), np.float64(30.01), np.float64(31.05), np.float64(41.95), np.float64(47.79), np.float64(53.14))
Train labels first 10 epochs: [1 1 1 1 0 0 1 0 0 1]

Test data first 10 spindles: (np.float64(29.89), np.float64(49.64), np.float64(61.33), np.float64(78.56), np.float64(87.9), np.float64(88.5), np.float64(97.5), np.float64(112.37), np.float64(119.92), np.float64(147.43))
Test labels first 10 epochs: [0 0 0 0 0 0 0 0 0 1]
CPU times: total: 27.3 s
Wall time: 27.6 s


### Apply STFT to each epoch

#### get rid of channel dimension

In [66]:
epochs_train_stft_downsampled = np.squeeze(epochs_train_stft_downsampled)
epochs_test_stft_downsampled = np.squeeze(epochs_test_stft_downsampled)

print(epochs_train_stft_downsampled.shape)
print(epochs_test_stft_downsampled.shape)

(14550, 300)
(11440, 300)


In [67]:
print(train_raw_downsampled.info['sfreq'])

100.0


In [68]:
fs = train_raw_downsampled.info['sfreq']  
# smaller nperseg means higher time resolution
# noverlap must be less than nperseg

# I want to focus more on the frequency resolution
# at the expense of the time resolution
# for spindles, want a 1-2 Hz resolution (frequency bins every 1-2 Hz)
# let's start with 2 Hz first 
nperseg = 50
noverlap = nperseg // 2
# common practice is to set noverlap to 50% of nperseg


epochs_train_stft_transformed_downsampled = []

for epoch in epochs_train_stft_downsampled:
    f, t, Zxx = stft(epoch, fs=fs, nperseg=nperseg, noverlap=noverlap)
    spectrogram = np.abs(Zxx)  
    epochs_train_stft_transformed_downsampled.append(spectrogram)

epochs_test_stft_transformed_downsampled = []

for epoch in epochs_test_stft_downsampled:
    f, t, Zxx = stft(epoch, fs=fs, nperseg=nperseg, noverlap=noverlap)
    spectrogram = np.abs(Zxx)  
    epochs_test_stft_transformed_downsampled.append(spectrogram)
    
# convert into numpy arrays
epochs_train_stft_transformed_downsampled = np.array(epochs_train_stft_transformed_downsampled)
epochs_test_stft_transformed_downsampled = np.array(epochs_test_stft_transformed_downsampled)

print("Train STFT shape:", epochs_train_stft_transformed_downsampled.shape)
print("Test STFT shape:", epochs_test_stft_transformed_downsampled.shape)

# shape is number of epochs, frequency_bins, time_bins


Train STFT shape: (14550, 26, 13)
Test STFT shape: (11440, 26, 13)


### X and y train and test sets

In [69]:
# Define X and y sets and reshape

X_train_stft_downsampled = epochs_train_stft_transformed_downsampled[..., np.newaxis]  # Shape: (14550, 65, 25, 1)
y_train_stft_downsampled = epoch_labels_train_stft_downsampled

X_test_stft_downsampled = epochs_test_stft_transformed_downsampled[..., np.newaxis]    # Shape: (11440, 65, 25, 1)
y_test_stft_downsampled = epoch_labels_test_stft_downsampled
                                                                 

# Print shapes

print(f"X_train shape: {X_train_stft_downsampled.shape}")
print(f"y_train shape: {y_train_stft_downsampled.shape}")

print(f"\nX_test shape: {X_test_stft_downsampled.shape}")
print(f"y_test shape: {y_test_stft_downsampled.shape}")

X_train shape: (14550, 26, 13, 1)
y_train shape: (14550,)

X_test shape: (11440, 26, 13, 1)
y_test shape: (11440,)


### Normalization of data

In [70]:
# this scales the spectrogram in range 0,1
# this is min-max normalisation

X_train_stft_norm_downsampled = np.array([
    (epoch - np.min(epoch)) / (np.max(epoch) - np.min(epoch) + 1e-8)
    for epoch in X_train_stft_downsampled
])

X_test_stft_norm_downsampled = np.array([
    (epoch - np.min(epoch)) / (np.max(epoch) - np.min(epoch) + 1e-8)
    for epoch in X_test_stft_downsampled
])

In [71]:
# should have values between 0 and 1 

print("Before normalisation:")
print("Max train value:", np.max(X_train_stft_downsampled))
print("Min train value:", np.min(X_train_stft_downsampled))

print("Max test value:", np.max(X_test_stft_downsampled))
print("Min test value:", np.min(X_test_stft_downsampled))

print("\nAfter normalisation:")
print("Max train value:", np.max(X_train_stft_norm_downsampled))
print("Min train value:", np.min(X_train_stft_norm_downsampled))

print("Max test value:", np.max(X_test_stft_norm_downsampled))
print("Min test value:", np.min(X_test_stft_norm_downsampled))

Before normalisation:
Max train value: 866.3673727627713
Min train value: 2.4061107274064854e-06
Max test value: 1011.6755263560262
Min test value: 5.496909294180341e-06

After normalisation:
Max train value: 0.9999999999884571
Min train value: 0.0
Max test value: 0.9999999999901138
Min test value: 0.0


### Training

In [29]:
early_stop = EarlyStopping(
    monitor='val_loss',      
    patience=5,               
    restore_best_weights=True 
)
# stop after 5 epochs with no improvement

In [31]:
%%time

training_info_cnn_2d_model_downsampled = cnn_2d_model_downsampled.fit(X_train_stft_norm_downsampled, y_train_stft_downsampled, validation_split=0.2, epochs=30, batch_size=128)

Epoch 1/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 71ms/step - accuracy: 0.5340 - loss: 0.6919 - val_accuracy: 0.5467 - val_loss: 0.6929
Epoch 2/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 71ms/step - accuracy: 0.5514 - loss: 0.6870 - val_accuracy: 0.4852 - val_loss: 0.6934
Epoch 3/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 71ms/step - accuracy: 0.5497 - loss: 0.6878 - val_accuracy: 0.5076 - val_loss: 0.6923
Epoch 4/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 66ms/step - accuracy: 0.5396 - loss: 0.6886 - val_accuracy: 0.4873 - val_loss: 0.6943
Epoch 5/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 60ms/step - accuracy: 0.5404 - loss: 0.6879 - val_accuracy: 0.5076 - val_loss: 0.6925
Epoch 6/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 60ms/step - accuracy: 0.5509 - loss: 0.6871 - val_accuracy: 0.4784 - val_loss: 0.6952
Epoch 7/30
[1m91/91[0m [32m━━━

In [23]:
def plot_training_history(training_info):
  fig, axs = plt.subplots(1, 2, figsize=(16, 5))
  axs[0].plot(training_info.history['loss'], label="training set")
  axs[0].plot(training_info.history['val_loss'], label="validation set")
  axs[0].set_xlabel("Epoch")
  axs[0].set_ylabel("Loss")
  axs[0].grid(True)
  axs[0].legend()
  try:
    axs[1].plot(training_info.history['accuracy'], label="training set")
    axs[1].plot(training_info.history['val_accuracy'], label="validation set")
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Accuracy")
    axs[1].grid(True)
    axs[1].legend()
  except:
    pass
  plt.show()

plot_training_history(training_info_cnn_2d_model_downsampled)

In [37]:
cnn_2d_model_downsampled.evaluate(X_test_stft_norm_downsampled, y_test_stft_downsampled)

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 18ms/step - accuracy: 0.5458 - loss: 0.6880


[0.6867818832397461, 0.5568181872367859]

In [32]:
y_pred = cnn_2d_model_downsampled.predict(X_test_stft_norm_downsampled)
y_pred_labels = (y_pred > 0.5).astype(int)

print(confusion_matrix(y_test_stft_downsampled, y_pred_labels))
print(classification_report(y_test_stft_downsampled, y_pred_labels))

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 10ms/step
[[1373 3781]
 [1337 4949]]
              precision    recall  f1-score   support

           0       0.51      0.27      0.35      5154
           1       0.57      0.79      0.66      6286

    accuracy                           0.55     11440
   macro avg       0.54      0.53      0.50     11440
weighted avg       0.54      0.55      0.52     11440



In [46]:
print("Predicted label distribution:", np.unique(y_pred_labels, return_counts=True))

Predicted label distribution: (array([0]), array([11440]))


In [47]:
print("True label distribution:", np.unique(y_test_stft, return_counts=True))

True label distribution: (array([0, 1]), array([5108, 6332]))


In [82]:
y_pred_probs = cnn_2d_model.predict(X_test_stft_norm)
print("Min prob:", y_pred_probs.min())
print("Max prob:", y_pred_probs.max())

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 38ms/step
Min prob: 0.45536286
Max prob: 0.59533376


In [69]:
unique, counts = np.unique(y_test_stft, return_counts=True)
print(dict(zip(unique, counts)))

{np.int64(0): np.int64(5108), np.int64(1): np.int64(6332)}


### Simplified CNN model

In [46]:
def build_simplified_cnn(input_shape=(65, 25, 1)):
    input_layer = tf.keras.Input(shape=input_shape)

    x = tf.keras.layers.Conv2D(32, (5, 5), padding='same', activation='relu')(input_layer)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)
    x = tf.keras.layers.BatchNormalization()(x)

    x = tf.keras.layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)
    x = tf.keras.layers.BatchNormalization()(x)

    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(64, activation='relu')(x)
    x = tf.keras.layers.Dropout(0.5)(x)

    output = tf.keras.layers.Dense(1, activation='sigmoid')(x)

    model = tf.keras.Model(inputs=input_layer, outputs=output)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

    return model

In [47]:
%%time

cnn_2d_model_simplified_downsampled = build_simplified_cnn(input_shape=(33, 11, 1))
training_info_cnn_2d_model_simplified_downsampled = cnn_2d_model_simplified_downsampled.fit(X_train_stft_norm_downsampled, y_train_stft_downsampled,
                                 validation_split=0.2,
                                 epochs=30,
                                 batch_size=128)

Epoch 1/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 22ms/step - accuracy: 0.5520 - loss: 0.7290 - val_accuracy: 0.5787 - val_loss: 0.6821
Epoch 2/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 19ms/step - accuracy: 0.6897 - loss: 0.5730 - val_accuracy: 0.5787 - val_loss: 0.7022
Epoch 3/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 19ms/step - accuracy: 0.7386 - loss: 0.5168 - val_accuracy: 0.5787 - val_loss: 0.7785
Epoch 4/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 19ms/step - accuracy: 0.7640 - loss: 0.4923 - val_accuracy: 0.5787 - val_loss: 0.8820
Epoch 5/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 19ms/step - accuracy: 0.7893 - loss: 0.4623 - val_accuracy: 0.5804 - val_loss: 0.9179
Epoch 6/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 21ms/step - accuracy: 0.7808 - loss: 0.4616 - val_accuracy: 0.5945 - val_loss: 0.7765
Epoch 7/30
[1m91/91[0m [32m━━━━

In [48]:
def plot_training_history(training_info):
  fig, axs = plt.subplots(1, 2, figsize=(16, 5))
  axs[0].plot(training_info.history['loss'], label="training set")
  axs[0].plot(training_info.history['val_loss'], label="validation set")
  axs[0].set_xlabel("Epoch")
  axs[0].set_ylabel("Loss")
  axs[0].grid(True)
  axs[0].legend()
  try:
    axs[1].plot(training_info.history['accuracy'], label="training set")
    axs[1].plot(training_info.history['val_accuracy'], label="validation set")
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Accuracy")
    axs[1].grid(True)
    axs[1].legend()
  except:
    pass
  plt.show()

plot_training_history(training_info_cnn_2d_model_simplified_downsampled)

In [49]:
y_pred = cnn_2d_model_simplified_downsampled.predict(X_test_stft_norm_downsampled)
y_pred_labels = (y_pred > 0.5).astype(int)

print(confusion_matrix(y_test_stft_downsampled, y_pred_labels))
print(classification_report(y_test_stft_downsampled, y_pred_labels))

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step  
[[2803 2351]
 [ 645 5641]]
              precision    recall  f1-score   support

           0       0.81      0.54      0.65      5154
           1       0.71      0.90      0.79      6286

    accuracy                           0.74     11440
   macro avg       0.76      0.72      0.72     11440
weighted avg       0.75      0.74      0.73     11440



In [86]:
y_pred_probs = cnn_2d_model_simplified.predict(X_test_stft_norm)
print("Min prob:", y_pred_probs.min())
print("Max prob:", y_pred_probs.max())

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step
Min prob: 0.09505223
Max prob: 1.0


## With wavelet transform

In [6]:
spindles_train_times_wavelet_downsampled = detect_spindles_times(train_raw, do_filter=True, do_downsample=True)
spindles_test_times_wavelet_downsampled = detect_spindles_times(test_raw, do_filter=True, do_downsample=True)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 73 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 12.00
- Lower transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 10.50 Hz)
- Upper passband edge: 16.00 Hz
- Upper transition bandwidth: 4.00 Hz (-6 dB cutoff frequency: 18.00 Hz)
- Filter length: 551 samples (1.102 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 50 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 pass

In [7]:
spindles_starts_train_wavelet_downsampled, spindles_ends_train_wavelet_downsampled = zip(*spindles_train_times_wavelet_downsampled) if spindles_train_times_wavelet_downsampled else([],[])
spindles_starts_test_wavelet_downsampled, spindles_ends_test_wavelet_downsampled = zip(*spindles_test_times_wavelet_downsampled) if spindles_test_times_wavelet_downsampled else([],[])

print(len(spindles_starts_train_wavelet_downsampled))
print(len(spindles_ends_train_wavelet_downsampled))

print(len(spindles_starts_test_wavelet_downsampled))
print(len(spindles_ends_test_wavelet_downsampled))

7123
7123
5794
5794


### Downsample

In [8]:
train_raw_downsampled = train_raw.copy().resample(100)
test_raw_downsampled = test_raw.copy().resample(100)

print(train_raw_downsampled.info['sfreq'])
print(test_raw_downsampled.info['sfreq'])

100.0
100.0


### Epoch the data

In [9]:
epochs_train_wavelet_downsampled = create_fixed_length_epochs(train_raw_downsampled)
epochs_test_wavelet_downsampled = create_fixed_length_epochs(test_raw_downsampled)

print(epochs_train_wavelet_downsampled.get_data().shape)
print(epochs_test_wavelet_downsampled.get_data().shape)

Not setting metadata
14550 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 14550 events and 300 original time points ...
0 bad epochs dropped
Not setting metadata
11440 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 11440 events and 300 original time points ...
0 bad epochs dropped
(14550, 1, 300)
(11440, 1, 300)


### Label the epochs

In [10]:
%%time

# Train set

epoch_labels_train_wavelet_downsampled = label_spindle_epochs(epochs_train_wavelet_downsampled, spindles_starts_train_wavelet_downsampled, spindles_ends_train_wavelet_downsampled)

print(f"Train data first 10 spindles: {spindles_starts_train_wavelet_downsampled[:10]}")
print(f"Train labels first 10 epochs: {epoch_labels_train_wavelet_downsampled[:10]}")

# Test set

epoch_labels_test_wavelet_downsampled = label_spindle_epochs(epochs_test_wavelet_downsampled, spindles_starts_test_wavelet_downsampled, spindles_ends_test_wavelet_downsampled)

print(f"\nTest data first 10 spindles: {spindles_starts_test_wavelet_downsampled[:10]}")
print(f"Test labels first 10 epochs: {epoch_labels_test_wavelet_downsampled[:10]}")

Train data first 10 spindles: (np.float64(0.05), np.float64(4.95), np.float64(10.01), np.float64(19.41), np.float64(27.23), np.float64(30.01), np.float64(31.05), np.float64(41.95), np.float64(47.79), np.float64(53.14))
Train labels first 10 epochs: [1 1 1 1 0 0 1 0 0 1]

Test data first 10 spindles: (np.float64(29.89), np.float64(49.64), np.float64(61.33), np.float64(78.56), np.float64(87.9), np.float64(88.5), np.float64(97.5), np.float64(112.37), np.float64(119.92), np.float64(147.43))
Test labels first 10 epochs: [0 0 0 0 0 0 0 0 0 1]
CPU times: total: 27.3 s
Wall time: 27.5 s


In [19]:
epochs_train_wavelet_downsampled = np.squeeze(epochs_train_wavelet_downsampled)
epochs_test_wavelet_downsampled = np.squeeze(epochs_test_wavelet_downsampled)

print(epochs_train_wavelet_downsampled.shape)
print(epochs_test_wavelet_downsampled.shape)

(14550, 300)
(11440, 300)


### Apply wavelet transform

In [16]:
def compute_scalogram_full(eeg_epoch, sfreq, img_shape=(64, 64)):
    # the default output image size is 64, 64
    
    scales = np.arange(1, 128)
    # array of wavelet scales from 1 to 127
    # small scale (1): captures high frequency
    # large scale (127): captures low frequency

    # time frequency decomposition
    coef, freqs = pywt.cwt(eeg_epoch, scales, 'cmor1.5-1.0', sampling_period=1/sfreq)
    # cwt: continuous wavelet transform
    # more precisely, the complex morlet wavelet (cmor1.5-1.0)
    # sampling_period looks at the sampling frequency
    # calculates the CWT for each scale
    # result is 2D matrix: shape(127, time_length): shows
    # how the power at each frequency changes over time

    # power computation
    power = np.abs(coef)**2  
    # computes the power (energy) at each scale and time point
    # np.abs: takes the magnitude of the complex coefficients
    # and **2 squares the magnitude to get power

    # normalization min-max
    # as with STFT but here incorporate it directly into function
    scaler = MinMaxScaler()
    normalized = scaler.fit_transform(power)
    # applies the scaling to each column of the power matrix

    # resizing
    # fixed image size for CNN input
    resized = cv2.resize(normalized, img_shape, interpolation=cv2.INTER_AREA)
    # uses OpenCV to resize power image to a fixed img_shape
    
    # add back the channel dimension
    # easier to put this directly in the function
    return resized[..., np.newaxis]

# from 1D EEG epoch to 2D image representing time-frequency power using 
# Continuous Wavelet Transform 

In [24]:
%%time

sfreq = train_raw_downsampled.info['sfreq']  

X_train_wavelet_downsampled = np.array([compute_scalogram_full(epoch, sfreq) for epoch in epochs_train_wavelet_downsampled])
y_train_wavelet_downsampled = np.array(epoch_labels_train_wavelet_downsampled)

X_test_wavelet_downsampled = np.array([compute_scalogram_full(epoch, sfreq) for epoch in epochs_test_wavelet_downsampled])
y_test_wavelet_downsampled = np.array(epoch_labels_test_wavelet_downsampled)

# Print shapes

print(f"X_train shape: {X_train_wavelet_downsampled.shape}")
print(f"y_train shape: {y_train_wavelet_downsampled.shape}")

print(f"\nX_test shape: {X_test_wavelet_downsampled.shape}")
print(f"y_test shape: {y_test_wavelet_downsampled.shape}")

X_train shape: (14550, 64, 64, 1)
y_train shape: (14550,)

X_test shape: (11440, 64, 64, 1)
y_test shape: (11440,)
CPU times: total: 10min 7s
Wall time: 10min 11s


In [23]:
X_train_wavelet_downsampled[2]

array([[[1.29820837e-03],
        [8.09250209e-04],
        [3.62803224e-04],
        ...,
        [9.94897171e-05],
        [4.83987876e-05],
        [5.28998654e-05]],

       [[6.49103194e-04],
        [3.03108213e-04],
        [1.86522278e-04],
        ...,
        [2.76810017e-04],
        [1.43850501e-03],
        [1.59868261e-03]],

       [[2.25915546e-03],
        [1.01363215e-03],
        [1.20906460e-03],
        ...,
        [1.08826389e-03],
        [2.63661447e-03],
        [2.03844854e-03]],

       ...,

       [[4.25665807e-01],
        [3.83832658e-01],
        [3.48251638e-01],
        ...,
        [7.65048739e-01],
        [7.79405269e-01],
        [8.21926234e-01]],

       [[3.84760529e-01],
        [3.48006428e-01],
        [3.04634528e-01],
        ...,
        [7.58691040e-01],
        [7.55295818e-01],
        [7.70743503e-01]],

       [[3.67188432e-01],
        [3.09190140e-01],
        [2.53313625e-01],
        ...,
        [7.46529125e-01],
        [7.6574

In [29]:
input_shape = (64, 64, 1)
cnn_2d_model_downsampled_wavelet = build_2d_cnn_model_downsampled(input_shape)
cnn_2d_model_downsampled_wavelet.summary()

In [30]:
early_stop = EarlyStopping(
    monitor='val_loss',      
    patience=5,               
    restore_best_weights=True 
)
# stop after 5 epochs with no improvement

In [31]:
%%time

training_info_cnn_2d_model_downsampled_wavelet = cnn_2d_model_downsampled_wavelet.fit(X_train_wavelet_downsampled, y_train_wavelet_downsampled, validation_split=0.2, epochs=30, batch_size=128)

Epoch 1/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m146s[0m 2s/step - accuracy: 0.5065 - loss: 0.7130 - val_accuracy: 0.4213 - val_loss: 0.7007
Epoch 2/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m137s[0m 2s/step - accuracy: 0.5201 - loss: 0.6950 - val_accuracy: 0.4636 - val_loss: 0.6962
Epoch 3/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m143s[0m 2s/step - accuracy: 0.5206 - loss: 0.6927 - val_accuracy: 0.4251 - val_loss: 0.7018
Epoch 4/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m154s[0m 2s/step - accuracy: 0.5306 - loss: 0.6896 - val_accuracy: 0.4447 - val_loss: 0.6986
Epoch 5/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m172s[0m 2s/step - accuracy: 0.5360 - loss: 0.6912 - val_accuracy: 0.4320 - val_loss: 0.7058
Epoch 6/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m181s[0m 2s/step - accuracy: 0.5206 - loss: 0.6926 - val_accuracy: 0.4784 - val_loss: 0.6974
Epoch 7/30
[1m91/91[0m [32m━━━━

In [32]:
def plot_training_history(training_info):
    fig, axs = plt.subplots(1, 2, figsize=(16, 5))
    axs[0].plot(training_info.history['loss'], label="training set")
    axs[0].plot(training_info.history['val_loss'], label="validation set")
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Loss")
    axs[0].grid(True)
    axs[0].legend()
    try:
        axs[1].plot(training_info.history['accuracy'], label="training set")
        axs[1].plot(training_info.history['val_accuracy'], label="validation set")
        axs[1].set_xlabel("Epoch")
        axs[1].set_ylabel("Accuracy")
        axs[1].grid(True)
        axs[1].legend()
    except:
        pass
  
    fig.suptitle("Training history for one-input 2D CNN model with wavelet EEG data for spindle detection", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

plot_training_history(training_info_cnn_2d_model_downsampled_wavelet)

In [33]:
# Get predictions
y_pred = cnn_2d_model_downsampled_wavelet.predict(X_test_wavelet_downsampled)
y_pred_labels = (y_pred > 0.5).astype(int)

# Confusion matrix
cm = confusion_matrix(y_test_wavelet_downsampled, y_pred_labels)
cm_df = pd.DataFrame(cm, index=["Actual 0", "Actual 1"], columns=["Predicted 0", "Predicted 1"])

# Classification report as a dataframe
report = classification_report(y_test_wavelet_downsampled, y_pred_labels, output_dict=True)
report_df = pd.DataFrame(report).transpose()

# Confusion matrix plotted as heatmap
plt.figure(figsize=(6, 5))
sns.heatmap(cm_df, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix for spindle detection \nusing a one-input 2D CNN with wavelet EEG data")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.tight_layout()
plt.show()

# Classification report as a table
fig, ax = plt.subplots(figsize=(10, 4))
ax.axis('tight')
ax.axis('off')
table = ax.table(cellText=report_df.round(2).values,
                 colLabels=report_df.columns,
                 rowLabels=report_df.index,
                 cellLoc='center',
                 loc='center')
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1.2, 1.2)
plt.title("Classification Report for spindle detection \nusing a one-input 2D CNN with wavelet EEG data", fontsize=14)
plt.tight_layout()
plt.show()


[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 101ms/step


## Three-channel CNN

### Reducing STFT dimension

The STFT shape is (number of epochs, time windows, frequency bins, channels).
If you only keep the frequency dimension, it will be (number of epochs, frequency bins, channels), so it will be all the frequencies for the average time.

#### Only keeping time axis

In [72]:
X_train_stft_time = np.mean(X_train_stft_norm_downsampled, axis=2)
# collapses the frequency axis by averaging them
X_test_stft_time =  np.mean(X_test_stft_norm_downsampled, axis=2)

print(f"Shape of X_train_stft_time:{X_train_stft_time.shape}")
print(f"Shape of X_test_stft_time:{X_test_stft_time.shape}")

Shape of X_train_stft_time:(14550, 26, 1)
Shape of X_test_stft_time:(11440, 26, 1)


#### Only keeping frequency axis

In [73]:
X_train_stft_freq = np.mean(X_train_stft_norm_downsampled, axis=1) 
X_test_stft_freq = np.mean(X_test_stft_norm_downsampled, axis=1)

print(f"Shape of X_train_stft_freq:{X_train_stft_freq.shape}")
print(f"Shape of X_test_stft_freq:{X_test_stft_freq.shape}")

Shape of X_train_stft_freq:(14550, 13, 1)
Shape of X_test_stft_freq:(11440, 13, 1)


### 3-input model with time for STFT

#### Dictionary for the input

For the y train and test data, keeping the y_train_raw_downsampled is sufficient because the labels for the spindles are always the same.

In [113]:
X_train_dict_time = {
    'raw_input': X_train_raw_downsampled,
    'filtered_input': X_train_filtered,
    'stft_input': X_train_stft_time
}

X_test_dict_time = {
    'raw_input': X_test_raw_downsampled,
    'filtered_input': X_test_filtered,
    'stft_input': X_test_stft_time
}

#### The model

In [116]:
def build_multi_input_cnn_model_time():
    
    input_raw = tf.keras.Input(shape=(300, 1), name='raw_input')
    input_filtered = tf.keras.Input(shape=(300, 1), name='filtered_input')
    input_stft = tf.keras.Input(shape=(26, 1), name='stft_input')  
    # 300 time points of raw and filtered EEG
    # 26 time points from the STFT transform

    def conv_branch(input_layer, kernel_sizes=[5, 11, 21]):
        outputs = []
        for k in kernel_sizes:
            pad = k // 2
            x = tf.keras.layers.ZeroPadding1D(padding=pad)(input_layer)
            # applies zero-padding so that the length after convolution stays the same
            x = tf.keras.layers.Conv1D(filters=10, kernel_size=k, strides=1, padding='valid')(x)
            # performs the convolution
            x = tf.keras.layers.LeakyReLU(negative_slope=0.01)(x)
            # applies LeakyReLU 
            x = tf.keras.layers.MaxPooling1D(pool_size=2)(x)
            # reduce temporal dimension by 2
            x = tf.keras.layers.BatchNormalization()(x)
            # normalize the output
            outputs.append(x)
        return tf.keras.layers.Concatenate()(outputs)

    # Convolutional branches: each input has its own CNN branch
    branch_raw = conv_branch(input_raw)
    branch_filtered = conv_branch(input_filtered)
    branch_stft = conv_branch(input_stft)

    # Each branch through its own GRU
    # this captures temporal dependencies
    gru_raw = tf.keras.layers.GRU(64)(branch_raw)
    gru_filtered = tf.keras.layers.GRU(64)(branch_filtered)
    gru_stft = tf.keras.layers.GRU(64)(branch_stft)

    # Concatenate GRU outputs (fixed-length vectors)
    # this leads to one 192-dimensional vector
    merged = tf.keras.layers.Concatenate()([gru_raw, gru_filtered, gru_stft])

    # Dense layers
    x = tf.keras.layers.Dense(64, activation='relu')(merged)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(x)

    # Build model
    model = tf.keras.Model(inputs=[input_raw, input_filtered, input_stft], outputs=output)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

In [117]:
cnn_model_multi_input_time = build_multi_input_cnn_model_time()
cnn_model_multi_input_time.summary()

#### Training

In [118]:
early_stop = EarlyStopping(
    monitor='val_loss',      
    patience=5,               
    restore_best_weights=True 
)
# stop after 5 epochs with no improvement

In [119]:
%%time

training_info_multiple_inputs_time = cnn_model_multi_input_time.fit(X_train_dict, y_train_raw_downsampled, validation_split=0.2, epochs=20, batch_size=128, callbacks=[early_stop])

Epoch 1/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 206ms/step - accuracy: 0.6123 - loss: 0.6465 - val_accuracy: 0.6347 - val_loss: 0.6817
Epoch 2/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 221ms/step - accuracy: 0.7697 - loss: 0.5290 - val_accuracy: 0.8313 - val_loss: 0.5050
Epoch 3/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 217ms/step - accuracy: 0.8345 - loss: 0.4249 - val_accuracy: 0.7416 - val_loss: 0.5437
Epoch 4/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 216ms/step - accuracy: 0.7712 - loss: 0.4929 - val_accuracy: 0.8096 - val_loss: 0.4063
Epoch 5/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 189ms/step - accuracy: 0.8522 - loss: 0.3828 - val_accuracy: 0.8811 - val_loss: 0.3276
Epoch 6/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 216ms/step - accuracy: 0.7747 - loss: 0.4535 - val_accuracy: 0.6533 - val_loss: 0.6370
Epoch 7/20
[1m91/91[

In [121]:
def plot_training_history(training_info):
    fig, axs = plt.subplots(1, 2, figsize=(16, 5))
    axs[0].plot(training_info.history['loss'], label="training set")
    axs[0].plot(training_info.history['val_loss'], label="validation set")
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Loss")
    axs[0].grid(True)
    axs[0].legend()
    try:
        axs[1].plot(training_info.history['accuracy'], label="training set")
        axs[1].plot(training_info.history['val_accuracy'], label="validation set")
        axs[1].set_xlabel("Epoch")
        axs[1].set_ylabel("Accuracy")
        axs[1].grid(True)
        axs[1].legend()
    except:
        pass
  
    fig.suptitle("Training History for three-input CNN model with time component of STFT", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

plot_training_history(training_info_multiple_inputs_time)

In [122]:
cnn_model_multi_input_time.evaluate(X_test_dict_time, y_test_raw_downsampled)

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 20ms/step - accuracy: 0.8990 - loss: 0.2316


[0.23758018016815186, 0.8965908885002136]

In [123]:
y_pred_time = cnn_model_multi_input_time.predict(X_test_dict)
y_pred_labels_time = (y_pred_freq > 0.5).astype(int)

print(confusion_matrix(y_test_raw_downsampled, y_pred_labels_time))
print(classification_report(y_test_raw_downsampled, y_pred_labels_time))

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 22ms/step
[[5011  143]
 [1040 5246]]
              precision    recall  f1-score   support

           0       0.83      0.97      0.89      5154
           1       0.97      0.83      0.90      6286

    accuracy                           0.90     11440
   macro avg       0.90      0.90      0.90     11440
weighted avg       0.91      0.90      0.90     11440



In [124]:
# Get predictions
y_pred = cnn_model_multi_input_time.predict(X_test_dict)
y_pred_labels = (y_pred > 0.5).astype(int)

# Confusion matrix
cm = confusion_matrix(y_test_raw_downsampled, y_pred_labels)
cm_df = pd.DataFrame(cm, index=["Actual 0", "Actual 1"], columns=["Predicted 0", "Predicted 1"])

# Classification report as a dataframe
report = classification_report(y_test_raw_downsampled, y_pred_labels, output_dict=True)
report_df = pd.DataFrame(report).transpose()

# Confusion matrix plotted as heatmap
plt.figure(figsize=(6, 5))
sns.heatmap(cm_df, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix for Spindle Detection \nusing a three-input CNN with time \ncomponent of STFT")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.tight_layout()
plt.show()

# Classification report as a table
fig, ax = plt.subplots(figsize=(10, 4))
ax.axis('tight')
ax.axis('off')
table = ax.table(cellText=report_df.round(2).values,
                 colLabels=report_df.columns,
                 rowLabels=report_df.index,
                 cellLoc='center',
                 loc='center')
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1.2, 1.2)
plt.title("Classification Report for Spindle Detection \nusing a three-input CNN with time \ncomponent of STFT", fontsize=14)
plt.tight_layout()
plt.show()


[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 20ms/step


In [125]:
# Flatten in case y_pred has shape (n_samples, 1)
y_pred_proba = y_pred_time.ravel()

# Compute ROC curve and AUC
fpr, tpr, thresholds = roc_curve(y_test_raw_downsampled, y_pred_proba)
roc_auc = auc(fpr, tpr)

# Plotting
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random chance')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic for a three-input CNN model \nwith the STFT time component")
plt.legend(loc="lower right")
plt.grid(True)
plt.tight_layout()
plt.show()

### 3-input model with frequency for the STFT

In [126]:
X_train_dict_freq = {
    'raw_input': X_train_raw_downsampled,
    'filtered_input': X_train_filtered,
    'stft_input': X_train_stft_freq
}

X_test_dict_freq = {
    'raw_input': X_test_raw_downsampled,
    'filtered_input': X_test_filtered,
    'stft_input': X_test_stft_freq
}

In [127]:
def build_multi_input_cnn_model_freq():
    # Inputs
    input_raw = tf.keras.Input(shape=(300, 1), name='raw_input')
    input_filtered = tf.keras.Input(shape=(300, 1), name='filtered_input')
    input_stft = tf.keras.Input(shape=(13, 1), name='stft_input')  

    def conv_branch(input_layer, kernel_sizes=[5, 11, 21]):
        outputs = []
        for k in kernel_sizes:
            pad = k // 2
            x = tf.keras.layers.ZeroPadding1D(padding=pad)(input_layer)
            x = tf.keras.layers.Conv1D(filters=10, kernel_size=k, strides=1, padding='valid')(x)
            x = tf.keras.layers.LeakyReLU(negative_slope=0.01)(x)
            x = tf.keras.layers.MaxPooling1D(pool_size=2)(x)
            x = tf.keras.layers.BatchNormalization()(x)
            outputs.append(x)
        return tf.keras.layers.Concatenate()(outputs)

    # Convolutional branches
    branch_raw = conv_branch(input_raw)
    branch_filtered = conv_branch(input_filtered)
    branch_stft = conv_branch(input_stft)

    # Each branch through its own GRU
    gru_raw = tf.keras.layers.GRU(64)(branch_raw)
    gru_filtered = tf.keras.layers.GRU(64)(branch_filtered)
    gru_stft = tf.keras.layers.GRU(64)(branch_stft)

    # Concatenate GRU outputs (fixed-length vectors)
    merged = tf.keras.layers.Concatenate()([gru_raw, gru_filtered, gru_stft])

    # Dense layers
    x = tf.keras.layers.Dense(64, activation='relu')(merged)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(x)

    # Build model
    model = tf.keras.Model(inputs=[input_raw, input_filtered, input_stft], outputs=output)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

In [128]:
cnn_model_multi_input_freq = build_multi_input_cnn_model_freq()
cnn_model_multi_input_freq.summary()

In [129]:
early_stop = EarlyStopping(
    monitor='val_loss',      
    patience=5,               
    restore_best_weights=True 
)
# stop after 5 epochs with no improvement

In [130]:
%%time

training_info_multiple_inputs_freq = cnn_model_multi_input_freq.fit(X_train_dict_freq, y_train_raw_downsampled, validation_split=0.2, epochs=20, batch_size=128, callbacks=[early_stop])

Epoch 1/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 268ms/step - accuracy: 0.6013 - loss: 0.6463 - val_accuracy: 0.6485 - val_loss: 1.0037
Epoch 2/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 333ms/step - accuracy: 0.7826 - loss: 0.5185 - val_accuracy: 0.7010 - val_loss: 0.6207
Epoch 3/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 336ms/step - accuracy: 0.6853 - loss: 0.5871 - val_accuracy: 0.6570 - val_loss: 0.6333
Epoch 4/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 306ms/step - accuracy: 0.6376 - loss: 0.6239 - val_accuracy: 0.8289 - val_loss: 0.4649
Epoch 5/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 336ms/step - accuracy: 0.7844 - loss: 0.4919 - val_accuracy: 0.7763 - val_loss: 0.4501
Epoch 6/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 303ms/step - accuracy: 0.8606 - loss: 0.3475 - val_accuracy: 0.8770 - val_loss: 0.3017
Epoch 7/20
[1m91/91[

In [132]:
def plot_training_history(training_info):
  fig, axs = plt.subplots(1, 2, figsize=(16, 5))
  axs[0].plot(training_info.history['loss'], label="training set")
  axs[0].plot(training_info.history['val_loss'], label="validation set")
  axs[0].set_xlabel("Epoch")
  axs[0].set_ylabel("Loss")
  axs[0].grid(True)
  axs[0].legend()
  try:
    axs[1].plot(training_info.history['accuracy'], label="training set")
    axs[1].plot(training_info.history['val_accuracy'], label="validation set")
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Accuracy")
    axs[1].grid(True)
    axs[1].legend()
  except:
    pass

    fig.suptitle("Training History for three-input CNN model with frequency component of STFT", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

plot_training_history(training_info_multiple_inputs_freq)

In [133]:
cnn_model_multi_input_freq.evaluate(X_test_dict_freq, y_test_raw_downsampled)

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 22ms/step - accuracy: 0.8895 - loss: 0.2646


[0.26418402791023254, 0.8936188817024231]

In [134]:
y_pred = cnn_model_multi_input_freq.predict(X_test_dict_freq)
y_pred_labels = (y_pred > 0.5).astype(int)

print(confusion_matrix(y_test_raw_downsampled, y_pred_labels))
print(classification_report(y_test_raw_downsampled, y_pred_labels))

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 25ms/step
[[4999  155]
 [1062 5224]]
              precision    recall  f1-score   support

           0       0.82      0.97      0.89      5154
           1       0.97      0.83      0.90      6286

    accuracy                           0.89     11440
   macro avg       0.90      0.90      0.89     11440
weighted avg       0.91      0.89      0.89     11440



In [98]:
# Get predictions
y_pred = cnn_model_multi_input_freq.predict(X_test_dict_freq)
y_pred_labels = (y_pred > 0.5).astype(int)

# Confusion matrix
cm = confusion_matrix(y_test_raw_downsampled, y_pred_labels)
cm_df = pd.DataFrame(cm, index=["Actual 0", "Actual 1"], columns=["Predicted 0", "Predicted 1"])

# Classification report as a dataframe
report = classification_report(y_test_raw_downsampled, y_pred_labels, output_dict=True)
report_df = pd.DataFrame(report).transpose()

# Confusion matrix plotted as heatmap
plt.figure(figsize=(6, 5))
sns.heatmap(cm_df, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix for Spindle Detection \nusing a three-input CNN with frequency \ncomponent of STFT")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.tight_layout()
plt.show()

# Classification report as a table
fig, ax = plt.subplots(figsize=(10, 4))
ax.axis('tight')
ax.axis('off')
table = ax.table(cellText=report_df.round(2).values,
                 colLabels=report_df.columns,
                 rowLabels=report_df.index,
                 cellLoc='center',
                 loc='center')
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1.2, 1.2)
plt.title("Classification Report for Spindle Detection \nusing a three-input CNN with frequency \ncomponent of STFT", fontsize=14)
plt.tight_layout()
plt.show()


[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 14ms/step


In [100]:
# Flatten in case y_pred has shape (n_samples, 1)
y_pred_proba = y_pred.ravel()

# Compute ROC curve and AUC
fpr, tpr, thresholds = roc_curve(y_test_raw_downsampled, y_pred_proba)
roc_auc = auc(fpr, tpr)

# Plotting
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random chance')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic for a three-input CNN model \nwith the STFT frequency component")
plt.legend(loc="lower right")
plt.grid(True)
plt.tight_layout()
plt.show()