<a href="https://colab.research.google.com/github/shainedl/Papers-Colab/blob/master/Autoencoding_Variational_Bayes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Based on *Auto-Encoding Variational Bayes* by Diederick P Kigma and Max Welling (Machine Learning Group, Universiteit van Amsterdam)

In [0]:
import torch
import torch.nn as nn

In [0]:
class VAE(nn.Module):
  def __init__(self, input_size, hidden_units, N_z):
    super(VAE, self).__init__()
    
    self.fc1 = nn.Linear(input_size, hidden_units)
    self.fc21 = nn.Linear(hidden_units, N_z)
    self.fc22 = nn.Linear(hidden_units, N_z)
    self.fc3 = nn.Linear(N_z, hidden_units)
    self.fc4 = nn.Linear(hidden_units, input_size)
    self.tanh = nn.Tanh()
  
  def encode(self, x):
    h_e  = self.tanh(self.fc1(x))
    mu = self.fc21(h_e)
    logvar = self.fc22(h_e)
    
    return mu, logvar
  
  def decode(self, z):
    h_d = self.tanh(self.fc3(z))
    
    return self.tanh(self.fc4(h_d))
  
  def forward(self, x):
    mu, logvar = self.encode(x)
    z = self.__reparameterize(mu, logvar)
    
    return self.decode(z)
  
  def __reparameterize(self, mu, logvar):
    std = torch.exp(logvar / 2)
    eps = torch.randn_like(std)
    
    return mu + std * eps
    