Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix potential divide by zero error when executing FractionalMaxPool, …
…when pooling ratio is higher than input size for a particular dimension.

PiperOrigin-RevId: 412151722
Change-Id: I06e57cbb8eca43816eff79eac264fa7aae8f7163
  • Loading branch information
ishark authored and tensorflower-gardener committed Nov 25, 2021
1 parent 222a7e8 commit ba4e8ac
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tensorflow/core/kernels/fractional_max_pool_op.cc
Expand Up @@ -83,6 +83,13 @@ class FractionalMaxPoolOp : public OpKernel {
std::vector<int> output_size(tensor_in_and_out_dims);
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
input_size[i] = tensor_in.dim_size(i);

OP_REQUIRES(
context, input_size[i] >= pooling_ratio_[i],
errors::InvalidArgument("Pooling ratio is higher than input "
"dimension size for dimension ",
i, ". Input dim size: ", input_size[i],
" pooling ratio: ", pooling_ratio_[i]));
}
// Output size.
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
Expand Down
Expand Up @@ -20,6 +20,7 @@

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
Expand Down Expand Up @@ -319,6 +320,24 @@ def testDeterminismExceptionThrowing(self):
nn_ops.fractional_max_pool(
rand_mat, [1, 1.5, 1.5, 1], seed=1, seed2=1, deterministic=True)

def testPoolingRatio(self):
with self.cached_session() as _:
with self.assertRaisesRegex(
errors.InvalidArgumentError,
r"Pooling ratio is higher than input dimension size for dimension 1.*"
):
result = nn_ops.gen_nn_ops.fractional_max_pool(
value=constant_op.constant(
value=[[[[1, 4, 2, 3]]]], dtype=dtypes.int64),
pooling_ratio=[1.0, 1.44, 1.73, 1.0],
pseudo_random=False,
overlapping=False,
deterministic=False,
seed=0,
seed2=0,
name=None)
self.evaluate(result)


class FractionalMaxPoolGradTest(test.TestCase):
"""Tests for FractionalMaxPoolGrad.
Expand Down

0 comments on commit ba4e8ac

Please sign in to comment.