-
Notifications
You must be signed in to change notification settings - Fork 74k
/
sort_ops_test.py
367 lines (321 loc) · 13.6 KB
/
sort_ops_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for sorting operators."""
import unittest
from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.compiler.tf2xla.python import xla
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
ALL_KEY_TYPES = [
dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,
np.int32, np.uint32, np.int16, np.uint16, np.int8, np.uint8
]
class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
def _assertOpOutputMatchesExpected(self, op, args, expected):
"""Tests that op(*args) == expected."""
with self.session() as session:
with self.test_scope():
placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
for arg in args
]
feeds = {placeholders[i]: args[i] for i in range(0, len(args))}
output = op(*placeholders)
if isinstance(output, ops.Tensor):
output = [output]
results = session.run(output, feeds)
for result, v in zip(results, expected):
self.assertAllClose(v, result, rtol=1e-3)
def _shuffled_arange(self, shape, dtype):
x = np.arange(np.prod(shape), dtype=dtype)
np.random.shuffle(x)
return x.reshape(shape)
def _supported_key_types(self):
supported_key_types = set(ALL_KEY_TYPES)
res = supported_key_types.intersection(self.numeric_types)
assert res
return res
def testSort(self):
for dtype in self._supported_key_types():
x = self._shuffled_arange((101,), dtype)
self._assertOpOutputMatchesExpected(
xla.sort, [x], expected=[np.arange(101, dtype=dtype)])
def testKeyValueSort(self):
for key_type in self._supported_key_types():
for value_type in self._supported_key_types():
if key_type == np.uint8 or value_type == np.uint8:
# I do not understand why the test fails on uint8. We plan to
# deprecate xla.key_value_sort in favor of xla.variadic_sort anyway.
continue
x = self._shuffled_arange((101,), key_type)
y = (-x).astype(value_type)
self._assertOpOutputMatchesExpected(
xla.key_value_sort, [x, y],
expected=[
np.arange(101, dtype=key_type),
-np.arange(101, dtype=value_type)
])
@parameterized.parameters(0, 1, 2)
def testVariadicSortDimension(self, dimension):
shape = (2, 3, 4)
for key_type in self._supported_key_types():
x = self._shuffled_arange(shape, key_type)
expected = np.sort(x, axis=dimension)
@function.Defun(key_type, key_type)
def compare_lt(x1, x2):
return x1 < x2
def wrap_sort(x):
return xla.variadic_sort([x],
dimension=dimension,
is_stable=False,
comparator=compare_lt)
self._assertOpOutputMatchesExpected(wrap_sort, [x], expected=[expected])
def testVariadicSortReverse(self):
shape = (100,)
for key_type in self._supported_key_types():
x = self._shuffled_arange(shape, key_type)
expected = np.sort(x, axis=0)[::-1]
@function.Defun(key_type, key_type)
def compare_gt(x1, x2):
return x1 > x2
def wrap_sort(x):
return xla.variadic_sort([x],
dimension=0,
is_stable=False,
comparator=compare_gt)
self._assertOpOutputMatchesExpected(wrap_sort, [x], expected=[expected])
@parameterized.product(dimension=[0, 1, 2], key_type=ALL_KEY_TYPES)
def testVariadicSortSeveral(self, dimension, key_type):
if np.__version__ < "1.15":
raise unittest.SkipTest("np.take_along_axis was added in 1.15")
if key_type not in self._supported_key_types():
return
shape = (2, 3, 4)
for value_type_1 in self._supported_key_types():
for value_type_2 in self._supported_key_types():
inputs = [
self._shuffled_arange(shape, key_type),
self._shuffled_arange(shape, value_type_1),
self._shuffled_arange(shape, value_type_2)
]
# The first array is sorted, and the others are shuffled the same way
sorted_indices = np.argsort(inputs[0], axis=dimension)
expected = [
np.take_along_axis(inp, sorted_indices, axis=dimension)
for inp in inputs
]
self.assertAllEqual(np.sort(inputs[0], axis=dimension), expected[0])
@function.Defun(key_type, key_type, value_type_1, value_type_1,
value_type_2, value_type_2)
def compare_lt(x1, x2, y1, y2, z1, z2):
del y1, y2, z1, z2
return x1 < x2
def wrap_sort(*args):
return xla.variadic_sort(
args, # Pass the arguments as a tuple
comparator=compare_lt,
dimension=dimension,
is_stable=False)
self._assertOpOutputMatchesExpected(
wrap_sort, inputs, expected=expected)
@parameterized.parameters(ALL_KEY_TYPES)
@test_util.disable_mlir_bridge("Not supported yet")
def testVariadicSortLexicographic(self, key_type_2):
# Three inputs: the first two are used for lexicographic sort, and the
# third is just swapped accordingly.
# The first array will contain only 0 and 1, to test lexicographic order
if np.__version__ < "1.15":
raise unittest.SkipTest("np.take_along_axis was added in 1.15")
shape = (20,)
if key_type_2 not in self._supported_key_types():
return
for key_type_1 in [np.int16, np.uint16, np.int32, np.uint32]:
for value_type in self._supported_key_types():
inputs = [
# Ensure that some keys in the first input are equal
np.random.uniform(0, 2, shape).astype(key_type_1),
self._shuffled_arange(shape, key_type_2),
self._shuffled_arange(shape, value_type)
]
# The first two arrays are sorted lexicographically, and the third
# is shuffled the same way
sorted_indices = np.argsort(100 * inputs[0] + inputs[1])
expected = [
np.take_along_axis(inp, sorted_indices, axis=0) for inp in inputs
]
@function.Defun(key_type_1, key_type_1, key_type_2, key_type_2,
value_type, value_type)
def compare_lexicographic(x1, x2, y1, y2, z1, z2):
del z1, z2
return math_ops.logical_or(
x1 < x2, math_ops.logical_and(math_ops.equal(x1, x2), y1 < y2))
def wrap_sort(*args):
return xla.variadic_sort(
args, # Pass the arguments as a tuple
comparator=compare_lexicographic,
dimension=0,
is_stable=False)
self._assertOpOutputMatchesExpected(
wrap_sort, inputs, expected=expected)
@parameterized.product(dimension=[0, 1, 2], key_type=ALL_KEY_TYPES)
def testVariadicSortSeveralStable(self, dimension, key_type):
shape = (2, 3, 4)
if key_type not in self._supported_key_types():
return
for value_type_1 in self._supported_key_types():
for value_type_2 in self._supported_key_types():
# The first input is all 0s, there should be no changes for
# stable sort.
inputs = [
np.zeros(shape, key_type),
self._shuffled_arange(shape, value_type_1),
self._shuffled_arange(shape, value_type_2)
]
@function.Defun(key_type, key_type, value_type_1, value_type_1,
value_type_2, value_type_2)
def compare_lt(x1, x2, y1, y2, z1, z2):
del y1, y2, z1, z2
return x1 < x2
def wrap_sort(*args):
return xla.variadic_sort(
args, # Pass the arguments as a tuple
comparator=compare_lt,
dimension=dimension,
is_stable=True)
self._assertOpOutputMatchesExpected(wrap_sort, inputs, expected=inputs)
def testTopK(self):
supported_types = set([
dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,
np.int32, np.uint32, np.int64, np.uint64, np.uint8, np.int8,
])
for dtype in supported_types.intersection(self.numeric_types):
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
# after conversion to bfloat16, so the possible resulting index array is
# no longer unique.
if dtype in (dtypes.bfloat16.as_numpy_dtype, np.float16):
array_size = 20
k_options = [0, 1, 2, 10, 20]
elif dtype in (dtypes.uint8.as_numpy_dtype, dtypes.int8.as_numpy_dtype):
array_size = 111
k_options = [0, 1, 2, 10, 20]
else:
array_size = 200 * 1000
k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000]
for x in [np.arange(array_size)]:
np.random.shuffle(x)
for k in k_options:
indices = x.argsort()[::-1][:k]
def topk(v, k=k):
return nn_ops.top_k(v, k=k, sorted=True)
self._assertOpOutputMatchesExpected(
topk, [x.astype(dtype)],
expected=[x[indices].astype(dtype), indices])
@parameterized.named_parameters(
("HalfPrecision", dtypes.bfloat16.as_numpy_dtype),
("HalfFloatPrecision", np.float16),
("SinglePrecision", np.float32),
("DoublePrecision", np.float64),
("Int32", np.int32),
("UnsignedInt32", np.uint32),
("Int64", np.int64),
("UnsignedInt64", np.uint64),
)
def testTopK2D(self, dtype):
if dtype in self.numeric_types:
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
# after conversion to bfloat16, so the possible resulting index array is
# no longer unique.
if dtype in (dtypes.bfloat16.as_numpy_dtype, np.float16):
array_size = 10
k_options = [0, 1, 2, 10]
else:
array_size = 200 * 1000
k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000]
batch = 16
for x in [np.arange(batch * array_size)]:
np.random.shuffle(x)
x = np.reshape(x, [batch, array_size])
for k in k_options:
indices = x.argsort(axis=1)[::, -1:-k - 1:-1]
expected = np.sort(x, axis=1)[::, -1:-k - 1:-1]
def topk(v, k=k):
return nn_ops.top_k(v, k=k, sorted=True)
self._assertOpOutputMatchesExpected(
topk, [x.astype(dtype)],
expected=[expected.astype(dtype), indices])
def testTopKZeros(self):
"""Tests that positive and negative zeros sort correctly."""
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64])
for dtype in supported_types.intersection(self.numeric_types):
with self.session() as sess:
p = array_ops.placeholder(dtype)
with self.test_scope():
topk = nn_ops.top_k(p, k=4)
results = sess.run(
topk,
{p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=dtype)})
self.assertAllEqual(np.array([3., 0., 0., 0.], dtype=dtype), results[0])
self.assertEqual(list([3, 0, 2, 6]), list(results[1]))
def testTopKInfinities(self):
"""Tests that positive and negative infinity sort correctly."""
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64])
for dtype in supported_types.intersection(self.numeric_types):
with self.session() as sess:
p = array_ops.placeholder(dtype)
with self.test_scope():
topk = nn_ops.top_k(p, k=6)
results = sess.run(topk, {
p:
np.array([1, 2, float("inf"), -float("inf"), -1, -2],
dtype=dtype)
})
self.assertAllEqual(
np.array([float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")],
dtype=dtype), results[0])
self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1]))
@parameterized.named_parameters(
("Int32", np.int32),
("Int64", np.uint64),
)
def testInTopK(self, dtype):
if dtype in self.numeric_types:
array_size = 200 * 1000
k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000]
batch = 16
for x in [np.arange(batch * array_size)]:
np.random.shuffle(x)
x = np.reshape(x, [batch, array_size])
y = np.random.randint(0, array_size, size=batch)
for k in k_options:
indices = x.argsort(axis=1)[::, -1:-k - 1:-1]
expected = [y[i] in indices[i] for i in range(batch)]
def in_topk(predictions, targets, k=k):
return nn_ops.in_top_k(predictions, targets, k)
self._assertOpOutputMatchesExpected(
in_topk,
[x.astype(np.float32), y.astype(dtype)],
expected=[expected])
if __name__ == "__main__":
test.main()