Skip to content

Commit

Permalink
Merge pull request #59552 from tensorflow/r2.11-8ae76cf085f
Browse files Browse the repository at this point in the history
r2.11 cherry-pick: 8ae76cf "[Tensorflow] Fix security vulnerability with DenseBincountOp"
  • Loading branch information
mihaimaruseac committed Feb 4, 2023
2 parents d82af7b + 2a6ee54 commit ebe119c
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 7 deletions.
16 changes: 16 additions & 0 deletions tensorflow/compiler/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2344,3 +2344,19 @@ tf_xla_py_test(
"//tensorflow/python:training",
],
)

tf_xla_py_test(
name = "bincount_op_test",
size = "small",
srcs = ["bincount_op_test.py"],
enable_mlir_bridge = False,
python_version = "PY3",
shard_count = 10,
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:platform_test",
],
)
40 changes: 40 additions & 0 deletions tensorflow/compiler/tests/bincount_op_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for bincount using the XLA JIT."""
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import errors
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.platform import googletest


class BincountTest(xla_test.XLATestCase):

def testInputRank0(self):
with self.session():
with self.test_scope():
bincount = gen_math_ops.bincount(arr=6, size=804, weights=[52, 351])

with self.assertRaisesRegex(
errors.InvalidArgumentError,
(
"`weights` must be the same shape as `arr` or a length-0"
" `Tensor`, in which case it acts as all weights equal to 1."
),
):
self.evaluate(bincount)


if __name__ == "__main__":
googletest.main()
17 changes: 10 additions & 7 deletions tensorflow/compiler/tf2xla/kernels/bincount_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,15 @@ class DenseBincountOp : public XlaOpKernel {
StatusOr<xla::Shape> input_shape_or = ctx->builder()->GetShape(input);
OP_REQUIRES_OK(ctx, input_shape_or.status());
auto input_shape = input_shape_or.value();
auto size = input_shape.dimensions(0);

if (!size) {
output = xla::Broadcast(zero, {output_size});
ctx->SetOutput(0, output);
return;
}
auto rank = input_shape.rank();

OP_REQUIRES(ctx, rank <= 2,
errors::InvalidArgument(
"Shape must be at most rank 2 but is rank ", rank));

xla::XlaOp weights = ctx->Input(2);
StatusOr<xla::Shape> weights_shape_or = ctx->builder()->GetShape(weights);

OP_REQUIRES_OK(ctx, weights_shape_or.status());

auto weights_shape = weights_shape_or.value();
Expand All @@ -91,11 +85,20 @@ class DenseBincountOp : public XlaOpKernel {
"1. Received ",
weights_shape.DebugString()));

auto size = input_shape.dimensions(0);

if (!size) {
output = xla::Broadcast(zero, {output_size});
ctx->SetOutput(0, output);
return;
}

auto weights_size = weights_shape.dimensions(0);
bool has_weights = false;
if (weights_size) {
has_weights = true;
}

xla::Shape output_shape = xla::ShapeUtil::MakeShape(dtype, {output_size});
xla::ScatterDimensionNumbers scatter_dnums;
scatter_dnums.set_index_vector_dim(1);
Expand Down

0 comments on commit ebe119c

Please sign in to comment.