Skip to content

Commit

Permalink
Internal cleanup: simplify the interface of converted_call, and short…
Browse files Browse the repository at this point in the history
…en the names of function scopes. This in turn allows for much less verbose generated code.

PiperOrigin-RevId: 273773506
  • Loading branch information
Dan Moldovan authored and tensorflower-gardener committed Oct 9, 2019
1 parent 720a3a1 commit 6165ddd
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 122 deletions.
3 changes: 1 addition & 2 deletions tensorflow/python/autograph/converters/call_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,11 @@ def visit_Call(self, node):
keywords=ast_util.keywords_to_dict(normal_keywords))

template = """
ag__.converted_call(func, options, args, kwargs, function_ctx)
ag__.converted_call(func, args, kwargs, function_ctx)
"""
new_call = templates.replace_as_expression(
template,
func=func,
options=parser.parse_expression(function_context_name + '.callopts'),
args=args,
kwargs=kwargs,
function_ctx=function_context_name)
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/python/autograph/converters/function_scopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def visit_Lambda(self, node):
return node

scope = anno.getanno(node, anno.Static.SCOPE)
function_context_name = self.ctx.namer.new_symbol('lambda_scope',
function_context_name = self.ctx.namer.new_symbol('lscope',
scope.referenced)
self.state[_Function].context_name = function_context_name
anno.setanno(node, 'function_context_name', function_context_name)
Expand All @@ -79,8 +79,8 @@ def visit_FunctionDef(self, node):
self.state[_Function].enter()
scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)

function_context_name = self.ctx.namer.new_symbol(
'{}_scope'.format(node.name), scope.referenced)
function_context_name = self.ctx.namer.new_symbol('fscope',
scope.referenced)
self.state[_Function].context_name = function_context_name
anno.setanno(node, 'function_context_name', function_context_name)

Expand Down
3 changes: 2 additions & 1 deletion tensorflow/python/autograph/core/converter_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def compiled(self, node, namespace, symbols=()):

self.dynamic_calls = []
# See api.converted_call
def converted_call(f, unused_opts, args, kwargs, unused_function_ctx):
def converted_call(
f, args, kwargs, unused_opts=None, unused_function_ctx=None):
"""Mock version of api.converted_call."""
self.dynamic_calls.append((args, kwargs))
if kwargs is None:
Expand Down
22 changes: 18 additions & 4 deletions tensorflow/python/autograph/impl/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def wrapper(*args, **kwargs):
user_requested=user_requested,
optional_features=optional_features)
try:
return converted_call(f, options, args, kwargs)
return converted_call(f, args, kwargs, options=options)
except Exception as e: # pylint:disable=broad-except
if hasattr(e, 'ag_error_metadata'):
raise e.ag_error_metadata.to_exception(e)
Expand Down Expand Up @@ -368,18 +368,27 @@ def _errors_are_normally_possible(entity, error):
return False


def converted_call(f, options, args, kwargs, caller_fn_scope=None):
def converted_call(f, args, kwargs, caller_fn_scope=None, options=None):
"""Compiles a function call inline.
For internal use only.
Note: The argument list is optimized for readability of generated code, which
may looks something like this:
ag__.converted_call(f, (arg1, arg2), None, fscope)
ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope)
ag__.converted_call(f, (arg1, arg2) + varargs, dict(**kwargs), lscope)
Args:
f: The function to convert.
options: converter.ConversionOptions
args: Tuple, the original positional arguments of f
kwargs: Dict, the original keyword arguments of f
kwargs: Optional[Dict], the original keyword arguments of f
caller_fn_scope: Optional[function_wrappers.FunctionScope], the function
scope of the converted function in which this call was originally made.
options: Optional[converter.ConversionOptions], conversion options. If not
specified, the value of caller_fn_scope.callopts is used. Either options
or caller_fn_scope must be present.
Returns:
Any, the result of executing a possibly-converted `f` with the given
Expand All @@ -388,6 +397,11 @@ def converted_call(f, options, args, kwargs, caller_fn_scope=None):
logging.log(1, 'Converted call: %s\n args: %s\n kwargs: %s\n', f, args,
kwargs)

if options is None:
if caller_fn_scope is None:
raise ValueError('either caller_fn_scope or options must have a value')
options = caller_fn_scope.callopts

if conversion.check_cached_unconverted(f, options):
return _call_unconverted(f, args, kwargs, options, False)

Expand Down
9 changes: 5 additions & 4 deletions tensorflow/python/autograph/impl/api_py3_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test

DEFAULT_RECURSIVE = converter.ConversionOptions(recursive=True)


class ApiTest(test.TestCase):

Expand All @@ -34,8 +36,8 @@ def test_converted_call_kwonly_args(self):
def test_fn(*, a):
return a

x = api.converted_call(test_fn, converter.ConversionOptions(recursive=True),
(), {'a': constant_op.constant(-1)})
x = api.converted_call(
test_fn, (), {'a': constant_op.constant(-1)}, options=DEFAULT_RECURSIVE)
self.assertEqual(-1, self.evaluate(x))

def test_super_with_no_arg(self):
Expand All @@ -54,8 +56,7 @@ def plus_three(self, x):
def no_arg(self, x):
return super().plus_three(x)

tc = api.converted_call(TestSubclass,
converter.ConversionOptions(recursive=True), (), {})
tc = api.converted_call(TestSubclass, (), {}, options=DEFAULT_RECURSIVE)

self.assertEqual(5, tc.no_arg(2))

Expand Down

0 comments on commit 6165ddd

Please sign in to comment.