In [1]:
import logging

import probtorch
import torch

import combinators
import foldable
import gmm
import hmm
import importance
import mcmc
import traces
import utils

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [2]:
hmm_params = utils.vardict({
    'mu': {
        'loc': torch.arange(5, dtype=torch.float) * 2,
        'scale': torch.ones(5) * 0.25,
    },
    'sigma': {
        'loc': torch.ones(5),
        'scale': torch.ones(5) * 0.25,
    }
})
for k in range(6):
    key = 'Pi_%d' % k if k else 'Pi'
    hmm_params[key] = {'concentration': torch.ones(5)}

In [3]:
init_gmm = gmm.InitGmm(params=hmm_params, trainable=False, batch_shape=(1,))
init_hmm = hmm.InitHmm(params=hmm_params, trainable=False, batch_shape=(1,))
init_hmm = combinators.Composition(init_hmm, init_gmm)
hmm_foldable = foldable.Foldable(hmm.HmmStep(), initializer=init_hmm)
hmm_run = foldable.Reduce(hmm_foldable, generator=lambda: range(50))

In [4]:
(z_last, mu, sigma, pi), generative, _ = hmm_run()

In [5]:
data = {k: v.value for k, v in generative.variables() if 'X_' in k}

In [6]:
num_particles = 250

In [7]:
smc_hmm_params = utils.vardict({
    'mu': {
        'loc': torch.arange(5, dtype=torch.float) * 2,
        'scale': torch.ones(5),
    },
    'sigma': {
        'loc': torch.ones(5),
        'scale': torch.ones(5),
    }
})
for k in range(6):
    key = 'Pi_%d' % k if k else 'Pi'
    smc_hmm_params[key] = {'concentration': torch.ones(5)}

In [8]:
init_gmm = gmm.InitGmm(params=smc_hmm_params, trainable=True, batch_shape=(num_particles,))
init_hmm = hmm.InitHmm(params=smc_hmm_params, trainable=True, batch_shape=(num_particles,))
init_hmm = combinators.Composition(init_hmm, init_gmm)

In [9]:
hmm_step = hmm.HmmStep(batch_shape=(num_particles,))

In [10]:
rmsmc_hmm = mcmc.reduce_resample_move_smc(hmm_step, lambda: range(50), initializer=init_hmm, moves=5)

In [11]:
samples, inference, _ = rmsmc_hmm(data=data)

In [12]:
latent_states = {k: v.value for k, v in generative.variables() if 'Z_' in k}

In [13]:
inferred_latent_states = {k: v.value for k, v in inference.variables() if 'Z_' in k}

In [14]:
for t in range(50):
    t += 1
    key = 'Z_%d' % t
    accuracy = (inferred_latent_states[key] == latent_states[key]).to(dtype=torch.float).mean()
    print('RM-SMC percent accuracy at time %d: %f' % (t, accuracy * 100))

RM-SMC percent accuracy at time 1: 68.000000
RM-SMC percent accuracy at time 2: 35.600002
RM-SMC percent accuracy at time 3: 11.200001
RM-SMC percent accuracy at time 4: 20.000000
RM-SMC percent accuracy at time 5: 70.800003
RM-SMC percent accuracy at time 6: 39.599998
RM-SMC percent accuracy at time 7: 35.600002
RM-SMC percent accuracy at time 8: 18.000000
RM-SMC percent accuracy at time 9: 8.800000
RM-SMC percent accuracy at time 10: 6.800000
RM-SMC percent accuracy at time 11: 40.400002
RM-SMC percent accuracy at time 12: 76.000000
RM-SMC percent accuracy at time 13: 24.000000
RM-SMC percent accuracy at time 14: 51.599998
RM-SMC percent accuracy at time 15: 18.799999
RM-SMC percent accuracy at time 16: 7.600000
RM-SMC percent accuracy at time 17: 8.000000
RM-SMC percent accuracy at time 18: 4.400000
RM-SMC percent accuracy at time 19: 23.600000
RM-SMC percent accuracy at time 20: 91.599998
RM-SMC percent accuracy at time 21: 84.000000
RM-SMC percent accuracy at time 22: 11.599999
RM