Skip to content

Commit

Permalink
Merge pull request #49316 from geetachavan1/cherrypicks_383PR
Browse files Browse the repository at this point in the history
Fix segfaults in `tf.raw_ops.SparseCountSparseOutput`.
  • Loading branch information
mihaimaruseac committed May 20, 2021
2 parents 8fbc6b2 + 9645ca1 commit f369322
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions tensorflow/core/kernels/count_ops.cc
Expand Up @@ -192,6 +192,10 @@ class SparseCount : public OpKernel {
"; values shape: ", values.shape().DebugString()));
}

OP_REQUIRES(context, shape.NumElements() != 0,
errors::InvalidArgument(
"The shape argument requires at least one element."));

bool is_1d = shape.NumElements() == 1;
auto shape_vector = shape.flat<int64>();
int num_batches = is_1d ? 1 : shape_vector(0);
Expand Down Expand Up @@ -220,6 +224,14 @@ class SparseCount : public OpKernel {

for (int idx = 0; idx < num_values; ++idx) {
int batch = is_1d ? 0 : indices_values(idx, 0);
if (batch >= num_batches) {
OP_REQUIRES(context, batch < num_batches,
errors::InvalidArgument(
"Indices value along the first dimension must be ",
"lower than the first index of the shape.", "Got ",
batch, " as batch and ", num_batches,
" as the first dimension of the shape."));
}
const auto& value = values_values(idx);
if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) {
if (binary_output_) {
Expand Down

0 comments on commit f369322

Please sign in to comment.