In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers

from models.vaes import Base, VAE, AIS_VAE, ULA_VAE, VAE_with_flows

import numpy as np
import os
import pickle

import matplotlib.pyplot as plt # inline
%matplotlib widget 

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
def generate_toy_data(N, d):
        ######### Problem Specification

        ###### Data generation parameters
        prior_mu_z = np.zeros(d, dtype=np.float32)    # Prior mean
        prior_sigma_z = np.eye(d, dtype=np.float32)   # Prior covariance matrix
        
        ###### True model parameters:
        # True offset
        num_range = np.arange(-(d-1)/2, (d+1)/2, dtype=np.float32)
        t_delta =  num_range / 5 + 10. #### !!!
        
        # True scale
        if d == 1:
            t_sigma = np.ones(1)
        else: 
            # Allow sigma to range from 0.1 to 1
            t_sigma = 36/(10*(d-1)**2) * num_range**2 + 0.1
            
            
        ######### Generate Training Data & Save - One for each test
        z = np.random.multivariate_normal(prior_mu_z, prior_sigma_z)
        x = np.random.multivariate_normal(z + t_delta, np.diag(t_sigma**2), size=N)
        
        return x, z, t_delta, t_sigma

#         # Folder should have already been created in the initializations
#         data_path = os.path.join('data/toy_hvae/', f'train_data_{str(d)}_{current_iter}.p')
#         pickle.dump(x, open(data_path, 'wb')) 

In [4]:
def replace_enc_dec(model, d, name='our'):
    if name == 'our':
        model.encoder_net = ToyEncoder(d=d)
    elif name == 'vb':
        model.encoder_net = ToyEncoder_VB(d=d)
    model.decoder_net = ToyDecoder(d=d)
    model = model.to(device)
    return model

In [5]:
class ToyDataset(Dataset):
    def __init__(self, data):
        super(ToyDataset, self).__init__()
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        sample = torch.tensor(self.data[item], dtype=torch.float32, device=device)
        return sample, -1.

In [24]:
class ToyEncoder(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.aux = nn.Parameter(torch.tensor(0., dtype=torch.float32))
        self.var_mean = torch.zeros(d, dtype=torch.float32, device=device)
        self.log_sigma_z = torch.zeros(d, dtype=torch.float32, device=device)
        
    def forward(self, x):
        return self.aux * 0. + self.var_mean, self.aux * 0. + self.log_sigma_z
    
class ToyEncoder_VB(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.var_mean = nn.Parameter(torch.zeros(d, dtype=torch.float32))
        self.log_sigma_z = nn.Parameter(torch.zeros(d, dtype=torch.float32))
        
    def forward(self, x):
        return self.var_mean, self.log_sigma_z
    
    
######## ENCODERS ABOVE    

    
class ToyDecoder(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.delta = nn.Parameter(torch.zeros(d, dtype=torch.float32))    # Prior mean
        self.log_sigma = nn.Parameter(3 * torch.ones(d, dtype=torch.float32))

######### The only true decoder



class Toy(Base):
    def joint_logdensity(self, use_true_decoder=None):
        def density(z, x):
#             var_x = torch.exp(2 * self.decoder_net.log_sigma)
#             grad_U = z + N_train * (z + self.delta - x) / var_x
#             return -grad_U
            log_Pr = torch.distributions.Normal(loc=torch.tensor(0., device=x.device, dtype=torch.float32),
                                    scale=torch.tensor(1., device=x.device, dtype=torch.float32)).log_prob(z).sum(-1)
            return N_train * torch.distributions.Normal(loc=self.decoder_net.delta[None] + z,
                                              scale=torch.exp(self.decoder_net.log_sigma[None]) * torch.ones_like(z)).log_prob(x).sum(-1) + log_Pr

        return density
    
    def encode(self, x):
        # We treat the first half of output as mu, and the rest as logvar
        mu, logvar = self.encoder_net(x)
        mu = mu.repeat(x.shape[0], 1)
        logvar = logvar.repeat(x.shape[0], 1)
        return mu, logvar
    
    def evaluate_nll(self, batch, beta):
        return torch.mean(torch.tensor(100500.))
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-2)
    

class VAE_Toy(VAE, Toy):    
    def step(self, batch):
        x_batch, _ = batch
        
        var_inv_vec = torch.exp(-2 * self.decoder_net.log_sigma)
        
        y_sig_y = torch.sum((x_batch - self.decoder_net.delta) ** 2 * var_inv_vec)
        y_sig_mu = torch.sum((x_batch - self.decoder_net.delta) * var_inv_vec * self.encoder_net.var_mean)
        var_Z_over_var_X = torch.sum(torch.exp(2*self.encoder_net.log_sigma_z) * var_inv_vec)
        
        mu_sig_mu = torch.sum(self.encoder_net.var_mean ** 2 * var_inv_vec)
        mu_T_mu = torch.sum(self.encoder_net.var_mean ** 2)
        
        Nd2_log2pi = N_train * x_batch.shape[1] / 2 * np.log(2 * np.pi)
        
        
        loss = - (- Nd2_log2pi + torch.sum(self.encoder_net.log_sigma_z)
                - N_train * torch.sum(self.decoder_net.log_sigma) - 1 / 2 * y_sig_y
                + y_sig_mu - N_train / 2 * (var_Z_over_var_X + mu_sig_mu)
                - 1 / 2 * torch.sum(torch.exp(2 * self.encoder_net.log_sigma_z)) - 1 / 2 * mu_T_mu
                - x_batch.shape[1] / 2
                )
        return loss, x_batch, x_batch
    
    
class VAE_with_flows_Toy(VAE_with_flows, Toy):
    def forward(self, z):
        return None
    
#     def loss_function(self, recon_x, x, mu, logvar, z, z_transformed, log_jac):

#         batch_size = mu.shape[0] // self.num_samples
#         log_likelihood = N_train * torch.distributions.Normal(loc=self.decoder_net.delta[None] + z_transformed,
#                                               scale=torch.exp(self.decoder_net.log_sigma[None]) * torch.ones_like(z_transformed)).log_prob(x).sum(-1).mean()
#         log_Q = torch.mean(torch.distributions.Normal(loc=mu, scale=torch.exp(0.5 * logvar)).log_prob(z).view(
#             (self.num_samples, batch_size, -1)).sum(-1) - log_jac.view((self.num_samples, -1)), dim=0)
#         log_Pr = (-0.5 * z_transformed ** 2).view(
#             (self.num_samples, batch_size, -1)).mean(0).sum(-1)
#         KLD = torch.mean(log_Q - log_Pr)
#         loss = -log_likelihood + KLD
#         return loss
    
    def loss_function(self, recon_x, x, mu, logvar, z, z_transformed, log_jac):
        var_inv_vec = torch.exp(-2 * self.decoder_net.log_sigma)
        y_sig_y = torch.sum((x - self.decoder_net.delta)**2 * var_inv_vec)
        y_bar_sig_z = torch.sum((x_bar - self.decoder_net.delta) * var_inv_vec * z_transformed, dim=1)
        z_sig_z = torch.sum(z_transformed**2 * var_inv_vec, dim=1)
        z_T_z = torch.sum(z_transformed * z_transformed, dim=1)
        
        Nd2_log2pi = N_train * x.shape[1] / 2*np.log(2*np.pi)
        
        loss = -(- Nd2_log2pi - N_train*torch.sum(self.decoder_net.log_sigma)
            - y_sig_y + N_train*torch.mean(y_bar_sig_z)
            - N_train/2.*torch.mean(z_sig_z) 
            - 1./2*torch.mean(z_T_z) + log_jac.mean())

        return loss
    
    
class ULA_VAE_Toy(ULA_VAE, Toy):
    pass
    
class AIS_VAE_Toy(AIS_VAE, Toy):
    pass

In [25]:
N_train = 10000
dim = 300

In [26]:
data, true_z, true_delta, true_sigma = generate_toy_data(N=int(1.2 * N_train), d=dim)

In [27]:
train_dataset = ToyDataset(data[:N_train])
val_dataset = ToyDataset(data[N_train:])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

x_bar = torch.tensor(data[:N_train].mean(0), device=device, dtype=torch.float32)

In [28]:
# ----- ULA_VAE ----- #
ula_vae = ULA_VAE_Toy(shape=28, act_func=nn.LeakyReLU,
            num_samples=1, hidden_dim=dim,
            net_type='conv', dataset='toy',
            step_size=0.01, K=1, use_transforms=False, learnable_transitions=False, return_pre_alphas=True, use_score_matching=False,
                      ula_skip_threshold=0.1, grad_skip_val=0., grad_clip_val=0., use_cloned_decoder=False, variance_sensitive_step=True,
                     acceptance_rate_target=0.9, annealing_scheme='linear')
ula_vae = replace_enc_dec(ula_vae, dim, 'our')
ula_vae.name = 'ULA_VAE'

# ----- AIS_VAE ----- #
ais_vae = AIS_VAE_Toy(shape=28, act_func=nn.LeakyReLU,
            num_samples=1, hidden_dim=dim,
            net_type='conv', dataset='toy',
            step_size=0.01, K=1, use_barker=False, learnable_transitions=False, use_alpha_annealing=False, grad_skip_val=0.,
                      grad_clip_val=50., use_cloned_decoder=False, variance_sensitive_step=True,
                     acceptance_rate_target=0.75, annealing_scheme='linear')
ais_vae = replace_enc_dec(ais_vae, dim, 'our')
ais_vae.name = 'AIS_VAE'

# ----- Flow_VAE ----- #
flow_vae = VAE_with_flows_Toy(shape=28, act_func=nn.LeakyReLU,
            num_samples=1, hidden_dim=dim,
            net_type='conv', dataset='toy',
            flow_type='RNVP', num_flows=2, need_permute=True)
flow_vae = replace_enc_dec(flow_vae, dim, 'our')
flow_vae.name = 'Flow_VAE'

In [29]:
# ----- VAE ------ #
vae = VAE_Toy(shape=28, act_func=nn.LeakyReLU,
            num_samples=1, hidden_dim=dim,
            net_type='conv', dataset='toy')
vae = replace_enc_dec(vae, dim, 'vb')
vae.name = 'VAE'

In [30]:
def run_trainer(model):
    tb_logger = pl_loggers.TensorBoardLogger('lightning_logs/')
    trainer = pl.Trainer(logger=tb_logger, fast_dev_run=False, max_epochs=51, automatic_optimization=True, track_grad_norm=2)
    trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)

In [31]:
run_trainer(flow_vae)

GPU available: True, used: False
TPU available: None, using: 0 TPU cores

  | Name            | Type       | Params
-----------------------------------------------
0 | encoder_net     | ToyEncoder | 1     
1 | decoder_net     | ToyDecoder | 600   
2 | transitions_nll | ModuleList | 8     
3 | Flow            | NormFlow   | 541 K 
-----------------------------------------------
542 K     Trainable params
8         Non-trainable params
542 K     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




In [13]:
run_trainer(vae)

GPU available: True, used: False
TPU available: None, using: 0 TPU cores

  | Name            | Type          | Params
--------------------------------------------------
0 | encoder_net     | ToyEncoder_VB | 600   
1 | decoder_net     | ToyDecoder    | 600   
2 | transitions_nll | ModuleList    | 8     
--------------------------------------------------
1.2 K     Trainable params
8         Non-trainable params
1.2 K     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…



HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




In [14]:
run_trainer(ais_vae)

GPU available: True, used: False
TPU available: None, using: 0 TPU cores

  | Name            | Type       | Params
-----------------------------------------------
0 | encoder_net     | ToyEncoder | 1     
1 | decoder_net     | ToyDecoder | 600   
2 | transitions_nll | ModuleList | 8     
3 | transitions     | ModuleList | 1     
-----------------------------------------------
601       Trainable params
9         Non-trainable params
610       Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




In [15]:
run_trainer(ula_vae)

GPU available: True, used: False
TPU available: None, using: 0 TPU cores

  | Name            | Type       | Params
-----------------------------------------------
0 | encoder_net     | ToyEncoder | 1     
1 | decoder_net     | ToyDecoder | 4     
2 | transitions_nll | ModuleList | 8     
3 | transitions     | ModuleList | 1     
-----------------------------------------------
5         Trainable params
9         Non-trainable params
14        Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




In [14]:
def estimate_model(model, t_delta, t_sigma):
    delta, sigma = model.decoder_net.delta.cpu().detach().numpy(), torch.exp(model.decoder_net.log_sigma).cpu().detach().numpy()
    delta_diff = np.sum((delta - t_delta)**2)
    sigma_diff = np.sum((sigma - t_sigma)**2)
    return delta_diff, sigma_diff

In [15]:
estimate_model(ula_vae, true_delta, true_sigma)

(200.02, 728.51544)

In [16]:
estimate_model(ais_vae, true_delta, true_sigma)

(26903.926, 27936.549)

In [17]:
estimate_model(vae, true_delta, true_sigma)

(119999.0, 116254.3)

In [15]:
estimate_model(flow_vae, true_delta, true_sigma)

(461646.75, 70.37191)

In [23]:
def plot_result(model):
    plt.close()
    plt.figure()
    plt.scatter(data[:, 0], data[:, 1])
    plt.scatter(true_delta[0], true_delta[1], color='red', s=100)
    plt.scatter(model.decoder_net.delta[0].cpu().detach().numpy(), model.decoder_net.delta[1].cpu().detach().numpy(), color='green', s=100)
    plt.axis('equal')
    plt.show();

In [24]:
plot_result(flow_vae)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [17]:
plot_result(vae)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [18]:
true_z

array([-1.21868733, -1.0987256 ])

In [20]:
plot_result(ais_vae)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [31]:
plot_result(ula_vae)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [24]:
true_delta.shape

(2,)