In [1]:
import logging

import probtorch
import torch

import combinators
import smc
import hmm
import utils

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [2]:
generative = smc.ParticleTrace(1)

In [3]:
hmm_params = utils.vardict({
    'mu': {
        'loc': torch.arange(5, dtype=torch.float).unsqueeze(0) * 2,
        'scale': torch.ones(1, 5) * 0.25,
    },
    'sigma': {
        'loc': torch.ones(1, 5),
        'scale': torch.ones(1, 5) * 0.25,
    }
})
for k in range(6):
    hmm_params['Pi_%d' % k] = {'concentration': torch.ones(1, 5)}

In [4]:
init_hmm = combinators.Model(f=hmm.init_hmm, theta=hmm_params)
init_hmm.condition(generative)

In [5]:
z0, mu, sigma, pi = init_hmm(T=50)

In [6]:
hmm_step = combinators.Model(f=hmm.hmm_step)
hmm_run = combinators.Model.partial(combinators.sequence, hmm_step)

In [7]:
hmm_run.condition(generative)

In [8]:
z_last, mu, sigma, pi = hmm_run(50, z0, mu, sigma, pi)

In [9]:
generative = generative.squeeze()

In [10]:
num_particles = 100

In [11]:
smc_hmm_params = utils.vardict({
    'mu': {
        'loc': torch.arange(5, dtype=torch.float) * 2,
        'scale': torch.ones(5),
    },
    'sigma': {
        'loc': torch.ones(5),
        'scale': torch.ones(5),
    }
})
for k in range(6):
    smc_hmm_params['Pi_%d' % k] = {'concentration': torch.ones(5)}

In [12]:
smc_hmm = smc.smc(hmm.hmm_step, hmm.hmm_retrace, reparameterized=False)

In [13]:
inference_init = combinators.Model(hmm.init_hmm, smc_hmm_params, {})

inference = smc.ParticleTrace(num_particles)
inference_init.condition(trace=inference, observations=generative)
smc_hmm.condition(trace=inference, observations=generative)

In [14]:
inference, smc_hmm_params = smc.variational_smc(num_particles, hmm.init_hmm, smc_hmm, 500, 50, smc_hmm_params, generative, use_cuda=False, marginal_model='hmm_step', lr=1e-2)

07/12/2018 14:37:16 Variational SMC ELBO=-1.26154678e+02 at epoch 1
07/12/2018 14:37:16 Variational SMC ELBO=-1.28316788e+02 at epoch 2
07/12/2018 14:37:17 Variational SMC ELBO=-1.25683502e+02 at epoch 3
07/12/2018 14:37:17 Variational SMC ELBO=-1.22453186e+02 at epoch 4
07/12/2018 14:37:17 Variational SMC ELBO=-1.31493835e+02 at epoch 5
07/12/2018 14:37:17 Variational SMC ELBO=-1.26792732e+02 at epoch 6
07/12/2018 14:37:18 Variational SMC ELBO=-1.26768860e+02 at epoch 7
07/12/2018 14:37:18 Variational SMC ELBO=-1.24953232e+02 at epoch 8
07/12/2018 14:37:18 Variational SMC ELBO=-1.21078339e+02 at epoch 9
07/12/2018 14:37:19 Variational SMC ELBO=-1.23843742e+02 at epoch 10
07/12/2018 14:37:19 Variational SMC ELBO=-1.16847771e+02 at epoch 11
07/12/2018 14:37:19 Variational SMC ELBO=-1.26547142e+02 at epoch 12
07/12/2018 14:37:19 Variational SMC ELBO=-1.21925575e+02 at epoch 13
07/12/2018 14:37:20 Variational SMC ELBO=-1.29930542e+02 at epoch 14
07/12/2018 14:37:20 Variational SMC ELBO=-1

In [15]:
for t in range(50):
    t += 1
    key = 'Z_%d' % t
    accuracy = (inference[key].value == generative[key].value).sum().to(dtype=torch.float) / inference.num_particles
    print('SMC accuracy at time %d: %f' % (t, accuracy))

SMC accuracy at time 1: 0.120000
SMC accuracy at time 2: 1.000000
SMC accuracy at time 3: 0.120000
SMC accuracy at time 4: 0.000000
SMC accuracy at time 5: 0.000000
SMC accuracy at time 6: 0.000000
SMC accuracy at time 7: 1.000000
SMC accuracy at time 8: 0.000000
SMC accuracy at time 9: 0.000000
SMC accuracy at time 10: 0.120000
SMC accuracy at time 11: 1.000000
SMC accuracy at time 12: 0.120000
SMC accuracy at time 13: 1.000000
SMC accuracy at time 14: 0.880000
SMC accuracy at time 15: 1.000000
SMC accuracy at time 16: 0.000000
SMC accuracy at time 17: 1.000000
SMC accuracy at time 18: 0.000000
SMC accuracy at time 19: 0.000000
SMC accuracy at time 20: 0.280000
SMC accuracy at time 21: 0.120000
SMC accuracy at time 22: 1.000000
SMC accuracy at time 23: 0.120000
SMC accuracy at time 24: 0.120000
SMC accuracy at time 25: 0.000000
SMC accuracy at time 26: 0.000000
SMC accuracy at time 27: 0.000000
SMC accuracy at time 28: 0.280000
SMC accuracy at time 29: 0.670000
SMC accuracy at time 30

In [16]:
smc_hmm_params

"{'Pi_0__concentration': 'Parameter containing:\ntensor([-0.5558,  1.9282,  3.7532, -0.0255,  0.7340])', 'Pi_1__concentration': 'Parameter containing:\ntensor([ 1.7620,  1.6556,  0.6515,  2.3116,  0.2630])', 'Pi_2__concentration': 'Parameter containing:\ntensor([ 1.7122,  1.4136,  0.5896,  2.0801,  1.0699])', 'Pi_3__concentration': 'Parameter containing:\ntensor([ 1.3979,  1.3487,  0.2671,  1.6443,  1.8489])', 'Pi_4__concentration': 'Parameter containing:\ntensor([ 2.2332,  2.1674,  0.7431,  1.3032,  0.3250])', 'Pi_5__concentration': 'Parameter containing:\ntensor([ 1.8072,  1.6749,  0.7768,  1.4101,  0.5392])', 'mu__loc': 'Parameter containing:\ntensor([ 0.3402,  1.3458,  3.1425,  5.7605,  6.5793])', 'mu__scale': 'Parameter containing:\ntensor([-0.6683, -0.7022,  0.8650, -1.9081, -0.6294])', 'sigma__loc': 'Parameter containing:\ntensor([ 1.2581,  1.4520,  2.1297, -0.5507,  0.4541])', 'sigma__scale': 'Parameter containing:\ntensor([-1.1175, -0.5011, -0.4390, -1.3168, -1.3055])'}"