## Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import save_image

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import cv2

%matplotlib inline

In [None]:
def show(img):
    npimg = np.transpose(selected.numpy(), (1,2,0))
    grey = cv2.cvtColor(npimg, cv2.COLOR_GRAY2BGR)
    plt.imshow(grey, interpolation='nearest')

## Model definition

In [None]:
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return torch.sigmoid(self.fc6(h)) 
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

In [None]:
def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

## Setup model and load pretrained weights

In [None]:
vae_2 = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=2)
vae_10 = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=10)
vae_20 = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=20)

In [None]:
vae_2.load_state_dict(torch.load("../../weights/mnist_vae_zdim_2.torch"))
vae_10.load_state_dict(torch.load("../../weights/mnist_vae_zdim_10.torch"))
vae_20.load_state_dict(torch.load("../../weights/mnist_vae_zdim_20.torch"))

In [None]:
if torch.cuda.is_available():
    vae_2.cuda()
    vae_10.cuda()
    vae_20.cuda()

## Create Reconstructions

In [None]:
test_dataset = datasets.MNIST(root='~/tool-presence/mnist_data/', train=False, transform=transforms.ToTensor(), download=False)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

In [None]:
def plot_image_rows(images_list, title_list):
    rows = len(images_list)
    cols = len(images_list[0])
    
    def plot_image_row(images, title):
        plt.figure(figsize=(cols, 3))
        plt.gcf().suptitle(title)
        for i, img in enumerate(images):
            plt.subplot(rows, cols, i + 1)
            plt.imshow(img[:,:,0], cmap='Greys_r')
            plt.axis('off')

    for images, title in zip(images_list, title_list):
        plot_image_row(images, title)

In [None]:
selected_idx = [5531, 2553, 1432, 4526, 9960, 6860, 6987, 3720, 5003, 9472]

[test_dataset[idx][0] for idx in selected_idx]

In [None]:
models = [vae_2, vae_10, vae_20]
zdims = [2, 10, 20]
for i, (data, _) in enumerate(test_loader):
    data = data.cuda()
    for j in range(len(zdims)):
        recon_batch, mu, logvar = models[j](data)
        if i == 0:
            n = data.size(0)
            comparison = torch.cat([data[:n],
                                  recon_batch.view(64, 1, 28, 28)[:n]])
            save_image(comparison.cpu(),
                     '../../samples/reconstruction_zdim_{}.png'.format(zdims[j]), nrow=n)

## View Images

In [None]:
fig, ax = plt.subplots(figsize=(20, 10))
img=mpimg.imread('../../samples/reconstruction_zdim_2.png')
ax.imshow(img);

In [None]:
fig, ax = plt.subplots(figsize=(20, 10))
img=mpimg.imread('../../samples/reconstruction_zdim_10.png')
ax.imshow(img);

In [None]:
fig, ax = plt.subplots(figsize=(20, 10))
img=mpimg.imread('../../samples/reconstruction_zdim_20.png')
ax.imshow(img);

## Compute Stats