Skip to content

Commit bce3717

Browse files
cantoniostensorflower-gardener
authored andcommitted
Fix empty inputs for Upper/LowerBound.
For upper/lower-bound searches via `tf.searchsorted`, if the sorted input is empty, the previous code resulted in a `nullptr` dereference. For emtpy inputs, any sorted search should return a value of 0, meaning that a value would be inserted into the first slot of the array. PiperOrigin-RevId: 460971165
1 parent 1bec956 commit bce3717

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

Diff for: tensorflow/core/kernels/searchsorted_op.cc

+17
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License.
2222
#include "tensorflow/core/framework/register_types.h"
2323
#include "tensorflow/core/framework/tensor.h"
2424
#include "tensorflow/core/framework/tensor_shape.h"
25+
#include "tensorflow/core/kernels/fill_functor.h"
2526
#include "tensorflow/core/lib/core/bits.h"
2627
#include "tensorflow/core/platform/logging.h"
2728
#include "tensorflow/core/platform/threadpool.h"
@@ -129,6 +130,14 @@ class UpperBoundOp : public OpKernel {
129130
auto output = output_t->template flat<OutType>();
130131
const auto sorted_inputs = sorted_inputs_t.template flat<T>();
131132
const auto values = values_t.template flat<T>();
133+
134+
// For empty inputs, all values will be placed at the zeroth position.
135+
if (sorted_inputs.size() == 0) {
136+
functor::SetZeroFunctor<Device, OutType> set_zero;
137+
set_zero(ctx->eigen_device<Device>(), output);
138+
return;
139+
}
140+
132141
OP_REQUIRES_OK(
133142
ctx, functor::UpperBoundFunctor<Device, T, OutType>::Compute(
134143
ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0),
@@ -174,6 +183,14 @@ class LowerBoundOp : public OpKernel {
174183
auto output = output_t->template flat<OutType>();
175184
const auto sorted_inputs = sorted_inputs_t.template flat<T>();
176185
const auto values = values_t.template flat<T>();
186+
187+
// For empty inputs, all values will be placed at the zeroth position.
188+
if (sorted_inputs.size() == 0) {
189+
functor::SetZeroFunctor<Device, OutType> set_zero;
190+
set_zero(ctx->eigen_device<Device>(), output);
191+
return;
192+
}
193+
177194
OP_REQUIRES_OK(
178195
ctx, functor::LowerBoundFunctor<Device, T, OutType>::Compute(
179196
ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0),

Diff for: tensorflow/python/kernel_tests/array_ops/array_ops_test.py

+11
Original file line numberDiff line numberDiff line change
@@ -2060,6 +2060,17 @@ def testZeroValueSize(self):
20602060
side=side,
20612061
out_type=dtype), array_ops.zeros([2, 0], dtype))
20622062

2063+
def testZeroInputSize(self):
2064+
dtype = dtypes.int32
2065+
for side in ("left", "right"):
2066+
with self.subTest(side=side):
2067+
self.assertAllEqual(
2068+
array_ops.searchsorted(
2069+
array_ops.ones([2, 0]),
2070+
array_ops.ones([2, 3]),
2071+
side=side,
2072+
out_type=dtype), array_ops.zeros([2, 3], dtype))
2073+
20632074
def testInt64(self):
20642075

20652076
@def_function.function

0 commit comments

Comments
 (0)