In [1]:
# standard imports
import torch
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
# own imports
import sys
sys.path.append("../../ml-library/")

from models import VariationalAutoencoder
from layers import GaussianSample
from data import get_mnist

In [3]:
# set device
if torch.cuda.is_available():
    device = "cuda:0"
    torch.cuda.empty_cache()
else: 
    device = "cpu"

print(f"Using device {device}")

Using device cuda:0


# Model Definition

In [4]:
# Model Parameters
x_dim   = 784
z_dim   = 32
h_dims  = [256, 128]

In [5]:
# Define model (and port to device)
model = VariationalAutoencoder([x_dim, z_dim, h_dims])
model.to(device)

print(model)

# Check if model is on cuda
# next(model.parameters()).is_cuda

  init.xavier_normal(m.weight.data)
VariationalAutoencoder(
  (encoder): Encoder(
    (hidden): ModuleList(
      (0): Linear(in_features=784, out_features=256, bias=True)
      (1): Linear(in_features=256, out_features=128, bias=True)
    )
    (sample): GaussianSample(
      (mu): Linear(in_features=128, out_features=32, bias=True)
      (log_var): Linear(in_features=128, out_features=32, bias=True)
    )
  )
  (decoder): Decoder(
    (hidden): ModuleList(
      (0): Linear(in_features=32, out_features=128, bias=True)
      (1): Linear(in_features=128, out_features=256, bias=True)
    )
    (reconstruction): Linear(in_features=256, out_features=784, bias=True)
    (activation): Sigmoid()
  )
)


# Training

Uses functionality from /data/gen_mnist.py

In [None]:
train, test = get_mnist(batch_size=64)

In [None]:
# import loss function
from losses import binary_cross_entropy as loss_function

In [None]:
# define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, betas=(0.9, 0.999))

In [None]:
from torch.autograd import Variable
# Training Loop
for epoch in range(50):
    model.train()
    total_loss = 0
    for u, _ in iter(train):
        u = Variable(u).to(device)

        reconstruction = model(u)
        
        likelihood = -loss_function(reconstruction, u)
        elbo = likelihood - model.kld
        
        L = -torch.mean(elbo)

        L.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += L.data.item()

    if epoch % 10 == 0:
        print(f"Epoch: {epoch}\tL: {total_loss/len(train):.2f}")

In [None]:
model.eval()
z_prior = Variable(torch.randn(16, 32)).to(device)

x_mu = model.sample(z_prior)

In [None]:
f, axarr = plt.subplots(4, 4, figsize=(18, 12))

samples = x_mu.data.view(-1, 28, 28).cpu().numpy()

for i, ax in enumerate(axarr.flat):
    ax.imshow(samples[i])
    ax.axis("off")