In [1]:
import numpy as np
import pandas as pd
import torch

In [377]:
class Encoder(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.cnn = torch.nn.Sequential(
            torch.nn.Conv2d(1, 128, kernel_size=(12,3), padding=(0,1)),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 64, kernel_size=(1,3), padding=(0,1)),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 32, kernel_size=(1,3), padding=(0,1)),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 16, kernel_size=(1,3), padding=(0,1)),
            torch.nn.ReLU(),
            torch.nn.Conv2d(16, 6, kernel_size=(1,3), padding=(0,1))
        )
        
        self.flatten = torch.nn.Flatten()

    def forward(self, X):
        X = self.cnn(X)
        X = self.flatten(X)
        return X
        
class Decoder(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.cnn = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(6, 16, kernel_size=(1,3), padding=(0,1), output_padding=(0,0)),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(16, 32, kernel_size=(1,3), padding=(0,1), output_padding=(0,0)),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(32, 64, kernel_size=(1,3), padding=(0,1), output_padding=(0,0)),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(64, 128, kernel_size=(1,3), padding=(0,1), output_padding=(0,0)),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(128, 1, kernel_size=(12,3), padding=(0,1), output_padding=(0,0))
        )
        
        self.unflatten = torch.nn.Unflatten(1, (6, 1, 5000))
    
    def forward(self, X):
        X = self.unflatten(X)
        X = self.cnn(X)
        return X
    
class Classifier(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.enc = Encoder()
        self.linear = torch.nn.Linear(30000, 56)
    
    def forward(self, X):
        Y = self.enc(X)
        Y = self.linear(Y)
        Y = torch.sigmoid(Y)
        return Y
    
class AutoEncoder(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.enc = Encoder()
        self.dec = Decoder()
        
    def forward(self, X):
        X = self.enc.forward(X)
        X = self.dec.forward(X)
        return X
    
class VariationalAutoEncoder(torch.nn.Module):
    
    def __init__(self, hidden_dim=10000):
        super().__init__()
        
        # Dimension of latent variable.
        self.hidden_dim = hidden_dim
        
        # Encoder: takes a sample and reduces down to a 30,000-dim vector.
        self.enc = Encoder()
        
        # Layers for predicting parameters of q(z|x).
        self.latent_mu = torch.nn.Linear(30000, hidden_dim)
        self.latent_log_var = torch.nn.Linear(30000, hidden_dim)
        
        # Linear layer to upscale input to decoder.
        self.dec_lin = torch.nn.Linear(hidden_dim, 30000)
        
        # Decoder: takes a 30,000-dim vector and upsamples to (12, 5000).
        self.dec = Decoder()
        
    def sample_gaussian(self, mu, log_var):
        B_SZ = mu.shape[0]
        eps = torch.randn((B_SZ, self.hidden_dim))
        return mu + eps * torch.exp(log_var / 2.0) # multiply by std. dev.
    
    def forward(self, X):
        
        # Pass input through the Encoder.
        enc_out = self.enc.forward(X)
        
        # Predict parameters of q(z|x).
        z_mu = self.latent_mu(enc_out)
        z_log_var = self.latent_log_var(enc_out)
        
        # Reparameterization "trick". 
        # Get a multivariate gaussian with means z_mu and variances z_var * I.
        z = self.sample_gaussian(z_mu, z_log_var)
        
        # Pass latent sample through decoder.
        dec_out = self.dec.forward(self.dec_lin(z))
        
        return dec_out, z_mu, z_log_var
    
    def sample(self, n):
        with torch.no_grad():
            z = torch.randn((n, self.hidden_dim))
            dec_out = self.dec.forward(self.dec_lin(z))
            return dec_out 
    

In [425]:
from pathlib import Path
from tqdm import tqdm

def load_data(n = -1):
    # Load Diagnostics.xlsx
    diagnostics = pd.read_excel("data/Diagnostics.xlsx")
    
    # Load Conditions.xlsx
    conditions = pd.read_excel("data/ConditionNames.xlsx")
    n_cond = conditions["Acronym Name"].size
    cond_map = { conditions["Acronym Name"][i] : i for i in range(n_cond) }
    
    # Load Examples from files.
    num_examples = diagnostics.shape[0] if (n == -1) else n
    
    X = np.zeros((5000, num_examples, 12))
    Y = np.zeros((num_examples, n_cond))

    print("Loading", num_examples, "Examples.")
    for i, file in tqdm(enumerate((Path.cwd() / "data/ECGData/").glob("*.csv")), total = num_examples-1):
        
        # get name of file.
        fname = str(file.name).split('.')[0]
        
        # get row number in Diagnostics.xlsx for this file name.
        j = diagnostics.index[diagnostics["FileName"] == fname].item()
        
        # make one-hot vector for condition names.
        for b in diagnostics["Beat"][j].split(' '):
            if b != "NONE":
                k = cond_map[b]
                Y[i,k] = 1
        
        X[:, i, :] = pd.read_csv(file).to_numpy()
        if i >= num_examples - 1:
            break
    print("Done!")
    
    Y = torch.from_numpy(Y).type(torch.FloatTensor)
    X = torch.from_numpy(X).type(torch.FloatTensor)
    
    return X, Y, fnames

import matplotlib.pyplot as plt

def plot_ecg(y, n_leads = 12):
    fig, axs = plt.subplots(n_leads, figsize=(8, 8))
    plt.xlabel("Seconds")
    plt.ylabel("Microvolts")
    plt.xlim([0, 10])
    
    x = np.linspace(0, 10, 10 * 500) # 500Hz
    for i in range(n_leads):
        print(y.shape)
        axs[i].plot(x, y[i,:].numpy(), color='black')
        
    plt.show()
    

In [426]:
### Loss for VAE
def vae_loss(X, dec_out, z_mu, z_log_var, alpha=1):
    
    # Reconstruction loss.
    recl = mse(X, dec_out)
    
    # KL Divergence.
    dkl = -0.5 * torch.sum(1 + z_log_var - torch.pow(z_mu, 2) - torch.exp(z_log_var),
                          axis=1)
    dkl = dkl.mean()
    
    print("recl:", recl, "dkl:", dkl)
    
    return alpha * recl + dkl

In [427]:
def reshape_for_cnn(X):
    """
    Reshapes the data from (X_LEN, B_SZ, X_DIM) (e.g. (5000, 10, 12))
    to (B_SZ, 1, X_DIM, X_LEN), because we have 1 channel.
    """
    X_LEN, B_SZ, X_DIM = X.shape
    X = torch.permute(X, (1, 2, 0))
    X = torch.reshape(X, (B_SZ, 1, X_DIM, X_LEN))
    return X

In [None]:
X,Y = load_data()

X_DIM = 12
X_LEN = 5000
N_EX = X.shape[1]
B_SZ = 50

# Shrink the data in accordance with above.
X = X[0:X_LEN,0:N_EX,0:X_DIM]

def reshape_for_cnn(X):
    """
    Reshapes the data from (X_LEN, B_SZ, X_DIM) (e.g. (5000, 10, 12))
    to (B_SZ, 1, X_DIM, X_LEN), because we have 1 channel.
    """
    X_LEN, B_SZ, X_DIM = X.shape
    X = torch.permute(X, (1, 2, 0))
    X = torch.reshape(X, (B_SZ, 1, X_DIM, X_LEN))
    return X

X = reshape_for_cnn(X)

# Model and optimizer.
vae = VariationalAutoEncoder()
cl = Classifier()
bce = torch.nn.BCELoss()
mse = torch.nn.MSELoss()
optimizer = torch.optim.Adam(vae.parameters(), lr = 1e-4, weight_decay = 1e-8)

# Place on GPU
# mps_device = torch.device("mps")

### Train
N_EPOCHS = 1
print("Starting training...")
for e_i in range(N_EPOCHS):
    
    # shuffle the indices
    idx = np.arange(N_EX)
    np.random.shuffle(idx)
    
    # iterate over batches
    epoch_loss = 0.0
    num_batches = N_EX // B_SZ
    for b_i in range(num_batches):
        
        # Make batch
        X_batch = X[idx[b_i * B_SZ : (b_i + 1) * B_SZ],:,:,:]
        Y_batch = Y[idx[b_i * B_SZ : (b_i + 1) * B_SZ],:]
  
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward Pass
        #dec_out, z_mu, z_log_var = vae.forward(X_batch)
        #Y_pred = cl.forward(X_batch)

        # Calculate loss
        #loss = vae_loss(X_batch, dec_out, z_mu, z_log_var)
        #loss = bce(Y_pred, Y_batch)

        # Update parameters
        loss.backward()
        optimizer.step()
        
        print(loss.item())
        
    print("Epoch #{0} Loss:{1}".format(e_i + 1, loss.item()))

Loading 10646 Examples.


 35%|█████████████                        | 3762/10645 [00:17<00:30, 222.59it/s]