# Example of merging gradient nodes

## Init

In [5]:
import sys, os, math, random
#sys.path.append('/Users/yaroslav/openai.git/pixel-cnn-private')

os.environ["CUDA_VISIBLE_DEVICES"]=""

import tensorflow as tf
import numpy as np

%matplotlib inline
import matplotlib.pyplot as plt

def create_session():
    config = tf.ConfigProto(log_device_placement=True,graph_options=tf.GraphOptions(optimizer_options=tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L0)))
    return tf.InteractiveSession(config=config)
    
import tensorflow.contrib.graph_editor as ge
from toposort import toposort

## Graph visualizer

In [6]:

# from https://github.com/yaroslavvb/notebooks/blob/master/simple_rewiring.ipynb
# make things wide
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

from IPython.display import clear_output, Image, display, HTML

def strip_consts(graph_def, max_const_size=32):
    """Strip large constant values from graph_def."""
    strip_def = tf.GraphDef()
    for n0 in graph_def.node:
        n = strip_def.node.add() 
        n.MergeFrom(n0)
        if n.op == 'Const':
            tensor = n.attr['value'].tensor
            size = len(tensor.tensor_content)
            if size > max_const_size:
                tensor.tensor_content = "<stripped %d bytes>"%size
    return strip_def

def show_graph(graph_def=None, width=1200, height=800, max_const_size=32, ungroup_gradients=False):
    if not graph_def:
        graph_def = tf.get_default_graph().as_graph_def()
        
    """Visualize TensorFlow graph."""
    if hasattr(graph_def, 'as_graph_def'):
        graph_def = graph_def.as_graph_def()
    strip_def = strip_consts(graph_def, max_const_size=max_const_size)
    data = str(strip_def)
    if ungroup_gradients:
        data = data.replace('"gradients/', '"b_')
        #print(data)
    code = """
        <script>
          function load() {{
            document.getElementById("{id}").pbtxt = {data};
          }}
        </script>
        <link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()>
        <div style="height:600px">
          <tf-graph-basic id="{id}"></tf-graph-basic>
        </div>
    """.format(data=repr(data), id='graph'+str(np.random.rand()))

    iframe = """
        <iframe seamless style="width:{}px;height:{}px;border:0" srcdoc="{}"></iframe>
    """.format(width, height, code.replace('"', '&quot;'))
    display(HTML(iframe))

## Create linear graph

In [53]:
tf.reset_default_graph()
node_mbs = 1
length = 4

dtype = np.float32
n = node_mbs * 250000
a0_ = tf.ones((n,), dtype=dtype)
a0 = tf.Variable(a0_, name="a0")
a = a0
for i in range(1, length):
    name = "a"+str(i)
    a = tf.tanh(a, name=name)

grad = tf.gradients([a], [a0])[0]
sess = create_session()

In [54]:
show_graph(ungroup_gradients=True)

# Common functions

In [55]:
g = tf.get_default_graph()

In [56]:
g.get_operations()

[<tf.Operation 'ones' type=Const>,
 <tf.Operation 'a0' type=VariableV2>,
 <tf.Operation 'a0/Assign' type=Assign>,
 <tf.Operation 'a0/read' type=Identity>,
 <tf.Operation 'a1' type=Tanh>,
 <tf.Operation 'a2' type=Tanh>,
 <tf.Operation 'a3' type=Tanh>,
 <tf.Operation 'gradients/Shape' type=Const>,
 <tf.Operation 'gradients/Const' type=Const>,
 <tf.Operation 'gradients/Fill' type=Fill>,
 <tf.Operation 'gradients/a3_grad/TanhGrad' type=TanhGrad>,
 <tf.Operation 'gradients/a2_grad/TanhGrad' type=TanhGrad>,
 <tf.Operation 'gradients/a1_grad/TanhGrad' type=TanhGrad>]

Each operation has inputs (tensors flowing in) and outputs (tensors flowing out). All outputs of an operation are computed/allocated together, so operation is the  core unit of the graph. Various tools (like toposort, networkx) expect graph in dictionary form {node: children}, you can convert to dictionary form using get_graph() utility.

In [12]:
# computation flows from parents to children

def parents(op):
  return set(input.op for input in op.inputs)
  
def children(op):
  return set(op for out in op.outputs for op in out.consumers())

def get_graph():
  """Creates dictionary {node: {child1, child2, ..},..} for current
  TensorFlow graph."""
  
  ops = tf.get_default_graph().get_operations()
  return {op: children(op) for op in ops}

In [57]:
list(toposort(get_graph()))

[{<tf.Operation 'a0/Assign' type=Assign>,
  <tf.Operation 'gradients/a1_grad/TanhGrad' type=TanhGrad>},
 {<tf.Operation 'ones' type=Const>,
  <tf.Operation 'gradients/a2_grad/TanhGrad' type=TanhGrad>},
 {<tf.Operation 'gradients/a3_grad/TanhGrad' type=TanhGrad>},
 {<tf.Operation 'gradients/Fill' type=Fill>, <tf.Operation 'a3' type=Tanh>},
 {<tf.Operation 'gradients/Const' type=Const>,
  <tf.Operation 'a2' type=Tanh>,
  <tf.Operation 'gradients/Shape' type=Const>},
 {<tf.Operation 'a1' type=Tanh>},
 {<tf.Operation 'a0/read' type=Identity>},
 {<tf.Operation 'a0' type=VariableV2>}]

Graph editor has some utilities to select ops using regular expression

In [58]:
ge.select_ops("Tanh", graph=g)

[<tf.Operation 'gradients/a3_grad/TanhGrad' type=TanhGrad>,
 <tf.Operation 'gradients/a2_grad/TanhGrad' type=TanhGrad>,
 <tf.Operation 'gradients/a1_grad/TanhGrad' type=TanhGrad>]

In [59]:
ge.select_ops("a1", graph=g)

[<tf.Operation 'a1' type=Tanh>,
 <tf.Operation 'gradients/a1_grad/TanhGrad' type=TanhGrad>]

Select backprop nodes for a2, a3 and fuse them together

In [60]:
ops_to_fuse=ge.select_ops("a3_grad|a2_grad", graph=g)

In [61]:
ops_to_fuse

[<tf.Operation 'gradients/a3_grad/TanhGrad' type=TanhGrad>,
 <tf.Operation 'gradients/a2_grad/TanhGrad' type=TanhGrad>]

Create part of graph that'll rewired in place of those ops

In [62]:
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.framework import function
from tensorflow.python.ops import functional_ops

@function.Defun(tf.float32, tf.float32, func_name="tanh_grad2")
def tanh_grad2(val1, bp):
    val2 = tf.tanh(val1)
    return gen_math_ops._tanh_grad(val1, gen_math_ops._tanh_grad(val2, bp))

In [64]:
list(ops_to_fuse[0].inputs)

[<tf.Tensor 'a3:0' shape=(250000,) dtype=float32>,
 <tf.Tensor 'gradients/Fill:0' shape=(250000,) dtype=float32>]

In [65]:
list(ops_to_fuse[1].inputs)

[<tf.Tensor 'a2:0' shape=(250000,) dtype=float32>,
 <tf.Tensor 'gradients/a3_grad/TanhGrad:0' shape=(250000,) dtype=float32>]

Create new node with correct inputs

In [66]:
new_node = tanh_grad2(ops_to_fuse[0].inputs[0], ops_to_fuse[0].inputs[1])

Use reroute to connect outputs

In [51]:
ops_to_fuse[1].outputs

[<tf.Tensor 'gradients/a1_grad/TanhGrad:0' shape=(250000,) dtype=float32>]

In [68]:
ge.reroute_a2b_ts(new_node, ops_to_fuse[1].outputs)

1

In [69]:
show_graph(ungroup_gradients=True)