Skip to content

Commit

Permalink
Merge pull request #53921 from tensorflow/cherrypick-ba4e8ac4dc2991e3…
Browse files Browse the repository at this point in the history
…50d5cc407f8598c8d4ee70fb-on-r2.7

Fix potential divide by zero error when executing FractionalMaxPool, …
  • Loading branch information
mihaimaruseac committed Jan 24, 2022
2 parents 45c06d5 + 152cf2c commit 14b3f05
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tensorflow/core/kernels/fractional_max_pool_op.cc
Expand Up @@ -83,6 +83,13 @@ class FractionalMaxPoolOp : public OpKernel {
std::vector<int> output_size(tensor_in_and_out_dims);
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
input_size[i] = tensor_in.dim_size(i);

OP_REQUIRES(
context, input_size[i] >= pooling_ratio_[i],
errors::InvalidArgument("Pooling ratio is higher than input "
"dimension size for dimension ",
i, ". Input dim size: ", input_size[i],
" pooling ratio: ", pooling_ratio_[i]));
}
// Output size.
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
Expand Down
17 changes: 17 additions & 0 deletions tensorflow/python/kernel_tests/fractional_max_pool_op_test.py
Expand Up @@ -24,6 +24,7 @@

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
Expand Down Expand Up @@ -307,6 +308,22 @@ def testDifferentInputTensorShape(self):
input_b, row_seq, col_seq, overlapping)
self.assertSequenceEqual(expected.shape, actual.shape)

def testDeterminismExceptionThrowing(self):
tensor_shape = (5, 20, 20, 3)
rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
with test_util.deterministic_ops():
with self.assertRaisesRegex(
ValueError, "requires a non-zero seed to be passed in when "
"determinism is enabled"):
nn_ops.fractional_max_pool_v2(rand_mat, [1, 1.5, 1.5, 1])
nn_ops.fractional_max_pool_v2(rand_mat, [1, 1.5, 1.5, 1], seed=1)

with self.assertRaisesRegex(ValueError,
'requires "seed" and "seed2" to be non-zero'):
nn_ops.fractional_max_pool(rand_mat, [1, 1.5, 1.5, 1])
nn_ops.fractional_max_pool(
rand_mat, [1, 1.5, 1.5, 1], seed=1, seed2=1, deterministic=True)


class FractionalMaxPoolGradTest(test.TestCase):
"""Tests for FractionalMaxPoolGrad.
Expand Down

0 comments on commit 14b3f05

Please sign in to comment.