Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a pyro.factor primitive #2022

Merged
merged 5 commits into from Aug 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/distributions.rst
Expand Up @@ -189,6 +189,13 @@ SpanningTree
:undoc-members:
:show-inheritance:

Unit
----
.. autoclass:: pyro.distributions.Unit
:members:
:undoc-members:
:show-inheritance:

VonMises
--------
.. autoclass:: pyro.distributions.VonMises
Expand Down
1 change: 1 addition & 0 deletions docs/source/primitives.rst
Expand Up @@ -5,6 +5,7 @@ Primitives
.. autofunction:: pyro.param
.. autofunction:: pyro.module
.. autofunction:: pyro.random_module
.. autofunction:: pyro.factor

.. autoclass:: pyro.plate
:members:
Expand Down
4 changes: 2 additions & 2 deletions examples/rsa/generics.py
Expand Up @@ -16,7 +16,7 @@
import pyro.distributions as dist
import pyro.poutine as poutine

from search_inference import factor, HashingMarginal, memoize, Search
from search_inference import HashingMarginal, memoize, Search

torch.set_default_dtype(torch.float64) # double precision for numerical stability

Expand Down Expand Up @@ -91,7 +91,7 @@ def meaning(utterance, state, threshold):
def listener0(utterance, threshold, prior):
state = pyro.sample("state", prior)
m = meaning(utterance, state, threshold)
factor("listener0_true", 0. if m else -99999.)
pyro.factor("listener0_true", 0. if m else -99999.)
return state


Expand Down
4 changes: 2 additions & 2 deletions examples/rsa/hyperbole.py
Expand Up @@ -13,7 +13,7 @@
import pyro.distributions as dist
import pyro.poutine as poutine

from search_inference import factor, HashingMarginal, memoize, Search
from search_inference import HashingMarginal, memoize, Search

torch.set_default_dtype(torch.float64) # double precision for numerical stability

Expand Down Expand Up @@ -97,7 +97,7 @@ def utterance_prior():
def literal_listener(utterance, qud):
price = price_prior()
state = State(price=price, valence=valence_prior(price))
factor("literal_meaning", 0. if meaning(utterance, price) else -999999.)
pyro.factor("literal_meaning", 0. if meaning(utterance, price) else -999999.)
return qud_fns[qud](state)


Expand Down
10 changes: 0 additions & 10 deletions examples/rsa/search_inference.py
Expand Up @@ -10,7 +10,6 @@
import queue
import functools

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer.abstract_infer import TracePosterior
Expand All @@ -23,15 +22,6 @@ def memoize(fn=None, **kwargs):
return functools.lru_cache(**kwargs)(fn)


def factor(name, value):
"""
Like factor in webPPL, adds a scalar weight to the log-probability of the trace
"""
value = value if torch.is_tensor(value) else torch.tensor(value)
d = dist.Bernoulli(logits=value)
pyro.sample(name, d, obs=torch.ones(value.size()))


class HashingMarginal(dist.Distribution):
"""
:param trace_dist: a TracePosterior instance representing a Monte Carlo posterior
Expand Down
10 changes: 5 additions & 5 deletions examples/rsa/semantic_parsing.py
Expand Up @@ -12,7 +12,7 @@
import pyro
import pyro.distributions as dist

from search_inference import HashingMarginal, BestFirstSearch, factor, memoize
from search_inference import HashingMarginal, BestFirstSearch, memoize

torch.set_default_dtype(torch.float64)

Expand Down Expand Up @@ -184,10 +184,10 @@ def world_prior(num_objs, meaning_fn):
for i in range(num_objs):
world.append(Obj("obj_{}".format(i)))
new_factor = heuristic(meaning_fn(world))
factor("factor_{}".format(i), new_factor - prev_factor)
pyro.factor("factor_{}".format(i), new_factor - prev_factor)
prev_factor = new_factor

factor("factor_{}".format(num_objs), prev_factor * -1)
pyro.factor("factor_{}".format(num_objs), prev_factor * -1)
return tuple(world)


Expand Down Expand Up @@ -276,7 +276,7 @@ def meaning(utterance):
def literal_listener(utterance):
m = meaning(utterance)
world = world_prior(2, m)
factor("world_constraint", heuristic(m(world)) * 1000)
pyro.factor("world_constraint", heuristic(m(world)) * 1000)
return world


Expand Down Expand Up @@ -306,7 +306,7 @@ def rsa_listener(utterance, qud):
def literal_listener_raw(utterance, qud):
m = meaning(utterance)
world = world_prior(3, m)
factor("world_constraint", heuristic(m(world)) * 1000)
pyro.factor("world_constraint", heuristic(m(world)) * 1000)
return qud(world)


Expand Down
5 changes: 3 additions & 2 deletions pyro/__init__.py
@@ -1,8 +1,8 @@
import pyro.poutine as poutine
from pyro.logger import log
from pyro.poutine import condition, do, markov
from pyro.primitives import (clear_param_store, enable_validation, get_param_store, iarange, irange, module, param,
plate, random_module, sample, validation_enabled)
from pyro.primitives import (clear_param_store, enable_validation, factor, get_param_store, iarange, irange, module,
param, plate, random_module, sample, validation_enabled)
from pyro.util import set_rng_seed

version_prefix = '0.4.1'
Expand All @@ -19,6 +19,7 @@
"condition",
"do",
"enable_validation",
"factor",
"get_param_store",
"iarange",
"irange",
Expand Down
5 changes: 1 addition & 4 deletions pyro/contrib/gp/models/sgpr.py
Expand Up @@ -146,10 +146,7 @@ def model(self):
return f_loc, f_var
else:
if self.approx == "VFE":
# inject trace_term to model's log_prob
pyro.sample("trace_term",
dist.Delta(v=trace_term.new_tensor(0.), log_density=-trace_term / 2.),
obs=trace_term.new_tensor(0.))
pyro.factor("trace_term", -trace_term / 2.)

return pyro.sample("y",
dist.LowRankMultivariateNormal(f_loc, W, D)
Expand Down
2 changes: 2 additions & 0 deletions pyro/distributions/__init__.py
Expand Up @@ -20,6 +20,7 @@
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.torch_transform import TransformModule
from pyro.distributions.util import enable_validation, is_validation_enabled, validation_enabled
from pyro.distributions.unit import Unit
from pyro.distributions.von_mises import VonMises
from pyro.distributions.von_mises_3d import VonMises3D
from pyro.distributions.zero_inflated_poisson import ZeroInflatedPoisson
Expand Down Expand Up @@ -52,6 +53,7 @@
"SpanningTree",
"TorchDistribution",
"TransformModule",
"Unit",
"VonMises",
"VonMises3D",
"ZeroInflatedPoisson",
Expand Down
38 changes: 38 additions & 0 deletions pyro/distributions/unit.py
@@ -0,0 +1,38 @@
import torch
from torch.distributions import constraints

from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import broadcast_shape


class Unit(TorchDistribution):
"""
Trivial nonnormalized distribution representing the unit type.

The unit type has a single value with no data, i.e. ``value.numel() == 0``.

This is used for :func:`pyro.factor` statements.
"""
arg_constraints = {'log_factor': constraints.real}
support = constraints.real

def __init__(self, log_factor, validate_args=None):
log_factor = torch.as_tensor(log_factor)
batch_shape = log_factor.shape
event_shape = torch.Size((0,)) # This satisfies .numel() == 0.
self.log_factor = log_factor
super(Unit, self).__init__(batch_shape, event_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Unit, _instance)
new.log_factor = self.log_factor.expand(batch_shape)
super(Unit, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def sample(self, sample_shape=torch.Size()):
return self.log_factor.new_empty(sample_shape + self.shape())

def log_prob(self, value):
shape = broadcast_shape(self.batch_shape, value.shape[:-1])
return self.log_factor.expand(shape)
13 changes: 13 additions & 0 deletions pyro/primitives.py
Expand Up @@ -111,6 +111,19 @@ def sample(name, fn, *args, **kwargs):
return msg["value"]


def factor(name, log_factor):
"""
Factor statement to add arbitrary log probability factor to a
probabilisitic model.

:param str name: Name of the trivial sample
:param torch.Tensor log_factor: A possibly batched log probability factor.
"""
unit_dist = dist.Unit(log_factor)
unit_value = unit_dist.sample()
sample(name, unit_dist, obs=unit_value)


class plate(PlateMessenger):
"""
Construct for conditionally independent sequences of variables.
Expand Down
29 changes: 29 additions & 0 deletions tests/distributions/test_unit.py
@@ -0,0 +1,29 @@
import pytest
import torch

import pyro.distributions as dist
from tests.common import assert_equal


@pytest.mark.parametrize('batch_shape', [(), (4,), (3, 2)])
def test_shapes(batch_shape):
log_factor = torch.randn(batch_shape)

d = dist.Unit(log_factor=log_factor)
x = d.sample()
assert x.shape == batch_shape + (0,)
assert (d.log_prob(x) == log_factor).all()


@pytest.mark.parametrize('sample_shape', [(), (4,), (3, 2)])
@pytest.mark.parametrize('batch_shape', [(), (7,), (6, 5)])
def test_expand(sample_shape, batch_shape):
log_factor = torch.randn(batch_shape)
d1 = dist.Unit(log_factor)
v1 = d1.sample()

d2 = d1.expand(sample_shape + batch_shape)
assert d2.batch_shape == sample_shape + batch_shape
v2 = d2.sample()
assert v2.shape == sample_shape + batch_shape + (0,)
assert_equal(d1.log_prob(v2), d2.log_prob(v1))
34 changes: 32 additions & 2 deletions tests/infer/test_autoguide.py
Expand Up @@ -8,13 +8,13 @@
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO
from pyro.infer.autoguide import (AutoCallable, AutoDelta, AutoDiagonalNormal, AutoDiscreteParallel, AutoGuideList,
AutoIAFNormal, AutoLaplaceApproximation, AutoLowRankMultivariateNormal,
AutoMultivariateNormal, init_to_feasible, init_to_mean, init_to_median,
init_to_sample)
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO
from pyro.optim import Adam
from tests.common import assert_equal
from tests.common import assert_close, assert_equal


@pytest.mark.parametrize("auto_class", [
Expand Down Expand Up @@ -43,6 +43,36 @@ def model():
assert guide_trace.nodes['z']['log_prob_sum'].item() == 0.0


@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
@pytest.mark.parametrize("auto_class", [
AutoDelta,
AutoDiagonalNormal,
AutoMultivariateNormal,
AutoLowRankMultivariateNormal,
AutoIAFNormal,
AutoLaplaceApproximation,
])
def test_factor(auto_class, Elbo):

def model(log_factor):
pyro.sample("z1", dist.Normal(0.0, 1.0))
pyro.factor("f1", log_factor)
pyro.sample("z2", dist.Normal(torch.zeros(2), torch.ones(2)).to_event(1))
with pyro.plate("plate", 3):
pyro.factor("f2", log_factor)
pyro.sample("z3", dist.Normal(torch.zeros(3), torch.ones(3)))

guide = auto_class(model)
elbo = Elbo(strict_enumeration_warning=False)
elbo.loss(model, guide, torch.tensor(0.)) # initialize param store

pyro.set_rng_seed(123)
loss_5 = elbo.loss(model, guide, torch.tensor(5.))
pyro.set_rng_seed(123)
loss_4 = elbo.loss(model, guide, torch.tensor(4.))
assert_close(loss_5 - loss_4, -1 - 3)


@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
@pytest.mark.parametrize("init_loc_fn", [
init_to_feasible,
Expand Down
42 changes: 25 additions & 17 deletions tutorial/source/RSA-hyperbole.ipynb

Large diffs are not rendered by default.

36 changes: 19 additions & 17 deletions tutorial/source/RSA-implicature.ipynb

Large diffs are not rendered by default.