Skip to content

Commit

Permalink
Merge 9488da8 into 91bc2b3
Browse files Browse the repository at this point in the history
  • Loading branch information
OlaRonning committed Apr 22, 2024
2 parents 91bc2b3 + 9488da8 commit 095b9a0
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 5 deletions.
46 changes: 42 additions & 4 deletions pyro/infer/predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,14 @@ def _predictive_sequential(
)
collected_trace.append(trace)
collected_samples.append(
{site: trace.nodes[site]["value"] for site in return_site_shapes}
{
site: (
trace.nodes[site]["value"]
if site in trace.nodes
else samples[i][site]
)
for site in return_site_shapes
}
)

return _predictiveResults(
Expand All @@ -84,6 +91,7 @@ def _predictive(
model_args=(),
model_kwargs={},
mask=True,
posterior_deterministic_sites=(),
):
model = torch.no_grad()(poutine.mask(model, mask=False) if mask else model)
max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
Expand Down Expand Up @@ -122,6 +130,9 @@ def _predictive(
elif site not in posterior_samples:
return_site_shapes[site] = site_shape

for site in posterior_deterministic_sites:
return_site_shapes[site] = posterior_samples[site].shape

# handle _RETURN site
if return_sites is not None and "_RETURN" in return_sites:
value = model_trace.nodes["_RETURN"]["value"]
Expand All @@ -143,7 +154,10 @@ def _predictive(
).get_trace(*model_args, **model_kwargs)
predictions = {}
for site, shape in return_site_shapes.items():
value = trace.nodes[site]["value"]
if site in trace.nodes:
value = trace.nodes[site]["value"]
else:
value = reshaped_samples[site]
if site == "_RETURN" and shape is None:
predictions[site] = value
continue
Expand Down Expand Up @@ -179,6 +193,8 @@ class Predictive(torch.nn.Module):
:param bool parallel: predict in parallel by wrapping the existing model
in an outermost `plate` messenger. Note that this requires that the model has
all batch dims correctly annotated via :class:`~pyro.plate`. Default is `False`.
:param return_deterministic_guide_sites: include deterministic sites from the guide
in returned samples; this does not affect the returned trace.
"""

def __init__(
Expand All @@ -189,6 +205,7 @@ def __init__(
num_samples=None,
return_sites=(),
parallel=False,
return_deterministic_guide_sites=False,
):
super().__init__()
if posterior_samples is None:
Expand Down Expand Up @@ -231,6 +248,7 @@ def __init__(
self.guide = guide
self.return_sites = return_sites
self.parallel = parallel
self.return_deterministic_guide_sites = return_deterministic_guide_sites

def call(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -262,18 +280,37 @@ def forward(self, *args, **kwargs):
"""
posterior_samples = self.posterior_samples
return_sites = self.return_sites

guide_deterministic_sites = ()

if self.guide is not None:
# return all sites by default if a guide is provided.
return_sites = None if not return_sites else return_sites
posterior_samples = _predictive(
guide_pred_res = _predictive(
self.guide,
posterior_samples,
self.num_samples,
return_sites=None,
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
).samples
)
posterior_samples = guide_pred_res.samples

if self.return_deterministic_guide_sites:
if isinstance(guide_pred_res, Trace):
guide_tr = guide_pred_res.trace
else:
guide_tr = guide_pred_res.trace[0]

guide_deterministic_sites = tuple(
name
for name, site in guide_tr.nodes.items()
if site["type"] == "sample"
if site["infer"].get("_deterministic")
if (return_sites is None or name in return_sites)
)

return _predictive(
self.model,
posterior_samples,
Expand All @@ -282,6 +319,7 @@ def forward(self, *args, **kwargs):
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
posterior_deterministic_sites=guide_deterministic_sites,
).samples

def get_samples(self, *args, **kwargs):
Expand Down
8 changes: 7 additions & 1 deletion pyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,12 @@ def check_model_guide_match(model_trace, guide_trace, max_plate_nesting=math.inf
if site["type"] == "sample"
if site["infer"].get("is_auxiliary")
)
det_vars = set(
name
for name, site in guide_trace.nodes.items()
if site["type"] == "sample"
if site["infer"].get("_deterministic")
)
model_vars = set(
name
for name, site in model_trace.nodes.items()
Expand All @@ -284,7 +290,7 @@ def check_model_guide_match(model_trace, guide_trace, max_plate_nesting=math.inf
warnings.warn(
"Found auxiliary vars in the model: {}".format(aux_vars & model_vars)
)
if not (guide_vars <= model_vars | aux_vars):
if not (guide_vars <= model_vars | aux_vars | det_vars):
warnings.warn(
"Found non-auxiliary vars in guide but not model, "
"consider marking these infer={{'is_auxiliary': True}}:\n{}".format(
Expand Down
38 changes: 38 additions & 0 deletions tests/infer/test_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,44 @@ def model(y=None):
assert_close(actual["x3"].mean(), y, rtol=0.1)


@pytest.mark.parametrize("with_plate", [True, False])
@pytest.mark.parametrize("event_shape", [(), (2,)])
@pytest.mark.parametrize("return_deterministic_guide_sites", [True, False])
def test_deterministic_guide_return(
with_plate, event_shape, return_deterministic_guide_sites
):
def model(y=None):
with pyro.util.optional(pyro.plate("plate", 3), with_plate):
x = pyro.sample("x", dist.Normal(0, 1).expand(event_shape).to_event())
x2 = pyro.deterministic("x2", x**2, event_dim=len(event_shape))

pyro.deterministic("x3", x2)
return pyro.sample("obs", dist.Normal(x2, 0.1).to_event(), obs=y)

def guide(y=None):
with pyro.util.optional(pyro.plate("plate", 3), with_plate):
x = pyro.sample("x", dist.Normal(0, 1).expand(event_shape).to_event())

pyro.deterministic("x4", x)

y = torch.tensor(4.0)
svi = SVI(model, guide, optim.Adam(dict(lr=0.1)), Trace_ELBO())
for i in range(100):
svi.step(y)

actual = Predictive(
model,
guide=guide,
num_samples=1000,
return_deterministic_guide_sites=return_deterministic_guide_sites,
)()

if return_deterministic_guide_sites:
assert "x4" in actual
else:
assert "x4" not in actual


def test_get_mask_optimization():
def model():
x = pyro.sample("x", dist.Normal(0, 1))
Expand Down

0 comments on commit 095b9a0

Please sign in to comment.