In [None]:
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.optim as optim

from torch.autograd import Variable

import numpy as np

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from torchvision import datasets, transforms

%matplotlib inline

In [None]:
use_cuda = torch.cuda.is_available()

In [None]:
batch_size = 16
noise_size = 100
hidden_size = 128
learning_rate = 1e-3

In [None]:
mnist = torch.utils.data.DataLoader(
    datasets.MNIST('data/mnist/raw/', train=True, download=True, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)

In [None]:
image_height = mnist.dataset.train_data.shape[1]
image_width = mnist.dataset.train_data.shape[2]
image_size = image_height * image_width
print(image_size)

In [None]:
def initialize_weights(m):
    if type(m) == nn.Linear:
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0.0)

In [None]:
class Encoder(nn.Module):
    
    def __init__(self):
        super(Encoder, self).__init__()
        
        self.h1 = nn.Linear(image_size, hidden_size)
        self.g1 = nn.ReLU()
        
        self.h_mu = nn.Linear(hidden_size, noise_size)
        self.h_var = nn.Linear(hidden_size, noise_size)
        
        self.apply(initialize_weights)

    def forward(self, A0):
        Z1 = self.h1(A0)
        A1 = self.g1(Z1)
        Z_mu = self.h_mu(A1)
        Z_logvar = self.h_var(A1)
        return Z_mu, Z_logvar

In [None]:
class Decoder(nn.Module):
    
    def __init__(self):
        super(Decoder, self).__init__()

        self.h1 = nn.Linear(noise_size, hidden_size)
        self.g1 = nn.ReLU()
        self.h2 = nn.Linear(hidden_size, image_size)
        self.g2 = nn.Sigmoid()
        
        self.apply(initialize_weights)

    def forward(self, A0):
        Z1 = self.h1(A0)
        A1 = self.g1(Z1)
        Z2 = self.h2(A1)
        A2 = self.g2(Z2)
        return A2

In [None]:
E = Encoder()
D = Decoder()

parameters = list(E.parameters()) + list(D.parameters())

solver = optim.Adam(parameters, lr=learning_rate)

In [None]:
if use_cuda:
    E = E.cuda()
    D = D.cuda()

In [None]:
def reparametrization(mu, logvar):
    gaussian_noise = Variable(torch.randn(batch_size, noise_size))
    if use_cuda:
        gaussian_noise = gaussian_noise.cuda()
    return mu + torch.exp(logvar / 2) * gaussian_noise

In [None]:
epochs = 10

for epoch in range(epochs):
    for iteration, (X, _) in enumerate(mnist):
        solver.zero_grad()

        X = Variable(X).view(-1, image_size)
        if use_cuda:
            X = X.cuda()

        Z_mu, Z_logvar = E(X)
        Z = reparametrization(Z_mu, Z_logvar)
        X_r = D(Z)

        r_loss = F.binary_cross_entropy(X_r, X, reduction="sum") / batch_size
        kl_loss = torch.mean(0.5 * torch.sum(torch.exp(Z_logvar) + Z_mu ** 2 - 1. - Z_logvar, 1))
        loss = r_loss + kl_loss

        loss.backward()
        solver.step()

    if use_cuda:
        loss = loss.cpu()
    print('Epoch {}; Loss: {}'.format(epoch + 1, loss.data.numpy()))