In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
import os
import pandas as pd
import numpy as np
from torchinfo import summary
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import wfdb
import seaborn as sns
import matplotlib.patheffects as path_effects
#from sklearn.model_selection import train_test_split

# Set seed

In [None]:
seed = 100
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
np.random.seed(seed)  # Numpy module.
torch.manual_seed(seed)
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True

# Define custom dataset

In [None]:
class CustomDataset(Dataset):
    def __init__(self, root_dir):
        self.root_path = root_dir
        self.alldata = np.load(self.root_path) 
        
    def __len__(self):
        return len(self.alldata)
    
    def __getitem__(self, index):
        #data = self.alldata.iloc[index].values.reshape((1, 1, 1024))
        data = self.alldata[index].reshape((2, 1, 512))
        return data

# Define Multi-Channel Convolutional Autoencoder

In [None]:
# model for 8x1 filter, 512 shape
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        
        
        self.encoder = nn.Sequential(
            nn.ConstantPad2d((3, 3), 0),
            nn.Conv2d(2, 40, (1, 8), (1, 2)),
            nn.BatchNorm2d(40),
            nn.ELU(),
            nn.ConstantPad2d((3, 3), 0),            
            nn.Conv2d(40, 20, (1, 8), (1, 2)),
            nn.BatchNorm2d(20),
            nn.ELU(),       
            nn.ConstantPad2d((3, 3), 0),            
            nn.Conv2d(20, 20, (1, 8), (1, 2)),
            nn.BatchNorm2d(20),
            nn.ELU(),            
            nn.ConstantPad2d((3, 3), 0),            
            nn.Conv2d(20, 40, (1, 8), (1, 2)),
            nn.BatchNorm2d(40),
            nn.ELU(),            
            nn.ConstantPad2d((4, 3), 0),            
            nn.Conv2d(40, 2, (1, 8), (1, 1)),
            nn.BatchNorm2d(2),
            nn.ELU(),        
        )
           
        self.decoder = nn.Sequential(
            nn.ConstantPad2d((1, 0, 4, 4), 0),        
            nn.ConvTranspose2d(2, 2, (1, 8), (1, 1), padding=4),
            nn.BatchNorm2d(2),
            nn.ELU(),            
            nn.ConstantPad2d((1, 0, 4, 4), 0),             
            nn.ConvTranspose2d(2, 40, (1, 8), (1, 2), padding=4),
            nn.BatchNorm2d(40),
            nn.ELU(),        
            nn.ConstantPad2d((1, 0, 4, 4), 0),            
            nn.ConvTranspose2d(40, 20, (1, 8), (1, 2), padding=4),
            nn.BatchNorm2d(20),
            nn.ELU(),                 
            nn.ConstantPad2d((1, 0, 4, 4), 0),            
            nn.ConvTranspose2d(20, 20, (1, 8), (1, 2), padding=4),
            nn.BatchNorm2d(20),
            nn.ELU(),  
            nn.ConstantPad2d((1, 0, 4, 4), 0),            
            nn.ConvTranspose2d(20, 40, (1, 8), (1, 2), padding=4),
            nn.BatchNorm2d(40),
            nn.ELU(),        
            nn.ConstantPad2d((4, 3), 0),             
            nn.Conv2d(40, 2, (1, 8), 1), 
        )     

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Display model architecture

In [None]:
model = Autoencoder()
model.cuda()
summary(model, (64, 2, 1, 512), col_names=["input_size", "output_size", "num_params"],)

In [None]:
train_data = CustomDataset("../data//train_data_multichannel.npy")
test_data = CustomDataset("../data/nrs_18patient_minmax_data_test.npy")
valid_data = CustomDataset("../data//valid_data_multichannel.npy")

Dataset and Dataloader

In [None]:
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=1)
valid_dataloader = DataLoader(valid_data, batch_size=64)

# Train

In [None]:
print(torch.cuda.is_available())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = Autoencoder().to(device)

In [None]:
epochs = 200
optimizer = optim.Adam(net.parameters(), betas=(0.9, 0.99) ,lr=1e-4)
criterion = nn.MSELoss().to(device)

In [None]:
def plot_curve(train_loss, test_loss):
    plt.plot(train_loss, color='red', label='Train Loss')
    plt.plot(test_loss, color='blue', label='Test Loss')
    plt.legend()
    plt.title('Loss')
    plt.ylabel('Epoch')
    plt.show()

In [None]:
train_ls, test_ls = [], []
patience = 10
best_train_loss = 10000
best_test_loss = 10000
for i in range(epochs):
    print('----------Epoch:{}----------'.format(i+1))
    net.train()
    # Keeping tracking of things for displaying the progress of the training
    local_loss = 0
    test_sum_loss = 0

    # Performing an epoch
    for batch in tqdm(train_dataloader):

        # Sending batch to device (GPU or CPU)
        x = batch.to(device=device, dtype=torch.float)

        # Erasing the gradients stored
        optimizer.zero_grad()

        # Sending batch to the Autoencoder and computing the loss
        y = net(x)
        loss = criterion(y, x)

        # Backpropagating gradients
        loss.backward()

        # Running the optimizer
        optimizer.step()
        
        local_loss += loss.item()
    
    train_ls.append(local_loss / len(train_dataloader))
    print("Epoch: " + str(i+1) + " | training loss: {}".format(train_ls[i]))    

    if train_ls[i] >= best_train_loss:
        patience -= 1
    if patience == 0:
        print("early stopping")
        plot_curve(train_ls)
        break

    net.eval()
    with torch.no_grad():
        for batch in tqdm(valid_dataloader):
            x = batch.to(device, dtype=torch.float)
            generate = net(x)
            test_loss = criterion(generate, x)
            test_sum_loss += test_loss.item()
        test_ls.append(test_sum_loss / len(valid_dataloader))
        print("Epoch: " + str(i+1) + " | test loss: {}".format(test_ls[i]))
                       
        # save best model
        if test_ls[i] < best_test_loss:
            best_test_loss = test_ls[i]
            patience = 10
            torch.save(net.state_dict(), "FCVE_multi_epoch200.pt")

plot_curve(train_ls, test_ls)
print("Done!")

# Test

In [None]:
def PRD(y_true, y_pred):
    y_true, y_pred = y_true.detach().numpy()[0].reshape((512, 2)), y_pred.detach().numpy()[0].reshape(512, 2)
    return np.sqrt( np.square(y_true-y_pred).sum() / (np.square(y_true).sum()) ) * 100

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = Autoencoder()
loaded = torch.load('../model/FCVE_multi_epoch200.pt', map_location=device)
net.load_state_dict(loaded)
net.eval()
testing_loss = []
PRD_list = []
for i, batch in enumerate(test_dataloader):
    if i < 20:
        img1 = batch[0].numpy().reshape((512, 2))
        wfdb.plot_items(signal=img1, fs=128, title='Test waveform')
    
    output = net(batch.to(dtype=torch.float))
    
    criterion = nn.MSELoss()
    temp = criterion(output, batch).item()
    testing_loss.append(temp)
    PRD_list.append(PRD(batch, output))
    
    if i < 20:
        out = output[0].detach().numpy().reshape((512, 2))
        wfdb.plot_items(signal=out, fs=128, title='Generate waveform')
    

In [None]:
plt.rcParams['figure.figsize'] = [12, 8]
mean_loss = np.mean(testing_loss)
max_loss = np.max(testing_loss)
percent_75 = np.percentile(testing_loss, 75)
ax = sns.boxplot(x=testing_loss).set(title = 'box plot of test loss')
print("mean loss: {}, max_loss: {}, 75% percentile: {}".format(mean_loss, max_loss, percent_75))

# Evaluate Performance

In [None]:
def add_median_labels(ax, fmt='.1f'):
    lines = ax.get_lines()
    boxes = [c for c in ax.get_children() if type(c).__name__ == 'PathPatch']
    lines_per_box = int(len(lines) / len(boxes))
    for median in lines[4:len(lines):lines_per_box]:
        x, y = (data.mean() for data in median.get_data())
        # choose value depending on horizontal or vertical plot orientation
        value = x if (median.get_xdata()[1] - median.get_xdata()[0]) == 0 else y
        text = ax.text(x, y, f'{value:{fmt}}', ha='center', va='center',
                       fontweight='bold', color='white')
        # create median-colored border around white text for contrast
        text.set_path_effects([
            path_effects.Stroke(linewidth=3, foreground=median.get_color()),
            path_effects.Normal(),
        ])

In [None]:
plt.rcParams['figure.figsize'] = [15, 3]
single_PRD_box = sns.boxplot(x=PRD_list, linewidth=2.5)
single_PRD_box.set(title = 'Box Plot of PRD (Multiple Channels FCAE)')
single_PRD_box.set(xlabel='PRD (%)')
add_median_labels(single_PRD_box)

In [None]:
print(np.mean(PRD_list), np.percentile(PRD_list, 25), np.percentile(PRD_list, 75), np.max(PRD_list), np.min(PRD_list))