In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.utils import save_image

# own imports
import sys
sys.path.append("../../ml-library/")

from models import DRAW
from layers import BaseAttention, #compute_filterbank

cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if cuda else "cpu")
print(device)

# Load Binarized MNIST Data

In [None]:
# Define batch size
BATCH_SIZE = 64

In [None]:
# Load the data
mnist_data = MNIST(
    './MNIST', 
    download=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: torch.bernoulli(x))
    ])
)

kwargs = {'num_workers': 2, 'pin_memory': True} if cuda else {}
data_loader = DataLoader(
    mnist_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    **kwargs
)

In [None]:
# plot some exampels
data_iter = iter(data_loader)
images, labels = data_iter.next()

f, ax = plt.subplots(1, N, figsize=(3*N, 6))
for i in range(N):
    if N == 1:
        ax.imshow(images[i, 0], cmap='gray')
        ax.set_title(f'Label: {labels[i]}')
        ax.axis('off')
    else:
        ax[i].imshow(images[i, 0], cmap='gray')
        ax[i].set_title(f'Label: {labels[i]}')
        ax[i].axis('off')
plt.show()

In [None]:
# 

In [None]:
def crop_img(img, xr, yr):
    return img[yr[0]:yr[1], xr[0]:xr[1]]

def filter_img(img, F_X, F_Y, log_gamma):
    return np.exp(log_g) * F_Y @ img @ F_X.T

# Train DRAW Model

In [None]:
# create model and optimizer
model = DRAW(x_dim=784, h_dim=256, z_dim=16, T=10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.5, 0.999))
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10, 0.5)

# loss
loss = nn.BCELoss(reduction='none').to(device)

In [None]:
EPOCHS = 100

In [None]:
# training loop
for epoch in range(EPOCHS):
    model.train()
    for x, _ in tqdm(data_loader):
        batch_size = x.size(0)

        x = x.view(batch_size, -1).to(device)

        x_hat, kld = model(x)
        x_hat = torch.sigmoid(x_hat)

        reconstruction = loss(x_hat, x).sum(1)
        kl = kld.sum(1)
        elbo = torch.mean(reconstruction + kl)

        elbo.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    scheduler.step()

    # Evaluate on test set
    # 1 forward pass in test set for loss
    model.eval()
    with torch.no_grad():

        if epoch % 1 == 0:
          
        
        # print train loss and save sample + recon
        if epoch % 10 == 0:
            print("\nLoss at step {}: {}".format(epoch, elbo.item()))
            x_sample = model.sample()
            save_image(x_hat, "reconstruction-{}.jpg".format(epoch))
            save_image(x_sample, "sample-{}.jpg".format(epoch))

# Plot Samples

In [None]:
from matplotlib import image

In [None]:
img = image.imread('reconstruction-99.jpg')[:, :, 0]

f, ax = plt.subplots(4, 8, figsize=(16,8))
for i in range(8):
    idx = i*4

    for j in range(4):
        ax[j, i].imshow(np.reshape(img[idx+j, :], (28,28)), cmap='gray')
        ax[j, i].axis('off')
plt.show()

In [None]:
img = image.imread('sample-99.jpg')[:, :, 0]

f, ax = plt.subplots(4, 8, figsize=(16,8))
for i in range(8):
    idx = i*4

    for j in range(4):
        ax[j, i].imshow(np.reshape(img[idx+j, :], (28,28)), cmap='gray')
        ax[j, i].axis('off')
plt.show()