# Libraries

In [None]:
import h5py
import torch

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from sklearn.metrics  import r2_score
from glob             import glob

import matplotlib.pyplot as plt
import scipy.signal      as signal
import torch.nn          as nn
import numpy  			 as np
import pandas 		     as pd


# ECG Definitions

ECG metadata

In [None]:
samplingFrequency = 400

ECG file headers

In [None]:
ecgHeaders = [
	"LI", 
	"LII", 
	"LIII", 
	"aVR", 
	"aVL",
	"aVF", 
	"V1",
	"V2",
	"V3",
	"V4",
	"V5",
	"V6"
]

ECG plot definitions

In [None]:
ecgPlotHeaders = [
	"LI", "aVR", "V1", "V4",
	"LII", "aVL", "V2", "V5",
	"LIII", "aVF", "V3", "V6"
]

ecgPlotColors = {
	"LI":   "seagreen",
	"aVR":  "black",
	"V1":   "gold",
	"V4":   "orangered",
	"LII":  "cornflowerblue",
	"aVL":  "seagreen",
	"V2":   "gold",
	"V5":   "crimson",
	"LIII": "cornflowerblue",
	"aVF":  "cornflowerblue",
	"V3":   "orangered",
	"V6":   "crimson"
}

# The Dataset: CODE-15

Metadata

In [None]:
dataFolder = "../../../data/CODE15/hdf5/"

In [None]:
ecgFeatures = [
	"LI", 
	"LII", 
	"LIII", 
	"aVR", 
	"aVL",
	"aVF", 
	"V1",
	"V2",
	"V3",
	"V4",
	"V5",
	"V6"
]

ecgFeaturesIndexes = [ecgHeaders.index(derivation) for derivation in ecgFeatures]
ecgFeaturesIndexes

In [None]:
ecgTarget = [
	"LI", 
	"LII", 
	"LIII", 
	"aVR", 
	"aVL",
	"aVF", 
	"V1",
	"V2",
	"V3",
	"V4",
	"V5",
	"V6"
]

ecgTargetIndexes = [ecgHeaders.index(derivation) for derivation in ecgTarget]
ecgTargetIndexes

Dataset class

In [None]:
class Code15Dataset(Dataset):
	def __init__(self, hdf5Files, features, target, transform):
		
		super().__init__()
		
		self.hdf5Files = hdf5Files
		self.indexMap  = []
		self.features  = features
		self.target    = target
		self.transform = transform


		for fileIndex, path in enumerate(self.hdf5Files):
			
			with h5py.File(path, "r") as f:
				samplesCount = f['exam_id'].shape[0]
				
				self.indexMap.extend([(fileIndex, i) for i in range(samplesCount)])

	def __len__(self):
		return len(self.indexMap)

	def __getitem__(self, idx):

		fileIndex, examIdx = self.indexMap[idx]
		hdf5File           = self.hdf5Files[fileIndex]

		with h5py.File(hdf5File, "r") as file:
			tracing = np.array(file['tracings'][examIdx])

		tracing = self.transform(tracing)

		X = tracing[:, self.features]
		X = torch.tensor(X, dtype = torch.float32)

		Y = tracing[:, self.target]
		Y = torch.tensor(Y, dtype = torch.float32)

		return X, Y


Transform Function

- highpass butterworth filter with $ f_c = 1$ Hz
- Truncation in the tails of $ N = 600  $ samples
- Gain of 5
- Normalization with z-score

In [None]:
def transform(ecg):
	b, a = signal.butter(
		N     = 1, 
		Wn    = 1, 
		btype = 'high', 
		fs    = samplingFrequency
	)
	
	ecgFiltred  = signal.filtfilt(b, a, ecg, axis = 0)
	ecgWithGain = 5 * ecgFiltred
	ecgClean    = ecgWithGain[600: -600, :]

	ecgMean = np.mean(ecgClean, axis = 0, keepdims = True)
	ecgStd  = np.std(ecgClean,  axis = 0, keepdims = True) + 1e-8

	ecgNormalized = (ecgClean - ecgMean) / ecgStd

	return ecgNormalized


Holdout dataset

In [None]:
dataset = Code15Dataset(
	hdf5Files = glob(f"{dataFolder}/*.hdf5"),
	features  = ecgFeaturesIndexes,
	target    = ecgTargetIndexes,	
	transform = transform
)

In [None]:
dataset[0][0].shape, dataset[0][1].shape 

In [None]:
len(dataset)

In [None]:
generator = torch.Generator().manual_seed(14)

In [None]:
trainRatio = 0.80

trainSize = int(trainRatio * len(dataset))
testSize  = len(dataset) - trainSize

print("Train dataset Len =", trainSize)
print("Train dataset Len =", testSize)

In [None]:
trainDataset, testDataset = random_split(
	dataset = dataset,
	lengths = [trainSize, testSize]
)

Dataloaders

In [None]:
trainDataloader = DataLoader(
	dataset     = trainDataset,
	batch_size  = 32,
	shuffle     = True,
	num_workers = 4

)

testDataloader = DataLoader(
	dataset     = testDataset,
	batch_size  = 32,
	shuffle     = False,
	num_workers = 4

)

In [None]:
print("X shape =", dataset[0][0].shape)
print("Y shape =", dataset[0][1].shape)

# Model definition

In [None]:
class ECGReconstructor(nn.Module):
	def __init__(self, latentDim=256):
		super(ECGReconstructor, self).__init__()

		self.latentDim = latentDim

		self.encoder = nn.Sequential(
			nn.Conv1d(12, 32, kernel_size=5, stride=2, padding=2),  # (batch, 32, 1448)
			nn.LeakyReLU(0.2),
			nn.Conv1d(32, 64, kernel_size=5, stride=2, padding=2), # (batch, 64, 724)
			nn.LeakyReLU(0.2),
			nn.Conv1d(64, 128, kernel_size=5, stride=2, padding=2), # (batch, 128, 362)
			nn.LeakyReLU(0.2),
			nn.Conv1d(128, 256, kernel_size=5, stride=2, padding=2), # (batch, 256, 181)
			nn.LeakyReLU(0.2),
		)

		self.flattenedSize = 256 * 181

		# Bottleneck
		self.fc_mu 	   = nn.Linear(self.flattenedSize, latentDim)
		self.fc_logvar = nn.Linear(self.flattenedSize, latentDim)

		self.fc_decode = nn.Linear(latentDim, self.flattenedSize)

		self.decoder = nn.Sequential(
			nn.ConvTranspose1d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1),  # (batch, 128, 362)
			nn.LeakyReLU(0.2),
			nn.ConvTranspose1d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),   # (batch, 64, 724)
			nn.LeakyReLU(0.2),
			nn.ConvTranspose1d(64, 32, kernel_size=5, stride=2, padding=2, output_padding=1),    # (batch, 32, 1448)
			nn.LeakyReLU(0.2),
			nn.ConvTranspose1d(32, 12, kernel_size=5, stride=2, padding=2, output_padding=1),     # (batch, 9, 2896)
		)

	def encode(self, X):
		X = self.encoder(X)
		X = torch.flatten(X, start_dim = 1)

		mu 	   = self.fc_mu(X)
		logvar = self.fc_logvar(X)
		
		return mu, logvar

	def reparameterize(self, mu, logvar):
		
		std = torch.exp(0.5 * logvar)
		eps = torch.randn_like(std)
		
		return mu + eps * std

	def decode(self, Z):
		Y = self.fc_decode(Z)
		Y = Y.view(-1, 256, 181)
		Y = self.decoder(Y)
		Y = Y.permute(0, 2, 1)
		
		return Y

	def forward(self, X):
		X = X.permute(0, 2, 1)

		mu, logvar = self.encode(X)
		Z 		   = self.reparameterize(mu, logvar)
		YHat 	   = self.decode(Z)

		return YHat, mu, logvar

# Training

Metadata

In [None]:
def vaeLoss(Y, yHat, mean, logvar):
	
	reproductionLoss = nn.functional.mse_loss(yHat, Y)
	KLD 			 = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp()) / Y.shape[0]
	KLD 			 = 1e-3 * KLD

	return reproductionLoss + KLD


In [None]:
device    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model     = ECGReconstructor(latentDim = 256).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
criterion = vaeLoss
epochs    = 50

print(device)

Run

In [None]:
trainingLoss = []

In [None]:
model.train()

for epoch in range(epochs):
	totalLoss = 0

	for X, Y in trainDataloader:
		X, Y = X.to(device), Y.to(device)

		YHat, mean, logvar  = model(X)
		loss 				= criterion(Y, YHat, mean, logvar)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		totalLoss += loss.item()

	trainingLoss.append(totalLoss / len(trainDataloader))

	print(f"Epoch {epoch + 1}: loss = {trainingLoss[-1]:.4f}")

Loss along the epochs

In [None]:
plt.scatter(range(epochs), trainingLoss, c = "blue", marker = "x")

plt.title("Training Loss")

plt.xlabel("Epoch")
plt.ylabel("Loss")

plt.grid()
plt.tight_layout()
plt.show()

# Test

In [None]:
model.eval()

testLoss = 0
testR2   = 0

totalSamples = 0

with torch.no_grad():
	for X, Y in testDataloader:
		X, Y       =  X.to(device), Y.to(device)

		YHat, mean, logvar  = model(X)
		loss 				= criterion(Y, YHat, mean, logvar)

		YFlat    = Y.cpu().numpy().reshape(-1, Y.shape[-1])
		YHatFlat = YHat.cpu().numpy().reshape(-1, YHat.shape[-1])

		testLoss += loss.item()
		testR2   += r2_score(YFlat, YHatFlat) * YFlat.shape[0]

		totalSamples += YFlat.shape[0]

testR2 /= totalSamples

print(f"Test Loss: {testLoss:.4f}")
print(f"Test R^2:  {testR2:.4f}")


# Ploting

Ploting functions

In [None]:
def plotECG(ecg, headers, colors): 
	figure, axes = plt.subplots(
		nrows   = 3,
		ncols   = 4,
		sharex  = True,
		figsize = (16, 9)
	)

	figure.suptitle("ECG 12-Lead")
	figure.supxlabel("Sample")
	figure.supylabel("Dpp")

	axes = axes.flatten()

	for idx, header in enumerate(headers):
		axes[idx].plot(ecg[header], color = colors[header])
		axes[idx].set_title(f"{header}")
	

	plt.tight_layout(pad = 1.5)

	plt.show()

	plt.close()

In [None]:
def comparativeFullEcgPlot(ecgOring, ecgRec, headers):
	figure, axes = plt.subplots(
		nrows   = 3,
		ncols   = 4,
		sharex  = True,
		figsize = (16, 9)
	)

	figure.suptitle("Comparison: ECG 12-Lead")
	figure.supxlabel("Sample")
	figure.supylabel("Dpp")

	axes = axes.flatten()

	for idx, header in enumerate(headers):
		corr = np.round(ecgOring[header].corr(ecgRec[header]), 3)
		r2   = np.round(r2_score(ecgOring[header], ecgRec[header]), 3)

		axes[idx].plot(
			ecgOring[header], 
			color = "blue", 
			alpha = 0.75
		)
		axes[idx].plot(
			ecgRec[header], 
			color = "red", 
			alpha = 0.75
		)

		axes[idx].set_title(f"{header} CORR = {corr} r2 = {r2}")
	

	plt.tight_layout(pad = 1.5)

	plt.show()

	plt.close()

Extract a sample ECG from dataset

In [None]:
model  = ECGReconstructor(latentDim = 256).to(device)

model.load_state_dict(
    torch.load(
        "../../../models/t3/ae-t3-v1.pth", 
        weights_only = True
    )
)

model

In [None]:
sampleECGTensor, _ = trainDataset[150]

sampleECG = pd.DataFrame(
	columns = ecgFeatures,
	data    = sampleECGTensor
)

sampleECG.head()


In [None]:
with torch.no_grad():
    sampleECGReconstructed, _, _ = model(sampleECGTensor.unsqueeze(0).to(device))
    sampleECGReconstructed       = sampleECGReconstructed.cpu().numpy()
    sampleECGReconstructed       = sampleECGReconstructed.squeeze(0)

sampleECGReconstructed = pd.DataFrame(
    columns = ecgTarget,
    data    = sampleECGReconstructed
)

In [None]:
plotECG(
	ecg     = sampleECG,
	headers = ecgPlotHeaders,
	colors  = ecgPlotColors
)

In [None]:
plotECG(
	ecg     = sampleECGReconstructed,
	headers = ecgPlotHeaders,
	colors  = ecgPlotColors
)

In [None]:
comparativeFullEcgPlot(
	ecgOring = sampleECG,
	ecgRec   = sampleECGReconstructed,
	headers  = ecgPlotHeaders
)

In [None]:
viewWindow = slice(1000, 2000)

In [None]:
comparativeFullEcgPlot(
	ecgOring = sampleECG[viewWindow],
	ecgRec   = sampleECGReconstructed[viewWindow],
	headers  = ecgPlotHeaders
)

# Save model

In [None]:
torch.save(model.state_dict(), "../../../models/t3/ae-t3-v1.pth")