Skip to content

Commit

Permalink
Add sparse tensor validation to SparseBincountOp.
Browse files Browse the repository at this point in the history
Addresses a security issue.

PiperOrigin-RevId: 460573866
  • Loading branch information
cantonios authored and tensorflower-gardener committed Jul 12, 2022
1 parent bbefbe0 commit 40adbe4
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 8 deletions.
4 changes: 4 additions & 0 deletions tensorflow/core/kernels/BUILD
Expand Up @@ -4421,6 +4421,7 @@ tf_kernel_library(
deps = [
":fill_functor",
":gpu_prim_hdrs",
":sparse_utils",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
Expand Down Expand Up @@ -5007,6 +5008,7 @@ cc_library(
SPARSE_DEPS = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
":sparse_utils",
]

tf_kernel_library(
Expand Down Expand Up @@ -6480,6 +6482,7 @@ filegroup(
"sparse_reorder_op.h",
"sparse_slice_op.h",
"sparse_tensor_dense_matmul_op.h",
"sparse_utils.h",
"string_util.h",
"string_to_hash_bucket_op.h",
"string_to_hash_bucket_fast_op.h",
Expand Down Expand Up @@ -6718,6 +6721,7 @@ filegroup(
"random_ops_util.h",
"random_poisson_op.cc",
"shuffle_common.h",
"sparse_utils.cc",
"random_shuffle_op.cc",
"reduce_join_op.cc",
"reduction_ops_all.cc",
Expand Down
13 changes: 9 additions & 4 deletions tensorflow/core/kernels/bincount_op.cc
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bincount_op.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/sparse_utils.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/determinism.h"
Expand Down Expand Up @@ -369,7 +370,8 @@ class SparseBincountOp : public OpKernel {

void Compute(OpKernelContext* ctx) override {
const Tensor& indices = ctx->input(0);
const auto values = ctx->input(1).flat<Tidx>();
const Tensor& values = ctx->input(1);
const auto values_flat = values.flat<Tidx>();
const Tensor& dense_shape = ctx->input(2);
const Tensor& size_t = ctx->input(3);
const auto weights = ctx->input(4).flat<T>();
Expand All @@ -382,6 +384,9 @@ class SparseBincountOp : public OpKernel {
OP_REQUIRES(
ctx, size >= 0,
errors::InvalidArgument("size (", size, ") must be non-negative"));
OP_REQUIRES_OK(
ctx, sparse_utils::ValidateSparseTensor<int64_t>(
indices, values, dense_shape, /*validate_indices=*/true));

bool is_1d = dense_shape.NumElements() == 1;

Expand All @@ -394,11 +399,11 @@ class SparseBincountOp : public OpKernel {
if (binary_output_) {
OP_REQUIRES_OK(ctx,
functor::BincountFunctor<Device, Tidx, T, true>::Compute(
ctx, values, weights, out, size));
ctx, values_flat, weights, out, size));
} else {
OP_REQUIRES_OK(
ctx, functor::BincountFunctor<Device, Tidx, T, false>::Compute(
ctx, values, weights, out, size));
ctx, values_flat, weights, out, size));
}
} else {
const auto shape = dense_shape.flat<int64_t>();
Expand All @@ -410,7 +415,7 @@ class SparseBincountOp : public OpKernel {
const auto indices_mat = indices.matrix<int64_t>();
for (int64_t i = 0; i < indices_mat.dimension(0); ++i) {
const int64_t batch = indices_mat(i, 0);
const Tidx bin = values(i);
const Tidx bin = values_flat(i);
OP_REQUIRES(
ctx, batch < out.dimension(0),
errors::InvalidArgument("Index out of bound. `batch` (", batch,
Expand Down
29 changes: 25 additions & 4 deletions tensorflow/python/kernel_tests/math_ops/bincount_op_test.py
Expand Up @@ -366,7 +366,7 @@ def test_sparse_bincount_all_count(self, dtype):
num_rows = 128
size = 1000
n_elems = 4096
inp_indices = np.random.randint(0, num_rows, (n_elems,))
inp_indices = np.random.randint(0, num_rows, (n_elems, 1))
inp_vals = np.random.randint(0, size, (n_elems,), dtype=dtype)

np_out = np.bincount(inp_vals, minlength=size)
Expand All @@ -390,7 +390,7 @@ def test_sparse_bincount_all_count_with_weights(self, dtype):
num_rows = 128
size = 1000
n_elems = 4096
inp_indices = np.random.randint(0, num_rows, (n_elems,))
inp_indices = np.random.randint(0, num_rows, (n_elems, 1))
inp_vals = np.random.randint(0, size, (n_elems,), dtype=dtype)
inp_weight = np.random.random((n_elems,))

Expand All @@ -415,7 +415,7 @@ def test_sparse_bincount_all_binary(self, dtype):
num_rows = 128
size = 10
n_elems = 4096
inp_indices = np.random.randint(0, num_rows, (n_elems,))
inp_indices = np.random.randint(0, num_rows, (n_elems, 1))
inp_vals = np.random.randint(0, size, (n_elems,), dtype=dtype)

np_out = np.ones((size,))
Expand All @@ -440,7 +440,7 @@ def test_sparse_bincount_all_binary_weights(self, dtype):
num_rows = 128
size = 10
n_elems = 4096
inp_indices = np.random.randint(0, num_rows, (n_elems,))
inp_indices = np.random.randint(0, num_rows, (n_elems, 1))
inp_vals = np.random.randint(0, size, (n_elems,), dtype=dtype)
inp_weight = np.random.random((n_elems,))

Expand Down Expand Up @@ -532,6 +532,27 @@ def test_size_is_not_scalar(self): # b/206619828
weights=[0, 0],
binary_output=False))

def test_sparse_bincount_input_validation(self):
np.random.seed(42)
num_rows = 128
size = 1000
n_elems = 4096
inp_indices = np.random.randint(0, num_rows, (n_elems, 1))
inp_vals = np.random.randint(0, size, (n_elems,))

# Insert negative index.
inp_indices[10, 0] = -2

with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"out of bounds"):
self.evaluate(
gen_math_ops.sparse_bincount(
indices=inp_indices,
values=inp_vals,
dense_shape=[num_rows],
size=size,
weights=[]))


class RaggedBincountOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
Expand Down

0 comments on commit 40adbe4

Please sign in to comment.