# Libraries

In [None]:
import h5py
import torch

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.metrics  import r2_score
from glob             import glob

import matplotlib.pyplot as plt
import scipy.signal      as signal
import scipy.stats 		 as stats
import torch.nn          as nn
import numpy  			 as np
import pandas 		     as pd
import random


# ECG Definitions

ECG metadata

In [None]:
samplingFrequency = 400

ECG file headers

In [None]:
ecgHeaders = [
    "LI", 
    "LII", 
    "LIII", 
    "aVR", 
    "aVL",
    "aVF", 
    "V1",
    "V2",
    "V3",
    "V4",
    "V5",
    "V6"
]

ECG plot definitions

In [None]:
ecgPlotHeaders = [
    "LI", "aVR", "V1", "V4",
    "LII", "aVL", "V2", "V5",
    "LIII", "aVF", "V3", "V6"
]

ecgPlotColors = {
    "LI":   "seagreen",
    "aVR":  "black",
    "V1":   "gold",
    "V4":   "orangered",
    "LII":  "cornflowerblue",
    "aVL":  "seagreen",
    "V2":   "gold",
    "V5":   "crimson",
    "LIII": "cornflowerblue",
    "aVF":  "cornflowerblue",
    "V3":   "orangered",
    "V6":   "crimson"
}

# The Dataset: CODE-15

Metadata

In [None]:
dataFolder = "../../../data/"

In [None]:
ecgFeatures = [
    "LI", 
    "LII", 
    "LIII", 
    "aVR", 
    "aVL",
    "aVF", 
    "V1",
    "V2",
    "V3",
    "V4",
    "V5",
    "V6"
]

ecgFeaturesIndexes = [ecgHeaders.index(derivation) for derivation in ecgFeatures]
ecgFeaturesIndexes

In [None]:
ecgTarget = [
    "LI", 
    "LII", 
    "LIII", 
    "aVR", 
    "aVL",
    "aVF", 
    "V1",
    "V2",
    "V3",
    "V4",
    "V5",
    "V6"
]

ecgTargetIndexes = [ecgHeaders.index(derivation) for derivation in ecgTarget]
ecgTargetIndexes

Dataset class

In [None]:
class Code15RandomLeadsDataset(Dataset):
    def __init__(self, hdf5Files, features, target, transform, subsetSize=None, seed=None):
        super().__init__()

        self.hdf5Files = hdf5Files
        self.indexMap  = []  
        self.features  = features
        self.target    = target
        self.transform = transform

        self.nLeads      = 12  
        self.leadSubsets = []

        for fileIndex, path in enumerate(self.hdf5Files):

            with h5py.File(path, "r") as f:
                samplesCount = f['exam_id'].shape[0]
                self.indexMap.extend([(fileIndex, i) for i in range(samplesCount)])

        if subsetSize is not None:

            if seed is not None:
                random.seed(seed)

            self.indexMap = random.sample(self.indexMap, min(subsetSize, len(self.indexMap)))

    def preComputeRandomLeadSubset(self):

        leadCounts       = np.zeros(self.nLeads)
        self.leadSubsets = []

        for _ in range(len(self.indexMap)):
            nLeadsToPick = random.randint(3, 9)

            odds =  1 / (leadCounts + 1)
            odds /= odds.sum()

            chosen = np.random.choice(
                a       = self.nLeads,
                size    = nLeadsToPick,
                replace = False,
                p       = odds
            )

            self.leadSubsets.append(chosen)

            leadCounts[chosen] += 1

    def __len__(self):
        return len(self.indexMap)

    def __getitem__(self, idx):
        
        fileIndex, examIdx = self.indexMap[idx]
        hdf5File           = self.hdf5Files[fileIndex]

        with h5py.File(hdf5File, "r") as file:
            tracing = np.array(file['tracings'][examIdx]) 

        tracing     = self.transform(tracing)
        randomLeads = self.leadSubsets[idx]

        X = np.zeros_like(tracing)
        X[:, randomLeads] = tracing[:, randomLeads]

        X = X[:, self.features]
        X = torch.tensor(X, dtype = torch.float32)

        Y = tracing[:, self.target]
        Y = torch.tensor(Y, dtype = torch.float32)

        return X, Y

Transform Function

- highpass butterworth filter with $ f_c = 1$ Hz
- Truncation in the tails of $ N = 600  $ samples
- Gain of 5
- Normalization with z-score

In [None]:
def transform(ecg):
    b, a = signal.butter(
        N     = 1, 
        Wn    = 1, 
        btype = 'high', 
        fs    = samplingFrequency
    )
    
    ecgFiltred  = signal.filtfilt(b, a, ecg, axis = 0)
    ecgWithGain = 5 * ecgFiltred
    ecgClean    = ecgWithGain[600: -600, :]

    ecgMean = np.mean(ecgClean, axis = 0, keepdims = True)
    ecgStd  = np.std(ecgClean,  axis = 0, keepdims = True) + 1e-8

    ecgNormalized = (ecgClean - ecgMean) / ecgStd

    return ecgNormalized


In [None]:
randomLeadsDataset = Code15RandomLeadsDataset(
    hdf5Files  = glob(f"{dataFolder}/*.hdf5"),
    features   = ecgFeaturesIndexes,
    target     = ecgTargetIndexes,    
    transform  = transform,
    subsetSize = 20_000,
    seed       = 42
)

randomLeadsDataset.preComputeRandomLeadSubset()

In [None]:
len(randomLeadsDataset)

Dataloader

In [None]:
dataloader = DataLoader(
    dataset     = randomLeadsDataset,
    batch_size  = 1,
    shuffle     = False,
    num_workers = 4

)

# Model loading

In [None]:
class ECGReconstructor(nn.Module):
    def __init__(self, latentDim, hiddenDim):
        super().__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv1d(12, 6, 5, padding = 2),
            nn.ReLU(),
            nn.Conv1d(6, hiddenDim, 5, padding = 2),
            nn.ReLU(),
            nn.Conv1d(hiddenDim, latentDim, 5, padding = 2)
        )
        
        self.decoder = nn.Sequential(
            nn.Conv1d(latentDim, hiddenDim, 5, padding = 2),
            nn.ReLU(),
            nn.Conv1d(hiddenDim, 6, 5, padding = 2),
            nn.ReLU(),
            nn.Conv1d(6, 12, 5, padding = 2)
        )

    def forward(self, x):
        x = x.permute(0, 2, 1)  

        encoded = self.encoder(x) 
        decoded = self.decoder(encoded)
        decoded = decoded.permute(0, 2, 1)

        return decoded

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = ECGReconstructor(latentDim = 128, hiddenDim = 32).to(device)

model.load_state_dict(
    torch.load(
        "../../../models/ae-t4-v1.pth", 
        weights_only = True
    )
)

model

# Graphical analysis of reconstruction performance

In [None]:
r2Scores = pd.DataFrame(
    columns = ecgTarget,
    index   = range(len(randomLeadsDataset))
)

correlations = pd.DataFrame(
    columns = ecgTarget,
    index   = range(len(randomLeadsDataset))
)

Calculating Pearson Correlations and r2 score for each ECG

In [None]:
model.eval()

with torch.no_grad():
    for i, (X, Y) in enumerate(dataloader):
        X, Y       =  X.to(device), Y.to(device)
        prediction =  model(X)

        prediction = prediction.cpu()[0]
        Y 		   = Y.cpu()[0]
        
        r2Row   	   = []
        correlationRow = []

        for j in range(len(ecgTarget)):

            yTrue = Y[:, j].numpy()
            yPred = prediction[:, j].numpy()

            r2 = r2_score(yTrue, yPred)

            if np.std(yTrue) == 0 or np.std(yPred) == 0:
                correlation = 0 
            else:
                correlation = stats.pearsonr(yTrue, yPred).statistic

            r2Row.append(r2)
            correlationRow.append(correlation)

        r2Scores.iloc[i]     = r2Row
        correlations.iloc[i] = correlationRow

The comparative plot function

In [None]:
def methodComparativePlot(df, derivation, method):
    dfMean = np.mean(df[derivation])
    dfMean = np.round(dfMean, 3)

    figure, axes = plt.subplots(nrows = 1, ncols = 2, figsize = (12, 6))

    axes[0].set_title(f"{method}($ {derivation} $, $ {derivation}_{{rec}} $)")
    axes[1].set_title(f"Histograma - {method}($ {derivation} $, $ {derivation}_{{rec}} $)")

    axes[0].set_xlabel("n")
    axes[0].set_ylabel(f"{method}")
    
    axes[1].set_xlabel(f"{method}")
    axes[1].set_ylabel("Frequência")


    axes[0].scatter(
        df.index, 
        df[derivation]
    )
    axes[0].axhline(
        dfMean, 
        color     = 'r', 
        linestyle = '--', 
        label     = f"Média = {dfMean}"
    )


    counts, bins = np.histogram(df[derivation], 50)
    axes[1].stairs(counts / len(df[derivation]), bins, fill = True)
    axes[1].axvline(
        dfMean, 
        color     = 'r', 
        linestyle = '--', 
        label     = f"Média = {dfMean}"
    )

    axes[1].legend()
    axes[0].legend()

    plt.show()

    plt.close()

## Results

In [None]:
for derivation in ecgTarget:
    methodComparativePlot(r2Scores, derivation, "R²")

In [None]:
for derivation in ecgTarget:
    methodComparativePlot(correlations, derivation, "CORR")