In [1]:
import logging

import probtorch
import torch

import combinators
import filtering
import hmm
import utils

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [2]:
generative = combinators.ParticleTrace(1)

In [3]:
hmm_params = utils.vardict({
    'mu': {
        'loc': torch.arange(5, dtype=torch.float).unsqueeze(0) * 2,
        'scale': torch.ones(1, 5) * 0.25,
    },
    'sigma': {
        'loc': torch.ones(1, 5),
        'scale': torch.ones(1, 5) * 0.25,
    }
})
for k in range(6):
    hmm_params['Pi_%d' % k] = {'concentration': torch.ones(1, 5)}

In [4]:
init_hmm = combinators.Model(f=hmm.init_hmm, theta=hmm_params)

In [5]:
z0, mu, sigma, pi, pi0 = init_hmm(trace=generative)

In [6]:
hmm_step = combinators.Model(f=hmm.hmm_step)
hmm_run = combinators.Model.sequence(hmm_step, 50, z0, mu, sigma, pi, pi0)

In [7]:
z_last, mu, sigma, pi, pi0 = hmm_run(trace=generative)

In [8]:
generative = generative.squeeze()

In [9]:
num_particles = 100

In [10]:
inference = combinators.ParticleTrace(100)

In [11]:
hmm_params = utils.vardict({
    'mu': {
        'loc': torch.arange(5, dtype=torch.float) * 2,
        'scale': torch.ones(5) * 0.25,
    },
    'sigma': {
        'loc': torch.ones(5),
        'scale': torch.ones(5) * 0.25,
    }
})
for k in range(6):
    hmm_params['Pi_%d' % k] = {'concentration': torch.ones(5)}
    hmm_params['Pi_%d' % k]['concentration'][1] = 1.25
    hmm_params['Pi_%d' % k]['concentration'][2] = 1.5
    hmm_params['Pi_%d' % k]['concentration'][3] = 1.25

In [12]:
init_hmm = combinators.Model(f=hmm.init_hmm, theta=hmm_params)

In [13]:
z0, mu, sigma, pi, pi0 = init_hmm(trace=inference, guide=generative)

In [14]:
hmm_step = hmm.forward_backward_filter_hmm(mu, sigma, pi, pi0)
hmm_run = combinators.Model.sequence(hmm_step, 50, z0, mu, sigma, pi, pi0)

In [15]:
z_last, mu, sigma, pi, pi0 = hmm_run(trace=inference, guide=generative)

In [16]:
hmm_step.backward_pass(T=50)

In [17]:
hmm_step.smoothed_posterior(T=50)[0][3]

tensor([[-12.9969,  -3.1063,  -0.1206,  -2.7341,  -5.5497],
        [ -2.1205,  -0.2737,  -2.1260,  -8.9432, -10.5090],
        [-50.2509, -27.3488,  -7.0636,  -0.8794,  -0.5376],
        [-11.0482,  -1.9790,  -0.1903,  -3.3782,  -6.9844],
        [ -1.2523,  -0.3636,  -3.9658, -12.2099, -12.1373],
        [-19.5723,  -8.4239,  -0.2899,  -1.4616,  -3.9335],
        [-90.6679, -57.9821, -20.1242,  -5.1012,  -0.0061],
        [ -3.7514,  -0.1935,  -1.8823,  -8.5910, -11.5007],
        [ -0.2600,  -1.5039,  -5.0056, -14.1368, -13.5088],
        [-25.7981, -11.1409,  -1.4544,  -0.4806,  -1.9103],
        [-18.2780,  -6.7808,  -0.6653,  -0.7676,  -3.8813],
        [-43.4122, -23.4982,  -6.4841,  -0.3036,  -1.3458],
        [-40.3472, -21.7063,  -5.3373,  -0.0488,  -3.1511],
        [ -0.0117,  -4.4783,  -8.1964, -17.1969, -16.1903],
        [-65.0331, -40.7350, -11.9073,  -1.9454,  -0.1542],
        [ -0.0809,  -2.5763,  -6.4049, -16.2521, -14.6330],
        [-15.7214,  -5.5963,  -0.2339,  

In [18]:
init_hmm_params = utils.vardict({
    'mu': {
        'loc': torch.rand(5) * 10,
        'scale': torch.ones(5),
    },
    'sigma': {
        'loc': torch.ones(5),
        'scale': torch.ones(5),
    }
})
for k in range(6):
    init_hmm_params['Pi_%d' % k] = {'concentration': torch.ones(5)}

In [19]:
init_hmm = combinators.Model(f=hmm.init_hmm, phi=init_hmm_params)

In [20]:
def hmm_step_builder(z0, mu, sigma, pi, pi0):
    return hmm.forward_backward_filter_hmm(mu, sigma, pi, pi0)

In [21]:
inference, variational_params = filtering.variational_forward_backward(init_hmm, hmm_step_builder, 500, 50, generative, use_cuda=False, lr=1e-2)

07/30/2018 13:28:24 Variational forward-backward ELBO=-8.80669434e+03 at epoch 1
07/30/2018 13:28:24 Variational forward-backward ELBO=-8.42048535e+03 at epoch 2
07/30/2018 13:28:25 Variational forward-backward ELBO=-8.47251660e+03 at epoch 3
07/30/2018 13:28:25 Variational forward-backward ELBO=-7.17532129e+03 at epoch 4
07/30/2018 13:28:25 Variational forward-backward ELBO=-7.04814160e+03 at epoch 5
07/30/2018 13:28:26 Variational forward-backward ELBO=-7.30672119e+03 at epoch 6
07/30/2018 13:28:26 Variational forward-backward ELBO=-8.46692285e+03 at epoch 7
07/30/2018 13:28:26 Variational forward-backward ELBO=-7.34596777e+03 at epoch 8
07/30/2018 13:28:26 Variational forward-backward ELBO=-8.52429883e+03 at epoch 9
07/30/2018 13:28:27 Variational forward-backward ELBO=-7.61517188e+03 at epoch 10
07/30/2018 13:28:27 Variational forward-backward ELBO=-7.01555957e+03 at epoch 11
07/30/2018 13:28:27 Variational forward-backward ELBO=-7.32445752e+03 at epoch 12
07/30/2018 13:28:27 Varia

In [22]:
variational_params

"{'Pi_0__concentration': 'Parameter containing:\ntensor([1.3087, 1.3544, 1.4905, 1.0398, 0.5154], requires_grad=True)', 'Pi_1__concentration': 'Parameter containing:\ntensor([1.8277, 0.5742, 1.0393, 0.6323, 2.0268], requires_grad=True)', 'Pi_2__concentration': 'Parameter containing:\ntensor([0.8938, 1.3899, 0.9828, 0.6790, 2.1441], requires_grad=True)', 'Pi_3__concentration': 'Parameter containing:\ntensor([1.3847, 1.0120, 2.4341, 1.1004, 0.9377], requires_grad=True)', 'Pi_4__concentration': 'Parameter containing:\ntensor([0.7632, 0.5676, 1.5479, 1.4903, 1.9788], requires_grad=True)', 'Pi_5__concentration': 'Parameter containing:\ntensor([0.7768, 0.6622, 1.5059, 0.7111, 2.4951], requires_grad=True)', 'mu__loc': 'Parameter containing:\ntensor([ 3.3627,  3.2764, -1.0115,  2.6428,  7.6379], requires_grad=True)', 'mu__scale': 'Parameter containing:\ntensor([ 1.4147,  1.3785, -0.5370,  1.6142, -0.3582], requires_grad=True)', 'sigma__loc': 'Parameter containing:\ntensor([2.0598, 1.7386, 1.16