### Load Test Data and Trained APG Sampler

In [None]:
%matplotlib inline
import os
import torch
import numpy as np
from experiments.apgs_bshape.models import init_models
from experiments.apgs_bshape.affine_transformer import Affine_Transformer
from experiments.apgs_bshape.main import test_gibbs_sweep, train_apg

data_dir = './dataset/'
timesteps = 10
frame_pixels = 96
shape_pixels = 28
num_objects = 3

device = 'cuda:1'
num_epochs = 1000
lr = 2e-4
batch_size = 5
budget = 15
num_sweeps = 3

num_hidden_digit = 400
num_hidden_coor = 400
z_where_dim = 2
z_what_dim = 10

out, frames = train_apg(num_epochs=num_epochs,
                      lr=lr,
                      batch_size=batch_size,
                      budget=budget,
                      num_sweeps=num_sweeps,
                      timesteps=timesteps,
                      data_dir=data_dir,
                      frame_pixels=frame_pixels, 
                      shape_pixels=shape_pixels, 
                      num_hidden_digit=num_hidden_digit, 
                      num_hidden_coor=num_hidden_coor, 
                      z_where_dim=z_where_dim, 
                      z_what_dim=z_what_dim, 
                      num_objects=num_objects, 
                      device=device)

# out, frames = test_gibbs_sweep(budget=budget, 
#                                num_sweeps=num_sweeps,
#                                timesteps=timesteps,
#                                  data_dir=data_dir,
#                                  frame_pixels=frame_pixels, 
#                                  shape_pixels=shape_pixels, 
#                                  num_hidden_digit=num_hidden_digit, 
#                                  num_hidden_coor=num_hidden_coor, 
#                                  z_where_dim=z_where_dim, 
#                                  z_what_dim=z_what_dim, 
#                                  num_objects=num_objects, 
#                                  device=device)

Training forapg-bshape-num_objects=3-num_sweeps=3-num_samples=5


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch=1, Group=1, ess: 1.0007,  log_p: -2033014.0000,  loss: 2936898.0000
Epoch=2, Group=1, ess: 1.0153,  log_p: -2088158.0000,  loss: 2866417.0000
Epoch=3, Group=1, ess: 1.0000,  log_p: -2042181.7500,  loss: 2946927.0000
Epoch=4, Group=1, ess: 1.0000,  log_p: -2197869.5000,  loss: 2997044.0000
Epoch=5, Group=1, ess: 1.0000,  log_p: -2179428.2500,  loss: 2993931.0000
Epoch=6, Group=1, ess: 1.0055,  log_p: -2128709.7500,  loss: 2930702.0000
Epoch=7, Group=1, ess: 1.0000,  log_p: -2061493.6250,  loss: 2917458.0000
Epoch=8, Group=1, ess: 1.0004,  log_p: -2145257.5000,  loss: 2981596.0000
Epoch=9, Group=1, ess: 1.0000,  log_p: -2120626.0000,  loss: 2926628.5000
Epoch=10, Group=1, ess: 1.0000,  log_p: -2021096.5000,  loss: 2907041.0000
Epoch=11, Group=1, ess: 1.0786,  log_p: -2070946.7500,  loss: 2913732.2500
Epoch=12, Group=1, ess: 1.0000,  log_p: -1975701.3750,  loss: 2866086.2500
Epoch=13, Group=1, ess: 1.0000,  log_p: -1985607.5000,  loss: 2800021.7500
Epoch=14, Group=1, ess: 1.0000,  l

In [None]:
from combinators import debug
if debug.runtime() == 'jupyter':
    from tqdm.notebook import trange, tqdm
else:
    from tqdm import trange, tqdm
from tqdm.contrib import tenumerate


In [None]:
def get_samples(out, sweeps, T):
    recon_vals = out.trace['recon'].dist.probs
    z_where_vals = []
    for t in range(T):
        z_where_vals.append(out.trace['z_where_%d_%d'%(t,sweeps)].value.unsqueeze(2))
    z_where_vals = torch.cat(z_where_vals, 2)
    return (recon_vals.detach().cpu(), z_where_vals.detach().cpu())

In [None]:
rs, ws = get_samples(out, num_sweeps, timesteps)

### Visualize Samples

In [None]:
from experiments.apgs_bshape.evaluation import viz_samples
viz_samples(frames, rs, ws, num_sweeps, num_objects, shape_pixels, fs=1)

### Computing log joint across all methods

In [None]:
from apgs.bshape.evaluation import density_all_instances
from random import shuffle
sample_size, num_sweeps = 20, 5
lf_step_size, lf_num_steps, bpg_factor = 5e-5, [100], 1
density_all_instances(models, AT, data_paths, sample_size, num_objects, z_where_dim, z_what_dim, num_sweeps, lf_step_size, lf_num_steps, bpg_factor, CUDA, device)

### Computational Budget Analysis

In [None]:
from apgs.bshape.evaluation import budget_analysis, plot_budget_analyais_results
data = torch.from_numpy(np.load(data_dir + '%dobjects/test/ob-1.npy' % num_objects)).float()
budget = 1000
num_sweeps = np.array([1, 5, 10 , 20, 25])
sample_sizes = 1000 / num_sweeps
blocks = ['decomposed', 'joint']
df = budget_analysis(models, blocks, num_sweeps, sample_sizes, data, num_objects, CUDA, device)
plot_budget_analyais_results(df)

### Comparison with Baselines

In [None]:
from apgs.bshape.evaluation import density_convergence, plot_convergence
seed = 1
data = torch.from_numpy(np.load(data_dir + '%dobjects/test/ob-1.npy' % num_objects)).float()
sample_size, num_sweeps, num_runs = 100, 30, 3
lf_step_size, lf_num_steps, bpg_factor = 1e-1, [1, 5, 10], 100
densities = density_convergence(models, AT, data, sample_size, num_objects, num_runs, num_sweeps, lf_step_size, lf_num_steps, bpg_factor, CUDA, device)
plot_convergence(densities)