diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index ff097f80f1f258..3d21fb5864c22a 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -443,7 +443,6 @@ def f(x, y): self.assertAllEqual((2, 3, 4), dz.shape.as_list()) def testNestedDefun(self): - self.skipTest('Nested defuns do not work on TPU at the moment') with self.test_scope(): @function.defun diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py index 0736ed02b74372..e5058bfd9480e2 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py +++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py @@ -218,7 +218,7 @@ def _force_device_sync(self): tf.constant(1.).cpu() def _benchmark_eager_apply(self, label, device_and_format, defun=False, - execution_mode=None, compiled=False): + execution_mode=None): with tfe.execution_mode(execution_mode): device, data_format = device_and_format model = densenet.DenseNet(self.depth, self.growth_rate, self.num_blocks, @@ -228,7 +228,7 @@ def _benchmark_eager_apply(self, label, device_and_format, defun=False, weight_decay=1e-4, dropout_rate=0, pool_initial=True, include_top=True) if defun: - model.call = tfe.defun(model.call, compiled=compiled) + model.call = tfe.defun(model.call) batch_size = 64 num_burn = 5 num_iters = 30 @@ -264,8 +264,7 @@ def _benchmark_eager_train(self, make_iterator, device_and_format, defun=False, - execution_mode=None, - compiled=False): + execution_mode=None): with tfe.execution_mode(execution_mode): device, data_format = device_and_format for batch_size in self._train_batch_sizes(): @@ -279,8 +278,8 @@ def _benchmark_eager_train(self, optimizer = tf.train.GradientDescentOptimizer(0.1) apply_grads = apply_gradients if defun: - model.call = tfe.defun(model.call, compiled=compiled) - apply_grads = tfe.defun(apply_gradients, compiled=compiled) + model.call = tfe.defun(model.call) + apply_grads = tfe.defun(apply_gradients) num_burn = 3 num_iters = 10 diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index 07d8788882c2d8..d265169b5eff68 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -216,12 +216,12 @@ def _force_device_sync(self): tf.constant(1.).cpu() def _benchmark_eager_apply(self, label, device_and_format, defun=False, - execution_mode=None, compiled=False): + execution_mode=None): with tfe.execution_mode(execution_mode): device, data_format = device_and_format model = resnet50.ResNet50(data_format) if defun: - model.call = tfe.defun(model.call, compiled=compiled) + model.call = tfe.defun(model.call) batch_size = 64 num_burn = 5 num_iters = 30 @@ -257,8 +257,7 @@ def _benchmark_eager_train(self, make_iterator, device_and_format, defun=False, - execution_mode=None, - compiled=False): + execution_mode=None): with tfe.execution_mode(execution_mode): device, data_format = device_and_format for batch_size in self._train_batch_sizes(): @@ -267,8 +266,8 @@ def _benchmark_eager_train(self, optimizer = tf.train.GradientDescentOptimizer(0.1) apply_grads = apply_gradients if defun: - model.call = tfe.defun(model.call, compiled=compiled) - apply_grads = tfe.defun(apply_gradients, compiled=compiled) + model.call = tfe.defun(model.call) + apply_grads = tfe.defun(apply_gradients) num_burn = 3 num_iters = 10 diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py index 84b2ddf0de0739..6a921e19978fdf 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py @@ -226,14 +226,13 @@ def _benchmark_eager_apply(self, label, device_and_format, defun=False, - execution_mode=None, - compiled=False): + execution_mode=None): config = config_.get_hparams_imagenet_56() with tfe.execution_mode(execution_mode): device, data_format = device_and_format model = revnet.RevNet(config=config) if defun: - model.call = tfe.defun(model.call, compiled=compiled) + model.call = tfe.defun(model.call) batch_size = 64 num_burn = 5 num_iters = 10 @@ -271,8 +270,7 @@ def _benchmark_eager_train(self, make_iterator, device_and_format, defun=False, - execution_mode=None, - compiled=False): + execution_mode=None): config = config_.get_hparams_imagenet_56() with tfe.execution_mode(execution_mode): device, data_format = device_and_format diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 189eb800768037..5cfa8895acc4ff 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -26,7 +26,6 @@ import numpy as np import six -from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context @@ -102,7 +101,7 @@ class CapturingGraph(ops.Graph): The entries are in the order they were captured. """ - def __init__(self): + def __init__(self, graph=None): super(CapturingGraph, self).__init__() self.captures = collections.OrderedDict() @@ -113,6 +112,13 @@ def __init__(self): # for resource tensors. self._last_op_using_resource_tensor = {} + if context.executing_eagerly(): + self._xla_compile = (context.context().device_spec.device_type == "TPU") + elif graph is not None: + self._xla_compile = getattr(graph, "_xla_compile", False) + else: + self._xla_compile = False + # TODO(apassos) remove once the C API is used by default. def _use_c_api_hack(self): return True @@ -207,7 +213,7 @@ def __init__(self, name, graph=None): graph: if specified, this FuncGraph will inherit its graph key, collections, and seed from `graph`. """ - super(FuncGraph, self).__init__() + super(FuncGraph, self).__init__(graph=graph) self.name = name self.inputs = [] @@ -267,9 +273,6 @@ def _register(fn): context.context().add_function(fn) -_xla_compile_attr = "_XlaCompile" - - # TODO(apassos) get rid of this by splitting framework.function._DefinedFunction # so it doesn't have the definition-generating logic and is just a container for # an already-defined function. @@ -311,7 +314,6 @@ def __init__(self, name, graph, operations, inputs, outputs, attrs): # It might be worth creating a convenient way to re-use status. pywrap_tensorflow.TF_FunctionSetAttrValueProto( fn, compat.as_str(name), serialized) - self._xla_compile = _xla_compile_attr in attrs # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. @@ -365,10 +367,7 @@ def call(self, ctx, args, output_shapes): executing_eagerly = ctx.executing_eagerly() - xla_compile = self._xla_compile or (executing_eagerly and - ctx.device_spec.device_type == "TPU") - - if xla_compile: + if self._graph._xla_compile: # pylint: disable=protected-access # XLA compilation relies upon a custom kernel creator to run functions. signature = self.signature if executing_eagerly: @@ -471,6 +470,10 @@ def __init__(self, attrs: (optional) dict mapping names of attributes to their AttrValue values. Attributes in `attrs` will be included in this function's definition. + + Raises: + ValueError: If number of input_placeholders is not equal to the number + of function inputs. """ self._attrs = attrs or {} defined_function = _EagerDefinedFunction( @@ -549,6 +552,7 @@ def _construct_backprop_function(self): operations = tuple(op for op in backwards_graph.get_operations() if op not in ignored_ops) + backwards_graph._xla_compile = self._graph._xla_compile # pylint: disable=protected-access self._backward_function = GraphModeFunction( backwards_graph.name, backwards_graph.inputs, @@ -729,14 +733,12 @@ def _get_defun_inputs_from_args(args): return nest.pack_sequence_as(args, function_inputs) -def _trace_and_define_function(name, python_func, compiled, args, kwds, - signature=None): +def _trace_and_define_function(name, python_func, args, kwds, signature=None): """Defines and returns graph-mode version of `python_func`. Args: name: an identifier for the function. python_func: the Python function to trace. - compiled: whether the graph function should be compiled through XLA. args: the positional args with which the Python function should be called; ignored if a signature is provided. kwds: the keyword args with which the Python function should be called; @@ -854,8 +856,6 @@ def check_mutation(n1, n2): _register(f._c_func.func) # pylint: disable=protected-access attrs = {} - if compiled: - attrs[_xla_compile_attr] = attr_value_pb2.AttrValue(b=True) return GraphModeFunction( func_graph.name, func_graph.inputs, func_graph.captures.keys(), @@ -924,8 +924,7 @@ class _PolymorphicFunction(object): def __init__(self, python_function, name, - input_signature=None, - compiled=False): + input_signature=None): """Initializes a polymorphic function. Args: @@ -934,7 +933,6 @@ def __init__(self, input_signature: a possibly nested sequence of `TensorSpec` objects specifying the input signature of this function. If `None`, a separate function is instantiated for each inferred input signature. - compiled: if True, the framework will attempt to compile func with XLA. Raises: ValueError: if `input_signature` is not None and the `python_function`'s @@ -953,7 +951,6 @@ def __init__(self, self._args_to_prepend = tuple() self._kwds_to_include = {} self._name = name - self._compiled = compiled self._arguments_to_functions = {} self._variables = [] @@ -1119,7 +1116,7 @@ def _maybe_define_function(self, *args, **kwds): if graph_function is None: graph_function = _trace_and_define_function( - self._name, self._python_function, self._compiled, args, kwds, + self._name, self._python_function, args, kwds, self._input_signature) self._variables.extend( [v for v in graph_function.variables if v not in self._variables]) @@ -1141,10 +1138,7 @@ def variables(self): return self._variables -# TODO(akshayka): Remove the `compiled` flag and create a separate -# API for xla compilation (`defun` is already complicated enough -# as it is, and the keyword argument makes 'compiled' an overloaded concept) -def defun(func=None, input_signature=None, compiled=False): +def defun(func=None, input_signature=None): """Compiles a Python function into a callable TensorFlow graph. `defun` (short for "define function") trace-compiles a Python function @@ -1435,9 +1429,10 @@ def fn(): func: function to be compiled. If `func` is None, returns a decorator that can be invoked with a single argument - `func`. The end result is equivalent to providing all the arguments up front. - In other words, defun(compiled=True)(func) is equivalent to - defun(func, compiled=True). The former allows the following use case: - @tf.contrib.eager.defun(compiled=True) + In other words, defun(input_signature=...)(func) is equivalent to + defun(func, input_signature=...). The former allows + the following use case: + @tf.contrib.eager.defun(input_signature=...) def foo(...): ... @@ -1448,11 +1443,6 @@ def foo(...): signature is specified, every input to `func` must be a `Tensor`, and `func` cannot accept `**kwargs`. - compiled: If True, an attempt to compile `func` with XLA will be made. - If it fails, function will be run normally. Experimental. Currently - supported only for execution on TPUs. For the vast majority of users, - this argument should be False. - Returns: If `func` is not None, returns a callable that will execute the compiled function (and return zero or more `tf.Tensor` objects). @@ -1468,7 +1458,7 @@ def decorated(function): return tf_decorator.make_decorator( function, _PolymorphicFunction( - function, name, input_signature=input_signature, compiled=compiled)) + function, name, input_signature=input_signature)) # This code path is for the `foo = tfe.defun(foo, ...)` use case if func is not None: @@ -1526,7 +1516,7 @@ def g(x, y): and which can be called directly the way a `@defun` wrapped function can. """ - return _trace_and_define_function(func.__name__, func, False, args, kwds) + return _trace_and_define_function(func.__name__, func, args, kwds) class AutomaticControlDependencies(object):