In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from utils import make_dataloaders
from models.vaes import Base, VAE, IWAE, AIS_VAE, ULA_VAE, Stacked_VAE, VAE_with_flows, repeat_data
from models.samplers import HMC, MALA, ULA, run_chain
import yaml
import numpy as np
from scipy.stats import norm
import copy
from tqdm.auto import tqdm

import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers

%matplotlib widget

colors = {
    0: 'blue',
    1: 'red',
    2: 'green',
    3: 'yellow',
    4: 'black',
    5: 'orange',
}

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
train_loader, val_loader = make_dataloaders(dataset='mnist', batch_size=100, val_batch_size=100, binarize=True)

In [None]:
def load_model(version):
    with open(f'lightning_logs/default/version_{version}/hparams.yaml') as file:
        fruits_list = yaml.load(file, Loader=yaml.FullLoader)
        print(fruits_list)
        hparams = fruits_list
        
    path = f'lightning_logs/default/version_{version}/checkpoints/'
    file_name = os.listdir(path)[0]
    checkpoint = torch.load(f'{path}{file_name}')
    
    for current_model in [VAE, IWAE, ULA_VAE, AIS_VAE]:
        try:
            model = current_model(**hparams).to(device)
            model.load_state_dict(checkpoint['state_dict'])
        except:
            pass
        else:
            print(f'loaded {model.name}')
            return model

In [None]:
version = 794
iwae = load_model(version=version)

In [6]:
def get_transitions_output(model, z, mu, logvar, x):
    with torch.no_grad():
        x = repeat_data(x, model.num_samples)
        output = model.run_transitions(z=z,
                                        x=x,
                                       mu=mu,
                                       logvar=logvar)
    return output

In [7]:
sigma = 0.1

In [8]:
model_W = iwae.decoder_net.net[0].weight.data
model_mu = iwae.decoder_net.net[0].bias.data[..., None]

C = (model_W @ model_W.T) + (sigma**2) * torch.eye(784, device=device)
C_inv = torch.inverse(C)
logdetC = torch.logdet(C)

first_term = 784 * np.log(2 * np.pi) + logdetC

def get_true_loglikelihood(x):
    true_loglikelihood = torch.empty(x.shape[0], device=device, dtype=torch.float32)
    for i in range(x.shape[0]):
        x_cur = x[i].view(784, 1)
        S = (x_cur - model_mu) @ (x_cur - model_mu).T
        true_loglikelihood[i] = -0.5 * (first_term + torch.trace(C_inv @ S))
    return true_loglikelihood

In [9]:
# ----- ULA_VAE ----- #
ula_vae = ULA_VAE(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
            step_size=0.01, K=10, use_transforms=False, learnable_transitions=False, return_pre_alphas=True, use_score_matching=False,
                      ula_skip_threshold=0.1, grad_skip_val=0., grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.9, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma).to(device)
ula_vae.decoder_net = copy.deepcopy(iwae.decoder_net)

# ----- AIS_VAE ----- #
ais_vae = AIS_VAE(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
                  step_size=0.01, K=10, use_barker=False, learnable_transitions=False, use_alpha_annealing=True, grad_skip_val=0.,
                      grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.8, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma).to(device)
ais_vae.decoder_net = copy.deepcopy(iwae.decoder_net)

In [10]:
n = 100

In [11]:
for batch in train_loader:
    x, _ = batch
    x = x.to(device)
    z, mu, logvar = iwae.enc_rep(x, 1) # <- latents are fixed
    
    ula_w = torch.tensor([], device=device, dtype=torch.float32)
    ais_w = torch.tensor([], device=device, dtype=torch.float32)
    
    true_loglikelihood = get_true_loglikelihood(x)
    
    for i in tqdm(range(n)):
        ula_log_w = get_transitions_output(ula_vae, z, mu, logvar, x)[1]
        ais_log_w = get_transitions_output(ais_vae, z, mu, logvar, x)[1]
        
        ula_w = torch.cat([ula_w, ula_log_w[..., None]], dim=1)
        ais_w = torch.cat([ais_w, ais_log_w[..., None]], dim=1)
    
    break

HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))




In [12]:
true_loglikelihood

tensor([-2452.1934, -6029.5659, -1680.9746, -3616.5933, -3597.8057, -3371.4492,
        -6793.7598, -3445.1914, -1711.1777, -4196.6079, -2745.9248, -4487.2944,
        -2636.4326, -2560.2698, -2440.2881, -2254.8745, -4458.2861, -3561.3682,
        -4743.0518, -3855.7749, -3778.6172, -3431.2993, -6701.0513, -3076.9558,
        -2926.0525, -2452.2312, -3208.3110, -2763.3035, -2491.3765, -5707.7031,
        -3896.1113, -5264.2686, -4555.1895, -4465.1865, -4051.3618, -2939.1489,
        -6437.1318, -3729.7852, -1889.5876, -5795.1494, -2217.9404, -4862.6978,
        -3329.8687, -3622.3506, -5568.6387, -2771.0420, -5163.8887, -3137.0747,
        -3684.6172, -1860.2253, -3312.0479, -2050.1240, -3070.3494, -6949.9531,
        -3030.4233, -3680.4229, -3263.5144, -2347.2092, -2584.5767, -4391.4863,
        -5637.6333,  -932.4315, -6219.0801, -6480.4229, -2314.8040, -6265.3525,
        -5660.7534, -2791.5156, -2552.9812, -2668.9629, -2486.2610, -4123.9375,
        -3069.2251, -4171.0742, -2596.88

In [13]:
torch.logsumexp(ula_w, dim=1) - np.log(n)

tensor([ -4051245.2500,  -4142501.2500,  -1728860.7500,  -4609149.0000,
         -3132180.5000,  -2071665.1250,  -5778044.0000,  -4034548.5000,
         -2453695.7500,  -4364711.5000,  -2941454.0000,  -2718607.5000,
         -5033194.5000,  -2583134.0000,  -2431394.0000,  -4131238.0000,
         -3396202.5000,  -2930794.5000,  -7239304.5000,  -2971218.5000,
         -3666390.0000,  -3923755.0000,  -5295299.5000,  -3744014.5000,
         -2780702.5000,  -7811330.5000, -11288383.0000,  -4978382.5000,
         -3855796.0000, -10961722.0000,  -5369453.5000,  -1731265.1250,
         -2615478.0000,  -6848207.5000,  -3184644.5000,  -3671973.0000,
         -4650130.0000,  -3262731.7500,  -5456484.5000,  -3052228.5000,
         -5707489.5000,  -4971593.5000,  -3142734.5000,  -4246878.5000,
         -3075003.0000,  -3660056.0000,  -4647316.0000,  -2765877.5000,
         -2984320.2500,  -5108746.0000,  -2510135.2500,  -3029675.5000,
         -4939290.5000,  -5840400.5000,  -4394059.5000,  -278125