# Setup Notebook

Import standard libraries 

In [None]:
# Install weights and biases
!pip install wandb > /dev/null

In [None]:
import sys, os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Setup repo from GitHub

In [None]:
# setup repo from GitHub
if os.path.exists('./ml-library'):
    !rm -r ml-library > /dev/null
!git clone https://github.com/simonamtoft/ml-library > /dev/null
sys.path.append('ml-library')

from models import VariationalAutoencoder
from training import train_vae

Download binarized MNIST data

In [None]:
def tmp_lambda(x):
    return torch.bernoulli(x)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(tmp_lambda)
])

train_data = MNIST('./', train=True, download=True, transform=transform)
test_data = MNIST('./', train=False, download=True, transform=transform)

# Variational Autoencoder

In [None]:
# Define config
config = {
    'batch_size': 64,
    'epochs': 250,
    'lr': 3e-4,
    'h_dim': [512, 256, 256, 256],
    'z_dim': 128, 
    'as_beta': False
}

## Convert to Data Loaders

In [None]:
# split into training and validation sets
train_set, val_set = torch.utils.data.random_split(train_data, [50000, 10000])

In [None]:
# Setup data loader
kwargs = {'num_workers': 4, 'pin_memory': True} if torch.cuda.is_available() else {}
train_loader = DataLoader(
    train_set,
    batch_size=config['batch_size'],
    shuffle=True,
    **kwargs
)
val_loader = DataLoader(
    val_set,
    batch_size=config['batch_size'],
    shuffle=True,
    **kwargs
)

In [None]:
# Load a batch of images into memory
images, labels = next(iter(train_loader))

# plot a few MNIST examples
f, axarr = plt.subplots(4, 16, figsize=(16, 4))
for i, ax in enumerate(axarr.flat):
    ax.imshow(images[i].view(28, 28), cmap="binary_r")
    ax.axis('off')
    
plt.suptitle('MNIST handwritten digits')
plt.show()

In [None]:
x_dim = images.shape[2]*images.shape[3] # 784
print(x_dim)

## Train and Inspect Results

In [None]:
# !wandb login

In [None]:
# Model Parameters
h_dims  = config['h_dim']
z_dim   = config['z_dim']

In [None]:
# Instantiate model
model = VariationalAutoencoder([x_dim, h_dims, z_dim]).to(device)

# Train model
train_vae(model, config, train_loader, val_loader, 'vae')