In [1]:
import logging

import probtorch
import torch

import combinators
import smc
import hmm
import utils

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [2]:
generative = smc.ParticleTrace(1)

In [3]:
hmm_params = utils.vardict({
    'mu': {
        'loc': torch.arange(5, dtype=torch.float).unsqueeze(0) * 2,
        'scale': torch.ones(1, 5) * 0.25,
    },
    'sigma': {
        'loc': torch.ones(1, 5),
        'scale': torch.ones(1, 5) * 0.25,
    }
})
for k in range(6):
    hmm_params['Pi_%d' % k] = {'concentration': torch.ones(1, 5)}

In [4]:
init_hmm = combinators.Model(f=hmm.init_hmm, theta=hmm_params)
init_hmm.condition(generative)

In [5]:
z0, mu, sigma, pi, pi0 = init_hmm(T=50)

In [6]:
hmm_step = combinators.Model(f=hmm.hmm_step)
hmm_run = combinators.Model.sequence(hmm_step, 50, z0, mu, sigma, pi, pi0)

In [7]:
hmm_run.condition(generative)

In [8]:
z_last, mu, sigma, pi, pi0 = hmm_run()

In [9]:
generative = generative.squeeze()

In [10]:
num_particles = 100

In [11]:
smc_hmm_params = utils.vardict({
    'mu': {
        'loc': torch.arange(5, dtype=torch.float) * 2,
        'scale': torch.ones(5),
    },
    'sigma': {
        'loc': torch.ones(5),
        'scale': torch.ones(5),
    }
})
for k in range(6):
    smc_hmm_params['Pi_%d' % k] = {'concentration': torch.ones(5)}

In [12]:
smc_hmm = smc.smc(hmm.hmm_step, hmm.hmm_retrace)

In [13]:
init_hmm = combinators.Model(hmm.init_hmm, phi=smc_hmm_params)

In [14]:
samples, elbos, inference = smc.particle_mh(num_particles, init_hmm, smc_hmm, 1000, 50, generative, use_cuda=False)

In [15]:
for t in range(50):
    t += 1
    key = 'Z_%d' % t
    accuracy = (inference[key].value == generative[key].value).sum().to(dtype=torch.float) / inference.num_particles
    print('SMC accuracy at time %d: %f' % (t, accuracy))

SMC accuracy at time 1: 0.000000
SMC accuracy at time 2: 0.000000
SMC accuracy at time 3: 0.210000
SMC accuracy at time 4: 0.000000
SMC accuracy at time 5: 0.530000
SMC accuracy at time 6: 0.210000
SMC accuracy at time 7: 0.470000
SMC accuracy at time 8: 0.000000
SMC accuracy at time 9: 1.000000
SMC accuracy at time 10: 0.740000
SMC accuracy at time 11: 0.210000
SMC accuracy at time 12: 0.440000
SMC accuracy at time 13: 0.970000
SMC accuracy at time 14: 0.210000
SMC accuracy at time 15: 0.740000
SMC accuracy at time 16: 0.000000
SMC accuracy at time 17: 0.210000
SMC accuracy at time 18: 0.530000
SMC accuracy at time 19: 0.740000
SMC accuracy at time 20: 0.000000
SMC accuracy at time 21: 0.210000
SMC accuracy at time 22: 0.000000
SMC accuracy at time 23: 0.210000
SMC accuracy at time 24: 0.030000
SMC accuracy at time 25: 0.210000
SMC accuracy at time 26: 0.210000
SMC accuracy at time 27: 0.210000
SMC accuracy at time 28: 1.000000
SMC accuracy at time 29: 0.210000
SMC accuracy at time 30

In [16]:
elbos

tensor([-133.1254, -133.1254, -133.0614, -133.0614, -132.2717, -132.2717,
        -132.2717, -132.2717, -131.5890, -131.5083, -131.5083, -131.5083,
        -131.5083, -131.5083, -131.5083, -131.5083, -132.2537, -132.2537,
        -130.7164, -130.7164, -130.7164, -130.7164, -130.7164, -130.7164,
        -130.7164, -130.7164, -130.7164, -130.7164, -130.7164, -130.7164,
        -130.7164, -130.7164, -130.7164, -130.7164, -129.8317, -129.8317,
        -129.8317, -129.8317, -129.8317, -129.8317, -129.8317, -129.8317,
        -130.3419, -130.8334, -130.8334, -130.8334, -131.5754, -131.5754,
        -131.5754, -131.5754, -131.5754, -131.5754, -131.5754, -131.5754,
        -131.5754, -131.5754, -131.5754, -131.5754, -134.0341, -134.0341,
        -131.3505, -131.3505, -131.3505, -131.3505, -131.3505, -131.3505,
        -131.3505, -131.3505, -131.3505, -131.3505, -132.7544, -133.9828,
        -133.8160, -132.7620, -129.3192, -129.3192, -129.3192, -129.3192,
        -129.3192, -129.3192, -129.319