From 9645ca19de2d4b3bfedf2a2d12aff5ffeac371a5 Mon Sep 17 00:00:00 2001 From: Amit Patankar Date: Tue, 2 Mar 2021 17:02:03 -0800 Subject: [PATCH] Fix segfaults in `tf.raw_ops.SparseCountSparseOutput`. PiperOrigin-RevId: 360547563 Change-Id: I781c7af4b54a63d867c6e18d43a44d64a5c4e7c9 --- tensorflow/core/kernels/count_ops.cc | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tensorflow/core/kernels/count_ops.cc b/tensorflow/core/kernels/count_ops.cc index 087deef0812f00..d6ab68c2c70bd3 100644 --- a/tensorflow/core/kernels/count_ops.cc +++ b/tensorflow/core/kernels/count_ops.cc @@ -192,6 +192,10 @@ class SparseCount : public OpKernel { "; values shape: ", values.shape().DebugString())); } + OP_REQUIRES(context, shape.NumElements() != 0, + errors::InvalidArgument( + "The shape argument requires at least one element.")); + bool is_1d = shape.NumElements() == 1; int num_batches = is_1d ? 1 : shape.flat()(0); int num_values = values.NumElements(); @@ -212,6 +216,14 @@ class SparseCount : public OpKernel { for (int idx = 0; idx < num_values; ++idx) { int batch = is_1d ? 0 : indices_values(idx, 0); + if (batch >= num_batches) { + OP_REQUIRES(context, batch < num_batches, + errors::InvalidArgument( + "Indices value along the first dimension must be ", + "lower than the first index of the shape.", "Got ", + batch, " as batch and ", num_batches, + " as the first dimension of the shape.")); + } const auto& value = values_values(idx); if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) { if (binary_output_) {