In [1]:
import logging

import probtorch
import torch

import combinators
import filtering
import hmm
import utils

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [2]:
generative = combinators.ParticleTrace(1)

In [3]:
hmm_params = utils.vardict({
    'mu': {
        'loc': torch.arange(5, dtype=torch.float).unsqueeze(0) * 2,
        'scale': torch.ones(1, 5) * 0.25,
    },
    'sigma': {
        'loc': torch.ones(1, 5),
        'scale': torch.ones(1, 5) * 0.25,
    }
})
for k in range(6):
    hmm_params['Pi_%d' % k] = {'concentration': torch.ones(1, 5)}

In [4]:
init_hmm = combinators.Model(f=hmm.init_hmm, theta=hmm_params)
init_hmm.condition(generative)

In [5]:
z0, mu, sigma, pi, pi0 = init_hmm()

In [6]:
hmm_step = combinators.Model(f=hmm.hmm_step)
hmm_run = combinators.Model.sequence(hmm_step, 50, z0, mu, sigma, pi, pi0)

In [7]:
hmm_run.condition(generative)

In [8]:
z_last, mu, sigma, pi, pi0 = hmm_run()

In [9]:
generative = generative.squeeze()

In [10]:
num_particles = 100

In [11]:
inference = combinators.ParticleTrace(100)

In [12]:
hmm_params = utils.vardict({
    'mu': {
        'loc': torch.arange(5, dtype=torch.float) * 2,
        'scale': torch.ones(5) * 0.25,
    },
    'sigma': {
        'loc': torch.ones(5),
        'scale': torch.ones(5) * 0.25,
    }
})
for k in range(6):
    hmm_params['Pi_%d' % k] = {'concentration': torch.ones(5)}
    hmm_params['Pi_%d' % k]['concentration'][1] = 1.25
    hmm_params['Pi_%d' % k]['concentration'][2] = 1.5
    hmm_params['Pi_%d' % k]['concentration'][3] = 1.25

In [13]:
init_hmm = combinators.Model(f=hmm.init_hmm, theta=hmm_params)
init_hmm.condition(inference, generative)

In [14]:
z0, mu, sigma, pi, pi0 = init_hmm()

In [15]:
hmm_step = hmm.forward_backward_filter_hmm(mu, sigma, pi, pi0)
hmm_run = combinators.Model.sequence(hmm_step, 50, z0, mu, sigma, pi, pi0)

In [16]:
hmm_run.condition(inference, generative)

In [17]:
z_last, mu, sigma, pi, pi0 = hmm_run()

In [18]:
hmm_step.backward_pass(T=50)

In [19]:
hmm_step.smoothed_posterior(T=50)[0][3]

tensor([[ -5.9246,  -6.8010,  -0.3467,  -1.2607,  -5.1545],
        [-19.3784, -31.9283, -39.9891, -23.0865,   0.0000],
        [ -0.8615,  -0.5491, -27.0288, -35.9691, -13.9205],
        [-16.9129, -27.7007, -32.4450, -18.0436,   0.0000],
        [ -0.0406,  -3.2254, -59.5305, -68.4895, -18.0930],
        [ -0.4254,  -1.0599, -33.0069, -42.6470, -15.4175],
        [ -0.5433,  -0.8695, -24.2281, -32.7439, -14.2998],
        [ -7.0861, -10.5247,  -3.3157,  -0.4572,  -1.1094],
        [-13.7468, -21.7283, -20.8757,  -9.8428,  -0.0001],
        [ -8.6269, -10.9936,  -5.9003,  -0.8777,  -0.5424],
        [ -7.4117,  -8.4883,  -1.7610,  -0.2384,  -3.2322],
        [ -8.0088, -12.3049,  -6.0280,  -1.0665,  -0.4261],
        [-11.3910, -16.2032, -13.0580,  -5.1070,  -0.0061],
        [ -0.0428,  -3.1723, -51.5473, -60.5633, -17.3712],
        [-23.0908, -39.9444, -53.9401, -33.5756,   0.0000],
        [ -9.8985, -12.9194,  -8.2162,  -2.3720,  -0.0983],
        [ -7.9882, -10.8885,  -5.9318,  

In [20]:
init_hmm_params = utils.vardict({
    'mu': {
        'loc': torch.rand(5) * 10,
        'scale': torch.ones(5),
    },
    'sigma': {
        'loc': torch.ones(5),
        'scale': torch.ones(5),
    }
})
for k in range(6):
    init_hmm_params['Pi_%d' % k] = {'concentration': torch.ones(5)}

In [21]:
init_hmm = combinators.Model(f=hmm.init_hmm, phi=init_hmm_params)

In [22]:
def hmm_step_builder(z0, mu, sigma, pi, pi0):
    return hmm.forward_backward_filter_hmm(mu, sigma, pi, pi0)

In [23]:
inference, variational_params = filtering.variational_forward_backward(init_hmm, hmm_step_builder, 500, 50, generative, use_cuda=False, lr=1e-2)

07/29/2018 11:51:40 Variational forward-backward ELBO=-8.15286621e+03 at epoch 1
07/29/2018 11:51:41 Variational forward-backward ELBO=-7.92951416e+03 at epoch 2
07/29/2018 11:51:41 Variational forward-backward ELBO=-9.78722070e+03 at epoch 3
07/29/2018 11:51:42 Variational forward-backward ELBO=-8.14984082e+03 at epoch 4
07/29/2018 11:51:43 Variational forward-backward ELBO=-8.12417920e+03 at epoch 5
07/29/2018 11:51:44 Variational forward-backward ELBO=-7.02318115e+03 at epoch 6
07/29/2018 11:51:44 Variational forward-backward ELBO=-1.05525771e+04 at epoch 7
07/29/2018 11:51:45 Variational forward-backward ELBO=-8.74362109e+03 at epoch 8
07/29/2018 11:51:46 Variational forward-backward ELBO=-1.09612852e+04 at epoch 9
07/29/2018 11:51:47 Variational forward-backward ELBO=-8.13174951e+03 at epoch 10
07/29/2018 11:51:48 Variational forward-backward ELBO=-7.95698438e+03 at epoch 11
07/29/2018 11:51:49 Variational forward-backward ELBO=-6.98796436e+03 at epoch 12
07/29/2018 11:51:49 Varia

In [24]:
variational_params

"{'Pi_0__concentration': 'Parameter containing:\ntensor([ 0.8742,  0.6998,  1.3063,  1.5975,  0.7711])', 'Pi_1__concentration': 'Parameter containing:\ntensor([ 0.2690,  2.5222,  0.7575,  0.9118, -0.4100])', 'Pi_2__concentration': 'Parameter containing:\ntensor([ 0.6347,  2.2639,  1.2084,  1.0927,  0.6171])', 'Pi_3__concentration': 'Parameter containing:\ntensor([ 0.9780,  1.4456,  1.6574,  0.9587,  0.9340])', 'Pi_4__concentration': 'Parameter containing:\ntensor([ 0.7394,  1.7053,  0.9916,  1.6835,  0.6025])', 'Pi_5__concentration': 'Parameter containing:\ntensor([-0.3361,  2.9504,  0.7905,  0.7360, -0.6260])', 'mu__loc': 'Parameter containing:\ntensor([ 7.0303,  0.7707,  4.7167,  5.2700,  7.7742])', 'mu__scale': 'Parameter containing:\ntensor([ 0.3512,  0.0508,  0.6685,  0.7616,  0.2536])', 'sigma__loc': 'Parameter containing:\ntensor([ 1.1412,  1.4149,  1.4398,  1.8701,  1.2442])', 'sigma__scale': 'Parameter containing:\ntensor([ 1.8466,  1.6415,  1.7397,  1.6081,  1.5097])'}"