@@ -12,22 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
15+
1516#define EIGEN_USE_THREADS
1617
1718#include < algorithm>
1819#include < cmath>
1920#include < random>
2021#include < vector>
2122
22- #include " tensorflow/core/kernels/fractional_pool_common.h"
23-
2423#include " third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
2524#include " tensorflow/core/framework/numeric_op.h"
2625#include " tensorflow/core/framework/op_kernel.h"
26+ #include " tensorflow/core/kernels/fractional_pool_common.h"
2727#include " tensorflow/core/lib/random/random.h"
2828#include " tensorflow/core/platform/logging.h"
2929#include " tensorflow/core/platform/mutex.h"
3030#include " tensorflow/core/util/guarded_philox_random.h"
31+ #include " tensorflow/core/util/overflow.h"
3132
3233namespace tensorflow {
3334typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -241,7 +242,32 @@ class FractionalAvgPoolGradOp : public OpKernel {
241242 orig_input_tensor_shape.NumElements () == 4 ,
242243 errors::InvalidArgument (" original input tensor shape must be"
243244 " 1-dimensional and 4 elements" ));
245+ int64_t num_elements = 1 ;
246+ for (int i = 0 ; i < orig_input_tensor_shape.dims (); i++) {
247+ OP_REQUIRES (context, orig_input_tensor_shape.dim_size (i) > 0 ,
248+ errors::InvalidArgument (
249+ " orig_input_tensor_shape must be positive, got: " ,
250+ orig_input_tensor_shape.dim_size (i)));
251+ num_elements = MultiplyWithoutOverflow (
252+ num_elements, orig_input_tensor_shape.dim_size (i));
253+ OP_REQUIRES (
254+ context, num_elements > 0 ,
255+ errors::InvalidArgument (
256+ " The total elements specified by orig_input_tensor_shape" ,
257+ " is too large. Encountered overflow after multiplying " ,
258+ orig_input_tensor_shape.dim_size (i), " , result: " , num_elements));
259+ }
260+
244261 const Tensor& out_backprop = context->input (1 );
262+ OP_REQUIRES (context, out_backprop.dims () == 4 ,
263+ errors::InvalidArgument (" out_backprop must be 4-dimensional" ));
264+ for (int i = 0 ; i < out_backprop.dims (); i++) {
265+ OP_REQUIRES (context, out_backprop.dim_size (i) > 0 ,
266+ errors::InvalidArgument (
267+ " out_backprop must be positive for all dimension, got:" ,
268+ out_backprop.dim_size (i)));
269+ }
270+
245271 const Tensor& row_seq_tensor = context->input (2 );
246272 const Tensor& col_seq_tensor = context->input (3 );
247273
0 commit comments