In [1]:
import argparse
import os

import numpy as np
import pandas as pd
import wfdb
from tqdm import tqdm


_LEAD_NAMES = ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
# _LEAD_NAMES = ['I', 'II', 'III', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']

In [2]:
def get_parser():
    description = "Process WFDB ECG database."
    # parser = argparse.ArgumentParser(description=description)
    parser = argparse.ArgumentParser()
    parser.add_argument('-i',
                        '--input_dir',
                        type=str,
                        # required=True,
                        default='/tf/physionet.org/files/challenge-2021/1.0.3/training/ningbo/',
                        help="Path to the WFDB ECG database directory.")
    parser.add_argument('-o',
                        '--output_dir',
                        type=str,
                        # required=True,
                        default='./ningbo/ecgs/',
                        help="Path to the directory where the preprocessed signals will be saved.")
    parser.add_argument('--index_path',
                        type=str,
                        default='./ningbo/index.csv',
                        help="Path to the index file.")
    args = parser.parse_args("")
    return args

In [3]:
args = get_parser()
args

Namespace(input_dir='/tf/physionet.org/files/challenge-2021/1.0.3/training/ningbo/', output_dir='./ningbo/ecgs/', index_path='./ningbo/index.csv')

get_parser() 부분 실행 완 => args 처리됨. 이후 run 코드 순서대로 실행하면 됨!!

In [4]:
def find_records(root_dir):
    """Find all the .hea files in the root directory and its subdirectories.
    Args:
        root_dir (str): The directory to search for .hea files.
    Returns:
        records (set): A set of record names.
                       (e.g., ['database/1/ecg001', 'database/1/ecg001', ..., 'database/9/ecg991'])
    """
    records = set()
    for root, _, files in os.walk(root_dir):
        for file in files:
            extension = os.path.splitext(file)[1]
            if extension == '.hea':
                record = os.path.relpath(os.path.join(root, file), root_dir)[:-4]
                records.add(record)
    records = sorted(records)
    return records

In [5]:
# Identify the header fiels
record_rel_paths = find_records(args.input_dir)
record_rel_paths

['g1/JS10647',
 'g1/JS10648',
 'g1/JS10649',
 'g1/JS10650',
 'g1/JS10651',
 'g1/JS10652',
 'g1/JS10653',
 'g1/JS10654',
 'g1/JS10655',
 'g1/JS10656',
 'g1/JS10657',
 'g1/JS10658',
 'g1/JS10659',
 'g1/JS10660',
 'g1/JS10661',
 'g1/JS10662',
 'g1/JS10663',
 'g1/JS10664',
 'g1/JS10665',
 'g1/JS10666',
 'g1/JS10667',
 'g1/JS10668',
 'g1/JS10669',
 'g1/JS10670',
 'g1/JS10671',
 'g1/JS10672',
 'g1/JS10673',
 'g1/JS10674',
 'g1/JS10675',
 'g1/JS10676',
 'g1/JS10677',
 'g1/JS10678',
 'g1/JS10679',
 'g1/JS10680',
 'g1/JS10681',
 'g1/JS10682',
 'g1/JS10683',
 'g1/JS10684',
 'g1/JS10685',
 'g1/JS10686',
 'g1/JS10687',
 'g1/JS10688',
 'g1/JS10689',
 'g1/JS10690',
 'g1/JS10691',
 'g1/JS10692',
 'g1/JS10693',
 'g1/JS10694',
 'g1/JS10695',
 'g1/JS10696',
 'g1/JS10697',
 'g1/JS10698',
 'g1/JS10699',
 'g1/JS10700',
 'g1/JS10701',
 'g1/JS10702',
 'g1/JS10703',
 'g1/JS10704',
 'g1/JS10705',
 'g1/JS10706',
 'g1/JS10707',
 'g1/JS10708',
 'g1/JS10709',
 'g1/JS10710',
 'g1/JS10711',
 'g1/JS10712',
 'g1/JS107

In [6]:
len(record_rel_paths), record_rel_paths[:4], record_rel_paths[-4:]

(34905,
 ['g1/JS10647', 'g1/JS10648', 'g1/JS10649', 'g1/JS10650'],
 ['g9/JS19642', 'g9/JS19643', 'g9/JS19644', 'g9/JS19645'])

In [7]:
print(f"Found {len(record_rel_paths)} records.")

Found 34905 records.


In [8]:
# Prepare an index dataframe
index_df = pd.DataFrame(columns = ["RELATIVE_FILE_PATH", "FILE_NAME", "SAMPLE_RATE", "SOURCE"])
index_df

Unnamed: 0,RELATIVE_FILE_PATH,FILE_NAME,SAMPLE_RATE,SOURCE


In [9]:
def moving_window_crop(x: np.ndarray, crop_length: int, crop_stride: int) -> np.ndarray:
    """Crop the input sequence with a moving window.
    """
    if crop_length > x.shape[1]:
        raise ValueError(f"crop_length must be smaller than the length of x ({x.shape[1]}).")
    start_idx = np.arange(0, x.shape[1] - crop_length + 1, crop_stride)
    return [x[:, i:i + crop_length] for i in start_idx]

In [10]:
num_saved = 0
for record_rel_path in tqdm(record_rel_paths):
    record_rel_dir, record_name = os.path.split(record_rel_path)
    save_dir = os.path.join(args.output_dir, record_rel_dir)
    os.makedirs(save_dir, exist_ok=True)
    source_name = record_rel_dir.split("/")[0]
    signal, record_info = wfdb.rdsamp(os.path.join(args.input_dir, record_rel_path))
    lead_idx = np.array([record_info["sig_name"].index(lead_name) for lead_name in _LEAD_NAMES])
    signal = signal[:, lead_idx]
    fs = record_info["fs"]
    signal_length = record_info["sig_len"]
    if signal_length < 10 * fs:  # Exclude the ECGs with lengths of less than 10 seconds
        continue
    cropped_signals = moving_window_crop(signal.T, crop_length=10 * fs, crop_stride=10 * fs)
    for idx, cropped_signal in enumerate(cropped_signals):
        if cropped_signal.shape[1] != 10 * fs or np.isnan(cropped_signal).any():
            continue
        pd.to_pickle(cropped_signal.astype(np.float32),
                     os.path.join(save_dir, f"{record_name}_{idx}.pkl"))
        index_df.loc[num_saved] = [f"{record_rel_path}_{idx}.pkl",
                                   f"{record_name}_{idx}.pkl",
                                   fs,
                                   source_name]
        num_saved += 1

print(f"Saved {num_saved} cropped signals.")
os.makedirs(os.path.dirname(args.index_path), exist_ok=True)
index_df.to_csv(args.index_path, index=False)

100%|██████████████████████████████████████████████████████████| 34905/34905 [2:21:30<00:00,  4.11it/s]


Saved 34808 cropped signals.


In [11]:
index_df

Unnamed: 0,RELATIVE_FILE_PATH,FILE_NAME,SAMPLE_RATE,SOURCE
0,g1/JS10647_0.pkl,JS10647_0.pkl,500,g1
1,g1/JS10648_0.pkl,JS10648_0.pkl,500,g1
2,g1/JS10649_0.pkl,JS10649_0.pkl,500,g1
3,g1/JS10650_0.pkl,JS10650_0.pkl,500,g1
4,g1/JS10651_0.pkl,JS10651_0.pkl,500,g1
...,...,...,...,...
34803,g9/JS19641_0.pkl,JS19641_0.pkl,500,g9
34804,g9/JS19642_0.pkl,JS19642_0.pkl,500,g9
34805,g9/JS19643_0.pkl,JS19643_0.pkl,500,g9
34806,g9/JS19644_0.pkl,JS19644_0.pkl,500,g9


In [None]:
# Save all the cropped signals
num_saved = 0
for record_rel_path in tqdm(record_rel_paths):
    record_rel_dir, record_name = os.path.split(record_rel_path)
    save_dir = os.path.join(args.output_dir, record_rel_dir)
    os.makedirs(save_dir, exist_ok=True)
    source_name = record_rel_dir.split("/")[0]
    signal, record_info = wfdb.rdsamp(os.path.join(args.input_dir, record_rel_path))
    lead_idx = np.array([record_info["sig_name"].index(lead_name) for lead_name in _LEAD_NAMES])
    signal = signal[:, lead_idx]
    fs = record_info["fs"]
    signal_length = record_info["sig_len"]
    # if signal_length < 10 * fs:  # Exclude the ECGs with lengths of less than 10 seconds
    #     continue
    # cropped_signals = moving_window_crop(signal.T, crop_length=10 * fs, crop_stride=10 * fs)
    # for idx, cropped_signal in enumerate(cropped_signals):
    #     if cropped_signal.shape[1] != 10 * fs or np.isnan(cropped_signal).any():
    #         continue
    #     pd.to_pickle(cropped_signal.astype(np.float32),
    #                  os.path.join(save_dir, f"{record_name}_{idx}.pkl"))
    #     index_df.loc[num_saved] = [f"{record_rel_path}_{idx}.pkl",
    #                                f"{record_name}_{idx}.pkl",
    #                                fs,
    #                                source_name]
    #     num_saved += 1
    break

In [None]:
record_rel_dir, record_name, save_dir, source_name

In [None]:
signal

In [None]:
record_info

In [None]:
if signal_length < 10 * fs:  # Exclude the ECGs with lengths of less than 10 seconds
    continue
cropped_signals = moving_window_crop(signal.T, crop_length=10 * fs, crop_stride=10 * fs)
for idx, cropped_signal in enumerate(cropped_signals):
    if cropped_signal.shape[1] != 10 * fs or np.isnan(cropped_signal).any():
        continue
    pd.to_pickle(cropped_signal.astype(np.float32),
                 os.path.join(save_dir, f"{record_name}_{idx}.pkl"))
    index_df.loc[num_saved] = [f"{record_rel_path}_{idx}.pkl",
                               f"{record_name}_{idx}.pkl",
                               fs,
                               source_name]
    num_saved += 1

In [None]:
if signal_length < 10 * fs:
    print(1)
else: print(0)

In [None]:
signal_length

In [None]:
fs * 10

In [None]:
cropped_signals = moving_window_crop(signal.T, crop_length=10 * fs, crop_stride=10 * fs)
cropped_signals

In [None]:
np.array(signal).shape

In [None]:
np.array(cropped_signals).shape

In [None]:
for idx, cropped_signal in enumerate(cropped_signals):
    print(cropped_signal)
    break

In [None]:
[f"{record_rel_path}_{idx}.pkl", f"{record_name}_{idx}.pkl", fs, source_name]

In [None]:
Y = pd.read_csv('/tf/physionet.org/files/ptb-xl/1.0.3/ptbxl_database.csv', index_col='ecg_id')
Y

In [None]:
Y.scp_codes # = Y.scp_codes.apply(lambda x: ast.literal_eval(x))

In [None]:
import ast
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))
Y.scp_codes