Skip to content

Commit

Permalink
Merge pull request #49664 from geetachavan1/cherrypicks_Y2CSH
Browse files Browse the repository at this point in the history
[CherryPick]Add missing validation, prevent heap OOB
  • Loading branch information
mihaimaruseac committed May 26, 2021
2 parents e86c9d7 + 9272dbe commit f7a80ec
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions tensorflow/core/kernels/pooling_ops_3d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -697,11 +697,36 @@ class MaxPooling3dGradGradOp : public OpKernel {

Pool3dParameters params{context, ksize_, stride_,
padding_, data_format_, tensor_in.shape()};
if (!context->status().ok()) return; // params is invalid

Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{2}, 0, tensor_out.shape(), &output));

// Given access patterns in LaunchMaxPooling3dGradGradOp, these tensors must
// have elements.
OP_REQUIRES(context, tensor_in.NumElements() > 0,
errors::InvalidArgument("received empty tensor tensor_in: ",
tensor_in.DebugString()));
OP_REQUIRES(context, tensor_out.NumElements() > 0,
errors::InvalidArgument("received empty tensor tensor_out: ",
tensor_out.DebugString()));
OP_REQUIRES(
context, out_grad_backprop.NumElements() > 0,
errors::InvalidArgument("received empty tensor out_grad_backprop: ",
out_grad_backprop.DebugString()));
OP_REQUIRES(context,
tensor_in.NumElements() == out_grad_backprop.NumElements(),
errors::InvalidArgument("tensor_in and out_grad_backprop must "
"have same number of elements, got <",
tensor_in.DebugString(), "> and <",
out_grad_backprop.DebugString(), ">"));
OP_REQUIRES(
context, tensor_out.NumElements() == output->NumElements(),
errors::InvalidArgument(
"tensor_out and output must have same number of elements, got <",
tensor_out.DebugString(), "> and <", output->DebugString(), ">"));

LaunchMaxPooling3dGradGradOp<Device, T>::launch(
context, params, tensor_in, tensor_out, out_grad_backprop, output);
}
Expand Down

0 comments on commit f7a80ec

Please sign in to comment.