In [1]:
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
import mne
import matplotlib.pyplot as plt
import pyvista
import ipywidgets
import ipyevents
import pyvistaqt
import yasa

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, f1_score, roc_curve, auc
from sklearn.utils import class_weight

import tensorflow as tf
from tensorflow.keras import layers, models

import scipy.signal as signal
from scipy.signal import hilbert

In [2]:
%matplotlib qt

### CNN Model

In [3]:
def build_cnn_model(input_shape=(500, 1)):

    # linear embedding layer
    input_layer = tf.keras.layers.Input(shape=input_shape)

    # Three convolutional blocks (like having three pattern detectors)

    # First convolution block, kernel size of 5
    padded1 = tf.keras.layers.ZeroPadding1D(padding=2)(input_layer)
    conv1 = tf.keras.layers.Conv1D(filters=10, kernel_size=5, strides=1, padding='valid')(padded1)
    # each filter learns a different type of short-time feature
    # stride of 1, moves one step at a time
    conv1 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv1)
    conv1 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv1)
    # K = 2
    conv1 = tf.keras.layers.BatchNormalization()(conv1)

    # Second convolution block, kernel size of 11
    padded2 = tf.keras.layers.ZeroPadding1D(padding=5)(input_layer)
    conv2 = tf.keras.layers.Conv1D(filters=10, kernel_size=11, strides=1, padding='valid')(padded2)
    conv2 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv2)
    conv2 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv2)
    conv2 = tf.keras.layers.BatchNormalization()(conv2)

    # Third convolution block, kernel size of 21
    padded3 = tf.keras.layers.ZeroPadding1D(padding=10)(input_layer)
    conv3 = tf.keras.layers.Conv1D(filters=10, kernel_size=21, strides=1, padding='valid')(padded3)
    conv3 = tf.keras.layers.LeakyReLU(alpha=0.01)(conv3)
    conv3 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv3)
    conv3 = tf.keras.layers.BatchNormalization()(conv3)

    # Concatenate the outputs of all blocks
    concatenated = tf.keras.layers.Concatenate()([conv1, conv2, conv3])

    # GRU Layer
    gru = tf.keras.layers.GRU(64)(concatenated)

    # Fully connected (dense) layer
    dense = tf.keras.layers.Dense(64, activation='relu')(gru)
    # add a Dropout layer to prevent overfitting
    dense = tf.keras.layers.Dropout(0.5)(dense)

    # Two softmax outputs for dual-task classification
    #output_task1 = tf.keras.layers.Dense(2, activation='softmax', name='task1')(dense)
    #output_task2 = tf.keras.layers.Dense(2, activation='softmax', name='task2')(dense)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(dense)

    # Create the model
    #model = tf.keras.models.Model(inputs=input_layer, outputs=[output_task1, output_task2])
    model = tf.keras.models.Model(inputs=input_layer, outputs=output)

    # Compile the model
    #model.compile(optimizer='adam', loss={'task1': 'categorical_crossentropy', 'task2': 'categorical_crossentropy'}, metrics={'task1': 'accuracy', 'task2': 'accuracy'})
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    # Return the compiled model
    return model

In [4]:
input_shape = (500, 1)
cnn_model = build_cnn_model(input_shape)
cnn_model.summary()



## Spindle detection

In [5]:
#import scipy.signal as signal
#from scipy.signal import find_peaks

def detect_spindles_times(eeg_raw):
    # Parameters
    #channel = 'Fz'
    
    # 1. Filter between 12 and 16 Hz
    
    filtered_data = eeg_raw.copy().pick_channels(['Fz'])
    filtered_data.filter(l_freq=12, h_freq=16)
    
    # 2. Downsample at 100 Hz (100 samples per second)
    
    #filtered_data.resample(100)
    sfreq = filtered_data.info['sfreq']  
    # update to new sampling frequency
    # because used later in the code
    channel_data = filtered_data.get_data()[0]
    # extract the filtered data
    
    
    # 3: Calculate amplitude by applying Hilbert transformation

    hilbert_signal = hilbert(channel_data)
    # apply hilbert transformation to bandpassed data
    # gives analytic signal with amplitude and phase information
    envelope = np.abs(hilbert_signal)
    # take the absolute part of the hilbert signal
    # also the instantaneous power of the signal
    # gives the envelope: amplitude modulation
    # how strength of oscillations change over time
    # size of sliding window
    
    # 4: Perform smoothing with a sliding window of 0.2 seconds
    # this removes high-frequency noise
    
    sliding_window = int(0.2 * sfreq)
    smoothed_envelope = np.convolve(envelope, np.ones(sliding_window) / sliding_window, mode='same')
    # convolving envelope with a uniform filter over the sliding window
    # convolution takes rolling average of 20 samples at a time
    # smooth the signal with the average of values in the window
    # in the smoothed envelope, can detect regions with higher amplitude 
    # which is when a spindle event occurs
    # np.ones: creates a filter kernel
    # have a filter where the sum of all elements equals 1
    # this filter is replaced by the average of the 20 surrounding samples
    # convolution between envelope and averaging filter
    # mode = 'same': so that output of convolution has same length as original envelope

    # 5. Define spindle detection threshold

    threshold = np.percentile(smoothed_envelope, 75)
    spindle_threshold = smoothed_envelope > threshold
    #threshold = np.mean(smoothed_envelope) + 1.5 * np.std(smoothed_envelope)
    #spindle_threshold = smoothed_envelope > threshold
    # threshold is 75th percentile of the smoothed envelope
    # will look at the duration later
    
    # 6. Detect spindles and define peaks and troughs for visualisation
    
    spindles = []
    # initialize list with spindles
    above_threshold = np.where(spindle_threshold)[0]
    # returns indices where signal above the threshold
    stacked_spindles = []
    # initialize list for stacking the spindles for the visualisation
    # contains aligned spindles at peak
    
    if len(above_threshold) > 0:
        # checking it's not empty
        start_idx = above_threshold[0]
        # would be the start of a potential spindle
        for i in range(1, len(above_threshold)):
            if above_threshold[i] > above_threshold[i - 1] + 1:  
                # if above threshold[1] > above_threshold[0] + 1
                # because all indices should be separated by 1
                # so here detects gaps
                # so starting from the second index
                # and comparing each index to the one before
                end_idx = above_threshold[i - 1]
                # so if above condition is true, this is the end of the spindle
                duration = (end_idx - start_idx) / sfreq
                if 0.5 <= duration <= 3:
                    # only keep spindles lasting 0.5 to 3 seconds
                    segment = channel_data[start_idx:end_idx]
                    # extract EEG segment corresponding to detected spindle
                    peak_idx = start_idx + np.argmax(segment) 
                    # extract the peak of the spindle
                    # this will be useful for later
                    spindles.append((start_idx / sfreq, end_idx / sfreq))
                    # all the spindles are stored in spindles
                    
                    # Aligning spindles at peak for visualization
                    before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
                    # still in the for loop, so this is the peak index of individual peak
                    after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
                    # extracting 1.5 seconds before and after peak
                    # max and min are used for out of bounds situations at the start and end of EEG data
                    aligned_segment = channel_data[before_peak_idx:after_peak_idx]
                    stacked_spindles.append(aligned_segment)
                    # the aligned segment is saved in stacked spindles
                
                start_idx = above_threshold[i]
                # update the start index for the for loop

        # then need to process the final spindle
        end_idx = above_threshold[-1]
        duration = (end_idx - start_idx) / sfreq
        if 0.5 <= duration <= 3:
            segment = channel_data[start_idx:end_idx]
            peak_idx = start_idx + np.argmax(segment)
            spindles.append((start_idx / sfreq, end_idx / sfreq))

            before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
            after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
            aligned_segment = channel_data[before_peak_idx:after_peak_idx]
            stacked_spindles.append(aligned_segment)
    
    return spindles

In [6]:
import scipy.signal as signal
from scipy.signal import find_peaks

def detect_spindles_peaks(eeg_raw):
    # Parameters
    #channel = 'Fz'
    
    # 1. Filter between 12 and 16 Hz
    
    filtered_data = eeg_raw.copy().pick_channels(['Fz'])
    filtered_data.filter(l_freq=12, h_freq=16)
    
    # 2. Downsample at 100 Hz (100 samples per second)
    
    #filtered_data.resample(100)
    sfreq = filtered_data.info['sfreq']  
    # update to new sampling frequency
    # because used later in the code
    channel_data = filtered_data.get_data()[0]
    # extract the filtered data
    
    # 3: Calculate amplitude by applying Hilbert transformation

    hilbert_signal = hilbert(channel_data)
    # apply hilbert transformation to bandpassed data
    # gives analytic signal with amplitude and phase information
    envelope = np.abs(hilbert_signal)
    # take the absolute part of the hilbert signal
    # also the instantaneous power of the signal
    # gives the envelope: amplitude modulation
    # how strength of oscillations change over time
    # size of sliding window
    
    # 4: Perform smoothing with a sliding window of 0.2 seconds
    # this removes high-frequency noise
    
    sliding_window = int(0.2 * sfreq)
    smoothed_envelope = np.convolve(envelope, np.ones(sliding_window) / sliding_window, mode='same')
    # convolving envelope with a uniform filter over the sliding window
    # convolution takes rolling average of 20 samples at a time
    # smooth the signal with the average of values in the window
    # in the smoothed envelope, can detect regions with higher amplitude 
    # which is when a spindle event occurs
    # np.ones: creates a filter kernel
    # have a filter where the sum of all elements equals 1
    # this filter is replaced by the average of the 20 surrounding samples
    # convolution between envelope and averaging filter
    # mode = 'same': so that output of convolution has same length as original envelope

    # 5. Define spindle detection threshold

    threshold = np.percentile(smoothed_envelope, 75)
    spindle_threshold = smoothed_envelope > threshold
    # 75th percentile as criteria

    #threshold = np.mean(smoothed_envelope) + 1.5 * np.std(smoothed_envelope)
    #spindle_threshold = smoothed_envelope > threshold
    
    # 6. Detect spindles and define peaks and troughs for visualisation
    
    spindles = []
    # initialize list with spindles
    above_threshold = np.where(spindle_threshold)[0]
    # returns indices where signal above the threshold
    stacked_spindles = []
    # initialize list for stacking the spindles for the visualisation
    # contains aligned spindles at peak
    
    if len(above_threshold) > 0:
        # checking it's not empty
        start_idx = above_threshold[0]
        # would be the start of a potential spindle
        for i in range(1, len(above_threshold)):
            if above_threshold[i] > above_threshold[i - 1] + 1:  
                # if above threshold[1] > above_threshold[0] + 1
                # because all indices should be separated by 1
                # so here detects gaps
                end_idx = above_threshold[i - 1]
                # so if above condition is true, this is the end of the spindle
                duration = (end_idx - start_idx) / sfreq
                if 0.5 <= duration <= 3:
                    # only keep spindles lasting 0.5 to 3 seconds
                    segment = channel_data[start_idx:end_idx]
                    # extract EEG segment corresponding to detected spindle
                    peak_idx = start_idx + np.argmax(segment) 
                    # extract the peak of the spindle
                    # this will be useful for later
                    #spindles.append(f"Spindle detected from {start_idx / sfreq:.2f}s to {end_idx / sfreq:.2f}s, peak at {peak_idx / sfreq:.2f}s")
                    spindles.append((peak_idx / sfreq))
                    # all the spindles are stored in spindles
                    
                    # Aligning spindles at peak for visualization
                    before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
                    # still in the for loop, so this is the peak index of individual peak
                    after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
                    # extracting 1.5 seconds before and after peak
                    # max and min are used for out of bounds situations at the start and end of EEG data
                    aligned_segment = channel_data[before_peak_idx:after_peak_idx]
                    stacked_spindles.append(aligned_segment)
                    # the aligned segment is saved in stacked spindles
                
                start_idx = above_threshold[i]
                # update the start index for the for loop

        # then need to process the final spindle
        end_idx = above_threshold[-1]
        duration = (end_idx - start_idx) / sfreq
        if 0.5 <= duration <= 3:
            segment = channel_data[start_idx:end_idx]
            peak_idx = start_idx + np.argmax(segment)
            spindles.append((peak_idx / sfreq))

            before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
            after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
            aligned_segment = channel_data[before_peak_idx:after_peak_idx]
            stacked_spindles.append(aligned_segment)

    
    return spindles

## Epochs 

### Create fixed length epochs

In [7]:

def create_fixed_length_epochs(raw, duration=1.0, overlap=0.0, preload=True, reject_by_annotation=False):

    return mne.make_fixed_length_epochs(
        raw,
        duration=duration,
        overlap=overlap,
        preload=preload,
        reject_by_annotation=reject_by_annotation
    )

### Epoch labels

In [8]:
def label_spindle_epochs(epochs, spindle_starts, spindle_ends, epoch_length_sec=1):

    epoch_starts = np.arange(len(epochs)) * epoch_length_sec
    # new np array with the start time of each epoch
    # epoch_starts[i] is the start time of each epoch

    epoch_labels = np.zeros(len(epochs), dtype=int)
    # initialize all the labels as 0 initially

    for start, end in zip(spindle_starts, spindle_ends):
        # loop through the start and end times of detected spindles by YASA
        for i, epoch_start in enumerate(epoch_starts):
            # loop through the one-second epochs that are not labelled yet
            epoch_end = epoch_start + epoch_length_sec
            # for each epoch, calculate the epoch end time
            # which is epoch_start + length of epoch
            # so now have the time range of each epoch
            if (start < epoch_end) and (end > epoch_start):
                # if the spindle started before the epoch ends
                # and the spindle ended after the epoch started
                epoch_labels[i] = 1
                
    return epoch_labels

## Participant 067

### Importing concatenated data

In [9]:
participant_067_concatenated_file = r"C:\EEG DATA\067\concatenated\067_concatenated_raw.fif"
combined_raw_067 = mne.io.read_raw_fif(participant_067_concatenated_file, preload=True)

Opening raw data file C:\EEG DATA\067\concatenated\067_concatenated_raw.fif...
Isotrak not found
    Range : 825000 ... 9180031 =   1650.000 ... 18360.062 secs
Ready.
Reading 0 ... 8355031  =      0.000 ... 16710.062 secs...


### Spindle detection

In [10]:
spindles_067_times = detect_spindles_times(combined_raw_067)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 32 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 12.00
- Lower transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 10.50 Hz)
- Upper passband edge: 16.00 Hz
- Upper transition bandwidth: 4.00 Hz (-6 dB cutoff frequency: 18.00 Hz)
- Filter length: 551 samples (1.102 s)



In [11]:
len(spindles_067_times)

2912

In [12]:
spindles_starts_067, spindles_ends_067 = zip(*spindles_067_times) if spindles_067_times else([],[])
# splits the spindle list into two list
# one with all the spindle starts
# and one with all the spindle ends

# if statement is in case found no spindles
# would just return empty lists

In [13]:
print(len(spindles_starts_067))
print(len(spindles_ends_067))

2912
2912


## Epoch the data

### Define 1-second epochs of EEG data

In [14]:
# use mne function to make epochs

epochs_067 = create_fixed_length_epochs(combined_raw_067)
# don't reject bad epochs to ensure that matches with spindles data

Not setting metadata
16710 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 16710 events and 500 original time points ...
0 bad epochs dropped


In [15]:
epochs_067.get_data().shape

(16710, 1, 500)

In [16]:
#epochs_067._data *= 1e6

# to convert from volts to microvolts
# otherwise units are really small and harder to detect events for CNN model

###  Define labels for each 1-second epoch

In [17]:
epoch_labels_067 = label_spindle_epochs(epochs_067, spindles_starts_067, spindles_ends_067)

In [18]:
epoch_labels_067[:30]

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 1])

In [19]:
spindles_starts_067[:10]

(np.float64(29.872),
 np.float64(47.342),
 np.float64(49.626),
 np.float64(61.314),
 np.float64(78.548),
 np.float64(87.148),
 np.float64(87.89),
 np.float64(97.47),
 np.float64(112.362),
 np.float64(119.904))

In [20]:
spindles_ends_067[:10]

(np.float64(30.49),
 np.float64(47.844),
 np.float64(50.196),
 np.float64(62.216),
 np.float64(79.516),
 np.float64(87.656),
 np.float64(90.13),
 np.float64(98.174),
 np.float64(114.392),
 np.float64(120.946))

## Building the CNN

### Prepare EEG data for CNN input

### Build the model

In [21]:
epochs_067_np = np.array(epochs_067)
epochs_067_np = epochs_067_np.squeeze(axis=1)
# reshape to go from (16710, 1, 100) to (16710, 100)
epochs_067_np = epochs_067_np[..., np.newaxis]
print(epochs_067_np.shape)
print(epoch_labels_067.shape)

(16710, 500, 1)
(16710,)


In [22]:
epochs_067_np[:1]
# do I need to convert back to correct unit?

array([[[ 1.80000003e+01],
        [ 1.82500003e+01],
        [ 1.55500002e+01],
        [ 1.48500002e+01],
        [ 1.74000003e+01],
        [ 1.49500002e+01],
        [ 1.33000002e+01],
        [ 1.59500002e+01],
        [ 1.81000003e+01],
        [ 1.54500002e+01],
        [ 1.77000003e+01],
        [ 2.23000003e+01],
        [ 2.13500003e+01],
        [ 1.94000003e+01],
        [ 2.11000003e+01],
        [ 1.81500003e+01],
        [ 1.56000002e+01],
        [ 1.65500002e+01],
        [ 1.74000003e+01],
        [ 1.66000002e+01],
        [ 2.03000003e+01],
        [ 2.11000003e+01],
        [ 1.78000003e+01],
        [ 1.60500002e+01],
        [ 1.42500002e+01],
        [ 7.05000011e+00],
        [ 2.10000003e+00],
        [ 1.30000002e+00],
        [ 5.50000008e-01],
        [-1.95000003e+00],
        [ 3.50000005e-01],
        [ 1.20000002e+00],
        [-4.90000007e+00],
        [-8.65000013e+00],
        [-6.70000010e+00],
        [-1.00000001e+01],
        [-1.07000002e+01],
 

### Prepare X and y train and test sets

In [23]:
# split into X and y (labels) data
X = epochs_067_np
y = epoch_labels_067

# split into train and test set

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
# stratify = y to ensure that same proportion of classes in both training and test set

In [24]:
print(X_train.shape)
print(y_train.shape)

print(X_test.shape)
print(y_test.shape)

(13368, 500, 1)
(13368,)
(3342, 500, 1)
(3342,)


### Train the model

In [25]:
#class_weights = class_weight.compute_class_weight(class_weight='balanced', classes=np.unique(y_train), y=y_train)
#class_weight_dict ={0: class_weights[0], 1: class_weights[1]}

In [26]:
training_info = cnn_model.fit(X_train, y_train, validation_split=0.2, epochs=30, batch_size=64)

Epoch 1/30
[1m168/168[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m36s[0m 167ms/step - accuracy: 0.6853 - loss: 0.6264 - val_accuracy: 0.6743 - val_loss: 0.6121
Epoch 2/30
[1m168/168[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 182ms/step - accuracy: 0.6975 - loss: 0.5925 - val_accuracy: 0.7364 - val_loss: 0.5291
Epoch 3/30
[1m168/168[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 190ms/step - accuracy: 0.7179 - loss: 0.5607 - val_accuracy: 0.7636 - val_loss: 0.4990
Epoch 4/30
[1m168/168[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 199ms/step - accuracy: 0.7451 - loss: 0.5261 - val_accuracy: 0.7610 - val_loss: 0.5250
Epoch 5/30
[1m168/168[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 198ms/step - accuracy: 0.7473 - loss: 0.5394 - val_accuracy: 0.6889 - val_loss: 0.5969
Epoch 6/30
[1m168/168[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 199ms/step - accuracy: 0.7023 - loss: 0.5804 - val_accuracy: 0.7311 - val_loss: 0.5477
Epoch 7/30

### Evaluate the model

##### Plot the training history

In [27]:
def plot_training_history(training_info):
  fig, axs = plt.subplots(1, 2, figsize=(16, 5))
  axs[0].plot(training_info.history['loss'], label="training set")
  axs[0].plot(training_info.history['val_loss'], label="validation set")
  axs[0].set_xlabel("Epoch")
  axs[0].set_ylabel("Loss")
  axs[0].grid(True)
  axs[0].legend()
  try:
    axs[1].plot(training_info.history['accuracy'], label="training set")
    axs[1].plot(training_info.history['val_accuracy'], label="validation set")
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Accuracy")
    axs[1].grid(True)
    axs[1].legend()
  except:
    pass
  plt.show()

plot_training_history(training_info)

##### Assess the model on the test set

In [28]:
cnn_model.evaluate(X_test, y_test)

[1m105/105[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 26ms/step - accuracy: 0.8627 - loss: 0.3202


[0.33498018980026245, 0.855176568031311]

##### Metrics: precision, recall, f1-score

In [29]:
y_pred = cnn_model.predict(X_test)
y_pred_labels = (y_pred > 0.5).astype(int)

print(confusion_matrix(y_test, y_pred_labels))
print(classification_report(y_test, y_pred_labels))

[1m105/105[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 43ms/step
[[2093  182]
 [ 302  765]]
              precision    recall  f1-score   support

           0       0.87      0.92      0.90      2275
           1       0.81      0.72      0.76      1067

    accuracy                           0.86      3342
   macro avg       0.84      0.82      0.83      3342
weighted avg       0.85      0.86      0.85      3342



##### Plotting the ROC curve

In [30]:
# Flatten in case y_pred has shape (n_samples, 1)
y_pred_proba = y_pred.ravel()

# Compute ROC curve and AUC
fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba)
roc_auc = auc(fpr, tpr)

# Plotting
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random chance')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic")
plt.legend(loc="lower right")
plt.grid(True)
plt.show()

In [None]:
model = build_cnn_model()
print(model.input_shape)

In [None]:
X_train.shape

In [None]:
print("X shape:", X_train.shape)
print("Model input shape:", model.input_shape)