In [1]:
import logging

import probtorch
import torch

import combinators
import importance
import mcmc
import smc
import ssm
import utils

In [2]:
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [3]:
generative = combinators.ParticleTrace(1)

In [4]:
mu = torch.zeros(1)
sigma = torch.ones(1) / 2
delta = torch.ones(1)
zs = torch.ones(50+1) * -1
zs[0] = 0

In [5]:
ssm_step = combinators.Model(ssm.ssm_step)

In [6]:
for t in range(zs.shape[0] - 1):
    zs[t+1], mu, sigma, delta = ssm_step((zs[t], mu, sigma, delta), t, trace=generative)

In [7]:
data = generative.unwrap(lambda k, rv: 'X_' in k)

In [8]:
num_particles = 100

In [9]:
ssm_params = {
    'mu': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    },
    'sigma': {
        'loc': torch.tensor(1.),
        'scale': torch.tensor(0.25),
    },
    'delta': {
        'loc': torch.tensor(0.),
        'scale': torch.tensor(0.25),
    }
}

In [10]:
init_ssm_generative = combinators.Model(ssm.init_ssm, hyper=ssm_params)
smc_ssm = smc.SequentialMonteCarlo(ssm.ssm_step, 50, initializer=init_ssm_generative)

In [11]:
init_ssm_proposal = combinators.Model(ssm.init_ssm, trainable=ssm_params)
ssm_proposal = combinators.Model.compose(combinators.Model.sequence(ssm.ssm_step, 50),
                                         init_ssm_proposal, intermediate_name='initializer')

In [12]:
inference, init_ssm_params = smc.variational_smc(num_particles, smc_ssm, ssm_proposal, 500, data,
                                                 use_cuda=True, lr=1e-1)

08/21/2018 18:02:41 Variational SMC ELBO=-5.72067383e+02 at epoch 1
08/21/2018 18:02:42 Variational SMC ELBO=-4.84151733e+02 at epoch 2
08/21/2018 18:02:43 Variational SMC ELBO=-4.85258026e+02 at epoch 3
08/21/2018 18:02:43 Variational SMC ELBO=-5.02256561e+02 at epoch 4
08/21/2018 18:02:44 Variational SMC ELBO=-4.26542755e+02 at epoch 5
08/21/2018 18:02:45 Variational SMC ELBO=-4.95528656e+02 at epoch 6
08/21/2018 18:02:46 Variational SMC ELBO=-5.07813538e+02 at epoch 7
08/21/2018 18:02:46 Variational SMC ELBO=-4.88988953e+02 at epoch 8
08/21/2018 18:02:47 Variational SMC ELBO=-4.35473633e+02 at epoch 9
08/21/2018 18:02:48 Variational SMC ELBO=-4.92395996e+02 at epoch 10
08/21/2018 18:02:49 Variational SMC ELBO=-5.43072693e+02 at epoch 11
08/21/2018 18:02:49 Variational SMC ELBO=-4.58492249e+02 at epoch 12
08/21/2018 18:02:50 Variational SMC ELBO=-4.63436829e+02 at epoch 13
08/21/2018 18:02:51 Variational SMC ELBO=-4.30039185e+02 at epoch 14
08/21/2018 18:02:52 Variational SMC ELBO=-4

In [13]:
for t in range(1, zs.shape[0]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 0.517210
SMC MSE at time 2: 0.143449
SMC MSE at time 3: 2.499226
SMC MSE at time 4: 0.597196
SMC MSE at time 5: 1.823441
SMC MSE at time 6: 0.139090
SMC MSE at time 7: 0.005323
SMC MSE at time 8: 0.033775
SMC MSE at time 9: 2.330405
SMC MSE at time 10: 0.042187
SMC MSE at time 11: 1.795071
SMC MSE at time 12: 3.369781
SMC MSE at time 13: 0.057659
SMC MSE at time 14: 0.531522
SMC MSE at time 15: 0.034264
SMC MSE at time 16: 0.805980
SMC MSE at time 17: 0.055468
SMC MSE at time 18: 2.651037
SMC MSE at time 19: 1.863771
SMC MSE at time 20: 1.231140
SMC MSE at time 21: 0.690294
SMC MSE at time 22: 1.225796
SMC MSE at time 23: 0.162681
SMC MSE at time 24: 0.481425
SMC MSE at time 25: 1.141840
SMC MSE at time 26: 1.025328
SMC MSE at time 27: 0.028042
SMC MSE at time 28: 0.619361
SMC MSE at time 29: 0.533066
SMC MSE at time 30: 1.881603
SMC MSE at time 31: 0.011903
SMC MSE at time 32: 0.511214
SMC MSE at time 33: 1.925137
SMC MSE at time 34: 0.109951
SMC MSE at time 35: 0.0

In [14]:
particle_mh = mcmc.IndependentMH(smc_ssm, ssm_proposal, 1000)

In [15]:
samples, elbos = particle_mh(trace=importance.ResamplerTrace(num_particles, data=data))
inference = particle_mh.trace

In [16]:
inference

Trace{'mu': Normal([torch.FloatTensor of size 100]), 'sigma': Normal([torch.FloatTensor of size 100]), 'delta': Normal([torch.FloatTensor of size 100]), 'Z_0': Normal([torch.FloatTensor of size 100]), 'Z_1': Normal([torch.FloatTensor of size 100]), 'X_1': Normal([torch.FloatTensor of size 100]), 'Z_2': Normal([torch.FloatTensor of size 100]), 'X_2': Normal([torch.FloatTensor of size 100]), 'Z_3': Normal([torch.FloatTensor of size 100]), 'X_3': Normal([torch.FloatTensor of size 100]), 'Z_4': Normal([torch.FloatTensor of size 100]), 'X_4': Normal([torch.FloatTensor of size 100]), 'Z_5': Normal([torch.FloatTensor of size 100]), 'X_5': Normal([torch.FloatTensor of size 100]), 'Z_6': Normal([torch.FloatTensor of size 100]), 'X_6': Normal([torch.FloatTensor of size 100]), 'Z_7': Normal([torch.FloatTensor of size 100]), 'X_7': Normal([torch.FloatTensor of size 100]), 'Z_8': Normal([torch.FloatTensor of size 100]), 'X_8': Normal([torch.FloatTensor of size 100]), 'Z_9': Normal([torch.FloatTenso

In [17]:
generative

Trace{'Z_1': Normal([torch.FloatTensor of size 1]), 'X_1': Normal([torch.FloatTensor of size 1]), 'Z_2': Normal([torch.FloatTensor of size 1]), 'X_2': Normal([torch.FloatTensor of size 1]), 'Z_3': Normal([torch.FloatTensor of size 1]), 'X_3': Normal([torch.FloatTensor of size 1]), 'Z_4': Normal([torch.FloatTensor of size 1]), 'X_4': Normal([torch.FloatTensor of size 1]), 'Z_5': Normal([torch.FloatTensor of size 1]), 'X_5': Normal([torch.FloatTensor of size 1]), 'Z_6': Normal([torch.FloatTensor of size 1]), 'X_6': Normal([torch.FloatTensor of size 1]), 'Z_7': Normal([torch.FloatTensor of size 1]), 'X_7': Normal([torch.FloatTensor of size 1]), 'Z_8': Normal([torch.FloatTensor of size 1]), 'X_8': Normal([torch.FloatTensor of size 1]), 'Z_9': Normal([torch.FloatTensor of size 1]), 'X_9': Normal([torch.FloatTensor of size 1]), 'Z_10': Normal([torch.FloatTensor of size 1]), 'X_10': Normal([torch.FloatTensor of size 1]), 'Z_11': Normal([torch.FloatTensor of size 1]), 'X_11': Normal([torch.Flo

In [18]:
for t in range(1, zs.shape[0]):
    key = 'Z_%d' % t
    accuracy = ((inference[key].value - utils.optional_to(generative[key], inference[key].value))**2).mean()
    print('SMC MSE at time %d: %f' % (t, accuracy))

SMC MSE at time 1: 0.000446
SMC MSE at time 2: 0.400163
SMC MSE at time 3: 0.004072
SMC MSE at time 4: 0.255581
SMC MSE at time 5: 0.037414
SMC MSE at time 6: 1.605237
SMC MSE at time 7: 1.449987
SMC MSE at time 8: 2.350627
SMC MSE at time 9: 2.768738
SMC MSE at time 10: 1.216853
SMC MSE at time 11: 0.019983
SMC MSE at time 12: 0.680732
SMC MSE at time 13: 0.800314
SMC MSE at time 14: 0.185998
SMC MSE at time 15: 0.000620
SMC MSE at time 16: 0.682279
SMC MSE at time 17: 0.494669
SMC MSE at time 18: 4.338040
SMC MSE at time 19: 4.943458
SMC MSE at time 20: 2.130695
SMC MSE at time 21: 0.040848
SMC MSE at time 22: 0.218536
SMC MSE at time 23: 0.999619
SMC MSE at time 24: 1.442287
SMC MSE at time 25: 0.000011
SMC MSE at time 26: 0.203298
SMC MSE at time 27: 0.102649
SMC MSE at time 28: 0.013707
SMC MSE at time 29: 0.000954
SMC MSE at time 30: 0.016303
SMC MSE at time 31: 0.057253
SMC MSE at time 32: 0.102559
SMC MSE at time 33: 0.353958
SMC MSE at time 34: 0.110133
SMC MSE at time 35: 0.2

In [19]:
elbos

tensor([-257.6532, -251.7312, -251.7312, -251.7312, -251.7312, -247.1214,
        -247.1214, -247.1214, -247.1214, -246.7860, -246.7860, -246.7860,
        -246.7860, -246.7860, -246.7860, -246.7860, -246.7860, -246.7860,
        -246.7860, -246.7860, -246.7860, -246.7860, -246.7860, -246.7860,
        -246.7860, -246.7860, -246.7860, -246.7860, -246.7860, -246.7860,
        -246.7860, -231.4556, -231.4556, -231.4556, -231.4556, -231.4556,
        -231.4556, -231.4556, -231.4556, -231.4556, -231.4556, -231.4556,
        -231.4556, -231.4556, -231.4556, -231.4556, -231.4556, -231.4556,
        -231.4556, -231.4556, -231.4556, -231.4556, -231.4556, -231.4556,
        -231.4556, -231.4556, -231.4556, -231.4556, -231.4556, -231.4556,
        -231.4556, -231.4556, -231.4556, -231.4556, -231.4556, -231.4556,
        -231.4556, -231.4556, -231.4556, -231.4556, -231.4556, -231.4556,
        -231.4556, -231.4556, -231.4556, -231.4556, -231.4556, -231.4556,
        -231.4556, -231.3249, -231.324