# Notebook for Generating ECG Embeddings from ECG-FM

This notebook contains the deidentified code we used to generate embeddings using the [ECG-FM](https://arxiv.org/abs/2408.05178) [1] models from standard 12-lead electrocardiograms.

Since the unique identifiers (UIDs) and the way they map to patients and waveform files are institution dependent and we risk identifying subjects from our own clinical center, we provide only the code necessary to save the embeddings. Please note that our ECGs were already saved as `.npy` files before loading.

Please find details regarding the models and weights from [HuggingFace](https://huggingface.co/wanglab/ecg-fm/tree/main) and the original authors' [GitHub repository](https://github.com/bowang-lab/ECG-FM). Additionally, you will want to reference the [`fairseq-signals`](https://github.com/Jwoo5/fairseq-signals) framework. You will have to use the environment in the `fairseq_signals_env.yaml` file or any appropriate `Python` environment with `fairseq-signals`.

[1] McKeen, Kaden, et al. "Ecg-fm: An open electrocardiogram foundation model." arXiv preprint arXiv:2408.05178 (2024).

In [None]:
import os
import numpy as np
import pandas as pd
import torch

from huggingface_hub import hf_hub_download
from fairseq_signals.utils.store import MemmapReader
from fairseq_signals.models import build_model_from_checkpoint

from scipy.interpolate import interp1d
from scipy.io import loadmat
import matplotlib.pyplot as plt
import pickle

In [None]:
class ecg_dataset(torch.utils.data.Dataset):

    def __init__(self, waveforms, labels):

        self.waveforms = waveforms
        self.labels = labels

    def __len__(self):
        return len(self.waveforms)
        
    def __getitem__(self, idx):

        waveform = self.waveforms[idx]
        waveform = torch.from_numpy(waveform).type(torch.FloatTensor)

        label = self.labels[idx].astype(float)
        label = torch.from_numpy(label).type(torch.FloatTensor)
        
        return {'waveform':waveform, 'label':label}

def get_files_of_type(parent_path:str, filetype:str, as_dict:bool=False) :

    assert os.path.isdir(parent_path), f'{parent_path} is not a valid directory.'

    path_list = sorted([os.path.join(root,file) for root,_,files in os.walk(parent_path) for file in files if file.endswith(filetype)])

    assert len(path_list) > 0, f'{parent_path} contains 0 files of file type {filetype}.'
    
    if as_dict :
        
        path_dict = {".".join(os.path.split(path)[-1].split(".")[:-1]):path for path in path_list}

        return path_dict

    return path_list


# Code from original fairseq-signals auhtors
# https://github.com/Jwoo5/fairseq-signals/blob/6b9e1375bbe6bc2e55ff4d9b95eabe8eee3adec7/scripts/preprocess/ecg/preprocess.py#L49C1-L62C77
def resample(feats, curr_sample_rate, desired_sample_rate):
    """
    Resample an ECG using linear interpolation.
    """
    if curr_sample_rate == desired_sample_rate:
        return feats

    desired_sample_size = int(
        feats.shape[-1] * (desired_sample_rate / curr_sample_rate)
    )

    x = np.linspace(0, desired_sample_size - 1, feats.shape[-1])

    return interp1d(x, feats, kind='linear')(np.arange(desired_sample_size))

# Code adapted from original fairseq-signals auhtors
# https://github.com/Jwoo5/fairseq-signals/blob/6b9e1375bbe6bc2e55ff4d9b95eabe8eee3adec7/scripts/preprocess/ecg/preprocess.py#L139
def lead_std_divide(feats, std = None, constant_lead_strategy='zero'):
        
    if not isinstance(std, np.ndarray) : 
        std = feats.std(axis=1, keepdims=True)
    
    std_zero = std == 0
    
    # Check if there are any zero stds or if strategy is 'nan'
    if not std_zero.any() or constant_lead_strategy == 'nan':
        # Directly divide, which will turn constant leads into NaN if any
        feats = feats / std

        return feats, std

    # Replace zero standard deviations with 1 temporarily to avoid division by zero
    std_replaced = np.where(std_zero, 1, std)
    feats = feats / std_replaced

    if constant_lead_strategy == 'zero':
        # Replace constant leads to be 0
        zero_mask = np.broadcast_to(std_zero, feats.shape)
        feats[zero_mask] = 0

    elif constant_lead_strategy == 'constant':
        # Leave constant leads as is
        pass

    else:
        raise ValueError("Unexpected constant lead strategy.")

    return feats, std

def standard_normalize_ecg(waveform) :

    waveform -= waveform.mean(axis = 1, keepdims = True)
    waveform = lead_std_divide(waveform)[0]

    return waveform

# Code adapted from code by fairseq-signals authors
# https://github.com/Jwoo5/fairseq-signals/blob/6b9e1375bbe6bc2e55ff4d9b95eabe8eee3adec7/fairseq_signals/models/classification/ecg_transformer_classifier.py#L55
def sum_and_divide(x) :
    return torch.div(x.sum(dim=1), (x != 0).sum(dim=1))

def get_all_ecg_embeddings(model, train_loader, val_loader, test_loader) :

    all_embeddings = {data_split:[] for data_split in ['train', 'val', 'test']}

    for data_split,current_loader in [('train', train_loader), ('val', val_loader), ('test', test_loader)] :
        for batch in current_loader :
            waveform = batch['waveform'].cuda()
            with torch.no_grad() :
                emb = model(source = waveform)
                emb = {k:v.cpu() if isinstance(v, torch.Tensor) and v.is_cuda else v for k,v in emb.items()}
                all_embeddings[data_split].append(emb)

    return all_embeddings

In [None]:
# TO-DO: Define paths for data loading and saving
HF_LOCAL_DIR = None
COMBINED_DATA_FOLDER_PATH = None
EMBEDDINGS_FOLDER_PATH = None
COMBINED_TABLE_PATH = None
fairseq_signals_root = None

assert os.path.isdir(HF_LOCAL_DIR)
assert os.path.isdir(COMBINED_DATA_FOLDER_PATH)
assert os.path.isdir(EMBEDDINGS_FOLDER_PATH)
assert os.path.isdir(fairseq_signals_root)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'{device=}')

In [None]:
_ = hf_hub_download(
    repo_id='wanglab/ecg-fm-preprint',
    filename='mimic_iv_ecg_physionet_pretrained.pt',
    local_dir=os.path.join(HF_LOCAL_DIR, 'ckpts'),
)
_ = hf_hub_download(
    repo_id='wanglab/ecg-fm-preprint',
    filename='mimic_iv_ecg_physionet_pretrained.yaml',
    local_dir=os.path.join(HF_LOCAL_DIR, 'ckpts'),
)

model = build_model_from_checkpoint(checkpoint_path=os.path.join(HF_LOCAL_DIR, 'ckpts/mimic_iv_ecg_physionet_pretrained.pt'))
model.to(device)

In [None]:
# "We sampled the waveforms at 500 Hz, performed z-score normalization, and segmented the signals into non-overlapping 5 s segments to produce the model inputs.""
# https://arxiv.org/abs/2408.05178
current_sample_rate = 250
desired_sample_rate = 500

npy_paths = get_files_of_type(COMBINED_DATA_FOLDER_PATH, filetype = '.npy', as_dict = True)

# TO-DO: Load data
X_train = None
y_train = None
X_val = None
y_val = None
X_test = None
y_test = None

if X_train.shape[1] == 1 :  X_train = X_train.squeeze(axis = 1)
if X_val.shape[1] == 1 :  X_val = X_val.squeeze(axis = 1)
if X_test.shape[1] == 1 :  X_test = X_test.squeeze(axis = 1)

if X_train.shape[-1] == 12 : X_train = X_train.transpose(0, 2, 1)
if X_val.shape[-1] == 12 : X_val = X_val.transpose(0, 2, 1)
if X_test.shape[-1] == 12 : X_test = X_test.transpose(0, 2, 1)

if current_sample_rate != desired_sample_rate :
    X_train = np.array([resample(feats = waveform, curr_sample_rate = current_sample_rate, desired_sample_rate = desired_sample_rate) for waveform in X_train])
    X_val = np.array([resample(feats = waveform, curr_sample_rate = current_sample_rate, desired_sample_rate = desired_sample_rate) for waveform in X_val])
    X_test = np.array([resample(feats = waveform, curr_sample_rate = current_sample_rate, desired_sample_rate = desired_sample_rate) for waveform in X_test])

X_train = np.array([standard_normalize_ecg(waveform) for waveform in X_train])
X_val = np.array([standard_normalize_ecg(waveform) for waveform in X_val])
X_test = np.array([standard_normalize_ecg(waveform) for waveform in X_test])

print(f'{X_train.shape=} {X_val.shape=} {X_test.shape=}')

train_dataset = ecg_dataset(waveforms = X_train, labels = y_train)
val_dataset = ecg_dataset(waveforms = X_val, labels = y_val)
test_dataset = ecg_dataset(waveforms = X_test, labels = y_test)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 1, shuffle = False)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = 1, shuffle = False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = 1, shuffle = False)

In [None]:
full_ecg_embeddings_path = os.path.join(EMBEDDINGS_FOLDER_PATH, 'full_ecg_embeddings_dict.pkl')

if not os.path.isfile(full_ecg_embeddings_path) :

    all_embeddings = get_all_ecg_embeddings(model, train_loader, val_loader, test_loader)
    
    with open(full_ecg_embeddings_path, 'wb') as handle:
        pickle.dump(all_embeddings, handle, protocol=pickle.HIGHEST_PROTOCOL)

else :

    with open(full_ecg_embeddings_path, 'rb') as handle:
        all_embeddings = pickle.load(handle)

processed_embbedings = {data_split:[sum_and_divide(emb['features']) for emb in all_embeddings[data_split]] for data_split in ['train', 'val', 'test']}