# Variational Autoencoders

Variational autoencoders (VAE) are generative networks that incorporate neural networks and latent variable models.  It consists of an encoder and a decoder that seeks to reconstruct the input from a compressed latent represenation of the input.  It is a great example of how we can use variational inference to solve machine learning problems. We demonstrate how to easily implement a VAE in Pyro.


First let's import the packages and modules we need. We really only require Pytorch and Pyro. Everything else is optional if you prefer to use your own data loader, visualization library, etc. For this example, we will use [visdom](https://github.com/facebookresearch/visdom) to visualize the output images.

In [2]:
import argparse
import numpy as np

# import pyro!
import pyro
from pyro.infer.kl_qp import KL_QP
from pyro.distributions import DiagNormal, Normal
from pyro.util import ng_zeros, ng_ones

# modules from torch
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim

import visdom # (optional) for visualization

First let's load some data. We'll use [MNIST](http://yann.lecun.com/exdb/mnist/) for simplicity, which contain images of size 28x28 flattened into a vector of length 784.

In [3]:
# path to data
root = './data'
download = True
trans = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
train_set = dset.MNIST(
    root=root,
    train=True,
    transform=trans,
    download=download)
test_set = dset.MNIST(root=root, train=False, transform=trans)

# Use batch size of 128
batch_size = 128
kwargs = {'num_workers': 1, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=batch_size,
    shuffle=False, **kwargs)

## Encoder

Our goal is to generate high quality images of handwritten digits. We first define our encoder, which is a neural net with 3 fully connected layers. It takes the input vector of size 784, and through non-linearities, encodes it to a latent representation with 20 dimensions.

In [None]:
class Encoder(nn.Module):

    def __init__(self):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(784, 200)
        self.fc21 = nn.Linear(200, 20)
        self.fc22 = nn.Linear(200, 20)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 784)
        h1 = self.relu(self.fc1(x))
        return self.fc21(h1), torch.exp(self.fc22(h1))

## Decoder

Next, our decoder will reproduce the original input from the latent represenation produced by the encoder in the previous step. It assumes an input of size 20 and through a series of non-linearities, outputs a vector of 2 * 784 (for $\mu$ and $\sigma$). 

In [3]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc3 = nn.Linear(20, 200)
        self.fc4 = nn.Linear(200, 2 * 784)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()

    def forward(self, z):
        h3 = self.relu(self.fc3(z))
        rv = (self.fc4(h3))

        # reshape to capture mu, sigma params for every pixel
        rvs = rv.view(z.size(0), -1, 2)

        # send back two params
        return rvs[:, :, 0], torch.exp(rvs[:, :, 1])


NameError: name 'nn' is not defined

# Model
Now we want to define our model and guide to do inference. The model samples from the prior $\mathcal{N}(\mu=0, \sigma=1)$ and runs the decoder on the sample to calculate the parameters for the image distribution. We then score the image data against a Gaussian parameterized by the $\mu, \sigma$ generated by the decoder in the previous step. This is done via an `observe()` statement.

The guide is the approximating distribution which will be used to sample when inference is run. It simply samples from a `DiagNormal()` parameterized by $\mu$ and $\sigma$ from the encoder.

In [2]:
def model(data):
    # klqp gets called with data.

    # wrap params for use in model -- required
    decoder = pyro.module("decoder", pt_decode)
    
    # sample from prior
    z_mu, z_sigma = ng_zeros(
        [data.size(0), 20]), ng_ones([data.size(0), 20])

    # sample (retrieve value set by the guide)
    z = pyro.sample("latent", DiagNormal(z_mu, z_sigma))

    # decode into size of imgx2 for mu/sigma
    img_mu, img_sigma = decoder.forward(z)

    # score against actual images
    pyro.observe("obs", DiagNormal(img_mu, img_sigma), data.view(-1, 784))


def guide(data):
    # wrap params for use in model -- required
    encoder = pyro.module("encoder", pt_encode)

    # use the ecnoder to get an estimate of mu, sigma
    z_mu, z_sigma = encoder.forward(data)

    pyro.sample("latent", DiagNormal(z_mu, z_sigma))

# Inference
For this example, we are going to use variational inference to approximate the posterior by minimizing the [Kullback–Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) between the approximate latent distribution and the true posterior, which equivalently maximizes the evidence lower bound (ELBO).

\begin{equation*}
\mathrm{ELBO} = 
\mathbb{E}_{q)} [ \log p(\mathbf{x}, \mathbf{z};) ]
- \mathbb{E}_{q} [ \log q(\mathbf{z}) ].
\end{equation*}

In Pyro, this can by using `pyro.infer.ELBO`. We'll use the [ADAM](https://arxiv.org/abs/1412.6980) algorithm as the optimizer.

In [6]:
kl_optim = Elbo(model, guide, pyro.optim(optim.Adam, {"lr": .0001}))

NameError: name 'KL_QP' is not defined

We now have all the components we need for the VAE. To run the program, we just iterate over the number of epochs with our minibatch.

In [None]:
def main():
    for i in range(num_epochs):
        epoch_loss = 0.
        for ix, batch_start in enumerate(all_batches[:-1]):
            batch_end = all_batches[ix + 1]
            # get batch
            batch_data = mnist_data[batch_start:batch_end]
            epoch_loss += kl_optim.step(batch_data)
        print("epoch avg loss {}".format(epoch_loss / float(mnist_size)))

And voila! We visualize the results with `visdom`.

Insert results and visualizations here
![Fig](link)


See the full code on [Github](https://github.com/uber/pyro/blob/dev/examples/vae.py).