In [1]:
import logging

import probtorch
import torch

import combinators
import filtering
import hmm
import utils

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [2]:
generative = combinators.ParticleTrace(1)

In [3]:
hmm_params = utils.vardict({
    'mu': {
        'loc': torch.arange(5, dtype=torch.float) * 2,
        'scale': torch.ones(5) * 0.25,
    },
    'sigma': {
        'loc': torch.ones(5),
        'scale': torch.ones(5) * 0.25,
    }
})
for k in range(6):
    hmm_params['Pi_%d' % k] = {'concentration': torch.ones(5)}

In [4]:
init_hmm = combinators.Model(f=hmm.init_hmm, hyper=hmm_params)

In [5]:
z0, mu, sigma, pi, pi0 = init_hmm(trace=generative)

In [6]:
hmm_step = combinators.Model(f=hmm.hmm_step)
hmm_run = combinators.Model.sequence(hmm_step, 50, z0, mu, sigma, pi, pi0)

In [7]:
z_last, mu, sigma, pi, pi0 = hmm_run(trace=generative)

In [8]:
data = generative.unwrap(lambda k, rv: 'X_' in k)

In [9]:
num_particles = 100

In [10]:
inference = combinators.GuidedTrace(100, data=data)

In [11]:
hmm_params = utils.vardict({
    'mu': {
        'loc': torch.arange(5, dtype=torch.float) * 2,
        'scale': torch.ones(5) * 0.25,
    },
    'sigma': {
        'loc': torch.ones(5),
        'scale': torch.ones(5) * 0.25,
    }
})
for k in range(6):
    hmm_params['Pi_%d' % k] = {'concentration': torch.ones(5)}
    hmm_params['Pi_%d' % k]['concentration'][1] = 1.25
    hmm_params['Pi_%d' % k]['concentration'][2] = 1.5
    hmm_params['Pi_%d' % k]['concentration'][3] = 1.25

In [12]:
init_hmm = combinators.Model(f=hmm.init_hmm, trainable=hmm_params)

In [13]:
z0, mu, sigma, pi, pi0 = init_hmm(trace=inference)

In [14]:
hmm_step = hmm.forward_backward_filter_hmm(mu, sigma, pi, pi0)
hmm_run = combinators.Model.sequence(hmm_step, 50, z0, mu, sigma, pi, pi0)

In [15]:
z_last, mu, sigma, pi, pi0 = hmm_run(trace=inference)

In [16]:
hmm_step.backward_pass(T=50)

In [17]:
hmm_step.smoothed_posterior(T=50)[0][3]

tensor([[ -1.7469,  -0.4443,  -1.6910, -11.2762,  -9.5778],
        [ -9.0224, -10.9474,  -3.3158,  -1.3226,  -0.3608],
        [ -4.7041,  -1.5847,  -1.5071,  -3.8322,  -0.6112],
        [-13.7406, -28.4533,  -4.0791,  -1.4319,  -0.2954],
        [ -7.8834,  -9.2057,  -3.3198,  -2.4410,  -0.1321],
        [-10.0298, -10.5242,  -3.6115,  -0.8148,  -0.6345],
        [ -0.1699, -13.7767,  -1.8563, -28.7582, -27.0582],
        [-12.6489, -21.7327,  -5.3618,  -0.4365,  -1.0527],
        [ -4.1647,  -1.2645,  -1.2730,  -2.7250,  -1.0312],
        [ -2.1749,  -0.3198,  -1.8394,  -9.0874,  -6.8340],
        [-15.4810, -44.1185,  -3.2923,  -0.2219,  -1.8212],
        [ -2.3406,  -0.2200,  -2.2917, -11.1135,  -9.0637],
        [-15.1672, -52.5688,  -1.3917,  -0.4804,  -2.0186],
        [-13.3694, -28.5494,  -4.4696,  -0.9685,  -0.4961],
        [ -4.8946,  -0.0493,  -3.2322,  -7.6875,  -7.2590],
        [ -0.0991, -16.6378,  -2.3609, -28.6136, -28.6910],
        [ -4.1786,  -0.1260,  -2.4557,  

In [18]:
init_hmm_params = utils.vardict({
    'mu': {
        'loc': torch.rand(5) * 10,
        'scale': torch.ones(5),
    },
    'sigma': {
        'loc': torch.ones(5),
        'scale': torch.ones(5),
    }
})
for k in range(6):
    init_hmm_params['Pi_%d' % k] = {'concentration': torch.ones(5)}

In [19]:
init_hmm = combinators.Model(f=hmm.init_hmm, trainable=init_hmm_params)

In [20]:
def hmm_step_builder(z0, mu, sigma, pi, pi0):
    return hmm.forward_backward_filter_hmm(mu, sigma, pi, pi0)

In [21]:
inference, variational_params = filtering.variational_forward_backward(init_hmm, hmm_step_builder, 500, 50, data, use_cuda=False, lr=1e-2)

08/03/2018 12:49:41 Variational forward-backward ELBO=-8.59412500e+03 at epoch 1
08/03/2018 12:49:42 Variational forward-backward ELBO=-9.33105273e+03 at epoch 2
08/03/2018 12:49:43 Variational forward-backward ELBO=-7.86899121e+03 at epoch 3
08/03/2018 12:49:44 Variational forward-backward ELBO=-8.08030811e+03 at epoch 4
08/03/2018 12:49:45 Variational forward-backward ELBO=-8.03441846e+03 at epoch 5
08/03/2018 12:49:45 Variational forward-backward ELBO=-7.69200439e+03 at epoch 6
08/03/2018 12:49:46 Variational forward-backward ELBO=-7.70211523e+03 at epoch 7
08/03/2018 12:49:47 Variational forward-backward ELBO=-7.51158301e+03 at epoch 8
08/03/2018 12:49:48 Variational forward-backward ELBO=-8.90505273e+03 at epoch 9
08/03/2018 12:49:48 Variational forward-backward ELBO=-7.76555371e+03 at epoch 10
08/03/2018 12:49:49 Variational forward-backward ELBO=-7.94686328e+03 at epoch 11
08/03/2018 12:49:50 Variational forward-backward ELBO=-7.77283740e+03 at epoch 12
08/03/2018 12:49:51 Varia

In [22]:
variational_params

"{'Pi_0__concentration': 'Parameter containing:\ntensor([0.8398, 1.2257, 0.8225, 1.9910, 0.1388], requires_grad=True)', 'Pi_1__concentration': 'Parameter containing:\ntensor([ 0.3748, -0.1951,  0.3709,  2.7194,  1.0785], requires_grad=True)', 'Pi_2__concentration': 'Parameter containing:\ntensor([0.9925, 0.0292, 1.0006, 2.5003, 0.8791], requires_grad=True)', 'Pi_3__concentration': 'Parameter containing:\ntensor([ 0.0487, -0.3510,  0.7039,  2.8934,  0.9465], requires_grad=True)', 'Pi_4__concentration': 'Parameter containing:\ntensor([0.6905, 0.6980, 1.1985, 2.0838, 1.6576], requires_grad=True)', 'Pi_5__concentration': 'Parameter containing:\ntensor([ 0.0847, -0.1011,  0.1516,  2.6322,  1.3133], requires_grad=True)', 'mu__loc': 'Parameter containing:\ntensor([2.2295, 1.2286, 3.2011, 7.6591, 4.7705], requires_grad=True)', 'mu__scale': 'Parameter containing:\ntensor([ 0.8468,  1.1927,  0.6511, -0.0554,  0.7202], requires_grad=True)', 'sigma__loc': 'Parameter containing:\ntensor([1.9250, 2.