Skip to content

Commit

Permalink
Improve upgrade script to handle list comprehensions as arguments. (#…
Browse files Browse the repository at this point in the history
…7229)

python's ast module does not return the correct location, so we
have to do our best to scan backwards to find where the [ token
that trully started the list comprehension occurs.
  • Loading branch information
aselle authored and rmlarsen committed Feb 3, 2017
1 parent ee770d9 commit 114a462
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 10 deletions.
103 changes: 93 additions & 10 deletions tensorflow/tools/compatibility/tf_upgrade.py
Expand Up @@ -347,6 +347,62 @@ def _get_attribute_full_path(self, node):
items.append(curr.id)
return ".".join(reversed(items))

def _find_true_position(self, node):
"""Return correct line number and column offset for a given node.
This is necessary mainly because ListComp's location reporting reports
the next token after the list comprehension list opening.
Args:
node: Node for which we wish to know the lineno and col_offset
"""
import re
find_open = re.compile("^\s*(\\[).*$")
find_string_chars = re.compile("['\"]")

if isinstance(node, ast.ListComp):
# Strangely, ast.ListComp returns the col_offset of the first token
# after the '[' token which appears to be a bug. Workaround by
# explicitly finding the real start of the list comprehension.
line = node.lineno
col = node.col_offset
# loop over lines
while 1:
# Reverse the text to and regular expression search for whitespace
text = self._lines[line-1]
reversed_preceding_text = text[:col][::-1]
# First find if a [ can be found with only whitespace between it and
# col.
m = find_open.match(reversed_preceding_text)
if m:
new_col_offset = col - m.start(1) - 1
return line, new_col_offset
else:
if (reversed_preceding_text=="" or
reversed_preceding_text.isspace()):
line = line - 1
prev_line = self._lines[line - 1]
# TODO(aselle):
# this is poor comment detection, but it is good enough for
# cases where the comment does not contain string literal starting/
# ending characters. If ast gave us start and end locations of the
# ast nodes rather than just start, we could use string literal
# node ranges to filter out spurious #'s that appear in string
# literals.
comment_start = prev_line.find("#")
if comment_start == -1:
col = len(prev_line) -1
elif find_string_chars.search(prev_line[comment_start:]) is None:
col = comment_start
else:
return None, None
else:
return None, None
# Most other nodes return proper locations (with notably does not), but
# it is not possible to use that in an argument.
return node.lineno, node.col_offset


def visit_Call(self, node): # pylint: disable=invalid-name
"""Handle visiting a call node in the AST.
Expand Down Expand Up @@ -376,26 +432,53 @@ def visit_Call(self, node): # pylint: disable=invalid-name
if full_name in function_reorders:
reordered = function_reorders[full_name]
for idx, arg in enumerate(node.args):
keyword_arg = reordered[idx]
if (full_name in function_keyword_renames and
keyword_arg in function_keyword_renames[full_name]):
keyword_arg = function_keyword_renames[full_name][keyword_arg]
self._file_edit.add("Added keyword %r to reordered function %r"
% (reordered[idx], full_name), arg.lineno,
arg.col_offset, "", keyword_arg + "=")
lineno, col_offset = self._find_true_position(arg)
if lineno is None or col_offset is None:
self._file_edit.add(
"Failed to add keyword %r to reordered function %r"
% (reordered[idx], full_name), arg.lineno, arg.col_offset,
"", "",
error="A necessary keyword argument failed to be inserted.")
else:
keyword_arg = reordered[idx]
if (full_name in function_keyword_renames and
keyword_arg in function_keyword_renames[full_name]):
keyword_arg = function_keyword_renames[full_name][keyword_arg]
self._file_edit.add("Added keyword %r to reordered function %r"
% (reordered[idx], full_name), lineno,
col_offset, "", keyword_arg + "=")

# Examine each keyword argument and convert it to the final renamed form
renamed_keywords = ({} if full_name not in function_keyword_renames else
function_keyword_renames[full_name])
for keyword in node.keywords:
argkey = keyword.arg
argval = keyword.value

if argkey in renamed_keywords:
self._file_edit.add("Renamed keyword argument from %r to %r" %
argval_lineno, argval_col_offset = self._find_true_position(argval)
if (argval_lineno is not None and argval_col_offset is not None):
# TODO(aselle): We should scan backward to find the start of the
# keyword key. Unfortunately ast does not give you the location of
# keyword keys, so we are forced to infer it from the keyword arg
# value.
key_start = argval_col_offset - len(argkey) - 1
key_end = key_start + len(argkey) + 1
if self._lines[argval_lineno - 1][key_start:key_end] == argkey + "=":
self._file_edit.add("Renamed keyword argument from %r to %r" %
(argkey, renamed_keywords[argkey]),
argval.lineno,
argval.col_offset - len(argkey) - 1,
argval_lineno,
argval_col_offset - len(argkey) - 1,
argkey + "=", renamed_keywords[argkey] + "=")
continue
self._file_edit.add(
"Failed to rename keyword argument from %r to %r" %
(argkey, renamed_keywords[argkey]),
argval.lineno,
argval.col_offset - len(argkey) - 1,
"", "",
error="Failed to find keyword lexographically. Fix manually.")

ast.NodeVisitor.generic_visit(self, node)

def visit_Attribute(self, node): # pylint: disable=invalid-name
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/tools/compatibility/tf_upgrade_test.py
Expand Up @@ -113,6 +113,19 @@ def testReverse(self):
self.assertEqual(new_text, new_text)
self.assertEqual(errors, ["test.py:1: tf.reverse requires manual check."])

def testListComprehension(self):
def _test(input, output):
_, unused_report, errors, new_text = self._upgrade(input)
self.assertEqual(new_text, output)
_test("tf.concat(0, \t[x for x in y])\n",
"tf.concat(axis=0, \tvalues=[x for x in y])\n")
_test("tf.concat(0,[x for x in y])\n",
"tf.concat(axis=0,values=[x for x in y])\n")
_test("tf.concat(0,[\nx for x in y])\n",
"tf.concat(axis=0,values=[\nx for x in y])\n")
_test("tf.concat(0,[\n \tx for x in y])\n",
"tf.concat(axis=0,values=[\n \tx for x in y])\n")

# TODO(aselle): Explicitly not testing command line interface and process_tree
# for now, since this is a one off utility.

Expand Down

0 comments on commit 114a462

Please sign in to comment.