<h1> 
Case Study: MOMENT for ECG Classification using PTB-XL, a large publicly available electrocardiography dataset
</h1>
<hr>

## Contents
### 1. PTB-XL dataset
#### &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 1.1 Download PTB-XL dataset
### 2. Loading MOMENT
### 3. Method 1: Learning a Statistical ML Classifier on MOMENT Embeddings
#### &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 3.1 Load PTB-XL dataset
#### &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 3.2 Dignostic label Classification using raw ECG signal
#### &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 3.3 Dignostic label Classification using MOMENT embedding on ECG signal
### 4. Method 2: Finetuning Linear Classification Head
### 5. Method 3: Full Finetuning MOMENT
#### &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 5.1 Assess MOMENT embedding with SVM after finetuning the encoder
#### &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 5.2 Training with Multiple GPUs and Parameter Efficient FineTuning (PEFT)

## 1. Problem: Classifying Abnormal ECG using PTB-XL

[PTB-XL](https://www.nature.com/articles/s41597-020-0495-6) is a large publicly available electrocardiography (ECG) dataset. It comprises of 21,837 12-lead, 10 seconds long ECG recordings collected from 18,885 patients. The ECG-waveform data was annotated by two cardiologists as a multi-label dataset, where diagnostic labels were further aggregated into super and subclasses. In this notebook, we will classify each 10 second ECG recording into one of 5 SCP ECG classes: (1) Normal ECG, (2) Conduction Disturbance, (3) Myocardial Infarction, (4) Hypertrophy, and (5) ST/T change. 

### 1.1 Download PTB-XL dataset

PTB-XL is avaliable on Physionet, and can be downloaded [here](https://physionet.org/content/ptb-xl/1.0.3/).

## 2. Loading MOMENT

We will first install the MOMENT package, load some essential packages and the pre-trained model. 

MOMENT can be loaded in 4 modes: (1) `reconstruction`, (2) `embedding`, (3) `forecasting`, and (4) `classification`.

In the `reconstruction` mode, MOMENT reconstructs input time series, potentially containing missing values. We can solve imputation and anomaly detection problems in this mode. This mode is suitable for solving imputation and anomaly detection tasks. During pre-training, MOMENT is trained to predict the missing values within uniformly randomly masked patches (disjoint sub-sequences) of the input time series, leveraging information from observed data in other patches. As a result, MOMENT comes equipped with a pre-trained reconstruction head, enabling it to address imputation and anomaly detection challenges in a zero-shot manner! Check out the `anomaly_detection.ipynb` and `imputation.ipynb` notebooks for more details!

In the `embedding` model, MOMENT learns a $d$-dimensional embedding (e.g., $d=1024$ for `MOMENT-1-large`) for each input time series. These embeddings can be used for clustering and classification. MOMENT can learn embeddings in a zero-shot setting! Check out `classification.ipynb` notebook for more details! 

The `forecasting` and `classification` modes are used for forecasting and classification tasks, respectively. In these modes, MOMENT learns representations which are subsequently mapped to the forecast horizon or the number of classes, using linear forecasting and classification heads. Both the forecasting and classification head are randomly initialized, and therefore must be fine-tuned before use. Check out the `forecasting.ipynb` notebook for more details!

In [None]:
# !pip install numpy pandas scikit-learn matplotlib tqdm
# !pip install git+https://github.com/moment-timeseries-foundation-model/moment.git

In [1]:
from momentfm import MOMENTPipeline

model = MOMENTPipeline.from_pretrained(
    "AutonLab/MOMENT-1-large", 
    model_kwargs={
        'task_name': 'classification',
        'n_channels': 12, # number of input channels
        'num_class': 5,
        'freeze_encoder': True, # Freeze the patch embedding layer
        'freeze_embedder': True, # Freeze the transformer encoder
        'freeze_head': False, # The linear forecasting head must be trained
        ## NOTE: Disable gradient checkpointing to supress the warning when linear probing the model as MOMENT encoder is frozen
        'enable_gradient_checkpointing': False,
        # Choose how embedding is obtained from the model: One of ['mean', 'concat']
        # Multi-channel embeddings are obtained by either averaging or concatenating patch embeddings 
        # along the channel dimension. 'concat' results in embeddings of size (n_channels * d_model), 
        # while 'mean' results in embeddings of size (d_model)
        'reduction': 'mean',
    },
    # local_files_only=True,  # Whether or not to only look at local files (i.e., do not try to download the model).
    )

In [2]:
model.init()
print(model)

MOMENTPipeline(
  (normalizer): RevIN()
  (tokenizer): Patching()
  (patch_embedding): PatchEmbedding(
    (value_embedding): Linear(in_features=8, out_features=1024, bias=False)
    (position_embedding): PositionalEmbedding()
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 1024)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1024, out_features=1024, bias=False)
              (k): Linear(in_features=1024, out_features=1024, bias=False)
              (v): Linear(in_features=1024, out_features=1024, bias=False)
              (o): Linear(in_features=1024, out_features=1024, bias=False)
              (relative_attention_bias): Embedding(32, 16)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
  

In [3]:
# Number of parameters in the encoder
num_params = sum(p.numel() for p in model.encoder.parameters())
print(f"Number of parameters: {num_params}")

Number of parameters: 341231104


In [4]:
import random
import os 
import torch 
import numpy as np 

def control_randomness(seed: int = 42):
    """Function to control randomness in the code."""
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
control_randomness(42)

## 3. Method 1: Learning a Statistical ML Classifier

First we will use MOMENT to generate powerful representations of ECG time series in zero-shot settings. Note that the PTB-XL dataset was not observed during pre-training. 

We will use these representations and labels from the PTB-XL dataset to train a Support Vector Machine (SVM) classifier. To illustrate the value of MOMENT's representations, we will also train another SVM classifier with the raw ECG data. This setting is common in field of unsupervised representation learning, where the goal is to learn meaningful time series representations without any labeled data (see [TS2Vec](https://arxiv.org/pdf/2106.10466) for a recent example). The quality of these representations are evaluated based on the performance of the downstream classifier (in this case, SVM). This is also the setting that we consider in our [paper](https://arxiv.org/abs/2402.03885). 

Checkout `representation_learning.ipynb` for details on how we can use MOMENT to embed time series data. 

### 3.1 Load PTBXL dataset

Once you have downloaded the PTB-XL dataset, make sure to unzip it! The PTB-XL dataset will read and pre-process the dataset.

In [5]:
from momentfm.data.ptbxl_classification_dataset import PTBXL_dataset
import torch 

class Config:
    # Path to the unzipped PTB-XL dataset folder
    basepath = '/zfsauton/project/public/Mononito/ptb-xl' # 'path/to/ptbxl_dataset'

    #path to cache directory to store preprocessed dataset if needed
    #note that preprocessing the dataset is time consuming so you might be benefited to cache it
    cache_dir = '/home/scratch/mgoswami/' # 'path/to/cache_dir'
    load_cache = True

    #sampling frequency, choose from 100 or 500
    fs = 100

    # Class to predict
    code_of_interest = 'diagnostic_class'
    output_type = 'Single'

    #sequence length, only support 512 for now
    seq_len = 512

args = Config()

#create dataloader for training and testing
train_dataset = PTBXL_dataset(args, phase='train')
test_dataset = PTBXL_dataset(args, phase='test')
val_dataset = PTBXL_dataset(args, phase='val')

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=16, shuffle=False)

[INFO] Loading PTB-XL


AttributeError: 'Config' object has no attribute 'basepath'

### 3.2 Dignostic label Classification using raw ECG signal

In this setting, we concat raw ECG signal along the channel dimension, and feed the concatenated time-series directly into a SVM. The goal is to provide a baseline to assess MOMENT embeddings.

In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm 
import numpy as np 

def get_timeseries(dataloader: DataLoader, agg='mean'):
    '''
    We provide two aggregation methods to convert the 12-lead ECG (2-dimensional) to a 1-dimensional time-series for SVM training:
    - mean: average over all channels, result in [1 x seq_len] for each time-series
    - channel: concat all channels, result in [1 x seq_len * num_channels] for each time-series

    labels: [num_samples]
    ts: [num_samples x seq_len] or [num_samples x seq_len * num_channels]

    *note that concat all channels will result in a much larger feature dimensionality, thus making the fitting process much slower
    '''
    ts, labels = [], []

    with torch.no_grad():
        for batch_x, batch_labels in tqdm(dataloader, total=len(dataloader)):
            # [batch_size x 12 x 512]
            if agg == 'mean':
                batch_x = batch_x.mean(dim=1)
                ts.append(batch_x.detach().cpu().numpy())
            elif agg == 'channel':
                ts.append(batch_x.view(batch_x.size(0), -1).detach().cpu().numpy())
            labels.append(batch_labels)        

    ts, labels = np.concatenate(ts), np.concatenate(labels)
    return ts, labels

In [None]:
# Fit a SVM classifier on the concatenated raw ECG signals
from momentfm.models.statistical_classifiers import fit_svm

train_embeddings, train_labels = get_timeseries(train_loader, agg='mean')
clf = fit_svm(features=train_embeddings, y=train_labels)
train_accuracy = clf.score(train_embeddings, train_labels)

test_embeddings, test_labels = get_timeseries(test_loader)
test_accuracy = clf.score(test_embeddings, test_labels)

print(f"Train accuracy: {train_accuracy:.2f}")
print(f"Test accuracy: {test_accuracy:.2f}")

### 3.3 Dignostic label Classification using MOMENT embedding on ECG signal

In this setting, we use MOMENT to embed time series data (see `representation_learning.ipynb`). Next, we train a Support Vector Machine (SVM) classifier using these embeddings as features and labels. This setting is common in field of unsupervised representation learning, where the goal is to learn meaningful time series representations without any labeled data (see [TS2Vec](https://arxiv.org/pdf/2106.10466) for a recent example). The quality of these representations are evaluated based on the performance of the downstream classifier (in this case, SVM). This is also the setting that we consider in our [paper](https://arxiv.org/abs/2402.03885). 

In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm 
import numpy as np 
from momentfm.models.statistical_classifiers import fit_svm

def get_embeddings(model, device, reduction, dataloader: DataLoader):
    '''
    labels: [num_samples]
    embeddings: [num_samples x d_model]
    '''
    embeddings, labels = [], []
    model.to(device)
    model.eval()

    with torch.no_grad():
        for batch_x, batch_labels in tqdm(dataloader, total=len(dataloader)):
            # [batch_size x 12 x 512]
            batch_x = batch_x.to(device).float()
            # [batch_size x num_patches x d_model (=1024)]
            output = model(x_enc=batch_x, reduction=reduction) 
            #mean over patches dimension, [batch_size x d_model]
            embedding = output.embeddings.mean(dim=1)
            embeddings.append(embedding.detach().cpu().numpy())
            labels.append(batch_labels)        

    embeddings, labels = np.concatenate(embeddings), np.concatenate(labels)
    return embeddings, labels

In [None]:
#set device to be 'cuda:0' or 'cuda' if you only have one GPU
device = 'cuda:6'
reduction = 'mean'
train_embeddings, train_labels = get_embeddings(model, device, reduction, train_loader)
clf = fit_svm(features=train_embeddings, y=train_labels)
train_accuracy = clf.score(train_embeddings, train_labels)

test_embeddings, test_labels = get_embeddings(model, device, reduction, test_loader)
test_accuracy = clf.score(test_embeddings, test_labels)

print(f"Train accuracy: {train_accuracy:.2f}")
print(f"Test accuracy: {test_accuracy:.2f}")

We saw that MOMENT-extracted embedding improves test time accuracy from 60% to 76%! Note that PTB-XL ECG signals does NOT appear in MOMENT pretraining data. This performance improvement shows MOMENT's high quality representation generation ability under zero shot setting.

## Method 2: Finetuning the Linear Classification Head only

In this setting, we freeze the MOMENT encoder and finetune the linear classification head using Cross Entropy Loss. MOMENT encoder is frozen by default.

In [None]:
def train_epoch(model, device, train_dataloader, criterion, optimizer, scheduler, reduction='mean'):
    '''
    Train only classification head
    '''
    model.to(device)
    model.train()
    losses = []

    for batch_x, batch_labels in train_dataloader:
        optimizer.zero_grad()
        batch_x = batch_x.to(device).float()
        batch_labels = batch_labels.to(device)

        #note that since MOMENT encoder is based on T5, it might experiences numerical unstable issue with float16
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float32):
            output = model(x_enc=batch_x, reduction=reduction)
            loss = criterion(output.logits, batch_labels)
        loss.backward()

        optimizer.step()
        scheduler.step()
        losses.append(loss.item())
    
    avg_loss = np.mean(losses)
    return avg_loss

In [None]:
def evaluate_epoch(dataloader, model, criterion, device, phase='val', reduction='mean'):
    model.eval()
    model.to(device)
    total_loss, total_correct = 0, 0

    with torch.no_grad():
        for batch_x, batch_labels in dataloader:
            batch_x = batch_x.to(device).float()
            batch_labels = batch_labels.to(device)

            output = model(x_enc=batch_x, reduction=reduction)
            loss = criterion(output.logits, batch_labels)
            total_loss += loss.item()
            total_correct += (output.logits.argmax(dim=1) == batch_labels).sum().item()
    
    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / len(dataloader.dataset)
    return avg_loss, accuracy

In [None]:
from tqdm import tqdm
import numpy as np 

epoch = 5
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.head.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=epoch * len(train_loader))
device = 'cuda:3'

for i in tqdm(range(epoch)):
    train_loss = train_epoch(model, device, train_loader, criterion, optimizer, scheduler)
    val_loss, val_accuracy = evaluate_epoch(val_loader, model, criterion, device, phase='test')
    print(f'Epoch {i}, train loss: {train_loss}, val loss: {val_loss}, val accuracy: {val_accuracy}')

test_loss, test_accuracy = evaluate_epoch(test_loader, model, criterion, device, phase='test')
print(f'Test loss: {test_loss}, test accuracy: {test_accuracy}')

## 5. Method 3: Full Finetuning MOMENT

In this section, we unfreeze MOMENT encoder and finetune the full model on PTB-XL dataset

In [None]:
#loading MOMENT with encoder unfrozen
from momentfm import MOMENTPipeline

model = MOMENTPipeline.from_pretrained(
                                        "AutonLab/MOMENT-1-large", 
                                        model_kwargs={
                                            'task_name': 'classification',
                                            'n_channels': 12,
                                            'num_class': 5,
                                            'freeze_encoder': False,
                                            'freeze_embedder': False,
                                            'reduction': 'mean',
                                        },
                                        )
model.init()

In [None]:
# the learning rate should be smaller to guide the encoder to learn the task without forgetting the pre-trained knowledge
import torch
from tqdm import tqdm
import numpy as np

epoch = 5
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-4, total_steps=epoch * len(train_loader))
device = 'cuda:3'

for i in tqdm(range(epoch)):
    train_loss = train_epoch(model, device, train_loader, criterion, optimizer, scheduler)
    val_loss, val_accuracy = evaluate_epoch(val_loader, model, criterion, device, phase='test')
    print(f'Epoch {i}, train loss: {train_loss}, val loss: {val_loss}, val accuracy: {val_accuracy}')

test_loss, test_accuracy = evaluate_epoch(test_loader, model, criterion, device, phase='test')
print(f'Test loss: {test_loss}, test accuracy: {test_accuracy}')

### 5.1 Assess MOMENT embedding with SVM after finetuning the encoder

In [None]:
#set device to be 'cuda:0' or 'cuda' if you only have one GPU
device = 'cuda:3'
reduction = 'mean'
train_embeddings, train_labels = get_embeddings(model, device, reduction, train_loader)
clf = fit_svm(features=train_embeddings, y=train_labels)
train_accuracy = clf.score(train_embeddings, train_labels)

test_embeddings, test_labels = get_embeddings(model, device, reduction, test_loader)
test_accuracy = clf.score(test_embeddings, test_labels)

print(f"Train accuracy: {train_accuracy:.2f}")
print(f"Test accuracy: {test_accuracy:.2f}")

We saw that after MOMENT encoder is finetuned for downstream dataset, the embedding gives better test accuracy with SVM

### 5.2 Training with Multi-GPU and Parameter Efficient FineTuning (PEFT)

It might be of interest to the research community with an example to train MOMENT with multi-gpu and PEFT approaches. We also offer a script where this could be achieved.

Note that number of processes should be adjusted in the config file at in finetune_demo/ds.ymal according to your setup.

In [None]:
!CUDA_VISIBLE_DEVICES=3,4 accelerate launch --config_file tutorials/finetune_demo/ds.yaml \
    tutorials/finetune_demo/classification.py \
    --base_path path to your ptbxl base folder \
    --cache_dir path to cache directory for preprocessed dataset \
    --mode full_finetuning \
    --output_path path to store train log and checkpoint \

The code also supports [LoRA](https://arxiv.org/abs/2106.09685) as a way of doing parameter efficient finetuning. To use LoRA, simply add a flag to the command line above. Currently, LoRA doesn't work well with deepspeed zero3, therefore one might consider switching to stage 2 for LoRA in finetune_demo/ds.ymal

In [None]:
!CUDA_VISIBLE_DEVICES=3,4 accelerate launch --config_file tutorials/finetune_demo/ds.yaml \
    tutorials/finetune_demo/classification.py \
    --base_path path to your ptbxl base folder \
    --cache_dir path to cache directory for preprocessed dataset \
    --mode full_finetuning \
    --output_path path to store train log and checkpoint \
    --lora