Skip to content
Permalink
Browse files Browse the repository at this point in the history
Further validate sparse tensor for SparseCount: indices must be val…
…id within dense shape.

PiperOrigin-RevId: 414888122
Change-Id: I4552bd74c135ecd4bcb5448acc0a3ce9402d8286
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Dec 8, 2021
1 parent 2b7100d commit adbbabd
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions tensorflow/core/kernels/count_ops.cc
Expand Up @@ -206,6 +206,23 @@ class SparseCount : public OpKernel {
OP_REQUIRES(context, shape.NumElements() > 0,
errors::InvalidArgument(
"The shape argument requires at least one element."));
// Validate indices: each index must be valid for the corresponding
// dimension. This could be possibly done better.
const auto indices_values = indices.matrix<int64_t>();
const auto shape_vector = shape.vec<int64_t>();
int num_values = values.NumElements(); // same as first dim of indices
int rank = indices.shape().dim_size(1);
for (int i = 0; i < num_values; ++i) {
for (int j = 0; j < rank; ++j) {
OP_REQUIRES(
context,
indices_values(i, j) >= 0 && indices_values(i, j) < shape_vector(j),
errors::InvalidArgument(
"Invalid index value at ", i, ": dimension ", j, " has value ",
indices_values(i, j), " which is not in [0, ", shape_vector(j),
") (as given by dense shape ", shape.DebugString()));
}
}

if (use_weights) {
OP_REQUIRES(
Expand All @@ -217,11 +234,8 @@ class SparseCount : public OpKernel {
}

bool is_1d = shape.NumElements() == 1;
auto shape_vector = shape.flat<int64_t>();
int num_batches = is_1d ? 1 : shape_vector(0);
int num_values = values.NumElements();

const auto indices_values = indices.matrix<int64_t>();
const auto values_values = values.flat<T>();
const auto weight_values = weights.flat<W>();

Expand Down

0 comments on commit adbbabd

Please sign in to comment.