Skip to content

Commit 03a659d

Browse files
Fix security vulnerability with FractionalAvgPoolGrad
PiperOrigin-RevId: 462292194
1 parent d631c03 commit 03a659d

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

Diff for: tensorflow/core/kernels/fractional_avg_pool_op.cc

+28-2
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
15+
1516
#define EIGEN_USE_THREADS
1617

1718
#include <algorithm>
1819
#include <cmath>
1920
#include <random>
2021
#include <vector>
2122

22-
#include "tensorflow/core/kernels/fractional_pool_common.h"
23-
2423
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
2524
#include "tensorflow/core/framework/numeric_op.h"
2625
#include "tensorflow/core/framework/op_kernel.h"
26+
#include "tensorflow/core/kernels/fractional_pool_common.h"
2727
#include "tensorflow/core/lib/random/random.h"
2828
#include "tensorflow/core/platform/logging.h"
2929
#include "tensorflow/core/platform/mutex.h"
3030
#include "tensorflow/core/util/guarded_philox_random.h"
31+
#include "tensorflow/core/util/overflow.h"
3132

3233
namespace tensorflow {
3334
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -241,7 +242,32 @@ class FractionalAvgPoolGradOp : public OpKernel {
241242
orig_input_tensor_shape.NumElements() == 4,
242243
errors::InvalidArgument("original input tensor shape must be"
243244
"1-dimensional and 4 elements"));
245+
int64_t num_elements = 1;
246+
for (int i = 0; i < orig_input_tensor_shape.dims(); i++) {
247+
OP_REQUIRES(context, orig_input_tensor_shape.dim_size(i) > 0,
248+
errors::InvalidArgument(
249+
"orig_input_tensor_shape must be positive, got: ",
250+
orig_input_tensor_shape.dim_size(i)));
251+
num_elements = MultiplyWithoutOverflow(
252+
num_elements, orig_input_tensor_shape.dim_size(i));
253+
OP_REQUIRES(
254+
context, num_elements > 0,
255+
errors::InvalidArgument(
256+
"The total elements specified by orig_input_tensor_shape",
257+
" is too large. Encountered overflow after multiplying ",
258+
orig_input_tensor_shape.dim_size(i), ", result: ", num_elements));
259+
}
260+
244261
const Tensor& out_backprop = context->input(1);
262+
OP_REQUIRES(context, out_backprop.dims() == 4,
263+
errors::InvalidArgument("out_backprop must be 4-dimensional"));
264+
for (int i = 0; i < out_backprop.dims(); i++) {
265+
OP_REQUIRES(context, out_backprop.dim_size(i) > 0,
266+
errors::InvalidArgument(
267+
"out_backprop must be positive for all dimension, got:",
268+
out_backprop.dim_size(i)));
269+
}
270+
245271
const Tensor& row_seq_tensor = context->input(2);
246272
const Tensor& col_seq_tensor = context->input(3);
247273

Diff for: tensorflow/python/kernel_tests/nn_ops/fractional_avg_pool_op_test.py

+21
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,27 @@ def testLargePoolingRatioThroughGradientError(self):
541541
delta=1e-2)
542542
self.assertLess(gradient_error, error_margin)
543543

544+
def testInvalidSeqRaiseErrorForFractionalAvgPoolGrad(self):
545+
with self.assertRaises((errors.InvalidArgumentError, ValueError)):
546+
with self.cached_session() as _:
547+
overlapping = True
548+
orig_input_tensor_shape = constant_op.constant(
549+
-1879048192, shape=[4], dtype=dtypes.int64)
550+
out_backprop = constant_op.constant([],
551+
shape=[0, 0, 0, 0],
552+
dtype=dtypes.float64)
553+
row_pooling_sequence = constant_op.constant(
554+
1, shape=[4], dtype=dtypes.int64)
555+
col_pooling_sequence = constant_op.constant(
556+
1, shape=[4], dtype=dtypes.int64)
557+
t = gen_nn_ops.fractional_avg_pool_grad(
558+
orig_input_tensor_shape=orig_input_tensor_shape,
559+
out_backprop=out_backprop,
560+
row_pooling_sequence=row_pooling_sequence,
561+
col_pooling_sequence=col_pooling_sequence,
562+
overlapping=overlapping)
563+
self.evaluate(t)
564+
544565

545566
if __name__ == "__main__":
546567
test.main()

0 commit comments

Comments
 (0)