Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix empty inputs for Upper/LowerBound.
For upper/lower-bound searches via `tf.searchsorted`, if the sorted input is empty,
the previous code resulted in a `nullptr` dereference.  For emtpy inputs, any
sorted search should return a value of 0, meaning that a value would be inserted
into the first slot of the array.

PiperOrigin-RevId: 460971165
  • Loading branch information
cantonios authored and tensorflower-gardener committed Jul 14, 2022
1 parent 1bec956 commit bce3717
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
17 changes: 17 additions & 0 deletions tensorflow/core/kernels/searchsorted_op.cc
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/threadpool.h"
Expand Down Expand Up @@ -129,6 +130,14 @@ class UpperBoundOp : public OpKernel {
auto output = output_t->template flat<OutType>();
const auto sorted_inputs = sorted_inputs_t.template flat<T>();
const auto values = values_t.template flat<T>();

// For empty inputs, all values will be placed at the zeroth position.
if (sorted_inputs.size() == 0) {
functor::SetZeroFunctor<Device, OutType> set_zero;
set_zero(ctx->eigen_device<Device>(), output);
return;
}

OP_REQUIRES_OK(
ctx, functor::UpperBoundFunctor<Device, T, OutType>::Compute(
ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0),
Expand Down Expand Up @@ -174,6 +183,14 @@ class LowerBoundOp : public OpKernel {
auto output = output_t->template flat<OutType>();
const auto sorted_inputs = sorted_inputs_t.template flat<T>();
const auto values = values_t.template flat<T>();

// For empty inputs, all values will be placed at the zeroth position.
if (sorted_inputs.size() == 0) {
functor::SetZeroFunctor<Device, OutType> set_zero;
set_zero(ctx->eigen_device<Device>(), output);
return;
}

OP_REQUIRES_OK(
ctx, functor::LowerBoundFunctor<Device, T, OutType>::Compute(
ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0),
Expand Down
11 changes: 11 additions & 0 deletions tensorflow/python/kernel_tests/array_ops/array_ops_test.py
Expand Up @@ -2060,6 +2060,17 @@ def testZeroValueSize(self):
side=side,
out_type=dtype), array_ops.zeros([2, 0], dtype))

def testZeroInputSize(self):
dtype = dtypes.int32
for side in ("left", "right"):
with self.subTest(side=side):
self.assertAllEqual(
array_ops.searchsorted(
array_ops.ones([2, 0]),
array_ops.ones([2, 3]),
side=side,
out_type=dtype), array_ops.zeros([2, 3], dtype))

def testInt64(self):

@def_function.function
Expand Down

0 comments on commit bce3717

Please sign in to comment.