Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix RaggedBincount Segmentation Fault from the Splits arg
PiperOrigin-RevId: 461715364
  • Loading branch information
tensorflower-gardener committed Jul 18, 2022
1 parent a7902da commit 7a4591f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tensorflow/core/kernels/bincount_op.cc
Expand Up @@ -493,6 +493,9 @@ class RaggedBincountOp : public OpKernel {
int num_values = values.size();
int batch_idx = 0;

OP_REQUIRES(ctx, splits.size() > 0,
errors::InvalidArgument("Splits must be non-empty"));

OP_REQUIRES(ctx, splits(0) == 0,
errors::InvalidArgument("Splits must start with 0, not with ",
splits(0)));
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/python/kernel_tests/math_ops/bincount_op_test.py
Expand Up @@ -734,6 +734,18 @@ def test_size_is_not_scalar(self): # b/206619828
binary_output=False,
name=None))

@test_util.run_in_graph_and_eager_modes
def test_splits_empty(self): # b/238450914
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"Splits must be non-empty"):
self.evaluate(
gen_math_ops.ragged_bincount(
splits=[], # Invalid splits
values=[1],
size=1,
weights=[1],
binary_output=False,
name=None))

if __name__ == "__main__":
googletest.main()

0 comments on commit 7a4591f

Please sign in to comment.