From 4dddb2fd0b01cdd196101afbba6518658a2c9e07 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Wed, 20 Oct 2021 14:53:58 -0700 Subject: [PATCH] Fix segfault in pools on empty shapes when certain dimension were very large. Pooling ops multiply certain components of the input shape, e.g. by multiplying input.shape[1] * input.shape[2] * input.shape[3]. This multiplication could overflow an int64 value if shape[0] was 0 but shape[1], shape[2], and shape[3] were very large, e.g. by passing an input with shape (0, 2**25, 2**25, 2**25). PiperOrigin-RevId: 404644978 Change-Id: Ic79f89c970357ca2962b1f231449066db9403146 --- tensorflow/core/kernels/pooling_ops_common.h | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h index 1a41a5adca5d29..8890e24a32ad97 100644 --- a/tensorflow/core/kernels/pooling_ops_common.h +++ b/tensorflow/core/kernels/pooling_ops_common.h @@ -189,6 +189,9 @@ class MaxPoolingOp : public OpKernel { void SpatialMaxPool(OpKernelContext* context, Tensor* output, const Tensor& tensor_in, const PoolParameters& params, const Padding& padding) { + if (output->NumElements() == 0) { + return; + } // On GPU, use Eigen's Spatial Max Pooling. On CPU, use an // EigenMatrix version that is currently faster than Eigen's // Spatial MaxPooling implementation. @@ -443,6 +446,9 @@ class MaxPoolingV2Op : public OpKernel { void SpatialMaxPool(OpKernelContext* context, Tensor* output, const Tensor& tensor_in, const PoolParameters& params, const Padding& padding) { + if (output->NumElements() == 0) { + return; + } // On GPU, use Eigen's Spatial Max Pooling. On CPU, use an // EigenMatrix version that is currently faster than Eigen's // Spatial MaxPooling implementation. @@ -561,6 +567,9 @@ template void SpatialAvgPool(OpKernelContext* context, Tensor* output, const Tensor& input, const PoolParameters& params, const Padding& padding) { + if (output->NumElements() == 0) { + return; + } typedef Eigen::Map> ConstEigenMatrixMap; typedef Eigen::Map>