<a href="https://colab.research.google.com/github/pelinsuciftcioglu/VAE/blob/main/VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Variational Auto-Encoder (VAE)**

VAE implementation inspired by:
- [Tomczak, J. M. (2021). Introduction to Deep Generative Modeling.](https://https://github.com/jmtomczak/intro_dgm)
- [CreativeAI: Deep Learning for Graphics Tutorial Code
](https://github.com/smartgeometry-ucl/dl4g/blob/master/variational_autoencoder.ipynb)





In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torch.utils
import torch.distributions
import torchvision
from torch import autograd  # TO USE "with autograd.detect_anomaly():"
import numpy as np
import matplotlib.pyplot as plt

import os

use_gpu = True

In [2]:
PI = torch.from_numpy(np.asarray(np.pi))
EPS = 1.e-5

# DISTRBUTION FOR THE DATA (INPUT)

def log_categorical(x, x_new, num_classes, reduction=None, dim=None):
    x_one_hot = F.one_hot(x.long(), num_classes)
    log_p = x_one_hot * torch.log(torch.clamp(x_new, EPS, 1. - EPS))
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

def log_bernoulli(x, x_new, reduction=None, dim=None):
    x_new = torch.clamp(x_new, EPS, 1. - EPS)
    log_p = x * torch.log(x_new) + (1. - x) * torch.log(1. - x_new)
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

# DISTRIBUTION FOR THE VARIATIONAL INFERENCE

def log_normal_diag(x, mu, log_var, reduction=None, dim=None):
    D = x.shape[1]
    log_p = -0.5 * D * torch.log(2. * PI) - 0.5 * log_var - 0.5 * torch.exp(-log_var) * (x - mu)**2.
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p


# PRIOR DISTRIBUTIONS for p(z)

def log_standard_normal(x, reduction=None, dim=None):
    D = x.shape[1]
    log_p = -0.5 * D * torch.log(2. * PI) - 0.5 * x**2.
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

In [3]:
class Encoder(nn.Module):
    def __init__(self, D, H, L):
        super(Encoder, self).__init__()

        self.encoder_net = nn.Sequential(nn.Linear(D, H), nn.LeakyReLU(), nn.Linear(H, H), nn.LeakyReLU(), nn.Linear(H, 2*L))

    def encode(self, x):
      mu, log_var = self.forward(x)
      return mu, log_var

    def forward(self, x):
      h = self.encoder_net(x)
      mu, log_var =  torch.chunk(h, 2, dim=1)

      return mu, log_var

    def sample(self, mu, log_var):
      std = torch.exp(0.5 * log_var)
      # Sample epsilon ~ N(0,I)
      eps = torch.randn_like(std)
      # Reparameterization trick
      z = mu + eps * std

      return z

    def log_prob(self, mu, log_var, z):
      return log_normal_diag(z, mu, log_var)

In [4]:
class Decoder(nn.Module):
    def __init__(self, D, H, L, distribution, num_vals):
        super(Decoder, self).__init__()
        self.D = D
        self.distribution = distribution
        self.num_vals = num_vals

        self.decoder_net = nn.Sequential(nn.Linear(L, H), nn.LeakyReLU(), nn.Linear(H, H), nn.LeakyReLU(), nn.Linear(H, D * num_vals))

    def decode(self, z):
      x_new = self.forward(z)

      return x_new

    def forward(self, z):
      x_new = self.decoder_net(z)
      
      if self.distribution == 'categorical':
        b = x_new.shape[0]
        d = self.D
        x_new = x_new.reshape(b, d, self.num_vals)
        return torch.softmax(x_new, 2)
      
      elif self.distribution == 'bernoulli':
        return torch.sigmoid(x_new)

In [5]:
class Prior(nn.Module):
    def __init__(self, L, prior_distribution):
        super(Prior, self).__init__()
        self.L = L
        self.distribution = prior_distribution

    def sample(self, batch_size):
        z = torch.randn((batch_size, self.L))
        return z

    def log_prob(self, z):
      if (self.distribution == 'standard normal'):
        return log_standard_normal(z)

      # elif self.distribution == ''
        # return log____(z)

In [6]:
class VAE(nn.Module):
  def __init__(self, D, H, L, distribution, num_vals, prior_distribution):
        super(VAE, self).__init__()
        self.encoder = Encoder(D, H, L)
        self.decoder = Decoder(D, H, L, distribution, num_vals)
        self.prior = Prior(L, prior_distribution)


        self.num_vals = num_vals
        self.distribution = distribution
  
  def forward(self, x, reduction='avg'):
        x = torch.flatten(x, start_dim=1)
        mu, log_var = self.encoder.encode(x)
        z = self.encoder.sample(mu, log_var)

        ELBO = self.loss(x, z, mu, log_var, reduction)

        return ELBO
  
  def loss(self, x, z, mu, log_var, reduction='avg'):
    # Reconstruction Error
    RE = self.log_prob(x, z)

    # KL-Divergence
    KL = (self.prior.log_prob(z) - self.encoder.log_prob(mu, log_var, z)).sum(-1)

    if reduction == 'sum':
      return -(RE + KL).sum()
    else:
      return -(RE + KL).mean()


  def log_prob(self, x, z):
    x_new = self.decoder.decode(z)

    if self.distribution == 'categorical':
      log_prob = log_categorical(x, x_new, self.num_vals, reduction='sum', dim=-1).sum(-1)
            
    elif self.distribution == 'bernoulli':
      log_prob = log_bernoulli(x, x_new, reduction='sum', dim=-1)

    return log_prob

  def sample(self, batch_size=128):
    z = self.prior.sample(batch_size=batch_size)
    return self.decoder.sample(z)


In [7]:
import math
import numbers
import warnings
from enum import Enum

from PIL import Image

from torch import Tensor
from typing import List, Tuple, Any, Optional

try:
    import accimage
except ImportError:
    accimage = None

In [8]:
# CUSTOM TRANFORM FOR THE DATA WITHOUT SCALING

class ToTensor:
    def __call__(self, pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
        Returns:
            Tensor: Converted image.
        """
        return to_tensor(np.array(pic))

    def __repr__(self):
        return self.__class__.__name__ + '()'

def to_tensor(pic):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
    This function does not support torchscript.
    See :class:`~torchvision.transforms.ToTensor` for more details.
    Args:
        pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
    Returns:
        Tensor: Converted image.
    """
    default_float_dtype = torch.get_default_dtype()

    if isinstance(pic, np.ndarray):
        # handle numpy array
        if pic.ndim == 2:
            pic = pic[:, :, None]

        img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
        # backward compatibility
        if isinstance(img, torch.ByteTensor):
            return img.to(dtype=default_float_dtype).div(255)
        else:
            return img

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.from_numpy(nppic).to(dtype=default_float_dtype)

    # handle PIL Image
    mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32}
    img = torch.from_numpy(
        np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True)
    )

    if pic.mode == '1':
        img = 255 * img
    img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
    # put it from HWC to CHW format
    img = img.permute((2, 0, 1)).contiguous()
    if isinstance(img, torch.ByteTensor):
        return img.to(dtype=default_float_dtype)
    else:
        return img

MNIST FROM: http://yann.lecun.com/exdb/mnist/

In [9]:
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

img_transform = transforms.Compose([ToTensor()])

batch_size = 64

train_dataset = MNIST(root='./data/MNIST', download=True, train=True, transform=img_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


test_dataset = MNIST(root='./data/MNIST', download=True, train=False, transform=img_transform)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting ./data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting ./data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting ./data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting ./data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [10]:
# DISTRUBUTION FOR THE INPUT DATA
likelihood_type = 'categorical'

if likelihood_type == 'categorical':
    num_vals = 256
elif likelihood_type == 'bernoulli':
    num_vals = 1

prior_distribution = 'standard normal'

In [11]:
D = 28*28 # INPUT DATA DIMENSIONALITY
H = 100   # HIDDEN LAYER NODES
L = 20    # LATENT VARIABLES DIMENSIONALITY

learning_rate = 1e-3
num_epochs = 100

In [12]:
model = VAE(D, H, L, likelihood_type, num_vals, prior_distribution)

device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
model = model.to(device)

# ADAM OPTIMIZER
optimizer = torch.optim.Adam([p for p in model.parameters() if p.requires_grad == True], lr=learning_rate)

In [18]:
results_dir = 'Results/'
if not(os.path.exists(results_dir)):
    os.mkdir(results_dir)

model_name = 'VAE'

In [19]:
def generate_samples(name, data_loader, extra_name):
    
  x = next(iter(data_loader).detach().numpy()

  model_best = torch.load(name + '.model')
  model_best.eval()

  num_x = 4
  num_y = 4
  x = model_best.sample(num_x * num_y)
  x = x.detach().numpy()

  fig, ax = plt.subplots(num_x, num_y)
  for i, ax in enumerate(ax.flatten()):
    plottable_image = np.reshape(x[i], (8, 8))
    ax.imshow(plottable_image, cmap='gray')
    ax.axis('off')

  plt.savefig(name + '_generated_images' + extra_name + '.pdf', bbox_inches='tight')
  plt.close()



In [20]:
train_loss_avgs = []
best_loss = 1000.

print("Training...")

name = results_dir + model_name

for epoch in range(num_epochs):
  model.train()

  train_loss_avgs.append(0)
  num_batches = 0
  for batch_idx, (data, labels) in enumerate(train_dataloader, 1):
    

      data = data.to(device)


      loss = model.forward(data) 


      optimizer.zero_grad()
      loss.backward(retain_graph=True)    
      optimizer.step()

      
      train_loss_avgs[-1] += loss.item()
      num_batches += 1

  train_loss_avgs[-1] /= num_batches
  print('Epoch [%d / %d] average training loss: %f' % (epoch+1, num_epochs, train_loss_avgs[-1]))

  if epoch == 0:
    print('Model Saved!')
    torch.save(model, name + '.model')

  else:
    if train_loss_avgs[-1] < best_loss:
      print('Model Saved!')
      torch.save(model, name + '.model')
      best_loss = train_loss_avgs[-1]

      generate_samples(name, train_dataloader, "_epoch_" + str(epoch))




Training...
Epoch [1 / 100] average training loss: 27.671579
Model Saved!
Epoch [2 / 100] average training loss: 27.494068
Model Saved!


AttributeError: ignored

In [None]:
model_best = torch.load(trained_dir_name + '.model')

# set to evaluation mode
model.eval()

test_loss_avg, num_batches = 0, 0

for batch_idx, (data, labels) in enumerate(test_dataloader):
    
  data = data.to(device)  

  loss = model.forward(data, reduction = 'sum') 

  test_loss_avgs += loss.item()
  num_batches += 1
    
test_loss_avg /= num_batches
print('Average test loss: %f' % (test_loss_avg))

In [None]:
f = open(result_dir + name + '_test_loss.txt', "w")
f.write(str(test_loss_avg))
f.close()

# REAL SAMPLES

num_x = 4
num_y = 4

x = next(iter(test_loader)).detach().numpy()
fig, ax = plt.subplots(num_x, num_y)
for i, ax in enumerate(ax.flatten()):
    plottable_image = np.reshape(x[i], (8, 8))
    ax.imshow(plottable_image, cmap='gray')
    ax.axis('off')
plt.savefig(trained_dir_name+'_real_images.pdf', bbox_inches='tight')
plt.close()

In [None]:
# plt.ion()

fig = plt.figure()
plt.plot(train_loss_avgs)
plt.xlabel('Epochs')
plt.ylabel('Negative Log Likelihood Loss')
plt.show()