Skip to content

Commit ddaac2b

Browse files
panzhufengtensorflower-gardener
authored andcommitted
Add inputs check for AvgPoolGrad
PiperOrigin-RevId: 488975844
1 parent 171852c commit ddaac2b

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

Diff for: tensorflow/core/kernels/avgpooling_op.cc

+13
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,19 @@ class AvgPoolingGradOp : public OpKernel {
342342
const T* out_backprop_ptr = out_backprop.flat<T>().data();
343343
T* input_backprop_ptr = output->flat<T>().data();
344344

345+
for (int64_t r = 0; r < out_backprop_rows; ++r) {
346+
int rindex, rsize;
347+
OP_REQUIRES_OK(context,
348+
GetBroadcastSize(r, in_rows, window_rows, row_stride,
349+
pad_rows, &rindex, &rsize));
350+
for (int64_t c = 0; c < out_backprop_cols; ++c) {
351+
int cindex, csize;
352+
OP_REQUIRES_OK(context,
353+
GetBroadcastSize(c, in_cols, window_cols, col_stride,
354+
pad_cols, &cindex, &csize));
355+
}
356+
}
357+
345358
auto shard = [context, out_backprop_ptr, input_backprop_ptr,
346359
out_backprop_rows, out_backprop_cols, out_backprop_depth,
347360
in_rows, in_cols, window_rows, window_cols, row_stride,

Diff for: tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py

+15
Original file line numberDiff line numberDiff line change
@@ -2510,6 +2510,21 @@ def testAvgPoolGradInvalidInputShapeRaiseError(self):
25102510
data_format="NHWC")
25112511
self.evaluate(t)
25122512

2513+
def testAvgPoolGradInvalidStrideRaiseErrorProperly(self):
2514+
with self.assertRaises(errors_impl.InvalidArgumentError):
2515+
with self.cached_session():
2516+
orig_input_shape = [11, 9, 78, 9]
2517+
grad = constant_op.constant(
2518+
0.1, shape=[16, 16, 16, 16], dtype=dtypes.float64)
2519+
t = gen_nn_ops.AvgPoolGrad(
2520+
orig_input_shape=orig_input_shape,
2521+
grad=grad,
2522+
ksize=[1, 40, 128, 1],
2523+
strides=[1, 128, 128, 30],
2524+
padding="SAME",
2525+
data_format="NHWC")
2526+
self.evaluate(t)
2527+
25132528

25142529
def GetMaxPoolFwdTest(input_size, filter_size, strides, padding):
25152530

0 commit comments

Comments
 (0)