<a href="https://colab.research.google.com/github/shainedl/Papers-Colab/blob/master/Autoencoding_Variational_Bayes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Based on *Auto-Encoding Variational Bayes* by Diederick P Kigma and Max Welling (Machine Learning Group, Universiteit van Amsterdam)

In [0]:
import torch
from torch import nn, optim
import torch.nn.functional as F
import torchvision 

In [0]:
class VAE(nn.Module):
  def __init__(self, input_size, hidden_units, N_z):
    super(VAE, self).__init__()
    
    self.fc1 = nn.Linear(input_size, hidden_units)
    self.fc21 = nn.Linear(hidden_units, N_z)
    self.fc22 = nn.Linear(hidden_units, N_z)
    self.fc3 = nn.Linear(N_z, hidden_units)
    self.fc4 = nn.Linear(hidden_units, input_size)
    
    self.input_size = input_size
  
  def encode(self, x):
    """
    According to Appendix C.2
    """
    h_e  = F.tanh(self.fc1(x.view(-1,self.input_size)))
    mu = self.fc21(h_e)
    logvar = self.fc22(h_e)
    
    return mu, logvar
  
  def decode(self, z):
    """
    According to Appendix C.1
    """
    h_d = F.tanh(self.fc3(z))
    
    return F.sigmoid(self.fc4(h_d))
  
  def forward(self, x):
    mu, logvar = self.encode(x)
    z = self.__reparameterize(mu, logvar)
    
    return self.decode(z), mu, logvar
  
  def __reparameterize(self, mu, logvar):
    std = torch.exp(logvar / 2)
    eps = torch.randn_like(std)
    
    return mu + std * eps
    

In [0]:
def loss_function(mu, logvar, y, x):
  """
  KL according to Appendix B
  """
  KL = torch.sum(1 + logvar - mu**2 - torch.exp(logvar)) / 2
  
  RE = F.binary_cross_entropy(y, x, reduction = 'sum')
  
  elbo = KL - RE
  loss = -1 * elbo
  
  return loss

In [0]:
def train(num_epochs):
  
  for epoch in range(num_epochs):
    
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader):
      inputs, labels = data
      
      # zero the parameter gradients
      optimizer.zero_grad()

      # forward + backward + optimize
      y, mu, logvar = model(inputs)
      loss = loss_function(mu, logvar, y, inputs)
      loss.backward()
      optimizer.step()
      
      # print statistics
      running_loss += loss.item()
      if batch_idx % 50 == 49:    # print every 50 mini-batches
        print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 50))
        running_loss = 0.0
            
  print('Finished Training')          
      

In [7]:
batch_size = 100
"""
https://nextjournal.com/gkoehler/pytorch-mnist
Remove the normalization to create Bernoulli data
"""
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor()
                             ])),
  batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor()
                             ])),
  batch_size=batch_size, shuffle=True)

0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /files/MNIST/raw/train-images-idx3-ubyte.gz


100%|█████████▉| 9912320/9912422 [00:13<00:00, 592787.95it/s]

Extracting /files/MNIST/raw/train-images-idx3-ubyte.gz to /files/MNIST/raw



0it [00:00, ?it/s][A
  0%|          | 0/28881 [00:00<?, ?it/s][A

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /files/MNIST/raw/train-labels-idx1-ubyte.gz



 57%|█████▋    | 16384/28881 [00:00<00:00, 109639.61it/s][A
32768it [00:00, 87032.84it/s]                            [A
0it [00:00, ?it/s][A
  0%|          | 0/1648877 [00:00<?, ?it/s][A

Extracting /files/MNIST/raw/train-labels-idx1-ubyte.gz to /files/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /files/MNIST/raw/t10k-images-idx3-ubyte.gz



  2%|▏         | 32768/1648877 [00:00<00:08, 185774.76it/s][A
  4%|▍         | 73728/1648877 [00:00<00:07, 206731.49it/s][A
  7%|▋         | 122880/1648877 [00:00<00:06, 234017.49it/s][A
 10%|█         | 172032/1648877 [00:00<00:05, 257342.56it/s][A
 12%|█▏        | 196608/1648877 [00:00<00:06, 222266.73it/s][A
 16%|█▌        | 262144/1648877 [00:01<00:06, 223766.96it/s][A
 19%|█▉        | 319488/1648877 [00:01<00:05, 257305.29it/s][A
 21%|██▏       | 352256/1648877 [00:01<00:06, 211921.83it/s][A
 23%|██▎       | 385024/1648877 [00:01<00:05, 216473.13it/s][A
 25%|██▍       | 409600/1648877 [00:01<00:06, 200411.86it/s][A
 26%|██▋       | 434176/1648877 [00:02<00:06, 190408.83it/s][A
 28%|██▊       | 458752/1648877 [00:02<00:06, 184364.70it/s][A
 30%|██▉       | 491520/1648877 [00:02<00:05, 195549.20it/s][A
 31%|███▏      | 516096/1648877 [00:02<00:06, 187058.10it/s][A
 33%|███▎      | 540672/1648877 [00:02<00:06, 180065.23it/s][A
 34%|███▍      | 565248/1648877 [00:02<00

Extracting /files/MNIST/raw/t10k-images-idx3-ubyte.gz to /files/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /files/MNIST/raw/t10k-labels-idx1-ubyte.gz




  0%|          | 0/4542 [00:00<?, ?it/s][A[A

8192it [00:00, 53752.18it/s]            [A[A

Extracting /files/MNIST/raw/t10k-labels-idx1-ubyte.gz to /files/MNIST/raw
Processing...
Done!


In [25]:
len(train_loader)

600

In [0]:
examples = enumerate(train_loader)
batch_idx, (example_data, example_targets) = next(examples)

In [36]:
example_data[0].size()

torch.Size([1, 28, 28])

In [71]:
# input_size = 28 * 28 = 784
model = VAE(784, 500, 20)
optimizer = optim.SGD(model.parameters(), lr=1e-3)
train(3)

  import sys


RuntimeError: ignored