Skip to content

Latest commit

 

History

History
661 lines (528 loc) · 25.5 KB

File metadata and controls

661 lines (528 loc) · 25.5 KB

8.1 命令式和符号式混合编程

本书到目前为止一直都在使用命令式编程,它使用编程语句改变程序状态。考虑下面这段简单的命令式程序。

def add(a, b):
    return a + b

def fancy_func(a, b, c, d):
    e = add(a, b)
    f = add(c, d)
    g = add(e, f)
    return g

fancy_func(1, 2, 3, 4) # 10

和我们预期的一样,在运行语句e = add(a, b)时,Python会做加法运算并将结果存储在变量e中,从而令程序的状态发生改变。类似地,后面的两条语句f = add(c, d)g = add(e, f)会依次做加法运算并存储变量。

虽然使用命令式编程很方便,但它的运行可能很慢。一方面,即使fancy_func函数中的add是被重复调用的函数,Python也会逐一执行这3条函数调用语句。另一方面,我们需要保存变量ef的值直到fancy_func中所有语句执行结束。这是因为在执行e = add(a, b)f = add(c, d)这2条语句之后我们并不知道变量ef是否会被程序的其他部分使用。

与命令式编程不同,符号式编程通常在计算流程完全定义好后才被执行。多个深度学习框架,如Theano和TensorFlow,都使用了符号式编程。通常,符号式编程的程序需要下面3个步骤:

  1. 定义计算流程;
  2. 把计算流程编译成可执行的程序;
  3. 给定输入,调用编译好的程序执行。

下面我们用符号式编程重新实现本节开头给出的命令式编程代码。

def add_str():
    return '''
def add(a, b):
    return a + b
'''

def fancy_func_str():
    return '''
def fancy_func(a, b, c, d):
    e = add(a, b)
    f = add(c, d)
    g = add(e, f)
    return g
'''

def evoke_str():
    return add_str() + fancy_func_str() + '''
print(fancy_func(1, 2, 3, 4))
'''

prog = evoke_str()
print(prog)
y = compile(prog, '', 'exec')
exec(y)

输出:

def add(a, b):
    return a + b

def fancy_func(a, b, c, d):
    e = add(a, b)
    f = add(c, d)
    g = add(e, f)
    return g

print(fancy_func(1, 2, 3, 4))

10

以上定义的3个函数都仅以字符串的形式返回计算流程。最后,我们通过compile函数编译完整的计算流程并运行。由于在编译时系统能够完整地获取整个程序,因此有更多空间优化计算。例如,编译的时候可以将程序改写成print((1 + 2) + (3 + 4)),甚至直接改写成print(10)。这样不仅减少了函数调用,还节省了内存。

对比这两种编程方式,我们可以看到以下两点。

  • 命令式编程更方便。当我们在Python里使用命令式编程时,大部分代码编写起来都很直观。同时,命令式编程更容易调试。这是因为我们可以很方便地获取并打印所有的中间变量值,或者使用Python的调试工具。

  • 符号式编程更高效并更容易移植。一方面,在编译的时候系统容易做更多优化;另一方面,符号式编程可以将程序变成一个与Python无关的格式,从而可以使程序在非Python环境下运行,以避开Python解释器的性能问题。

8.1.1 混合式编程取两者之长

大部分深度学习框架在命令式编程和符号式编程之间二选一。例如,Theano和受其启发的后来者TensorFlow1.x使用了符号式编程,Chainer和它的追随者PyTorch使用了命令式编程。开发人员在设计Tensorflow2.x时思考了这个问题:有没有可能既得到命令式编程的好处,又享受符号式编程的优势?开发者们认为,用户应该用纯命令式编程进行开发和调试;当需要产品级别的计算性能和部署时,用户可以将大部分命令式程序转换成符号式程序来运行。Tensorflow通过提供静态图转换器tf.function,实现对两种编程方式的支持。在 Tensorflow

在不使用静态图转换器tf.function时,用户编写的python函数默认会采用命令式编程逐行执行,符合python编程的直觉,便于调试,但因为框架不能获得完整的静态运算图,不能进行优化,且动态图不能序列化。官方由此推出了静态图转换器tf.function,其作用在python_function后会将这个函数"编译"成一个运算图,接受input_tensors为输入并用图执行的方式计算结果,可以加速函数执行,并且可以被序列化后供任何其他语言(如C++和Java)调用,这样用户只需要在使用动态图运算编写和测试完毕函数之后使用tf.function装饰一下就能够获得静态图的所有优点。

8.1.2 tf.function 的使用

8.1.2.1 基础

tf.function可以定义一个Tensorflow操作,既可以命令式的执行它,

@tf.function
def add(a, b):
    return a+b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>

也可以在图中执行它,并求其梯度,

v = tf.Variable(1.0)
with tf.GradientTape() as tape:
    result = add(v, 1.0)
tape.gradient(result, v)
  <tf.Tensor: shape=(), dtype=float32, numpy=1.0>

也可以定义嵌套定义(当然,在实际使用中,可以直接在顶层定义,会自动对子图进行转换),

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

8.1.2.2 追踪与多态

Python 的动态类型意味着用户可以传递多种类型的参数,这也许会导致函数产生不同的行为。

在另一方面,Tensorflow的静态图需要确定的数据类型和维度。tf.function 通过retracing函数调用来弥补这一差距,并在必要的时候产生正确的计算图。tf.function大多数微妙的行为都产生自retracing行为。

我们可以用不同的参数调用同一函数来观察retracing行为:

# Functions are polymorphic

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

如果希望控制 tracing 行为,可以用如下方式操作:

  • 创建新的 tf.function。分离 tf.fucntion 对象,保证没有共享的计算图引用。
  • 使用 get_concrete_function 方法,得到特定的计算图。
  • 声明 input_signature 当调用 tf.function 时,仅跟踪与输入签名一致的调用。
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
print("Using a concrete trace with incompatible types will throw an error")
with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))

Obtaining concrete trace Tracing with Tensor("a:0", dtype=string) Executing traced function tf.Tensor(b'aa', shape=(), dtype=string) tf.Tensor(b'bb', shape=(), dtype=string) Using a concrete trace with incompatible types will throw an error Caught expected exception <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: Traceback (most recent call last): File "", line 8, in assert_raises yield File "", line 8, in double_strings(tf.constant(1)) tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_87 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_87]

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# We specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-9-9939c82c1507>", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))

8.1.2.3 追踪触发的时机

多态函数 tf.function 会缓存之前追踪行为触发生成过的具体函数。缓存的键由传入的参数确定,对于 tf.Tensor 参数而言,是其维度和类型,而对于 Python 元语,是其值。对于其它 Python类型,使用对象 id,即对每个不同的类实例都会触发独立的追踪行为,并生成相应的静态图。

8.1.2.4 输入参数的选择 Python or Tensor

通常,Python 参数被用作超参数,如 num_layers=10training=True以及nonlinearity='relu'。此时,Python 参数的改变触发 retrace 行为来构建新的计算图是合理的。然而,在另一些情况下,Python 参数并不改变计算图,是不需要触发 retrace 重新构建计算图的。例如,在训练过程中控制步数,AutoGraph 会自动动态展开,因此传入不同的步数,其生成图是一致的,这时如果触发多个 trace 生成同样的计算图,是很低效的。

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = {}".format(num_steps))
  for _ in tf.range(num_steps):
    train_one_step()

train(num_steps=10)
train(num_steps=20)
Tracing with num_steps = 10
Tracing with num_steps = 20

一种简单的绕过方式是,将参数转换为 Tensor,这样不改变 shape,就不会触发生成计算图。

train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Tracing with num_steps = Tensor("num_steps:0", shape=(), dtype=int32)

8.1.2.5 tf.function 的附带效应

通常,Python 的附带效应(print 或改变对象)仅发生在 tracing 行为中。何时应该触发 tf.function 的附带效应呢?

经验上推荐仅使用附带效应来 debug trace 行为。其它情况则建议使用 Tensorflow 运算如 tf.Variable.assigntf.printtf.summary来在 Tensorflow runtime 保证代码跟踪和执行。帮助调试的最佳实践是使用函数式风格。

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

如果想要在每次调用 tf.function 时执行 Python 代码,tf.py_function 提供了这种方式的支持。使用 tf.py_function 的一个缺点是性能,不能很好的工作在分布式环境下(多GPU/TPU)。由于 tf.py_function 需要被可微的连入计算图,输入和输出都会被转换为 tf.Tensor

external_list = []

def side_effect(x):
  print('Python side effect')
  external_list.append(x)

@tf.function
def f(x):
  tf.py_function(side_effect, inp=[x], Tout=[])

f(1)
f(1)
f(1)
assert len(external_list) == 3
# .numpy() call required because py_function casts 1 to tf.constant(1)
assert external_list[0].numpy() == 1
Python side effect
Python side effect
Python side effect

8.1.2.6 注意 Python 的状态

许多 Python 特性,如生成器和迭代器,依赖于 Python 运行时跟踪其状态。通常,这些构件在动态图模式下工作正常,但由于 tracing 行为,在 tf.function 内会发生预期外行为。

例如,迭代器状态被视为一种 Python 附带效应,因此只在 tracing 时触发一次。

external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
  external_var.assign_add(next(iterator))
  tf.print("Value of external_var:", external_var)

iterator = iter([0, 1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value of external_var: 0
Value of external_var: 0
Value of external_var: 0

如果一个迭代器生成和消费完全在 tf.function 中,它会正确地工作,但会产生巨大地计算图,这也许不符合你的预期,更严重的是,对于以 Python List 表示的内存中大型数据集,相应的大型计算图也并不能带来性能提升。

如果想要在 Python 数据集上迭代,最安全的方式是使用 tf.data.Dataset 封装,并以 for x in y 方式遍历。AutoGraph 对 for 循环中 y 是一个 tf.Tensortf.data.Dataset 有特殊支持。

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 2
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1)]) contains 8 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 9 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 9 nodes in its graph

当封装 Python/Numpy 数据为 tf.data.Dataset 时,要注意 tf.data.Dataset.from_generatortf.data.Dataset.from_tensors 的区别。前者保持数据在 Python runtime 中,通过 tf.py_function 取数,会造成一定的性能瓶颈。而后者会复制一份到 Tensorflow runtime 成为计算图的一个 tf.constant() 节点,会占用更多的内存。

从文件中读数据,如 TFRecordDataset/CsvDataset 等,是读取数据最高效的方式,Tensorflow 可以自行管理异步读取和预加载数据,而不需要 Python 的参与。

8.1.2.7 自动控制依赖

作为编程模型,tf.function 一个非常吸引人的特性是,在通常的数据流图之上,为运行时环境提供了更多关于代码预期行为的信息。

例如,当写代码是多次读写同一变量,数据流图也许不会编码原本预期的操作顺序。在 tf.function 中,则会依照在 Python 代码中的声明顺序消除执行顺序的歧义。这使得 tf.function 支持状态操作,可以复制动态图模式的语义。

这意味着不再需要手动添加控制依赖,而可以交由 tf.function 自动添加控制依赖。

# Automatic control dependencies

a = tf.Variable(1.0)
b = tf.Variable(2.0)

@tf.function
def f(x, y):
  a.assign(y * b)
  b.assign_add(x * a)
  return a + b

f(1.0, 2.0)  # 10.0

<tf.Tensor: shape=(), dtype=float32, numpy=10.0>

8.1.2.8 变量

变量也可能使动态图模式的执行结果和静态图模式产生差异。当每次调用创建一个新变量时,由于 tracing 语义,tf.function 会在每次调用时重用同一变量,但动态图时会在每次调用时新建相应的变量。为了避免类似的问题,tf.function 将在监测到危险的变量创建行为时抛出异常。

@tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

with assert_raises(ValueError):
  f(1.0)
<ipython-input-17-73e410646579>:3 f  *
    v = tf.Variable(1.0)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/variables.py:260 __call__
    return cls._variable_v2_call(*args, **kwargs)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/variables.py:254 _variable_v2_call
    shape=shape)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/variables.py:65 getter
    return captured_getter(captured_previous, **kwargs)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py:502 invalid_creator_scope
    "tf.function-decorated function tried to create "

ValueError: tf.function-decorated function tried to create variables on non-first call.

而无歧义的代码就不会触发相应的行为

v = tf.Variable(1.0)

@tf.function
def f(x):
  return v.assign_add(x)

print(f(1.0))  # 2.0
print(f(2.0))  # 4.0
tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)

只要可以证明tf.function中的变量只在函数初次执行时被创建,也是可以通过检查的。

class C:
  pass

obj = C()
obj.v = None

@tf.function
def g(x):
  if obj.v is None:
    obj.v = tf.Variable(1.0)
  return obj.v.assign_add(x)

print(g(1.0))  # 2.0
print(g(2.0))  # 4.0
tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)

8.1.3 AutoGraph

AutoGraph 集成在 tf.function 中,用来重写依赖与张量的条件分支和循环,以支持计算图动态改变结构。

tf.condtf.while_loop 依旧与 tf.function 兼容,但以命令式写控制流更简单直观。

# Simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.654489756 0.378447413 0.843570352 0.706569076 0.899703264]
[0.57468468 0.361358374 0.687695503 0.608520865 0.716153383]
[0.518791437 0.346409947 0.596499443 0.543085039 0.614520967]
[0.476766676 0.333187848 0.534554 0.495319664 0.54730171]
[0.443650424 0.321382284 0.488854468 0.458428353 0.498495191]
[0.416665703 0.310756236 0.453306764 0.428802431 0.460932881]
[0.394117773 0.301124901 0.424613386 0.40432 0.430844247]
[0.374904692 0.292341709 0.400809854 0.383639246 0.406026632]
[0.358274311 0.284288675 0.380641699 0.36586377 0.385093749]
[0.343693078 0.276869684 0.36326462 0.3503685 0.367122918]
[0.330770403 0.270005435 0.348086357 0.336702287 0.351472586]
[0.319212824 0.263629884 0.334677339 0.324530154 0.337680846]
[0.308794975 0.257687539 0.322717279 0.313597322 0.325405359]
[0.299340427 0.252131343 0.31196183 0.303706169 0.314386249]
[0.290708899 0.246921107 0.302220792 0.294700593 0.30442214]
[0.282787144 0.242022276 0.293343604 0.286455452 0.295354247]
[0.275482684 0.237404957 0.285209358 0.278869152 0.287055373]
[0.268719077 0.233043134 0.277719557 0.271858126 0.279422313]
[0.262432516 0.228914022 0.27079317 0.265352964 0.272370338]
[0.256569326 0.22499761 0.264362723 0.259295493 0.265829057]
[0.25108391 0.221276194 0.258371592 0.25363645 0.259739518]
[0.245937288 0.217734098 0.252771795 0.248333916 0.254051864]
[0.241095871 0.214357331 0.247522414 0.243351877 0.248723671]
[0.236530572 0.211133406 0.242588282 0.238659233 0.24371852]
[0.23221606 0.2080511 0.237939 0.234228939 0.239004955]
[0.228130147 0.205100328 0.233548105 0.230037391 0.234555662]
[0.224253282 0.202271983 0.229392484 0.226063833 0.230346799]
[0.220568195 0.199557811 0.225451797 0.222289979 0.226357415]
[0.217059553 0.196950331 0.221708104 0.218699604 0.222569034]
[0.213713691 0.194442704 0.21814549 0.215278283 0.218965292]
[0.210518375 0.192028716 0.214749783 0.212013125 0.215531647]
[0.207462624 0.189702675 0.211508334 0.208892584 0.21225509]
[0.204536542 0.18745935 0.208409771 0.205906287 0.209123984]
[0.201731205 0.185293958 0.205443889 0.203044847 0.206127867]

<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.19903852, 0.18320207, 0.20260146, 0.20029978, 0.20325728],
      dtype=float32)>

如下代码可以观察 autograph 生成的代码

def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

print(tf.autograph.to_code(f))
def tf__f(x):
  do_return = False
  retval_ = ag__.UndefinedReturnValue()
  with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:

    def get_state():
      return ()

    def set_state(_):
      pass

    def loop_body(x):
      ag__.converted_call(tf.print, (x,), None, fscope)
      x = ag__.converted_call(tf.tanh, (x,), None, fscope)
      return x,

    def loop_test(x):
      return ag__.converted_call(tf.reduce_sum, (x,), None, fscope) > 1
    x, = ag__.while_stmt(loop_test, loop_body, get_state, set_state, (x,), ('x',), ())
    do_return = True
    retval_ = fscope.mark_return_value(x)
  do_return,
  return ag__.retval(retval_)

8.1.3.1 AutoGraph:条件分支

AutoGraphif 语句转换为等效的 tf.cond 调用。这一替换发生在条件变量为张量时,除此之外的条件,在 tracing 时确定。

test_tf_cond 函数用来检查函数中是否使用了 tf.cond

def test_tf_cond(f, *args):
  g = f.get_concrete_function(*args).graph
  if any(node.name == 'cond' for node in g.as_graph_def().node):
    print("{}({}) uses tf.cond.".format(
        f.__name__, ', '.join(map(str, args))))
  else:
    print("{}({}) executes normally.".format(
        f.__name__, ', '.join(map(str, args))))

  print("  result: ",f(*args).numpy())

当参数为 python True 时,正常地执行条件:

@tf.function
def dropout(x, training=True):
  if training:
    x = tf.nn.dropout(x, rate=0.5)
  return x

test_tf_cond(dropout, tf.ones([10], dtype=tf.float32), True)
dropout(tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(10,), dtype=float32), True) executes normally.
  result:  [2. 2. 2. 2. 0. 0. 2. 0. 0. 2.]

但传递一个张量则会使 python if 替换为 tf.cond

test_tf_cond(dropout, tf.ones([10], dtype=tf.float32), tf.constant(True))
dropout(tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(10,), dtype=float32), tf.Tensor(True, shape=(), dtype=bool)) uses tf.cond.
  result:  [0. 2. 2. 2. 0. 0. 0. 0. 2. 2.]

8.1.3.2 AutoGraph 与循环

AutoGraph 有一些转换循环的简单规则。

  • for:迭代器是张量时转换
  • while:循环条件与张量有关时转换

如果一个循环被转换,它将由 tf.while 动态展开,或在 for x in tf.data.Dataset 情况下,将循环转换为 tf.data.Dataset.reduce

如果循环没有被转换,则静态展开。

test_dynamically_unrolled(f, *args)

def test_dynamically_unrolled(f, *args):
  g = f.get_concrete_function(*args).graph
  if any(node.name == 'while' for node in g.as_graph_def().node):
    print("{}({}) uses tf.while_loop.".format(
        f.__name__, ', '.join(map(str, args))))
  elif any(node.name == 'ReduceDataset' for node in g.as_graph_def().node):
    print("{}({}) uses tf.data.Dataset.reduce.".format(
        f.__name__, ', '.join(map(str, args))))
  else:
    print("{}({}) gets unrolled.".format(
        f.__name__, ', '.join(map(str, args))))

For 循环

tf.function 的静态展开

@tf.function
def for_in_range():
  x = 0
  for i in range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_range)
for_in_range() gets unrolled.
@tf.function
def for_in_tfrange():
  x = tf.constant(0, dtype=tf.int32)
  for i in tf.range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_tfrange)
for_in_tfrange() uses tf.while_loop.
@tf.function
def for_in_tfdataset():
  x = tf.constant(0, dtype=tf.int64)
  for i in tf.data.Dataset.range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_tfdataset)
for_in_tfdataset() uses tf.data.Dataset.reduce.

While 循环

@tf.function
def while_py_cond():
  x = 5
  while x > 0:
    x -= 1
  return x

test_dynamically_unrolled(while_py_cond)
while_py_cond() gets unrolled.
@tf.function
def while_tf_cond():
  x = tf.constant(5)
  while x > 0:
    x -= 1
  return x

test_dynamically_unrolled(while_tf_cond)
while_tf_cond() uses tf.while_loop.

感兴趣的可以去看原文 Better performance with tf.function Functions, not Sessions AutoGraph Reference tf.Module Tensorflow2.0 学习笔记之静态图转换器 Tensorflow2.0 学习笔记之状态容器