In [1]:
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
import mne
import matplotlib.pyplot as plt
import pyvista
import ipywidgets
import ipyevents
import pyvistaqt
import yasa
import os
import random

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, f1_score, roc_curve, auc, precision_score, recall_score
from sklearn.utils import class_weight
from sklearn.preprocessing import MinMaxScaler

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping

import scipy.signal as signal
from scipy.signal import hilbert
from scipy.signal import stft

from scipy.stats import friedmanchisquare

import pywt
import cv2

SEED = 15

os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

os.environ['TF_DETERMINISTIC_OPS'] = '1'
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

In [2]:
%matplotlib qt

### CNN one input Models

In [3]:
def build_cnn_model_downsampled(input_shape=(300,1)):

    # linear embedding layer
    input_layer = tf.keras.layers.Input(shape=input_shape)

    # Three convolutional blocks (like having three pattern detectors)

    # First convolution block, kernel size of 5
    padded1 = tf.keras.layers.ZeroPadding1D(padding=2)(input_layer)
    conv1 = tf.keras.layers.Conv1D(filters=10, kernel_size=5, strides=1, padding='valid')(padded1)
    # each filter learns a different type of short-time feature
    # stride of 1, moves one step at a time
    conv1 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv1)
    conv1 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv1)
    # K = 2
    conv1 = tf.keras.layers.BatchNormalization()(conv1)

    # Second convolution block, kernel size of 11
    padded2 = tf.keras.layers.ZeroPadding1D(padding=5)(input_layer)
    conv2 = tf.keras.layers.Conv1D(filters=10, kernel_size=11, strides=1, padding='valid')(padded2)
    conv2 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv2)
    conv2 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv2)
    conv2 = tf.keras.layers.BatchNormalization()(conv2)

    # Third convolution block, kernel size of 21
    padded3 = tf.keras.layers.ZeroPadding1D(padding=10)(input_layer)
    conv3 = tf.keras.layers.Conv1D(filters=10, kernel_size=21, strides=1, padding='valid')(padded3)
    conv3 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv3)
    conv3 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv3)
    conv3 = tf.keras.layers.BatchNormalization()(conv3)

    # Concatenate the outputs of all blocks
    concatenated = tf.keras.layers.Concatenate()([conv1, conv2, conv3])

    # GRU Layer
    gru = tf.keras.layers.GRU(64)(concatenated)

    # Fully connected (dense) layer
    dense = tf.keras.layers.Dense(64, activation='relu')(gru)
    # add a Dropout layer to prevent overfitting
    #dense = tf.keras.layers.Dropout(0.5)(dense)

    # Two softmax outputs for dual-task classification
    #output_task1 = tf.keras.layers.Dense(2, activation='softmax', name='task1')(dense)
    #output_task2 = tf.keras.layers.Dense(2, activation='softmax', name='task2')(dense)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(dense)

    # Create the model
    #model = tf.keras.models.Model(inputs=input_layer, outputs=[output_task1, output_task2])
    model = tf.keras.models.Model(inputs=input_layer, outputs=output)

    # Compile the model
    #model.compile(optimizer='adam', loss={'task1': 'categorical_crossentropy', 'task2': 'categorical_crossentropy'}, metrics={'task1': 'accuracy', 'task2': 'accuracy'})
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    # Return the compiled model
    return model

In [4]:
def build_multi_input_cnn_model_freq():
    # Inputs
    input_raw = tf.keras.Input(shape=(300, 1), name='raw_input')
    input_filtered = tf.keras.Input(shape=(300, 1), name='filtered_input')
    input_stft = tf.keras.Input(shape=(13, 1), name='stft_input')  

    def conv_branch(input_layer, kernel_sizes=[5, 11, 21]):
        outputs = []
        for k in kernel_sizes:
            pad = k // 2
            x = tf.keras.layers.ZeroPadding1D(padding=pad)(input_layer)
            x = tf.keras.layers.Conv1D(filters=10, kernel_size=k, strides=1, padding='valid')(x)
            x = tf.keras.layers.LeakyReLU(negative_slope=0.01)(x)
            x = tf.keras.layers.MaxPooling1D(pool_size=2)(x)
            x = tf.keras.layers.BatchNormalization()(x)
            outputs.append(x)
        return tf.keras.layers.Concatenate()(outputs)

    # Convolutional branches
    branch_raw = conv_branch(input_raw)
    branch_filtered = conv_branch(input_filtered)
    branch_stft = conv_branch(input_stft)

    # Each branch through its own GRU
    gru_raw = tf.keras.layers.GRU(64)(branch_raw)
    gru_filtered = tf.keras.layers.GRU(64)(branch_filtered)
    gru_stft = tf.keras.layers.GRU(64)(branch_stft)

    # Concatenate GRU outputs (fixed-length vectors)
    merged = tf.keras.layers.Concatenate()([gru_raw, gru_filtered, gru_stft])

    # Dense layers
    x = tf.keras.layers.Dense(64, activation='relu')(merged)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(x)

    # Build model
    model = tf.keras.Model(inputs=[input_raw, input_filtered, input_stft], outputs=output)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

### Slow oscillation detection function

In [5]:
def detect_slow_oscillations_times(combined_raw, do_filter=True, do_downsample=False, downsample_rate=100):

    # according to methods from Klinzing et al.(2016)

    data = combined_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=0.16, h_freq=1.25)

    if do_downsample:
        data.resample(downsample_rate)
        
    sfreq = data.info['sfreq']
    channel_data = data.get_data()[0]
    
    # 3. find all positive-to-negative zero-crossings
    
    # zero_crossings = np.where( S!= 0)[0]
    # can also save this somewhere for further detection of spindles
    
    S = np.diff(np.sign(channel_data))
    # np.sign returns an array with 1 (positive), 0 (zero), -1 (negative)
    # np.diff calculates the difference between consecutive elements in an array
    # positive value: transition from negative to positive
    # negative value: transition from positive to negative
    # when it's a zero, means that value stayed the same
    zero_crossings = np.where(S < 0)[0]
    # -2 is when a positive-to-negative zero-crossing occurs
    # goes from 1 to -1 
    # -1 - 1 = -2
    # [0] extracts the actual array
    # extracts the indices of interest from current_data (not S)
    #signs = np.sign(current_data)
    #pos_to_neg = np.where((signs[:-1] > 0) & (signs[1:] < 0))[0]
    # detect +1 to -1
    #neg_to_pos = np.where((signs[:-1] <  0) & (signs[1:] > 0))[0]
    # detect -1 to +1

    # 4. Detect peak potentials in each pair
    slow_oscillations = []
    negative_peaks = []
    positive_peaks = []
    peak_to_peak_amplitudes = []
    candidate_indices = []

    # for loop for each pair
    # to collect all the negative and positive peaks
    # to further apply criteria
    count = 0
    for i in range(0, len(zero_crossings)-1, 1):
        # loop through all the zero_crossings
        # step of 1 (with step of 2, miss some zero_crossings)
        start_idx = zero_crossings[i] + 1
        # assigns index of zero-crossing (representing start of potential SO)
        # to start_idx
        end_idx = zero_crossings[i + 1] + 1
        # assigns index of next zero-crossing (representing end of potential SO)
        # to end_idx

        # find the negative to positive crossing in between
        #mid_crossings = neg_to_pos[(neg_to_pos > start_idx) & (neg_to_pos < end_idx)]

        #if len(mid_crossings) != 1:
            #continue

        #mid_idx = mid_crossings [0]

        #duration = (end_idx - start_idx) / sfreq
        #if not (0.8 <= duration <= 2.0):
  
        
        segment_length = (end_idx - start_idx) / sfreq

        # need to add +1 because of way extract segment later

        # have identified index for the pair
        
        # extract data segment between crossings
        
        # find peaks
        if 0.8 <= segment_length <= 2.0:
            count += 1
            segment = channel_data[start_idx:end_idx]
            positive_peak = np.max(segment)
            negative_peak = np.min(segment)
            peak_to_peak_amplitude = positive_peak - negative_peak

        # store values
            candidate_indices.append((start_idx, end_idx))
            positive_peaks.append(positive_peak)
            negative_peaks.append(negative_peak)
            peak_to_peak_amplitudes.append(peak_to_peak_amplitude)

    # calculate mean values for comparison
    #mean_negative_peak = np.mean(negative_peaks)
    # mean_negative_peak = np.mean(negative_peaks) if negative_peaks else 0
    #mean_peak_to_peak_amplitude = np.mean(peak_to_peak_amplitudes)
    # mean_peak_to_peak_amplitude = np.mean(peak_to_peak_amplitudes) if peak_to_peak_amplitudes else 0

    negative_peak_threshold = np.percentile(negative_peaks, 25)
    # keep lowest negative peaks (under the 25th percentile)
    peak_to_peak_amplitude_threshold = np.percentile(peak_to_peak_amplitudes, 75)
    # keep largest peak-to-peak amplitude (over 75th percentile)

    for (start_idx, end_idx), negative_peak, peak_to_peak_amplitude in zip(candidate_indices, negative_peaks, peak_to_peak_amplitudes):
        if peak_to_peak_amplitude >= peak_to_peak_amplitude_threshold and negative_peak <= negative_peak_threshold:
            slow_oscillations.append((start_idx / sfreq, end_idx / sfreq))
            
    return slow_oscillations
    # returns a list of tuples, in which each tuple represents the start and end times of
    # a detected slow oscillation

def detect_slow_oscillations_peaks(combined_raw, do_filter=True, do_downsample=True, downsample_rate=100):

    # according to methods from Klinzing et al.(2016)

    data = combined_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=0.16, h_freq=1.25)

    if do_downsample:
        data.resample(downsample_rate)
        
    sfreq = data.info['sfreq']
    channel_data = data.get_data()[0]
    
    # 3. find all positive-to-negative zero-crossings
    
    # zero_crossings = np.where( S!= 0)[0]
    # can also save this somewhere for further detection of spindles
    
    S = np.diff(np.sign(channel_data))
    # np.sign returns an array with 1 (positive), 0 (zero), -1 (negative)
    # np.diff calculates the difference between consecutive elements in an array
    # positive value: transition from negative to positive
    # negative value: transition from positive to negative
    # when it's a zero, means that value stayed the same
    zero_crossings = np.where(S < 0)[0]
    # -2 is when a positive-to-negative zero-crossing occurs
    # goes from 1 to -1 
    # -1 - 1 = -2
    # [0] extracts the actual array
    # extracts the indices of interest from current_data (not S)


    # 4. Detect peak potentials in each pair
    slow_oscillations = []
    slow_oscillations_peaks = []
    negative_peaks = []
    positive_peaks = []
    peak_to_peak_amplitudes = []
    candidate_indices =  []

    # for loop for each pair
    # to collect all the negative and positive peaks
    # to further apply criteria
    count = 0
    for i in range(0, len(zero_crossings) - 1, 1):
        # loop through all the zero_crossings
        # step of 1 (with step of 2, miss some zero_crossings)
        start_idx = zero_crossings[i] + 1
        # assigns index of zero-crossing (representing start of potential SO)
        # to start_idx
        end_idx = zero_crossings[i + 1] + 1
        # assigns index of next zero-crossing (representing end of potential SO)
        # to end_idx
        segment_length = (end_idx - start_idx) / sfreq

        # need to add +1 because of way extract segment later

        # have identified index for the pair
        
        # extract data segment between crossings
        
        # find peaks
        if 0.8 <= segment_length <= 2.0:
            count += 1
            segment = channel_data[start_idx:end_idx]
            positive_peak = np.max(segment)
            negative_peak = np.min(segment)
            peak_to_peak_amplitude = positive_peak - negative_peak

        # store values
            candidate_indices.append((start_idx, end_idx))
            positive_peaks.append(positive_peak)
            negative_peaks.append(negative_peak)
            peak_to_peak_amplitudes.append(peak_to_peak_amplitude)

    # calculate mean values for comparison
    #mean_negative_peak = np.mean(negative_peaks)
    # mean_negative_peak = np.mean(negative_peaks) if negative_peaks else 0
    #mean_peak_to_peak_amplitude = np.mean(peak_to_peak_amplitudes)
    # mean_peak_to_peak_amplitude = np.mean(peak_to_peak_amplitudes) if peak_to_peak_amplitudes else 0

    negative_peak_threshold = np.percentile(negative_peaks, 25)
    peak_to_peak_amplitude_threshold = np.percentile(peak_to_peak_amplitudes, 75)

    for (start_idx, end_idx), negative_peak, peak_to_peak_amplitude in zip(candidate_indices, negative_peaks, peak_to_peak_amplitudes):
        if peak_to_peak_amplitude >= peak_to_peak_amplitude_threshold and negative_peak <= negative_peak_threshold:
            slow_oscillations.append((start_idx / sfreq, end_idx / sfreq))
            slow_oscillations_peaks.append((negative_peak, positive_peak))

            
    return slow_oscillations
    # returns a list of tuples, in which each tuple represents the start and end times of
    # a detected slow oscillation

### Epochs function

In [6]:
def create_fixed_length_epochs(raw, duration=3.0, overlap=0.0, preload=True, reject_by_annotation=False):

    return mne.make_fixed_length_epochs(
        raw,
        duration=duration,
        overlap=overlap,
        preload=preload,
        reject_by_annotation=reject_by_annotation
    )
# function mne.make_fixed_length_epochs takes into account the sampling frequency of the data


def label_so_epochs_moderate(epochs, so_starts, so_ends, epoch_length_sec=3.0):
    epoch_starts = np.arange(len(epochs)) * epoch_length_sec
    epoch_labels = np.zeros(len(epochs), dtype=int)

    for so_start, so_end in zip(so_starts, so_ends):
        so_duration = so_end - so_start
        required_overlap = 0.5 * so_duration  
        # only label 1 if epoch contains 80% of the SO duration

        for i, epoch_start in enumerate(epoch_starts):
            epoch_end = epoch_start + epoch_length_sec

            # Calculate overlap between SO and epoch
            overlap_start = max(so_start, epoch_start)
            overlap_end = min(so_end, epoch_end)
            overlap_duration = overlap_end - overlap_start

            if overlap_duration >= required_overlap:
                epoch_labels[i] = 1

    return epoch_labels

### Importing data

In [7]:
# for 5-fold validation
# load the all the files needed that were pre-processed before
# from train_1_raw and test_1_raw to train_5_raw and test_5_raw
split_files = {
    f'split_{i}': {
        'train': fr"C:\EEG DATA\combined_sets\train_{i}_raw.fif",
        'test': fr"C:\EEG DATA\combined_sets\test_{i}_raw.fif"
    } for i in range(1, 6) 
}

raw_splits = {}
for split_name, files in split_files.items():
    print(f"Loading data for {split_name}...")
    try:
        train_raw = mne.io.read_raw_fif(files['train'], preload=True)
        test_raw = mne.io.read_raw_fif(files['test'], preload=True)
        raw_splits[split_name] = {'train': train_raw, 'test': test_raw}
        print(f"Loaded train and test data for {split_name}")
    except FileNotFoundError as e:
        print(f"Error: File not found for {split_name}: {e}")
    except Exception as e:
        print(f"Error loading data for {split_name}: {e}")
        # errors in case dictionaries not found

Loading data for split_1...
Opening raw data file C:\EEG DATA\combined_sets\train_1_raw.fif...
Isotrak not found
    Range : 90000 ... 32700095 =    180.000 ... 65400.190 secs
Ready.
Reading 0 ... 32610095  =      0.000 ... 65220.190 secs...
Opening raw data file C:\EEG DATA\combined_sets\test_1_raw.fif...
Isotrak not found
    Range : 1470000 ... 7845026 =   2940.000 ... 15690.052 secs
Ready.
Reading 0 ... 6375026  =      0.000 ... 12750.052 secs...
Loaded train and test data for split_1
Loading data for split_2...
Opening raw data file C:\EEG DATA\combined_sets\train_2_raw.fif...
Isotrak not found
    Range : 1470000 ... 32985091 =   2940.000 ... 65970.182 secs
Ready.
Reading 0 ... 31515091  =      0.000 ... 63030.182 secs...
Opening raw data file C:\EEG DATA\combined_sets\test_2_raw.fif...
Isotrak not found
    Range : 90000 ... 32700095 =    180.000 ... 65400.190 secs
Ready.
Reading 0 ... 32610095  =      0.000 ... 65220.190 secs...
Loaded train and test data for split_2
Loading da

In [8]:
# we want to evaluate the models on all these scores
model_metrics = {
    'one_input_cnn_raw': {
        'f1_scores': [],
        'accuracy_scores': [],
        'loss_scores': [],
        'precision_scores': [],
        'recall_scores': []
    },
    'one_input_cnn_filtered': {
        'f1_scores': [],
        'accuracy_scores': [],
        'loss_scores': [],
        'precision_scores': [],
        'recall_scores': []
    },
    'three_input_cnn_freq': {
        'f1_scores': [],
        'accuracy_scores': [],
        'loss_scores': [],
        'precision_scores': [],
        'recall_scores': []
    }
}


# List of models to evaluate
models_to_evaluate = {
    'one_input_cnn_raw': build_cnn_model_downsampled,
    'one_input_cnn_filtered': build_cnn_model_downsampled,
    # this has the same model architecture but different input data
    'three_input_cnn_freq': build_multi_input_cnn_model_freq
}

# then we go through each split
# so create a for loop
for split_name, raw_data in raw_splits.items():
    print(f"\n--- Processing Split: {split_name} ---")

    train_raw = raw_data['train']
    test_raw = raw_data['test']

    # this is for each split

    # Slow oscillation detection on raw data for raw/three-input model labels
    so_train_times_raw_downsampled = detect_slow_oscillations_times(train_raw, do_filter=True, do_downsample=True)
    so_test_times_raw_downsampled = detect_slow_oscillations_times(test_raw, do_filter=True, do_downsample=True)
    so_starts_train_raw_downsampled, so_ends_train_raw_downsampled = zip(*so_train_times_raw_downsampled) if so_train_times_raw_downsampled else ([], [])
    so_starts_test_raw_downsampled, so_ends_test_raw_downsampled = zip(*so_test_times_raw_downsampled) if so_test_times_raw_downsampled else ([], [])
    # Downsample raw data for one input
    train_raw_downsampled = train_raw.copy().resample(100)
    test_raw_downsampled = test_raw.copy().resample(100)

    # Filtered data for filtered input and downsample
    train_filtered_downsampled = train_raw.copy().filter(l_freq=0.16, h_freq=1.25)
    test_filtered_downsampled = test_raw.copy().filter(l_freq=0.16, h_freq=1.25)
    train_filtered_downsampled.resample(100)
    test_filtered_downsampled.resample(100)
    # resample because already copied before

    # SO detection on filtered and downsampled data for filtered model labels
    so_train_times_filtered_downsampled = detect_slow_oscillations_times(train_filtered_downsampled, do_filter=False, do_downsample=False)
    so_test_times_filtered_downsampled = detect_slow_oscillations_times(test_filtered_downsampled, do_filter=False, do_downsample=False)
    # since filtering and downsampling before, do not filter and downsample again in function

    so_starts_train_filtered_downsampled, so_ends_train_filtered_downsampled = zip(*so_train_times_filtered_downsampled) if so_train_times_filtered_downsampled else([],[])
    so_starts_test_filtered_downsampled, so_ends_test_filtered_downsampled = zip(*so_test_times_filtered_downsampled) if so_test_times_filtered_downsampled else([],[])


    # then fixed length epochs (using downsampled raw and downsampled and filtered data)
    epochs_train_raw_downsampled = create_fixed_length_epochs(train_raw_downsampled, duration=3.0, overlap=0.0)
    epochs_test_raw_downsampled = create_fixed_length_epochs(test_raw_downsampled, duration=3.0, overlap=0.0)

    epochs_train_filtered_downsampled = create_fixed_length_epochs(train_filtered_downsampled, duration=3.0, overlap=0.0)
    epochs_test_filtered_downsampled = create_fixed_length_epochs(test_filtered_downsampled, duration=3.0, overlap=0.0)


    # Prepare STFT input (using downsampled raw data)
    epochs_train_stft_downsampled = np.squeeze(np.array(epochs_train_raw_downsampled))
    epochs_test_stft_downsampled = np.squeeze(np.array(epochs_test_raw_downsampled))

    fs = train_raw_downsampled.info['sfreq']
    nperseg = 50
    noverlap = nperseg // 2

    X_train_stft_transformed = []
    for epoch in epochs_train_stft_downsampled:
        f, t, Zxx = stft(epoch, fs=fs, nperseg=nperseg, noverlap=noverlap)
        spectrogram = np.abs(Zxx)
        X_train_stft_transformed.append(spectrogram)
    X_train_stft_transformed = np.array(X_train_stft_transformed)

    X_test_stft_transformed = []
    for epoch in epochs_test_stft_downsampled:
        f, t, Zxx = stft(epoch, fs=fs, nperseg=nperseg, noverlap=noverlap)
        spectrogram = np.abs(Zxx)
        X_test_stft_transformed.append(spectrogram)
    X_test_stft_transformed = np.array(X_test_stft_transformed)

    # Reduce STFT dimension to frequency axis only
    X_train_stft_freq = np.mean(X_train_stft_transformed, axis=1)
    X_test_stft_freq = np.mean(X_test_stft_transformed, axis=1)
    X_train_stft_freq = X_train_stft_freq[..., np.newaxis] # Add channel dimension
    X_test_stft_freq = X_test_stft_freq[..., np.newaxis] # Add channel dimension

    # Normalize STFT frequency data
    X_train_stft_freq_norm = np.array([
        (epoch - np.min(epoch)) / (np.max(epoch) - np.min(epoch) + 1e-8)
        for epoch in X_train_stft_freq
    ])
    X_test_stft_freq_norm = np.array([
        (epoch - np.min(epoch)) / (np.max(epoch) - np.min(epoch) + 1e-8)
        for epoch in X_test_stft_freq
    ])


    # Prepare raw and filtered epoch data (reshaping)
    X_train_raw = np.array(epochs_train_raw_downsampled).reshape(len(epochs_train_raw_downsampled), -1, 1)
    X_test_raw = np.array(epochs_test_raw_downsampled).reshape(len(epochs_test_raw_downsampled), -1, 1)

    X_train_filtered = np.array(epochs_train_filtered_downsampled).reshape(len(epochs_train_filtered_downsampled), -1, 1)
    X_test_filtered = np.array(epochs_test_filtered_downsampled).reshape(len(epochs_test_filtered_downsampled), -1, 1)


    # still in the same split
    # now iterate through the models
    for model_name, build_model_func in models_to_evaluate.items():
        print(f"\n--- Evaluating Model: {model_name} on {split_name} ---")

        # Prepare data and labels based on model input requirements
        if model_name == 'one_input_cnn_raw':
            X_train_input = X_train_raw
            X_test_input = X_test_raw
            y_train = label_so_epochs_moderate(epochs_train_raw_downsampled, so_starts_train_raw_downsampled, so_ends_train_raw_downsampled)
            y_test = label_so_epochs_moderate(epochs_test_raw_downsampled, so_starts_test_raw_downsampled, so_ends_test_raw_downsampled)
            input_shape = (X_train_input.shape[1], X_train_input.shape[2])

        elif model_name == 'one_input_cnn_filtered':
             X_train_input = X_train_filtered
             X_test_input = X_test_filtered
             y_train = label_so_epochs_moderate(epochs_train_filtered_downsampled, so_starts_train_filtered_downsampled, so_ends_train_filtered_downsampled)
             y_test = label_so_epochs_moderate(epochs_test_filtered_downsampled, so_starts_test_filtered_downsampled, so_ends_test_filtered_downsampled)
             input_shape = (X_train_input.shape[1], X_train_input.shape[2])

        elif model_name == 'three_input_cnn_freq':
            X_train_input = {
                'raw_input': X_train_raw,
                'filtered_input': X_train_filtered,
                'stft_input': X_train_stft_freq_norm 
            }
            X_test_input = {
                'raw_input': X_test_raw,
                'filtered_input': X_test_filtered,
                'stft_input': X_test_stft_freq_norm 
            }
            # Labels for the three-input model come from the raw downsampled data
            y_train = label_so_epochs_moderate(epochs_train_raw_downsampled, so_starts_train_raw_downsampled, so_ends_train_raw_downsampled)
            y_test = label_so_epochs_moderate(epochs_test_raw_downsampled, so_starts_test_raw_downsampled, so_ends_test_raw_downsampled)

            input_shape = None # Shape is handled by the model definition for multi-input

        print(f"Training data shapes: { {k: v.shape for k, v in X_train_input.items()} if isinstance(X_train_input, dict) else X_train_input.shape}, labels={y_train.shape}")
        print(f"Test data shapes: { {k: v.shape for k, v in X_test_input.items()} if isinstance(X_test_input, dict) else X_test_input.shape}, labels={y_test.shape}")
        # to check whether a dictionary or not 
        # because it is a dictionary for the three-input model
        # but not for the other models


        # build the models

        print("Building and compiling model...")
        if model_name in ['one_input_cnn_raw', 'one_input_cnn_filtered']:
             model = build_model_func(input_shape)
        else:
            model = build_model_func()


        # Define early stopping
        early_stop = EarlyStopping(
            monitor='val_loss',
            patience=5,
            restore_best_weights=True
        )

        # Train the model
        print("Training the model...")
        # Use a validation split from the training data if desired, or train on the full training set
        # For cross-validation on pre-defined splits, it's common to train on the full training set
        history = model.fit(
            X_train_input,
            y_train,
            validation_split=0.2,
            epochs=20, # Adjust epochs as needed
            batch_size=128, # Adjust batch size as needed
            callbacks=[early_stop], # Optional: Use early stopping
        )
        print("Training finished.")

        # Evaluate the model on the test data of the current split
        print(f"Evaluating on {split_name}'s test data...")
        loss, accuracy = model.evaluate(X_test_input, y_test, verbose=0)
        print(f"Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")

        # Get predictions and calculate F1 score
        y_pred_proba = model.predict(X_test_input, verbose=0)
        y_pred_labels = (y_pred_proba > 0.5).astype(int)

        
        split_f1 = f1_score(y_test, y_pred_labels)
        split_precision = precision_score(y_test, y_pred_labels)
        split_recall = recall_score(y_test, y_pred_labels)
        print(f"F1 Score for {model_name} on {split_name}: {split_f1:.4f}")
        print(f"Precision for {model_name} on {split_name}: {split_precision:.4f}")
        print(f"Recall for {model_name} on {split_name}: {split_recall:.4f}")

        # Store the metrics for the current model type and split
        model_metrics[model_name]['f1_scores'].append(split_f1)
        model_metrics[model_name]['accuracy_scores'].append(accuracy)
        model_metrics[model_name]['loss_scores'].append(loss)
        model_metrics[model_name]['precision_scores'].append(split_precision)
        model_metrics[model_name]['recall_scores'].append(split_recall)

        # Clear TensorFlow session to free up memory
        tf.keras.backend.clear_session()

print("\n--- Evaluation finished for all models across all splits ---")


--- Processing Split: split_1 ---
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 96 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 12.00
- Lower transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 10.50 Hz)
- Upper passband edge: 16.00 Hz
- Upper transition bandwidth: 4.00 Hz (-6 dB cutoff frequency: 18.00 Hz)
- Filter length: 551 samples (1.102 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 27 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) metho



Training the model...
Epoch 1/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 114ms/step - accuracy: 0.5630 - loss: 0.6829 - val_accuracy: 0.5584 - val_loss: 0.6818
Epoch 2/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 114ms/step - accuracy: 0.6277 - loss: 0.6515 - val_accuracy: 0.6789 - val_loss: 0.6132
Epoch 3/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 115ms/step - accuracy: 0.6702 - loss: 0.6035 - val_accuracy: 0.6391 - val_loss: 0.6110
Epoch 4/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 115ms/step - accuracy: 0.6152 - loss: 0.6469 - val_accuracy: 0.5345 - val_loss: 0.6915
Epoch 5/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 112ms/step - accuracy: 0.5911 - loss: 0.6700 - val_accuracy: 0.5964 - val_loss: 0.6575
Epoch 6/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 112ms/step - accuracy: 0.6353 - loss: 0.6422 - val_accuracy: 0.6095 - val_l





--- Evaluating Model: one_input_cnn_filtered on split_1 ---
Training data shapes: (21740, 300, 1), labels=(21740,)
Test data shapes: (4250, 300, 1), labels=(4250,)
Building and compiling model...
Training the model...
Epoch 1/20




[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 107ms/step - accuracy: 0.6297 - loss: 0.6495 - val_accuracy: 0.7054 - val_loss: 0.7377
Epoch 2/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 100ms/step - accuracy: 0.7370 - loss: 0.5306 - val_accuracy: 0.8300 - val_loss: 0.4304
Epoch 3/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 100ms/step - accuracy: 0.8414 - loss: 0.3727 - val_accuracy: 0.8528 - val_loss: 0.3258
Epoch 4/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 102ms/step - accuracy: 0.8856 - loss: 0.2690 - val_accuracy: 0.8949 - val_loss: 0.2446
Epoch 5/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 103ms/step - accuracy: 0.8924 - loss: 0.2535 - val_accuracy: 0.9154 - val_loss: 0.2125
Epoch 6/20
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 100ms/step - accuracy: 0.9070 - loss: 0.2229 - val_accuracy: 0.9151 - val_loss: 0.1952
Epoch 7/20
[1m136/13



Training the model...
Epoch 1/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 163ms/step - accuracy: 0.5603 - loss: 0.6851 - val_accuracy: 0.5676 - val_loss: 0.6803
Epoch 2/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 187ms/step - accuracy: 0.6094 - loss: 0.6627 - val_accuracy: 0.5369 - val_loss: 0.6957
Epoch 3/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 218ms/step - accuracy: 0.5570 - loss: 0.6875 - val_accuracy: 0.5781 - val_loss: 0.6784
Epoch 4/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 217ms/step - accuracy: 0.6063 - loss: 0.6559 - val_accuracy: 0.6840 - val_loss: 0.5726
Epoch 5/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 212ms/step - accuracy: 0.7452 - loss: 0.5060 - val_accuracy: 0.6790 - val_loss: 0.6044
Epoch 6/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 213ms/step - accuracy: 0.7647 - loss: 0.4789 - val_accuracy: 0.7587 - val_l



[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 118ms/step - accuracy: 0.5918 - loss: 0.6676 - val_accuracy: 0.5942 - val_loss: 0.6450
Epoch 2/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 114ms/step - accuracy: 0.7977 - loss: 0.4959 - val_accuracy: 0.6894 - val_loss: 0.5984
Epoch 3/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 111ms/step - accuracy: 0.7415 - loss: 0.5329 - val_accuracy: 0.8596 - val_loss: 0.3559
Epoch 4/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 112ms/step - accuracy: 0.8537 - loss: 0.3407 - val_accuracy: 0.8936 - val_loss: 0.2537
Epoch 5/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 112ms/step - accuracy: 0.8948 - loss: 0.2501 - val_accuracy: 0.9131 - val_loss: 0.2188
Epoch 6/20
[1m132/132[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 112ms/step - accuracy: 0.9128 - loss: 0.2087 - val_accuracy: 0.9208 - val_loss: 0.1980
Epoch 7/20
[1m132/13



Training the model...
Epoch 1/20
[1m339/339[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m72s[0m 189ms/step - accuracy: 0.6018 - loss: 0.6609 - val_accuracy: 0.6055 - val_loss: 0.6608
Epoch 2/20
[1m339/339[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m59s[0m 175ms/step - accuracy: 0.7148 - loss: 0.5570 - val_accuracy: 0.8013 - val_loss: 0.4425
Epoch 3/20
[1m339/339[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m59s[0m 174ms/step - accuracy: 0.8281 - loss: 0.4013 - val_accuracy: 0.8503 - val_loss: 0.3562
Epoch 4/20
[1m339/339[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m58s[0m 171ms/step - accuracy: 0.8658 - loss: 0.3120 - val_accuracy: 0.8710 - val_loss: 0.3142
Epoch 5/20
[1m339/339[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m58s[0m 170ms/step - accuracy: 0.8853 - loss: 0.2729 - val_accuracy: 0.8748 - val_loss: 0.2939
Epoch 6/20
[1m339/339[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m58s[0m 171ms/step - accuracy: 0.8944 - loss: 0.2535 - val_accuracy: 0.8948 - val_l



Epoch 1/20
[1m339/339[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m71s[0m 188ms/step - accuracy: 0.6741 - loss: 0.5824 - val_accuracy: 0.8797 - val_loss: 0.3038
Epoch 2/20
[1m339/339[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m63s[0m 186ms/step - accuracy: 0.8632 - loss: 0.3338 - val_accuracy: 0.8600 - val_loss: 0.3148
Epoch 3/20
[1m339/339[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m63s[0m 185ms/step - accuracy: 0.8768 - loss: 0.2805 - val_accuracy: 0.9206 - val_loss: 0.1917
Epoch 4/20
[1m339/339[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m62s[0m 184ms/step - accuracy: 0.9173 - loss: 0.1987 - val_accuracy: 0.9342 - val_loss: 0.1624
Epoch 5/20
[1m339/339[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m63s[0m 185ms/step - accuracy: 0.9291 - loss: 0.1783 - val_accuracy: 0.9396 - val_loss: 0.1534
Epoch 6/20
[1m339/339[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m62s[0m 182ms/step - accuracy: 0.9334 - loss: 0.1671 - val_accuracy: 0.9443 - val_loss: 0.1442
Epoch 7/20

KeyboardInterrupt: 

In [9]:
# Step 6: Calculate and display average metrics for all models across splits

rows = []

for model_name, metrics in model_metrics.items():
    average_f1 = np.mean(metrics['f1_scores'])
    std_f1 = np.std(metrics['f1_scores'])

    average_accuracy = np.mean(metrics['accuracy_scores'])
    std_accuracy = np.std(metrics['accuracy_scores'])

    average_loss = np.mean(metrics['loss_scores'])
    std_loss = np.std(metrics['loss_scores'])

    average_precision = np.mean(metrics['precision_scores'])
    std_precision = np.std(metrics['precision_scores'])

    average_recall = np.mean(metrics['recall_scores'])
    std_recall = np.std(metrics['recall_scores'])

    # Append a row as a dict
    rows.append({
        "Model": model_name,
        "F1 Score (mean ± std)": f"{average_f1:.4f} ± {std_f1:.4f}",
        "Precision (mean ± std)": f"{average_precision:.4f} ± {std_precision:.4f}",
        "Recall (mean ± std)": f"{average_recall:.4f} ± {std_recall:.4f}",
        "Accuracy (mean ± std)": f"{average_accuracy:.4f} ± {std_accuracy:.4f}",
        "Loss (mean ± std)": f"{average_loss:.4f} ± {std_loss:.4f}"
    })

# Create DataFrame
summary_df = pd.DataFrame(rows)

# Print a title and the table
print("\n--- Average Metrics Across Splits For SO Detection ---\n")
print(summary_df.to_string(index=False))

# Step 7: Perform statistical test

# Prepare the data for the Friedman test
# Each row is a split, each column is a model
data_for_friedman = [
    model_metrics['one_input_cnn_raw']['f1_scores'],
    model_metrics['one_input_cnn_filtered']['f1_scores'],
    model_metrics['three_input_cnn_freq']['f1_scores']
]

print(f"\n--- Friedman Test for Comparing F1 Scores Across Splits ---")

# Check if we have enough data points for the Friedman test (at least k=3 models and n>=2 splits/ranks)
if len(data_for_friedman[0]) >= 2 and len(data_for_friedman) >= 3:
    statistic, p_value = friedmanchisquare(*data_for_friedman)

    print(f"Statistic: {statistic:.4f}")
    print(f"P-value: {p_value:.4f}")

    alpha = 0.05
    if p_value < alpha:
        print("\nInterpretation: The p-value is less than the significance level (alpha = 0.05).")
        print("This suggests that there is a statistically significant difference in the median F1 scores among the three models across the splits.")
        print("You may want to perform post-hoc tests (e.g., Wilcoxon signed-rank tests with a multiple comparison correction like Bonferroni or Holm) to determine which specific model pairs are significantly different.")
    else:
        print("\nInterpretation: The p-value is greater than the significance level (alpha = 0.05).")
        print("This suggests that there is no statistically significant difference in the median F1 scores among the three models across the splits.")
else:
    print("\nCannot perform Friedman test: Need at least 3 models and 2 splits with valid F1 scores.")

# You now have the individual scores stored in the model_metrics dictionary
# if you want to perform further analysis on them.


--- Average Metrics Across Splits ---


KeyError: 'precision_scores'

In [None]:
# Step 6: Calculate and display average metrics for all models across splits
# This part is already in the previous cell, but we'll re-calculate averages for the table/plot

model_names = list(model_metrics.keys())
average_f1s = [np.mean(model_metrics[name]['f1_scores']) for name in model_names]
std_f1s = [np.std(model_metrics[name]['f1_scores']) for name in model_names]

# Create a summary table (DataFrame)
summary_data = {
    'Model': model_names,
    'Average F1 Score (Label 1)': average_f1s,
    'Standard Deviation (F1 Score)': std_f1s
}
summary_df = pd.DataFrame(summary_data)

print("\n--- Summary Table of Average Metrics Across Splits for SO Detection ---")
display(summary_df)

# Create a bar plot with error bars
plt.figure(figsize=(10, 6))
bars = plt.bar(summary_df['Model'], summary_df['Average F1 Score (Label 1)'], yerr=summary_df['Standard Deviation (F1 Score)'], capsize=5, color=['skyblue', 'lightcoral', 'lightgreen'])

# Add the average F1 score on top of each bar
for bar in bars:
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2.0, yval, f'{yval:.4f}', va='bottom', ha='center') # va: vertical alignment

plt.ylabel("Average F1 Score (Label 1)")
plt.title("Average F1 Score (Label 1) per Model Across 5 Splits")
plt.ylim(0, 1.05) # F1 score is between 0 and 1
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()