Skip to content
Permalink
Browse files Browse the repository at this point in the history
Make SparseFillEmptyRows validate that the length of values must be…
… equal to the number of index tuples.

PiperOrigin-RevId: 399969549
Change-Id: I3c2f2ca1c1d2cc88bb5951c6958b38c16e9436c8
  • Loading branch information
penpornk authored and tensorflower-gardener committed Sep 30, 2021
1 parent 421fba8 commit 67bfd9f
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
Expand Up @@ -24,11 +24,13 @@ limitations under the License.
#include <vector>

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"

namespace tensorflow {
Expand Down Expand Up @@ -222,6 +224,12 @@ void SparseFillEmptyRowsOpImpl(OpKernelContext* context,
errors::InvalidArgument("values must be a vector, saw: ",
values_t.shape().DebugString()),
done);
OP_REQUIRES_ASYNC(
context, indices_t.dim_size(0) == values_t.dim_size(0),
errors::InvalidArgument("The length of `values` (", values_t.dim_size(0),
") must match the first dimension of `indices` (",
indices_t.dim_size(0), ")."),
done);
OP_REQUIRES_ASYNC(
context, TensorShapeUtils::IsScalar(default_value_t.shape()),
errors::InvalidArgument("default_value must be a scalar, saw: ",
Expand Down

0 comments on commit 67bfd9f

Please sign in to comment.