<a href="https://colab.research.google.com/github/supertime1/OSA/blob/main/sleep_staging.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#1.Import Dependency

In [6]:
import os
import mne
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import glob
import pickle

#2.Process the data

In [None]:
def load_data(file_path, sf=128, epoch_duration=30):

    ecg_samples = []
    ecg_labels = []
    total_epoches = 0
    
    for signal_file in glob.glob(file_path + '*[0-9].edf'):
        
        ecg_epoches = []
        
        data = mne.io.read_raw_edf(signal_file)
        ecg_ch = [i for i, v in enumerate(data.info.ch_names) if v == 'ECG']
        ecg_signal = data.get_data()[ecg_ch[0]]
        
        num_of_sample_per_epoch = sf * epoch_duration
        num_of_epoches = len(ecg_signal) // (num_of_sample_per_epoch)
        total_epoches += num_of_epoches
        
        print(f'{signal_file[-12:]} has {num_of_epoches} epoches')

        for i in range(num_of_epoches):
            ecg_epoch = ecg_signal[i*num_of_sample_per_epoch : (i+1)*num_of_sample_per_epoch]
            ecg_epoches.append(ecg_epoch)
        ecg_samples.append(ecg_epoches)
    
    for label_file in glob.glob(file_path + '*stage.txt'):
        print(f'reading {label_file}')
        ecg_labels.append(np.loadtxt(label_file))
        
    assert len(ecg_samples) == len(ecg_labels)

    for i in range(len(ecg_samples)):
        new_length = len(ecg_labels[i])
        ecg_samples[i] = ecg_samples[i][:new_length]
    
    return ecg_samples, ecg_labels, total_epoches

fp = "C:/Users/57lzhang.US04WW4008/Downloads/ucd/files/"
ecg_samples, ecg_labels, total_epoches = load_data(fp, 128, 30)

In [7]:
with open("C:/Users/57lzhang.US04WW4008/Downloads/ucd/files/processed_data/ecg_samples.pkl", "wb") as fp:
    pickle.dump(ecg_samples, fp)

with open("C:/Users/57lzhang.US04WW4008/Downloads/ucd/files/processed_data/ecg_labels.pkl", "wb") as fp:
    pickle.dump(ecg_labels, fp)

In [8]:
# split by patient id 
def split_data(ecg_samples, ecg_labels, train_ratio, test_ratio):
    test_num = round(len(ecg_samples) * test_ratio)
    train_num = round((len(ecg_samples) - test_num) * train_ratio)
    val_num = len(ecg_samples) - test_num - train_num

    np.random.seed(seed=7)
    np.random.shuffle(ecg_samples)
    train_samples = ecg_samples[:train_num]
    val_samples = ecg_samples[train_num:train_num+val_num]
    test_samples = ecg_samples[-test_num:]

    np.random.seed(seed=7)
    np.random.shuffle(ecg_labels)
    train_labels = ecg_labels[:train_num]
    val_labels = ecg_labels[train_num:train_num+val_num]
    test_labels = ecg_labels[-test_num:]

    return train_samples, train_labels, val_samples, val_labels, test_samples, test_labels 

train_samples, train_labels, val_samples, val_labels, test_samples, test_labels = split_data(ecg_samples, ecg_labels, 0.8, 0.12)
print(f'There are {len(train_samples)} subjects in training dataset')
print(f'There are {len(val_samples)} subjects in validation dataset')
print(f'There are {len(test_samples)} subjects in testing dataset')

There are 18 subjects in training dataset
There are 4 subjects in validation dataset
There are 3 subjects in testing dataset


In [9]:
def preprocess_data(num_epoch, epoch_duration, sf, ecg_samples, ecg_labels, oversample=True):
    """
    preprocess data with the matched dimension for training
    [num_epoch, 30*sampling_frequency, 1]

    params:
    epoch - number of epoches for each training sample
    sf - sampling frequency of ECG signal
    file_path - file path of raw EDF file
    """
    model_signal_input = []
    model_label_input = []

    num_of_sample_per_epoch = sf * epoch_duration

    for i in range(len(ecg_samples)):
        
        if oversample:
            overlap = int(0.9 * num_epoch)
            for j in range(len(ecg_samples[i])):
                signal_segment = np.asarray(ecg_samples[i][j*(num_epoch - overlap): j*(num_epoch - overlap) + num_epoch])
                if len(signal_segment) == num_epoch:
                    new_signal_seg = np.reshape(signal_segment, (num_epoch, num_of_sample_per_epoch, 1))
                    model_signal_input.append(new_signal_seg)
                
                # apply to labels as well
                label_segment = np.asarray(ecg_labels[i][j*(num_epoch - overlap): j*(num_epoch - overlap) + num_epoch])
                if len(label_segment) == num_epoch:
                    model_label_input.append(label_segment)
        
        else:
            for j in range(len(ecg_samples[i])):
                signal_segment = np.asarray(ecg_samples[i][j*num_epoch: (j+1)*num_epoch]) 
                if len(signal_segment) == num_epoch:
                    new_signal_seg = np.reshape(signal_segment, (num_epoch, num_of_sample_per_epoch, 1))
                    model_signal_input.append(new_signal_seg)

                # apply to labels as well
                label_segment = np.asarray(ecg_labels[i][j*num_epoch: (j+1)*num_epoch])
                if len(label_segment) == num_epoch:
                    model_label_input.append(label_segment)
        
    print(f'shape of processed signal data: {np.asarray(model_signal_input).shape}')
    print(f'shape of processed label data: {np.asarray(model_label_input).shape}')

    return np.asarray(model_signal_input), np.asarray(model_label_input)

In [10]:
def helper(samples, labels):
    for i in range(len(samples)):
        print(f'{len(samples[i])}, {len(labels[i])}')  

In [11]:
helper(ecg_samples, ecg_labels)

882, 882
768, 768
774, 774
789, 789
826, 826
711, 711
864, 864
752, 752
916, 916
748, 748
893, 893
925, 925
908, 908
913, 913
721, 721
811, 811
787, 787
900, 900
822, 822
907, 907
861, 861
808, 808
838, 838
813, 813
852, 852


In [12]:
train_signal_input, train_label_input = preprocess_data(100, 30, 128, train_samples, train_labels, oversample=True)
val_signal_input, val_label_input = preprocess_data(100, 30, 128, val_samples, val_labels, oversample=False)
test_signal_input, test_label_input = preprocess_data(100, 30, 128, test_samples, test_labels, oversample=False)

shape of processed signal data: (1319, 100, 3840, 1)
shape of processed label data: (1319, 100)
shape of processed signal data: (33, 100, 3840, 1)
shape of processed label data: (33, 100)
shape of processed signal data: (24, 100, 3840, 1)
shape of processed label data: (24, 100)


In [18]:
def join_labels(label_input):
    for i in range(len(label_input)):
        label_input[label_input > 1] = 2
    return label_input

In [29]:
train_label_input_join = join_labels(train_label_input)
val_label_input_join = join_labels(val_label_input)
test_label_input_join = join_labels(test_label_input)

#3.Model

In [30]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model
from tensorflow.keras.models import load_model 
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint
from tensorflow.keras.layers import Conv1D, BatchNormalization, Input, Add, Activation,\
MaxPooling1D,Dropout,Flatten,TimeDistributed,Bidirectional,Dense,LSTM, ZeroPadding1D, \
AveragePooling1D,GlobalMaxPooling1D, Concatenate, Permute, Dot, Multiply, RepeatVector,\
Lambda, Average
from tensorflow.keras.initializers import glorot_uniform
import tensorflow_datasets as tfds

In [31]:
batch_size = 16
train_dataset = tf.data.Dataset.from_tensor_slices((train_signal_input, train_label_input_join))
val_dataset = tf.data.Dataset.from_tensor_slices((val_signal_input, val_label_input_join))
train_dataset = train_dataset.cache()
train_dataset = train_dataset.shuffle(1024).repeat().batch(batch_size, drop_remainder=True)
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
val_dataset = val_dataset.repeat().batch(batch_size, drop_remainder=True)

##Callbacks

In [None]:
# callbacks
log_dir = r"C:\Users\57lzhang.US04WW4008\Desktop\OSA\TensorBoard\logs\fit\\" + \
          datetime.now().strftime("%Y%m%d-%H%M%S") + "test"

## confusion matrix callback
def log_confusion_matrix(epoch, logs):
    # Use the model to predict the values from the test_images.
    
    val_pred_raw = model.predict(val_signal_input)

    val_pred = np.argmax(val_pred_raw, axis=-1)
    #test_labels = val_labels.reshape([len(t_val_labels)])

    # Calculate the confusion matrix using sklearn.metrics
    cm = sklearn.metrics.confusion_matrix(val_label_input_join, val_pred)

    figure = model_util.plot_confusion_matrix(cm, class_names=class_names, normalize=True)
    cm_image = model_util.plot_to_image(figure)

    # Log the confusion matrix as an image summary.
    with file_writer_cm.as_default():
        tf.summary.image("Confusion Matrix", cm_image, step=epoch)

class_names = ['Wake','NREM','REM']
file_writer_cm = tf.summary.create_file_writer(log_dir + '/cm')
cm_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=log_confusion_matrix)

## tensorboard callback
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)

## checkpoint callback
filepath = r"C:\Users\57lzhang.US04WW4008\Desktop\OSA\models\test-oversample-128Hz-{epoch:02d}-{loss:.4f}"
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='auto')

## early stop
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True)

## learning rate decay callback
lr_schedule = tf.keras.callbacks.LearningRateScheduler(model_util.decay)

callbacks_list = [tensorboard_callback, cm_callback, checkpoint, early_stop, lr_schedule]

##Model

In [32]:
cnn = tf.keras.Sequential([
    #1st Conv1D
    tf.keras.layers.Conv1D(8, 1, strides=1, 
                          activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.MaxPooling1D(pool_size=2,strides=2),
    tf.keras.layers.Dropout(0.2),
    #2nd Conv1D
    tf.keras.layers.Conv1D(16, 3, strides=1,
                          activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.MaxPooling1D(pool_size=2,strides=2),
    tf.keras.layers.Dropout(0.2),
    #3rd Conv1D
    tf.keras.layers.Conv1D(32, 3, strides=1,
                          activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.MaxPooling1D(pool_size=2,strides=2),
    tf.keras.layers.Dropout(0.2),
    #4th Conv1D
    tf.keras.layers.Conv1D(64, 3, strides=1,
                          activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.MaxPooling1D(pool_size=2,strides=2),
    tf.keras.layers.Dropout(0.2),
    #5th Conv1D
    tf.keras.layers.Conv1D(16, 1, strides=1,
                          activation='relu'),
    tf.keras.layers.BatchNormalization(),
    #Full connection layer
    tf.keras.layers.Flatten()
])

#combine with LSTM
model = tf.keras.Sequential([
        tf.keras.layers.TimeDistributed(cnn,input_shape=(100,3840,1)),                   
        tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(256,return_sequences=True)),
        tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128, return_sequences = True)),
        tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(8))
])

model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
time_distributed (TimeDistri (None, 100, 3808)         9776      
_________________________________________________________________
bidirectional (Bidirectional (None, 100, 512)          8325120   
_________________________________________________________________
bidirectional_1 (Bidirection (None, 100, 256)          656384    
_________________________________________________________________
time_distributed_1 (TimeDist (None, 100, 8)            2056      
Total params: 8,993,336
Trainable params: 8,993,064
Non-trainable params: 272
_________________________________________________________________


In [33]:
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [None]:
model.fit(train_dataset,
          epochs=10,
          steps_per_epoch=len(train_signal_input)//batch_size,
          verbose=1,
          validation_data=val_dataset,
          validation_steps=len(val_signal_input)//batch_size
          )

Epoch 1/10
12/82 [===>..........................] - ETA: 29s - loss: 1.5343 - accuracy: 0.4036