In [2]:
import os
from collections import defaultdict
import torch
import numpy as np
import scipy.stats
from torch.distributions import constraints
from matplotlib import pyplot
%matplotlib inline

import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer.autoguide import AutoDelta
from pyro.optim import Adam
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete
data = torch.tensor([0., 1., 10., 11., 12.])

K = 2  # Fixed number of components.

@config_enumerate
def model(data):
    # Global variables.
    weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
    scale = pyro.sample('scale', dist.LogNormal(0., 2.))
    with pyro.plate('components', K):
        locs = pyro.sample('locs', dist.Normal(0., 10.))

    with pyro.plate('data', len(data)):
        # Local variables.
        assignment = pyro.sample('assignment', dist.Categorical(weights))
        pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)

In [3]:
model(data)


In [4]:
def init_loc_fn(site):
    if site["name"] == "weights":
        # Initialize weights to uniform.
        return torch.ones(K) / K
    if site["name"] == "scale":
        return (data.var() / 2).sqrt()
    if site["name"] == "locs":
        return data[torch.multinomial(torch.ones(len(data)) / len(data), K)]
    raise ValueError(site["name"])

def initialize(seed):
    global global_guide, svi
    pyro.set_rng_seed(seed)
    pyro.clear_param_store()
    global_guide = AutoDelta(poutine.block(model, expose=['weights', 'locs', 'scale']),
                             init_loc_fn=init_loc_fn)
    #svi = SVI(model, global_guide, optim, loss=elbo)
    return global_guide#, #svi.loss(model, global_guide, data)

# Choose the best among 100 random initializations.
# loss, seed = min((initialize(seed), seed) for seed in range(100))
# initialize(seed)
# print('seed = {}, initial_loss = {}'.format(seed, loss))

In [5]:
initialize(42)

AutoDelta()

In [6]:
global_guide.__dict__

{'init_loc_fn': <function __main__.init_loc_fn(site)>,
 '_pyro_name': 'AutoDelta',
 '_pyro_context': <pyro.nn.module._Context at 0x7f76e8050fd0>,
 '_pyro_params': OrderedDict(),
 '_pyro_samples': OrderedDict(),
 'training': True,
 '_parameters': OrderedDict(),
 '_buffers': OrderedDict(),
 '_non_persistent_buffers_set': set(),
 '_backward_hooks': OrderedDict(),
 '_forward_hooks': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_modules': OrderedDict(),
 'master': None,
 '_model': (_bound_partial(functools.partial(<function _context_wrap at 0x7f7687a07ee0>, <pyro.infer.autoguide.initialization.InitMessenger object at 0x7f76e80509a0>, _bound_partial(functools.partial(<function _context_wrap at 0x7f7687a07ee0>, <pyro.poutine.block_messenger.BlockMessenger object at 0x7f76a7f580d0>, _bound_partial(functools.partial(<function _context_wrap at 0x7f7687a07ee0>, <pyro.poutine.infer_config_messenger.InferCo

In [7]:
map_estimates = global_guide(data)
weights = map_estimates['weights']
locs = map_estimates['locs']
scale = map_estimates['scale']
print('weights = {}'.format(weights.data.numpy()))
print('locs = {}'.format(locs.data.numpy()))
print('scale = {}'.format(scale.data.numpy()))


weights = [0.5 0.5]
locs = [11.  1.]
scale = 4.104875087738037


In [9]:
guide_trace = poutine.trace(global_guide).get_trace()#*args, **kwargs)
model_trace = poutine.trace(
    poutine.replay(model, trace=guide_trace)).get_trace(data)#*args, **kwargs)

In [15]:
model_trace.values()

AttributeError: 'Trace' object has no attribute 'values'