<a href="https://colab.research.google.com/github/pelinsuciftcioglu/VAE/blob/main/VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Variational Auto-Encoder (VAE)**

In [None]:
import torch
import torch.nn as nn
import torch.optim
import torch.utils
import torch.distributions
import torchvision
import numpy as np
import matplotlib.pyplot as plt

In [None]:
PI = torch.from_numpy(np.asarray(np.pi))
EPS = 1.e-5

def log_categorical(x, p, num_classes, reduction=None, dim=None):
    x_one_hot = F.one_hot(x.long(), num_classes=num_classes)
    log_p = x_one_hot * torch.log(torch.clamp(p, EPS, 1. - EPS))
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

def log_bernoulli(x, x_new, reduction=None, dim=None):
    x_new = torch.clamp(x_new, EPS, 1. - EPS)
    log_p = x * torch.log(x_new) + (1. - x) * torch.log(1. - x_new)
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

In [None]:
class Encoder(nn.Module):
    def __init__(self, D, H, L):
        super(Encoder, self).__init__()

        self.encoder_net = nn.Sequential(nn.Linear(D, H), nn.LeakyReLU(), nn.Linear(H, 2*L))

    def encode(self, x):
      mu, log_var = self.forward(x)
      return mu, log_var

    def forward(self, x):
      x = torch.flatten(x, start_dim=1)
      h = self.encoder_net(x)
      mu, log_var =  torch.chunk(h, 2, dim=1)

      return mu, log_var

    def sample(self, mean, log_var):
      std = torch.exp(0.5 * log_var)
      # Sample epsilon ~ N(0,I)
      eps = torch.randn_like(std)
      # Reparameterization trick
      z = mu + eps * std

      return z

    def log_p(self, ):
      

In [None]:
class Decoder(nn.Module):
    def __init__(self, D, H, L, distribution, num_vals):
        super(Decoder, self).__init__()
        self.D = D
        self.num_vals = numm_vals

        self.decoder_net = nn.Sequential(nn.Linear(L, H), nn.LeakyReLU(),
                                         nn.Linear(H, D * num_vals))

    def decode(self, z):
      x_new = self.forward(z)
      
      return x_new

    def forward(self, z):
      x_new = self.decoder_net(z)
      
      if self.distribution == 'categorical':
        b = x_new.shape[0]
        d = self.D
        x_new.reshape(b, d, self.num_vals)
        return torch.softmax(x_new, 2)
      
      elif self.distribution == 'bernoulli':
        return torch.sigmoid(x_new)

In [None]:
class Prior(nn.Module):
    def __init__(self, L, prior_distribution):
        super(Prior, self).__init__()
        self.L = L

    def sample(self, batch_size):
        z = torch.randn((batch_size, self.L))
        return z

    def log_prob(self, z):
        return log_standard_normal(z)

In [None]:
class VAE(nn.Module):
  def __init__(self, D, H, L, num_vals, distribution): # Initializations should be added!!!!
        super(VAE, self).__init__()
        self.encoder = Encoder(D, H, L)
        self.decoder = Decoder(D, H, L, distribution, num_vals)
        self.num_vals = num_vals
        self.distribution = distribution
  
  def forward(self, x, reduction='avg'):
        mu, log_var = self.encoder.encode(x)
        z = self.encoder.sample(mu, log_var)

        ELBO = loss(x, z, mu, log_var, reduction)

        return ELBO
  
  def loss(x, z, mu, log_var, reduction='avg'):
    # Reconstruction Error
    RE = self.log_prob(x, z)

    # KL-Divergence
    KL = (self.prior.log_prob(z) - self.encoder.log_prob(mu, log_var, z)).sum(-1)

    if reduction == 'sum':
      return -(RE + KL).sum()
    else:
      return -(RE + KL).mean()


  def log_prob(self, x, z):
    x_new = self.decoder.decode(x)

    if self.distribution == 'categorical':
      log_prob = log_categorical(x, x_new, num_classes=self.num_vals, reduction='sum', dim=-1).sum(-1)
            
    elif self.distribution == 'bernoulli':
      log_prob = log_bernoulli(x, x_new, reduction='sum', dim=-1)

    return log_prob


In [None]:
D = 64
H = 128
L = 16

likelihood_type = 'categorical'

if likelihood_type == 'categorical':
    num_vals = 17
elif likelihood_type == 'bernoulli':
    num_vals = 1

In [None]:
model = VAE(D, H, L, likelihood_type, num_vals)