Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix security vulnerability with DenseBincountOp
PiperOrigin-RevId: 460826735
  • Loading branch information
sagunb authored and tensorflower-gardener committed Jul 13, 2022
1 parent b8fbc52 commit bf4c143
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tensorflow/compiler/tf2xla/kernels/bincount_op.cc
Expand Up @@ -80,6 +80,17 @@ class DenseBincountOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, weights_shape_or.status());

auto weights_shape = weights_shape_or.ValueOrDie();
OP_REQUIRES(ctx,
xla::ShapeUtil::CompatibleIgnoringElementType(weights_shape,
input_shape) ||
(weights_shape.dimensions_size() > 0 &&
weights_shape.dimensions(0) == 0),
errors::InvalidArgument(
"`weights` must be the same shape as `arr` or a length-0 "
"`Tensor`, in which case it acts as all weights equal to "
"1. Received ",
weights_shape.DebugString()));

auto weights_size = weights_shape.dimensions(0);
bool has_weights = false;
if (weights_size) {
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/core/kernels/bincount_op.cc
Expand Up @@ -280,6 +280,14 @@ class DenseBincountOp : public OpKernel {
OP_REQUIRES(ctx, size_t.dims() == 0,
errors::InvalidArgument("Shape must be rank 0 but is rank ",
size_t.dims()));
OP_REQUIRES(ctx,
weights.shape() == data.shape() || weights.NumElements() == 0,
errors::InvalidArgument(
"`weights` must be the same shape as `arr` or a length-0 "
"`Tensor`, in which case it acts as all weights equal to "
"1. Received ",
weights.shape().DebugString()));

Tidx size = size_t.scalar<Tidx>()();
OP_REQUIRES(
ctx, size >= 0,
Expand Down
26 changes: 26 additions & 0 deletions tensorflow/python/kernel_tests/math_ops/bincount_op_test.py
Expand Up @@ -24,6 +24,7 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import bincount_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
Expand Down Expand Up @@ -152,6 +153,31 @@ def test_shape_function(self):
v2 = gen_math_ops.bincount([1, 2, 3, 1, 6, 8], s, [])
self.assertAllEqual(v2.get_shape().as_list(), [None])

@test_util.run_in_graph_and_eager_modes
def test_invalid_inputs(self):
binary_output = True
inp = random_ops.random_uniform(
shape=[10, 10],
minval=-10000,
maxval=10000,
dtype=dtypes.int32,
seed=-2460)
size = random_ops.random_uniform(
shape=[], minval=-10000, maxval=10000, dtype=dtypes.int32, seed=-10000)
weights = random_ops.random_uniform(
shape=[],
minval=-10000,
maxval=10000,
dtype=dtypes.float32,
seed=-10000)
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(
gen_math_ops.dense_bincount(
input=inp,
size=size,
weights=weights,
binary_output=binary_output))


class BincountOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):

Expand Down

0 comments on commit bf4c143

Please sign in to comment.