In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import argparse
import os
import sys
import numpy as np
import torch
from matplotlib import pyplot as plt
import pandas as pd

In [None]:
import gc

In [None]:
module_path = os.path.abspath(os.path.join('/users/dli44/tool-presence'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src import constants as c
from src import utils
from src import visualization as v
from src import model as m

In [None]:
parser = utils.setup_argparse()
args = parser.parse_args(args=['--root=/users/dli44/tool-presence/',
                               '--data-dir=data/youtube_data/',
                               '--image-size=64',
                               '--loss-function=mmd',
                               '--z-dim=10'
                              ])
datasets, dataloaders = utils.setup_data(args, augmentation=False)

In [None]:
zs = [2,5,10,20,40,80]
betas=[1.0,2.0,5.0,10.0,20.0,40.0]
lambdas = [1.0,5.0,10.0,20.0,100.0,500.0]
mmd_model_paths = [['mmd/weights/final_beta_{}_zdim_{}_epoch_80.torch'.format(l, z) 
                   for l in lambdas] for z in zs]
elbo_model_paths = [['elbo/weights/final_beta_{}_zdim_{}_epoch_80.torch'.format(beta, z) 
                    for beta in betas] for z in zs]

In [None]:
mmd_table = []

for i,zdim in enumerate(zs):
    for j, path in enumerate(mmd_model_paths[i]):
        model = m.VAE(image_channels=args.image_channels,
                  image_size=args.image_size,
                  h_dim1=1024,
                  h_dim2=128,
                  zdim=zdim).to(c.device)
        model.load_state_dict(torch.load(os.path.join(args.root, path)))
        print(path)

        logpx = utils.estimate_logpx(dataloaders['val'], model, args, 128)

        # Compute rl and mmd
        recon_loss, mmd_div, kl_div = 0,0,0
        n = len(dataloaders['val'].dataset)
        model.eval()
        with torch.no_grad():
            for batch_idx, (data, _) in enumerate(dataloaders['val']):
                data = data.to(c.device)
                recon_batch, z, mu, logvar = model(data)
                loss_params = {'recon': recon_batch,
                               'x': data,
                               'z': z,
                               'mu': mu,
                               'logvar': logvar,
                               'batch_size': args.batch_size,
                               'input_size': args.image_size,
                               'zdim': zdim,
                               'beta': betas[j]}
#                 print(loss_params)
                _, mmd, rl = m.mmd_loss(**loss_params)
                _, _, kld = m.vae_loss(**loss_params)
                print(rl.item(), mmd.item(), kld.item())
                recon_loss += rl.item()
                mmd_div += mmd.item()
                kl_div += kld.item()

        mmd_table.append([recon_loss/n, mmd_div/n, kl_div/n, np.nanmean(logpx)])
        # Free GPU memory
        del model
        torch.cuda.empty_cache()
        foo = range(10000000)
        del foo
        gc.collect()

In [None]:
import pickle
with open('mmd_table.pkl', 'rb') as f:
    mmd_table = pickle.load(f)

In [None]:
elbo_table = []

for i,zdim in enumerate(zs):
    for j, path in enumerate(elbo_model_paths[i]):
        model = m.VAE(image_channels=args.image_channels,
                  image_size=args.image_size,
                  h_dim1=1024,
                  h_dim2=128,
                  zdim=zdim).to(c.device)
        model.load_state_dict(torch.load(os.path.join(args.root, path)))
        print(path)

        logpx = utils.estimate_logpx(dataloaders['val'], model, args, 128)

        # Compute rl and mmd
        recon_loss, mmd_div, kl_div = 0,0,0
        n = len(dataloaders['val'].dataset)
        model.eval()
        with torch.no_grad():
            for batch_idx, (data, _) in enumerate(dataloaders['val']):
                data = data.to(c.device)
                recon_batch, z, mu, logvar = model(data)
                loss_params = {'recon': recon_batch,
                               'x': data,
                               'z': z,
                               'mu': mu,
                               'logvar': logvar,
                               'batch_size': args.batch_size,
                               'input_size': args.image_size,
                               'zdim': zdim,
                               'beta': betas[j]}
#                 print(loss_params)
                _, mmd, rl = m.mmd_loss(**loss_params)
                _, _, kld = m.vae_loss(**loss_params)
#                 print(rl.item(), mmd.item(), kld.item())
                recon_loss += rl.item()
                mmd_div += mmd.item()
                kl_div += kld.item()

        elbo_table.append([recon_loss/n, mmd_div/n, kl_div/n, np.nanmean(logpx)])
        # Free GPU memory
        del model
        torch.cuda.empty_cache()
        foo = range(10000000)
        del foo
        gc.collect()

In [None]:
def sci_notation(number, sig_fig=2):
    ret_string = "{0:.{1:d}e}".format(number, sig_fig)
    a,b = ret_string.split("e")
#     print(a,b)
    b = int(b) #removed leading "+" and strips leading zeros too.
    return "$" + a + "\\times 10^{" + str(b) + "}$"

In [None]:
sci_notation(elbo_table[5][0]* pixels) 

In [None]:
print("$z$ & $\\beta$ & Reconstruction Loss & MMD Distance & KL-Divergence & $\log(p(x))$ (est.) & bits/pixel\\\\\\midrule")
pixels = args.image_size * args.image_size * args.image_channels
for i, zdim in enumerate(zs):
    for j, path in enumerate(betas):
        print("{} & {} & {} & {} & {} & {} & {:.2f}\\\\".format(zdim, 
                                                                int(path), 
                                                                sci_notation(elbo_table[j+6*i][0] * pixels), 
                                                                sci_notation(elbo_table[j+6*i][1]), 
                                                                sci_notation(elbo_table[j+6*i][2]), 
                                                                sci_notation(elbo_table[j+6*i][3]),
                                                                -elbo_table[j+6*i][3]/pixels * np.log(10)/np.log(2)))
    print('\\midrule')

In [None]:
with open('out.txt', 'w') as f:
    f.write(s)