In [1]:
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
import mne
import matplotlib.pyplot as plt
import pyvista
import ipywidgets
import ipyevents
import pyvistaqt
import yasa
import os
import random

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, f1_score, roc_curve, auc
from sklearn.utils import class_weight
from sklearn.preprocessing import MinMaxScaler

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping

import scipy.signal as signal
from scipy.signal import hilbert
from scipy.signal import stft

import pywt
import cv2


SEED = 56

os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

os.environ['TF_DETERMINISTIC_OPS'] = '1'
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

In [2]:
%matplotlib qt

### CNN one input Models

In [3]:
def build_cnn_model_downsampled(input_shape=(300,1)):

    # linear embedding layer
    input_layer = tf.keras.layers.Input(shape=input_shape)

    # Three convolutional blocks (like having three pattern detectors)

    # First convolution block, kernel size of 5
    padded1 = tf.keras.layers.ZeroPadding1D(padding=2)(input_layer)
    conv1 = tf.keras.layers.Conv1D(filters=10, kernel_size=5, strides=1, padding='valid')(padded1)
    # each filter learns a different type of short-time feature
    # stride of 1, moves one step at a time
    conv1 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv1)
    conv1 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv1)
    # K = 2
    conv1 = tf.keras.layers.BatchNormalization()(conv1)

    # Second convolution block, kernel size of 11
    padded2 = tf.keras.layers.ZeroPadding1D(padding=5)(input_layer)
    conv2 = tf.keras.layers.Conv1D(filters=10, kernel_size=11, strides=1, padding='valid')(padded2)
    conv2 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv2)
    conv2 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv2)
    conv2 = tf.keras.layers.BatchNormalization()(conv2)

    # Third convolution block, kernel size of 21
    padded3 = tf.keras.layers.ZeroPadding1D(padding=10)(input_layer)
    conv3 = tf.keras.layers.Conv1D(filters=10, kernel_size=21, strides=1, padding='valid')(padded3)
    conv3 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv3)
    conv3 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv3)
    conv3 = tf.keras.layers.BatchNormalization()(conv3)

    # Concatenate the outputs of all blocks
    concatenated = tf.keras.layers.Concatenate()([conv1, conv2, conv3])

    # GRU Layer
    gru = tf.keras.layers.GRU(64)(concatenated)

    # Fully connected (dense) layer
    dense = tf.keras.layers.Dense(64, activation='relu')(gru)
    # add a Dropout layer to prevent overfitting
    #dense = tf.keras.layers.Dropout(0.5)(dense)

    # Two softmax outputs for dual-task classification
    #output_task1 = tf.keras.layers.Dense(2, activation='softmax', name='task1')(dense)
    #output_task2 = tf.keras.layers.Dense(2, activation='softmax', name='task2')(dense)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(dense)

    # Create the model
    #model = tf.keras.models.Model(inputs=input_layer, outputs=[output_task1, output_task2])
    model = tf.keras.models.Model(inputs=input_layer, outputs=output)

    # Compile the model
    #model.compile(optimizer='adam', loss={'task1': 'categorical_crossentropy', 'task2': 'categorical_crossentropy'}, metrics={'task1': 'accuracy', 'task2': 'accuracy'})
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    # Return the compiled model
    return model

In [4]:
def build_multi_input_cnn_model_freq():
    # Inputs
    input_raw = tf.keras.Input(shape=(300, 1), name='raw_input')
    input_filtered = tf.keras.Input(shape=(300, 1), name='filtered_input')
    input_stft = tf.keras.Input(shape=(13, 1), name='stft_input')  

    def conv_branch(input_layer, kernel_sizes=[5, 11, 21]):
        outputs = []
        for k in kernel_sizes:
            pad = k // 2
            x = tf.keras.layers.ZeroPadding1D(padding=pad)(input_layer)
            x = tf.keras.layers.Conv1D(filters=10, kernel_size=k, strides=1, padding='valid')(x)
            x = tf.keras.layers.LeakyReLU(negative_slope=0.01)(x)
            x = tf.keras.layers.MaxPooling1D(pool_size=2)(x)
            x = tf.keras.layers.BatchNormalization()(x)
            outputs.append(x)
        return tf.keras.layers.Concatenate()(outputs)

    # Convolutional branches
    branch_raw = conv_branch(input_raw)
    branch_filtered = conv_branch(input_filtered)
    branch_stft = conv_branch(input_stft)

    # Each branch through its own GRU
    gru_raw = tf.keras.layers.GRU(64)(branch_raw)
    gru_filtered = tf.keras.layers.GRU(64)(branch_filtered)
    gru_stft = tf.keras.layers.GRU(64)(branch_stft)

    # Concatenate GRU outputs (fixed-length vectors)
    merged = tf.keras.layers.Concatenate()([gru_raw, gru_filtered, gru_stft])

    # Dense layers
    x = tf.keras.layers.Dense(64, activation='relu')(merged)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(x)

    # Build model
    model = tf.keras.Model(inputs=[input_raw, input_filtered, input_stft], outputs=output)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

### Spindle detection function

In [5]:
def detect_spindles_times(eeg_raw, do_filter=True, do_downsample=False, downsample_rate=100):
    
    # 1. Filter between 12 and 16 Hz
    
    data = eeg_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=12, h_freq=16)
    
    # 2. Downsample at 100 Hz (100 samples per second)

    if do_downsample:
        data.resample(downsample_rate)
    
    sfreq = data.info['sfreq']  
    channel_data = data.get_data()[0]
    # extract the filtered data
    
    
    # 3: Calculate amplitude by applying Hilbert transformation

    hilbert_signal = hilbert(channel_data)
    # apply hilbert transformation to bandpassed data
    # gives analytic signal with amplitude and phase information
    envelope = np.abs(hilbert_signal)
    # take the absolute part of the hilbert signal
    # also the instantaneous power of the signal
    # gives the envelope: amplitude modulation
    # how strength of oscillations change over time
    # size of sliding window
    
    # 4: Perform smoothing with a sliding window of 0.2 seconds
    # this removes high-frequency noise
    
    sliding_window = int(0.2 * sfreq)
    smoothed_envelope = np.convolve(envelope, np.ones(sliding_window) / sliding_window, mode='same')
    # convolving envelope with a uniform filter over the sliding window
    # convolution takes rolling average of 20 samples at a time
    # smooth the signal with the average of values in the window
    # in the smoothed envelope, can detect regions with higher amplitude 
    # which is when a spindle event occurs
    # np.ones: creates a filter kernel
    # have a filter where the sum of all elements equals 1
    # this filter is replaced by the average of the 20 surrounding samples
    # convolution between envelope and averaging filter
    # mode = 'same': so that output of convolution has same length as original envelope

    # 5. Define spindle detection threshold

    threshold = np.percentile(smoothed_envelope, 75)
    spindle_threshold = smoothed_envelope > threshold
    #threshold = np.mean(smoothed_envelope) + 1.5 * np.std(smoothed_envelope)
    #spindle_threshold = smoothed_envelope > threshold
    # threshold is 75th percentile of the smoothed envelope
    # will look at the duration later
    
    # 6. Detect spindles and define peaks and troughs for visualisation
    
    spindles = []
    # initialize list with spindles
    above_threshold = np.where(spindle_threshold)[0]
    # returns indices where signal above the threshold
    stacked_spindles = []
    # initialize list for stacking the spindles for the visualisation
    # contains aligned spindles at peak
    
    if len(above_threshold) > 0:
        # checking it's not empty
        start_idx = above_threshold[0]
        # would be the start of a potential spindle
        for i in range(1, len(above_threshold)):
            if above_threshold[i] > above_threshold[i - 1] + 1:  
                # if above threshold[1] > above_threshold[0] + 1
                # because all indices should be separated by 1
                # so here detects gaps
                # so starting from the second index
                # and comparing each index to the one before
                end_idx = above_threshold[i - 1]
                # so if above condition is true, this is the end of the spindle
                duration = (end_idx - start_idx) / sfreq
                if 0.5 <= duration <= 3:
                    # only keep spindles lasting 0.5 to 3 seconds
                    segment = channel_data[start_idx:end_idx]
                    # extract EEG segment corresponding to detected spindle
                    peak_idx = start_idx + np.argmax(segment) 
                    # extract the peak of the spindle
                    # this will be useful for later
                    spindles.append((start_idx / sfreq, end_idx / sfreq))
                    # all the spindles are stored in spindles
                    
                    # Aligning spindles at peak for visualization
                    before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
                    # still in the for loop, so this is the peak index of individual peak
                    after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
                    # extracting 1.5 seconds before and after peak
                    # max and min are used for out of bounds situations at the start and end of EEG data
                    aligned_segment = channel_data[before_peak_idx:after_peak_idx]
                    stacked_spindles.append(aligned_segment)
                    # the aligned segment is saved in stacked spindles
                
                start_idx = above_threshold[i]
                # update the start index for the for loop

        # then need to process the final spindle
        end_idx = above_threshold[-1]
        duration = (end_idx - start_idx) / sfreq
        if 0.5 <= duration <= 3:
            segment = channel_data[start_idx:end_idx]
            peak_idx = start_idx + np.argmax(segment)
            spindles.append((start_idx / sfreq, end_idx / sfreq))

            before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
            after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
            aligned_segment = channel_data[before_peak_idx:after_peak_idx]
            stacked_spindles.append(aligned_segment)
    
    return spindles
    

def detect_spindles_peaks(eeg_raw, do_filter=True, do_downsample=False, downsample_rate=100):
    
    # 1. Filter between 12 and 16 Hz
    
    data = eeg_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=12, h_freq=16)
    
    # 2. Downsample at 100 Hz (100 samples per second)
    
    if do_downsample:
        data.resample(downsample_rate)
        
    sfreq = data.info['sfreq']  
    # update to new sampling frequency
    # because used later in the code
    channel_data = data.get_data()[0]
    # extract the filtered data
    
    # 3: Calculate amplitude by applying Hilbert transformation

    hilbert_signal = hilbert(channel_data)
    # apply hilbert transformation to bandpassed data
    # gives analytic signal with amplitude and phase information
    envelope = np.abs(hilbert_signal)
    # take the absolute part of the hilbert signal
    # also the instantaneous power of the signal
    # gives the envelope: amplitude modulation
    # how strength of oscillations change over time
    # size of sliding window
    
    # 4: Perform smoothing with a sliding window of 0.2 seconds
    # this removes high-frequency noise
    
    sliding_window = int(0.2 * sfreq)
    smoothed_envelope = np.convolve(envelope, np.ones(sliding_window) / sliding_window, mode='same')
    # convolving envelope with a uniform filter over the sliding window
    # convolution takes rolling average of 20 samples at a time
    # smooth the signal with the average of values in the window
    # in the smoothed envelope, can detect regions with higher amplitude 
    # which is when a spindle event occurs
    # np.ones: creates a filter kernel
    # have a filter where the sum of all elements equals 1
    # this filter is replaced by the average of the 20 surrounding samples
    # convolution between envelope and averaging filter
    # mode = 'same': so that output of convolution has same length as original envelope

    # 5. Define spindle detection threshold

    threshold = np.percentile(smoothed_envelope, 75)
    spindle_threshold = smoothed_envelope > threshold
    # 75th percentile as criteria

    #threshold = np.mean(smoothed_envelope) + 1.5 * np.std(smoothed_envelope)
    #spindle_threshold = smoothed_envelope > threshold
    
    # 6. Detect spindles and define peaks and troughs for visualisation
    
    spindles = []
    # initialize list with spindles
    above_threshold = np.where(spindle_threshold)[0]
    # returns indices where signal above the threshold
    stacked_spindles = []
    # initialize list for stacking the spindles for the visualisation
    # contains aligned spindles at peak
    
    if len(above_threshold) > 0:
        # checking it's not empty
        start_idx = above_threshold[0]
        # would be the start of a potential spindle
        for i in range(1, len(above_threshold)):
            if above_threshold[i] > above_threshold[i - 1] + 1:  
                # if above threshold[1] > above_threshold[0] + 1
                # because all indices should be separated by 1
                # so here detects gaps
                end_idx = above_threshold[i - 1]
                # so if above condition is true, this is the end of the spindle
                duration = (end_idx - start_idx) / sfreq
                if 0.5 <= duration <= 3:
                    # only keep spindles lasting 0.5 to 3 seconds
                    segment = channel_data[start_idx:end_idx]
                    # extract EEG segment corresponding to detected spindle
                    peak_idx = start_idx + np.argmax(segment) 
                    # extract the peak of the spindle
                    # this will be useful for later
                    #spindles.append(f"Spindle detected from {start_idx / sfreq:.2f}s to {end_idx / sfreq:.2f}s, peak at {peak_idx / sfreq:.2f}s")
                    spindles.append((peak_idx / sfreq))
                    # all the spindles are stored in spindles
                    
                    # Aligning spindles at peak for visualization
                    before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
                    # still in the for loop, so this is the peak index of individual peak
                    after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
                    # extracting 1.5 seconds before and after peak
                    # max and min are used for out of bounds situations at the start and end of EEG data
                    aligned_segment = channel_data[before_peak_idx:after_peak_idx]
                    stacked_spindles.append(aligned_segment)
                    # the aligned segment is saved in stacked spindles
                
                start_idx = above_threshold[i]
                # update the start index for the for loop

        # then need to process the final spindle
        end_idx = above_threshold[-1]
        duration = (end_idx - start_idx) / sfreq
        if 0.5 <= duration <= 3:
            segment = channel_data[start_idx:end_idx]
            peak_idx = start_idx + np.argmax(segment)
            spindles.append((peak_idx / sfreq))

            before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
            after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
            aligned_segment = channel_data[before_peak_idx:after_peak_idx]
            stacked_spindles.append(aligned_segment)

    
    return spindles

### Epochs function

In [6]:
def create_fixed_length_epochs(raw, duration=3.0, overlap=0.0, preload=True, reject_by_annotation=False):

    return mne.make_fixed_length_epochs(
        raw,
        duration=duration,
        overlap=overlap,
        preload=preload,
        reject_by_annotation=reject_by_annotation
    )
# function mne.make_fixed_length_epochs takes into account the sampling frequency of the data


def label_spindle_epochs_moderate(epochs, spindle_starts, spindle_ends, epoch_length_sec=3.0):
    epoch_starts = np.arange(len(epochs)) * epoch_length_sec
    epoch_labels = np.zeros(len(epochs), dtype=int)

    for spindle_start, spindle_end in zip(spindle_starts, spindle_ends):
        spindle_duration = spindle_end - spindle_start
        required_overlap = 0.5 * spindle_duration  
        # only label 1 if epoch contains 50% of the spindle duration

        for i, epoch_start in enumerate(epoch_starts):
            epoch_end = epoch_start + epoch_length_sec

            # Calculate overlap between spindle and epoch
            overlap_start = max(spindle_start, epoch_start)
            overlap_end = min(spindle_end, epoch_end)
            overlap_duration = overlap_end - overlap_start

            if overlap_duration >= required_overlap:
                epoch_labels[i] = 1

    return epoch_labels

### Importing data

In [7]:
# file paths
train_file = r"C:\EEG DATA\combined_sets\train_raw.fif"
test_file = r"C:\EEG DATA\combined_sets\test_raw.fif"

# load raw files
train_raw = mne.io.read_raw_fif(train_file, preload=True)
test_raw = mne.io.read_raw_fif(test_file, preload=True)

Opening raw data file C:\EEG DATA\combined_sets\train_raw.fif...
Isotrak not found
    Range : 1470000 ... 23295072 =   2940.000 ... 46590.144 secs
Ready.
Reading 0 ... 21825072  =      0.000 ... 43650.144 secs...
Opening raw data file C:\EEG DATA\combined_sets\test_raw.fif...
Isotrak not found
    Range : 825000 ... 17985049 =   1650.000 ... 35970.098 secs
Ready.
Reading 0 ... 17160049  =      0.000 ... 34320.098 secs...


## With raw data

### Spindle detection

In [8]:
spindles_train_times_raw_downsampled = detect_spindles_times(train_raw, do_filter=True, do_downsample=True)
spindles_test_times_raw_downsampled = detect_spindles_times(test_raw, do_filter=True, do_downsample=True)

spindles_starts_train_raw_downsampled, spindles_ends_train_raw_downsampled = zip(*spindles_train_times_raw_downsampled) if spindles_train_times_raw_downsampled else([],[])
spindles_starts_test_raw_downsampled, spindles_ends_test_raw_downsampled = zip(*spindles_test_times_raw_downsampled) if spindles_test_times_raw_downsampled else([],[])

print(len(spindles_starts_train_raw_downsampled))
print(len(spindles_ends_train_raw_downsampled))

print(len(spindles_starts_test_raw_downsampled))
print(len(spindles_ends_test_raw_downsampled))

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 73 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 12.00
- Lower transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 10.50 Hz)
- Upper passband edge: 16.00 Hz
- Upper transition bandwidth: 4.00 Hz (-6 dB cutoff frequency: 18.00 Hz)
- Filter length: 551 samples (1.102 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 50 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 pass

### Downsample

In [9]:
train_raw_downsampled = train_raw.copy().resample(100)
test_raw_downsampled = test_raw.copy().resample(100)

print(train_raw_downsampled.info['sfreq'])
print(test_raw_downsampled.info['sfreq'])

100.0
100.0


### Epoch the data

In [10]:
epochs_train_raw_downsampled = create_fixed_length_epochs(train_raw_downsampled)
epochs_test_raw_downsampled = create_fixed_length_epochs(test_raw_downsampled)
print(epochs_train_raw_downsampled.get_data().shape)
print(epochs_test_raw_downsampled.get_data().shape)

Not setting metadata
14550 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 14550 events and 300 original time points ...
0 bad epochs dropped
Not setting metadata
11440 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 11440 events and 300 original time points ...
0 bad epochs dropped
(14550, 1, 300)
(11440, 1, 300)


### Labels for 3-second epochs

In [11]:
%%time

# Train set

epoch_labels_train_raw_downsampled_moderate = label_spindle_epochs_moderate(epochs_train_raw_downsampled, spindles_starts_train_raw_downsampled, spindles_ends_train_raw_downsampled)

print(f"Train data first 10 spindles: {spindles_starts_train_raw_downsampled[:10]}")
print(f"Train labels first 10 epochs: {epoch_labels_train_raw_downsampled_moderate[:10]}")

# Test set

epoch_labels_test_raw_downsampled_moderate = label_spindle_epochs_moderate(epochs_test_raw_downsampled, spindles_starts_test_raw_downsampled, spindles_ends_test_raw_downsampled)

print(f"\nTest data first 10 spindles: {spindles_starts_test_raw_downsampled[:10]}")
print(f"Test labels first 10 epochs: {epoch_labels_test_raw_downsampled_moderate[:10]}")

Train data first 10 spindles: (np.float64(0.05), np.float64(4.95), np.float64(10.01), np.float64(19.41), np.float64(27.23), np.float64(30.01), np.float64(31.05), np.float64(41.95), np.float64(47.79), np.float64(53.14))
Train labels first 10 epochs: [1 1 0 1 0 0 1 0 0 1]

Test data first 10 spindles: (np.float64(29.89), np.float64(49.64), np.float64(61.33), np.float64(78.56), np.float64(87.9), np.float64(88.5), np.float64(97.5), np.float64(112.37), np.float64(119.92), np.float64(147.43))
Test labels first 10 epochs: [0 0 0 0 0 0 0 0 0 0]
CPU times: total: 1min 23s
Wall time: 1min 23s


### Prepare EEG data for CNN input

#### X and y train and test sets

In [12]:
# Reshape arrays

epochs_train_np_raw_downsampled = np.array(epochs_train_raw_downsampled).reshape(len(epochs_train_raw_downsampled), -1, 1)
# number of epochs N, sampling frequency (time dimension automatically inferred), channel dimension
epochs_test_np_raw_downsampled = np.array(epochs_test_raw_downsampled).reshape(len(epochs_test_raw_downsampled), -1, 1)
                                                                 
# Define X and y sets
                                                                 
X_train_raw_downsampled = epochs_train_np_raw_downsampled
y_train_raw_downsampled = epoch_labels_train_raw_downsampled_moderate

X_test_raw_downsampled = epochs_test_np_raw_downsampled
y_test_raw_downsampled = epoch_labels_test_raw_downsampled_moderate

# Print shapes

print(f"X_train shape: {X_train_raw_downsampled.shape}")
print(f"y_train shape: {y_train_raw_downsampled.shape}")

print(f"\nX_test shape: {X_test_raw_downsampled.shape}")
print(f"y_test shape: {y_test_raw_downsampled.shape}")

X_train shape: (14550, 300, 1)
y_train shape: (14550,)

X_test shape: (11440, 300, 1)
y_test shape: (11440,)


### Train the model

In [13]:
early_stop = EarlyStopping(
    monitor='val_loss',      
    patience=5,               
    restore_best_weights=True 
)
# stop after 5 epochs with no improvement

In [14]:
# show model architecture
input_shape = (300, 1)
model_raw_downsampled = build_cnn_model_downsampled(input_shape)
model_raw_downsampled.summary()



In [15]:
%%time

training_info_raw_downsampled = model_raw_downsampled.fit(X_train_raw_downsampled, y_train_raw_downsampled, validation_split=0.2, epochs=20, batch_size=128, callbacks=[early_stop])

Epoch 1/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 118ms/step - accuracy: 0.5338 - loss: 0.6915 - val_accuracy: 0.6093 - val_loss: 0.6750
Epoch 2/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 111ms/step - accuracy: 0.5634 - loss: 0.6833 - val_accuracy: 0.6605 - val_loss: 0.6461
Epoch 3/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 119ms/step - accuracy: 0.5973 - loss: 0.6691 - val_accuracy: 0.6801 - val_loss: 0.5634
Epoch 4/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 118ms/step - accuracy: 0.5934 - loss: 0.6623 - val_accuracy: 0.6378 - val_loss: 0.6628
Epoch 5/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 124ms/step - accuracy: 0.6098 - loss: 0.6523 - val_accuracy: 0.6485 - val_loss: 0.6534
Epoch 6/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 117ms/step - accuracy: 0.5574 - loss: 0.6845 - val_accuracy: 0.6577 - val_loss: 0.6570
Epoch 7/20
[1m91/91[

#### Plot the training history

In [16]:
def plot_training_history(training_info):
    fig, axs = plt.subplots(1, 2, figsize=(16, 5))
    axs[0].plot(training_info.history['loss'], label="training set")
    axs[0].plot(training_info.history['val_loss'], label="validation set")
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Loss")
    axs[0].grid(True)
    axs[0].legend()
    try:
        axs[1].plot(training_info.history['accuracy'], label="training set")
        axs[1].plot(training_info.history['val_accuracy'], label="validation set")
        axs[1].set_xlabel("Epoch")
        axs[1].set_ylabel("Accuracy")
        axs[1].grid(True)
        axs[1].legend()
    except:
        pass
  
    fig.suptitle("Training history for one-input CNN model with raw EEG data for spindle detection", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

plot_training_history(training_info_raw_downsampled)

#### Evaluation on test set

In [17]:
model_raw_downsampled.evaluate(X_test_raw_downsampled, y_test_raw_downsampled)

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.5792 - loss: 0.6605


[0.6644614338874817, 0.5788461565971375]

In [18]:
y_pred = model_raw_downsampled.predict(X_test_raw_downsampled)
y_pred_labels = (y_pred > 0.5).astype(int)

print(confusion_matrix(y_test_raw_downsampled, y_pred_labels))
print(classification_report(y_test_raw_downsampled, y_pred_labels))

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step
[[5755  559]
 [4259  867]]
              precision    recall  f1-score   support

           0       0.57      0.91      0.70      6314
           1       0.61      0.17      0.26      5126

    accuracy                           0.58     11440
   macro avg       0.59      0.54      0.48     11440
weighted avg       0.59      0.58      0.51     11440



In [19]:
# Get predictions
y_pred = model_raw_downsampled.predict(X_test_raw_downsampled)
y_pred_labels = (y_pred > 0.5).astype(int)

# Confusion matrix
cm = confusion_matrix(y_test_raw_downsampled, y_pred_labels)
cm_df = pd.DataFrame(cm, index=["Actual 0", "Actual 1"], columns=["Predicted 0", "Predicted 1"])

# Classification report as a dataframe
report = classification_report(y_test_raw_downsampled, y_pred_labels, output_dict=True)
report_df = pd.DataFrame(report).transpose()

# Confusion matrix plotted as heatmap
plt.figure(figsize=(6, 5))
sns.heatmap(cm_df, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix for Spindle Detection \nusing a one-input CNN with raw EEG data")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.tight_layout()
plt.show()

# Classification report as a table
fig, ax = plt.subplots(figsize=(10, 4))
ax.axis('tight')
ax.axis('off')
table = ax.table(cellText=report_df.round(2).values,
                 colLabels=report_df.columns,
                 rowLabels=report_df.index,
                 cellLoc='center',
                 loc='center')
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1.2, 1.2)
plt.title("Classification Report for Spindle Detection \nusing a one-input CNN with raw EEG data", fontsize=14)
plt.tight_layout()
plt.show()

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 15ms/step


In [20]:
# Flatten in case y_pred has shape (n_samples, 1)
y_pred_proba = y_pred.ravel()

# Compute ROC curve and AUC
fpr, tpr, thresholds = roc_curve(y_test_raw_downsampled, y_pred_proba)
roc_auc = auc(fpr, tpr)

# Plotting
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random chance')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic for Spindle Detection \nusing a one-input CNN with raw EEG data")
plt.legend(loc="lower right")
plt.grid(True)
plt.show()

## With filtered data

In [21]:
# Apply bandpass filter between 12 and 16 Hz
# to compare performance of filtered dataset to unfiltered one
train_filtered = train_raw.copy().filter(l_freq=12, h_freq=16)
test_filtered = test_raw.copy().filter(l_freq=12, h_freq=16)

# Downsample to 100 Hz
train_filtered.resample(100)
test_filtered.resample(100)

Filtering raw data in 73 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 12.00
- Lower transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 10.50 Hz)
- Upper passband edge: 16.00 Hz
- Upper transition bandwidth: 4.00 Hz (-6 dB cutoff frequency: 18.00 Hz)
- Filter length: 551 samples (1.102 s)

Filtering raw data in 50 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 12.00
- Lower transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 10.50 Hz)
- Upper passban

Unnamed: 0,General,General.1
,Filename(s),test_raw.fif
,MNE object type,Raw
,Measurement date,2023-09-06 at 23:28:45 UTC
,Participant,Unknown
,Experimenter,Unknown
,Acquisition,Acquisition
,Duration,09:32:01 (HH:MM:SS)
,Sampling frequency,100.00 Hz
,Time points,3432010
,Channels,Channels


### Spindle detection

In [22]:
spindles_train_times_filtered = detect_spindles_times(train_filtered, do_filter=False, do_downsample=False)
spindles_test_times_filtered = detect_spindles_times(test_filtered, do_filter=False, do_downsample=False)
# since filtering and downsampling before, do not filter and downsample again in function

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


In [23]:
spindles_starts_train_filtered, spindles_ends_train_filtered = zip(*spindles_train_times_filtered) if spindles_train_times_filtered else([],[])
spindles_starts_test_filtered, spindles_ends_test_filtered = zip(*spindles_test_times_filtered) if spindles_test_times_filtered else([],[])

In [24]:
print(len(spindles_starts_train_filtered))
print(len(spindles_ends_train_filtered))

print(len(spindles_starts_test_filtered))
print(len(spindles_ends_test_filtered))

7123
7123
5794
5794


### Epoch the data

In [25]:
epochs_train_filtered = create_fixed_length_epochs(train_filtered)

Not setting metadata
14550 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 14550 events and 300 original time points ...
0 bad epochs dropped


In [26]:
epochs_test_filtered = create_fixed_length_epochs(test_filtered)

Not setting metadata
11440 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 11440 events and 300 original time points ...
0 bad epochs dropped


In [27]:
print(epochs_train_filtered.get_data().shape)
print(epochs_test_filtered.get_data().shape)

(14550, 1, 300)
(11440, 1, 300)


### Labels for 3-second epochs

In [28]:
%%time

# Train set

epoch_labels_train_filtered_moderate = label_spindle_epochs_moderate(epochs_train_filtered, spindles_starts_train_filtered, spindles_ends_train_filtered)

print(f"Train data first 10 spindles: {spindles_starts_train_filtered[:10]}")
print(f"Train labels first 10 epochs: {epoch_labels_train_filtered_moderate[:10]}")

# Test set

epoch_labels_test_filtered_moderate = label_spindle_epochs_moderate(epochs_test_filtered, spindles_starts_test_filtered, spindles_ends_test_filtered)

print(f"\nTest data first 10 spindles: {spindles_starts_test_filtered[:10]}")
print(f"Test labels first 10 epochs: {epoch_labels_test_filtered_moderate[:10]}")

Train data first 10 spindles: (np.float64(0.05), np.float64(4.95), np.float64(10.01), np.float64(19.41), np.float64(27.23), np.float64(30.01), np.float64(31.05), np.float64(41.95), np.float64(47.79), np.float64(53.14))
Train labels first 10 epochs: [1 1 0 1 0 0 1 0 0 1]

Test data first 10 spindles: (np.float64(29.89), np.float64(49.64), np.float64(61.33), np.float64(78.56), np.float64(87.9), np.float64(88.5), np.float64(97.5), np.float64(112.37), np.float64(119.92), np.float64(147.43))
Test labels first 10 epochs: [0 0 0 0 0 0 0 0 0 0]
CPU times: total: 1min 21s
Wall time: 1min 21s


### Prepare EEG data for CNN input

#### X and y train and test sets

In [29]:
# Reshape arrays

epochs_train_np_filtered = np.array(epochs_train_filtered).reshape(len(epochs_train_filtered), -1, 1)
# number of epochs N, sampling frequency (time dimension automatically inferred), channel dimension
epochs_test_np_filtered = np.array(epochs_test_filtered).reshape(len(epochs_test_filtered), -1, 1)
                                                                 
# Define X and y sets
                                                                 
X_train_filtered = epochs_train_np_filtered
y_train_filtered = epoch_labels_train_filtered_moderate

X_test_filtered = epochs_test_np_filtered
y_test_filtered = epoch_labels_test_filtered_moderate

# Print shapes

print(f"X_train shape: {X_train_filtered.shape}")
print(f"y_train shape: {y_train_filtered.shape}")

print(f"\nX_test shape: {X_test_filtered.shape}")
print(f"y_test shape: {y_test_filtered.shape}")
                                                                 

X_train shape: (14550, 300, 1)
y_train shape: (14550,)

X_test shape: (11440, 300, 1)
y_test shape: (11440,)


#### Training

In [30]:
early_stop = EarlyStopping(
    monitor='val_loss',      
    patience=5,               
    restore_best_weights=True 
)
# stop after 5 epochs with no improvement

In [31]:
# show model architecture
input_shape = (300, 1)
model_filtered_downsampled = build_cnn_model_downsampled(input_shape)
model_filtered_downsampled.summary()



In [32]:
%%time

training_info_filtered_downsampled = model_filtered_downsampled.fit(X_train_filtered, y_train_filtered, validation_split=0.2, epochs=20, batch_size=128, callbacks=[early_stop])

Epoch 1/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 116ms/step - accuracy: 0.5944 - loss: 0.6711 - val_accuracy: 0.6076 - val_loss: 0.6743
Epoch 2/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 112ms/step - accuracy: 0.6373 - loss: 0.6473 - val_accuracy: 0.6021 - val_loss: 0.6717
Epoch 3/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 112ms/step - accuracy: 0.6030 - loss: 0.6644 - val_accuracy: 0.6436 - val_loss: 0.6531
Epoch 4/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 115ms/step - accuracy: 0.6140 - loss: 0.6601 - val_accuracy: 0.6522 - val_loss: 0.6443
Epoch 5/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 110ms/step - accuracy: 0.6148 - loss: 0.6574 - val_accuracy: 0.6601 - val_loss: 0.6388
Epoch 6/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 131ms/step - accuracy: 0.6199 - loss: 0.6541 - val_accuracy: 0.6708 - val_loss: 0.6302
Epoch 7/20
[1m91/91[

In [33]:
def plot_training_history(training_info):
    fig, axs = plt.subplots(1, 2, figsize=(16, 5))
    axs[0].plot(training_info.history['loss'], label="training set")
    axs[0].plot(training_info.history['val_loss'], label="validation set")
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Loss")
    axs[0].grid(True)
    axs[0].legend()
    try:
        axs[1].plot(training_info.history['accuracy'], label="training set")
        axs[1].plot(training_info.history['val_accuracy'], label="validation set")
        axs[1].set_xlabel("Epoch")
        axs[1].set_ylabel("Accuracy")
        axs[1].grid(True)
        axs[1].legend()
    except:
        pass
  
    fig.suptitle("Training history for one-input CNN model with filtered EEG data for spindle detection", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

plot_training_history(training_info_filtered_downsampled)

In [34]:
model_filtered_downsampled.evaluate(X_test_filtered, y_test_filtered)

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step - accuracy: 0.9201 - loss: 0.2041


[0.21022364497184753, 0.9190559387207031]

### Metrics

In [35]:
y_pred = model_filtered_downsampled.predict(X_test_filtered)
y_pred_labels = (y_pred > 0.5).astype(int)

print(confusion_matrix(y_test_filtered, y_pred_labels))
print(classification_report(y_test_filtered, y_pred_labels))

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 15ms/step
[[6103  211]
 [ 715 4411]]
              precision    recall  f1-score   support

           0       0.90      0.97      0.93      6314
           1       0.95      0.86      0.91      5126

    accuracy                           0.92     11440
   macro avg       0.92      0.91      0.92     11440
weighted avg       0.92      0.92      0.92     11440



In [36]:
# Get predictions
y_pred = model_filtered_downsampled.predict(X_test_filtered)
y_pred_labels = (y_pred > 0.5).astype(int)

# Confusion matrix
cm = confusion_matrix(y_test_raw_downsampled, y_pred_labels)
cm_df = pd.DataFrame(cm, index=["Actual 0", "Actual 1"], columns=["Predicted 0", "Predicted 1"])

# Classification report as a dataframe
report = classification_report(y_test_raw_downsampled, y_pred_labels, output_dict=True)
report_df = pd.DataFrame(report).transpose()

# Confusion matrix plotted as heatmap
plt.figure(figsize=(6, 5))
sns.heatmap(cm_df, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix for Spindle Detection \nusing a one-input CNN with filtered EEG data (12-16 Hz)")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.tight_layout()
plt.show()

# Classification report as a table
fig, ax = plt.subplots(figsize=(10, 4))
ax.axis('tight')
ax.axis('off')
table = ax.table(cellText=report_df.round(2).values,
                 colLabels=report_df.columns,
                 rowLabels=report_df.index,
                 cellLoc='center',
                 loc='center')
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1.2, 1.2)
plt.title("Classification Report for Spindle Detection \nusing a one-input CNN with filtered EEG data (12-16 Hz)", fontsize=14)
plt.tight_layout()
plt.show()


[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 14ms/step


In [37]:
# Flatten in case y_pred has shape (n_samples, 1)
y_pred_proba = y_pred.ravel()

# Compute ROC curve and AUC
fpr, tpr, thresholds = roc_curve(y_test_filtered, y_pred_proba)
roc_auc = auc(fpr, tpr)

# Plotting
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random chance')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic for Spindle Detection \nusing a one-input CNN with filtered EEG data (12-16 Hz)")
plt.legend(loc="lower right")
plt.grid(True)
plt.show()

## With STFT

In [38]:
spindles_train_times_stft_downsampled = detect_spindles_times(train_raw, do_filter=True, do_downsample=True)
spindles_test_times_stft_downsampled = detect_spindles_times(test_raw, do_filter=True, do_downsample=True)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 73 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 12.00
- Lower transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 10.50 Hz)
- Upper passband edge: 16.00 Hz
- Upper transition bandwidth: 4.00 Hz (-6 dB cutoff frequency: 18.00 Hz)
- Filter length: 551 samples (1.102 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 50 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 pass

In [39]:
spindles_starts_train_stft_downsampled, spindles_ends_train_stft_downsampled = zip(*spindles_train_times_stft_downsampled) if spindles_train_times_stft_downsampled else([],[])
spindles_starts_test_stft_downsampled, spindles_ends_test_stft_downsampled = zip(*spindles_test_times_stft_downsampled) if spindles_test_times_stft_downsampled else([],[])

print(len(spindles_starts_train_stft_downsampled))
print(len(spindles_ends_train_stft_downsampled))

print(len(spindles_starts_test_stft_downsampled))
print(len(spindles_ends_test_stft_downsampled))

7123
7123
5794
5794


### Downsample

In [40]:
train_raw_downsampled = train_raw.copy().resample(100)
test_raw_downsampled = test_raw.copy().resample(100)

print(train_raw_downsampled.info['sfreq'])
print(test_raw_downsampled.info['sfreq'])

100.0
100.0


### Epoch the data

In [41]:
epochs_train_stft_downsampled = create_fixed_length_epochs(train_raw_downsampled)
epochs_test_stft_downsampled = create_fixed_length_epochs(test_raw_downsampled)

print(epochs_train_stft_downsampled.get_data().shape)
print(epochs_test_stft_downsampled.get_data().shape)

Not setting metadata
14550 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 14550 events and 300 original time points ...
0 bad epochs dropped
Not setting metadata
11440 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 11440 events and 300 original time points ...
0 bad epochs dropped
(14550, 1, 300)
(11440, 1, 300)


### Labels for 3-second epochs

In [42]:
%%time

# Train set

epoch_labels_train_stft_downsampled_moderate = label_spindle_epochs_moderate(epochs_train_stft_downsampled, spindles_starts_train_stft_downsampled, spindles_ends_train_stft_downsampled)

print(f"Train data first 10 spindles: {spindles_starts_train_stft_downsampled[:10]}")
print(f"Train labels first 10 epochs: {epoch_labels_train_stft_downsampled_moderate[:10]}")

# Test set

epoch_labels_test_stft_downsampled_moderate = label_spindle_epochs_moderate(epochs_test_stft_downsampled, spindles_starts_test_stft_downsampled, spindles_ends_test_stft_downsampled)

print(f"\nTest data first 10 spindles: {spindles_starts_test_stft_downsampled[:10]}")
print(f"Test labels first 10 epochs: {epoch_labels_test_stft_downsampled_moderate[:10]}")

Train data first 10 spindles: (np.float64(0.05), np.float64(4.95), np.float64(10.01), np.float64(19.41), np.float64(27.23), np.float64(30.01), np.float64(31.05), np.float64(41.95), np.float64(47.79), np.float64(53.14))
Train labels first 10 epochs: [1 1 0 1 0 0 1 0 0 1]

Test data first 10 spindles: (np.float64(29.89), np.float64(49.64), np.float64(61.33), np.float64(78.56), np.float64(87.9), np.float64(88.5), np.float64(97.5), np.float64(112.37), np.float64(119.92), np.float64(147.43))
Test labels first 10 epochs: [0 0 0 0 0 0 0 0 0 0]
CPU times: total: 1min 9s
Wall time: 1min 9s


### Apply STFT to each epoch

#### get rid of channel dimension

In [43]:
epochs_train_stft_downsampled = np.squeeze(epochs_train_stft_downsampled)
epochs_test_stft_downsampled = np.squeeze(epochs_test_stft_downsampled)

print(epochs_train_stft_downsampled.shape)
print(epochs_test_stft_downsampled.shape)

(14550, 300)
(11440, 300)


In [44]:
print(train_raw_downsampled.info['sfreq'])

100.0


In [45]:
fs = train_raw_downsampled.info['sfreq']  
# smaller nperseg means higher time resolution
# noverlap must be less than nperseg

# I want to focus more on the frequency resolution
# at the expense of the time resolution
# for spindles, want a 1-2 Hz resolution (frequency bins every 1-2 Hz)
# let's start with 2 Hz first 
nperseg = 50
noverlap = nperseg // 2
# common practice is to set noverlap to 50% of nperseg


epochs_train_stft_transformed_downsampled = []

for epoch in epochs_train_stft_downsampled:
    f, t, Zxx = stft(epoch, fs=fs, nperseg=nperseg, noverlap=noverlap)
    spectrogram = np.abs(Zxx)  
    epochs_train_stft_transformed_downsampled.append(spectrogram)

epochs_test_stft_transformed_downsampled = []

for epoch in epochs_test_stft_downsampled:
    f, t, Zxx = stft(epoch, fs=fs, nperseg=nperseg, noverlap=noverlap)
    spectrogram = np.abs(Zxx)  
    epochs_test_stft_transformed_downsampled.append(spectrogram)
    
# convert into numpy arrays
epochs_train_stft_transformed_downsampled = np.array(epochs_train_stft_transformed_downsampled)
epochs_test_stft_transformed_downsampled = np.array(epochs_test_stft_transformed_downsampled)

print("Train STFT shape:", epochs_train_stft_transformed_downsampled.shape)
print("Test STFT shape:", epochs_test_stft_transformed_downsampled.shape)

# shape is number of epochs, frequency_bins, time_bins


Train STFT shape: (14550, 26, 13)
Test STFT shape: (11440, 26, 13)


### X and y train and test sets

In [46]:
# Define X and y sets and reshape

X_train_stft_downsampled = epochs_train_stft_transformed_downsampled[..., np.newaxis]  # Shape: (14550, 65, 25, 1)
y_train_stft_downsampled = epoch_labels_train_stft_downsampled_moderate

X_test_stft_downsampled = epochs_test_stft_transformed_downsampled[..., np.newaxis]    # Shape: (11440, 65, 25, 1)
y_test_stft_downsampled = epoch_labels_test_stft_downsampled_moderate
                                                                 

# Print shapes

print(f"X_train shape: {X_train_stft_downsampled.shape}")
print(f"y_train shape: {y_train_stft_downsampled.shape}")

print(f"\nX_test shape: {X_test_stft_downsampled.shape}")
print(f"y_test shape: {y_test_stft_downsampled.shape}")

X_train shape: (14550, 26, 13, 1)
y_train shape: (14550,)

X_test shape: (11440, 26, 13, 1)
y_test shape: (11440,)


### Normalization of data

In [47]:
# this scales the spectrogram in range 0,1
# this is min-max normalisation

X_train_stft_norm_downsampled = np.array([
    (epoch - np.min(epoch)) / (np.max(epoch) - np.min(epoch) + 1e-8)
    for epoch in X_train_stft_downsampled
])

X_test_stft_norm_downsampled = np.array([
    (epoch - np.min(epoch)) / (np.max(epoch) - np.min(epoch) + 1e-8)
    for epoch in X_test_stft_downsampled
])

In [48]:
# should have values between 0 and 1 

print("Before normalisation:")
print("Max train value:", np.max(X_train_stft_downsampled))
print("Min train value:", np.min(X_train_stft_downsampled))

print("Max test value:", np.max(X_test_stft_downsampled))
print("Min test value:", np.min(X_test_stft_downsampled))

print("\nAfter normalisation:")
print("Max train value:", np.max(X_train_stft_norm_downsampled))
print("Min train value:", np.min(X_train_stft_norm_downsampled))

print("Max test value:", np.max(X_test_stft_norm_downsampled))
print("Min test value:", np.min(X_test_stft_norm_downsampled))

Before normalisation:
Max train value: 866.3673727627713
Min train value: 2.4061107274064854e-06
Max test value: 1011.6755263560262
Min test value: 5.496909294180341e-06

After normalisation:
Max train value: 0.9999999999884571
Min train value: 0.0
Max test value: 0.9999999999901138
Min test value: 0.0


## Three-channel CNN

### Reducing STFT dimension

The STFT shape is (number of epochs, time windows, frequency bins, channels).
If you only keep the frequency dimension, it will be (number of epochs, frequency bins, channels), so it will be all the frequencies for the average time.

#### Only keeping frequency axis

In [49]:
X_train_stft_freq = np.mean(X_train_stft_norm_downsampled, axis=1) 
X_test_stft_freq = np.mean(X_test_stft_norm_downsampled, axis=1)

print(f"Shape of X_train_stft_freq:{X_train_stft_freq.shape}")
print(f"Shape of X_test_stft_freq:{X_test_stft_freq.shape}")

Shape of X_train_stft_freq:(14550, 13, 1)
Shape of X_test_stft_freq:(11440, 13, 1)


### 3-input model with frequency for the STFT

In [50]:
X_train_dict_freq = {
    'raw_input': X_train_raw_downsampled,
    'filtered_input': X_train_filtered,
    'stft_input': X_train_stft_freq
}

X_test_dict_freq = {
    'raw_input': X_test_raw_downsampled,
    'filtered_input': X_test_filtered,
    'stft_input': X_test_stft_freq
}

In [51]:
cnn_model_multi_input_freq = build_multi_input_cnn_model_freq()
cnn_model_multi_input_freq.summary()

In [52]:
early_stop = EarlyStopping(
    monitor='val_loss',      
    patience=5,               
    restore_best_weights=True 
)
# stop after 5 epochs with no improvement

In [53]:
%%time

training_info_multiple_inputs_freq = cnn_model_multi_input_freq.fit(X_train_dict_freq, y_train_raw_downsampled, validation_split=0.2, epochs=20, batch_size=128, callbacks=[early_stop])

Epoch 1/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 251ms/step - accuracy: 0.5999 - loss: 0.6663 - val_accuracy: 0.4265 - val_loss: 1.0271
Epoch 2/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 250ms/step - accuracy: 0.7577 - loss: 0.4870 - val_accuracy: 0.8570 - val_loss: 0.3692
Epoch 3/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 277ms/step - accuracy: 0.8187 - loss: 0.4450 - val_accuracy: 0.8509 - val_loss: 0.3417
Epoch 4/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 270ms/step - accuracy: 0.7220 - loss: 0.5128 - val_accuracy: 0.6942 - val_loss: 0.5713
Epoch 5/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 261ms/step - accuracy: 0.7259 - loss: 0.5354 - val_accuracy: 0.8457 - val_loss: 0.3354
Epoch 6/20
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 253ms/step - accuracy: 0.8519 - loss: 0.3267 - val_accuracy: 0.9110 - val_loss: 0.2118
Epoch 7/20
[1m91/91[

In [54]:
def plot_training_history(training_info):
  fig, axs = plt.subplots(1, 2, figsize=(16, 5))
  axs[0].plot(training_info.history['loss'], label="training set")
  axs[0].plot(training_info.history['val_loss'], label="validation set")
  axs[0].set_xlabel("Epoch")
  axs[0].set_ylabel("Loss")
  axs[0].grid(True)
  axs[0].legend()
  try:
    axs[1].plot(training_info.history['accuracy'], label="training set")
    axs[1].plot(training_info.history['val_accuracy'], label="validation set")
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Accuracy")
    axs[1].grid(True)
    axs[1].legend()
  except:
    pass

    fig.suptitle("Training History for three-input CNN model with frequency component of STFT for spindle detection", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

plot_training_history(training_info_multiple_inputs_freq)

In [55]:
cnn_model_multi_input_freq.evaluate(X_test_dict_freq, y_test_raw_downsampled)

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 29ms/step - accuracy: 0.9146 - loss: 0.2334


[0.23678061366081238, 0.913811206817627]

In [56]:
y_pred = cnn_model_multi_input_freq.predict(X_test_dict_freq)
y_pred_labels = (y_pred > 0.5).astype(int)

print(confusion_matrix(y_test_raw_downsampled, y_pred_labels))
print(classification_report(y_test_raw_downsampled, y_pred_labels))

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 32ms/step
[[6121  193]
 [ 793 4333]]
              precision    recall  f1-score   support

           0       0.89      0.97      0.93      6314
           1       0.96      0.85      0.90      5126

    accuracy                           0.91     11440
   macro avg       0.92      0.91      0.91     11440
weighted avg       0.92      0.91      0.91     11440



In [57]:
# Get predictions
y_pred = cnn_model_multi_input_freq.predict(X_test_dict_freq)
y_pred_labels = (y_pred > 0.5).astype(int)

# Confusion matrix
cm = confusion_matrix(y_test_raw_downsampled, y_pred_labels)
cm_df = pd.DataFrame(cm, index=["Actual 0", "Actual 1"], columns=["Predicted 0", "Predicted 1"])

# Classification report as a dataframe
report = classification_report(y_test_raw_downsampled, y_pred_labels, output_dict=True)
report_df = pd.DataFrame(report).transpose()

# Confusion matrix plotted as heatmap
plt.figure(figsize=(6, 5))
sns.heatmap(cm_df, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix for Spindle Detection \nusing a three-input CNN with frequency \ncomponent of STFT")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.tight_layout()
plt.show()

# Classification report as a table
fig, ax = plt.subplots(figsize=(10, 4))
ax.axis('tight')
ax.axis('off')
table = ax.table(cellText=report_df.round(2).values,
                 colLabels=report_df.columns,
                 rowLabels=report_df.index,
                 cellLoc='center',
                 loc='center')
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1.2, 1.2)
plt.title("Classification Report for Spindle Detection \nusing a three-input CNN with frequency \ncomponent of STFT", fontsize=14)
plt.tight_layout()
plt.show()


[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 28ms/step


In [58]:
# Flatten in case y_pred has shape (n_samples, 1)
y_pred_proba = y_pred.ravel()

# Compute ROC curve and AUC
fpr, tpr, thresholds = roc_curve(y_test_raw_downsampled, y_pred_proba)
roc_auc = auc(fpr, tpr)

# Plotting
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random chance')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic for a three-input CNN model \nwith the STFT frequency component")
plt.legend(loc="lower right")
plt.grid(True)
plt.tight_layout()
plt.show()