Skip to content

Commit

Permalink
Working memory tutorial (#2028)
Browse files Browse the repository at this point in the history
  • Loading branch information
ae-foster authored and fritzo committed Sep 17, 2019
1 parent 55afba3 commit 64e3b1a
Show file tree
Hide file tree
Showing 5 changed files with 891 additions and 2 deletions.
3 changes: 2 additions & 1 deletion pyro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pyro.logger import log
from pyro.poutine import condition, do, markov
from pyro.primitives import (clear_param_store, enable_validation, factor, get_param_store, iarange, irange, module,
param, plate, random_module, sample, validation_enabled)
param, plate, plate_stack, random_module, sample, validation_enabled)
from pyro.util import set_rng_seed

version_prefix = '0.4.1'
Expand All @@ -29,6 +29,7 @@
"param",
"plate",
"plate",
"plate_stack",
"poutine",
"random_module",
"sample",
Expand Down
4 changes: 4 additions & 0 deletions pyro/contrib/oed/eig.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,10 @@ def opt_eig_ape_loss(design, loss_fn, num_samples, num_steps, optim, return_hist
if return_history:
history.append(loss)
optim(params)
try:
optim.step()
except AttributeError:
pass

_, loss = loss_fn(final_design, final_num_samples, evaluation=True)
if return_history:
Expand Down
21 changes: 20 additions & 1 deletion pyro/primitives.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import warnings
from collections import OrderedDict
from contextlib import contextmanager
from contextlib import contextmanager, ExitStack
from inspect import isclass

import pyro.distributions as dist
Expand Down Expand Up @@ -243,6 +243,25 @@ def __init__(self, *args, **kwargs):
super(irange, self).__init__(*args, **kwargs)


@contextmanager
def plate_stack(prefix, sizes, rightmost_dim=-1):
"""
Create a contiguous stack of :class:`plate` s with dimensions::
rightmost_dim - len(sizes), ..., rightmost_dim
:param str prefix: Name prefix for plates.
:param iterable sizes: An iterable of plate sizes.
:param int rightmost_dim: The rightmost dim, counting from the right.
"""
assert rightmost_dim < 0
with ExitStack() as stack:
for i, size in enumerate(reversed(sizes)):
plate_i = plate("{}_{}".format(prefix, i), size, dim=rightmost_dim - i)
stack.enter_context(plate_i)
yield


def module(name, nn_module, update_module_params=False):
"""
Takes a torch.nn.Module and registers its parameters with the ParamStore.
Expand Down
54 changes: 54 additions & 0 deletions tests/infer/test_valid_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,60 @@ def guide():
assert_ok(model, guide, Elbo())


@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
@pytest.mark.parametrize("sizes", [(3,), (3, 4), (3, 4, 5)])
def test_plate_stack_ok(Elbo, sizes):

def model():
p = torch.tensor(0.5)
with pyro.plate_stack("plate_stack", sizes):
pyro.sample("x", dist.Bernoulli(p))

def guide():
p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
with pyro.plate_stack("plate_stack", sizes):
pyro.sample("x", dist.Bernoulli(p))

if Elbo is TraceEnum_ELBO:
guide = config_enumerate(guide)

assert_ok(model, guide, Elbo())


@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
@pytest.mark.parametrize("sizes", [(3,), (3, 4), (3, 4, 5)])
def test_plate_stack_and_plate_ok(Elbo, sizes):

def model():
p = torch.tensor(0.5)
with pyro.plate_stack("plate_stack", sizes):
with pyro.plate("plate", 7):
pyro.sample("x", dist.Bernoulli(p))

def guide():
p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
with pyro.plate_stack("plate_stack", sizes):
with pyro.plate("plate", 7):
pyro.sample("x", dist.Bernoulli(p))

if Elbo is TraceEnum_ELBO:
guide = config_enumerate(guide)

assert_ok(model, guide, Elbo())


@pytest.mark.parametrize("sizes", [(3,), (3, 4), (3, 4, 5)])
def test_plate_stack_sizes(sizes):

def model():
p = 0.5 * torch.ones(3)
with pyro.plate_stack("plate_stack", sizes):
x = pyro.sample("x", dist.Bernoulli(p).to_event(1))
assert x.shape == sizes + (3,)

model()


@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
def test_nested_plate_plate_ok(Elbo):

Expand Down
811 changes: 811 additions & 0 deletions tutorial/source/working_memory.ipynb

Large diffs are not rendered by default.

0 comments on commit 64e3b1a

Please sign in to comment.