Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix MaxPool crash on GPU for invalid filter size.
If the filter size exceeds the input size by one for `VALID` padding,
return an empty tensor. This is consistent with XLA.

PiperOrigin-RevId: 462684864
  • Loading branch information
cantonios authored and tensorflower-gardener committed Jul 22, 2022
1 parent 86a827c commit 32d7bd3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tensorflow/core/kernels/maxpooling_op.cc
Expand Up @@ -1268,6 +1268,13 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
ShapeFromFormat(data_format_, params.tensor_in_batch, params.out_height,
params.out_width, params.depth);

// Degenerate pooling output should return an empty tensor.
if (out_shape.num_elements() == 0) {
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
return;
}

// Assuming qint8 <--> NCHW_VECT_C (int8x4) here.
constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)),
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py
Expand Up @@ -772,6 +772,18 @@ def testMaxPoolEmptyInput(self, **kwargs):
expected=[],
**kwargs)

@parameterized.parameters(
GetTestConfigsDicts(nn_ops.max_pool, gen_nn_ops.max_pool_v2))
@test_util.run_deprecated_v1
def testMaxPoolInvalidFilterSize(self, **kwargs):
with self.cached_session(use_gpu=test.is_gpu_available()):
t = constant_op.constant(1.0, shape=[1, 1, 1, 1])
with self.assertRaisesRegex(
(errors_impl.InvalidArgumentError, ValueError),
"Negative dimension size"):
t = self.evaluate(
nn_ops.max_pool(t, ksize=[1, 1, 2, 1], strides=1, padding="VALID"))

# Tests for DepthwiseMaxPooling on CPU only.
@parameterized.parameters(
GetTestConfigsDicts(
Expand Down

0 comments on commit 32d7bd3

Please sign in to comment.