Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix check-fail when bincount ops are passed invalid values.
PiperOrigin-RevId: 415063028
Change-Id: I20f8dc09933ddca1111c4efbf9a3a1e863215d02
  • Loading branch information
edloper authored and tensorflower-gardener committed Dec 8, 2021
1 parent 54441d2 commit 7019ce4
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 0 deletions.
9 changes: 9 additions & 0 deletions tensorflow/core/kernels/bincount_op.cc
Expand Up @@ -276,6 +276,9 @@ class DenseBincountOp : public OpKernel {
const Tensor& size_t = ctx->input(1);
const Tensor& weights = ctx->input(2);

OP_REQUIRES(ctx, size_t.dims() == 0,
errors::InvalidArgument("Shape must be rank 0 but is rank ",
size_t.dims()));
Tidx size = size_t.scalar<Tidx>()();
OP_REQUIRES(
ctx, size >= 0,
Expand Down Expand Up @@ -372,6 +375,9 @@ class SparseBincountOp : public OpKernel {
const auto weights = ctx->input(4).flat<T>();
const int64_t weights_size = weights.size();

OP_REQUIRES(ctx, size_t.dims() == 0,
errors::InvalidArgument("Shape must be rank 0 but is rank ",
size_t.dims()));
Tidx size = size_t.scalar<Tidx>()();
OP_REQUIRES(
ctx, size >= 0,
Expand Down Expand Up @@ -462,6 +468,9 @@ class RaggedBincountOp : public OpKernel {
const auto weights = ctx->input(3).flat<T>();
const int64_t weights_size = weights.size();

OP_REQUIRES(ctx, size_t.dims() == 0,
errors::InvalidArgument("Shape must be rank 0 but is rank ",
size_t.dims()));
Tidx size = size_t.scalar<Tidx>()();
OP_REQUIRES(
ctx, size >= 0,
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/core/ops/math_ops.cc
Expand Up @@ -1699,6 +1699,11 @@ REGISTER_OP("Bincount")
return Status::OK();
}

if (size_tensor->dims() != 0) {
return errors::InvalidArgument("Shape must be rank 0 but is rank ",
size_tensor->dims());
}

// Return `[size]` shape if size is known.
int32_t size_val = size_tensor->scalar<int32>()();
if (size_val < 0) {
Expand Down Expand Up @@ -1730,6 +1735,10 @@ REGISTER_OP("DenseBincount")
c->set_output(0, c->UnknownShape());
return Status::OK();
}
if (size_tensor->dims() != 0) {
return errors::InvalidArgument("Shape must be rank 0 but is rank ",
size_tensor->dims());
}

int64_t size_val;
DataType dtype;
Expand Down Expand Up @@ -1771,6 +1780,10 @@ REGISTER_OP("SparseBincount")
c->set_output(0, c->UnknownShape());
return Status::OK();
}
if (size_tensor->dims() != 0) {
return errors::InvalidArgument("Shape must be rank 0 but is rank ",
size_tensor->dims());
}

int64_t size_val;
DataType dtype;
Expand Down
34 changes: 34 additions & 0 deletions tensorflow/python/kernel_tests/math_ops/bincount_op_test.py
Expand Up @@ -344,6 +344,14 @@ def test_invalid_rank(self):
gen_math_ops.dense_bincount(
input=[[[1, 2, 3], [0, 3, 2]]], weights=[], size=10))

@test_util.run_in_graph_and_eager_modes
def test_size_is_not_scalar(self): # b/206619828
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"Shape must be rank 0 but is rank 1"):
self.evaluate(
gen_math_ops.dense_bincount(
input=[0], size=[1, 1], weights=[3], binary_output=False))


class SparseBincountOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
Expand Down Expand Up @@ -511,6 +519,19 @@ def test_sparse_bincount_col_reduce_binary(self, dtype):
weights=[],
binary_output=True)))

@test_util.run_in_graph_and_eager_modes
def test_size_is_not_scalar(self): # b/206619828
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"Shape must be rank 0 but is rank 1"):
self.evaluate(
gen_math_ops.sparse_bincount(
indices=[[0], [1]],
values=[0, 0],
dense_shape=[1, 1],
size=[1, 1],
weights=[0, 0],
binary_output=False))


class RaggedBincountOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
Expand Down Expand Up @@ -650,6 +671,19 @@ def test_ragged_bincount_binary_np_with_weights(self, dtype):
size=size,
binary_output=True)))

@test_util.run_in_graph_and_eager_modes
def test_size_is_not_scalar(self): # b/206619828
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"Shape must be rank 0 but is rank 1"):
self.evaluate(
gen_math_ops.ragged_bincount(
splits=[0, 0, 1],
values=[1],
size=[1, 1],
weights=[0, 0, 0],
binary_output=False,
name=None))


if __name__ == "__main__":
googletest.main()

0 comments on commit 7019ce4

Please sign in to comment.