# VAE
Variational autoencoder [1] models inherit autoencoder architecture, but  use variational approach for latent representation learning. In this homework, we will implement VAE and quantitatively measure the quality of the generated samples via Inception score [2,3].

[1] Auto-Encoding Variational Bayes, Diederik P Kingma, Max Welling 2013
https://arxiv.org/abs/1312.6114

[2] Improved techniques for training gans, Salimans, T., Goodfellow, I., Zaremba, W., Cheung, V., Rad- ford, A., and Chen, X. 2016
In Advances in Neural Information Processing Systems 

[3] A note on inception score, Shane Barratt, Rishi Sharma 2018
https://arxiv.org/abs/1801.01973


# PART I. Train a good VAE model

## Setup

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

import numpy as np
import os

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# A bunch of utility functions

def show_images(images):
    images = images.view(images.shape[0], -1).detach().cpu().numpy()
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))

    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg, sqrtimg]))
    plt.show()

def preprocess_img(x):
    return 2 * x - 1.0

def deprocess_img(x):
    return (x + 1.0) / 2.0

def rel_error(x, y):
    return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

## Dataset
We will be working on the MNIST dataset, which is 60,000 training and 10,000 test images. Each picture contains a centered image of white digit on black background (0 through 9). This was one of the first datasets used to train convolutional neural networks and it is fairly easy -- a standard CNN model can easily exceed 99% accuracy. 
 

**Heads-up**: Our MNIST wrapper returns images as vectors. That is, they're size (batch, 784). If you want to treat them as images, we have to resize them to (batch,28,28) or (batch,28,28,1). They are also type np.float32 and bounded [0,1]. 

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),                      # [0,1]
    # transforms.Lambda(lambda x: preprocess_img(x))  # [-1,1]
])

mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
batch_size = 16
mnist_loader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=False)


In [None]:
# Show a batch
data_iter = iter(mnist_loader)
images, labels = next(data_iter)
show_images(images)

In [None]:
X_DIM = images[0].numel()
num_samples = 100000
num_to_show = 100

# Hyperparamters. Your job is to find these.
# TODO:
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
num_epochs = None
batch_size = None
Z_DIM = None
learning_rate = None
# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

## Encoder
Our first step is to build a variational encoder network $q_\phi(z \mid x)$. 

**Hint:** Use four Linear layers.

The encoder should return two tensors of shape `[batch_size, z_dim]`, which  corresponds to the mean $\mu(x_i)$ and diagonal log variance $\log \sigma(x_i)^2$ of each of the `batch_size` input images. Note, we want to make it return log of the variance for numerical stability.

**WARNING:** Do not apply any non-linearity to the last activation.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, z_dim=Z_DIM, x_dim=X_DIM):
        super(Encoder, self).__init__()
        self.z_dim = z_dim
        self.x_dim = x_dim
        # TODO: implement here

    def forward(self, x):
        # TODO: implement here
        mu, log_var = out[:, :self.z_dim], out[:, self.z_dim:]
        return mu, log_var

In [None]:
# TODO: implement reparameterization trick
def sample_z(mu, log_var):
    # Your code here for the reparameterization trick.
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    samples = None
    
    pass
    
    return samples
    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

## Decoder
Now to build a decoder network $p_\theta(x \mid z)$. Use four linear layers.

In this exercise, we will use continuous Bernoulli MLP decoder where $p_\theta(x \mid z)$ is modeled with multivariate continuous Bernoulli distribution, in contrast to the Gaussian distribution we discussed in the lecture, as following (see Appendix C.1 in the original paper and https://arxiv.org/abs/1907.06845 for more details),

$\log p(x \mid z) = \sum_{i=1} x_i \log \lambda(z)_i + (1-x_i) \log (1-\lambda(z)_i) + \log C(\lambda(z)_i)$,

where $\lambda(z)_i$ is the parameter of continuous Bernoulli distribution corresponding to $i$-th pixel. (Note that $\lambda(z)$ is corresponding to $\mathrm{sigmoid}(x\_\text{logit})$ in this implementation, and it also can be seen as the decoded image for latent $z$.)

Note, the output of the decoder should have shape `[batch_size, x_dim]` and should output the unnormalized logits of $x_i$.

**WARNING:** Do not apply any non-linearity to the last activation.

In [None]:
class Decoder(nn.Module):
    def __init__(self, z_dim=Z_DIM, x_dim=X_DIM):
        super(Decoder, self).__init__()
        # TODO: implement here

    def forward(self, z):
        # TODO: implement here
        return x_logit

## Loss definition
Compute the VAE loss. 
1. For the reconstruction loss, you might find `F.binary_cross_entropy_with_logits` useful.
2. For the kl loss, we discussed the closed form kl divergence between two gaussians in the lecture.

In [None]:
def vae_loss(x, x_logit, z_mu, z_logvar):
    recon_loss = None
    kl_loss = None
    
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    
    pass
    
    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    vae_loss = torch.mean(recon_loss + kl_loss)
    return vae_loss, torch.mean(recon_loss)

## Optimizing our loss


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Q = Encoder().to(device)
P = Decoder().to(device)

optimizer = torch.optim.Adam(list(Q.parameters()) + list(P.parameters()), lr=learning_rate)
# MNIST DataLoader (shuffle=True)
transform = transforms.Compose([
    transforms.ToTensor(),
])
mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_loader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True)

Visualize generated samples before training

In [None]:
z_gen = torch.randn(num_to_show, Z_DIM).to(device)
x_gen = P(z_gen)
imgs_numpy = torch.sigmoid(x_gen).detach().cpu()
show_images(imgs_numpy)
plt.show()

## Training a VAE!
If everything works, your batch average reconstruction loss should drop below 95.

In [None]:
iter_count = 0
show_every = 200

# ----- Training Loop -----
for epoch in range(num_epochs):
    for x_i, _ in mnist_loader:
        x_i = x_i.view(x_i.size(0), -1).to(device)   # Flatten and move to device
        
        z_mu, z_logvar = Q(preprocess_img(x_i))
        z_i = sample_z(z_mu, z_logvar)
        x_logit = P(z_i)

        loss, recon_loss = vae_loss(x_i, x_logit, z_mu, z_logvar)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if iter_count % show_every == 0:
            print(f'Epoch: {epoch}, Iter: {iter_count}, Loss: {loss.item():.4f}, Recon: {recon_loss.item():.4f}')
            # imgs_numpy = torch.sigmoid(x_logit).detach().cpu()
            # show_images(imgs_numpy[:16])
            # plt.show()
        iter_count += 1

Visualize generated samples after training

In [None]:
z_gen = torch.randn(num_to_show, Z_DIM).to(device)
x_gen = P(z_gen)
imgs_numpy = torch.sigmoid(x_gen).detach().cpu()
show_images(imgs_numpy)
plt.show()

# PART II. Compute the inception score for your trained VAE model
In this part, we will quantitavely measure how good your VAE model is.

### Train a classifier
We first need to train a classifier. 

In [None]:
# ----- Hyperparameters -----
batch_size = 128
num_classes = 10
epochs = 20
# ----- Data Preparation -----
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

print(f'{len(train_dataset)} train samples')
print(f'{len(test_dataset)} test samples')

# ----- Model Definition -----
class MLPClassifier(nn.Module):
    def __init__(self):
        super(MLPClassifier, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, 0.2)
        x = F.relu(self.fc2(x))
        x = F.dropout(x, 0.2)
        x = self.fc3(x)
        return x

    def prob(self, x):
        x = self.forward(x)
        prob = F.softmax(x, dim=-1)
        return prob

model = MLPClassifier().to(device)
optimizer = optim.RMSprop(model.parameters(), lr=0.001, alpha=0.9)
criterion = nn.CrossEntropyLoss()

# ----- Training -----
for epoch in range(epochs):
    model.train()
    for batch_x, batch_y in train_loader:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        optimizer.zero_grad()
        outputs = model(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}')

# ----- Evaluation -----
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch_x, batch_y in test_loader:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        outputs = model(batch_x)
        predicted = torch.argmax(outputs, dim=1)
        total += batch_y.size(0)
        correct += (predicted == batch_y).sum().item()

print('Test accuracy:', correct / total)

### Verify the trained classifier on the generated samples
Generate samples and visually inspect if the predicted labels on the samples match the actual digits in generated images.

In [None]:
z_gen = torch.randn(num_samples, Z_DIM).to(device)
x_gen = P(z_gen)
imgs_numpy = torch.sigmoid(x_gen[:num_to_show]).detach().cpu()
show_images(imgs_numpy)
plt.show()

In [None]:
preds = torch.argmax(model(torch.sigmoid(x_gen[:20])), dim=1)
print(preds.cpu().numpy())

### Implement the inception score
Implement Equation 1 in the reference [3]. Replace expectation in the equation with empirical average of `num_samples` samples. Don't forget the exponentiation at the end. You should get Inception score of at least 9.0.

In [None]:
with torch.no_grad():
    # TODO: implement here
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    inception_score = None
    pass
    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

print(f'Inception score: {inception_score:.4f}')

### Plot the histogram of predicted labels
Let's additionally inspect the class diversity of the generated samples.

In [None]:
hist_preds = torch.argmax(model(torch.sigmoid(x_gen)), dim=1).cpu().numpy()
plt.hist(hist_preds, bins=np.arange(11)-0.5, rwidth=0.8, density=True)
plt.xticks(range(10))
plt.show()