From 14c78f9675a29190460d6cd5e10ce38e3be84fad Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Sat, 2 Feb 2019 13:05:11 -0800 Subject: [PATCH] Always convert function calls dynamically. This greatly simplifies the implementation at the cost of less readability of the output. Also included is a cleanup of the tests that had to be updated due to this change. Addresses #25281, #24759. Lastly, the CL enables the automatic fallback on compilation error. PiperOrigin-RevId: 232135777 --- .../autograph/integration_tests/keras_test.py | 22 +- .../python/autograph/converters/call_trees.py | 293 ++---------------- .../autograph/converters/call_trees_test.py | 146 ++------- tensorflow/python/autograph/core/config.py | 19 +- tensorflow/python/autograph/impl/api.py | 221 ++++++++----- .../python/autograph/impl/conversion.py | 47 ++- .../python/autograph/impl/conversion_test.py | 4 +- .../python/autograph/pyct/inspect_utils.py | 9 +- tensorflow/python/autograph/pyct/parser.py | 7 +- .../python/autograph/utils/ag_logging.py | 6 +- 10 files changed, 250 insertions(+), 524 deletions(-) diff --git a/tensorflow/examples/autograph/integration_tests/keras_test.py b/tensorflow/examples/autograph/integration_tests/keras_test.py index 3fe33df920d008..72b62f1ad4d709 100644 --- a/tensorflow/examples/autograph/integration_tests/keras_test.py +++ b/tensorflow/examples/autograph/integration_tests/keras_test.py @@ -87,18 +87,16 @@ def test_conditional_attributes_True(self): @test_util.run_deprecated_v1 def test_recursive_true(self): - with self.assertRaisesRegexp(NotImplementedError, - 'Object conversion is not yet supported.'): - with tf.Graph().as_default(): - model = CompoundModel() - model.build(tf.TensorShape((None, 10, 10, 1))) - init = tf.global_variables_initializer() - - with tf.Session() as sess: - self.evaluate(init) - sample_input = tf.random_uniform((1, 10, 10, 1)) - output = model(sample_input) # pylint: disable=not-callable - self.assertEqual(self.evaluate(output).shape, (1, 3)) + with tf.Graph().as_default(): + model = CompoundModel() + model.build(tf.TensorShape((None, 10, 10, 1))) + init = tf.global_variables_initializer() + + with tf.Session() as sess: + self.evaluate(init) + sample_input = tf.random_uniform((1, 10, 10, 1)) + output = model(sample_input) # pylint: disable=not-callable + self.assertEqual(self.evaluate(output).shape, (1, 3)) if __name__ == '__main__': diff --git a/tensorflow/python/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py index 7026a162a28da4..04439ba5162f11 100644 --- a/tensorflow/python/autograph/converters/call_trees.py +++ b/tensorflow/python/autograph/converters/call_trees.py @@ -22,231 +22,46 @@ from __future__ import division from __future__ import print_function -import collections - import gast from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.pyct import anno -from tensorflow.python.autograph.pyct import ast_util -from tensorflow.python.autograph.pyct import inspect_utils from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import templates -from tensorflow.python.util import tf_inspect - - -class FunctionInfo(collections.namedtuple('FunctionInfo', ('dtype',))): - pass - - -# TODO(mdan): Move this to a separate transformer. -KNOWN_NUMPY_FUNCTIONS = { - ('numpy', 'random', 'binomial'): FunctionInfo(dtype='tf.int64'), -} -# TODO(mdan): Get rid of these interfaces. Can now depend directly on Namer. - - -class FunctionNamer(object): - """Describes the interface for CallTreeTransformer's namer.""" - - def compiled_function_name(self, - original_fqn, - live_entity=None, - owner_type=None): - """Generate the name corresponding to the compiled version of a function. - - Args: - original_fqn: string or tuple(string) - live_entity: Callable, the actual target function, if known. - owner_type: Optional object. If present, it indicates that the function is - a member of the given type. - Returns: - string, bool - """ - raise NotImplementedError() - - def compiled_class_name(self, original_fqn, live_entity=None): - """Generate the name corresponding to the compiled version of a class. - - Args: - original_fqn: string or tuple(string) - live_entity: The actual target class, if known. - Returns: - string - """ - raise NotImplementedError() - - -# TODO(mdan): Rename to CallsTransformer. +# TODO(mdan): Rename to FunctionCallsTransformer. class CallTreeTransformer(converter.Base): """Transforms the call tree by renaming transformed symbols.""" - def _resolve_decorator_name(self, node): - """Used to resolve decorator info.""" - if isinstance(node, gast.Call): - return self._resolve_decorator_name(node.func) - if isinstance(node, gast.Name): - # TODO(mdan): Add test coverage for this branch. - return self.ctx.info.namespace.get(node.id) - if isinstance(node, gast.Attribute): - parent = self._resolve_decorator_name(node.value) - if parent is not None: - return getattr(parent, node.attr) - return None - raise ValueError(node) - - def _try_resolve_target(self, node): - """Works for methods of objects of known type.""" - if anno.hasanno(node, 'live_val'): - return anno.getanno(node, 'live_val') - if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'): - owner_type = anno.getanno(node, 'type') - if hasattr(owner_type, node.attr): - return getattr(owner_type, node.attr) - else: - # TODO(mdan): We should probably return None here rather than an error. - raise ValueError('Type "%s" has no attribute "%s". Is it dynamic?' % - (owner_type, node.attr)) - return None - - def _function_is_compilable(self, target_entity): - """Determines whether an entity can be compiled at all.""" - # TODO(mdan): Expand. - - if target_entity.__module__ is None: - # Functions like builtins and NumPy don't expose a module. - # Those in general should not be compiled. - return False - - if inspect_utils.isbuiltin(target_entity): - return False - - if inspect_utils.isnamedtuple(target_entity): - # namedtuple doesn't expose its source code, making it uncompilable. - return False - - return True - - def _should_compile(self, node, fqn): - """Determines whether an entity should be compiled in the context.""" - # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether. - module_name = fqn[0] - for mod in self.ctx.program.uncompiled_modules: - if module_name.startswith(mod[0] + '.'): - return False - - for i in range(1, len(fqn)): - if fqn[:i] in self.ctx.program.uncompiled_modules: - return False - - target_entity = self._try_resolve_target(node.func) - - if target_entity is not None: - - # Currently, lambdas are always converted. - # TODO(mdan): Allow markers of the kind f = ag.do_not_convert(lambda: ...) - if inspect_utils.islambda(target_entity): - return True - - # This may be reached when "calling" a callable attribute of an object. - # For example: - # - # self.fc = tf.keras.layers.Dense() - # self.fc() - # - for mod in self.ctx.program.uncompiled_modules: - if target_entity.__module__.startswith(mod[0] + '.'): - return False - - # Inspect the target function decorators. If any include a @convert - # or @do_not_convert annotation, then they must be called as they are. - # TODO(mdan): This may be quite heavy. Perhaps always dynamically convert? - # To parse and re-analyze each function for every call site could be quite - # wasteful. Maybe we could cache the parsed AST? - try: - target_node, _ = parser.parse_entity(target_entity) - target_node = target_node.body[0] - except TypeError: - # Functions whose source we cannot access are compilable (e.g. wrapped - # to py_func). - return True - - # This attribute is set when the decorator was applied before the - # function was parsed. See api.py. - if hasattr(target_entity, '__ag_compiled'): - return False - - for dec in target_node.decorator_list: - decorator_fn = self._resolve_decorator_name(dec) - if (decorator_fn is not None and - self.ctx.program.options.should_strip(decorator_fn)): - return False - - return True - - def _rename_compilable_function(self, node): - assert anno.hasanno(node.func, 'live_val') - assert anno.hasanno(node.func, 'fqn') - target_entity = anno.getanno(node.func, 'live_val') - target_fqn = anno.getanno(node.func, 'fqn') - - if anno.hasanno(node, 'is_constructor'): - new_name = self.ctx.namer.compiled_class_name( - target_fqn, live_entity=target_entity) - do_rename = True - else: - if anno.hasanno(node.func, 'parent_type'): - owner_type = anno.getanno(node.func, 'parent_type') - else: - # Fallback - not reliable. - owner_type = inspect_utils.getmethodclass(target_entity) - new_name, do_rename = self.ctx.namer.compiled_function_name( - target_fqn, live_entity=target_entity, owner_type=owner_type) + def visit_FunctionDef(self, node): + node.args = self.visit(node.args) + node.body = self.visit_block(node.body) + # TODO(mdan): Is this correct for local functions? + node.decorator_list = [] + if node.returns: + node.returns = self.visit(node.returns) + return node - if do_rename: - if target_entity is not None: - if tf_inspect.ismethod(target_entity): - # The renaming process will transform it into a regular function. - # TODO(mdan): Is this complete? How does it work with nested members? - node.args = [node.func.value] + node.args - node.func = templates.replace_as_expression( - 'func_name', func_name=new_name) + def visit_With(self, node): + # Context manager calls (in node.items) are not converted. + node.body = self.visit_block(node.body) return node - def _wrap_to_py_func_single_return(self, node, dtype): - # TODO(mdan): Properly handle varargs, etc. - template = """ - ag__.utils.wrap_py_func(func, dtype, (args,), kwargs, False) - """ - return templates.replace_as_expression( - template, - func=node.func, - dtype=parser.parse_expression(dtype), - args=node.args, - kwargs=ast_util.keywords_to_dict(node.keywords)) + def visit_Call(self, node): + # TODO(mdan): Refactor converted_call as a 'Call' operator. + + # Calls to the internal 'ag__' module are never converted (though their + # arguments might be). + full_name = str(anno.getanno(node.func, anno.Basic.QN, default='')) + if full_name.startswith('ag__.'): + return self.generic_visit(node) + if (full_name == 'print' and + not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)): + return self.generic_visit(node) - def _insert_dynamic_conversion(self, node): - """Inlines a dynamic conversion for a dynamic function.""" - # TODO(mdan): Pass information on the statically compiled functions. - # Having access to the statically compiled functions can help avoid - # unnecessary compilation. - # For example, this would lead to function `a` being compiled twice: - # - # def a(): - # v = b - # b() - # def b(): - # a() - # - # This is really a problem with recursive calls, which currently can - # only be gated by a static condition, and should be rare. - # TODO(mdan): It probably makes sense to use dynamic conversion every time. - # Before we could convert all the time though, we'd need a reasonable - # caching mechanism. template = """ ag__.converted_call(func, owner, options, args) """ @@ -256,6 +71,7 @@ def _insert_dynamic_conversion(self, node): else: func = node.func owner = parser.parse_expression('None') + new_call = templates.replace_as_expression( template, func=func, @@ -266,67 +82,8 @@ def _insert_dynamic_conversion(self, node): args=node.args) # TODO(mdan): Improve the template mechanism to better support this. new_call.keywords = node.keywords - return new_call - def visit_FunctionDef(self, node): - node.args = self.visit(node.args) - node.body = self.visit_block(node.body) - node.decorator_list = [] - node.returns = self.visit_block(node.returns) - return node - - def visit_Call(self, node): - if anno.hasanno(node.func, 'live_val'): - target_entity = anno.getanno(node.func, 'live_val') - - if anno.hasanno(node.func, 'fqn'): - target_fqn = anno.getanno(node.func, 'fqn') - else: - target_fqn = None - - if self._function_is_compilable(target_entity): - if self._should_compile(node, target_fqn): - node = self._rename_compilable_function(node) - else: - node = self.generic_visit(node) - return node - - elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: - # TODO(mdan): Should we replace these with equivalent TF ops instead? - node = self._wrap_to_py_func_single_return( - node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) - - elif inspect_utils.isbuiltin(target_entity): - # Note: Any builtin that passed the builtins converter is assumed to be - # safe for graph mode. - return node - - elif inspect_utils.isnamedtuple(target_entity): - # Although not compilable, we assume they are safe for graph mode. - node = self.generic_visit(node) - return node - - else: - # TODO(mdan): Instert dynamic conversion here instead. - raise NotImplementedError( - 'py_func with return values (unknown function)') - else: - # Special cases - # TODO(mdan): These need a systematic review - there may be more. - - # 1. super() calls - these are preserved. The class conversion mechanism - # will ensure that they return the correct value. - if ast_util.matches(node, parser.parse_expression('super(_)')): - return node - - # 2. super().method calls - these are preserved as well, when the - # conversion processes the entire class. - if (ast_util.matches(node, parser.parse_expression('super(_)._(_)')) and - self.ctx.info.owner_type is not None): - return node - - node = self._insert_dynamic_conversion(node) - return node + return new_call def transform(node, ctx): diff --git a/tensorflow/python/autograph/converters/call_trees_test.py b/tensorflow/python/autograph/converters/call_trees_test.py index 454d75d755c727..6ee56bf6bef054 100644 --- a/tensorflow/python/autograph/converters/call_trees_test.py +++ b/tensorflow/python/autograph/converters/call_trees_test.py @@ -18,147 +18,49 @@ from __future__ import division from __future__ import print_function -import collections - -import numpy as np - from tensorflow.python.autograph.converters import call_trees from tensorflow.python.autograph.core import converter_testing -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import math_ops from tensorflow.python.platform import test class CallTreesTest(converter_testing.TestCase): - def test_basic(self): - - def test_fn_1(_): - raise ValueError('This should not be called in the compiled version.') - - def other_test_fn_1(a): - return a + 1 - - def test_fn_2(a): - return test_fn_1(a) + 1 - - ns = {'test_fn_1': test_fn_1} - node, ctx = self.prepare(test_fn_2, ns) - node = call_trees.transform(node, ctx) - - with self.compiled(node, ns) as result: - new_name, _ = ctx.namer.compiled_function_name(('test_fn_1',)) - setattr(result, new_name, other_test_fn_1) - self.assertEquals(result.test_fn_2(1), 3) - - def test_dynamic_function(self): + def test_normal_function(self): - def test_fn_1(): - raise ValueError('This should be masked by the mock in self.compiled.') - - def test_fn_2(f): + def test_fn(f): return f() + 3 - with self.converted(test_fn_2, call_trees, {}) as result: - # 10 = 7 (from the mock) + 3 (from test_fn_2) - self.assertEquals(10, result.test_fn_2(test_fn_1)) + with self.converted(test_fn, call_trees, {}) as result: + self.assertEquals( + result.test_fn(None), + converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3) + self.assertListEqual(self.dynamic_calls, [()]) - def test_basic_method(self): + def test_class_method(self): class TestClass(object): - def test_fn_1(self, a): - return a + 1 - - def test_fn_2(self, a): - return self.test_fn_1(a) + 1 - - ns = {'TestClass': TestClass} - node, ctx = self.prepare( - TestClass.test_fn_2, - ns, - namer=converter_testing.FakeNoRenameNamer(), - arg_types={'self': (TestClass.__name__, TestClass)}) - node = call_trees.transform(node, ctx) - - with self.compiled(node, ns) as result: - tc = TestClass() - self.assertEquals(3, result.test_fn_2(tc, 1)) - - def test_known_called_lambda(self): - - l = lambda x: x - - def test_fn(a): - return l(a) + def test_method(self, a): + return self.other_method(a) + 1 - ns = {'l': l} - node, ctx = self.prepare(test_fn, ns) - node = call_trees.transform(node, ctx) + tc = TestClass() + with self.converted(TestClass.test_method, call_trees, {}) as result: + self.assertEquals(converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 1, + result.test_method(tc, 1)) + self.assertListEqual(self.dynamic_calls, [(1,)]) - with self.compiled(node, ns) as result: - self.assertEquals(1, result.test_fn(1)) + def test_object_method(self): - def test_known_called_namedtuple(self): - - nt = collections.namedtuple('TestNamedTuple', ['a']) - - def test_fn(a): - return nt(a) - - ns = {'nt': nt} - node, ctx = self.prepare(test_fn, ns) - node = call_trees.transform(node, ctx) - - with self.compiled(node, ns) as result: - self.assertEquals(nt(1), result.test_fn(1)) - - def test_py_func_known_function(self): - - def test_fn(): - return np.random.binomial(2, 0.5) - - with self.converted(test_fn, call_trees, {'np': np}, - dtypes.int64) as result: - with self.cached_session() as sess: - self.assertTrue(isinstance(result.test_fn(), ops.Tensor)) - self.assertIn(self.evaluate(result.test_fn()), (0, 1, 2)) - - def test_uncompiled_modules(self): - - def test_fn(a): - a = math_ops.multiply(a, constant_op.constant(2)) - a = math_ops.add(a, constant_op.constant(1)) - return a - - ns = {'math_ops': math_ops, 'constant_op': constant_op} - node, ctx = self.prepare( - test_fn, - ns, - arg_types=set(((math_ops.__name__,), (constant_op.__name__,)))) - node = call_trees.transform(node, ctx) - - with self.compiled(node, ns) as result: - with self.cached_session() as sess: - result_tensor = result.test_fn(constant_op.constant(1)) - self.assertEquals(self.evaluate(result_tensor), 3) - - def test_call_to_decorated_function(self): - - def decorator(f): - return f - - @decorator - def called_fn(a): - return a + class TestClass(object): - def test_fn(a): - return called_fn(a) + def test_method(self, a): + return self.other_method(a) + 1 - node, ctx = self.prepare(test_fn, {'called_fn': called_fn}) - node = call_trees.transform(node, ctx) + tc = TestClass() + with self.converted(tc.test_method, call_trees, {}) as result: + self.assertEquals(converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 1, + result.test_method(tc, 1)) + self.assertListEqual(self.dynamic_calls, [(1,)]) if __name__ == '__main__': diff --git a/tensorflow/python/autograph/core/config.py b/tensorflow/python/autograph/core/config.py index 574f819504e526..5dce3e6deebf8c 100644 --- a/tensorflow/python/autograph/core/config.py +++ b/tensorflow/python/autograph/core/config.py @@ -28,21 +28,16 @@ 'float': float, } -DEFAULT_UNCOMPILED_MODULES = set(( - ('tensorflow',), - (utils.__name__,), - # All of tensorflow's subpackages. Unlike the root tf module, they don't - # have well-known names. Not referring to the module directly to avoid - # circular imports. - ( - utils.__name__[:-len('.python.autograph.utils')],), -)) +def internal_module_name(name): + full_name = utils.__name__ + name_start = full_name.find(name) + name_end = name_start + len(name) + 1 + return full_name[:name_end] -NO_SIDE_EFFECT_CONSTRUCTORS = set(('tensorflow',)) -# TODO(mdan): Also allow controlling the generated names. -# TODO(mdan); Consolidate all internal imports into a single __ag module. +DEFAULT_UNCOMPILED_MODULES = set(((internal_module_name('tensorflow'),),)) + COMPILED_IMPORT_STATEMENTS = ( 'from __future__ import print_function', ) diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index 7e0b5e79b8cfc9..3ce6896e931567 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -18,7 +18,10 @@ from __future__ import division from __future__ import print_function +import collections +import copy import functools +import pdb import sys from enum import Enum @@ -35,9 +38,9 @@ from tensorflow.python.autograph.pyct import compiler from tensorflow.python.autograph.pyct import errors from tensorflow.python.autograph.pyct import inspect_utils +from tensorflow.python.autograph.utils import ag_logging as logging from tensorflow.python.autograph.utils import py_func from tensorflow.python.framework import tensor_util -from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect @@ -161,7 +164,9 @@ def py_func_wrapper(*args, **kwargs): def converted_call(f, owner, options, *args, **kwargs): """Compiles a function call inline. For internal use only.""" - logging.vlog(logging.DEBUG, 'Converted call: %s; owner: %s', f, owner) + logging.log(1, + 'Converted call: %s; owner: %s\n args: %s\n kwargs: %s\n', + f, owner, args, kwargs) if owner is not None: if not isinstance(f, str): @@ -180,10 +185,20 @@ def converted_call(f, owner, options, *args, **kwargs): if inspect_utils.isbuiltin(f): return py_builtins.overload_of(f)(*args, **kwargs) - # Don't convert wrapt-decorated functions/methods. - # TODO(b/122265385): Fully support wrapt + # TODO(b/122265385): Remove this bypass. if ('wrapt' in sys.modules and + hasattr(sys.modules['wrapt'], 'FunctionWrapper') and isinstance(f, sys.modules['wrapt'].FunctionWrapper)): + logging.warn( + 'Entity {} appears to be decorated by wrapt, which is not yet supported' + ' by AutoGraph. The function will be called without transformation.' + ' You may however apply AutoGraph before the decorator.'.format(f), 1) + return f(*args, **kwargs) + + # Other built-in modules are permanently whitelisted. + # TODO(mdan): Figure out how to do this consistently for all stdlib modules. + if (f in collections.__dict__.values() or f in pdb.__dict__.values() or + f in copy.__dict__.values()): return f(*args, **kwargs) # TODO(mdan): This needs cleanup. @@ -213,91 +228,118 @@ def converted_call(f, owner, options, *args, **kwargs): if not options.internal_convert_user_code: return f(*args, **kwargs) - # Unwrap functools.partial objects - # TODO(mdan): Consider sharing unwrapping logic with tf_inspect. - while isinstance(f, functools.partial): - args = f.args + args - new_kwargs = {} - if f.keywords is not None: - new_kwargs.update(f.keywords) - new_kwargs.update(kwargs) - kwargs = new_kwargs - f = f.func - - if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): - # Regular functions - target_entity = f - arg_map_target = f - f_self = inspect_utils.getmethodself(f) - - # TODO(b/119246461): This may be more elegantly handled using __get__? - if f_self is not None: - # If this is a method call, it may or may not include self. - # - # Example when self is included: - # converted_call(to_graph(foo.bar), foo) - # - # Example when self is not included: - # super(...).foo(args) - # - if owner is not None and (not args or args[0] is not owner): - effective_args = (owner,) + args - else: - # When the owner is not specified, use the result of - # inspect_utils.getmethodclass. - # TODO(b/119246461): Make sure an owner is always specified. - if not args or args[0] is not f_self: - effective_args = (f_self,) + args + # TODO(mdan): Move this entire block inside to_graph. + try: # Begin of transformation error guards + + # Unwrap functools.partial objects + # TODO(mdan): Consider sharing unwrapping logic with tf_inspect. + while isinstance(f, functools.partial): + args = f.args + args + new_kwargs = {} + if f.keywords is not None: + new_kwargs.update(f.keywords) + new_kwargs.update(kwargs) + kwargs = new_kwargs + f = f.func + + if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): + # Regular functions + target_entity = f + arg_map_target = f + f_self = inspect_utils.getmethodself(f) + + # TODO(b/119246461): This may be more elegantly handled using __get__? + if f_self is not None: + # If this is a method call, it may or may not include self. + # + # Example when self is included: + # converted_call(to_graph(foo.bar), foo) + # + # Example when self is not included: + # super(...).foo(args) + # + if owner is not None and (not args or args[0] is not owner): + effective_args = (owner,) + args else: - effective_args = (f_self,) + args[1:] - partial_types = (f_self,) - else: + # When the owner is not specified, use the result of + # inspect_utils.getmethodclass. + # TODO(b/119246461): Make sure an owner is always specified. + if not args or args[0] is not f_self: + effective_args = (f_self,) + args + else: + effective_args = (f_self,) + args[1:] + partial_types = (f_self,) + else: + effective_args = args + partial_types = () + + elif tf_inspect.isclass(f): + # Constructors + target_entity = f + arg_map_target = f.__init__ effective_args = args partial_types = () - elif tf_inspect.isclass(f): - # Constructors - target_entity = f - arg_map_target = f.__init__ - effective_args = args - partial_types = () - - elif hasattr(f, '__call__') and hasattr(f, '__class__'): - # Callable objects - target_entity = f.__call__ - arg_map_target = f.__call__ - effective_args = (f,) + args - partial_types = (f.__class__,) - - else: - raise NotImplementedError('unknown callable type "%s"' % type(f)) - - arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs) - arg_types = {} - for name, arg in arg_values.items(): - arg_class = arg.__class__ - arg_types[name] = (arg_class.__name__, arg_class) - - # When called from within a decorator, this is the only indication that - # the function is a method - it appears that the decorator is applied - # before the method is bound. - if not partial_types: - if 'self' in arg_values: - if tf_inspect.isclass(arg_values['self'].__class__): - partial_types = (arg_values['self'].__class__,) - elif 'cls' in arg_values: - if tf_inspect.isclass(arg_values['cls']): - partial_types = (arg_values['cls'],) - - converted_f = to_graph( - target_entity, - recursive=options.recursive, - arg_values=arg_values, - arg_types=arg_types, - experimental_optional_features=options.optional_features, - experimental_strip_decorators=options.strip_decorators, - experimental_verbose=options.verbose, - experimental_partial_types=partial_types) + elif hasattr(f, '__call__') and hasattr(f, '__class__'): + # Callable objects + target_entity = f.__call__ + arg_map_target = f.__call__ + effective_args = (f,) + args + partial_types = (f.__class__,) + + else: + raise NotImplementedError('unknown callable type "%s"' % type(f)) + + arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs) + arg_types = {} + for name, arg in arg_values.items(): + arg_class = arg.__class__ + arg_types[name] = (arg_class.__name__, arg_class) + + # When called from within a decorator, this is the only indication that + # the function is a method - it appears that the decorator is applied + # before the method is bound. + if not partial_types: + if 'self' in arg_values: + if tf_inspect.isclass(arg_values['self'].__class__): + partial_types = (arg_values['self'].__class__,) + elif 'cls' in arg_values: + if tf_inspect.isclass(arg_values['cls']): + partial_types = (arg_values['cls'],) + + logging.log(3, 'Partial types in conversion of %s: %s', target_entity, + partial_types) + + converted_f = to_graph( + target_entity, + recursive=options.recursive, + arg_values=arg_values, + arg_types=arg_types, + experimental_optional_features=options.optional_features, + experimental_strip_decorators=options.strip_decorators, + experimental_verbose=options.verbose, + experimental_partial_types=partial_types) + + if logging.has_verbosity(2): + logging.log(2, 'Defaults of %s : %s', converted_f, + converted_f.__defaults__) + callargs = tf_inspect.getcallargs(converted_f, *effective_args, **kwargs) + formatted_callargs = '\n'.join( + ' {}: {}'.format(k, v) for k, v in callargs.items()) + logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs) + + # TODO(mdan): Reduce this list. + except (errors.AutoGraphError, AssertionError, AttributeError, IndexError, + KeyError, NameError, NotImplementedError, SyntaxError, TypeError, + ValueError, IOError) as e: + logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True) + logging.warn( + 'Entity %s could not be transformed and will be staged without change.' + ' Error details can be found in the logs when running with the env' + ' variable AUTOGRAPH_VERBOSITY=5. Please report this to the AutoGraph' + ' team. Cause: %s', e) + + return f(*args, **kwargs) result = converted_f(*effective_args, **kwargs) @@ -442,8 +484,15 @@ def foo(x): compiled_module.__dict__[key] = val compiled = getattr(compiled_module, name) - if tf_inspect.isfunction(entity): + if hasattr(entity, '__defaults__'): + logging.log(3, 'Default args mapping: %s has: %s', entity, + entity.__defaults__) compiled.__defaults__ = entity.__defaults__ + else: + logging.log(3, 'Default args mapping: %s has no __defaults__', entity) + + logging.log(3, 'Namespace of %s includes: %s', compiled, + compiled_module.__dict__.keys()) if hasattr(compiled, '__globals__'): # Remove self to avoid circular references. This will probably only work diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py index 691210616596e4..32a7d896214f3b 100644 --- a/tensorflow/python/autograph/impl/conversion.py +++ b/tensorflow/python/autograph/impl/conversion.py @@ -20,6 +20,8 @@ import functools import imp +# import types +import unittest import gast @@ -80,28 +82,31 @@ def is_whitelisted_for_graph(o): m = functools else: m = tf_inspect.getmodule(o) - if not hasattr(m, '__name__'): - # Note: typically it's builtins that fall in this category. Builtins will - # be handled by specific code that follows this screening layer. - logging.log(2, '%s is NOT whitelisted: unknown module name', o) - return False - - for prefix, in config.DEFAULT_UNCOMPILED_MODULES: - if m.__name__.startswith(prefix): - logging.log(2, '%s is whitelisted: name starts with "%s"', o, prefix) + + if hasattr(m, '__name__'): + # Builtins typically have unnamed modules. + for prefix, in config.DEFAULT_UNCOMPILED_MODULES: + if m.__name__.startswith(prefix): + logging.log(2, '%s is whitelisted: name starts with "%s"', o, prefix) + return True + + # Temporary -- whitelist tensorboard modules. + # TODO(b/122731813): Remove. + if m.__name__ == 'tensorboard' or '.tensorboard' in m.__name__: + logging.log(2, '%s is whitelisted: name contains "tensorboard"', o) return True if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'): logging.log(2, '%s is whitelisted: already converted', o) return True - if (not inspect_utils.isweakrefself(o) and not tf_inspect.isclass(o) and - hasattr(o, '__call__') and hasattr(o, '__class__')): + if hasattr(o, '__call__'): # Callable objects: whitelisted if their __call__ method is. - call_whitelisted = is_whitelisted_for_graph(o.__call__) - if call_whitelisted: + # The type check avoids infinite recursion around the __call__ method + # of function objects. + if (type(o) != type(o.__call__)) and is_whitelisted_for_graph(o.__call__): # pylint: disable=unidiomatic-typecheck logging.log(2, '%s is whitelisted: object __call__ whitelisted', o) - return call_whitelisted + return True if tf_inspect.ismethod(o): # Methods of whitelisted classes are also whitelisted, even if they are @@ -121,6 +126,10 @@ def is_whitelisted_for_graph(o): owner_class = inspect_utils.getmethodclass(o) if owner_class is not None: + if issubclass(owner_class, unittest.TestCase): + logging.log(2, '%s is whitelisted: method of TestCase subclass', o) + return True + owner_class = inspect_utils.getdefiningclass(o, owner_class) if is_whitelisted_for_graph(owner_class): logging.log(2, '%s is whitelisted: owner is whitelisted %s', o, @@ -132,9 +141,10 @@ def is_whitelisted_for_graph(o): # because they don't expose source code. But we assume they are safe for # graph mode since they are just containers. if tf_inspect.isclass(o) and len(o.__bases__) > 1: - logging.warn_first_n( - 'Entity {} looks like a namedtuple subclass. If it has any custom' - ' methods, they will not be converted by AutoGraph.'.format(o), 1) + logging.warn( + 'Entity {} looks like a namedtuple subclass. Its constructor will' + ' not be converted by AutoGraph, but if it has any custom methods,' + ' those will be.'.format(o), 1) logging.log(2, '%s is whitelisted: named tuple', o) return True @@ -207,6 +217,9 @@ def entity_to_graph(o, program_ctx, arg_values, arg_types): if logging.has_verbosity(2): logging.log(2, 'Compiled output of %s:\n\n%s\n', o, compiler.ast_to_source(node)) + if logging.has_verbosity(4): + for n in node: + logging.log(4, 'Compiled AST of %s:\n\n%s\n', o, gast.dump(n)) if program_ctx.options.recursive: while True: diff --git a/tensorflow/python/autograph/impl/conversion_test.py b/tensorflow/python/autograph/impl/conversion_test.py index cd893e3ff14eaa..ddda4089fd8946 100644 --- a/tensorflow/python/autograph/impl/conversion_test.py +++ b/tensorflow/python/autograph/impl/conversion_test.py @@ -92,11 +92,9 @@ def f(a): conversion.entity_to_graph(f, program_ctx, None, None) self.assertTrue(f in program_ctx.dependency_cache) - self.assertTrue(g in program_ctx.dependency_cache) + self.assertFalse(g in program_ctx.dependency_cache) f_node = program_ctx.dependency_cache[f][0] - g_node = program_ctx.dependency_cache[g][0] self.assertEqual('tf__f', f_node.name) - self.assertEqual('tf__g', g_node.name) def test_entity_to_graph_class_hierarchy(self): diff --git a/tensorflow/python/autograph/pyct/inspect_utils.py b/tensorflow/python/autograph/pyct/inspect_utils.py index 6d9bc43d34652f..eab01ee9cd613b 100644 --- a/tensorflow/python/autograph/pyct/inspect_utils.py +++ b/tensorflow/python/autograph/pyct/inspect_utils.py @@ -31,7 +31,7 @@ # These functions test negative for isinstance(*, types.BuiltinFunctionType) # and inspect.isbuiltin, and are generally not visible in globals(). -# TODO(mdan): Find a more generic way to test this - just enumerate __builtin__? +# TODO(mdan): Remove this. SPECIAL_BUILTINS = { 'dict': dict, 'enumerate': enumerate, @@ -42,6 +42,7 @@ 'print': print, 'range': range, 'tuple': tuple, + 'type': type, 'zip': zip } @@ -73,7 +74,7 @@ def isnamedtuple(f): def isbuiltin(f): """Returns True if the argument is a built-in function.""" - if f in SPECIAL_BUILTINS.values(): + if f in six.moves.builtins.__dict__.values(): return True if isinstance(f, types.BuiltinFunctionType): return True @@ -125,6 +126,10 @@ def getqualifiedname(namespace, object_, max_depth=5, visited=None): if visited is None: visited = set() + # Copy the dict to avoid "changed size error" during concurrent invocations. + # TODO(mdan): This is on the hot path. Can we avoid the copy? + namespace = dict(namespace) + for name in namespace: # The value may be referenced by more than one symbol, case in which # any symbol will be fine. If the program contains symbol aliases that diff --git a/tensorflow/python/autograph/pyct/parser.py b/tensorflow/python/autograph/pyct/parser.py index 011d80dd62afea..67bfb55fc7e98c 100644 --- a/tensorflow/python/autograph/pyct/parser.py +++ b/tensorflow/python/autograph/pyct/parser.py @@ -23,6 +23,7 @@ import re import textwrap +import threading import gast import six @@ -30,10 +31,14 @@ from tensorflow.python.util import tf_inspect +_parse_lock = threading.Lock() # Prevents linecache concurrency errors. + + def parse_entity(entity): """Returns the AST of given entity.""" try: - source = tf_inspect.getsource(entity) + with _parse_lock: + source = tf_inspect.getsource(entity) except (IOError, OSError) as e: raise ValueError( 'Unable to locate the source code of {}. Note that functions defined' diff --git a/tensorflow/python/autograph/utils/ag_logging.py b/tensorflow/python/autograph/utils/ag_logging.py index 847000aa25351d..cd737a829037ee 100644 --- a/tensorflow/python/autograph/utils/ag_logging.py +++ b/tensorflow/python/autograph/utils/ag_logging.py @@ -110,7 +110,7 @@ def get_verbosity(): global verbosity_level if verbosity_level is not None: return verbosity_level - return os.getenv(VERBOSITY_VAR_NAME, DEFAULT_VERBOSITY) + return int(os.getenv(VERBOSITY_VAR_NAME, DEFAULT_VERBOSITY)) def has_verbosity(level): @@ -131,5 +131,9 @@ def log(level, msg, *args, **kwargs): print(msg % args) +def warn(msg, *args, **kwargs): + logging.warn(msg, *args, **kwargs) + + def warn_first_n(msg, *args, **kwargs): logging.log_first_n(logging.WARN, msg, *args, **kwargs)