In [1]:
import logging

import probtorch
import torch

import combinators
import importance
import mcmc
import smc
import ssm
import utils

In [2]:
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [3]:
generative = combinators.BroadcastingTrace(1)

In [4]:
mu = torch.zeros(1)
sigma = torch.ones(1) / 2
delta = torch.ones(1)
zs = torch.ones(50+1) * -1
zs[0] = 0

In [5]:
ssm_step = combinators.Model(ssm.ssm_step)

In [6]:
for t in range(zs.shape[0] - 1):
    zs[t+1], mu, sigma, delta = ssm_step((zs[t], mu, sigma, delta), t, trace=generative)

In [7]:
data = generative.unwrap(lambda k, rv: 'X_' in k)

In [8]:
num_particles = 100

In [9]:
ssm_params = {
    'mu': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'sigma': {
        'loc': torch.tensor(1.),
        'scale': torch.tensor(0.25),
    },
    'delta': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    }
}

In [10]:
init_ssm_generative = combinators.Model(ssm.init_ssm, hyper=ssm_params)
smc_ssm = smc.SequentialMonteCarlo(combinators.Model(ssm.ssm_step), 50, initializer=init_ssm_generative)

In [11]:
init_ssm_proposal = combinators.Model(ssm.init_ssm, trainable=ssm_params)
ssm_step_proposal = combinators.Model(ssm.ssm_step)
ssm_proposal = combinators.Model.compose(combinators.Model.sequence(ssm_step_proposal, 50),
                                         init_ssm_proposal, intermediate_name='initializer')

In [12]:
ssm_importance = importance.ImportanceSampler(model=smc_ssm, proposal=ssm_proposal)

In [13]:
inference, init_ssm_params = smc.variational_smc(num_particles, ssm_importance, 1000, data,
                                                 use_cuda=False, lr=1e-1)

09/26/2018 10:52:40 Variational SMC ELBO=-2.35184998e+02 at epoch 1
09/26/2018 10:52:41 Variational SMC ELBO=-2.27124893e+02 at epoch 2
09/26/2018 10:52:41 Variational SMC ELBO=-2.09583496e+02 at epoch 3
09/26/2018 10:52:41 Variational SMC ELBO=-2.34544510e+02 at epoch 4
09/26/2018 10:52:41 Variational SMC ELBO=-2.19242767e+02 at epoch 5
09/26/2018 10:52:41 Variational SMC ELBO=-2.08407776e+02 at epoch 6
09/26/2018 10:52:42 Variational SMC ELBO=-1.86157150e+02 at epoch 7
09/26/2018 10:52:42 Variational SMC ELBO=-1.91073807e+02 at epoch 8
09/26/2018 10:52:42 Variational SMC ELBO=-1.94652908e+02 at epoch 9
09/26/2018 10:52:42 Variational SMC ELBO=-2.22669540e+02 at epoch 10
09/26/2018 10:52:42 Variational SMC ELBO=-2.09851227e+02 at epoch 11
09/26/2018 10:52:43 Variational SMC ELBO=-2.16961624e+02 at epoch 12
09/26/2018 10:52:43 Variational SMC ELBO=-2.44525116e+02 at epoch 13
09/26/2018 10:52:43 Variational SMC ELBO=-1.96299469e+02 at epoch 14
09/26/2018 10:52:43 Variational SMC ELBO=-1

In [14]:
for t in range(1, zs.shape[0]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 1.006944
SMC MSE at time 2: 0.549598
SMC MSE at time 3: 0.014943
SMC MSE at time 4: 0.079370
SMC MSE at time 5: 0.009843
SMC MSE at time 6: 0.195259
SMC MSE at time 7: 0.003629
SMC MSE at time 8: 0.240854
SMC MSE at time 9: 1.701914
SMC MSE at time 10: 0.621080
SMC MSE at time 11: 0.634806
SMC MSE at time 12: 0.003889
SMC MSE at time 13: 2.259517
SMC MSE at time 14: 1.336889
SMC MSE at time 15: 1.847433
SMC MSE at time 16: 0.757955
SMC MSE at time 17: 0.096018
SMC MSE at time 18: 0.003758
SMC MSE at time 19: 0.019548
SMC MSE at time 20: 0.289756
SMC MSE at time 21: 0.024810
SMC MSE at time 22: 1.061951
SMC MSE at time 23: 3.255645
SMC MSE at time 24: 0.214557
SMC MSE at time 25: 3.353390
SMC MSE at time 26: 0.285931
SMC MSE at time 27: 1.036148
SMC MSE at time 28: 0.271831
SMC MSE at time 29: 2.221883
SMC MSE at time 30: 0.003903
SMC MSE at time 31: 1.275993
SMC MSE at time 32: 0.004608
SMC MSE at time 33: 0.280620
SMC MSE at time 34: 0.889715
SMC MSE at time 35: 0.5

In [15]:
particle_mh = mcmc.IndependentMH(smc_ssm, ssm_proposal, 1000)

In [16]:
samples, elbos = particle_mh(trace=importance.ResamplerTrace(num_particles, data=data))
inference = particle_mh.trace

In [17]:
inference

Trace{'mu': Normal([torch.FloatTensor of size 100]), 'sigma': Normal([torch.FloatTensor of size 100]), 'delta': Normal([torch.FloatTensor of size 100]), 'Z_0': Normal([torch.FloatTensor of size 100]), 'Z_1': Normal([torch.FloatTensor of size 100]), 'X_1': Normal([torch.FloatTensor of size 100]), 'Z_2': Normal([torch.FloatTensor of size 100]), 'X_2': Normal([torch.FloatTensor of size 100]), 'Z_3': Normal([torch.FloatTensor of size 100]), 'X_3': Normal([torch.FloatTensor of size 100]), 'Z_4': Normal([torch.FloatTensor of size 100]), 'X_4': Normal([torch.FloatTensor of size 100]), 'Z_5': Normal([torch.FloatTensor of size 100]), 'X_5': Normal([torch.FloatTensor of size 100]), 'Z_6': Normal([torch.FloatTensor of size 100]), 'X_6': Normal([torch.FloatTensor of size 100]), 'Z_7': Normal([torch.FloatTensor of size 100]), 'X_7': Normal([torch.FloatTensor of size 100]), 'Z_8': Normal([torch.FloatTensor of size 100]), 'X_8': Normal([torch.FloatTensor of size 100]), 'Z_9': Normal([torch.FloatTenso

In [18]:
generative

Trace{'Z_1': Normal([torch.FloatTensor of size 1]), 'X_1': Normal([torch.FloatTensor of size 1]), 'Z_2': Normal([torch.FloatTensor of size 1]), 'X_2': Normal([torch.FloatTensor of size 1]), 'Z_3': Normal([torch.FloatTensor of size 1]), 'X_3': Normal([torch.FloatTensor of size 1]), 'Z_4': Normal([torch.FloatTensor of size 1]), 'X_4': Normal([torch.FloatTensor of size 1]), 'Z_5': Normal([torch.FloatTensor of size 1]), 'X_5': Normal([torch.FloatTensor of size 1]), 'Z_6': Normal([torch.FloatTensor of size 1]), 'X_6': Normal([torch.FloatTensor of size 1]), 'Z_7': Normal([torch.FloatTensor of size 1]), 'X_7': Normal([torch.FloatTensor of size 1]), 'Z_8': Normal([torch.FloatTensor of size 1]), 'X_8': Normal([torch.FloatTensor of size 1]), 'Z_9': Normal([torch.FloatTensor of size 1]), 'X_9': Normal([torch.FloatTensor of size 1]), 'Z_10': Normal([torch.FloatTensor of size 1]), 'X_10': Normal([torch.FloatTensor of size 1]), 'Z_11': Normal([torch.FloatTensor of size 1]), 'X_11': Normal([torch.Flo

In [19]:
for t in range(1, zs.shape[0]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 2.254701
SMC MSE at time 2: 0.300078
SMC MSE at time 3: 0.368085
SMC MSE at time 4: 1.893311
SMC MSE at time 5: 2.638643
SMC MSE at time 6: 0.098539
SMC MSE at time 7: 0.847202
SMC MSE at time 8: 2.281480
SMC MSE at time 9: 1.685320
SMC MSE at time 10: 4.645953
SMC MSE at time 11: 1.418363
SMC MSE at time 12: 0.315594
SMC MSE at time 13: 5.387441
SMC MSE at time 14: 0.551522
SMC MSE at time 15: 0.086991
SMC MSE at time 16: 0.000006
SMC MSE at time 17: 0.515793
SMC MSE at time 18: 3.139635
SMC MSE at time 19: 0.748002
SMC MSE at time 20: 0.144948
SMC MSE at time 21: 0.826032
SMC MSE at time 22: 1.309672
SMC MSE at time 23: 0.967631
SMC MSE at time 24: 0.169935
SMC MSE at time 25: 0.091448
SMC MSE at time 26: 0.125816
SMC MSE at time 27: 0.014457
SMC MSE at time 28: 0.166079
SMC MSE at time 29: 4.631927
SMC MSE at time 30: 0.081573
SMC MSE at time 31: 0.162347
SMC MSE at time 32: 0.031665
SMC MSE at time 33: 0.602391
SMC MSE at time 34: 3.427154
SMC MSE at time 35: 0.0

In [20]:
elbos

tensor([-148.3310, -148.3310, -146.1279, -140.7279, -141.6854, -134.1581,
        -134.1581, -134.1581, -134.1581, -134.1581, -134.1581, -134.1581,
        -134.1581, -134.1581, -134.1581, -134.1581, -134.1581, -134.1581,
        -134.1581, -134.1581, -134.1581, -134.1581, -134.1581, -134.1581,
        -134.1581, -134.1581, -134.1581, -134.1581, -134.1581, -134.1581,
        -134.1581, -134.1581, -134.1581, -134.1581, -134.1581, -134.1581,
        -134.1581, -134.1581, -134.1581, -125.6691, -125.6691, -125.6691,
        -125.6691, -125.6691, -125.6691, -125.6691, -125.6691, -125.6691,
        -125.6691, -125.6691, -125.6691, -125.6691, -125.6691, -125.6691,
        -125.6691, -125.6691, -125.6691, -125.6691, -125.6691, -125.6691,
        -125.6691, -125.6691, -125.6691, -125.6691, -125.6691, -125.6691,
        -125.6691, -125.6691, -125.6691, -125.6691, -125.6691, -125.6691,
        -125.6691, -125.6691, -125.6691, -125.6691, -125.6691, -125.6691,
        -125.6691, -125.6691, -125.669