In [1]:
%cd ../..

/home/eli/AnacondaProjects/combinators


In [2]:
import logging

import probtorch
import torch

import combinators
import combinators.model as model
import combinators.model.foldable as foldable
from combinators.inference import importance, mcmc
import combinators.utils as utils

import examples.gmm.gmm as gmm
import examples.hmm.hmm as hmm

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [3]:
hmm_params = utils.vardict({
    'mu': {
        'loc': torch.arange(5, dtype=torch.float) * 2,
        'scale': torch.ones(5) * 0.25,
    },
    'sigma': {
        'loc': torch.ones(5),
        'scale': torch.ones(5) * 0.25,
    }
})
for k in range(6):
    key = 'Pi_%d' % k if k else 'Pi'
    hmm_params[key] = {'concentration': torch.ones(5)}

In [4]:
init_gmm = gmm.InitGmm(params=hmm_params, trainable=False, batch_shape=(1,))
init_hmm = hmm.InitHmm(params=hmm_params, trainable=False, batch_shape=(1,))
init_hmm = model.Composition(init_hmm, init_gmm)
hmm_foldable = foldable.Foldable(hmm.HmmStep(), initializer=init_hmm)
hmm_run = foldable.Reduce(hmm_foldable, generator=lambda: range(50))

In [5]:
(z_last, mu, sigma, pi), generative, _ = hmm_run()

In [6]:
data = {k: v.value for k, v in generative.variables() if 'X_' in k}

In [7]:
num_particles = 250

In [8]:
smc_hmm_params = utils.vardict({
    'mu': {
        'loc': torch.arange(5, dtype=torch.float) * 2,
        'scale': torch.ones(5),
    },
    'sigma': {
        'loc': torch.ones(5),
        'scale': torch.ones(5),
    }
})
for k in range(6):
    key = 'Pi_%d' % k if k else 'Pi'
    smc_hmm_params[key] = {'concentration': torch.ones(5)}

In [9]:
init_gmm = gmm.InitGmm(params=smc_hmm_params, trainable=True, batch_shape=(num_particles,))
init_hmm = hmm.InitHmm(params=smc_hmm_params, trainable=True, batch_shape=(num_particles,))
init_hmm = model.Composition(init_hmm, init_gmm)

In [10]:
hmm_step = hmm.HmmStep(batch_shape=(num_particles,))

In [11]:
rmsmc_hmm = mcmc.reduce_resample_move_smc(hmm_step, lambda: range(50), initializer=init_hmm, moves=50)

In [12]:
samples, inference, _ = rmsmc_hmm(data=data)

In [13]:
latent_states = {k: v.value for k, v in generative.variables() if 'Z_' in k}

In [14]:
inferred_latent_states = {k: v.value for k, v in inference.variables() if 'Z_' in k}

In [15]:
for t in range(50):
    t += 1
    key = 'Z_%d' % t
    accuracy = (inferred_latent_states[key] == latent_states[key]).to(dtype=torch.float).mean()
    print('RM-SMC percent accuracy at time %d: %f' % (t, accuracy * 100))

RM-SMC percent accuracy at time 1: 21.600000
RM-SMC percent accuracy at time 2: 3.200000
RM-SMC percent accuracy at time 3: 58.399998
RM-SMC percent accuracy at time 4: 33.599998
RM-SMC percent accuracy at time 5: 34.799999
RM-SMC percent accuracy at time 6: 50.000000
RM-SMC percent accuracy at time 7: 62.400002
RM-SMC percent accuracy at time 8: 34.000000
RM-SMC percent accuracy at time 9: 3.200000
RM-SMC percent accuracy at time 10: 28.400002
RM-SMC percent accuracy at time 11: 53.600002
RM-SMC percent accuracy at time 12: 39.200001
RM-SMC percent accuracy at time 13: 8.400000
RM-SMC percent accuracy at time 14: 48.000000
RM-SMC percent accuracy at time 15: 0.000000
RM-SMC percent accuracy at time 16: 86.000000
RM-SMC percent accuracy at time 17: 10.400001
RM-SMC percent accuracy at time 18: 63.200005
RM-SMC percent accuracy at time 19: 84.400002
RM-SMC percent accuracy at time 20: 98.000000
RM-SMC percent accuracy at time 21: 64.400002
RM-SMC percent accuracy at time 22: 53.600002
R