# Introduction to graphs and tf.function 

In [1]:
import tensorflow as tf
import timeit
from datetime import datetime

2024-11-19 14:16:30.691221: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Taking advantage of graphs

In [7]:
# Define a Python function
def a_regular_function(x, y, b):
    x = tf.matmul(x, y)
    x = x + b
    return x

# The Python type of `a_function_that_uses_a_graph` will now be a
# `PolymorphicFunction`.
a_function_that_uses_a_graph = tf.function(a_regular_function)

# Make some tensors
x1 = tf.constant([[1.0, 2.0]])
y1 = tf.constant([[2.0], [3.0]])
b1 = tf.constant(4.0)

orig_value = a_regular_function(x1, y1, b1).numpy()

# Call a `tf.function` like a Python function.
tf_function_value = a_function_that_uses_a_graph(x1, y1, b1).numpy()
assert(orig_value == tf_function_value)

In [8]:
def inner_function(x, y, b):
    x = tf.matmul(x, y)
    x = x + b
    return x

# Using the `tf.function` decorator makes `outer_function` into a
# `PolymorphicFunction`.
@tf.function
def outer_function(x):
    y = tf.constant([[2.0], [3.0]])
    b = tf.constant(4.0)
    return inner_function(x, y, b)

# Note that the callable will create a graph that
# includes `inner_function` as well as `outer_function`.
outer_function(tf.constant([[1.0, 2.0]])).numpy()

array([[12.]], dtype=float32)

## Converting Python functions to graphs

In [10]:
def simple_relu(x):
    if tf.greater(x, 0):
        return x
    else:
        return 0

# Using `tf.function` makes `tf_simple_relu` a `PolymorphicFunction` that wraps
# `simple_relu`.
tf_simple_relu = tf.function(simple_relu)

print("First branch, with graph:", tf_simple_relu(tf.constant(1)).numpy())
print("Second branch, with graph:", tf_simple_relu(tf.constant(-1)).numpy())

First branch, with graph: 1
Second branch, with graph: 0


In [11]:
# This is the graph-generating output of AutoGraph.
print(tf.autograph.to_code(simple_relu))

def tf__simple_relu(x):
    with ag__.FunctionScope('simple_relu', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (do_return, retval_)

        def set_state(vars_):
            nonlocal retval_, do_return
            do_return, retval_ = vars_

        def if_body():
            nonlocal retval_, do_return
            try:
                do_return = True
                retval_ = ag__.ld(x)
            except:
                do_return = False
                raise

        def else_body():
            nonlocal retval_, do_return
            try:
                do_return = True
                retval_ = 0
            except:
                do_return = False
                raise
        ag__.if_stmt(ag__.converted_call(ag__.ld(tf).greater, (ag__.ld(x), 0), None, fscope), if_body

In [12]:
# This is the graph itself.
print(tf_simple_relu.get_concrete_function(tf.constant(1)).graph.as_graph_def())

node {
  name: "x"
  op: "Placeholder"
  attr {
    key: "shape"
    value {
      shape {
      }
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "_user_specified_name"
    value {
      s: "x"
    }
  }
}
node {
  name: "Greater/y"
  op: "Const"
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 0
      }
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "Greater"
  op: "Greater"
  input: "x"
  input: "Greater/y"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "cond"
  op: "StatelessIf"
  input: "Greater"
  input: "x"
  attr {
    key: "then_branch"
    value {
      func {
        name: "cond_true_63"
      }
    }
  }
  attr {
    key: "output_shapes"
    value {
      list {
        shape {
        }
        shape {
        }
      }
    }
  }
  attr {
    key: "else_branch"
    

## Polymorphism: one `tf.function`, many graphs

In [13]:
@tf.function
def my_relu(x):
    return tf.maximum(0.0, x)

# `my_relu` creates new graphs as it observes different input types.
print(my_relu(tf.constant(5.5)))
print(my_relu([1, -1]))
print(my_relu(tf.constant([3., -3.])))

tf.Tensor(5.5, shape=(), dtype=float32)
tf.Tensor([1. 0.], shape=(2,), dtype=float32)
tf.Tensor([3. 0.], shape=(2,), dtype=float32)


In [14]:
# These two calls do *not* create new graphs.
print(my_relu(tf.constant(-2.5))) # Input type matches `tf.constant(5.5)`.
print(my_relu(tf.constant([-1., 1.]))) # Input type matches `tf.constant([3., -3.])`.

tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor([0. 1.], shape=(2,), dtype=float32)


In [15]:
print(my_relu.pretty_printed_concrete_signatures())

Input Parameters:
  x (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.float32, name=None)
Output Type:
  TensorSpec(shape=(), dtype=tf.float32, name=None)
Captures:
  None

Input Parameters:
  x (POSITIONAL_OR_KEYWORD): List[Literal[1], Literal[-1]]
Output Type:
  TensorSpec(shape=(2,), dtype=tf.float32, name=None)
Captures:
  None

Input Parameters:
  x (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(2,), dtype=tf.float32, name=None)
Output Type:
  TensorSpec(shape=(2,), dtype=tf.float32, name=None)
Captures:
  None


## Using `tf.function`

### Graph execution vs. eager execution

In [20]:
@tf.function
def get_MSE(y_true, y_pred):
    sq_diff = tf.pow(y_true - y_pred, 2)
    return tf.reduce_mean(sq_diff)

In [21]:
y_true = tf.random.uniform([5], maxval=10, dtype=tf.int32)
y_pred = tf.random.uniform([5], maxval=10, dtype=tf.int32)
print(y_true)
print(y_pred)

tf.Tensor([6 3 4 4 7], shape=(5,), dtype=int32)
tf.Tensor([6 1 3 1 5], shape=(5,), dtype=int32)


In [22]:
get_MSE(y_true, y_pred)

<tf.Tensor: shape=(), dtype=int32, numpy=3>

In [23]:
tf.config.run_functions_eagerly(True)

In [24]:
get_MSE(y_true, y_pred)

<tf.Tensor: shape=(), dtype=int32, numpy=3>

In [25]:
# Don't forget to set it back when you are done.
tf.config.run_functions_eagerly(False)

In [26]:
@tf.function
def get_MSE(y_true, y_pred):
    print("Calculating MSE!")
    sq_diff = tf.pow(y_true - y_pred, 2)
    return tf.reduce_mean(sq_diff)

In [27]:
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)

Calculating MSE!


In [28]:
# Now, globally set everything to run eagerly to force eager execution.
tf.config.run_functions_eagerly(True)

In [29]:
# Observe what is printed below.
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)

Calculating MSE!
Calculating MSE!
Calculating MSE!


In [30]:
tf.config.run_functions_eagerly(False)

In [31]:
def unused_return_eager(x):
    # Get index 1 will fail when `len(x) == 1`
    tf.gather(x, [1]) # unused 
    return x

try:
  print(unused_return_eager(tf.constant([0.0])))
except tf.errors.InvalidArgumentError as e:
  # All operations are run during eager execution so an error is raised.
  print(f'{type(e).__name__}: {e}')

InvalidArgumentError: {{function_node __wrapped__GatherV2_device_/job:localhost/replica:0/task:0/device:CPU:0}} indices[0] = 1 is not in [0, 1) [Op:GatherV2]


2024-11-19 14:39:09.629002: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: INVALID_ARGUMENT: indices[0] = 1 is not in [0, 1)


In [32]:
@tf.function
def unused_return_graph(x):
    tf.gather(x, [1]) # unused
    return x

# Only needed operations are run during graph execution. The error is not raised.
print(unused_return_graph(tf.constant([0.0])))

tf.Tensor([0.], shape=(1,), dtype=float32)


## Seeing the speed-up

In [35]:
x = tf.random.uniform(shape=[10, 10], minval=-1, maxval=2, dtype=tf.dtypes.int32)
x

<tf.Tensor: shape=(10, 10), dtype=int32, numpy=
array([[-1,  1,  1,  0,  1,  0,  1,  0,  1, -1],
       [-1, -1,  1, -1, -1, -1,  1,  1, -1,  1],
       [ 1, -1,  1,  1,  1, -1, -1, -1,  0,  0],
       [ 1, -1,  1,  0,  1,  1,  0,  1,  1, -1],
       [-1, -1, -1,  1,  0,  0,  0, -1,  0,  1],
       [-1,  0,  0,  1,  0,  0, -1,  1,  1, -1],
       [ 0,  0, -1,  0, -1,  0, -1,  0, -1, -1],
       [ 1, -1,  1,  1, -1,  0,  1, -1, -1,  1],
       [ 0, -1, -1,  1,  0,  1,  0,  0,  1, -1],
       [ 0,  0,  1,  1,  1,  1, -1, -1,  0,  1]], dtype=int32)>

In [37]:
def power(x, y):
    result = tf.eye(10, dtype=tf.dtypes.int32)
    for _ in range(y):
        result = tf.matmul(x, result)
    return result

In [38]:
print("Eager execution:", timeit.timeit(lambda: power(x, 100), number=1000), "seconds")

Eager execution: 5.599606186005985 seconds


In [39]:
power_as_graph = tf.function(power)
print("Graph execution:", timeit.timeit(lambda: power_as_graph(x, 100), number=1000), "seconds")

Graph execution: 0.5835083819983993 seconds


## When is a `tf.function` tracing?

In [40]:
@tf.function
def a_function_with_python_side_effect(x):
  print("Tracing!") # An eager-only side effect.
  return x * x + tf.constant(2)

# This is traced the first time.
print(a_function_with_python_side_effect(tf.constant(2)))
# The second time through, you won't see the side effect.
print(a_function_with_python_side_effect(tf.constant(3)))

Tracing!
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(11, shape=(), dtype=int32)


In [41]:
# This retraces each time the Python argument changes,
# as a Python argument could be an epoch count or other
# hyperparameter.
print(a_function_with_python_side_effect(2))
print(a_function_with_python_side_effect(3))

Tracing!
tf.Tensor(6, shape=(), dtype=int32)
Tracing!
tf.Tensor(11, shape=(), dtype=int32)
