##### Copyright 2020 The TensorFlow Authors.


In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 使用 tf.function 提升性能

<table class="tfo-notebook-buttons" align="left">
  <td>     <a target="_blank" href="https://www.tensorflow.org/guide/function"><img src="https://www.tensorflow.org/images/tf_logo_32px.png">在 TensorFlow.org 上查看</a>
</td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/function.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png">Run in Google Colab</a>
  </td>
  <td>     <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/guide/function.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png">在 GitHub 上查看源代码</a>
</td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/function.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png">Download notebook</a>
  </td>
</table>

In TensorFlow 2, [eager execution](eager.ipynb) is turned on by default. The user interface is intuitive and flexible (running one-off operations is much easier and faster), but this can come at the expense of performance and deployability.

You can use `tf.function` to make graphs out of your programs. It is a transformation tool that creates Python-independent dataflow graphs out of your Python code. This will help you create performant and portable models, and it is required to use `SavedModel`.

This guide will help you conceptualize how `tf.function` works under the hood, so you can use it effectively.

The main takeaways and recommendations are:

- Debug in eager mode, then decorate with `@tf.function`.
- Don't rely on Python side effects like object mutation or list appends.
- `tf.function` works best with TensorFlow ops; NumPy and Python calls are converted to constants.


## Setup

In [None]:
# Update TensorFlow, as this notebook requires version 2.9 or later
!pip install -q -U tensorflow>=2.9.0
import tensorflow as tf

Define a helper function to demonstrate the kinds of errors you might encounter:

In [None]:
import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

## Basics

### Usage

您定义的 `Function`（例如，通过应用 `@tf.function` 装饰器）就像核心 TensorFlow 运算：您可以在 Eager 模式下执行它，可以计算梯度，等等。

In [None]:
@tf.function  # The decorator converts `add` into a `Function`.
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]

In [None]:
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)

You can use `Function`s inside other `Function`s.

In [None]:
@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))

`Function`s can be faster than eager code, especially for graphs with many small ops. But for graphs with a few expensive ops (like convolutions), you may not see much speedup.


In [None]:
import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")


### Tracing

本部分介绍了 `Function` 的幕后运作方式，包括*未来可能会发生变化*的实现细节。但是，当您了解跟踪的原因和时间后，就能够更轻松高效地使用 `tf.function`！

#### 什么是“跟踪”？

`Function` 在 [TensorFlow 计算图](https://www.tensorflow.org/guide/intro_to_graphs#what_are_graphs)中运行您的程序。但是，`tf.Graph` 不能代表您在 Eager TensorFlow 程序中编写的全部内容。例如，Python 支持多态，但是 `tf.Graph` 要求其输入具有指定的数据类型和维度。或者，您可能执行辅助任务，例如读取命令行参数、引发错误或使用更复杂的 Python 对象。这些内容均不能在 `tf.Graph` 中运行。

`Function` 通过将代码分为以下两个阶段填补了这一空缺：

1. 第一阶段称为**跟踪**，在这一阶段中，`Function` 会创建新的 `tf.Graph`。Python 代码可以正常运行，但是所有 TensorFlow 运算（例如添加两个张量）都会被*推迟*：它们会被 `tf.Graph` 捕获而不运行。

2. 在第二阶段中，将运行包含第一阶段中推迟的全部内容的 `tf.Graph`。此阶段比跟踪阶段快得多。

根据输入，`Function` 在调用时并非总会运行第一阶段。请参阅下方的[跟踪规则](#rules_of_tracing)以更好地了解其决定方式。跳过第一阶段并仅执行第二阶段，可以实现 TensorFlow 的高性能。

当 `Function` 决定跟踪时，在跟踪阶段完成后会立即运行第二阶段，因此调用 `Function` 会创建并运行 `tf.Graph`。稍后，您将了解如何使用 [`get_concrete_function`](#obtaining_concrete_functions) 来仅运行跟踪阶段。

当您将不同类型的参数传递给 `Function` 时，两个阶段都将运行：


In [None]:
@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()


请注意，如果重复使用同一参数类型调用 `Function`，TensorFlow 会跳过跟踪阶段并重用之前跟踪的计算图，因为后面的调用生成的计算图可能相同。

In [None]:
# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))

You can use `pretty_printed_concrete_signatures()` to see all of the available traces:

In [None]:
print(double.pretty_printed_concrete_signatures())

So far, you've seen that `tf.function` creates a cached, dynamic dispatch layer over TensorFlow's graph tracing logic. To be more specific about the terminology:

- `tf.Graph` 与语言无关，是 TensorFlow 计算的原始可移植表示。
- `ConcreteFunction` 封装 `tf.Graph`。
- A `Function` manages a cache of `ConcreteFunction`s and picks the right one for your inputs.
- `tf.function` wraps a Python function, returning a `Function` object.
- **跟踪**会创建 `tf.Graph` 并将其封装在 `ConcreteFunction` 中，也称为**跟踪**。


#### 跟踪规则

被调用时，`Function` 使用每个参数的 `tf.types.experimental.TraceType` 将调用参数与现有的 `ConcreteFunction` 匹配。如果找到匹配的 `ConcreteFunction`，则将调用分派给它。如果未找到匹配项，则跟踪新的 `ConcreteFunction`。

如果找到多个匹配项，则会选择最具体的签名。匹配是通过[子类型化](https://en.wikipedia.org/wiki/Subtyping)完成的，就像 C++ 或 Java 中的普通函数调用一样。例如，`TensorShape([1, 2])` 是 `TensorShape([None, None])` 的子类型，因此可以将使用 `TensorShape([1, 2])` 对 tf.function 进行的调用分派到使用 `TensorShape([None, None])` 生成的 `ConcreteFunction`。但是，如果具有 `TensorShape([1, None])` 的 `ConcreteFunction` 也存在，那么它将被优先考虑，因为它更具体。

`TraceType` 由输入参数确定，具体如下所示：

- 对于 `Tensor`，类型由 `Tensor` 的 `dtype` 和 `shape` 参数化；有秩形状是无秩形状的子类型；固定维度是未知维度的子类型
- 对于 `Variable`，类型类似于 `Tensor`，但还包括变量的唯一资源 ID，这是正确连接控制依赖项所必需的
- 对于 Python 基元值，类型对应于**值**本身。例如，值为 `3` 的 `TraceType` 是 `LiteralTraceType<3>`，而不是 `int`。
- 对于 `list` 和 `tuple` 等 Python 有序容器，类型是通过其元素的类型来参数化的；例如，`[1, 2]` 的类型是 `ListTraceType<LiteralTraceType<1>, LiteralTraceType<2>>`，`[2, 1]` 的类型是 `ListTraceType<LiteralTraceType<2>, LiteralTraceType<1>>`，两者不同。
- 对于 `dict` 等 Python 映射，类型也是从相同的键到值类型而不是实际值的映射。例如，`{1: 2, 3: 4}` 的类型为 `MappingTraceType<<KeyValue<1, LiteralTraceType<2>>>, <KeyValue<3, LiteralTraceType<4>>>>`。但是，与有序容器不同的是，`{1: 2, 3: 4}` 和 `{3: 4, 1: 2}` 具有等价的类型。
- 对于实现 `__tf_tracing_type__` 方法的 Python 对象，类型为该方法返回的任何内容
- 对于任何其他 Python 对象，类型是通用的 `TraceType`，它使用对象的 Python 相等性和散列进行匹配。（注：它依赖于对对象的[弱引用](https://docs.python.org/3/library/weakref.html)，因此仅在对象处于范围内/未被删除时才有效。）


注：`TraceType` 基于 `Function` 输入参数，因此仅对全局变量和<a>自由变量</a>进行更改将不会创建新的跟踪记录。有关处理 Python 全局变量和自由变量的建议做法，请参阅[本部分](https://docs.python.org/3/reference/executionmodel.html#binding-of-names)。

### Controlling retracing

回溯即 `Function` 创建多个跟踪记录的过程，可以确保 TensorFlow 为每组输入生成正确的计算图。但是，跟踪非常消耗资源！如果 `Function` 为每一次调用都回溯新的计算图，您会发现代码的执行速度远不如不使用 `tf.function` 时快。

To control the tracing behavior, you can use the following techniques:

#### 将固定的 `input_signature` 传递给 `tf.function`

In [None]:
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))


#### 使用未知维度以获得灵活性

由于 TensorFlow 根据其形状匹配张量，因此，对于可变大小输入，使用 `None` 维度作为通配符可以让 `Function` 重复使用跟踪记录。对于每个批次，如果有不同长度的序列或不同大小的图像，则会出现可变大小输入（请参阅 [Transformer](../tutorials/text/transformer.ipynb) 和 [Deep Dream](../tutorials/generative/deepdream.ipynb) 教程了解示例）。

In [None]:
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))


#### 传递张量而不是 Python 文字

通常，Python 参数用于控制超参数和计算图构造，例如 `num_layers=10`、`training=True` 或 `nonlinearity='relu'`。所以，如果 Python 参数改变，则有必要回溯计算图。

However, it's possible that a Python argument is not being used to control graph construction. In these cases, a change in the Python value can trigger needless retracing. Take, for example, this training loop, which AutoGraph will dynamically unroll. Despite the multiple traces, the generated graph is actually identical, so retracing is unnecessary.

In [None]:
def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))

If you need to force retracing, create a new `Function`. Separate `Function` objects are guaranteed not to share traces.

In [None]:
def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()

#### 使用跟踪协议

在可能的情况下，您应当首选将 Python 类型转换为 `tf.experimental.ExtensionType`。此外，`ExtensionType` 的 `TraceType` 是与其关联的 `tf.TypeSpec`。因此，如果需要，您只需重写默认的 `tf.TypeSpec` 即可控制 `ExtensionType` 的 `Tracing Protocol`。请参阅[扩展程序类型](extension_type.ipynb)指南中的*自定义 ExtensionType 的 TypeSpec*部分以了解详情。

否则，要直接控制 `Function` 何时应针对特定 Python 类型进行重新跟踪，您可以自行为其实现 `Tracing Protocol`。

In [None]:
@tf.function
def get_mixed_flavor(fruit_a, fruit_b):
  return fruit_a.flavor + fruit_b.flavor

class Fruit:
  flavor = tf.constant([0, 0])

class Apple(Fruit):
  flavor = tf.constant([1, 2])

class Mango(Fruit):
  flavor = tf.constant([3, 4])

# As described in the above rules, a generic TraceType for `Apple` and `Mango`
# is generated (and a corresponding ConcreteFunction is traced) but it fails to 
# match the second function call since the first pair of Apple() and Mango() 
# have gone out out of scope by then and deleted.
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again

# However, each subclass of the `Fruit` class has a fixed flavor, and you
# can reuse an existing traced concrete function if it was the same
# subclass. Avoiding such unnecessary tracing of concrete functions
# can have significant performance benefits.

class FruitTraceType(tf.types.experimental.TraceType):
  def __init__(self, fruit_type):
    self.fruit_type = fruit_type

  def is_subtype_of(self, other):
      return (type(other) is FruitTraceType and
              self.fruit_type is other.fruit_type)

  def most_specific_common_supertype(self, others):
      return self if all(self == other for other in others) else None

  def __eq__(self, other):
    return type(other) is FruitTraceType and self.fruit_type == other.fruit_type
  
  def __hash__(self):
    return hash(self.fruit_type)

class FruitWithTraceType:

  def __tf_tracing_type__(self, context):
    return FruitTraceType(type(self))

class AppleWithTraceType(FruitWithTraceType):
  flavor = tf.constant([1, 2])

class MangoWithTraceType(FruitWithTraceType):
  flavor = tf.constant([3, 4])

# Now if you try calling it again:
get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Traces a new concrete function
get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Re-uses the traced concrete function

### Obtaining concrete functions

Every time a function is traced, a new concrete function is created. You can directly obtain a concrete function, by using `get_concrete_function`.


In [None]:
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))


In [None]:
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))

Printing a `ConcreteFunction` displays a summary of its input arguments (with types) and its output type.

In [None]:
print(double_strings)

You can also directly retrieve a concrete function's signature.

In [None]:
print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)

Using a concrete trace with incompatible types will throw an error

In [None]:
with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))

You may notice that Python arguments are given special treatment in a concrete function's input signature. Prior to TensorFlow 2.3, Python arguments were simply removed from the concrete function's signature. Starting with TensorFlow 2.3, Python arguments remain in the signature, but are constrained to take the value set during tracing.

In [None]:
@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)

In [None]:
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)

### Obtaining graphs

Each concrete function is a callable wrapper around a `tf.Graph`. Although retrieving the actual `tf.Graph` object is not something you'll normally need to do, you can obtain it easily from any concrete function.

In [None]:
graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')


### Debugging

In general, debugging code is easier in eager mode than inside `tf.function`. You should ensure that your code executes error-free in eager mode before decorating with `tf.function`. To assist in the debugging process, you can call `tf.config.run_functions_eagerly(True)` to globally disable and reenable `tf.function`.

When tracking down issues that only appear within `tf.function`, here are some tips:

- Plain old Python `print` calls only execute during tracing, helping you track down when your function gets (re)traced.
- `tf.print` calls will execute every time, and can help you track down intermediate values during execution.
- `tf.debugging.enable_check_numerics` is an easy way to track down where NaNs and Inf are created.
- `pdb`（[Python 调试器](https://docs.python.org/3/library/pdb.html)）可以帮助您理解跟踪的详细过程。（提醒：使用 `pdb` 调试时，AutoGraph 会自动转换 Python 源代码。）

## AutoGraph 转换

AutoGraph is a library that is on by default in `tf.function`, and transforms a subset of Python eager code into graph-compatible TensorFlow ops. This includes control flow like `if`, `for`, `while`.

TensorFlow ops like `tf.cond` and `tf.while_loop` continue to work, but control flow is often easier to write and understand when written in Python.

In [None]:
# A simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))

If you're curious you can inspect the code autograph generates.

In [None]:
print(tf.autograph.to_code(f.python_function))

### Conditionals

AutoGraph will convert some `if <condition>` statements into the equivalent `tf.cond` calls. This substitution is made if `<condition>` is a Tensor. Otherwise, the `if` statement is executed as a Python conditional.

A Python conditional executes during tracing, so exactly one branch of the conditional will be added to the graph. Without AutoGraph, this traced graph would be unable to take the alternate branch if there is data-dependent control flow.

`tf.cond` 跟踪并将条件的两个分支添加到计算图，在执行时动态选择分支。跟踪可能产生意外的副作用；请参阅 [AutoGraph 跟踪作用](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#effects-of-the-tracing-process)以了解详情。

In [None]:
@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))

See the [reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#if-statements) for additional restrictions on AutoGraph-converted if statements.

### Loops

AutoGraph will convert some `for` and `while` statements into the equivalent TensorFlow looping ops, like `tf.while_loop`. If not converted, the `for` or `while` loop is executed as a Python loop.

This substitution is made in the following situations:

- `for x in y`: if `y` is a Tensor, convert to `tf.while_loop`. In the special case where `y` is a `tf.data.Dataset`, a combination of `tf.data.Dataset` ops are generated.
- `while <condition>`: if `<condition>` is a Tensor, convert to `tf.while_loop`.

A Python loop executes during tracing, adding additional ops to the `tf.Graph` for every iteration of the loop.

A TensorFlow loop traces the body of the loop, and dynamically selects how many iterations to run at execution time.  The loop body only appears once in the generated `tf.Graph`.

See the [reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#while-statements) for additional restrictions on AutoGraph-converted `for` and `while` statements.

#### Looping over Python data

A common pitfall is to loop over Python/NumPy data within a `tf.function`. This loop will execute during the tracing process, adding a copy of your model to the `tf.Graph` for each iteration of the loop.

If you want to wrap the entire training loop in `tf.function`, the safest way to do this is to wrap your data as a `tf.data.Dataset` so that AutoGraph will dynamically unroll the training loop.

In [None]:
def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))

When wrapping Python/NumPy data in a Dataset, be mindful of `tf.data.Dataset.from_generator` versus ` tf.data.Dataset.from_tensors`. The former will keep the data in Python and fetch it via `tf.py_function` which can have performance implications, whereas the latter will bundle a copy of the data as one large `tf.constant()` node in the graph, which can have memory implications.

通过 `TFRecordDataset`、`CsvDataset` 等从文件中读取数据是最高效的数据使用方式，因为这样 TensorFlow 就可以自行管理数据的异步加载和预提取，不必利用 Python。要了解详细信息，请参阅 [`tf.data`：构建 TensorFlow 输入流水线](../../guide/data)指南。

#### Accumulating values in a loop

A common pattern is to accumulate intermediate values from a loop. Normally, this is accomplished by appending to a Python list or adding entries to a Python dictionary. However, as these are Python side effects, they will not work as expected in a dynamically unrolled loop. Use `tf.TensorArray` to accumulate results from a dynamically unrolled loop.

In [None]:
batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))

## 限制

TensorFlow `Function` 有意设计了一些限制，在将 Python 函数转换为 `Function` 时需加以注意。

### 执行 Python 副作用

副作用（如打印、附加到列表、改变全局变量）在 `Function` 内部可能会出现异常行为，有时会执行两次或完全无法执行。它们只会在您第一次使用一组输入调用 `Function` 时发生。之后，将重新执行跟踪的 `tf.Graph`，而不执行 Python 代码。

一般经验法则是避免在逻辑中依赖 Python 副作用，而仅使用它们来调试跟踪记录。否则，TensorFlow API（例如 `tf.data`、`tf.print`、`tf.summary`、`tf.Variable.assign` 和 `tf.TensorArray`）是确保在每次调用时 TensorFlow 运行时都能执行您的代码的最佳方式。

In [None]:
@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)


如果希望在每次调用 `Function` 时都执行 Python 代码，`tf.py_function` 可以作为退出点。`tf.py_function` 的缺点是不可移植，性能不高，无法使用 SavedModel 保存并且在分布式（多 GPU、TPU）设置中效果不佳。另外，由于 `tf.py_function` 必须连接到计算图中，它会将所有输入/输出转换为张量。

#### 更改 Python 全局变量和自由变量

更改 Python 全局变量和[自由变量](https://docs.python.org/3/reference/executionmodel.html#binding-of-names)视为 Python 副作用，因此仅在跟踪期间发生。


In [None]:
external_list = []

@tf.function
def side_effect(x):
  print('Python side effect')
  external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1

有时很难注意到意外行为。在下面的示例中，`counter` 旨在保护变量的增量。然而，由于它是一个 Python 整数而不是 TensorFlow 对象，它的值在第一次跟踪期间被捕获。使用 `tf.function` 时，`assign_add` 将被无条件记录在底层计算图中。因此，每次调用 `tf.function` 时 `v` 都会增加 1。当使用 Python 副作用（示例中的 `counter`）确定要运行的运算（示例中的 `assign_add`）时，此问题在尝试使用 `tf.function` 装饰器将其计算图模式 Tensorflow 代码迁移到 Tensorflow 2 的用户中十分常见。通常，用户只有在看到可疑的数值结果或明显低于预期的性能（例如，如果受保护运算的开销非常大）后才会意识到这一点。

In [None]:
class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # A python side-effect
      self.counter += 1
      self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 2, 3

实现预期行为的一种解决方法是使用 [`tf.init_scope`](https://www.tensorflow.org/api_docs/python/tf/init_scope) 将运算提升到函数计算图以外。这样可以确保变量增量在跟踪期间只执行一次。应当注意的是，`init_scope` 还有其他副作用，包括清除控制流和梯度带。有时 `init_scope` 的使用会变得过于复杂而无法实际管理。

In [None]:
class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # Lifts ops out of function-building graphs
      with tf.init_scope():
        self.counter += 1
        self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 1, 1

总之，根据经验，您应避免改变整数或容器（如位于 `Function` 外部的列表）等 Python 对象，而应使用参数和 TF 对象。例如，[在循环中累加值](#accumulating_values_in_a_loop)部分中提供了一个如何实现类列表运算的示例。

在某些情况下，如果为 [`tf.Variable`](https://www.tensorflow.org/guide/variable)，则您可以捕获和处理状态。这是通过重复调用相同的 `ConcreteFunction` 来更新 Keras 模型权重的方式。

#### 使用 Python 迭代器和生成器

很多 Python 功能（如生成器和迭代器）依赖 Python 运行时来跟踪状态。通常，虽然这些构造在 Eager 模式下可以正常工作，但它们是 Python 副作用的示例，因此仅在跟踪期间发生。

In [None]:
@tf.function
def buggy_consume_next(iterator):
  tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)


Just like how TensorFlow has a specialized `tf.TensorArray` for list constructs, it has a specialized `tf.data.Iterator` for iteration constructs. See the section on [AutoGraph transformations](#autograph_transformations) for an overview. Also, the [`tf.data`](https://www.tensorflow.org/guide/data) API can help implement generator patterns:


In [None]:
@tf.function
def good_consume_next(iterator):
  # This is ok, iterator is a tf.data.Iterator
  tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)

### tf.function 的所有输出都必须是返回值

除了 `tf.Variable` 外，一个 tf.function 必须返回其所有输出。尝试直接从函数访问任何张量而不遍历返回值会导致“泄漏”。

例如，下面的函数通过 Python 全局变量 `x`“泄漏”张量 `a`：

In [None]:
x = None

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return a + 2

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)

即使同时返回泄漏的值时也是如此：

In [None]:
@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return x  # Good - uses local tensor

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)

@tf.function
def captures_leaked_tensor(b):
  b += x  # Bad - `x` is leaked from `leaky_function`
  return b

with assert_raises(TypeError):
  captures_leaked_tensor(tf.constant(2))

通常，当您使用 Python 语句或数据结构时，会发生此类泄漏。除了泄漏不可访问的张量之外，此类语句也可能是错误的，因为它们被视为 Python 副作用，而且不能保证在每次函数调用时都执行。

泄漏局部张量的常见方法还包括改变外部 Python 集合或对象：

In [None]:
class MyClass:

  def __init__(self):
    self.field = None

external_list = []
external_object = MyClass()

def leaky_function():
  a = tf.constant(1)
  external_list.append(a)  # Bad - leaks tensor
  external_object.field = a  # Bad - leaks tensor

### 不支持递归 tf.functions

不支持递归 `Function`，它们可能导致无限循环。例如：

In [None]:
@tf.function
def recursive_fn(n):
  if n > 0:
    return recursive_fn(n - 1)
  else:
    return 1

with assert_raises(Exception):
  recursive_fn(tf.constant(5))  # Bad - maximum recursion error.

即使递归 `Function` 看似有效，Python 函数也会被多次跟踪，并且可能会对性能产生影响。例如：

In [None]:
@tf.function
def recursive_fn(n):
  if n > 0:
    print('tracing')
    return recursive_fn(n - 1)
  else:
    return 1

recursive_fn(5)  # Warning - multiple tracings

## Known Issues

如果您的 `Function` 评估不正确，则这些计划于将来得到修复的已知问题可能可以解释该问题。

### 取决于 Python 全局变量和自由变量

当使用 Python 参数的新值进行调用时，`Function` 会创建新的 `ConcreteFunction`。但是，对于该 `Function` 的 Python 闭包、全局变量或非局部变量，则不会创建。如果它们的值在调用 `Function` 之间发生变化，则 `Function` 仍将使用其在跟踪时所具有的值。这与常规 Python 函数的工作方式不同。

因此，您应采用使用参数的函数式编程风格而非闭合外部名称。

In [None]:
@tf.function
def buggy_add():
  return 1 + foo

@tf.function
def recommended_add(foo):
  return 1 + foo

foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))

In [None]:
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add())  # Did not change!
print("Correct:", recommended_add(foo))

更新全局值的另一种方法是使其成为 `tf.Variable` 并改用 `Variable.assign` 方法。


In [None]:
@tf.function
def variable_add():
  return 1 + foo

foo = tf.Variable(1)
print("Variable:", variable_add())


In [None]:
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())

#### 取决于 Python 对象

将 Python 对象作为参数传递给 `tf.function` 的建议存在许多已知问题，预计会在以后得到解决。通常，如果您使用 Python 基元或兼容 `tf.nest` 的结构作为参数，或将对象的*不同*实例传递给 `Function`，则可以依赖稳定的跟踪。但是，如果您传递**同一对象并仅更改其特性**时，`Function` 将*不会*创建新的跟踪记录。

In [None]:
class SimpleModel(tf.Module):
  def __init__(self):
    # These values are *not* tf.Variables.
    self.bias = 0.
    self.weight = 2.

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))

In [None]:
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(

如果使用相同的 `Function` 评估模型的更新实例，那么更新后的模型与原始模型将具有[相同的缓存键](#rules_of_tracing)，所以这种做法并不合理。

因此，建议您编写 `Function` 以避免依赖于可变对象特性，或者创建新对象。

如果这不可行，则一种解决方法是，每次修改对象时都创建新的 `Function` 以强制回溯：

In [None]:
def evaluate(model, x):
  return model.weight * x + model.bias

new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))  

In [None]:
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.

[回溯可能十分耗费资源](https://www.tensorflow.org/guide/intro_to_graphs#tracing_and_performance)，您可以使用 `tf.Variable` 作为对象特性，可以对其进行改变（但非更改，请注意！） 以在无需回溯的情况下实现相似效果。


In [None]:
class BetterModel:

  def __init__(self):
    self.bias = tf.Variable(0.)
    self.weight = tf.Variable(2.)

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

better_model = BetterModel()
print(evaluate(better_model, x))


In [None]:
print("Adding bias!")
better_model.bias.assign_add(5.0)  # Note: instead of better_model.bias += 5
print(evaluate(better_model, x))  # This works!

### 创建 tf.Variables

`Function` 仅支持在第一次调用时创建一次，并且在后续函数调用中重复使用的单例 `tf.Variable`。下面的代码段会在每个函数调用中创建一个新的 `tf.Variable`，这会导致 `ValueError` 异常。

示例：

In [None]:
@tf.function
def f(x):
  v = tf.Variable(1.0)
  return v

with assert_raises(ValueError):
  f(1.0)

用于解决这种限制的常见模式是从 Python None 值开始，随后，在值为 None 时，有条件地创建 `tf.Variable`：

In [None]:
class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())

#### 与多个 Keras 优化器一起使用

将多个 Keras 优化器与 `tf.function` 一起使用时，您可能会遇到 `ValueError: tf.function only supports singleton tf.Variables created on the first call.`。发生此错误的原因是优化器在首次应用梯度时会在内部创建 `tf.Variables`。

In [None]:
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
 
@tf.function
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
  train_step(w, x, y, opt2)

如果您需要在训练期间更改优化器，一种解决方法是为每个优化器创建一个新的 `Function`，直接调用 [`ConcreteFunction`](#obtaining_concrete_functions)。

In [None]:
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

# Not a tf.function.
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
  if i % 2 == 0:
    train_step_1(w, x, y) # `opt1` is not used as a parameter. 
  else:
    train_step_2(w, x, y) # `opt2` is not used as a parameter.

#### 与多个 Keras 模型一起使用

将不同的模型实例传递给同一 `Function` 时，您也可能会遇到 `ValueError: tf.function only supports singleton tf.Variables created on the first call.`。

发生此错误的原因是 Keras 模型（[未定义其输入形状](https://www.tensorflow.org/guide/keras/custom_layers_and_models#best_practice_deferring_weight_creation_until_the_shape_of_the_inputs_is_known)）和 Keras 层会在首次调用时创建 `tf.Variables`。您可能正在尝试在已调用的 `Function` 中初始化这些变量。为避免此错误，请在训练模型之前尝试调用 `model.build(input_shape)` 以初始化所有权重。


## Further reading

要了解如何导出和加载 `Function`，请参阅 [SavedModel 指南](https://render.githubusercontent.com/guide/saved_model)。要详细了解跟踪后执行的计算图优化，请参阅 [Grappler 指南](https://render.githubusercontent.com/guide/graph_optimization)。要了解如何优化数据流水线和剖析模型性能，请参阅 [Profiler 指南](https://render.githubusercontent.com/guide/profiler.md)。