<a id="title"></a>
# Variational Autoencoder MNIST Tutorial using PyTorch
***
## Learning Goals:
By the end of this tutorial, you will:
- build a variational autoencoder (VAE)
- train and evaluate a VAE
- visualize the latent space of a VAE
- generate samples from the latent space

## Table of Contents
[Introduction](#intro) <br>
[0. Imports](#imports) <br>
[1. MNIST Dataset and Scaling](#mnist) <br>
[2. Build a VAE](#build) <br>
[3. Test Model Functionality](#test) <br>
[4. Set Training and Test Sets](#set) <br>
[5. Hyperparameters and Loading](#hyper) <br>
[6. Train Model](#train) <br>
[7. Plot Loss Function and R2](#plot) <br>
[8. Analyze Samples](#analyze) <br>
[9. Visualize the Latent Space](#latent) <br>
[10. Generate Samples](#detect) <br>
[11.. Conclusions](#con) <br>
[Additional Resources](#add) <br>
[About this Notebook](#about) <br>
[Citations](#cite) <br>

## Introduction <a id="intro"></a>

The main purpose of this notebook is to build an autoencoder in [PyTorch](https://pytorch.org/), a deep learning Python library. This tutorial is not an exhaustive introduction to machine learning and assumes the user is familiar with vocabulary (supervised v unsupervised, neural networks, loss functions, backpropogation, etc) and methodology (model selection, feature selection, hyperparameter tuning, etc). This notebook also assumes the user is familiar with autoencoders and the [MNIST handwritten dataset](http://yann.lecun.com/exdb/mnist/). Look at [Additional Resources](#add) for more complete machine learning guides. The paragraphs below serve as a brief introduction to variational autoencoders.

A [variational autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) is an autoencoder that forces the latent space to approximate a given probability distribution. Most VAEs approximate the latent space to be a standard normal distribution. If the latent space approximates a normal distribution, the latent space becomes smooth and the decoding is more robust across the latent space, unlike in traditional autoencoders. The same use cases from traditional autoencoders still apply with the addition of the VAE being a generative model. Since VAEs learned the probability distribution of the input data through the latent space, it can generate new fake samples that well represent the training data.

[Reparametrization trick](https://stats.stackexchange.com/questions/199605/how-does-the-reparameterization-trick-for-vaes-work-and-why-is-it-important)
    
**In this notebook, we will build a variational autoencoder using PyTorch to learn the representation of MNIST handwritten digit dataset.**

## 0. Imports <a id="imports"></a>

If you are running this notebook on Google Colab, you shouldn't have to install anything. If you are running this notebook in Jupyter, this notebook assumes you created the virtual environment defined in `environment.yml`. If not, close this notebook and run the following lines in a terminal window:

`conda env create -f environment.yml`

`conda activate deepwfc3_env`

We import the following libraries:
- *numpy* for handling arrays
- *matplotlib* for plotting
- *tqdm* for keeping track of loop speed
- *torchvision* for accessing MNIST images 
- *torch* as our machine learning framework

In [None]:
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm

import torchvision

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

## 1. MNIST Dataset and Scaling<a id="mnist"></a>

The MNIST dataset is nicely packed in `torch` as `torch.Tensors`. We'll download the training and test sets, which is unpacked as `x_train` for training features, `y_train` for training labels, `x_test` for testing features, and `y_test` for testing labels. In addition, we'll convert the datasets to `np.arrays` for easier data manipulation.

In [None]:
root = 'mnist'
train_dataset = torchvision.datasets.MNIST(root, train=True, download=True, transform=torchvision.transforms.ToTensor())
test_dataset  = torchvision.datasets.MNIST(root, train=False, download=True, transform=torchvision.transforms.ToTensor())

x_train = train_dataset.data.numpy()
y_train = train_dataset.targets.numpy()
x_test = test_dataset.data.numpy()
y_test = test_dataset.targets.numpy()

We'll also define some frequently used global variables. `x_train_size` is the number of images in the training set, `x_test_size` is the number of images in the test set, and `x_length` is the length/width of an image. In addition, we min-max scale our images to have a minimum value of 0 and a maximum value of 1.

In [None]:
x_train_size = x_train.shape[0]
x_test_size = x_test.shape[0]
x_length = x_train.shape[1]

norm = x_train.max()
x_train_scale = x_train / norm
x_test_scale = x_test / norm

## 2. Build a VAE <a id="build"></a>

PyTorch has its own unique data objects called `torch.utils.data.Dataset`. `Dataset` has methods to retrieve the data length and instances. The datasets built from the class are used as inputs for `torch.utils.data.Dataloader`, which prepares our data for training. Since an autoencoder isn't trained using labels, the "labels" are not defined.

In [None]:
class LoadDataset(Dataset):
    
    def __init__(self, images):
        self.images = images
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        image = self.images[index]
        
        return image

First, we define the functions and layers to build our encoder. The constructor has our model hyperparameters as inputs:

- `filters`: the number of filters the convolutional layers will learn
- `neurons`: the number of neurons in the fully connected layers
- `sub_array_size`: the image length/width
- `kernel_size`: the size of the filter being learned

Using the constructor's parameters, we define the encoder's layers and functions. We use the [rectified linear unit (ReLU)](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) as our activation function to add nonlinearity to our model and [batch normalization](https://en.wikipedia.org/wiki/Batch_normalization) to rescale our features in each convolutional layer.

The `forward` fucntion builds our encoder from the functions we defined in the constructor. The encoder is built as follows:
- Encoder Layer 1
    - convole 1 28x28 image into 8 24x24 feature maps
    - perform batch normalization
    - activate the feature maps using ReLU
- Encoder Layer 2
    - convole 8 24x24 feature maps into 16 20x20 feature maps
    - perform batch normalization
    - activate the feature maps using ReLU
- Encoder Layer 3
    - convolve 16 20x20 feature maps into 32 16x16 feature maps
    - perform batch normalization
    - activate the feature maps using ReLU
- Flatten the 32 16x16 feature maps to a 1D 32 * 16 * 16 array
- Hidden Layer
    - use the flatten 1D array as inputs for a 128 neuron hidden layer
    - activate the neurons using ReLU
- Latent Space
     - use the 128 neuron hidden layer as inputs for the 2 dimensional latent space
     - perform the reparameterization trick on the latent space
         - z = mu + sigma * epsilon

In [None]:
# define functions and build encoder

class Encoder(nn.Module):
    def __init__(self, 
                 params_dict):

        super(Encoder, self).__init__()

        # Params
        filters = params_dict['filters']
        neurons = params_dict['neurons']
        sub_array_size = params_dict['sub_array_size']
        kernel_size = params_dict['kernel_size']
        
        # The Rectified Linear Unit (ReLU)
        self.relu = nn.ReLU()
        
        # Flattens the feature map to a 1D array
        self.flatten = nn.Flatten(start_dim=1)

        # ---- CONVOLUTION ----
        self.conv1 = nn.Conv2d(in_channels=filters[0], out_channels=filters[1], kernel_size=kernel_size)
        self.conv2 = nn.Conv2d(in_channels=filters[1], out_channels=filters[2], kernel_size=kernel_size)
        self.conv3 = nn.Conv2d(in_channels=filters[2], out_channels=filters[3], kernel_size=kernel_size)
        
        # ---- BATCH NORMALIZATION ----
        self.batch1 = nn.BatchNorm2d(filters[1])
        self.batch2 = nn.BatchNorm2d(filters[2])
        self.batch3 = nn.BatchNorm2d(filters[3])
        
        # ---- LATENT ----
        last_feature_map_size = sub_array_size + (1 - kernel_size) * (len(filters) - 1)
        size_before_latent = filters[-1] * last_feature_map_size ** 2
        self.linear1a = nn.Linear(size_before_latent, neurons[0])
        self.linear1b = nn.Linear(neurons[0], neurons[1])
        self.linear2_mu = nn.Linear(neurons[1], neurons[2])
        self.linear2_logvar = nn.Linear(neurons[1], neurons[2])
        
        # ---- REPARAM -----
        self.N = torch.distributions.Normal(0, 1)

    def forward(self, x):
        
        # Layer 1
        x = self.conv1(x)
        x = self.batch1(x)
        x = self.relu(x)

        # Layer 2
        x = self.conv2(x)
        x = self.batch2(x)
        x = self.relu(x)
        
        # Layer 3
        x = self.conv3(x)
        x = self.batch3(x)
        x = self.relu(x)

        # Hidden Layer
        x = self.flatten(x)
        x = self.linear1a(x)
        x = self.relu(x)
        x = self.linear1b(x)
        x = self.relu(x)
        
        # Mu and LogVar
        mu = self.linear2_mu(x)
        logvar = self.linear2_logvar(x)
        
        # Reparamaterize
        sigma = torch.exp(0.5 * logvar)
        epsilon = self.N.sample(sigma.size())
        z = mu + sigma * epsilon
        
        return z, mu, logvar

Then, we define the functions and layers to build our decoder. The constructor uses the same parameters and functions as the encoder.

The `forward` fucntion builds our decoder from the functions we defined in the constructor, and is built as follows:
 
- Latent Space
    - use the 2 dimensional latent space as inputs for the 128 neuron hidden layer
    - activate using ReLU
- Hidden Layer
    - use the 128 neuron hidden layer as inputs for the flattened 1D array
    - activate using ReLU
- Unflatten a 1D 32 * 16 * 16 array to the 32 16x16 feature maps
- Decoder Layer 1
    - transpose convole 32 16x16 feature maps into 16 20x20 feature maps
    - perform batch normalization
    - activate using ReLU
- Decoder Layer 2
    - transpose convole 16 20x20 feature maps into 8 24x24 feature maps
    - perform batch normalization
    - activate using ReLU
- Decoder Layer 3
    - transpose convole 8 24x24 feature maps into 1 28x28 feature map
    - activate using the [Sigmoid function](https://en.wikipedia.org/wiki/Sigmoid_function) (all outputs are between 0 and 1)

In [None]:
# define functions and build decoder

class Decoder(nn.Module):
    def __init__(self, 
                 params_dict):

        super(Decoder, self).__init__()
        
        # Params
        filters = params_dict['filters']
        neurons = params_dict['neurons']
        sub_array_size = params_dict['sub_array_size']
        kernel_size = params_dict['kernel_size']
        
        last_feature_map_size = sub_array_size + (1 - kernel_size) * (len(filters) - 1)
        unflattened_size = (filters[-1], last_feature_map_size, last_feature_map_size)
        
        # The Rectified Linear Unit (ReLU)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
        # Unflattens the 1D array to feature maps
        self.unflatten = nn.Unflatten(dim=1, unflattened_size=unflattened_size)

        # ---- CONVTRANSPOSE ----
        self.trans1 = nn.ConvTranspose2d(in_channels=filters[3], out_channels=filters[2], kernel_size=kernel_size)
        self.trans2 = nn.ConvTranspose2d(in_channels=filters[2], out_channels=filters[1], kernel_size=kernel_size)
        self.trans3 = nn.ConvTranspose2d(in_channels=filters[1], out_channels=filters[0], kernel_size=kernel_size)
        
        # ---- BATCH NORMALIZATION ----
        self.batch1 = nn.BatchNorm2d(filters[2])
        self.batch2 = nn.BatchNorm2d(filters[1])
        
        # ---- LATENT ----
        size_before_latent = filters[-1] * (last_feature_map_size) ** 2
        self.linear1a = nn.Linear(neurons[2], neurons[1])
        self.linear1b = nn.Linear(neurons[1], neurons[0])
        self.linear2 = nn.Linear(neurons[0], size_before_latent)

    def forward(self, x):
        
        # Out of Latent
        x = self.linear1a(x)
        x = self.relu(x)
        x = self.linear1b(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.relu(x)
        x = self.unflatten(x)        
        
        # Layer 1
        x = self.trans1(x)
        x = self.batch1(x)
        x = self.relu(x)
        
        # Layer 2
        x = self.trans2(x)
        x = self.batch2(x)
        x = self.relu(x)
        
        # Layer 3
        x = self.trans3(x)
        x = self.sigmoid(x)

        return x    

Finally, we define our VAE using the encoder and decoder. **Note: one could have made the entire VAE under one class, but it's good practice to build more complex neural network architectures using separate classes and combining them together. In addition, it'll be easier to call the encoder or decoder.**

In [None]:
class VariationalAutoencoder(nn.Module):
    def __init__(self, params_dict):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = Encoder(params_dict)
        self.decoder = Decoder(params_dict)

    def forward(self, x):
        z, mu, logvar = self.encoder(x)
        x = self.decoder(z)
        return x, mu, logvar

**Tip: a lot of CNNs use max pooling as a downsampling method for more robust feature extraction, but we do not use a downsampling method here. Instead, our convolutional layers will downsample the feature maps for us since we don't use zero padding. The author found it difficult to train a VAE to have the latent space normally distributed using max pooling. When trained using max pooling, all the samples were essentially collapsing on a mean of 0, which does not represent a normal distribution.**

## 3. Test Model Functionality <a id="test"></a>

Before training, we need to make sure our model is properly built, i.e. the expected input (2D 28x28 array) will return the expected output (2D 28x28 array). An error indicates the architecture is inconsistent in some way, such as unexpected input and output filters, unexpected input and output neurons, etc. Some ways to "break" the model are listed below:
- comment out a method in the constructor or forward
- manually change arguments in the methods to a different value

To start off, we define some hyperparameters and build our model.

In [None]:
params_dict = {'filters': [1, 16, 32, 64],
               'neurons': [256, 128, 4],
               'sub_array_size': x_length,
               'kernel_size': 5
              }

model = VariationalAutoencoder(params_dict)

Next, we change the shape of our image to be compatible with PyTorch. The input dimensions for images are (number of samples, number of input channels, y dimension, x dimension), which in our case is (1, 1, 28, 28).

In [None]:
index = 0
test_image = x_train_scale[index].reshape(1,1,x_length,x_length) 

After the dimensions are changed, we convert the image from a `np.array` to a `torch.Tensor`.

In [None]:
test_image_torch = torch.Tensor(test_image)

Now we can "reconstruct" our input image.

In [None]:
testoutput_torch, test_mu, test_logvar = model(test_image_torch)

Since there is no error, we know our model is working. We also move the output from our model using the `detach()` method and convert the `torch.Tensor` to a `np.array` by using the `numpy()` method.

In [None]:
testoutput = testoutput_torch.detach().numpy()

Let's check the shape of the output to make sure they are what we expect. If it's not, then we have to fix our parameters where we defined the model.

In [None]:
print ('The shape of the output is {}.'.format(testoutput.shape))

Now let's plot the input and output. Since the model hasn't been trained, the ouput should look like random noise.

In [None]:
fig, axs = plt.subplots(1,2,figsize=[10,5])
axs[0].set_title('Training Scaled Image {}'.format(index))
axs[0].imshow(test_image[0,0].reshape(x_length, x_length))
axs[1].set_title('Reconstructed Output')
axs[1].imshow(testoutput[0,0].reshape(x_length, x_length))

In addition, it's good practice to know how many trainable parameters are in our model. The number of trainable parameters can be used as a proxy for estimating total training time. We define [a counting function](https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model) for us and determine how many trainable parameters there are in our model.

In [None]:
def count_parameters(model):
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        print([name, param])
        total_params+=param
    return total_params

In [None]:
count_parameters(model)

## 4. Set Training and Test Sets <a id="set"></a>

PyTorch uses iterables to create its data objects. Here we show two ways to format the data to be PyTorch compatible.

1. **Use arrays:** Experienced Python users are more likely to be comfortable using and manipulating arrays. We will just reshape our images to have an input channel of 1, i.e. (1, 28, 28).

2. **Use LoadDataset:** In [Section 2](#build), we defined the `LoadDataset` class to format the data to be PyTorch compatible. The `Dataset` class comes with additional functionality specifically for PyTorch, but is beyond the scope of this tutorial.

We choose option 1 as default, but option 2 can be uncommented below. Using either does not affect training at all and is up to user comfortability/preference.

In [None]:
train_set = x_train_scale.reshape(x_train_size, 1, x_length, x_length)
val_set = x_test_scale.reshape(x_test_size, 1, x_length, x_length)

In [None]:
# LoadDataset class
#train_set = LoadDataset(x_train_scale.reshape(x_train_size, 1, x_length, x_length))
#val_set = LoadDataset(x_test_scale.reshape(x_test_size, 1, x_length, x_length))

We also need to define a baseline for our model to perform better than. The baseline helps us understand if our model is learning anything at all. We choose the mean pixel of the inputs to be our baseline, i.e. a poor model would learn the reconstructed image as an image of the mean pixel. By calculating the [Mean Squared Error (MSE)](https://en.wikipedia.org/wiki/Mean_squared_error) of our training set and mean pixels, we have an established baseline to outperform.

In [None]:
# find mean pixel values of each image
mean = np.mean(x_test_scale, axis=(1,2)).reshape(x_test_size,1,1)

# create mean pixel value images
ones = np.ones((x_test_size, x_length, x_length))
mean_ones = mean * ones

# calculate baseline
baseline = np.sum(np.square(x_test_scale - mean_ones)) / (x_test_size)

baseline

## 5. Hyperparameters and Loading <a id="hyper"></a>

We must set some other hyperparameters for the model to use for training. The hyperparamters we are using are batch size, shuffle, and number of workers. Batch size can be tuned as needed to improve results. Shuffle should almost always be True since the data shouldn't be ordered in any specific way when training. In addition, the number of workers has a default of 0, which uses the main processor on the machine you are using. We also choose the number of epochs we wish to train for.

In [None]:
torch.manual_seed(42)

# Prepping arguments we have to feed to `DataLoader`
params = {
        'batch_size': 128,
        'shuffle': True,
        'num_workers': 0
    }

# Number of epochs to train for
num_epochs = 10

Another useful metric to calculate is how many updates our model will perform during training. We can calculate this by finding the number of batches in the training set (number of training samples / batch size) and multiplying it by the number of epochs. Knowing how many batches our model might need to be well trained can be a good place to start when tuning hyperparameters.

In [None]:
print ('The model will train using a total of {} batches'.format(num_epochs * 
                                                       int(x_train_size / params['batch_size'])))

With our hyperparameters set, we can load our training and test set using `DataLoader`. 

**Note the variable and function names in the notebook are directed for validation sets, but we will use them for the test set instead.** That being said, we use the definitions for validation set and test set interchangeably here.

In [None]:
# TRAINING SET
train_loader = DataLoader(train_set, **params)

# VALIDATION SET
valid_loader = DataLoader(val_set, **params)

We will initialize our model again to be sure we are starting from scratch.

In [None]:
model = VariationalAutoencoder(params_dict)

Now we define our loss function, which in the sum of a reconstruction loss and a divergence.
- The reconstruction loss ensures the decoder accurately reconstructs our input. We choose [Binary Cross Entropy (BCE)](https://en.wikipedia.org/wiki/Cross_entropy) to be the reconstruction loss. We can also choose MSE alternatively, but there are mathematical motivations beyond the scope of this tutorial as to why that isn't the best choice (see [Additional Resources](#add) for more information). In addition, since MNIST pixels approximate a multinomial distribution more than a multivariate gaussian distribution, BCE is more appropriate. In the cell block below, you can uncomment the reconstruction loss to be MSE instead of BCE. **Note: using a ReLU for a final activation instead of a sigmoid may better optimize MSE since the background will most likely become 0 instead of being close to 0 (i.e. the summation of millions of 1e-6 residuals becomes noticable).**
- A divergence measures the statistical distance between two probability distributions. We use the [Kullback-Leibler divergence (also known as KL divergence or $D_{KL}$)](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) as our divergence. Since we want the latent space to approximate a normal distribution, the KL divergence enforces that constraint. In our encoder, the final outputs are the $\mu$ and $log(\sigma^2)$ vectors. We can use these vectors to find the divergence between our latent space and a normal distribution. Some notes about KL divergence:
    - Here is a [derivation](https://leenashekhar.github.io/2019-01-30-KL-Divergence/) of the KL divergence between normal distributions
    - KL divergence is an asymmetric function: $D_{KL}(p||q) \neq D_{KL}(q||p)$
    - KL divergence between a distribution and itself is 0: $D_{KL}(p||p)=0$
    - KL divergence is non negative: $D_{KL}(p||q) \geq 0$

In [None]:
reconstruction_loss = nn.BCELoss(reduction='sum')
#reconstruction_loss = nn.MSELoss(reduction='sum')
def distance(output, data, mu, logvar):
    recon_loss = reconstruction_loss(output, data)
    KLD = -0.5 * (1 + logvar - mu ** 2 - torch.exp(logvar)).sum()
    return recon_loss, KLD

Then we choose our optimizer to be [Adam](https://en.wikipedia.org/wiki/Stochastic_gradient_descent#Adam), since the learning rate updates automatically and trains relatvely fast compared to [Stochastic Gradient Descent (SGD)](https://en.wikipedia.org/wiki/Stochastic_gradient_descent).

In [None]:
optimizer = torch.optim.Adam(model.parameters(),  weight_decay=1e-5)

If you have GPUs available, then those will be used for training. If not, then the model will train on CPUs.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device);

Let's print the device to make sure we know what's available.

In [None]:
device

## 6. Train Model <a id="train"></a>

In order to train our model, we have to manually loop through our data for training. This is probably the biggest difference between PyTorch and [Tensorflow](https://www.tensorflow.org/), but this allows for more hands-on manipulation of how training is performed, which can be advantageous. We will train our model as follows:
1. Change the model to trianing mode to activate backpropogation
2. Initialize training loss to be 0
3. Loop through each batch of features by:
    - Putting the data onto your device
    - Calculating the outputs and the loss
    - Performing backgrpopogation and adding the batch training loss to total training loss
4. Normalize the total training loss by number of samples

In [None]:
# Define train loop

def train_model(train_loader):

    # Change model to training mode (activates backpropogation)
    model.train()
    
    # Initialize training loss
    train_loss_recon = 0
    train_loss_kld = 0
    
    # Loop through batches of training data
    for data in train_loader:
        
        # Put training batch on device
        data = data.float().to(device)

        # Calculate output and loss from training batch
        output, mu, logvar = model(data)
        recon_loss, kld = distance(output, data, mu, logvar)
        loss = recon_loss + kld
        
        # Backpropogation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss_recon += recon_loss.item()
        train_loss_kld += kld.item()
    
    # Normalize training loss from one epoch
    norm = train_loader.dataset.shape[0]
    train_loss_recon_norm = train_loss_recon / norm
    train_loss_kld_norm = train_loss_kld / norm
    train_loss_norm = [train_loss_recon_norm, train_loss_kld_norm]
    
    return train_loss_norm

In addition, we define a similar loop for evaluating the test set at each epoch, which signals us if our model is generalizing. We will test our model as follows:
1. Change model to evaluation mode to deactivate backpropogation
2. Initialize test loss and variances to 0
3. Loop through each batch of features by:
    - Putting the data onto your device
    - Calculating the outputs and the loss
4. Calculate test set loss and normalize it by number of samples
5. Calculate R2 score to determine how correlated the inputs are with the outputs (no correlation tends to 0, high correlation tends to 1)

In [None]:
# Define validation loop

def validate_model(valid_loader):

    # Change model to evaluate mode (deactivates backpropogation)
    model.eval()
    
    # Initialize validation loss and variances
    val_loss_recon = 0
    val_loss_kld = 0
    data_variance = 0
    res_variance = 0
    
    # Do not calculate gradients for the loop
    with torch.no_grad():
        
        # Loop through batches of validation data
        for data in valid_loader:
            
            # Put validation batch on device
            data = data.float().to(device)
            
            # Calculate output and loss from validation batch
            output, mu, logvar = model(data)
            recon_loss, kld = distance(output, data, mu, logvar)
            
            # Calculate variances
            data = data.detach().numpy()
            output = output.detach().numpy()
            
            data_mean = np.mean(data, axis=(1,2,3)).reshape(data.shape[0], 1, 1, 1)
            data_var = np.nansum((data - data_mean)**2)
            res_var = np.nansum((data - output)**2)
            
            data_variance += data_var
            res_variance += res_var
            
            val_loss_recon += recon_loss.item()
            val_loss_kld += kld.item()
    
    # Normalize validation loss from one epoch
    norm = valid_loader.dataset.shape[0]
    val_loss_recon_norm = val_loss_recon / norm
    val_loss_kld_norm = val_loss_kld / norm
    val_loss_norm = [val_loss_recon_norm, val_loss_kld_norm]
    
    # Calculate r2 score
    r2_score = 1 - res_variance / data_variance
    
    return val_loss_norm, r2_score

Finally, we can train our model! We will print out the train and test loss/R2 score per epoch to keep track of performance. The loop below performs the training and validation loops defined above and records our metrics. **Warning: training may take close to an hour; we want our decoder to do a decent job reconstructing the inputs and our latent space to be as smooth as possible. If you have some time/computational constraints, feel free to decrease the number of filters and/or neurons for faster, but less accurate training. The author has found in practice the trade off between increased performance and increased training time to be worth it.**

In [None]:
# keep track of metrics
lst_train_loss = []
lst_val_loss = []
lst_r2_score = []

# training loop
for epoch in tqdm(range(num_epochs), total=num_epochs):

    # Go through loops
    train_loss = train_model(train_loader)
    val_loss, r2_score = validate_model(valid_loader)

    # Append metrics
    lst_train_loss.append(train_loss)
    lst_val_loss.append(val_loss)
    lst_r2_score.append(r2_score)

    # Log
    print('Recon Loss Epoch {:.3f} - Train loss: {:.4f} - Val Loss: {:.4f}'.format(
            epoch, train_loss[0], val_loss[0]))
    print('KLD Epoch {:.3f} - Train loss: {:.4f} - Val Loss: {:.4f}'.format(
            epoch, train_loss[1], val_loss[1]))
    print('R2 Score: {:.4f}'.format(r2_score))

## 7. Plot Loss Function and R2 <a id="plot"></a>

We plot the train/test loss and R2 scores to determine how well converged our model is.

In [None]:
lst_train_loss = np.array(lst_train_loss)
lst_val_loss = np.array(lst_val_loss)

fig, axs = plt.subplots(1, 3, figsize=[15,5])

axs[0].set_title('Reconstruction Loss')
axs[0].plot(np.arange(num_epochs), lst_train_loss[:, 0], label='train')
axs[0].plot(np.arange(num_epochs), lst_val_loss[:, 0], label='val')
axs[0].set_xlabel('Epochs')
axs[0].legend()

axs[1].set_title('KLD')
axs[1].plot(np.arange(num_epochs), lst_train_loss[:, 1], label='train')
axs[1].plot(np.arange(num_epochs), lst_val_loss[:, 1], label='val')
axs[1].set_xlabel('Epochs')
axs[1].legend()

axs[2].set_title('R2 Score')
axs[2].plot(np.arange(num_epochs), lst_r2_score, color='C1')
axs[2].set_xlabel('Epochs')

Don't worry if the KL divergence is increasing; it acts as a regularizer. We can think of the reconstruction loss as fighting against KLD. 

Let's also compare our baseline to our model over the test set to see how well our reconstructions are on average. First, we predict the outputs of our test set.

In [None]:
output, mu, logvar = model(torch.Tensor(val_set).to(device))
recon = output.detach().numpy()
mse = np.sum((val_set - recon) ** 2, axis=(1,2,3))

mu = mu.detach().numpy()
logvar = logvar.detach().numpy()

Now we can compare the baseline and the model.

In [None]:
print ('The model is performing {:.4f} times better than the baseline'.format(baseline / (mse.sum() / x_test_size)))

## 8. Analyze Samples <a id="analyze">
    
Now that our model is trained, let's plot random samples, their reconstructions, and their squared residuals to see how well our decoder reconstructs inputs.

In [None]:
# choose random image and corresponding output from test set
rand_index = np.random.randint(x_test_size)
rand_image = x_test_scale[rand_index]
rand_recon = recon[rand_index][0]
rand_sq_res = (rand_image - rand_recon) ** 2
rand_mse = mse[rand_index]

# plot input, output, and squared residuals
fig, axs = plt.subplots(2,2,figsize=[10,10])
axs[0,0].set_title('Testing Scaled Image {}'.format(rand_index))
axs[0,0].imshow(rand_image)
axs[0,1].set_title('Reconstructed Output'.format(rand_index))
axs[0,1].imshow(rand_recon)
axs[1,0].set_title('Squared Residual Image')
axs[1,0].imshow(rand_sq_res)
axs[1,1].set_title('Squared Residual Image (0-1 min-max)')
axs[1,1].imshow(rand_sq_res, vmin=0, vmax=1)
plt.tight_layout()

print ('MSE: {:.4f}'.format(rand_mse))

Now, let's plot the distribution of MSEs to get a better understanding of how well our model reconstructs each test sample.

In [None]:
plt.figure(figsize=[10,5])
plt.title('Test Set MSE Distribution')
plt.hist(mse, bins=50)
plt.xlabel('mse')
plt.ylabel('frequency')

In addition, let's see if we can distinguish the loss by class. If most of the distributions are within reason in relation to each other, then the model generalizes to all classes.

In [None]:
plt.figure(figsize=[10,5])
plt.title('Test Set MSE Distribution (by class)')
for digit in range(10):
    plt.hist(mse[y_test == digit], bins=50, label=digit, alpha=0.25)
plt.xlabel('mse')
plt.ylabel('frequency')
plt.legend()

Although the model is performing a lot better than the baseline, there are still samples it struggles with. Let's see how many samples have a MSE more than 3 sigma above the mean.

In [None]:
threshold = mse.mean() + 3 * mse.std()
mask = mse > threshold

print ('There are {} MSEs above {:.4f}.'.format(mask.sum(), threshold))

Now with our mask, we can look through "poorly" reconstructed samples. **Note: because the model performs the reparameterization trick on our inputs, the poor reconstructions could be due to high variance added to the sample in the latent space.**

In [None]:
# choose random incorrect image and corresponding output from test set
rand_index = np.random.randint(mask.sum())
rand_image = x_test_scale[mask][rand_index]
rand_recon = recon[mask][rand_index][0]
rand_sq_res = (rand_image - rand_recon) ** 2
rand_mse = mse[mask][rand_index]

# plot input, output, and squared residuals
fig, axs = plt.subplots(2,2,figsize=[10,10])
axs[0,0].set_title('Testing Masked Scaled Image {}'.format(rand_index))
axs[0,0].imshow(rand_image)
axs[0,1].set_title('Reconstructed Output'.format(rand_index))
axs[0,1].imshow(rand_recon)
axs[1,0].set_title('Squared Residual Image')
axs[1,0].imshow(rand_sq_res)
axs[1,1].set_title('Squared Residual Image (0-1 min-max)')
axs[1,1].imshow(rand_sq_res, vmin=0, vmax=1)
plt.tight_layout()

print ('MSE: {:.4f}'.format(rand_mse))

## 9. Visualize the Latent Space <a id="latent"></a>

As mentioned in the [Introduction](#intro), one of the use cases of an autoencoder is to reduce the dimensionality of the dataset. Since our decoder is able to reconstruct our inputs to a high degree, that means our data is efficiently stored in the latent space. By using the encoder as a feature extractor, we can visualize the data in the latent space.

In [None]:
fig, axs = plt.subplots(1, 2, figsize=[20,10])
for digit in range(10):
    axs[0].scatter(mu[:, 0][y_test==digit], mu[:, 2][y_test==digit], label=digit, alpha=0.25)
    axs[1].scatter(mu[:, 1][y_test==digit], mu[:, 3][y_test==digit], label=digit, alpha=0.25)

axs[0].set_xlabel('mu 0')
axs[0].set_ylabel('mu 2')
axs[1].set_xlabel('mu 1')
axs[1].set_ylabel('mu 3')

axs[0].legend()
axs[1].legend()

plt.tight_layout()

Even if the classes do not separate out distinctly, it's still impressive that the decoder is able to decipher what a digit will look like in a two dimensional space! There are a few things you can try to see if we can get better representation in the latent space.
- Increase the number of dimensions of the latent space: currently we are using 2 dimensions because it's easy to plot and visualize. However, as the number of dimensions increases so does the amount of information the latent space can store. The learned manifold will drasically improve in these slightly higher dimensions (5-10), which is still far lower than the original input space (784).
- Increase the depth/width of the VAE: currently we using 3 convolutional layers and 2 fully connected layers. By increasing the depth/width, we can extract even higher level features that could better represent the digits.

The samples should approximate (subjective) a standard normal distribution: $N(\mu\approx0,\sigma^2\approx1)$. Let's check to make sure that's true. **Note: The $log(\sigma^2)$ space, however, can form any distribution.**

In [None]:
print ('Latent 0 Mean {:.3f} and Variance {:.3f}'.format(mu[:, 0].mean(), mu[:, 0].std()**2))
print ('Latent 1 Mean {:.3f} and Variance {:.3f}'.format(mu[:, 1].mean(), mu[:, 1].std()**2))

Finally, we can [plot the learned manifold](https://github.com/eugeniaring/Pytorch-tutorial/blob/main/VAE_mnist.ipynb) of our latent space to see how samples smoothly change across it.

In [None]:
def plot_reconstructed(encoder, decoder, N=3, n=15):
    r0 = (-N, N)
    r1 = (-N, N)
    plt.figure(figsize=(10,10))
    w = 28
    img = np.zeros((n*w, n*w))
    for i, y in enumerate(np.linspace(*r1, n)):
        for j, x in enumerate(np.linspace(*r0, n)):
            z = torch.Tensor([[x, 0, y, 0]]).to(device)
            x_hat = decoder(z)
            x_hat = x_hat.reshape(w, w).detach().numpy()
            img[(n-1-i)*w:(n-1-i+1)*w, j*w:(j+1)*w] = x_hat
    plt.imshow(img, extent=[*r0, *r1], cmap='gist_gray')

In [None]:
plot_reconstructed(model.encoder, model.decoder)  

## 10. Generate Samples <a id="detect"></a>

As mentioned in the [Introduction](#intro), another use case of an autoencoder is to generate new samples. Since the latent space can be accurately decoded, we can randomly generate new digits by drawing from the latent space.

In [None]:
n = 6
z = torch.Tensor(np.random.normal(0,1,(n**2, params_dict['neurons'][-1]))).to(device)
generated_z = model.decoder(z).detach().numpy()

Now let's plot our generated images.

In [None]:
fig, axs = plt.subplots(6, 6, figsize=[10,10])
for i in range (n):
    for j in range (n):
        axs[i, j].imshow(generated_z[i*n+j][0])
plt.tight_layout()

We can definetly tell some of them are fake, but not a bad start considering these 28x28 images are mapped from a 2D space! We can increase the purity of sample generation by training on a wider/deeper VAE, increasing the latent space to higher dimensions, and/or training until full convergence.

## 11. Conclusions <a id="con"></a>

The variational autoencoder is a popular and powerful unsupervised learning method for dimensionality reduction and sample generation. It adds on to the already impressive traditional autoencoder by approximating the latent space to a normal distribution. There's a reason it's become a staple in discovering the structure of complex datasets. 

Thank you for walking through this notebook. Now you should be more familiar with:
- building a variational autoencoder (VAE)
- training and evaluate a VAE
- visualizing the latent space of a VAE
- generating samples from the latent space

**Congratulations, you have completed the notebook!**

## Additional Resources <a id="add"></a>

Machine learning is a dense and rapidly evolving field of study. Becoming an expert takes years of practice and patience, but hopefully this notebook brought you closer in that direction. Here are some of the author's favorite resources for learning about machine learning and data science:

- [Google Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/ml-intro)
- [scikit-learn Python Library](https://scikit-learn.org/stable/index.html) (go-to for most ML algorithms besides neural networks)
- [StatQuest YouTube Channel](https://www.youtube.com/c/joshstarmer)
- [DeepLearningAI YouTube Channel](https://www.youtube.com/c/Deeplearningai/videos)
- [Towards Data Science](https://towardsdatascience.com/) (articles about data science and machine learning, some involving example blocks of code)
- Advance searching [arxiv](https://arxiv.org/search/advanced) (e.g. search term "machine learning" in Abstract for Subject astro-ph) to see what others are doing currently
- Google, YouTube, and Wikipedia in general
- [Variational Autoencoder Original Paper](https://arxiv.org/abs/1312.6114)
- MSE vs BCE
    - [A Tutorial on VAEs](https://arxiv.org/pdf/2006.10273.pdf) (see section 5.1 for arguments against MSE)
    - [Maximizing Log Likelihood](https://www.expunctis.com/2019/01/27/Loss-functions.html)
    - [StatsStackExchange](https://stats.stackexchange.com/questions/350211/loss-function-autoencoder-vs-variational-autoencoder-or-mse-loss-vs-binary-cross)
- Supplementary GitHub Repos used for learning VAEs in PyTorch:
    - [PyTorch Beginner - VAE](https://github.com/L1aoXingyu/pytorch-beginner/blob/master/08-AutoEncoder/Variational_autoencoder.py)
    - [PyTorch Tutorials - VAE](https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/variational_autoencoder/main.py)
    - [Medium Articles - VAE](https://github.com/eugeniaring/Medium-Articles/blob/main/Pytorch/VAE_mnist.ipynb)
    

## About this Notebook <a id="about"></a>

**Author:** Fred Dauphin, DeepWFC3

**Updated on:** 2021-12-14

## Citations <a id="cite"></a>

If you use `numpy`, `matplotlib`, or `torch` for published research, please cite the authors. Follow these links for more information about citing `numpy`, `matplotlib`, and `torch`:

* [Citing `numpy`](https://numpy.org/doc/stable/license.html)
* [Citing `matplotlib`](https://matplotlib.org/stable/users/project/license.html#:~:text=Matplotlib%20only%20uses%20BSD%20compatible,are%20acceptable%20in%20matplotlib%20toolkits.)
* [Citing `torch`](https://github.com/pytorch/pytorch/blob/master/LICENSE)

***
[Top of Page](#title)