Skip to content

Commit 32d7bd3

Browse files
cantoniostensorflower-gardener
authored andcommitted
Fix MaxPool crash on GPU for invalid filter size.
If the filter size exceeds the input size by one for `VALID` padding, return an empty tensor. This is consistent with XLA. PiperOrigin-RevId: 462684864
1 parent 86a827c commit 32d7bd3

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

Diff for: tensorflow/core/kernels/maxpooling_op.cc

+7
Original file line numberDiff line numberDiff line change
@@ -1268,6 +1268,13 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
12681268
ShapeFromFormat(data_format_, params.tensor_in_batch, params.out_height,
12691269
params.out_width, params.depth);
12701270

1271+
// Degenerate pooling output should return an empty tensor.
1272+
if (out_shape.num_elements() == 0) {
1273+
Tensor* output = nullptr;
1274+
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
1275+
return;
1276+
}
1277+
12711278
// Assuming qint8 <--> NCHW_VECT_C (int8x4) here.
12721279
constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
12731280
OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)),

Diff for: tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py

+12
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,18 @@ def testMaxPoolEmptyInput(self, **kwargs):
772772
expected=[],
773773
**kwargs)
774774

775+
@parameterized.parameters(
776+
GetTestConfigsDicts(nn_ops.max_pool, gen_nn_ops.max_pool_v2))
777+
@test_util.run_deprecated_v1
778+
def testMaxPoolInvalidFilterSize(self, **kwargs):
779+
with self.cached_session(use_gpu=test.is_gpu_available()):
780+
t = constant_op.constant(1.0, shape=[1, 1, 1, 1])
781+
with self.assertRaisesRegex(
782+
(errors_impl.InvalidArgumentError, ValueError),
783+
"Negative dimension size"):
784+
t = self.evaluate(
785+
nn_ops.max_pool(t, ksize=[1, 1, 2, 1], strides=1, padding="VALID"))
786+
775787
# Tests for DepthwiseMaxPooling on CPU only.
776788
@parameterized.parameters(
777789
GetTestConfigsDicts(

0 commit comments

Comments
 (0)