In [1]:
import logging

import probtorch
import torch

import combinators
import importance
import smc
import ssm
import utils

In [2]:
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [3]:
generative = combinators.ParticleTrace(1)

In [4]:
mu = torch.zeros(1, 1)
sigma = torch.ones(1, 1) / 2
delta = torch.ones(1, 1)
zs = torch.ones(1, 50+1) * -1
zs[:, 0] = 0

In [5]:
ssm_step = combinators.Model(ssm.ssm_step)

In [6]:
for t in range(zs.shape[1] - 1):
    zs[:, t+1], mu, sigma, delta = ssm_step((zs[:, t], mu, sigma, delta), t, trace=generative)

In [7]:
generative

Trace{'Z_1': Normal([torch.FloatTensor of size 1x1]), 'X_1': Normal([torch.FloatTensor of size 1x1]), 'Z_2': Normal([torch.FloatTensor of size 1x1]), 'X_2': Normal([torch.FloatTensor of size 1x1]), 'Z_3': Normal([torch.FloatTensor of size 1x1]), 'X_3': Normal([torch.FloatTensor of size 1x1]), 'Z_4': Normal([torch.FloatTensor of size 1x1]), 'X_4': Normal([torch.FloatTensor of size 1x1]), 'Z_5': Normal([torch.FloatTensor of size 1x1]), 'X_5': Normal([torch.FloatTensor of size 1x1]), 'Z_6': Normal([torch.FloatTensor of size 1x1]), 'X_6': Normal([torch.FloatTensor of size 1x1]), 'Z_7': Normal([torch.FloatTensor of size 1x1]), 'X_7': Normal([torch.FloatTensor of size 1x1]), 'Z_8': Normal([torch.FloatTensor of size 1x1]), 'X_8': Normal([torch.FloatTensor of size 1x1]), 'Z_9': Normal([torch.FloatTensor of size 1x1]), 'X_9': Normal([torch.FloatTensor of size 1x1]), 'Z_10': Normal([torch.FloatTensor of size 1x1]), 'X_10': Normal([torch.FloatTensor of size 1x1]), 'Z_11': Normal([torch.FloatTenso

In [8]:
num_particles = 100

In [9]:
smc_runner = smc.smc(ssm.ssm_step, 50)

In [10]:
ssm_params = {
    'mu': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'sigma': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'delta': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    }
}

In [11]:
init_ssm = combinators.Model(ssm.init_ssm, trainable=ssm_params)

In [None]:
inference, init_ssm_params = smc.variational_smc(num_particles, init_ssm, smc_runner, 500, generative,
                                                 use_cuda=True, lr=1e-2)

07/30/2018 18:40:23 Variational SMC ELBO=-2.30102441e+04 at epoch 1
07/30/2018 18:40:23 Variational SMC ELBO=-1.57867451e+04 at epoch 2
07/30/2018 18:40:24 Variational SMC ELBO=-1.92311465e+04 at epoch 3
07/30/2018 18:40:24 Variational SMC ELBO=-1.64848809e+04 at epoch 4
07/30/2018 18:40:25 Variational SMC ELBO=-1.52201025e+04 at epoch 5
07/30/2018 18:40:25 Variational SMC ELBO=-1.54481221e+04 at epoch 6
07/30/2018 18:40:26 Variational SMC ELBO=-1.31173076e+04 at epoch 7
07/30/2018 18:40:26 Variational SMC ELBO=-2.30344199e+04 at epoch 8
07/30/2018 18:40:27 Variational SMC ELBO=-1.51401982e+04 at epoch 9
07/30/2018 18:40:27 Variational SMC ELBO=-2.20861562e+04 at epoch 10
07/30/2018 18:40:28 Variational SMC ELBO=-1.99741074e+04 at epoch 11
07/30/2018 18:40:28 Variational SMC ELBO=-2.25494102e+04 at epoch 12
07/30/2018 18:40:29 Variational SMC ELBO=-1.85026855e+04 at epoch 13
07/30/2018 18:40:29 Variational SMC ELBO=-2.10304473e+04 at epoch 14
07/30/2018 18:40:30 Variational SMC ELBO=-1

In [None]:
for t in range(1, zs.shape[1]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

In [None]:
particle_mh = smc.ParticleMH(combinators.Model(ssm.init_ssm, hyper=ssm_params), smc_runner, num_particles)

In [None]:
samples, elbos = particle_mh(trace=importance.ResamplerTrace(num_particles), guide=generative)
inference = particle_mh.trace

In [None]:
inference

In [None]:
generative

In [None]:
for t in range(1, zs.shape[1]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

In [None]:
elbos