In this tutoral, we demonstrate
1. how to load the pretrained model
2. how to compute the representation of the ECG

# Load the pretrained model

In [1]:
import numpy as np
import torch
from models import load_encoder

# Load the encoder 
ckpt_dir = './weights/multiblock_epoch100.pth'  # See https://github.com/sehunfromdaegu/ECG_JEPA for checkpoint download link 
encoder, dim = load_encoder(ckpt_dir=ckpt_dir) # dim is the dimension of the latent space

  from .autonotebook import tqdm as notebook_tqdm


# Dataloader construction

In [2]:
from ecg_data import ECGDataset, ECGDataset_pretrain

# Dummy ECG data
n_samples = 32
waves = np.random.randn(n_samples, 8, 2500) # n_samples, 8 leads, 2500 timesteps 
labels = np.random.randint(0, 2, 32)

# Dataset with labels.
dataset_with_labels = ECGDataset(waves, labels)

# Dataset without labels. If you want to compute the representations without label information, use this.
dataset_wo_labels = ECGDataset_pretrain(waves)

# Create a dataloader
dataloader_with_labels = torch.utils.data.DataLoader(dataset_with_labels, batch_size=4, shuffle=True)
dataloader_wo_labels = torch.utils.data.DataLoader(dataset_wo_labels, batch_size=4, shuffle=True)



# Compute the ECG representations

In [3]:
for wave, target in dataloader_with_labels:
    repr = encoder.representation(wave) # (bs, 8, 2500) -> (bs, dim)
    print(f'Representation shape: {repr.shape}')    
    break

for wave in dataloader_wo_labels:
    repr = encoder.representation(wave)
    break




Representation shape: torch.Size([4, 768])
