In [1]:
# standard imports
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import Variable
%matplotlib inline

In [2]:
# own imports/
import sys
sys.path.append("../")

In [3]:
# set device
if torch.cuda.is_available():
    device = "cuda:0"
    torch.cuda.empty_cache()
else: 
    device = "cpu"

print(f"Using device {device}")

Using device cuda:0


# Get Data

In [12]:
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from operator import __or__
from functools import reduce
from torch.utils.data.sampler import SubsetRandomSampler

In [13]:
# Define a sampler
def get_mnist_sampler(labels, n=None, n_labels=10):
    # Only choose classes in n_labels
    classes = np.arange(n_labels)
    (indices,) = np.where(reduce(__or__, [labels == i for i in classes]))

    # Ensure uniform distribution of labels
    np.random.shuffle(indices)
    indices = np.hstack([list(filter(lambda idx: labels[idx] == i, indices))[:n] for i in range(n_labels)])

    indices = torch.from_numpy(indices)
    sampler = SubsetRandomSampler(indices)
    return sampler

# Define flatten transform as a 'lambda' func
def tmp_lambda_func(x):
    return torch.flatten(x)

In [14]:
# define stuff
batch_size = 64

# Transform of data
flatten_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(tmp_lambda_func),
])

# Download train and test data
from torchvision.datasets import MNIST
!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz
train_data = MNIST(root='./', train=True, download=True, transform=flatten_transform)
val_data = MNIST(root='./', train=False, download=True, transform=flatten_transform)

# loaders, which perform the actual work
train_loader = DataLoader(
    train_data, 
    batch_size=batch_size,
    num_workers=2,
    pin_memory=cuda,
    sampler=get_mnist_sampler(train_data.targets)
)
test_loader = DataLoader(
    val_data, 
    batch_size=batch_size, 
    num_workers=2,
    pin_memory=cuda,
    sampler=get_mnist_sampler(val_data.targets)
)

'wget' is not recognized as an internal or external command,
operable program or batch file.
tar: Error opening archive: Failed to open 'MNIST.tar.gz'

0it [00:00, ?it/s][ADownloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST\raw\train-images-idx3-ubyte.gz


HTTPError: HTTP Error 503: Service Unavailable

# Define Model

In [4]:
# bottom to top model dimensions
x_dim = 784
z_dim = [32, 16, 8]
h_dim = [256, 128, 64]

In [5]:
from models import LadderVAE
model = LadderVAE([x_dim, z_dim, h_dim])
model.to(device)
print(model)

LadderVAE(
  (encoder): ModuleList(
    (0): Encoder(
      (linear): Linear(in_features=784, out_features=256, bias=True)
      (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (sample): GaussianSample(
        (mu): Linear(in_features=256, out_features=32, bias=True)
        (log_var): Linear(in_features=256, out_features=32, bias=True)
      )
    )
    (1): Encoder(
      (linear): Linear(in_features=256, out_features=128, bias=True)
      (batch_norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (sample): GaussianSample(
        (mu): Linear(in_features=128, out_features=16, bias=True)
        (log_var): Linear(in_features=128, out_features=16, bias=True)
      )
    )
    (2): Encoder(
      (linear): Linear(in_features=128, out_features=64, bias=True)
      (batch_norm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (sample): GaussianSample(
    

# Training the model

In [6]:
# Define hyper parameters
learning_rate = 3e-4
epochs = 10

In [11]:
# get data
# from data import get_mnist
# train, test = get_mnist(location="./", batch_size=64) 

In [None]:
# define optimizer and linear warm-up constant
from utils import DeterministicWarmup, bce_loss
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))
gamma = DeterministicWarmup(n=50, t_max=1)  

In [None]:
# training loop
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for (u, _) in train:
        u = Variable(u).to(device)

        reconstruction = model(u)
        
        likelihood = -bce_loss(reconstruction, u)
        elbo = likelihood - next(gamma) * model.kld
        
        L = -torch.mean(elbo)

        L.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += L.data.item()

    m = len(train)

    if epoch % 10 == 0:
        print(f"Epoch: {epoch+1}\tL: {total_loss/m:.2f}")

# Sampling from Generative model

In [None]:
# sample
model.eval()
x_mu = model.sample(Variable(torch.randn(16, 8)).to(device))

# Plot
f, axarr = plt.subplots(2, 8, figsize=(18, 6))
samples = x_mu.data.view(-1, 28, 28).cpu().numpy()
for i, ax in enumerate(axarr.flat):
    ax.imshow(samples[i])
    ax.axis("off")