This tutorial is modified from https://www.kaggle.com/rkuo2000/ecg-classification/comments#778183

Model design based on ECG Heartbeat Classification: [A Deep Transferable Representation](https://arxiv.org/pdf/1805.00794.pdf)

In [None]:
#@title Show GPU infomation
!nvidia-smi

In [None]:
#@title Download data and unzip the file from google drive share-link
#@markdown Data source：https://www.kaggle.com/shayanfazeli/heartbeat
!gdown --id 1O0_YAHtdEc2uO9R4rX5hYlN7DIY6y3qX
!unzip -n -q 'heartbeat.zip'
print("... done")

# -- Import packages

In [None]:
import os
import math
import random
import pickle
import itertools
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report
from sklearn.utils import shuffle
from scipy.signal import resample
from scipy import interp
from itertools import cycle, product
import tensorflow.keras as keras
from tensorflow.keras.utils import to_categorical, plot_model
from tensorflow.keras.models import Sequential, Model, load_model
from tensorflow.keras.layers import Input, Dense, Conv1D, MaxPooling1D, Softmax, Add, Flatten, Activation, Dropout
from tensorflow.keras.callbacks import LearningRateScheduler, ModelCheckpoint
from tensorflow.keras.optimizers import Adam

# 1 Dataset

In [None]:
#@title load dataset
df = pd.read_csv("mitbih_train.csv", header=None)
df

In [None]:
# show one data    # colume 187 is target label
idx = 1000
plt.plot(df.iloc[idx,:187])

In [None]:
#@title define dataset X and y
M = df.values  # numpy array M
X = M[:, :-1]  # column 0-186
y = M[:, -1].astype(int)  # # column 187
print("dataset X shape =", X.shape)
print("dataset y shape =", y.shape)

In [None]:
## visualize dataset X
C0_idx = np.argwhere(y==0).flatten()
C1_idx = np.argwhere(y==1).flatten()
C2_idx = np.argwhere(y==2).flatten()
C3_idx = np.argwhere(y==3).flatten()
C4_idx = np.argwhere(y==4).flatten()

t = np.arange(0, 187) * (1/125) * 1000  # time-line with Sampling Frequency: 125Hz

# plot
plt.figure(figsize=(15,6))
plt.plot(t, X[C0_idx, :][0], label="C0: N")
plt.plot(t, X[C1_idx, :][0], label="C1: S")
plt.plot(t, X[C2_idx, :][0], label="C2: V")
plt.plot(t, X[C3_idx, :][0], label="C3: F")
plt.plot(t, X[C4_idx, :][0], label="C4: Q")
plt.legend()
plt.title("1-beat ECG for every category", fontsize=20)
plt.ylabel("Amplitude", fontsize=15)
plt.xlabel("Time (ms)", fontsize=15)
plt.show()

In [None]:
# show number of each class
class_num = [ C0_idx.shape[0], C1_idx.shape[0], C2_idx.shape[0], C3_idx.shape[0], C4_idx.shape[0]]
print("number of each class: C0, C1, C2, C3, C4 =", class_num)

# plot the circle of class counts
plt.figure(figsize=(6,6))
my_circle=plt.Circle( (0,0), 0.7, color='white')
plt.pie(class_num, labels=['C0','C1','C2','C3','C4'], colors=['red','green','blue','skyblue','orange'],autopct='%1.1f%%')
p=plt.gcf()
p.gca().add_artist(my_circle)
plt.show()

## split dataset to training_set and valid_set

In [None]:
# train_test_split (https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html)
val_set_ratio = 0.2
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=val_set_ratio, random_state=42)

# set channel last
X_train = X_train[:,:,np.newaxis]  # as shape (number_of_data, X_lengh, channel)
X_val = X_val[:,:,np.newaxis]  # as shape (number_of_data, X_lengh, channel)

print("class count train : (C0:", X_train[np.argwhere(y_train==0)].shape[0], ", C1:", X_train[np.argwhere(y_train==1)].shape[0], ", C2:", X_train[np.argwhere(y_train==2)].shape[0], ", C3:", X_train[np.argwhere(y_train==3)].shape[0], ", C4:", X_train[np.argwhere(y_train==4)].shape[0], ")")
print("class count val   : (C0:", X_val[np.argwhere(y_val==0)].shape[0], ", C1:", X_val[np.argwhere(y_val==1)].shape[0], ", C2:", X_val[np.argwhere(y_val==2)].shape[0], ", C3:", X_val[np.argwhere(y_val==3)].shape[0], ", C4:", X_val[np.argwhere(y_val==4)].shape[0], ")")
print("X_train shape =", X_train.shape)
print("y_train shape =", y_train.shape)
print("X_val shape =", X_val.shape)
print("y_val shape =", y_val.shape)

## balance dataset via augmentation

In [None]:
def stretch(x):
    l = int(187 * (1 + (random.random()-0.5)/3))
    y = resample(x, l)
    if l < 187:
        y_ = np.zeros(shape=(187, ))
        y_[:l] = y
    else:
        y_ = y[:187]
    return y_

def amplify(x):
    alpha = (random.random()-0.5)
    factor = -alpha*x + (1+alpha)
    return x*factor

def augment(x):
    # 4 times augmentation
    result = np.zeros(shape= (4, 187))
    for i in range(3):
        if random.random() < 0.33:
            new_x = stretch(x)
        elif random.random() < 0.66:
            new_x = amplify(x)
        else:
            new_x = stretch(x)
            new_x = amplify(new_x)
        result[i, :] = new_x
    return result

plt.title("Demo of one augmentation")
plt.plot(X[0, :], 'b', label="origin")
plt.plot(amplify(X[0, :]), '--', label="Aug amplify")
plt.plot(stretch(X[0, :]), '--', label="Aug stretch")
plt.legend()
plt.show()

In [None]:
print("before augmentation:")
print("class count train : (C0:", X_train[np.argwhere(y_train==0)].shape[0], ", C1:", X_train[np.argwhere(y_train==1)].shape[0], ", C2:", X_train[np.argwhere(y_train==2)].shape[0], ", C3:", X_train[np.argwhere(y_train==3)].shape[0], ", C4:", X_train[np.argwhere(y_train==4)].shape[0], ")")
print("class count val   : (C0:", X_val[np.argwhere(y_val==0)].shape[0], ", C1:", X_val[np.argwhere(y_val==1)].shape[0], ", C2:", X_val[np.argwhere(y_val==2)].shape[0], ", C3:", X_val[np.argwhere(y_val==3)].shape[0], ", C4:", X_val[np.argwhere(y_val==4)].shape[0], ")")

In [None]:
#@title apply augmentation

# augment C3 5 times
def augment_C3(X, y):
    C3_idx = np.argwhere(y==3).flatten()
    result = np.apply_along_axis(augment, axis=1, arr=X[C3_idx]).reshape(-1, 187, 1)
    classe = np.ones(shape=(result.shape[0],), dtype=int)*3
    # append aug_data to training dataset
    X = np.vstack([X, result])
    y = np.hstack([y, classe])
    return X, y

X_train, y_train = augment_C3(X_train, y_train)
X_val, y_val = augment_C3(X_val, y_val)

print("after augmentation:")
print("class count train : (C0:", X_train[np.argwhere(y_train==0)].shape[0], ", C1:", X_train[np.argwhere(y_train==1)].shape[0], ", C2:", X_train[np.argwhere(y_train==2)].shape[0], ", C3:", X_train[np.argwhere(y_train==3)].shape[0], ", C4:", X_train[np.argwhere(y_train==4)].shape[0], ")")
print("class count val   : (C0:", X_val[np.argwhere(y_val==0)].shape[0], ", C1:", X_val[np.argwhere(y_val==1)].shape[0], ", C2:", X_val[np.argwhere(y_val==2)].shape[0], ", C3:", X_val[np.argwhere(y_val==3)].shape[0], ", C4:", X_val[np.argwhere(y_val==4)].shape[0], ")")

## shuffle data

In [None]:
print("Before shuffle, y_val[:30] =", y_val[:30])

X_train, y_train = shuffle(X_train, y_train, random_state=0)
X_val, y_val = shuffle(X_val, y_val, random_state=0)

print("After  shuffle, y_val[:30] =", y_val[:30])

## one-hot encoding y
https://www.tensorflow.org/api_docs/python/tf/keras/utils/to_categorical

In [None]:
print("Before one-hot encoding y_train[0] =", y_train[0])

y_train = to_categorical(y_train)
y_val = to_categorical(y_val)

print("After one-hot encoding y_train[0] =", y_train[0])
print("y_train shape =", y_train.shape)
print("y_val shape =", y_val.shape)

# 2 Model
https://www.tensorflow.org/api_docs/python/tf/keras/Sequential
https://keras.io/layers/about-keras-layers/

In [None]:
#@title build paper model with Keras Functional-API
n_obs, feature, depth = X_train.shape

inp = Input(shape=(feature, depth))
C = Conv1D(filters=32, kernel_size=5, strides=1)(inp)

C11 = Conv1D(filters=32, kernel_size=5, strides=1, padding='same')(C)
A11 = Activation("relu")(C11)
C12 = Conv1D(filters=32, kernel_size=5, strides=1, padding='same')(A11)
S11 = Add()([C12, C])
A12 = Activation("relu")(S11)
M11 = MaxPooling1D(pool_size=5, strides=2)(A12)

C21 = Conv1D(filters=32, kernel_size=5, strides=1, padding='same')(M11)
A21 = Activation("relu")(C21)
C22 = Conv1D(filters=32, kernel_size=5, strides=1, padding='same')(A21)
S21 = Add()([C22, M11])
A22 = Activation("relu")(S11)
M21 = MaxPooling1D(pool_size=5, strides=2)(A22)

C31 = Conv1D(filters=32, kernel_size=5, strides=1, padding='same')(M21)
A31 = Activation("relu")(C31)
C32 = Conv1D(filters=32, kernel_size=5, strides=1, padding='same')(A31)
S31 = Add()([C32, M21])
A32 = Activation("relu")(S31)
M31 = MaxPooling1D(pool_size=5, strides=2)(A32)

C41 = Conv1D(filters=32, kernel_size=5, strides=1, padding='same')(M31)
A41 = Activation("relu")(C41)
C42 = Conv1D(filters=32, kernel_size=5, strides=1, padding='same')(A41)
S41 = Add()([C42, M31])
A42 = Activation("relu")(S41)
M41 = MaxPooling1D(pool_size=5, strides=2)(A42)

C51 = Conv1D(filters=32, kernel_size=5, strides=1, padding='same')(M41)
A51 = Activation("relu")(C51)
C52 = Conv1D(filters=32, kernel_size=5, strides=1, padding='same')(A51)
S51 = Add()([C52, M41])
A52 = Activation("relu")(S51)
M51 = MaxPooling1D(pool_size=5, strides=2)(A52)

F1 = Flatten()(M51)

D1 = Dense(32)(F1)
A6 = Activation("relu")(D1)
D2 = Dense(32)(A6)
D3 = Dense(5)(D2)
A7 = Softmax()(D3)

model = Model(inputs=inp, outputs=A7)
model.summary()

In [None]:
#@title plot model
plot_model(model, to_file="model_plot.png", show_shapes=True, dpi=60)

In [None]:
#@title Compile Model
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

# 3 Train

In [None]:
batch_size = 10
num_epochs = 10

history = model.fit(X_train, y_train,
                    epochs=num_epochs,
                    batch_size=batch_size,
                    validation_data=(X_val, y_val),
                    shuffle=True)

In [None]:
# The history.history attribute is a dictionary
history.history.keys()

In [None]:
#@title Training history visualization
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6,12))

# Plot training & validation accuracy values
ax1.plot(history.history['accuracy'])  # fix
ax1.plot(history.history['val_accuracy'])  # fix
ax1.set_title('Model accuracy')
ax1.set(ylabel='Accuracy', xlabel='Epoch')
ax1.legend(['Train', 'Valid'], loc='upper left')

# Plot training & validation loss values
ax2.plot(history.history['loss'])
ax2.plot(history.history['val_loss'])
ax2.set_title('Model loss')
ax2.set(ylabel='Loss', xlabel='Epoch')
ax2.legend(['Train', 'Valid'], loc='upper right')

plt.savefig('train_history.png', dpi=72)  # <-- save plot
plt.show()

In [None]:
#@title save model
model.save("ecg_arrhythmia.keras")

In [None]:
#@title evaluate model with valid_set
loss, acc = model.evaluate(X_val,  y_val, verbose=2)
print("Valid_set accuracy: {:5.2f}%".format(100*acc))

# 4 Test

In [None]:
#@title Prepare test data
CLASSES = ['C0', 'C1', 'C2', 'C3', 'C4']

df2 = pd.read_csv("mitbih_test.csv", header=None)
M = df2.values  # numpy array M
X_test = M[:, :-1]
y_test = M[:, -1].astype(int)
X_test = X_test[:,:,np.newaxis]  # as shape (number_of_data, X_lengh, channel)
y_test = to_categorical(y_test)  # one-hot encoding

print("X_test:", X_test.shape)
print("y_test:", y_test.shape)

In [None]:
#@title restore model
model = load_model('ecg_arrhythmia.keras')
model.trainable = False

In [None]:
#@title start inference model
y_pred = model.predict(X_test, batch_size=1000)
print("y_pred shape =", y_pred.shape)
y_pred

In [None]:
#@title plot ROC curve
#@title Compute ROC curve and ROC area for each class
N_CLASSES = len(CLASSES)
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(N_CLASSES):
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_pred[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_pred.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

# Compute macro-average ROC curve and ROC area

lw = 2

# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(N_CLASSES)]))

# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(N_CLASSES):
    mean_tpr += interp(all_fpr, fpr[i], tpr[i])

# Finally average it and compute AUC
mean_tpr /= N_CLASSES

fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

# Plot all ROC curves
fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(111)

# plt.plot(fpr["micro"], tpr["micro"],
#          label='micro-average ROC curve (area = {0:0.2f})'
#                ''.format(roc_auc["micro"]),
#          color='deeppink', linestyle=':', linewidth=4)

# plt.plot(fpr["macro"], tpr["macro"],
#          label='macro-average ROC curve (area = {0:0.2f})'
#                ''.format(roc_auc["macro"]),
#          color='navy', linestyle=':', linewidth=4)

colors = cycle(['cornflowerblue', 'green', 'darkorange', 'red'])
for i, color in zip(range(N_CLASSES), colors):
    ax.plot(fpr[i], tpr[i], color=color, lw=lw,
             label='ROC curve: {0}\n(area = {1:0.2f})'
             ''.format(CLASSES[i], roc_auc[i]))

ax.plot([0, 1], [0, 1], 'k--', lw=lw)
ax.set(xlim=(0.0, 1.0), ylim=(0.0, 1.0))
ax.axis('equal')
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('Some extension of Receiver operating characteristic to multi-class')
ax.legend(loc="lower right")
plt.show()

In [None]:
#@title plot confusion matrix
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.plot_confusion_matrix.html#sklearn.metrics.plot_confusion_matrix
class_names = CLASSES
plt.rcParams.update({'font.size': 12})

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
#         print("Normalized confusion matrix")
    else:
#         print('Confusion matrix, without normalization')
        pass

#     print(cm)

    fig = plt.figure(figsize=(5,5))
    ax = fig.add_subplot(111)
    ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.set_title(title)
#     plt.colorbar()
    tick_marks = np.arange(len(classes))
    ax.axis('equal')
    plt.xticks(tick_marks, classes)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in product(range(cm.shape[0]), range(cm.shape[1])):
        ax.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    ax.set_ylabel('True label')
    ax.set_xlabel('Predicted label')
    plt.tight_layout()

# Compute confusion matrix
cnf_matrix = confusion_matrix(np.argmax(y_test, axis=1), np.argmax(y_pred, axis=1))
np.set_printoptions(precision=2)

# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names,
                      title='Confusion matrix, without normalization')

# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
                      title='Normalized confusion matrix')

plt.show()

In [None]:
#@title classification report
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html?highlight=classification_report#sklearn.metrics.classification_report

print(classification_report(y_test.argmax(axis=1), y_pred.argmax(axis=1), target_names=CLASSES))