In [None]:
import numpy as np
import pandas as pd 
import os
import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.model_selection import train_test_split
from tensorflow.keras import utils
from collections import Counter
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.metrics import classification_report, confusion_matrix

In [None]:
X = np.load('/kaggle/input/ptbxl-atrial-fibrillation-detection/ecgeq-500hzsrfava.npy')
Y = pd.read_csv('/kaggle/input/ptbxl-atrial-fibrillation-detection/coorteeqsrafva.csv',sep=';')

In [None]:
Counter(Y['ritmi'])

In [None]:
print(Y.shape)
print(X.shape)

In [None]:
def convert(x):
    if x =='AF':
        y = 1
    elif x == 'SR':
        y = 0
    elif x == 'VA':
        y = 2
    return(y)

In [None]:
labels = Y['ritmi'].apply(convert)

In these ECG records there are 12 leads. 1 2 3, avL,

In [None]:
Xtrain,Xtest,pytrain,pytest = train_test_split(X[:,:,[1,10]],labels,test_size=.2) # we use lead 2 and lead v5

In [None]:
ytrain = utils.to_categorical(pytrain)
ytest  = utils.to_categorical(pytest)

In [None]:
def resnet1d(data,nfilters,res):
    x = tf.expand_dims(data,-1)

    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(x)
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(x)
    b1o = layers.BatchNormalization()(x)
    
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(b1o)
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(x)
    x = layers.BatchNormalization()(x)
    b2o = layers.add([x,b1o])

    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(b2o)
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(x)
    x = layers.BatchNormalization()(x)
    b3o = layers.add([x,b2o])
    
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(b3o)
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(x)
    x = layers.BatchNormalization()(x)
    b4o = layers.add([x,b3o])
    
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(b4o)
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(x)
    x = layers.BatchNormalization()(x)
    b5o = layers.add([x,b4o])
    
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(b5o)
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(x)
    x = layers.BatchNormalization()(x)
    b6o = layers.add([x,b5o])
    
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(b6o)
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(x)
    x = layers.BatchNormalization()(x)
    b7o = layers.add([x,b6o])
    
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(b7o)
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(x)
    x = layers.BatchNormalization()(x)
    b8o = layers.add([x,b7o])
    
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(b8o)
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(x)
    x = layers.BatchNormalization()(x)
    b9o = layers.add([x,b8o])
    
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(b9o)
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(x)
    x = layers.BatchNormalization()(x)
    b10o = layers.add([x,b9o])
    
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(b10o)
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(x)
    x = layers.BatchNormalization()(x)
    b11o = layers.add([x,b10o])
    
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(b11o)
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(x)
    x = layers.BatchNormalization()(x)
    b12o = layers.add([x,b11o])
    
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(b12o)
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(x)
    x = layers.BatchNormalization()(x)
    b13o = layers.add([x,b12o])
    
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(b13o)
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(x)
    x = layers.BatchNormalization()(x)
    b14o = layers.add([x,b13o])
    
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(b14o)
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(x)
    x = layers.BatchNormalization()(x)
    b15o = layers.add([x,b14o])
    
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(b15o)
    x = layers.Conv1D(nfilters,res,strides=1,activation='relu',padding='same')(x)
    x = layers.BatchNormalization()(x)
    b16o = layers.add([x,b15o])
    

    x = layers.Conv1D(nfilters*2,res,strides=1,activation='relu',padding='same')(b16o)

    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dense(64)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(.5)(x)
    x = layers.Dense(32)(x)
    x = layers.BatchNormalization()(x)

    return(x)

In [None]:
def inception_module(x, f1, f3, f5, name=None):
    
    c1 = layers.Conv1D(f1, 1, padding='same', activation='relu')(x)
    
    c3 = layers.Conv1D(f3, 3, padding='same', activation='relu')(x)

    c5 = layers.Conv1D(f5, 5, padding='same', activation='relu')(x)

    output = layers.concatenate([c1, c3, c5], axis=2, name=name)
    
    return output

In [None]:
#@tf.function
def incnet(data):
    x = tf.expand_dims(data,-1)
    
    x = layers.Conv1D(256,100,strides=2,activation='relu',padding='same')(x)
    x = layers.BatchNormalization()(x)
    
    x = inception_module(x, f1=80, f3=40, f5=20)
    x = layers.BatchNormalization()(x)
    x = inception_module(x, f1=80, f3=40, f5=20)
    x = layers.BatchNormalization()(x)
    x = inception_module(x, f1=80, f3=40, f5=20)
    x = layers.BatchNormalization()(x)
    x = inception_module(x, f1=80, f3=40, f5=20)
    x = layers.BatchNormalization()(x)
    
    x = layers.Conv1D(32,5,strides=1,activation='relu',padding='same')(x)
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dense(32)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(.5)(x)
    x = layers.Dense(16)(x)
    x = layers.BatchNormalization()(x)

    return(x)

In [None]:
#in1 = layers.Input(shape=(5000))
inl2 = layers.Input(shape=(5000))
#in3 = layers.Input(shape=(5000))
inv5 = layers.Input(shape=(5000))

#o1 = resnet1d(in1)
#ol2 = resnet1d(inl2,32,25)
ol2 = incnet(inl2)
#o3 = resnet1d(in3)
#ov5 = resnet1d(inv5,32,25)
ov5 = incnet(inv5)

#x = layers.add([o1,o2,o3,o4])

x = layers.add([ol2,ov5])
outputs = layers.Dense(3,'softmax')(x)

In [None]:
#model = models.Model([in1,in2,in3,in4],outputs)
model = models.Model([inl2,inv5],outputs)

In [None]:
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
model.count_params()
#model.summary()

In [None]:
callbacks = EarlyStopping(monitor='val_accuracy',patience=30,min_delta=.05,restore_best_weights=1)

In [None]:
#model.fit([Xtrain[:,:,0],Xtrain[:,:,1],Xtrain[:,:,2],Xtrain[:,:,3]],ytrain,epochs=50,callbacks=callbacks,validation_split=.2)
model.fit([Xtrain[:,:,0],Xtrain[:,:,1]],ytrain,epochs=100,callbacks=callbacks,validation_split=.2,batch_size=12)

In [None]:
# Test Model
#yp = model.predict([Xtest[:,:,0],Xtest[:,:,1],Xtest[:,:,2],Xtest[:,:,3]])
yp = model.predict([Xtest[:,:,0],Xtest[:,:,1]])
yp = np.argmax(yp,axis=1)
ya = np.argmax(ytest,axis=1)
cm = confusion_matrix(yp, ya)
print(cm)

In [None]:
print(classification_report(ya, yp, target_names=['SR','AF','VA']))

In [None]:
FP = cm.sum(axis=1) - np.diag(cm)  
FN = cm.sum(axis=0) - np.diag(cm)
TP = np.diag(cm)
TN = cm.sum() - (FP + FN + TP)

# Sensitivity, hit rate, recall, or true positive rate
TPR = TP/(TP+FN)
# Specificity or true negative rate
TNR = TN/(TN+FP) 
# Precision or positive predictive value
PPV = TP/(TP+FP)
# Negative predictive value
NPV = TN/(TN+FN)
# Fall out or false positive rate
FPR = FP/(FP+TN)
# False negative rate
FNR = FN/(TP+FN)
# False discovery rate
FDR = FP/(TP+FP)

# Overall accuracy
ACC = (TP+TN)/(TP+FP+FN+TN)

In [None]:
print(TPR,TNR,ACC)