diff --git a/tensorflow/python/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py index 5a5a2c95ddec16..87db74d8dad078 100644 --- a/tensorflow/python/autograph/converters/call_trees.py +++ b/tensorflow/python/autograph/converters/call_trees.py @@ -174,12 +174,11 @@ def visit_Call(self, node): keywords=ast_util.keywords_to_dict(normal_keywords)) template = """ - ag__.converted_call(func, options, args, kwargs, function_ctx) + ag__.converted_call(func, args, kwargs, function_ctx) """ new_call = templates.replace_as_expression( template, func=func, - options=parser.parse_expression(function_context_name + '.callopts'), args=args, kwargs=kwargs, function_ctx=function_context_name) diff --git a/tensorflow/python/autograph/converters/function_scopes.py b/tensorflow/python/autograph/converters/function_scopes.py index 52bd701b79083a..d0ba142545a35b 100644 --- a/tensorflow/python/autograph/converters/function_scopes.py +++ b/tensorflow/python/autograph/converters/function_scopes.py @@ -56,7 +56,7 @@ def visit_Lambda(self, node): return node scope = anno.getanno(node, anno.Static.SCOPE) - function_context_name = self.ctx.namer.new_symbol('lambda_scope', + function_context_name = self.ctx.namer.new_symbol('lscope', scope.referenced) self.state[_Function].context_name = function_context_name anno.setanno(node, 'function_context_name', function_context_name) @@ -79,8 +79,8 @@ def visit_FunctionDef(self, node): self.state[_Function].enter() scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) - function_context_name = self.ctx.namer.new_symbol( - '{}_scope'.format(node.name), scope.referenced) + function_context_name = self.ctx.namer.new_symbol('fscope', + scope.referenced) self.state[_Function].context_name = function_context_name anno.setanno(node, 'function_context_name', function_context_name) diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py index 7560b436ef5cb9..a5533188b45815 100644 --- a/tensorflow/python/autograph/core/converter_testing.py +++ b/tensorflow/python/autograph/core/converter_testing.py @@ -57,7 +57,8 @@ def compiled(self, node, namespace, symbols=()): self.dynamic_calls = [] # See api.converted_call - def converted_call(f, unused_opts, args, kwargs, unused_function_ctx): + def converted_call( + f, args, kwargs, unused_opts=None, unused_function_ctx=None): """Mock version of api.converted_call.""" self.dynamic_calls.append((args, kwargs)) if kwargs is None: diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index 8a6ea7e08c39fd..5754249bc917ad 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -231,7 +231,7 @@ def wrapper(*args, **kwargs): user_requested=user_requested, optional_features=optional_features) try: - return converted_call(f, options, args, kwargs) + return converted_call(f, args, kwargs, options=options) except Exception as e: # pylint:disable=broad-except if hasattr(e, 'ag_error_metadata'): raise e.ag_error_metadata.to_exception(e) @@ -368,18 +368,27 @@ def _errors_are_normally_possible(entity, error): return False -def converted_call(f, options, args, kwargs, caller_fn_scope=None): +def converted_call(f, args, kwargs, caller_fn_scope=None, options=None): """Compiles a function call inline. For internal use only. + Note: The argument list is optimized for readability of generated code, which + may looks something like this: + + ag__.converted_call(f, (arg1, arg2), None, fscope) + ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope) + ag__.converted_call(f, (arg1, arg2) + varargs, dict(**kwargs), lscope) + Args: f: The function to convert. - options: converter.ConversionOptions args: Tuple, the original positional arguments of f - kwargs: Dict, the original keyword arguments of f + kwargs: Optional[Dict], the original keyword arguments of f caller_fn_scope: Optional[function_wrappers.FunctionScope], the function scope of the converted function in which this call was originally made. + options: Optional[converter.ConversionOptions], conversion options. If not + specified, the value of caller_fn_scope.callopts is used. Either options + or caller_fn_scope must be present. Returns: Any, the result of executing a possibly-converted `f` with the given @@ -388,6 +397,11 @@ def converted_call(f, options, args, kwargs, caller_fn_scope=None): logging.log(1, 'Converted call: %s\n args: %s\n kwargs: %s\n', f, args, kwargs) + if options is None: + if caller_fn_scope is None: + raise ValueError('either caller_fn_scope or options must have a value') + options = caller_fn_scope.callopts + if conversion.check_cached_unconverted(f, options): return _call_unconverted(f, args, kwargs, options, False) diff --git a/tensorflow/python/autograph/impl/api_py3_test.py b/tensorflow/python/autograph/impl/api_py3_test.py index d1ae2152bd2b24..9f8a4b3f31ddd0 100644 --- a/tensorflow/python/autograph/impl/api_py3_test.py +++ b/tensorflow/python/autograph/impl/api_py3_test.py @@ -26,6 +26,8 @@ from tensorflow.python.framework import constant_op from tensorflow.python.platform import test +DEFAULT_RECURSIVE = converter.ConversionOptions(recursive=True) + class ApiTest(test.TestCase): @@ -34,8 +36,8 @@ def test_converted_call_kwonly_args(self): def test_fn(*, a): return a - x = api.converted_call(test_fn, converter.ConversionOptions(recursive=True), - (), {'a': constant_op.constant(-1)}) + x = api.converted_call( + test_fn, (), {'a': constant_op.constant(-1)}, options=DEFAULT_RECURSIVE) self.assertEqual(-1, self.evaluate(x)) def test_super_with_no_arg(self): @@ -54,8 +56,7 @@ def plus_three(self, x): def no_arg(self, x): return super().plus_three(x) - tc = api.converted_call(TestSubclass, - converter.ConversionOptions(recursive=True), (), {}) + tc = api.converted_call(TestSubclass, (), {}, options=DEFAULT_RECURSIVE) self.assertEqual(5, tc.no_arg(2)) diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py index 1a3ed4ffc6bfde..cccbb54b810c77 100644 --- a/tensorflow/python/autograph/impl/api_test.py +++ b/tensorflow/python/autograph/impl/api_test.py @@ -53,6 +53,8 @@ global_n = 2 +DEFAULT_RECURSIVE = converter.ConversionOptions(recursive=True) + class TestResource(object): @@ -207,9 +209,8 @@ def called_member(self, a): @api.convert(recursive=True) def test_method(self, x, s, a): while tf.reduce_sum(x) > s: - x //= api.converted_call(self.called_member, - converter.ConversionOptions(recursive=True), - (a,), {}) + x //= api.converted_call( + self.called_member, (a,), None, options=DEFAULT_RECURSIVE) return x tc = TestClass() @@ -219,13 +220,13 @@ def test_method(self, x, s, a): self.assertListEqual([0, 1], self.evaluate(x).tolist()) def test_converted_call_builtin(self): - x = api.converted_call(range, converter.ConversionOptions(recursive=True), - (3,), {}) + x = api.converted_call(range, (3,), None, options=DEFAULT_RECURSIVE) self.assertEqual((0, 1, 2), tuple(x)) - x = api.converted_call(re.compile, - converter.ConversionOptions(recursive=True), - ('mnas_v4_a.*\\/.*(weights|kernel):0$',), {}) + x = api.converted_call( + re.compile, ('mnas_v4_a.*\\/.*(weights|kernel):0$',), + None, + options=DEFAULT_RECURSIVE) self.assertIsNotNone(x.match('mnas_v4_a/weights:0')) def test_converted_call_function(self): @@ -235,8 +236,8 @@ def test_fn(x): return -x return x - x = api.converted_call(test_fn, converter.ConversionOptions(recursive=True), - (constant_op.constant(-1),), {}) + x = api.converted_call( + test_fn, (constant_op.constant(-1),), None, options=DEFAULT_RECURSIVE) self.assertEqual(1, self.evaluate(x)) @test_util.run_v1_only('b/120545219') @@ -249,15 +250,17 @@ def test_fn(x, y, z): x = api.converted_call( functools.partial(test_fn, constant_op.constant(-1), z=-3), - converter.ConversionOptions(recursive=True), - (constant_op.constant(-2),), {}) + (constant_op.constant(-2),), + None, + options=DEFAULT_RECURSIVE) self.assertEqual((1, 2, 3), self.evaluate(x)) x = api.converted_call( functools.partial( functools.partial(test_fn, constant_op.constant(-1)), z=-3), - converter.ConversionOptions(recursive=True), - (constant_op.constant(-2),), {}) + (constant_op.constant(-2),), + None, + options=DEFAULT_RECURSIVE) self.assertEqual((1, 2, 3), self.evaluate(x)) def test_converted_call_method(self): @@ -273,8 +276,7 @@ def test_method(self): return self.x tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(tc.test_method, - converter.ConversionOptions(recursive=True), (), {}) + x = api.converted_call(tc.test_method, (), None, options=DEFAULT_RECURSIVE) self.assertEqual(1, self.evaluate(x)) def test_converted_call_synthetic_method(self): @@ -292,8 +294,7 @@ def test_function(self): tc = TestClass(constant_op.constant(-1)) test_method = types.MethodType(test_function, tc) - x = api.converted_call(test_method, - converter.ConversionOptions(recursive=True), (), {}) + x = api.converted_call(test_method, (), None, options=DEFAULT_RECURSIVE) self.assertEqual(1, self.evaluate(x)) def test_converted_call_method_wrapper(self): @@ -306,9 +307,8 @@ def foo(self): tc = TestClass() # `method.__get__()` returns a so-called method-wrapper. - wrapper = api.converted_call(tc.foo.__get__, - converter.ConversionOptions(recursive=True), - (tc,), {}) + wrapper = api.converted_call( + tc.foo.__get__, (tc,), None, options=DEFAULT_RECURSIVE) self.assertEqual(wrapper, tc.foo) def test_converted_call_method_as_object_attribute(self): @@ -331,8 +331,8 @@ def __init__(self, another_obj_method): obj = AnotherClass() tc = TestClass(obj.method) - x = api.converted_call(tc.another_obj_method, - converter.ConversionOptions(recursive=True), (), {}) + x = api.converted_call( + tc.another_obj_method, (), None, options=DEFAULT_RECURSIVE) self.assertEqual(self.evaluate(x), 2) def test_converted_call_method_converts_recursively(self): @@ -351,8 +351,7 @@ def test_method(self): return self.other_method() tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(tc.test_method, - converter.ConversionOptions(recursive=True), (), {}) + x = api.converted_call(tc.test_method, (), None, options=DEFAULT_RECURSIVE) self.assertEqual(1, self.evaluate(x)) def test_converted_call_method_by_class(self): @@ -368,9 +367,8 @@ def test_method(self): return self.x tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(TestClass.test_method, - converter.ConversionOptions(recursive=True), (tc,), - {}) + x = api.converted_call( + TestClass.test_method, (tc,), None, options=DEFAULT_RECURSIVE) self.assertEqual(1, self.evaluate(x)) def test_converted_call_callable_object(self): @@ -386,8 +384,7 @@ def __call__(self): return self.x tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(tc, converter.ConversionOptions(recursive=True), (), - {}) + x = api.converted_call(tc, (), None, options=DEFAULT_RECURSIVE) self.assertEqual(1, self.evaluate(x)) def test_converted_call_callable_metaclass(self): @@ -405,8 +402,7 @@ def __call__(cls): # This functools.partial will hide the class form the constructor # check. Not ideal. See b/120224672. tc = functools.partial(tc) - converted_tc = api.converted_call( - tc, converter.ConversionOptions(recursive=True), (), {}) + converted_tc = api.converted_call(tc, (), None, options=DEFAULT_RECURSIVE) self.assertIsInstance(converted_tc, TestMetaclass) self.assertEqual(1, self.evaluate(converted_tc.x)) @@ -423,9 +419,8 @@ def test_method(self): return -self.x return self.x - tc = api.converted_call(TestClass, - converter.ConversionOptions(recursive=True), - (constant_op.constant(-1),), {}) + tc = api.converted_call( + TestClass, (constant_op.constant(-1),), None, options=DEFAULT_RECURSIVE) # tc is still a TestClass - constructors are whitelisted. # TODO(b/124016764): Support this use case. # The error below is specific to the `if` statement not being converted. @@ -447,8 +442,7 @@ def test_method(self): tc = TestClass(constant_op.constant(-1)) # The error below is specific to the `if` statement not being converted. with self.assertRaisesRegex(NotImplementedError, 'Mangled names'): - api.converted_call(tc.test_method, - converter.ConversionOptions(recursive=True), (), {}) + api.converted_call(tc.test_method, (), None, options=DEFAULT_RECURSIVE) tc.test_method() def test_converted_call_already_converted(self): @@ -456,15 +450,16 @@ def test_converted_call_already_converted(self): def f(x): return x == 0 - x = api.converted_call(f, converter.ConversionOptions(recursive=True), - (constant_op.constant(0),), {}) + x = api.converted_call( + f, (constant_op.constant(0),), None, options=DEFAULT_RECURSIVE) self.assertTrue(self.evaluate(x)) converted_f = api.to_graph( f, experimental_optional_features=converter.Feature.ALL) - x = api.converted_call(converted_f, - converter.ConversionOptions(recursive=True), - (constant_op.constant(0),), {}) + x = api.converted_call( + converted_f, (constant_op.constant(0),), + None, + options=DEFAULT_RECURSIVE) self.assertTrue(self.evaluate(x)) def test_converted_call_then_already_converted_dynamic(self): @@ -479,8 +474,8 @@ def g(x): def f(g, x): return g(x) - x = api.converted_call(f, converter.ConversionOptions(recursive=True), - (g, constant_op.constant(1)), {}) + x = api.converted_call( + f, (g, constant_op.constant(1)), None, options=DEFAULT_RECURSIVE) self.assertEqual(self.evaluate(x), 1) def test_converted_call_forced_when_explicitly_whitelisted(self): @@ -489,16 +484,13 @@ def test_converted_call_forced_when_explicitly_whitelisted(self): def f(x): return x + 1 - x = api.converted_call( - f, converter.ConversionOptions(recursive=True, user_requested=True), - (constant_op.constant(0),), {}) + opts = converter.ConversionOptions(recursive=True, user_requested=True) + x = api.converted_call(f, (constant_op.constant(0),), None, options=opts) self.assertTrue(self.evaluate(x)) converted_f = api.to_graph( f, experimental_optional_features=converter.Feature.ALL) - x = api.converted_call(converted_f, - converter.ConversionOptions(recursive=True), (0,), - {}) + x = api.converted_call(converted_f, (0,), None, options=DEFAULT_RECURSIVE) self.assertEqual(x, 1) @test_util.run_deprecated_v1 @@ -511,10 +503,11 @@ def f(x): # f should not be converted, causing len to error out. with self.assertRaisesRegexp(Exception, 'len is not well defined'): - api.converted_call(f, opts, (constant_op.constant([0]),), {}) + api.converted_call(f, (constant_op.constant([0]),), None, options=opts) # len on the other hand should work fine. - x = api.converted_call(len, opts, (constant_op.constant([0]),), {}) + x = api.converted_call( + len, (constant_op.constant([0]),), None, options=opts) # The constant has static shape so the result is a primitive not a Tensor. self.assertEqual(x, 1) @@ -525,38 +518,34 @@ def f(*args): return np.broadcast(args[:1]) opts = converter.ConversionOptions(internal_convert_user_code=False) - - self.assertIsNotNone(api.converted_call(f, opts, (1, 2, 3, 4), None)) + self.assertIsNotNone( + api.converted_call(f, (1, 2, 3, 4), None, options=opts)) def test_converted_call_whitelisted_method(self): - opts = converter.ConversionOptions(recursive=True) - model = sequential.Sequential([core.Dense(2)]) - x = api.converted_call(model.call, opts, (constant_op.constant([[0.0]]),), - {'training': True}) + x = api.converted_call( + model.call, (constant_op.constant([[0.0]]),), {'training': True}, + options=DEFAULT_RECURSIVE) self.evaluate(variables.global_variables_initializer()) self.assertAllEqual([[0.0, 0.0]], self.evaluate(x)) def test_converted_call_whitelisted_method_via_owner(self): - opts = converter.ConversionOptions(recursive=True) - model = sequential.Sequential([core.Dense(2)]) - x = api.converted_call(model.call, opts, (constant_op.constant([[0.0]]),), - {'training': True}) + x = api.converted_call( + model.call, (constant_op.constant([[0.0]]),), {'training': True}, + options=DEFAULT_RECURSIVE) self.evaluate(variables.global_variables_initializer()) self.assertAllEqual([[0.0, 0.0]], self.evaluate(x)) def test_converted_call_numpy(self): - opts = converter.ConversionOptions(recursive=True) - - x = api.converted_call(np.arange, opts, (5,), {}) + x = api.converted_call(np.arange, (5,), None, options=DEFAULT_RECURSIVE) self.assertAllEqual(x, list(range(5))) @@ -566,7 +555,7 @@ def test_converted_call_tf_op_forced(self): opts = converter.ConversionOptions( user_requested=True, optional_features=None) - x = api.converted_call(gen_math_ops.add, opts, (1, 1), {}) + x = api.converted_call(gen_math_ops.add, (1, 1), None, options=opts) self.assertAllEqual(self.evaluate(x), 2) @@ -580,25 +569,25 @@ def foo(x): exec(textwrap.dedent(dynamic_code), temp_mod.__dict__) # pylint:disable=exec-used opts = converter.ConversionOptions(optional_features=None) - x = api.converted_call(temp_mod.foo, opts, (1,), {}) + x = api.converted_call(temp_mod.foo, (1,), None, options=opts) self.assertAllEqual(x, 2) def test_converted_call_namedtuple(self): - opts = converter.ConversionOptions(recursive=True) - - x = api.converted_call(collections.namedtuple, opts, - ('TestNamedtuple', ('a', 'b')), {}) + x = api.converted_call( + collections.namedtuple, ('TestNamedtuple', ('a', 'b')), + None, + options=DEFAULT_RECURSIVE) self.assertTrue(inspect_utils.isnamedtuple(x)) def test_converted_call_namedtuple_via_collections(self): - opts = converter.ConversionOptions(recursive=True) - - x = api.converted_call(collections.namedtuple, opts, - ('TestNamedtuple', ('a', 'b')), {}) + x = api.converted_call( + collections.namedtuple, ('TestNamedtuple', ('a', 'b')), + None, + options=DEFAULT_RECURSIVE) self.assertTrue(inspect_utils.isnamedtuple(x)) @@ -611,11 +600,11 @@ def test_method(self, x): x //= self.b return x - opts = converter.ConversionOptions(recursive=True) - obj = TestClass(5, 2) - x = api.converted_call(obj.test_method, opts, - (constant_op.constant([2, 4]),), {}) + x = api.converted_call( + obj.test_method, (constant_op.constant([2, 4]),), + None, + options=DEFAULT_RECURSIVE) self.assertAllEqual(self.evaluate(x), [1, 2]) @@ -624,11 +613,9 @@ def test_converted_call_namedtuple_method(self): class TestClass(collections.namedtuple('TestNamedtuple', ('a', 'b'))): pass - opts = converter.ConversionOptions(recursive=True) - obj = TestClass(5, 2) # _asdict is a documented method of namedtuple. - x = api.converted_call(obj._asdict, opts, (), {}) + x = api.converted_call(obj._asdict, (), None, options=DEFAULT_RECURSIVE) self.assertDictEqual(x, {'a': 5, 'b': 2}) @@ -641,29 +628,26 @@ def test_method(self, x): x //= self.b return x - opts = converter.ConversionOptions(recursive=True) - obj = TestClass(5, 2) - x = api.converted_call(TestClass.test_method, opts, - (obj, constant_op.constant([2, 4])), {}) + x = api.converted_call( + TestClass.test_method, (obj, constant_op.constant([2, 4])), + None, + options=DEFAULT_RECURSIVE) self.assertAllEqual(self.evaluate(x), [1, 2]) def test_converted_call_lambda(self): - opts = converter.ConversionOptions(recursive=True) - l = lambda x: x == 0 - x = api.converted_call(l, opts, (constant_op.constant(0),), {}) + x = api.converted_call( + l, (constant_op.constant(0),), None, options=DEFAULT_RECURSIVE) self.evaluate(variables.global_variables_initializer()) self.assertAllEqual(True, self.evaluate(x)) def test_converted_call_defun_object_method(self): - opts = converter.ConversionOptions(recursive=True) - # pylint:disable=method-hidden class TestClass(object): @@ -678,7 +662,7 @@ def prepare(self): tc = TestClass() tc.prepare() - x = api.converted_call(tc.method, opts, (), {}) + x = api.converted_call(tc.method, (), None, options=DEFAULT_RECURSIVE) self.assertAllEqual(1, self.evaluate(x)) @@ -695,8 +679,7 @@ def f(): # Dataset iteration only works inside tf. @def_function.function def graph_fn(): - opts = converter.ConversionOptions(recursive=True) - ds = api.converted_call(f, opts, (), {}) + ds = api.converted_call(f, (), None, options=DEFAULT_RECURSIVE) itr = iter(ds) return next(itr), next(itr), next(itr) @@ -719,8 +702,7 @@ def test_fn(): def f(y): return res.x + y - opts = converter.ConversionOptions(recursive=True) - api.converted_call(f, opts, (1,), {}) + api.converted_call(f, (1,), None, options=DEFAULT_RECURSIVE) self.assertNoMemoryLeaks(test_fn) @@ -736,8 +718,7 @@ def inner_f(): return inner_f - opts = converter.ConversionOptions(recursive=True) - api.converted_call(f, opts, (1,), {})() + api.converted_call(f, (1,), None, options=DEFAULT_RECURSIVE)() self.assertNoMemoryLeaks(test_fn) @@ -869,10 +850,12 @@ def test_fn(): self.assertNotEqual(converted_recursive.ag_module, converted_non_recursive.ag_module) - self.assertRegex(tf_inspect.getsource(converted_recursive), - 'FunctionScope(.*recursive=True.*)') - self.assertRegex(tf_inspect.getsource(converted_non_recursive), - 'FunctionScope(.*recursive=False.*)') + self.assertRegex( + tf_inspect.getsource(converted_recursive), + 'FunctionScope(.*recursive=True.*)') + self.assertRegex( + tf_inspect.getsource(converted_non_recursive), + 'FunctionScope(.*recursive=False.*)') def test_to_graph_preserves_bindings(self): y = 3 @@ -1025,8 +1008,7 @@ def one_arg(self, x): test_base = test_base_unbound.__get__(self, TestSubclass) return test_base.plus_three(x) - tc = api.converted_call(TestSubclass, - converter.ConversionOptions(recursive=True), (), {}) + tc = api.converted_call(TestSubclass, (), None, options=DEFAULT_RECURSIVE) self.assertEqual(5, tc.one_arg(2)) @@ -1046,8 +1028,7 @@ def plus_three(self, x): def two_args(self, x): return super(TestSubclass, self).plus_three(x) - tc = api.converted_call(TestSubclass, - converter.ConversionOptions(recursive=True), (), {}) + tc = api.converted_call(TestSubclass, (), None, options=DEFAULT_RECURSIVE) self.assertEqual(5, tc.two_args(2)) diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index d375e4afc5f162..d22190d6297225 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -914,11 +914,13 @@ def wrapper(*args, **kwargs): try: return autograph.converted_call( original_func, - autograph.ConversionOptions( + args, + kwargs, + options=autograph.ConversionOptions( recursive=True, optional_features=autograph_options, user_requested=True, - ), args, kwargs) + )) except Exception as e: # pylint:disable=broad-except if hasattr(e, "ag_error_metadata"): raise e.ag_error_metadata.to_exception(e)