Skip to content
Branch: master
Find file Copy path
Find file Copy path
1 contributor

Users who have contributed to this file

1124 lines (712 sloc) 41.8 KB

TensorFlow 2.0: Functions, not Sessions.

Status Proposed
Updated 2018-10-02


This document presents a proposal to make TensorFlow be more "Pythonic" in 2.0. In five bullet points, the proposal is to:

  • Encourage the encapsulation of graph computation as Python functions
    (where the graph is executed when the function is invoked, instead of via Session)
  • Align "state" in the TensorFlow runtime (e.g., resource tensors like those that back tf.Variable objects) with state in the Python program (e.g., Python objects corresponding to the runtime state with lifetimes attached to each other).
  • Make it easy to export these encapsulations to a GraphDef+Checkpoint and/or SavedModel.
  • Enable eager execution by default.
  • Provide a path for incorporating existing code that uses the 1.x APIs to construct TensorFlows graphs as functions in TensorFlow 2.x programs.

This document liberally employs the use of sample code to describe the end-user effect of proposed changes.

(We say "encourage" instead of "require" since removing the Session API from the Python frontend within a year may be an unrealistic aspiration. Particularly given the use in Estimators and the use of MonitoredSession and hooks. The Session API may have to stick around in tf.compat.v1.)

Design Proposal

Basic idea: Python functions as Graphs

Today, the TensorFlow graph defines the union of all computation that the author of the graph may be interested in. The actual computation to execute is defined by the arguments to Once this subgraph is defined, the runtime can optimize and execute. For example, consider the following:

import tensorflow as tf

x = tf.placeholder(tf.float32)
y = tf.square(x)
z = tf.add(x, y)

sess = tf.Session()

z0 =[z], feed_dict={x: 2.})        # 6.0
z1 =[z], feed_dict={x: 2., y: 2.}) # 4.0

Though there is one tf.Graph object the user is interacting with (tf.get_default_graph()), the two calls are executing different programs (indeed the runtime ends up with two separate Graph objects in C++, one for each program), equivalent to:

def compute_z0(x):
  return tf.add(x, tf.square(x))

def compute_z1(x, y):
  return tf.add(x,  y)

The core proposal of this document is the alignment between computation expressed in Python and the computation executed by the runtime. Instead of defining a graph and then selecting the subgraph to execute at time, the exact computation of interest is encapsulated in a Python callable. For example, the program above that uses to compute z0 and z1 can be written as:

import tensorflow as tf

def compute_z1(x, y):
  return tf.add(x, y)

def compute_z0(x):
  return compute_z1(x, tf.square(x))

z0 = compute_z0(2.)
z1 = compute_z1(2., 2.)

Where tf.function is a decorator that "defines a TensorFlow function". A "TensorFlow function" defines a computation as a graph of TensorFlow operations, with named arguments and explicit return values. Users define the function they want TensorFlow to "accelerate" as a Python function and integrate it into their Python program like any other Python function call.

Having the Python function correspond to what the runtime will execute reduces conceptual complexity in translating between the two domains. It also affords an opportunity to provide more helpful stacktraces on errors. More advanced features available today (e.g., carving sub-graphs, feeding intermediate values) will still be possible (discussed later), though most users should not need to think in terms of graphs, feeds, and fetches. The constructed graph also provides a natural point for accelerators/acceleration libraries (NVIDIA TensorRT, Google Cloud TPUs etc.) to hook in for rewrites.

function: A brief specification

function constructs a TensorFlow graph by "tracing" the TensorFlow operations executed by the Python function. Specifically:

  • f is a Python function that returns zero or more Tensors
  • function(f) is a Python function that returns a Python callable, F
  • When F is invoked it:
    1. Potentially casts inputs to tensors if an input signature was specified, see the "Input Signatures" section below.
    2. Determines a "trace_cache_key" (based on the types and/or values of the arguments).
    3. Every time a new trace_cache_key is encountered, it invokes f to create a TensorFlow graph, G. If the trace_cache_key has been seen before, it looks up G from a cache.
    4. It executes the graph defined by G, feeding each argument as a value of the corresponding node in the graph, and returns a tuple of Tensors (or list of Tensors).

Referencing state: Variables, tables etc.

A function decorated Python function encapsulates a graph and its execution. The Python function may reference stateful objects (i.e., state backed by DT_RESOURCE tensors in the runtime, e.g., tf.Variable) by referencing the corresponding Python object, and these will be captured as implicit inputs to the function.

Comparing TensorFlow code today with how we propose it looks in 2.x:

TensorFlow 1.x 2.0
W = tf.Variable(
    (10, 10)))
b = tf.Variable(tf.zeros(10))
c = tf.Variable(0)

x = tf.placeholder(tf.float32)
ctr = c.assign_add(1)
with tf.control_dependencies([ctr]):
  y = tf.matmul(x, W) + b
init = 

with tf.Session() as sess:
  feed_dict={x: make_input_value()}))
  assert int( == 1
W = tf.Variable(
    (10, 10)))
b = tf.Variable(tf.zeros(10))
c = tf.Variable(0)

def f(x):
  return tf.matmul(x, W) + b

assert int(c) == 1

Worthy of note here - in TensorFlow 1.x, the memory underlying the variables W and b in the runtime lives for the lifetime of the Session - unrelated to the lifetime of the Python objects. In 2.x, the lifetime of the Python objects and the runtime state are tied together.

Program-order semantics / Control dependencies

In TensorFlow graphs today, control dependencies are sometimes needed to ensure correct evaluation order. For example, consider the following:

v = tf.Variable(1.0)
init_op = tf.global_variables_initializer()
assign_op = v.assign(2.0)
read = v.read_value()

with tf.Session() as sess:
  val =
  print(val) # Will print 1.0, the assign is ignored
  val =[read, assign_op])[0]
  print(val)   # Non-deterministically prints 1.0 or 2.0,

The output here is not deterministic, since val may evaluate to either 1.0 or 2.0 depending on whether the runtime happened to execute assign_op before read or not. tf.control_dependencies is a mechanism provided to add annotations at graph construction time to influence graph execution. The TensorFlow user, a Python programmer, is thus forced to think about two execution models - TensorFlow graphs and the Python interpreter. To eliminate this cognitive load, function will automatically insert control dependencies to ensure that (1) operations that produce or consume a given DT_RESOURCE tensor and (2) operations that are marked stateful (REGISTER_OP(...).SetIsStateful()) follow graph construction order. Thus:

v = tf.Variable(1.0)
def f():
  return v.read_value()

print(f()) # Always prints 2.0.

Note that the intention here is to avoid observable differences from program order. For example:

a = tf.Variable(1.0)
b = tf.Variable(1.0)
def f():
  return a + b

Will always print 5.0 since the assignments will occur before the read. However, there is no guaranteed ordering between the assignment of a and b (as any difference in that is not observable).

A preview of this implemented in tf.contrib.eager.defun today (using AutomaticControlDependencies).

Functions that create state

In the above code, no tf.Variable objects are created inside a tf.function decorated function. This is makes it clear that the code will have the same semantics once wrapped.

Note that if the function naturally creates state only on the first trace, all is well:

v = None

def f(x):
  global v
  if v is None:
    v = tf.Variable(1.0)
  return tf.cast(x, tf.float32) + v

f(tf.constant(1, dtype=tf.float32)) # Creates the variable, returns 2.0
f(tf.constant(2, dtype=tf.int32))   # Reuses the variable, returns 3.0

To support this function imposes some requirements on the decorated function:

  1. State (like tf.Variable objects) are only created the first time the function f is called.
    How that is accomplished is left up to the implementation of f.
    If any variables are created in the first execution of f, then @tf.function will trace f again the second time it is invoked in order to record the behavior that will be used from then on. No variables may be created during that second trace, or any other trace after that (due to different dtypes, shapes, or non-tensor arguments).
  2. The caller must make sure that any variable referenced by the function still exists whenever the function is evaluated.
    @tf.function itself will keep only weak references to these created variables. Thus, if the referenced state does not exist when the decorated function is invoked, an exception will be raised.

In the future we may want to allow for function local tf.Variables, which are created and destroyed each time the decorated function is invoked.

Trace Caches

Every argument to a function decorated Python function (F) must be either:

  • A Tensor object (NumPy ndarrays are converted to the equivalent Tensor), or
  • A list of Tensor objects, or
  • An arbitrary Python value.

(There seems to be some interest expressed in supporting structured inputs using nest.flatten and nest.pack_sequence_as. This will be considered as follow-up work.)

Every time F is invoked in the Python program, a trace_cache_key is computed as a function of:

  1. The element datatype and shape of every Tensor argument
  2. The length of the list, and (dtype, shape) of every element in the list of Tensor argument
  3. The concrete value of non-Tensor (and list of Tensor) Python object arguments
  4. The "context" in which F is invoked (e.g., the device prescribed by the tf.device() scope in which F is invoked).

This key is used to determine if a new graph needs to be created or if a previously created graph can be invoked.

Since new graphs are traced when new input signatures are encountered, a function can encapsulate multiple graphs. For example, consider the following:

def f(x):
  return tf.square(x)

f(tf.constant(1, dtype=tf.int32))
f(tf.constant(1.0, dtype=tf.float32))

There are two graphs created here - one which corresponds to the Square operation applied to DT_INT32 tensors, and one with the Square operation applied to DT_FLOAT32 tensors. The object returned by function encapsulates multiple graphs (lazily generated based on the type and shape of input arguments), multiplexing between them in __call__.

Note the use of tf.constant to ensure that the argument is a Tensor. If the argument were a Python value, then additional graphs will be traced for each such value. For example, the following two calls will result in two additional graphs being traced:


Where arguments are not Tensors, the "value" of the argument is used to compute the trace_cache_key. For example:

def f(x, use_multiply):
  return tf.multiply(x, x) if use_multiply else tf.square(x)

f(tf.constant(2.0), True)
f(tf.constant(2.0), False)

will result in 2 graphs being created, since the two calls result in two different cache keys because the value of the Python object (the second argument) changes between the two.

Note that the "type" of Tensor inputs to the function also incorporates the shape. For example:

def f(x): return tf.add(x, 1.)
f(tf.constant([2.0, 3.0]))
f(tf.constant([4.0, 5.0]))

will result in 3 graphs being created:

  1. One for when the first argument is a tf.float32 vector with 1 element
  2. One for when the first argument is a tf.float32 vector with 2 elements
  3. One for when the first argument is a tf.float32 1x1 matrix

The trace_cache_key also incorporates the "context" in which the call was made. For example:

def f(x): return tf.add(x, 1.)

with tf.device("/device:CPU:0"):
with tf.device("/device:GPU:0"):

Will create 2 graphs, one where the operations are pinned to the CPU device and one where they are pinned to the GPU device.

CAUTION: Too many traces

Since new traces are generated on demand, the object returned by function may hold on to more resources than the user may realize. Possible mitigations:

  • Garbage collect the graphs when the weak reference to any component of the trace_cache_key is no longer alive.
  • Use input signatures to prevent unnecessary retraces (see "Input Signatures" section below)
  • Raise / log an error when the ratio of calls to traces is greater than some threshold (e.g., if every 2 calls to a function decorated function generates a new graph).

CAUTION: Mutable non-Tensor arguments

The trace_cache_key includes the Python object for non-Tensor arguments. Mutations of these arguments might not be detected. For example:

class Params(object):
  multiply = True

p = Params()
def f(x, y):
  return tf.multiply(x, 2.) if y.multiply else tf.add(x, 2.)

f(3., p) # Returns 6.0
p.multiply = False
f(3., p)  # Mutations to `p` may not trigger a retrace, so might still return 6.0

Input Signatures

Tracing the decorated function to create a new graph on each input shape is a conservative choice. Often the same graph suffices for Tensors of multiple shapes. As a trivial example, consider:

def f(x): return tf.add(x, 1.)

f(tf.constant(1.0)) # Scalar argument
f(tf.constant([1.0, 2.0])) # Vector argument
f(tf.constant([[3.0]])) # Matrix

This snippet would result in 3 graphs being traced. An "input signature" can be explicitly specified to control the trace_cache_key computation based on the type and shape of Tensor (and list of Tensor) arguments to f.

For example:

@tf.function(input_signature=((tf.float32, [None]))
def f(x): return tf.add(x, 1.)

f(tf.constant([2.0]))      # Returns [3.0]
f(tf.constant([2.0, 3.0])) # Matches the input signature as [None]
                           # matches the actual shape [2]
f(tf.constant([[2.0]]))    # Raises an error as the arguments don't match the
                           # input signature.
f(tf.constant([2], dtype=tf.int32)) # Raises an error as the dtype of the argument
                                    # does not match the input signature

# f is backed by a single Graph since the input signature specification allowed
# for the same graph to be used when the input shape is (1,) or (2,).

An "input signature" specifies a pattern for each of the arguments that may be accepted by the function-decorated function. Specifically:

  • For a Tensor argument, it specifies a (dtype, shape pattern).
    For example:
    • (tf.float32, [None]) means the argument must be a float32 vector (with any number of elements).
    • (tf.int32, []) means that the argument must be an int32 scalar.

      In this case, non-Tensor Python values provided at call time are automatically converted (using tf.convert_to_tensor) to a Tensor matching this signature.
  • For a list of Tensor objects, it specifies an optional list length and the signature for elements in the list (i.e., the dtype and shape pattern for all elements in the list).
  • For non-Tensor arguments: tf.PYTHON_VALUE

When an input signature is specified, new graphs are traced only when the value of the Python argument or the context in which the function is invoked changes. If this is considered to be too restrictive, a possible future extension would be to annotate signature of an argument so that new traces can be created. For example:

@tf.function(input_signature=((tf.TRACE_ON_NEW_VALUE, [None]))
def f(x): return tf.square(x)

f(tf.constant([2.0])) # Returns 4.0
f(tf.constant([2, 2], dtype=tf.int32) # Returns [4, 4] after tracing a new graph

API for function

We've introduced a single new symbol: function that consumes a Python function and returns a callable Python object. The precise API of the object is being iterated on in go/tf-2.0-function-api, but at a high level it will have methods to:

  • List out all captured state (tf.Variable objects, other DT_RESOURCE tensors used by the computation and provided as implicit inputs).
  • Access the tf.Graph that corresponds to the graph executed by the __call__ method of the object.
  • Execute the function with custom RunOptions and retrieve RunMetadata.


If a member function of a class does not create variables, it may be decorated with @tf.function and it will work:

class ScalarModel(object):
  def __init__(self):
    self.v = tf.Variable(0)

  def increment(self, amount):

model1 = ScalarModel()
assert int(model1.v) == 3
assert int(model1.v) == 7
model2 = ScalarModel()
assert int(model2.v) == 5

This works since increment() has self as a non-tensor argument, and a new trace will be created for each value of self. However, if variables are created in a method, we want to allow a new set of variables for every instantiation of self.

class AnyShapeModel(object):
  def __init__(self):
    self.v = None

  def increment(self, amount):
    if self.v is None:
      self.v = tf.Variable(tf.zeros_like(amount))

model1 = AnyShapeModel()
assert int(model1.v) == 3
assert int(model1.v) == 7
model2 = AnyShapeModel()
model2.increment(tf.constant([4, 5]))
assert model2.v.numpy() == [4, 5]

The semantics here are that each new instance is allowed to create variables in each @tf.function once.

In addition, as long as all variable creation/initialization happens while we are tracing, we should be able to support exporting the initialization graph when exporting a SavedModel or MetaGraphDef.

Transitioning from 1.x

The definition of tf.function above is careful to check that invoking a decorated Python function would have the same behavior as invoking an undecorated function. This is to guard against it being passed code from TensorFlow v1.x that expects to only be called once (and relies on things like graph collections to track which variables are created), for example:

def f(x, do_add):
  v = tf.Variable(5.0)
  if do_add:
  return v

For this case, we use a different API, tf.compat.v1.wrap_function, that treats any created variables as static local state:

f_add = tf.compat.v1.wrap_function(f, tf.TensorSpec(tf.float32, ()), True)

assert float(f_add(1.0)) == 6.0
assert float(f_add(1.0)) == 7.0

# Can call tf.compat.v1.wrap_function again to get a new trace, a new set
# of variables, and possibly different non-template arguments.
f_sub = tf.compat.v1.wrap_function(f, tf.TensorSpec(tf.float32, ()), False)

assert float(f_sub(1.0)) == 4.0
assert float(f_sub(1.0)) == 3.0

Note these differences from tf.function:

  • Only ever traces f() once (per call to tf.compat.v1.wrap_function).
  • The complete input tensor signature (via tf.TensorSpec calls) and the values of all non-tensor arguments must be specified when wrapping the function. Note: we may want a tf.tensor_like(x) convenience function that returns tf.TensorSpec(x.dtype, x.shape).
  • Will include extra TF v1.x compatibility features like collections, and access v1.x APIs like tf.compat.v1.get_variable()
  • Will not automatically insert control dependencies to maintain program order across stateful operations/state accesses.
  • May only use a function or Python constant to initialize variables, no tensors. This is a technical limitation, required by the fact that we need some way of disentangling the initializers for variables from the other operations from the function.
  • Keeps strong references to variables created in f, weak references to other variables accessed by f. This is to match the v1.x graph behavior that variables have the lifetime of the graph they are created, and can generally be accessed through graph collections. Some common patterns of writing v1.x code don't leave any references to those variables around. Keeping references to those variables extends their lifetime to match that of the object returned by tf.compat.v1.wrap_function.
  • Typically won't be used as a decorator. Calling tf.compat.v1.wrap_function takes some arguments, traces the function, and creates an object with state. The lifetime of the return value should be tracked explicitly by saving it in a variable.

Treating state (like tf.Variable) as static local does mean that the behavior of a tf.compat.v1.wrap_function-decorated Python function differs from that of an undecorated one. In the above example, f(1.0, True) will always return 6.0 (as a scalar Tensor), while each call to f_add(1.0) will return a different value. We propose this separate tf.compat.v1.wrap_function endpoint specifically to make it easy to migrate TensorFlow 1.x libraries to the TensorFlow 2.0. The behavior of 2.0 tf.function is restricted to cases where we can say that the behavior will match.

We recognize that code written for TensorFlow 1.x commonly does not encapsulate state in Python objects, instead adding to hidden (graph-)global collections. We will support code that accesses collections inside a tf.compat.v1.wrap_function, though those collections will be local to a single trace.

With the tf.compat.v1.wrap_function proposed above, most graph construction library functions written against TensorFlow 1.x can be incorporated into TensorFlow 2.x programs.

def f(x):
  W = tf.compat.v1.get_variable(name="weight", shape=[10, 10])
  b = tf.compat.v1.get_variable(name="bias", shape=[10],
  c = tf.Variable(0, dtype=tf.int32, name="counter")
  with tf.control_dependencies([c.assign_add(1)]):
    return tf.matmul(x, W) + b
f = tf.compat.v1.wrap_function(f, tf.placeholder(tf.float32, None))
assert len(f.variables) == 3
assert f.variables[0].name == "weight"

In this case, the object returned by tf.compat.v1.wrap_function owns the state created within f, and the __call__ method on it invokes the corresponding computation.

Long story short, tf.compat.v1.wrap_function helps in incorporating graph construction code written against TensorFlow 1.x into TensorFlow 2.x programs. wrap_function constructs the same object as a function decorated function, which provides the conceptual equivalent of graph construction and

Serialization: Exporting SavedModel/GraphDefs

So far we've only considered Python programs. One of the key features of TensorFlow is the ability to integrate models created (and possibly trained) in a Python program into an application written in another programming language and/or platform (e.g., servers, mobile phones, self-driving cars). This ability will of course remain, with a smoother path to exporting models.

In TensorFlow 1.x, "saving a model" could mean one of three things:

  1. Saving parameter values, but not the computation:
    A "checkpoint" containing the values of all model parameters.
    (tf.train.Saver / tf.train.Checkpoint)
    Restoring this model required that the restoring program duplicate the Python code to construct the graph with the same model parameters.
  2. Saving the computation graph, but not the parameter values:
    The computation is represented by a GraphDef that can be exported by calls to tf.Graph.as_graph_def(), or tf.train.export_meta_graph(), and reconstructed by calls to tf.import_graph_def() / tf.train.import_meta_graph().
    Note that the parameter (tf.Variable) values are not saved, but their initializers are.
  3. Saving both the computation and the parameter values:
    The two packaged together in a SavedModel.
    At a high level, the SavedModel format packages the MetaGraphDef, checkpoint, and a signature (names of input and output tensors).
    (tf.saved_model.simple_save / tf.saved_model.builder.SavedModelBuilder)
    This is the format preferred for exporting for serving via TensorFlow Serving or to other languages (e.g., SavedModelBundle.load() in Java, LoadSavedModel in Go)

The objects created by function encapsulate (1) the computation expressed as a GraphDef, (2) the state used by it. Thus, these objects are naturally suited for import/export in any of the above formats, using something like the following:

TensorFlow 1.x 2.x
Save only the parameters, not the computation
W = tf.get_variable(
  "weights", shape=[10, 10])

# Presumably the train_op is
# a little fancier 
train_op = W.assign_add(1.)
saver = tf.train.Saver()

with tf.Session() as sess:, "/tmp/checkpoint/")

with tf.Session() as sess:
  saver.restore(sess, "/tmp/checkpoint/")
W = tf.Variable(
    (10, 10)))

def train():

ckpt = tf.train.Checkpoint(W=W)"/tmp/checkpoint")
Exporting/Importing GraphDefs
W = tf.get_variable("weights", shape=[10, 10])
x = tf.placeholder(
  tf.float32, shape=(None, 10)))
y = tf.matmul(x, W)

graph = tf.get_default_graph()
graph_def =  graph.as_graph_def()
with open("/tmp/graph.pb", "w") as f:


graph_def = tf.GraphDef()
with open("/tmp/graph.pbtxt") as f:

W = tf.Variable(
    (10, 10)))

def f(x):
  return tf.matmul(x, W)

# Retrieve the object corresponding to
# a particular input signature:
graph = f.graph_function(
  (tf.float32, (None, 10)).graph
graph_def = graph.as_graph_def()

with open("/tmp/graph.pb", "w") as f:


Exporting/Importing SavedModels
def save_model():
  W = tf.get_variable("weights",
                      shape=[10, 10])
  x = tf.placeholder(
    tf.float32, shape=(None, 10))
  y = tf.matmul(x, W)

  with tf.Session() as sess:
      inputs={"x": x},
      outputs={"y": y})

def load_model():
  sess = tf.Session()
  with sess.as_default():
    inputs, outputs =  tf.saved_model.simple_load(sess, "/tmp/model")
  return inputs, outputs, sess
To be worked on but something along the lines of:
class Model(tf.train.Checkpointable):
  def __init__(self):
    self.W = tf.Variable(...)

  def f(self, x):
    return tf.matmul(x, self.W)

m = Model()

tf.saved_model.export("/tmp/model", m)

m =

Derived/Related Graphs

One reservation expressed by TensorFlow graph/session enthusiasts today is that the ability to write generic analysis/inspection tooling on graphs, precluding the need to understand or modify the Python code that constructed the graph, is important to them. To put it differently, some find it easier to navigate the GraphDef program than navigating the Python program. \

This ability will be maintained. function-decorated Python functions have an associated graph, and new functions can be created by specifying the sub-graph of interest. For example:

TensorFlow 1.x TensorFlow 2.x
Carving out a subgraph
def build_graph():
  x = tf.placeholder(tf.float32)
  y = tf.square(x)
  z = tf.square(y)

with tf.Session() as sess:
   feed_dict={"Square:0": 2.0})  # 4.0
def f(x):
  return tf.square(tf.square(x))

# tf.Graph corresponding to "x" 
# being a float32 tensor with unknown
# shape
graph = f.graph_function(
  (tf.float32, None)).graph

f2 = tf.NewGraphFunction(
# The above may optionally take a
# "prune" argument to allow for
# pruning stateful operations in
# `graph` that are not in the path
# from inputs to outputs.
f2(2.0) # 4.0
Extending a graph
def build_graph():
  x = tf.placeholder(tf.float32)
  y = tf.square(x)
  return y

y = build_graph()
z = tf.square(y)

with tf.Session() as sess:
  # Line below will return 16.0, feed_dict={"Placeholder:0": 2.0))
def f(x):
  return tf.square(x)

def g(x):
  return tf.square(f(x))

g(2.0) # 16.0

Distributed Execution

At the lowest level of the API, distributed execution continues to work with tf.device annotations, where the device name can reference remote devices as well, just like they do today.

The DistributionStrategy API, typically aimed at synchronous training will continue to be the method of choice (where the API can be used inside a function). Other APIs such as go/tf-replicator will also be usable.

The author realizes that this section can do with more detail. However, to keep this document more focused, these details will be discussed separately. In particular, usage of MonitoredSession and session hooks today needs additional thought.

function-ing Python control flow

function decorates a graph construction function and transparently recreates graphs if needed. However, this does mean that if the function has data-dependent control flow then though the function will execute fine with eager execution enabled, function decorating it will fail. For example:

def f(x, y):
  if tf.equal(y, 0.0):
    return y
  return x / y

x = tf.constant(2.0)
y = tf.constant(2.0)

f(x, y) # Will be 1.0

df = tf.function(f)
df(x, y)  # Will raise an error complaining about the data-dependent control flow

To fix this, one would have to use the graph construction APIs for control flow (tf.cond, tf.while_loop):

def f(x, y):
  return tf.cond(tf.equal(y, 0.0), lambda: y, lambda: x/y)

x = tf.constant(2.0)
y = tf.constant(2.0)

f(x, y) # Will be 1.0

df = tf.function(f)
df(x, y)  # Will be 1.0

This situation can be improved with the help of autograph to allow expression of control flow in Python. Whether autograph will be enabled by default or not is still under debate, but the option will be there as a flag on function. For example:

df = tf.function(autograph=True)(f)
f(x, y) # Will be 1.0


The summary writing operations (tb.summary.scalar, tb.summary.image etc.) can be naturally placed in the graph by using them in a function-decorated function. These operations require two "external" inputs - the summary writer resource and the condition, that will be picked up from the context (e.g., tb.summary.create_file_writer and tb.summary.record_summary_every_n_global_steps). When defining the graph, these inputs are converted to placeholders, which are then resolved at function invocation time. Thus, something like this:

writer = tf.contrib.summary.create_file_writer('/tmp/test')
with writer.as_default(), tf.contrib.summary.always_record_summaries():
with writer.as_default(), tf.contrib.summary.never_record_summaries():

Will write one summary to writer whether f is defined as:

def f(): 
    tb.summary.scalar("loss", compute_loss()) 


f = tf.contrib.eager.defun(f)

(NOTE: As of August 2018, this is not the case, but it will be. See b/112269952).

Note that the runtime is free to prune away the summary writing operations when the function is invoked in a context where there is no summary writer resource or the condition is false.

What does that have to do with eager execution?

So far this proposal has dealt with the encapsulation of TensorFlow graphs in Python functions with the intention of making it easier to integrate TensorFlow-accelerated computation in Python programs.

Additionally, this proposal suggests enabling eager execution by default in TensorFlow 2.0. Keeping function in mind, this basically means:

  • Inside the context of defining a TensorFlow function (i.e., within a function decorated function) tf.Tensor objects created refer to symbolic tensors.
  • Outside this context, tf.Tensor objects created are backed by concrete values and TensorFlow API. The underlying memory of the tensor can be backed by any device (i.e., CPU/GPU) and is not restricted to host-memory (like numpy arrays).

See the docstring for tf.contrib.eager.defun - the evolving playground for the implementation of the proposal in this document. The basic takeaway is that:

  • For users that embrace symbolic tensors and graphs, continue doing so with your code placed inside a function decorated Python function.
  • We believe most users (new ones in particular) will find it more convenient to deal with Tensor objects backed by concrete values and then selectively "compiling" portions of their Python program into TensorFlow graphs rather than being exposed to graph metaprogramming in Python upfront. In spirit, this is similar to Swift4TensorFlow with the obvious glaring difference that graph program extraction here is manually specified (with the function decoration).

NOTE: In TensorFlow 1.x, eager execution is enabled by tf.enable_eager_execution(). Once invoked, all public API endpoints that consume or produce symbolic Tensor objects begin to produce and consume Tensor objects that are backed by a concrete value. See the "Research and Experimentation" section at for an introduction.

A few things of note

  • This change only applies to the TensorFlow Python frontend
    • TensorFlow.js is already "eager by default".
    • Switf4TensorFlow has similar design goals, doing away with the define-then-run style of TensorFlow graphs.
    • Most other language bindings (Java, C++, Go, others) are mostly targeting deployment of defined models in applications. While an imperative style might help simplify model development and training in these languages, doing so is explicitly out of scope for TensorFlow 2.0. The notion of graphs and sessions will remain in them, as well as in the stable C API. In these APIs, the lifetime of program state (like variables) will continue to be tied to the lifetime of the Session.
  • Users of Estimator will see no change
    • Canned Estimators are black boxes that create and train models. Enabling eager execution will have no effect on their usage. This is true today.
    • The model_fn of a regular (non-canned) Estimator will remain as a graph construction function.
  • SavedModel will continue to be the format encouraged for exporting trained models
    • Crudely speaking, a SavedModel encapsulates a Graph, a checkpoint of variable values, and some metadata like signature information (names of input and output tensors).
    • A path will be provided to easily export models in this format (e.g., via There may be instances where converting the Python code to a graph is not trivial (e.g., it uses the subset of Python that autograph does not support), in which case, exporting to a SavedModel (and thus a Graph) will fail.

Alternatives Considered

Creating state inside a function

How state (DT_RESOURCE tensors) created inside a function should be handled is actively being debated. Options include:

  1. "Lifting" state out as a static local function variable
  2. Mimic the undecorated code - creating and destroying variables on each call.

"Static-local" state

tf.contrib.eager.function today treats state as function-static variables, which allows for code like:

def f(x):
  v = tf.Variable(1, dtype=x.dtype)
  return v

df = tf.contrib.eager.defun(f)
# tf.function(f) proposed in this document will raise an exception on first use
x = tf.constant(1, dtype=tf.float32))
print(df(x))  # 2.0
print(df(x))  # 3.0

However, the one major issue with this approach is that it behaves differently from how an undecorated function would:

print(f(1.0), df(1.0))  # 2.0, 2.0
print(f(1.0), df(1.0))  # 2.0, 3.0

To be conservative, we propose some restrictions on function, such as:

  1. State is created only once, i.e., function will fail if calling f a second time results in new state being created.
  2. function decorated functions can only produce Tensor return values.
  3. If you want to convert TF v1.x code like f above, you may use tf.compat.v1.wrap_function which guarantees it will only trace f once.

Function-local state

Another option would be to match typical Python functions, where state is created and destroyed during the call to the function. So:

def f(x):
  v = tf.Variable(1.0)
  return v

df = tf.function(f)

assert f(1.0) == df(1.0) # Both will be 2.0
assert f(1.0) == df(1.0) # Still 2.0, since 'v' would be recreated.

This seems like an avenue definitely worth pursuing, but requires careful consideration of some additional design points such as escape analysis of return values (e.g. the lifetime of tf.Variable objects that are returned from a decorated function).

For now, we propose that function continue with the restricted abilities proposed in this document and a "maintain Python semantics" decorator be investigated independently.

Open Questions/Ideas

  • Naming:
    • tf.compat.v1.wrap_function or tf.compat.v1.defun or tf.compat.v1.function or tf.compat.v1.wrap_graph_as_function?
  • Signatures in Python 3? (From ngc92)
  • Supporting structured inputs:
    As proposed, arguments to function must be either Tensor objects, or objects that can be converted to a Tensor (tf.convert_to_tensor), or opaque Python objects.

    Perhaps we can support nested structures of Tensors (using nest.flatten and nest.pack_sequence_as), or even arbitrary Python objects?

    If this is supported, then specifying an input_signature may become cumbersome, but perhaps we can have a function(infer_signature_from_first_call=True) to make that easier.
You can’t perform that action at this time.