In [None]:
#!pip install wfdb

In [1]:
# Importing needed libraries
from matplotlib import pyplot as plt
from wfdb.io import get_record_list
from wfdb import rdsamp, processing
import numpy as np
import random
from scipy.signal import resample_poly
from pickle import dump, load
import torch
import torch.nn as nn
from torch.utils import data
from sklearn.metrics import recall_score
import tqdm

import utils
from data_generator import dataset_gen

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Loading data

In [None]:
load_flag = True
db_name = 'ltafdb'

if load_flag:
    # import files
    records = get_record_list(db_name)
    signals, beats, beat_types = utils.data_from_records(records, channel=0, db=db_name)
    # save file
    #dump(signals, open('./tests db/' + db_name + '_signals.pkl', 'wb'))
    #dump(beats, open('./tests db/' + db_name + '_beats.pkl', 'wb'))
    #dump(beat_types, open('./tests db/' + db_name + 'beat_types.pkl', 'wb'))
else:
    signals = load(open('./tests db/' + db_name + '_signals.pkl', 'rb'))
    beats = load(open('./tests db/' + db_name + '_beats.pkl', 'rb'))
    beat_types = load(open('./tests db/' + db_name + 'beat_types.pkl', 'rb'))


  5%|▍         | 4/84 [01:20<29:41, 22.27s/it]

## Different R-peak types

In [None]:
# Extract beat symbols from all records
all_symbols = []
for symbols in beat_types:
    all_symbols.append(symbols)
    
all_symbols = [item for sublist in all_symbols for item in sublist]
all_symbols = np.asarray(all_symbols)
u, c = np.unique(all_symbols, return_counts=True)

# Meanings for different heart beat codings
label_meanings = {
    "N": "Normal beat",
    "L": "Left bundle branch block beat",
    "R": "Right bundle branch block beat",
    "V": "Premature ventricular contraction",
    "/": "Paced beat",
    "A": "Atrial premature beat",
    "f": "Fusion of paced and normal beat",
    "F": "Fusion of ventricular and normal beat",
    "j": "Nodal (junctional) escape beat",
    "a": "Aberrated atrial premature beat",
    "E": "Ventricular escape beat",
    "J": "Nodal (junctional) premature beat",
    "Q": "Unclassifiable beat",
    "e": "Atrial escape beat",
    "S": "Supraventricular premature or ectopic"
}

# Print number of instances in each beat type
label_counts = [(label, count) for label, count in zip(u.tolist(), c.tolist())]
label_counts.sort(key=lambda tup: tup[1], reverse=True)
for label in label_counts:
    print(label_meanings[label[0]], "-"*(40-len(label_meanings[label[0]])), label[1])

In [None]:
# Plot examples of beat types
plt.style.use('ggplot')
fig, (ax1, ax2) = plt.subplots(1, 2)
fig.set_figheight(5), fig.set_figwidth(18)
ax1.plot(signals[0][:1000])
ax1.set_title('Normal beats')
ax2.plot(signals[1][:1000])
ax2.set_title('Premature venticular contractions')
plt.show()

## Converting character encodings to numerical

In [None]:
# Create dictionary that encodes characters as numerical
label_codings = {}
for i in range(0, len(label_counts)):
    if label_counts[i][0] == 'N':
        label_codings[label_counts[i][0]] = 1
    else:
        label_codings[label_counts[i][0]] = -1
label_codings

In [None]:
# Save new numerical encodings as list
labels = []
for beat_types in beat_types:
    numerical_symbols = [label_codings[sym] for sym in beat_types]
    labels.append(np.asarray(numerical_symbols))
    
labels[:5]

In [None]:
# Plot signal and beat locations as dashed line
plt.figure(figsize=(10,5))
peaks = beats[0][0]
plt.plot(signals[0][peaks-100:peaks+100])
for peak in beats[0][:1]:
    plt.axvline(x=100, color='k', linestyle='--', alpha=0.5)
plt.title('Normal beats where R-peak location occurs at the bottom of a valley')
plt.show()

In [None]:
labels = utils.fix_labels(signals, beats, labels)

## Plot the noise

In [None]:
# Load data
baseline_wander = rdsamp('bw', pn_dir='nstdb')
muscle_artifact = rdsamp('ma', pn_dir='nstdb')

# Concatenate two channels to make one longer recording
ma = np.concatenate((muscle_artifact[0][:,0], muscle_artifact[0][:,1]))
bw = np.concatenate((baseline_wander[0][:,0], baseline_wander[0][:,1]))

# Resample noise to 250 Hz
ma = resample_poly(ma, up=250, down=muscle_artifact[1]['fs'])
bw = resample_poly(bw, up=250, down=baseline_wander[1]['fs'])

# Plot examples of baseline wandering and muscle artifact noise types
fig, (ax1, ax2) = plt.subplots(1, 2)
fig.set_figheight(7), fig.set_figwidth(18)
ax1.plot(bw[:5000])
ax1.set_title('Baseline wander')
ax2.plot(ma[:5000])
ax2.set_title('Muscle artifact')
plt.show()

## Create data loader

In [None]:
# Define Training Data
n_batch = 84
win_size = 1000
index_list = list(range(len(signals)))
print('Number of examples is %d' % len(index_list))


train_dataset = dataset_gen(signals=[signals[i] for i in index_list], 
                            peaks=[beats[i] for i in index_list], 
                            labels=[labels[i] for i in index_list], 
                            ma=None,
                            bw=None,
                            win_size=win_size,
                            add_noise = False)

train_loader = data.DataLoader(train_dataset, 
                               batch_size=n_batch, 
                               shuffle=True)


## Create dataset from loader

In [None]:
y_true = []
X = []
epochs = 50

# turn off gradients for evalute
for epoch in tqdm.tqdm(range(epochs)):
    with torch.no_grad():
        for i,batch in enumerate(train_loader):
            # get batch images and labels
            X_batch,y_true_batch = batch 

            # convert to cpu numpy
            X_batch = X_batch.cpu().squeeze(2).numpy()
            y_true_batch = y_true_batch.cpu().squeeze(2).numpy()

            # append true
            y_true.append(y_true_batch)
            X.append(X_batch)


### Save file to pickle

In [None]:
dump(X, open('./tests db/X_' + db_name + '.pkl', 'wb'))
dump(y_true, open('./tests db/y_true_' + db_name + '.pkl', 'wb'))

# Test Different Models

In [None]:
db = ['nsrdb', 'svdb', 'incartdb', 'edb']

In [None]:
db_name = "edb"
X = load(open('./tests db/X_' + db_name + '.pkl', 'rb'))
y_true = load(open('./tests db/y_true_' + db_name + '.pkl', 'rb'))

In [None]:
# load mode;
model = torch.load('./transformer_inception_label_correction_model.pt')
model.to(device)

# empty lists
y = []
y_pred = []

In [None]:
# turn off gradients for evalute
with torch.no_grad():
    for X_batch,y_true_batch in tqdm.tqdm(zip(X,y_true)):
        # get batch images and labels
        X_batch = torch.from_numpy(X_batch).unsqueeze(2).to(device)
        y_pred_batch = model(X_batch)
        
        # save to vectors
        y_true_batch = y_true_batch.flatten().tolist()
        y_pred_batch = torch.round(y_pred_batch.cpu().flatten()).numpy().tolist()
        y.append(y_true_batch)
        y_pred.append(y_pred_batch)
        
# flat arrays
y = np.array([item for sublist in y for item in sublist]).astype(int)
y_pred = np.array([item for sublist in y_pred for item in sublist]).astype(int)   

In [None]:
# calculate recall and specificty
recall = recall_score(y, y_pred) * 100
specificty = recall_score(y, y_pred, pos_label = 0) * 100
print('For gqrs algorithm, Recall is %.3f and Specificity is %.3f' % (recall, specificty))