In [None]:
import numpy as np
import scipy
from scipy.stats import norm
import pandas as pd
import torch
import matplotlib.pyplot as plt

from data import load_data

In [None]:
def signal_to_impulse_train(x):
    
    # standardize the signal.
    x = x / (2 * np.std(x))
    
    # raise to a power.
    x = np.power(x, 4)
    
    # smooth the signal.
    hw = np.hamming(11)
    x = np.convolve(x, hw, 'same')
    
    # threshold
    # x[x < 0.5 * np.max(x)] = 0
    
    # restore relative amplitudes. 
    x = np.log(1 + x)
    
    return x

def make_impulse_trains(x):
    """
    Generates an impulse train for each example.
    ----
    x: input of shape (N, H, W)
    
    Returns impulse_trains: a tensor of impulse trains for all 
    examples and ECG channels of shape (N, H, W).
    """
    N, H, W = x.shape
    impulse_trains = torch.zeros((N, H, W))
    for i in range(N): # Loop over batch elements.
        for h in range(H): # Loop over ECG channels (height).
            im_tr = signal_to_impulse_train(x[i,h].numpy())
            im_tr = torch.from_numpy(im_tr).type(torch.FloatTensor)
            impulse_trains[i, h] = im_tr
    
    return impulse_trains

class EncoderResidualBlock(torch.nn.Module):
    
    def __init__(self, in_ch, out_ch, kernel_sz, downsample=True):
        super().__init__()
        
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.downsample = downsample
        
        h_pad = int(kernel_sz / 2.0)
        
        self.conv_1 = torch.nn.Sequential(
            #torch.nn.BatchNorm2d(in_ch),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_ch, out_ch, kernel_size=(1, kernel_sz), padding='same', stride=(1,1)),
        )
        
        self.conv_2 = torch.nn.Sequential(
            #torch.nn.BatchNorm2d(out_ch),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_ch, out_ch, kernel_size=(1, kernel_sz), padding='same', stride=(1,1)),   
        )
        
        # 1x1 Convolution with a horizontal stride of 2.
        # We need two of these because if in_ch < out_ch, we can't downsample both the output and the input (identity)
        # using the same layer.
        self.downsampler_id = torch.nn.Conv2d(in_ch, in_ch, kernel_size = (1, 1), padding=(0, 0), stride=(1,2))
        self.downsampler_out = torch.nn.Conv2d(out_ch, out_ch, kernel_size = (1, 1), padding=(0, 0), stride=(1,2))
        
        self.dropout = torch.nn.Dropout(p=0.1)
    
    def forward(self, X):
        
        # Store X - to be added to the output later as a skip connection.
        Id = X
        
        # Apply first convolution.
        X = self.conv_1(X)
        
        # Apply second convolution.
        X = self.conv_2(X)
    
        # Downsample.
        if self.downsample:
            X = self.downsampler_out(X)
            # we also need to downsample the original input so we can add the two.
            Id = self.downsampler_id(Id) 
        
        # Skip Connection.
        # If this is a residual block in which the number of channels increases,
        # add the input to the first in_ch dimensions of the output.
        # TODO: add twice if out_ch = 2 * in_ch.
        if self.in_ch < self.out_ch:
            pd = self.out_ch - self.in_ch
            X = X + torch.nn.functional.pad(Id, (0,0, 0,0, 0,pd, 0,0))
        else:
            X = X + Id
        
        Apply Dropout.
        X = self.dropout(X)
        
        return X
            
class DecoderResidualBlock(torch.nn.Module):
    
    def __init__(self, in_ch, out_ch, kernel_sz, upsample=True):
        super().__init__()
        
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.upsample = upsample
        
        h_pad = int(kernel_sz / 2.0)
        
        self.conv_1 = torch.nn.Sequential(
            #torch.nn.BatchNorm2d(in_ch),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(in_ch, out_ch, kernel_size=(1, kernel_sz), padding=(0,h_pad), stride=(1,1)),   
        )
        
        self.conv_2 = torch.nn.Sequential(
            #torch.nn.BatchNorm2d(out_ch),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(out_ch, out_ch, kernel_size=(1, kernel_sz), padding=(0, h_pad), stride=(1,1)),
        )
        
        # 1x1 Convolution with a horizontal stride of 2.
        self.upsampler = torch.nn.ConvTranspose2d(out_ch, out_ch, kernel_size = (1, 1), padding=(0, 0), stride=(1,2), output_padding=(0,1))
        
        self.dropout = torch.nn.Dropout(p=0.1)
        
    def forward(self, X):
        
        # Apply first convolution.
        X = self.conv_1(X)
        
        # Apply second convolution.
        X = self.conv_2(X)
        
        # Upsample.
        if self.upsample:
            X = self.upsampler(X)
        
        # Apply Dropout.
        # X = self.dropout(X)
        
        return X
    
class CNNVariationalAutoencoder(torch.nn.Module):

    def __init__(self):
        super().__init__()
        
        # Encoder.
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(1, 16, 1),
            EncoderResidualBlock(16, 16, 19, downsample=False),
            EncoderResidualBlock(16, 16, 19, downsample=True),
            EncoderResidualBlock(16, 32, 19, downsample=True),
            EncoderResidualBlock(32, 48, 19, downsample=True),
            EncoderResidualBlock(48, 64, 19, downsample=True),
            EncoderResidualBlock(64, 64, 19, downsample=True),
            EncoderResidualBlock(64, 80, 9, downsample=False),
            EncoderResidualBlock(80, 80, 9, downsample=False),
            EncoderResidualBlock(80, 80, 9, downsample=False),
            torch.nn.Flatten()
        )

        # Decoder.
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(60, 5120), # Upscale latent vector.
            torch.nn.Unflatten(1, (80, 1, 64)),
            DecoderResidualBlock(80, 80, 9, upsample=False),
            DecoderResidualBlock(80, 80, 9, upsample=False),
            DecoderResidualBlock(80, 64, 9, upsample=False),
            DecoderResidualBlock(64, 64, 19, upsample=True),
            DecoderResidualBlock(64, 48, 19, upsample=True),
            DecoderResidualBlock(48, 32, 19, upsample=True),
            DecoderResidualBlock(32, 16, 19, upsample=True),
            DecoderResidualBlock(16, 16, 19, upsample=True),
            DecoderResidualBlock(16, 16, 19, upsample=False),
            torch.nn.ConvTranspose2d(16, 1, 1)
        )
        
        # Layers for predicting parameters of q(z|x).
        self.latent_mu = torch.nn.Linear(5120, 60)
        self.latent_log_var = torch.nn.Linear(5120, 60)
    
    def sample_gaussian(self, mu, log_var):
        B_SZ = mu.shape[0]
        eps = torch.randn((B_SZ, 60))
        return mu + eps * torch.exp(log_var / 2.0) # multiply by std. dev.
    
    def forward(self, x):
        
        # Reshape x before passing thru CNN.
        x = x.unsqueeze(1) # (N, H, W) => (N, 1, H, W)
        
        # Get encoder output.
        enc_out = self.encoder(x)
        
        # Predict parameters of q(z|x).
        z_mu = self.latent_mu(enc_out)
        z_log_var = self.latent_log_var(enc_out)
        
        # Reparameterization "trick". 
        # Get a multivariate gaussian with means z_mu and variances z_var * I.
        z = self.sample_gaussian(z_mu, z_log_var)
        
        # Pass latent sample through decoder.
        dec_out = self.decoder.forward(z)
        
        return dec_out, z_mu, z_log_var
    
    def sample(self, n):
        with torch.no_grad():
            z = torch.randn((n, 60))
            dec_out = self.decoder.forward(z)
            return dec_out
    
class ModifiedVariationalAutoencoder(torch.nn.Module):

    def __init__(self):
        super().__init__()
        
        self.device = "cpu"
        
        self.latent_dim = 60
        
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(1, 16, 1),
            EncoderResidualBlock(16, 16, 19, downsample=False),
            EncoderResidualBlock(16, 16, 19, downsample=True),
            EncoderResidualBlock(16, 32, 19, downsample=True),
            EncoderResidualBlock(32, 48, 19, downsample=True),
            EncoderResidualBlock(48, 64, 19, downsample=True),
            EncoderResidualBlock(64, 64, 19, downsample=True),
            EncoderResidualBlock(64, 64, 19, downsample=True),
            EncoderResidualBlock(64, 80, 9, downsample=False),
            EncoderResidualBlock(80, 80, 9, downsample=False),
            EncoderResidualBlock(80, 80, 9, downsample=False),
            torch.nn.Flatten(),
        )

        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(self.latent_dim, 6320), # Upscale latent vector.
            torch.nn.Unflatten(1, (80, 1, 79)),
            DecoderResidualBlock(80, 80, 9, upsample=False),
            DecoderResidualBlock(80, 80, 9, upsample=False),
            DecoderResidualBlock(80, 64, 9, upsample=False),
            DecoderResidualBlock(64, 64, 19, upsample=False),
            DecoderResidualBlock(64, 64, 19, upsample=False),
            DecoderResidualBlock(64, 48, 19, upsample=True),
            DecoderResidualBlock(48, 32, 19, upsample=True),
            DecoderResidualBlock(32, 16, 19, upsample=True),
            DecoderResidualBlock(16, 16, 19, upsample=False),
            torch.nn.ConvTranspose2d(16, 1, 1), # 625
        )
        
        # Layers for predicting parameters of q(z|x).
        self.linear_mu = torch.nn.Linear(2560, self.latent_dim)
        torch.nn.init.constant_(self.linear_mu.weight, 0)
        self.linear_var =  torch.nn.Linear(2560, self.latent_dim)
        torch.nn.init.constant_(self.linear_var.weight, 0)
    
    def apply_filters(self, f, im):
        """
        f: tensor of filters of shape (N, 1, H, W)
        im: tensor of impulse trains of shape (N, H, W)
        """
        N, H, W = im.shape
        outputs = torch.zeros((0, H, W), device=self.device)
        for i in range(N):
            
            out_i = torch.zeros((0, W), device=self.device) # Build up conv. outputs across height dim.
            for h in range(H):

                # Convolve the filter at height h with the impulse train at height h.
                out = torch.nn.functional.conv1d(im[i:i+1,h].unsqueeze(1), # Expects (N, C, W) - add C dim.
                                                  f[i:i+1,:,h],
                                                 padding='same')
                
                out = out.squeeze(0) # remove fake batch dim.
                
                # "append" this result to out_i along the H dim.
                out_i = torch.cat((out_i, out)) 
            
            # Add fake batch dimension so we can concat to outputs.
            out_i = out_i.unsqueeze(0) 
            
            # "append" this result to outputs along the N dim.
            outputs = torch.cat((outputs, out_i)) 
        
        return outputs 
    
    def sample_gaussian(self, mu, log_var):
        N, D = mu.shape
        eps = torch.randn((N, D), device=self.device)
        return mu + eps * torch.exp(log_var / 2.0) # multiply by std. dev.
    
    def forward(self, x, im):
        """
        x: input of shape (N, H, W)
        im: impulse trains of shape (N, H, W)
        """
        
        # Reshape x before passing thru CNN.
        x = x.unsqueeze(1) # (N, H, W) => (N, 1, H, W)

        # Get encoder output.
        eo = self.encoder(x)
        
        # Predict parameters of q(z|x).
        z_mu = self.linear_mu(eo)
        z_log_var = self.linear_var(eo)
        
        # Sample a multivariate gaussian via the reparameterization "trick".
        z = self.sample_gaussian(z_mu, z_log_var)
        
        # Get filter from decoder.
        f = self.decoder(z)
        
        # Convolve impulse trains with filters.
        out = self.apply_filters(f, im)
        
        return out, z_mu, z_log_var
    
    def generate(self, im):
        """
        im: tensor of shapes (N, H, W)
        """
        with torch.no_grad():
            
            N, H, W = im.shape
            
            # Sample latent vector.
            z = torch.randn((N, self.latent_dim))

            # Get filter from decoder.
            f = self.decoder(z)
            
            # Convolve impulse trains with filters.
            out = self.apply_filters(f, im)
            
            return out
    
def vae_loss(x, r, z_mu, z_log_var, alpha=1):
    
    # Reconstruction loss.
    recl = torch.nn.functional.mse_loss(x, r, reduction='mean')
    
    # KL Divergence.
    dkl = -0.5 * torch.sum(1 + z_log_var - torch.pow(z_mu, 2) - torch.exp(z_log_var), axis=1)
    dkl = dkl.mean()
    
    return recl + alpha * dkl, recl, dkl

In [None]:
### frange_cycle_linear is from Fu et. al. https://arxiv.org/abs/1903.10145
def frange_cycle_linear(start, stop, n_epoch, n_cycle=4, ratio=0.5):
    L = np.ones(n_epoch)
    period = n_epoch/n_cycle
    step = (stop-start)/(period*ratio) # linear schedule

    for c in range(n_cycle):

        v , i = start , 0
        while v <= stop and (int(i+c*period) < n_epoch):
            L[int(i+c*period)] = v
            v += step
            i += 1
    return L 

In [None]:
def train_vae(mdl, x, opt, batch_size=25, num_epochs=10, save_checkpoints=0, checkpoint_dir=""):
    
    N, H, W = x.shape
    
    num_batches = N // batch_size
        
    alphas = frange_cycle_linear(0.0, 1.0, num_epochs, n_cycle = 4, ratio=0.5)
    
    for e_i in range(num_epochs):
    
        # shuffle the indices
        idx = np.arange(N)
        np.random.shuffle(idx)

        # iterate over batches
        for b_i in range(num_batches):

            # Make batch.
            batch_idx = idx[b_i * batch_size : (b_i + 1) * batch_size]
            x_batch = x[batch_idx,:,:]

            # Zero gradients.
            opt.zero_grad()

            # Forward pass.
            out, z_mu, z_log_var = mdl.forward(x_batch)

            # Calculate loss.
            alpha = alphas[e_i]
            loss, recl, dkl = vae_loss(x_batch, out, z_mu, z_log_var, alpha=alpha)
            print(loss.item(), recl.item(), dkl.item(), alpha)

            # Compute gradients.
            loss.backward()

            # Clip gradients.
            torch.nn.utils.clip_grad_norm_(mdl.parameters(), max_norm=0.1, norm_type="inf")
            
            # Update parameters.
            opt.step()

        # Print Loss.
        print("Epoch #{0} Loss:{1}".format(e_i + 1, loss.item()))
        
        # Save model checkpoint.
        if ((save_checkpoints > 0 and (e_i + 1) % save_checkpoints == 0) or 
            (save_checkpoints == -1 and e_i + 1 == num_epochs)):
            
            PATH = checkpoint_dir + "mdl_epoch_{1}.pt".format(checkpoint_dir, e_i + 1)
            torch.save({
                'epoch': e_i + 1,
                'model_state_dict': mdl.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'loss': loss.item(),
                }, PATH)

def train_modified_vae(mdl, x, im, opt, batch_size=25, num_epochs=10, save_checkpoints=0, checkpoint_dir=""):
    """
    If save_checkpoints is 0, will not save checkpoints.
    If save_checkpoints is -1, will save only at the last epoch. 
    """
    
    N, H, W = x.shape
    
    num_batches = N // batch_size
        
    #alphas = frange_cycle_linear(0.0, 1.0, num_epochs * num_batches, n_cycle = 4, ratio=0.5)
    alphas = frange_cycle_linear(0.0, 1.0, num_epochs, n_cycle = 4, ratio=0.5)
    
    for e_i in range(num_epochs):
    
        # shuffle the indices
        idx = np.arange(N)
        np.random.shuffle(idx)

        # iterate over batches
        for b_i in range(num_batches):

            # Make batch.
            batch_idx = idx[b_i * batch_size : (b_i + 1) * batch_size]
            x_batch = x[batch_idx,:,:]
            im_batch = im[batch_idx,:,:]

            # Zero gradients.
            opt.zero_grad()

            # Forward pass.
            out, z_mu, z_log_var = mdl.forward(x_batch, im_batch)

            # Calculate loss.
            #alpha = alphas[num_batches * e_i + b_i]
            alpha = alphas[e_i]
            loss, recl, dkl = vae_loss(x_batch, out, z_mu, z_log_var, alpha=alpha)
            print(loss.item(), recl.item(), dkl.item(), alpha)

            # Compute gradients.
            loss.backward()

            # Clip gradients.
            torch.nn.utils.clip_grad_norm_(mdl.parameters(), max_norm=0.1, norm_type="inf")
            
            # Update parameters.
            opt.step()

        # Print Loss.
        print("Epoch #{0} Loss:{1}".format(e_i + 1, loss.item()))
        
        # Save model checkpoint.
        if ((save_checkpoints > 0 and (e_i + 1) % save_checkpoints == 0) or 
            (save_checkpoints == -1 and e_i + 1 == num_epochs)):
            
            PATH = checkpoint_dir + "mdl_epoch_{1}.pt".format(checkpoint_dir, e_i + 1)
            torch.save({
                'epoch': e_i + 1,
                'model_state_dict': mdl.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'loss': loss.item(),
                }, PATH)
            
def load_model(PATH):
    
    mdl = ModifiedVariationalAutoencoder()
    opt = torch.optim.Adam(mdl.parameters(), lr = 1e-4, weight_decay = 1e-8)
    
    checkpoint = torch.load(PATH)
    mdl.load_state_dict(checkpoint['model_state_dict'])
    opt.load_state_dict(checkpoint['optimizer_state_dict'])
    
    return mdl, opt


In [None]:
#### Train Modified VAE

# Load data.
x, y = load_data(1000)
x = torch.permute(x, (1, 2, 0)) # (W, N, H) => (N, H, W)
x = x[:,:1,:2048] # first lead

# Skip problematic files.
skip = []
for i in range(x.shape[0]):
    if int(np.linalg.norm(x[i,0])) == 0:
        skip.append(i)
skip_idxs = [i for i in range(x.shape[0]) if i not in skip]
x = x[skip_idxs]

# Get impulse trains from inputs.
im = make_impulse_trains(x)

# Build model.
mdl = ModifiedVariationalAutoencoder()

# Move to GPU
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = x.to(mps_device)
    im = im.to(mps_device)
    mdl = mdl.to(mps_device)
    mdl.device = mps_device
else:
    print ("MPS device not found.")

# Train.
opt = torch.optim.Adam(mdl.parameters(), lr = 1e-5, weight_decay = 1e-9)
PATH = "/Users/yonatano/Documents/Courses/CS 230/Project/models/vae_run_4/"
train_modified_vae(mdl, x, im, opt, batch_size=1, num_epochs=20000, save_checkpoints=0, checkpoint_dir=PATH)