Skip to content

Commit

Permalink
consolidated histogram code in traceposterior and marginal
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 committed Jul 20, 2017
1 parent 79044e3 commit f054b16
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 102 deletions.
45 changes: 1 addition & 44 deletions pyro/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,51 +4,8 @@
import pyro.util
import pyro.poutine

from pyro.infer.abstract_infer import AbstractInfer
from pyro.infer.abstract_infer import Marginal, TracePosterior
from pyro.infer.search import Search
from pyro.infer.mh import MH
from pyro.infer.importance import Importance
from pyro.infer.kl_qp import KL_QP


class Marginal(pyro.distributions.Distribution):
"""
Marginal histogram
"""
def __init__(self, trace_dist):
assert isinstance(trace_dist, AbstractInfer), \
"trace_dist must be trace posterior distribution object"
super(Marginal, self).__init__()
self.trace_dist = trace_dist

@pyro.util.memoize
def _dist(self, *args, **kwargs):
"""
Convert a histogram over traces to a histogram over return values
Currently very inefficient...
"""
vs, log_weights = [], []
for tr, log_weight in self.trace_dist._traces(*args, **kwargs):
vs.append(tr["_RETURN"]["value"])
log_weights.append(log_weight)

log_weights = torch.cat(log_weights)
if not isinstance(log_weights, torch.autograd.Variable):
log_weights = torch.autograd.Variable(log_weights)
log_z = pyro.util.log_sum_exp(log_weights)
ps = torch.exp(log_weights - log_z.expand_as(log_weights))

if isinstance(vs[0], (torch.autograd.Variable, torch.Tensor, np.ndarray)):
hist = pyro.util.tensor_histogram(ps, vs)
else:
hist = pyro.util.basic_histogram(ps, vs)
return pyro.distributions.Categorical(ps=hist["ps"], vs=hist["vs"])

def sample(self, *args, **kwargs):
return pyro.poutine.block(self._dist(*args, **kwargs)).sample()

def log_pdf(self, val, *args, **kwargs):
return pyro.poutine.block(self._dist(*args, **kwargs)).log_pdf(val)

def support(self, *args, **kwargs):
return pyro.poutine.block(self._dist(*args, **kwargs)).support()
105 changes: 60 additions & 45 deletions pyro/infer/abstract_infer.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,83 @@
import numpy as np
import torch
from torch.autograd import Variable
import pyro
import pyro.util
from pyro.distributions import Categorical
import pdb


class AbstractInfer(pyro.distributions.Distribution):
class Histogram(pyro.distributions.Distribution):
"""
abstract inference class
TODO documentation
Histogram
"""
def _traces(self, *args, **kwargs):
"""
Virtual method to get unnormalized weighted list of posterior traces
"""
raise NotImplementedError("inference algorithm must implement _traces")

@pyro.util.memoize
def _dist(self, *args, **kwargs):
"""
make trace posterior distribution object with normalized probs
Convert a histogram over traces to a histogram over return values
Currently very inefficient...
"""
trace_hist = self._traces(*args, **kwargs)
traces, log_weights = [], []
for tr, log_weight in trace_hist:
traces.append(tr)
vs, log_weights = [], []
for v, log_weight in self.gen_weighted_samples(*args, **kwargs):
vs.append(v)
log_weights.append(log_weight)
log_ps = torch.cat(log_weights, 0)
log_ps = log_ps - pyro.util.log_sum_exp(log_ps).expand_as(log_ps)
# XXX Categorical not working correctly with non-Tensor vs
return Categorical(ps=torch.exp(log_ps), vs=[traces])

log_weights = torch.cat(log_weights)
if not isinstance(log_weights, torch.autograd.Variable):
log_weights = torch.autograd.Variable(log_weights)
log_z = pyro.util.log_sum_exp(log_weights)
ps = torch.exp(log_weights - log_z.expand_as(log_weights))

if isinstance(vs[0], (torch.autograd.Variable, torch.Tensor, np.ndarray)):
hist = pyro.util.tensor_histogram(ps, vs)
elif isinstance(vs[0], pyro.poutine.Trace):
hist = {"ps": ps, "vs": [vs]}
else:
hist = pyro.util.basic_histogram(ps, vs)
return pyro.distributions.Categorical(ps=hist["ps"], vs=hist["vs"])

def gen_weighted_samples(self, *args, **kwargs):
raise NotImplementedError("gen_weighted_samples is abstract method")

def sample(self, *args, **kwargs):
"""
sample from trace posterior
"""
return self._dist(*args, **kwargs).sample()
return pyro.poutine.block(self._dist)(*args, **kwargs).sample()

def log_pdf(self, val, *args, **kwargs):
return pyro.poutine.block(self._dist)(*args, **kwargs).log_pdf(val)

def support(self, *args, **kwargs):
return pyro.poutine.block(self._dist)(*args, **kwargs).support()


class Marginal(Histogram):
"""
Marginal histogram
"""
def __init__(self, trace_dist):
assert isinstance(trace_dist, TracePosterior), \
"trace_dist must be trace posterior distribution object"
super(Marginal, self).__init__()
self.trace_dist = trace_dist

def gen_weighted_samples(self, *args, **kwargs):
for tr, log_weight in self.trace_dist._traces(*args, **kwargs):
yield (tr["_RETURN"]["value"], log_weight)


class TracePosterior(Histogram):
"""
abstract inference class
TODO documentation
"""
def gen_weighted_samples(self, *args, **kwargs):
for tr, log_weight in self._traces(*args, **kwargs):
yield (tr, log_weight)

def _traces(self, *args, **kwargs):
"""
Use the histogram to score a value
Virtual method to get unnormalized weighted list of posterior traces
"""
return self._dist(*args, **kwargs).log_pdf(val)
raise NotImplementedError("inference algorithm must implement _traces")

# def log_z(self, *args, **kwargs):
# """
Expand All @@ -53,24 +89,3 @@ def log_pdf(self, val, *args, **kwargs):
# for tr, log_w in zip(traces, log_weights):
# log_z = log_z + log_w
# return log_z / len(traces)


def lw_expectation(trace_dist, functional, num_samples):
# running var
accum_so_far = 0.
sum_weight = 0.

# sample from trace_dist
samples = trace_dist.runner(num_samples)

# loop over the sample tuples
for i, rv, cur_score in samples:

# not necessarily efficient torch.exp call x2, fix later
sum_weight += torch.exp(cur_score)

# apply function to return value, multiply by exp(cur_score)
accum_so_far += functional(rv) * torch.exp(cur_score)

#
return accum_so_far / sum_weight
14 changes: 7 additions & 7 deletions pyro/infer/importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,27 @@

import pyro
import pyro.poutine as poutine
from pyro.infer import AbstractInfer
from pyro.infer import TracePosterior


# XXX what should be the base class here?
class Importance(AbstractInfer):
class Importance(TracePosterior):
"""
A new implementation of importance sampling
"""
def __init__(self, model, guide=None, samples=10):
def __init__(self, model, guide=None, samples=None):
"""
Constructor
TODO proper docs etc
"""
super(Importance, self).__init__()
self.samples = samples
self.model = model
if samples is None:
samples = 10
if guide is None:
# propose from the prior
guide = poutine.block(model, hide_types=["observe"])
self.samples = samples
self.model = model
self.guide = guide

def _traces(self, *args, **kwargs):
Expand All @@ -34,6 +36,4 @@ def _traces(self, *args, **kwargs):
model_trace = poutine.trace(
poutine.replay(self.model, guide_trace))(*args, **kwargs)
log_weight = model_trace.log_pdf() - guide_trace.log_pdf()
# traces.append((model_trace, log_weight))
yield (model_trace, log_weight)
# return traces
3 changes: 1 addition & 2 deletions pyro/infer/kl_qp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from collections import OrderedDict
import pyro
import pyro.poutine as poutine
# from pyro.infer.abstract_infer import AbstractInfer


class KL_QP(object): # AbstractInfer):
class KL_QP(object):
"""
A new, Trace and Poutine-based implementation of SVI
"""
Expand Down
4 changes: 2 additions & 2 deletions pyro/infer/mh.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import pyro
import pyro.poutine as poutine
from pyro.distributions import Uniform
from pyro.infer import AbstractInfer
from pyro.infer import TracePosterior


class MH(AbstractInfer):
class MH(TracePosterior):
"""
Initial implementation of MH MCMC
"""
Expand Down
4 changes: 2 additions & 2 deletions pyro/infer/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from queue import Queue

import pyro.poutine as poutine
from pyro.infer import AbstractInfer
from pyro.infer import TracePosterior


class Search(AbstractInfer):
class Search(TracePosterior):
"""
New Trace and Poutine-based implementation of systematic search
"""
Expand Down

0 comments on commit f054b16

Please sign in to comment.