Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix security vulnerability with FractionalMax(AVG)Pool with illegal p…
…ooling_ratio

PiperOrigin-RevId: 483486453
  • Loading branch information
tensorflower-gardener committed Oct 24, 2022
1 parent d689c19 commit 2165251
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 10 deletions.
14 changes: 11 additions & 3 deletions tensorflow/core/kernels/fractional_avg_pool_op.cc
Expand Up @@ -44,6 +44,12 @@ class FractionalAvgPoolOp : public OpKernel {
OP_REQUIRES(context, pooling_ratio_.size() == 4,
errors::InvalidArgument(
"pooling_ratio field must specify 4 dimensions"));
for (std::size_t i = 0; i < pooling_ratio_.size(); ++i) {
OP_REQUIRES(context, pooling_ratio_[i] >= 1,
errors::InvalidArgument(
"pooling_ratio cannot be smaller than 1, got: ",
pooling_ratio_[i]));
}
OP_REQUIRES(
context, pooling_ratio_[0] == 1 || pooling_ratio_[3] == 1,
errors::Unimplemented("Fractional average pooling is not yet "
Expand Down Expand Up @@ -82,9 +88,11 @@ class FractionalAvgPoolOp : public OpKernel {
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
input_size[i] = tensor_in.dim_size(i);
OP_REQUIRES(
context, pooling_ratio_[i] <= input_size[i],
errors::InvalidArgument(
"Pooling ratio cannot be bigger than input tensor dim size."));
context, input_size[i] >= pooling_ratio_[i],
errors::InvalidArgument("Pooling ratio is higher than input "
"dimension size for dimension ",
i, ". Input dim size: ", input_size[i],
" pooling ratio: ", pooling_ratio_[i]));
}
// Output size.
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/core/kernels/fractional_max_pool_op.cc
Expand Up @@ -45,6 +45,12 @@ class FractionalMaxPoolOp : public OpKernel {
OP_REQUIRES(context, pooling_ratio_.size() == 4,
errors::InvalidArgument("pooling_ratio field must "
"specify 4 dimensions"));
for (std::size_t i = 0; i < pooling_ratio_.size(); ++i) {
OP_REQUIRES(context, pooling_ratio_[i] >= 1,
errors::InvalidArgument(
"pooling_ratio cannot be smaller than 1, got: ",
pooling_ratio_[i]));
}

OP_REQUIRES(
context, pooling_ratio_[0] == 1 || pooling_ratio_[3] == 1,
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/core/ops/nn_ops.cc
Expand Up @@ -63,6 +63,13 @@ Status FractionalPoolShapeFn(InferenceContext* c) {
}
}

for (std::size_t i = 0; i < pooling_ratio.size(); ++i) {
if (pooling_ratio[i] < 1) {
return errors::InvalidArgument(
"pooling_ratio cannot be smaller than 1, got: ", pooling_ratio[i]);
}
}

c->set_output(0, c->MakeShape(output_dims));
c->set_output(1, c->Vector(output_dims[1]));
c->set_output(2, c->Vector(output_dims[2]));
Expand Down
13 changes: 7 additions & 6 deletions tensorflow/core/ops/nn_ops_test.cc
Expand Up @@ -523,7 +523,8 @@ TEST(NNOpsTest, FractionalPool_ShapeFn) {
.Finalize(&op.node_def));
};

set_op(std::vector<float>{2.0f, 1, 1 / 1.5f, 1 / 2.0f});
// pooling_ratio must >= 1.0
set_op(std::vector<float>{2.0f, 1, 1.5f, 4.0f});

// Rank check.
INFER_ERROR("must be rank 4", op, "[?,?,?]");
Expand All @@ -532,11 +533,11 @@ TEST(NNOpsTest, FractionalPool_ShapeFn) {
INFER_OK(op, "?", "[?,?,?,?];[?];[?]");
INFER_OK(op, "[?,?,?,?]", "[?,?,?,?];[?];[?]");

INFER_OK(op, "[10,20,30,40]", "[5,20,45,80];[20];[45]");
INFER_OK(op, "[?,20,30,40]", "[?,20,45,80];[20];[45]");
INFER_OK(op, "[10,?,30,40]", "[5,?,45,80];[?];[45]");
INFER_OK(op, "[10,20,?,40]", "[5,20,?,80];[20];[?]");
INFER_OK(op, "[10,20,30,?]", "[5,20,45,?];[20];[45]");
INFER_OK(op, "[10,20,30,40]", "[5,20,20,10];[20];[20]");
INFER_OK(op, "[?,20,30,40]", "[?,20,20,10];[20];[20]");
INFER_OK(op, "[10,?,30,40]", "[5,?,20,10];[?];[20]");
INFER_OK(op, "[10,20,?,40]", "[5,20,?,10];[20];[?]");
INFER_OK(op, "[10,20,30,?]", "[5,20,20,?];[20];[20]");

// Wrong number of values for pooling_ratio.
set_op(std::vector<float>{.5, 1.0, 1.5});
Expand Down
Expand Up @@ -333,6 +333,41 @@ def testNegativeSeqValuesForGradOp(self):

self.evaluate(z)

def testPoolingRatioHasMoreDimThanInput(self):
with self.cached_session() as _:
with self.assertRaisesRegex(
errors.InvalidArgumentError,
r"Pooling ratio is higher than input dimension size for dimension 1.*"
):
result = nn_ops.gen_nn_ops.fractional_avg_pool(
value=constant_op.constant(
value=[[[[1, 4, 2, 3]]]], dtype=dtypes.int64),
pooling_ratio=[1.0, 1.44, 1.73, 1.0],
pseudo_random=False,
overlapping=False,
deterministic=False,
seed=0,
seed2=0,
name=None)
self.evaluate(result)

def testPoolingRatioValueOutOfRange(self):
with self.cached_session() as _:
# Whether turn on `TF2_BEHAVIOR` generates different error messages
with self.assertRaisesRegex(
(errors.InvalidArgumentError, ValueError),
r"(pooling_ratio cannot be smaller than 1, got: .*)|(is negative)"):
result = nn_ops.gen_nn_ops.fractional_avg_pool(
value=np.zeros([3, 30, 30, 3]),
pooling_ratio=[1, -1, 3, 1],
pseudo_random=False,
overlapping=False,
deterministic=False,
seed=0,
seed2=0,
)
self.evaluate(result)


class FractionalAvgPoolGradTest(test.TestCase):
"""Tests for FractionalAvgPoolGrad.
Expand Down
Expand Up @@ -320,7 +320,7 @@ def testDeterminismExceptionThrowing(self):
nn_ops.fractional_max_pool(
rand_mat, [1, 1.5, 1.5, 1], seed=1, seed2=1, deterministic=True)

def testPoolingRatio(self):
def testPoolingRatioHasMoreDimThanInput(self):
with self.cached_session() as _:
with self.assertRaisesRegex(
errors.InvalidArgumentError,
Expand All @@ -338,6 +338,23 @@ def testPoolingRatio(self):
name=None)
self.evaluate(result)

def testPoolingRatioValueOutOfRange(self):
with self.cached_session() as _:
# Whether turn on `TF2_BEHAVIOR` generates different error messages
with self.assertRaisesRegex(
(errors.InvalidArgumentError, ValueError),
r"(pooling_ratio cannot be smaller than 1, got: .*)|(is negative)"):
result = nn_ops.gen_nn_ops.fractional_max_pool(
value=np.zeros([3, 30, 30, 3]),
pooling_ratio=[1, -1, 3, 1],
pseudo_random=False,
overlapping=False,
deterministic=False,
seed=0,
seed2=0,
)
self.evaluate(result)


class FractionalMaxPoolGradTest(test.TestCase):
"""Tests for FractionalMaxPoolGrad.
Expand Down

0 comments on commit 2165251

Please sign in to comment.