Skip to content

Commit

Permalink
Merge a51bc5e into 8317281
Browse files Browse the repository at this point in the history
  • Loading branch information
rpgoldman committed May 24, 2019
2 parents 8317281 + a51bc5e commit abf4916
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 17 deletions.
2 changes: 1 addition & 1 deletion pymc3/data.py
Expand Up @@ -390,7 +390,7 @@ def align_minibatches(batches=None):

class Data:
"""Data container class that wraps the theano SharedVariable class
and let the model be aware of its inputs and outputs.
and lets the model be aware of its inputs and outputs.
Parameters
----------
Expand Down
46 changes: 31 additions & 15 deletions pymc3/model_graph.py
@@ -1,17 +1,16 @@
from collections import deque
from typing import Iterator, Optional, MutableSet
from typing import Dict, Iterator, Set, Optional

VarName = str

from theano.gof.graph import stack_search
from theano.compile import SharedVariable
from theano.tensor import Tensor

from .util import get_default_varnames
from .model import ObservedRV
import pymc3 as pm

# this is a placeholder for a better characterization of the type
# of variables in a model.
RV = Tensor


class ModelGraph:
def __init__(self, model):
Expand All @@ -30,16 +29,16 @@ def get_deterministics(self, var):
deterministics.append(v)
return deterministics

def _get_ancestors(self, var, func) -> MutableSet[RV]:
def _get_ancestors(self, var: Tensor, func) -> Set[Tensor]:
"""Get all ancestors of a function, doing some accounting for deterministics.
"""

# this contains all of the variables in the model EXCEPT var...
vars = set(self.var_list)
vars.remove(var)

blockers = set()
retval = set()
blockers = set() # type: Set[Tensor]
retval = set() # type: Set[Tensor]
def _expand(node) -> Optional[Iterator[Tensor]]:
if node in blockers:
return None
Expand All @@ -58,9 +57,9 @@ def _expand(node) -> Optional[Iterator[Tensor]]:
mode='bfs')
return retval

def _filter_parents(self, var, parents):
def _filter_parents(self, var, parents) -> Set[VarName]:
"""Get direct parents of a var, as strings"""
keep = set()
keep = set() # type: Set[VarName]
for p in parents:
if p == var:
continue
Expand All @@ -73,7 +72,7 @@ def _filter_parents(self, var, parents):
raise AssertionError('Do not know what to do with {}'.format(str(p)))
return keep

def get_parents(self, var):
def get_parents(self, var: Tensor) -> Set[VarName]:
"""Get the named nodes that are direct inputs to the var"""
if hasattr(var, 'transformed'):
func = var.transformed.logpt
Expand All @@ -85,11 +84,26 @@ def get_parents(self, var):
parents = self._get_ancestors(var, func)
return self._filter_parents(var, parents)

def make_compute_graph(self):
def make_compute_graph(self) -> Dict[str, Set[VarName]]:
"""Get map of var_name -> set(input var names) for the model"""
input_map = {}
input_map = {} # type: Dict[str, Set[VarName]]
def update_input_map(key: str, val: Set[VarName]):
if key in input_map:
input_map[key] = input_map[key].union(val)
else:
input_map[key] = val

for var_name in self.var_names:
input_map[var_name] = self.get_parents(self.model[var_name])
var = self.model[var_name]
update_input_map(var_name, self.get_parents(var))
if isinstance(var, ObservedRV):
try:
obs_name = var.observations.name
if obs_name:
input_map[var_name] = input_map[var_name].difference(set([obs_name]))
update_input_map(obs_name, set([var_name]))
except AttributeError:
pass
return input_map

def _make_node(self, var_name, graph):
Expand All @@ -107,9 +121,11 @@ def _make_node(self, var_name, graph):
# Get name for node
if hasattr(v, 'distribution'):
distribution = v.distribution.__class__.__name__
elif isinstance(v, SharedVariable):
distribution = 'Data'
else:
distribution = 'Deterministic'
attrs['shape'] = 'box'
attrs['shape'] = 'box'

graph.node(var_name.replace(':', '&'),
'{var_name} ~ {distribution}'.format(var_name=var_name, distribution=distribution),
Expand Down
2 changes: 1 addition & 1 deletion pymc3/tests/test_data_container.py
Expand Up @@ -101,5 +101,5 @@ def test_model_to_graphviz_for_model_with_data_container(self):
pm.sample(1000, init=None, tune=1000, chains=1)

g = pm.model_to_graphviz(model)
text = 'x [label="x ~ Deterministic" shape=box style=filled]'
text = 'x [label="x ~ Data" shape=box style=filled]'
assert text in g.source

0 comments on commit abf4916

Please sign in to comment.