In [1]:
import wfdb
import matplotlib.pyplot as plt
import numpy as np
import scipy.signal as signal
import os

In [2]:
RECORDS_TO_USE = ['100', '101', '103', '105', '106', '108', '109', '111', '112', '113', '114', '115', '116', '117', '118', '119', '121', '122', '123', '124', '200', '201', '202', '203', '205', '207', '208', '209', '210', '212', '213', '214', '215', '219', '220', '221', '222', '223', '228', '230', '231', '232', '233', '234']

In [3]:
# Selects which records are used for training and which are used for evaluation
NUM_RECORDS = len(RECORDS_TO_USE)
NUM_EVAL_RECORDS = int(np.ceil(0.1*NUM_RECORDS))
EVAL_RECORD_INDICES = np.random.choice(len(RECORDS_TO_USE), (NUM_EVAL_RECORDS), replace = False)
EVAL_RECORDS = []
TRAIN_RECORDS = []
for index in range(NUM_RECORDS):
    if index in EVAL_RECORD_INDICES:
        EVAL_RECORDS += [RECORDS_TO_USE[index]]
    else:
        TRAIN_RECORDS += [RECORDS_TO_USE[index]]

In [4]:
SAMPLE_LENGTH = 7424

In [5]:
path = os.path.join('./mit-bih-arrhythmia-database-1.0.0/', RECORDS_TO_USE[0])

In [6]:
def FindBeats(start, end, beats, samples):
    a = samples >= start
    b = samples <= end
    mask = np.where(a & b)
    return beats[mask]

def CheckBeats(beats):
    for b in beats:
        if b != 'N':
            #print(b)
            return False
        
    return True

In [7]:
def PrepRecord(path, new_fs):
    record = wfdb.rdrecord(path)
    annotation = wfdb.rdann(path, 'atr')
    data = record.p_signal[:,0]
    len_data = data.shape[0]
    new_len = int(np.round(len_data * new_fs / record.fs))
    re_data = signal.resample(data, new_len)
    beats = annotation.symbol
    samples = [s/record.fs*new_fs for s in annotation.sample]
    return [re_data, np.array(beats), np.array(samples)]

In [28]:
def GenerateTrainingData(path):
    data, beats, samples = PrepRecord(path, 250)
    
    # Finds Valid indices
    # Ignores the first 10 seconds
    index_start = 2500
    index_end = data.shape[0] - SAMPLE_LENGTH
    # Start is varied by 2 seconds for each record
    index_range = 500
    # Creates empty matrices
    NUM_SAMPLES = 500
    training_data = np.zeros((NUM_SAMPLES, SAMPLE_LENGTH, 1))
    training_label = np.zeros((NUM_SAMPLES, 1))

    ii = 0
    while(True):
    
        start_index = int(np.random.rand()*index_range + index_start)
        end_index = int(start_index + SAMPLE_LENGTH)
        if (end_index > data.shape[0]):
            break
            
        b = FindBeats(start_index, end_index, beats, samples)
        
        index_start = end_index
        
        # Checks for annotations we don't want
        if '~' in b:
            continue
        if 'U' in b:
            continue
        if '?' in b:
            continue
        if '|' in b:
            continue
        if 'Q' in b:
            continue
        if '+' in b:
            continue
        
        
        if CheckBeats(b) is True:
            training_label[ii] = 0
        else:
            training_label[ii] = 1
        training_data[ii, :, 0] = data[start_index:end_index] 
#        if ii >= 49:
#            break
        ii += 1
    print(ii)
    return training_data[0:ii, :, :], training_label[0:ii, :]

In [29]:
train_data = []
train_labels = []
for rec in TRAIN_RECORDS:
    path = os.path.join('./mit-bih-arrhythmia-database-1.0.0/', rec)
    d, l = GenerateTrainingData(path)
    
    train_data += [d]
    train_labels += [l]
train_data = np.concatenate(train_data, 0)
train_labels = np.concatenate(train_labels, 0)

58
53
19
25
30
56
54
53
58
53
54
57
52
17
53
57
58
50
51
24
54
44
25
32
45
51
39
43
42
53
42
22
38
27
12
48
45
32
52


In [30]:
eval_data = []
eval_labels = []
for rec in EVAL_RECORDS:
    path = os.path.join('./mit-bih-arrhythmia-database-1.0.0/', rec)
    d, l = GenerateTrainingData(path)
    
    eval_data += [d]
    eval_labels += [l]
eval_data = np.concatenate(eval_data, 0)
eval_labels = np.concatenate(eval_labels, 0)

54
52
9
36
42


In [34]:
np.save('./data/train_data', train_data)
np.save('./data/train_labels', train_labels)
np.save('./data/eval_data', eval_data)
np.save('./data/eval_labels',eval_labels)

In [35]:
print(train_data.shape)
print(train_labels.shape)
print(np.sum(train_labels))
print(eval_data.shape)
print(eval_labels.shape)
print(np.sum(eval_labels))

(1678, 7424, 1)
(1678, 1)
930.0
(193, 7424, 1)
(193, 1)
89.0
