Skip to content
Permalink
Browse files Browse the repository at this point in the history
Add missing validation to RaggedTensorToSparse.
There needs to be a check that the splits allow for valid ragged tensors.

PiperOrigin-RevId: 387712169
Change-Id: I2499175324b82b65d159a260c7f83b98ceb5cc7d
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Jul 30, 2021
1 parent 0f387ff commit 1071f55
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/errors.h"

namespace tensorflow {

Expand All @@ -38,7 +39,8 @@ class RaggedTensorToSparseOp : public OpKernel {
OP_REQUIRES_OK(
context, context->input_list("rt_nested_splits", &rt_nested_splits_in));
const int rt_nested_splits_len = rt_nested_splits_in.size();
DCHECK_GT(rt_nested_splits_len, 0); // Enforced by REGISTER_OP.
OP_REQUIRES(context, rt_nested_splits_len > 0,
errors::InvalidArgument("rt_nested_splits must be non empty"));
std::vector<ConstFlatSplits> rt_nested_splits;
rt_nested_splits.reserve(rt_nested_splits_len);
for (int i = 0; i < rt_nested_splits_len; ++i) {
Expand Down Expand Up @@ -162,6 +164,14 @@ class RaggedTensorToSparseOp : public OpKernel {
if (rt_nested_splits[i](0) != 0) {
return InvalidArgument("First value of ragged splits must be 0.");
}
for (int j = 1; j < rt_nested_splits[i].size(); ++j) {
if (rt_nested_splits[i](j) < rt_nested_splits[i](j - 1)) {
return InvalidArgument(
"Ragged splits should be non decreasing, but we got ",
rt_nested_splits[i](j - 1), " followed by ",
rt_nested_splits[i](j));
}
}
if (i > 0) {
SPLITS_TYPE last_split =
rt_nested_splits[i - 1](rt_nested_splits[i - 1].size() - 1);
Expand Down

0 comments on commit 1071f55

Please sign in to comment.