# "Auto-Encoding Variational Bayes" paper implementation - https://arxiv.org/pdf/1312.6114.pdfd

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image

import numpy as np
from matplotlib import pyplot as plt

to_tensor = transforms.ToTensor()
to_image = transforms.ToPILImage()

train_set = MNIST('./data', download=True, transform=to_tensor)
train_loader = DataLoader(train_set, shuffle=True, batch_size=128)


In [3]:
class VAE(nn.Module):

    def __init__(self, n_input=784, n_hidden=256, n_latent=2):
        super(VAE, self).__init__()
        self.n_input = n_input
        self.n_hidden = n_hidden
        self.n_latent = n_latent
        self.encoder = nn.Sequential(
            nn.Linear(self.n_input, self.n_hidden),
            nn.ReLU(),
        )
        self.z_mean = nn.Linear(self.n_hidden, self.n_latent)
        self.z_log_sigma = nn.Linear(self.n_hidden, self.n_latent)
        self.decoder = nn.Sequential(
            nn.Linear(self.n_latent, self.n_hidden),
            nn.ReLU(),
            nn.Linear(self.n_hidden, self.n_input),
            nn.Sigmoid()
        )

    def sample(self, mu, log_sigma):
        sigma = torch.exp(log_sigma/2)
        epsilon = torch.rand_like(sigma)
        return mu + epsilon * sigma

    def forward(self, x):
        x = x.view(-1, self.n_input)
        x = self.encoder(x)
        mu = self.z_mean(x)
        log_sigma = self.z_log_sigma(x)
        z = self.sample(mu, log_sigma)
        y = self.decoder(z)
        return y, mu, log_sigma


In [4]:
def calculate_loss(y, x, mu, log_sigma):
    KLdiv = -0.5*(1+log_sigma-mu.pow(2)-log_sigma.exp()).sum()
    recon_loss = nn.BCELoss(reduction='sum')(y, x)
    return KLdiv + recon_loss