In [12]:
import pandas as pd
import numpy as np
import wfdb
import ast
import json
import os

In [13]:
def load_raw_data(df, sampling_rate, path):
    if sampling_rate == 100:
        data = [wfdb.rdsamp(path+f) for f in df.filename_lr]
    else:
        data = [wfdb.rdsamp(path+f) for f in df.filename_hr]
    data = np.array([signal for signal, meta in data])
    return data

path = '../ptb-xl/'
sampling_rate=100

# Major Labels
- NORM
- STTC
- MI
- HYP
- CD

# Load up the default database

In [14]:
# load and convert annotation data
files = pd.read_csv(os.path.join(path,'ptbxl_database.csv'), index_col='ecg_id')
files.scp_codes = files.scp_codes.apply(lambda x: ast.literal_eval(x))
files[['age', 'sex', 'height', 'weight', 'scp_codes']].head(20)

Unnamed: 0_level_0,age,sex,height,weight,scp_codes
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1,56.0,1,,63.0,"{'NORM': 100.0, 'LVOLT': 0.0, 'SR': 0.0}"
2,19.0,0,,70.0,"{'NORM': 80.0, 'SBRAD': 0.0}"
3,37.0,1,,69.0,"{'NORM': 100.0, 'SR': 0.0}"
4,24.0,0,,82.0,"{'NORM': 100.0, 'SR': 0.0}"
5,19.0,1,,70.0,"{'NORM': 100.0, 'SR': 0.0}"
6,18.0,1,,58.0,"{'NORM': 100.0, 'SR': 0.0}"
7,54.0,0,,83.0,"{'NORM': 100.0, 'SR': 0.0}"
8,48.0,0,,95.0,"{'IMI': 35.0, 'ABQRS': 0.0, 'SR': 0.0}"
9,55.0,0,,70.0,"{'NORM': 100.0, 'SR': 0.0}"
10,22.0,1,,56.0,"{'NORM': 100.0, 'SR': 0.0}"


In [15]:
print(f'# of Height: {files.height.count()}')
print(f'# of Weight: {files.weight.count()}')
print(f'# of Age: {files.age.count()}')
print(f'# of Sex: {files.sex.count()}')


# of Height: 6974
# of Weight: 9421
# of Age: 21799
# of Sex: 21799


# Check which normal codes are less than 100 percent confidence, remove them

In [16]:
def less_than_100_check(row):
    codes = row['scp_codes']
    if 'NORM' in codes and codes['NORM'] < 100:
        return 1
    else:
        return 0
    
removal_mask = files.apply(less_than_100_check, axis=1).values
filtered_files = files[removal_mask == 0]
print(f"Number of normal codes with 100% confidence: {len(filtered_files)}")

Number of normal codes with 100% confidence: 19457


# Apply high level diagnostic codes for each of the classes

In [17]:
# Load scp_statements.csv for diagnostic aggregation
agg_df = pd.read_csv(path+'scp_statements.csv', index_col=0)
agg_df = agg_df[agg_df.diagnostic == 1]

def aggregate_diagnostic(y_dic):
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].diagnostic_class)
    return list(set(tmp))

# Apply diagnostic superclass
filtered_files['diagnostic_superclass'] = filtered_files.scp_codes.apply(aggregate_diagnostic)
filtered_files['diagnostic_superclass'].head(20)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_files['diagnostic_superclass'] = filtered_files.scp_codes.apply(aggregate_diagnostic)


ecg_id
1     [NORM]
3     [NORM]
4     [NORM]
5     [NORM]
6     [NORM]
7     [NORM]
8       [MI]
9     [NORM]
10    [NORM]
13    [NORM]
14    [NORM]
15    [NORM]
16    [NORM]
17        []
18        []
19    [NORM]
20        []
21    [NORM]
22    [STTC]
23        []
Name: diagnostic_superclass, dtype: object

# Count how many of each class are present

In [18]:
diagnostic_superclass = filtered_files['diagnostic_superclass'].value_counts()
print(diagnostic_superclass)
classes = diagnostic_superclass.index
counts = diagnostic_superclass.values
frame = pd.DataFrame([co for co in zip(classes, counts)], columns=['class', 'counts'])

diagnostic_superclass
[NORM]                 7004
[MI]                   2532
[STTC]                 2400
[CD]                   1708
[CD, MI]               1297
[HYP, STTC]             781
[STTC, MI]              599
[HYP]                   535
[CD, STTC]              471
[]                      411
[HYP, STTC, MI]         361
[HYP, CD]               300
[CD, STTC, MI]          223
[HYP, MI]               183
[NORM, CD]              143
[HYP, CD, MI]           117
[HYP, STTC, CD]         109
[CD, HYP, STTC, MI]      99
[HYP, CD, STTC]          84
[HYP, CD, STTC, MI]      53
[NORM, STTC]             22
[CD, STTC, HYP]          18
[CD, STTC, HYP, MI]       4
[NORM, HYP]               2
[NORM, CD, STTC]          1
Name: count, dtype: int64


## Remove classes below the count of 20, as well as ones which contain NORM as well as other labels.

In [19]:
# Convert the index to a list and drop rows based on a condition
def convert_and_check(row):
    classes = row['class']
    count = row['counts']
    if 'NORM' in classes and len(classes) > 1:
        return 1    # Drop
    elif len(classes) == 0:
        return 1
    elif count < 20:
        return 1
    else:
        return 0

drop_indices = frame.apply(convert_and_check, axis=1).values
filtered_classes = frame[drop_indices == 0]
print(filtered_classes)

                  class  counts
0                [NORM]    7004
1                  [MI]    2532
2                [STTC]    2400
3                  [CD]    1708
4              [CD, MI]    1297
5           [HYP, STTC]     781
6            [STTC, MI]     599
7                 [HYP]     535
8            [CD, STTC]     471
10      [HYP, STTC, MI]     361
11            [HYP, CD]     300
12       [CD, STTC, MI]     223
13            [HYP, MI]     183
15        [HYP, CD, MI]     117
16      [HYP, STTC, CD]     109
17  [CD, HYP, STTC, MI]      99
18      [HYP, CD, STTC]      84
19  [HYP, CD, STTC, MI]      53


## Remove the undesired classes from the dataset

In [20]:
def all_values_match(x, y):
    '''Check if all values in x are in y'''
    if len(x) != len(y):
        return 0
    else:
        for value in x:
            if value not in y:
                return 0
    return 1

def in_class_list(row):
    '''Check if the diagnostic superclass is in the filtered classes'''
    try:
        for allowable_classes in filtered_classes['class'].values:
            if all_values_match(allowable_classes, row['diagnostic_superclass']):
                return 1
        return 0
    
    except:
        print(row['diagnostic_superclass'])
        print(filtered_classes['class'].values)

keep_indices = filtered_files.apply(in_class_list, axis=1)
filtered_files = filtered_files[keep_indices == 1]


# Label 0 for normal, 1 for abnormal

In [22]:
def is_normal(row):
    '''Check if the diagnostic superclass is normal'''
    if 'NORM' in row['diagnostic_superclass']:
        return 1
    else:
        return 0

In [23]:
norm_col = filtered_files.apply(is_normal, axis=1)

filtered_files['NORM'] = norm_col
filtered_files['ABNORM'] = (filtered_files['NORM'] == 0).astype(int)

In [24]:
filtered_files

Unnamed: 0_level_0,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,...,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr,diagnostic_superclass,NORM,ABNORM
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,sinusrhythmus periphere niederspannung,...,,,,,3,records100/00000/00001_lr,records500/00000/00001_hr,[NORM],1,0
3,20372.0,37.0,1,,69.0,2.0,0.0,CS-12 E,1984-11-15 12:49:10,sinusrhythmus normales ekg,...,,,,,5,records100/00000/00003_lr,records500/00000/00003_hr,[NORM],1,0
4,17014.0,24.0,0,,82.0,2.0,0.0,CS-12 E,1984-11-15 13:44:57,sinusrhythmus normales ekg,...,,,,,3,records100/00000/00004_lr,records500/00000/00004_hr,[NORM],1,0
5,17448.0,19.0,1,,70.0,2.0,0.0,CS-12 E,1984-11-17 10:43:15,sinusrhythmus normales ekg,...,,,,,4,records100/00000/00005_lr,records500/00000/00005_hr,[NORM],1,0
6,19005.0,18.0,1,,58.0,2.0,0.0,CS-12 E,1984-11-28 13:32:13,sinusrhythmus normales ekg,...,,,,,4,records100/00000/00006_lr,records500/00000/00006_hr,[NORM],1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21833,17180.0,67.0,1,,,1.0,2.0,AT-60 3,2001-05-31 09:14:35,ventrikulÄre extrasystole(n) sinustachykardie ...,...,,,1ES,,7,records100/21000/21833_lr,records500/21000/21833_hr,[STTC],0,1
21834,20703.0,300.0,0,,,1.0,2.0,AT-60 3,2001-06-05 11:33:39,sinusrhythmus lagetyp normal qrs(t) abnorm ...,...,,,,,4,records100/21000/21834_lr,records500/21000/21834_hr,[NORM],1,0
21835,19311.0,59.0,1,,,1.0,2.0,AT-60 3,2001-06-08 10:30:27,sinusrhythmus lagetyp normal t abnorm in anter...,...,,,,,2,records100/21000/21835_lr,records500/21000/21835_hr,[STTC],0,1
21836,8873.0,64.0,1,,,1.0,2.0,AT-60 3,2001-06-09 18:21:49,supraventrikulÄre extrasystole(n) sinusrhythmu...,...,,,SVES,,8,records100/21000/21836_lr,records500/21000/21836_hr,[NORM],1,0


In [25]:
len(filtered_files[filtered_files['NORM'] == 0])/len(filtered_files)

0.6289861214111664

In [26]:
filtered_files.to_json(os.path.join(path, 'updated_ptbxl_database.json'))