In [None]:
#imports for VAE module
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn

#imports for training
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import torch.optim as optim

import numpy as np
import matplotlib.pyplot as plt

#imports for dataset
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

#imports for loss function
from IPython.display import HTML, display

## VAE Module 
- Input img (x)
- Hidden dim
- mean, std
- paremterization trick
- Decoder 
- Ouput img 

In [None]:

class VAE(nn.Module): 
    
    #initilzation functino
    def __init__(self, input_dim=784, hidden_dim=200, latent_z_dim=20): 
        super(VAE, self).__init__()
        
        #encoder
        self.img_to_hidden = nn.Linear(input_dim, hidden_dim)       #hidden layer
        self.hidden_to_mean = nn.Linear(hidden_dim, latent_z_dim)   #mean of the latent space
        self.hidden_to_sigma = nn.Linear(hidden_dim, latent_z_dim)  #sigma of the latent space
        
        #decoder 
        self.latent_to_hidden = nn.Linear(latent_z_dim, hidden_dim) #hidden layer
        self.hidden_to_img = nn.Linear(hidden_dim, input_dim)       #output layer


    #encoder which equates to q_phi(z|x)
    def encode(self, x): 
        x = nn.ReLU(self.img_to_hidden(x))
        mu, sigma = self.hidden_to_mean(x), self.hidden_to_sigma(x)
        return mu, sigma
    
    #decoder which equates to p_theta(x|z)
    def decode(self, z): 
        z = self.relu(self.latent_to_hidden(z))
        x_hat = self.hidden_to_img(z)
        
        #sigmoid to ensure the output is between 0 and 1
        return torch.sigmoid(x_hat)
    
    #reparameterization trick - helper function
    def reparameterize(self, mu, sigma): 
        epsilon = torch.randn_like(mu)
        return mu + sigma * epsilon
    
    def forward(self, x): 
        
        mu, sigma = self.encode(x)
        
        #reparameterize the latent space
        z = self.reparameterize(mu, sigma)
        
        #decode the latent space where x hat is the reconstructed image
        x_hat = self.decode(z)
        
        #return the reconstructed image, the mean and the sigma
        return x_hat, mu, sigma
    
    #loss function
    def loss_function(self, x, x_hat, mu, sigma): 
        
        #reconstruction loss for MNIST Bernoulli distribution
        recon_loss = F.binary_cross_entropy(x_hat, x, reduction='sum')
        
        #KL divergence loss
        kl_loss = -0.5 * torch.sum(1 + torch.log(sigma**2) - mu**2 - sigma**2)
        
        #total loss
        return recon_loss, kl_loss, recon_loss + kl_loss

## Training Module

In [None]:
class Trainer():
    
    def __init__(self, model, device, epochs=10, batch_size=100, learning_rate=0.001, input_dim=784, dataloader=None):
        
        #hyperparameters
        self.model = model
        self.device = device
        self.epochs = epochs
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.input_dim = input_dim
        self.dataloader = dataloader
        
        #optimizer
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        
    def train(self, fixed_x): 
        
        #training loop
        for epoch in range(self.epochs): 
            for i, (x, _) in enumerate(self.dataloader): 
                
                #move to device
                x = x.reshape(-1, self.input_dim).to(self.device)
                
                #reconstruction 
                x_hat, mu, sigma = self.model(x)
                
                #loss function
                recon_loss, kl_loss, loss = self.model.loss_function(x, x_hat, mu, sigma)
                
                #backpropagation
                self.optimizer.zero_grad()  #zero the gradients
                loss.backward()  #backpropagation
                self.optimizer.step()  #update the parameters 
                
                #print loss
                if i % 20 == 0: 
                        print(f"Epoch [{epoch+1}/{self.epochs}] | Batch [{i+1}/{len(self.dataloader)}] | Loss {loss.item():.4f} | Reconstruction Loss {recon_loss.item():.4f} | KL Divergence Loss {kl_loss.item():.4f}")
            
            #visulize after each epoch
            with torch.no_grad():
                recon, _, _, _ = model(fixed_x)
                recon = recon.view(-1, 1, 28, 28).cpu()
                fig, axes = plt.subplots(2, 16, figsize=(16, 2))
                for j in range(16):
                    axes[0, j].imshow(fixed_x[j].cpu().view(28, 28), cmap='gray')
                    axes[0, j].axis('off')
                    axes[1, j].imshow(recon[j].view(28, 28), cmap='gray')
                    axes[1, j].axis('off')
                plt.suptitle(f"Top: Original, Bottom: Reconstruction (Epoch {epoch+1})")
                plt.show()
                
        #return the trained model
        return self.model


if __name__ == "__main__": 
    
    #hyperparameters
    epochs = 10
    batch_size = 100
    learning_rate = 0.001
    hidden_dim = 400
    latent_z_dim = 16
    input_dim = 784
    
    #importing MNIST dataset
    dataset = MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
    dataloader = DataLoader(dataset, batch_size=100, shuffle=True)
        
    #device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    #model & trainer module initialization
    model = VAE(hidden_dim, latent_z_dim, input_dim).to(device)
    
    #fixed batch of images
    fixed_x, fixed_y = next(iter(dataloader))
    fixed_x = fixed_x[:16].reshape(-1, input_dim).to(device)
    
    trainer = Trainer(model, device, epochs, batch_size, learning_rate, input_dim, dataloader)
    
    #train the model
    trainer.train(fixed_x)
    
    #save the model to onnx format 
    torch.onnx.export(model, fixed_x, "vae.onnx", verbose=True)