Skip to content
Permalink
Browse files Browse the repository at this point in the history
Merge pull request #51975 from yongtang:51936-max_pool3d
PiperOrigin-RevId: 401245519
Change-Id: I67d2cbb0e21729b94186ca9bf82450ff93132ff2
  • Loading branch information
tensorflower-gardener committed Oct 6, 2021
2 parents a2bbfe8 + 647ae1f commit 12b1ff8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tensorflow/core/kernels/pooling_ops_3d.cc
Expand Up @@ -141,6 +141,11 @@ class Pooling3DOp : public UnaryOp<T> {
OP_REQUIRES(context, ksize_.size() == 5,
errors::InvalidArgument("Sliding window ksize field must "
"specify 5 dimensions"));
bool non_negative =
std::all_of(ksize_.begin(), ksize_.end(), [](int k) { return k > 0; });
OP_REQUIRES(context, non_negative,
errors::InvalidArgument("Sliding window ksize field must "
"have non-negative dimensions"));
OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
OP_REQUIRES(context, stride_.size() == 5,
errors::InvalidArgument("Sliding window stride field must "
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/python/kernel_tests/pooling_ops_3d_test.py
Expand Up @@ -17,6 +17,7 @@
import numpy as np

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
Expand Down Expand Up @@ -501,6 +502,19 @@ def testAvgPoolGradSamePadding3_1_3d(self):
strides=(1, 1, 1),
padding="SAME")

def testMaxPool3DZeroPoolSize(self):
# Test case for GitHub issue 51936.
for f in [nn_ops.max_pool3d, nn_ops.avg_pool3d]:
with self.session():
with self.assertRaises((errors.InvalidArgumentError, ValueError)):
input_sizes = [3, 4, 10, 11, 12]

input_data = 1.
input_tensor = constant_op.constant(
input_data, shape=input_sizes, name="input")
pool_3d = f(input_tensor, ksize=[2, 2, 0], strides=1, padding="VALID")
self.evaluate(pool_3d)


if __name__ == "__main__":
test.main()

0 comments on commit 12b1ff8

Please sign in to comment.