In [None]:
import os
import torch
import torchvision.datasets as dataset
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Hyper-parameters
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3

In [None]:
mnist_train = dataset.MNIST("./", train=True, transform=transforms.ToTensor(),  download=True)
mnist_val  = dataset.MNIST("./", train=False, transform=transforms.ToTensor(), download=True)

train_loader = torch.utils.data.DataLoader(dataset = mnist_train,
                                           batch_size = batch_size,
                                           shuffle = True)

**[TO DO]** 

Now we will implement the variational autoencoder. The autoencoder uses linear layers. The encoder has a single linear layer with a ReLU activation funtion which projects the flattened image into a hidden dimension of 400.  

- Define the decoder at the `__init__` funtion
- Fill in the `encode` and `reparameterize` funtions. Define the appropriate layers for them in the `__init__` function. The latent dimension of the mean and standard deviation is 20. The decoder consists of 2 linear layers, the first which projects the latent dimension to the hidden dimension with a ReLU activation funtion, and the second which projects the result to a dimension which equals to the image size. 
- Define the forward funtion

In [None]:
# VAE model
class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        
        self.encoder = nn.Sequential(nn.Linear(image_size, h_dim), 
                                     nn.ReLU())
        
        self.mean = nn.Linear(h_dim, z_dim)
        self.var = nn.Linear(h_dim, z_dim)
        
        self.decoder =  nn.Sequential(nn.Linear(z_dim, h_dim), 
                                      nn.ReLU(),
                                      nn.Linear(h_dim, image_size),
                                      nn.Sigmoid())
        
    def encode(self, x):
        latent = self.encoder(x)
        mu, var = self.mean(latent), self.var(latent)
        return mu, var

    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var / 2)   # assume that they are log-based
        eps = torch.randn(*std.shape).to(device)
        return mu + std * eps

    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mu, var = self.encode(x)
        x = self.reparameterize(mu, var)
        x = self.decode(x)
        return x, mu, var

In [None]:
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
mse_loss = nn.MSELoss(reduction = 'sum')

# can use mean reduction, but then KLD should be:
# kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim = 1), dim = 0)

**[TO DO]** Write the loss term which consists of both the reconstruction loss and the KL divergence loss

In [None]:
# Start training
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(train_loader):
        # Forward pass   
        image = x.to(device).view(-1, image_size)   # (batch_size, 1, 28, 28) -->
        x_reconst, mu, log_var = model(image)
        
        # Compute reconstruction loss and kl divergence
        reconst_loss = mse_loss(x_reconst, image)
        kl_div = 0.5 * torch.sum(torch.exp(log_var) + mu**2 - 1 - log_var)
        loss = reconst_loss + kl_div
        
        # Backprop and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" 
                   .format(epoch+1, num_epochs, i+1, len(train_loader), reconst_loss.item(), kl_div.item()))

**[TO DO]** Generate new samples. Sample the latent vector from the normal distribution and feed it to the decoder. Make sure to reshape your flattened output image to 2D for visualization! 

In [None]:
# generate new images

num_images = 10
latent = torch.randn(num_images, z_dim).to(device)
gen_images = model.decoder(latent)
gen_images = gen_images.reshape(num_images, 28, 28)
gen_images = gen_images.detach().cpu().numpy()

In [None]:
fig, axs = plt.subplots(1, num_images, figsize=(20,10))
np.vectorize(lambda ax:ax.axis('off'))(axs)
for i in range(num_images):
    axs[i].imshow(gen_images[i], cmap = 'gray')