# CNJC Variational Bayesian Methods

This notebook walks through a Variational auto-encoder as introduced by [Kingma & Welling 2013](https://arxiv.org/abs/1312.6114) on the classic MNIST dataset.

# Setup


## Add GPU to Colab notebook


Edit menu -> Notebook Settings -> Hardware accelerator -> select "GPU" -> Save

## imports

In [0]:
%matplotlib inline

In [0]:
import torch
from torch import nn, optim, utils
from torchvision import datasets, transforms
from torch.nn import functional as F
import torchvision.utils
import numpy as np
import moviepy.editor as mpy
from moviepy.video.io.bindings import mplfig_to_npimage
import matplotlib.pyplot as plt
from functools import partial
device = torch.device("cuda")
torch.manual_seed(20190710) # reproducible analysis

Imageio: 'ffmpeg-linux64-v3.3.1' was not found on your computer; downloading it now.
Try 1. Download from https://github.com/imageio/imageio-binaries/raw/master/ffmpeg/ffmpeg-linux64-v3.3.1 (43.8 MB)
Downloading: 8192/45929032 bytes (0.0%)3252224/45929032 bytes (7.1%)6914048/45929032 bytes (15.1%)10371072/45929032 bytes (22.6%)14024704/45929032 bytes (30.5%)17678336/45929032 bytes (38.5%)21258240/45929032 bytes (46.3%)24944640/45929032 bytes (54.3%)28663808/45929032 bytes (62.4%)32325632/45929032 bytes (70.4%)36003840/45929032 bytes (78.4%)39460864/45929032 bytes (85.9%)43188224/45929032 bytes (94.0%)

<torch._C.Generator at 0x7fdd6b3ce290>

## Data

In [0]:
!mkdir mnist

In [0]:
batch_size = 128
kwargs = {'num_workers': 1, 'pin_memory': True}
mnist_train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
mnist_test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

# shape is batch_size x 1 x 28 x 28

0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


9920512it [00:01, 8450249.04it/s]                            


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


32768it [00:00, 135841.73it/s]           
  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


1654784it [00:00, 2222769.65it/s]                            
0it [00:00, ?it/s]

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


8192it [00:00, 50819.52it/s]            


Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!


# Variational Auto-encoder

## Implementation

Implementation based off of PyTorch [basic VAE](https://github.com/pytorch/examples/tree/master/vae)

In [0]:
# subclass PyTorch Module for reverse-mode autodifferentiation 
# for easy backpropogation of loss gradient
class VAE(nn.Module):
    
    def __init__(self, nfeatures,nlatent=20):
        super(VAE, self).__init__()
        self.nfeatures = nfeatures
        self.nhidden = int(nfeatures/5)
        
        # nn.Linear is a "dense" layer of form y = Ax + b
        
        # Encoder layers
        self.hidden_encoder = nn.Linear(nfeatures, self.nhidden)
        # mean encoding layer 
        self.mean_encoder = nn.Linear(self.nhidden, nlatent)
        # log variance encoding layer 
        self.logvar_encoder = nn.Linear(self.nhidden, nlatent)
        
        # Decoder layers
        self.hidden_decoder = nn.Linear(nlatent, int(nfeatures/5))
        self.reconstruction_decoder = nn.Linear(self.nhidden, nfeatures)

    def encode(self, x):
        # we use a ReLu (rectified linear unit) activation function
        h1 = F.relu(self.hidden_encoder(x))
        return self.mean_encoder(h1), self.logvar_encoder(h1)

    def reparameterize(self, mean, logvar):
        """Reparameterize out stochastic node so the gradient can propogate 
           deterministically."""

        if self.training:
            standard_deviation = torch.exp(0.5*logvar)
            # sample from unit gaussian with same shape as standard_deviation
            epsilon = torch.randn_like(standard_deviation)
            # TODO: write this line. Stuck? see answers at bottom of notebook
            return NotImplementedError()
        else:
            return mean

    def decode(self, z):
        h3 = F.relu(self.hidden_decoder(z))
        # use sigmoid to bound output to (0,1)
        return F.sigmoid(self.reconstruction_decoder(h3))

    
    def forward(self, x):
        "A special method in PyTorch modules that is called by __call__"
        
        # flatten batch x height x width into batch x nFeatures, then encode
        mean, logvar = self.encode(x.view(-1, self.nfeatures))
        # sample an embedding, z
        z = self.reparameterize(mean, logvar)
        # return the (sampled) reconstruction, mean, and log variance
        return self.decode(z), mean, logvar

In [0]:
def loss_function(recon_x, x, mu, logvar, nfeatures):
    "Reconstruction + KL divergence losses summed over all elements and batch."
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, nfeatures), size_average=False)

    # we want KLD = - 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    # where sigma is standard deviation and mu is mean
    # (interested? check out Appendix B of https://arxiv.org/abs/1312.6114)
    # In pytorch, x^2 is written as x.pow(2), e^x is written as x.exp(),
    # and sum_{i=1}^n (x_i + y_i) for x,y of length n
    # can be written as torch.sum(x+y)
    # TODO: write this line. Stuck? see answers at bottom of notebook
    KLD = NotImplementedError()

    return BCE + KLD

## train & test functions

In [0]:
def train(epoch, model, optimizer, train_loader, log_interval=10):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        data = data[0].to(device)  # we ignore any labels & transfer to GPU
        nfeatures = data[0].numel()
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar, nfeatures)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))


def test(epoch, model, test_loader,folder="results"):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            data = data[0].to(device)
            nfeatures = data[0].numel()
            n = min(data.size(0), 15)
            if len(data.shape)==3:
                  # zebrafish
                _, H, W = data.shape
                dat = data[:n,None]
            elif len(data.shape)==4:
                  # MNIST
                _, _, H, W = data.shape
                dat = data[:n]
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar, nfeatures).item()
            if i == 0:              
                comparison = torch.cat([dat,
                                   recon_batch.view(-1, 1, H, W)[:n]])
                torchvision.utils.save_image(comparison.cpu(),
                         folder+'/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

## MNIST

In [0]:
# run cell to reset model
nfeatures = 28**2
# we use a latent space of dimension 2 as to get an easy-to-visualize manifold
# (see mnist/sample_*.png while running next cell)
nlatent = 2
mnist_model = VAE(nfeatures,nlatent=nlatent).to(device)
mnist_optimizer = optim.Adam(mnist_model.parameters(), lr=1e-3)
!rm mnist/*

rm: cannot remove 'mnist/*': No such file or directory


In [0]:
# this will take two minutes to run.
# As it does, check out the mnist folder!
# each epoch, reconstruction examples are saved (original on top) 
# select files on right, then refresh, and double click image
# click bottom right corner of image to resize

nepochs = 2
H, W = (28,28)

# make grid of z1 x z2 where z1,z2 \elem (-3.5,-2.5, ..., 3.5)
nrow = 25
latents = torch.zeros(nrow,nrow,nlatent)
z1_tick = np.linspace(-3.5,3.5,nrow)
z2_tick = np.linspace(-3.5,3.5,nrow)
for i, z1 in enumerate(z1_tick):
    for j, z2 in enumerate(z2_tick):
        latents[i,j,[0,1]] = torch.tensor([z1,z2])
latents = latents.to(device)

for epoch in range(1, nepochs + 1):
    train(epoch, mnist_model, mnist_optimizer, mnist_train_loader)
    test(epoch, mnist_model, mnist_test_loader,folder='mnist')
    with torch.no_grad():
        latent_space = mnist_model.decode(latents.view(-1,nlatent)).cpu()
        torchvision.utils.save_image(latent_space.view(-1, 1, H, W),
                   'mnist/sample_' + str(epoch) + '.png',nrow=nrow)

# Stuck? Here's the answers

## reparameterize

`return epsilon * standard_deviation + mean`

## loss_function

`KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())`