Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix security vulnerability with FractionalMax(AVG)Pool with illegal p…
…ooling_ratio

PiperOrigin-RevId: 501651261
  • Loading branch information
vufg authored and tensorflower-gardener committed Jan 12, 2023
1 parent 0cbca6a commit ee50d1e
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 4 deletions.
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/fractional_avg_pool_op.cc
Expand Up @@ -51,7 +51,7 @@ class FractionalAvgPoolOp : public OpKernel {
pooling_ratio_[i]));
}
OP_REQUIRES(
context, pooling_ratio_[0] == 1 || pooling_ratio_[3] == 1,
context, pooling_ratio_[0] == 1 && pooling_ratio_[3] == 1,
errors::Unimplemented("Fractional average pooling is not yet "
"supported on the batch nor channel dimension."));
OP_REQUIRES_OK(context, context->GetAttr("deterministic", &deterministic_));
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/fractional_max_pool_op.cc
Expand Up @@ -53,7 +53,7 @@ class FractionalMaxPoolOp : public OpKernel {
}

OP_REQUIRES(
context, pooling_ratio_[0] == 1 || pooling_ratio_[3] == 1,
context, pooling_ratio_[0] == 1 && pooling_ratio_[3] == 1,
errors::Unimplemented("Fractional max pooling is not yet "
"supported on the batch nor channel dimension."));

Expand Down
Expand Up @@ -351,7 +351,7 @@ def testPoolingRatioHasMoreDimThanInput(self):
name=None)
self.evaluate(result)

def testPoolingRatioValueOutOfRange(self):
def testPoolingRatioIllegalSmallValue(self):
with self.cached_session() as _:
# Whether turn on `TF2_BEHAVIOR` generates different error messages
with self.assertRaisesRegex(
Expand All @@ -368,6 +368,16 @@ def testPoolingRatioValueOutOfRange(self):
)
self.evaluate(result)

def testPoolingIllegalRatioForBatch(self):
with self.cached_session() as _:
with self.assertRaises(errors.UnimplementedError):
result = nn_ops.gen_nn_ops.fractional_avg_pool(
np.zeros([3, 30, 50, 3]),
[2, 3, 1.5, 1],
True,
True)
self.evaluate(result)


class FractionalAvgPoolGradTest(test.TestCase):
"""Tests for FractionalAvgPoolGrad.
Expand Down
Expand Up @@ -338,7 +338,7 @@ def testPoolingRatioHasMoreDimThanInput(self):
name=None)
self.evaluate(result)

def testPoolingRatioValueOutOfRange(self):
def testPoolingRatioIllegalSmallValue(self):
with self.cached_session() as _:
# Whether turn on `TF2_BEHAVIOR` generates different error messages
with self.assertRaisesRegex(
Expand All @@ -355,6 +355,16 @@ def testPoolingRatioValueOutOfRange(self):
)
self.evaluate(result)

def testPoolingIllegalRatioForBatch(self):
with self.cached_session() as _:
with self.assertRaises(errors.UnimplementedError):
result = nn_ops.fractional_max_pool(
np.zeros([3, 30, 50, 3]),
[2, 3, 1.5, 1],
True,
True)
self.evaluate(result)


class FractionalMaxPoolGradTest(test.TestCase):
"""Tests for FractionalMaxPoolGrad.
Expand Down

0 comments on commit ee50d1e

Please sign in to comment.