TODO(b/138297412): This colab retains some useful code snippets and demonstrations that used to be in the tf.function/AutoGraph customization tutorial, and should be rolled into the existing docs as part of a broader markdown->colab conversion.

In [0]:
import tensorflow as tf

Define a helper function to demonstrate the kinds of errors you might encounter:

In [0]:
import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

## Using AutoGraph

The [autograph](https://www.tensorflow.org/guide/function) library is fully integrated with `tf.function`, and it will rewrite conditionals and loops which depend on Tensors to run dynamically in the graph.

`tf.cond` and `tf.while_loop` continue to work with `tf.function`, but code with control flow is often easier to write and understand when written in imperative style.

## AutoGraph: Conditionals

AutoGraph will convert `if` statements into the equivalent `tf.cond` calls.

This substitution is made if the condition is a Tensor. Otherwise, the conditional is executed during tracing.

Here is a function that checks if the resulting graph uses `tf.cond`:

In [0]:
def test_tf_cond(f, *args):
  g = f.get_concrete_function(*args).graph
  if any(node.name == 'cond' for node in g.as_graph_def().node):
    print("{}({}) uses tf.cond.".format(
        f.__name__, ', '.join(map(str, args))))
  else:
    print("{}({}) executes normally.".format(
        f.__name__, ', '.join(map(str, args))))

  print("  result: ",f(*args).numpy())

This substitution is made if the condition is a Tensor. Otherwise, the conditional is executed during tracing.

Passing a python `True` executes the conditional normally:

In [0]:
@tf.function
def dropout(x, training=True):
  if training:
    x = tf.nn.dropout(x, rate=0.5)
  return x

In [0]:
test_tf_cond(dropout, tf.ones([10], dtype=tf.float32), True)

But passing a tensor replaces the python `if` with a `tf.cond`:

In [0]:
test_tf_cond(dropout, tf.ones([10], dtype=tf.float32), tf.constant(True))

`tf.cond` has a number of subtleties.

it works by tracing both sides of the conditional, and then choosing the appropriate branch at runtime, depending on the condition. Tracing both sides can result in unexpected execution of Python code.

In [0]:
@tf.function
def f(x):
  if x > 0:
    x = x + 1.
    print("Tracing `then` branch")
  else:
    x = x - 1.
    print("Tracing `else` branch")
  return x

In [0]:
f(-1.0).numpy()

In [0]:
f(1.0).numpy()

In [0]:
f(tf.constant(1.0)).numpy()

It requires that if one branch creates a tensor used downstream, the other branch must also create that tensor.

In [0]:
@tf.function
def f():
  if tf.constant(True):
    x = tf.ones([3, 3])
  return x

# Throws an error because both branches need to define `x`.
with assert_raises(ValueError):
  f()

If you want to be sure that a particular section of control flow is never converted by autograph, then explicitly convert the object to a python type so an error is raised instead: 

In [0]:
@tf.function
def f(x, y):
  if bool(x):
    y = y + 1.
    print("Tracing `then` branch")
  else:
    y = y - 1.
    print("Tracing `else` branch")
  return y

In [0]:
f(True, 0).numpy()

In [0]:
f(False, 0).numpy()

In [0]:
with assert_raises(TypeError):
  f(tf.constant(True), 0.0)

## AutoGraph and loops

AutoGraph has a few simple rules for converting loops.

- `for`: Convert if the iterable is a tensor
- `while`: Convert if the while condition depends on a tensor

If a loop is converted, it will be dynamically unrolled with `tf.while_loop`, or in the special case of a `for x in tf.data.Dataset`, transformed into `tf.data.Dataset.reduce`.

If a loop is _not_ converted, it will be statically unrolled 

In [0]:
def test_dynamically_unrolled(f, *args):
  g = f.get_concrete_function(*args).graph
  if any(node.name == 'while' for node in g.as_graph_def().node):
    print("{}({}) uses tf.while_loop.".format(
        f.__name__, ', '.join(map(str, args))))
  elif any(node.name == 'ReduceDataset' for node in g.as_graph_def().node):
    print("{}({}) uses tf.data.Dataset.reduce.".format(
        f.__name__, ', '.join(map(str, args))))
  else:
    print("{}({}) gets unrolled.".format(
        f.__name__, ', '.join(map(str, args))))


### For loops

Here is a `tf.function` that demonstrates static unrolling:

In [0]:
@tf.function
def for_in_range():
  x = 0
  for i in range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_range)

In [0]:
@tf.function
def for_in_tfrange():
  x = tf.constant(0, dtype=tf.int32)
  for i in tf.range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_tfrange)

In [0]:
@tf.function
def for_in_tfdataset():
  x = tf.constant(0, dtype=tf.int64)
  for i in tf.data.Dataset.range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_tfdataset)

In [0]:
@tf.function
def while_py_cond():
  x = 5
  while x > 0:
    x -= 1
  return x

test_dynamically_unrolled(while_py_cond)

In [0]:
@tf.function
def while_tf_cond():
  x = tf.constant(5)
  while x > 0:
    x -= 1
  return x

test_dynamically_unrolled(while_tf_cond)

 If you have a `break` or early `return` clause that depends on a tensor, the top-level condition or iterable should also be a tensor.

Compare the following examples:

In [0]:
@tf.function
def while_py_true_py_break(x):
  while True:  # py true
    if x == 0: # py break
      break
    x -= 1
  return x

test_dynamically_unrolled(while_py_true_py_break, 5)

In [0]:
@tf.function
def buggy_while_py_true_tf_break(x):
  while True:   # py true
    if tf.equal(x, 0): # tf break
      break
    x -= 1
  return x

with assert_raises(TypeError):
  test_dynamically_unrolled(buggy_while_py_true_tf_break, 5)

In [0]:
@tf.function
def while_tf_true_tf_break(x):
  while tf.constant(True): # tf true
    if x == 0:  # py break
      break
    x -= 1
  return x

test_dynamically_unrolled(while_tf_true_tf_break, 5)

In [0]:
@tf.function
def buggy_py_for_tf_break():
  x = 0
  for i in range(5):  # py for
    if tf.equal(i, 3): # tf break
      break
    x += i
  return x

with assert_raises(TypeError):
  test_dynamically_unrolled(buggy_py_for_tf_break)

In [0]:
@tf.function
def tf_for_py_break():
  x = 0
  for i in tf.range(5): # tf for
    if i == 3:  # py break
      break
    x += i
  return x

test_dynamically_unrolled(tf_for_py_break)

In order to accumulate results from a dynamically unrolled loop, you'll want to use `tf.TensorArray`.


In [0]:
batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])
  
dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))

### Gotcha's

As with `tf.cond`, `tf.while_loop` also comes with a number of subtleties.


#### Zero iterations

Since a loop can execute 0 times, all tensors used downstream of the while_loop must be initialized above the loop.

Here is an example of incorrect code:

In [0]:
@tf.function
def buggy_loop_var_uninitialized():
  for i in tf.range(3):
    x = i
  return x

with assert_raises(ValueError):
  buggy_loop_var_uninitialized()

And the correct version:

In [0]:
@tf.function
def f():
  x = tf.constant(0)
  for i in tf.range(3):
    x = i
  return x

f()

#### Consistent shapes and types

The shape/dtypes of all loop variables must stay consistent with each iteration.

Here is an incorrect example that attempts to change a tensor's type:

In [0]:
@tf.function
def buggy_loop_type_changes():
  x = tf.constant(0, dtype=tf.float32)
  for i in tf.range(3): # Yields tensors of type tf.int32...
    x = i
  return x

with assert_raises(TypeError):
  buggy_loop_type_changes()

Here is an incorrect example that attempts to change a Tensor's shape while iterating:

In [0]:
@tf.function
def buggy_concat():
  x = tf.ones([0, 10])
  for i in tf.range(5):
    x = tf.concat([x, tf.ones([1, 10])], axis=0)
  return x

with assert_raises(ValueError):
  buggy_concat()

In [0]:
@tf.function
def concat_with_padding():
  x = tf.zeros([5, 10])
  for i in tf.range(5):
    x = tf.concat([x[:i], tf.ones([1, 10]), tf.zeros([4-i, 10])], axis=0)
    x.set_shape([5, 10])
  return x

concat_with_padding()
