Skip to content
Permalink
Browse files Browse the repository at this point in the history
Cleanup and remove duplicate validation in SparseCount.
We have valdiation that is duplicated, checking different conditions, in different formats and failing to capture all cases. This should fix all the previous bugs.

PiperOrigin-RevId: 414886981
Change-Id: Ibf0bba0beb057b76d505324bb9487565daf95f01
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Dec 8, 2021
1 parent 6089a29 commit 2b7100d
Showing 1 changed file with 21 additions and 27 deletions.
48 changes: 21 additions & 27 deletions tensorflow/core/kernels/count_ops.cc
Expand Up @@ -185,6 +185,27 @@ class SparseCount : public OpKernel {
errors::InvalidArgument(
"Input indices must be a 2-dimensional tensor. Got: ",
indices.shape().DebugString()));
OP_REQUIRES(context, TensorShapeUtils::IsVector(values.shape()),
errors::InvalidArgument("Input values must be a vector. Got: ",
values.shape().DebugString()));
OP_REQUIRES(context, TensorShapeUtils::IsVector(shape.shape()),
errors::InvalidArgument("Input shape must be a vector. Got: ",
shape.shape().DebugString()));
OP_REQUIRES(context,
values.shape().dim_size(0) == indices.shape().dim_size(0),
errors::InvalidArgument(
"Number of values must match first dimension of indices.",
"Got ", values.shape().dim_size(0),
" values, indices shape: ", indices.shape().DebugString()));
OP_REQUIRES(
context, shape.shape().dim_size(0) == indices.shape().dim_size(1),
errors::InvalidArgument(
"Number of dimensions must match second dimension of indices.",
"Got ", shape.shape().dim_size(0),
" dimensions, indices shape: ", indices.shape().DebugString()));
OP_REQUIRES(context, shape.NumElements() > 0,
errors::InvalidArgument(
"The shape argument requires at least one element."));

if (use_weights) {
OP_REQUIRES(
Expand All @@ -195,28 +216,11 @@ class SparseCount : public OpKernel {
"; values shape: ", values.shape().DebugString()));
}

OP_REQUIRES(context, shape.NumElements() != 0,
errors::InvalidArgument(
"The shape argument requires at least one element."));

bool is_1d = shape.NumElements() == 1;
auto shape_vector = shape.flat<int64_t>();
int num_batches = is_1d ? 1 : shape_vector(0);
int num_values = values.NumElements();

for (int b = 0; b < shape_vector.size(); b++) {
OP_REQUIRES(context, shape_vector(b) >= 0,
errors::InvalidArgument(
"Elements in dense_shape must be >= 0. Instead got:",
shape.DebugString()));
}

OP_REQUIRES(context, num_values == indices.shape().dim_size(0),
errors::InvalidArgument(
"Number of values must match first dimension of indices.",
"Got ", num_values,
" values, indices shape: ", indices.shape().DebugString()));

const auto indices_values = indices.matrix<int64_t>();
const auto values_values = values.flat<T>();
const auto weight_values = weights.flat<W>();
Expand All @@ -225,16 +229,6 @@ class SparseCount : public OpKernel {

T max_value = 0;

OP_REQUIRES(context, num_values <= indices.shape().dim_size(0),
errors::InvalidArgument(
"The first dimension of indices must be equal to or "
"greather than number of values. ( ",
indices.shape().dim_size(0), " vs. ", num_values, " )"));
OP_REQUIRES(context, indices.shape().dim_size(1) > 0,
errors::InvalidArgument("The second dimension of indices must "
"be greater than 0. Received: ",
indices.shape().dim_size(1)));

for (int idx = 0; idx < num_values; ++idx) {
int batch = is_1d ? 0 : indices_values(idx, 0);
if (batch >= num_batches) {
Expand Down

0 comments on commit 2b7100d

Please sign in to comment.