In [None]:
# standard imports
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import Variable
%matplotlib inline

In [None]:
# own imports/
import sys
sys.path.append("../")

In [None]:
# set device
if torch.cuda.is_available():
    device = "cuda:0"
    torch.cuda.empty_cache()
else: 
    device = "cpu"

print(f"Using device {device}")

# Define Model

In [None]:
# bottom to top model dimensions
x_dim = 784
z_dim = [32, 16, 8]
h_dim = [256, 128, 64]

In [None]:
from models import LadderVAE
model = LadderVAE([x_dim, z_dim, h_dim])
model.to(device)
print(model)

# Training the model

In [None]:
# Define hyper parameters
learning_rate = 3e-4
epochs = 250

In [None]:
# get data
from data import get_mnist
train, test = get_mnist(location="./", batch_size=64) 

In [None]:
# define optimizer and linear warm-up constant
from utils import DeterministicWarmup, bce_loss
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))
gamma = DeterministicWarmup(n=50, t_max=1)  

In [None]:
# training loop
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for (u, _) in train:
        u = Variable(u).to(device)

        reconstruction = model(u)
        
        likelihood = -bce_loss(reconstruction, u)
        elbo = likelihood - next(gamma) * model.kld
        
        L = -torch.mean(elbo)

        L.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += L.data.item()

    m = len(train)

    if epoch % 10 == 0:
        print(f"Epoch: {epoch+1}\tL: {total_loss/m:.2f}")

# Sampling from Generative model

In [None]:
# sample
model.eval()
x_mu = model.sample(Variable(torch.randn(16, 8)).to(device))

# Plot
f, axarr = plt.subplots(2, 8, figsize=(18, 6))
samples = x_mu.data.view(-1, 28, 28).cpu().numpy()
for i, ax in enumerate(axarr.flat):
    ax.imshow(samples[i])
    ax.axis("off")