In [1]:
import logging

import probtorch
import torch

import combinators
import smc
import ssm
import utils

In [2]:
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [3]:
generative = combinators.ParticleTrace(1)

In [4]:
mu = torch.zeros(1, 1)
sigma = torch.ones(1, 1) / 2
delta = torch.ones(1, 1)
zs = torch.ones(1, 50+1) * -1
zs[:, 0] = 0

In [5]:
ssm_step = combinators.Model(ssm.ssm_step)

In [6]:
for t in range(zs.shape[1] - 1):
    zs[:, t+1], mu, sigma, delta = ssm_step((zs[:, t], mu, sigma, delta), t, trace=generative)

In [7]:
generative

Trace{'Z_1': Normal([torch.FloatTensor of size 1x1]), 'X_1': Normal([torch.FloatTensor of size 1x1]), 'Z_2': Normal([torch.FloatTensor of size 1x1]), 'X_2': Normal([torch.FloatTensor of size 1x1]), 'Z_3': Normal([torch.FloatTensor of size 1x1]), 'X_3': Normal([torch.FloatTensor of size 1x1]), 'Z_4': Normal([torch.FloatTensor of size 1x1]), 'X_4': Normal([torch.FloatTensor of size 1x1]), 'Z_5': Normal([torch.FloatTensor of size 1x1]), 'X_5': Normal([torch.FloatTensor of size 1x1]), 'Z_6': Normal([torch.FloatTensor of size 1x1]), 'X_6': Normal([torch.FloatTensor of size 1x1]), 'Z_7': Normal([torch.FloatTensor of size 1x1]), 'X_7': Normal([torch.FloatTensor of size 1x1]), 'Z_8': Normal([torch.FloatTensor of size 1x1]), 'X_8': Normal([torch.FloatTensor of size 1x1]), 'Z_9': Normal([torch.FloatTensor of size 1x1]), 'X_9': Normal([torch.FloatTensor of size 1x1]), 'Z_10': Normal([torch.FloatTensor of size 1x1]), 'X_10': Normal([torch.FloatTensor of size 1x1]), 'Z_11': Normal([torch.FloatTenso

In [8]:
num_particles = 100

In [9]:
smc_runner = smc.smc(ssm.ssm_step, 50)

In [10]:
ssm_params = {
    'mu': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'sigma': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'delta': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    }
}

In [11]:
init_ssm = combinators.Model(ssm.init_ssm, trainable=ssm_params)

In [12]:
inference, init_ssm_params = smc.variational_smc(num_particles, init_ssm, smc_runner, 500, generative,
                                                 use_cuda=True, lr=1e-2)

07/30/2018 13:44:41 Variational SMC ELBO=-1.55196641e+04 at epoch 1
07/30/2018 13:44:42 Variational SMC ELBO=-2.02935156e+04 at epoch 2
07/30/2018 13:44:42 Variational SMC ELBO=-1.89200293e+04 at epoch 3
07/30/2018 13:44:43 Variational SMC ELBO=-1.59382666e+04 at epoch 4
07/30/2018 13:44:43 Variational SMC ELBO=-1.76801582e+04 at epoch 5
07/30/2018 13:44:44 Variational SMC ELBO=-1.56786660e+04 at epoch 6
07/30/2018 13:44:45 Variational SMC ELBO=-2.04012734e+04 at epoch 7
07/30/2018 13:44:45 Variational SMC ELBO=-1.94632246e+04 at epoch 8
07/30/2018 13:44:46 Variational SMC ELBO=-1.87675000e+04 at epoch 9
07/30/2018 13:44:47 Variational SMC ELBO=-1.79929961e+04 at epoch 10
07/30/2018 13:44:47 Variational SMC ELBO=-1.65347832e+04 at epoch 11
07/30/2018 13:44:48 Variational SMC ELBO=-2.06208848e+04 at epoch 12
07/30/2018 13:44:48 Variational SMC ELBO=-2.14268047e+04 at epoch 13
07/30/2018 13:44:49 Variational SMC ELBO=-2.09204023e+04 at epoch 14
07/30/2018 13:44:50 Variational SMC ELBO=-2

In [13]:
for t in range(1, zs.shape[1]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 3.219168
SMC MSE at time 2: 2.368361
SMC MSE at time 3: 0.075128
SMC MSE at time 4: 1.107255
SMC MSE at time 5: 0.011367
SMC MSE at time 6: 0.000292
SMC MSE at time 7: 4.235059
SMC MSE at time 8: 0.002446
SMC MSE at time 9: 1.621157
SMC MSE at time 10: 6.134621
SMC MSE at time 11: 15.676647
SMC MSE at time 12: 0.528747
SMC MSE at time 13: 1.350168
SMC MSE at time 14: 6.158684
SMC MSE at time 15: 1.651449
SMC MSE at time 16: 0.019419
SMC MSE at time 17: 0.066842
SMC MSE at time 18: 2.212059
SMC MSE at time 19: 0.321106
SMC MSE at time 20: 0.273648
SMC MSE at time 21: 1.764631
SMC MSE at time 22: 0.459933
SMC MSE at time 23: 2.338941
SMC MSE at time 24: 0.262827
SMC MSE at time 25: 4.732226
SMC MSE at time 26: 0.168336
SMC MSE at time 27: 2.857367
SMC MSE at time 28: 1.018824
SMC MSE at time 29: 0.861422
SMC MSE at time 30: 2.583290
SMC MSE at time 31: 0.084618
SMC MSE at time 32: 0.248405
SMC MSE at time 33: 1.526667
SMC MSE at time 34: 2.795800
SMC MSE at time 35: 4.

In [15]:
samples, elbos, inference = smc.particle_mh(num_particles, combinators.Model(ssm.init_ssm, hyper=ssm_params), smc_runner, 500, generative)

In [16]:
inference

Trace{'mu': Normal([torch.cuda.FloatTensor of size 100]), 'sigma': Normal([torch.cuda.FloatTensor of size 100]), 'delta': Normal([torch.cuda.FloatTensor of size 100]), 'Z_0': Normal([torch.cuda.FloatTensor of size 100]), 'Z_1': Normal([torch.cuda.FloatTensor of size 100]), 'X_1': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_2': Normal([torch.cuda.FloatTensor of size 100]), 'X_2': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_3': Normal([torch.cuda.FloatTensor of size 100]), 'X_3': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_4': Normal([torch.cuda.FloatTensor of size 100]), 'X_4': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_5': Normal([torch.cuda.FloatTensor of size 100]), 'X_5': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_6': Normal([torch.cuda.FloatTensor of size 100]), 'X_6': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_7': Normal([torch.cuda.FloatTensor of size 100]), 'X_7': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_8': No

In [17]:
generative

Trace{'Z_1': Normal([torch.FloatTensor of size 1x1]), 'X_1': Normal([torch.FloatTensor of size 1x1]), 'Z_2': Normal([torch.FloatTensor of size 1x1]), 'X_2': Normal([torch.FloatTensor of size 1x1]), 'Z_3': Normal([torch.FloatTensor of size 1x1]), 'X_3': Normal([torch.FloatTensor of size 1x1]), 'Z_4': Normal([torch.FloatTensor of size 1x1]), 'X_4': Normal([torch.FloatTensor of size 1x1]), 'Z_5': Normal([torch.FloatTensor of size 1x1]), 'X_5': Normal([torch.FloatTensor of size 1x1]), 'Z_6': Normal([torch.FloatTensor of size 1x1]), 'X_6': Normal([torch.FloatTensor of size 1x1]), 'Z_7': Normal([torch.FloatTensor of size 1x1]), 'X_7': Normal([torch.FloatTensor of size 1x1]), 'Z_8': Normal([torch.FloatTensor of size 1x1]), 'X_8': Normal([torch.FloatTensor of size 1x1]), 'Z_9': Normal([torch.FloatTensor of size 1x1]), 'X_9': Normal([torch.FloatTensor of size 1x1]), 'Z_10': Normal([torch.FloatTensor of size 1x1]), 'X_10': Normal([torch.FloatTensor of size 1x1]), 'Z_11': Normal([torch.FloatTenso

In [18]:
for t in range(1, zs.shape[1]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 1.249012
SMC MSE at time 2: 0.313053
SMC MSE at time 3: 2.011073
SMC MSE at time 4: 0.148668
SMC MSE at time 5: 0.015371
SMC MSE at time 6: 1.747577
SMC MSE at time 7: 0.710250
SMC MSE at time 8: 0.090786
SMC MSE at time 9: 1.025138
SMC MSE at time 10: 1.528336
SMC MSE at time 11: 1.957589
SMC MSE at time 12: 1.375567
SMC MSE at time 13: 0.001321
SMC MSE at time 14: 2.203247
SMC MSE at time 15: 9.468381
SMC MSE at time 16: 3.271747
SMC MSE at time 17: 10.362039
SMC MSE at time 18: 0.087272
SMC MSE at time 19: 8.961942
SMC MSE at time 20: 0.692879
SMC MSE at time 21: 4.403110
SMC MSE at time 22: 4.215178
SMC MSE at time 23: 0.801467
SMC MSE at time 24: 0.949349
SMC MSE at time 25: 0.544910
SMC MSE at time 26: 0.295890
SMC MSE at time 27: 0.151998
SMC MSE at time 28: 0.530081
SMC MSE at time 29: 0.151404
SMC MSE at time 30: 0.160485
SMC MSE at time 31: 0.500091
SMC MSE at time 32: 2.067556
SMC MSE at time 33: 2.197637
SMC MSE at time 34: 0.646288
SMC MSE at time 35: 4.

In [19]:
elbos

tensor([-21524.1953, -21524.1953, -21524.1953, -21524.1953, -19719.4258,
        -17313.6152, -17313.6152, -14112.8770, -14112.8770, -14112.8770,
        -14112.8770, -14112.8770, -14112.8770, -14112.8770, -14112.8770,
        -14112.8770, -14112.8770, -14112.8770, -14112.8770, -14112.8770,
        -14112.8770, -14112.8770, -14112.8770, -14112.8770, -14112.8770,
        -14112.8770, -14112.8770, -14112.8770, -14112.8770, -14112.8770,
        -14112.8770, -14112.8770, -14112.8770, -14112.8770, -14112.8770,
        -14112.8770, -14112.8770, -14112.8770, -14112.8770, -14112.8770,
        -14112.8770, -14112.8770, -14112.8770, -14112.8770, -14112.8770,
        -14112.8770, -14112.8770, -14112.8770, -14112.8770, -14112.8770,
        -14112.8770, -14112.8770, -14112.8770, -14112.8770, -14112.8770,
        -14112.8770, -14112.8770, -14112.8770, -14112.8770, -14112.8770,
        -14112.8770, -14112.8770, -14112.8770, -14112.8770, -14112.8770,
        -14112.8770, -14112.8770, -14112.8770, -141