In [1]:
from torchvision.datasets import MNIST
from torchvision import transforms
import torch
from torch import nn
import numpy as np
import matplotlib.pylab as plt

In [2]:
torch.manual_seed(0)

def get_free_gpu():
    from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlDeviceGetCount
    nvmlInit()

    return np.argmax([
        nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i)).free
        for i in range(nvmlDeviceGetCount())
    ])

if torch.cuda.is_available():
    cuda_id = get_free_gpu()
    device = 'cuda:%d' % (get_free_gpu(), )
    print('Selected %s' % (device, ))
else:
    device = 'cpu'
    print('WARNING: using cpu!')

### please, don't remove the following line
x = torch.tensor([1], dtype=torch.float32).to(device)

Selected cuda:2


In [3]:
# Training dataset
binarizing_transform = transforms.Compose([
    transforms.ToTensor(),
    lambda x: x>0.5,
    lambda x: x.float()])
train_loader = torch.utils.data.DataLoader(
    MNIST(root='../../data', train=True, download=True,
          transform=binarizing_transform),
    batch_size=100, shuffle=True, pin_memory=True)
# Test dataset
test_loader = torch.utils.data.DataLoader(
    MNIST(root='../../data', train=False,
          transform=binarizing_transform),
    batch_size=100, shuffle=True, pin_memory=True)

# Practical Session. Variational Autoencoders

In this session you will implement a vanilla VAE on MNIST dataset. The implementation will be based on the [*torch.distributions*](https://pytorch.org/docs/stable/distributions.html) module.

To complete the task, you will need construct the loss function using the distributions module and then train the model.

### Goal:
The goal is to train and study a variational auto-encoder to approximate the dataset distribution, namely a distribution of low-resolution binary hand-written digits. To achieve the goal you should only implement the loss function for the variational auto-encoder.

### Data:

The dataset consists of 70000 binary images with hand-written digits of size $ 28 \times 28 $ .

### Criteria:

The main training criterion for variational-autoencoders is the average log-likelihood on the test set. For simplicity, we will use the evidence lower bound (ELBO, a lower bound on log-likelihood) instead of the likelihood. Recall that the evidence lower bound is the training objective of VAE. For the correct implementation of the loss function the average loss (negative ELBO) on a test set should be close to 95.

After traning the model should be able to reconstruct input digits (run *plot_reconstructions*, are outputs reasonable?) and interpolate between digits in the latent space (run *plot_interpolations*, are outputs reasonable?). 

In the last cell we project latent codes of test data on a 2d plane. The points color represents the digits from the original image. The projection algorithm preserves local distances (if points are close on a picture, they were close in the original space). The points on the plane form clusters for different digits. Similar digits should have close clusters.

# Motivation: AEs  vs. VAEs

Below are pair of MNIST images and reconstructions built with an auto-encoder. Notably, autoencoders can provide good reconstruction quality.

![Autoencoder reconstructions](https://github.com/bayesgroup/deepbayes-2018/blob/master/day2_vae/ae_reconstructions.png?raw=true)

Still, the model has no control over the learned latent representations. For example, an interpolation of latent representations of two digits is typically not a latent representation for a digit:

![Autoencoder interpolations](https://github.com/bayesgroup/deepbayes-2018/blob/master/day2_vae/ae_interpolations.png?raw=true)

On the other hand, a standard VAE model forces latent representation to fit a multivariate Gaussian distribution. As a result, an interpolation of two latent representations is likely to be a latent representation of a digit.

# Distributions for VAE

For the assignment, we will need two types of distributions to define the probabilistic model.
1. For the latent variable $z$ we need a vector of independent [normally distributed](https://pytorch.org/docs/stable/distributions.html#normal) scalars.
2. For observations $x$, we will need a vector of independent [Bernoulli](https://pytorch.org/docs/stable/distributions.html#bernoulli) random variables.
3. By default, the classes model a tensor of independent random **variables**. To represent a matrix of independent random variables as a batch of random **vectors** you may also use the [Independent](https://pytorch.org/docs/stable/distributions.html#independent) class.

### Bernoulli random vector

In the task, the class to models $p(x | z)$ parametrized by the output of the decoder. To define the loss function you will need to compute $\log p(x | z)$ for input images using *log_prob()* method.

*Tip:* While the class can be initialized both with probabilities and logits, the best practice is to initialize the class with logits. Otherwise, computing logarithm of probability can be highly unstable. 


### Normal Distribution

In this task, you will use the class to define the approximate posterior distribution $q(x | z)$ and the latent variable distribution $p(z)$. Again, you will use *log_prob()* method to compute the loss function.

**Importantly,** VAE generates a sample from $q(x | z)$ to pass it to the decoder. To implement the reparametrization trick the class defines a specific sampling method *rsample()*, that computes $z = \mu(x) + \varepsilon \odot \sigma(x)$ for standard Gaussian noise $\varepsilon$. Notice that the implementation of *rsample()* method differs from the implementation of *sample()* method.




In [4]:
from torch.distributions import Normal, Bernoulli, Independent

# Quick references:

## Vanilla VAE Specification

Probabilistic model: 
\begin{align}
& p(x, z \mid \theta) =  p(z) p(x \mid z, \theta) \\
& p(z) = \mathcal N(z \mid 0, I) \\
& p(x \mid z, \theta) = \prod_{i = 1}^D f_i(z, \theta)^{x_i} (1 - f_i(z, \theta))^{1 - x_i}.
\end{align}
Inference model:
\begin{equation}
q(z \mid x, \phi) = \mathcal N(z \mid \mu(x, \phi), \operatorname{diag}(\sigma^2(x, \phi))).
\end{equation}
Objective for a single sample $x$:
$$ \mathcal L(x, \theta, \phi) = \mathbb E_{q(z \mid x, \phi)} \left[ \log p(x \mid z, \phi) + \log p(z) - \log q(z \mid x, \theta) \right] $$
Objective estimate for a single sample $x$:
\begin{align*}
\log p(x \mid z_0, \phi) + \log p(z_0) - \log q(z_0 \mid x, \theta) \\
z_0 = \mu(x, \phi) + \sigma^2(x, \phi)^T \varepsilon_0 \\
\varepsilon_0 \sim \mathcal N(0, I)
\end{align*}

Tip: to train the model we average the lower bound values over the minibatch of $x$ and then maximize the estimate with gradient ascent.

## Encoder and decoder parametrization

VAE uses two neural netowrk to parametrize two distributions introduced above:

- Encoder (*enc*) takes $x$ as input and return $2 \times d$-dimensional vector to parametrize mean and standard deviation of $q(z \mid x, \theta)$
- Decoder (*dec*) takes a latent representation $z$ and returns the logits of distribution $p(x \mid z, \phi)$.

The computational graph has a simple structure of autoencoder. The only difference is that now it uses a stochastic variable $\varepsilon$:

![vae](https://github.com/bayesgroup/deepbayes-2018/blob/master/day2_vae/vae.png?raw=true)

Below we initialize a couple of simple fully-connected networks to model the two distributions. 

In [5]:
d, nh, D = 32, 100, 28 * 28

enc = nn.Sequential(
    nn.Linear(D, nh),
    nn.ReLU(),
    nn.Linear(nh, nh),
    nn.ReLU(),
    nn.Linear(nh, 2 * d)) # note that the final layer outputs real values

dec = nn.Sequential(
    nn.Linear(d, nh),
    nn.ReLU(),
    nn.Linear(nh, nh),
    nn.ReLU(),
    nn.Linear(nh, D)).to(device) # <-----------------------------------------------

enc = enc.to(device)
dec = dec.to(device)

## Task: VAE Loss function

Implement the loss function for the variational autoencoder

In [0]:
def loss_vae(x, encoder, decoder):
    """ 
    TODO
    
    input:
    x is a tensor of shape (batch_size x 784)
    encoder is a nn.Module that follows the above specification
    decoder is a nn.Mudule that follows the above specification
    
    output:
    1. the avergave value of negative ELBO across the minibatch x
    2. and the output of the decoder
    """    
    loss = 
    decoder_output = 
    return loss, decoder_output

After implementing the loss run the followings cells to train VAE and visualise several examples.

## Training
The cell below implements a simple training function that can be used for both models.

In [0]:
from itertools import chain

def train_model(loss, model, batch_size=100, num_epochs=3, learning_rate=1e-3):
    gd = torch.optim.Adam(
        chain(*[x.parameters() for x in model
                if (isinstance(x, nn.Module) or isinstance(x, nn.Parameter))]),
        lr=learning_rate)
    train_losses = []
    test_results = []
    for _ in range(num_epochs):
        for i, (batch, _) in enumerate(train_loader):
            total = len(train_loader)
            gd.zero_grad()
            batch = batch.view(-1, D).to(device)
            loss_value, _ = loss(batch, *model)
            loss_value.backward()
            train_losses.append(loss_value.item())
            if (i + 1) % 10 == 0:
                print('\rTrain loss:', train_losses[-1],
                      'Batch', i + 1, 'of', total, ' ' * 10, end='', flush=True)
            gd.step()
        test_loss = 0.
        for i, (batch, _) in enumerate(test_loader):
            
            batch = batch.view(-1, D).to(device)
            batch_loss, _ = loss(batch, *model)
            test_loss += (batch_loss - test_loss) / (i + 1)
        print('\nTest loss after an epoch: {}'.format(test_loss))

In [0]:
# my implementation has test loss ~96
train_model(loss_vae, model=[enc, dec], num_epochs=16)

## Visualisations

- How do reconstruction compare to reconstructions of autoencoder?
- Interpolations?
- Is the latent space regularly covered? 
- Is there any dependence between T-SNE encoding and the digit label?

In [0]:
def sample_vae(dec, n_samples=50):
    # this function returns the mean of p(x | z), not a sample from p(x | z)
    # so, being accurate, the output is not a sample from vae
    # but the mean value is more informative for visualization
    with torch.no_grad():
        samples = torch.sigmoid(dec(torch.randn(n_samples, d).to(device)))
        samples = samples.view(n_samples, 28, 28).cpu().numpy()
    return samples
    
def plot_samples(samples, h=5, w=10):
    fig, axes = plt.subplots(nrows=h,
                             ncols=w,
                             figsize=(int(1.4 * w), int(1.4 * h)),
                             subplot_kw={'xticks': [], 'yticks': []})
    for i, ax in enumerate(axes.flatten()):
        ax.imshow(samples[i], cmap='gray')

In [0]:
plot_samples(sample_vae(dec=dec))

In [0]:
def plot_reconstructions(loss, model):
    with torch.no_grad():
        batch = (test_loader.dataset.data[:25].float() / 255.)
        batch = (batch > 0.5).float()
        batch = batch.view(-1, D).to(device)
        _, rec = loss(batch, *model)
        rec = torch.sigmoid(rec)
        rec = rec.view(-1, 28, 28).cpu().numpy()
        batch = batch.view(-1, 28, 28).cpu().numpy()
    
        fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(14, 7),
                                 subplot_kw={'xticks': [], 'yticks': []})
        for i in range(25):
            axes[i % 5, 2 * (i // 5)].imshow(batch[i], cmap='gray')
            axes[i % 5, 2 * (i // 5) + 1].imshow(rec[i], cmap='gray')
        

In [0]:
plot_reconstructions(loss_vae, [enc, dec])

In [0]:
def plot_interpolations(encoder, decoder):
    with torch.no_grad():
        batch = (test_loader.dataset.data[:10].float() / 255.)
        batch = (batch > 0.5).float()
        batch = batch.view(-1, D).to(device)
        batch = encoder(batch)
        z_0 = batch[:5, :d].view(5, 1, d)
        z_1 = batch[5:, :d].view(5, 1, d)
        
        alpha = torch.linspace(0., 1., 10).to(device)
        alpha = alpha.view(1, 10, 1)
        
        interpolations_z = (z_0 * alpha + z_1 * (1 - alpha))
        interpolations_z = interpolations_z.view(50, d)
        interpolations_x = torch.sigmoid(decoder(interpolations_z))
        interpolations_x = interpolations_x.view(5, 10, 28, 28).cpu().numpy()
    
    fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(14, 7),
                             subplot_kw={'xticks': [], 'yticks': []})
    for i in range(50):
        axes[i // 10, i % 10].imshow(interpolations_x[i // 10, i % 10], cmap='gray')

In [0]:
plot_interpolations(enc, dec)

In [0]:
def plot_tsne(objects, labels):
    from sklearn.manifold import TSNE
    embeddings = TSNE(n_components=2).fit_transform(objects)
    plt.figure(figsize=(8, 8))
    for k in range(10):
        embeddings_for_k = embeddings[labels == k]
        plt.scatter(embeddings_for_k[:, 0], embeddings_for_k[:, 1],
                    label='{}'.format(k))
    plt.legend()

In [0]:
with torch.no_grad():
    batch = (test_loader.dataset.data[:1000].float() / 255.)
    batch = (batch > 0.5).float()
    batch = batch.view(-1, D).to(device)

    latent_variables = enc(batch)[:, :d]
    latent_variables = latent_variables.cpu().numpy()
    labels = test_loader.dataset.targets[:1000].numpy()
  
plot_tsne(latent_variables, labels)

The notebook implements a basic variational auto-encoder. Even with a toy architecture the model manages to learn a crude approximation of the data distribution.