In [1]:
import logging

import probtorch
import torch

import combinators
import smc
import ssm
import utils

In [2]:
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [3]:
generative = combinators.ParticleTrace(1)

In [4]:
mu = torch.zeros(1, 1)
sigma = torch.ones(1, 1) / 2
delta = torch.ones(1, 1)
zs = torch.ones(1, 50+1) * -1
zs[:, 0] = 0

In [5]:
ssm_step = combinators.Model(ssm.ssm_step)

In [6]:
for t in range(zs.shape[1] - 1):
    zs[:, t+1], mu, sigma, delta = ssm_step((zs[:, t], mu, sigma, delta), t, trace=generative)

In [7]:
generative

Trace{'Z_1': Normal([torch.FloatTensor of size 1x1]), 'X_1': Normal([torch.FloatTensor of size 1x1]), 'Z_2': Normal([torch.FloatTensor of size 1x1]), 'X_2': Normal([torch.FloatTensor of size 1x1]), 'Z_3': Normal([torch.FloatTensor of size 1x1]), 'X_3': Normal([torch.FloatTensor of size 1x1]), 'Z_4': Normal([torch.FloatTensor of size 1x1]), 'X_4': Normal([torch.FloatTensor of size 1x1]), 'Z_5': Normal([torch.FloatTensor of size 1x1]), 'X_5': Normal([torch.FloatTensor of size 1x1]), 'Z_6': Normal([torch.FloatTensor of size 1x1]), 'X_6': Normal([torch.FloatTensor of size 1x1]), 'Z_7': Normal([torch.FloatTensor of size 1x1]), 'X_7': Normal([torch.FloatTensor of size 1x1]), 'Z_8': Normal([torch.FloatTensor of size 1x1]), 'X_8': Normal([torch.FloatTensor of size 1x1]), 'Z_9': Normal([torch.FloatTensor of size 1x1]), 'X_9': Normal([torch.FloatTensor of size 1x1]), 'Z_10': Normal([torch.FloatTensor of size 1x1]), 'X_10': Normal([torch.FloatTensor of size 1x1]), 'Z_11': Normal([torch.FloatTenso

In [8]:
num_particles = 100

In [9]:
smc_runner = smc.smc(ssm.ssm_step, 50)

In [10]:
ssm_params = {
    'mu': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'sigma': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'delta': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    }
}

In [11]:
init_ssm = combinators.Model(ssm.init_ssm, phi=ssm_params, theta={})

In [12]:
inference, init_ssm_params = smc.variational_smc(num_particles, init_ssm, smc_runner, 500, generative,
                                                 use_cuda=True, lr=1e-2)

07/30/2018 13:03:22 Variational SMC ELBO=-1.55656553e+04 at epoch 1
07/30/2018 13:03:22 Variational SMC ELBO=-2.04559180e+04 at epoch 2
07/30/2018 13:03:23 Variational SMC ELBO=-1.90102949e+04 at epoch 3
07/30/2018 13:03:23 Variational SMC ELBO=-1.95146680e+04 at epoch 4
07/30/2018 13:03:24 Variational SMC ELBO=-1.81831621e+04 at epoch 5
07/30/2018 13:03:24 Variational SMC ELBO=-1.69172832e+04 at epoch 6
07/30/2018 13:03:25 Variational SMC ELBO=-2.47568809e+04 at epoch 7
07/30/2018 13:03:25 Variational SMC ELBO=-1.86647422e+04 at epoch 8
07/30/2018 13:03:26 Variational SMC ELBO=-2.12120176e+04 at epoch 9
07/30/2018 13:03:26 Variational SMC ELBO=-1.77020957e+04 at epoch 10
07/30/2018 13:03:27 Variational SMC ELBO=-2.28430059e+04 at epoch 11
07/30/2018 13:03:28 Variational SMC ELBO=-2.14087188e+04 at epoch 12
07/30/2018 13:03:28 Variational SMC ELBO=-1.84517930e+04 at epoch 13
07/30/2018 13:03:29 Variational SMC ELBO=-2.28326152e+04 at epoch 14
07/30/2018 13:03:29 Variational SMC ELBO=-2

In [13]:
for t in range(1, zs.shape[1]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 0.099556
SMC MSE at time 2: 0.402681
SMC MSE at time 3: 0.260956
SMC MSE at time 4: 0.037811
SMC MSE at time 5: 1.273722
SMC MSE at time 6: 1.358226
SMC MSE at time 7: 0.019702
SMC MSE at time 8: 8.070623
SMC MSE at time 9: 9.337977
SMC MSE at time 10: 0.047308
SMC MSE at time 11: 2.385378
SMC MSE at time 12: 0.038537
SMC MSE at time 13: 0.103355
SMC MSE at time 14: 0.044454
SMC MSE at time 15: 0.771564
SMC MSE at time 16: 1.472725
SMC MSE at time 17: 0.179746
SMC MSE at time 18: 2.058315
SMC MSE at time 19: 1.453321
SMC MSE at time 20: 3.487162
SMC MSE at time 21: 2.626670
SMC MSE at time 22: 0.535110
SMC MSE at time 23: 4.961892
SMC MSE at time 24: 0.099847
SMC MSE at time 25: 0.460322
SMC MSE at time 26: 0.551805
SMC MSE at time 27: 0.726268
SMC MSE at time 28: 3.315559
SMC MSE at time 29: 0.044659
SMC MSE at time 30: 0.264195
SMC MSE at time 31: 2.385571
SMC MSE at time 32: 1.111139
SMC MSE at time 33: 4.404155
SMC MSE at time 34: 5.811265
SMC MSE at time 35: 0.5

In [14]:
samples, elbos, inference = smc.particle_mh(num_particles, combinators.Model(ssm.init_ssm, theta=ssm_params), smc_runner, 500, generative)

In [15]:
inference

Trace{'mu': Normal([torch.cuda.FloatTensor of size 100]), 'sigma': Normal([torch.cuda.FloatTensor of size 100]), 'delta': Normal([torch.cuda.FloatTensor of size 100]), 'Z_0': Normal([torch.cuda.FloatTensor of size 100]), 'Z_1': Normal([torch.cuda.FloatTensor of size 100]), 'X_1': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_2': Normal([torch.cuda.FloatTensor of size 100]), 'X_2': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_3': Normal([torch.cuda.FloatTensor of size 100]), 'X_3': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_4': Normal([torch.cuda.FloatTensor of size 100]), 'X_4': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_5': Normal([torch.cuda.FloatTensor of size 100]), 'X_5': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_6': Normal([torch.cuda.FloatTensor of size 100]), 'X_6': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_7': Normal([torch.cuda.FloatTensor of size 100]), 'X_7': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_8': No

In [16]:
generative

Trace{'Z_1': Normal([torch.FloatTensor of size 1x1]), 'X_1': Normal([torch.FloatTensor of size 1x1]), 'Z_2': Normal([torch.FloatTensor of size 1x1]), 'X_2': Normal([torch.FloatTensor of size 1x1]), 'Z_3': Normal([torch.FloatTensor of size 1x1]), 'X_3': Normal([torch.FloatTensor of size 1x1]), 'Z_4': Normal([torch.FloatTensor of size 1x1]), 'X_4': Normal([torch.FloatTensor of size 1x1]), 'Z_5': Normal([torch.FloatTensor of size 1x1]), 'X_5': Normal([torch.FloatTensor of size 1x1]), 'Z_6': Normal([torch.FloatTensor of size 1x1]), 'X_6': Normal([torch.FloatTensor of size 1x1]), 'Z_7': Normal([torch.FloatTensor of size 1x1]), 'X_7': Normal([torch.FloatTensor of size 1x1]), 'Z_8': Normal([torch.FloatTensor of size 1x1]), 'X_8': Normal([torch.FloatTensor of size 1x1]), 'Z_9': Normal([torch.FloatTensor of size 1x1]), 'X_9': Normal([torch.FloatTensor of size 1x1]), 'Z_10': Normal([torch.FloatTensor of size 1x1]), 'X_10': Normal([torch.FloatTensor of size 1x1]), 'Z_11': Normal([torch.FloatTenso

In [17]:
for t in range(1, zs.shape[1]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 0.012340
SMC MSE at time 2: 2.300410
SMC MSE at time 3: 5.395426
SMC MSE at time 4: 0.190428
SMC MSE at time 5: 0.936656
SMC MSE at time 6: 0.845242
SMC MSE at time 7: 0.597309
SMC MSE at time 8: 0.069614
SMC MSE at time 9: 2.196369
SMC MSE at time 10: 2.107530
SMC MSE at time 11: 0.164513
SMC MSE at time 12: 4.758962
SMC MSE at time 13: 1.782368
SMC MSE at time 14: 1.720088
SMC MSE at time 15: 0.862607
SMC MSE at time 16: 1.797004
SMC MSE at time 17: 3.160613
SMC MSE at time 18: 1.244672
SMC MSE at time 19: 7.023519
SMC MSE at time 20: 8.287155
SMC MSE at time 21: 2.425328
SMC MSE at time 22: 0.033332
SMC MSE at time 23: 5.374684
SMC MSE at time 24: 0.028393
SMC MSE at time 25: 4.983664
SMC MSE at time 26: 0.863724
SMC MSE at time 27: 0.109640
SMC MSE at time 28: 1.772712
SMC MSE at time 29: 0.992343
SMC MSE at time 30: 4.041850
SMC MSE at time 31: 7.264267
SMC MSE at time 32: 0.233119
SMC MSE at time 33: 9.078816
SMC MSE at time 34: 12.144407
SMC MSE at time 35: 4.

In [18]:
elbos

tensor([-21326.7266, -16662.9395, -16662.9395, -16662.9395, -16662.9395,
        -16662.9395, -16662.9395, -13925.1201, -13925.1201, -13925.1201,
        -13925.1201, -13925.1201, -13925.1201, -13925.1201, -12398.2061,
        -12398.2061, -12398.2061, -12398.2061, -12398.2061, -12398.2061,
        -12398.2061, -12398.2061, -12398.2061, -12398.2061, -12398.2061,
        -12398.2061, -12398.2061, -12398.2061, -12398.2061, -12398.2061,
        -12398.2061, -12398.2061, -12398.2061, -12398.2061, -12398.2061,
        -12398.2061, -12398.2061, -12398.2061, -12398.2061, -12398.2061,
        -12398.2061, -12398.2061, -12398.2061, -12398.2061, -12398.2061,
        -12398.2061, -12398.2061, -12398.2061, -12398.2061, -12398.2061,
        -12398.2061, -12398.2061, -12398.2061, -12398.2061, -12398.2061,
        -12398.2061, -12398.2061, -12398.2061, -12398.2061, -12398.2061,
        -12398.2061, -12398.2061, -12398.2061, -12398.2061, -12398.2061,
        -12398.2061, -12398.2061, -12398.2061, -123