In [1]:
import logging

import probtorch
import torch

import combinators
import foldable
import hmm
import importance
import mcmc
import trace_tries
import utils

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [2]:
hmm_params = utils.vardict({
    'mu': {
        'loc': torch.arange(5, dtype=torch.float) * 2,
        'scale': torch.ones(5) * 0.25,
    },
    'sigma': {
        'loc': torch.ones(5),
        'scale': torch.ones(5) * 0.25,
    }
})
for k in range(6):
    hmm_params['Pi_%d' % k] = {'concentration': torch.ones(5)}

In [3]:
init_hmm = combinators.HyperPopulation(hmm.init_hmm, (1,), hyper=hmm_params)
hmm_step = foldable.Foldable(combinators.PrimitiveCall(hmm.hmm_step), initializer=init_hmm)
hmm_run = foldable.Reduce(hmm_step, generator=lambda: range(50))

In [4]:
(z_last, mu, sigma, pi, pi0), generative = hmm_run()

In [5]:
data = {k.rpartition('/')[-1]: rv.value for (k, rv) in generative.filter(lambda k, rv: 'X_' in k)}

In [6]:
num_particles = 250

In [7]:
smc_hmm_params = utils.vardict({
    'mu': {
        'loc': torch.arange(5, dtype=torch.float) * 2,
        'scale': torch.ones(5),
    },
    'sigma': {
        'loc': torch.ones(5),
        'scale': torch.ones(5),
    }
})
for k in range(6):
    smc_hmm_params['Pi_%d' % k] = {'concentration': torch.ones(5)}

In [8]:
init_hmm = combinators.HyperPopulation(hmm.init_hmm, (num_particles,), trainable=smc_hmm_params)

In [9]:
rmsmc_hmm = mcmc.reduce_resample_move_smc(combinators.PrimitiveCall(hmm.hmm_step), (num_particles,), lambda: range(50), initializer=init_hmm, moves=5)

In [10]:
samples, inference = rmsmc_hmm(data=data, trace=trace_tries.HierarchicalTrace())

In [11]:
for t in range(50):
    t += 1
    key = list(filter(lambda key: ('Z_%d' % t) in key, generative.keys()))[0]
    accuracy = (inference[key].value == generative[key].value).to(dtype=torch.float).mean()
    print('SMC accuracy at time %d: %f' % (t, accuracy))

SMC accuracy at time 1: 0.124000
SMC accuracy at time 2: 0.188000
SMC accuracy at time 3: 0.912000
SMC accuracy at time 4: 0.592000
SMC accuracy at time 5: 0.564000
SMC accuracy at time 6: 0.352000
SMC accuracy at time 7: 0.560000
SMC accuracy at time 8: 0.416000
SMC accuracy at time 9: 0.620000
SMC accuracy at time 10: 0.340000
SMC accuracy at time 11: 0.292000
SMC accuracy at time 12: 0.572000
SMC accuracy at time 13: 0.368000
SMC accuracy at time 14: 0.580000
SMC accuracy at time 15: 0.372000
SMC accuracy at time 16: 0.488000
SMC accuracy at time 17: 0.388000
SMC accuracy at time 18: 0.356000
SMC accuracy at time 19: 0.508000
SMC accuracy at time 20: 0.056000
SMC accuracy at time 21: 0.460000
SMC accuracy at time 22: 0.484000
SMC accuracy at time 23: 0.608000
SMC accuracy at time 24: 0.264000
SMC accuracy at time 25: 0.548000
SMC accuracy at time 26: 0.620000
SMC accuracy at time 27: 0.564000
SMC accuracy at time 28: 0.188000
SMC accuracy at time 29: 0.256000
SMC accuracy at time 30