Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sorted builtin support for autograph #36812

Merged
merged 18 commits into from
Feb 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorflow/python/autograph/operators/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ py_test(
"//tensorflow/python/autograph/core",
# TODO(b/145618471): Remove this transitive dependency.
"//tensorflow/python/distribute:input_lib",
"//tensorflow/python/ops/parallel_for:control_flow_ops",
mdanatg marked this conversation as resolved.
Show resolved Hide resolved
"//tensorflow/python/ops/signal",
],
)

Expand Down
48 changes: 47 additions & 1 deletion tensorflow/python/autograph/operators/py_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sort_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.util import lazy_loader
from tensorflow.python.util import nest

Expand All @@ -47,6 +49,10 @@
input_lib = lazy_loader.LazyLoader(
'input_lib', globals(),
'tensorflow.python.distribute.input_lib')
parallel_ops = lazy_loader.LazyLoader(
'parallel_ops', globals(),
'tensorflow.python.ops.parallel_for.control_flow_ops'
)


UNSPECIFIED = object()
Expand Down Expand Up @@ -461,8 +467,47 @@ def _py_all(iterable):
return all(iterable)


def sorted_(iterable, key=UNSPECIFIED, reverse=UNSPECIFIED):
if tensor_util.is_tensor(iterable):
return _tf_sorted(iterable, key, reverse)
return _py_sorted(iterable, key, reverse)


def _tf_sorted(iterable, key, reverse):
"""Overload of sorted_ for Tensor iterable."""
if reverse is UNSPECIFIED:
direction = 'ASCENDING'
else:
direction = 'DESCENDING'
if key is not UNSPECIFIED:
mapped = parallel_ops.vectorized_map(key, iterable)
if mapped.shape.rank is not None and mapped.shape.rank != 1:
raise ValueError('sort only supports only 1D tensors')
with ops.control_dependencies(
Copy link

@mdanatg mdanatg Feb 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: If the rank is static, you could have an extra check for an early warning:

if mapped.shape.rank is not None and mapped.shape.rank != 1:
  raise ValueError('sort only supports only 1D tensors')

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mdanatg, the PR has been updated.

[check_ops.assert_rank_v2(
mapped, 1, 'sort only supports only 1D tensors')]):
order = sort_ops.argsort(mapped, direction=direction)
return array_ops.gather_v2(iterable, order)
if iterable.shape.rank is not None and iterable.shape.rank != 1:
raise ValueError('sort only supports only 1D tensors')
with ops.control_dependencies(
[check_ops.assert_rank_v2(
iterable, 1, 'sort only supports only 1D tensors')]):
return sort_ops.sort(iterable, direction=direction)


def _py_sorted(iterable, key, reverse):
if key is not UNSPECIFIED and reverse is UNSPECIFIED:
return sorted(iterable, key=key)
if key is UNSPECIFIED and reverse is not UNSPECIFIED:
return sorted(iterable, reverse=reverse)
if key is not UNSPECIFIED and reverse is not UNSPECIFIED:
return sorted(iterable, key=key, reverse=reverse)
return sorted(iterable)


SUPPORTED_BUILTINS = (abs, float, int, len, print, range, enumerate, zip, map,
filter, any, all)
filter, any, all, sorted)

if six.PY2:
SUPPORTED_BUILTINS += (xrange,)
Expand All @@ -482,4 +527,5 @@ def _py_all(iterable):
'filter': filter_,
'any': any_,
'all': all_,
'sorted': sorted_,
}
35 changes: 35 additions & 0 deletions tensorflow/python/autograph/operators/py_builtins_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test


Expand Down Expand Up @@ -373,6 +374,40 @@ def test_all_dataset(self):
with self.assertRaises(ValueError):
py_builtins.all_(dataset_mixed)

def test_sorted(self):
self.assertListEqual(py_builtins.sorted_([2, 3, 1]), [1, 2, 3])
self.assertListEqual(
py_builtins.sorted_([2, 3, 1], key=lambda x: -x), [3, 2, 1])
self.assertListEqual(
py_builtins.sorted_([2, 3, 1], reverse=True), [3, 2, 1])
self.assertListEqual(
py_builtins.sorted_([2, 3, 1], key=lambda x: -x, reverse=True),
[1, 2, 3])
self.assertAllEqual(
py_builtins.sorted_([[4, 3], [2, 1]], key=lambda x: sum(x)),
[[2, 1], [4, 3]])

def test_sorted_tensor(self):
iterable_1 = constant_op.constant([2, 3, 1])
self.assertListEqual(list(self.evaluate(
py_builtins.sorted_(iterable_1))), [1, 2, 3])
self.assertListEqual(list(self.evaluate(
py_builtins.sorted_(iterable_1, key=lambda x: -x))), [3, 2, 1])
self.assertListEqual(list(self.evaluate(
py_builtins.sorted_(iterable_1, reverse=True))), [3, 2, 1])
self.assertListEqual(list(self.evaluate(
py_builtins.sorted_(iterable_1, key=lambda x: -x, reverse=True))),
[1, 2, 3])

iterable_2 = constant_op.constant([[4, 3], [2, 1]])
with self.assertRaises(ValueError):
py_builtins.sorted_(iterable_2)
with self.assertRaises(ValueError):
py_builtins.sorted_(iterable_2, key=lambda x: -x)
self.assertAllEqual(list(self.evaluate(
py_builtins.sorted_(iterable_2, key=lambda x: math_ops.reduce_sum(x)))),
[[2, 1], [4, 3]])


if __name__ == '__main__':
test.main()