In [1]:
import logging

import probtorch
import torch

import combinators
import smc
import ssm
import utils

In [2]:
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [3]:
generative = smc.ParticleTrace(1)

In [4]:
mu = torch.zeros(1, 1)
sigma = torch.ones(1, 1) / 2
delta = torch.ones(1, 1)
zs = torch.ones(1, 50+1) * -1
zs[:, 0] = 0

In [5]:
for t in range(zs.shape[1] - 1):
    zs[:, t+1], mu, sigma, delta, generative = ssm.ssm_step(zs[:, t], mu, sigma, delta, t, generative)

In [6]:
generative

Trace{'Z_1': Normal([torch.FloatTensor of size 1x1]), 'X_1': Normal([torch.FloatTensor of size 1x1]), 'Z_2': Normal([torch.FloatTensor of size 1x1]), 'X_2': Normal([torch.FloatTensor of size 1x1]), 'Z_3': Normal([torch.FloatTensor of size 1x1]), 'X_3': Normal([torch.FloatTensor of size 1x1]), 'Z_4': Normal([torch.FloatTensor of size 1x1]), 'X_4': Normal([torch.FloatTensor of size 1x1]), 'Z_5': Normal([torch.FloatTensor of size 1x1]), 'X_5': Normal([torch.FloatTensor of size 1x1]), 'Z_6': Normal([torch.FloatTensor of size 1x1]), 'X_6': Normal([torch.FloatTensor of size 1x1]), 'Z_7': Normal([torch.FloatTensor of size 1x1]), 'X_7': Normal([torch.FloatTensor of size 1x1]), 'Z_8': Normal([torch.FloatTensor of size 1x1]), 'X_8': Normal([torch.FloatTensor of size 1x1]), 'Z_9': Normal([torch.FloatTensor of size 1x1]), 'X_9': Normal([torch.FloatTensor of size 1x1]), 'Z_10': Normal([torch.FloatTensor of size 1x1]), 'X_10': Normal([torch.FloatTensor of size 1x1]), 'Z_11': Normal([torch.FloatTenso

In [7]:
num_particles = 25

In [8]:
smc_runner = smc.smc(ssm.ssm_step, ssm.ssm_retrace)

In [9]:
ssm_params = {
    'mu': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'sigma': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'delta': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    }
}

In [10]:
inference, init_ssm = smc.variational_smc(num_particles, ssm.init_ssm, smc_runner, 500, 50, ssm_params, generative)

07/06/2018 15:05:18 Variational SMC ELBO=-1.6094e+02 at epoch 0
07/06/2018 15:05:18 Variational SMC ELBO=-1.6094e+02 at epoch 1
07/06/2018 15:05:19 Variational SMC ELBO=-1.6094e+02 at epoch 2
07/06/2018 15:05:20 Variational SMC ELBO=-1.6094e+02 at epoch 3
07/06/2018 15:05:21 Variational SMC ELBO=-1.6094e+02 at epoch 4
07/06/2018 15:05:21 Variational SMC ELBO=-1.6094e+02 at epoch 5
07/06/2018 15:05:22 Variational SMC ELBO=-1.6094e+02 at epoch 6
07/06/2018 15:05:23 Variational SMC ELBO=-1.6094e+02 at epoch 7
07/06/2018 15:05:23 Variational SMC ELBO=-1.6094e+02 at epoch 8
07/06/2018 15:05:24 Variational SMC ELBO=-1.6094e+02 at epoch 9
07/06/2018 15:05:25 Variational SMC ELBO=-1.6094e+02 at epoch 10
07/06/2018 15:05:26 Variational SMC ELBO=-1.6094e+02 at epoch 11
07/06/2018 15:05:26 Variational SMC ELBO=-1.6094e+02 at epoch 12
07/06/2018 15:05:27 Variational SMC ELBO=-1.6094e+02 at epoch 13
07/06/2018 15:05:28 Variational SMC ELBO=-1.6094e+02 at epoch 14
07/06/2018 15:05:29 Variational SMC

In [11]:
inference

Trace{'mu': Normal([torch.cuda.FloatTensor of size 25]), 'sigma': Normal([torch.cuda.FloatTensor of size 25]), 'delta': Normal([torch.cuda.FloatTensor of size 25]), 'Z_0': Normal([torch.cuda.FloatTensor of size 25]), 'Z_1': Normal([torch.cuda.FloatTensor of size 25]), 'X_1': Normal([torch.cuda.FloatTensor of size 25x1x1]), 'Z_2': Normal([torch.cuda.FloatTensor of size 25]), 'X_2': Normal([torch.cuda.FloatTensor of size 25x1x1]), 'Z_3': Normal([torch.cuda.FloatTensor of size 25]), 'X_3': Normal([torch.cuda.FloatTensor of size 25x1x1]), 'Z_4': Normal([torch.cuda.FloatTensor of size 25]), 'X_4': Normal([torch.cuda.FloatTensor of size 25x1x1]), 'Z_5': Normal([torch.cuda.FloatTensor of size 25]), 'X_5': Normal([torch.cuda.FloatTensor of size 25x1x1]), 'Z_6': Normal([torch.cuda.FloatTensor of size 25]), 'X_6': Normal([torch.cuda.FloatTensor of size 25x1x1]), 'Z_7': Normal([torch.cuda.FloatTensor of size 25]), 'X_7': Normal([torch.cuda.FloatTensor of size 25x1x1]), 'Z_8': Normal([torch.cuda.F

In [12]:
generative

Trace{'Z_1': Normal([torch.FloatTensor of size 1x1]), 'X_1': Normal([torch.FloatTensor of size 1x1]), 'Z_2': Normal([torch.FloatTensor of size 1x1]), 'X_2': Normal([torch.FloatTensor of size 1x1]), 'Z_3': Normal([torch.FloatTensor of size 1x1]), 'X_3': Normal([torch.FloatTensor of size 1x1]), 'Z_4': Normal([torch.FloatTensor of size 1x1]), 'X_4': Normal([torch.FloatTensor of size 1x1]), 'Z_5': Normal([torch.FloatTensor of size 1x1]), 'X_5': Normal([torch.FloatTensor of size 1x1]), 'Z_6': Normal([torch.FloatTensor of size 1x1]), 'X_6': Normal([torch.FloatTensor of size 1x1]), 'Z_7': Normal([torch.FloatTensor of size 1x1]), 'X_7': Normal([torch.FloatTensor of size 1x1]), 'Z_8': Normal([torch.FloatTensor of size 1x1]), 'X_8': Normal([torch.FloatTensor of size 1x1]), 'Z_9': Normal([torch.FloatTensor of size 1x1]), 'X_9': Normal([torch.FloatTensor of size 1x1]), 'Z_10': Normal([torch.FloatTensor of size 1x1]), 'X_10': Normal([torch.FloatTensor of size 1x1]), 'Z_11': Normal([torch.FloatTenso

In [13]:
for t in range(1, zs.shape[1]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 3.606209
SMC MSE at time 2: 1.196524
SMC MSE at time 3: 2.222853
SMC MSE at time 4: 0.001626
SMC MSE at time 5: 0.058935
SMC MSE at time 6: 0.000600
SMC MSE at time 7: 1.813691
SMC MSE at time 8: 1.584975
SMC MSE at time 9: 1.818634
SMC MSE at time 10: 0.003384
SMC MSE at time 11: 0.088939
SMC MSE at time 12: 0.047706
SMC MSE at time 13: 0.878667
SMC MSE at time 14: 0.151569
SMC MSE at time 15: 1.270037
SMC MSE at time 16: 0.000190
SMC MSE at time 17: 0.662449
SMC MSE at time 18: 0.883696
SMC MSE at time 19: 0.196870
SMC MSE at time 20: 3.234029
SMC MSE at time 21: 6.630997
SMC MSE at time 22: 0.028301
SMC MSE at time 23: 0.231596
SMC MSE at time 24: 0.273923
SMC MSE at time 25: 2.335908
SMC MSE at time 26: 0.027013
SMC MSE at time 27: 0.080081
SMC MSE at time 28: 4.457159
SMC MSE at time 29: 3.230865
SMC MSE at time 30: 0.107517
SMC MSE at time 31: 5.494481
SMC MSE at time 32: 0.482205
SMC MSE at time 33: 0.451217
SMC MSE at time 34: 0.020088
SMC MSE at time 35: 1.7