In [1]:
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
import mne
import matplotlib.pyplot as plt
import pyvista
import ipywidgets
import ipyevents
import pyvistaqt
import yasa

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, f1_score, roc_curve, auc
from sklearn.utils import class_weight

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping

import scipy.signal as signal
from scipy.signal import hilbert

In [2]:
%matplotlib qt

### CNN Model

In [3]:
def build_cnn_model(input_shape=(300,1)):

    # linear embedding layer
    input_layer = tf.keras.layers.Input(shape=input_shape)

    # Three convolutional blocks (like having three pattern detectors)

    # First convolution block, kernel size of 5
    padded1 = tf.keras.layers.ZeroPadding1D(padding=2)(input_layer)
    conv1 = tf.keras.layers.Conv1D(filters=10, kernel_size=5, strides=1, padding='valid')(padded1)
    # each filter learns a different type of short-time feature
    # stride of 1, moves one step at a time
    conv1 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv1)
    conv1 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv1)
    # K = 2
    conv1 = tf.keras.layers.BatchNormalization()(conv1)

    # Second convolution block, kernel size of 11
    padded2 = tf.keras.layers.ZeroPadding1D(padding=5)(input_layer)
    conv2 = tf.keras.layers.Conv1D(filters=10, kernel_size=11, strides=1, padding='valid')(padded2)
    conv2 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv2)
    conv2 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv2)
    conv2 = tf.keras.layers.BatchNormalization()(conv2)

    # Third convolution block, kernel size of 21
    padded3 = tf.keras.layers.ZeroPadding1D(padding=10)(input_layer)
    conv3 = tf.keras.layers.Conv1D(filters=10, kernel_size=21, strides=1, padding='valid')(padded3)
    conv3 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv3)
    conv3 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv3)
    conv3 = tf.keras.layers.BatchNormalization()(conv3)

    # Concatenate the outputs of all blocks
    concatenated = tf.keras.layers.Concatenate()([conv1, conv2, conv3])

    # GRU Layer
    gru = tf.keras.layers.GRU(64)(concatenated)

    # Fully connected (dense) layer
    dense = tf.keras.layers.Dense(64, activation='relu')(gru)
    # add a Dropout layer to prevent overfitting
    #dense = tf.keras.layers.Dropout(0.5)(dense)

    # Two softmax outputs for dual-task classification
    #output_task1 = tf.keras.layers.Dense(2, activation='softmax', name='task1')(dense)
    #output_task2 = tf.keras.layers.Dense(2, activation='softmax', name='task2')(dense)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(dense)

    # Create the model
    #model = tf.keras.models.Model(inputs=input_layer, outputs=[output_task1, output_task2])
    model = tf.keras.models.Model(inputs=input_layer, outputs=output)

    # Compile the model
    #model.compile(optimizer='adam', loss={'task1': 'categorical_crossentropy', 'task2': 'categorical_crossentropy'}, metrics={'task1': 'accuracy', 'task2': 'accuracy'})
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    # Return the compiled model
    return model

# show model architecture
input_shape = (300, 1)
model = build_cnn_model(input_shape)
model.summary()



### Spindle detection

In [4]:
def detect_spindles_times(eeg_raw, do_filter=True, do_downsample=False, downsample_rate=100):
    
    # 1. Filter between 12 and 16 Hz
    
    data = eeg_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=12, h_freq=16)
    
    # 2. Downsample at 100 Hz (100 samples per second)

    if do_downsample:
        data.resample(downsample_rate)
    
    sfreq = data.info['sfreq']  
    channel_data = data.get_data()[0]
    # extract the filtered data
    
    
    # 3: Calculate amplitude by applying Hilbert transformation

    hilbert_signal = hilbert(channel_data)
    # apply hilbert transformation to bandpassed data
    # gives analytic signal with amplitude and phase information
    envelope = np.abs(hilbert_signal)
    # take the absolute part of the hilbert signal
    # also the instantaneous power of the signal
    # gives the envelope: amplitude modulation
    # how strength of oscillations change over time
    # size of sliding window
    
    # 4: Perform smoothing with a sliding window of 0.2 seconds
    # this removes high-frequency noise
    
    sliding_window = int(0.2 * sfreq)
    smoothed_envelope = np.convolve(envelope, np.ones(sliding_window) / sliding_window, mode='same')
    # convolving envelope with a uniform filter over the sliding window
    # convolution takes rolling average of 20 samples at a time
    # smooth the signal with the average of values in the window
    # in the smoothed envelope, can detect regions with higher amplitude 
    # which is when a spindle event occurs
    # np.ones: creates a filter kernel
    # have a filter where the sum of all elements equals 1
    # this filter is replaced by the average of the 20 surrounding samples
    # convolution between envelope and averaging filter
    # mode = 'same': so that output of convolution has same length as original envelope

    # 5. Define spindle detection threshold

    threshold = np.percentile(smoothed_envelope, 75)
    spindle_threshold = smoothed_envelope > threshold
    #threshold = np.mean(smoothed_envelope) + 1.5 * np.std(smoothed_envelope)
    #spindle_threshold = smoothed_envelope > threshold
    # threshold is 75th percentile of the smoothed envelope
    # will look at the duration later
    
    # 6. Detect spindles and define peaks and troughs for visualisation
    
    spindles = []
    # initialize list with spindles
    above_threshold = np.where(spindle_threshold)[0]
    # returns indices where signal above the threshold
    stacked_spindles = []
    # initialize list for stacking the spindles for the visualisation
    # contains aligned spindles at peak
    
    if len(above_threshold) > 0:
        # checking it's not empty
        start_idx = above_threshold[0]
        # would be the start of a potential spindle
        for i in range(1, len(above_threshold)):
            if above_threshold[i] > above_threshold[i - 1] + 1:  
                # if above threshold[1] > above_threshold[0] + 1
                # because all indices should be separated by 1
                # so here detects gaps
                # so starting from the second index
                # and comparing each index to the one before
                end_idx = above_threshold[i - 1]
                # so if above condition is true, this is the end of the spindle
                duration = (end_idx - start_idx) / sfreq
                if 0.5 <= duration <= 3:
                    # only keep spindles lasting 0.5 to 3 seconds
                    segment = channel_data[start_idx:end_idx]
                    # extract EEG segment corresponding to detected spindle
                    peak_idx = start_idx + np.argmax(segment) 
                    # extract the peak of the spindle
                    # this will be useful for later
                    spindles.append((start_idx / sfreq, end_idx / sfreq))
                    # all the spindles are stored in spindles
                    
                    # Aligning spindles at peak for visualization
                    before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
                    # still in the for loop, so this is the peak index of individual peak
                    after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
                    # extracting 1.5 seconds before and after peak
                    # max and min are used for out of bounds situations at the start and end of EEG data
                    aligned_segment = channel_data[before_peak_idx:after_peak_idx]
                    stacked_spindles.append(aligned_segment)
                    # the aligned segment is saved in stacked spindles
                
                start_idx = above_threshold[i]
                # update the start index for the for loop

        # then need to process the final spindle
        end_idx = above_threshold[-1]
        duration = (end_idx - start_idx) / sfreq
        if 0.5 <= duration <= 3:
            segment = channel_data[start_idx:end_idx]
            peak_idx = start_idx + np.argmax(segment)
            spindles.append((start_idx / sfreq, end_idx / sfreq))

            before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
            after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
            aligned_segment = channel_data[before_peak_idx:after_peak_idx]
            stacked_spindles.append(aligned_segment)
    
    return spindles
    

def detect_spindles_peaks(eeg_raw, do_filter=True, do_downsample=False, downsample_rate=100):
    
    # 1. Filter between 12 and 16 Hz
    
    data = eeg_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=12, h_freq=16)
    
    # 2. Downsample at 100 Hz (100 samples per second)
    
    if do_downsample:
        data.resample(downsample_rate)
        
    sfreq = data.info['sfreq']  
    # update to new sampling frequency
    # because used later in the code
    channel_data = data.get_data()[0]
    # extract the filtered data
    
    # 3: Calculate amplitude by applying Hilbert transformation

    hilbert_signal = hilbert(channel_data)
    # apply hilbert transformation to bandpassed data
    # gives analytic signal with amplitude and phase information
    envelope = np.abs(hilbert_signal)
    # take the absolute part of the hilbert signal
    # also the instantaneous power of the signal
    # gives the envelope: amplitude modulation
    # how strength of oscillations change over time
    # size of sliding window
    
    # 4: Perform smoothing with a sliding window of 0.2 seconds
    # this removes high-frequency noise
    
    sliding_window = int(0.2 * sfreq)
    smoothed_envelope = np.convolve(envelope, np.ones(sliding_window) / sliding_window, mode='same')
    # convolving envelope with a uniform filter over the sliding window
    # convolution takes rolling average of 20 samples at a time
    # smooth the signal with the average of values in the window
    # in the smoothed envelope, can detect regions with higher amplitude 
    # which is when a spindle event occurs
    # np.ones: creates a filter kernel
    # have a filter where the sum of all elements equals 1
    # this filter is replaced by the average of the 20 surrounding samples
    # convolution between envelope and averaging filter
    # mode = 'same': so that output of convolution has same length as original envelope

    # 5. Define spindle detection threshold

    threshold = np.percentile(smoothed_envelope, 75)
    spindle_threshold = smoothed_envelope > threshold
    # 75th percentile as criteria

    #threshold = np.mean(smoothed_envelope) + 1.5 * np.std(smoothed_envelope)
    #spindle_threshold = smoothed_envelope > threshold
    
    # 6. Detect spindles and define peaks and troughs for visualisation
    
    spindles = []
    # initialize list with spindles
    above_threshold = np.where(spindle_threshold)[0]
    # returns indices where signal above the threshold
    stacked_spindles = []
    # initialize list for stacking the spindles for the visualisation
    # contains aligned spindles at peak
    
    if len(above_threshold) > 0:
        # checking it's not empty
        start_idx = above_threshold[0]
        # would be the start of a potential spindle
        for i in range(1, len(above_threshold)):
            if above_threshold[i] > above_threshold[i - 1] + 1:  
                # if above threshold[1] > above_threshold[0] + 1
                # because all indices should be separated by 1
                # so here detects gaps
                end_idx = above_threshold[i - 1]
                # so if above condition is true, this is the end of the spindle
                duration = (end_idx - start_idx) / sfreq
                if 0.5 <= duration <= 3:
                    # only keep spindles lasting 0.5 to 3 seconds
                    segment = channel_data[start_idx:end_idx]
                    # extract EEG segment corresponding to detected spindle
                    peak_idx = start_idx + np.argmax(segment) 
                    # extract the peak of the spindle
                    # this will be useful for later
                    #spindles.append(f"Spindle detected from {start_idx / sfreq:.2f}s to {end_idx / sfreq:.2f}s, peak at {peak_idx / sfreq:.2f}s")
                    spindles.append((peak_idx / sfreq))
                    # all the spindles are stored in spindles
                    
                    # Aligning spindles at peak for visualization
                    before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
                    # still in the for loop, so this is the peak index of individual peak
                    after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
                    # extracting 1.5 seconds before and after peak
                    # max and min are used for out of bounds situations at the start and end of EEG data
                    aligned_segment = channel_data[before_peak_idx:after_peak_idx]
                    stacked_spindles.append(aligned_segment)
                    # the aligned segment is saved in stacked spindles
                
                start_idx = above_threshold[i]
                # update the start index for the for loop

        # then need to process the final spindle
        end_idx = above_threshold[-1]
        duration = (end_idx - start_idx) / sfreq
        if 0.5 <= duration <= 3:
            segment = channel_data[start_idx:end_idx]
            peak_idx = start_idx + np.argmax(segment)
            spindles.append((peak_idx / sfreq))

            before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
            after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
            aligned_segment = channel_data[before_peak_idx:after_peak_idx]
            stacked_spindles.append(aligned_segment)

    
    return spindles

### Epochs

In [5]:
def create_fixed_length_epochs(raw, duration=3.0, overlap=0.0, preload=True, reject_by_annotation=False):

    return mne.make_fixed_length_epochs(
        raw,
        duration=duration,
        overlap=overlap,
        preload=preload,
        reject_by_annotation=reject_by_annotation
    )
# function mne.make_fixed_length_epochs takes into account the sampling frequency of the data

def label_spindle_epochs(epochs, spindle_starts, spindle_ends, epoch_length_sec=3.0):

    epoch_starts = np.arange(len(epochs)) * epoch_length_sec
    # new np array with the start time of each epoch
    # epoch_starts[i] is the start time of each epoch

    epoch_labels = np.zeros(len(epochs), dtype=int)
    # initialize all the labels as 0 initially

    for start, end in zip(spindle_starts, spindle_ends):
        # loop through the start and end times of detected spindles by YASA
        for i, epoch_start in enumerate(epoch_starts):
            # loop through the one-second epochs that are not labelled yet
            epoch_end = epoch_start + epoch_length_sec
            # for each epoch, calculate the epoch end time
            # which is epoch_start + length of epoch
            # so now have the time range of each epoch
            if (start < epoch_end) and (end > epoch_start):
                # if the spindle started before the epoch ends
                # and the spindle ended after the epoch started
                epoch_labels[i] = 1
                
    return epoch_labels

### Importing data

In [6]:
# file paths
train_file = r"C:\EEG DATA\combined_sets\train_raw.fif"
test_file = r"C:\EEG DATA\combined_sets\test_raw.fif"

# load raw files
train_raw = mne.io.read_raw_fif(train_file, preload=True)
test_raw = mne.io.read_raw_fif(test_file, preload=True)

Opening raw data file C:\EEG DATA\combined_sets\train_raw.fif...
Isotrak not found
    Range : 1470000 ... 23295072 =   2940.000 ... 46590.144 secs
Ready.
Reading 0 ... 21825072  =      0.000 ... 43650.144 secs...
Opening raw data file C:\EEG DATA\combined_sets\test_raw.fif...
Isotrak not found
    Range : 825000 ... 17985049 =   1650.000 ... 35970.098 secs
Ready.
Reading 0 ... 17160049  =      0.000 ... 34320.098 secs...


## With filtered data

In [7]:
# Apply bandpass filter between 12 and 16 Hz
# to compare performance of filtered dataset to unfiltered one
train_filtered = train_raw.copy().filter(l_freq=12, h_freq=16)
test_filtered = test_raw.copy().filter(l_freq=12, h_freq=16)

# Downsample to 100 Hz
train_filtered.resample(100)
test_filtered.resample(100)

Filtering raw data in 73 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 12.00
- Lower transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 10.50 Hz)
- Upper passband edge: 16.00 Hz
- Upper transition bandwidth: 4.00 Hz (-6 dB cutoff frequency: 18.00 Hz)
- Filter length: 551 samples (1.102 s)

Filtering raw data in 50 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 12.00
- Lower transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 10.50 Hz)
- Upper passban

Unnamed: 0,General,General.1
,Filename(s),test_raw.fif
,MNE object type,Raw
,Measurement date,2023-09-06 at 23:28:45 UTC
,Participant,Unknown
,Experimenter,Unknown
,Acquisition,Acquisition
,Duration,09:32:01 (HH:MM:SS)
,Sampling frequency,100.00 Hz
,Time points,3432010
,Channels,Channels


### Spindle detection

In [8]:
spindles_train_times_filtered = detect_spindles_times(train_filtered, do_filter=False, do_downsample=False)
spindles_test_times_filtered = detect_spindles_times(test_filtered, do_filter=False, do_downsample=False)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


In [9]:
print(len(spindles_train_times_filtered))
print(len(spindles_test_times_filtered))

7123
5794


In [10]:
spindles_starts_train_filtered, spindles_ends_train_filtered = zip(*spindles_train_times_filtered) if spindles_train_times_filtered else([],[])
spindles_starts_test_filtered, spindles_ends_test_filtered = zip(*spindles_test_times_filtered) if spindles_test_times_filtered else([],[])

In [11]:
print(len(spindles_starts_train_filtered))
print(len(spindles_ends_train_filtered))

print(len(spindles_starts_test_filtered))
print(len(spindles_ends_test_filtered))

7123
7123
5794
5794


### Epoch the data

In [12]:
epochs_train_filtered = create_fixed_length_epochs(train_filtered)

Not setting metadata
14550 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 14550 events and 300 original time points ...
0 bad epochs dropped


In [13]:
epochs_test_filtered = create_fixed_length_epochs(test_filtered)

Not setting metadata
11440 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 11440 events and 300 original time points ...
0 bad epochs dropped


In [14]:
print(epochs_train_filtered.get_data().shape)
print(epochs_test_filtered.get_data().shape)

(14550, 1, 300)
(11440, 1, 300)


### Labels for 3-second epochs

In [15]:
%%time

# Train set

epoch_labels_train_filtered = label_spindle_epochs(epochs_train_filtered, spindles_starts_train_filtered, spindles_ends_train_filtered)

print(f"Train data first 10 spindles: {spindles_starts_train_filtered[:10]}")
print(f"Train labels first 10 epochs: {epoch_labels_train_filtered[:10]}")

# Test set

epoch_labels_test_filtered = label_spindle_epochs(epochs_test_filtered, spindles_starts_test_filtered, spindles_ends_test_filtered)

print(f"\nTest data first 10 spindles: {spindles_starts_test_filtered[:10]}")
print(f"Test labels first 10 epochs: {epoch_labels_test_filtered[:10]}")

Train data first 10 spindles: (np.float64(0.05), np.float64(4.95), np.float64(10.01), np.float64(19.41), np.float64(27.23), np.float64(30.01), np.float64(31.05), np.float64(41.95), np.float64(47.79), np.float64(53.14))
Train labels first 10 epochs: [1 1 1 1 0 0 1 0 0 1]

Test data first 10 spindles: (np.float64(29.89), np.float64(49.64), np.float64(61.33), np.float64(78.56), np.float64(87.9), np.float64(88.5), np.float64(97.5), np.float64(112.37), np.float64(119.92), np.float64(147.43))
Test labels first 10 epochs: [0 0 0 0 0 0 0 0 0 1]
CPU times: total: 53.3 s
Wall time: 54 s


### Prepare EEG data for CNN input

#### X and y train and test sets

In [16]:
# Reshape arrays

epochs_train_np_filtered = np.array(epochs_train_filtered).reshape(len(epochs_train_filtered), -1, 1)
# number of epochs N, sampling frequency (time dimension automatically inferred), channel dimension
epochs_test_np_filtered = np.array(epochs_test_filtered).reshape(len(epochs_test_filtered), -1, 1)
                                                                 
# Define X and y sets
                                                                 
X_train_filtered = epochs_train_np_filtered
y_train_filtered = epoch_labels_train_filtered

X_test_filtered = epochs_test_np_filtered
y_test_filtered = epoch_labels_test_filtered

# Print shapes

print(f"X_train shape: {X_train_filtered.shape}")
print(f"y_train shape: {y_train_filtered.shape}")

print(f"\nX_test shape: {X_test_filtered.shape}")
print(f"y_test shape: {y_test_filtered.shape}")
                                                                 

X_train shape: (14550, 300, 1)
y_train shape: (14550,)

X_test shape: (11440, 300, 1)
y_test shape: (11440,)


### Train the model

#### Inspect class imbalance

In [17]:
# Count spindle and non-spindle labels in train and test sets
unique_train_filtered, counts_train_filtered = np.unique(y_train_filtered, return_counts=True)
unique_test_filtered, counts_test_filtered = np.unique(y_test_filtered, return_counts=True)

print("Train label distribution:")
for label, count in zip(unique_train_filtered, counts_train_filtered):
    print(f"Label {label}: {count}")

print("\nTest label distribution:")
for label, count in zip(unique_test_filtered, counts_test_filtered):
    print(f"Label {label}: {count}")

Train label distribution:
Label 0: 7103
Label 1: 7447

Test label distribution:
Label 0: 5154
Label 1: 6286


#### Training

In [19]:
early_stop = EarlyStopping(
    monitor='val_loss',      
    patience=5,               
    restore_best_weights=True 
)
# stop after 5 epochs with no improvement

In [20]:
%%time

training_info = model.fit(X_train_filtered, y_train_filtered, validation_split=0.2, epochs=30, batch_size=128, callbacks=[early_stop])

Epoch 1/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 170ms/step - accuracy: 0.5920 - loss: 0.6493 - val_accuracy: 0.5780 - val_loss: 0.7163
Epoch 2/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 163ms/step - accuracy: 0.7226 - loss: 0.5401 - val_accuracy: 0.6182 - val_loss: 0.6570
Epoch 3/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 171ms/step - accuracy: 0.6273 - loss: 0.6229 - val_accuracy: 0.6533 - val_loss: 0.6407
Epoch 4/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 164ms/step - accuracy: 0.6633 - loss: 0.5930 - val_accuracy: 0.8553 - val_loss: 0.3555
Epoch 5/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 166ms/step - accuracy: 0.8434 - loss: 0.3481 - val_accuracy: 0.8842 - val_loss: 0.2843
Epoch 6/30
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 172ms/step - accuracy: 0.8757 - loss: 0.2922 - val_accuracy: 0.8955 - val_loss: 0.2463
Epoch 7/30
[1m91/91[

In [21]:
def plot_training_history(training_info):
  fig, axs = plt.subplots(1, 2, figsize=(16, 5))
  axs[0].plot(training_info.history['loss'], label="training set")
  axs[0].plot(training_info.history['val_loss'], label="validation set")
  axs[0].set_xlabel("Epoch")
  axs[0].set_ylabel("Loss")
  axs[0].grid(True)
  axs[0].legend()
  try:
    axs[1].plot(training_info.history['accuracy'], label="training set")
    axs[1].plot(training_info.history['val_accuracy'], label="validation set")
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Accuracy")
    axs[1].grid(True)
    axs[1].legend()
  except:
    pass
  plt.show()

plot_training_history(training_info)

In [22]:
model.evaluate(X_test_filtered, y_test_filtered)

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 26ms/step - accuracy: 0.8999 - loss: 0.2338 


[0.24101856350898743, 0.8963286876678467]

### Metrics

In [23]:
y_pred = model.predict(X_test_filtered)
y_pred_labels = (y_pred > 0.5).astype(int)

print(confusion_matrix(y_test_filtered, y_pred_labels))
print(classification_report(y_test_filtered, y_pred_labels))

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 28ms/step
[[5042  112]
 [1074 5212]]
              precision    recall  f1-score   support

           0       0.82      0.98      0.89      5154
           1       0.98      0.83      0.90      6286

    accuracy                           0.90     11440
   macro avg       0.90      0.90      0.90     11440
weighted avg       0.91      0.90      0.90     11440



In [24]:
# Flatten in case y_pred has shape (n_samples, 1)
y_pred_proba = y_pred.ravel()

# Compute ROC curve and AUC
fpr, tpr, thresholds = roc_curve(y_test_filtered, y_pred_proba)
roc_auc = auc(fpr, tpr)

# Plotting
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random chance')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic")
plt.legend(loc="lower right")
plt.grid(True)
plt.show()