From bd8f3b8d71543d8d7c84d77c075e1afa07b6c6d9 Mon Sep 17 00:00:00 2001 From: Amit Patankar Date: Thu, 15 Apr 2021 13:28:49 -0700 Subject: [PATCH] Fix `tf.raw_ops.RaggedTensorToTensor` failing CHECK. PiperOrigin-RevId: 368706628 Change-Id: I5c9ea4833f38835ee183ca50d63251dc89c9f3bc --- .../kernels/ragged_tensor_to_tensor_op.cc | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc b/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc index ca9e1836c82127..6f5c07e25c3da0 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc @@ -208,7 +208,7 @@ class RaggedTensorToTensorBaseOp : public OpKernel { } void CalculateOutputIndexRowSplit( - const RowPartitionTensor& row_split, + OpKernelContext* context, const RowPartitionTensor& row_split, const vector& parent_output_index, INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size, vector* result) { @@ -233,7 +233,8 @@ class RaggedTensorToTensorBaseOp : public OpKernel { } } if (row_split_size > 0) { - DCHECK_EQ(result->size(), row_split(row_split_size - 1)); + OP_REQUIRES(context, result->size() == row_split(row_split_size - 1), + errors::InvalidArgument("Invalid row split size.")); } } @@ -259,7 +260,7 @@ class RaggedTensorToTensorBaseOp : public OpKernel { // result[7] = -1 because parent_output_index[value_rowids[6]] == -1 // result[8] = parent_output_index[value_rowids[7]] void CalculateOutputIndexValueRowID( - const RowPartitionTensor& value_rowids, + OpKernelContext* context, const RowPartitionTensor& value_rowids, const vector& parent_output_index, INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size, vector* result) { @@ -293,7 +294,8 @@ class RaggedTensorToTensorBaseOp : public OpKernel { } result->push_back(current_output_index); } - DCHECK_EQ(result->size(), value_rowids.size()); + OP_REQUIRES(context, result->size() == value_rowids.size(), + errors::InvalidArgument("Invalid row ids.")); } Status CalculateOutputIndex(OpKernelContext* context, int dimension, @@ -307,13 +309,13 @@ class RaggedTensorToTensorBaseOp : public OpKernel { switch (partition_type) { case RowPartitionType::VALUE_ROWIDS: CalculateOutputIndexValueRowID( - row_partition_tensor, parent_output_index, output_index_multiplier, - output_size, result); + context, row_partition_tensor, parent_output_index, + output_index_multiplier, output_size, result); return tensorflow::Status::OK(); case RowPartitionType::ROW_SPLITS: - CalculateOutputIndexRowSplit(row_partition_tensor, parent_output_index, - output_index_multiplier, output_size, - result); + CalculateOutputIndexRowSplit( + context, row_partition_tensor, parent_output_index, + output_index_multiplier, output_size, result); return tensorflow::Status::OK(); default: return errors::InvalidArgument(