Skip to content

Commit

Permalink
Revert "Unsuccessful attempt to graph observations."
Browse files Browse the repository at this point in the history
This reverts commit 05a0580.
  • Loading branch information
rpgoldman committed May 23, 2019
1 parent 05a0580 commit e845c2f
Showing 1 changed file with 8 additions and 22 deletions.
30 changes: 8 additions & 22 deletions pymc3/model_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import deque
from typing import Iterator, Optional, MutableSet, FrozenSet, Tuple
from typing import Iterator, Optional, MutableSet

from theano.gof.graph import stack_search
from theano.compile import SharedVariable
Expand Down Expand Up @@ -34,7 +34,7 @@ def _get_ancestors(self, var, func) -> MutableSet[Tensor]:
vars = set(self.var_list)
vars.remove(var)

blockers = set() # type: MutableSet[Tensor]
blockers = set()
retval = set()
def _expand(node) -> Optional[Iterator[Tensor]]:
if node in blockers:
Expand All @@ -54,9 +54,9 @@ def _expand(node) -> Optional[Iterator[Tensor]]:
mode='bfs')
return retval

def _filter_parents(self, var, parents) -> Tuple[FrozenSet[str], FrozenSet[str]]:
def _filter_parents(self, var, parents):
"""Get direct parents of a var, as strings"""
keep = set() # type: MutableSet[str]
keep = set()
for p in parents:
if p == var:
continue
Expand All @@ -67,14 +67,9 @@ def _filter_parents(self, var, parents) -> Tuple[FrozenSet[str], FrozenSet[str]]
keep.add(self.transform_map[p])
else:
raise AssertionError('Do not know what to do with {}'.format(str(p)))
children = frozenset() # type: FrozenSet[str]
try:
children = frozenset(var.observations.name)
except AttributeError:
pass
return frozenset(keep - children), children
return keep

def get_parents(self, var) -> Tuple[FrozenSet[str], FrozenSet[str]]:
def get_parents(self, var):
"""Get the named nodes that are direct inputs to the var"""
if hasattr(var, 'transformed'):
func = var.transformed.logpt
Expand All @@ -88,18 +83,9 @@ def get_parents(self, var) -> Tuple[FrozenSet[str], FrozenSet[str]]:

def make_compute_graph(self):
"""Get map of var_name -> set(input var names) for the model"""
input_map = {} # type: Dict[str, FrozenSet[str]]
input_map = {}
for var_name in self.var_names:
# the "children" here are exclusively observations
parents, children = self.get_parents(self.model[var_name])
input_map[var_name] = parents
for child in children:
if child in input_map:
# slightly awkward union because input map values are not
# mutable
input_map[child] = frozenset(input_map[child] + set(var_name))
else:
input_map[child] = frozenset(var_name)
input_map[var_name] = self.get_parents(self.model[var_name])
return input_map

def _make_node(self, var_name, graph):
Expand Down

0 comments on commit e845c2f

Please sign in to comment.