Skip to content

Commit

Permalink
Fix #3245 bug in poutine.trace(graph_type="dense") (#3247)
Browse files Browse the repository at this point in the history
* Fix #3245 bug in poutine.trace(graph_type="dense")

* Update test assertion
  • Loading branch information
fritzo committed Jul 28, 2023
1 parent 47bee49 commit a59f6a2
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyro/poutine/trace_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def identify_dense_edges(trace):
continue
if node["type"] == "sample":
for past_name, past_node in trace.nodes.items():
if site_is_subsample(node):
if site_is_subsample(past_node):
continue
if past_node["type"] == "sample":
if past_name == name:
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/test_tracegraph_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def guide():
model_trace = pyro.poutine.trace(
pyro.poutine.replay(model, trace=guide_trace), graph_type="dense"
).get_trace()
assert len(list(model_trace.edges)) == 27
assert len(list(model_trace.edges)) == 9
assert len(model_trace.nodes) == 16
assert len(list(guide_trace.edges)) == 0
assert len(guide_trace.nodes) == 9
Expand Down
14 changes: 14 additions & 0 deletions tests/poutine/test_poutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,20 @@ def model():
assert len(_DIM_ALLOCATOR._stack) == 0, "stack was not cleaned on error"


@pytest.mark.parametrize(
"graph_type, expected", [("flat", set()), ("dense", {"x", "y"})]
)
def test_trace_plate(graph_type: str, expected: set):
def model():
with pyro.plate("plate", 2):
x = pyro.sample("x", dist.Normal(0, 1))
pyro.sample("y", dist.Normal(x, 1))

trace = poutine.trace(model, graph_type=graph_type).get_trace()
nodes = set().union(*trace._succ.values(), *trace._pred.values())
assert nodes == expected


def test_decorator_interface_primitives():
@poutine.trace
def model():
Expand Down

0 comments on commit a59f6a2

Please sign in to comment.