In [1]:
import logging

import probtorch
import torch

import combinators
import filtering
import hmm
import utils

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [2]:
generative = combinators.GraphingTrace(1)

In [3]:
hmm_params = utils.vardict({
    'mu': {
        'loc': torch.arange(5, dtype=torch.float).unsqueeze(0) * 2,
        'scale': torch.ones(1, 5) * 0.25,
    },
    'sigma': {
        'loc': torch.ones(1, 5),
        'scale': torch.ones(1, 5) * 0.25,
    }
})
for k in range(6):
    hmm_params['Pi_%d' % k] = {'concentration': torch.ones(1, 5)}

In [4]:
init_hmm = combinators.Model(f=hmm.init_hmm, theta=hmm_params)
init_hmm.condition(generative)

In [5]:
z0, mu, sigma, pi, pi0 = init_hmm(T=50)

In [6]:
hmm_step = combinators.Model(f=hmm.hmm_step)
hmm_run = combinators.Model.sequence(hmm_step, 50, z0, mu, sigma, pi, pi0)

In [7]:
hmm_run.condition(generative)

In [8]:
z_last, mu, sigma, pi, pi0 = hmm_run()

In [9]:
generative = generative.squeeze()

In [10]:
num_particles = 100

In [11]:
inference = combinators.GraphingTrace(100)

In [12]:
hmm_params = utils.vardict({
    'mu': {
        'loc': torch.arange(5, dtype=torch.float) * 2,
        'scale': torch.ones(5) * 0.25,
    },
    'sigma': {
        'loc': torch.ones(5),
        'scale': torch.ones(5) * 0.25,
    }
})
for k in range(6):
    hmm_params['Pi_%d' % k] = {'concentration': torch.ones(5)}
    hmm_params['Pi_%d' % k]['concentration'][1] = 1.25
    hmm_params['Pi_%d' % k]['concentration'][2] = 1.5
    hmm_params['Pi_%d' % k]['concentration'][3] = 1.25

In [13]:
init_hmm = combinators.Model(f=hmm.init_hmm, theta=hmm_params)
init_hmm.condition(inference, generative)

In [14]:
z0, mu, sigma, pi, pi0 = init_hmm(T=50)

In [15]:
hmm_step = hmm.forward_backward_filter_hmm(mu, sigma, pi, pi0)
hmm_run = combinators.Model.sequence(hmm_step, 50, z0, mu, sigma, pi, pi0)

In [16]:
hmm_run.condition(inference, generative)

In [17]:
z_last, mu, sigma, pi, pi0 = hmm_run()

In [18]:
hmm_step.backward_pass(T=50)

In [19]:
hmm_step.smoothed_posterior(T=50)[0][3]

tensor([[-13.8617,  -3.5003,  -0.2583,  -1.6223, -10.0452],
        [ -5.9796,  -0.5193,  -0.9234,  -5.2362, -21.6560],
        [ -0.5050,  -0.9259,  -8.0508, -11.2737, -36.5298],
        [-15.3296,  -3.4595,  -0.3768,  -1.2654,  -7.9666],
        [ -4.1990,  -0.4815,  -1.0170,  -5.2114, -23.6194],
        [-43.4313, -16.3309, -30.5280,  -5.5095,  -0.0041],
        [-74.0819, -28.9385, -97.7721, -16.5926,   0.0000],
        [-10.1751,  -4.6286,  -0.0176,  -4.8758, -15.9855],
        [-34.2206, -11.1558, -17.7689,  -2.8633,  -0.0588],
        [ -0.0513,  -2.9957, -12.8798, -15.1098, -43.7429],
        [-28.2593,  -8.9528, -11.0397,  -2.5651,  -0.0802],
        [-10.9419,  -3.6862,  -0.0412,  -4.1791, -14.6509],
        [ -0.0527,  -2.9686, -21.8265, -20.1739, -55.0382],
        [-22.5018,  -5.2456,  -3.0148,  -0.0916,  -3.4036],
        [-20.7111,  -5.9534,  -5.7071,  -0.0321,  -3.6615],
        [-10.7028,  -2.6181,  -0.3486,  -1.5080, -13.1964],
        [-11.5457,  -2.7852,  -0.1973,  

In [20]:
init_hmm_params = utils.vardict({
    'mu': {
        'loc': torch.rand(5) * 10,
        'scale': torch.ones(5),
    },
    'sigma': {
        'loc': torch.ones(5),
        'scale': torch.ones(5),
    }
})
for k in range(6):
    init_hmm_params['Pi_%d' % k] = {'concentration': torch.ones(5)}

In [21]:
init_hmm = combinators.Model(f=hmm.init_hmm, phi=init_hmm_params)

In [22]:
def hmm_step_builder(z0, mu, sigma, pi, pi0):
    return hmm.forward_backward_filter_hmm(mu, sigma, pi, pi0)

In [23]:
inference, variational_params = filtering.variational_forward_backward(init_hmm, hmm_step_builder, 500, 50, generative, use_cuda=False, lr=1e-2)

07/25/2018 12:47:35 Variational forward-backward ELBO=-7.07705713e+03 at epoch 1
07/25/2018 12:47:35 Variational forward-backward ELBO=-7.64760840e+03 at epoch 2
07/25/2018 12:47:35 Variational forward-backward ELBO=-7.45284424e+03 at epoch 3
07/25/2018 12:47:35 Variational forward-backward ELBO=-6.69325244e+03 at epoch 4
07/25/2018 12:47:35 Variational forward-backward ELBO=-6.88334277e+03 at epoch 5
07/25/2018 12:47:35 Variational forward-backward ELBO=-6.79968994e+03 at epoch 6
07/25/2018 12:47:36 Variational forward-backward ELBO=-6.97438574e+03 at epoch 7
07/25/2018 12:47:36 Variational forward-backward ELBO=-6.88557764e+03 at epoch 8
07/25/2018 12:47:36 Variational forward-backward ELBO=-8.07945068e+03 at epoch 9
07/25/2018 12:47:36 Variational forward-backward ELBO=-6.67037842e+03 at epoch 10
07/25/2018 12:47:36 Variational forward-backward ELBO=-6.89625586e+03 at epoch 11
07/25/2018 12:47:37 Variational forward-backward ELBO=-7.07961865e+03 at epoch 12
07/25/2018 12:47:37 Varia

In [24]:
variational_params

"{'Pi_0__concentration': 'Parameter containing:\ntensor([ 1.2516,  0.5523,  1.6250,  1.1715,  0.6673])', 'Pi_1__concentration': 'Parameter containing:\ntensor([ 0.1298,  2.2648, -0.0633,  0.6689,  1.9680])', 'Pi_2__concentration': 'Parameter containing:\ntensor([ 1.3682,  1.0834,  1.4090,  1.8055,  1.3244])', 'Pi_3__concentration': 'Parameter containing:\ntensor([-0.3365,  2.4198, -0.5777,  0.0632,  2.2681])', 'Pi_4__concentration': 'Parameter containing:\ntensor([-0.0491,  2.5800, -0.2757, -0.1874,  2.0938])', 'Pi_5__concentration': 'Parameter containing:\ntensor([ 1.2146,  1.4893,  1.4559,  1.3468,  1.0465])', 'mu__loc': 'Parameter containing:\ntensor([ 6.1483,  2.3865,  7.1145,  5.9291,  2.5431])', 'mu__scale': 'Parameter containing:\ntensor([ 0.2777, -0.3736,  0.3986,  0.7898, -0.3342])', 'sigma__loc': 'Parameter containing:\ntensor([ 2.1118,  1.7853,  1.8520,  2.3269,  1.2685])', 'sigma__scale': 'Parameter containing:\ntensor([ 1.4017,  0.7025,  1.6896,  0.9506,  1.3112])'}"