In [1]:
import logging

import probtorch
import torch

import combinators
import importance
import smc
import ssm
import utils

In [2]:
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [3]:
generative = combinators.ParticleTrace(1)

In [4]:
mu = torch.zeros(1, 1)
sigma = torch.ones(1, 1) / 2
delta = torch.ones(1, 1)
zs = torch.ones(1, 50+1) * -1
zs[:, 0] = 0

In [5]:
ssm_step = combinators.Model(ssm.ssm_step)

In [6]:
for t in range(zs.shape[1] - 1):
    zs[:, t+1], mu, sigma, delta = ssm_step((zs[:, t], mu, sigma, delta), t, trace=generative)

In [7]:
generative

Trace{'Z_1': Normal([torch.FloatTensor of size 1x1]), 'X_1': Normal([torch.FloatTensor of size 1x1]), 'Z_2': Normal([torch.FloatTensor of size 1x1]), 'X_2': Normal([torch.FloatTensor of size 1x1]), 'Z_3': Normal([torch.FloatTensor of size 1x1]), 'X_3': Normal([torch.FloatTensor of size 1x1]), 'Z_4': Normal([torch.FloatTensor of size 1x1]), 'X_4': Normal([torch.FloatTensor of size 1x1]), 'Z_5': Normal([torch.FloatTensor of size 1x1]), 'X_5': Normal([torch.FloatTensor of size 1x1]), 'Z_6': Normal([torch.FloatTensor of size 1x1]), 'X_6': Normal([torch.FloatTensor of size 1x1]), 'Z_7': Normal([torch.FloatTensor of size 1x1]), 'X_7': Normal([torch.FloatTensor of size 1x1]), 'Z_8': Normal([torch.FloatTensor of size 1x1]), 'X_8': Normal([torch.FloatTensor of size 1x1]), 'Z_9': Normal([torch.FloatTensor of size 1x1]), 'X_9': Normal([torch.FloatTensor of size 1x1]), 'Z_10': Normal([torch.FloatTensor of size 1x1]), 'X_10': Normal([torch.FloatTensor of size 1x1]), 'Z_11': Normal([torch.FloatTenso

In [8]:
num_particles = 100

In [9]:
smc_runner = smc.SequentialMonteCarlo(ssm.ssm_step, 50)

In [10]:
ssm_params = {
    'mu': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'sigma': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'delta': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    }
}

In [11]:
init_ssm = combinators.Model(ssm.init_ssm, trainable=ssm_params)

In [12]:
inference, init_ssm_params = smc.variational_smc(num_particles, init_ssm, smc_runner, 500, generative,
                                                 use_cuda=True, lr=1e-2)

07/30/2018 21:01:28 Variational SMC ELBO=-1.98209609e+04 at epoch 1
07/30/2018 21:01:28 Variational SMC ELBO=-1.44563291e+04 at epoch 2
07/30/2018 21:01:29 Variational SMC ELBO=-1.90572168e+04 at epoch 3
07/30/2018 21:01:29 Variational SMC ELBO=-1.96702227e+04 at epoch 4
07/30/2018 21:01:30 Variational SMC ELBO=-1.83048125e+04 at epoch 5
07/30/2018 21:01:30 Variational SMC ELBO=-2.40485664e+04 at epoch 6
07/30/2018 21:01:31 Variational SMC ELBO=-2.25600801e+04 at epoch 7
07/30/2018 21:01:32 Variational SMC ELBO=-2.22402617e+04 at epoch 8
07/30/2018 21:01:32 Variational SMC ELBO=-2.18565059e+04 at epoch 9
07/30/2018 21:01:32 Variational SMC ELBO=-1.48391201e+04 at epoch 10
07/30/2018 21:01:33 Variational SMC ELBO=-2.27129863e+04 at epoch 11
07/30/2018 21:01:33 Variational SMC ELBO=-1.89049863e+04 at epoch 12
07/30/2018 21:01:34 Variational SMC ELBO=-2.19309785e+04 at epoch 13
07/30/2018 21:01:34 Variational SMC ELBO=-2.08573691e+04 at epoch 14
07/30/2018 21:01:35 Variational SMC ELBO=-1

In [13]:
for t in range(1, zs.shape[1]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 0.781043
SMC MSE at time 2: 0.013105
SMC MSE at time 3: 1.567353
SMC MSE at time 4: 1.537120
SMC MSE at time 5: 0.003059
SMC MSE at time 6: 3.099310
SMC MSE at time 7: 0.021023
SMC MSE at time 8: 0.000009
SMC MSE at time 9: 3.967583
SMC MSE at time 10: 0.940300
SMC MSE at time 11: 0.002547
SMC MSE at time 12: 0.125580
SMC MSE at time 13: 1.575669
SMC MSE at time 14: 0.009845
SMC MSE at time 15: 0.010262
SMC MSE at time 16: 8.543577
SMC MSE at time 17: 0.157636
SMC MSE at time 18: 7.986557
SMC MSE at time 19: 0.002137
SMC MSE at time 20: 0.746088
SMC MSE at time 21: 0.420019
SMC MSE at time 22: 0.088286
SMC MSE at time 23: 1.559742
SMC MSE at time 24: 5.579687
SMC MSE at time 25: 0.308456
SMC MSE at time 26: 0.986532
SMC MSE at time 27: 1.389776
SMC MSE at time 28: 1.014365
SMC MSE at time 29: 0.123951
SMC MSE at time 30: 4.969187
SMC MSE at time 31: 1.671149
SMC MSE at time 32: 0.311712
SMC MSE at time 33: 1.158751
SMC MSE at time 34: 3.839809
SMC MSE at time 35: 0.6

In [14]:
particle_mh = smc.ParticleMH(combinators.Model(ssm.init_ssm, hyper=ssm_params), smc_runner, num_particles)

In [15]:
samples, elbos = particle_mh(trace=importance.ResamplerTrace(num_particles), guide=generative)
inference = particle_mh.trace

In [16]:
inference

Trace{'mu': Normal([torch.FloatTensor of size 100]), 'sigma': Normal([torch.FloatTensor of size 100]), 'delta': Normal([torch.FloatTensor of size 100]), 'Z_0': Normal([torch.FloatTensor of size 100]), 'Z_1': Normal([torch.FloatTensor of size 100]), 'X_1': Normal([torch.FloatTensor of size 100x1x1]), 'Z_2': Normal([torch.FloatTensor of size 100]), 'X_2': Normal([torch.FloatTensor of size 100x1x1]), 'Z_3': Normal([torch.FloatTensor of size 100]), 'X_3': Normal([torch.FloatTensor of size 100x1x1]), 'Z_4': Normal([torch.FloatTensor of size 100]), 'X_4': Normal([torch.FloatTensor of size 100x1x1]), 'Z_5': Normal([torch.FloatTensor of size 100]), 'X_5': Normal([torch.FloatTensor of size 100x1x1]), 'Z_6': Normal([torch.FloatTensor of size 100]), 'X_6': Normal([torch.FloatTensor of size 100x1x1]), 'Z_7': Normal([torch.FloatTensor of size 100]), 'X_7': Normal([torch.FloatTensor of size 100x1x1]), 'Z_8': Normal([torch.FloatTensor of size 100]), 'X_8': Normal([torch.FloatTensor of size 100x1x1]),

In [17]:
generative

Trace{'Z_1': Normal([torch.FloatTensor of size 1x1]), 'X_1': Normal([torch.FloatTensor of size 1x1]), 'Z_2': Normal([torch.FloatTensor of size 1x1]), 'X_2': Normal([torch.FloatTensor of size 1x1]), 'Z_3': Normal([torch.FloatTensor of size 1x1]), 'X_3': Normal([torch.FloatTensor of size 1x1]), 'Z_4': Normal([torch.FloatTensor of size 1x1]), 'X_4': Normal([torch.FloatTensor of size 1x1]), 'Z_5': Normal([torch.FloatTensor of size 1x1]), 'X_5': Normal([torch.FloatTensor of size 1x1]), 'Z_6': Normal([torch.FloatTensor of size 1x1]), 'X_6': Normal([torch.FloatTensor of size 1x1]), 'Z_7': Normal([torch.FloatTensor of size 1x1]), 'X_7': Normal([torch.FloatTensor of size 1x1]), 'Z_8': Normal([torch.FloatTensor of size 1x1]), 'X_8': Normal([torch.FloatTensor of size 1x1]), 'Z_9': Normal([torch.FloatTensor of size 1x1]), 'X_9': Normal([torch.FloatTensor of size 1x1]), 'Z_10': Normal([torch.FloatTensor of size 1x1]), 'X_10': Normal([torch.FloatTensor of size 1x1]), 'Z_11': Normal([torch.FloatTenso

In [18]:
for t in range(1, zs.shape[1]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 0.009432
SMC MSE at time 2: 1.065997
SMC MSE at time 3: 4.582652
SMC MSE at time 4: 0.043449
SMC MSE at time 5: 5.061024
SMC MSE at time 6: 0.084090
SMC MSE at time 7: 1.724977
SMC MSE at time 8: 0.543080
SMC MSE at time 9: 1.615933
SMC MSE at time 10: 0.087068
SMC MSE at time 11: 4.662377
SMC MSE at time 12: 4.264666
SMC MSE at time 13: 5.609253
SMC MSE at time 14: 0.632022
SMC MSE at time 15: 7.029235
SMC MSE at time 16: 3.302767
SMC MSE at time 17: 0.432468
SMC MSE at time 18: 4.122752
SMC MSE at time 19: 0.880092
SMC MSE at time 20: 0.204926
SMC MSE at time 21: 3.503390
SMC MSE at time 22: 2.370379
SMC MSE at time 23: 2.794586
SMC MSE at time 24: 5.043123
SMC MSE at time 25: 1.222393
SMC MSE at time 26: 0.501718
SMC MSE at time 27: 0.277068
SMC MSE at time 28: 1.086754
SMC MSE at time 29: 1.840037
SMC MSE at time 30: 2.081943
SMC MSE at time 31: 0.432089
SMC MSE at time 32: 2.390168
SMC MSE at time 33: 2.136827
SMC MSE at time 34: 1.641999
SMC MSE at time 35: 0.1

In [19]:
elbos

tensor([-19514.8047, -19514.8047, -19514.8047, -19514.8047, -17708.9629,
        -17708.9629, -17708.9629, -16402.4395, -16402.4395, -16402.4395,
        -16402.4395, -16402.4395, -16402.4395, -16402.4395, -16402.4395,
        -16402.4395, -16402.4395, -16402.4395, -16402.4395, -16402.4395,
        -16402.4395, -16402.4395, -16402.4395, -16402.4395, -16402.4395,
        -14867.7881, -14867.7881, -14867.7881, -14867.7881, -14867.7881,
        -14867.7881, -14867.7881, -14867.7881, -14867.7881, -14867.7881,
        -14867.7881, -14537.9990, -14537.9990, -14537.9990, -14537.9990,
        -14537.9990, -14537.9990, -14537.9990, -13280.4844, -13280.4844,
        -13280.4844, -13280.4844, -13280.4844, -13280.4844, -13280.4844,
        -13280.4844, -13280.4844, -13280.4844, -13280.4844, -13280.4844,
        -13280.4844, -13280.4844, -13280.4844, -13280.4844, -13280.4844,
        -13280.4844, -13280.4844, -13280.4844, -13280.4844, -13280.4844,
        -13280.4844, -13280.4844, -13280.4844, -132