In [1]:
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
import mne
import matplotlib.pyplot as plt
import pyvista
import ipywidgets
import ipyevents
import pyvistaqt
import yasa
import os
import random

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, f1_score, roc_curve, auc, precision_score, recall_score
from sklearn.utils import class_weight
from sklearn.preprocessing import MinMaxScaler

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping

import scipy.signal as signal
from scipy.signal import hilbert
from scipy.signal import stft

from scipy.stats import friedmanchisquare
from scipy.stats import ttest_rel, wilcoxon, shapiro

import pywt
import cv2

SEED = 15

os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

os.environ['TF_DETERMINISTIC_OPS'] = '1'
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

In [2]:
%matplotlib qt

## CNN Models

In [3]:
def build_cnn_model_downsampled(input_shape=(300,1)):

    # linear embedding layer
    input_layer = tf.keras.layers.Input(shape=input_shape)

    # Three convolutional blocks (like having three pattern detectors)

    # First convolution block, kernel size of 5
    padded1 = tf.keras.layers.ZeroPadding1D(padding=2)(input_layer)
    conv1 = tf.keras.layers.Conv1D(filters=10, kernel_size=5, strides=1, padding='valid')(padded1)
    # each filter learns a different type of short-time feature
    # stride of 1, moves one step at a time
    conv1 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv1)
    conv1 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv1)
    # K = 2
    conv1 = tf.keras.layers.BatchNormalization()(conv1)

    # Second convolution block, kernel size of 11
    padded2 = tf.keras.layers.ZeroPadding1D(padding=5)(input_layer)
    conv2 = tf.keras.layers.Conv1D(filters=10, kernel_size=11, strides=1, padding='valid')(padded2)
    conv2 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv2)
    conv2 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv2)
    conv2 = tf.keras.layers.BatchNormalization()(conv2)

    # Third convolution block, kernel size of 21
    padded3 = tf.keras.layers.ZeroPadding1D(padding=10)(input_layer)
    conv3 = tf.keras.layers.Conv1D(filters=10, kernel_size=21, strides=1, padding='valid')(padded3)
    conv3 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv3)
    conv3 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv3)
    conv3 = tf.keras.layers.BatchNormalization()(conv3)

    # Concatenate the outputs of all blocks
    concatenated = tf.keras.layers.Concatenate()([conv1, conv2, conv3])

    # GRU Layer
    gru = tf.keras.layers.GRU(64)(concatenated)

    # Fully connected (dense) layer
    dense = tf.keras.layers.Dense(64, activation='relu')(gru)
    # add a Dropout layer to prevent overfitting
    #dense = tf.keras.layers.Dropout(0.5)(dense)

    # Two softmax outputs for dual-task classification
    #output_task1 = tf.keras.layers.Dense(2, activation='softmax', name='task1')(dense)
    #output_task2 = tf.keras.layers.Dense(2, activation='softmax', name='task2')(dense)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(dense)

    # Create the model
    #model = tf.keras.models.Model(inputs=input_layer, outputs=[output_task1, output_task2])
    model = tf.keras.models.Model(inputs=input_layer, outputs=output)

    # Compile the model
    #model.compile(optimizer='adam', loss={'task1': 'categorical_crossentropy', 'task2': 'categorical_crossentropy'}, metrics={'task1': 'accuracy', 'task2': 'accuracy'})
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    # Return the compiled model
    return model

In [4]:
def build_multi_input_cnn_model_filtered():
    # Inputs
    input_raw = tf.keras.Input(shape=(300, 1), name='raw_input')
    input_filtered = tf.keras.Input(shape=(300, 1), name='filtered_input') 

    def conv_branch(input_layer, kernel_sizes=[5, 11, 21]):
        outputs = []
        for k in kernel_sizes:
            pad = k // 2
            x = tf.keras.layers.ZeroPadding1D(padding=pad)(input_layer)
            x = tf.keras.layers.Conv1D(filters=10, kernel_size=k, strides=1, padding='valid')(x)
            x = tf.keras.layers.LeakyReLU(negative_slope=0.01)(x)
            x = tf.keras.layers.MaxPooling1D(pool_size=2)(x)
            x = tf.keras.layers.BatchNormalization()(x)
            outputs.append(x)
        return tf.keras.layers.Concatenate()(outputs)

    # Convolutional branches
    branch_raw = conv_branch(input_raw)
    branch_filtered = conv_branch(input_filtered)

    # Each branch through its own GRU
    gru_raw = tf.keras.layers.GRU(64)(branch_raw)
    gru_filtered = tf.keras.layers.GRU(64)(branch_filtered)

    # Concatenate GRU outputs (fixed-length vectors)
    merged = tf.keras.layers.Concatenate()([gru_raw, gru_filtered])

    # Dense layers
    x = tf.keras.layers.Dense(64, activation='relu')(merged)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(x)

    # Build model
    model = tf.keras.Model(inputs=[input_raw, input_filtered], outputs=output)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

In [5]:
def build_multi_input_cnn_model_freq():
    # Inputs
    input_raw = tf.keras.Input(shape=(300, 1), name='raw_input')
    input_filtered = tf.keras.Input(shape=(300, 1), name='filtered_input')
    input_stft = tf.keras.Input(shape=(13, 1), name='stft_input')  

    def conv_branch(input_layer, kernel_sizes=[5, 11, 21]):
        outputs = []
        for k in kernel_sizes:
            pad = k // 2
            x = tf.keras.layers.ZeroPadding1D(padding=pad)(input_layer)
            x = tf.keras.layers.Conv1D(filters=10, kernel_size=k, strides=1, padding='valid')(x)
            x = tf.keras.layers.LeakyReLU(negative_slope=0.01)(x)
            x = tf.keras.layers.MaxPooling1D(pool_size=2)(x)
            x = tf.keras.layers.BatchNormalization()(x)
            outputs.append(x)
        return tf.keras.layers.Concatenate()(outputs)

    # Convolutional branches
    branch_raw = conv_branch(input_raw)
    branch_filtered = conv_branch(input_filtered)
    branch_stft = conv_branch(input_stft)

    # Each branch through its own GRU
    gru_raw = tf.keras.layers.GRU(64)(branch_raw)
    gru_filtered = tf.keras.layers.GRU(64)(branch_filtered)
    gru_stft = tf.keras.layers.GRU(64)(branch_stft)

    # Concatenate GRU outputs (fixed-length vectors)
    merged = tf.keras.layers.Concatenate()([gru_raw, gru_filtered, gru_stft])

    # Dense layers
    x = tf.keras.layers.Dense(64, activation='relu')(merged)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(x)

    # Build model
    model = tf.keras.Model(inputs=[input_raw, input_filtered, input_stft], outputs=output)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

## Spindle detection function

In [6]:
def detect_spindles_times(eeg_raw, do_filter=True, do_downsample=False, downsample_rate=100):
    
    # 1. Filter between 12 and 16 Hz
    
    data = eeg_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=12, h_freq=16)
    
    # 2. Downsample at 100 Hz (100 samples per second)

    if do_downsample:
        data.resample(downsample_rate)
    
    sfreq = data.info['sfreq']  
    channel_data = data.get_data()[0]
    # extract the filtered data
    
    
    # 3: Calculate amplitude by applying Hilbert transformation

    hilbert_signal = hilbert(channel_data)
    # apply hilbert transformation to bandpassed data
    # gives analytic signal with amplitude and phase information
    envelope = np.abs(hilbert_signal)
    # take the absolute part of the hilbert signal
    # also the instantaneous power of the signal
    # gives the envelope: amplitude modulation
    # how strength of oscillations change over time
    # size of sliding window
    
    # 4: Perform smoothing with a sliding window of 0.2 seconds
    # this removes high-frequency noise
    
    sliding_window = int(0.2 * sfreq)
    smoothed_envelope = np.convolve(envelope, np.ones(sliding_window) / sliding_window, mode='same')
    # convolving envelope with a uniform filter over the sliding window
    # convolution takes rolling average of 20 samples at a time
    # smooth the signal with the average of values in the window
    # in the smoothed envelope, can detect regions with higher amplitude 
    # which is when a spindle event occurs
    # np.ones: creates a filter kernel
    # have a filter where the sum of all elements equals 1
    # this filter is replaced by the average of the 20 surrounding samples
    # convolution between envelope and averaging filter
    # mode = 'same': so that output of convolution has same length as original envelope

    # 5. Define spindle detection threshold

    threshold = np.percentile(smoothed_envelope, 75)
    spindle_threshold = smoothed_envelope > threshold
    #threshold = np.mean(smoothed_envelope) + 1.5 * np.std(smoothed_envelope)
    #spindle_threshold = smoothed_envelope > threshold
    # threshold is 75th percentile of the smoothed envelope
    # will look at the duration later
    
    # 6. Detect spindles and define peaks and troughs for visualisation
    
    spindles = []
    # initialize list with spindles
    above_threshold = np.where(spindle_threshold)[0]
    # returns indices where signal above the threshold
    stacked_spindles = []
    # initialize list for stacking the spindles for the visualisation
    # contains aligned spindles at peak
    
    if len(above_threshold) > 0:
        # checking it's not empty
        start_idx = above_threshold[0]
        # would be the start of a potential spindle
        for i in range(1, len(above_threshold)):
            if above_threshold[i] > above_threshold[i - 1] + 1:  
                # if above threshold[1] > above_threshold[0] + 1
                # because all indices should be separated by 1
                # so here detects gaps
                # so starting from the second index
                # and comparing each index to the one before
                end_idx = above_threshold[i - 1]
                # so if above condition is true, this is the end of the spindle
                duration = (end_idx - start_idx) / sfreq
                if 0.5 <= duration <= 3:
                    # only keep spindles lasting 0.5 to 3 seconds
                    segment = channel_data[start_idx:end_idx]
                    # extract EEG segment corresponding to detected spindle
                    peak_idx = start_idx + np.argmax(segment) 
                    # extract the peak of the spindle
                    # this will be useful for later
                    spindles.append((start_idx / sfreq, end_idx / sfreq))
                    # all the spindles are stored in spindles
                    
                    # Aligning spindles at peak for visualization
                    before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
                    # still in the for loop, so this is the peak index of individual peak
                    after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
                    # extracting 1.5 seconds before and after peak
                    # max and min are used for out of bounds situations at the start and end of EEG data
                    aligned_segment = channel_data[before_peak_idx:after_peak_idx]
                    stacked_spindles.append(aligned_segment)
                    # the aligned segment is saved in stacked spindles
                
                start_idx = above_threshold[i]
                # update the start index for the for loop

        # then need to process the final spindle
        end_idx = above_threshold[-1]
        duration = (end_idx - start_idx) / sfreq
        if 0.5 <= duration <= 3:
            segment = channel_data[start_idx:end_idx]
            peak_idx = start_idx + np.argmax(segment)
            spindles.append((start_idx / sfreq, end_idx / sfreq))

            before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
            after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
            aligned_segment = channel_data[before_peak_idx:after_peak_idx]
            stacked_spindles.append(aligned_segment)
    
    return spindles
    

def detect_spindles_peaks(eeg_raw, do_filter=True, do_downsample=False, downsample_rate=100):
    
    # 1. Filter between 12 and 16 Hz
    
    data = eeg_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=12, h_freq=16)
    
    # 2. Downsample at 100 Hz (100 samples per second)
    
    if do_downsample:
        data.resample(downsample_rate)
        
    sfreq = data.info['sfreq']  
    # update to new sampling frequency
    # because used later in the code
    channel_data = data.get_data()[0]
    # extract the filtered data
    
    # 3: Calculate amplitude by applying Hilbert transformation

    hilbert_signal = hilbert(channel_data)
    # apply hilbert transformation to bandpassed data
    # gives analytic signal with amplitude and phase information
    envelope = np.abs(hilbert_signal)
    # take the absolute part of the hilbert signal
    # also the instantaneous power of the signal
    # gives the envelope: amplitude modulation
    # how strength of oscillations change over time
    # size of sliding window
    
    # 4: Perform smoothing with a sliding window of 0.2 seconds
    # this removes high-frequency noise
    
    sliding_window = int(0.2 * sfreq)
    smoothed_envelope = np.convolve(envelope, np.ones(sliding_window) / sliding_window, mode='same')
    # convolving envelope with a uniform filter over the sliding window
    # convolution takes rolling average of 20 samples at a time
    # smooth the signal with the average of values in the window
    # in the smoothed envelope, can detect regions with higher amplitude 
    # which is when a spindle event occurs
    # np.ones: creates a filter kernel
    # have a filter where the sum of all elements equals 1
    # this filter is replaced by the average of the 20 surrounding samples
    # convolution between envelope and averaging filter
    # mode = 'same': so that output of convolution has same length as original envelope

    # 5. Define spindle detection threshold

    threshold = np.percentile(smoothed_envelope, 75)
    spindle_threshold = smoothed_envelope > threshold
    # 75th percentile as criteria

    #threshold = np.mean(smoothed_envelope) + 1.5 * np.std(smoothed_envelope)
    #spindle_threshold = smoothed_envelope > threshold
    
    # 6. Detect spindles and define peaks and troughs for visualisation
    
    spindles = []
    # initialize list with spindles
    above_threshold = np.where(spindle_threshold)[0]
    # returns indices where signal above the threshold
    stacked_spindles = []
    # initialize list for stacking the spindles for the visualisation
    # contains aligned spindles at peak
    
    if len(above_threshold) > 0:
        # checking it's not empty
        start_idx = above_threshold[0]
        # would be the start of a potential spindle
        for i in range(1, len(above_threshold)):
            if above_threshold[i] > above_threshold[i - 1] + 1:  
                # if above threshold[1] > above_threshold[0] + 1
                # because all indices should be separated by 1
                # so here detects gaps
                end_idx = above_threshold[i - 1]
                # so if above condition is true, this is the end of the spindle
                duration = (end_idx - start_idx) / sfreq
                if 0.5 <= duration <= 3:
                    # only keep spindles lasting 0.5 to 3 seconds
                    segment = channel_data[start_idx:end_idx]
                    # extract EEG segment corresponding to detected spindle
                    peak_idx = start_idx + np.argmax(segment) 
                    # extract the peak of the spindle
                    # this will be useful for later
                    #spindles.append(f"Spindle detected from {start_idx / sfreq:.2f}s to {end_idx / sfreq:.2f}s, peak at {peak_idx / sfreq:.2f}s")
                    spindles.append((peak_idx / sfreq))
                    # all the spindles are stored in spindles
                    
                    # Aligning spindles at peak for visualization
                    before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
                    # still in the for loop, so this is the peak index of individual peak
                    after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
                    # extracting 1.5 seconds before and after peak
                    # max and min are used for out of bounds situations at the start and end of EEG data
                    aligned_segment = channel_data[before_peak_idx:after_peak_idx]
                    stacked_spindles.append(aligned_segment)
                    # the aligned segment is saved in stacked spindles
                
                start_idx = above_threshold[i]
                # update the start index for the for loop

        # then need to process the final spindle
        end_idx = above_threshold[-1]
        duration = (end_idx - start_idx) / sfreq
        if 0.5 <= duration <= 3:
            segment = channel_data[start_idx:end_idx]
            peak_idx = start_idx + np.argmax(segment)
            spindles.append((peak_idx / sfreq))

            before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
            after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
            aligned_segment = channel_data[before_peak_idx:after_peak_idx]
            stacked_spindles.append(aligned_segment)

    
    return spindles

## Epochs function

In [7]:
def create_fixed_length_epochs(raw, duration=3.0, overlap=0.0, preload=True, reject_by_annotation=False):

    return mne.make_fixed_length_epochs(
        raw,
        duration=duration,
        overlap=overlap,
        preload=preload,
        reject_by_annotation=reject_by_annotation
    )
# function mne.make_fixed_length_epochs takes into account the sampling frequency of the data


def label_spindle_epochs_moderate(epochs, spindle_starts, spindle_ends, epoch_length_sec=3.0):
    epoch_starts = np.arange(len(epochs)) * epoch_length_sec
    epoch_labels = np.zeros(len(epochs), dtype=int)

    for spindle_start, spindle_end in zip(spindle_starts, spindle_ends):
        spindle_duration = spindle_end - spindle_start
        required_overlap = 0.5 * spindle_duration  
        # only label 1 if epoch contains 50% of the spindle duration

        for i, epoch_start in enumerate(epoch_starts):
            epoch_end = epoch_start + epoch_length_sec

            # Calculate overlap between spindle and epoch
            overlap_start = max(spindle_start, epoch_start)
            overlap_end = min(spindle_end, epoch_end)
            overlap_duration = overlap_end - overlap_start

            if overlap_duration >= required_overlap:
                epoch_labels[i] = 1

    return epoch_labels

## Importing data

In [8]:
# for 5-fold validation
# load the all the files needed that were pre-processed before
# from train_1_raw and test_1_raw to train_5_raw and test_5_raw
split_files = {
    f'split_{i}': {
        'train': fr"C:\EEG DATA\combined_sets\train_{i}_raw.fif",
        'test': fr"C:\EEG DATA\combined_sets\test_{i}_raw.fif"
    } for i in range(1, 6) 
}

raw_splits = {}
for split_name, files in split_files.items():
    print(f"Loading data for {split_name}...")
    try:
        train_raw = mne.io.read_raw_fif(files['train'], preload=True)
        test_raw = mne.io.read_raw_fif(files['test'], preload=True)
        raw_splits[split_name] = {'train': train_raw, 'test': test_raw}
        print(f"Loaded train and test data for {split_name}")
    except FileNotFoundError as e:
        print(f"Error: File not found for {split_name}: {e}")
        # error in case the file does not exist
    except Exception as e:
        print(f"Error loading data for {split_name}: {e}")
        # errors in case not loading data

        # error statements useful if running this notebook on another laptop

Loading data for split_1...
Opening raw data file C:\EEG DATA\combined_sets\train_1_raw.fif...
Isotrak not found
    Range : 90000 ... 32700095 =    180.000 ... 65400.190 secs
Ready.
Reading 0 ... 32610095  =      0.000 ... 65220.190 secs...
Opening raw data file C:\EEG DATA\combined_sets\test_1_raw.fif...
Isotrak not found
    Range : 1470000 ... 7845026 =   2940.000 ... 15690.052 secs
Ready.
Reading 0 ... 6375026  =      0.000 ... 12750.052 secs...
Loaded train and test data for split_1
Loading data for split_2...
Opening raw data file C:\EEG DATA\combined_sets\train_2_raw.fif...
Isotrak not found
    Range : 1470000 ... 32985091 =   2940.000 ... 65970.182 secs
Ready.
Reading 0 ... 31515091  =      0.000 ... 63030.182 secs...
Opening raw data file C:\EEG DATA\combined_sets\test_2_raw.fif...
Isotrak not found
    Range : 90000 ... 7560030 =    180.000 ... 15120.060 secs
Ready.
Reading 0 ... 7470030  =      0.000 ... 14940.060 secs...
Loaded train and test data for split_2
Loading data

## Training

In [9]:
# we want to evaluate the models on all these scores
model_metrics = {
    'one_input_cnn': {
        'f1_scores': [],
        'precision_scores': [],
        'recall_scores': []
    },
    'two_input_cnn': {
        'f1_scores': [],
        'precision_scores': [],
        'recall_scores': []
    },
    'three_input_cnn': {
        'f1_scores': [],
        'precision_scores': [],
        'recall_scores': []
    }
}


# will evaluate the three models
# model 1: raw
# model 2: raw + filtered
# model 3: raw + filtered + STFT frequency
models_to_evaluate = {
    'one_input_cnn': build_cnn_model_downsampled,
    'two_input_cnn': build_multi_input_cnn_model_filtered,
    'three_input_cnn': build_multi_input_cnn_model_freq
}

# then we go through each split
# so create a for loop
for split_name, raw_data in raw_splits.items():
    print(f"\n--- Processing Split: {split_name} ---")
    # when running the code for a long time, allows you to know at which stage it's at
    train_raw = raw_data['train']
    test_raw = raw_data['test']

    # this is for each split

    # spindle detection on raw data
    # this is for model 1 and model 3
    spindles_train_times_raw = detect_spindles_times(train_raw, do_filter=True, do_downsample=True)
    spindles_test_times_raw = detect_spindles_times(test_raw, do_filter=True, do_downsample=True)
    spindles_starts_train_raw, spindles_ends_train_raw = zip(*spindles_train_times_raw) if spindles_train_times_raw else ([], [])
    spindles_starts_test_raw, spindles_ends_test_raw = zip(*spindles_test_times_raw) if spindles_test_times_raw else ([], [])
    # Downsample raw data for one input
    train_raw_downsampled = train_raw.copy().resample(100)
    test_raw_downsampled = test_raw.copy().resample(100)

    # for model 2, first filter and downsample
    train_filtered = train_raw.copy().filter(l_freq=12, h_freq=16)
    test_filtered = test_raw.copy().filter(l_freq=12, h_freq=16)
    train_filtered = train_filtered.resample(100)
    test_filtered = test_filtered.resample(100)

    # spindle detection for model 2
    spindles_train_times_filtered = detect_spindles_times(train_filtered, do_filter=False, do_downsample=False) 
    # false because already filtered and downsampled
    spindles_test_times_filtered = detect_spindles_times(test_filtered, do_filter=False, do_downsample=False) 
    spindles_starts_train_filtered, spindles_ends_train_filtered = zip(*spindles_train_times_filtered) if spindles_train_times_filtered else ([], [])
    spindles_starts_test_filtered, spindles_ends_test_filtered = zip(*spindles_test_times_filtered) if spindles_test_times_filtered else ([], [])


    # create fixed length epochs, are of 3 seconds each
    epochs_train_raw_downsampled = create_fixed_length_epochs(train_raw_downsampled)
    epochs_test_raw_downsampled = create_fixed_length_epochs(test_raw_downsampled)

    epochs_train_filtered_downsampled = create_fixed_length_epochs(train_filtered)
    epochs_test_filtered_downsampled = create_fixed_length_epochs(test_filtered)


    # STFT input for model 3
    # created on epochs from raw downsampled data
    epochs_train_stft_downsampled = np.squeeze(np.array(epochs_train_raw_downsampled))
    epochs_test_stft_downsampled = np.squeeze(np.array(epochs_test_raw_downsampled))

    fs = train_raw_downsampled.info['sfreq']
    nperseg = 50
    noverlap = nperseg // 2

    X_train_stft_transformed = []
    for epoch in epochs_train_stft_downsampled:
        f, t, Zxx = stft(epoch, fs=fs, nperseg=nperseg, noverlap=noverlap)
        spectrogram = np.abs(Zxx)
        X_train_stft_transformed.append(spectrogram)
    X_train_stft_transformed = np.array(X_train_stft_transformed)

    X_test_stft_transformed = []
    for epoch in epochs_test_stft_downsampled:
        f, t, Zxx = stft(epoch, fs=fs, nperseg=nperseg, noverlap=noverlap)
        spectrogram = np.abs(Zxx)
        X_test_stft_transformed.append(spectrogram)
    X_test_stft_transformed = np.array(X_test_stft_transformed)

    # only keep the frequency dimension of STFT
    X_train_stft_freq = np.mean(X_train_stft_transformed, axis=1)
    X_test_stft_freq = np.mean(X_test_stft_transformed, axis=1)
    X_train_stft_freq = X_train_stft_freq[..., np.newaxis] 
    # to have correct input size for CNN, adds channel dimension
    X_test_stft_freq = X_test_stft_freq[..., np.newaxis] 

    # normalize per epoch
    X_train_stft_freq_norm = np.array([
        (epoch - np.min(epoch)) / (np.max(epoch) - np.min(epoch) + 1e-8)
        for epoch in X_train_stft_freq
    ])
    X_test_stft_freq_norm = np.array([
        (epoch - np.min(epoch)) / (np.max(epoch) - np.min(epoch) + 1e-8)
        for epoch in X_test_stft_freq
    ])


    # reshape the epochs for model 1 and model 2
    X_train_raw = np.array(epochs_train_raw_downsampled).reshape(len(epochs_train_raw_downsampled), -1, 1)
    X_test_raw = np.array(epochs_test_raw_downsampled).reshape(len(epochs_test_raw_downsampled), -1, 1)

    X_train_filtered = np.array(epochs_train_filtered_downsampled).reshape(len(epochs_train_filtered_downsampled), -1, 1)
    X_test_filtered = np.array(epochs_test_filtered_downsampled).reshape(len(epochs_test_filtered_downsampled), -1, 1)


    # still in the same split
    # now iterate through the models
    for model_name, build_model_func in models_to_evaluate.items():
        print(f"\n--- Evaluating Model: {model_name} on {split_name} ---")

        # here define X and y sets
        # y set defined by assigning labels
        if model_name == 'one_input_cnn':
            X_train_input = X_train_raw
            X_test_input = X_test_raw
            y_train = label_spindle_epochs_moderate(epochs_train_raw_downsampled, spindles_starts_train_raw, spindles_ends_train_raw)
            y_test = label_spindle_epochs_moderate(epochs_test_raw_downsampled, spindles_starts_test_raw, spindles_ends_test_raw)
            input_shape = (X_train_input.shape[1], X_train_input.shape[2])

        elif model_name == 'two_input_cnn':
             X_train_input = {
                 'raw_input': X_train_raw,
                 'filtered_input': X_train_filtered
             }
             X_test_input = {
                 'raw_input': X_test_raw,
                 'filtered_input': X_test_filtered
             }
             y_train = label_spindle_epochs_moderate(epochs_train_raw_downsampled, spindles_starts_train_raw, spindles_ends_train_raw)
             y_test = label_spindle_epochs_moderate(epochs_test_raw_downsampled, spindles_starts_test_raw, spindles_ends_test_raw)
             input_shape = None
            # when input shape = None, infers it itself

        elif model_name == 'three_input_cnn':
            X_train_input = {
                'raw_input': X_train_raw,
                'filtered_input': X_train_filtered,
                'stft_input': X_train_stft_freq_norm 
            }
            X_test_input = {
                'raw_input': X_test_raw,
                'filtered_input': X_test_filtered,
                'stft_input': X_test_stft_freq_norm 
            }
            # Labels for the three-input model come from the raw downsampled data
            y_train = label_spindle_epochs_moderate(epochs_train_raw_downsampled, spindles_starts_train_raw, spindles_ends_train_raw)
            y_test = label_spindle_epochs_moderate(epochs_test_raw_downsampled, spindles_starts_test_raw, spindles_ends_test_raw)

            input_shape = None 

        print(f"Training data shapes: { {k: v.shape for k, v in X_train_input.items()} if isinstance(X_train_input, dict) else X_train_input.shape}, labels={y_train.shape}")
        print(f"Test data shapes: { {k: v.shape for k, v in X_test_input.items()} if isinstance(X_test_input, dict) else X_test_input.shape}, labels={y_test.shape}")
        # to check whether a dictionary or not 
        # because it is a dictionary for the three-input model
        # but not for the other models


        # build the models

        print("Building and compiling model...")
        if model_name in ['one_input_cnn']:
             model = build_model_func(input_shape)
        else:
            model = build_model_func()


        # define early stopping
        # if validation loss does not change after 5 epochs
        # stop training
        early_stop = EarlyStopping(
            monitor='val_loss',
            patience=5,
            restore_best_weights=True
        )

        # Train the model
        print("Training the model...")
        # keep 20% of training set as validation
        # this is useful to detect overfitting
        history = model.fit(
            X_train_input,
            y_train,
            validation_split=0.2,
            epochs=20, # Adjust epochs as needed
            batch_size=128, # Adjust batch size as needed
            callbacks=[early_stop], # Optional: Use early stopping
        )
        print("Training finished.")

        # this evaluates the model on test data of split (unseen data)
        # these are the predictions
        print(f"Evaluating on {split_name}'s test data...")
        loss, accuracy = model.evaluate(X_test_input, y_test, verbose=0)
        print(f"Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")
        # useful to compare to accuracy and loss of training and validation set
        # to detect any overfitting

        # now that have the predictions can calculate the F1 score
        # and also accuracy and recall
        y_pred_proba = model.predict(X_test_input, verbose=0)
        y_pred_labels = (y_pred_proba > 0.5).astype(int)

        split_f1 = f1_score(y_test, y_pred_labels)
        split_precision = precision_score(y_test, y_pred_labels)
        split_recall = recall_score(y_test, y_pred_labels)
        print(f"F1 Score for {model_name} on {split_name}: {split_f1:.4f}")
        print(f"Precision for {model_name} on {split_name}: {split_precision:.4f}")
        print(f"Recall for {model_name} on {split_name}: {split_recall:.4f}")

        # store all the metrics
        # then move on to next step
        model_metrics[model_name]['f1_scores'].append(split_f1)
        model_metrics[model_name]['precision_scores'].append(split_precision)
        model_metrics[model_name]['recall_scores'].append(split_recall)

        # clears tensorflow 
        # this is to free up memory
        tf.keras.backend.clear_session()

print("\n--- Evaluation finished for all models across all splits ---")


--- Processing Split: split_1 ---
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 96 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 12.00
- Lower transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 10.50 Hz)
- Upper passband edge: 16.00 Hz
- Upper transition bandwidth: 4.00 Hz (-6 dB cutoff frequency: 18.00 Hz)
- Filter length: 551 samples (1.102 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 27 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) metho



[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 112ms/step - accuracy: 0.5630 - loss: 0.6829 - val_accuracy: 0.5584 - val_loss: 0.6818
Epoch 2/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 110ms/step - accuracy: 0.6277 - loss: 0.6515 - val_accuracy: 0.6789 - val_loss: 0.6132
Epoch 3/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 110ms/step - accuracy: 0.6702 - loss: 0.6035 - val_accuracy: 0.6391 - val_loss: 0.6110
Epoch 4/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 110ms/step - accuracy: 0.6152 - loss: 0.6469 - val_accuracy: 0.5345 - val_loss: 0.6915
Epoch 5/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 103ms/step - accuracy: 0.5911 - loss: 0.6700 - val_accuracy: 0.5964 - val_loss: 0.6575
Epoch 6/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 106ms/step - accuracy: 0.6353 - loss: 0.6422 - val_accuracy: 0.6095 - val_loss: 0.6597
Epoch 7/20
[1m136/13





--- Evaluating Model: two_input_cnn on split_1 ---
Training data shapes: {'raw_input': (21740, 300, 1), 'filtered_input': (21740, 300, 1)}, labels=(21740,)
Test data shapes: {'raw_input': (4250, 300, 1), 'filtered_input': (4250, 300, 1)}, labels=(4250,)
Building and compiling model...
Training the model...
Epoch 1/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 237ms/step - accuracy: 0.6207 - loss: 0.6527 - val_accuracy: 0.4632 - val_loss: 0.7164
Epoch 2/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 185ms/step - accuracy: 0.6154 - loss: 0.6052 - val_accuracy: 0.7447 - val_loss: 0.4998
Epoch 3/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 187ms/step - accuracy: 0.7984 - loss: 0.4361 - val_accuracy: 0.8735 - val_loss: 0.2876
Epoch 4/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 226ms/step - accuracy: 0.8717 - loss: 0.2918 - val_accuracy: 0.8873 - val_loss: 0.2606
Epoch 5/20
[1m136/136[0m [32



Training the model...
Epoch 1/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 118ms/step - accuracy: 0.5596 - loss: 0.6873 - val_accuracy: 0.5607 - val_loss: 0.6844
Epoch 2/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 113ms/step - accuracy: 0.5829 - loss: 0.6730 - val_accuracy: 0.7030 - val_loss: 0.6006
Epoch 3/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 112ms/step - accuracy: 0.6259 - loss: 0.6480 - val_accuracy: 0.7120 - val_loss: 0.5794
Epoch 4/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 114ms/step - accuracy: 0.6498 - loss: 0.5904 - val_accuracy: 0.5471 - val_loss: 0.6917
Epoch 5/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 118ms/step - accuracy: 0.5683 - loss: 0.6796 - val_accuracy: 0.5897 - val_loss: 0.6695
Epoch 6/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 114ms/step - accuracy: 0.6259 - loss: 0.6450 - val_accuracy: 0.6287 - val_l



[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 113ms/step - accuracy: 0.5562 - loss: 0.6853 - val_accuracy: 0.5883 - val_loss: 0.6744
Epoch 2/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 110ms/step - accuracy: 0.5959 - loss: 0.6674 - val_accuracy: 0.6534 - val_loss: 0.6306
Epoch 3/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 116ms/step - accuracy: 0.6432 - loss: 0.6220 - val_accuracy: 0.7755 - val_loss: 0.5001
Epoch 4/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 113ms/step - accuracy: 0.6467 - loss: 0.6105 - val_accuracy: 0.6998 - val_loss: 0.5966
Epoch 5/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 108ms/step - accuracy: 0.7229 - loss: 0.5489 - val_accuracy: 0.8058 - val_loss: 0.4151
Epoch 6/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 111ms/step - accuracy: 0.7949 - loss: 0.4317 - val_accuracy: 0.8365 - val_loss: 0.3687
Epoch 7/20
[1m130/13



Training the model...
Epoch 1/20
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 117ms/step - accuracy: 0.5512 - loss: 0.6875 - val_accuracy: 0.6033 - val_loss: 0.6721
Epoch 2/20
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 107ms/step - accuracy: 0.5900 - loss: 0.6682 - val_accuracy: 0.5874 - val_loss: 0.6667
Epoch 3/20
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 109ms/step - accuracy: 0.6054 - loss: 0.6501 - val_accuracy: 0.6885 - val_loss: 0.6026
Epoch 4/20
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 107ms/step - accuracy: 0.6758 - loss: 0.6060 - val_accuracy: 0.7400 - val_loss: 0.5348
Epoch 5/20
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 109ms/step - accuracy: 0.7089 - loss: 0.5722 - val_accuracy: 0.7667 - val_loss: 0.4989
Epoch 6/20
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 106ms/step - accuracy: 0.7698 - loss: 0.4803 - val_accuracy: 0.7801 - val_l



Training the model...
Epoch 1/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 172ms/step - accuracy: 0.5499 - loss: 0.6862 - val_accuracy: 0.6260 - val_loss: 0.6614
Epoch 2/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 167ms/step - accuracy: 0.6053 - loss: 0.6662 - val_accuracy: 0.7209 - val_loss: 0.5866
Epoch 3/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 159ms/step - accuracy: 0.6178 - loss: 0.6469 - val_accuracy: 0.6879 - val_loss: 0.5652
Epoch 4/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 165ms/step - accuracy: 0.6357 - loss: 0.6329 - val_accuracy: 0.7520 - val_loss: 0.5102
Epoch 5/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 154ms/step - accuracy: 0.6788 - loss: 0.6001 - val_accuracy: 0.7122 - val_loss: 0.5639
Epoch 6/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 150ms/step - accuracy: 0.7221 - loss: 0.5468 - val_accuracy: 0.7751 - val_l

## Display average metrics and statistics

In [10]:
# display F1, precision and recall in a data frame

rows = []

for model_name, metrics in model_metrics.items():
    average_f1 = np.mean(metrics['f1_scores'])
    std_f1 = np.std(metrics['f1_scores'])

    average_precision = np.mean(metrics['precision_scores'])
    std_precision = np.std(metrics['precision_scores'])

    average_recall = np.mean(metrics['recall_scores'])
    std_recall = np.std(metrics['recall_scores'])

    # the row is appended as a dict
    rows.append({
        "Model": model_name,
        "F1 Score (mean ± std)": f"{average_f1:.4f} ± {std_f1:.4f}",
        "Precision (mean ± std)": f"{average_precision:.4f} ± {std_precision:.4f}",
        "Recall (mean ± std)": f"{average_recall:.4f} ± {std_recall:.4f}",
    })

# use pandas to create the data frame
summary_df = pd.DataFrame(rows)

# add a title and print the table
print("\n--- Average Metrics Across Splits For Spindle Detection ---\n")
print(summary_df.to_string(index=False))


# add statistics

# definition
def compare_models(f1_a, f1_b, model_a_name, model_b_name, alpha=0.05, n_comparisons=1):
    # first alpha is adjusted for multiple comparisons
    corrected_alpha = alpha / n_comparisons
    print(f"\nBonferroni corrected alpha: {corrected_alpha:.4f} (original alpha={alpha} / {n_comparisons} comparisons)")

    # differences computed
    diff = np.array(f1_b) - np.array(f1_a)

    # plot distribution to assess normality
    plt.figure(figsize=(6, 4))
    sns.histplot(diff, kde=True, bins=8)
    plt.title(f'Distribution of F1 Score Differences: {model_b_name} - {model_a_name}')
    plt.xlabel('F1 Score Difference')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # shapiro wilk test to assess normality
    # determines which type of t-test to do
    w_stat, p_norm = shapiro(diff)
    print(f"Shapiro-Wilk normality test: W = {w_stat:.4f}, p = {p_norm:.4f}")

    # then perform t-test (paired t-test or wilcoxon signed-rank test)
    print(f"\n--- Statistical Comparison: {model_b_name} vs {model_a_name} ---")
    if p_norm > corrected_alpha:
        print("Paired t-test (normal distribution)")
        t_stat, p_val = ttest_rel(f1_b, f1_a)
        print(f"t-statistic = {t_stat:.4f}, p-value = {p_val:.4f}")
    else:
        print("Wilcoxon signed-rank test (non-normal distribution)")
        w_stat, p_val = wilcoxon(f1_b, f1_a)
        print(f"W-statistic = {w_stat:.4f}, p-value = {p_val:.4f}")

    # then compare with the Bonferroni corrected alpha
    if p_val < corrected_alpha:
        print(f"Significant difference at corrected alpha = {corrected_alpha:.4f}")
    else:
        print(f"No significant difference at corrected alpha = {corrected_alpha:.4f}")

# then do the three model comparisons

compare_models(
    f1_a=model_metrics['one_input_cnn']['f1_scores'],
    f1_b=model_metrics['two_input_cnn']['f1_scores'],
    model_a_name='one_input_cnn',
    model_b_name='two_input_cnn',
    alpha=0.05,
    n_comparisons=3
)

compare_models(
    f1_a=model_metrics['one_input_cnn']['f1_scores'],
    f1_b=model_metrics['three_input_cnn']['f1_scores'],
    model_a_name='one_input_cnn',
    model_b_name='three_input_cnn',
    alpha=0.05,
    n_comparisons=3
)

compare_models(
    f1_a=model_metrics['two_input_cnn']['f1_scores'],
    f1_b=model_metrics['three_input_cnn']['f1_scores'],
    model_a_name='two_input_cnn',
    model_b_name='three_input_cnn',
    alpha=0.05,
    n_comparisons=3
)




--- Average Metrics Across Splits For Spindle Detection ---

          Model F1 Score (mean ± std) Precision (mean ± std) Recall (mean ± std)
  one_input_cnn       0.8247 ± 0.0394        0.8220 ± 0.1198     0.8524 ± 0.0807
  two_input_cnn       0.8932 ± 0.0512        0.8883 ± 0.1120     0.9137 ± 0.0570
three_input_cnn       0.8948 ± 0.0570        0.8841 ± 0.1178     0.9208 ± 0.0430

Bonferroni corrected alpha: 0.0167 (original alpha=0.05 / 3 comparisons)
Shapiro-Wilk normality test: W = 0.8203, p = 0.1174

--- Statistical Comparison: two_input_cnn vs one_input_cnn ---
Paired t-test (normal distribution)
t-statistic = 4.9079, p-value = 0.0080
Significant difference at corrected alpha = 0.0167

Bonferroni corrected alpha: 0.0167 (original alpha=0.05 / 3 comparisons)
Shapiro-Wilk normality test: W = 0.9434, p = 0.6904

--- Statistical Comparison: three_input_cnn vs one_input_cnn ---
Paired t-test (normal distribution)
t-statistic = 4.8558, p-value = 0.0083
Significant difference at corre

In [11]:
# this is to get nicer visualisations

# have a table only with F1 scores
model_names = list(model_metrics.keys())
average_f1s = [round(np.mean(model_metrics[name]['f1_scores']), 2) for name in model_names]
std_f1s = [round(np.std(model_metrics[name]['f1_scores']), 2) for name in model_names]

summary_data = {
    'Model': model_names,
    'Mean F1 Score': average_f1s,
    'F1 Score Std Dev': std_f1s
}
summary_df = pd.DataFrame(summary_data)

print("\n--- Summary of F1 Scores Across 5 Folds for Spindle Detection ---")
display(summary_df)

# create bar plot with F1 score for each model
plt.figure(figsize=(10, 6))
bars = plt.bar(summary_df['Model'], summary_df['Mean F1 Score'],
               yerr=summary_df['F1 Score Std Dev'], capsize=5,
               color=['lightcoral', 'lightgreen', 'skyblue'])

# the F1 values are on top of each bar
for bar in bars:
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2.0, yval + 0.01, f'{yval:.2f}', va='bottom', ha='center', fontsize=10)

plt.ylabel("Mean F1 Score")
plt.title("Comparison of Model Performance on Spindle Detection for 5 participants")
plt.ylim(0, 1.05)
plt.grid(axis='y', linestyle='--', alpha=0.6)
plt.xticks(rotation=15)
plt.tight_layout()
plt.show()



--- Summary of F1 Scores Across 5 Folds for Spindle Detection ---


Unnamed: 0,Model,Mean F1 Score,F1 Score Std Dev
0,one_input_cnn,0.82,0.04
1,two_input_cnn,0.89,0.05
2,three_input_cnn,0.89,0.06
