In [12]:
import argparse
import os

import numpy as np
import pandas as pd
import wfdb
from tqdm import tqdm


_LEAD_NAMES = ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
# _LEAD_NAMES = ['I', 'II', 'III', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']

In [3]:
def get_parser():
    description = "Process WFDB ECG database."
    # parser = argparse.ArgumentParser(description=description)
    parser = argparse.ArgumentParser()
    parser.add_argument('-i',
                        '--input_dir',
                        type=str,
                        # required=True,
                        default='/tf/physionet.org/files/challenge-2021/1.0.3/training/chapman_shaoxing/',
                        help="Path to the WFDB ECG database directory.")
    parser.add_argument('-o',
                        '--output_dir',
                        type=str,
                        # required=True,
                        default='./chapman/ecgs/',
                        help="Path to the directory where the preprocessed signals will be saved.")
    parser.add_argument('--index_path',
                        type=str,
                        default='./chapman/index.csv',
                        help="Path to the index file.")
    args = parser.parse_args("")
    return args

In [4]:
args = get_parser()
args

Namespace(input_dir='/tf/physionet.org/files/challenge-2021/1.0.3/training/chapman_shaoxing/', output_dir='./chapman/ecgs/', index_path='./chapman/index.csv')

get_parser() 부분 실행 완 => args 처리됨. 이후 run 코드 순서대로 실행하면 됨!!

In [5]:
def find_records(root_dir):
    """Find all the .hea files in the root directory and its subdirectories.
    Args:
        root_dir (str): The directory to search for .hea files.
    Returns:
        records (set): A set of record names.
                       (e.g., ['database/1/ecg001', 'database/1/ecg001', ..., 'database/9/ecg991'])
    """
    records = set()
    for root, _, files in os.walk(root_dir):
        for file in files:
            extension = os.path.splitext(file)[1]
            if extension == '.hea':
                record = os.path.relpath(os.path.join(root, file), root_dir)[:-4]
                records.add(record)
    records = sorted(records)
    return records

In [6]:
# Identify the header fiels
record_rel_paths = find_records(args.input_dir)
record_rel_paths

['g1/JS00001',
 'g1/JS00002',
 'g1/JS00004',
 'g1/JS00005',
 'g1/JS00006',
 'g1/JS00007',
 'g1/JS00008',
 'g1/JS00009',
 'g1/JS00010',
 'g1/JS00011',
 'g1/JS00012',
 'g1/JS00013',
 'g1/JS00014',
 'g1/JS00015',
 'g1/JS00016',
 'g1/JS00017',
 'g1/JS00018',
 'g1/JS00019',
 'g1/JS00020',
 'g1/JS00021',
 'g1/JS00022',
 'g1/JS00023',
 'g1/JS00024',
 'g1/JS00025',
 'g1/JS00026',
 'g1/JS00027',
 'g1/JS00029',
 'g1/JS00030',
 'g1/JS00031',
 'g1/JS00032',
 'g1/JS00033',
 'g1/JS00034',
 'g1/JS00036',
 'g1/JS00037',
 'g1/JS00038',
 'g1/JS00039',
 'g1/JS00040',
 'g1/JS00041',
 'g1/JS00042',
 'g1/JS00043',
 'g1/JS00044',
 'g1/JS00045',
 'g1/JS00046',
 'g1/JS00047',
 'g1/JS00048',
 'g1/JS00049',
 'g1/JS00050',
 'g1/JS00051',
 'g1/JS00052',
 'g1/JS00053',
 'g1/JS00054',
 'g1/JS00055',
 'g1/JS00056',
 'g1/JS00057',
 'g1/JS00058',
 'g1/JS00059',
 'g1/JS00060',
 'g1/JS00061',
 'g1/JS00062',
 'g1/JS00063',
 'g1/JS00064',
 'g1/JS00065',
 'g1/JS00066',
 'g1/JS00067',
 'g1/JS00068',
 'g1/JS00069',
 'g1/JS000

In [7]:
len(record_rel_paths), record_rel_paths[:4], record_rel_paths[-4:]

(10247,
 ['g1/JS00001', 'g1/JS00002', 'g1/JS00004', 'g1/JS00005'],
 ['g9/JS09369', 'g9/JS09370', 'g9/JS09371', 'g9/JS09372'])

In [8]:
print(f"Found {len(record_rel_paths)} records.")

Found 10247 records.


In [13]:
# Prepare an index dataframe
index_df = pd.DataFrame(columns = ["RELATIVE_FILE_PATH", "FILE_NAME", "SAMPLE_RATE", "SOURCE"])
index_df

Unnamed: 0,RELATIVE_FILE_PATH,FILE_NAME,SAMPLE_RATE,SOURCE


In [14]:
def moving_window_crop(x: np.ndarray, crop_length: int, crop_stride: int) -> np.ndarray:
    """Crop the input sequence with a moving window.
    """
    if crop_length > x.shape[1]:
        raise ValueError(f"crop_length must be smaller than the length of x ({x.shape[1]}).")
    start_idx = np.arange(0, x.shape[1] - crop_length + 1, crop_stride)
    return [x[:, i:i + crop_length] for i in start_idx]

In [16]:
num_saved = 0
for record_rel_path in tqdm(record_rel_paths):
    record_rel_dir, record_name = os.path.split(record_rel_path)
    save_dir = os.path.join(args.output_dir, record_rel_dir)
    os.makedirs(save_dir, exist_ok=True)
    source_name = record_rel_dir.split("/")[0]
    signal, record_info = wfdb.rdsamp(os.path.join(args.input_dir, record_rel_path))
    lead_idx = np.array([record_info["sig_name"].index(lead_name) for lead_name in _LEAD_NAMES])
    signal = signal[:, lead_idx]
    fs = record_info["fs"]
    signal_length = record_info["sig_len"]
    if signal_length < 10 * fs:  # Exclude the ECGs with lengths of less than 10 seconds
        continue
    cropped_signals = moving_window_crop(signal.T, crop_length=10 * fs, crop_stride=10 * fs)
    for idx, cropped_signal in enumerate(cropped_signals):
        if cropped_signal.shape[1] != 10 * fs or np.isnan(cropped_signal).any():
            continue
        pd.to_pickle(cropped_signal.astype(np.float32),
                     os.path.join(save_dir, f"{record_name}_{idx}.pkl"))
        index_df.loc[num_saved] = [f"{record_rel_path}_{idx}.pkl",
                                   f"{record_name}_{idx}.pkl",
                                   fs,
                                   source_name]
        num_saved += 1

print(f"Saved {num_saved} cropped signals.")
os.makedirs(os.path.dirname(args.index_path), exist_ok=True)
index_df.to_csv(args.index_path, index=False)

100%|████████████████████████████████████████████████████████████| 10247/10247 [48:46<00:00,  3.50it/s]


Saved 10247 cropped signals.


In [17]:
index_df

Unnamed: 0,RELATIVE_FILE_PATH,FILE_NAME,SAMPLE_RATE,SOURCE
0,g1/JS00001_0.pkl,JS00001_0.pkl,500,g1
1,g1/JS00002_0.pkl,JS00002_0.pkl,500,g1
2,g1/JS00004_0.pkl,JS00004_0.pkl,500,g1
3,g1/JS00005_0.pkl,JS00005_0.pkl,500,g1
4,g1/JS00006_0.pkl,JS00006_0.pkl,500,g1
...,...,...,...,...
10242,g9/JS09367_0.pkl,JS09367_0.pkl,500,g9
10243,g9/JS09369_0.pkl,JS09369_0.pkl,500,g9
10244,g9/JS09370_0.pkl,JS09370_0.pkl,500,g9
10245,g9/JS09371_0.pkl,JS09371_0.pkl,500,g9


In [None]:
# Save all the cropped signals
num_saved = 0
for record_rel_path in tqdm(record_rel_paths):
    record_rel_dir, record_name = os.path.split(record_rel_path)
    save_dir = os.path.join(args.output_dir, record_rel_dir)
    os.makedirs(save_dir, exist_ok=True)
    source_name = record_rel_dir.split("/")[0]
    signal, record_info = wfdb.rdsamp(os.path.join(args.input_dir, record_rel_path))
    lead_idx = np.array([record_info["sig_name"].index(lead_name) for lead_name in _LEAD_NAMES])
    signal = signal[:, lead_idx]
    fs = record_info["fs"]
    signal_length = record_info["sig_len"]
    # if signal_length < 10 * fs:  # Exclude the ECGs with lengths of less than 10 seconds
    #     continue
    # cropped_signals = moving_window_crop(signal.T, crop_length=10 * fs, crop_stride=10 * fs)
    # for idx, cropped_signal in enumerate(cropped_signals):
    #     if cropped_signal.shape[1] != 10 * fs or np.isnan(cropped_signal).any():
    #         continue
    #     pd.to_pickle(cropped_signal.astype(np.float32),
    #                  os.path.join(save_dir, f"{record_name}_{idx}.pkl"))
    #     index_df.loc[num_saved] = [f"{record_rel_path}_{idx}.pkl",
    #                                f"{record_name}_{idx}.pkl",
    #                                fs,
    #                                source_name]
    #     num_saved += 1
    break

In [None]:
record_rel_dir, record_name, save_dir, source_name

In [None]:
signal

In [None]:
record_info

In [None]:
if signal_length < 10 * fs:  # Exclude the ECGs with lengths of less than 10 seconds
    continue
cropped_signals = moving_window_crop(signal.T, crop_length=10 * fs, crop_stride=10 * fs)
for idx, cropped_signal in enumerate(cropped_signals):
    if cropped_signal.shape[1] != 10 * fs or np.isnan(cropped_signal).any():
        continue
    pd.to_pickle(cropped_signal.astype(np.float32),
                 os.path.join(save_dir, f"{record_name}_{idx}.pkl"))
    index_df.loc[num_saved] = [f"{record_rel_path}_{idx}.pkl",
                               f"{record_name}_{idx}.pkl",
                               fs,
                               source_name]
    num_saved += 1

In [None]:
if signal_length < 10 * fs:
    print(1)
else: print(0)

In [None]:
signal_length

In [None]:
fs * 10

In [None]:
cropped_signals = moving_window_crop(signal.T, crop_length=10 * fs, crop_stride=10 * fs)
cropped_signals

In [None]:
np.array(signal).shape

In [None]:
np.array(cropped_signals).shape

In [None]:
for idx, cropped_signal in enumerate(cropped_signals):
    print(cropped_signal)
    break

In [None]:
[f"{record_rel_path}_{idx}.pkl", f"{record_name}_{idx}.pkl", fs, source_name]

In [None]:
Y = pd.read_csv('/tf/physionet.org/files/ptb-xl/1.0.3/ptbxl_database.csv', index_col='ecg_id')
Y

In [None]:
Y.scp_codes # = Y.scp_codes.apply(lambda x: ast.literal_eval(x))

In [None]:
import ast
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))
Y.scp_codes