In [1]:
import logging

import probtorch
import torch

import combinators
import smc
import ssm
import utils

In [2]:
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [3]:
generative = combinators.GraphingTrace(1)

In [4]:
mu = torch.zeros(1, 1)
sigma = torch.ones(1, 1) / 2
delta = torch.ones(1, 1)
zs = torch.ones(1, 50+1) * -1
zs[:, 0] = 0

In [5]:
ssm_step = combinators.Model(ssm.ssm_step)
ssm_step.condition(trace=generative, guide=None)

In [6]:
for t in range(zs.shape[1] - 1):
    zs[:, t+1], mu, sigma, delta = ssm_step((zs[:, t], mu, sigma, delta), t)

In [7]:
generative

Trace{'Z_1': Normal([torch.FloatTensor of size 1x1]), 'X_1': Normal([torch.FloatTensor of size 1x1]), 'Z_2': Normal([torch.FloatTensor of size 1x1]), 'X_2': Normal([torch.FloatTensor of size 1x1]), 'Z_3': Normal([torch.FloatTensor of size 1x1]), 'X_3': Normal([torch.FloatTensor of size 1x1]), 'Z_4': Normal([torch.FloatTensor of size 1x1]), 'X_4': Normal([torch.FloatTensor of size 1x1]), 'Z_5': Normal([torch.FloatTensor of size 1x1]), 'X_5': Normal([torch.FloatTensor of size 1x1]), 'Z_6': Normal([torch.FloatTensor of size 1x1]), 'X_6': Normal([torch.FloatTensor of size 1x1]), 'Z_7': Normal([torch.FloatTensor of size 1x1]), 'X_7': Normal([torch.FloatTensor of size 1x1]), 'Z_8': Normal([torch.FloatTensor of size 1x1]), 'X_8': Normal([torch.FloatTensor of size 1x1]), 'Z_9': Normal([torch.FloatTensor of size 1x1]), 'X_9': Normal([torch.FloatTensor of size 1x1]), 'Z_10': Normal([torch.FloatTensor of size 1x1]), 'X_10': Normal([torch.FloatTensor of size 1x1]), 'Z_11': Normal([torch.FloatTenso

In [8]:
num_particles = 100

In [9]:
smc_runner = smc.smc(ssm.ssm_step, 50)

In [10]:
ssm_params = {
    'mu': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'sigma': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'delta': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    }
}

In [11]:
init_ssm = combinators.Model(ssm.init_ssm, phi=ssm_params, theta={})

In [12]:
inference, init_ssm_params = smc.variational_smc(num_particles, init_ssm, smc_runner, 500, generative,
                                                 use_cuda=True, lr=1e-2)

07/27/2018 16:45:17 Variational SMC ELBO=-1.86804922e+04 at epoch 1
07/27/2018 16:45:17 Variational SMC ELBO=-1.83878086e+04 at epoch 2
07/27/2018 16:45:17 Variational SMC ELBO=-2.22733105e+04 at epoch 3
07/27/2018 16:45:17 Variational SMC ELBO=-2.29266113e+04 at epoch 4
07/27/2018 16:45:17 Variational SMC ELBO=-1.94069336e+04 at epoch 5
07/27/2018 16:45:18 Variational SMC ELBO=-2.59256719e+04 at epoch 6
07/27/2018 16:45:18 Variational SMC ELBO=-1.89978789e+04 at epoch 7
07/27/2018 16:45:18 Variational SMC ELBO=-1.49116543e+04 at epoch 8
07/27/2018 16:45:18 Variational SMC ELBO=-1.89061621e+04 at epoch 9
07/27/2018 16:45:18 Variational SMC ELBO=-2.06446934e+04 at epoch 10
07/27/2018 16:45:19 Variational SMC ELBO=-2.17913516e+04 at epoch 11
07/27/2018 16:45:19 Variational SMC ELBO=-2.13650059e+04 at epoch 12
07/27/2018 16:45:19 Variational SMC ELBO=-2.56696855e+04 at epoch 13
07/27/2018 16:45:19 Variational SMC ELBO=-1.95097480e+04 at epoch 14
07/27/2018 16:45:20 Variational SMC ELBO=-1

In [13]:
for t in range(1, zs.shape[1]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 4.757679
SMC MSE at time 2: 1.152176
SMC MSE at time 3: 0.566984
SMC MSE at time 4: 0.662791
SMC MSE at time 5: 0.050764
SMC MSE at time 6: 0.434361
SMC MSE at time 7: 1.503834
SMC MSE at time 8: 5.288766
SMC MSE at time 9: 1.334679
SMC MSE at time 10: 0.201546
SMC MSE at time 11: 4.027096
SMC MSE at time 12: 1.944739
SMC MSE at time 13: 0.330881
SMC MSE at time 14: 1.681702
SMC MSE at time 15: 0.158697
SMC MSE at time 16: 2.561338
SMC MSE at time 17: 0.153660
SMC MSE at time 18: 1.731791
SMC MSE at time 19: 0.973667
SMC MSE at time 20: 2.217055
SMC MSE at time 21: 1.084507
SMC MSE at time 22: 0.096041
SMC MSE at time 23: 0.028734
SMC MSE at time 24: 1.332212
SMC MSE at time 25: 1.826220
SMC MSE at time 26: 0.919911
SMC MSE at time 27: 0.011625
SMC MSE at time 28: 1.263967
SMC MSE at time 29: 0.880703
SMC MSE at time 30: 1.179494
SMC MSE at time 31: 0.575168
SMC MSE at time 32: 1.659107
SMC MSE at time 33: 0.142253
SMC MSE at time 34: 3.046354
SMC MSE at time 35: 0.6

In [14]:
samples, elbos, inference = smc.particle_mh(num_particles, combinators.Model(ssm.init_ssm, theta=ssm_params), smc_runner, 500, generative)

In [15]:
inference

Trace{'mu': Normal([torch.FloatTensor of size 100]), 'sigma': Normal([torch.FloatTensor of size 100]), 'delta': Normal([torch.FloatTensor of size 100]), 'Z_0': Normal([torch.FloatTensor of size 100]), 'Z_1': Normal([torch.FloatTensor of size 100]), 'X_1': Normal([torch.FloatTensor of size 100x1x1]), 'Z_2': Normal([torch.FloatTensor of size 100]), 'X_2': Normal([torch.FloatTensor of size 100x1x1]), 'Z_3': Normal([torch.FloatTensor of size 100]), 'X_3': Normal([torch.FloatTensor of size 100x1x1]), 'Z_4': Normal([torch.FloatTensor of size 100]), 'X_4': Normal([torch.FloatTensor of size 100x1x1]), 'Z_5': Normal([torch.FloatTensor of size 100]), 'X_5': Normal([torch.FloatTensor of size 100x1x1]), 'Z_6': Normal([torch.FloatTensor of size 100]), 'X_6': Normal([torch.FloatTensor of size 100x1x1]), 'Z_7': Normal([torch.FloatTensor of size 100]), 'X_7': Normal([torch.FloatTensor of size 100x1x1]), 'Z_8': Normal([torch.FloatTensor of size 100]), 'X_8': Normal([torch.FloatTensor of size 100x1x1]),

In [16]:
generative

Trace{'Z_1': Normal([torch.FloatTensor of size 1x1]), 'X_1': Normal([torch.FloatTensor of size 1x1]), 'Z_2': Normal([torch.FloatTensor of size 1x1]), 'X_2': Normal([torch.FloatTensor of size 1x1]), 'Z_3': Normal([torch.FloatTensor of size 1x1]), 'X_3': Normal([torch.FloatTensor of size 1x1]), 'Z_4': Normal([torch.FloatTensor of size 1x1]), 'X_4': Normal([torch.FloatTensor of size 1x1]), 'Z_5': Normal([torch.FloatTensor of size 1x1]), 'X_5': Normal([torch.FloatTensor of size 1x1]), 'Z_6': Normal([torch.FloatTensor of size 1x1]), 'X_6': Normal([torch.FloatTensor of size 1x1]), 'Z_7': Normal([torch.FloatTensor of size 1x1]), 'X_7': Normal([torch.FloatTensor of size 1x1]), 'Z_8': Normal([torch.FloatTensor of size 1x1]), 'X_8': Normal([torch.FloatTensor of size 1x1]), 'Z_9': Normal([torch.FloatTensor of size 1x1]), 'X_9': Normal([torch.FloatTensor of size 1x1]), 'Z_10': Normal([torch.FloatTensor of size 1x1]), 'X_10': Normal([torch.FloatTensor of size 1x1]), 'Z_11': Normal([torch.FloatTenso

In [17]:
for t in range(1, zs.shape[1]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 0.115434
SMC MSE at time 2: 1.573904
SMC MSE at time 3: 0.180315
SMC MSE at time 4: 2.384506
SMC MSE at time 5: 1.039185
SMC MSE at time 6: 2.150662
SMC MSE at time 7: 2.592822
SMC MSE at time 8: 0.960414
SMC MSE at time 9: 0.298991
SMC MSE at time 10: 1.673225
SMC MSE at time 11: 2.282982
SMC MSE at time 12: 0.027656
SMC MSE at time 13: 2.700670
SMC MSE at time 14: 3.691529
SMC MSE at time 15: 0.310591
SMC MSE at time 16: 0.846530
SMC MSE at time 17: 0.765306
SMC MSE at time 18: 1.312931
SMC MSE at time 19: 2.062731
SMC MSE at time 20: 0.231016
SMC MSE at time 21: 0.033182
SMC MSE at time 22: 0.087363
SMC MSE at time 23: 0.195570
SMC MSE at time 24: 1.592415
SMC MSE at time 25: 1.753107
SMC MSE at time 26: 3.770720
SMC MSE at time 27: 0.236609
SMC MSE at time 28: 0.317390
SMC MSE at time 29: 1.064816
SMC MSE at time 30: 4.013926
SMC MSE at time 31: 0.966906
SMC MSE at time 32: 1.439032
SMC MSE at time 33: 0.947906
SMC MSE at time 34: 1.327830
SMC MSE at time 35: 0.0

In [18]:
elbos

tensor([-15350.2256, -15350.2256, -15350.2256, -15350.2256, -15350.2256,
        -14092.4580, -14092.4580, -14092.4580, -14092.4580, -14092.4580,
        -11811.2852, -11811.2852, -11811.2852, -11811.2852, -11811.2852,
        -11811.2852, -11811.2852, -11811.2852, -11811.2852, -11811.2852,
        -11811.2852, -11811.2852, -11811.2852, -11811.2852, -11811.2852,
        -11811.2852, -11811.2852, -11811.2852, -11811.2852, -11811.2852,
        -11811.2852, -11811.2852, -11811.2852, -11811.2852, -11811.2852,
        -11811.2852, -11811.2852, -11811.2852, -11811.2852, -11811.2852,
        -11811.2852, -11811.2852, -11811.2852, -11811.2852, -11811.2852,
        -11811.2852, -11811.2852, -11811.2852, -11811.2852, -11811.2852,
        -11811.2852, -11811.2852, -11811.2852, -11811.2852, -11811.2852,
        -11811.2852, -11811.2852, -11811.2852, -11811.2852, -11811.2852,
        -11811.2852, -11811.2852, -11811.2852, -11811.2852, -11811.2852,
        -11811.2852, -11811.2852, -11811.2852, -118