# Data Loader for waveform data

In [1]:
import pandas as pd
import ast
import numpy as np

scp_statements_path = '../../../../../data/padmalab_external/special_project/physionet.org/files/ptb-xl/1.0.3/scp_statements.csv'
database_path = '../../../../../data/padmalab_external/special_project/physionet.org/files/ptb-xl/1.0.3/ptbxl_database.csv'

In [2]:
# scp statements file
df = pd.read_csv(scp_statements_path)
df.head()

Unnamed: 0.1,Unnamed: 0,description,diagnostic,form,rhythm,diagnostic_class,diagnostic_subclass,Statement Category,SCP-ECG Statement Description,AHA code,aECG REFID,CDISC Code,DICOM Code
0,NDT,non-diagnostic T abnormalities,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,non-diagnostic T abnormalities,,,,
1,NST_,non-specific ST changes,1.0,1.0,,STTC,NST_,Basic roots for coding ST-T changes and abnorm...,non-specific ST changes,145.0,MDC_ECG_RHY_STHILOST,,
2,DIG,digitalis-effect,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,suggests digitalis-effect,205.0,,,
3,LNGQT,long QT-interval,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,long QT-interval,148.0,,,
4,NORM,normal ECG,1.0,,,NORM,NORM,Normal/abnormal,normal ECG,1.0,,,F-000B7


In [3]:
# database
df2 = pd.read_csv(database_path)
df2.head()

Unnamed: 0,ecg_id,patient_id,age,sex,height,weight,nurse,site,device,recording_date,...,validated_by_human,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr
0,1,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,...,True,,", I-V1,",,,,,3,records100/00000/00001_lr,records500/00000/00001_hr
1,2,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,...,True,,,,,,,2,records100/00000/00002_lr,records500/00000/00002_hr
2,3,20372.0,37.0,1,,69.0,2.0,0.0,CS-12 E,1984-11-15 12:49:10,...,True,,,,,,,5,records100/00000/00003_lr,records500/00000/00003_hr
3,4,17014.0,24.0,0,,82.0,2.0,0.0,CS-12 E,1984-11-15 13:44:57,...,True,", II,III,AVF",,,,,,3,records100/00000/00004_lr,records500/00000/00004_hr
4,5,17448.0,19.0,1,,70.0,2.0,0.0,CS-12 E,1984-11-17 10:43:15,...,True,", III,AVR,AVF",,,,,,4,records100/00000/00005_lr,records500/00000/00005_hr


In [4]:
print(len(df2))
print(len(df2.loc[df2['validated_by_human'] == True]))

21799
16056


In [5]:
# load and convert annotation data
Y = pd.read_csv(database_path, index_col='ecg_id')
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))

In [6]:
# Apply lambda function to extract keys 
Y['scp_labels'] = Y.scp_codes.apply(lambda scp: [key for key, value in scp.items()])

In [7]:
Y.scp_labels

ecg_id
1               [NORM, LVOLT, SR]
2                   [NORM, SBRAD]
3                      [NORM, SR]
4                      [NORM, SR]
5                      [NORM, SR]
                   ...           
21833    [NDT, PVC, VCLVH, STACH]
21834           [NORM, ABQRS, SR]
21835                 [ISCAS, SR]
21836                  [NORM, SR]
21837                  [NORM, SR]
Name: scp_labels, Length: 21799, dtype: object

In [8]:
# Load scp_statements.csv for diagnostic aggregation
agg_df = pd.read_csv(scp_statements_path, index_col=0)
agg_df = agg_df[agg_df.diagnostic == 1]

def aggregate_diagnostic(y_dic):
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].diagnostic_class)
    return list(set(tmp))

# Apply diagnostic superclass
Y['diagnostic_superclass'] = Y.scp_codes.apply(aggregate_diagnostic)

In [9]:
Y.head(20)

Unnamed: 0_level_0,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,...,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr,scp_labels,diagnostic_superclass
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,sinusrhythmus periphere niederspannung,...,", I-V1,",,,,,3,records100/00000/00001_lr,records500/00000/00001_hr,"[NORM, LVOLT, SR]",[NORM]
2,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,sinusbradykardie sonst normales ekg,...,,,,,,2,records100/00000/00002_lr,records500/00000/00002_hr,"[NORM, SBRAD]",[NORM]
3,20372.0,37.0,1,,69.0,2.0,0.0,CS-12 E,1984-11-15 12:49:10,sinusrhythmus normales ekg,...,,,,,,5,records100/00000/00003_lr,records500/00000/00003_hr,"[NORM, SR]",[NORM]
4,17014.0,24.0,0,,82.0,2.0,0.0,CS-12 E,1984-11-15 13:44:57,sinusrhythmus normales ekg,...,,,,,,3,records100/00000/00004_lr,records500/00000/00004_hr,"[NORM, SR]",[NORM]
5,17448.0,19.0,1,,70.0,2.0,0.0,CS-12 E,1984-11-17 10:43:15,sinusrhythmus normales ekg,...,,,,,,4,records100/00000/00005_lr,records500/00000/00005_hr,"[NORM, SR]",[NORM]
6,19005.0,18.0,1,,58.0,2.0,0.0,CS-12 E,1984-11-28 13:32:13,sinusrhythmus normales ekg,...,,,,,,4,records100/00000/00006_lr,records500/00000/00006_hr,"[NORM, SR]",[NORM]
7,16193.0,54.0,0,,83.0,2.0,0.0,CS-12 E,1984-11-28 13:32:22,"sinusrhythmus linkstyp t abnormal, wahrscheinl...",...,,,,,,7,records100/00000/00007_lr,records500/00000/00007_hr,"[NORM, SR]",[NORM]
8,11275.0,48.0,0,,95.0,2.0,0.0,CS-12 E,1984-12-01 14:49:52,sinusrhythmus linkstyp qrs(t) abnormal infe...,...,", I-AVF,",,,,,9,records100/00000/00008_lr,records500/00000/00008_hr,"[IMI, ABQRS, SR]",[MI]
9,18792.0,55.0,0,,70.0,2.0,0.0,CS-12 E,1984-12-08 09:44:43,sinusrhythmus normales ekg,...,", I-AVR,",,,,,10,records100/00000/00009_lr,records500/00000/00009_hr,"[NORM, SR]",[NORM]
10,9456.0,22.0,1,,56.0,2.0,0.0,CS-12 E,1984-12-12 14:12:46,sinusrhythmus normales ekg,...,,,,,,9,records100/00000/00010_lr,records500/00000/00010_hr,"[NORM, SR]",[NORM]


In [10]:
import pandas as pd
from sklearn.preprocessing import LabelEncoder

# Step 1: Flatten the lists in both columns to find unique values
all_superclasses = [item for sublist in Y['diagnostic_superclass'] for item in sublist]
all_labels = [item for sublist in Y['scp_labels'] for item in sublist]

# Step 2: Create separate LabelEncoders for each column
superclass_encoder = LabelEncoder()
label_encoder = LabelEncoder()

# Fit the encoders on their respective unique values
superclass_encoder.fit(all_superclasses)
label_encoder.fit(all_labels)

# Step 3: Create mappings and save them to a text file
with open('label_mappings.txt', 'w') as f:
    f.write('Diagnostic Superclass Mappings:\n')
    for key, value in zip(superclass_encoder.classes_, superclass_encoder.transform(superclass_encoder.classes_)):
        f.write(f'{key}: {value}\n')
    
    f.write('\nSCP Labels Mappings:\n')
    for key, value in zip(label_encoder.classes_, label_encoder.transform(label_encoder.classes_)):
        f.write(f'{key}: {value}\n')

# Step 4: Apply the encodings to both columns
Y['diagnostic_superclass_encoded'] = Y['diagnostic_superclass'].apply(lambda x: superclass_encoder.transform(x).tolist())
Y['scp_labels_encoded'] = Y['scp_labels'].apply(lambda x: label_encoder.transform(x).tolist())

In [11]:
Y.head()

Unnamed: 0_level_0,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,...,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr,scp_labels,diagnostic_superclass,diagnostic_superclass_encoded,scp_labels_encoded
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,sinusrhythmus periphere niederspannung,...,,,,3,records100/00000/00001_lr,records500/00000/00001_hr,"[NORM, LVOLT, SR]",[NORM],[3],"[46, 44, 61]"
2,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,sinusbradykardie sonst normales ekg,...,,,,2,records100/00000/00002_lr,records500/00000/00002_hr,"[NORM, SBRAD]",[NORM],[3],"[46, 59]"
3,20372.0,37.0,1,,69.0,2.0,0.0,CS-12 E,1984-11-15 12:49:10,sinusrhythmus normales ekg,...,,,,5,records100/00000/00003_lr,records500/00000/00003_hr,"[NORM, SR]",[NORM],[3],"[46, 61]"
4,17014.0,24.0,0,,82.0,2.0,0.0,CS-12 E,1984-11-15 13:44:57,sinusrhythmus normales ekg,...,,,,3,records100/00000/00004_lr,records500/00000/00004_hr,"[NORM, SR]",[NORM],[3],"[46, 61]"
5,17448.0,19.0,1,,70.0,2.0,0.0,CS-12 E,1984-11-17 10:43:15,sinusrhythmus normales ekg,...,,,,4,records100/00000/00005_lr,records500/00000/00005_hr,"[NORM, SR]",[NORM],[3],"[46, 61]"


In [12]:
# Split data into train and test
test_fold = 10

# Train
y_train = Y[(Y.strat_fold != test_fold)][['diagnostic_superclass_encoded', 'scp_labels_encoded']]
# Test
y_test = Y[Y.strat_fold == test_fold][['diagnostic_superclass_encoded', 'scp_labels_encoded']]

In [13]:
y_train.head(10)

Unnamed: 0_level_0,diagnostic_superclass_encoded,scp_labels_encoded
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1
1,[3],"[46, 44, 61]"
2,[3],"[46, 59]"
3,[3],"[46, 61]"
4,[3],"[46, 61]"
5,[3],"[46, 61]"
6,[3],"[46, 61]"
7,[3],"[46, 61]"
8,[2],"[18, 3, 61]"
10,[3],"[46, 61]"
11,[3],"[46, 58]"


## Run the following cells to download test set (IMAGE DATA)

In [92]:
import os
import pandas as pd
import ast
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.preprocessing import LabelEncoder
import torchvision.transforms as transforms

split = 'test'
freq = 500
prefix = f"../../../../../data/padmalab_external/special_project/physionet.org/files/ptb-xl/1.0.3/records{freq}_ground_truth/records{freq}_ground_truth_"

# Load the database file
ptb_xl_database_df = pd.read_csv(database_path, index_col='ecg_id')
ptb_xl_database_df.scp_codes = ptb_xl_database_df.scp_codes.apply(lambda x: ast.literal_eval(x))

# Load scp_statements.csv for diagnostic aggregation
agg_df = pd.read_csv(scp_statements_path, index_col=0)
agg_df = agg_df[agg_df.diagnostic == 1]

def aggregate_diagnostic(y_dic):
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].diagnostic_class)
    return list(set(tmp))

# Apply diagnostic superclass
ptb_xl_database_df['diagnostic_superclass'] = ptb_xl_database_df.scp_codes.apply(aggregate_diagnostic)
Y = ptb_xl_database_df
# Apply lambda function to extract keys 
Y['scp_labels'] = Y.scp_codes.apply(lambda scp: [key for key, value in scp.items()])

# Split data into train and test
test_fold = 10
y_train = Y[Y.strat_fold != test_fold]
y_test = Y[Y.strat_fold == test_fold]

if split == 'test':
    y = y_test
else:
    y = y_train

if freq == 100:
    y['path'] = y.filename_lr.apply(lambda x: prefix + x.split('/')[-2] + '/' + x.split('/')[-1] + '-0.png')
    y = y[['path', 'diagnostic_superclass', 'scp_labels']]
else:
    y['path'] = y.filename_hr.apply(lambda x: prefix + x.split('/')[-2] + '/' + x.split('/')[-1] + '-0.png')
    y = y[['path', 'diagnostic_superclass', 'scp_labels']]

# Exclude rows where diagnostic_superclass_encoded or scp_encoded list is empty
y = y[y['diagnostic_superclass'].apply(lambda x: len(x) > 0)]
y = y[y['scp_labels'].apply(lambda x: len(x) > 0)]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  y['path'] = y.filename_hr.apply(lambda x: prefix + x.split('/')[-2] + '/' + x.split('/')[-1] + '-0.png')


In [93]:
y.iloc[0]['path']

'../../../../../data/padmalab_external/special_project/physionet.org/files/ptb-xl/1.0.3/records500_ground_truth/records500_ground_truth_00000/00009_hr-0.png'

In [94]:
import os

img_dir = f"../../../../../data/padmalab_external/special_project/physionet.org/files/ptb-xl/1.0.3/records{freq}_ground_truth"

image_paths = []
for root, _, files in os.walk(img_dir):
    for file in files:
        if file.endswith('.png'):
            image_paths.append(os.path.join(root, file))
            
image_paths[0]

'../../../../../data/padmalab_external/special_project/physionet.org/files/ptb-xl/1.0.3/records500_ground_truth/records500_ground_truth_00000/00545_hr-0.png'

In [95]:
# Collect indices to drop
indices_to_drop = []
for index, row in y.iterrows():
    p1 = row['path']
    if p1 not in image_paths:
        indices_to_drop.append(index)

# Drop rows by index
y = y.drop(indices_to_drop)

In [96]:
# Reset the index
y.reset_index(drop=True, inplace=True)
y.head()

Unnamed: 0,path,diagnostic_superclass,scp_labels
0,../../../../../data/padmalab_external/special_...,[NORM],"[NORM, SR]"
1,../../../../../data/padmalab_external/special_...,[NORM],"[NORM, SR]"
2,../../../../../data/padmalab_external/special_...,[NORM],"[NORM, SR]"
3,../../../../../data/padmalab_external/special_...,[NORM],"[NORM, SR]"
4,../../../../../data/padmalab_external/special_...,[NORM],"[NORM, SR]"


In [97]:
len(y)

714

In [98]:
# Define mappings
diagnostic_mapping = {
    'CD': 0, 'HYP': 1, 'MI': 2, 'NORM': 3, 'STTC': 4
}
scp_mapping = {
    '1AVB': 0, '2AVB': 1, '3AVB': 2, 'ABQRS': 3, 'AFIB': 4, 'AFLT': 5, 'ALMI': 6, 'AMI': 7, 
    'ANEUR': 8, 'ASMI': 9, 'BIGU': 10, 'CLBBB': 11, 'CRBBB': 12, 'DIG': 13, 'EL': 14, 'HVOLT': 15, 
    'ILBBB': 16, 'ILMI': 17, 'IMI': 18, 'INJAL': 19, 'INJAS': 20, 'INJIL': 21, 'INJIN': 22, 
    'INJLA': 23, 'INVT': 24, 'IPLMI': 25, 'IPMI': 26, 'IRBBB': 27, 'ISCAL': 28, 'ISCAN': 29, 
    'ISCAS': 30, 'ISCIL': 31, 'ISCIN': 32, 'ISCLA': 33, 'ISC_': 34, 'IVCD': 35, 'LAFB': 36, 
    'LAO/LAE': 37, 'LMI': 38, 'LNGQT': 39, 'LOWT': 40, 'LPFB': 41, 'LPR': 42, 'LVH': 43, 
    'LVOLT': 44, 'NDT': 45, 'NORM': 46, 'NST_': 47, 'NT_': 48, 'PAC': 49, 'PACE': 50, 'PMI': 51, 
    'PRC(S)': 52, 'PSVT': 53, 'PVC': 54, 'QWAVE': 55, 'RAO/RAE': 56, 'RVH': 57, 'SARRH': 58, 
    'SBRAD': 59, 'SEHYP': 60, 'SR': 61, 'STACH': 62, 'STD_': 63, 'STE_': 64, 'SVARR': 65, 
    'SVTAC': 66, 'TAB_': 67, 'TRIGU': 68, 'VCLVH': 69, 'WPW': 70
}

# Inverse mappings for one-hot encoding
num_diagnostic_classes = len(diagnostic_mapping)
num_scp_classes = len(scp_mapping)

def encode_labels(label_list, mapping, num_classes):
    one_hot = np.zeros(num_classes, dtype=int)
    for label in label_list:
        if label in mapping:
            one_hot[mapping[label]] = 1
    return one_hot

# Apply encoding
y['diagnostic_one_hot'] = y['diagnostic_superclass'].apply(
    lambda labels: encode_labels(labels, diagnostic_mapping, num_diagnostic_classes)
)
y['scp_one_hot'] = y['scp_labels'].apply(
    lambda labels: encode_labels(labels, scp_mapping, num_scp_classes)
)

print(y.head(10))

                                                path diagnostic_superclass  \
0  ../../../../../data/padmalab_external/special_...                [NORM]   
1  ../../../../../data/padmalab_external/special_...                [NORM]   
2  ../../../../../data/padmalab_external/special_...                [NORM]   
3  ../../../../../data/padmalab_external/special_...                [NORM]   
4  ../../../../../data/padmalab_external/special_...                [NORM]   
5  ../../../../../data/padmalab_external/special_...                  [MI]   
6  ../../../../../data/padmalab_external/special_...                  [CD]   
7  ../../../../../data/padmalab_external/special_...                [NORM]   
8  ../../../../../data/padmalab_external/special_...                [NORM]   
9  ../../../../../data/padmalab_external/special_...                [NORM]   

          scp_labels diagnostic_one_hot  \
0         [NORM, SR]    [0, 0, 0, 1, 0]   
1         [NORM, SR]    [0, 0, 0, 1, 0]   
2         [N

In [99]:
y['scp_one_hot'][0].shape

(71,)

In [100]:
from PIL import Image
# Iterate over all paths and check if they are valid
for path in y['path']:
    try:
        ecg = Image.open(path)
        # print(f"Successfully loaded: {path}")
    except Exception as e:
        print(f"Failed to load: {path}, Error: {e}")

In [101]:
# Save the DataFrame as a CSV file
y.to_csv(f'{split}-{freq}.csv', index=False)

## Run the following cell to download train set (IMAGE DATA)

In [102]:
import os
import pandas as pd
import ast
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.preprocessing import LabelEncoder
import torchvision.transforms as transforms

split = 'train'
freq = 100
prefix = f"../../../../../data/padmalab_external/special_project/physionet.org/files/ptb-xl/1.0.3/records{freq}_ground_truth/records{freq}_ground_truth_"

# Load the database file
ptb_xl_database_df = pd.read_csv(database_path, index_col='ecg_id')
ptb_xl_database_df.scp_codes = ptb_xl_database_df.scp_codes.apply(lambda x: ast.literal_eval(x))

# Load scp_statements.csv for diagnostic aggregation
agg_df = pd.read_csv(scp_statements_path, index_col=0)
agg_df = agg_df[agg_df.diagnostic == 1]

def aggregate_diagnostic(y_dic):
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].diagnostic_class)
    return list(set(tmp))

# Apply diagnostic superclass
ptb_xl_database_df['diagnostic_superclass'] = ptb_xl_database_df.scp_codes.apply(aggregate_diagnostic)
Y = ptb_xl_database_df
# Apply lambda function to extract keys 
Y['scp_labels'] = Y.scp_codes.apply(lambda scp: [key for key, value in scp.items()])

# Split data into train and test
test_fold = 10
y_train = Y[Y.strat_fold != test_fold]
y_test = Y[Y.strat_fold == test_fold]

if split == 'test':
    y = y_test
else:
    y = y_train

if freq == 100:
    y['path'] = y.filename_lr.apply(lambda x: prefix + x.split('/')[-2] + '/' + x.split('/')[-1] + '-0.png')
    y = y[['path', 'diagnostic_superclass', 'scp_labels']]
else:
    y['path'] = y.filename_hr.apply(lambda x: prefix + x.split('/')[-2] + '/' + x.split('/')[-1] + '-0.png')
    y = y[['path', 'diagnostic_superclass', 'scp_labels']]

# Exclude rows where diagnostic_superclass_encoded or scp_encoded list is empty
y = y[y['diagnostic_superclass'].apply(lambda x: len(x) > 0)]
y = y[y['scp_labels'].apply(lambda x: len(x) > 0)]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  y['path'] = y.filename_lr.apply(lambda x: prefix + x.split('/')[-2] + '/' + x.split('/')[-1] + '-0.png')


In [103]:
import os

img_dir = f"../../../../../data/padmalab_external/special_project/physionet.org/files/ptb-xl/1.0.3/records{freq}_ground_truth"

image_paths = []
for root, _, files in os.walk(img_dir):
    for file in files:
        if file.endswith('.png'):
            image_paths.append(os.path.join(root, file))
            
image_paths[0]

'../../../../../data/padmalab_external/special_project/physionet.org/files/ptb-xl/1.0.3/records100_ground_truth/records100_ground_truth_08000/08588_lr-0.png'

In [104]:
# Collect indices to drop
indices_to_drop = []
for index, row in y.iterrows():
    p1 = row['path']
    if p1 not in image_paths:
        indices_to_drop.append(index)

# Drop rows by index
y = y.drop(indices_to_drop)

In [105]:
# Reset the index
y.reset_index(drop=True, inplace=True)
y.head()

Unnamed: 0,path,diagnostic_superclass,scp_labels
0,../../../../../data/padmalab_external/special_...,[NORM],"[NORM, LVOLT, SR]"
1,../../../../../data/padmalab_external/special_...,[NORM],"[NORM, SBRAD]"
2,../../../../../data/padmalab_external/special_...,[NORM],"[NORM, SR]"
3,../../../../../data/padmalab_external/special_...,[NORM],"[NORM, SR]"
4,../../../../../data/padmalab_external/special_...,[NORM],"[NORM, SR]"


In [106]:
len(y)

15783

In [107]:
# Define mappings
diagnostic_mapping = {
    'CD': 0, 'HYP': 1, 'MI': 2, 'NORM': 3, 'STTC': 4
}
scp_mapping = {
    '1AVB': 0, '2AVB': 1, '3AVB': 2, 'ABQRS': 3, 'AFIB': 4, 'AFLT': 5, 'ALMI': 6, 'AMI': 7, 
    'ANEUR': 8, 'ASMI': 9, 'BIGU': 10, 'CLBBB': 11, 'CRBBB': 12, 'DIG': 13, 'EL': 14, 'HVOLT': 15, 
    'ILBBB': 16, 'ILMI': 17, 'IMI': 18, 'INJAL': 19, 'INJAS': 20, 'INJIL': 21, 'INJIN': 22, 
    'INJLA': 23, 'INVT': 24, 'IPLMI': 25, 'IPMI': 26, 'IRBBB': 27, 'ISCAL': 28, 'ISCAN': 29, 
    'ISCAS': 30, 'ISCIL': 31, 'ISCIN': 32, 'ISCLA': 33, 'ISC_': 34, 'IVCD': 35, 'LAFB': 36, 
    'LAO/LAE': 37, 'LMI': 38, 'LNGQT': 39, 'LOWT': 40, 'LPFB': 41, 'LPR': 42, 'LVH': 43, 
    'LVOLT': 44, 'NDT': 45, 'NORM': 46, 'NST_': 47, 'NT_': 48, 'PAC': 49, 'PACE': 50, 'PMI': 51, 
    'PRC(S)': 52, 'PSVT': 53, 'PVC': 54, 'QWAVE': 55, 'RAO/RAE': 56, 'RVH': 57, 'SARRH': 58, 
    'SBRAD': 59, 'SEHYP': 60, 'SR': 61, 'STACH': 62, 'STD_': 63, 'STE_': 64, 'SVARR': 65, 
    'SVTAC': 66, 'TAB_': 67, 'TRIGU': 68, 'VCLVH': 69, 'WPW': 70
}

# Inverse mappings for one-hot encoding
num_diagnostic_classes = len(diagnostic_mapping)
num_scp_classes = len(scp_mapping)

def encode_labels(label_list, mapping, num_classes):
    one_hot = np.zeros(num_classes, dtype=int)
    for label in label_list:
        if label in mapping:
            one_hot[mapping[label]] = 1
    return one_hot

# Apply encoding
y['diagnostic_one_hot'] = y['diagnostic_superclass'].apply(
    lambda labels: encode_labels(labels, diagnostic_mapping, num_diagnostic_classes)
)
y['scp_one_hot'] = y['scp_labels'].apply(
    lambda labels: encode_labels(labels, scp_mapping, num_scp_classes)
)

print(y.head(10))


                                                path diagnostic_superclass  \
0  ../../../../../data/padmalab_external/special_...                [NORM]   
1  ../../../../../data/padmalab_external/special_...                [NORM]   
2  ../../../../../data/padmalab_external/special_...                [NORM]   
3  ../../../../../data/padmalab_external/special_...                [NORM]   
4  ../../../../../data/padmalab_external/special_...                [NORM]   
5  ../../../../../data/padmalab_external/special_...                [NORM]   
6  ../../../../../data/padmalab_external/special_...                [NORM]   
7  ../../../../../data/padmalab_external/special_...                  [MI]   
8  ../../../../../data/padmalab_external/special_...                [NORM]   
9  ../../../../../data/padmalab_external/special_...                [NORM]   

          scp_labels diagnostic_one_hot  \
0  [NORM, LVOLT, SR]    [0, 0, 0, 1, 0]   
1      [NORM, SBRAD]    [0, 0, 0, 1, 0]   
2         [N

In [108]:
y['scp_one_hot'][0]

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
       0, 0, 0, 0, 0])

In [109]:
from PIL import Image
# Iterate over all paths and check if they are valid
for path in y['path']:
    try:
        ecg = Image.open(path)
        # print(f"Successfully loaded: {path}")
    except Exception as e:
        print(f"Failed to load: {path}, Error: {e}")

In [110]:
y.to_csv(f'{split}-{freq}.csv', index=False)

# Data Loader and dataset class

In [125]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
import ast

class ECGImageDataset(Dataset):
    def __init__(self, info_df_path, transform=None):
        self.info_df = pd.read_csv(info_df_path)
        self.transform = transform

    def __len__(self):
        return len(self.info_df)

    def __getitem__(self, idx):
        img_path = self.info_df.iloc[idx]['path']
        image = Image.open(img_path).convert('RGB')
        
        # Convert diagnostic_superclass_encoded
        diagnostic_encoded = self.info_df.iloc[idx]['diagnostic_one_hot']
        try:
            if isinstance(diagnostic_encoded, str):
                label_1 = torch.tensor(ast.literal_eval(diagnostic_encoded)).long()
            else:
                label_1 = torch.tensor([diagnostic_encoded]).long()
        except (SyntaxError, ValueError):
            # Handle invalid format
            label_1 = torch.tensor([0] * 5).long()  # or any default value or error handling

        # Convert scp_labels_encoded
        scp_encoded = self.info_df.iloc[idx]['scp_one_hot']
        try:
            if isinstance(scp_encoded, str):
                label_2 = torch.tensor(ast.literal_eval(scp_encoded)).long()
            else:
                label_2 = torch.tensor([scp_encoded]).long()
        except (SyntaxError, ValueError):
            # Handle invalid format
            label_2 = torch.tensor([0] * 71).long()  # or any default value or error handling

        if self.transform:
            image = self.transform(image)

        return image, (label_1, label_2)

# Define transformations
img_size = 224  
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    normalize,
])

# Create dataset and dataloader
dataset = ECGImageDataset('test-500.csv', transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Iterate through the DataLoader
for imgs, labels in dataloader:
    print(imgs.shape, labels[0].shape, labels[1].shape)
    

torch.Size([32, 3, 224, 224]) torch.Size([32, 5]) torch.Size([32, 71])
torch.Size([32, 3, 224, 224]) torch.Size([32, 5]) torch.Size([32, 71])
torch.Size([32, 3, 224, 224]) torch.Size([32, 5]) torch.Size([32, 71])
torch.Size([32, 3, 224, 224]) torch.Size([32, 5]) torch.Size([32, 71])
torch.Size([32, 3, 224, 224]) torch.Size([32, 5]) torch.Size([32, 71])
torch.Size([32, 3, 224, 224]) torch.Size([32, 5]) torch.Size([32, 71])
torch.Size([32, 3, 224, 224]) torch.Size([32, 5]) torch.Size([32, 71])
torch.Size([32, 3, 224, 224]) torch.Size([32, 5]) torch.Size([32, 71])
torch.Size([32, 3, 224, 224]) torch.Size([32, 5]) torch.Size([32, 71])
torch.Size([32, 3, 224, 224]) torch.Size([32, 5]) torch.Size([32, 71])
torch.Size([32, 3, 224, 224]) torch.Size([32, 5]) torch.Size([32, 71])
torch.Size([32, 3, 224, 224]) torch.Size([32, 5]) torch.Size([32, 71])
torch.Size([32, 3, 224, 224]) torch.Size([32, 5]) torch.Size([32, 71])
torch.Size([32, 3, 224, 224]) torch.Size([32, 5]) torch.Size([32, 71])
torch.