In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from utils import make_dataloaders
from models.vaes import Base, VAE, IWAE, AIS_VAE, ULA_VAE, Stacked_VAE, VAE_with_flows, repeat_data
from models.samplers import HMC, MALA, ULA, run_chain
import yaml
import numpy as np
from scipy.stats import norm
import copy
from tqdm.auto import tqdm

import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers

from inspect import signature

%matplotlib widget

colors = {
    0: 'blue',
    1: 'red',
    2: 'green',
    3: 'yellow',
    4: 'black',
    5: 'orange',
}

In [None]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [None]:
train_loader, val_loader = make_dataloaders(dataset='mnist', batch_size=100, val_batch_size=100, binarize=True)

In [4]:
batch = None
for v_b in val_loader:
    batch = v_b
    break

In [5]:
def load_model(version):
    with open(f'lightning_logs/default/version_{version}/hparams.yaml') as file:
        fruits_list = yaml.load(file, Loader=yaml.FullLoader)
        print(fruits_list)
        hparams = fruits_list
        
    path = f'lightning_logs/default/version_{version}/checkpoints/'
    file_name = os.listdir(path)[0]
    checkpoint = torch.load(f'{path}{file_name}')
    
    for current_model in [VAE, IWAE, ULA_VAE, AIS_VAE]:
        try:
            model = current_model(**hparams).to(device)
            model.load_state_dict(checkpoint['state_dict'])
        except:
            pass
        else:
            print(f'loaded {model.name}')
            return model

In [6]:
version = 794
iwae = load_model(version=version)

{'act_func': <class 'torch.nn.modules.activation.GELU'>, 'dataset': 'mnist', 'hidden_dim': 100, 'name': 'IWAE', 'net_type': 'fc', 'num_samples': 50, 'shape': 28, 'sigma': 0.1, 'specific_likelihood': 'gaussian'}
loaded IWAE


In [7]:
def get_transitions_output(model, z, mu, logvar, x):
    x = repeat_data(x, model.num_samples)
    output = model.run_transitions(z=z,
                                    x=x,
                                   mu=mu,
                                   logvar=logvar)
    if len(str(signature(model.loss_function)).split(',')) > 1:
        loss = model.loss_function(sum_log_alphas=output[2], sum_log_weights=output[1])
    else:
        loss = model.loss_function(sum_log_weights=output[1])
    grad = torch.autograd.grad(loss, model.decoder_net.net[0].bias)[0][:50]
    return output, grad

In [8]:
sigma = 0.1

In [9]:
model_W = iwae.decoder_net.net[0].weight.data
model_mu = iwae.decoder_net.net[0].bias.data[..., None]
model_mu.requires_grad_(True)

C = (model_W @ model_W.T) + (sigma**2) * torch.eye(784, device=device)
C_inv = torch.inverse(C)
logdetC = torch.logdet(C)

first_term = 784 * np.log(2 * np.pi) + logdetC

def get_true_loglikelihood(x):
    true_loglikelihood = torch.empty(x.shape[0], device=device, dtype=torch.float32)
    for i in range(x.shape[0]):
        x_cur = x[i].view(784, 1)
        S = (x_cur - model_mu) @ (x_cur - model_mu).T
        true_loglikelihood[i] = -0.5 * (first_term + torch.trace(C_inv @ S))
    grad_true = torch.autograd.grad(true_loglikelihood.sum(), model_mu)[0][:50]
    return true_loglikelihood, grad_true

In [10]:
class ULA_VAE_reverse(ULA_VAE):
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.reverse_kernels.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.25)
        return [optimizer], [scheduler]

In [11]:
# ------------------------------------------------------------------------------------------------------------------------------------------------------------
# NO reverse

# ----- ULA_VAE ----- #
ula_5 = ULA_VAE(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
            step_size=0.003, K=5, use_transforms=False, learnable_transitions=False, return_pre_alphas=True, use_score_matching=False,
                      ula_skip_threshold=0.1, grad_skip_val=0., grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.9, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma).to(device)
ula_5.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ula_5.decoder_net.parameters():
    p.requires_grad_(True)
ula_5.encoder_net = copy.deepcopy(iwae.encoder_net)
ula_5.use_stepsize_update = False

# ----- ULA_VAE ----- #
ula_10 = ULA_VAE(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
            step_size=0.003, K=10, use_transforms=False, learnable_transitions=False, return_pre_alphas=True, use_score_matching=False,
                      ula_skip_threshold=0.1, grad_skip_val=0., grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.9, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma).to(device)
ula_10.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ula_10.decoder_net.parameters():
    p.requires_grad_(True)
ula_10.encoder_net = copy.deepcopy(iwae.encoder_net)
ula_10.use_stepsize_update = False


# -------------------------------------------------------------------------------------------------------------------------------------------------------------

# Reverse

# ----- ULA_VAE ----- #
ula_5_r = ULA_VAE_reverse(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
            step_size=0.003, K=1, use_transforms=False, learnable_transitions=False, return_pre_alphas=True, use_score_matching=False,
                      ula_skip_threshold=0.1, grad_skip_val=0., grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.9, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma, use_reverse_kernel=True).to(device)
ula_5_r.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ula_5_r.decoder_net.parameters():
    p.requires_grad_(True)
ula_5_r.encoder_net = copy.deepcopy(iwae.encoder_net)
ula_5_r.use_stepsize_update = False


# ----- ULA_VAE ----- #
ula_10_r = ULA_VAE_reverse(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
            step_size=0.003, K=2, use_transforms=False, learnable_transitions=False, return_pre_alphas=True, use_score_matching=False,
                      ula_skip_threshold=0.1, grad_skip_val=0., grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.9, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma, use_reverse_kernel=True).to(device)
ula_10_r.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ula_10_r.decoder_net.parameters():
    p.requires_grad_(True)
ula_10_r.encoder_net = copy.deepcopy(iwae.encoder_net)
ula_10_r.use_stepsize_update = False




# -------------------------------------------------------------------------------------------------------------------------------------------------------------

# No multisample

# ----- AIS_VAE ----- #
ais_5 = AIS_VAE(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
                  step_size=0.003, K=5, use_barker=False, learnable_transitions=False, use_alpha_annealing=True, grad_skip_val=0.,
                      grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.8, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma).to(device)
ais_5.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ais_5.decoder_net.parameters():
    p.requires_grad_(True)
ais_5.encoder_net = copy.deepcopy(iwae.encoder_net)
ais_5.use_stepsize_update = False


# ----- AIS_VAE ----- #
ais_10 = AIS_VAE(shape=28, act_func=nn.LeakyReLU, num_samples=1, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
                  step_size=0.003, K=10, use_barker=False, learnable_transitions=False, use_alpha_annealing=True, grad_skip_val=0.,
                      grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.8, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma).to(device)
ais_10.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ais_10.decoder_net.parameters():
    p.requires_grad_(True)
ais_10.encoder_net = copy.deepcopy(iwae.encoder_net)
ais_10.use_stepsize_update = False


# -------------------------------------------------------------------------------------------------------------------------------------------------------------

# Multisample

# ----- AIS_VAE ----- #
ais_5_3 = AIS_VAE(shape=28, act_func=nn.LeakyReLU, num_samples=10, hidden_dim=iwae.hidden_dim, net_type='fc', dataset='mnist',
                  step_size=0.003, K=5, use_barker=False, learnable_transitions=False, use_alpha_annealing=True, grad_skip_val=0.,
                      grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=False,
                     acceptance_rate_target=0.8, annealing_scheme='linear', specific_likelihood='gaussian', sigma=sigma).to(device)
ais_5_3.decoder_net = copy.deepcopy(iwae.decoder_net)
for p in ais_5_3.decoder_net.parameters():
    p.requires_grad_(True)
ais_5_3.encoder_net = copy.deepcopy(iwae.encoder_net)
ais_5_3.use_stepsize_update = False

In [12]:
def run_exp(model, n = 200):
    x, _ = batch
    x = x.to(device)
    z, mu, logvar = iwae.enc_rep(x, model.num_samples) # <- latents are fixed

    model_w = torch.tensor([], device=device, dtype=torch.float32)
    model_g = []

    true_loglikelihood_, grad_true = get_true_loglikelihood(x)
    #true_loglikelihood = true_loglikelihood_.repeat(model.num_samples).cpu().detach().numpy()
    true_loglikelihood_mean = np.mean(true_loglikelihood_.cpu().detach().numpy())
    for i in tqdm(range(n)):
        model_log_w, grad_model = get_transitions_output(model, z, mu, logvar, x)

        with torch.no_grad():
            model_log_w = model_log_w[1]

            model_w = torch.cat([model_w, model_log_w[..., None]], dim=1)


            model_g.append(grad_model.cpu().detach().numpy())

 
    return model_w.cpu().detach().numpy(), true_loglikelihood_mean, np.array(model_g), grad_true.cpu().detach().numpy()

In [13]:
def run_trainer(model):
    tb_logger = pl_loggers.TensorBoardLogger('lightning_logs/')
    trainer = pl.Trainer(logger=tb_logger, fast_dev_run=False, max_epochs=5, automatic_optimization=True, gpus=1)
    trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)

In [14]:
ula_5_r.reverse_kernels[0].net[0].weight

Parameter containing:
tensor([[-0.0670,  0.0324,  0.0665,  ..., -0.0644,  0.0245, -0.0544],
        [-0.0212, -0.0665,  0.0074,  ..., -0.0375, -0.0210,  0.0423],
        [-0.0325, -0.0565,  0.0307,  ...,  0.0049,  0.0565,  0.0462],
        ...,
        [-0.0529,  0.0007,  0.0208,  ..., -0.0108,  0.0630, -0.0531],
        [-0.0366,  0.0108,  0.0113,  ...,  0.0618,  0.0569, -0.0484],
        [ 0.0210, -0.0170,  0.0544,  ..., -0.0077, -0.0486,  0.0012]],
       device='cuda:1', requires_grad=True)

In [None]:
run_trainer(ula_5_r)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name            | Type             | Params
-----------------------------------------------------
0 | encoder_net     | FC_encoder_mnist | 157 K 
1 | decoder_net     | FC_decoder_mnist | 79.2 K
2 | transitions_nll | ModuleList       | 8     
3 | transitions     | ModuleList       | 1     
4 | reverse_kernels | ModuleList       | 80.4 K
-----------------------------------------------------
316 K     Trainable params
9         Non-trainable params
316 K     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

> [0;32m/home/nkotelevskii/github/metflow/models/samplers.py[0m(295)[0;36mmake_transition[0;34m()[0m
[0;32m    293 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    294 [0;31m            [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 295 [0;31m        [0mproposal_density_denominator[0m [0;34m=[0m [0mstd_normal[0m[0;34m.[0m[0mlog_prob[0m[0;34m([0m[0meps[0m[0;34m)[0m[0;34m.[0m[0msum[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    296 [0;31m[0;34m[0m[0m
[0m[0;32m    297 [0;31m        [0mz_new[0m [0;34m=[0m [0mz_upd[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  l


[1;32m    290 [0m            [0mmu[0m[0;34m,[0m [0mlogvar[0m [0;34m=[0m [0mreverse_kernel[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0mcat[0m[0;34m([0m[0;34m[[0m[0mz_upd[0m[0;34m,[0m [0mmu_amortize[0m[0;34m][0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[1;32m    291 [0m            proposal_density_numerator = torch.distributions.Normal(loc=mu, scale=torch.exp(0.5 * logvar)).log_prob(
[1;32m    292 [0m                z).sum(1)
[1;32m    293 [0m            [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[1;32m    294 [0m            [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;32m--> 295 [0;31m        [0mproposal_density_denominator[0m [0;34m=[0m [0mstd_normal[0m[0;34m.[0m[0mlog_prob[0m[0;34m([0m[0meps[0m[0;34m)[0m[0;34m.[0m[0msum[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[1;32m    296 [0m[0;34m

ipdb>  proposal_density_numerator


tensor([-95.3776, -95.8059, -94.4750, -94.1509, -94.9008, -95.1708, -93.9506,
        -95.0646, -93.0227, -94.1119, -95.3042, -94.0263, -96.5608, -94.9266,
        -93.3799, -94.7214, -93.8542, -94.9650, -93.8211, -93.8972, -95.5769,
        -96.5783, -94.1705, -93.8395, -94.9955, -96.5774, -93.3102, -97.7786,
        -97.4422, -94.1206, -94.7868, -94.9281, -94.6953, -94.4966, -95.6322,
        -94.2076, -96.3949, -96.3103, -94.7338, -95.2991, -93.6644, -94.9886,
        -93.5534, -94.1277, -93.5960, -95.2754, -94.1880, -94.9850, -93.7809,
        -95.4933, -93.5445, -97.5387, -95.2633, -94.2864, -94.2159, -96.4697,
        -96.9878, -95.2658, -96.5847, -93.5586, -95.9307, -93.9657, -95.1893,
        -96.8474, -94.4391, -93.8481, -95.2231, -94.0024, -96.0819, -96.4588,
        -94.7752, -93.9633, -93.5512, -95.9609, -94.8333, -95.8542, -95.3843,
        -94.1602, -94.9772, -95.4344, -94.1814, -95.8854, -96.0177, -95.2733,
        -94.5953, -93.6507, -94.4008, -95.3690, -95.6521, -94.82

ipdb>  n


> [0;32m/home/nkotelevskii/github/metflow/models/samplers.py[0m(297)[0;36mmake_transition[0;34m()[0m
[0;32m    295 [0;31m        [0mproposal_density_denominator[0m [0;34m=[0m [0mstd_normal[0m[0;34m.[0m[0mlog_prob[0m[0;34m([0m[0meps[0m[0;34m)[0m[0;34m.[0m[0msum[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    296 [0;31m[0;34m[0m[0m
[0m[0;32m--> 297 [0;31m        [0mz_new[0m [0;34m=[0m [0mz_upd[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    298 [0;31m[0;34m[0m[0m
[0m[0;32m    299 [0;31m        [0;31m###[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  proposal_density_denominator


tensor([-146.4310, -140.6102, -139.3007, -141.9236, -143.5369, -135.6804,
        -141.8717, -143.9990, -134.0378, -141.8576, -147.4485, -157.5022,
        -147.4998, -142.9543, -155.0458, -143.4251, -154.0109, -139.0259,
        -138.4986, -141.6771, -144.8130, -139.0892, -134.6164, -157.0891,
        -149.0808, -163.8890, -135.5802, -133.1441, -132.5892, -146.1041,
        -143.8495, -140.1778, -148.4436, -140.0159, -146.3909, -141.1794,
        -139.1979, -135.1289, -148.4689, -134.5874, -138.1894, -132.1818,
        -138.9670, -151.2648, -136.1422, -154.1900, -140.7306, -141.6888,
        -133.0919, -135.2489, -128.4404, -135.6926, -151.4668, -134.9214,
        -138.6302, -138.7498, -142.9134, -139.3253, -136.8326, -139.7886,
        -137.5486, -145.8272, -144.9510, -130.8595, -153.5489, -132.1464,
        -138.6730, -142.0182, -145.7953, -131.9983, -134.0290, -137.4797,
        -150.0326, -140.0414, -156.3582, -151.1036, -147.0923, -142.0720,
        -153.4742, -136.0888, -139.944

ipdb>  c


> [0;32m/home/nkotelevskii/github/metflow/models/samplers.py[0m(295)[0;36mmake_transition[0;34m()[0m
[0;32m    293 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    294 [0;31m            [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 295 [0;31m        [0mproposal_density_denominator[0m [0;34m=[0m [0mstd_normal[0m[0;34m.[0m[0mlog_prob[0m[0;34m([0m[0meps[0m[0;34m)[0m[0;34m.[0m[0msum[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    296 [0;31m[0;34m[0m[0m
[0m[0;32m    297 [0;31m        [0mz_new[0m [0;34m=[0m [0mz_upd[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c




HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

> [0;32m/home/nkotelevskii/github/metflow/models/samplers.py[0m(295)[0;36mmake_transition[0;34m()[0m
[0;32m    293 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    294 [0;31m            [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 295 [0;31m        [0mproposal_density_denominator[0m [0;34m=[0m [0mstd_normal[0m[0;34m.[0m[0mlog_prob[0m[0;34m([0m[0meps[0m[0;34m)[0m[0;34m.[0m[0msum[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    296 [0;31m[0;34m[0m[0m
[0m[0;32m    297 [0;31m        [0mz_new[0m [0;34m=[0m [0mz_upd[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  l


[1;32m    290 [0m            [0mmu[0m[0;34m,[0m [0mlogvar[0m [0;34m=[0m [0mreverse_kernel[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0mcat[0m[0;34m([0m[0;34m[[0m[0mz_upd[0m[0;34m,[0m [0mmu_amortize[0m[0;34m][0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[1;32m    291 [0m            proposal_density_numerator = torch.distributions.Normal(loc=mu, scale=torch.exp(0.5 * logvar)).log_prob(
[1;32m    292 [0m                z).sum(1)
[1;32m    293 [0m            [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[1;32m    294 [0m            [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;32m--> 295 [0;31m        [0mproposal_density_denominator[0m [0;34m=[0m [0mstd_normal[0m[0;34m.[0m[0mlog_prob[0m[0;34m([0m[0meps[0m[0;34m)[0m[0;34m.[0m[0msum[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[1;32m    296 [0m[0;34m

ipdb>  proposal_density_numerator


tensor([-94.5317, -94.7253, -96.6883, -96.0809, -94.2829, -95.5246, -93.6617,
        -93.8099, -94.4432, -95.2868, -95.8652, -94.8354, -94.4960, -95.6355,
        -95.0625, -94.9936, -95.6617, -94.3483, -93.5368, -96.9681, -93.8587,
        -94.1969, -94.2508, -94.4562, -95.0666, -96.0399, -97.1531, -94.8723,
        -95.0929, -95.8054, -94.9713, -94.9028, -94.4580, -96.7896, -95.6754,
        -95.0082, -96.0695, -96.5562, -95.8982, -93.6395, -96.5245, -95.7232,
        -95.7653, -96.4272, -96.8164, -95.0334, -93.9237, -97.4513, -94.5016,
        -94.2400, -95.3942, -95.1177, -93.3016, -95.6268, -96.0296, -95.5117,
        -96.5693, -95.0053, -95.2332, -95.5879, -95.4145, -94.6781, -93.8831,
        -95.5221, -96.9571, -95.0927, -96.3691, -94.5027, -95.5555, -94.9506,
        -95.7418, -94.0392, -96.3006, -94.8357, -93.9037, -93.9412, -94.5750,
        -96.8072, -94.0914, -95.4753, -94.3898, -94.4171, -95.3787, -95.6731,
        -95.1285, -96.3996, -95.1335, -94.8360, -98.9753, -94.36

ipdb>  n


> [0;32m/home/nkotelevskii/github/metflow/models/samplers.py[0m(297)[0;36mmake_transition[0;34m()[0m
[0;32m    295 [0;31m        [0mproposal_density_denominator[0m [0;34m=[0m [0mstd_normal[0m[0;34m.[0m[0mlog_prob[0m[0;34m([0m[0meps[0m[0;34m)[0m[0;34m.[0m[0msum[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    296 [0;31m[0;34m[0m[0m
[0m[0;32m--> 297 [0;31m        [0mz_new[0m [0;34m=[0m [0mz_upd[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    298 [0;31m[0;34m[0m[0m
[0m[0;32m    299 [0;31m        [0;31m###[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  proposal_density_denominator


tensor([-132.6566, -138.1509, -136.4125, -147.7957, -141.0547, -143.9251,
        -150.3295, -138.8692, -139.5088, -157.0158, -142.9257, -134.1563,
        -144.2855, -144.2211, -135.5975, -141.2521, -153.8958, -131.0676,
        -138.4034, -127.4715, -154.0866, -147.9740, -149.1210, -140.4524,
        -137.2724, -143.9175, -135.7776, -137.3575, -144.6543, -141.9239,
        -139.6571, -145.7996, -151.4924, -137.8004, -148.6528, -140.2309,
        -145.1800, -132.9691, -149.0033, -144.9300, -140.0458, -147.3461,
        -154.6968, -151.4992, -145.9287, -142.9902, -154.4283, -140.9539,
        -134.7433, -144.0301, -133.8429, -150.0531, -141.6931, -143.7299,
        -147.4802, -142.3478, -160.2601, -150.2877, -136.1246, -142.4634,
        -146.7168, -133.0498, -136.8847, -145.8396, -153.1057, -132.4257,
        -144.2570, -148.9386, -132.9288, -148.3356, -150.7859, -147.7441,
        -141.5187, -148.0009, -137.7659, -142.4707, -153.6358, -132.9777,
        -149.7246, -141.9248, -137.133

ipdb>  c


> [0;32m/home/nkotelevskii/github/metflow/models/samplers.py[0m(295)[0;36mmake_transition[0;34m()[0m
[0;32m    293 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    294 [0;31m            [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 295 [0;31m        [0mproposal_density_denominator[0m [0;34m=[0m [0mstd_normal[0m[0;34m.[0m[0mlog_prob[0m[0;34m([0m[0meps[0m[0;34m)[0m[0;34m.[0m[0msum[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    296 [0;31m[0;34m[0m[0m
[0m[0;32m    297 [0;31m        [0mz_new[0m [0;34m=[0m [0mz_upd[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/home/nkotelevskii/github/metflow/models/samplers.py[0m(295)[0;36mmake_transition[0;34m()[0m
[0;32m    293 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    294 [0;31m            [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 295 [0;31m        [0mproposal_density_denominator[0m [0;34m=[0m [0mstd_normal[0m[0;34m.[0m[0mlog_prob[0m[0;34m([0m[0meps[0m[0;34m)[0m[0;34m.[0m[0msum[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    296 [0;31m[0;34m[0m[0m
[0m[0;32m    297 [0;31m        [0mz_new[0m [0;34m=[0m [0mz_upd[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/home/nkotelevskii/github/metflow/models/samplers.py[0m(295)[0;36mmake_transition[0;34m()[0m
[0;32m    293 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    294 [0;31m            [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 295 [0;31m        [0mproposal_density_denominator[0m [0;34m=[0m [0mstd_normal[0m[0;34m.[0m[0mlog_prob[0m[0;34m([0m[0meps[0m[0;34m)[0m[0;34m.[0m[0msum[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    296 [0;31m[0;34m[0m[0m
[0m[0;32m    297 [0;31m        [0mz_new[0m [0;34m=[0m [0mz_upd[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/home/nkotelevskii/github/metflow/models/samplers.py[0m(295)[0;36mmake_transition[0;34m()[0m
[0;32m    293 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    294 [0;31m            [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 295 [0;31m        [0mproposal_density_denominator[0m [0;34m=[0m [0mstd_normal[0m[0;34m.[0m[0mlog_prob[0m[0;34m([0m[0meps[0m[0;34m)[0m[0;34m.[0m[0msum[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    296 [0;31m[0;34m[0m[0m
[0m[0;32m    297 [0;31m        [0mz_new[0m [0;34m=[0m [0mz_upd[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/home/nkotelevskii/github/metflow/models/samplers.py[0m(295)[0;36mmake_transition[0;34m()[0m
[0;32m    293 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    294 [0;31m            [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 295 [0;31m        [0mproposal_density_denominator[0m [0;34m=[0m [0mstd_normal[0m[0;34m.[0m[0mlog_prob[0m[0;34m([0m[0meps[0m[0;34m)[0m[0;34m.[0m[0msum[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    296 [0;31m[0;34m[0m[0m
[0m[0;32m    297 [0;31m        [0mz_new[0m [0;34m=[0m [0mz_upd[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  l


[1;32m    290 [0m            [0mmu[0m[0;34m,[0m [0mlogvar[0m [0;34m=[0m [0mreverse_kernel[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0mcat[0m[0;34m([0m[0;34m[[0m[0mz_upd[0m[0;34m,[0m [0mmu_amortize[0m[0;34m][0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[1;32m    291 [0m            proposal_density_numerator = torch.distributions.Normal(loc=mu, scale=torch.exp(0.5 * logvar)).log_prob(
[1;32m    292 [0m                z).sum(1)
[1;32m    293 [0m            [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[1;32m    294 [0m            [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;32m--> 295 [0;31m        [0mproposal_density_denominator[0m [0;34m=[0m [0mstd_normal[0m[0;34m.[0m[0mlog_prob[0m[0;34m([0m[0meps[0m[0;34m)[0m[0;34m.[0m[0msum[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[1;32m    296 [0m[0;34m

ipdb>  proposal_density_numerator


tensor([-90.1474, -89.8752, -89.8598, -90.0304, -89.6248, -89.6380, -89.8418,
        -89.8358, -89.1677, -91.0337, -90.4426, -90.3195, -90.8715, -89.2637,
        -90.0070, -90.3618, -90.3828, -90.1516, -89.6788, -90.2895, -91.1151,
        -89.7854, -90.7627, -90.9568, -89.9686, -89.6017, -90.5007, -89.2577,
        -90.1455, -90.5593, -90.0968, -88.8921, -89.8700, -90.6064, -90.6168,
        -89.4250, -90.8374, -90.6886, -90.4815, -89.9760, -90.4556, -89.7790,
        -91.6048, -90.2273, -89.6544, -90.3576, -90.5672, -90.6879, -90.6878,
        -90.0638, -90.2425, -90.3293, -89.9753, -90.7523, -90.1524, -89.1491,
        -89.7670, -90.4344, -90.4239, -89.0687, -90.8441, -91.2478, -90.3467,
        -90.4700, -90.3797, -90.7043, -90.0969, -90.1593, -88.6058, -90.5237,
        -88.9268, -89.5062, -90.5376, -89.9160, -91.1390, -90.3023, -90.0065,
        -90.4652, -89.7879, -91.2493, -91.1123, -90.3531, -89.1107, -89.8702,
        -90.3288, -90.4971, -89.5395, -89.9662, -89.9535, -89.04

ipdb>  proposal_density_denominator


*** NameError: name 'proposal_density_denominator' is not defined


ipdb>  n


> [0;32m/home/nkotelevskii/github/metflow/models/samplers.py[0m(297)[0;36mmake_transition[0;34m()[0m
[0;32m    295 [0;31m        [0mproposal_density_denominator[0m [0;34m=[0m [0mstd_normal[0m[0;34m.[0m[0mlog_prob[0m[0;34m([0m[0meps[0m[0;34m)[0m[0;34m.[0m[0msum[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    296 [0;31m[0;34m[0m[0m
[0m[0;32m--> 297 [0;31m        [0mz_new[0m [0;34m=[0m [0mz_upd[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    298 [0;31m[0;34m[0m[0m
[0m[0;32m    299 [0;31m        [0;31m###[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  proposal_density_denominator


tensor([-137.1668, -149.8611, -151.2194, -140.2371, -151.1601, -142.2115,
        -129.7785, -133.8708, -139.5077, -143.1925, -147.3088, -153.4055,
        -134.9669, -137.3990, -145.6248, -144.7758, -144.2202, -138.5620,
        -131.7334, -141.7127, -146.2631, -143.2713, -137.4313, -143.3028,
        -166.4186, -156.6298, -129.3212, -141.9250, -154.9762, -135.1765,
        -142.4826, -134.7117, -142.1162, -130.6470, -145.0455, -130.4093,
        -134.4709, -134.4338, -142.5820, -149.8792, -134.5950, -140.5591,
        -133.3982, -149.4653, -141.7820, -143.8646, -159.0844, -138.1960,
        -133.7022, -144.3874, -141.9352, -137.8752, -135.7475, -131.7495,
        -133.1917, -143.2497, -129.4335, -139.3657, -131.5886, -140.1616,
        -144.5827, -139.6280, -143.9297, -153.1750, -141.5427, -152.7619,
        -147.4227, -147.0909, -142.4054, -144.4585, -142.6401, -136.9293,
        -145.1759, -135.9094, -144.1416, -133.5063, -139.6616, -142.2040,
        -155.2632, -143.6912, -136.375

In [16]:
[p for p in ula_5_r.reverse_kernels.parameters()]

[Parameter containing:
 tensor([[ 0.0053, -0.0491,  0.0285,  ...,  0.0390,  0.0034,  0.0400],
         [-0.0881,  0.0659,  0.0072,  ..., -0.0059,  0.0004, -0.1810],
         [ 0.0620,  0.0009, -0.0101,  ..., -0.1145,  0.2492,  0.1068],
         ...,
         [ 0.0322, -0.0928,  0.0628,  ...,  0.0707,  0.1588,  0.0102],
         [ 0.0459, -0.0771,  0.0069,  ...,  0.1253, -0.0568, -0.0779],
         [-0.0007, -0.0386, -0.0108,  ..., -0.2071,  0.0102,  0.1555]],
        requires_grad=True),
 Parameter containing:
 tensor([ 0.2657, -0.2468,  0.1401, -0.2648, -0.2313, -0.2938,  0.1642,  0.2663,
         -0.3005,  0.2815, -0.2469,  0.2552,  0.2399, -0.2619,  0.2035,  0.2566,
          0.2356, -0.2278, -0.1630, -0.2653,  0.2357, -0.2679, -0.2331, -0.2440,
          0.1170, -0.2291,  0.2686, -0.2909,  0.2376,  0.1965, -0.2018, -0.1411,
         -0.2472,  0.1608,  0.2480,  0.1655, -0.2181,  0.1919,  0.2273, -0.2739,
          0.2601, -0.2629,  0.1592,  0.1845,  0.1889,  0.2549,  0.2065,  0.1630

In [17]:
ula_5_r = ula_5_r.to(device)

In [18]:
run_trainer(ula_10_r)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name            | Type             | Params
-----------------------------------------------------
0 | encoder_net     | FC_encoder_mnist | 157 K 
1 | decoder_net     | FC_decoder_mnist | 79.2 K
2 | transitions_nll | ModuleList       | 8     
3 | transitions     | ModuleList       | 2     
4 | reverse_kernels | ModuleList       | 160 K 
-----------------------------------------------------
396 K     Trainable params
10        Non-trainable params
396 K     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




In [19]:
ula_10_r = ula_10_r.to(device)

In [20]:
output_ula_5 = run_exp(ula_5)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=200.0), HTML(value='')))




In [21]:
output_ula_10 = run_exp(ula_10)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=200.0), HTML(value='')))




In [22]:
output_ula_5_r = run_exp(ula_5_r)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=200.0), HTML(value='')))




In [23]:
output_ula_10_r = run_exp(ula_10_r)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=200.0), HTML(value='')))




In [24]:
output_ais_5 = run_exp(ais_5)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=200.0), HTML(value='')))




In [25]:
output_ais_10 = run_exp(ais_10)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=200.0), HTML(value='')))




In [26]:
output_ais_5_3 = run_exp(ais_5_3)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=200.0), HTML(value='')))




In [27]:
def plot_beautiful_boxplots(list_of_things_to_plot, list_of_titles, title_file=None, grad=False):
    plt.close()
    plt.figure(figsize = (5, 3), dpi=200)

    #sns.violinplot(data = total_res_10)
    #sns.swarmplot(data = list_of_things_to_plot, size = 2.5, dodge = True, alpha = .8)
    sns.boxplot(data=list_of_things_to_plot, showfliers=False)
    
    if grad:
        plt.yticks([])
        
    plt.xticks(range(len(list_of_titles)), list_of_titles, rotation=7, fontsize=6)
    plt.tight_layout()
    plt.show()
    if title_file is not None:
        plt.savefig(title_file +'.pdf')

In [28]:
outputs = [output_ula_5, output_ula_10, output_ula_5_r, output_ula_10_r, output_ais_5, output_ais_10, output_ais_5_3]
# outputs = [output_ula_5, output_ula_10, output_ais_5, output_ais_10, output_ais_5_3]

In [29]:
# list_of_titles = ['L-MCVAE, K=5', 'L-MCVAE, K=10', 'A-MCVAE, K=5', 'A-MCVAE, K=10', 'A-MCVAE, K=5, reduced variance',]# 'L-MCVAE-5-r', 'L-MCVAE-10-r']
list_of_titles = ['L-MCVAE, K=5', 'L-MCVAE, K=10', 'L-MCVAE-5-r', 'L-MCVAE-10-r', 'A-MCVAE, K=5', 'A-MCVAE, K=10', 'A-MCVAE, K=5, reduced variance',]#

In [30]:
plot_beautiful_boxplots([l[0].mean(0) - l[1] for l in outputs], list_of_titles, 'est')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
coord = 18
plot_beautiful_boxplots([l[2][:, coord] for l in outputs], list_of_titles, 'grad_est', grad=True)

In [None]:
outputs[0][2][10].shape

In [None]:
# 19

In [None]:
outputs = [output_ula_5, output_ula_10, output_ula_5_r, output_ula_10_r, output_ais_5, output_ais_10, output_ais_5_3]
list_of_titles = ['L-MCVAE, K=5', 'L-MCVAE, K=10', 'L-MCVAE-5-r', 'L-MCVAE-10-r', 'A-MCVAE, K=5', 'A-MCVAE, K=10', 'A-MCVAE, K=5, reduced variance']

In [None]:
plot_beautiful_boxplots([l[0].mean(0) - l[1] for l in outputs], list_of_titles, 'est')