In [1]:
import probtorch
import torch

import combinators
import hmm
import smc
import utils

In [2]:
generative = smc.ParticleTrace(1)

In [3]:
pi = torch.zeros(1, 5, 5)
for i in range(5):
    n = i + 1
    if n >= 5:
        n = 0
    pi[:, i, n] = 1
mu = torch.arange(5).unsqueeze(0)
sigma = torch.ones(1, 5) / 2
zs = torch.ones(1, 50+1, dtype=torch.long) * -1
zs[:, 0] = 0

In [4]:
for t in range(zs.shape[1] - 1):
    zs, pi, mu, sigma, t, generative = hmm.hmm_step(zs, pi, mu, sigma, t+1, generative)

In [5]:
generative

Trace{'Z_1': Categorical([torch.LongTensor of size 1]), 'X_1': Normal([torch.FloatTensor of size 1]), 'Z_2': Categorical([torch.LongTensor of size 1]), 'X_2': Normal([torch.FloatTensor of size 1]), 'Z_3': Categorical([torch.LongTensor of size 1]), 'X_3': Normal([torch.FloatTensor of size 1]), 'Z_4': Categorical([torch.LongTensor of size 1]), 'X_4': Normal([torch.FloatTensor of size 1]), 'Z_5': Categorical([torch.LongTensor of size 1]), 'X_5': Normal([torch.FloatTensor of size 1]), 'Z_6': Categorical([torch.LongTensor of size 1]), 'X_6': Normal([torch.FloatTensor of size 1]), 'Z_7': Categorical([torch.LongTensor of size 1]), 'X_7': Normal([torch.FloatTensor of size 1]), 'Z_8': Categorical([torch.LongTensor of size 1]), 'X_8': Normal([torch.FloatTensor of size 1]), 'Z_9': Categorical([torch.LongTensor of size 1]), 'X_9': Normal([torch.FloatTensor of size 1]), 'Z_10': Categorical([torch.LongTensor of size 1]), 'X_10': Normal([torch.FloatTensor of size 1]), 'Z_11': Categorical([torch.LongT

In [6]:
num_particles = 5000

In [7]:
inference = smc.ParticleTrace(num_particles)

In [8]:
smc_step = smc.smc(hmm.hmm_step, hmm.hmm_retrace)

In [9]:
pi, _, _, zs = hmm.init_hmm(5, 50, inference)
mu = mu.expand(num_particles, 5)
sigma = sigma.expand(num_particles, 5)

In [10]:
for t in range(zs.shape[1] - 1):
    zs, pi, mu, sigma, t, inference = smc_step(zs, pi, mu, sigma, t+1, trace=inference, conditions=generative)

In [11]:
inference

Trace{'\\Pi_0': Dirichlet([torch.FloatTensor of size 5000x5]), '\\Pi_1': Dirichlet([torch.FloatTensor of size 5000x5]), '\\Pi_2': Dirichlet([torch.FloatTensor of size 5000x5]), '\\Pi_3': Dirichlet([torch.FloatTensor of size 5000x5]), '\\Pi_4': Dirichlet([torch.FloatTensor of size 5000x5]), '\\Pi_5': Dirichlet([torch.FloatTensor of size 5000x5]), '\\mu': Normal([torch.FloatTensor of size 5000x5]), '\\sigma': LogNormal([torch.FloatTensor of size 5000x5]), 'Z_0': Categorical([torch.LongTensor of size 5000]), 'Z_1': Categorical([torch.LongTensor of size 5000]), 'X_1': Normal([torch.FloatTensor of size 1]), 'Z_2': Categorical([torch.LongTensor of size 5000]), 'X_2': Normal([torch.FloatTensor of size 1]), 'Z_3': Categorical([torch.LongTensor of size 5000]), 'X_3': Normal([torch.FloatTensor of size 1]), 'Z_4': Categorical([torch.LongTensor of size 5000]), 'X_4': Normal([torch.FloatTensor of size 1]), 'Z_5': Categorical([torch.LongTensor of size 5000]), 'X_5': Normal([torch.FloatTensor of size

In [12]:
generative

Trace{'Z_1': Categorical([torch.LongTensor of size 1]), 'X_1': Normal([torch.FloatTensor of size 1]), 'Z_2': Categorical([torch.LongTensor of size 1]), 'X_2': Normal([torch.FloatTensor of size 1]), 'Z_3': Categorical([torch.LongTensor of size 1]), 'X_3': Normal([torch.FloatTensor of size 1]), 'Z_4': Categorical([torch.LongTensor of size 1]), 'X_4': Normal([torch.FloatTensor of size 1]), 'Z_5': Categorical([torch.LongTensor of size 1]), 'X_5': Normal([torch.FloatTensor of size 1]), 'Z_6': Categorical([torch.LongTensor of size 1]), 'X_6': Normal([torch.FloatTensor of size 1]), 'Z_7': Categorical([torch.LongTensor of size 1]), 'X_7': Normal([torch.FloatTensor of size 1]), 'Z_8': Categorical([torch.LongTensor of size 1]), 'X_8': Normal([torch.FloatTensor of size 1]), 'Z_9': Categorical([torch.LongTensor of size 1]), 'X_9': Normal([torch.FloatTensor of size 1]), 'Z_10': Categorical([torch.LongTensor of size 1]), 'X_10': Normal([torch.FloatTensor of size 1]), 'Z_11': Categorical([torch.LongT

In [13]:
for t in range(1, zs.shape[1]):
    key = 'Z_%d' % t
    accuracy = (inference[key].value == generative[key].value).sum().item()
    print('SMC accuracy at time %d: %f' % (t, accuracy / num_particles))

SMC accuracy at time 1: 0.329000
SMC accuracy at time 2: 0.851200
SMC accuracy at time 3: 0.887400
SMC accuracy at time 4: 0.894000
SMC accuracy at time 5: 0.984800
SMC accuracy at time 6: 0.420400
SMC accuracy at time 7: 0.765200
SMC accuracy at time 8: 0.543400
SMC accuracy at time 9: 0.978200
SMC accuracy at time 10: 0.998400
SMC accuracy at time 11: 0.570000
SMC accuracy at time 12: 0.234800
SMC accuracy at time 13: 0.796000
SMC accuracy at time 14: 0.992200
SMC accuracy at time 15: 0.877000
SMC accuracy at time 16: 0.182400
SMC accuracy at time 17: 0.920200
SMC accuracy at time 18: 0.784400
SMC accuracy at time 19: 0.968000
SMC accuracy at time 20: 1.000000
SMC accuracy at time 21: 0.243400
SMC accuracy at time 22: 0.426400
SMC accuracy at time 23: 0.217200
SMC accuracy at time 24: 0.588200
SMC accuracy at time 25: 0.880400
SMC accuracy at time 26: 0.802600
SMC accuracy at time 27: 0.600200
SMC accuracy at time 28: 0.060200
SMC accuracy at time 29: 0.841800
SMC accuracy at time 30