In [1]:
import logging

import probtorch
import torch

import combinators
import smc
import ssm
import utils

In [2]:
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [3]:
generative = smc.ParticleTrace(1)

In [4]:
mu = torch.zeros(1, 1)
sigma = torch.ones(1, 1) / 2
delta = torch.ones(1, 1)
zs = torch.ones(1, 50+1) * -1
zs[:, 0] = 0

In [5]:
ssm_step = combinators.Model(ssm.ssm_step)
ssm_step.condition(trace=generative, observations=None)

In [6]:
for t in range(zs.shape[1] - 1):
    zs[:, t+1], mu, sigma, delta = ssm_step((zs[:, t], mu, sigma, delta), t)

In [7]:
generative

Trace{'Z_1': Normal([torch.FloatTensor of size 1x1]), 'X_1': Normal([torch.FloatTensor of size 1x1]), 'Z_2': Normal([torch.FloatTensor of size 1x1]), 'X_2': Normal([torch.FloatTensor of size 1x1]), 'Z_3': Normal([torch.FloatTensor of size 1x1]), 'X_3': Normal([torch.FloatTensor of size 1x1]), 'Z_4': Normal([torch.FloatTensor of size 1x1]), 'X_4': Normal([torch.FloatTensor of size 1x1]), 'Z_5': Normal([torch.FloatTensor of size 1x1]), 'X_5': Normal([torch.FloatTensor of size 1x1]), 'Z_6': Normal([torch.FloatTensor of size 1x1]), 'X_6': Normal([torch.FloatTensor of size 1x1]), 'Z_7': Normal([torch.FloatTensor of size 1x1]), 'X_7': Normal([torch.FloatTensor of size 1x1]), 'Z_8': Normal([torch.FloatTensor of size 1x1]), 'X_8': Normal([torch.FloatTensor of size 1x1]), 'Z_9': Normal([torch.FloatTensor of size 1x1]), 'X_9': Normal([torch.FloatTensor of size 1x1]), 'Z_10': Normal([torch.FloatTensor of size 1x1]), 'X_10': Normal([torch.FloatTensor of size 1x1]), 'Z_11': Normal([torch.FloatTenso

In [8]:
num_particles = 100

In [9]:
smc_runner = smc.smc(ssm.ssm_step, ssm.ssm_retrace)

In [10]:
ssm_params = {
    'mu': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'sigma': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'delta': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    }
}

In [11]:
inference, init_ssm_params = smc.variational_smc(num_particles, ssm.init_ssm, smc_runner, 500, 50, ssm_params, generative)

07/17/2018 12:52:12 Variational SMC ELBO=-2.83866562e+04 at epoch 1
07/17/2018 12:52:12 Variational SMC ELBO=-3.21914766e+04 at epoch 2
07/17/2018 12:52:13 Variational SMC ELBO=-2.04598613e+04 at epoch 3
07/17/2018 12:52:13 Variational SMC ELBO=-2.21398398e+04 at epoch 4
07/17/2018 12:52:14 Variational SMC ELBO=-2.17681758e+04 at epoch 5
07/17/2018 12:52:14 Variational SMC ELBO=-2.41241465e+04 at epoch 6
07/17/2018 12:52:14 Variational SMC ELBO=-1.91457949e+04 at epoch 7
07/17/2018 12:52:15 Variational SMC ELBO=-2.03032891e+04 at epoch 8
07/17/2018 12:52:15 Variational SMC ELBO=-2.20991426e+04 at epoch 9
07/17/2018 12:52:16 Variational SMC ELBO=-2.40878984e+04 at epoch 10
07/17/2018 12:52:16 Variational SMC ELBO=-2.33318848e+04 at epoch 11
07/17/2018 12:52:17 Variational SMC ELBO=-1.55028906e+04 at epoch 12
07/17/2018 12:52:17 Variational SMC ELBO=-2.79902324e+04 at epoch 13
07/17/2018 12:52:17 Variational SMC ELBO=-1.60151699e+04 at epoch 14
07/17/2018 12:52:18 Variational SMC ELBO=-2

In [12]:
for t in range(1, zs.shape[1]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 2.357658
SMC MSE at time 2: 1.318651
SMC MSE at time 3: 2.232557
SMC MSE at time 4: 0.565079
SMC MSE at time 5: 9.545084
SMC MSE at time 6: 3.357962
SMC MSE at time 7: 0.013659
SMC MSE at time 8: 0.066781
SMC MSE at time 9: 0.004033
SMC MSE at time 10: 0.746017
SMC MSE at time 11: 8.520008
SMC MSE at time 12: 1.081330
SMC MSE at time 13: 4.623265
SMC MSE at time 14: 1.169900
SMC MSE at time 15: 10.703228
SMC MSE at time 16: 0.008256
SMC MSE at time 17: 11.667866
SMC MSE at time 18: 2.531841
SMC MSE at time 19: 0.112927
SMC MSE at time 20: 0.532819
SMC MSE at time 21: 4.460123
SMC MSE at time 22: 0.353634
SMC MSE at time 23: 0.044910
SMC MSE at time 24: 3.622287
SMC MSE at time 25: 5.798809
SMC MSE at time 26: 2.225645
SMC MSE at time 27: 0.008085
SMC MSE at time 28: 1.237315
SMC MSE at time 29: 0.115244
SMC MSE at time 30: 6.887531
SMC MSE at time 31: 0.152500
SMC MSE at time 32: 0.237024
SMC MSE at time 33: 0.058175
SMC MSE at time 34: 0.083359
SMC MSE at time 35: 0

In [13]:
samples, elbos, inference = smc.particle_mh(num_particles, ssm.init_ssm, smc_runner, 500, 50, ssm_params, generative)

In [14]:
inference

Trace{'mu': Normal([torch.FloatTensor of size 100]), 'sigma': Normal([torch.FloatTensor of size 100]), 'delta': Normal([torch.FloatTensor of size 100]), 'Z_0': Normal([torch.FloatTensor of size 100]), 'Z_1': Normal([torch.FloatTensor of size 100]), 'X_1': Normal([torch.FloatTensor of size 100x1x1]), 'Z_2': Normal([torch.FloatTensor of size 100]), 'X_2': Normal([torch.FloatTensor of size 100x1x1]), 'Z_3': Normal([torch.FloatTensor of size 100]), 'X_3': Normal([torch.FloatTensor of size 100x1x1]), 'Z_4': Normal([torch.FloatTensor of size 100]), 'X_4': Normal([torch.FloatTensor of size 100x1x1]), 'Z_5': Normal([torch.FloatTensor of size 100]), 'X_5': Normal([torch.FloatTensor of size 100x1x1]), 'Z_6': Normal([torch.FloatTensor of size 100]), 'X_6': Normal([torch.FloatTensor of size 100x1x1]), 'Z_7': Normal([torch.FloatTensor of size 100]), 'X_7': Normal([torch.FloatTensor of size 100x1x1]), 'Z_8': Normal([torch.FloatTensor of size 100]), 'X_8': Normal([torch.FloatTensor of size 100x1x1]),

In [15]:
generative

Trace{'Z_1': Normal([torch.FloatTensor of size 1x1]), 'X_1': Normal([torch.FloatTensor of size 1x1]), 'Z_2': Normal([torch.FloatTensor of size 1x1]), 'X_2': Normal([torch.FloatTensor of size 1x1]), 'Z_3': Normal([torch.FloatTensor of size 1x1]), 'X_3': Normal([torch.FloatTensor of size 1x1]), 'Z_4': Normal([torch.FloatTensor of size 1x1]), 'X_4': Normal([torch.FloatTensor of size 1x1]), 'Z_5': Normal([torch.FloatTensor of size 1x1]), 'X_5': Normal([torch.FloatTensor of size 1x1]), 'Z_6': Normal([torch.FloatTensor of size 1x1]), 'X_6': Normal([torch.FloatTensor of size 1x1]), 'Z_7': Normal([torch.FloatTensor of size 1x1]), 'X_7': Normal([torch.FloatTensor of size 1x1]), 'Z_8': Normal([torch.FloatTensor of size 1x1]), 'X_8': Normal([torch.FloatTensor of size 1x1]), 'Z_9': Normal([torch.FloatTensor of size 1x1]), 'X_9': Normal([torch.FloatTensor of size 1x1]), 'Z_10': Normal([torch.FloatTensor of size 1x1]), 'X_10': Normal([torch.FloatTensor of size 1x1]), 'Z_11': Normal([torch.FloatTenso

In [16]:
for t in range(1, zs.shape[1]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 1.266341
SMC MSE at time 2: 3.044266
SMC MSE at time 3: 3.493159
SMC MSE at time 4: 0.506778
SMC MSE at time 5: 0.873811
SMC MSE at time 6: 4.163754
SMC MSE at time 7: 0.907463
SMC MSE at time 8: 0.818031
SMC MSE at time 9: 0.295777
SMC MSE at time 10: 0.302492
SMC MSE at time 11: 4.485109
SMC MSE at time 12: 0.797736
SMC MSE at time 13: 0.314461
SMC MSE at time 14: 1.197501
SMC MSE at time 15: 2.886729
SMC MSE at time 16: 0.171307
SMC MSE at time 17: 2.181560
SMC MSE at time 18: 1.526894
SMC MSE at time 19: 1.969720
SMC MSE at time 20: 4.990408
SMC MSE at time 21: 2.913910
SMC MSE at time 22: 0.421794
SMC MSE at time 23: 0.272783
SMC MSE at time 24: 0.461016
SMC MSE at time 25: 3.441230
SMC MSE at time 26: 1.112032
SMC MSE at time 27: 0.486193
SMC MSE at time 28: 0.875669
SMC MSE at time 29: 0.443883
SMC MSE at time 30: 2.926815
SMC MSE at time 31: 0.460008
SMC MSE at time 32: 0.336158
SMC MSE at time 33: 1.379226
SMC MSE at time 34: 0.096752
SMC MSE at time 35: 0.2