In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from utils import make_dataloaders
from models.vaes import Base, VAE, IWAE, AIS_VAE, ULA_VAE, Stacked_VAE, VAE_with_flows, repeat_data
from models.samplers import HMC, MALA, ULA, run_chain
import yaml
import numpy as np
from scipy.stats import norm
import copy
from tqdm.auto import tqdm

import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers

from inspect import signature

%matplotlib widget

colors = {
    0: 'blue',
    1: 'red',
    2: 'green',
    3: 'yellow',
    4: 'black',
    5: 'orange',
}

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
train_loader, val_loader = make_dataloaders(dataset='mnist', batch_size=100, val_batch_size=100, binarize=True)



In [4]:
def load_model(version):
    with open(f'lightning_logs/default/version_{version}/hparams.yaml') as file:
        fruits_list = yaml.load(file, Loader=yaml.FullLoader)
        print(fruits_list)
        hparams = fruits_list
        
    path = f'lightning_logs/default/version_{version}/checkpoints/'
    file_name = os.listdir(path)[0]
    checkpoint = torch.load(f'{path}{file_name}')
    
    for current_model in [VAE, IWAE, ULA_VAE, AIS_VAE]:
        try:
            model = current_model(**hparams).to(device)
            model.load_state_dict(checkpoint['state_dict'])
        except:
            pass
        else:
            print(f'loaded {model.name}')
            return model

In [5]:
version = 794
iwae = load_model(version=version)

{'act_func': <class 'torch.nn.modules.activation.GELU'>, 'dataset': 'mnist', 'hidden_dim': 100, 'name': 'IWAE', 'net_type': 'fc', 'num_samples': 50, 'shape': 28, 'sigma': 0.1, 'specific_likelihood': 'gaussian'}
loaded IWAE


In [6]:
def get_transitions_output(model, z, mu, logvar, x):
    x = repeat_data(x, model.num_samples)
    output = model.run_transitions(z=z,
                                    x=x,
                                   mu=mu,
                                   logvar=logvar)
    if len(str(signature(model.loss_function)).split(',')) > 1:
        loss = model.loss_function(sum_log_alphas=output[2], sum_log_weights=output[1])
    else:
        loss = model.loss_function(sum_log_weights=output[1])
    import pdb
    grad = torch.autograd.grad(loss, model.decoder_net.net[0].bias)[0][:10]
    return output, grad

In [7]:
sigma = 0.1

In [8]:
model_W = iwae.decoder_net.net[0].weight.data
model_mu = iwae.decoder_net.net[0].bias.data[..., None]

C = (model_W @ model_W.T) + (sigma**2) * torch.eye(784, device=device)
C_inv = torch.inverse(C)
logdetC = torch.logdet(C)

first_term = 784 * np.log(2 * np.pi) + logdetC

def get_true_loglikelihood(x):
    true_loglikelihood = torch.empty(x.shape[0], device=device, dtype=torch.float32)
    for i in range(x.shape[0]):
        x_cur = x[i].view(784, 1)
        S = (x_cur - model_mu) @ (x_cur - model_mu).T
        true_loglikelihood[i] = -0.5 * (first_term + torch.trace(C_inv @ S))
    return true_loglikelihood

In [None]:
# ----- ULA_VAE ----- #
ula_vae = ULA_VAE(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
            step_size=0.001, K=10, use_transforms=False, learnable_transitions=False, return_pre_alphas=True, use_score_matching=False,
                      ula_skip_threshold=0.1, grad_skip_val=0., grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.9, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma).to(device)
ula_vae.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ula_vae.decoder_net.parameters():
    p.requires_grad_(True)
ula_vae.encoder_net = copy.deepcopy(iwae.encoder_net)

# ----- AIS_VAE ----- #
ais_vae = AIS_VAE(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
                  step_size=0.01, K=10, use_barker=False, learnable_transitions=False, use_alpha_annealing=True, grad_skip_val=0.,
                      grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.8, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma).to(device)
ais_vae.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ais_vae.decoder_net.parameters():
    p.requires_grad_(True)
ais_vae.encoder_net = copy.deepcopy(iwae.encoder_net)

In [None]:
n = 10

In [None]:
bias_ula = []
Esqr_ula = []
E_ula = []
grad_elbo_ula = []

bias_ais = []
Esqr_ais = []
E_ais = []
grad_elbo_ais = []

j = 0
for batch in val_loader:
    x, _ = batch
    x = x.to(device)
    z, mu, logvar = iwae.enc_rep(x, 1) # <- latents are fixed
    
    ula_w = torch.tensor([], device=device, dtype=torch.float32)
    ais_w = torch.tensor([], device=device, dtype=torch.float32)
    
    ula_g = []
    ais_g = []
    
    true_loglikelihood = get_true_loglikelihood(x).cpu().detach().numpy()
    
    for i in tqdm(range(n)):
        ula_log_w, grad_ula = get_transitions_output(ula_vae, z, mu, logvar, x)
        ais_log_w, grad_ais = get_transitions_output(ais_vae, z, mu, logvar, x)
        
        with torch.no_grad():
            ula_log_w = ula_log_w[1]
            ais_log_w = ais_log_w[1]

            ula_w = torch.cat([ula_w, ula_log_w[..., None]], dim=1)
            ais_w = torch.cat([ais_w, ais_log_w[..., None]], dim=1)


            ula_g.append(grad_ula.cpu().detach().numpy())
            ais_g.append(grad_ais.cpu().detach().numpy())
    
    log_p_ais = (torch.logsumexp(ais_w, dim=1) - np.log(n)).cpu().detach().numpy()
    log_p_ula = (torch.logsumexp(ula_w, dim=1) - np.log(n)).cpu().detach().numpy()
    
    bias_ula.append(np.mean(log_p_ula - true_loglikelihood))
    Esqr_ula.append(np.mean(log_p_ula**2))
    E_ula.append(np.mean(log_p_ula))
    grad_elbo_ula.append(ula_g)
    
    bias_ais.append(np.mean(log_p_ais - true_loglikelihood))
    Esqr_ais.append(np.mean(log_p_ais**2))
    E_ais.append(np.mean(log_p_ais))
    grad_elbo_ais.append(ais_g)
    
    j += 1
    if j == 5:
        break

    
l2_ais = np.mean(bias_ais)**2 + np.mean(Esqr_ais) - np.mean(E_ais)**2 + sigma**2
l2_ula = np.mean(bias_ula)**2 + np.mean(Esqr_ula) - np.mean(E_ula)**2 + sigma**2


grad_elbo_ula = np.std(np.array(grad_elbo_ula))
grad_elbo_ais = np.std(np.array(grad_elbo_ais))

In [None]:
l2_ais

In [None]:
l2_ula

In [None]:
grad_elbo_ula

In [None]:
grad_elbo_ais