# The variational autoencoder

A "new" model presented in 2013/2014, this is since then one of the most important models used for unsupervised learning.

The VAE is a probabilistic AE with strong advantages,as published in https://arxiv.org/abs/1312.6114 and https://proceedings.mlr.press/v32/rezende14.html

This notebook (c) Patrick van der Smagt, March 2023, heavily building on various internet sources.  Please do not distribute this without Patrick's consent, since he nicked from the internet.

Sources used: https://github.com/ANLGBOY/VAE-with-PyTorch/, https://avandekleut.github.io/vae/, https://github.com/pytorch/examples/blob/main/vae/main.py

In [None]:
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from lightning.pytorch.callbacks import TQDMProgressBar, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from torchvision import datasets, transforms
from torchvision.datasets import FashionMNIST, MNIST

import os
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline


if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    # the following may be necessary on Mac with M1/M2 arch and some versions of PyTorch, 
    # which allows it to fall back to CPU computing if the graphics card (MPS) implementation fails.
    os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
else:
    device = torch.device("cpu")
print("device:", device)

Create a lightning module, with lots of magic numbers.

In [None]:
class VAE(pl.LightningModule):
    def __init__(self, latent_dim, dataset):
        super().__init__()
        self.latent_dim = latent_dim
        self.data_dir = "./data"
        self.dataset = dataset
        
        self.batch_size = 128
        self.h1size = 512
        self.h2size = 256

        # Hardcode some dataset specific attributes
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.nnsize = channels * width * height

        self.fc1 = nn.Linear(self.nnsize, self.h1size)
        self.fc2 = nn.Linear(self.h1size, self.h2size)
        self.fc21 = nn.Linear(self.h2size, latent_dim)  # fc21 for mean of Z
        self.fc22 = nn.Linear(self.h2size, latent_dim)  # fc22 for log variance of Z
        self.fc3 = nn.Linear(latent_dim, self.h2size)
        self.fc4 = nn.Linear(self.h2size, self.h1size)
        self.fc5 = nn.Linear(self.h1size, self.nnsize)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        mu = self.fc21(h2)
        # use logvar instead of var since the output of fc22 can be negative (var is always positive)
        logvar = self.fc22(h2)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        h4 = F.relu(self.fc4(h3))
        out = self.fc5(h4)
        return torch.sigmoid(out)

    def forward(self, x):
        # x: [batch size, 1, 28,28] -> x: [batch size, 784]
        x = x.view(-1, self.nnsize)
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar


    def loss_function(self, recon_x, x, mu, logvar):
        # Reconstruction loss
        recon_loss = F.binary_cross_entropy(recon_x, x.view(-1, self.nnsize), reduction='sum')
        # KL divergence loss
        kld_loss = 0.5 * torch.sum(mu.pow(2) + logvar.exp() - logvar - 1)
        loss = (recon_loss + kld_loss) # / x.size(0)
        return loss

    def training_step(self, batch, batch_idx):
        x, _ = batch
        recon_x, mu, logvar = self(x)
        loss = self.loss_function(recon_x, x, mu, logvar)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        recon_x, mu, logvar = self(x)
        loss = self.loss_function(recon_x, x, mu, logvar)
        self.log('val_loss', loss)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)  # this learning rate is the default one for adam in Pytorch
        return optimizer

    def setup(self, stage=None):        
        dataset_full = eval(self.dataset)(self.data_dir, train=True, download=True, transform=transforms.ToTensor())
        self.dataset_train, self.dataset_val = random_split(dataset_full, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.dataset_train, batch_size=self.batch_size, shuffle=True, num_workers=8)

    def val_dataloader(self):
        return DataLoader(self.dataset_val, batch_size=self.batch_size, num_workers=8)
    
# now we create the vae, with a specific latent space
vae_fashion = VAE(2, 'FashionMNIST')
vae_mnist = VAE(2, 'MNIST')

# choose one we play with
model = vae_mnist

In [None]:
logger = TensorBoardLogger("tensorlogs", name="vae")
trainer = pl.Trainer(
    max_epochs=1,
    callbacks=[TQDMProgressBar(refresh_rate=20)],
    logger=logger,
)
trainer.fit(model)

In [None]:
%tensorboard --logdir logs/fit

## plot of reconstruction
Put an image in the encoder; then have it reconstructed at the decoder.  Show the images side-by-side.

In [None]:
model.eval; # put in evaluation mode
offset = 0  # you can plot other digits by increasing this
fig = plt.figure()
for i in range(4):
  plt.subplot(2,4,i+1)
  plt.tight_layout()
  plt.imshow(model.dataset_train.dataset[i+offset][0][0], cmap='gray', interpolation='none')
  plt.subplot(2,4,i+5)
  pred = model(model.dataset_train.dataset[i+offset][0])[0].detach().numpy().reshape(28,28)
  plt.imshow(pred, cmap='gray', interpolation='none')
  plt.xticks([])
  plt.yticks([])
model.train; # back in training mode

## train more
Go back to training more epochs.  Then you can go back to the above and check if the reconstruction improved.

In [None]:
trainer.fit_loop.max_epochs = 20
trainer.fit(model)

## now do some plotting
We show the distribution of classes in the 2-dimensional latent space

In [None]:
def plot_latent(autoencoder, data, num_batches=100):
    for i, (x, y) in enumerate(data):
        z, var = autoencoder.encode(x.view(-1, model.nnsize))
        z = z.to('cpu').detach().numpy()
        plt.scatter(z[:, 0], z[:, 1], c=y, cmap='tab10')
        if i > num_batches:
            plt.colorbar()
            break
plot_latent(model, model.train_dataloader() )

In [None]:
def plot_reconstructed(autoencoder, r0=(-1.5, 1.5), r1=(-1.5, 1.5), n=12):
    w = 28
    img = np.zeros((n*w, n*w))
    for i, y in enumerate(np.linspace(*r1, n)):
        for j, x in enumerate(np.linspace(*r0, n)):
            z = torch.Tensor([[x, y]]) #.to(device)
            x_hat = autoencoder.decode(z)
            x_hat = x_hat.reshape(28, 28).to('cpu').detach().numpy()
            img[(n-1-i)*w:(n-1-i+1)*w, j*w:(j+1)*w] = x_hat
    plt.imshow(img, extent=[*r0, *r1])
plot_reconstructed(model)