<a href="https://colab.research.google.com/github/pelinsuciftcioglu/VAE/blob/main/VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Variational Auto-Encoder (VAE)**

VAE implementation inspired by:
- [Tomczak, J. M. (2021). Introduction to Deep Generative Modeling.](https://https://github.com/jmtomczak/intro_dgm)
- [CreativeAI: Deep Learning for Graphics Tutorial Code
](https://github.com/smartgeometry-ucl/dl4g/blob/master/variational_autoencoder.ipynb)





In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torch.utils
import torch.distributions
import torchvision
from torch import autograd  # TO USE "with autograd.detect_anomaly():"
import numpy as np
import matplotlib.pyplot as plt
from math import ceil


import os

use_gpu = True

In [1]:
class_names = ['Zero', 'One', 'Two', 'Three', 'Four',
               'Five', 'Six', 'Seven', 'Eight', 'Nine']

In [2]:
import math

def plot_data(file_name, num_images, images, labels):
  grid = math.ceil(math.sqrt(num_images))
  plt.figure(figsize=(grid*2,grid*2))
  for i in range(num_images):
      plt.subplot(grid,grid,i+1)
      plt.xticks([])
      plt.yticks([])
      plt.grid(False)     
      plt.imshow(images[i].reshape(28,28))
      plt.xlabel(class_names[labels[i]])      
  plt.savefig(file)

In [None]:
file = 'before.png'

x_train, y_train = next(iter(train_dataloader))

plot_data(file, 25, x_train, y_train)

In [2]:
def save_images(name, images, num_samples_x, num_samples_y, epoch=None):
  epoch_gen = [10, 30, 50, 70, 100]

  if (epoch == None) or (epoch in epoch_gen):

    images = images.cpu()
    x = images.detach().numpy()
    fig, ax = plt.subplots(num_samples_x, num_samples_y)

    for i, ax in enumerate(ax.flatten()):
      if (i < x.shape[0]):
        plottable_image = np.reshape(x[i], (28, 28))
        ax.imshow(plottable_image, cmap='gray')
      ax.axis('off')
    plt.savefig(name + '_images' + '_epoch_' + str(epoch) + '.pdf', bbox_inches='tight')
    plt.close()

In [18]:
PI = torch.from_numpy(np.asarray(np.pi))
EPS = 1.e-5

# DISTRBUTION FOR THE DATA (INPUT)

def log_categorical(x, x_new, num_classes, reduction=None, dim=None):
    x_one_hot = F.one_hot(x.long(), num_classes)
    log_p = x_one_hot * torch.log(torch.clamp(x_new, EPS, 1. - EPS))
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

def log_bernoulli(x, x_new, reduction=None, dim=None):
    x_new = torch.clamp(x_new, EPS, 1. - EPS)
    log_p = x * torch.log(x_new) + (1. - x) * torch.log(1. - x_new)
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

# DISTRIBUTION FOR THE VARIATIONAL INFERENCE

def log_normal_diag(x, mu, log_var, reduction=None, dim=None):
    D = x.shape[1]
    log_p = -0.5 * D * torch.log(2. * PI) - 0.5 * log_var - 0.5 * torch.exp(-log_var) * (x - mu)**2.
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p


# PRIOR DISTRIBUTIONS for p(z)

def log_standard_normal(x, reduction=None, dim=None):
    D = x.shape[1]
    log_p = -0.5 * D * torch.log(2. * PI) - 0.5 * x**2.
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

In [19]:
class Encoder(nn.Module):
    def __init__(self, D, H, L):
        super(Encoder, self).__init__()

        self.encoder_net = nn.Sequential(nn.Linear(D, H), nn.ReLU(), nn.Linear(H, H//2), nn.ReLU(), nn.Linear(H//2, 2*L))

    def encode(self, x):
      mu, log_var = self.forward(x)
      return mu, log_var

    def forward(self, x):
      h = self.encoder_net(x)
      
      # Split the neural network output for means and log variances
      mu, log_var =  torch.chunk(h, 2, dim=1)
      return mu, log_var

    def sample(self, mu, log_var):
      std = torch.exp(0.5 * log_var)
      # Sample epsilon ~ N(0,I)
      eps = torch.randn_like(std)
      # Reparameterization trick
      z = mu + eps * std
      return z

    def log_prob(self, mu, log_var, z):
      return log_normal_diag(z, mu, log_var)

In [20]:
class Decoder(nn.Module):
    def __init__(self, D, H, L, distribution, num_vals):
        super(Decoder, self).__init__()
        self.D = D
        self.distribution = distribution
        self.num_vals = num_vals

        self.decoder_net = nn.Sequential(nn.Linear(L, H//2), nn.ReLU(), nn.Linear(H//2, H), nn.ReLU(), nn.Linear(H, D * num_vals))

    def decode(self, z):
      x_new = self.forward(z)

      return x_new

    def forward(self, z):
      x_new = self.decoder_net(z)
      
      if self.distribution == 'categorical':
        b = x_new.shape[0]
        d = self.D
        x_new = x_new.reshape(b, d, self.num_vals)
        return torch.softmax(x_new, 2)
      
      elif self.distribution == 'bernoulli':
        return torch.sigmoid(x_new)

    def sample(self, z):
        x_generated = self.decode(z)

        if self.distribution == 'categorical':
            b = x_generated.shape[0]
            m = x_generated.shape[1]
            x_generated = x_generated.view(b, -1, self.num_vals)
            p = x_generated.view(-1, self.num_vals)
            x_generated = torch.multinomial(p, num_samples=1).view(b, m)

        elif self.distribution == 'bernoulli':
            x_generated = torch.bernoulli(x_generated)

        return x_generated

    

In [21]:
class Prior(nn.Module):
    def __init__(self, L, prior_distribution):
        super(Prior, self).__init__()
        self.L = L
        self.distribution = prior_distribution

    def sample(self, batch_size):
        z = torch.randn((batch_size, self.L))
        return z

    def log_prob(self, z):
      if (self.distribution == 'standard normal'):
        return log_standard_normal(z)

      # elif self.distribution == ''
        # return log____(z)

In [22]:
class VAE(nn.Module):
  def __init__(self, D, H, L, distribution, num_vals, prior_distribution):
        super(VAE, self).__init__()
        self.encoder = Encoder(D, H, L)
        self.decoder = Decoder(D, H, L, distribution, num_vals)
        self.prior = Prior(L, prior_distribution)


        self.num_vals = num_vals
        self.distribution = distribution
  
  def forward(self, x, epoch=None, reduction='avg'):
        # Flatten the input to a single dimension
        x = torch.flatten(x, start_dim=1)
        mu, log_var = self.encoder.encode(x)

        # Sample z based on the mean and log variance
        z = self.encoder.sample(mu, log_var)

        # Evidence Lower Bound
        ELBO = self.loss(x, z, mu, log_var, reduction, epoch)

        return ELBO
  
  def loss(self, x, z, mu, log_var, epoch=None, reduction='avg'):
    # Reconstruction Error
    RE = self.log_prob(x, z, epoch)

    # KL-Divergence
    KL = (self.prior.log_prob(z) - self.encoder.log_prob(mu, log_var, z)).sum(-1)

    if reduction == 'sum':
      # For training sum all the loss values in the batch then take the average in the training loop
      return -(RE + KL).sum() 
    else:
      # For testing calculate the mean of the loss values of the batch
      return -(RE + KL).mean()


  def log_prob(self, x, z, epoch=None):
    x_new = self.decoder.decode(z)

    x_reconstructed, ~ = torch.max(x_new, dim=2)

    b = x_new.shape[0]

    name = 'Results/'
    if (self.training):
      save_images(name + 'training_reconstructed', x_reconstructed, 10, ceil(b/10), epoch)
    else:
      save_images(name + 'test_reconstructed', x_reconstructed, 10, ceil(b/10), epoch)


    if self.distribution == 'categorical':
      log_prob = log_categorical(x, x_new, self.num_vals, reduction='sum', dim=-1).sum(-1)
            
    elif self.distribution == 'bernoulli':
      log_prob = log_bernoulli(x, x_new, reduction='sum', dim=-1)

    return log_prob

  def sample(self, batch_size=128):
    z = self.prior.sample(batch_size=batch_size)

    z = z.to(device)  # FOR GENERATING SAMPLES

    return self.decoder.sample(z)

In [8]:
import math
import numbers
import warnings
from enum import Enum

from PIL import Image

from torch import Tensor
from typing import List, Tuple, Any, Optional

try:
    import accimage
except ImportError:
    accimage = None

https://github.com/pytorch/vision/blob/a839796328cf4f789c9de5da0b3367b742b5a00c/torchvision/transforms/transforms.py

In [5]:
# CUSTOM TRANFORM FOR THE DATA WITHOUT NORMALIZING

class ToTensor:
    def __call__(self, pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
        Returns:
            Tensor: Converted image.
        """
        return to_tensor(np.array(pic))

    def __repr__(self):
        return self.__class__.__name__ + '()'

def to_tensor(pic):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
    This function does not support torchscript.
    See :class:`~torchvision.transforms.ToTensor` for more details.
    Args:
        pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
    Returns:
        Tensor: Converted image.
    """
    default_float_dtype = torch.get_default_dtype()

    if isinstance(pic, np.ndarray):
        # handle numpy array
        if pic.ndim == 2:
            pic = pic[:, :, None]

        img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
        # backward compatibility
        if isinstance(img, torch.ByteTensor):
            return img.to(dtype=default_float_dtype)
        else:
            return img

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.from_numpy(nppic).to(dtype=default_float_dtype)

    # handle PIL Image
    mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32}
    img = torch.from_numpy(
        np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True)
    )

    if pic.mode == '1':
        img = 255 * img
    img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
    # put it from HWC to CHW format
    img = img.permute((2, 0, 1)).contiguous()
    if isinstance(img, torch.ByteTensor):
        return img.to(dtype=default_float_dtype)
    else:
        return img

MNIST FROM: http://yann.lecun.com/exdb/mnist/

In [6]:
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

img_transform = transforms.Compose([ToTensor()])

batch_size = 32

train_dataset = MNIST(root='./data/MNIST', download=True, train=True, transform=img_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


test_dataset = MNIST(root='./data/MNIST', download=True, train=False, transform=img_transform)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting ./data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting ./data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting ./data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting ./data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [11]:
# DISTRUBUTION FOR THE INPUT DATA
likelihood_type = 'categorical'

if likelihood_type == 'categorical':
    num_vals = 256
elif likelihood_type == 'bernoulli':
    num_vals = 1

prior_distribution = 'standard normal'

In [12]:
D = 28*28 # INPUT DATA DIMENSIONALITY
H = 512   # HIDDEN LAYER NODES
L = 10    # LATENT VARIABLES DIMENSIONALITY

learning_rate = 0.0001
num_epochs = 100

In [13]:
model = VAE(D, H, L, likelihood_type, num_vals, prior_distribution)

device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
model = model.to(device)

# ADAM OPTIMIZER
optimizer = torch.optim.Adam([p for p in model.parameters() if p.requires_grad == True], lr=learning_rate)

In [14]:
results_dir = 'Results/'
if not(os.path.exists(results_dir)):
    os.mkdir(results_dir)

model_name = 'VAE'

In [15]:
def generate_samples(name, extra_name):
  
  model_best = torch.load(name + '.model')

  model_best = model_best.to(device)

  model_best.eval()

  num_x = 4
  num_y = 4

  x = torch.tensor(0, device=device)

  x = model_best.sample(num_x * num_y)

  x = x.cpu()

  save_images(name + '_generated' + extra_name, x, num_x, num_y)  

In [16]:
train_loss_avgs = []
best_loss = 1000.

print("Training...")

name = results_dir + model_name

for epoch in range(num_epochs):
  model.train()

  train_loss_avgs.append(0)
  num_batches = 0
  for batch_idx, (data, labels) in enumerate(train_dataloader, 1):
    
    data = data.to(device)

    loss = model.forward(data, epoch=epoch+1) 

    optimizer.zero_grad()
    loss.backward(retain_graph=True)    
    optimizer.step()

    train_loss_avgs[-1] += loss.item()
    num_batches += 1

  train_loss_avgs[-1] /= num_batches
  print('Epoch [%d / %d] average training loss: %f' % (epoch+1, num_epochs, train_loss_avgs[-1]))

  if epoch == 0:
    print('Model Saved!')
    torch.save(model, name + '.model')
    best_loss = train_loss_avgs[-1]

  else:
    if train_loss_avgs[-1] < best_loss:
      print('Model Saved!')
      torch.save(model, name + '.model')
      best_loss = train_loss_avgs[-1]

      generate_samples(name, "_epoch_" + str(epoch))


Training...
Epoch [1 / 100] average training loss: 2408526.381210
Model Saved!
Epoch [2 / 100] average training loss: 1039.382920
Model Saved!
Epoch [3 / 100] average training loss: 977.333445
Model Saved!
Epoch [4 / 100] average training loss: 932.457492
Model Saved!
Epoch [5 / 100] average training loss: 896.440575
Model Saved!
Epoch [6 / 100] average training loss: 870.693872
Model Saved!
Epoch [7 / 100] average training loss: 851.778546
Model Saved!
Epoch [8 / 100] average training loss: 837.236467
Model Saved!
Epoch [9 / 100] average training loss: 825.100642
Model Saved!
Epoch [10 / 100] average training loss: 814.785660
Model Saved!
Epoch [11 / 100] average training loss: 806.026337
Model Saved!
Epoch [12 / 100] average training loss: 798.280636
Model Saved!
Epoch [13 / 100] average training loss: 791.317552
Model Saved!
Epoch [14 / 100] average training loss: 785.011127
Model Saved!
Epoch [15 / 100] average training loss: 779.033226
Model Saved!
Epoch [16 / 100] average trainin

In [17]:
model_best = torch.load(name + '.model')

# set to evaluation mode
model.eval()

test_loss, num_batches = 0, 0

for data, labels in test_dataloader:
    
  data = data.to(device)  

  loss = model.forward(data, reduction = 'sum')

  test_loss += loss.item()
  num_batches = num_batches + data.shape[0]
    
test_loss /= num_batches
print('Average test loss: %f' % (test_loss))

0 loss: tensor(846.8008, device='cuda:0', grad_fn=<NegBackward>)
128 loss: tensor(812.4882, device='cuda:0', grad_fn=<NegBackward>)
256 loss: tensor(844.2125, device='cuda:0', grad_fn=<NegBackward>)
384 loss: tensor(855.6114, device='cuda:0', grad_fn=<NegBackward>)
512 loss: tensor(900.8594, device='cuda:0', grad_fn=<NegBackward>)
640 loss: tensor(861.2792, device='cuda:0', grad_fn=<NegBackward>)
768 loss: tensor(875.0887, device='cuda:0', grad_fn=<NegBackward>)
896 loss: tensor(839.6137, device='cuda:0', grad_fn=<NegBackward>)
1024 loss: tensor(859.0531, device='cuda:0', grad_fn=<NegBackward>)
1152 loss: tensor(866.0262, device='cuda:0', grad_fn=<NegBackward>)
1280 loss: tensor(856.2942, device='cuda:0', grad_fn=<NegBackward>)
1408 loss: tensor(900.4405, device='cuda:0', grad_fn=<NegBackward>)
1536 loss: tensor(848.0768, device='cuda:0', grad_fn=<NegBackward>)
1664 loss: tensor(842.4064, device='cuda:0', grad_fn=<NegBackward>)
1792 loss: tensor(849.1746, device='cuda:0', grad_fn=<NegB

In [19]:
# REAL SAMPLES
num_x = 4
num_y = 4

test_data, test_labels = next(iter(test_dataloader))

save_images(name + 'test_real_', test_data, num_x, num_y)

In [None]:
# plt.ion()

fig = plt.figure()
plt.plot(train_loss_avgs)
plt.xlabel('Epochs')
plt.ylabel('Negative Log Likelihood Loss')
plt.show()

print(train_loss_avgs)