In [1]:
import logging

import probtorch
import torch

import combinators
import importance
import smc
import ssm
import utils

In [2]:
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [3]:
generative = combinators.ParticleTrace(1)

In [4]:
mu = torch.zeros(1)
sigma = torch.ones(1) / 2
delta = torch.ones(1)
zs = torch.ones(50+1) * -1
zs[0] = 0

In [5]:
ssm_step = combinators.Model(ssm.ssm_step)

In [6]:
for t in range(zs.shape[0] - 1):
    zs[t+1], mu, sigma, delta = ssm_step((zs[t], mu, sigma, delta), t, trace=generative)

In [7]:
data = generative.unwrap(lambda k, rv: 'X_' in k)

In [8]:
num_particles = 100

In [9]:
smc_runner = smc.SequentialMonteCarlo(ssm.ssm_step, 50)

In [10]:
ssm_params = {
    'mu': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'sigma': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'delta': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    }
}

In [11]:
init_ssm = combinators.Model(ssm.init_ssm, trainable=ssm_params)

In [12]:
inference, init_ssm_params = smc.variational_smc(num_particles, init_ssm, smc_runner, 500, data,
                                                 use_cuda=True, lr=1e-2)

08/03/2018 12:49:58 Variational SMC ELBO=-2.77575317e+02 at epoch 1
08/03/2018 12:49:58 Variational SMC ELBO=-2.82259155e+02 at epoch 2
08/03/2018 12:49:59 Variational SMC ELBO=-3.28605377e+02 at epoch 3
08/03/2018 12:50:00 Variational SMC ELBO=-3.02955841e+02 at epoch 4
08/03/2018 12:50:01 Variational SMC ELBO=-2.87140411e+02 at epoch 5
08/03/2018 12:50:01 Variational SMC ELBO=-2.96396667e+02 at epoch 6
08/03/2018 12:50:02 Variational SMC ELBO=-2.68181396e+02 at epoch 7
08/03/2018 12:50:03 Variational SMC ELBO=-2.88672729e+02 at epoch 8
08/03/2018 12:50:04 Variational SMC ELBO=-3.08635162e+02 at epoch 9
08/03/2018 12:50:04 Variational SMC ELBO=-2.95626312e+02 at epoch 10
08/03/2018 12:50:05 Variational SMC ELBO=-3.07722900e+02 at epoch 11
08/03/2018 12:50:06 Variational SMC ELBO=-2.91829834e+02 at epoch 12
08/03/2018 12:50:06 Variational SMC ELBO=-2.84201172e+02 at epoch 13
08/03/2018 12:50:07 Variational SMC ELBO=-2.70907013e+02 at epoch 14
08/03/2018 12:50:08 Variational SMC ELBO=-3

In [13]:
for t in range(1, zs.shape[0]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 0.430157
SMC MSE at time 2: 0.830760
SMC MSE at time 3: 1.347285
SMC MSE at time 4: 0.405974
SMC MSE at time 5: 0.723052
SMC MSE at time 6: 0.196357
SMC MSE at time 7: 1.826710
SMC MSE at time 8: 1.961942
SMC MSE at time 9: 0.018973
SMC MSE at time 10: 0.107938
SMC MSE at time 11: 0.081867
SMC MSE at time 12: 0.317769
SMC MSE at time 13: 0.957932
SMC MSE at time 14: 1.196552
SMC MSE at time 15: 0.501348
SMC MSE at time 16: 0.966183
SMC MSE at time 17: 1.888758
SMC MSE at time 18: 0.939833
SMC MSE at time 19: 3.437114
SMC MSE at time 20: 0.623674
SMC MSE at time 21: 0.354927
SMC MSE at time 22: 0.718752
SMC MSE at time 23: 1.780598
SMC MSE at time 24: 0.658953
SMC MSE at time 25: 1.914777
SMC MSE at time 26: 2.762562
SMC MSE at time 27: 0.048602
SMC MSE at time 28: 0.866014
SMC MSE at time 29: 1.399533
SMC MSE at time 30: 3.000289
SMC MSE at time 31: 1.581892
SMC MSE at time 32: 0.312253
SMC MSE at time 33: 2.109411
SMC MSE at time 34: 1.018794
SMC MSE at time 35: 0.9

In [14]:
particle_mh = smc.ParticleMH(combinators.Model(ssm.init_ssm, hyper=ssm_params), smc_runner, num_particles)

In [15]:
samples, elbos = particle_mh(trace=importance.ResamplerTrace(num_particles, data=data))
inference = particle_mh.trace

In [16]:
inference

Trace{'mu': Normal([torch.FloatTensor of size 100]), 'sigma': Normal([torch.FloatTensor of size 100]), 'delta': Normal([torch.FloatTensor of size 100]), 'Z_0': Normal([torch.FloatTensor of size 100]), 'Z_1': Normal([torch.FloatTensor of size 100]), 'X_1': Normal([torch.FloatTensor of size 100]), 'Z_2': Normal([torch.FloatTensor of size 100]), 'X_2': Normal([torch.FloatTensor of size 100]), 'Z_3': Normal([torch.FloatTensor of size 100]), 'X_3': Normal([torch.FloatTensor of size 100]), 'Z_4': Normal([torch.FloatTensor of size 100]), 'X_4': Normal([torch.FloatTensor of size 100]), 'Z_5': Normal([torch.FloatTensor of size 100]), 'X_5': Normal([torch.FloatTensor of size 100]), 'Z_6': Normal([torch.FloatTensor of size 100]), 'X_6': Normal([torch.FloatTensor of size 100]), 'Z_7': Normal([torch.FloatTensor of size 100]), 'X_7': Normal([torch.FloatTensor of size 100]), 'Z_8': Normal([torch.FloatTensor of size 100]), 'X_8': Normal([torch.FloatTensor of size 100]), 'Z_9': Normal([torch.FloatTenso

In [17]:
generative

Trace{'Z_1': Normal([torch.FloatTensor of size 1]), 'X_1': Normal([torch.FloatTensor of size 1]), 'Z_2': Normal([torch.FloatTensor of size 1]), 'X_2': Normal([torch.FloatTensor of size 1]), 'Z_3': Normal([torch.FloatTensor of size 1]), 'X_3': Normal([torch.FloatTensor of size 1]), 'Z_4': Normal([torch.FloatTensor of size 1]), 'X_4': Normal([torch.FloatTensor of size 1]), 'Z_5': Normal([torch.FloatTensor of size 1]), 'X_5': Normal([torch.FloatTensor of size 1]), 'Z_6': Normal([torch.FloatTensor of size 1]), 'X_6': Normal([torch.FloatTensor of size 1]), 'Z_7': Normal([torch.FloatTensor of size 1]), 'X_7': Normal([torch.FloatTensor of size 1]), 'Z_8': Normal([torch.FloatTensor of size 1]), 'X_8': Normal([torch.FloatTensor of size 1]), 'Z_9': Normal([torch.FloatTensor of size 1]), 'X_9': Normal([torch.FloatTensor of size 1]), 'Z_10': Normal([torch.FloatTensor of size 1]), 'X_10': Normal([torch.FloatTensor of size 1]), 'Z_11': Normal([torch.FloatTensor of size 1]), 'X_11': Normal([torch.Flo

In [18]:
for t in range(1, zs.shape[0]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 0.151076
SMC MSE at time 2: 0.159158
SMC MSE at time 3: 1.866906
SMC MSE at time 4: 0.775655
SMC MSE at time 5: 0.919981
SMC MSE at time 6: 0.060901
SMC MSE at time 7: 0.022647
SMC MSE at time 8: 0.513106
SMC MSE at time 9: 0.116411
SMC MSE at time 10: 0.675727
SMC MSE at time 11: 0.100060
SMC MSE at time 12: 0.710672
SMC MSE at time 13: 0.540825
SMC MSE at time 14: 0.000906
SMC MSE at time 15: 0.313400
SMC MSE at time 16: 1.269102
SMC MSE at time 17: 0.044479
SMC MSE at time 18: 3.085858
SMC MSE at time 19: 0.737249
SMC MSE at time 20: 2.006048
SMC MSE at time 21: 0.162627
SMC MSE at time 22: 0.115693
SMC MSE at time 23: 0.142863
SMC MSE at time 24: 1.426639
SMC MSE at time 25: 0.075525
SMC MSE at time 26: 0.240287
SMC MSE at time 27: 0.324577
SMC MSE at time 28: 0.721841
SMC MSE at time 29: 0.875001
SMC MSE at time 30: 0.928948
SMC MSE at time 31: 1.673232
SMC MSE at time 32: 0.271159
SMC MSE at time 33: 0.257160
SMC MSE at time 34: 0.344125
SMC MSE at time 35: 0.1

In [19]:
elbos

tensor([-317.3787, -317.3787, -299.6674, -298.3973, -290.7737, -290.7737,
        -277.3042, -277.3042, -277.3042, -277.3042, -277.3042, -277.3042,
        -277.3042, -270.0576, -270.0576, -270.0576, -270.0576, -270.0576,
        -270.0576, -270.0576, -270.0576, -270.0576, -270.0576, -270.0576,
        -270.0576, -270.0576, -270.0576, -270.0576, -270.0576, -270.0576,
        -270.0576, -270.0576, -270.0576, -270.0576, -270.0576, -270.0576,
        -270.0576, -270.0576, -270.0576, -270.0576, -270.0576, -270.0576,
        -270.0576, -270.0576, -270.0576, -270.0576, -270.0576, -270.0576,
        -270.0576, -270.0576, -268.4351, -268.4351, -268.4351, -266.5109,
        -266.5109, -266.5109, -266.5109, -266.5109, -266.5109, -266.5109,
        -266.5109, -266.5109, -266.5109, -266.5109, -266.5109, -266.5109,
        -266.5109, -266.5109, -266.5109, -266.5109, -266.5109, -266.5109,
        -266.5109, -266.5109, -266.5109, -266.5109, -266.5109, -266.5109,
        -266.5109, -266.5109, -266.510