In [1]:
import logging

import probtorch
import torch

import combinators
import importance
import smc
import ssm
import utils

In [2]:
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [3]:
generative = combinators.ParticleTrace(1)

In [4]:
mu = torch.zeros(1)
sigma = torch.ones(1) / 2
delta = torch.ones(1)
zs = torch.ones(50+1) * -1
zs[0] = 0

In [5]:
ssm_step = combinators.Model(ssm.ssm_step)

In [6]:
for t in range(zs.shape[0] - 1):
    zs[t+1], mu, sigma, delta = ssm_step((zs[t], mu, sigma, delta), t, trace=generative)

In [7]:
data = generative.unwrap(lambda k, rv: 'X_' in k)

In [8]:
num_particles = 100

In [9]:
smc_runner = smc.SequentialMonteCarlo(ssm.ssm_step, 50)

In [10]:
ssm_params = {
    'mu': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'sigma': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'delta': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    }
}

In [11]:
init_ssm_proposal = combinators.Model(ssm.init_ssm, trainable=ssm_params)

In [12]:
init_ssm_generative = combinators.Model(ssm.init_ssm, hyper=ssm_params)

In [13]:
inference, init_ssm_params = smc.variational_smc(num_particles, init_ssm_generative, smc_runner, 500, data,
                                                 init_ssm_proposal, use_cuda=True, lr=1e-1)

08/21/2018 12:53:22 Variational SMC ELBO=-2.86959229e+02 at epoch 1
08/21/2018 12:53:23 Variational SMC ELBO=-2.86898285e+02 at epoch 2
08/21/2018 12:53:24 Variational SMC ELBO=-2.92547546e+02 at epoch 3
08/21/2018 12:53:24 Variational SMC ELBO=-3.09805481e+02 at epoch 4
08/21/2018 12:53:25 Variational SMC ELBO=-3.00864441e+02 at epoch 5
08/21/2018 12:53:26 Variational SMC ELBO=-3.02714294e+02 at epoch 6
08/21/2018 12:53:26 Variational SMC ELBO=-2.88314331e+02 at epoch 7
08/21/2018 12:53:27 Variational SMC ELBO=-2.77015015e+02 at epoch 8
08/21/2018 12:53:27 Variational SMC ELBO=-2.94663055e+02 at epoch 9
08/21/2018 12:53:28 Variational SMC ELBO=-2.91173950e+02 at epoch 10
08/21/2018 12:53:29 Variational SMC ELBO=-3.27092651e+02 at epoch 11
08/21/2018 12:53:29 Variational SMC ELBO=-3.04665985e+02 at epoch 12
08/21/2018 12:53:30 Variational SMC ELBO=-3.12079620e+02 at epoch 13
08/21/2018 12:53:31 Variational SMC ELBO=-2.73465485e+02 at epoch 14
08/21/2018 12:53:31 Variational SMC ELBO=-2

In [14]:
for t in range(1, zs.shape[0]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 22.759600
SMC MSE at time 2: 10.453484
SMC MSE at time 3: 0.945925
SMC MSE at time 4: 1.109453
SMC MSE at time 5: 0.213952
SMC MSE at time 6: 0.556554
SMC MSE at time 7: 0.215961
SMC MSE at time 8: 0.570421
SMC MSE at time 9: 0.311704
SMC MSE at time 10: 3.853765
SMC MSE at time 11: 7.211798
SMC MSE at time 12: 7.302386
SMC MSE at time 13: 6.036883
SMC MSE at time 14: 0.192786
SMC MSE at time 15: 0.255359
SMC MSE at time 16: 0.091930
SMC MSE at time 17: 1.540784
SMC MSE at time 18: 0.348285
SMC MSE at time 19: 1.093883
SMC MSE at time 20: 1.155147
SMC MSE at time 21: 0.358450
SMC MSE at time 22: 0.318781
SMC MSE at time 23: 4.058233
SMC MSE at time 24: 0.972230
SMC MSE at time 25: 0.055192
SMC MSE at time 26: 0.066936
SMC MSE at time 27: 0.153464
SMC MSE at time 28: 1.231830
SMC MSE at time 29: 4.811998
SMC MSE at time 30: 3.047478
SMC MSE at time 31: 0.011762
SMC MSE at time 32: 4.430334
SMC MSE at time 33: 1.259596
SMC MSE at time 34: 0.266555
SMC MSE at time 35: 0

In [15]:
particle_mh = smc.ParticleMH(combinators.Model(ssm.init_ssm, hyper=ssm_params), smc_runner, num_particles)

In [16]:
samples, elbos = particle_mh(trace=importance.ResamplerTrace(num_particles, data=data))
inference = particle_mh.trace

In [17]:
inference

Trace{'mu': Normal([torch.FloatTensor of size 100]), 'sigma': Normal([torch.FloatTensor of size 100]), 'delta': Normal([torch.FloatTensor of size 100]), 'Z_0': Normal([torch.FloatTensor of size 100]), 'Z_1': Normal([torch.FloatTensor of size 100]), 'X_1': Normal([torch.FloatTensor of size 100]), 'Z_2': Normal([torch.FloatTensor of size 100]), 'X_2': Normal([torch.FloatTensor of size 100]), 'Z_3': Normal([torch.FloatTensor of size 100]), 'X_3': Normal([torch.FloatTensor of size 100]), 'Z_4': Normal([torch.FloatTensor of size 100]), 'X_4': Normal([torch.FloatTensor of size 100]), 'Z_5': Normal([torch.FloatTensor of size 100]), 'X_5': Normal([torch.FloatTensor of size 100]), 'Z_6': Normal([torch.FloatTensor of size 100]), 'X_6': Normal([torch.FloatTensor of size 100]), 'Z_7': Normal([torch.FloatTensor of size 100]), 'X_7': Normal([torch.FloatTensor of size 100]), 'Z_8': Normal([torch.FloatTensor of size 100]), 'X_8': Normal([torch.FloatTensor of size 100]), 'Z_9': Normal([torch.FloatTenso

In [18]:
generative

Trace{'Z_1': Normal([torch.FloatTensor of size 1]), 'X_1': Normal([torch.FloatTensor of size 1]), 'Z_2': Normal([torch.FloatTensor of size 1]), 'X_2': Normal([torch.FloatTensor of size 1]), 'Z_3': Normal([torch.FloatTensor of size 1]), 'X_3': Normal([torch.FloatTensor of size 1]), 'Z_4': Normal([torch.FloatTensor of size 1]), 'X_4': Normal([torch.FloatTensor of size 1]), 'Z_5': Normal([torch.FloatTensor of size 1]), 'X_5': Normal([torch.FloatTensor of size 1]), 'Z_6': Normal([torch.FloatTensor of size 1]), 'X_6': Normal([torch.FloatTensor of size 1]), 'Z_7': Normal([torch.FloatTensor of size 1]), 'X_7': Normal([torch.FloatTensor of size 1]), 'Z_8': Normal([torch.FloatTensor of size 1]), 'X_8': Normal([torch.FloatTensor of size 1]), 'Z_9': Normal([torch.FloatTensor of size 1]), 'X_9': Normal([torch.FloatTensor of size 1]), 'Z_10': Normal([torch.FloatTensor of size 1]), 'X_10': Normal([torch.FloatTensor of size 1]), 'Z_11': Normal([torch.FloatTensor of size 1]), 'X_11': Normal([torch.Flo

In [19]:
for t in range(1, zs.shape[0]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 0.909398
SMC MSE at time 2: 0.020387
SMC MSE at time 3: 2.317432
SMC MSE at time 4: 0.346826
SMC MSE at time 5: 0.045961
SMC MSE at time 6: 0.306583
SMC MSE at time 7: 0.525512
SMC MSE at time 8: 0.034793
SMC MSE at time 9: 0.050570
SMC MSE at time 10: 2.704625
SMC MSE at time 11: 8.628906
SMC MSE at time 12: 5.400190
SMC MSE at time 13: 6.256023
SMC MSE at time 14: 0.110076
SMC MSE at time 15: 0.543156
SMC MSE at time 16: 0.027027
SMC MSE at time 17: 1.933975
SMC MSE at time 18: 1.114132
SMC MSE at time 19: 1.882058
SMC MSE at time 20: 4.439986
SMC MSE at time 21: 0.836910
SMC MSE at time 22: 0.174652
SMC MSE at time 23: 1.288056
SMC MSE at time 24: 0.033283
SMC MSE at time 25: 0.950033
SMC MSE at time 26: 0.882588
SMC MSE at time 27: 1.606020
SMC MSE at time 28: 0.809595
SMC MSE at time 29: 2.311864
SMC MSE at time 30: 2.021975
SMC MSE at time 31: 0.147093
SMC MSE at time 32: 3.453352
SMC MSE at time 33: 0.319565
SMC MSE at time 34: 0.176308
SMC MSE at time 35: 0.3

In [20]:
elbos

tensor([-301.3392, -303.0569, -300.5208, -300.5208, -300.5208, -272.4926,
        -272.4926, -272.4926, -272.4926, -272.4926, -272.4926, -272.4926,
        -272.4926, -272.4926, -272.4926, -272.4926, -272.4926, -272.4926,
        -272.4926, -272.4926, -272.4926, -272.4926, -272.4926, -272.4926,
        -272.4926, -272.4926, -272.4926, -272.4926, -272.4926, -272.4926,
        -272.4926, -272.4926, -272.4926, -272.4926, -272.4926, -272.4926,
        -272.4926, -272.4926, -272.4926, -272.4926, -272.4926, -272.4926,
        -272.4926, -272.4926, -272.4926, -272.4926, -270.7831, -270.7831,
        -270.7831, -270.7831, -270.7831, -270.7831, -263.6644, -263.6644,
        -263.6644, -263.6644, -263.6644, -263.6644, -263.6644, -263.6644,
        -263.6644, -263.6644, -263.6644, -263.6644, -263.6644, -263.6644,
        -263.6644, -263.6644, -263.6644, -263.6644, -263.6644, -263.6644,
        -263.6644, -263.6644, -263.6644, -263.6644, -263.6644, -263.6644,
        -263.6644, -263.6644, -263.664