Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Support nested defuns on TPU
PiperOrigin-RevId: 209239670
  • Loading branch information
iganichev authored and tensorflower-gardener committed Aug 18, 2018
1 parent fd1957d commit b7c2424
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 53 deletions.
1 change: 0 additions & 1 deletion tensorflow/compiler/tests/eager_test.py
Expand Up @@ -443,7 +443,6 @@ def f(x, y):
self.assertAllEqual((2, 3, 4), dz.shape.as_list())

def testNestedDefun(self):
self.skipTest('Nested defuns do not work on TPU at the moment')
with self.test_scope():

@function.defun
Expand Down
Expand Up @@ -218,7 +218,7 @@ def _force_device_sync(self):
tf.constant(1.).cpu()

def _benchmark_eager_apply(self, label, device_and_format, defun=False,
execution_mode=None, compiled=False):
execution_mode=None):
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
model = densenet.DenseNet(self.depth, self.growth_rate, self.num_blocks,
Expand All @@ -228,7 +228,7 @@ def _benchmark_eager_apply(self, label, device_and_format, defun=False,
weight_decay=1e-4, dropout_rate=0,
pool_initial=True, include_top=True)
if defun:
model.call = tfe.defun(model.call, compiled=compiled)
model.call = tfe.defun(model.call)
batch_size = 64
num_burn = 5
num_iters = 30
Expand Down Expand Up @@ -264,8 +264,7 @@ def _benchmark_eager_train(self,
make_iterator,
device_and_format,
defun=False,
execution_mode=None,
compiled=False):
execution_mode=None):
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
for batch_size in self._train_batch_sizes():
Expand All @@ -279,8 +278,8 @@ def _benchmark_eager_train(self,
optimizer = tf.train.GradientDescentOptimizer(0.1)
apply_grads = apply_gradients
if defun:
model.call = tfe.defun(model.call, compiled=compiled)
apply_grads = tfe.defun(apply_gradients, compiled=compiled)
model.call = tfe.defun(model.call)
apply_grads = tfe.defun(apply_gradients)

num_burn = 3
num_iters = 10
Expand Down
Expand Up @@ -216,12 +216,12 @@ def _force_device_sync(self):
tf.constant(1.).cpu()

def _benchmark_eager_apply(self, label, device_and_format, defun=False,
execution_mode=None, compiled=False):
execution_mode=None):
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
model = resnet50.ResNet50(data_format)
if defun:
model.call = tfe.defun(model.call, compiled=compiled)
model.call = tfe.defun(model.call)
batch_size = 64
num_burn = 5
num_iters = 30
Expand Down Expand Up @@ -257,8 +257,7 @@ def _benchmark_eager_train(self,
make_iterator,
device_and_format,
defun=False,
execution_mode=None,
compiled=False):
execution_mode=None):
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
for batch_size in self._train_batch_sizes():
Expand All @@ -267,8 +266,8 @@ def _benchmark_eager_train(self,
optimizer = tf.train.GradientDescentOptimizer(0.1)
apply_grads = apply_gradients
if defun:
model.call = tfe.defun(model.call, compiled=compiled)
apply_grads = tfe.defun(apply_gradients, compiled=compiled)
model.call = tfe.defun(model.call)
apply_grads = tfe.defun(apply_gradients)

num_burn = 3
num_iters = 10
Expand Down
Expand Up @@ -226,14 +226,13 @@ def _benchmark_eager_apply(self,
label,
device_and_format,
defun=False,
execution_mode=None,
compiled=False):
execution_mode=None):
config = config_.get_hparams_imagenet_56()
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
model = revnet.RevNet(config=config)
if defun:
model.call = tfe.defun(model.call, compiled=compiled)
model.call = tfe.defun(model.call)
batch_size = 64
num_burn = 5
num_iters = 10
Expand Down Expand Up @@ -271,8 +270,7 @@ def _benchmark_eager_train(self,
make_iterator,
device_and_format,
defun=False,
execution_mode=None,
compiled=False):
execution_mode=None):
config = config_.get_hparams_imagenet_56()
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
Expand Down
60 changes: 25 additions & 35 deletions tensorflow/python/eager/function.py
Expand Up @@ -26,7 +26,6 @@
import numpy as np
import six

from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
Expand Down Expand Up @@ -102,7 +101,7 @@ class CapturingGraph(ops.Graph):
The entries are in the order they were captured.
"""

def __init__(self):
def __init__(self, graph=None):
super(CapturingGraph, self).__init__()

self.captures = collections.OrderedDict()
Expand All @@ -113,6 +112,13 @@ def __init__(self):
# for resource tensors.
self._last_op_using_resource_tensor = {}

if context.executing_eagerly():
self._xla_compile = (context.context().device_spec.device_type == "TPU")
elif graph is not None:
self._xla_compile = getattr(graph, "_xla_compile", False)
else:
self._xla_compile = False

# TODO(apassos) remove once the C API is used by default.
def _use_c_api_hack(self):
return True
Expand Down Expand Up @@ -207,7 +213,7 @@ def __init__(self, name, graph=None):
graph: if specified, this FuncGraph will inherit its graph key,
collections, and seed from `graph`.
"""
super(FuncGraph, self).__init__()
super(FuncGraph, self).__init__(graph=graph)

self.name = name
self.inputs = []
Expand Down Expand Up @@ -267,9 +273,6 @@ def _register(fn):
context.context().add_function(fn)


_xla_compile_attr = "_XlaCompile"


# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction
# so it doesn't have the definition-generating logic and is just a container for
# an already-defined function.
Expand Down Expand Up @@ -311,7 +314,6 @@ def __init__(self, name, graph, operations, inputs, outputs, attrs):
# It might be worth creating a convenient way to re-use status.
pywrap_tensorflow.TF_FunctionSetAttrValueProto(
fn, compat.as_str(name), serialized)
self._xla_compile = _xla_compile_attr in attrs

# TODO(apassos) avoid creating a FunctionDef (specially to grab the
# signature, but also in general it's nice not to depend on it.
Expand Down Expand Up @@ -365,10 +367,7 @@ def call(self, ctx, args, output_shapes):

executing_eagerly = ctx.executing_eagerly()

xla_compile = self._xla_compile or (executing_eagerly and
ctx.device_spec.device_type == "TPU")

if xla_compile:
if self._graph._xla_compile: # pylint: disable=protected-access
# XLA compilation relies upon a custom kernel creator to run functions.
signature = self.signature
if executing_eagerly:
Expand Down Expand Up @@ -471,6 +470,10 @@ def __init__(self,
attrs: (optional) dict mapping names of attributes to their AttrValue
values. Attributes in `attrs` will be included in this function's
definition.
Raises:
ValueError: If number of input_placeholders is not equal to the number
of function inputs.
"""
self._attrs = attrs or {}
defined_function = _EagerDefinedFunction(
Expand Down Expand Up @@ -549,6 +552,7 @@ def _construct_backprop_function(self):
operations = tuple(op for op in backwards_graph.get_operations()
if op not in ignored_ops)

backwards_graph._xla_compile = self._graph._xla_compile # pylint: disable=protected-access
self._backward_function = GraphModeFunction(
backwards_graph.name,
backwards_graph.inputs,
Expand Down Expand Up @@ -729,14 +733,12 @@ def _get_defun_inputs_from_args(args):
return nest.pack_sequence_as(args, function_inputs)


def _trace_and_define_function(name, python_func, compiled, args, kwds,
signature=None):
def _trace_and_define_function(name, python_func, args, kwds, signature=None):
"""Defines and returns graph-mode version of `python_func`.
Args:
name: an identifier for the function.
python_func: the Python function to trace.
compiled: whether the graph function should be compiled through XLA.
args: the positional args with which the Python function should be called;
ignored if a signature is provided.
kwds: the keyword args with which the Python function should be called;
Expand Down Expand Up @@ -854,8 +856,6 @@ def check_mutation(n1, n2):
_register(f._c_func.func) # pylint: disable=protected-access

attrs = {}
if compiled:
attrs[_xla_compile_attr] = attr_value_pb2.AttrValue(b=True)

return GraphModeFunction(
func_graph.name, func_graph.inputs, func_graph.captures.keys(),
Expand Down Expand Up @@ -924,8 +924,7 @@ class _PolymorphicFunction(object):
def __init__(self,
python_function,
name,
input_signature=None,
compiled=False):
input_signature=None):
"""Initializes a polymorphic function.
Args:
Expand All @@ -934,7 +933,6 @@ def __init__(self,
input_signature: a possibly nested sequence of `TensorSpec` objects
specifying the input signature of this function. If `None`, a separate
function is instantiated for each inferred input signature.
compiled: if True, the framework will attempt to compile func with XLA.
Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
Expand All @@ -953,7 +951,6 @@ def __init__(self,
self._args_to_prepend = tuple()
self._kwds_to_include = {}
self._name = name
self._compiled = compiled
self._arguments_to_functions = {}
self._variables = []

Expand Down Expand Up @@ -1119,7 +1116,7 @@ def _maybe_define_function(self, *args, **kwds):

if graph_function is None:
graph_function = _trace_and_define_function(
self._name, self._python_function, self._compiled, args, kwds,
self._name, self._python_function, args, kwds,
self._input_signature)
self._variables.extend(
[v for v in graph_function.variables if v not in self._variables])
Expand All @@ -1141,10 +1138,7 @@ def variables(self):
return self._variables


# TODO(akshayka): Remove the `compiled` flag and create a separate
# API for xla compilation (`defun` is already complicated enough
# as it is, and the keyword argument makes 'compiled' an overloaded concept)
def defun(func=None, input_signature=None, compiled=False):
def defun(func=None, input_signature=None):
"""Compiles a Python function into a callable TensorFlow graph.
`defun` (short for "define function") trace-compiles a Python function
Expand Down Expand Up @@ -1435,9 +1429,10 @@ def fn():
func: function to be compiled. If `func` is None, returns a
decorator that can be invoked with a single argument - `func`. The
end result is equivalent to providing all the arguments up front.
In other words, defun(compiled=True)(func) is equivalent to
defun(func, compiled=True). The former allows the following use case:
@tf.contrib.eager.defun(compiled=True)
In other words, defun(input_signature=...)(func) is equivalent to
defun(func, input_signature=...). The former allows
the following use case:
@tf.contrib.eager.defun(input_signature=...)
def foo(...):
...
Expand All @@ -1448,11 +1443,6 @@ def foo(...):
signature is specified, every input to `func` must be a `Tensor`, and
`func` cannot accept `**kwargs`.
compiled: If True, an attempt to compile `func` with XLA will be made.
If it fails, function will be run normally. Experimental. Currently
supported only for execution on TPUs. For the vast majority of users,
this argument should be False.
Returns:
If `func` is not None, returns a callable that will execute the compiled
function (and return zero or more `tf.Tensor` objects).
Expand All @@ -1468,7 +1458,7 @@ def decorated(function):
return tf_decorator.make_decorator(
function,
_PolymorphicFunction(
function, name, input_signature=input_signature, compiled=compiled))
function, name, input_signature=input_signature))

# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None:
Expand Down Expand Up @@ -1526,7 +1516,7 @@ def g(x, y):
and which can be called directly the way a `@defun` wrapped function
can.
"""
return _trace_and_define_function(func.__name__, func, False, args, kwds)
return _trace_and_define_function(func.__name__, func, args, kwds)


class AutomaticControlDependencies(object):
Expand Down

0 comments on commit b7c2424

Please sign in to comment.