In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from utils import make_dataloaders
from models.vaes import Base, VAE, IWAE, AIS_VAE, ULA_VAE, Stacked_VAE, VAE_with_flows, repeat_data
from models.samplers import HMC, MALA, ULA, run_chain
import yaml
import numpy as np
from scipy.stats import norm
import copy
from tqdm.auto import tqdm

import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers

from inspect import signature

%matplotlib widget

colors = {
    0: 'blue',
    1: 'red',
    2: 'green',
    3: 'yellow',
    4: 'black',
    5: 'orange',
}

In [2]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [3]:
train_loader, val_loader = make_dataloaders(dataset='mnist', batch_size=100, val_batch_size=100, binarize=True)



In [4]:
def load_model(version):
    with open(f'lightning_logs/default/version_{version}/hparams.yaml') as file:
        fruits_list = yaml.load(file, Loader=yaml.FullLoader)
        print(fruits_list)
        hparams = fruits_list
        
    path = f'lightning_logs/default/version_{version}/checkpoints/'
    file_name = os.listdir(path)[0]
    checkpoint = torch.load(f'{path}{file_name}')
    
    for current_model in [VAE, IWAE, ULA_VAE, AIS_VAE]:
        try:
            model = current_model(**hparams).to(device)
            model.load_state_dict(checkpoint['state_dict'])
        except:
            pass
        else:
            print(f'loaded {model.name}')
            return model

In [None]:
version = 794
iwae = load_model(version=version)

In [None]:
def get_transitions_output(model, z, mu, logvar, x):
    x = repeat_data(x, model.num_samples)
    output = model.run_transitions(z=z,
                                    x=x,
                                   mu=mu,
                                   logvar=logvar)
    if len(str(signature(model.loss_function)).split(',')) > 1:
        loss = model.loss_function(sum_log_alphas=output[2], sum_log_weights=output[1])
    else:
        loss = model.loss_function(sum_log_weights=output[1])
    import pdb
    grad = torch.autograd.grad(loss, model.decoder_net.net[0].bias)[0][:50]
    return output, grad

In [None]:
sigma = 0.1

In [None]:
model_W = iwae.decoder_net.net[0].weight.data
model_mu = iwae.decoder_net.net[0].bias.data[..., None]

C = (model_W @ model_W.T) + (sigma**2) * torch.eye(784, device=device)
C_inv = torch.inverse(C)
logdetC = torch.logdet(C)

first_term = 784 * np.log(2 * np.pi) + logdetC

def get_true_loglikelihood(x):
    true_loglikelihood = torch.empty(x.shape[0], device=device, dtype=torch.float32)
    for i in range(x.shape[0]):
        x_cur = x[i].view(784, 1)
        S = (x_cur - model_mu) @ (x_cur - model_mu).T
        true_loglikelihood[i] = -0.5 * (first_term + torch.trace(C_inv @ S))
    return true_loglikelihood

In [None]:
class ULA_VAE_reverse(ULA_VAE):
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.reverse_kernels.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.25)
        return [optimizer], [scheduler]

In [None]:
n = 100

In [None]:
# ------------------------------------------------------------------------------------------------------------------------------------------------------------
# NO reverse

# ----- ULA_VAE ----- #
ula_5 = ULA_VAE(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
            step_size=0.001, K=5, use_transforms=False, learnable_transitions=False, return_pre_alphas=True, use_score_matching=False,
                      ula_skip_threshold=0.1, grad_skip_val=0., grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.9, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma).to(device)
ula_5.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ula_5.decoder_net.parameters():
    p.requires_grad_(True)
ula_5.encoder_net = copy.deepcopy(iwae.encoder_net)
ula_5.use_stepsize_update = False

# ----- ULA_VAE ----- #
ula_10 = ULA_VAE(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
            step_size=0.001, K=10, use_transforms=False, learnable_transitions=False, return_pre_alphas=True, use_score_matching=False,
                      ula_skip_threshold=0.1, grad_skip_val=0., grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.9, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma).to(device)
ula_10.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ula_10.decoder_net.parameters():
    p.requires_grad_(True)
ula_10.encoder_net = copy.deepcopy(iwae.encoder_net)
ula_10.use_stepsize_update = False


# -------------------------------------------------------------------------------------------------------------------------------------------------------------

# Reverse

# ----- ULA_VAE ----- #
ula_5_r = ULA_VAE_reverse(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
            step_size=0.001, K=5, use_transforms=False, learnable_transitions=False, return_pre_alphas=True, use_score_matching=False,
                      ula_skip_threshold=0.1, grad_skip_val=0., grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.9, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma, use_reverse_kernel=True).to(device)
ula_5_r.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ula_5_r.decoder_net.parameters():
    p.requires_grad_(True)
ula_5_r.encoder_net = copy.deepcopy(iwae.encoder_net)
ula_5_r.use_stepsize_update = False


# ----- ULA_VAE ----- #
ula_10_r = ULA_VAE_reverse(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
            step_size=0.001, K=10, use_transforms=False, learnable_transitions=False, return_pre_alphas=True, use_score_matching=False,
                      ula_skip_threshold=0.1, grad_skip_val=0., grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.9, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma, use_reverse_kernel=True).to(device)
ula_10_r.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ula_10_r.decoder_net.parameters():
    p.requires_grad_(True)
ula_10_r.encoder_net = copy.deepcopy(iwae.encoder_net)
ula_10_r.use_stepsize_update = False




# -------------------------------------------------------------------------------------------------------------------------------------------------------------

# No multisample

# ----- AIS_VAE ----- #
ais_5 = AIS_VAE(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
                  step_size=0.01, K=5, use_barker=False, learnable_transitions=False, use_alpha_annealing=True, grad_skip_val=0.,
                      grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.8, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma).to(device)
ais_5.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ais_5.decoder_net.parameters():
    p.requires_grad_(True)
ais_5.encoder_net = copy.deepcopy(iwae.encoder_net)
ais_5.use_stepsize_update = False


# ----- AIS_VAE ----- #
ais_10 = AIS_VAE(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
                  step_size=0.01, K=10, use_barker=False, learnable_transitions=False, use_alpha_annealing=True, grad_skip_val=0.,
                      grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.8, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma).to(device)
ais_10.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ais_10.decoder_net.parameters():
    p.requires_grad_(True)
ais_10.encoder_net = copy.deepcopy(iwae.encoder_net)
ais_10.use_stepsize_update = False


# -------------------------------------------------------------------------------------------------------------------------------------------------------------

# Multisample

# ----- AIS_VAE ----- #
ais_5_3 = AIS_VAE(shape=28, act_func=nn.LeakyReLU, num_samples=3, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
                  step_size=0.01, K=5, use_barker=False, learnable_transitions=False, use_alpha_annealing=True, grad_skip_val=0.,
                      grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.8, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma).to(device)
ais_5_3.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ais_5_3.decoder_net.parameters():
    p.requires_grad_(True)
ais_5_3.encoder_net = copy.deepcopy(iwae.encoder_net)
ais_5_3.use_stepsize_update = False

In [None]:
def run_exp(model):
    for batch in tqdm(val_loader):
        x, _ = batch
        x = x.to(device)
        z, mu, logvar = iwae.enc_rep(x, model.num_samples) # <- latents are fixed

        model_w = torch.tensor([], device=device, dtype=torch.float32)
        model_g = []

        true_loglikelihood = get_true_loglikelihood(x).repeat(model.num_samples).cpu().detach().numpy()
        true_loglikelihood_mean = np.mean(true_loglikelihood)
        for i in range(n):
            model_log_w, grad_model = get_transitions_output(model, z, mu, logvar, x)
            with torch.no_grad():
                model_log_w = model_log_w[1]
                model_w = torch.cat([model_w, model_log_w[..., None]], dim=1)
                model_g.append(grad_model.cpu().detach().numpy())
    
    
    return model_w, true_loglikelihood_mean, np.array(model_g)

In [None]:
def run_trainer(model):
    tb_logger = pl_loggers.TensorBoardLogger('lightning_logs/')
    trainer = pl.Trainer(logger=tb_logger, fast_dev_run=False, max_epochs=101, automatic_optimization=True, gpus=1)
    trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)

In [None]:
run_trainer(ula_5_r)

In [None]:
run_trainer(ula_10_r)

In [None]:
ula_5_l2, ula_5_grad = run_exp(ula_5)

In [None]:
ula_10_l2, ula_10_grad = run_exp(ula_10)

In [None]:
ais_5_l2, ais_5_grad = run_exp(ais_5)

In [None]:
ais_10_l2, ais_10_grad = run_exp(ais_10)

In [None]:
ais_5_5_l2, ais_5_5_grad = run_exp(ais_5_3)

In [None]:
ais_5_5_l2

In [None]:
run_trainer(model)

In [None]:
ula_5_l2, ula_5_grad

In [None]:
ula_10_l2, ula_10_grad

In [None]:
ais_5_l2, ais_5_grad

In [None]:
ais_10_l2, ais_10_grad