# Libraries

In [None]:
import h5py
import torch
import random

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from sklearn.metrics  import r2_score
from glob             import glob

import matplotlib.pyplot as plt
import scipy.signal      as signal
import torch.nn          as nn
import numpy  			 as np
import pandas 		     as pd


# ECG Definitions

ECG metadata

In [None]:
samplingFrequency = 400

ECG file headers

In [None]:
ecgHeaders = [
	"LI", 
	"LII", 
	"LIII", 
	"aVR", 
	"aVL",
	"aVF", 
	"V1",
	"V2",
	"V3",
	"V4",
	"V5",
	"V6"
]

ECG plot definitions

In [None]:
ecgPlotHeaders = [
	"LI", "aVR", "V1", "V4",
	"LII", "aVL", "V2", "V5",
	"LIII", "aVF", "V3", "V6"
]

ecgPlotColors = {
	"LI":   "seagreen",
	"aVR":  "black",
	"V1":   "gold",
	"V4":   "orangered",
	"LII":  "cornflowerblue",
	"aVL":  "seagreen",
	"V2":   "gold",
	"V5":   "crimson",
	"LIII": "cornflowerblue",
	"aVF":  "cornflowerblue",
	"V3":   "orangered",
	"V6":   "crimson"
}

# The Dataset: CODE-15

Metadata

In [None]:
dataFolder = "../../../data/"

In [None]:
ecgFeatures = [
	"LI", 
	"LII", 
	"LIII", 
	"aVR", 
	"aVL",
	"aVF", 
	"V1",
	"V2",
	"V3",
	"V4",
	"V5",
	"V6"
]

ecgFeaturesIndexes = [ecgHeaders.index(derivation) for derivation in ecgFeatures]
ecgFeaturesIndexes

In [None]:
ecgTarget = [
	"LI", 
	"LII", 
	"LIII", 
	"aVR", 
	"aVL",
	"aVF", 
	"V1",
	"V2",
	"V3",
	"V4",
	"V5",
	"V6"
]

ecgTargetIndexes = [ecgHeaders.index(derivation) for derivation in ecgTarget]
ecgTargetIndexes

Dataset class

In [None]:
class Code15RandomLeadsDataset(Dataset):
    def __init__(self, hdf5Files, features, target, transform, subsetSize=None, seed=None):
        super().__init__()

        self.hdf5Files = hdf5Files
        self.indexMap  = []  
        self.features  = features
        self.target    = target
        self.transform = transform

        self.nLeads      = 12  
        self.leadSubsets = []

        for fileIndex, path in enumerate(self.hdf5Files):

            with h5py.File(path, "r") as f:
                samplesCount = f['exam_id'].shape[0]
                self.indexMap.extend([(fileIndex, i) for i in range(samplesCount)])

        if subsetSize is not None:

            if seed is not None:
                random.seed(seed)

            self.indexMap = random.sample(self.indexMap, min(subsetSize, len(self.indexMap)))

    def preComputeRandomLeadSubset(self):

        leadCounts       = np.zeros(self.nLeads)
        self.leadSubsets = []

        for _ in range(len(self.indexMap)):
            nLeadsToPick = random.randint(3, 9)

            odds =  1 / (leadCounts + 1)
            odds /= odds.sum()

            chosen = np.random.choice(
                a       = self.nLeads,
                size    = nLeadsToPick,
                replace = False,
                p       = odds
            )

            self.leadSubsets.append(chosen)

            leadCounts[chosen] += 1

    def __len__(self):
        return len(self.indexMap)

    def __getitem__(self, idx):
        
        fileIndex, examIdx = self.indexMap[idx]
        hdf5File           = self.hdf5Files[fileIndex]

        with h5py.File(hdf5File, "r") as file:
            tracing = np.array(file['tracings'][examIdx]) 

        tracing     = self.transform(tracing)
        randomLeads = self.leadSubsets[idx]

        X = np.zeros_like(tracing)
        X[:, randomLeads] = tracing[:, randomLeads]

        X = X[:, self.features]
        X = torch.tensor(X, dtype = torch.float32)

        Y = tracing[:, self.target]
        Y = torch.tensor(Y, dtype = torch.float32)

        return X, Y

Transform Function

- highpass butterworth filter with $ f_c = 1$ Hz
- Truncation in the tails of $ N = 600  $ samples
- Gain of 5
- Normalization with z-score

In [None]:
def transform(ecg):
	b, a = signal.butter(
		N     = 1, 
		Wn    = 1, 
		btype = 'high', 
		fs    = samplingFrequency
	)
	
	ecgFiltred  = signal.filtfilt(b, a, ecg, axis = 0)
	ecgWithGain = 5 * ecgFiltred
	ecgClean    = ecgWithGain[600: -600, :]

	ecgMean = np.mean(ecgClean, axis = 0, keepdims = True)
	ecgStd  = np.std(ecgClean,  axis = 0, keepdims = True) + 1e-8

	ecgNormalized = (ecgClean - ecgMean) / ecgStd

	return ecgNormalized


Holdout dataset

In [None]:
randomLeadsDataset = Code15RandomLeadsDataset(
    hdf5Files  = glob(f"{dataFolder}/*.hdf5"),
    features   = ecgFeaturesIndexes,
    target     = ecgTargetIndexes,    
    transform  = transform,
    subsetSize = 20_000,
    seed       = 42
)

randomLeadsDataset.preComputeRandomLeadSubset()

In [None]:
len(randomLeadsDataset)

In [None]:
generator = torch.Generator().manual_seed(42)

In [None]:
trainSize = int(0.80 * len(randomLeadsDataset))
testSize  = len(randomLeadsDataset) - trainSize

print("Train size =", trainSize)
print("Test size  =", testSize)

In [None]:
trainSet, testSet = random_split(
    randomLeadsDataset, 
    [trainSize, testSize], 
    generator = generator
)

Dataloaders

In [None]:
batchSize = 32

In [None]:
trainDataloader = DataLoader(
	dataset     = trainSet,
	batch_size  = batchSize,
	shuffle     = True,
	num_workers = 4
)

testDataloader = DataLoader(
	dataset     = testSet,
	batch_size  = batchSize,
	shuffle     = False,
	num_workers = 4
)

# Model definition

In [None]:
class ECGReconstructor(nn.Module):
    def __init__(self, latentDim, hiddenDim):
        super().__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv1d(12, 6, 5, padding = 2),
            nn.ReLU(),
            nn.Conv1d(6, hiddenDim, 5, padding = 2),
            nn.ReLU(),
            nn.Conv1d(hiddenDim, latentDim, 5, padding = 2)
        )
        
        self.decoder = nn.Sequential(
            nn.Conv1d(latentDim, hiddenDim, 5, padding = 2),
            nn.ReLU(),
            nn.Conv1d(hiddenDim, 6, 5, padding = 2),
            nn.ReLU(),
            nn.Conv1d(6, 12, 5, padding = 2)
        )

    def forward(self, x):
        x = x.permute(0, 2, 1)  

        encoded = self.encoder(x) 
        decoded = self.decoder(encoded)
        decoded = decoded.permute(0, 2, 1)

        return decoded

# Training

Metadata

In [None]:
device    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model     = ECGReconstructor(128, 32).to(device)																																																																																																	
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
criterion = nn.MSELoss()
epochs    = 30																																																																																																																																									

print(device)																																																																																																																																										

Run

In [None]:
trainingLoss = []

In [None]:
model.train()

for epoch in range(epochs):
	totalLoss = 0

	for X, Y in trainDataloader:
		X, Y = X.to(device), Y.to(device)

		prediction = model(X)
		loss       = criterion(prediction, Y)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		totalLoss += loss.item()

	trainingLoss.append(totalLoss / len(trainDataloader))

	print(f"Epoch {epoch + 1}: loss = {trainingLoss[-1]:.4f}")

Loss along the epochs

In [None]:
plt.scatter(range(epochs), trainingLoss, c = "blue", marker = "x")

plt.title("Training Loss")

plt.xlabel("Epoch")
plt.ylabel("Loss")

plt.grid()
plt.tight_layout()
plt.show()

# Test

In [None]:
model.eval()

testLoss = 0
testR2   = 0

totalSamples = 0

with torch.no_grad():
	for X, Y in testDataloader:
		X, Y       =  X.to(device), Y.to(device)
		prediction =  model(X)
		loss       =  criterion(prediction, Y)
		testLoss   += loss.item()

		YFlat 		   = Y.cpu().numpy().reshape(-1, Y.shape[-1])
		predictionFlat = prediction.cpu().numpy().reshape(-1, prediction.shape[-1])

		testR2 += r2_score(YFlat, predictionFlat) * YFlat.shape[0]

		totalSamples += YFlat.shape[0]

testR2 /= totalSamples

print(f"Test Loss: {testLoss:.4f}")
print(f"Test R^2:  {testR2:.4f}")


# Ploting

Ploting functions

In [None]:
def plotECG(ecg, headers, colors): 
	figure, axes = plt.subplots(
		nrows   = 3,
		ncols   = 4,
		sharex  = True,
		figsize = (16, 9)
	)

	figure.suptitle("ECG 12-Lead")
	figure.supxlabel("Sample")
	figure.supylabel("Dpp")

	axes = axes.flatten()

	for idx, header in enumerate(headers):
		axes[idx].plot(ecg[header], color = colors[header])
		axes[idx].set_title(f"{header}")
	

	plt.tight_layout(pad = 1.5)

	plt.show()

	plt.close()

In [None]:
def comparativeFullEcgPlot(ecgOring, ecgRec, headers):
	figure, axes = plt.subplots(
		nrows   = 3,
		ncols   = 4,
		sharex  = True,
		figsize = (16, 9)
	)

	figure.suptitle("Comparison: ECG 12-Lead")
	figure.supxlabel("Sample")
	figure.supylabel("Dpp")

	axes = axes.flatten()

	for idx, header in enumerate(headers):
		corr = np.round(ecgOring[header].corr(ecgRec[header]), 3)
		r2   = np.round(r2_score(ecgOring[header], ecgRec[header]), 3)

		axes[idx].plot(
			ecgOring[header], 
			color = "blue", 
			alpha = 0.75
		)
		axes[idx].plot(
			ecgRec[header], 
			color = "red", 
			alpha = 0.75
		)

		axes[idx].set_title(f"{header} CORR = {corr} r2 = {r2}")
	

	plt.tight_layout(pad = 1.5)

	plt.show()

	plt.close()

Extract a sample ECG from dataset

In [None]:
sampleX, sampleY = trainSet[13]

In [None]:
with torch.no_grad():
    prediction = model(sampleX.unsqueeze(0)\
        .to(device))\
        .squeeze(0)\
        .cpu()\
        .numpy()

In [None]:
sampleECG              = pd.DataFrame(sampleY,    columns = ecgFeatures)
sampleRandomLeadECG    = pd.DataFrame(sampleX,    columns = ecgFeatures)
sampleECGReconstructed = pd.DataFrame(prediction, columns = ecgFeatures)

In [None]:
plotECG(
	ecg     = sampleRandomLeadECG,
	headers = ecgPlotHeaders,
	colors  = ecgPlotColors
)

In [None]:
plotECG(
	ecg     = sampleECG,
	headers = ecgPlotHeaders,
	colors  = ecgPlotColors
)

In [None]:
plotECG(
	ecg     = sampleECGReconstructed,
	headers = ecgPlotHeaders,
	colors  = ecgPlotColors
)

In [None]:
comparativeFullEcgPlot(
	ecgOring = sampleECG,
	ecgRec   = sampleECGReconstructed,
	headers  = ecgPlotHeaders
)

In [None]:
viewWindow = slice(1000, 2000)

In [None]:
comparativeFullEcgPlot(
	ecgOring = sampleECG[viewWindow],
	ecgRec   = sampleECGReconstructed[viewWindow],
	headers  = ecgPlotHeaders
)

# Save model

In [None]:
# torch.save(model.state_dict(), "../../../models/ae-t4-v1.pth")