In [3]:
import torch
from src.utils import load_data, pad_window_interval
from src.tokenization import tokenize
from src.positional_encoding import positional_encoding
import src.transformer_model
from src.change_points import detect_changes
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from src.dataset import PTBXL
import numpy as np
from numpy.typing import NDArray
import pandas as pd
import ruptures as rpt
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import pandas as

# Load Dataset

In [4]:
path = './dataset/ecg/WFDB_PTBXL/ptbxl/'
sampling_rate=100
X_train, y_train, X_test, y_test = load_data(path, sampling_rate)

# preprocess data without labels
mask = y_train.apply(lambda x: isinstance(x, list) and len(x) > 0)
y_train = y_train[mask]
X_train = X_train[mask]
y_train = y_train.apply(lambda x: x[0]).astype('category')
y_train = torch.tensor(y_train.cat.codes.values, dtype=torch.long)

print(torch.unique(y_train))

tensor([0, 1, 2, 3, 4])


# Visualize

In [5]:
def detect_changes(time_series: NDArray):

    # use 'l2', better for multivariate data
    algo = rpt.Pelt(model="rbf", min_size=10, jump=20).fit(time_series)
    result = algo.predict(pen=5)

    change_points = [i for i in result if i < len(time_series)]
    return change_points

In [6]:
ECG_LABELS = ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]

In [None]:
for time_series_idx, time_series in enumerate(X_train[0:1]):

    # Create a figure with subplots (12 channels in a single column)
    fig, axs = plt.subplots(12, 1, figsize=(12, 12), sharey=True)
    plt.rcParams.update({'font.family': 'Times New Roman'})

    change_points = detect_changes(time_series)
    # change_points = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150]

    # Loop through each of the 12 ECG channels
    for i in range(time_series.shape[1]):
        channel_data = time_series[:, i] # Extract data for the i-th channel (lead)
        axs[i].plot(channel_data, label=f"Channel {i+1}", color='#d62728')
        
        # Add vertical lines for each detected change point
        for cp in change_points:
            axs[i].axvline(x=cp, color='#1f77b4', label=f"Change Point {cp}" if i == 0 else "")

        if i == 11:
            axs[i].set_xlabel('Time')
        else:
            axs[i].tick_params(axis='x', which='both', bottom=False, labelbottom=False)  # Remove x-axis ticks

        axs[i].set_ylabel(ECG_LABELS[i], rotation=0, labelpad=10)
        axs[i].yaxis.set_label_coords(0.02, 0) 

    # Adjust layout to avoid overlap
    plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for the main title
    fig.text(-0.02, 0.5, 'Amplitude', va='center', rotation='vertical', fontsize=14)
    
    # Show the plot for the current time series
    plt.savefig("ecg_plot.png", format="png", dpi=300, bbox_inches='tight')
    plt.show()