# Gaussian Mixture Model

In [50]:
import torch
import numpy as np
from torch.nn import Softmax
from torch.autograd import Variable

import pyro
from pyro.infer.kl_qp import KL_QP
from pyro.distributions import Categorical, DiagNormal, dirichlet, Gamma

Let's start by modeling univariate real-valued data using a mixture of a fixed number of Gaussians with unknown means and shared unknown variance.

In [51]:
K = 2  # Fixed number of components.
softmax = Softmax()
use_dirichlet_prior = True  # See 

def model(data):
    # Global parameters.
    if use_dirichlet_prior:
        ps = pyro.sample('ps', dirichlet, alpha=Variable(torch.ones(K) * 0.5))
    else:
        # FIXME I would put a Dirichlet prior here, but Dirichlet is buggy.
        pre_ps = pyro.param('pre_ps', Variable(torch.zeros(K))) 
        ps = softmax(pre_ps.unsqueeze(0)).squeeze(0)  # Ugly.

    # Per-component parameters.
    mu = pyro.sample('mu', DiagNormal(Variable(torch.normal(torch.zeros(K), torch.ones(K))),
                                      Variable(torch.ones(K))))
    sigma = pyro.sample('sigma', Gamma(torch.ones(1) * 0.5, torch.ones(1) * 0.5))

    # Per-datum parameters.
    def local_model(i, datum):
        z = pyro.sample('z_{}'.format(i), Categorical(ps))
        pyro.observe('x_{}'.format(i), Normal, mu[z], sigma)
        return z

    return pyro.map_data('obs', data, local_model)

def guide(data):
    mu_init = torch.normal(torch.zeros(K), torch.ones(K))
    mu = pyro.param('mu', Variable(mu_init))
    sigma = Variable(torch.ones(K))
    p_init = torch.ones(K) / K

    def local_guide(i, datum):
        p = pyro.param('p_{}'.format(i), Variable(p_init))
        softmax_p = softmax(p.unsqueeze(0)).squeeze(0)  # Ugly.
        z = pyro.sample('z_{}'.format(i), Categorical(softmax_p))
        return z

    return pyro.map_data('obs', data, local_guide)

In [52]:
optim_fct = pyro.optim(torch.optim.Adam, {'lr': 0.001})
inference = KL_QP(model, guide, optim_fct)

In [53]:
data = torch.Tensor([0, 1, 2, 5, 6, 7])

In [54]:
%pdb off
inference(data)

Automatic pdb calling has been turned OFF


AssertionError: site mu must be sample in guide_trace