### Imports

In [None]:
import ast
import wfdb

import pandas as pd
import numpy as np

### Helper functions from `example_physionet.py`

In [None]:
def load_raw_data(df, sampling_rate, path):
    """
    
    Load raw signal data based on a specified sampling rate.

    This function reads raw data files using a sampling rate to determine the appropriate file set
    (low resolution or high resolution) to load. The function returns an array of signals.

    Args:
        df (DataFrame): A pandas DataFrame containing the file names under columns 
                        'filename_lr' for low resolution and 'filename_hr' for high resolution.
        sampling_rate (int): The sampling rate in Hz. If set to 100, low resolution data files are loaded.
        path (str): The file path prefix required to access the data files.

    Returns:
        numpy.ndarray: An array of signal data extracted from the files.
        
    """
    
    if sampling_rate == 100:
        data = [wfdb.rdsamp(path+f) for f in df.filename_lr]
    else:
        data = [wfdb.rdsamp(path+f) for f in df.filename_hr]
    data = np.array([signal for signal, meta in data])
    return data


def aggregate_diagnostic(y_dic):
    """
    
    Aggregate diagnostic classes based on keys from a dictionary that exist in a global DataFrame.

    This function iterates over keys in the input dictionary and aggregates diagnostic classes 
    from a predefined DataFrame `agg_df` where keys match indices. The function returns a list 
    of unique diagnostic classes.

    Args:
        y_dic (dict): A dictionary where keys represent some identifiers which are used to look up 
                      diagnostic classes in a global DataFrame `agg_df`.

    Returns:
        list: A list of unique diagnostic classes corresponding to the keys in the input dictionary
              that are present in `agg_df`.
              
    """
    
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].diagnostic_class)
    return list(set(tmp))

### Variables

In [None]:
path = 'data/'
sampling_rate = 100

### Load 

Optimize the following, loading is pretty slow!  

In [None]:
# load and convert annotation data
Y = pd.read_csv(path+'ptbxl_database.csv', index_col='ecg_id')
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))

# Load raw signal data
X = load_raw_data(Y, sampling_rate, path)

# Load scp_statements.csv for diagnostic aggregation
agg_df = pd.read_csv(path+'scp_statements.csv', index_col=0)
agg_df = agg_df[agg_df.diagnostic == 1]

# Apply diagnostic superclass
Y['diagnostic_superclass'] = Y.scp_codes.apply(aggregate_diagnostic)

In [None]:
# Split data into train and test
test_fold = 10
# Train
X_train = X[np.where(Y.strat_fold != test_fold)]
y_train = Y[(Y.strat_fold != test_fold)].diagnostic_superclass
# Test
X_test = X[np.where(Y.strat_fold == test_fold)]
y_test = Y[Y.strat_fold == test_fold].diagnostic_superclass

In [6]:
print(X.shape)

(21799, 1000, 12)
