# ECG Rhythm Classification
## 1. Create Training Dataset
### Sebastian D. Goodfellow, Ph.D.

# Setup Noteboook

In [None]:
# Import 3rd party libraries
import os
import sys
import json
import numpy as np
import pandas as pd
import matplotlib.pylab as plt

# Import local Libraries
#sys.path.insert(0, r'C:\Users\sebastian goodfellow\Documents\code\deep_ecg')
sys.path.insert(0, r'D:\GIT\deepECG_MIP\deepecg')
from deepecg.config.config import DATA_DIR
from deepecg.training.data.ecg import ECG

# Configure Notebook
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
%load_ext autoreload
%autoreload 2

# 1. Load Training Labels

In [None]:
# Save path
path_save=os.path.join(DATA_DIR, 'training', 'disc', 'train')

# Load training labels
labels_train = json.load(open(os.path.join(path_save, 'labels', 'labels.json')))

# Label lookup
label_lookup = {'N': 0, 'A': 1, 'O': 2, '~': 3}

# 2. MIT-BIH Atrial Fibrillation Database

In [None]:
def set_duration(waveform, length):
    """Set duration of ecg waveform."""
    if len(waveform) > length:
        return waveform[0:length]
    else:
        return waveform

In [None]:
# Inputs
path_data=os.path.join(DATA_DIR, 'db2')
duration=60.
fs = 300
length = int(duration * fs)

# Load labels
labels = pd.read_csv(os.path.join(path_data, 'labels', 'labels.csv'))

# labels dictionary
labels_dict_db2 = dict()

# Loop through files
for idx, row in labels.iterrows():
    
    # Load waveform
    waveform = np.load(os.path.join(path_data, 'waveforms', row['file_name']))
    
    try:
        # Process ECG waveform
        ecg = ECG(file_name=row['file_name'], label=row['train_label'], waveform=waveform, filter_bands=[3, 45], fs=fs)

        # Set waveform duration
        waveform = set_duration(waveform=ecg.filtered, length=length)

        if len(waveform) < length:
            # Get remainder
            remainder = length - len(waveform)

            # Pad waveform
            waveform = np.pad(waveform, (int(remainder / 2), remainder - int(remainder / 2)), 'constant', constant_values=0)

        # Get label
        labels_dict_db2[row['file_name'].split('.')[0]] = label_lookup[row['train_label']]

        # Save waveform
        np.save(os.path.join(path_save, 'waveforms', row['file_name']), waveform)
        
    except:
        pass

# 3. MIT-BIH Normal Sinus Rhythm Database

In [None]:
# Inputs
path_data=os.path.join(DATA_DIR, 'db3')
duration=60.
fs = 300
length = int(duration * fs)

# Load labels
labels = pd.read_csv(os.path.join(path_data, 'labels', 'labels.csv'))

# labels dictionary
labels_dict_db3 = dict()

# Loop through files
for idx, row in labels.iterrows():
    
    # Load waveform
    waveform = np.load(os.path.join(path_data, 'waveforms', row['file_name']))
    
    try:
        # Process ECG waveform
        ecg = ECG(file_name=row['file_name'], label=row['train_label'], waveform=waveform, filter_bands=[3, 45], fs=fs)

        # Set waveform duration
        waveform = set_duration(waveform=ecg.filtered, length=length)

        if len(waveform) < length:
            # Get remainder
            remainder = length - len(waveform)

            # Pad waveform
            waveform = np.pad(waveform, (int(remainder / 2), remainder - int(remainder / 2)), 'constant', constant_values=0)

        # Get label
        labels_dict_db3[row['file_name'].split('.')[0]] = label_lookup[row['train_label']]

        # Save waveform
        np.save(os.path.join(path_save, 'waveforms', row['file_name']), waveform)
        
    except:
        pass

# 4. Merge Labels

In [None]:
# Merge labels
labels_train.update(labels_dict_db2)
labels_train.update(labels_dict_db3)

# Save labels
with open(os.path.join(path_save, 'labels', 'labels.json'), 'w') as file:
    json.dump(labels_train, file, sort_keys=True)