Skip to content

Commit

Permalink
add substitute handler (#3125)
Browse files Browse the repository at this point in the history
  • Loading branch information
thisiscam committed Jan 25, 2023
1 parent af3db08 commit 1678ee2
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/source/pyro.poutine.txt
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ __________________
:undoc-members:
:show-inheritance:

SubstituteMessenger
___________________

.. automodule:: pyro.poutine.substitute_messenger
:members:
:undoc-members:
:show-inheritance:

TraceMessenger
_______________

Expand Down
2 changes: 2 additions & 0 deletions pyro/poutine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
replay,
scale,
seed,
substitute,
trace,
uncondition,
)
Expand Down Expand Up @@ -47,6 +48,7 @@
"queue",
"scale",
"seed",
"substitute",
"trace",
"Trace",
"uncondition",
Expand Down
2 changes: 2 additions & 0 deletions pyro/poutine/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from .runtime import NonlocalExit
from .scale_messenger import ScaleMessenger
from .seed_messenger import SeedMessenger
from .substitute_messenger import SubstituteMessenger
from .trace_messenger import TraceMessenger
from .uncondition_messenger import UnconditionMessenger

Expand All @@ -97,6 +98,7 @@
SeedMessenger,
TraceMessenger,
UnconditionMessenger,
SubstituteMessenger,
]


Expand Down
85 changes: 85 additions & 0 deletions pyro/poutine/substitute_messenger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import warnings

from pyro import params
from pyro.poutine.messenger import Messenger
from pyro.poutine.util import is_validation_enabled


class SubstituteMessenger(Messenger):
"""
Given a stochastic function with param calls and a set of parameter values,
create a stochastic function where all param calls are substituted with
the fixed values.
data should be a dict of names to values.
Consider the following Pyro program:
>>> def model(x):
... a = pyro.param("a", torch.tensor(0.5))
... x = pyro.sample("x", dist.Bernoulli(probs=a))
... return x
>>> substituted_model = pyro.poutine.substitute(model, data={"s": 0.3})
In this example, site `a` will now have value `0.3`.
:param data: dictionary of values keyed by site names.
:returns: ``fn`` decorated with a :class:`~pyro.poutine.substitute_messenger.SubstituteMessenger`
"""

def __init__(self, data):
"""
:param data: values for the parameters.
Constructor
"""
super().__init__()
self.data = data
self._data_cache = {}

def __enter__(self):
self._data_cache = {}
if is_validation_enabled() and isinstance(self.data, dict):
self._param_hits = set()
self._param_misses = set()
return super().__enter__()

def __exit__(self, *args, **kwargs):
self._data_cache = {}
if is_validation_enabled() and isinstance(self.data, dict):
extra = set(self.data) - self._param_hits
if extra:
warnings.warn(
"pyro.module data did not find params ['{}']. "
"Did you instead mean one of ['{}']?".format(
"', '".join(extra), "', '".join(self._param_misses)
)
)
return super().__exit__(*args, **kwargs)

def _pyro_sample(self, msg):
return None

def _pyro_param(self, msg):
"""
Overrides the `pyro.param` with substituted values.
If the param name does not match the name the keys in `data`,
that param value is unchanged.
"""
name = msg["name"]
param_name = params.user_param_name(name)

if param_name in self.data.keys():
msg["value"] = self.data[param_name]
if is_validation_enabled():
self._param_hits.add(param_name)
else:
if is_validation_enabled():
self._param_misses.add(param_name)
return None

if name in self._data_cache:
# Multiple pyro.param statements with the same
# name. Block the site and fix the value.
msg["value"] = self._data_cache[name]["value"]
else:
self._data_cache[name] = msg
30 changes: 30 additions & 0 deletions tests/poutine/test_poutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,36 @@ def _test_scale_factor(batch_size_outer, batch_size_inner, expected):
_test_scale_factor(2, 1, [2.0] * 2)


class SubstituteHandlerTests(NormalNormalNormalHandlerTestCase):
def test_substitute(self):
data = {"loc1": torch.randn(2)}
tr2 = poutine.trace(poutine.substitute(self.guide, data=data)).get_trace()
assert "loc1" in tr2
assert tr2.nodes["loc1"]["type"] == "param"
assert tr2.nodes["loc1"]["value"] is data["loc1"]

def test_stack_overwrite_behavior(self):
data1 = {"loc1": torch.randn(2)}
data2 = {"loc1": torch.randn(2)}
with poutine.trace() as tr:
cm = poutine.substitute(
poutine.substitute(self.guide, data=data1), data=data2
)
cm()
assert tr.trace.nodes["loc1"]["value"] is data2["loc1"]

def test_stack_success(self):
data1 = {"loc1": torch.randn(2)}
data2 = {"loc2": torch.randn(2)}
tr = poutine.trace(
poutine.substitute(poutine.substitute(self.guide, data=data1), data=data2)
).get_trace()
assert tr.nodes["loc1"]["type"] == "param"
assert tr.nodes["loc1"]["value"] is data1["loc1"]
assert tr.nodes["loc2"]["type"] == "param"
assert tr.nodes["loc2"]["value"] is data2["loc2"]


class ConditionHandlerTests(NormalNormalNormalHandlerTestCase):
def test_condition(self):
data = {"latent2": torch.randn(2)}
Expand Down

0 comments on commit 1678ee2

Please sign in to comment.