In [1]:
import logging

import probtorch
import torch

import combinators
import smc
import ssm
import utils

In [2]:
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [3]:
generative = smc.ParticleTrace(1)

In [4]:
mu = torch.zeros(1, 1)
sigma = torch.ones(1, 1) / 2
delta = torch.ones(1, 1)
zs = torch.ones(1, 50+1) * -1
zs[:, 0] = 0

In [5]:
ssm_step = combinators.Model(ssm.ssm_step)
ssm_step.condition(trace=generative, observations=None)

In [6]:
for t in range(zs.shape[1] - 1):
    zs[:, t+1], mu, sigma, delta = ssm_step((zs[:, t], mu, sigma, delta), t)

In [7]:
generative

Trace{'Z_1': Normal([torch.FloatTensor of size 1x1]), 'X_1': Normal([torch.FloatTensor of size 1x1]), 'Z_2': Normal([torch.FloatTensor of size 1x1]), 'X_2': Normal([torch.FloatTensor of size 1x1]), 'Z_3': Normal([torch.FloatTensor of size 1x1]), 'X_3': Normal([torch.FloatTensor of size 1x1]), 'Z_4': Normal([torch.FloatTensor of size 1x1]), 'X_4': Normal([torch.FloatTensor of size 1x1]), 'Z_5': Normal([torch.FloatTensor of size 1x1]), 'X_5': Normal([torch.FloatTensor of size 1x1]), 'Z_6': Normal([torch.FloatTensor of size 1x1]), 'X_6': Normal([torch.FloatTensor of size 1x1]), 'Z_7': Normal([torch.FloatTensor of size 1x1]), 'X_7': Normal([torch.FloatTensor of size 1x1]), 'Z_8': Normal([torch.FloatTensor of size 1x1]), 'X_8': Normal([torch.FloatTensor of size 1x1]), 'Z_9': Normal([torch.FloatTensor of size 1x1]), 'X_9': Normal([torch.FloatTensor of size 1x1]), 'Z_10': Normal([torch.FloatTensor of size 1x1]), 'X_10': Normal([torch.FloatTensor of size 1x1]), 'Z_11': Normal([torch.FloatTenso

In [8]:
num_particles = 100

In [9]:
smc_runner = smc.smc(ssm.ssm_step, ssm.ssm_retrace)

In [10]:
ssm_params = {
    'mu': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'sigma': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'delta': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    }
}

In [11]:
init_ssm = combinators.Model(ssm.init_ssm, phi=ssm_params, theta={})

In [12]:
inference, init_ssm_params = smc.variational_smc(num_particles, init_ssm, smc_runner, 500, 50, generative)

07/25/2018 18:41:57 Variational SMC ELBO=-1.69804609e+04 at epoch 1
07/25/2018 18:41:57 Variational SMC ELBO=-1.88850312e+04 at epoch 2
07/25/2018 18:41:58 Variational SMC ELBO=-1.83490430e+04 at epoch 3
07/25/2018 18:41:58 Variational SMC ELBO=-1.43030254e+04 at epoch 4
07/25/2018 18:41:59 Variational SMC ELBO=-2.37473242e+04 at epoch 5
07/25/2018 18:41:59 Variational SMC ELBO=-1.91949199e+04 at epoch 6
07/25/2018 18:42:00 Variational SMC ELBO=-1.43476172e+04 at epoch 7
07/25/2018 18:42:00 Variational SMC ELBO=-1.88801465e+04 at epoch 8
07/25/2018 18:42:01 Variational SMC ELBO=-1.85902441e+04 at epoch 9
07/25/2018 18:42:01 Variational SMC ELBO=-1.91123086e+04 at epoch 10
07/25/2018 18:42:02 Variational SMC ELBO=-2.04263750e+04 at epoch 11
07/25/2018 18:42:02 Variational SMC ELBO=-3.33777305e+04 at epoch 12
07/25/2018 18:42:02 Variational SMC ELBO=-2.24045059e+04 at epoch 13
07/25/2018 18:42:03 Variational SMC ELBO=-1.73942578e+04 at epoch 14
07/25/2018 18:42:03 Variational SMC ELBO=-1

In [13]:
for t in range(1, zs.shape[1]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 2.978261
SMC MSE at time 2: 2.260783
SMC MSE at time 3: 0.148717
SMC MSE at time 4: 5.068857
SMC MSE at time 5: 0.153612
SMC MSE at time 6: 1.520075
SMC MSE at time 7: 0.117425
SMC MSE at time 8: 5.878562
SMC MSE at time 9: 1.607564
SMC MSE at time 10: 0.522438
SMC MSE at time 11: 0.743277
SMC MSE at time 12: 0.929200
SMC MSE at time 13: 0.010730
SMC MSE at time 14: 1.053820
SMC MSE at time 15: 2.670548
SMC MSE at time 16: 3.011329
SMC MSE at time 17: 0.022915
SMC MSE at time 18: 0.028701
SMC MSE at time 19: 0.088042
SMC MSE at time 20: 1.895900
SMC MSE at time 21: 0.000210
SMC MSE at time 22: 0.342852
SMC MSE at time 23: 0.341864
SMC MSE at time 24: 0.114657
SMC MSE at time 25: 4.092432
SMC MSE at time 26: 4.334611
SMC MSE at time 27: 3.089882
SMC MSE at time 28: 0.151114
SMC MSE at time 29: 1.802423
SMC MSE at time 30: 2.181132
SMC MSE at time 31: 1.063315
SMC MSE at time 32: 0.610774
SMC MSE at time 33: 0.932879
SMC MSE at time 34: 0.250010
SMC MSE at time 35: 0.2

In [15]:
samples, elbos, inference = smc.particle_mh(num_particles, combinators.Model(ssm.init_ssm, theta=ssm_params), smc_runner, 500, 50, generative)

In [16]:
inference

Trace{'mu': Normal([torch.cuda.FloatTensor of size 100]), 'sigma': Normal([torch.cuda.FloatTensor of size 100]), 'delta': Normal([torch.cuda.FloatTensor of size 100]), 'Z_0': Normal([torch.cuda.FloatTensor of size 100]), 'Z_1': Normal([torch.cuda.FloatTensor of size 100]), 'X_1': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_2': Normal([torch.cuda.FloatTensor of size 100]), 'X_2': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_3': Normal([torch.cuda.FloatTensor of size 100]), 'X_3': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_4': Normal([torch.cuda.FloatTensor of size 100]), 'X_4': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_5': Normal([torch.cuda.FloatTensor of size 100]), 'X_5': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_6': Normal([torch.cuda.FloatTensor of size 100]), 'X_6': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_7': Normal([torch.cuda.FloatTensor of size 100]), 'X_7': Normal([torch.cuda.FloatTensor of size 100x1x1]), 'Z_8': No

In [17]:
generative

Trace{'Z_1': Normal([torch.FloatTensor of size 1x1]), 'X_1': Normal([torch.FloatTensor of size 1x1]), 'Z_2': Normal([torch.FloatTensor of size 1x1]), 'X_2': Normal([torch.FloatTensor of size 1x1]), 'Z_3': Normal([torch.FloatTensor of size 1x1]), 'X_3': Normal([torch.FloatTensor of size 1x1]), 'Z_4': Normal([torch.FloatTensor of size 1x1]), 'X_4': Normal([torch.FloatTensor of size 1x1]), 'Z_5': Normal([torch.FloatTensor of size 1x1]), 'X_5': Normal([torch.FloatTensor of size 1x1]), 'Z_6': Normal([torch.FloatTensor of size 1x1]), 'X_6': Normal([torch.FloatTensor of size 1x1]), 'Z_7': Normal([torch.FloatTensor of size 1x1]), 'X_7': Normal([torch.FloatTensor of size 1x1]), 'Z_8': Normal([torch.FloatTensor of size 1x1]), 'X_8': Normal([torch.FloatTensor of size 1x1]), 'Z_9': Normal([torch.FloatTensor of size 1x1]), 'X_9': Normal([torch.FloatTensor of size 1x1]), 'Z_10': Normal([torch.FloatTensor of size 1x1]), 'X_10': Normal([torch.FloatTensor of size 1x1]), 'Z_11': Normal([torch.FloatTenso

In [18]:
for t in range(1, zs.shape[1]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 0.215998
SMC MSE at time 2: 0.000834
SMC MSE at time 3: 4.298633
SMC MSE at time 4: 1.489324
SMC MSE at time 5: 0.270435
SMC MSE at time 6: 0.015515
SMC MSE at time 7: 0.021044
SMC MSE at time 8: 1.557909
SMC MSE at time 9: 0.254637
SMC MSE at time 10: 0.773414
SMC MSE at time 11: 0.687131
SMC MSE at time 12: 0.061820
SMC MSE at time 13: 0.130021
SMC MSE at time 14: 1.122403
SMC MSE at time 15: 1.248771
SMC MSE at time 16: 0.161322
SMC MSE at time 17: 0.349231
SMC MSE at time 18: 0.032719
SMC MSE at time 19: 0.007499
SMC MSE at time 20: 0.120678
SMC MSE at time 21: 0.468430
SMC MSE at time 22: 2.317103
SMC MSE at time 23: 0.032233
SMC MSE at time 24: 0.143272
SMC MSE at time 25: 1.271073
SMC MSE at time 26: 1.603632
SMC MSE at time 27: 0.008897
SMC MSE at time 28: 2.923405
SMC MSE at time 29: 0.107455
SMC MSE at time 30: 2.128040
SMC MSE at time 31: 1.723192
SMC MSE at time 32: 3.245904
SMC MSE at time 33: 0.029436
SMC MSE at time 34: 0.299565
SMC MSE at time 35: 0.9