In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
import os
import pandas as pd
import numpy as np
from torchinfo import summary
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import wfdb
import matplotlib.patheffects as path_effects
import seaborn as sns
from einops.layers.torch import Reduce
#from sklearn.model_selection import train_test_split

# Set seed

In [None]:
seed = 100
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
np.random.seed(seed)  # Numpy module.
torch.manual_seed(seed)
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True

# Define custom dataset

In [None]:
class CustomDataset(Dataset):
    def __init__(self, root_dir):
        self.root_path = root_dir
        self.alldata = np.load(self.root_path) 
        
    def __len__(self):
        return len(self.alldata)
    
    def __getitem__(self, index):
        #data = self.alldata.iloc[index].values.reshape((1, 1, 1024))
        data = self.alldata[index].reshape((1, 1, 512))
        return data

# Define Variational Autoencoder

In [None]:
# model for 8x1 filter, 512 shape
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        
        feature_dim = 70
        
        self.encoder = nn.Sequential(
            nn.ConstantPad2d((3, 3), 0),
            nn.Conv2d(1, 16, (1, 8), (1, 2)),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ConstantPad2d((3, 3), 0),            
            nn.Conv2d(16, 32, (1, 8), (1, 2)),
            nn.BatchNorm2d(32),
            nn.ReLU(),       
            nn.ConstantPad2d((3, 3), 0),            
            nn.Conv2d(32, 48, (1, 8), (1, 2)),
            nn.BatchNorm2d(48),
            nn.ReLU(),            
            nn.ConstantPad2d((3, 3), 0),            
            nn.Conv2d(48, 64, (1, 8), (1, 2)),
            nn.BatchNorm2d(64),
            nn.ReLU(),                
        )
           
        self.decoder = nn.Sequential(
            nn.ConstantPad2d((1, 0, 4, 4), 0),            
            nn.ConvTranspose2d(64, 48, (1, 8), (1, 2), padding=4),
            nn.BatchNorm2d(48),
            nn.ReLU(),                 
            nn.ConstantPad2d((1, 0, 4, 4), 0),            
            nn.ConvTranspose2d(48, 32, (1, 8), (1, 2), padding=4),
            nn.BatchNorm2d(32),
            nn.ReLU(),  
            nn.ConstantPad2d((1, 0, 4, 4), 0),            
            nn.ConvTranspose2d(32, 16, (1, 8), (1, 2), padding=4),
            nn.BatchNorm2d(16),
            nn.ReLU(),        
            nn.ConstantPad2d((1, 0, 4, 4), 0),            
            nn.ConvTranspose2d(16, 1, (1, 8), (1, 2), padding=4),            
        )
        
        # mu and sigma
        self.encFC1 = nn.Linear(64*32, feature_dim)
        self.encFC2 = nn.Linear(64*32, feature_dim)
        
        self.decFC1 = nn.Linear(feature_dim, 64*32)
        
    def encode(self, x):
        x = self.encoder(x)
        x = x.view(-1, 1, 1, 64*32)
        mu = self.encFC1(x)
        logvar = self.encFC2(x)
        return mu, logvar
    
    def reparameterize(self, mu, logVar):
        std = torch.exp(logVar * 0.5)
        eps = torch.randn_like(std)
        return mu + std * eps
    
    def decode(self, z):
        x = F.relu(self.decFC1(z))
        x = x.view((-1, 64, 1, 32))
        x = self.decoder(x)
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        out = self.decode(z)
        return out, mu, logvar

# Display model architecture

In [None]:
model = Autoencoder()
model.cuda()
summary(model, (64, 1, 1, 512), col_names=["input_size", "output_size", "num_params"],)

Dataset and Dataloader

In [None]:
train_data = CustomDataset("../data/train_data.npy")
test_data = CustomDataset("../data/nrs_18patient_minmax_data_onechannel_test.npy")
valid_data = CustomDataset("../data/valid_data.npy")

In [None]:
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=1)
valid_dataloader = DataLoader(valid_data, batch_size=64)

In [None]:
def vae_loss(recon_loss, mu, logvar, beta):
    KLD = (-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()))
    return (recon_loss + beta * KLD)

# Train

In [None]:
print(torch.cuda.is_available())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = Autoencoder().to(device)

In [None]:
epochs = 200
optimizer = optim.Adam(net.parameters(), betas=(0.9, 0.99) ,lr=1e-4)
criterion = nn.MSELoss(reduction='sum').to(device)

In [None]:
def plot_curve(train_loss, test_loss):
    plt.plot(train_loss, color='red', label='Train Loss')
    plt.plot(test_loss, color='blue', label='Test Loss')
    plt.legend()
    plt.title('Loss')
    plt.ylabel('Epoch')
    plt.show()

In [None]:
train_ls, test_ls = [], []
reconsturction_loss, KL_loss = [], []
patience = 10
best_train_loss = 10000
best_test_loss = 10000
fea_dim = 70
# beta = 0.00001
beta = 0
for i in range(epochs):
    print('----------Epoch:{}----------'.format(i+1))
    net.train()
    # Keeping tracking of things for displaying the progress of the training
    local_loss = 0
    test_sum_loss = 0
    if i < 100:
        beta += 0.00001/(fea_dim / 100)


    # Performing an epoch
    for batch in tqdm(train_dataloader):

        # Sending batch to device (GPU or CPU)
        x = batch.to(device=device, dtype=torch.float)

        # Erasing the gradients stored
        optimizer.zero_grad()

        # Sending batch to the Autoencoder and computing the loss
        y, mu, logVar = net(x)
        
        recons_loss = criterion(y, x)
        loss = vae_loss(recons_loss, mu, logVar, beta)

        # Backpropagating gradients
        loss.backward()

        # Running the optimizer
        optimizer.step()
        
        local_loss += loss.item()
    
    reconsturction_loss.append(recons_loss.item())
    temp_kl = loss - recons_loss
    KL_loss.append(temp_kl.item())
    train_ls.append(local_loss / len(train_dataloader))
    print("Epoch: " + str(i+1) + " | training loss: {}".format(train_ls[i]))
    print("reconstruction_loss: {}, weighted_KL_diver: {}, loss: {}".format(recons_loss, 
                    temp_kl, loss.item()))


    if train_ls[i] >= best_train_loss:
        patience -= 1
    if patience == 0:
        print("early stopping")
        plot_curve(train_ls)
        break

    net.eval()
    with torch.no_grad():
        for batch in tqdm(valid_dataloader):
            x = batch.to(device, dtype=torch.float)
            generate, mu, logVar = net(x)
#             kl_divergence = -0.5 * torch.sum(1 + logVar - mu.pow(2) - logVar.exp())
#             test_loss = criterion(generate, x) + 0.01*kl_divergence
            test_loss = vae_loss(criterion(generate, x), mu, logVar, beta)
            test_sum_loss += test_loss.item()
        test_ls.append(test_sum_loss / len(valid_dataloader))
        print("Epoch: " + str(i+1) + " | test loss: {}".format(test_ls[i]))
                       
        # save best model
        if test_ls[i] < best_test_loss:
            best_test_loss = test_ls[i]
            patience = 10
            torch.save(net.state_dict(), "VAE.pt")

plot_curve(train_ls, test_ls)

# Test

In [None]:
def PRD(y_true, y_pred):
    y_true, y_pred = y_true.detach().cpu().numpy()[0].reshape(512, 1), y_pred.detach().cpu().numpy()[0].reshape(512, 1)
    return np.sqrt( np.square(y_true-y_pred).sum() / (np.square(y_true).sum()) ) * 100

In [None]:
fea_dim = 70
beta = 0.001/(fea_dim / 100)
device = torch.device('cpu')
net = Autoencoder()
loaded = torch.load('../model/VAE_fea70_epoch200.pt', map_location=device)
net.load_state_dict(loaded)
net.eval()
testing_loss = []
PRD_list = []
criterion = nn.MSELoss(reduction='sum')
for i, batch in enumerate(test_dataloader):
    if i < 20:
        img1 = batch.numpy().reshape(512, 1)
        wfdb.plot_items(signal=img1, fs=128, title='Test waveform')
    batch = batch.to(device=device, dtype=torch.float)
    output, mu, logVar = net(batch)

    temp = vae_loss(criterion(output, batch), mu, logVar, beta).item()
    testing_loss.append(temp)
    PRD_list.append(PRD(batch, output))

    if i < 20:
        out = output.detach().cpu().numpy()[0].reshape(512, 1)
        wfdb.plot_items(signal=out, fs=128, title='Generate waveform')

# Evaluate Performance

In [None]:
plt.rcParams['figure.figsize'] = [12, 8]
mean_loss = np.mean(testing_loss)
max_loss = np.max(testing_loss)
percent_75 = np.percentile(testing_loss, 75)
ax = sns.boxplot(x=testing_loss).set(title = 'box plot of test loss')
print("mean loss: {}, max_loss: {}, 75% percentile: {}".format(mean_loss, max_loss, percent_75))

In [None]:
def add_median_labels(ax, fmt='.1f'):
    lines = ax.get_lines()
    boxes = [c for c in ax.get_children() if type(c).__name__ == 'PathPatch']
    lines_per_box = int(len(lines) / len(boxes))
    for median in lines[4:len(lines):lines_per_box]:
        x, y = (data.mean() for data in median.get_data())
        # choose value depending on horizontal or vertical plot orientation
        value = x if (median.get_xdata()[1] - median.get_xdata()[0]) == 0 else y
        text = ax.text(x, y, f'{value:{fmt}}', ha='center', va='center',
                       fontweight='bold', color='white')
        # create median-colored border around white text for contrast
        text.set_path_effects([
            path_effects.Stroke(linewidth=3, foreground=median.get_color()),
            path_effects.Normal(),
        ])

In [None]:
plt.rcParams['figure.figsize'] = [15, 3]
single_PRD_box = sns.boxplot(x=PRD_list, linewidth=2.5)
single_PRD_box.set(title = 'Box Plot of PRD (Single Channel VAE)')
single_PRD_box.set(xlabel='PRD (%)')
add_median_labels(single_PRD_box)

In [None]:
print(np.mean(PRD_list), np.percentile(PRD_list, 25), np.percentile(PRD_list, 75), np.max(PRD_list), np.min(PRD_list))