Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotations to improve the tracing process of tf.function #40901

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
* `tf.keras`:
* <ADD RELEASE NOTES HERE>
* `tf.function`/AutoGraph:
* <ADD RELEASE NOTES HERE>
* Added `experimental_follow_type_hints` argument. When True, the function may use type
annotations to optimize the tracing performance.
* `tf.lite`:
* <ADD RELEASE NOTES HERE>
* `tf.random`:
Expand Down
41 changes: 36 additions & 5 deletions tensorflow/python/eager/def_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,8 @@ def __init__(self,
experimental_implements=None,
experimental_autograph_options=None,
experimental_relax_shapes=False,
experimental_compile=None):
experimental_compile=None,
experimental_follow_type_hints=False):
"""Initializes a `Function`.

Args:
Expand Down Expand Up @@ -512,14 +513,17 @@ def embedding_matmul(a, b):
executor). Set this value to `False` when directly running a
multi-device function on TPUs (e.g. two TPU cores, one TPU core and its
host CPU).
experimental_follow_type_hints: See the documentation for `tf.function`.

Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
argspec has keyword arguments.
"""
self._lock = threading.Lock()
self._python_function = python_function
self._function_spec = function_lib.FunctionSpec.from_function_and_signature(
python_function, input_signature)
python_function, input_signature,
experimental_follow_type_hints=experimental_follow_type_hints)
self._implements = experimental_implements
# If `True`, the function uses the rendezvous of the parent. This is only
# needed to support code where raw send/recv operations are inserted and
Expand All @@ -529,6 +533,7 @@ def embedding_matmul(a, b):
self._experimental_autograph_options = experimental_autograph_options
self._experimental_relax_shapes = experimental_relax_shapes
self._experimental_compile = experimental_compile
self._experimental_follow_type_hints = experimental_follow_type_hints
self._created_variables = None # GUARDED_BY(self._lock)
self._stateful_fn = None # GUARDED_BY(self._lock)
self._stateless_fn = None # GUARDED_BY(self._lock)
Expand Down Expand Up @@ -658,6 +663,7 @@ def _defun(self, fn):
autograph=self._autograph,
experimental_autograph_options=self._experimental_autograph_options,
experimental_compile=self._experimental_compile,
experimental_follow_type_hints=self._experimental_follow_type_hints,
experimental_relax_shapes=self._experimental_relax_shapes)

def _initialize(self, args, kwds, add_initializers_to=None):
Expand Down Expand Up @@ -716,7 +722,8 @@ def _clone(self, python_function):
experimental_implements=self._implements,
experimental_autograph_options=self._experimental_autograph_options,
experimental_relax_shapes=self._experimental_relax_shapes,
experimental_compile=self._experimental_compile)
experimental_compile=self._experimental_compile,
experimental_follow_type_hints=self._experimental_follow_type_hints)

if self._shared_rendezvous:
f._shared_rendezvous = self._shared_rendezvous # pylint: disable=protected-access
Expand Down Expand Up @@ -1203,7 +1210,8 @@ def function(func=None,
experimental_implements=None,
experimental_autograph_options=None,
experimental_relax_shapes=False,
experimental_compile=None):
experimental_compile=None,
experimental_follow_type_hints=False):
"""Compiles a function into a callable TensorFlow graph.

`tf.function` constructs a callable that executes a TensorFlow graph
Expand Down Expand Up @@ -1366,6 +1374,24 @@ def function(func=None,
In general, it is recommended to create stateful objects like `tf.Variable`
outside of `tf.function` and passing them as arguments.

_Using type annotations to improve performance_

'experimental_follow_type_hints` can be used along with type annotations to
improve performance by reducing the number of expensive graph retracings.
For example, an argument annotated with `tf.Tensor` is converted to Tensor
even when the input is a non-Tensor value.

>>> @tf.function(
... experimental_follow_type_hints=True)
... def f(x: tf.Tensor):
... print('Tracing!')
... tf.print('Executing')
>>> f(1)
Tracing!
Executing
>>> f(2)
Executing

Args:
func: the function to be compiled. If `func` is None, `tf.function` returns
a decorator that can be invoked with a single argument - `func`. In other
Expand Down Expand Up @@ -1406,6 +1432,10 @@ def function(func=None,
experimental_compile: If True, the function is always compiled by
[XLA](https://www.tensorflow.org/xla). XLA may be more efficient in some
cases (e.g. TPU, XLA_GPU, dense tensor computations).
experimental_follow_type_hints: When True, the function may use type
annotations to optimize the tracing performance. For example,
arguments annotated with `tf.Tensor` will automatically be converted
to a Tensor.

Returns:
If `func` is not None, returns a callable that will execute the compiled
Expand Down Expand Up @@ -1436,7 +1466,8 @@ def decorated(inner_function):
experimental_autograph_options=experimental_autograph_options,
experimental_relax_shapes=experimental_relax_shapes,
experimental_compile=experimental_compile,
experimental_implements=experimental_implements))
experimental_implements=experimental_implements,
experimental_follow_type_hints=experimental_follow_type_hints))

# This code path is for the `foo = tf.function(foo, ...)` use case
if func is not None:
Expand Down
63 changes: 59 additions & 4 deletions tensorflow/python/eager/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2313,7 +2313,8 @@ class FunctionSpec(object):

@staticmethod
def from_function_and_signature(python_function, input_signature,
is_pure=False):
is_pure=False,
experimental_follow_type_hints=False):
"""Create a FunctionSpec instance given a python function and signature.

Args:
Expand Down Expand Up @@ -2398,13 +2399,17 @@ def from_function_and_signature(python_function, input_signature,
name = getattr(python_function, "__name__", "f")

return FunctionSpec(
fullargspec, is_method, input_signature, is_pure=is_pure, name=name)
fullargspec, is_method, input_signature,
is_pure=is_pure,
experimental_follow_type_hints=experimental_follow_type_hints,
name=name)

def __init__(self,
fullargspec,
is_method,
input_signature,
is_pure=False,
experimental_follow_type_hints=False,
name=None):
"""Constructs a FunctionSpec describing a python function.

Expand All @@ -2419,6 +2424,7 @@ def __init__(self,
self._fullargspec = fullargspec
self._is_method = is_method
self._is_pure = is_pure
self._experimental_follow_type_hints = experimental_follow_type_hints

# TODO(edloper): Include name when serializing for SavedModel?
self._name = name or "f"
Expand Down Expand Up @@ -2487,6 +2493,10 @@ def flat_input_signature(self):
def is_pure(self):
return self._is_pure

@property
def experimental_follow_type_hints(self):
return self._experimental_follow_type_hints

@property
def arg_names(self):
return self._arg_names
Expand Down Expand Up @@ -2525,6 +2535,43 @@ def _convert_variables_to_tensors(self, args, kwargs):
kwargs = {kw: ops.convert_to_tensor(x) for kw, x in kwargs.items()}
return tuple(args), kwargs

def _convert_annotated_args_to_tensors(self, args, kwargs):
if self.input_signature is not None:
return

args = list(args)
for i, arg in enumerate(args):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional, if you get the chance: it would be ideal to disentangle these blocks, and separate them into stages:

  1. one set of loops to find type annotations, if any, resulting in an arg_annos, and kwarg_annos
  2. another set of loops to convert arguments as necessary

That way the logic is easier to follow an verify, and you avoid the duplication of the "if annotation == ops.Tensor: convert" logic, which is expected to become more complex in time.

# See https://docs.python.org/3/library/inspect.html#inspect.getfullargspec for details on fullargspec
if i < len(self._fullargspec.args):
arg_annotation = self._fullargspec.annotations.get(
self._fullargspec.args[i])
# TODO(rahulkamat): Once TensorLike is ready, change the following conditional statements
# to check if the input arg is annotated with TensorLike
if arg_annotation == ops.Tensor:
args[i] = ops.convert_to_tensor(arg)
else:
mdanatg marked this conversation as resolved.
Show resolved Hide resolved
varargs_annotation = self._fullargspec.annotations.get(
self._fullargspec.varargs)
if varargs_annotation == ops.Tensor:
args[i] = ops.convert_to_tensor(arg)

for kw, v in kwargs.items():
if kw in self._fullargspec.kwonlyargs:
kwonlyarg_annotation = self._fullargspec.annotations.get(kw)
if kwonlyarg_annotation == ops.Tensor:
kwargs[kw] = ops.convert_to_tensor(v)
elif self._fullargspec.varkw is not None:
varkw_annotation = self._fullargspec.annotations.get(
self._fullargspec.varkw)
if kw in self._fullargspec.args:
arg_annotation = self._fullargspec.annotations.get(kw)
if arg_annotation == ops.Tensor:
kwargs[kw] = ops.convert_to_tensor(v)
elif varkw_annotation == ops.Tensor:
kwargs[kw] = ops.convert_to_tensor(v)

return tuple(args), kwargs

def canonicalize_function_inputs(self, *args, **kwargs):
"""Canonicalizes `args` and `kwargs`.

Expand Down Expand Up @@ -2557,6 +2604,8 @@ def canonicalize_function_inputs(self, *args, **kwargs):
"""
if self._is_pure:
args, kwargs = self._convert_variables_to_tensors(args, kwargs)
if self._experimental_follow_type_hints:
args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs)
if self._input_signature is not None:
if len(args) > len(self._input_signature):
raise TypeError("{} takes {} positional arguments (as specified by the "
Expand Down Expand Up @@ -2788,7 +2837,8 @@ def __init__(self,
autograph_options=None,
experimental_relax_shapes=False,
capture_by_value=None,
experimental_compile=None):
experimental_compile=None,
experimental_follow_type_hints=False):
"""Initializes a `Function`.

Args:
Expand All @@ -2812,6 +2862,7 @@ def __init__(self,
default to False.
experimental_compile: Force-compile the function with XLA, cf.
def_function.Function doc on experimental_compile.
experimental_follow_type_hints: See the documentation for `tf.function`.

Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
Expand All @@ -2820,7 +2871,8 @@ def __init__(self,
self._python_function = python_function
pure_function = attributes and IMPLEMENTS_ATTRIBUTE_NAME in attributes
self._function_spec = FunctionSpec.from_function_and_signature(
python_function, input_signature, is_pure=pure_function)
python_function, input_signature, is_pure=pure_function,
experimental_follow_type_hints=experimental_follow_type_hints)
self._name = name
self._autograph = autograph
self._autograph_options = autograph_options
Expand All @@ -2836,6 +2888,7 @@ def __init__(self,
# functions for each instance.
self._descriptor_cache = weakref.WeakKeyDictionary()
self._experimental_compile = experimental_compile
self._experimental_follow_type_hints = experimental_follow_type_hints

def __call__(self, *args, **kwargs):
"""Calls a graph function specialized to the inputs."""
Expand Down Expand Up @@ -3612,6 +3665,7 @@ def defun_with_attributes(func=None,
autograph=True,
experimental_autograph_options=None,
experimental_compile=None,
experimental_follow_type_hints=False,
experimental_relax_shapes=False):
"""Compiles a Python function into a callable TensorFlow graph.

Expand Down Expand Up @@ -3661,6 +3715,7 @@ def decorated(function):
autograph=autograph,
autograph_options=experimental_autograph_options,
experimental_compile=experimental_compile,
experimental_follow_type_hints=experimental_follow_type_hints,
experimental_relax_shapes=experimental_relax_shapes))

# This code path is for the `foo = tfe.defun(foo, ...)` use case
Expand Down