Skip to content
Permalink
Browse files

Update graph_analsysis

  • Loading branch information...
colah committed Jun 12, 2019
1 parent 67f2add commit ec6d027b3c32e2af5d49185cb31ccb3721c2b29f
Showing with 196 additions and 6 deletions.
  1. +55 −6 lucid/misc/graph_analysis/overlay_graph.py
  2. +141 −0 lucid/misc/graph_analysis/parse_overlay.py
@@ -33,7 +33,11 @@ def __init__(self, name, overlay_graph):
self.name = name
self.overlay_graph = overlay_graph
self.tf_graph = overlay_graph.tf_graph
self.tf_node = self.tf_graph.get_tensor_by_name(name)
try:
self.tf_node = self.tf_graph.get_tensor_by_name(name)
except:
self.tf_node = None
self.sub_structure = None

@staticmethod
def as_name(node):
@@ -75,6 +79,27 @@ def gcd(self):
def lcm(self):
return self.overlay_graph.lcm(self.consumers)


class OverlayStructure():
"""Represents a sub-structure of a OverlayGraph.
Often, we want to find structures within a graph, such as branches and
sequences, to assist with graph layout for users.
An OverlayStructure represents such a structure. It is typically used
in conjunction with OverlayGraph.collapse_structures() to parse a graph.
"""

def __init__(self, structure_type, structure):
self.structure_type = structure_type
self.structure = structure # A dictionary
self.children = sum([component if isinstance(component, (list, tuple)) else [component]
for component in structure.values()], [])

def __contains__(self, item):
return OverlayNode.as_name(item) in [n.name for n in self.children]


class OverlayGraph():
"""A subgraph of a TensorFlow computational graph.
@@ -86,7 +111,7 @@ class OverlayGraph():
edges correspond to paths through the original graph.
"""

def __init__(self, tf_graph, nodes=None, no_pass_through=None):
def __init__(self, tf_graph, nodes=None, no_pass_through=None, prev_overlay=None):
self.tf_graph = tf_graph

if nodes is None:
@@ -99,6 +124,7 @@ def __init__(self, tf_graph, nodes=None, no_pass_through=None):
self.no_pass_through = [] if no_pass_through is None else no_pass_through
self.node_to_consumers = defaultdict(lambda: set())
self.node_to_inputs = defaultdict(lambda: set())
self.prev_overlay = prev_overlay

for node in self.nodes:
for inp in self._get_overlay_inputs(node):
@@ -129,9 +155,13 @@ def get_tf_node(self, node):
return self.tf_graph.get_tensor_by_name(name)

def _get_overlay_inputs(self, node):
node = self.get_tf_node(node)
if self.prev_overlay:
raw_inps = self.prev_overlay[node].inputs
else:
raw_inps = self.get_tf_node(node).op.inputs

overlay_inps = []
for inp in node.op.inputs:
for inp in raw_inps:
if inp in self:
overlay_inps.append(self[inp])
elif not node.name in self.no_pass_through:
@@ -156,8 +186,8 @@ def graphviz(self, groups=None):
print(" ", '"' + inp.name + '"', " -> ", '"' + (node.name) + '"')
print("}")


def filter(self, keep_nodes, pass_through=True):
keep_nodes = [self[n].name for n in keep_nodes]
old_nodes = set(self.name_map.keys())
new_nodes = set(keep_nodes)
no_pass_through = set(self.no_pass_through)
@@ -166,7 +196,10 @@ def filter(self, keep_nodes, pass_through=True):
no_pass_through += old_nodes - new_nodes

keep_nodes = [node for node in self.name_map if node in keep_nodes]
return OverlayGraph(self.tf_graph, keep_nodes, no_pass_through)
new_overlay = OverlayGraph(self.tf_graph, keep_nodes, no_pass_through, prev_overlay=self)
for node in new_overlay.nodes:
node.sub_structure = self[node].sub_structure
return new_overlay

def gcd(self, branches):
"""Greatest common divisor (ie. input) of several nodes."""
@@ -181,3 +214,19 @@ def lcm(self, branches):
branch_nodes = [set([node]) | node.extended_consumers for node in branches]
branch_shared = set.intersection(*branch_nodes)
return min(branch_shared, key=lambda n: self.nodes.index(n))

def sorted(self, items):
return sorted(items, key=lambda n: self.nodes.index(self[n]))

def collapse_structures(self, structure_map):

keep_nodes = [node.name for node in self.nodes
if not any(node in structure.children for structure in structure_map.values())
or node in structure_map]

new_overlay = self.filter(keep_nodes)

for node in structure_map:
new_overlay[node].sub_structure = structure_map[node]

return new_overlay
@@ -0,0 +1,141 @@
from lucid.misc.graph_analysis.overlay_graph import OverlayGraph, OverlayNode, OverlayStructure

def collapse_sequences(overlay):
"""Detect and collapse sequences of nodes in an overlay."""
sequences = []
for node in overlay.nodes:
if any([node in seq for seq in sequences]): continue
seq = [node]
while len(node.consumers) == 1 and len(list(node.consumers)[0].inputs) == 1:
node = list(node.consumers)[0]
seq.append(node)
if len(seq) > 1:
sequences.append(seq)

structure_map = {}
for seq in sequences:
structure_map[seq[-1]] = OverlayStructure("Sequence", {"sequence": seq})

return overlay.collapse_structures(structure_map)


def collapse_branches(overlay):
"""Detect and collapse brances of nodes in an overlay."""
structure_map = {}

for node in overlay.nodes:
if len(node.inputs) <= 1: continue
gcd = node.gcd
if all(inp == gcd or inp.inputs == set([gcd]) for inp in node.inputs):
branches = [inp if inp != gcd else None
for inp in overlay.sorted(node.inputs)]
structure_map[node] = OverlayStructure("HeadBranch", {"branches" : branches, "head": node})

for node in overlay.nodes:
if len(node.consumers) <= 1: continue
if all(len(out.consumers) == 0 for out in node.consumers):
branches = overlay.sorted(node.consumers)
max_node = overlay.sorted(branches)[-1]
structure_map[max_node] = OverlayStructure("TailBranch", {"branches" : branches, "tail": node})

return overlay.collapse_structures(structure_map)


def parse_structure(node):
"""Turn a collapsed node in an OverlayGraph into a heirchaical grpah structure."""
if node is None:
return None

structure = node.sub_structure

if structure is None:
return node.name
elif structure.structure_type == "Sequence":
return {"Sequence" : [parse_structure(n) for n in structure.structure["sequence"]]}
elif structure.structure_type == "HeadBranch":
return {"Sequence" : [
{"Branch" : [parse_structure(n) for n in structure.structure["branches"]] },
parse_structure(structure.structure["head"])
]}
elif structure.structure_type == "TailBranch":
return {"Sequence" : [
parse_structure(structure.structure["tail"]),
{"Branch" : [parse_structure(n) for n in structure.structure["branches"]] },
]}
else:
data = {}
for k in structure.structure:
if isinstance(structure.structure[k], list):
data[k] = [parse_structure(n) for n in structure.structure[k]]
else:
data[k] = parse_structure(structure.structure[k])

return {structure.structure_type : data}


def flatten_sequences(structure):
"""Flatten nested sequences into a single sequence."""
if isinstance(structure, str) or structure is None:
return structure
else:
structure = structure.copy()
for k in structure:
structure[k] = [flatten_sequences(sub) for sub in structure[k]]

if "Sequence" in structure:
new_seq = []
for sub in structure["Sequence"]:
if isinstance(sub, dict) and "Sequence" in sub:
new_seq += sub["Sequence"]
else:
new_seq.append(sub)
structure["Sequence"] = new_seq
return structure


def parse_overlay(overlay):
new_overlay = overlay
prev_len = len(overlay.nodes)

collapsers = [collapse_sequences, collapse_branches]

while True:
new_overlay = collapse_branches(collapse_sequences(new_overlay))
if not len(new_overlay.nodes) < prev_len:
break
prev_len = len(new_overlay.nodes)

if len(new_overlay.nodes) != 1: return None


return flatten_sequences(parse_structure(new_overlay.nodes[0]))


def _namify(arr):
return [x.name for x in arr]

def toplevel_group_data(overlay):
pres = overlay.nodes[-1]
tops = [pres]
while pres.inputs:
pres = pres.gcd
tops.append(pres)
tops = tops[::-1]

groups = {}

for top in tops:
if top.op in ["Concat", "ConcatV2"]:
groups[top.name] = {
"immediate" : _namify(overlay.sorted(top.inputs)),
"all" : _namify(overlay.sorted(top.extended_inputs - top.gcd.extended_inputs - set([top.gcd]) | set([top]))),
"direction" : "backward"
}
if len(top.consumers) > 1 and all(out.op == "Split" for out in top.consumers):
groups[top.name] = {
"immediate" : _namify(overlay.sorted(top.consumers)),
"all" : _namify(overlay.sorted(top.extended_consumers - top.lcm.extended_consumers - set([top.lcm]) | set([top]))),
"direction" : "forward"
}

return groups

0 comments on commit ec6d027

Please sign in to comment.
You can’t perform that action at this time.