In [1]:
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
import mne
import matplotlib.pyplot as plt
import pyvista
import ipywidgets
import ipyevents
import pyvistaqt
import yasa
import os
import random

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, f1_score, roc_curve, auc, precision_score, recall_score
from sklearn.utils import class_weight
from sklearn.preprocessing import MinMaxScaler

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping

import scipy.signal as signal
from scipy.signal import hilbert
from scipy.signal import stft

from scipy.stats import friedmanchisquare
from scipy.stats import ttest_rel, wilcoxon, shapiro

import pywt
import cv2

SEED = 15

os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

os.environ['TF_DETERMINISTIC_OPS'] = '1'
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

In [2]:
%matplotlib qt

## CNN Models

In [3]:
def build_cnn_model_downsampled(input_shape=(300,1)):

    # linear embedding layer
    input_layer = tf.keras.layers.Input(shape=input_shape)

    # Three convolutional blocks (like having three pattern detectors)

    # First convolution block, kernel size of 5
    padded1 = tf.keras.layers.ZeroPadding1D(padding=2)(input_layer)
    conv1 = tf.keras.layers.Conv1D(filters=10, kernel_size=5, strides=1, padding='valid')(padded1)
    # each filter learns a different type of short-time feature
    # stride of 1, moves one step at a time
    conv1 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv1)
    conv1 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv1)
    # K = 2
    conv1 = tf.keras.layers.BatchNormalization()(conv1)

    # Second convolution block, kernel size of 11
    padded2 = tf.keras.layers.ZeroPadding1D(padding=5)(input_layer)
    conv2 = tf.keras.layers.Conv1D(filters=10, kernel_size=11, strides=1, padding='valid')(padded2)
    conv2 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv2)
    conv2 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv2)
    conv2 = tf.keras.layers.BatchNormalization()(conv2)

    # Third convolution block, kernel size of 21
    padded3 = tf.keras.layers.ZeroPadding1D(padding=10)(input_layer)
    conv3 = tf.keras.layers.Conv1D(filters=10, kernel_size=21, strides=1, padding='valid')(padded3)
    conv3 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv3)
    conv3 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv3)
    conv3 = tf.keras.layers.BatchNormalization()(conv3)

    # Concatenate the outputs of all blocks
    concatenated = tf.keras.layers.Concatenate()([conv1, conv2, conv3])

    # GRU Layer
    gru = tf.keras.layers.GRU(64)(concatenated)

    # Fully connected (dense) layer
    dense = tf.keras.layers.Dense(64, activation='relu')(gru)
    # add a Dropout layer to prevent overfitting
    #dense = tf.keras.layers.Dropout(0.5)(dense)

    # Two softmax outputs for dual-task classification
    #output_task1 = tf.keras.layers.Dense(2, activation='softmax', name='task1')(dense)
    #output_task2 = tf.keras.layers.Dense(2, activation='softmax', name='task2')(dense)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(dense)

    # Create the model
    #model = tf.keras.models.Model(inputs=input_layer, outputs=[output_task1, output_task2])
    model = tf.keras.models.Model(inputs=input_layer, outputs=output)

    # Compile the model
    #model.compile(optimizer='adam', loss={'task1': 'categorical_crossentropy', 'task2': 'categorical_crossentropy'}, metrics={'task1': 'accuracy', 'task2': 'accuracy'})
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    # Return the compiled model
    return model

In [4]:
def build_multi_input_cnn_model_filtered():
    # Inputs
    input_raw = tf.keras.Input(shape=(300, 1), name='raw_input')
    input_filtered_so = tf.keras.Input(shape=(300, 1), name='filtered_so_input')
    input_filtered_spindles = tf.keras.Input(shape=(300, 1), name='filtered_spindles_input') 

    def conv_branch(input_layer, kernel_sizes=[5, 11, 21]):
        outputs = []
        for k in kernel_sizes:
            pad = k // 2
            x = tf.keras.layers.ZeroPadding1D(padding=pad)(input_layer)
            x = tf.keras.layers.Conv1D(filters=10, kernel_size=k, strides=1, padding='valid')(x)
            x = tf.keras.layers.LeakyReLU(negative_slope=0.01)(x)
            x = tf.keras.layers.MaxPooling1D(pool_size=2)(x)
            x = tf.keras.layers.BatchNormalization()(x)
            outputs.append(x)
        return tf.keras.layers.Concatenate()(outputs)

    # Convolutional branches
    branch_raw = conv_branch(input_raw)
    branch_filtered_so = conv_branch(input_filtered_so)
    branch_filtered_spindles = conv_branch(input_filtered_spindles)

    # Each branch through its own GRU
    gru_raw = tf.keras.layers.GRU(64)(branch_raw)
    gru_filtered_so = tf.keras.layers.GRU(64)(branch_filtered_so)
    gru_filtered_spindles = tf.keras.layers.GRU(64)(branch_filtered_spindles)

    # Concatenate GRU outputs (fixed-length vectors)
    merged = tf.keras.layers.Concatenate()([gru_raw, gru_filtered_so, gru_filtered_spindles])

    # Dense layers
    x = tf.keras.layers.Dense(64, activation='relu')(merged)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(x)

    # Build model
    model = tf.keras.Model(inputs=[input_raw, input_filtered_so, input_filtered_spindles], outputs=output)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

In [5]:
def build_multi_input_cnn_model_freq():
    # Inputs
    input_raw = tf.keras.Input(shape=(300, 1), name='raw_input')
    input_filtered_so = tf.keras.Input(shape=(300, 1), name='filtered_so_input')
    input_filtered_spindles = tf.keras.Input(shape=(300, 1), name='filtered_spindles_input')
    input_stft = tf.keras.Input(shape=(13, 1), name='stft_input')  

    def conv_branch(input_layer, kernel_sizes=[5, 11, 21]):
        outputs = []
        for k in kernel_sizes:
            pad = k // 2
            x = tf.keras.layers.ZeroPadding1D(padding=pad)(input_layer)
            x = tf.keras.layers.Conv1D(filters=10, kernel_size=k, strides=1, padding='valid')(x)
            x = tf.keras.layers.LeakyReLU(alpha=0.01)(x)
            x = tf.keras.layers.MaxPooling1D(pool_size=2)(x)
            x = tf.keras.layers.BatchNormalization()(x)
            outputs.append(x)
        return tf.keras.layers.Concatenate()(outputs)

    # Convolutional branches
    branch_raw = conv_branch(input_raw)
    branch_filtered_so = conv_branch(input_filtered_so)
    branch_filtered_spindles = conv_branch(input_filtered_spindles)
    branch_stft = conv_branch(input_stft)

    # Each branch through its own GRU
    gru_raw = tf.keras.layers.GRU(64)(branch_raw)
    gru_filtered_so = tf.keras.layers.GRU(64)(branch_filtered_so)
    gru_filtered_spindles = tf.keras.layers.GRU(64)(branch_filtered_spindles)
    gru_stft = tf.keras.layers.GRU(64)(branch_stft)

    # Concatenate GRU outputs (fixed-length vectors)
    merged = tf.keras.layers.Concatenate()([gru_raw, gru_filtered_so, gru_filtered_spindles, gru_stft])

    # Dense layers
    x = tf.keras.layers.Dense(64, activation='relu')(merged)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(x)

    # Build model
    model = tf.keras.Model(inputs=[input_raw, input_filtered_so, input_filtered_spindles, input_stft], outputs=output)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

## Spindle detection function

In [6]:
def detect_spindles_times(eeg_raw, do_filter=True, do_downsample=False, downsample_rate=100):
    
    # 1. Filter between 12 and 16 Hz
    
    data = eeg_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=12, h_freq=16)
    
    # 2. Downsample at 100 Hz (100 samples per second)

    if do_downsample:
        data.resample(downsample_rate)
    
    sfreq = data.info['sfreq']  
    channel_data = data.get_data()[0]
    # extract the filtered data
    
    
    # 3: Calculate amplitude by applying Hilbert transformation

    hilbert_signal = hilbert(channel_data)
    # apply hilbert transformation to bandpassed data
    # gives analytic signal with amplitude and phase information
    envelope = np.abs(hilbert_signal)
    # take the absolute part of the hilbert signal
    # also the instantaneous power of the signal
    # gives the envelope: amplitude modulation
    # how strength of oscillations change over time
    # size of sliding window
    
    # 4: Perform smoothing with a sliding window of 0.2 seconds
    # this removes high-frequency noise
    
    sliding_window = int(0.2 * sfreq)
    smoothed_envelope = np.convolve(envelope, np.ones(sliding_window) / sliding_window, mode='same')
    # convolving envelope with a uniform filter over the sliding window
    # convolution takes rolling average of 20 samples at a time
    # smooth the signal with the average of values in the window
    # in the smoothed envelope, can detect regions with higher amplitude 
    # which is when a spindle event occurs
    # np.ones: creates a filter kernel
    # have a filter where the sum of all elements equals 1
    # this filter is replaced by the average of the 20 surrounding samples
    # convolution between envelope and averaging filter
    # mode = 'same': so that output of convolution has same length as original envelope

    # 5. Define spindle detection threshold

    threshold = np.percentile(smoothed_envelope, 75)
    spindle_threshold = smoothed_envelope > threshold
    #threshold = np.mean(smoothed_envelope) + 1.5 * np.std(smoothed_envelope)
    #spindle_threshold = smoothed_envelope > threshold
    # threshold is 75th percentile of the smoothed envelope
    # will look at the duration later
    
    # 6. Detect spindles and define peaks and troughs for visualisation
    
    spindles = []
    # initialize list with spindles
    above_threshold = np.where(spindle_threshold)[0]
    # returns indices where signal above the threshold
    stacked_spindles = []
    # initialize list for stacking the spindles for the visualisation
    # contains aligned spindles at peak
    
    if len(above_threshold) > 0:
        # checking it's not empty
        start_idx = above_threshold[0]
        # would be the start of a potential spindle
        for i in range(1, len(above_threshold)):
            if above_threshold[i] > above_threshold[i - 1] + 1:  
                # if above threshold[1] > above_threshold[0] + 1
                # because all indices should be separated by 1
                # so here detects gaps
                # so starting from the second index
                # and comparing each index to the one before
                end_idx = above_threshold[i - 1]
                # so if above condition is true, this is the end of the spindle
                duration = (end_idx - start_idx) / sfreq
                if 0.5 <= duration <= 3:
                    # only keep spindles lasting 0.5 to 3 seconds
                    segment = channel_data[start_idx:end_idx]
                    # extract EEG segment corresponding to detected spindle
                    peak_idx = start_idx + np.argmax(segment) 
                    # extract the peak of the spindle
                    # this will be useful for later
                    spindles.append((start_idx / sfreq, end_idx / sfreq))
                    # all the spindles are stored in spindles
                    
                    # Aligning spindles at peak for visualization
                    before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
                    # still in the for loop, so this is the peak index of individual peak
                    after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
                    # extracting 1.5 seconds before and after peak
                    # max and min are used for out of bounds situations at the start and end of EEG data
                    aligned_segment = channel_data[before_peak_idx:after_peak_idx]
                    stacked_spindles.append(aligned_segment)
                    # the aligned segment is saved in stacked spindles
                
                start_idx = above_threshold[i]
                # update the start index for the for loop

        # then need to process the final spindle
        end_idx = above_threshold[-1]
        duration = (end_idx - start_idx) / sfreq
        if 0.5 <= duration <= 3:
            segment = channel_data[start_idx:end_idx]
            peak_idx = start_idx + np.argmax(segment)
            spindles.append((start_idx / sfreq, end_idx / sfreq))

            before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
            after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
            aligned_segment = channel_data[before_peak_idx:after_peak_idx]
            stacked_spindles.append(aligned_segment)
    
    return spindles
    

def detect_spindles_peaks(eeg_raw, do_filter=True, do_downsample=False, downsample_rate=100):
    
    # 1. Filter between 12 and 16 Hz
    
    data = eeg_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=12, h_freq=16)
    
    # 2. Downsample at 100 Hz (100 samples per second)
    
    if do_downsample:
        data.resample(downsample_rate)
        
    sfreq = data.info['sfreq']  
    # update to new sampling frequency
    # because used later in the code
    channel_data = data.get_data()[0]
    # extract the filtered data
    
    # 3: Calculate amplitude by applying Hilbert transformation

    hilbert_signal = hilbert(channel_data)
    # apply hilbert transformation to bandpassed data
    # gives analytic signal with amplitude and phase information
    envelope = np.abs(hilbert_signal)
    # take the absolute part of the hilbert signal
    # also the instantaneous power of the signal
    # gives the envelope: amplitude modulation
    # how strength of oscillations change over time
    # size of sliding window
    
    # 4: Perform smoothing with a sliding window of 0.2 seconds
    # this removes high-frequency noise
    
    sliding_window = int(0.2 * sfreq)
    smoothed_envelope = np.convolve(envelope, np.ones(sliding_window) / sliding_window, mode='same')
    # convolving envelope with a uniform filter over the sliding window
    # convolution takes rolling average of 20 samples at a time
    # smooth the signal with the average of values in the window
    # in the smoothed envelope, can detect regions with higher amplitude 
    # which is when a spindle event occurs
    # np.ones: creates a filter kernel
    # have a filter where the sum of all elements equals 1
    # this filter is replaced by the average of the 20 surrounding samples
    # convolution between envelope and averaging filter
    # mode = 'same': so that output of convolution has same length as original envelope

    # 5. Define spindle detection threshold

    threshold = np.percentile(smoothed_envelope, 75)
    spindle_threshold = smoothed_envelope > threshold
    # 75th percentile as criteria

    #threshold = np.mean(smoothed_envelope) + 1.5 * np.std(smoothed_envelope)
    #spindle_threshold = smoothed_envelope > threshold
    
    # 6. Detect spindles and define peaks and troughs for visualisation
    
    spindles = []
    # initialize list with spindles
    above_threshold = np.where(spindle_threshold)[0]
    # returns indices where signal above the threshold
    stacked_spindles = []
    # initialize list for stacking the spindles for the visualisation
    # contains aligned spindles at peak
    
    if len(above_threshold) > 0:
        # checking it's not empty
        start_idx = above_threshold[0]
        # would be the start of a potential spindle
        for i in range(1, len(above_threshold)):
            if above_threshold[i] > above_threshold[i - 1] + 1:  
                # if above threshold[1] > above_threshold[0] + 1
                # because all indices should be separated by 1
                # so here detects gaps
                end_idx = above_threshold[i - 1]
                # so if above condition is true, this is the end of the spindle
                duration = (end_idx - start_idx) / sfreq
                if 0.5 <= duration <= 3:
                    # only keep spindles lasting 0.5 to 3 seconds
                    segment = channel_data[start_idx:end_idx]
                    # extract EEG segment corresponding to detected spindle
                    peak_idx = start_idx + np.argmax(segment) 
                    # extract the peak of the spindle
                    # this will be useful for later
                    #spindles.append(f"Spindle detected from {start_idx / sfreq:.2f}s to {end_idx / sfreq:.2f}s, peak at {peak_idx / sfreq:.2f}s")
                    spindles.append((peak_idx / sfreq))
                    # all the spindles are stored in spindles
                    
                    # Aligning spindles at peak for visualization
                    before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
                    # still in the for loop, so this is the peak index of individual peak
                    after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
                    # extracting 1.5 seconds before and after peak
                    # max and min are used for out of bounds situations at the start and end of EEG data
                    aligned_segment = channel_data[before_peak_idx:after_peak_idx]
                    stacked_spindles.append(aligned_segment)
                    # the aligned segment is saved in stacked spindles
                
                start_idx = above_threshold[i]
                # update the start index for the for loop

        # then need to process the final spindle
        end_idx = above_threshold[-1]
        duration = (end_idx - start_idx) / sfreq
        if 0.5 <= duration <= 3:
            segment = channel_data[start_idx:end_idx]
            peak_idx = start_idx + np.argmax(segment)
            spindles.append((peak_idx / sfreq))

            before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
            after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
            aligned_segment = channel_data[before_peak_idx:after_peak_idx]
            stacked_spindles.append(aligned_segment)

    
    return spindles

## Slow oscillation detection function

In [7]:
def detect_slow_oscillations_times(combined_raw, do_filter=True, do_downsample=False, downsample_rate=100):

    # according to methods from Klinzing et al.(2016)

    data = combined_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=0.16, h_freq=1.25)

    if do_downsample:
        data.resample(downsample_rate)
        
    sfreq = data.info['sfreq']
    channel_data = data.get_data()[0]
    
    # 3. find all positive-to-negative zero-crossings
    
    # zero_crossings = np.where( S!= 0)[0]
    # can also save this somewhere for further detection of spindles
    
    S = np.diff(np.sign(channel_data))
    # np.sign returns an array with 1 (positive), 0 (zero), -1 (negative)
    # np.diff calculates the difference between consecutive elements in an array
    # positive value: transition from negative to positive
    # negative value: transition from positive to negative
    # when it's a zero, means that value stayed the same
    zero_crossings = np.where(S < 0)[0]
    # -2 is when a positive-to-negative zero-crossing occurs
    # goes from 1 to -1 
    # -1 - 1 = -2
    # [0] extracts the actual array
    # extracts the indices of interest from current_data (not S)
    #signs = np.sign(current_data)
    #pos_to_neg = np.where((signs[:-1] > 0) & (signs[1:] < 0))[0]
    # detect +1 to -1
    #neg_to_pos = np.where((signs[:-1] <  0) & (signs[1:] > 0))[0]
    # detect -1 to +1

    # 4. Detect peak potentials in each pair
    slow_oscillations = []
    negative_peaks = []
    positive_peaks = []
    peak_to_peak_amplitudes = []
    candidate_indices = []

    # for loop for each pair
    # to collect all the negative and positive peaks
    # to further apply criteria
    count = 0
    for i in range(0, len(zero_crossings)-1, 1):
        # loop through all the zero_crossings
        # step of 1 (with step of 2, miss some zero_crossings)
        start_idx = zero_crossings[i] + 1
        # assigns index of zero-crossing (representing start of potential SO)
        # to start_idx
        end_idx = zero_crossings[i + 1] + 1
        # assigns index of next zero-crossing (representing end of potential SO)
        # to end_idx

        # find the negative to positive crossing in between
        #mid_crossings = neg_to_pos[(neg_to_pos > start_idx) & (neg_to_pos < end_idx)]

        #if len(mid_crossings) != 1:
            #continue

        #mid_idx = mid_crossings [0]

        #duration = (end_idx - start_idx) / sfreq
        #if not (0.8 <= duration <= 2.0):
  
        
        segment_length = (end_idx - start_idx) / sfreq

        # need to add +1 because of way extract segment later

        # have identified index for the pair
        
        # extract data segment between crossings
        
        # find peaks
        if 0.8 <= segment_length <= 2.0:
            count += 1
            segment = channel_data[start_idx:end_idx]
            positive_peak = np.max(segment)
            negative_peak = np.min(segment)
            peak_to_peak_amplitude = positive_peak - negative_peak

        # store values
            candidate_indices.append((start_idx, end_idx))
            positive_peaks.append(positive_peak)
            negative_peaks.append(negative_peak)
            peak_to_peak_amplitudes.append(peak_to_peak_amplitude)

    # calculate mean values for comparison
    #mean_negative_peak = np.mean(negative_peaks)
    # mean_negative_peak = np.mean(negative_peaks) if negative_peaks else 0
    #mean_peak_to_peak_amplitude = np.mean(peak_to_peak_amplitudes)
    # mean_peak_to_peak_amplitude = np.mean(peak_to_peak_amplitudes) if peak_to_peak_amplitudes else 0

    negative_peak_threshold = np.percentile(negative_peaks, 25)
    # keep lowest negative peaks (under the 25th percentile)
    peak_to_peak_amplitude_threshold = np.percentile(peak_to_peak_amplitudes, 75)
    # keep largest peak-to-peak amplitude (over 75th percentile)

    for (start_idx, end_idx), negative_peak, peak_to_peak_amplitude in zip(candidate_indices, negative_peaks, peak_to_peak_amplitudes):
        if peak_to_peak_amplitude >= peak_to_peak_amplitude_threshold and negative_peak <= negative_peak_threshold:
            slow_oscillations.append((start_idx / sfreq, end_idx / sfreq))
            
    return slow_oscillations
    # returns a list of tuples, in which each tuple represents the start and end times of
    # a detected slow oscillation

def detect_slow_oscillations_peaks(combined_raw, do_filter=True, do_downsample=True, downsample_rate=100):

    # according to methods from Klinzing et al.(2016)

    data = combined_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=0.16, h_freq=1.25)

    if do_downsample:
        data.resample(downsample_rate)
        
    sfreq = data.info['sfreq']
    channel_data = data.get_data()[0]
    
    # 3. find all positive-to-negative zero-crossings
    
    # zero_crossings = np.where( S!= 0)[0]
    # can also save this somewhere for further detection of spindles
    
    S = np.diff(np.sign(channel_data))
    # np.sign returns an array with 1 (positive), 0 (zero), -1 (negative)
    # np.diff calculates the difference between consecutive elements in an array
    # positive value: transition from negative to positive
    # negative value: transition from positive to negative
    # when it's a zero, means that value stayed the same
    zero_crossings = np.where(S < 0)[0]
    # -2 is when a positive-to-negative zero-crossing occurs
    # goes from 1 to -1 
    # -1 - 1 = -2
    # [0] extracts the actual array
    # extracts the indices of interest from current_data (not S)


    # 4. Detect peak potentials in each pair
    slow_oscillations = []
    slow_oscillations_peaks = []
    negative_peaks = []
    positive_peaks = []
    peak_to_peak_amplitudes = []
    candidate_indices =  []

    # for loop for each pair
    # to collect all the negative and positive peaks
    # to further apply criteria
    count = 0
    for i in range(0, len(zero_crossings) - 1, 1):
        # loop through all the zero_crossings
        # step of 1 (with step of 2, miss some zero_crossings)
        start_idx = zero_crossings[i] + 1
        # assigns index of zero-crossing (representing start of potential SO)
        # to start_idx
        end_idx = zero_crossings[i + 1] + 1
        # assigns index of next zero-crossing (representing end of potential SO)
        # to end_idx
        segment_length = (end_idx - start_idx) / sfreq

        # need to add +1 because of way extract segment later

        # have identified index for the pair
        
        # extract data segment between crossings
        
        # find peaks
        if 0.8 <= segment_length <= 2.0:
            count += 1
            segment = channel_data[start_idx:end_idx]
            positive_peak = np.max(segment)
            negative_peak = np.min(segment)
            peak_to_peak_amplitude = positive_peak - negative_peak

        # store values
            candidate_indices.append((start_idx, end_idx))
            positive_peaks.append(positive_peak)
            negative_peaks.append(negative_peak)
            peak_to_peak_amplitudes.append(peak_to_peak_amplitude)

    # calculate mean values for comparison
    #mean_negative_peak = np.mean(negative_peaks)
    # mean_negative_peak = np.mean(negative_peaks) if negative_peaks else 0
    #mean_peak_to_peak_amplitude = np.mean(peak_to_peak_amplitudes)
    # mean_peak_to_peak_amplitude = np.mean(peak_to_peak_amplitudes) if peak_to_peak_amplitudes else 0

    negative_peak_threshold = np.percentile(negative_peaks, 25)
    peak_to_peak_amplitude_threshold = np.percentile(peak_to_peak_amplitudes, 75)

    for (start_idx, end_idx), negative_peak, peak_to_peak_amplitude in zip(candidate_indices, negative_peaks, peak_to_peak_amplitudes):
        if peak_to_peak_amplitude >= peak_to_peak_amplitude_threshold and negative_peak <= negative_peak_threshold:
            slow_oscillations.append((start_idx / sfreq, end_idx / sfreq))
            slow_oscillations_peaks.append((negative_peak, positive_peak))

            
    return slow_oscillations
    # returns a list of tuples, in which each tuple represents the start and end times of
    # a detected slow oscillation

## Coupling detection function

In [8]:
def detect_slow_oscillations_spindles_coupling_so_times(combined_raw, do_filter=True, do_downsample=True, downsample_rate=100):
    slow_oscillations_peaks = detect_slow_oscillations_peaks(combined_raw, do_filter=do_filter, do_downsample=do_downsample, downsample_rate=downsample_rate)
    slow_oscillations_times = detect_slow_oscillations_times(combined_raw, do_filter=do_filter, do_downsample=do_downsample, downsample_rate=downsample_rate)
    spindles_peaks = detect_spindles_peaks(combined_raw, do_filter=do_filter, do_downsample=do_downsample, downsample_rate=downsample_rate)

    coupling_times = []
    coupling_times_so = []

    # first detect the coupling events
    for (start_time, end_time), (negative_peak, positive_peak) in zip(slow_oscillations_times, slow_oscillations_peaks):
        for peak in spindles_peaks:
            if negative_peak < peak < end_time:
                coupling_times.append(peak)
                # if the peak of the spindle is between the negative and positive trough
                # add it to list of coupling times

    # then calculate the slow oscillation length
    for start_time, end_time in slow_oscillations_times:
        current_start_time = start_time
        current_end_time = end_time
        for coupling_peak in coupling_times:
            if current_start_time < coupling_peak < current_end_time:
                coupling_times_so.append((current_start_time, current_end_time))

    return coupling_times_so

## Epochs function

In [9]:
def create_fixed_length_epochs(raw, duration=3.0, overlap=0.0, preload=True, reject_by_annotation=False):

    return mne.make_fixed_length_epochs(
        raw,
        duration=duration,
        overlap=overlap,
        preload=preload,
        reject_by_annotation=reject_by_annotation
    )
# function mne.make_fixed_length_epochs takes into account the sampling frequency of the data


def label_coupling_epochs_strict(epochs, coupling_starts, coupling_ends, epoch_length_sec=3.0):
    epoch_starts = np.arange(len(epochs)) * epoch_length_sec
    epoch_labels = np.zeros(len(epochs), dtype=int)

    for coupling_start, coupling_end in zip(coupling_starts, coupling_ends):
        coupling_duration = coupling_end - coupling_start
        required_overlap = 0.8 * coupling_duration  
        # only label 1 if epoch contains 50% of the SO duration

        for i, epoch_start in enumerate(epoch_starts):
            epoch_end = epoch_start + epoch_length_sec

            # Calculate overlap between coupling and epoch
            overlap_start = max(coupling_start, epoch_start)
            overlap_end = min(coupling_end, epoch_end)
            overlap_duration = overlap_end - overlap_start

            if overlap_duration >= required_overlap:
                epoch_labels[i] = 1

    return epoch_labels

## Importing data

In [10]:
# for 5-fold validation
# load the all the files needed that were pre-processed before
# from train_1_raw and test_1_raw to train_5_raw and test_5_raw
split_files = {
    f'split_{i}': {
        'train': fr"C:\EEG DATA\combined_sets\train_{i}_raw.fif",
        'test': fr"C:\EEG DATA\combined_sets\test_{i}_raw.fif"
    } for i in range(1, 6) 
}

raw_splits = {}
for split_name, files in split_files.items():
    print(f"Loading data for {split_name}...")
    try:
        train_raw = mne.io.read_raw_fif(files['train'], preload=True)
        test_raw = mne.io.read_raw_fif(files['test'], preload=True)
        raw_splits[split_name] = {'train': train_raw, 'test': test_raw}
        print(f"Loaded train and test data for {split_name}")
    except FileNotFoundError as e:
        print(f"Error: File not found for {split_name}: {e}")
        # error in case the file does not exist
    except Exception as e:
        print(f"Error loading data for {split_name}: {e}")
        # errors in case not loading data

        # error statements useful if running this notebook on another laptop

Loading data for split_1...
Opening raw data file C:\EEG DATA\combined_sets\train_1_raw.fif...
Isotrak not found
    Range : 90000 ... 32700095 =    180.000 ... 65400.190 secs
Ready.
Reading 0 ... 32610095  =      0.000 ... 65220.190 secs...
Opening raw data file C:\EEG DATA\combined_sets\test_1_raw.fif...
Isotrak not found
    Range : 1470000 ... 7845026 =   2940.000 ... 15690.052 secs
Ready.
Reading 0 ... 6375026  =      0.000 ... 12750.052 secs...
Loaded train and test data for split_1
Loading data for split_2...
Opening raw data file C:\EEG DATA\combined_sets\train_2_raw.fif...
Isotrak not found
    Range : 1470000 ... 32985091 =   2940.000 ... 65970.182 secs
Ready.
Reading 0 ... 31515091  =      0.000 ... 63030.182 secs...
Opening raw data file C:\EEG DATA\combined_sets\test_2_raw.fif...
Isotrak not found
    Range : 90000 ... 7560030 =    180.000 ... 15120.060 secs
Ready.
Reading 0 ... 7470030  =      0.000 ... 14940.060 secs...
Loaded train and test data for split_2
Loading data

## Training

In [11]:
# we want to evaluate the models on all these scores
model_metrics = {
    'raw': {
        'f1_scores': [],
        'precision_scores': [],
        'recall_scores': []
    },
    'raw_and_filtered': {
        'f1_scores': [],
        'precision_scores': [],
        'recall_scores': []
    },
    'raw_and_filtered_and_stft': {
        'f1_scores': [],
        'precision_scores': [],
        'recall_scores': []
    }
}

# will evaluate the three models
# model 1: raw
# model 2: raw + filtered
# model 3: raw + filtered + STFT frequency
models_to_evaluate = {
    'raw': build_cnn_model_downsampled,
    'raw_and_filtered': build_multi_input_cnn_model_filtered,
    'raw_and_filtered_and_stft': build_multi_input_cnn_model_freq
}

# then we go through each split
# so create a for loop
for split_name, raw_data in raw_splits.items():
    print(f"\n--- Processing Split: {split_name} ---")
    # when running the code for a long time, allows you to know at which stage it's at
    train_raw = raw_data['train']
    test_raw = raw_data['test']

    # this is for each split

    # Slow oscillation detection on raw data for raw/three-input model labels
    coupling_train_times_raw_downsampled = detect_slow_oscillations_spindles_coupling_so_times(train_raw, do_filter=True, do_downsample=True)
    coupling_test_times_raw_downsampled = detect_slow_oscillations_spindles_coupling_so_times(test_raw, do_filter=True, do_downsample=True)

    coupling_starts_train_raw_downsampled, coupling_ends_train_raw_downsampled = zip(*coupling_train_times_raw_downsampled) if coupling_train_times_raw_downsampled else([],[])
    coupling_starts_test_raw_downsampled, coupling_ends_test_raw_downsampled = zip(*coupling_test_times_raw_downsampled) if coupling_test_times_raw_downsampled else([],[])
    # Downsample raw data for one input
    train_raw_downsampled = train_raw.copy().resample(100)
    test_raw_downsampled = test_raw.copy().resample(100)

    # Filtered data for filtered input and downsample
    train_filtered_downsampled_so = train_raw.copy().filter(l_freq=0.16, h_freq=1.25)
    test_filtered_downsampled_so = test_raw.copy().filter(l_freq=0.16, h_freq=1.25)

    # Downsample to 100 Hz
    train_filtered_downsampled_so = train_filtered_downsampled_so.resample(100)
    test_filtered_downsampled_so = test_filtered_downsampled_so.resample(100)
    # resample because already copied before

    # Apply bandpass filter between 12 and 16 Hz
    train_filtered_downsampled_spindles = train_raw.copy().filter(l_freq=12, h_freq=16)
    test_filtered_downsampled_spindles = test_raw.copy().filter(l_freq=12, h_freq=16)

    # Downsample to 100 Hz
    train_filtered_downsampled_spindles = train_filtered_downsampled_spindles.resample(100)
    test_filtered_downsampled_spindles = test_filtered_downsampled_spindles.resample(100)

    # so detection for model 2
    #coupling_train_times_filtered_downsampled_so = detect_slow_oscillations_spindles_coupling_so_times(train_filtered_downsampled_so, do_filter=False, do_downsample=False)
    #coupling_test_times_filtered_downsampled_so = detect_slow_oscillations_spindles_coupling_so_times(test_filtered_downsampled_so, do_filter=False, do_downsample=False)
    # since filtering and downsampling before, do not filter and downsample again in function

    #coupling_starts_train_filtered_downsampled_so, coupling_ends_train_filtered_downsampled_so = zip(*coupling_train_times_filtered_downsampled_so) if coupling_train_times_filtered_downsampled_so else([],[])
    #coupling_starts_test_filtered_downsampled_so, coupling_ends_test_filtered_downsampled_so = zip(*coupling_test_times_filtered_downsampled_so) if coupling_test_times_filtered_downsampled_so else([],[])

    # spindle detection for model 2
    #coupling_train_times_filtered_downsampled_spindles = detect_slow_oscillations_spindles_coupling_so_times(train_filtered_downsampled_spindles, do_filter=False, do_downsample=False)
    #coupling_test_times_filtered_downsampled_spindles = detect_slow_oscillations_spindles_coupling_so_times(test_filtered_downsampled_spindles, do_filter=False, do_downsample=False)

    #coupling_starts_train_filtered_downsampled_spindles, coupling_ends_train_filtered_downsampled_spindles = zip(*coupling_train_times_filtered_downsampled_spindles) if coupling_train_times_filtered_downsampled_spindles else([],[])
    #coupling_starts_test_filtered_downsampled_spindles, coupling_ends_test_filtered_downsampled_spindles = zip(*coupling_test_times_filtered_downsampled_spindles) if coupling_test_times_filtered_downsampled_spindles else([],[])

    # create fixed length epochs, are of 3 seconds each
    epochs_train_raw_downsampled = create_fixed_length_epochs(train_raw_downsampled, duration=3.0, overlap=0.0)
    epochs_test_raw_downsampled = create_fixed_length_epochs(test_raw_downsampled, duration=3.0, overlap=0.0)

    epochs_train_filtered_downsampled_so = create_fixed_length_epochs(train_filtered_downsampled_so)
    epochs_test_filtered_downsampled_so = create_fixed_length_epochs(test_filtered_downsampled_so)

    epochs_train_filtered_downsampled_spindles = create_fixed_length_epochs(train_filtered_downsampled_spindles)
    epochs_test_filtered_downsampled_spindles = create_fixed_length_epochs(test_filtered_downsampled_spindles)
    
    # STFT input for model 3
    # created on epochs from raw downsampled data
    epochs_train_stft_downsampled = np.squeeze(np.array(epochs_train_raw_downsampled))
    epochs_test_stft_downsampled = np.squeeze(np.array(epochs_test_raw_downsampled))

    fs = train_raw_downsampled.info['sfreq']
    nperseg = 50
    noverlap = nperseg // 2

    X_train_stft_transformed = []
    for epoch in epochs_train_stft_downsampled:
        f, t, Zxx = stft(epoch, fs=fs, nperseg=nperseg, noverlap=noverlap)
        spectrogram = np.abs(Zxx)
        X_train_stft_transformed.append(spectrogram)
    X_train_stft_transformed = np.array(X_train_stft_transformed)

    X_test_stft_transformed = []
    for epoch in epochs_test_stft_downsampled:
        f, t, Zxx = stft(epoch, fs=fs, nperseg=nperseg, noverlap=noverlap)
        spectrogram = np.abs(Zxx)
        X_test_stft_transformed.append(spectrogram)
    X_test_stft_transformed = np.array(X_test_stft_transformed)

    # only keep the frequency dimension of STFT
    X_train_stft_freq = np.mean(X_train_stft_transformed, axis=1)
    X_test_stft_freq = np.mean(X_test_stft_transformed, axis=1)
    X_train_stft_freq = X_train_stft_freq[..., np.newaxis] 
    # to have correct input size for CNN, adds channel dimension
    X_test_stft_freq = X_test_stft_freq[..., np.newaxis] 

    # normalize per epoch
    X_train_stft_freq_norm = np.array([
        (epoch - np.min(epoch)) / (np.max(epoch) - np.min(epoch) + 1e-8)
        for epoch in X_train_stft_freq
    ])
    X_test_stft_freq_norm = np.array([
        (epoch - np.min(epoch)) / (np.max(epoch) - np.min(epoch) + 1e-8)
        for epoch in X_test_stft_freq
    ])


    # reshape the epochs for model 1 and model 2
    X_train_raw = np.array(epochs_train_raw_downsampled).reshape(len(epochs_train_raw_downsampled), -1, 1)
    X_test_raw = np.array(epochs_test_raw_downsampled).reshape(len(epochs_test_raw_downsampled), -1, 1)

    X_train_filtered_so = np.array(epochs_train_filtered_downsampled_so).reshape(len(epochs_train_filtered_downsampled_so), -1, 1)
    X_test_filtered_so = np.array(epochs_test_filtered_downsampled_so).reshape(len(epochs_test_filtered_downsampled_so), -1, 1)

    X_train_filtered_spindles = np.array(epochs_train_filtered_downsampled_spindles).reshape(len(epochs_train_filtered_downsampled_spindles), -1, 1)
    X_test_filtered_spindles = np.array(epochs_test_filtered_downsampled_spindles).reshape(len(epochs_test_filtered_downsampled_spindles), -1, 1)


    # still in the same split
    # now iterate through the models
    for model_name, build_model_func in models_to_evaluate.items():
        print(f"\n--- Evaluating Model: {model_name} on {split_name} ---")

        # here define X and y sets
        # y set defined by assigning labels
        if model_name == 'raw':
            X_train_input = X_train_raw
            X_test_input = X_test_raw
            y_train = label_coupling_epochs_strict(epochs_train_raw_downsampled, coupling_starts_train_raw_downsampled, coupling_ends_train_raw_downsampled)
            y_test = label_coupling_epochs_strict(epochs_test_raw_downsampled, coupling_starts_test_raw_downsampled, coupling_ends_test_raw_downsampled)
            input_shape = (X_train_input.shape[1], X_train_input.shape[2])

        elif model_name == 'raw_and_filtered':
             X_train_input = {
                 'raw_input': X_train_raw,
                 'filtered_so_input': X_train_filtered_so,
                 'filtered_spindles_input': X_train_filtered_spindles
             }
             X_test_input = {
                 'raw_input': X_test_raw,
                 'filtered_so_input': X_test_filtered_so,
                 'filtered_spindles_input': X_test_filtered_spindles
             }
             y_train = label_coupling_epochs_strict(epochs_train_raw_downsampled, coupling_starts_train_raw_downsampled, coupling_ends_train_raw_downsampled)
             y_test = label_coupling_epochs_strict(epochs_test_raw_downsampled, coupling_starts_test_raw_downsampled, coupling_ends_test_raw_downsampled)
             input_shape = None
            # when input shape = None, infers it itself

        elif model_name == 'raw_and_filtered_and_stft':
            X_train_input = {
                'raw_input': X_train_raw,
                'filtered_so_input': X_train_filtered_so,
                'filtered_spindles_input': X_train_filtered_spindles,
                'stft_input': X_train_stft_freq_norm 
            }
            X_test_input = {
                'raw_input': X_test_raw,
                'filtered_so_input': X_test_filtered_so,
                'filtered_spindles_input': X_test_filtered_spindles,
                'stft_input': X_test_stft_freq_norm
            }
            # Labels for the three-input model come from the raw downsampled data
            y_train = label_coupling_epochs_strict(epochs_train_raw_downsampled, coupling_starts_train_raw_downsampled, coupling_ends_train_raw_downsampled)
            y_test = label_coupling_epochs_strict(epochs_test_raw_downsampled, coupling_starts_test_raw_downsampled, coupling_ends_test_raw_downsampled)

            input_shape = None 

        print(f"Training data shapes: { {k: v.shape for k, v in X_train_input.items()} if isinstance(X_train_input, dict) else X_train_input.shape}, labels={y_train.shape}")
        print(f"Test data shapes: { {k: v.shape for k, v in X_test_input.items()} if isinstance(X_test_input, dict) else X_test_input.shape}, labels={y_test.shape}")
        # to check whether a dictionary or not 
        # because it is a dictionary for the three-input model
        # but not for the other models


        # build the models

        print("Building and compiling model...")
        if model_name in ['raw']:
             model = build_model_func(input_shape)
        else:
            model = build_model_func()


        # define early stopping
        # if validation loss does not change after 5 epochs
        # stop training
        early_stop = EarlyStopping(
            monitor='val_loss',
            patience=5,
            restore_best_weights=True
        )

        # Train the model
        print("Training the model...")
        # keep 20% of training set as validation
        # this is useful to detect overfitting
        history = model.fit(
            X_train_input,
            y_train,
            validation_split=0.2,
            epochs=20, # Adjust epochs as needed
            batch_size=128, # Adjust batch size as needed
            callbacks=[early_stop], # Optional: Use early stopping
        )
        print("Training finished.")

        # this evaluates the model on test data of split (unseen data)
        # these are the predictions
        print(f"Evaluating on {split_name}'s test data...")
        loss, accuracy = model.evaluate(X_test_input, y_test, verbose=0)
        print(f"Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")
        # useful to compare to accuracy and loss of training and validation set
        # to detect any overfitting

        # now that have the predictions can calculate the F1 score
        # and also accuracy and recall
        y_pred_proba = model.predict(X_test_input, verbose=0)
        y_pred_labels = (y_pred_proba > 0.5).astype(int)

        split_f1 = f1_score(y_test, y_pred_labels)
        split_precision = precision_score(y_test, y_pred_labels)
        split_recall = recall_score(y_test, y_pred_labels)
        print(f"F1 Score for {model_name} on {split_name}: {split_f1:.4f}")
        print(f"Precision for {model_name} on {split_name}: {split_precision:.4f}")
        print(f"Recall for {model_name} on {split_name}: {split_recall:.4f}")

        # store all the metrics
        # then move on to next step
        model_metrics[model_name]['f1_scores'].append(split_f1)
        model_metrics[model_name]['precision_scores'].append(split_precision)
        model_metrics[model_name]['recall_scores'].append(split_recall)

        # clears tensorflow 
        # this is to free up memory
        tf.keras.backend.clear_session()

print("\n--- Evaluation finished for all models across all splits ---")


--- Processing Split: split_1 ---
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 96 contiguous segments
Setting up band-pass filter from 0.16 - 1.2 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.16
- Lower transition bandwidth: 0.16 Hz (-6 dB cutoff frequency: 0.08 Hz)
- Upper passband edge: 1.25 Hz
- Upper transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 2.25 Hz)
- Filter length: 10313 samples (20.626 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 96 contiguous segments
Setting up band-pass filter from 0.16 - 1.2 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) 



Training the model...
Epoch 1/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 123ms/step - accuracy: 0.8728 - loss: 0.3426 - val_accuracy: 0.9556 - val_loss: 0.1835
Epoch 2/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 110ms/step - accuracy: 0.9363 - loss: 0.2305 - val_accuracy: 0.9556 - val_loss: 0.1757
Epoch 3/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 114ms/step - accuracy: 0.9363 - loss: 0.2042 - val_accuracy: 0.9556 - val_loss: 0.1558
Epoch 4/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 115ms/step - accuracy: 0.9363 - loss: 0.1767 - val_accuracy: 0.9556 - val_loss: 0.1481
Epoch 5/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 102ms/step - accuracy: 0.9363 - loss: 0.1664 - val_accuracy: 0.9556 - val_loss: 0.1453
Epoch 6/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 101ms/step - accuracy: 0.9361 - loss: 0.1576 - val_accuracy: 0.9556 - val_l





--- Evaluating Model: raw_and_filtered on split_1 ---
Training data shapes: {'raw_input': (21740, 300, 1), 'filtered_so_input': (21740, 300, 1), 'filtered_spindles_input': (21740, 300, 1)}, labels=(21740,)
Test data shapes: {'raw_input': (4250, 300, 1), 'filtered_so_input': (4250, 300, 1), 'filtered_spindles_input': (4250, 300, 1)}, labels=(4250,)
Building and compiling model...
Training the model...
Epoch 1/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m80s[0m 521ms/step - accuracy: 0.8955 - loss: 0.3059 - val_accuracy: 0.9554 - val_loss: 0.1899
Epoch 2/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m64s[0m 468ms/step - accuracy: 0.9363 - loss: 0.2190 - val_accuracy: 0.9556 - val_loss: 0.1808
Epoch 3/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m64s[0m 474ms/step - accuracy: 0.9363 - loss: 0.2263 - val_accuracy: 0.9556 - val_loss: 0.1778
Epoch 4/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 478ms/step - accuracy:



Training the model...
Epoch 1/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m94s[0m 607ms/step - accuracy: 0.9062 - loss: 0.2702 - val_accuracy: 0.9547 - val_loss: 0.1899
Epoch 2/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m72s[0m 529ms/step - accuracy: 0.9372 - loss: 0.1815 - val_accuracy: 0.9545 - val_loss: 0.1566
Epoch 3/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 508ms/step - accuracy: 0.9368 - loss: 0.1657 - val_accuracy: 0.9552 - val_loss: 0.1226
Epoch 4/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 511ms/step - accuracy: 0.9372 - loss: 0.1586 - val_accuracy: 0.9554 - val_loss: 0.1121
Epoch 5/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m70s[0m 515ms/step - accuracy: 0.9375 - loss: 0.1482 - val_accuracy: 0.9554 - val_loss: 0.1179
Epoch 6/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m72s[0m 526ms/step - accuracy: 0.9371 - loss: 0.1603 - val_accuracy: 0.9547 - val_l



Epoch 1/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 164ms/step - accuracy: 0.9092 - loss: 0.3141 - val_accuracy: 0.9507 - val_loss: 0.2004
Epoch 2/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 149ms/step - accuracy: 0.9427 - loss: 0.2114 - val_accuracy: 0.9507 - val_loss: 0.1814
Epoch 3/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 151ms/step - accuracy: 0.9427 - loss: 0.1920 - val_accuracy: 0.9507 - val_loss: 0.1734
Epoch 4/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 146ms/step - accuracy: 0.9427 - loss: 0.1682 - val_accuracy: 0.9507 - val_loss: 0.1650
Epoch 5/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 123ms/step - accuracy: 0.9427 - loss: 0.1541 - val_accuracy: 0.9507 - val_loss: 0.1644
Epoch 6/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 120ms/step - accuracy: 0.9427 - loss: 0.1499 - val_accuracy: 0.9507 - val_loss: 0.1617
Epoch 7/20



Training the model...
Epoch 1/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m87s[0m 569ms/step - accuracy: 0.9264 - loss: 0.2529 - val_accuracy: 0.9510 - val_loss: 0.2143
Epoch 2/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m73s[0m 554ms/step - accuracy: 0.9430 - loss: 0.1749 - val_accuracy: 0.9507 - val_loss: 0.1796
Epoch 3/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 568ms/step - accuracy: 0.9435 - loss: 0.1600 - val_accuracy: 0.9507 - val_loss: 0.1514
Epoch 4/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m74s[0m 564ms/step - accuracy: 0.9432 - loss: 0.1563 - val_accuracy: 0.9507 - val_loss: 0.1389
Epoch 5/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 570ms/step - accuracy: 0.9432 - loss: 0.1451 - val_accuracy: 0.9510 - val_loss: 0.1218
Epoch 6/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 573ms/step - accuracy: 0.9440 - loss: 0.1331 - val_accuracy: 0.9522 - val_l



[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 159ms/step - accuracy: 0.9351 - loss: 0.2819 - val_accuracy: 0.9504 - val_loss: 0.2096
Epoch 2/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 143ms/step - accuracy: 0.9465 - loss: 0.2040 - val_accuracy: 0.9504 - val_loss: 0.1956
Epoch 3/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 145ms/step - accuracy: 0.9465 - loss: 0.1868 - val_accuracy: 0.9504 - val_loss: 0.1788
Epoch 4/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 135ms/step - accuracy: 0.9465 - loss: 0.1566 - val_accuracy: 0.9504 - val_loss: 0.1580
Epoch 5/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 140ms/step - accuracy: 0.9465 - loss: 0.1431 - val_accuracy: 0.9504 - val_loss: 0.1527
Epoch 6/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 131ms/step - accuracy: 0.9464 - loss: 0.1367 - val_accuracy: 0.9507 - val_loss: 0.1438
Epoch 7/20
[1m130/13



Training the model...
Epoch 1/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m91s[0m 607ms/step - accuracy: 0.9328 - loss: 0.2452 - val_accuracy: 0.9504 - val_loss: 0.2259
Epoch 2/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m81s[0m 625ms/step - accuracy: 0.9470 - loss: 0.1698 - val_accuracy: 0.9504 - val_loss: 0.1823
Epoch 3/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m84s[0m 650ms/step - accuracy: 0.9472 - loss: 0.1572 - val_accuracy: 0.9511 - val_loss: 0.1525
Epoch 4/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 588ms/step - accuracy: 0.9469 - loss: 0.1480 - val_accuracy: 0.9504 - val_loss: 0.1432
Epoch 5/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 599ms/step - accuracy: 0.9482 - loss: 0.1407 - val_accuracy: 0.9502 - val_loss: 0.1272
Epoch 6/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 607ms/step - accuracy: 0.9479 - loss: 0.1302 - val_accuracy: 0.9502 - val_l



[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 177ms/step - accuracy: 0.9366 - loss: 0.2925 - val_accuracy: 0.9537 - val_loss: 0.2014
Epoch 2/20
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 148ms/step - accuracy: 0.9427 - loss: 0.2134 - val_accuracy: 0.9537 - val_loss: 0.1936
Epoch 3/20
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 150ms/step - accuracy: 0.9427 - loss: 0.1929 - val_accuracy: 0.9537 - val_loss: 0.1793
Epoch 4/20
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 144ms/step - accuracy: 0.9427 - loss: 0.1956 - val_accuracy: 0.9537 - val_loss: 0.1704
Epoch 5/20
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 141ms/step - accuracy: 0.9427 - loss: 0.1761 - val_accuracy: 0.9537 - val_loss: 0.1601
Epoch 6/20
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 168ms/step - accuracy: 0.9427 - loss: 0.1648 - val_accuracy: 0.9537 - val_loss: 0.1644
Epoch 7/20
[1m128/12



Training the model...
Epoch 1/20
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m103s[0m 700ms/step - accuracy: 0.9145 - loss: 0.2627 - val_accuracy: 0.9488 - val_loss: 0.2291
Epoch 2/20
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 589ms/step - accuracy: 0.9428 - loss: 0.1724 - val_accuracy: 0.9520 - val_loss: 0.1768
Epoch 3/20
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 586ms/step - accuracy: 0.9437 - loss: 0.1599 - val_accuracy: 0.9493 - val_loss: 0.1463
Epoch 4/20
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 602ms/step - accuracy: 0.9437 - loss: 0.1477 - val_accuracy: 0.9493 - val_loss: 0.1299
Epoch 5/20
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 605ms/step - accuracy: 0.9445 - loss: 0.1405 - val_accuracy: 0.9503 - val_loss: 0.1248
Epoch 6/20
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 601ms/step - accuracy: 0.9469 - loss: 0.1330 - val_accuracy: 0.9493 - val_



Training the model...
Epoch 1/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 177ms/step - accuracy: 0.9361 - loss: 0.2852 - val_accuracy: 0.9637 - val_loss: 0.1650
Epoch 2/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 172ms/step - accuracy: 0.9391 - loss: 0.2187 - val_accuracy: 0.9637 - val_loss: 0.1458
Epoch 3/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 160ms/step - accuracy: 0.9391 - loss: 0.1995 - val_accuracy: 0.9637 - val_loss: 0.1271
Epoch 4/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 159ms/step - accuracy: 0.9391 - loss: 0.1786 - val_accuracy: 0.9637 - val_loss: 0.1021
Epoch 5/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 152ms/step - accuracy: 0.9391 - loss: 0.1591 - val_accuracy: 0.9637 - val_loss: 0.0994
Epoch 6/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 153ms/step - accuracy: 0.9391 - loss: 0.1526 - val_accuracy: 0.9637 - val_l



Training the model...
Epoch 1/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m110s[0m 758ms/step - accuracy: 0.9272 - loss: 0.2649 - val_accuracy: 0.9635 - val_loss: 0.1564
Epoch 2/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m81s[0m 646ms/step - accuracy: 0.9390 - loss: 0.1773 - val_accuracy: 0.9642 - val_loss: 0.1161
Epoch 3/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m84s[0m 663ms/step - accuracy: 0.9405 - loss: 0.1663 - val_accuracy: 0.9635 - val_loss: 0.1075
Epoch 4/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m80s[0m 637ms/step - accuracy: 0.9408 - loss: 0.1563 - val_accuracy: 0.9617 - val_loss: 0.1043
Epoch 5/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m81s[0m 643ms/step - accuracy: 0.9417 - loss: 0.1481 - val_accuracy: 0.9573 - val_loss: 0.1057
Epoch 6/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 626ms/step - accuracy: 0.9441 - loss: 0.1408 - val_accuracy: 0.9595 - val_

## Display average metrics and statistics

In [12]:
# display F1, precision and recall in a data frame

rows = []

for model_name, metrics in model_metrics.items():
    average_f1 = np.mean(metrics['f1_scores'])
    std_f1 = np.std(metrics['f1_scores'])

    average_precision = np.mean(metrics['precision_scores'])
    std_precision = np.std(metrics['precision_scores'])

    average_recall = np.mean(metrics['recall_scores'])
    std_recall = np.std(metrics['recall_scores'])

    # the row is appended as a dict
    rows.append({
        "Model": model_name,
        "F1 Score (mean ± std)": f"{average_f1:.4f} ± {std_f1:.4f}",
        "Precision (mean ± std)": f"{average_precision:.4f} ± {std_precision:.4f}",
        "Recall (mean ± std)": f"{average_recall:.4f} ± {std_recall:.4f}",
    })

# use pandas to create the data frame
summary_df = pd.DataFrame(rows)

# add a title and print the table
print("\n--- Average Metrics Across Splits For Coupling Detection ---\n")
print(summary_df.to_string(index=False))


# add statistics

# definition
def compare_models(f1_a, f1_b, model_a_name, model_b_name, alpha=0.05, n_comparisons=1):
    # first alpha is adjusted for multiple comparisons
    corrected_alpha = alpha / n_comparisons
    print(f"\nBonferroni corrected alpha: {corrected_alpha:.4f} (original alpha={alpha} / {n_comparisons} comparisons)")

    # differences computed
    diff = np.array(f1_b) - np.array(f1_a)

    # plot distribution to assess normality
    plt.figure(figsize=(6, 4))
    sns.histplot(diff, kde=True, bins=8)
    plt.title(f'Distribution of F1 Score Differences: {model_b_name} - {model_a_name}')
    plt.xlabel('F1 Score Difference')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # shapiro wilk test to assess normality
    # determines which type of t-test to do
    w_stat, p_norm = shapiro(diff)
    print(f"Shapiro-Wilk normality test: W = {w_stat:.4f}, p = {p_norm:.4f}")

    # then perform t-test (paired t-test or wilcoxon signed-rank test)
    print(f"\n--- Statistical Comparison: {model_b_name} vs {model_a_name} ---")
    if p_norm > corrected_alpha:
        print("Paired t-test (normal distribution)")
        t_stat, p_val = ttest_rel(f1_b, f1_a)
        print(f"t-statistic = {t_stat:.4f}, p-value = {p_val:.4f}")
    else:
        print("Wilcoxon signed-rank test (non-normal distribution)")
        w_stat, p_val = wilcoxon(f1_b, f1_a)
        print(f"W-statistic = {w_stat:.4f}, p-value = {p_val:.4f}")

    # then compare with the Bonferroni corrected alpha
    if p_val < corrected_alpha:
        print(f"Significant difference at corrected alpha = {corrected_alpha:.4f}")
    else:
        print(f"No significant difference at corrected alpha = {corrected_alpha:.4f}")

# then do the three model comparisons

compare_models(
    f1_a=model_metrics['raw']['f1_scores'],
    f1_b=model_metrics['raw_and_filtered']['f1_scores'],
    model_a_name='raw',
    model_b_name='raw_and_filtered',
    alpha=0.05,
    n_comparisons=3
)

compare_models(
    f1_a=model_metrics['raw']['f1_scores'],
    f1_b=model_metrics['raw_and_filtered_and_stft']['f1_scores'],
    model_a_name='raw',
    model_b_name='raw_and_filtered',
    alpha=0.05,
    n_comparisons=3
)

compare_models(
    f1_a=model_metrics['raw_and_filtered']['f1_scores'],
    f1_b=model_metrics['raw_and_filtered_and_stft']['f1_scores'],
    model_a_name='raw_and_filtered',
    model_b_name='raw_and_filtered_and_stft',
    alpha=0.05,
    n_comparisons=3
)




--- Average Metrics Across Splits For Coupling Detection ---

                    Model F1 Score (mean ± std) Precision (mean ± std) Recall (mean ± std)
                      raw       0.2421 ± 0.1300        0.5539 ± 0.1021     0.1881 ± 0.1405
         raw_and_filtered       0.4281 ± 0.0654        0.5186 ± 0.0423     0.3814 ± 0.1070
raw_and_filtered_and_stft       0.3080 ± 0.0997        0.4653 ± 0.0564     0.2525 ± 0.1089

Bonferroni corrected alpha: 0.0167 (original alpha=0.05 / 3 comparisons)
Shapiro-Wilk normality test: W = 0.9209, p = 0.5359

--- Statistical Comparison: raw_and_filtered vs raw ---
Paired t-test (normal distribution)
t-statistic = 2.6110, p-value = 0.0594
No significant difference at corrected alpha = 0.0167

Bonferroni corrected alpha: 0.0167 (original alpha=0.05 / 3 comparisons)
Shapiro-Wilk normality test: W = 0.9335, p = 0.6203

--- Statistical Comparison: raw_and_filtered vs raw ---
Paired t-test (normal distribution)
t-statistic = 0.7545, p-value = 0.4925
No 

In [13]:
# this is to get nicer visualisations

# have a table only with F1 scores
model_names = list(model_metrics.keys())
average_f1s = [round(np.mean(model_metrics[name]['f1_scores']), 2) for name in model_names]
std_f1s = [round(np.std(model_metrics[name]['f1_scores']), 2) for name in model_names]

summary_data = {
    'Model': model_names,
    'Mean F1 Score': average_f1s,
    'F1 Score Std Dev': std_f1s
}
summary_df = pd.DataFrame(summary_data)

print("\n--- Summary of F1 Scores Across 5 Folds for Coupling Detection ---")
display(summary_df)

# create bar plot with F1 score for each model
plt.figure(figsize=(10, 6))
bars = plt.bar(summary_df['Model'], summary_df['Mean F1 Score'],
               yerr=summary_df['F1 Score Std Dev'], capsize=5,
               color=['lightcoral', 'lightgreen', 'skyblue'])

# the F1 values are on top of each bar
for bar in bars:
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2.0, yval + 0.01, f'{yval:.2f}', va='bottom', ha='center', fontsize=10)

plt.ylabel("Mean F1 Score")
plt.title("Comparison of Model Performance on Coupling Detection for 5 participants")
plt.ylim(0, 1.05)
plt.grid(axis='y', linestyle='--', alpha=0.6)
plt.xticks(rotation=15)
plt.tight_layout()
plt.show()



--- Summary of F1 Scores Across 5 Folds for Coupling Detection ---


Unnamed: 0,Model,Mean F1 Score,F1 Score Std Dev
0,raw,0.24,0.13
1,raw_and_filtered,0.43,0.07
2,raw_and_filtered_and_stft,0.31,0.1
