In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from models.vaes import Base, VAE, IWAE, AIS_VAE, ULA_VAE, Stacked_VAE, VAE_with_flows
from models.samplers import HMC, MALA, ULA, run_chain
from utils import make_dataloaders
import yaml
import numpy as np

import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers

%matplotlib widget

In [2]:
pl.__version__

'1.1.3'

In [3]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [4]:
use_true = False
problem_name = 'circle' # 'parabola', 'circle'
def generate_dataset(N, eps0=None, d=2, sigma=1.):
    if problem_name == 'parabola':
        z = np.random.randn(N, d)
        x = eps0 * z**2 + np.random.randn(N, d) * sigma
        return x
    elif problem_name == 'circle':
        z = np.random.randn(N, d)
        x = 2 * np.pi * (np.linalg.norm(z, axis=1, keepdims=True) + 2.) + np.random.randn(N, 1) * sigma
        return x
    else:
        raise NotImplementedError

In [5]:
def additional_loss(alpha, beta):
    return 0. #torch.pow(alpha + beta - 2 * np.pi - 2., 2.) + torch.pow(alpha - beta - 2 * np.pi + 2., 2.)

In [6]:
class Toy(Base):   
    def joint_logdensity(self, use_true_decoder=None):
        def density(z, x):
            if (use_true_decoder is not None) and use_true_decoder:
                x_reconst = self(z)
            elif hasattr(self, 'use_cloned_decoder') and self.use_cloned_decoder:
                x_reconst = self.cloned_decoder(z)
            else:
                x_reconst = self(z)
            log_Pr = torch.distributions.Normal(loc=torch.tensor(0., device=x.device, dtype=torch.float32),
                                    scale=torch.tensor(1., device=x.device, dtype=torch.float32)).log_prob(
                    z).sum(-1)
            
            if not use_true: # if we want to use true generative process
                return torch.distributions.Normal(loc=self.decoder_net.alpha * (torch.sqrt(torch.sum(torch.pow(z, 2), dim=1, keepdim=True)) + self.decoder_net.beta),
                                                      scale=self.decoder_net.sigma).log_prob(x).sum(-1) + log_Pr + 0. * self.decoder_net.aux
            else:
                if problem_name == 'parabola':
                    return torch.distributions.Normal(loc=self.decoder_net.eps * torch.abs(z).pow(2.) + 0. * self.decoder_net.aux,
                                                      scale=self.decoder_net.sigma).log_prob(x).sum(-1) + log_Pr
                elif problem_name == 'circle':
                    return torch.distributions.Normal(loc=2 * np.pi * (torch.sqrt(torch.sum(torch.pow(z, 2), dim=1, keepdim=True)) + 2) + 0. * self.decoder_net.aux,
                                                          scale=self.decoder_net.sigma).log_prob(x).sum(-1) + log_Pr
                else:
                    raise NotImplementedError

        return density
    

class VAE_Toy(VAE, Toy):
    def loss_function(self, recon_x, x, mu, logvar):
        batch_size = mu.shape[0] // self.num_samples
        MSE = F.mse_loss(recon_x.view(mu.shape[0], -1), x.view(mu.shape[0], -1),
                                                 reduction='none').view(
            (self.num_samples, batch_size, -1)).mean(0).sum(-1).mean()
        KLD = -0.5 * torch.mean((1 + logvar - mu.pow(2) - logvar.exp()).view(
            (self.num_samples, -1, self.hidden_dim)).mean(0).sum(-1))
        loss = MSE + KLD
        if not use_true:
            loss = loss + additional_loss(self.decoder_net.alpha, self.decoder_net.beta)
        return loss
    
class IWAE_Toy(IWAE, Toy):
    def loss_function(self, recon_x, x, mu, logvar, z):
        batch_size = mu.shape[0] // self.num_samples
        self.hidden_dim = mu.shape[1]
        log_Q = torch.distributions.Normal(loc=mu,
                                           scale=torch.exp(0.5 * logvar)).log_prob(z).view(
            (self.num_samples, -1, self.hidden_dim)).sum(-1)

        log_Pr = torch.sum((-0.5 * torch.abs(z).pow(2.)).view((self.num_samples, -1, self.hidden_dim)), -1)
        MSE = F.mse_loss(recon_x.view(mu.shape[0], -1), x.view(mu.shape[0], -1),
                                                 reduction='none').view(
            (self.num_samples, batch_size, -1)).sum(-1)

        log_weight = log_Pr - MSE - log_Q
        log_weight = log_weight - torch.max(log_weight, 0)[0]  # for stability
        weight = torch.exp(log_weight)
        weight = weight / torch.sum(weight, 0)
        weight = weight.detach()
        loss = torch.mean(torch.sum(weight * (-log_Pr + MSE + log_Q), 0))
        if not use_true:
            loss = loss + additional_loss(self.decoder_net.alpha, self.decoder_net.beta)
        return loss
    
class VAE_with_flows_Toy(VAE_with_flows, Toy):
    def loss_function(self, recon_x, x, mu, logvar, z, z_transformed, log_jac):
        batch_size = mu.shape[0] // self.num_samples
        MSE = F.mse_loss(recon_x.view(mu.shape[0], -1), x.view(mu.shape[0], -1),
                                                 reduction='none').view(
            (self.num_samples, batch_size, -1)).mean(0).sum(-1).mean()
        log_Q = torch.mean(torch.distributions.Normal(loc=mu, scale=torch.exp(0.5 * logvar)).log_prob(z).view(
            (self.num_samples, batch_size, -1)).sum(-1) - log_jac.view((self.num_samples, -1)), dim=0).mean()
        log_Pr = (-0.5 * z_transformed ** 2).view(
            (self.num_samples, batch_size, -1)).mean(0).sum(-1).mean()
        KLD = log_Q - log_Pr
        loss = MSE + KLD
        if not use_true:
            loss = loss + additional_loss(self.decoder_net.alpha, self.decoder_net.beta)
        return loss
    
#     def configure_optimizers(self):
#         inf_params = list(self.encoder_net.parameters()) + list(self.transitions.parameters())
#         ## If we are using cloned decoder to approximate the true one, we add its params to inference optimizer
#         if self.use_cloned_decoder:
#             inf_params += list(self.cloned_decoder.parameters())

#         optimizer = torch.optim.Adam(all_params, lr=1e-3)
#         optimizer = optim.Adam([
#                 {'params': self.decoder_net.parameters(), 'lr': 1e-2},
#                 {'params': inf_params}
#             ], lr=1e-3, weight_decay=0.0001)

        return optimizer
    
class ULA_VAE_Toy(ULA_VAE, Toy):
    def loss_function(self, sum_log_weights):
        loss = super(ULA_VAE_Toy, self).loss_function(sum_log_weights)
        if not use_true:
            loss = loss + additional_loss(self.decoder_net.alpha, self.decoder_net.beta)
        return loss

class AIS_VAE_Toy(AIS_VAE, Toy):
    def loss_function(self, sum_log_alphas, sum_log_weights):
        loss = super(AIS_VAE_Toy, self).loss_function(sum_log_alphas, sum_log_weights)
        if not use_true:
            loss = loss + additional_loss(self.decoder_net.alpha, self.decoder_net.beta)
        return loss

In [7]:
class ToyDataset(Dataset):
    def __init__(self, data):
        super(ToyDataset, self).__init__()
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        sample = torch.tensor(self.data[item], dtype=torch.float32, device=device)
        return sample, -1.

In [8]:
# class ToyEncoder(nn.Module):
#     def __init__(self, d):
#         super().__init__()
#         if problem_name == 'parabola':
#             self.net = nn.Sequential(
#                 nn.Linear(d, d),
#                 nn.LeakyReLU(),
#                 nn.Linear(d, 2*d),
#             )
#         elif problem_name == 'circle':
#             self.net = nn.Sequential(
#                 nn.Linear(1, 2*d),
#                 nn.LeakyReLU(),
#                 nn.Linear(2*d, 2*d),
#             )
#         else:
#             raise NotImplementedError

#     def forward(self, x):
#         return self.net(x)

class ToyEncoder(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.aux = nn.Parameter(torch.tensor(0., dtype=torch.float32))
#         self.var_mean = torch.zeros(d, dtype=torch.float32, device=device)
#         self.log_sigma_z = torch.zeros(d, dtype=torch.float32, device=device)
        
    def forward(self, x):
        return torch.zeros(x.shape[0], 2*d, device=x.device, dtype=torch.float32) + self.aux
    
    
class ToyEncoder_VB(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.var_mean = nn.Parameter(torch.zeros(d, dtype=torch.float32))
        self.log_sigma_z = nn.Parameter(torch.zeros(d, dtype=torch.float32))
        
    def forward(self, x):
        return self.var_mean, self.log_sigma_z
    
    
class ToyDecoder(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.aux = nn.Parameter(torch.tensor(0., dtype=torch.float32))
        if problem_name == 'parabola':
            self.net = nn.Sequential(
                nn.Linear(d, d),
                nn.LeakyReLU(),
                nn.Linear(d, d),
                nn.LeakyReLU(),
                nn.Linear(d, d),
                nn.LeakyReLU(),
                nn.Linear(d, d),
                nn.LeakyReLU(),
                nn.Linear(d, d),
            )
        elif problem_name == 'circle':
            self.log_alpha = nn.Parameter(torch.tensor(0, device=device, dtype=torch.float32))
            self.log_beta = nn.Parameter(torch.tensor(0., device=device, dtype=torch.float32))
            self.sigma = sigma
        else:
            raise NotImplementedError
    @property
    def alpha(self,):
        return torch.exp(self.log_alpha)
    
    @property
    def beta(self,):
        return torch.exp(self.log_beta)

    def forward(self, x):
        if problem_name == 'circle':
            return self.alpha * (torch.sqrt(torch.sum(torch.pow(x, 2), dim=1, keepdim=True)) + self.beta) + torch.randn_like(x[:, :1]) * self.sigma + 0. * self.aux
        else:
            return self.net(x) 
    
class TrueDecoder(nn.Module):
    def __init__(self, d, sigma, eps=None):
        super().__init__()
        self.register_buffer('eps', torch.tensor(eps, dtype=torch.float32))
        self.aux = nn.Parameter(torch.tensor(0., dtype=torch.float32))
        self.sigma = sigma

    def forward(self, z):
        if problem_name == 'parabola':
            return self.eps * torch.abs(z).pow(2) + torch.randn_like(z) * self.sigma + 0. * self.aux
        elif problem_name == 'circle':
            return 2 * np.pi * (torch.sqrt(torch.sum(torch.pow(z, 2), dim=1, keepdim=True)) + 2.) + torch.randn_like(z[:, :1]) * self.sigma + 0. * self.aux
        else:
            raise NotImplementedError

In [9]:
N = 10000
if problem_name == 'parabola':
    d = 2
    sigma = 1.
elif problem_name == 'circle':
    d = 20
    sigma = 1.
eps = 2 + np.random.randn(1, d)
X_train = generate_dataset(N=N, eps0=eps, d=d, sigma=sigma)
X_val = generate_dataset(N=N // 100, eps0=eps, d=d, sigma=sigma)

In [10]:
plt.close()
plt.figure()
plt.title('True data')
plt.hist(x=X_train, bins=100)
plt.show();

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [11]:
plt.close()
plt.figure()
plt.title('True data')
if d == 2 and problem_name == 'parabola':
    plt.scatter(x=X_train[:, 0], y=X_train[:, 1], alpha=0.25)
else:
    plt.plot(X_train)
plt.show();

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [12]:
train_dataset = ToyDataset(data=X_train)
val_dataset = ToyDataset(data=X_val)

In [13]:
train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True,)
val_loader = DataLoader(val_dataset, batch_size=10, shuffle=False)

In [14]:
def replace_enc_dec(model, is_stacked=False):
    if not is_stacked:
        model.encoder_net = ToyEncoder(d=d)
        if use_true:
            model.decoder_net = TrueDecoder(d=d, eps=eps, sigma=sigma)
        else:
            model.decoder_net = ToyDecoder(d=d)
        model = model.to(device)
    else:
        model.proxy_model.encoder_net = ToyEncoder(d=d)
        model.main_model.encoder_net = ToyEncoder(d=d)
        if use_true:
            model.proxy_model.decoder_net = TrueDecoder(d=d, eps=eps, sigma=sigma)
            model.main_model.decoder_net = TrueDecoder(d=d, eps=eps, sigma=sigma)
        else:
            model.proxy_model.decoder_net = ToyDecoder(d=d)
            model.main_model.decoder_net = ToyDecoder(d=d)
        model = model.to(device)
    return model

In [15]:
#----- VAE ------ #
vae = VAE_Toy(shape=28, act_func=nn.LeakyReLU,
            num_samples=1, hidden_dim=d,
            net_type='conv', dataset='toy')
vae = replace_enc_dec(vae)
if not use_true:
    vae.encoder_net = ToyEncoder_VB(d=d).to(device)

# ----- IWAE ------ #
iwae = IWAE_Toy(shape=28, act_func=nn.LeakyReLU,
            num_samples=50, hidden_dim=d,
            net_type='conv', dataset='toy')
iwae = replace_enc_dec(iwae)
iwae.name = 'IWAE'
if not use_true:
    iwae.encoder_net = ToyEncoder_VB(d=d).to(device)

# ----- ULA_VAE ----- #
ula_vae = ULA_VAE_Toy(shape=28, act_func=nn.LeakyReLU,
            num_samples=1, hidden_dim=d,
            net_type='conv', dataset='toy',
            step_size=0.01, K=7, use_transforms=False, learnable_transitions=False, return_pre_alphas=True, use_score_matching=False,
                      ula_skip_threshold=0.1, grad_skip_val=0., grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=True,
                     acceptance_rate_target=0.9, annealing_scheme='all_learnable')
ula_vae = replace_enc_dec(ula_vae)
ula_vae.name = 'ULA_VAE'

# ----- AIS_VAE ----- #
ais_vae = AIS_VAE_Toy(shape=28, act_func=nn.LeakyReLU,
            num_samples=1, hidden_dim=d,
            net_type='conv', dataset='toy',
            step_size=0.01, K=2, use_barker=False, learnable_transitions=False, use_alpha_annealing=True, grad_skip_val=0.,
                      grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=True,
                     acceptance_rate_target=0.8, annealing_scheme='all_learnable')
ais_vae = replace_enc_dec(ais_vae)
ais_vae.name = 'AIS_VAE'


# ----- VAE_with_Flows ----- #
flows_vae = VAE_with_flows_Toy(shape=28, act_func=nn.LeakyReLU,
            num_samples=1, hidden_dim=d,
            net_type='conv', dataset='toy',
            flow_type='RNVP', num_flows=5, need_permute=True)
flows_vae = replace_enc_dec(flows_vae)
flows_vae.name = 'VAE_with_Flows'

# # ----- Stacked_VAE ----- #
# stacked_vae = Stacked_VAE(shape=28, act_func=nn.LeakyReLU,
#             num_samples=50, hidden_dim=d,
#             net_type='conv', dataset='toy',
#             step_size=0.01, K=7, use_barker=False, name='Stacked_VAE')
# stacked_vae.proxy_model = IWAE_Toy(shape=28, act_func=nn.LeakyReLU,
#                             num_samples=50, hidden_dim=d,
#                             net_type='conv', dataset='toy')
# stacked_vae.main_model = AIS_VAE_Toy(shape=28, act_func=nn.LeakyReLU,
#             num_samples=1, hidden_dim=d,
#             net_type='conv', dataset='toy',
#             step_size=0.01, K=7, use_barker=False)

# stacked_vae = replace_enc_dec(stacked_vae, is_stacked=True)
# stacked_vae.main_model.epsilon_target = 0.75

In [16]:
def run_trainer(model):
    tb_logger = pl_loggers.TensorBoardLogger('lightning_logs/')
    trainer = pl.Trainer(logger=tb_logger, fast_dev_run=False, max_epochs=31, automatic_optimization=True, )
    trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)

In [17]:
# run_trainer(stacked_vae)

In [18]:
run_trainer(flows_vae)

GPU available: True, used: False
TPU available: None, using: 0 TPU cores

  | Name            | Type       | Params
-----------------------------------------------
0 | encoder_net     | ToyEncoder | 1     
1 | decoder_net     | ToyDecoder | 3     
2 | transitions_nll | ModuleList | 8     
3 | Flow            | NormFlow   | 6.3 K 
-----------------------------------------------
6.3 K     Trainable params
8         Non-trainable params
6.3 K     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…



HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




In [19]:
run_trainer(ais_vae)

GPU available: True, used: False
TPU available: None, using: 0 TPU cores

  | Name            | Type       | Params
-----------------------------------------------
0 | encoder_net     | ToyEncoder | 1     
1 | decoder_net     | ToyDecoder | 3     
2 | transitions_nll | ModuleList | 8     
3 | transitions     | ModuleList | 2     
-----------------------------------------------
6         Trainable params
10        Non-trainable params
16        Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…






In [None]:
run_trainer(ula_vae)

In [None]:
run_trainer(iwae)

In [20]:
run_trainer(vae)

GPU available: True, used: False
TPU available: None, using: 0 TPU cores

  | Name            | Type          | Params
--------------------------------------------------
0 | encoder_net     | ToyEncoder_VB | 40    
1 | decoder_net     | ToyDecoder    | 3     
2 | transitions_nll | ModuleList    | 8     
--------------------------------------------------
43        Trainable params
8         Non-trainable params
51        Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

AttributeError: 'tuple' object has no attribute 'shape'

In [24]:
def compute_discrepancy(model):
    with torch.no_grad():
        return (model.decoder_net.alpha - 2 * np.pi)**2 + (model.decoder_net.beta - 2)**2

In [25]:
compute_discrepancy(vae)

ModuleAttributeError: 'TrueDecoder' object has no attribute 'alpha'

In [25]:
compute_discrepancy(ais_vae)

tensor(21.2995, device='cuda:0')

In [19]:
vae.decoder_net.alpha

tensor(0.9801, device='cuda:0', grad_fn=<ExpBackward>)

In [20]:
vae.decoder_net.beta

tensor(0.9614, device='cuda:0', grad_fn=<ExpBackward>)

In [21]:
ais_vae.decoder_net.alpha

tensor(2.1026, device='cuda:0', grad_fn=<ExpBackward>)

In [22]:
ais_vae.decoder_net.beta

tensor(3.9550, device='cuda:0', grad_fn=<ExpBackward>)

In [23]:
2 * np.pi / vae.decoder_net.alpha

tensor(1.8519, device='cuda:0', grad_fn=<MulBackward0>)

In [24]:
2 / vae.decoder_net.beta

tensor(0.4960, device='cuda:0', grad_fn=<MulBackward0>)

In [22]:
def plot_generated_data(model, stacked=False):
    z = torch.randn(10000, d, dtype=torch.float32, device=device)
        
    with torch.no_grad():
        if not stacked:
            generated_samples = model(z).cpu().numpy()
        else:
            generated_samples = model.main_model(z).cpu().numpy()

    plt.close()
    plt.figure()
    plt.title('Generated data')
    if (d == 2) and (problem_name == 'parabola'):
        plt.hist(x=generated_samples, bins=100)
    else:
        plt.plot(generated_samples)
    plt.show();
    
    plt.figure()
    plt.title('Generated data')
    if (d == 2) and (problem_name == 'parabola'):
        plt.scatter(x=generated_samples[:, 0], y=generated_samples[:, 1], alpha=0.25)
    else:
        plt.plot(generated_samples)
    plt.show();

In [23]:
# plot_generated_data(stacked_vae, True)

In [24]:
plot_generated_data(flows_vae)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [25]:
plot_generated_data(ais_vae)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [45]:
plot_generated_data(ula_vae)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [46]:
plot_generated_data(iwae)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [47]:
plot_generated_data(vae)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [26]:
n_leapfrogs = 5
step_size = 0.01
n_samples = 10000
hmc = HMC(n_leapfrogs=n_leapfrogs, step_size=step_size, partial_ref=False, use_barker=False).to(device)

In [27]:
# idx = 2
# X_item = torch.tensor(X_val[idx][None], device=device, dtype=torch.float32).repeat(n_samples, 1)
if problem_name == 'parabola':
    X_item = torch.tensor([[5., 5.]], device=device, dtype=torch.float32)
elif problem_name == 'circle':
    X_item = torch.tensor([[2 * np.pi * (2. + 2.)]], device=device, dtype=torch.float32)
else:
    raise NotImplementedError

In [28]:
## Receive true posterior sample:
if problem_name == 'parabola':
    target_density = lambda z, x: torch.distributions.Normal(loc=torch.tensor(eps, device=device) * (z ** 2), scale=sigma).log_prob(X_item.repeat(n_samples, 1)).sum(
        -1) + torch.distributions.Normal(loc=torch.tensor(0., device=device, dtype=torch.float32),
                                         scale=torch.tensor(1., device=device, dtype=torch.float32)).log_prob(z).sum(-1)
elif problem_name == 'circle':
    target_density = lambda z, x: torch.distributions.Normal(loc=2 * np.pi * (torch.sqrt(torch.sum(torch.pow(z, 2), dim=1, keepdim=True)) + 2.), scale=sigma).log_prob(X_item.repeat(n_samples, 1)).sum(
        -1) + torch.distributions.Normal(loc=torch.tensor(0., device=device, dtype=torch.float32),
                                         scale=torch.tensor(1., device=device, dtype=torch.float32)).log_prob(z).sum(-1)
else:
    raise NotImplementedError
true_posterior_samples = run_chain(kernel=hmc, z_init=torch.randn(n_samples, 2, device=device), target=target_density, return_trace=False, n_steps=1, burnin=500).cpu().numpy()

In [29]:
def receive_model_samples(model, X_item, stacked=False):
    with torch.no_grad():
        if not stacked:
            x = X_item
            model_samples, mu, logvar = model.enc_rep(x=x, n_samples=n_samples)
            if model.name in ['ULA_VAE', 'AIS_VAE']:
                model_samples_init = model_samples
                model_samples = model.run_transitions(z=model_samples, x=x.repeat(n_samples, 1), mu=mu, logvar=logvar)[0]
            elif model.name in ['VAE_with_Flows']:
                model_samples = model.Flow(model_samples)[0]
        else:
            x = X_item
            model_samples, mu, logvar = model.main_model.enc_rep(x=x, n_samples=n_samples)
            model_samples = model.main_model.run_transitions(z=model_samples, x=x.repeat(n_samples, 1), mu=mu, logvar=logvar)[0]
    return model_samples.cpu().numpy()

def plot_contours(model, X_item, stacked=False, graph=None):
    x = X_item
    with torch.no_grad():
        if not stacked:
            model_samples, mu, logvar = model.enc_rep(x=x, n_samples=n_samples)
        else:
            model_samples, mu, logvar = model.main_model.enc_rep(x=x, n_samples=n_samples)
        logprobs = torch.distributions.Normal(loc=mu, scale=torch.exp(0.5 * logvar)).log_prob(model_samples).sum(-1)
#     plt.contour(model_samples[:, 0].cpu()[..., None], model_samples[:, 1].cpu()[..., None], logprobs.cpu()[..., None])
    if graph is None:
        plt.scatter(model_samples[:, 0].cpu(), model_samples[:, 1].cpu(),)
    else:
        graph.x = model_samples[:, 0].cpu()
        graph.y = model_samples[:, 1].cpu()
        graph.plot_joint(plt.scatter, marker='x', c='g', s=50, alpha=0.5)
    

def receive_posterior_samples(model, n_samples, stacked=False):
    with torch.no_grad():
        if not stacked:
            model_target_density = lambda z, x: model.joint_logdensity()(z=z, x=X_item.repeat(n_samples, 1))
        else:
            model_target_density = lambda z, x: model.main_model.joint_logdensity()(z=z, x=X_item.repeat(n_samples, 1))
        model_posterior_samples = run_chain(kernel=hmc, z_init=torch.randn(n_samples, 2, device=device), target=model_target_density, return_trace=False, n_steps=1, burnin=500).cpu().numpy()
    return model_posterior_samples

In [73]:
# vae_sample = receive_model_samples(vae, X_item)
# vae_posterior_sample = receive_posterior_samples(vae, n_samples)

# iwae_sample = receive_model_samples(iwae, X_item)
# iwae_posterior_sample = receive_posterior_samples(iwae, n_samples)

# ula_vae_sample = receive_model_samples(ula_vae, X_item)
# ula_vae_posterior_sample = receive_posterior_samples(ula_vae, n_samples)

ais_vae_sample = receive_model_samples(ais_vae, X_item)
ais_vae_posterior_sample = receive_posterior_samples(ais_vae, n_samples)

# flows_vae_sample = receive_model_samples(flows_vae, X_item)
# flows_vae_posterior_sample = receive_posterior_samples(flows_vae, n_samples)

# stacked_vae_sample = receive_model_samples(stacked_vae, X_item, True)
# stacked_vae_posterior_sample = receive_posterior_samples(stacked_vae, True)

In [74]:
# plot
# ========================================
graph = sns.jointplot(x=true_posterior_samples[:, 0], y=true_posterior_samples[:, 1], color='r', alpha=0.5);
# graph = sns.jointplot(x=vae_posterior_sample[:, 0], y=vae_posterior_sample[:, 1], color='r', alpha=0.5);
# graph = sns.jointplot(x=iwae_posterior_sample[:, 0], y=iwae_posterior_sample[:, 1], color='r', alpha=0.5);
# graph = sns.jointplot(x=ula_vae_posterior_sample[:, 0], y=ula_vae_posterior_sample[:, 1], color='r', alpha=0.5);
# graph = sns.jointplot(x=ais_vae_posterior_sample[:, 0], y=ais_vae_posterior_sample[:, 1], color='r', alpha=0.5);
# graph = sns.jointplot(x=flows_vae_posterior_sample[:, 0], y=flows_vae_posterior_sample[:, 1], color='r', alpha=0.5);
# graph = sns.jointplot(x=stacked_vae_posterior_sample[:, 0], y=stacked_vae_posterior_sample[:, 1], color='r', alpha=0.5);


# graph.x = vae_sample[:, 0]
# graph.y = vae_sample[:, 1]
# graph.plot_joint(plt.scatter, marker='x', c='g', s=50, alpha=0.5);
# plot_contours(vae, X_item, graph=graph)

# graph.x = iwae_sample[:, 0]
# graph.y = iwae_sample[:, 1]
# graph.plot_joint(plt.scatter, marker='x', c='b', s=50, alpha=0.5);
# plot_contours(iwae, X_item, graph=graph)

# graph.x = ula_vae_sample[:, 0]
# graph.y = ula_vae_sample[:, 1]
# graph.plot_joint(plt.scatter, marker='x', c='y', s=50, alpha=0.1);
# plot_contours(ula_vae, X_item, graph=graph)

graph.x = ais_vae_sample[:, 0]
graph.y = ais_vae_sample[:, 1]
graph.plot_joint(plt.scatter, marker='x', c='purple', s=50, alpha=0.5);
# plot_contours(ais_vae, X_item, graph=graph)

# graph.x = flows_vae_sample[:, 0]
# graph.y = flows_vae_sample[:, 1]
# graph.plot_joint(plt.scatter, marker='x', c='orange', s=50, alpha=0.5);
# plot_contours(flows_vae, X_item, graph=graph)



# graph.x = stacked_vae_sample[:, 0]
# graph.y = stacked_vae_sample[:, 1]
# graph.plot_joint(plt.scatter, marker='x', c='black', s=50, alpha=0.5);

# plt.xlim(-5., 5.)
# plt.ylim(-5., 5.)
# plt.axis('equal');


# plot_contours(model=stacked_vae, X_item=X_item, stacked=True)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [37]:
def plot_heatmap(samples, x_limits, y_limits, title=None, gamma=0.75, name='default.png'):
    import matplotlib.colors as mcolors
    from scipy.stats import kde
    plt.close()
    plt.figure(figsize=(5, 5), dpi=300)
    
        # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents
    x = samples[:, 0]
    y = samples[:, 1]
    nbins = 300
    k = kde.gaussian_kde([x,y], bw_method=0.1)
    xi, yi = np.mgrid[x_limits[0]:x_limits[1]:nbins*1j, y_limits[0]:y_limits[1]:nbins*1j]
    zi = k(np.vstack([xi.flatten(), yi.flatten()]))

    # Make the plot
    plt.pcolormesh(xi, yi, zi.reshape(xi.shape))
    plt.axis('off')
    plt.xlim((x_limits[0], x_limits[1]))
    plt.ylim((y_limits[0], y_limits[1]))
#     plt.axis('equal')
    plt.tight_layout()
    plt.savefig(name, format='png')
    plt.show()

#     # Change color palette
#     plt.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=plt.cm.Greens_r)
#     plt.show()


#     plt.close()
#     plt.figure(figsize=(5, 5), dpi=300)
#     if title is not None:
#         plt.title(title)
#     plt.hist2d(x=samples[:, 0], y=samples[:, 1], bins=nbins, density=True, norm=mcolors.PowerNorm(gamma), range=[x_limits, y_limits], cmap=plt.cm.Reds)
#     plt.axis('off')
#     plt.xlim((x_limits[0], x_limits[1]))
#     plt.ylim((y_limits[0], y_limits[1]))
# #     plt.axis('equal')
#     plt.tight_layout()
#     plt.savefig(name, format='png')
#     plt.show();

In [38]:
plot_heatmap(true_posterior_samples, (-4, 4), (-4, 4), title=None, gamma=0.5, name='true_posterior.png')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  plt.pcolormesh(xi, yi, zi.reshape(xi.shape))


In [75]:
plot_heatmap(ais_vae_sample, (-4, 4), (-4, 4), title=None, gamma=0.5, name='ais_vae.png')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  plt.pcolormesh(xi, yi, zi.reshape(xi.shape))


In [None]:
# def latent_kl(model, data, num_samples=100):
#     with torch.no_grad():
#         for batch in data:
#             x, _ = batch
#             model_samples, mu, logvar = model.enc_rep(x, num_samples) ## here we have samples directly from encoder
#             x = repeat_data(x, num_samples)
            
#             if model.name in ['ULA_VAE', 'AIS_VAE']:
#                 model_samples = model.run_transitions(z=model_samples, x=x, mu=mu, logvar=logvar)[0]
#             elif model.name in ['VAE_with_Flows']:
#                 model_samples = model.Flow(model_samples)[0]
#         return model_samples.cpu().numpy()

In [5]:
from sklearn.decomposition import PCA

In [34]:
def receive_posterior_samples_pics(model, n_samples, x):
    with torch.no_grad():
        model_target_density = lambda z, x: model.joint_logdensity()(z=z, x=x)
        model_posterior_samples = run_chain(kernel=hmc, z_init=torch.randn(n_samples, model.hidden_dim, device=device), x=x.repeat(n_samples, 1, 1, 1), target=model_target_density, return_trace=False, n_steps=1, burnin=500).cpu().numpy()
    return model_posterior_samples

def latent_projection(model, x, n_chains=1000):
    latent_samples = receive_posterior_samples_pics(model, n_chains, x)
    print(latent_samples.shape)
    pca = PCA(n_components=1)
    resulting_samples = pca.fit_transform(latent_samples)
    print(resulting_samples.shape)
    return resulting_samples

In [90]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import kde
 
# create data
x = np.random.normal(size=500)
y = x * 3 + np.random.normal(size=500)
 
# Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents
nbins=300
k = kde.gaussian_kde([x,y])
xi, yi = np.mgrid[x.min():x.max():nbins*1j, y.min():y.max():nbins*1j]
zi = k(np.vstack([xi.flatten(), yi.flatten()]))
 
plt.close()
plt.figure()
# Make the plot
plt.pcolormesh(xi, yi, zi.reshape(xi.shape))
plt.show()
 
plt.close()
plt.figure()
# Change color palette
plt.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=plt.cm.Greens_r)
plt.show()


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  plt.pcolormesh(xi, yi, zi.reshape(xi.shape))


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  plt.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=plt.cm.Greens_r)


In [95]:
xi.shape

(300, 300)

In [98]:
ks = [0, 1, 2, 3, 5, 10, 15, 20]
ll = np.array([[-52.25, -54.82, -55.12, -55.69, -56.04, -55.52, -54.71, -55.1],
      [-51.33, -55.4, -55.79, -56.23, -55.94, -58.38, -56.15, -56.17,],
     ])

In [102]:
plt.close()
plt.figure()
plt.plot(ks, ll.mean(0), '--o')
plt.fill_between(x=ks, y1=ll.mean(0) - ll.std(0), y2=ll.mean(0) + ll.std(0), alpha=0.5)
plt.grid()
plt.ylabel('LogLikelihood')
plt.xlabel('K')
plt.show();

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Test transitions

In [37]:
target_distr = torch.distributions.Normal(loc=torch.tensor([10., 10.], device=device, dtype=torch.float32),
                                                     scale=torch.tensor([1., 1.], device=device, dtype=torch.float32))
target_fake = lambda z, x: target_distr.log_prob(z).sum(1)

In [61]:
mala_kernel = HMC(n_leapfrogs=3, step_size=0.1)
samples = run_chain(kernel=mala_kernel, z_init=torch.zeros(1, 2, device=device, dtype=torch.float32), target=target_fake, n_steps=1000, return_trace=True).cpu()

In [62]:
true_samples = target_distr.sample((1000,)).cpu()

In [63]:
plt.close()
plt.figure()
plt.scatter(true_samples[:, 0], true_samples[:, 1], label='True')
plt.scatter(samples[:, 0], samples[:, 1], label='MALA')
plt.legend()
plt.tight_layout()
plt.show();

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Block with interpolation. Trained models are required.

In [52]:
# train_loader, val_loader = make_dataloaders(dataset='mnist', batch_size=100, val_batch_size=10, binarize=True)



In [83]:
version = 88
epoch = 100

In [84]:
with open(f'lightning_logs/default/version_{version}/hparams.yaml') as file:
    # The FullLoader parameter handles the conversion from YAML
    # scalar values to Python the dictionary format
    fruits_list = yaml.load(file, Loader=yaml.FullLoader)

    print(fruits_list)
    hparams = fruits_list

{'K': 10, 'acceptance_rate_target': 0.9, 'act_func': <class 'torch.nn.modules.activation.GELU'>, 'beta': None, 'dataset': 'mnist', 'grad_clip_val': 0.0, 'grad_skip_val': 0.0, 'hidden_dim': 64, 'learnable_transitions': False, 'name': 'ULA_VAE', 'net_type': 'conv', 'num_samples': 1, 'shape': 28, 'step_size': 0.01, 'ula_skip_threshold': 0.1, 'use_cloned_decoder': False, 'use_score_matching': False, 'variance_sensitive_step': True}


In [86]:
model = ULA_VAE(**hparams).to(device)
checkpoint = torch.load(f'lightning_logs/default/version_{version}/checkpoints/epoch=99-step=59999.ckpt')
model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [76]:
# for batch in train_loader:
#     x, _ = batch
#     break

In [None]:
plt.close()
plt.figure()
obj1 = x[0][None].to(device)
if obj1.shape[1] == 1:
    plt.imshow(obj1[0][0].cpu())
else:
    plt.imshow(obj1[0].permute((1, 2, 0)).cpu())
plt.tight_layout()
plt.show();

In [None]:
model_proj = latent_projection(model, obj1)

In [89]:
plt.close()
plt.figure()
plt.hist(model_proj[:, 0], bins=100)
plt.show();

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [86]:
plt.close()
plt.figure()
obj2 = x[-1][None].to(device)
if obj2.shape[1] == 1:
    plt.imshow(obj2[0][0].cpu())
else:
    plt.imshow(obj2[0].permute((1, 2, 0)).cpu())
plt.tight_layout()
plt.show();

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Interpolation functions

In [30]:
def interpolate_annealing(model, obj1, obj2, kernel, T=10, n_steps=10):
    with torch.no_grad():
        all_z = torch.tensor([], dtype=torch.float32, device=obj1.device)
        z_init = torch.mean(model.enc_rep(obj1)[0], 0)[None]
        for t in np.linspace(0., 1., T):
            target_density = lambda t: lambda z, x: (1 - t) * model.joint_density()(z=z, x=obj1) + t * model.joint_density()(z=z, x=obj2)
            current_z = run_chain(kernel=kernel, z_init=z_init, target=target_density(t=t), return_trace=False, n_steps=n_steps)
            all_z = torch.cat([all_z, z_init])
            z_init = current_z
    return all_z

def interpolate_mixture(model, obj1, obj2, kernel, T=10, n_steps=10):
    with torch.no_grad():
        all_z = torch.tensor([], dtype=torch.float32, device=obj1.device)
        z_init = torch.mean(model.enc_rep(obj1)[0], 0)[None]
        for t in np.linspace(0.01, 0.99, T):
            target_density = lambda t: lambda z, x: torch.logsumexp(torch.cat([np.log(1 - t) + model.joint_density()(z=z, x=obj1)[None],
                                                                               np.log(t) + model.joint_density()(z=z, x=obj2)[None]]), dim=0)
            current_z = run_chain(kernel=kernel, z_init=z_init, target=target_density(t=t), return_trace=False, n_steps=n_steps)
            all_z = torch.cat([all_z, z_init])
            z_init = current_z
    return all_z

def interpolate_linear(model, obj1, obj2, T=10):
    with torch.no_grad():
        all_z = torch.tensor([], dtype=torch.float32, device=obj1.device)
        z_1 = torch.mean(model.enc_rep(obj1)[0], 0)[None]
        z_2 = torch.mean(model.enc_rep(obj2)[0], 0)[None]
        for t in np.linspace(0., 1., T):
            current_z = (1 - t) * z_1 + t * z_2
            all_z = torch.cat([all_z, current_z])
    return all_z

def visualize(model, z, shape=(-1, 1, 28, 28)):
    with torch.no_grad():
        x_hat = torch.sigmoid(model(z)).view(shape).cpu()
        plt.close()
        plt.figure()
        if shape[1] == 1:
            grid = torchvision.utils.make_grid(x_hat, nrow=15).mean(0)
            plt.imshow(grid, 'gray')
        else:
            grid = torchvision.utils.make_grid(x_hat, nrow=15)
            plt.imshow(grid.permute((1, 2, 0)))
        plt.tight_layout()
        plt.show();

In [26]:
all_z = interpolate_annealing(model, obj1, obj2, hmc, T=10, n_steps=20)
visualize(model, all_z, (-1, 3, 64, 64))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [27]:
all_z = interpolate_linear(model, obj1, obj2, T=10)
visualize(model, all_z, (-1, 3, 64, 64))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [33]:
all_z = interpolate_mixture(model, obj1, obj2, hmc, T=10, n_steps=30)
visualize(model, all_z, (-1, 3, 64, 64))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …