Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix tf.raw_ops.RaggedTensorToTensor failing CHECK.
PiperOrigin-RevId: 368706628
Change-Id: I5c9ea4833f38835ee183ca50d63251dc89c9f3bc
  • Loading branch information
Amit Patankar authored and tensorflower-gardener committed Apr 15, 2021
1 parent 1b4de33 commit b761c9b
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc
Expand Up @@ -208,7 +208,7 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
}

void CalculateOutputIndexRowSplit(
const RowPartitionTensor& row_split,
OpKernelContext* context, const RowPartitionTensor& row_split,
const vector<INDEX_TYPE>& parent_output_index,
INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size,
vector<INDEX_TYPE>* result) {
Expand All @@ -233,7 +233,8 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
}
}
if (row_split_size > 0) {
DCHECK_EQ(result->size(), row_split(row_split_size - 1));
OP_REQUIRES(context, result->size() == row_split(row_split_size - 1),
errors::InvalidArgument("Invalid row split size."));
}
}

Expand All @@ -259,7 +260,7 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
// result[7] = -1 because parent_output_index[value_rowids[6]] == -1
// result[8] = parent_output_index[value_rowids[7]]
void CalculateOutputIndexValueRowID(
const RowPartitionTensor& value_rowids,
OpKernelContext* context, const RowPartitionTensor& value_rowids,
const vector<INDEX_TYPE>& parent_output_index,
INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size,
vector<INDEX_TYPE>* result) {
Expand Down Expand Up @@ -293,7 +294,8 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
}
result->push_back(current_output_index);
}
DCHECK_EQ(result->size(), value_rowids.size());
OP_REQUIRES(context, result->size() == value_rowids.size(),
errors::InvalidArgument("Invalid row ids."));
}

Status CalculateOutputIndex(OpKernelContext* context, int dimension,
Expand All @@ -307,13 +309,13 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
switch (partition_type) {
case RowPartitionType::VALUE_ROWIDS:
CalculateOutputIndexValueRowID(
row_partition_tensor, parent_output_index, output_index_multiplier,
output_size, result);
context, row_partition_tensor, parent_output_index,
output_index_multiplier, output_size, result);
return tensorflow::Status::OK();
case RowPartitionType::ROW_SPLITS:
CalculateOutputIndexRowSplit(row_partition_tensor, parent_output_index,
output_index_multiplier, output_size,
result);
CalculateOutputIndexRowSplit(
context, row_partition_tensor, parent_output_index,
output_index_multiplier, output_size, result);
return tensorflow::Status::OK();
default:
return errors::InvalidArgument(
Expand Down

0 comments on commit b761c9b

Please sign in to comment.