In [9]:
import pandas as pd
import numpy as np
import wfdb
import ast

In [10]:
def load_raw_data(df, sampling_rate, path):
    if sampling_rate == 100:
        data = [wfdb.rdsamp(path+f) for f in df.filename_lr]
    else:
        data = [wfdb.rdsamp(path+f) for f in df.filename_hr]
    data = np.array([signal for signal, meta in data])
    return data

path = '~/Desktop/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3/'
sampling_rate=100

In [11]:
# load and convert annotation data
Y = pd.read_csv(path+'ptbxl_database.csv', index_col='ecg_id')
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))

In [12]:
# Load scp_statements.csv for diagnostic aggregation
agg_df = pd.read_csv(path+'scp_statements.csv', index_col=0)
agg_df = agg_df[agg_df.diagnostic == 1]

In [13]:
def aggregate_diagnostic(y_dic):
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].diagnostic_class)
    return list(set(tmp))

In [14]:
# Apply diagnostic superclass
Y['diagnostic_superclass'] = Y.scp_codes.apply(aggregate_diagnostic)

In [15]:
print(Y['diagnostic_superclass'].value_counts())

diagnostic_superclass
[NORM]                 9069
[MI]                   2532
[STTC]                 2400
[CD]                   1708
[MI, CD]               1297
[HYP, STTC]             781
[MI, STTC]              599
[HYP]                   535
[STTC, CD]              471
[]                      411
[NORM, CD]              407
[MI, HYP, STTC]         340
[HYP, CD]               300
[MI, STTC, CD]          223
[MI, HYP]               183
[MI, HYP, CD]           117
[CD, MI, HYP, STTC]      93
[CD, HYP, STTC]          89
[STTC, HYP, CD]          67
[HYP, STTC, CD]          55
[STTC, MI, HYP, CD]      48
[NORM, STTC]             28
[HYP, MI, STTC]          21
[HYP, MI, STTC, CD]      15
[NORM, STTC, CD]          5
[NORM, HYP, CD]           2
[NORM, HYP]               2
[NORM, MI, HYP, CD]       1
Name: count, dtype: int64


In [None]:
# Convert each class to its own label
diseases = ['NORM', 'MI', 'STTC', 'CD', 'HYP']
#Y['diagnostic_superclass'] = Y['diagnostic_superclass'].str.strip('[]').str.split(', ')

for disease in diseases:
    Y[disease] = Y['diagnostic_superclass'].apply(lambda x: 1 if disease in x else 0)

# Create the NORM column, excluding NORM from the calculation
Y['AD'] = Y[diseases[1:]].sum(axis=1).apply(lambda x: 1 if x == 0 else 0)

In [23]:
label_counts = Y[['MI', 'STTC', 'CD', 'HYP', 'AD']].sum()
print(label_counts)

MI      5469
STTC    5235
CD      4898
HYP     2649
AD      9480
dtype: int64


In [26]:
Y.to_json('./updated_ptbxl_database.json')