In [1]:
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
import mne
import matplotlib.pyplot as plt
import pyvista
import ipywidgets
import ipyevents
import pyvistaqt
import yasa

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, f1_score, roc_curve, auc
from sklearn.utils import class_weight

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping

import scipy.signal as signal
from scipy.signal import hilbert
from scipy.signal import stft

SEED = 82

os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

os.environ['TF_DETERMINISTIC_OPS'] = '1'
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

In [2]:
%matplotlib qt

### One-input CNN model

In [3]:
def build_cnn_model_downsampled(input_shape=(300,1)):

    # linear embedding layer
    input_layer = tf.keras.layers.Input(shape=input_shape)

    # Three convolutional blocks (like having three pattern detectors)

    # First convolution block, kernel size of 5
    padded1 = tf.keras.layers.ZeroPadding1D(padding=2)(input_layer)
    conv1 = tf.keras.layers.Conv1D(filters=10, kernel_size=5, strides=1, padding='valid')(padded1)
    # each filter learns a different type of short-time feature
    # stride of 1, moves one step at a time
    conv1 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv1)
    conv1 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv1)
    # K = 2
    conv1 = tf.keras.layers.BatchNormalization()(conv1)

    # Second convolution block, kernel size of 11
    padded2 = tf.keras.layers.ZeroPadding1D(padding=5)(input_layer)
    conv2 = tf.keras.layers.Conv1D(filters=10, kernel_size=11, strides=1, padding='valid')(padded2)
    conv2 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv2)
    conv2 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv2)
    conv2 = tf.keras.layers.BatchNormalization()(conv2)

    # Third convolution block, kernel size of 21
    padded3 = tf.keras.layers.ZeroPadding1D(padding=10)(input_layer)
    conv3 = tf.keras.layers.Conv1D(filters=10, kernel_size=21, strides=1, padding='valid')(padded3)
    conv3 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv3)
    conv3 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv3)
    conv3 = tf.keras.layers.BatchNormalization()(conv3)

    # Concatenate the outputs of all blocks
    concatenated = tf.keras.layers.Concatenate()([conv1, conv2, conv3])

    # GRU Layer
    gru = tf.keras.layers.GRU(64)(concatenated)

    # Fully connected (dense) layer
    dense = tf.keras.layers.Dense(64, activation='relu')(gru)
    # add a Dropout layer to prevent overfitting
    #dense = tf.keras.layers.Dropout(0.5)(dense)

    # Two softmax outputs for dual-task classification
    #output_task1 = tf.keras.layers.Dense(2, activation='softmax', name='task1')(dense)
    #output_task2 = tf.keras.layers.Dense(2, activation='softmax', name='task2')(dense)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(dense)

    # Create the model
    #model = tf.keras.models.Model(inputs=input_layer, outputs=[output_task1, output_task2])
    model = tf.keras.models.Model(inputs=input_layer, outputs=output)

    # Compile the model
    #model.compile(optimizer='adam', loss={'task1': 'categorical_crossentropy', 'task2': 'categorical_crossentropy'}, metrics={'task1': 'accuracy', 'task2': 'accuracy'})
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    # Return the compiled model
    return model

### Three-input CNN model

In [4]:
def build_multi_input_cnn_model_time():
    
    input_raw = tf.keras.Input(shape=(300, 1), name='raw_input')
    input_filtered = tf.keras.Input(shape=(300, 1), name='filtered_input')
    input_stft = tf.keras.Input(shape=(26, 1), name='stft_input')  
    # 300 time points of raw and filtered EEG
    # 26 time points from the STFT transform

    def conv_branch(input_layer, kernel_sizes=[5, 11, 21]):
        outputs = []
        for k in kernel_sizes:
            pad = k // 2
            x = tf.keras.layers.ZeroPadding1D(padding=pad)(input_layer)
            # applies zero-padding so that the length after convolution stays the same
            x = tf.keras.layers.Conv1D(filters=10, kernel_size=k, strides=1, padding='valid')(x)
            # performs the convolution
            x = tf.keras.layers.LeakyReLU(negative_slope=0.01)(x)
            # applies LeakyReLU 
            x = tf.keras.layers.MaxPooling1D(pool_size=2)(x)
            # reduce temporal dimension by 2
            x = tf.keras.layers.BatchNormalization()(x)
            # normalize the output
            outputs.append(x)
        return tf.keras.layers.Concatenate()(outputs)

    # Convolutional branches: each input has its own CNN branch
    branch_raw = conv_branch(input_raw)
    branch_filtered = conv_branch(input_filtered)
    branch_stft = conv_branch(input_stft)

    # Each branch through its own GRU
    # this captures temporal dependencies
    gru_raw = tf.keras.layers.GRU(64)(branch_raw)
    gru_filtered = tf.keras.layers.GRU(64)(branch_filtered)
    gru_stft = tf.keras.layers.GRU(64)(branch_stft)

    # Concatenate GRU outputs (fixed-length vectors)
    # this leads to one 192-dimensional vector
    merged = tf.keras.layers.Concatenate()([gru_raw, gru_filtered, gru_stft])

    # Dense layers
    x = tf.keras.layers.Dense(64, activation='relu')(merged)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(x)

    # Build model
    model = tf.keras.Model(inputs=[input_raw, input_filtered, input_stft], outputs=output)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

In [5]:
def build_multi_input_cnn_model_freq():
    # Inputs
    input_raw = tf.keras.Input(shape=(300, 1), name='raw_input')
    input_filtered = tf.keras.Input(shape=(300, 1), name='filtered_input')
    input_stft = tf.keras.Input(shape=(13, 1), name='stft_input')  

    def conv_branch(input_layer, kernel_sizes=[5, 11, 21]):
        outputs = []
        for k in kernel_sizes:
            pad = k // 2
            x = tf.keras.layers.ZeroPadding1D(padding=pad)(input_layer)
            x = tf.keras.layers.Conv1D(filters=10, kernel_size=k, strides=1, padding='valid')(x)
            x = tf.keras.layers.LeakyReLU(negative_slope=0.01)(x)
            x = tf.keras.layers.MaxPooling1D(pool_size=2)(x)
            x = tf.keras.layers.BatchNormalization()(x)
            outputs.append(x)
        return tf.keras.layers.Concatenate()(outputs)

    # Convolutional branches
    branch_raw = conv_branch(input_raw)
    branch_filtered = conv_branch(input_filtered)
    branch_stft = conv_branch(input_stft)

    # Each branch through its own GRU
    gru_raw = tf.keras.layers.GRU(64)(branch_raw)
    gru_filtered = tf.keras.layers.GRU(64)(branch_filtered)
    gru_stft = tf.keras.layers.GRU(64)(branch_stft)

    # Concatenate GRU outputs (fixed-length vectors)
    merged = tf.keras.layers.Concatenate()([gru_raw, gru_filtered, gru_stft])

    # Dense layers
    x = tf.keras.layers.Dense(64, activation='relu')(merged)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(x)

    # Build model
    model = tf.keras.Model(inputs=[input_raw, input_filtered, input_stft], outputs=output)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

### Slow oscillation detection function

In [6]:
def detect_slow_oscillations_times(combined_raw, do_filter=True, do_downsample=False, downsample_rate=100):

    # according to methods from Klinzing et al.(2016)

    data = combined_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=0.16, h_freq=1.25)

    if do_downsample:
        data.resample(downsample_rate)
        
    sfreq = data.info['sfreq']
    channel_data = data.get_data()[0]
    
    # 3. find all positive-to-negative zero-crossings
    
    # zero_crossings = np.where( S!= 0)[0]
    # can also save this somewhere for further detection of spindles
    
    S = np.diff(np.sign(channel_data))
    # np.sign returns an array with 1 (positive), 0 (zero), -1 (negative)
    # np.diff calculates the difference between consecutive elements in an array
    # positive value: transition from negative to positive
    # negative value: transition from positive to negative
    # when it's a zero, means that value stayed the same
    zero_crossings = np.where(S < 0)[0]
    # -2 is when a positive-to-negative zero-crossing occurs
    # goes from 1 to -1 
    # -1 - 1 = -2
    # [0] extracts the actual array
    # extracts the indices of interest from current_data (not S)
    #signs = np.sign(current_data)
    #pos_to_neg = np.where((signs[:-1] > 0) & (signs[1:] < 0))[0]
    # detect +1 to -1
    #neg_to_pos = np.where((signs[:-1] <  0) & (signs[1:] > 0))[0]
    # detect -1 to +1

    # 4. Detect peak potentials in each pair
    slow_oscillations = []
    negative_peaks = []
    positive_peaks = []
    peak_to_peak_amplitudes = []
    candidate_indices = []

    # for loop for each pair
    # to collect all the negative and positive peaks
    # to further apply criteria
    count = 0
    for i in range(0, len(zero_crossings)-1, 1):
        # loop through all the zero_crossings
        # step of 1 (with step of 2, miss some zero_crossings)
        start_idx = zero_crossings[i] + 1
        # assigns index of zero-crossing (representing start of potential SO)
        # to start_idx
        end_idx = zero_crossings[i + 1] + 1
        # assigns index of next zero-crossing (representing end of potential SO)
        # to end_idx

        # find the negative to positive crossing in between
        #mid_crossings = neg_to_pos[(neg_to_pos > start_idx) & (neg_to_pos < end_idx)]

        #if len(mid_crossings) != 1:
            #continue

        #mid_idx = mid_crossings [0]

        #duration = (end_idx - start_idx) / sfreq
        #if not (0.8 <= duration <= 2.0):
  
        
        segment_length = (end_idx - start_idx) / sfreq

        # need to add +1 because of way extract segment later

        # have identified index for the pair
        
        # extract data segment between crossings
        
        # find peaks
        if 0.8 <= segment_length <= 2.0:
            count += 1
            segment = channel_data[start_idx:end_idx]
            positive_peak = np.max(segment)
            negative_peak = np.min(segment)
            peak_to_peak_amplitude = positive_peak - negative_peak

        # store values
            candidate_indices.append((start_idx, end_idx))
            positive_peaks.append(positive_peak)
            negative_peaks.append(negative_peak)
            peak_to_peak_amplitudes.append(peak_to_peak_amplitude)

    # calculate mean values for comparison
    #mean_negative_peak = np.mean(negative_peaks)
    # mean_negative_peak = np.mean(negative_peaks) if negative_peaks else 0
    #mean_peak_to_peak_amplitude = np.mean(peak_to_peak_amplitudes)
    # mean_peak_to_peak_amplitude = np.mean(peak_to_peak_amplitudes) if peak_to_peak_amplitudes else 0

    negative_peak_threshold = np.percentile(negative_peaks, 25)
    # keep lowest negative peaks (under the 25th percentile)
    peak_to_peak_amplitude_threshold = np.percentile(peak_to_peak_amplitudes, 75)
    # keep largest peak-to-peak amplitude (over 75th percentile)

    for (start_idx, end_idx), negative_peak, peak_to_peak_amplitude in zip(candidate_indices, negative_peaks, peak_to_peak_amplitudes):
        if peak_to_peak_amplitude >= peak_to_peak_amplitude_threshold and negative_peak <= negative_peak_threshold:
            slow_oscillations.append((start_idx / sfreq, end_idx / sfreq))
            
    return slow_oscillations
    # returns a list of tuples, in which each tuple represents the start and end times of
    # a detected slow oscillation

def detect_slow_oscillations_peaks(combined_raw, do_filter=True, do_downsample=True, downsample_rate=100):

    # according to methods from Klinzing et al.(2016)

    data = combined_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=0.16, h_freq=1.25)

    if do_downsample:
        data.resample(downsample_rate)
        
    sfreq = data.info['sfreq']
    channel_data = data.get_data()[0]
    
    # 3. find all positive-to-negative zero-crossings
    
    # zero_crossings = np.where( S!= 0)[0]
    # can also save this somewhere for further detection of spindles
    
    S = np.diff(np.sign(channel_data))
    # np.sign returns an array with 1 (positive), 0 (zero), -1 (negative)
    # np.diff calculates the difference between consecutive elements in an array
    # positive value: transition from negative to positive
    # negative value: transition from positive to negative
    # when it's a zero, means that value stayed the same
    zero_crossings = np.where(S < 0)[0]
    # -2 is when a positive-to-negative zero-crossing occurs
    # goes from 1 to -1 
    # -1 - 1 = -2
    # [0] extracts the actual array
    # extracts the indices of interest from current_data (not S)


    # 4. Detect peak potentials in each pair
    slow_oscillations = []
    slow_oscillations_peaks = []
    negative_peaks = []
    positive_peaks = []
    peak_to_peak_amplitudes = []
    candidate_indices =  []

    # for loop for each pair
    # to collect all the negative and positive peaks
    # to further apply criteria
    count = 0
    for i in range(0, len(zero_crossings) - 1, 1):
        # loop through all the zero_crossings
        # step of 1 (with step of 2, miss some zero_crossings)
        start_idx = zero_crossings[i] + 1
        # assigns index of zero-crossing (representing start of potential SO)
        # to start_idx
        end_idx = zero_crossings[i + 1] + 1
        # assigns index of next zero-crossing (representing end of potential SO)
        # to end_idx
        segment_length = (end_idx - start_idx) / sfreq

        # need to add +1 because of way extract segment later

        # have identified index for the pair
        
        # extract data segment between crossings
        
        # find peaks
        if 0.8 <= segment_length <= 2.0:
            count += 1
            segment = channel_data[start_idx:end_idx]
            positive_peak = np.max(segment)
            negative_peak = np.min(segment)
            peak_to_peak_amplitude = positive_peak - negative_peak

        # store values
            candidate_indices.append((start_idx, end_idx))
            positive_peaks.append(positive_peak)
            negative_peaks.append(negative_peak)
            peak_to_peak_amplitudes.append(peak_to_peak_amplitude)

    # calculate mean values for comparison
    #mean_negative_peak = np.mean(negative_peaks)
    # mean_negative_peak = np.mean(negative_peaks) if negative_peaks else 0
    #mean_peak_to_peak_amplitude = np.mean(peak_to_peak_amplitudes)
    # mean_peak_to_peak_amplitude = np.mean(peak_to_peak_amplitudes) if peak_to_peak_amplitudes else 0

    negative_peak_threshold = np.percentile(negative_peaks, 25)
    peak_to_peak_amplitude_threshold = np.percentile(peak_to_peak_amplitudes, 75)

    for (start_idx, end_idx), negative_peak, peak_to_peak_amplitude in zip(candidate_indices, negative_peaks, peak_to_peak_amplitudes):
        if peak_to_peak_amplitude >= peak_to_peak_amplitude_threshold and negative_peak <= negative_peak_threshold:
            slow_oscillations.append((start_idx / sfreq, end_idx / sfreq))
            slow_oscillations_peaks.append((negative_peak, positive_peak))

            
    return slow_oscillations
    # returns a list of tuples, in which each tuple represents the start and end times of
    # a detected slow oscillation

### Epochs function

In [7]:
def create_fixed_length_epochs(raw, duration=3.0, overlap=0.0, preload=True, reject_by_annotation=False):
    #raw_zero = raw.copy().crop(tmin=raw.times[0], tmax=raw.times[-1], include_tmax=True)
    # because otherwise does not start at 0
    return mne.make_fixed_length_epochs(
        #raw_zero,
        raw,
        duration=duration,
        overlap=overlap,
        preload=preload,
        reject_by_annotation=reject_by_annotation
    )
# function mne.make_fixed_length_epochs takes into account the sampling frequency of the data

# because can change the overlap in the previous function
# can also change this function


def label_so_epochs(epochs, so_starts, so_ends, epoch_length_sec=3.0):

    epoch_starts = np.arange(len(epochs)) * epoch_length_sec
    # new np array with the start time of each epoch
    # epoch_starts[i] is the start time of each epoch

    epoch_labels = np.zeros(len(epochs), dtype=int)
    # initialize all the labels as 0 initially

    for start, end in zip(so_starts, so_ends):
        # loop through the start and end times of detected spindles by YASA
        for i, epoch_start in enumerate(epoch_starts):
            # loop through the one-second epochs that are not labelled yet
            epoch_end = epoch_start + epoch_length_sec
            # for each epoch, calculate the epoch end time
            # which is epoch_start + length of epoch
            # so now have the time range of each epoch
            if (start < epoch_end) and (end > epoch_start):
                # if the spindle started before the epoch ends
                # and the spindle ended after the epoch started
                epoch_labels[i] = 1
                
    return epoch_labels

def label_so_epochs_strict(epochs, so_starts, so_ends, epoch_length_sec=3.0):
    epoch_starts = np.arange(len(epochs)) * epoch_length_sec
    epoch_labels = np.zeros(len(epochs), dtype=int)

    for so_start, so_end in zip(so_starts, so_ends):
        so_duration = so_end - so_start
        required_overlap = 0.8 * so_duration  
        # only label 1 if epoch contains 80% of the SO duration

        for i, epoch_start in enumerate(epoch_starts):
            epoch_end = epoch_start + epoch_length_sec

            # Calculate overlap between SO and epoch
            overlap_start = max(so_start, epoch_start)
            overlap_end = min(so_end, epoch_end)
            overlap_duration = overlap_end - overlap_start

            if overlap_duration >= required_overlap:
                epoch_labels[i] = 1

    return epoch_labels

def label_so_epochs_moderate(epochs, so_starts, so_ends, epoch_length_sec=3.0):
    epoch_starts = np.arange(len(epochs)) * epoch_length_sec
    epoch_labels = np.zeros(len(epochs), dtype=int)

    for so_start, so_end in zip(so_starts, so_ends):
        so_duration = so_end - so_start
        required_overlap = 0.5 * so_duration  
        # only label 1 if epoch contains 80% of the SO duration

        for i, epoch_start in enumerate(epoch_starts):
            epoch_end = epoch_start + epoch_length_sec

            # Calculate overlap between SO and epoch
            overlap_start = max(so_start, epoch_start)
            overlap_end = min(so_end, epoch_end)
            overlap_duration = overlap_end - overlap_start

            if overlap_duration >= required_overlap:
                epoch_labels[i] = 1

    return epoch_labels



In [8]:
# def label_so_epochs(epochs, so_starts, so_ends):
#     sfreq = epochs.info['sfreq']
#     epoch_starts = epochs.events[:, 0] / sfreq
#     # gives the sample index at which each epoch starts
#     # divides by sfreq converts these to start times in seconds
#     epoch_length_sec = epochs.get_data().shape[2] / sfreq
#     # gives the duration of each epoch

#     epoch_labels = np.zeros(len(epochs), dtype=int)
#     # initialize all the labels as 0 initially

#     for start, end in zip(so_starts, so_ends):
#         # loop through the start and end times of detected spindles 
#         for i, epoch_start in enumerate(epoch_starts):
#             # loop through the one-second epochs that are not labelled yet
#             epoch_end = epoch_start + epoch_length_sec
#             # for each epoch, calculate the epoch end time
#             # which is epoch_start + length of epoch
#             # so now have the time range of each epoch
#             if (start < epoch_end) and (end > epoch_start):
#                 # if the spindle started before the epoch ends
#                 # and the spindle ended after the epoch started
#                 epoch_labels[i] = 1
                
#     return epoch_labels

### Importing data

In [9]:
# file paths
train_file = r"C:\EEG DATA\combined_sets\train_raw.fif"
test_file = r"C:\EEG DATA\combined_sets\test_raw.fif"

# load raw files
train_raw = mne.io.read_raw_fif(train_file, preload=True)
test_raw = mne.io.read_raw_fif(test_file, preload=True)

#train_raw.crop(tmin=train_raw.times[0], include_tmax=True)
#test_raw.crop(tmin=test_raw.times[0], include_tmax=True)

Opening raw data file C:\EEG DATA\combined_sets\train_raw.fif...
Isotrak not found
    Range : 1470000 ... 23295072 =   2940.000 ... 46590.144 secs
Ready.
Reading 0 ... 21825072  =      0.000 ... 43650.144 secs...
Opening raw data file C:\EEG DATA\combined_sets\test_raw.fif...
Isotrak not found
    Range : 825000 ... 17985049 =   1650.000 ... 35970.098 secs
Ready.
Reading 0 ... 17160049  =      0.000 ... 34320.098 secs...


## With raw data

### Slow oscillation detection

In [10]:
so_train_times_raw_downsampled = detect_slow_oscillations_times(train_raw, do_filter=True, do_downsample=True)
so_test_times_raw_downsampled = detect_slow_oscillations_times(test_raw, do_filter=True, do_downsample=True)

so_starts_train_raw_downsampled, so_ends_train_raw_downsampled = zip(*so_train_times_raw_downsampled) if so_train_times_raw_downsampled else([],[])
so_starts_test_raw_downsampled, so_ends_test_raw_downsampled = zip(*so_test_times_raw_downsampled) if so_test_times_raw_downsampled else([],[])

print(len(so_starts_train_raw_downsampled))
print(len(so_ends_train_raw_downsampled))

print(len(so_starts_test_raw_downsampled))
print(len(so_ends_test_raw_downsampled))

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 73 contiguous segments
Setting up band-pass filter from 0.16 - 1.2 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.16
- Lower transition bandwidth: 0.16 Hz (-6 dB cutoff frequency: 0.08 Hz)
- Upper passband edge: 1.25 Hz
- Upper transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 2.25 Hz)
- Filter length: 10313 samples (20.626 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 50 contiguous segments
Setting up band-pass filter from 0.16 - 1.2 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194

### Downsample

In [11]:
train_raw_downsampled = train_raw.copy().resample(100)
test_raw_downsampled = test_raw.copy().resample(100)

print(train_raw_downsampled.info['sfreq'])
print(test_raw_downsampled.info['sfreq'])

100.0
100.0


### Epoch the data

In [12]:
epochs_train_raw_downsampled = create_fixed_length_epochs(train_raw_downsampled, duration=3.0, overlap=0.0)
epochs_test_raw_downsampled = create_fixed_length_epochs(test_raw_downsampled, duration=3.0, overlap=0.0)
print(epochs_train_raw_downsampled.get_data().shape)
print(epochs_test_raw_downsampled.get_data().shape)

Not setting metadata
14550 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 14550 events and 300 original time points ...
0 bad epochs dropped
Not setting metadata
11440 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 11440 events and 300 original time points ...
0 bad epochs dropped
(14550, 1, 300)
(11440, 1, 300)


In [13]:
%%time

# Train set

epoch_labels_train_raw_downsampled = label_so_epochs(epochs_train_raw_downsampled, so_starts_train_raw_downsampled, so_ends_train_raw_downsampled)

print(f"Train data first 10 SO: {so_starts_train_raw_downsampled[:10]}")
print(f"Train labels first 10 epochs: {epoch_labels_train_raw_downsampled[:10]}")

# Test set

epoch_labels_test_raw_downsampled = label_so_epochs(epochs_test_raw_downsampled, so_starts_test_raw_downsampled, so_ends_test_raw_downsampled)

print(f"\nTest data first 10 SO: {so_starts_test_raw_downsampled[:10]}")
print(f"Test labels first 10 epochs: {epoch_labels_test_raw_downsampled[:10]}")

Train data first 10 SO: (np.float64(7.82), np.float64(83.41), np.float64(108.08), np.float64(262.01), np.float64(274.78), np.float64(308.33), np.float64(330.51), np.float64(378.52), np.float64(394.11), np.float64(478.54))
Train labels first 10 epochs: [0 0 1 0 0 0 0 0 0 0]

Test data first 10 SO: (np.float64(65.96), np.float64(120.42), np.float64(161.02), np.float64(224.46), np.float64(257.48), np.float64(353.86), np.float64(404.85), np.float64(406.44), np.float64(440.22), np.float64(459.74))
Test labels first 10 epochs: [0 0 0 0 0 0 0 0 0 0]
CPU times: total: 11.4 s
Wall time: 11.4 s


In [None]:
%%time

# Train set

epoch_labels_train_raw_downsampled_moderate = label_so_epochs_moderate(epochs_train_raw_downsampled, so_starts_train_raw_downsampled, so_ends_train_raw_downsampled)

print(f"Train data first 10 SO: {so_starts_train_raw_downsampled[:10]}")
print(f"Train labels first 10 epochs: {epoch_labels_train_raw_downsampled_moderate[:10]}")

# Test set

epoch_labels_test_raw_downsampled_moderate = label_so_epochs_moderate(epochs_test_raw_downsampled, so_starts_test_raw_downsampled, so_ends_test_raw_downsampled)

print(f"\nTest data first 10 SO: {so_starts_test_raw_downsampled[:10]}")
print(f"Test labels first 10 epochs: {epoch_labels_test_raw_downsampled_moderate[:10]}")

Train data first 10 SO: (np.float64(7.82), np.float64(83.41), np.float64(108.08), np.float64(262.01), np.float64(274.78), np.float64(308.33), np.float64(330.51), np.float64(378.52), np.float64(394.11), np.float64(478.54))
Train labels first 10 epochs: [0 0 1 0 0 0 0 0 0 0]


In [None]:

# Train set
train_labels = epoch_labels_train_raw_downsampled
train_counts = np.bincount(train_labels)
train_total = len(train_labels)

print("\nTrain label distribution:")
for label, count in enumerate(train_counts):
    proportion = count / train_total
    print(f"Label {label}: count = {count}, proportion = {proportion:.2%}")

# Test set
test_labels = epoch_labels_test_raw_downsampled
test_counts = np.bincount(test_labels)
test_total = len(test_labels)

print("\nTest label distribution:")
for label, count in enumerate(test_counts):
    proportion = count / test_total
    print(f"Label {label}: count = {count}, proportion = {proportion:.2%}")


In [None]:

# Train set
train_labels = epoch_labels_train_raw_downsampled_moderate
train_counts = np.bincount(train_labels)
train_total = len(train_labels)

print("\nTrain label distribution:")
for label, count in enumerate(train_counts):
    proportion = count / train_total
    print(f"Label {label}: count = {count}, proportion = {proportion:.2%}")

# Test set
test_labels = epoch_labels_test_raw_downsampled_moderate
test_counts = np.bincount(test_labels)
test_total = len(test_labels)

print("\nTest label distribution:")
for label, count in enumerate(test_counts):
    proportion = count / test_total
    print(f"Label {label}: count = {count}, proportion = {proportion:.2%}")


In [None]:
print(len(epoch_labels_train_raw_downsampled))
print(len(epoch_labels_test_raw_downsampled))

### Prepare EEG data for CNN Inputs

#### X and y train and test sets

In [None]:
# Reshape arrays

epochs_train_np_raw_downsampled = np.array(epochs_train_raw_downsampled).reshape(len(epochs_train_raw_downsampled), -1, 1)
# number of epochs N, sampling frequency (time dimension automatically inferred), channel dimension
epochs_test_np_raw_downsampled = np.array(epochs_test_raw_downsampled).reshape(len(epochs_test_raw_downsampled), -1, 1)
                                                                 
# Define X and y sets
                                                                 
X_train_raw_downsampled = epochs_train_np_raw_downsampled
y_train_raw_downsampled = epoch_labels_train_raw_downsampled_moderate

X_test_raw_downsampled = epochs_test_np_raw_downsampled
y_test_raw_downsampled = epoch_labels_test_raw_downsampled_moderate

# Print shapes

print(f"X_train shape: {X_train_raw_downsampled.shape}")
print(f"y_train shape: {y_train_raw_downsampled.shape}")

print(f"\nX_test shape: {X_test_raw_downsampled.shape}")
print(f"y_test shape: {y_test_raw_downsampled.shape}")

#### Train the model

In [None]:
early_stop = EarlyStopping(
    monitor='val_loss',      
    patience=5,               
    restore_best_weights=True 
)
# stop after 5 epochs with no improvement

In [None]:
# show model architecture
input_shape = (300, 1)
model_raw_downsampled_so = build_cnn_model_downsampled(input_shape)
model_raw_downsampled_so.summary()

In [None]:
%%time

training_info_raw_downsampled_so = model_raw_downsampled_so.fit(X_train_raw_downsampled, y_train_raw_downsampled, validation_split=0.2, epochs=30, batch_size=128, callbacks=[early_stop])

#### Plot the training history

In [None]:
def plot_training_history(training_info):
    fig, axs = plt.subplots(1, 2, figsize=(16, 5))
    axs[0].plot(training_info.history['loss'], label="training set")
    axs[0].plot(training_info.history['val_loss'], label="validation set")
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Loss")
    axs[0].grid(True)
    axs[0].legend()
    try:
        axs[1].plot(training_info.history['accuracy'], label="training set")
        axs[1].plot(training_info.history['val_accuracy'], label="validation set")
        axs[1].set_xlabel("Epoch")
        axs[1].set_ylabel("Accuracy")
        axs[1].grid(True)
        axs[1].legend()
    except:
        pass
  
    fig.suptitle("Training history for one-input CNN model with raw EEG data for SO detection", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

plot_training_history(training_info_raw_downsampled_so)

In [None]:
model_raw_downsampled_so.evaluate(X_test_raw_downsampled, y_test_raw_downsampled)

In [None]:
y_pred = model_raw_downsampled_so.predict(X_test_raw_downsampled)
y_pred_labels = (y_pred > 0.5).astype(int)

print(confusion_matrix(y_test_raw_downsampled, y_pred_labels))
print(classification_report(y_test_raw_downsampled, y_pred_labels))

In [None]:
# Get predictions
y_pred = model_raw_downsampled_so.predict(X_test_raw_downsampled)
y_pred_labels = (y_pred > 0.5).astype(int)

# Confusion matrix
cm = confusion_matrix(y_test_raw_downsampled, y_pred_labels)
cm_df = pd.DataFrame(cm, index=["Actual 0", "Actual 1"], columns=["Predicted 0", "Predicted 1"])

# Classification report as a dataframe
report = classification_report(y_test_raw_downsampled, y_pred_labels, output_dict=True)
report_df = pd.DataFrame(report).transpose()

# Confusion matrix plotted as heatmap
plt.figure(figsize=(6, 5))
sns.heatmap(cm_df, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix for SO Detection \nusing a one-input CNN with raw EEG data")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.tight_layout()
plt.show()

# Classification report as a table
fig, ax = plt.subplots(figsize=(10, 4))
ax.axis('tight')
ax.axis('off')
table = ax.table(cellText=report_df.round(2).values,
                 colLabels=report_df.columns,
                 rowLabels=report_df.index,
                 cellLoc='center',
                 loc='center')
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1.2, 1.2)
plt.title("Classification Report for SO Detection \nusing a one-input CNN with raw EEG data", fontsize=14)
plt.tight_layout()
plt.show()


In [None]:
# Flatten in case y_pred has shape (n_samples, 1)
y_pred_proba = y_pred.ravel()

# Compute ROC curve and AUC
fpr, tpr, thresholds = roc_curve(y_test_raw_downsampled, y_pred_proba)
roc_auc = auc(fpr, tpr)

# Plotting
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random chance')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic for SO Detection \nusing a one-input CNN with raw EEG data")
plt.legend(loc="lower right")
plt.grid(True)
plt.tight_layout()
plt.show()

## With filtered data

In [None]:
# Apply bandpass filter between 0.16 and 1.25 Hz
# to compare performance of filtered dataset to unfiltered one
train_filtered_downsampled = train_raw.copy().filter(l_freq=0.16, h_freq=1.25)
test_filtered_downsampled = test_raw.copy().filter(l_freq=0.16, h_freq=1.25)

# Downsample to 100 Hz
train_filtered_downsampled.resample(100)
test_filtered_downsampled.resample(100)

### Slow oscillation detection

In [None]:
so_train_times_filtered_downsampled = detect_slow_oscillations_times(train_filtered_downsampled, do_filter=False, do_downsample=False)
so_test_times_filtered_downsampled = detect_slow_oscillations_times(test_filtered_downsampled, do_filter=False, do_downsample=False)
# since filtering and downsampling before, do not filter and downsample again in function

so_starts_train_filtered_downsampled, so_ends_train_filtered_downsampled = zip(*so_train_times_filtered_downsampled) if so_train_times_filtered_downsampled else([],[])
so_starts_test_filtered_downsampled, so_ends_test_filtered_downsampled = zip(*so_test_times_filtered_downsampled) if so_test_times_filtered_downsampled else([],[])

print(len(so_starts_train_filtered_downsampled))
print(len(so_ends_train_filtered_downsampled))

print(len(so_starts_test_filtered_downsampled))
print(len(so_ends_test_filtered_downsampled))

### Epoch the data

In [None]:
epochs_train_filtered_downsampled = create_fixed_length_epochs(train_filtered_downsampled, duration=3.0, overlap=0.0)
epochs_test_filtered_downsampled = create_fixed_length_epochs(test_filtered_downsampled, duration=3.0, overlap=0.0)
print(epochs_train_filtered_downsampled.get_data().shape)
print(epochs_test_filtered_downsampled.get_data().shape)

### Labels for 3-second epochs

In [None]:
%%time

# Train set

epoch_labels_train_filtered_downsampled_moderate = label_so_epochs_moderate(epochs_train_filtered_downsampled, so_starts_train_filtered_downsampled, so_ends_train_filtered_downsampled)

print(f"Train data first 10 slow oscillations: {so_starts_train_filtered_downsampled[:10]}")
print(f"Train labels first 10 epochs: {epoch_labels_train_filtered_downsampled_moderate[:10]}")

# Test set

epoch_labels_test_filtered_downsampled_moderate = label_so_epochs_moderate(epochs_test_filtered_downsampled, so_starts_test_filtered_downsampled, so_ends_test_filtered_downsampled)

print(f"\nTest data first 10 slow oscillations: {so_starts_test_filtered_downsampled[:10]}")
print(f"Test labels first 10 epochs: {epoch_labels_test_filtered_downsampled_moderate[:10]}")

### Prepare EEG data for CNN input

##### X and y train and test sets

In [None]:
# Reshape arrays

epochs_train_np_filtered_downsampled = np.array(epochs_train_filtered_downsampled).reshape(len(epochs_train_filtered_downsampled), -1, 1)
# number of epochs N, sampling frequency (time dimension automatically inferred), channel dimension
epochs_test_np_filtered_downsampled = np.array(epochs_test_filtered_downsampled).reshape(len(epochs_test_filtered_downsampled), -1, 1)
                                                                 
# Define X and y sets
                                                                 
X_train_filtered_downsampled = epochs_train_np_filtered_downsampled
y_train_filtered_downsampled = epoch_labels_train_filtered_downsampled_moderate

X_test_filtered_downsampled = epochs_test_np_filtered_downsampled
y_test_filtered_downsampled = epoch_labels_test_filtered_downsampled_moderate

# Print shapes

print(f"X_train shape: {X_train_filtered_downsampled.shape}")
print(f"y_train shape: {y_train_filtered_downsampled.shape}")

print(f"\nX_test shape: {X_test_filtered_downsampled.shape}")
print(f"y_test shape: {y_test_filtered_downsampled.shape}")

#### Train the model

In [None]:
early_stop = EarlyStopping(
    monitor='val_loss',      
    patience=5,               
    restore_best_weights=True 
)
# stop after 5 epochs with no improvement

In [None]:
# show model architecture
input_shape = (300, 1)
model_filtered_downsampled = build_cnn_model_downsampled(input_shape)
model_filtered_downsampled.summary()

In [None]:
%%time

# change it to training_info_filtered_downsampled_so next time I run the code

training_info_filtered_downsampled = model_filtered_downsampled.fit(X_train_filtered_downsampled, y_train_filtered_downsampled, validation_split=0.2, epochs=30, batch_size=128, callbacks=[early_stop])

#### Plot the training history

In [None]:
def plot_training_history(training_info):
    fig, axs = plt.subplots(1, 2, figsize=(16, 5))
    axs[0].plot(training_info.history['loss'], label="training set")
    axs[0].plot(training_info.history['val_loss'], label="validation set")
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Loss")
    axs[0].grid(True)
    axs[0].legend()
    try:
        axs[1].plot(training_info.history['accuracy'], label="training set")
        axs[1].plot(training_info.history['val_accuracy'], label="validation set")
        axs[1].set_xlabel("Epoch")
        axs[1].set_ylabel("Accuracy")
        axs[1].grid(True)
        axs[1].legend()
    except:
        pass
  
    fig.suptitle("Training history for one-input CNN model with filtered EEG data for SO detection", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

plot_training_history(training_info_filtered_downsampled)

In [None]:
model_filtered_downsampled.evaluate(X_test_filtered_downsampled, y_test_filtered_downsampled)

In [None]:
y_pred = model_filtered_downsampled.predict(X_test_raw_downsampled)
y_pred_labels = (y_pred > 0.5).astype(int)

print(confusion_matrix(y_test_raw_downsampled, y_pred_labels))
print(classification_report(y_test_raw_downsampled, y_pred_labels))

In [None]:
# Get predictions
y_pred = model_filtered_downsampled.predict(X_test_raw_downsampled)
y_pred_labels = (y_pred > 0.5).astype(int)

# Confusion matrix
cm = confusion_matrix(y_test_filtered_downsampled, y_pred_labels)
cm_df = pd.DataFrame(cm, index=["Actual 0", "Actual 1"], columns=["Predicted 0", "Predicted 1"])

# Classification report as a dataframe
report = classification_report(y_test_filtered_downsampled, y_pred_labels, output_dict=True)
report_df = pd.DataFrame(report).transpose()

# Confusion matrix plotted as heatmap
plt.figure(figsize=(6, 5))
sns.heatmap(cm_df, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix for SO Detection \nusing a one-input CNN with filtered EEG data")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.tight_layout()
plt.show()

# Classification report as a table
fig, ax = plt.subplots(figsize=(10, 4))
ax.axis('tight')
ax.axis('off')
table = ax.table(cellText=report_df.round(2).values,
                 colLabels=report_df.columns,
                 rowLabels=report_df.index,
                 cellLoc='center',
                 loc='center')
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1.2, 1.2)
plt.title("Classification Report for SO Detection \nusing a one-input CNN with filtered EEG data", fontsize=14)
plt.tight_layout()
plt.show()


## With STFT

In [None]:
so_train_times_stft_downsampled = detect_slow_oscillations_times(train_raw, do_filter=True, do_downsample=True)
so_test_times_stft_downsampled = detect_slow_oscillations_times(test_raw, do_filter=True, do_downsample=True)

In [None]:
so_starts_train_stft_downsampled, so_ends_train_stft_downsampled = zip(*so_train_times_stft_downsampled) if so_train_times_stft_downsampled else([],[])
so_starts_test_stft_downsampled, so_ends_test_stft_downsampled = zip(*so_test_times_stft_downsampled) if so_test_times_stft_downsampled else([],[])

print(len(so_starts_train_stft_downsampled))
print(len(so_ends_train_stft_downsampled))

print(len(so_starts_test_stft_downsampled))
print(len(so_ends_test_stft_downsampled))

### Epoch the data

In [None]:
epochs_train_stft_downsampled = create_fixed_length_epochs(train_raw_downsampled, duration=3.0, overlap=0.0)
epochs_test_stft_downsampled = create_fixed_length_epochs(test_raw_downsampled, duration=3.0, overlap=0.0)

print(epochs_train_stft_downsampled.get_data().shape)
print(epochs_test_stft_downsampled.get_data().shape)

In [None]:
%%time

# Train set

epoch_labels_train_stft_downsampled_moderate = label_so_epochs_moderate(epochs_train_stft_downsampled, so_starts_train_stft_downsampled, so_ends_train_stft_downsampled)

print(f"Train data first 10 SO: {so_starts_train_stft_downsampled[:10]}")
print(f"Train labels first 10 epochs: {epoch_labels_train_stft_downsampled_moderate[:10]}")

# Test set

epoch_labels_test_stft_downsampled_moderate = label_so_epochs_moderate(epochs_test_stft_downsampled, so_starts_test_stft_downsampled, so_ends_test_stft_downsampled)

print(f"\nTest data first 10 SO: {so_starts_test_stft_downsampled[:10]}")
print(f"Test labels first 10 epochs: {epoch_labels_test_stft_downsampled_moderate[:10]}")

### Apply STFT to each epoch

##### get rid of channel dimension

In [None]:
epochs_train_stft_downsampled = np.squeeze(epochs_train_stft_downsampled)
epochs_test_stft_downsampled = np.squeeze(epochs_test_stft_downsampled)

print(epochs_train_stft_downsampled.shape)
print(epochs_test_stft_downsampled.shape)

In [None]:
print(train_raw_downsampled.info['sfreq'])

In [None]:
fs = train_raw_downsampled.info['sfreq']  
# smaller nperseg means higher time resolution
# noverlap must be less than nperseg

# I want to focus more on the frequency resolution
# at the expense of the time resolution
# for spindles, want a 1-2 Hz resolution (frequency bins every 1-2 Hz)
# let's start with 2 Hz first 
nperseg = 50
noverlap = nperseg // 2
# common practice is to set noverlap to 50% of nperseg


epochs_train_stft_transformed_downsampled = []

for epoch in epochs_train_stft_downsampled:
    f, t, Zxx = stft(epoch, fs=fs, nperseg=nperseg, noverlap=noverlap)
    spectrogram = np.abs(Zxx)  
    epochs_train_stft_transformed_downsampled.append(spectrogram)

epochs_test_stft_transformed_downsampled = []

for epoch in epochs_test_stft_downsampled:
    f, t, Zxx = stft(epoch, fs=fs, nperseg=nperseg, noverlap=noverlap)
    spectrogram = np.abs(Zxx)  
    epochs_test_stft_transformed_downsampled.append(spectrogram)
    
# convert into numpy arrays
epochs_train_stft_transformed_downsampled = np.array(epochs_train_stft_transformed_downsampled)
epochs_test_stft_transformed_downsampled = np.array(epochs_test_stft_transformed_downsampled)

print("Train STFT shape:", epochs_train_stft_transformed_downsampled.shape)
print("Test STFT shape:", epochs_test_stft_transformed_downsampled.shape)

# shape is number of epochs, frequency_bins, time_bins


#### Define X and y sets

In [None]:
# Define X and y sets and reshape

X_train_stft_downsampled = epochs_train_stft_transformed_downsampled[..., np.newaxis]  # Shape: (14550, 65, 25, 1)
y_train_stft_downsampled = epoch_labels_train_stft_downsampled_moderate

X_test_stft_downsampled = epochs_test_stft_transformed_downsampled[..., np.newaxis]    # Shape: (11440, 65, 25, 1)
y_test_stft_downsampled = epoch_labels_test_stft_downsampled_moderate
                                                                 

# Print shapes

print(f"X_train shape: {X_train_stft_downsampled.shape}")
print(f"y_train shape: {y_train_stft_downsampled.shape}")

print(f"\nX_test shape: {X_test_stft_downsampled.shape}")
print(f"y_test shape: {y_test_stft_downsampled.shape}")

### Normalization of data

In [None]:
# this scales the spectrogram in range 0,1
# this is min-max normalisation

X_train_stft_norm_downsampled = np.array([
    (epoch - np.min(epoch)) / (np.max(epoch) - np.min(epoch) + 1e-8)
    for epoch in X_train_stft_downsampled
])

X_test_stft_norm_downsampled = np.array([
    (epoch - np.min(epoch)) / (np.max(epoch) - np.min(epoch) + 1e-8)
    for epoch in X_test_stft_downsampled
])

In [None]:
# should have values between 0 and 1 

print("Before normalisation:")
print("Max train value:", np.max(X_train_stft_downsampled))
print("Min train value:", np.min(X_train_stft_downsampled))

print("Max test value:", np.max(X_test_stft_downsampled))
print("Min test value:", np.min(X_test_stft_downsampled))

print("\nAfter normalisation:")
print("Max train value:", np.max(X_train_stft_norm_downsampled))
print("Min train value:", np.min(X_train_stft_norm_downsampled))

print("Max test value:", np.max(X_test_stft_norm_downsampled))
print("Min test value:", np.min(X_test_stft_norm_downsampled))

## Three-input model

### Reducing STFT dimension

#### Only keeping time axis

In [None]:
X_train_stft_time = np.mean(X_train_stft_norm_downsampled, axis=2)
# collapses the frequency axis by averaging them
X_test_stft_time =  np.mean(X_test_stft_norm_downsampled, axis=2)

print(f"Shape of X_train_stft_time:{X_train_stft_time.shape}")
print(f"Shape of X_test_stft_time:{X_test_stft_time.shape}")

#### Only keeping frequency axis

In [None]:
X_train_stft_freq = np.mean(X_train_stft_norm_downsampled, axis=1) 
X_test_stft_freq = np.mean(X_test_stft_norm_downsampled, axis=1)

print(f"Shape of X_train_stft_freq:{X_train_stft_freq.shape}")
print(f"Shape of X_test_stft_freq:{X_test_stft_freq.shape}")

### 3-input model with frequency for the STFT

In [None]:
X_train_dict_freq = {
    'raw_input': X_train_raw_downsampled,
    'filtered_input': X_train_filtered_downsampled,
    'stft_input': X_train_stft_freq
}

X_test_dict_freq = {
    'raw_input': X_test_raw_downsampled,
    'filtered_input': X_test_filtered_downsampled,
    'stft_input': X_test_stft_freq
}

In [None]:
cnn_model_multi_input_freq = build_multi_input_cnn_model_freq()
cnn_model_multi_input_freq.summary()

In [None]:
early_stop = EarlyStopping(
    monitor='val_loss',      
    patience=5,               
    restore_best_weights=True 
)
# stop after 5 epochs with no improvement

In [None]:
%%time

training_info_multiple_inputs_freq = cnn_model_multi_input_freq.fit(X_train_dict_freq, y_train_raw_downsampled, validation_split=0.2, epochs=20, batch_size=128, callbacks=[early_stop])

In [None]:
def plot_training_history(training_info):
    fig, axs = plt.subplots(1, 2, figsize=(16, 5))
    axs[0].plot(training_info.history['loss'], label="training set")
    axs[0].plot(training_info.history['val_loss'], label="validation set")
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Loss")
    axs[0].grid(True)
    axs[0].legend()
    try:
        axs[1].plot(training_info.history['accuracy'], label="training set")
        axs[1].plot(training_info.history['val_accuracy'], label="validation set")
        axs[1].set_xlabel("Epoch")
        axs[1].set_ylabel("Accuracy")
        axs[1].grid(True)
        axs[1].legend()
    except:
        pass
  
    fig.suptitle("Training History for three-input CNN model with frequency component of STFT for SO detection", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

plot_training_history(training_info_multiple_inputs_freq)

In [None]:
cnn_model_multi_input_freq.evaluate(X_test_dict_freq, y_test_raw_downsampled)

In [None]:
y_pred = cnn_model_multi_input_freq.predict(X_test_dict_freq)
y_pred_labels = (y_pred > 0.5).astype(int)

print(confusion_matrix(y_test_raw_downsampled, y_pred_labels))
print(classification_report(y_test_raw_downsampled, y_pred_labels))

In [None]:
# Get predictions
y_pred = cnn_model_multi_input_freq.predict(X_test_dict_freq)
y_pred_labels = (y_pred > 0.5).astype(int)

# Confusion matrix
cm = confusion_matrix(y_test_raw_downsampled, y_pred_labels)
cm_df = pd.DataFrame(cm, index=["Actual 0", "Actual 1"], columns=["Predicted 0", "Predicted 1"])

# Classification report as a dataframe
report = classification_report(y_test_raw_downsampled, y_pred_labels, output_dict=True)
report_df = pd.DataFrame(report).transpose()

# Confusion matrix plotted as heatmap
plt.figure(figsize=(6, 5))
sns.heatmap(cm_df, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix for SO Detection \nusing a three-input CNN with frequency \ncomponent of STFT")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.tight_layout()
plt.show()

# Classification report as a table
fig, ax = plt.subplots(figsize=(10, 4))
ax.axis('tight')
ax.axis('off')
table = ax.table(cellText=report_df.round(2).values,
                 colLabels=report_df.columns,
                 rowLabels=report_df.index,
                 cellLoc='center',
                 loc='center')
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1.2, 1.2)
plt.title("Classification Report for SO Detection \nusing a three-input CNN with frequency \ncomponent of STFT", fontsize=14)
plt.tight_layout()
plt.show()


In [None]:
# Flatten in case y_pred has shape (n_samples, 1)
y_pred_proba = y_pred.ravel()

# Compute ROC curve and AUC
fpr, tpr, thresholds = roc_curve(y_test_raw_downsampled, y_pred_proba)
roc_auc = auc(fpr, tpr)

# Plotting
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random chance')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic for a three-input CNN model \nwith the STFT frequency component for SO detection")
plt.legend(loc="lower right")
plt.grid(True)
plt.tight_layout()
plt.show()