In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from utils import make_dataloaders
from models.vaes import Base, VAE, IWAE, AIS_VAE, ULA_VAE, Stacked_VAE, VAE_with_flows, repeat_data
from models.samplers import HMC, MALA, ULA, run_chain
import yaml
import numpy as np
from scipy.stats import norm
import copy
from tqdm.auto import tqdm

import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers

from inspect import signature

%matplotlib widget

colors = {
    0: 'blue',
    1: 'red',
    2: 'green',
    3: 'yellow',
    4: 'black',
    5: 'orange',
}

In [2]:
device = "cpu" #torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [3]:
train_loader, val_loader = make_dataloaders(dataset='mnist', batch_size=100, val_batch_size=100, binarize=True)



In [4]:
batch = None
for v_b in val_loader:
    batch = v_b
    break

In [5]:
def load_model(version):
    with open(f'lightning_logs/default/version_{version}/hparams.yaml') as file:
        fruits_list = yaml.load(file, Loader=yaml.FullLoader)
        print(fruits_list)
        hparams = fruits_list
        
    path = f'lightning_logs/default/version_{version}/checkpoints/'
    file_name = os.listdir(path)[0]
    checkpoint = torch.load(f'{path}{file_name}')
    
    for current_model in [VAE, IWAE, ULA_VAE, AIS_VAE]:
        try:
            model = current_model(**hparams).to(device)
            model.load_state_dict(checkpoint['state_dict'])
        except:
            pass
        else:
            print(f'loaded {model.name}')
            return model

In [8]:
version = 904 #794
iwae = load_model(version=version)

{'K': 10, 'acceptance_rate_target': 0.9, 'act_func': <class 'torch.nn.modules.activation.GELU'>, 'annealing_scheme': 'all_learnable', 'dataset': 'celeba', 'grad_clip_val': 50.0, 'grad_skip_val': 0.0, 'hidden_dim': 128, 'learnable_transitions': False, 'name': 'ULA_VAE', 'net_type': 'conv', 'num_samples': 1, 'shape': 64, 'sigma': 1.0, 'specific_likelihood': None, 'step_size': 0.01, 'ula_skip_threshold': 0.1, 'use_cloned_decoder': False, 'use_reverse_kernel': False, 'use_score_matching': False, 'variance_sensitive_step': True}
loaded ULA_VAE


In [7]:
def get_transitions_output(model, z, mu, logvar, x):
    x = repeat_data(x, model.num_samples)
    output = model.run_transitions(z=z,
                                    x=x,
                                   mu=mu,
                                   logvar=logvar)
    if len(str(signature(model.loss_function)).split(',')) > 1:
        loss = model.loss_function(sum_log_alphas=output[2], sum_log_weights=output[1])
    else:
        loss = model.loss_function(sum_log_weights=output[1])
    grad = torch.autograd.grad(loss, model.decoder_net.net[0].bias)[0][:50]
    return output, grad

In [8]:
sigma = 0.1

In [9]:
model_W = iwae.decoder_net.net[0].weight.data
model_mu = iwae.decoder_net.net[0].bias.data[..., None]
model_mu.requires_grad_(True)

C = (model_W @ model_W.T) + (sigma**2) * torch.eye(784, device=device)
C_inv = torch.inverse(C)
logdetC = torch.logdet(C)

first_term = 784 * np.log(2 * np.pi) + logdetC

def get_true_loglikelihood(x):
    true_loglikelihood = torch.empty(x.shape[0], device=device, dtype=torch.float32)
    for i in range(x.shape[0]):
        x_cur = x[i].view(784, 1)
        S = (x_cur - model_mu) @ (x_cur - model_mu).T
        true_loglikelihood[i] = -0.5 * (first_term + torch.trace(C_inv @ S))
    grad_true = torch.autograd.grad(true_loglikelihood.sum(), model_mu)[0][:50]
    return true_loglikelihood, grad_true

In [10]:
class ULA_VAE_reverse(ULA_VAE):
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.reverse_kernels.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.25)
        return [optimizer], [scheduler]

In [11]:
# ------------------------------------------------------------------------------------------------------------------------------------------------------------
# NO reverse

# ----- ULA_VAE ----- #
ula_5 = ULA_VAE(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
            step_size=0.003, K=5, use_transforms=False, learnable_transitions=False, return_pre_alphas=True, use_score_matching=False,
                      ula_skip_threshold=0.1, grad_skip_val=0., grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.9, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma).to(device)
ula_5.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ula_5.decoder_net.parameters():
    p.requires_grad_(True)
ula_5.encoder_net = copy.deepcopy(iwae.encoder_net)
ula_5.use_stepsize_update = False

# ----- ULA_VAE ----- #
ula_10 = ULA_VAE(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
            step_size=0.003, K=10, use_transforms=False, learnable_transitions=False, return_pre_alphas=True, use_score_matching=False,
                      ula_skip_threshold=0.1, grad_skip_val=0., grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.9, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma).to(device)
ula_10.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ula_10.decoder_net.parameters():
    p.requires_grad_(True)
ula_10.encoder_net = copy.deepcopy(iwae.encoder_net)
ula_10.use_stepsize_update = False


# -------------------------------------------------------------------------------------------------------------------------------------------------------------

# Reverse

# ----- ULA_VAE ----- #
ula_5_r = ULA_VAE_reverse(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
            step_size=0.003, K=1, use_transforms=False, learnable_transitions=False, return_pre_alphas=True, use_score_matching=False,
                      ula_skip_threshold=0.1, grad_skip_val=0., grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.9, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma, use_reverse_kernel=True).to(device)
ula_5_r.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ula_5_r.decoder_net.parameters():
    p.requires_grad_(True)
ula_5_r.encoder_net = copy.deepcopy(iwae.encoder_net)
ula_5_r.use_stepsize_update = False


# ----- ULA_VAE ----- #
ula_10_r = ULA_VAE_reverse(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
            step_size=0.003, K=2, use_transforms=False, learnable_transitions=False, return_pre_alphas=True, use_score_matching=False,
                      ula_skip_threshold=0.1, grad_skip_val=0., grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.9, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma, use_reverse_kernel=True).to(device)
ula_10_r.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ula_10_r.decoder_net.parameters():
    p.requires_grad_(True)
ula_10_r.encoder_net = copy.deepcopy(iwae.encoder_net)
ula_10_r.use_stepsize_update = False




# -------------------------------------------------------------------------------------------------------------------------------------------------------------

# No multisample

# ----- AIS_VAE ----- #
ais_5 = AIS_VAE(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
                  step_size=0.003, K=5, use_barker=False, learnable_transitions=False, use_alpha_annealing=True, grad_skip_val=0.,
                      grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.8, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma).to(device)
ais_5.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ais_5.decoder_net.parameters():
    p.requires_grad_(True)
ais_5.encoder_net = copy.deepcopy(iwae.encoder_net)
ais_5.use_stepsize_update = False


# ----- AIS_VAE ----- #
ais_10 = AIS_VAE(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
                  step_size=0.003, K=10, use_barker=False, learnable_transitions=False, use_alpha_annealing=True, grad_skip_val=0.,
                      grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.8, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma).to(device)
ais_10.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ais_10.decoder_net.parameters():
    p.requires_grad_(True)
ais_10.encoder_net = copy.deepcopy(iwae.encoder_net)
ais_10.use_stepsize_update = False


# -------------------------------------------------------------------------------------------------------------------------------------------------------------

# Multisample

# ----- AIS_VAE ----- #
ais_5_3 = AIS_VAE(shape=28, act_func=nn.LeakyReLU, num_samples=10, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
                  step_size=0.003, K=5, use_barker=False, learnable_transitions=False, use_alpha_annealing=True, grad_skip_val=0.,
                      grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.8, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma).to(device)
ais_5_3.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ais_5_3.decoder_net.parameters():
    p.requires_grad_(True)
ais_5_3.encoder_net = copy.deepcopy(iwae.encoder_net)
ais_5_3.use_stepsize_update = False

In [12]:
def run_exp(model, n = 200):
    x, _ = batch
    x = x.to(device)
    z, mu, logvar = iwae.enc_rep(x, model.num_samples) # <- latents are fixed

    model_w = torch.tensor([], device=device, dtype=torch.float32)
    model_g = []

    true_loglikelihood_, grad_true = get_true_loglikelihood(x)
    #true_loglikelihood = true_loglikelihood_.repeat(model.num_samples).cpu().detach().numpy()
    true_loglikelihood_mean = np.mean(true_loglikelihood_.cpu().detach().numpy())
    for i in tqdm(range(n)):
        model_log_w, grad_model = get_transitions_output(model, z, mu, logvar, x)

        with torch.no_grad():
            model_log_w = model_log_w[1]

            model_w = torch.cat([model_w, model_log_w[..., None]], dim=1)


            model_g.append(grad_model.cpu().detach().numpy())

 
    return model_w.cpu().detach().numpy(), true_loglikelihood_mean, np.array(model_g), grad_true.cpu().detach().numpy()

In [13]:
def run_trainer(model):
    tb_logger = pl_loggers.TensorBoardLogger('lightning_logs/')
    trainer = pl.Trainer(logger=tb_logger, fast_dev_run=False, max_epochs=10, automatic_optimization=True, gpus=1)
    trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)

In [14]:
ula_5_r.reverse_kernels[0].net[0].weight

Parameter containing:
tensor([[-0.0032, -0.0102, -0.0502,  ..., -0.0091, -0.0220,  0.0287],
        [ 0.0638, -0.0587,  0.0589,  ..., -0.0230, -0.0417, -0.0492],
        [-0.0159,  0.0059, -0.0290,  ...,  0.0004, -0.0329, -0.0417],
        ...,
        [ 0.0549,  0.0054, -0.0362,  ...,  0.0616,  0.0196, -0.0702],
        [ 0.0133,  0.0698, -0.0303,  ...,  0.0267,  0.0336, -0.0071],
        [-0.0444,  0.0682, -0.0008,  ..., -0.0559,  0.0489,  0.0070]],
       device='cuda:1', requires_grad=True)

In [15]:
run_trainer(ula_5_r)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name            | Type             | Params
-----------------------------------------------------
0 | encoder_net     | FC_encoder_mnist | 157 K 
1 | decoder_net     | FC_decoder_mnist | 79.2 K
2 | transitions_nll | ModuleList       | 8     
3 | transitions     | ModuleList       | 1     
4 | reverse_kernels | ModuleList       | 80.4 K
-----------------------------------------------------
316 K     Trainable params
9         Non-trainable params
316 K     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…



HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




In [16]:
ula_5_r = ula_5_r.to(device)

In [17]:
run_trainer(ula_10_r)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name            | Type             | Params
-----------------------------------------------------
0 | encoder_net     | FC_encoder_mnist | 157 K 
1 | decoder_net     | FC_decoder_mnist | 79.2 K
2 | transitions_nll | ModuleList       | 8     
3 | transitions     | ModuleList       | 2     
4 | reverse_kernels | ModuleList       | 160 K 
-----------------------------------------------------
396 K     Trainable params
10        Non-trainable params
396 K     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




In [18]:
ula_10_r = ula_10_r.to(device)

In [19]:
output_ula_5 = run_exp(ula_5)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=200.0), HTML(value='')))




In [20]:
output_ula_10 = run_exp(ula_10)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=200.0), HTML(value='')))




In [21]:
output_ula_5_r = run_exp(ula_5_r)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=200.0), HTML(value='')))




In [22]:
output_ula_10_r = run_exp(ula_10_r)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=200.0), HTML(value='')))




In [23]:
output_ais_5 = run_exp(ais_5)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=200.0), HTML(value='')))




In [24]:
output_ais_10 = run_exp(ais_10)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=200.0), HTML(value='')))




In [25]:
output_ais_5_3 = run_exp(ais_5_3)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=200.0), HTML(value='')))




In [26]:
def plot_beautiful_boxplots(list_of_things_to_plot, list_of_titles, title_file=None, grad=False):
    plt.close()
    plt.figure(figsize = (5, 3), dpi=200)

    #sns.violinplot(data = total_res_10)
    #sns.swarmplot(data = list_of_things_to_plot, size = 2.5, dodge = True, alpha = .8)
    sns.boxplot(data=list_of_things_to_plot, showfliers=False)
    
    if grad:
        plt.yticks([])
        
    plt.xticks(range(len(list_of_titles)), list_of_titles, rotation=7, fontsize=6)
    plt.tight_layout()
    plt.show()
    if title_file is not None:
        plt.savefig(title_file +'.pdf')

In [27]:
outputs = [output_ula_5, output_ula_10, output_ula_5_r, output_ula_10_r, output_ais_5, output_ais_10, output_ais_5_3]
# outputs = [output_ula_5, output_ula_10, output_ais_5, output_ais_10, output_ais_5_3]

In [28]:
# list_of_titles = ['L-MCVAE, K=5', 'L-MCVAE, K=10', 'A-MCVAE, K=5', 'A-MCVAE, K=10', 'A-MCVAE, K=5, reduced variance',]# 'L-MCVAE-5-r', 'L-MCVAE-10-r']
list_of_titles = ['L-MCVAE, K=5', 'L-MCVAE, K=10', 'L-MCVAE-5-r', 'L-MCVAE-10-r', 'A-MCVAE, K=5', 'A-MCVAE, K=10', 'A-MCVAE, K=5, reduced variance',]#

In [30]:
plot_beautiful_boxplots([l[0].mean(0) - l[1] for l in outputs], list_of_titles, 'est')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
coord = 18
plot_beautiful_boxplots([l[2][:, coord] for l in outputs], list_of_titles, 'grad_est', grad=True)

In [None]:
outputs[0][2][10].shape

In [None]:
# 19

In [None]:
outputs = [output_ula_5, output_ula_10, output_ula_5_r, output_ula_10_r, output_ais_5, output_ais_10, output_ais_5_3]
list_of_titles = ['L-MCVAE, K=5', 'L-MCVAE, K=10', 'L-MCVAE-5-r', 'L-MCVAE-10-r', 'A-MCVAE, K=5', 'A-MCVAE, K=10', 'A-MCVAE, K=5, reduced variance']

In [None]:
plot_beautiful_boxplots([l[0].mean(0) - l[1] for l in outputs], list_of_titles, 'est')