Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversion script update (adding concat_dim to axis rename for tf.concat and tf.pack renames) #7119

Merged
merged 7 commits into from Jan 30, 2017
3 changes: 3 additions & 0 deletions tensorflow/tools/compatibility/README.md
Expand Up @@ -36,6 +36,9 @@ particular, functions that have had reordered arguments like `tf.concat`,
`tf.split` will cause the script to incorrectly add keyword arguments that
mismap arguments.

- This script wouldn't actually reorder arguments. Instead, the script will add
keyword arguments to functions that had their arguments reordered.

- This script is not able to upgrade all functions. One notable example is
`tf.reverse()` which has been changed to take a list of indices rather than
a tensor of bools. If the script detects this, it will report this to stdout
Expand Down
18 changes: 14 additions & 4 deletions tensorflow/tools/compatibility/tf_upgrade.py
Expand Up @@ -95,7 +95,10 @@ def __init__(self):
"tf.split": {
"split_dim": "axis",
"num_split": "num_or_size_splits"
}
},
"tf.concat": {
"concat_dim": "axis"
},
}

# Mapping from function to the new name of the function
Expand Down Expand Up @@ -142,6 +145,8 @@ def __init__(self):
"tf.select": "tf.where",
"tf.complex_abs": "tf.abs",
"tf.batch_matmul": "tf.matmul",
"tf.pack": "tf.stack",
"tf.unpack": "tf.unstack",
}

# Functions that were reordered should be changed to the new keyword args
Expand Down Expand Up @@ -356,16 +361,21 @@ def visit_Call(self, node): # pylint: disable=invalid-name
# Examine any non-keyword argument and make it into a keyword argument
# if reordering required.
function_reorders = self._api_change_spec.function_reorders
function_keyword_renames = (
self._api_change_spec.function_keyword_renames)

if full_name in function_reorders:
reordered = function_reorders[full_name]
for idx, arg in enumerate(node.args):
keyword_arg = reordered[idx]
if (full_name in function_keyword_renames and
keyword_arg in function_keyword_renames[full_name]):
keyword_arg = function_keyword_renames[full_name][keyword_arg]
self._file_edit.add("Added keyword %r to reordered function %r"
% (reordered[idx], full_name), arg.lineno,
arg.col_offset, "", reordered[idx] + "=")
arg.col_offset, "", keyword_arg + "=")

# Examine each keyword argument and convert it to the final renamed form
function_keyword_renames = (
self._api_change_spec.function_keyword_renames)
renamed_keywords = ({} if full_name not in function_keyword_renames else
function_keyword_renames[full_name])
for keyword in node.keywords:
Expand Down
27 changes: 26 additions & 1 deletion tensorflow/tools/compatibility/tf_upgrade_test.py
Expand Up @@ -59,12 +59,37 @@ def testRename(self):
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, "tf.multiply(a, tf.subtract(b, c))\n")

def testRenamePack(self):
text = "tf.pack(a)\n"
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, "tf.stack(a)\n")
text = "tf.unpack(a)\n"
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, "tf.unstack(a)\n")

def testReorder(self):
text = "tf.concat(a, b)\ntf.split(a, b, c)\n"
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, "tf.concat(concat_dim=a, values=b)\n"
self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n"
"tf.split(axis=a, num_or_size_splits=b, value=c)\n")

def testConcatReorderWithKeywordArgs(self):
text = "tf.concat(concat_dim=a, values=b)\n"
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n")
text = "tf.concat(values=b, concat_dim=a)\n"
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, "tf.concat(values=b, axis=a)\n")
text = "tf.concat(a, values=b)\n"
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n")

def testConcatReorderNested(self):
text = "tf.concat(a, tf.concat(c, d))\n"
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(
new_text, "tf.concat(axis=a, values=tf.concat(axis=c, values=d))\n")

def testKeyword(self):
text = "tf.reduce_any(a, reduction_indices=[1, 2])\n"
_, unused_report, unused_errors, new_text = self._upgrade(text)
Expand Down