> This is a self-correcting activity generated by [nbgrader](https://nbgrader.readthedocs.io). Fill in any place that says `YOUR CODE HERE` or `YOUR ANSWER HERE`. Run subsequent cells to check your code.

---

# Generate handwritten digits with a VAE (PyTorch)

The goal here is to train a VAE to generate handwritten digits.

![VAE digits](images/vae_digits.png)

## Environment setup

In [1]:
import os
import math

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

In [2]:
# Setup plots
%matplotlib inline
plt.rcParams['figure.figsize'] = 10, 8
%config InlineBackend.figure_format = 'retina'

In [15]:
# Import ML packages (edit this list if needed)
import torch
print(f'PyTorch version: {torch.__version__}')

import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.optim as optim

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device = ", device)

PyTorch version: 1.4.0
device =  cpu


## Data loading

In [6]:
# Load MNIST dataset
trainset = torchvision.datasets.MNIST(
    root="./data", train=True, transform=transforms.ToTensor(), download=True
)
testset = torchvision.datasets.MNIST(
    root="./data", train=False, transform=transforms.ToTensor(), download=True
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100.1%

Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


113.5%

Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100.4%

Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


180.4%

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw
Processing...
Done!


### Question

Create batch data loaders `trainloader` and `testloader` resp. for training and test datasets.

In [12]:
batch_size = 128

# YOUR CODE HERE
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=True)

## Model definition

### Question

Complete the following class to create a variational autoencoder.

In [17]:
# VAE model
class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, latent_dim)
        self.fc3 = nn.Linear(hidden_dim, latent_dim)
        self.fc4 = nn.Linear(latent_dim, hidden_dim)
        self.fc5 = nn.Linear(hidden_dim, input_dim)
        
    def encode(self, x):
        """Encode input into its latent representation
        Returns mean and standard deviation"""
        
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def sample(self, mu, log_var):
        """Sample a random codings vector from a gaussian distribution
        Takes mean and log_var (gamma) as parameters"""
        
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        """Decode codings"""
        
        h = F.relu(self.fc4(z))
        return torch.sigmoid(self.fc5(h))
    
    def forward(self, x):
        """Encode inputs to obtain mean and standard deviation
           Sample codings from gaussian distribution using mean and std
           Returns decoded codings, mean and standard deviation"""
        # YOUR CODE HERE
        mu, log_var = self.encode(x)
        sampled = self.sample(mu, log_var)
        decoded = self.decode(sampled)
        return decoded, mu, log_var

## Model training

### Question

Complete the following training loop to:
- instantiate the variational autoencoder on target device.
- instanciate the Adam optimizer.
- implement forward pass and gradient descent.

In [30]:
input_dim = 784
hidden_dim = 400
latent_dim = 20
num_epochs = 15
learning_rate = 1e-3
prints_per_epoch = 1  # Increase to see more feedback during training

# Instanciate VAE and optimizer
# YOUR CODE HERE
vae = VAE(input_dim, hidden_dim, latent_dim)
optimizer = optim.Adam(vae.parameters(), lr=learning_rate)


# Train model
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(trainloader):
        # Forward pass
        # YOUR CODE HERE
        x = x.to(device).view(-1, input_dim)
        x_reconst, mu, log_var = vae(x) #appel implicite à la méthode forward

        # Compute reconstruction loss and KL divergence
        reconst_loss = F.binary_cross_entropy(x_reconst, x, reduction="sum")
        kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        loss = reconst_loss + kl_div

        # Backprop and optimize
        # YOUR CODE HERE
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Print losses at regular intervals
        step_count = len(trainloader)
        print_threshold = math.ceil(step_count / prints_per_epoch)
        if (i + 1) % print_threshold == 0 or (i + 1) == step_count:
            print(
                f"Epoch [{epoch + 1}/{num_epochs}]"
                f", step [{i + 1}/{step_count}]"
                f", reconst loss: {reconst_loss.item():.4f}"
                f", KL div: {kl_div.item():.4f}"
            )

Epoch [1/15], step [469/469], reconst loss: 9645.1650, KL div: 2002.2250
Epoch [2/15], step [469/469], reconst loss: 8252.0947, KL div: 2194.8657
Epoch [3/15], step [469/469], reconst loss: 8215.7480, KL div: 2345.3354
Epoch [4/15], step [469/469], reconst loss: 8252.9414, KL div: 2338.0591
Epoch [5/15], step [469/469], reconst loss: 7969.6001, KL div: 2369.6467
Epoch [6/15], step [469/469], reconst loss: 8530.3330, KL div: 2452.9790
Epoch [7/15], step [469/469], reconst loss: 7970.4634, KL div: 2365.6082
Epoch [8/15], step [469/469], reconst loss: 7711.5215, KL div: 2486.0381
Epoch [9/15], step [469/469], reconst loss: 7952.0791, KL div: 2383.5613
Epoch [10/15], step [469/469], reconst loss: 7854.0146, KL div: 2466.0996
Epoch [11/15], step [469/469], reconst loss: 8159.8701, KL div: 2370.4321
Epoch [12/15], step [469/469], reconst loss: 7377.4419, KL div: 2415.7322
Epoch [13/15], step [469/469], reconst loss: 7571.5806, KL div: 2408.0161
Epoch [14/15], step [469/469], reconst loss: 76

## Reconstructions visualization¶

In [None]:
def plot_image(image):
    plt.imshow(image.numpy().squeeze(), cmap="binary")
    plt.axis("off")

def show_reconstructions(model, images, n_images=8):
    """Show original and reconstructed images side-by-side"""
    
    inputs = images.reshape(-1, 28*28).to(device)
    reconstructions, _, _ = model(inputs)
    
    fig = plt.figure(figsize=(n_images * 1.5, 3))
    for image_index in range(n_images):
        plt.subplot(2, n_images, 1 + image_index)
        plot_image(images[image_index])
        plt.subplot(2, n_images, 1 + n_images + image_index)
        plot_image(reconstructions[image_index].view(1, 28, 28))

### Question

Show reconstructions for one batch of test data.

In [31]:
# YOUR CODE HERE
dataiter = iter(testloader)
images, _ = dataiter.next()
with torch.no_grad():
    show_reconstruction(vae, images)

NameError: name 'show_reconstruction' is not defined

## Generating new images¶

In [None]:
def plot_multiple_images(images, n_cols=None):
    """Show a series of images"""

    n_cols = n_cols or len(images)
    n_rows = (len(images) - 1) // n_cols + 1
    if images.shape[-1] == 1:
        images = np.squeeze(images, axis=-1)
    plt.figure(figsize=(n_cols * 1.5, 3))
    for index, image in enumerate(images):
        plt.subplot(n_rows, n_cols, index + 1)
        plt.imshow(image.numpy().squeeze(), cmap="binary")
        plt.axis("off")

### Question

Use the VAE to show several generated digits.

In [None]:
with torch.no_grad():
    z = torch.randn(16, latent_dim).to(device)
    # YOUR CODE HERE
    gen_images = vae.decode(z).view(-1, 1, 28, 28)
    plot_multiple_image(gen_images, n_cols=8)