In [2]:
import pandas as pd
import numpy as np
import wfdb
import ast
import torch

def load_raw_data(df, sampling_rate, path):
    if(sampling_rate == 100):
        data = [wfdb.rdsamp(path + f) for f in df.filename_lr]
    else:
        data = [wfdb.rdsamp(path + f) for f in df.filename_hr]
    data = np.array([signal for signal, meta in data])
    return data


path = "ptb_xl/"
sampling_rate = 100

# load and convert annotation data
Y = pd.read_csv(path+'ptbxl_database.csv', index_col='ecg_id')
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))

# Load raw signal data
X = load_raw_data(Y, sampling_rate, path)

# Load scp_statements.csv for diagnostic aggregation
agg_df = pd.read_csv(path+'scp_statements.csv', index_col=0)
agg_df = agg_df[agg_df.diagnostic == 1]

def aggregate_diagnostic(y_dic):
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].diagnostic_class)
    return list(set(tmp))

# Apply diagnostic superclass
Y['diagnostic_superclass'] = Y.scp_codes.apply(aggregate_diagnostic)

# Split data into train and test
test_fold = 10
# Train
X_train = X[np.where(Y.strat_fold != test_fold)]
y_train = Y[(Y.strat_fold != test_fold)].diagnostic_superclass
# Test
X_test = X[np.where(Y.strat_fold == test_fold)]
y_test = Y[Y.strat_fold == test_fold].diagnostic_superclass

def multihot_encoder(labels, n_categories = 1, dtype=torch.float32):
    label_set = set()
    for label_list in labels:
        label_set = label_set.union(set(label_list))
    label_set = sorted(label_set)

    multihot_vectors = []
    for label_list in labels:
        multihot_vectors.append([1 if x in label_list else 0 for x in label_set])
    if dtype is None:
        return pd.DataFrame(multihot_vectors, columns=label_set)
    return torch.Tensor(multihot_vectors).to(dtype)

y_train = multihot_encoder(y_train, n_categories = 5)
y_test = multihot_encoder(y_test, n_categories= 5)

In [3]:
y_train

tensor([[0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        ...,
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.]])

In [5]:
dataset_train = torch.utils.data.TensorDataset(torch.tensor(X_train), torch.tensor(y_train))

  dataset_train = torch.utils.data.TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
