Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix security vulnerability with AvgPool3DGrad.
PiperOrigin-RevId: 461244371
  • Loading branch information
tensorflower-gardener committed Jul 15, 2022
1 parent 2e0578e commit 9178ac9
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 2 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/tf2xla/BUILD
Expand Up @@ -403,6 +403,7 @@ cc_library(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core/util:overflow",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
Expand Down
11 changes: 11 additions & 0 deletions tensorflow/compiler/tf2xla/xla_op_kernel.cc
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/util/overflow.h"

namespace tensorflow {

Expand Down Expand Up @@ -443,6 +444,16 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape,
TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
std::vector<int64_t> dims;
TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));

int64_t num_elements = 1;
for (auto i = dims.begin(); i != dims.end(); ++i) {
num_elements = MultiplyWithoutOverflow(num_elements, *i);
if (num_elements < 0)
return errors::InvalidArgument(
"The total elements specified by orig_input_shape is too large.",
"Encountered overflow after multiplying", *i,
", result: ", num_elements);
}
*shape = TensorShape(dims);
return OkStatus();
}
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/pooling_ops_3d.cc
Expand Up @@ -523,7 +523,7 @@ class AvgPooling3dGradOp : public OpKernel {
TensorShape output_shape;
auto shape_vec = tensor_in_shape.vec<int32>();
for (int64_t i = 0; i < tensor_in_shape.NumElements(); ++i) {
output_shape.AddDim(shape_vec(i));
OP_REQUIRES_OK(context, output_shape.AddDimWithStatus(shape_vec(i)));
}

Tensor* output;
Expand Down
1 change: 1 addition & 0 deletions tensorflow/python/kernel_tests/nn_ops/BUILD
Expand Up @@ -500,6 +500,7 @@ cuda_py_test(
srcs = ["pooling_ops_3d_test.py"],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:nn_grad",
"//tensorflow/python:nn_ops",
Expand Down
20 changes: 19 additions & 1 deletion tensorflow/python/kernel_tests/nn_ops/pooling_ops_3d_test.py
Expand Up @@ -18,6 +18,7 @@

from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
Expand Down Expand Up @@ -67,7 +68,7 @@ def _VerifyOneTest(self, pool_func, input_sizes, window, strides, padding,
# Initializes the input tensor with array containing incrementing
# numbers from 1.
x = [f * 1.0 for f in range(1, total_size + 1)]
with self.cached_session(use_gpu=use_gpu) as sess:
with self.cached_session(use_gpu=use_gpu):
t = constant_op.constant(x, shape=input_sizes)
window = [1] + list(window) + [1]
strides = [1] + list(strides) + [1]
Expand Down Expand Up @@ -124,6 +125,23 @@ def testAvgPool3dSamePaddingDifferentStrides(self):
padding="SAME",
expected=expected_output)

def testMaxPool3dGrad(self):
with self.assertRaises(
(errors.ResourceExhaustedError, errors.InvalidArgumentError)):
with self.cached_session():
orig_input_shape = constant_op.constant(
1879048192, shape=[5], dtype=dtypes.int32)
grad = constant_op.constant(
1, shape=[1, 3, 2, 4, 2], dtype=dtypes.float32)
t = gen_nn_ops.AvgPool3DGrad(
orig_input_shape=orig_input_shape,
grad=grad,
ksize=[1, 1, 1, 1, 1],
strides=[1, 1, 1, 1, 1],
padding="SAME",
data_format="NDHWC")
self.evaluate(t)

def testMaxPool3dValidPadding(self):
expected_output = [40.0, 41.0, 42.0]
self._VerifyValues(
Expand Down

0 comments on commit 9178ac9

Please sign in to comment.