In [None]:
import sys
import os
sys.path.append(os.path.abspath('../..')) # include top level package in python path

In [None]:
import torch
from comparison.examples.vae_mnist import VAE_MNIST
from comparison.loss import ELBO, IWAE_loss, CIWAE_loss, PIWAE_loss, DREG_loss
from comparison.metric import IWAE_64, log_px
from comparison.metric import IWAE_metric, CIWAE_metric, PIWAE_metric, sample_ess
from tqdm.notebook import tqdm
import csv




In [None]:
# Change device as needed. 
# CPU is supported on all machines, CUDA for specific GPUs.

# device = 'cpu'
device = 'cuda'

import time
model_dir = './_modeldata/'
idstr = "bernoulli50"
print(idstr)

In [None]:
import re
import pandas as pd

class MNISTTrainData():
    def __init__(self, model_dir, idstr):
        self.model_dir = model_dir
        self.idstr = idstr
        
    def model_from_filename(self, filename):
        model_str = "^%s-(.+)-r([0-9]+)-([^-]+)-model\\.pt" % self.idstr
        match = re.search(model_str, filename)
        if match == None:
            return None
        loss = match.group(1)
        run_no = int(match.group(2))
        epoch = self.epoch_from_str(match.group(3))

        return (loss, run_no, epoch)
    
    def results_from_filename(self, filename):
        results_str = "^%s-(.+)-r([0-9]+)-results\\.csv" % self.idstr
        match = re.search(results_str, filename)
        if match == None:
            return None
        loss = match.group(1)
        run_no = int(match.group(2))

        return (loss, run_no)
        
    def get_files(self):
        filenames = os.listdir(self.model_dir)
        models = [self.model_from_filename(fl) for fl in filenames]
        results = [self.results_from_filename(fl) for fl in filenames]
        return [*filter(lambda x: x, models)], [*filter(lambda x: x, results)]
        
    def load_all_results(self):
        results = {
            'vae': [],
            'iwae': [],
            'ciwae-05': [],
            'miwae-8-8': [],
            'piwae-8-8': [],
            'dreg-1-64': [],
        }
        
        _, result_files = self.get_files()
        for loss, run_no in result_files:
            data = self.load_results(loss, run_no)
            results[loss].append((run_no, data))
            
        return results
    
    #Loads the latest models into the test objects
    def load_latest_models(self, tests):
        model_filenames, _ = self.get_files()
        losses = [t.name for t in tests]
        
        for t in tests:
            loss = t.name
            _, run_no, epoch = max(filter(lambda x: x[0] == loss, model_filenames), key=lambda x: x[2])
            self.load_model(t.model, loss, run_no, epoch)
            t.finished_epochs = epoch
        
    def epoch_str(self, epoch = None):
        return "final" if epoch == None else "e" + str(epoch)
    
    def epoch_from_str(self, s):
        if s == "final":
            return None
        else:
            match = re.search("e([0-9]+)", s)
            if not match:
                raise ValueError
            return int(match.group(1))
            
    def start_prefix(self):
        return self.model_dir + self.idstr + "-start"
    
    def run_prefix(self, loss, run_no):
        return (
            self.model_dir 
            + self.idstr 
            + "-"
            + loss 
            + "-r"
            + str(run_no)
        )
        
    def start_model_str(self):
        return self.start_prefix() + "-model.pt"
        
    def model_str(self, epoch, *args, **kargs):
        return self.run_prefix(*args, **kargs) + "-" + self.epoch_str(epoch)  + "-model.pt"

    def results_str(self, *args, **kargs):
        return self.run_prefix(*args, **kargs) + "-results.csv"
        
    def load_start(self, model):
        return model.load_state_dict(torch.load(self.start_model_str()))
    
    def save_start(self, model):
        torch.save(model.state_dict(), self.start_model_str())
        
    def load_model(self, model, loss, run_no, epoch):
        filename = self.model_str(epoch, loss, run_no)
        return model.load_state_dict(torch.load(filename))
        
    def save_model(self, model, loss, run_no, epoch):
        filename = self.model_str(epoch, loss, run_no)
        torch.save(model.state_dict(), filename)
        
    def load_results(self, loss, run_no):
        filename = self.results_str(loss, run_no)
        return pd.read_csv(filename)
    
    def load_results_dict(self, loss, run_no):
        tmp = {k: list(i.values()) for (k, i) in mnist_train_data.load_results(loss, run_no).to_dict().items()}
        if 'Unnamed: 0' in tmp: del tmp['Unnamed: 0']
        if 'Unnamed: 0.1' in tmp: del tmp['Unnamed: 0.1']
        return tmp
        
    def save_results(self, test_metrics, loss, run_no):
        filename = self.results_str(loss, run_no)
        test_metrics.to_csv(filename)
            
mnist_train_data = MNISTTrainData(model_dir, idstr)

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset

train_dataset = datasets.MNIST('./_mnist', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       torch.bernoulli
                       
                   ]))

test_dataset = datasets.MNIST('./_mnist', train=False, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       torch.bernoulli
                   ]))


def static_copy(dataset):
    ys = torch.tensor([*map(lambda tpl: tpl[1], dataset)])
    xs = torch.cat([*map(lambda tpl: tpl[0].unsqueeze(0), dataset)], dim=0)
    return TensorDataset(xs, ys)

#train_dataset = static_copy(train_dataset)
#test_datast = static_copy(test_dataset)


train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# We set a low batch_size for sampling IWAE with K=5000
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True)


In [None]:
def train_epoch(model, optims, dataloader, loss_function, label):
    model.train()
    lss = []
    for imgs, _ in tqdm(dataloader, desc="Training Epoch №%s" % label, leave=False):
        imgs = imgs.view(-1, 28*28).to(device)
        losses = loss_function(model, imgs)
        
        if type(losses) is not tuple:
            optim = optims[0]
            loss = -losses
            lss.append(loss.item())
            optim.zero_grad()
            loss.backward()
            optim.step()
        else: #PIWAE
            grads = []
            for idx, (optim, loss) in enumerate(zip(optims, losses)):
                loss = -loss
                for p in model.parameters():
                    p.grad = None
                loss.backward(retain_graph=idx<len(optims)-1)
                grads.append([p.grad for p in optim.param_groups[0]["params"]])

            for idx, optim in enumerate(optims):
                for g, p in zip(grads[idx], optim.param_groups[0]["params"]):
                    #assert(idx == 1 or not torch.equal(p.grad, g))
                    p.grad = g
            
            for optim in optims:
                optim.step()
    print(torch.Tensor(lss).mean())



In [None]:
def test_epoch(model, test_loader, label):
    with torch.no_grad():
        #evaluate metrics
        iwae_losses  = []
        logpx_losses = []
        kl_losses    = []
        

        for imgs, _ in tqdm(test_loader, desc="Test Epoch №%s" % label, leave=False):
            imgs = imgs.view(-1, 28*28).to(device)            
            IWAE_64_loss = IWAE_64(model, imgs)
            logpx_loss = log_px(model, imgs)
            negKL_loss = -(logpx_loss - IWAE_64_loss)
            
            iwae_losses  += IWAE_64_loss.tolist()
            logpx_losses += logpx_loss.tolist()
            kl_losses    += negKL_loss.tolist()
        
        
        test_scores = {
            "iwae-64": torch.tensor(iwae_losses).mean().item(),
            "logpx": torch.tensor(logpx_losses).mean().item(),
            "-kl"    : torch.tensor(kl_losses).mean().item()
        }
        
        return test_scores
        

In [None]:
def train_and_evaluate(
    model, 
    train_loader, 
    test_loader, 
    test,
    run_no,
    restart = False
):    
    test_metrics = mnist_train_data.load_results_dict(test.name, run_no) 
    if restart or test_metrics is None:
        test_metrics = {
            "epoch"  : [],
            "iwae-64": [],
            "logpx": [],
            "-kl"    : []
        }

    for epoch in range(test.finished_epochs+1, test.no_epochs + 1):
        train_epoch(model, test.optims, train_loader, test.loss_fn, epoch)
        
        # save metrics
        if epoch % test.epochs_per_sample == 0 or epoch == test.no_epochs:
            test_scores = test_epoch(model, test_loader, epoch)
            test_metrics["epoch"].append(epoch)
            test_metrics["iwae-64"].append(test_scores["iwae-64"])
            test_metrics["logpx"].append(test_scores["logpx"])
            test_metrics["-kl"].append(test_scores["-kl"])
            
        # snapshot model and metrics
        if epoch % 2 == 0:
            mnist_train_data.save_model(model, test.name, run_no, epoch)
            mnist_train_data.save_results(pd.DataFrame(test_metrics), test.name, run_no)
    
    mnist_train_data.save_model(model, test.name, run_no, epoch = test.no_epochs)
    mnist_train_data.save_results(test_metrics, test.name, run_no)
    return pd.DataFrame(test_metrics)

In [None]:
def eval_discard(loss_fn, M=1, K=1):
    return lambda model, xs: loss_fn(model(xs, M, K))

lr=3e-4

class LossTest:
    
    def __init__(self, name, loss_fn, no_epochs, epochs_per_sample, no_runs = 1, piwae=False, finished_epochs = 0):
        self.name = name
        self.loss_fn = loss_fn
        self.no_epochs = no_epochs
        self.epochs_per_sample = epochs_per_sample
        self.no_runs = no_runs
        self.piwae = piwae
        self.finished_epochs = finished_epochs;
        self.initialize_model()
    
    def initialize_model(self, model=None):
        if model == None:
            self.model = VAE_MNIST()
        else:
            self.model = model
        
        if self.piwae:
            self.optims = [torch.optim.Adam(self.model.encode_params(), lr=lr),
                           torch.optim.Adam(self.model.decode_params(), lr=lr)]
        else:
            self.optims = [torch.optim.Adam(self.model.parameters(), lr=lr)]
            


fig5_tests = [
    LossTest("iwae", eval_discard(IWAE_loss, M=1, K=64), 1500, 25, 1),
    LossTest("ciwae-05", eval_discard(lambda res: CIWAE_loss(res, 0.5), M=1, K=64), 1500, 24, 1),
    LossTest("piwae-8-8", eval_discard(PIWAE_loss, M=8, K=8), 1500, 24, 1, piwae=True),
    LossTest("miwae-8-8", eval_discard(IWAE_loss, M=8, K=8), 1500, 24, 1),
    LossTest("vae", eval_discard(IWAE_loss, M=64, K=1), 1500, 24, 1),
    LossTest("dreg-1-64", eval_discard(DREG_loss, M=1, K=64), 1500, 24, 1)
]        

    
mnist_train_data.load_latest_models(fig5_tests)

In [None]:
restart = False

for test in fig5_tests:
    print("Training ", test.name)
    for run_no in range(1, test.no_runs+1):
        print("Run ", run_no)
        test.initialize_model()
        
        if not restart:
            mnist_train_data.load_latest_models(fig5_tests)
        model = test.model.to(device)

        test_metrics = train_and_evaluate(
            model, 
            train_loader, 
            test_loader, 
            test,
            run_no,
            restart=restart
        )
        
        model.to('cpu')
        

In [None]:
def compute_ess(model, T=100):
    effective_sample_sizes = []

    for imgs, _ in tqdm(test_loader, desc="Computing ESS", leave=False):
        with torch.no_grad():
            imgs = imgs.view(-1, 28*28).to(device)
            esss = sample_ess(model, imgs, T=T)

            for ess in list(esss):
                effective_sample_sizes.append(ess.item())
    return effective_sample_sizes

ess_samples = []

for test in fig5_tests:
     ess_samples.append(compute_ess(test.model.to(device), 1000))
    

In [None]:
import matplotlib.pyplot as plt
# Create a figure instance
fig = plt.figure()

# Create an axes instance
ax = fig.add_axes([0,0,1,1])

# Create the boxplot
bp = ax.violinplot(ess_samples, showmedians=True, points=100)
plt.ylim(0,40e-3)
plt.show()

In [None]:
#test_results = train_and_evaluate(model, [optim], train_loader, test_loader, eval_discard(IWAE_loss, M=1, K=64), 6, 10, model_dir+"test")
 
#model.load_state_dict(torch.load(model_dir+'20220404-192054MIWAE8_8_run1_model_epoch6_'))


import matplotlib.pyplot as plt

plt.figure(figsize=(4,26))

for imgs, _ in tqdm(test_loader):
    for i in range(imgs.shape[0]):
        if i >= 10:
            break
        img = imgs[i].view(28 * 28).to(device)
        recon = model.reconstruct(img)
        ax = plt.subplot(10, 2, 2 * i + 1)
        ax.imshow(img.view(28, 28).clone().detach().cpu())
        plt.title("Original")
        ax = plt.subplot(10, 2, 2 * i + 2)
        ax.imshow(recon.view(28, 28).clone().detach().cpu())
        plt.title("Reconstructed")
    break
plt.show()

In [None]:
mnist_train_data.load_all_results()

In [None]:
def plot_metric(ress: dict, col: str, ax):
    lines = []
    for loss in ress.keys():
        table = ress[loss][0][1]
        epochs = table['epoch']
        metric = table[col]
        line, = ax.plot(epochs, metric, label=loss)
        lines.append(line)
    ax.legend(handles=lines)

In [None]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(3)
fig.set_figheight(12)
fig.set_figwidth(7)
plot_metric(mnist_train_data.load_all_results(), 'iwae-64', ax[0])
plot_metric(mnist_train_data.load_all_results(), 'logpx', ax[1])
plot_metric(mnist_train_data.load_all_results(), '-kl', ax[2])
ax[0].set_xscale('log')
ax[1].set_xscale('log')
fig.suptitle('Convergence of evaluation metrics on the test set over time.')
ax[0].set_ylabel('IWAE-64')
ax[1].set_ylabel('log(p(x))')
ax[2].set_ylabel('-KL')
ax[0].set_xlabel('Epoch')
ax[1].set_xlabel('Epoch')
ax[2].set_xlabel('Epoch')
ax[0].set_xlim(24,1000)
ax[1].set_xlim(24,1000)
ax[2].set_xlim(24,1000)
fig.tight_layout()