Skip to content
Permalink
Browse files Browse the repository at this point in the history
Validate inputs of FractionalAvgPoolGrad.
PiperOrigin-RevId: 372420640
Change-Id: Icc583928e6cdc3062e12498e4d2337a8fe3da016
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed May 6, 2021
1 parent dcba796 commit 12c727c
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions tensorflow/core/kernels/fractional_avg_pool_op.cc
Expand Up @@ -250,6 +250,19 @@ class FractionalAvgPoolGradOp : public OpKernel {
const int64 out_cols = out_backprop.dim_size(2);
const int64 out_depth = out_backprop.dim_size(3);

OP_REQUIRES(context, row_seq_tensor.NumElements() > out_rows,
errors::InvalidArgument("Given out_backprop shape ",
out_backprop.shape().DebugString(),
", row_seq_tensor must have at least ",
out_rows + 1, " elements, but got ",
row_seq_tensor.NumElements()));
OP_REQUIRES(context, col_seq_tensor.NumElements() > out_cols,
errors::InvalidArgument("Given out_backprop shape ",
out_backprop.shape().DebugString(),
", col_seq_tensor must have at least ",
out_cols + 1, " elements, but got ",
col_seq_tensor.NumElements()));

auto row_seq_tensor_flat = row_seq_tensor.flat<int64>();
auto col_seq_tensor_flat = col_seq_tensor.flat<int64>();
auto orig_input_tensor_shape_flat = orig_input_tensor_shape.flat<int64>();
Expand Down

0 comments on commit 12c727c

Please sign in to comment.