Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improved render_model #3039

Merged
merged 5 commits into from
Mar 16, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 70 additions & 10 deletions pyro/infer/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def _pyro_post_sample(self, msg):
value = detach_provenance(msg["value"])
msg["value"] = ProvenanceTensor(value, provenance)

def _pyro_post_param(self, msg):
if msg["type"] == "param":
provenance = frozenset({msg["name"]}) # track only direct dependencies
value = detach_provenance(msg["value"])
msg["value"] = ProvenanceTensor(value, provenance)


@torch.enable_grad()
def get_dependencies(
Expand Down Expand Up @@ -276,17 +282,34 @@ def model(data):
trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)

sample_sample = {}
sample_param = {}
sample_dist = {}
param_constraint = {}
plate_sample = defaultdict(list)
observed = []

def _get_type_from_frozenname(frozen_name):
return trace.nodes[frozen_name]["type"]

for name, site in trace.nodes.items():
if site["type"] == "param":
param_constraint[name] = site["kwargs"]["constraint"]
fritzo marked this conversation as resolved.
Show resolved Hide resolved

if site["type"] != "sample" or site_is_subsample(site):
continue

sample_sample[name] = [
upstream
for upstream in get_provenance(site["fn"].log_prob(site["value"]))
if upstream != name
if upstream != name and _get_type_from_frozenname(upstream) == "sample"
]

sample_param[name] = [
upstream
for upstream in get_provenance(site["fn"].log_prob(site["value"]))
if upstream != name and _get_type_from_frozenname(upstream) == "param"
]

sample_dist[name] = _get_dist_name(site["fn"])
for frame in site["cond_indep_stack"]:
plate_sample[frame.name].append(name)
Expand All @@ -313,7 +336,9 @@ def _resolve_plate_samples(plate_samples):

return {
"sample_sample": sample_sample,
"sample_param": sample_param,
"sample_dist": sample_dist,
"param_constraint": param_constraint,
"plate_sample": dict(plate_sample),
"observed": observed,
}
Expand All @@ -327,7 +352,9 @@ def _get_dist_name(fn):
return type(fn).__name__


def generate_graph_specification(model_relations: dict) -> dict:
def generate_graph_specification(
model_relations: dict, render_params: bool = False
) -> dict:
"""
Convert model relations into data structure which can be readily
converted into a network.
Expand All @@ -339,6 +366,16 @@ def generate_graph_specification(model_relations: dict) -> dict:
rv for rv in model_relations["sample_sample"] if rv not in plate_rvs
] # RVs which are in no plate

params = set()
fritzo marked this conversation as resolved.
Show resolved Hide resolved
if render_params:
for rv, params_list in model_relations["sample_param"].items():
for param in params_list:
params.add(param)
params = list(params)
fritzo marked this conversation as resolved.
Show resolved Hide resolved
plate_groups[None].extend(params)

# get set of params

# retain node metadata
node_data = {}
for rv in model_relations["sample_sample"]:
Expand All @@ -347,6 +384,14 @@ def generate_graph_specification(model_relations: dict) -> dict:
"distribution": model_relations["sample_dist"][rv],
}

if render_params:
for param, constraint in model_relations["param_constraint"].items():
node_data[param] = {
"is_observed": False,
"constraint": constraint,
"distribution": None,
}

# infer plate structure
# (when the order of plates cannot be determined from subset relations,
# it follows the order in which plates appear in trace)
Expand Down Expand Up @@ -380,6 +425,10 @@ def generate_graph_specification(model_relations: dict) -> dict:
for target, source_list in model_relations["sample_sample"].items():
edge_list.extend([(source, target) for source in source_list])

if render_params:
for target, source_list in model_relations["sample_param"].items():
edge_list.extend([(source, target) for source in source_list])

return {
"plate_groups": plate_groups,
"plate_data": plate_data,
Expand All @@ -389,8 +438,7 @@ def generate_graph_specification(model_relations: dict) -> dict:


def render_graph(
graph_specification: dict,
render_distributions: bool = False,
graph_specification: dict, render_distributions: bool = False
) -> "graphviz.Digraph":
"""
Create a graphviz object given a graph specification.
Expand Down Expand Up @@ -431,9 +479,15 @@ def render_graph(

for rv in rv_list:
color = "grey" if node_data[rv]["is_observed"] else "white"
cur_graph.node(
rv, label=rv, shape="ellipse", style="filled", fillcolor=color
)

# For sample_nodes - ellipse
if node_data[rv]["distribution"]:
shape = "ellipse"

# For param_nodes - No shape
else:
shape = "plain"
cur_graph.node(rv, label=rv, shape=shape, style="filled", fillcolor=color)

# add leaf nodes first
while len(plate_data) >= 1:
Expand All @@ -460,7 +514,11 @@ def render_graph(
dist_label = ""
for rv, data in node_data.items():
rv_dist = data["distribution"]
dist_label += rf"{rv} ~ {rv_dist}\l"
if rv_dist:
dist_label += rf"{rv} ~ {rv_dist}\l"

if "constraint" in data and data["constraint"]:
dist_label += rf"{rv} ∈ {data['constraint']}\l"

graph.node("distribution_description_node", label=dist_label, shape="plaintext")

Expand All @@ -474,6 +532,7 @@ def render_model(
model_kwargs: Optional[dict] = None,
filename: Optional[str] = None,
render_distributions: bool = False,
render_params: bool = False,
) -> "graphviz.Digraph":
"""
Renders a model using `graphviz <https://graphviz.org>`_ .
Expand All @@ -487,12 +546,13 @@ def render_model(
:param model_kwargs: Keyword arguments to pass to the model.
:param str filename: File to save rendered model in.
:param bool render_distributions: Whether to include RV distribution
annotations in the plot.
annotations (and param constraints) in the plot.
:param bool render_params: Whether to show params inthe plot.
:returns: A model graph.
:rtype: graphviz.Digraph
"""
relations = get_model_relations(model, model_args, model_kwargs)
graph_spec = generate_graph_specification(relations)
graph_spec = generate_graph_specification(relations, render_params=render_params)
graph = render_graph(graph_spec, render_distributions=render_distributions)

if filename is not None:
Expand Down
5 changes: 3 additions & 2 deletions pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.distributions import constraints, transform_to

import pyro
from pyro.ops.provenance import detach_provenance
from pyro.poutine.runtime import _PYRO_PARAM_STORE


Expand Down Expand Up @@ -533,7 +534,7 @@ def __setattr__(self, name, value):
constraint=constraint,
event_dim=event_dim,
)
constrained_value = pyro.param(fullname)
constrained_value = detach_provenance(pyro.param(fullname))
unconstrained_value = constrained_value.unconstrained()
if not isinstance(unconstrained_value, torch.nn.Parameter):
# Update PyroModule ---> ParamStore (type only; data is preserved).
Expand All @@ -556,7 +557,7 @@ def __setattr__(self, name, value):
value = pyro.param(fullname, value)
if not isinstance(value, torch.nn.Parameter):
# Update PyroModule ---> ParamStore (type only; data is preserved).
value = torch.nn.Parameter(value)
value = torch.nn.Parameter(detach_provenance(value))
_PYRO_PARAM_STORE._params[fullname] = value
_PYRO_PARAM_STORE._param_to_name[value] = fullname
super().__setattr__(name, value)
Expand Down