##### Copyright 2020 The TensorFlow Authors.


In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Better performance with tf.function

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/guide/function"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/function.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/guide/function.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/function.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

In TensorFlow 2, eager execution is turned on by default. The user interface is intuitive and flexible (running one-off operations is much easier
and faster), but this can come at the expense of performance and deployability.

To get performant and portable models, use `tf.function` to make graphs out of your programs. However there are pitfalls to be wary of - `tf.function` is not a magical make-it-faster bullet!

This document will help you conceptualize what `tf.function` is doing under the hood, so that you can master its use.

The main takeaways and recommendations are:

- Debug in Eager mode, then decorate with `@tf.function`.
- Don't rely on Python side effects like object mutation or list appends.
- tf.function works best with TensorFlow ops; NumPy and Python calls are converted to constants.


## Setup

In [0]:
import tensorflow as tf

Define a helper function to demonstrate the kinds of errors you might encounter:

In [0]:
import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

## Basics

A `tf.function` you define is just like a core TensorFlow operation: You can execute it eagerly; you can compute gradients; and so on.

In [0]:
@tf.function
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]

In [0]:
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)

You can use functions inside functions.

In [0]:
@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))

Functions can be faster than eager code, especially for graphs with many small ops. But for graphs with a few expensive ops (like convolutions), you may not see much speedup.


In [0]:
import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")


## Debugging

In general, debugging code is easier in Eager mode than inside a `tf.function`. You should ensure that your code executes error-free in Eager mode before decorating with `tf.function`. To assist in the debugging process, you can call `tf.config.run_functions_eagerly(True)` to globally disable and reenable `tf.function`.

When tracking down issues that only appear within `tf.function`, here are some tips:
- Plain old Python `print` calls only execute during tracing, helping you track down when your functions get (re)traced.
- `tf.print` calls will execute every time, and can help you track down intermediate values during execution.
- `tf.debugging.enable_check_numerics` is an easy way to track down where NaNs and Inf are created.
- `pdb` can help you understand what's going on during tracing. (Caveat: PDB will drop you into AutoGraph-transformed source code.)

## Tracing and polymorphism

Python's dynamic typing means that you can call functions with a variety of argument types, and Python will do something different in each scenario.

On the other hand, TensorFlow graphs require static dtypes and shape dimensions. `tf.function` bridges this gap by retracing the function when necessary to generate the correct graphs. Most of the subtlety of `tf.function` usage stems from this retracing behavior.

You can call a function with arguments of different types to see what is happening.

In [0]:
# Functions are polymorphic

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()


To control the tracing behavior, you can use the following techniques:

Create a new `tf.function`. Separate `tf.function` objects are guaranteed not to share traces.

In [0]:
def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()

Use `get_concrete_function` method to get a specific trace.


In [0]:
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
print("Using a concrete trace with incompatible types will throw an error")
with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))

Specify `input_signature` in `tf.function` to limit tracing.

In [0]:
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# We specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))


## When to retrace?

A polymorphic `tf.function` keeps a cache of concrete functions generated by tracing. The cache keys are effectively tuples of keys generated from the function args and kwargs. The key generated for a `tf.Tensor` argument is its number of dimensions and type. The key generated for a Python primitive is its value. For all other Python types, the keys are based on the object `id()` so that methods are traced independently for each instance of a class. In the future, TensorFlow may add more sophisticated cachi
ng for Python objects that can be safely converted to tensors.

See [Concrete functions](../../guide/concrete_function.ipynb)


## Python or Tensor args?

Often, Python arguments are used to control hyperparameters and graph constructions - for example, `num_layers=10` or `training=True` or `nonlinearity='relu'`. So if the Python argument changes, it makes sense that you'd have to retrace the graph.

However, it's possible that a Python argument is not being used to control graph construction. In these cases, a change in the Python value can trigger needless retracing. Take, for example, this training loop, which AutoGraph will dynamically unroll. Despite the multiple traces, the generated graph is actually identical, so this is a bit inefficient.

In [0]:
def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = {}".format(num_steps))
  for _ in tf.range(num_steps):
    train_one_step()

train(num_steps=10)
train(num_steps=20)


The simple workaround here is to cast your arguments to Tensors if they do not affect the shape of the generated graph.

In [0]:
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))

## Side effects in `tf.function`

In general, Python side effects (like printing or mutating objects) only happen during tracing. So how can you reliably trigger side effects from `tf.function`?

The general rule of thumb is to only use Python side effects to debug your traces. Otherwise, TensorFlow ops like `tf.Variable.assign`, `tf.print`, and `tf.summary` are the best way to ensure your code will be traced and executed by the TensorFlow runtime with each call. In general using a functional style will yield the best results. 

In [0]:
@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)


If you would like to execute Python code during each invocation of a `tf.function`, `tf.py_function` is an exit hatch. The drawback of `tf.py_function` is that it's not portable or particularly performant, nor does it work well in distributed (multi-GPU, TPU) setups. Also, since `tf.py_function` has to be wired into the graph for differentiability, it casts all inputs/outputs to tensors.

In [0]:
external_list = []

def side_effect(x):
  print('Python side effect')
  external_list.append(x)

@tf.function
def f(x):
  tf.py_function(side_effect, inp=[x], Tout=[])

f(1)
f(1)
f(1)
assert len(external_list) == 3
# .numpy() call required because py_function casts 1 to tf.constant(1)
assert external_list[0].numpy() == 1


## Beware of Python state

Many Python features, such as generators and iterators, rely on the Python runtime to keep track of state. In general, while these constructs work as expected in Eager mode, many unexpected things can happen inside a `tf.function` due to tracing behavior.

To give one example, advancing iterator state is a Python side effect and therefore only happens during tracing.

In [0]:
external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
  external_var.assign_add(next(iterator))
  tf.print("Value of external_var:", external_var)

iterator = iter([0, 1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)


## Variables

We can use the same idea of leveraging the intended execution order of the code to make variable creation and utilization very easy in `tf.function`. There is one very important caveat, though, which is that with variables it's possible to write code which behaves differently in eager mode and graph mode.

Specifically, this will happen when you create a new Variable with each call. Due to tracing semantics, `tf.function` will reuse the same variable each call, but eager mode will create a new variable with each call. To guard against this mistake, `tf.function` will raise an error if it detects dangerous variable creation behavior.

In [0]:
@tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

with assert_raises(ValueError):
  f(1.0)

Non-ambiguous code is ok, though.

In [0]:
v = tf.Variable(1.0)

@tf.function
def f(x):
  return v.assign_add(x)

print(f(1.0))  # 2.0
print(f(2.0))  # 4.0


You can also create variables inside a tf.function as long as we can prove
that those variables are created only the first time the function is executed.

In [0]:
class C:
  pass

obj = C()
obj.v = None

@tf.function
def g(x):
  if obj.v is None:
    obj.v = tf.Variable(1.0)
  return obj.v.assign_add(x)

print(g(1.0))  # 2.0
print(g(2.0))  # 4.0

Variable initializers can depend on function arguments and on values of other
variables. We can figure out the right initialization order using the same
method we use to generate control dependencies.

In [0]:
state = []
@tf.function
def fn(x):
  if not state:
    state.append(tf.Variable(2.0 * x))
    state.append(tf.Variable(state[0] * 3.0))
  return state[0] * x * state[1]

print(fn(tf.constant(1.0)))
print(fn(tf.constant(3.0)))


## AutoGraph Transformations

AutoGraph is a library that is on by default in `tf.function`, and transforms a subset of Python Eager code into graph-compatible TensorFlow ops. This includes control flow like `if`, `for`, `while`.

TensorFlow ops like `tf.cond` and `tf.while_loop` continue to work, but control flow is often easier to write and understand when written in Python.

In [0]:
# Simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))

If you're curious you can inspect the code autograph generates.

In [0]:
print(tf.autograph.to_code(f.python_function))

### Conditionals

AutoGraph will convert some `if <condition>` statements into the equivalent `tf.cond` calls. This substitution is made if `<condition>` is a Tensor. Otherwise, the `if` statement is executed as a Python conditional.

A Python conditional executes during tracing, so exactly one branch of the conditional will be added to the graph. Without AutoGraph, this traced graph would be unable to take the alternate branch if there is data-dependent control flow.

`tf.cond` traces and adds both branches of the conditional to the graph, dynamically selecting a branch at execution time. Tracing can have unintended side effects; see [AutoGraph tracing effects](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#effects-of-the-tracing-process) for more.

In [0]:
@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))

See the [reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#if-statements) for additional restrictions on AutoGraph-converted if statements.

### Loops

AutoGraph will convert some `for` and `while` statements into the equivalent TensorFlow looping ops, like `tf.while_loop`. If not converted, the `for` or `while` loop is executed as a Python loop.

This substitution is made in the following situations:

- `for x in y`: if `y` is a Tensor, convert to `tf.while_loop`. In the special case where `y` is a `tf.data.Dataset`, a combination of `tf.data.Dataset` ops are generated.
- `while <condition>`: if `<condition>` is a Tensor, convert to `tf.while_loop`.

A Python loop executes during tracing, adding additional ops to the `tf.Graph` for every iteration of the loop.

A TensorFlow loop traces the body of the loop, and dynamically selects how many iterations to run at execution time.  The loop body only appears once in the generated `tf.Graph`.

See the [reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#while-statements) for additional restrictions on AutoGraph-converted `for` and `while` statements.

#### Looping over Python data

A common pitfall is to loop over Python/Numpy data within a `tf.function`. This loop will execute during the tracing process, adding a copy of your model to the `tf.Graph` for each iteration of the loop.

If you want to wrap the entire training loop in `tf.function`, the safest way to do this is to wrap your data as a `tf.data.Dataset` so that AutoGraph will dynamically unroll the training loop.

In [0]:
def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))

When wrapping Python/Numpy data in a Dataset, be mindful of `tf.data.Dataset.from_generator` versus ` tf.data.Dataset.from_tensors`. The former will keep the data in Python and fetch it via `tf.py_function` which can have performance implications, whereas the latter will bundle a copy of the data as one large `tf.constant()` node in the graph, which can have memory implications.

Reading data from files via TFRecordDataset/CsvDataset/etc. is the most effective way to consume data, as then TensorFlow itself can manage the asynchronous loading and prefetching of data, without having to involve Python. To learn more, see the [tf.data guide](../../guide/data).

#### Accumulating values in a loop

A common pattern is to accumulate intermediate values from a loop. Normally, this is accomplished by appending to a Python list or adding entries to a Python dictionary. However, as these are Python side effects, they will not work as expected in a dynamically unrolled loop. Use `tf.TensorArray` to accumulate results from a dynamically unrolled loop.

In [0]:
batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])
  
dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))

## Further reading

To learn more about graph optimizations that are performed after tracing a `tf.function`, see the [Grappler guide](../../guide/graph_optimization). To learn how to optimize your data pipeline and profile your model, see the [Profiler guide](../../guide/profiler.md).