<a href="https://colab.research.google.com/github/shainedl/Papers-Colab/blob/master/Autoencoding_Variational_Bayes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Based on *Auto-Encoding Variational Bayes* by Diederick P Kigma and Max Welling (Machine Learning Group, Universiteit van Amsterdam)

In [0]:
import torch
from torch import nn, optim
import torchvision 
import torch.nn.functional as F

In [0]:
class VAE(nn.Module):
  def __init__(self, input_size, hidden_units, N_z):
    super(VAE, self).__init__()
    
    self.fc1 = nn.Linear(input_size, hidden_units)
    self.fc21 = nn.Linear(hidden_units, N_z)
    self.fc22 = nn.Linear(hidden_units, N_z)
    self.fc3 = nn.Linear(N_z, hidden_units)
    self.fc4 = nn.Linear(hidden_units, input_size)
    
    self.input_size = input_size
  
  def encode(self, x):
    """
    According to Appendix C.2
    """
    h_e  = torch.tanh(self.fc1(x.view(-1,self.input_size)))
    mu = self.fc21(h_e)
    logvar = self.fc22(h_e)
    
    return mu, logvar
  
  def decode(self, z):
    """
    According to Appendix C.1
    """
    h_d = torch.tanh(self.fc3(z))
    
    return torch.sigmoid(self.fc4(h_d))
  
  def forward(self, x):
    mu, logvar = self.encode(x)
    z = self.__reparameterize(mu, logvar)
    
    return self.decode(z), mu, logvar
  
  def __reparameterize(self, mu, logvar):
    std = torch.exp(logvar / 2)
    eps = torch.randn_like(std)
    
    return mu + std * eps
    

In [0]:
def loss_function(mu, logvar, y, x):
  """
  KL according to Appendix B
  """
  KL = torch.sum(1 + logvar - mu**2 - torch.exp(logvar)) / 2
  
  RE = F.binary_cross_entropy(y, x.view(-1,784), reduction = 'sum')
  
  elbo = KL - RE
  loss = -1 * elbo
  
  return loss

In [0]:
def train(num_epochs):
  
  for epoch in range(num_epochs):
    
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader):
      inputs, labels = data
      
      # zero the parameter gradients
      optimizer.zero_grad()
      # forward + backward + optimize
      y, mu, logvar = model(inputs)
      loss = loss_function(mu, logvar, y, inputs)
      loss.backward()
      optimizer.step()
      
      # print statistics
      running_loss += loss.item()
      if batch_idx % 50 == 49:    # print every 50 mini-batches
        print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 50))
        running_loss = 0.0
            
  print('Finished Training')          
      

In [0]:
batch_size = 100
"""
https://nextjournal.com/gkoehler/pytorch-mnist
Remove the normalization to create Bernoulli data
"""
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor()
                             ])),
  batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor()
                             ])),
  batch_size=batch_size, shuffle=True)

In [25]:
len(train_loader)

600

In [0]:
examples = enumerate(train_loader)
batch_idx, (example_data, example_targets) = next(examples)

In [36]:
example_data[0].size()

torch.Size([1, 28, 28])

In [118]:
# input_size = 28 * 28 = 784
model = VAE(784, 500, 20)
optimizer = optim.SGD(model.parameters(), lr=1e-4)
train(3)

[1,    50] loss: 26383.101
[1,   100] loss: 21038.506
[1,   150] loss: 19653.496
[1,   200] loss: 18316.511
[1,   250] loss: 17852.317
[1,   300] loss: 17086.552
[1,   350] loss: 16694.581
[1,   400] loss: 16436.039
[1,   450] loss: 16014.190
[1,   500] loss: 15929.894
[1,   550] loss: 15642.902
[1,   600] loss: 15360.051
[2,    50] loss: 15224.824
[2,   100] loss: 14919.626
[2,   150] loss: 14835.653
[2,   200] loss: 14595.623
[2,   250] loss: 14546.039
[2,   300] loss: 14407.812
[2,   350] loss: 14269.965
[2,   400] loss: 14138.437
[2,   450] loss: 13887.065
[2,   500] loss: 13855.277
[2,   550] loss: 14032.742
[2,   600] loss: 13937.020
[3,    50] loss: 13700.414
[3,   100] loss: 13669.763
[3,   150] loss: 13465.537
[3,   200] loss: 13620.297
[3,   250] loss: 13330.999
[3,   300] loss: 13465.110
[3,   350] loss: 13211.208
[3,   400] loss: 13256.086
[3,   450] loss: 13245.867
[3,   500] loss: 13128.460
[3,   550] loss: 13283.745
[3,   600] loss: 13129.416
Finished Training
