Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix the segfault in tf.raw_ops.SparseCountSparseOutput.
PiperOrigin-RevId: 369264941
Change-Id: I23a96a15b8370c01ee21ba3841e1c7dcbf55e93d
  • Loading branch information
Amit Patankar authored and tensorflower-gardener committed Apr 19, 2021
1 parent 8f7b60e commit c57c0b9
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion tensorflow/core/kernels/count_ops.cc
Expand Up @@ -197,9 +197,17 @@ class SparseCount : public OpKernel {
"The shape argument requires at least one element."));

bool is_1d = shape.NumElements() == 1;
int num_batches = is_1d ? 1 : shape.flat<int64>()(0);
auto shape_vector = shape.flat<int64>();
int num_batches = is_1d ? 1 : shape_vector(0);
int num_values = values.NumElements();

for (int b = 0; b < shape_vector.size(); b++) {
OP_REQUIRES(context, shape_vector(b) >= 0,
errors::InvalidArgument(
"Elements in dense_shape must be >= 0. Instead got:",
shape.DebugString()));
}

OP_REQUIRES(context, num_values == indices.shape().dim_size(0),
errors::InvalidArgument(
"Number of values must match first dimension of indices.",
Expand Down

0 comments on commit c57c0b9

Please sign in to comment.