Skip to content
Permalink
Browse files Browse the repository at this point in the history
Add inputs check for AvgPoolGrad
PiperOrigin-RevId: 488975844
  • Loading branch information
vufg authored and tensorflower-gardener committed Nov 16, 2022
1 parent 171852c commit ddaac2b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
13 changes: 13 additions & 0 deletions tensorflow/core/kernels/avgpooling_op.cc
Expand Up @@ -342,6 +342,19 @@ class AvgPoolingGradOp : public OpKernel {
const T* out_backprop_ptr = out_backprop.flat<T>().data();
T* input_backprop_ptr = output->flat<T>().data();

for (int64_t r = 0; r < out_backprop_rows; ++r) {
int rindex, rsize;
OP_REQUIRES_OK(context,
GetBroadcastSize(r, in_rows, window_rows, row_stride,
pad_rows, &rindex, &rsize));
for (int64_t c = 0; c < out_backprop_cols; ++c) {
int cindex, csize;
OP_REQUIRES_OK(context,
GetBroadcastSize(c, in_cols, window_cols, col_stride,
pad_cols, &cindex, &csize));
}
}

auto shard = [context, out_backprop_ptr, input_backprop_ptr,
out_backprop_rows, out_backprop_cols, out_backprop_depth,
in_rows, in_cols, window_rows, window_cols, row_stride,
Expand Down
15 changes: 15 additions & 0 deletions tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py
Expand Up @@ -2510,6 +2510,21 @@ def testAvgPoolGradInvalidInputShapeRaiseError(self):
data_format="NHWC")
self.evaluate(t)

def testAvgPoolGradInvalidStrideRaiseErrorProperly(self):
with self.assertRaises(errors_impl.InvalidArgumentError):
with self.cached_session():
orig_input_shape = [11, 9, 78, 9]
grad = constant_op.constant(
0.1, shape=[16, 16, 16, 16], dtype=dtypes.float64)
t = gen_nn_ops.AvgPoolGrad(
orig_input_shape=orig_input_shape,
grad=grad,
ksize=[1, 40, 128, 1],
strides=[1, 128, 128, 30],
padding="SAME",
data_format="NHWC")
self.evaluate(t)


def GetMaxPoolFwdTest(input_size, filter_size, strides, padding):

Expand Down

0 comments on commit ddaac2b

Please sign in to comment.