Skip to content

Commit

Permalink
Account for the existence of None values in lists of AST nodes. This …
Browse files Browse the repository at this point in the history
…can happen in Python 3 when keyword-only arguments are used. Fixes #28725.

PiperOrigin-RevId: 253034028
  • Loading branch information
Dan Moldovan authored and tensorflower-gardener committed Jun 13, 2019
1 parent 888a58e commit 22ba2eb
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 4 deletions.
18 changes: 18 additions & 0 deletions tensorflow/python/autograph/impl/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,24 @@ tf_py_test(
],
)

py_test(
name = "api_py3_test",
srcs = ["api_py3_test.py"],
python_version = "PY3",
srcs_version = "PY3",
tags = [
"no_oss_py2",
"no_pip",
"nopip",
],
deps = [
":impl",
"//tensorflow/python:client_testlib",
"//tensorflow/python/autograph/utils",
"//third_party/py/numpy",
],
)

tf_py_test(
name = "conversion_test",
srcs = ["conversion_test.py"],
Expand Down
45 changes: 45 additions & 0 deletions tensorflow/python/autograph/impl/api_py3_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for api module."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.impl import api
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test


class ApiTest(test.TestCase):

def test_converted_call_kwonly_args(self):

def test_fn(*, a):
return a

x = api.converted_call(test_fn, None,
converter.ConversionOptions(recursive=True),
(), {'a': constant_op.constant(-1)})
self.assertEqual(-1, self.evaluate(x))


if __name__ == '__main__':
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
test.main()
7 changes: 5 additions & 2 deletions tensorflow/python/autograph/pyct/ast_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,8 @@ def parallel_walk(node, other):
n = node_stack.pop()
o = other_stack.pop()

if (not isinstance(n, (ast.AST, gast.AST, str)) or
not isinstance(o, (ast.AST, gast.AST, str)) or
if ((not isinstance(n, (ast.AST, gast.AST, str)) and n is not None) or
(not isinstance(o, (ast.AST, gast.AST, str)) and n is not None) or
n.__class__.__name__ != o.__class__.__name__):
raise ValueError('inconsistent nodes: {} ({}) and {} ({})'.format(
n, n.__class__.__name__, o, o.__class__.__name__))
Expand All @@ -294,6 +294,9 @@ def parallel_walk(node, other):
if isinstance(n, str):
assert isinstance(o, str), 'The check above should have ensured this'
continue
if n is None:
assert o is None, 'The check above should have ensured this'
continue

for f in n._fields:
n_child = getattr(n, f, None)
Expand Down
10 changes: 8 additions & 2 deletions tensorflow/python/autograph/pyct/pretty_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ def generic_visit(self, node, name=None):
self._print('%s%s=[' % (self._indent(), self._field(f)))
self.indent_lvl += 1
for n in v:
self.generic_visit(n)
if n is not None:
self.generic_visit(n)
else:
self._print('%sNone' % (self._indent()))
self.indent_lvl -= 1
self._print('%s]' % (self._indent()))
else:
Expand All @@ -101,7 +104,10 @@ def generic_visit(self, node, name=None):
self._print('%s%s=(' % (self._indent(), self._field(f)))
self.indent_lvl += 1
for n in v:
self.generic_visit(n)
if n is not None:
self.generic_visit(n)
else:
self._print('%sNone' % (self._indent()))
self.indent_lvl -= 1
self._print('%s)' % (self._indent()))
else:
Expand Down

0 comments on commit 22ba2eb

Please sign in to comment.