Skip to content

Commit

Permalink
Always convert function calls dynamically. This greatly simplifies th…
Browse files Browse the repository at this point in the history
…e implementation at the cost of less readability of the output.

Also included is a cleanup of the tests that had to be updated due to this change. Addresses #25281, #24759.
Lastly, the CL enables the automatic fallback on compilation error.

PiperOrigin-RevId: 232135777
  • Loading branch information
Dan Moldovan authored and tensorflower-gardener committed Feb 2, 2019
1 parent f22fc30 commit 14c78f9
Show file tree
Hide file tree
Showing 10 changed files with 250 additions and 524 deletions.
22 changes: 10 additions & 12 deletions tensorflow/examples/autograph/integration_tests/keras_test.py
Expand Up @@ -87,18 +87,16 @@ def test_conditional_attributes_True(self):

@test_util.run_deprecated_v1
def test_recursive_true(self):
with self.assertRaisesRegexp(NotImplementedError,
'Object conversion is not yet supported.'):
with tf.Graph().as_default():
model = CompoundModel()
model.build(tf.TensorShape((None, 10, 10, 1)))
init = tf.global_variables_initializer()

with tf.Session() as sess:
self.evaluate(init)
sample_input = tf.random_uniform((1, 10, 10, 1))
output = model(sample_input) # pylint: disable=not-callable
self.assertEqual(self.evaluate(output).shape, (1, 3))
with tf.Graph().as_default():
model = CompoundModel()
model.build(tf.TensorShape((None, 10, 10, 1)))
init = tf.global_variables_initializer()

with tf.Session() as sess:
self.evaluate(init)
sample_input = tf.random_uniform((1, 10, 10, 1))
output = model(sample_input) # pylint: disable=not-callable
self.assertEqual(self.evaluate(output).shape, (1, 3))


if __name__ == '__main__':
Expand Down
293 changes: 25 additions & 268 deletions tensorflow/python/autograph/converters/call_trees.py
Expand Up @@ -22,231 +22,46 @@
from __future__ import division
from __future__ import print_function

import collections

import gast

from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.util import tf_inspect


class FunctionInfo(collections.namedtuple('FunctionInfo', ('dtype',))):
pass


# TODO(mdan): Move this to a separate transformer.
KNOWN_NUMPY_FUNCTIONS = {
('numpy', 'random', 'binomial'): FunctionInfo(dtype='tf.int64'),
}


# TODO(mdan): Get rid of these interfaces. Can now depend directly on Namer.


class FunctionNamer(object):
"""Describes the interface for CallTreeTransformer's namer."""

def compiled_function_name(self,
original_fqn,
live_entity=None,
owner_type=None):
"""Generate the name corresponding to the compiled version of a function.
Args:
original_fqn: string or tuple(string)
live_entity: Callable, the actual target function, if known.
owner_type: Optional object. If present, it indicates that the function is
a member of the given type.
Returns:
string, bool
"""
raise NotImplementedError()

def compiled_class_name(self, original_fqn, live_entity=None):
"""Generate the name corresponding to the compiled version of a class.
Args:
original_fqn: string or tuple(string)
live_entity: The actual target class, if known.
Returns:
string
"""
raise NotImplementedError()


# TODO(mdan): Rename to CallsTransformer.
# TODO(mdan): Rename to FunctionCallsTransformer.


class CallTreeTransformer(converter.Base):
"""Transforms the call tree by renaming transformed symbols."""

def _resolve_decorator_name(self, node):
"""Used to resolve decorator info."""
if isinstance(node, gast.Call):
return self._resolve_decorator_name(node.func)
if isinstance(node, gast.Name):
# TODO(mdan): Add test coverage for this branch.
return self.ctx.info.namespace.get(node.id)
if isinstance(node, gast.Attribute):
parent = self._resolve_decorator_name(node.value)
if parent is not None:
return getattr(parent, node.attr)
return None
raise ValueError(node)

def _try_resolve_target(self, node):
"""Works for methods of objects of known type."""
if anno.hasanno(node, 'live_val'):
return anno.getanno(node, 'live_val')
if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'):
owner_type = anno.getanno(node, 'type')
if hasattr(owner_type, node.attr):
return getattr(owner_type, node.attr)
else:
# TODO(mdan): We should probably return None here rather than an error.
raise ValueError('Type "%s" has no attribute "%s". Is it dynamic?' %
(owner_type, node.attr))
return None

def _function_is_compilable(self, target_entity):
"""Determines whether an entity can be compiled at all."""
# TODO(mdan): Expand.

if target_entity.__module__ is None:
# Functions like builtins and NumPy don't expose a module.
# Those in general should not be compiled.
return False

if inspect_utils.isbuiltin(target_entity):
return False

if inspect_utils.isnamedtuple(target_entity):
# namedtuple doesn't expose its source code, making it uncompilable.
return False

return True

def _should_compile(self, node, fqn):
"""Determines whether an entity should be compiled in the context."""
# TODO(mdan): Needs cleanup. We should remove the use of fqn altogether.
module_name = fqn[0]
for mod in self.ctx.program.uncompiled_modules:
if module_name.startswith(mod[0] + '.'):
return False

for i in range(1, len(fqn)):
if fqn[:i] in self.ctx.program.uncompiled_modules:
return False

target_entity = self._try_resolve_target(node.func)

if target_entity is not None:

# Currently, lambdas are always converted.
# TODO(mdan): Allow markers of the kind f = ag.do_not_convert(lambda: ...)
if inspect_utils.islambda(target_entity):
return True

# This may be reached when "calling" a callable attribute of an object.
# For example:
#
# self.fc = tf.keras.layers.Dense()
# self.fc()
#
for mod in self.ctx.program.uncompiled_modules:
if target_entity.__module__.startswith(mod[0] + '.'):
return False

# Inspect the target function decorators. If any include a @convert
# or @do_not_convert annotation, then they must be called as they are.
# TODO(mdan): This may be quite heavy. Perhaps always dynamically convert?
# To parse and re-analyze each function for every call site could be quite
# wasteful. Maybe we could cache the parsed AST?
try:
target_node, _ = parser.parse_entity(target_entity)
target_node = target_node.body[0]
except TypeError:
# Functions whose source we cannot access are compilable (e.g. wrapped
# to py_func).
return True

# This attribute is set when the decorator was applied before the
# function was parsed. See api.py.
if hasattr(target_entity, '__ag_compiled'):
return False

for dec in target_node.decorator_list:
decorator_fn = self._resolve_decorator_name(dec)
if (decorator_fn is not None and
self.ctx.program.options.should_strip(decorator_fn)):
return False

return True

def _rename_compilable_function(self, node):
assert anno.hasanno(node.func, 'live_val')
assert anno.hasanno(node.func, 'fqn')
target_entity = anno.getanno(node.func, 'live_val')
target_fqn = anno.getanno(node.func, 'fqn')

if anno.hasanno(node, 'is_constructor'):
new_name = self.ctx.namer.compiled_class_name(
target_fqn, live_entity=target_entity)
do_rename = True
else:
if anno.hasanno(node.func, 'parent_type'):
owner_type = anno.getanno(node.func, 'parent_type')
else:
# Fallback - not reliable.
owner_type = inspect_utils.getmethodclass(target_entity)
new_name, do_rename = self.ctx.namer.compiled_function_name(
target_fqn, live_entity=target_entity, owner_type=owner_type)
def visit_FunctionDef(self, node):
node.args = self.visit(node.args)
node.body = self.visit_block(node.body)
# TODO(mdan): Is this correct for local functions?
node.decorator_list = []
if node.returns:
node.returns = self.visit(node.returns)
return node

if do_rename:
if target_entity is not None:
if tf_inspect.ismethod(target_entity):
# The renaming process will transform it into a regular function.
# TODO(mdan): Is this complete? How does it work with nested members?
node.args = [node.func.value] + node.args
node.func = templates.replace_as_expression(
'func_name', func_name=new_name)
def visit_With(self, node):
# Context manager calls (in node.items) are not converted.
node.body = self.visit_block(node.body)
return node

def _wrap_to_py_func_single_return(self, node, dtype):
# TODO(mdan): Properly handle varargs, etc.
template = """
ag__.utils.wrap_py_func(func, dtype, (args,), kwargs, False)
"""
return templates.replace_as_expression(
template,
func=node.func,
dtype=parser.parse_expression(dtype),
args=node.args,
kwargs=ast_util.keywords_to_dict(node.keywords))
def visit_Call(self, node):
# TODO(mdan): Refactor converted_call as a 'Call' operator.

# Calls to the internal 'ag__' module are never converted (though their
# arguments might be).
full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
if full_name.startswith('ag__.'):
return self.generic_visit(node)
if (full_name == 'print' and
not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
return self.generic_visit(node)

def _insert_dynamic_conversion(self, node):
"""Inlines a dynamic conversion for a dynamic function."""
# TODO(mdan): Pass information on the statically compiled functions.
# Having access to the statically compiled functions can help avoid
# unnecessary compilation.
# For example, this would lead to function `a` being compiled twice:
#
# def a():
# v = b
# b()
# def b():
# a()
#
# This is really a problem with recursive calls, which currently can
# only be gated by a static condition, and should be rare.
# TODO(mdan): It probably makes sense to use dynamic conversion every time.
# Before we could convert all the time though, we'd need a reasonable
# caching mechanism.
template = """
ag__.converted_call(func, owner, options, args)
"""
Expand All @@ -256,6 +71,7 @@ def _insert_dynamic_conversion(self, node):
else:
func = node.func
owner = parser.parse_expression('None')

new_call = templates.replace_as_expression(
template,
func=func,
Expand All @@ -266,67 +82,8 @@ def _insert_dynamic_conversion(self, node):
args=node.args)
# TODO(mdan): Improve the template mechanism to better support this.
new_call.keywords = node.keywords
return new_call

def visit_FunctionDef(self, node):
node.args = self.visit(node.args)
node.body = self.visit_block(node.body)
node.decorator_list = []
node.returns = self.visit_block(node.returns)
return node

def visit_Call(self, node):
if anno.hasanno(node.func, 'live_val'):
target_entity = anno.getanno(node.func, 'live_val')

if anno.hasanno(node.func, 'fqn'):
target_fqn = anno.getanno(node.func, 'fqn')
else:
target_fqn = None

if self._function_is_compilable(target_entity):
if self._should_compile(node, target_fqn):
node = self._rename_compilable_function(node)
else:
node = self.generic_visit(node)
return node

elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
# TODO(mdan): Should we replace these with equivalent TF ops instead?
node = self._wrap_to_py_func_single_return(
node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)

elif inspect_utils.isbuiltin(target_entity):
# Note: Any builtin that passed the builtins converter is assumed to be
# safe for graph mode.
return node

elif inspect_utils.isnamedtuple(target_entity):
# Although not compilable, we assume they are safe for graph mode.
node = self.generic_visit(node)
return node

else:
# TODO(mdan): Instert dynamic conversion here instead.
raise NotImplementedError(
'py_func with return values (unknown function)')
else:
# Special cases
# TODO(mdan): These need a systematic review - there may be more.

# 1. super() calls - these are preserved. The class conversion mechanism
# will ensure that they return the correct value.
if ast_util.matches(node, parser.parse_expression('super(_)')):
return node

# 2. super().method calls - these are preserved as well, when the
# conversion processes the entire class.
if (ast_util.matches(node, parser.parse_expression('super(_)._(_)')) and
self.ctx.info.owner_type is not None):
return node

node = self._insert_dynamic_conversion(node)
return node
return new_call


def transform(node, ctx):
Expand Down

0 comments on commit 14c78f9

Please sign in to comment.