In [1]:
import sys
import os
sys.path.append(os.path.abspath('../..')) # include top level package in python path

In [2]:
import torch
from comparison.examples.mnist.vae_mnist import VAE_MNIST
# from comparison.loss import ELBO, IWAE
from comparison.metric import IWAE_64, log_px
from comparison.metric import IWAE_metric, CIWAE_metric, PIWAE_metric
from tqdm.notebook import tqdm

In [3]:
# Change device as needed. 
# CPU is supported on all machines, CUDA for specific GPUs.

device = 'cpu'
# device = 'cuda'

In [4]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

train_dataset = datasets.MNIST('./_mnist', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                   ]))

test_dataset = datasets.MNIST('./_mnist', train=False, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                   ]))

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=True)

In [5]:
def train_epoch(model, optim, dataloader, loss_function, label):
    model.train()
    
    for imgs, _ in tqdm(dataloader, desc="Training Epoch №%s" % label, leave=False):
        imgs = imgs.view(-1, 28*28).to(device)
        loss = -loss_function(model, imgs)
        optim.zero_grad()
        loss.backward()
        optim.step()


In [11]:
# this is kinda slow and if someone who's better at pytorch
# than me fancies speeding it up would be very helpful

def test(model, test_loader, label):
    with torch.no_grad():
        #evaluate metrics
        test_scores = {
            "IWAE-64": 0.0,
            "log(px)": 0.0,
            "-KL"    : 0.0
        }

        for imgs, _ in tqdm(test_loader, desc="Test Epoch №%s" % label, leave=False):
            imgs = imgs.view(-1, 28*28).to(device)            
            IWAE_64_loss = IWAE_64(model, imgs)
            logpx_loss = log_px(model, imgs)
            negKL_loss = logpx_loss - IWAE_64_loss
            
            # I can't work out best practice for summing these tensors
            # expecially with cuda maybe it's faster to append them all to a 
            # list and then sum that but idk
            test_scores["IWAE-64"] += IWAE_64_loss.sum().item()
            test_scores["log(px)"] += logpx_loss.sum().item()
            test_scores["-KL"] += negKL_loss.sum().item()
        
        return test_scores
        

In [7]:
def train_and_evaluate(model, optim, train_loader, test_loader, loss_function, no_epochs=3):
    model.train()
    
    test_metrics = {
        "IWAE-64": [],
        "log(px)": [],
        "-KL"    : []
    }

    for epoch in range(1, no_epochs + 1):
        train_epoch(model, optim, train_loader, loss_function, epoch)
        test_scores = test(model, test_loader, epoch)

        test_metrics["IWAE-64"].append(test_scores["IWAE-64"])
        test_metrics["log(px)"].append(test_scores["log(px)"])
        test_metrics["-KL"].append(test_scores["-KL"])

    return test_metrics

In [8]:
model = VAE_MNIST().to(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-3)

In [12]:
iwae64  = lambda model, xs: IWAE_metric(model, xs, M=1, K=64)
ciwae05 = lambda model, xs: CIWAE_metric(model, xs, beta=0.5)
#etc etc

iwae64_results = train_and_evaluate(model, optim, train_loader, test_loader, iwae64, no_epochs=1)
#ciwae05_results = train_and_evaluate(model, optim, train_loader, test_loader, ciwae05, no_epochs=1)

print(iwae64_results)

Training Epoch №1:   0%|          | 0/469 [00:00<?, ?it/s]

Test Epoch №1:   0%|          | 0/79 [00:00<?, ?it/s]

tensor(38652.7188)
tensor(36829.7617)
tensor(34817.4219)
tensor(38157.1602)
tensor(34745.5430)
tensor(33560.7773)


In [None]:
from comparison.plot import plot_smoothed

In [None]:
plot_smoothed(-losses, sigma=10, fit_sigma=True)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(4,26))

for imgs, _ in test_loader:
    for i in range(imgs.shape[0]):
        if i >= 10:
            break
        img = imgs[i].view(28 * 28).to(device)
        recon = model.reconstruct(img)
        ax = plt.subplot(10, 2, 2 * i + 1)
        ax.imshow(img.view(28, 28).clone().detach().cpu())
        plt.title("Original")
        ax = plt.subplot(10, 2, 2 * i + 2)
        ax.imshow(recon.view(28, 28).clone().detach().cpu())
        plt.title("Reconstructed")
        
plt.show()