From c122773d9c32bbe7902d95650159e90d5457bf54 Mon Sep 17 00:00:00 2001 From: Sagun Bajra Date: Thu, 14 Jul 2022 10:16:55 -0700 Subject: [PATCH] Fix security vulnerability with UnbatchGradKernel PiperOrigin-RevId: 460992964 --- tensorflow/core/kernels/batch_kernels.cc | 10 +++++ tensorflow/python/ops/batch_ops_test.py | 53 ++++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc index 8f0be9d0a3506d..7859bbe77e1009 100644 --- a/tensorflow/core/kernels/batch_kernels.cc +++ b/tensorflow/core/kernels/batch_kernels.cc @@ -879,8 +879,13 @@ class UnbatchGradResource : public ResourceBase { const Tensor& data_t = context->input(0); const Tensor& batch_index_t = context->input(1); const Tensor& grad_t = context->input(2); + const Tensor& batch_key_t = context->input(3); mutex_lock ml(mu_); + if (batch_key_t.NumElements() != 1) { + return errors::InvalidArgument("Expected `id` to be scalar. Received ", + batch_key_t.DebugString()); + } const int64_t batch_key = context->input(3).scalar()(); // Mark our tensor as available. @@ -896,6 +901,11 @@ class UnbatchGradResource : public ResourceBase { "batch_index is empty while the tensor isn't."); } std::unordered_set missing_tensors; + if (batch_index_t.NumElements() != batch_index_t.dim_size(0) * 3) { + return errors::InvalidArgument( + "batch_index should contain ", batch_index_t.dim_size(0) * 3, + " elements. Received ", batch_index_t.NumElements()); + } const auto batch_index = batch_index_t.shaped({batch_index_t.dim_size(0), 3}); for (int i = 0; i < batch_index_t.dim_size(0); ++i) { diff --git a/tensorflow/python/ops/batch_ops_test.py b/tensorflow/python/ops/batch_ops_test.py index c29a30600c549d..4e0ffab751b36a 100644 --- a/tensorflow/python/ops/batch_ops_test.py +++ b/tensorflow/python/ops/batch_ops_test.py @@ -24,7 +24,9 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -34,6 +36,7 @@ from tensorflow.python.ops import gen_batch_ops from tensorflow.python.ops import gen_functional_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import script_ops from tensorflow.python.ops import variables @@ -561,6 +564,56 @@ def worker(): # The thread's call should hit the timeout, and thus get 0 results. self.assertEqual(len(thread_results), 0) + def testUnbatchGradInvalidId(self): + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate( + gen_batch_ops.unbatch_grad( + original_input=constant_op.constant([1]), + batch_index=constant_op.constant([ + [0, 0, 0], + ], dtype=dtypes.int64), + grad=constant_op.constant([ + 1, + ]), + id=constant_op.constant([ + 1, + 1, + ], dtype=dtypes.int64))) + + def testUnbatchGradInvalidBatchId(self): + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate( + gen_batch_ops.unbatch_grad( + original_input=constant_op.constant([1]), + batch_index=constant_op.constant([ + [0, 0], + ], dtype=dtypes.int64), + grad=constant_op.constant([ + 1, + ]), + id=constant_op.constant([ + 1, + ], dtype=dtypes.int64))) + + def testUnbatchGradInvalidArgs(self): + original_input = random_ops.random_uniform( + shape=(3, 1), dtype=dtypes.float64, maxval=None) + batch_index = random_ops.random_uniform( + shape=(3, 1), dtype=dtypes.int64, maxval=65536) + grad = random_ops.random_uniform( + shape=(3, 1), dtype=dtypes.float64, maxval=None) + batch_id = random_ops.random_uniform( + shape=(3, 1), dtype=dtypes.int64, maxval=65536) + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate( + gen_batch_ops.unbatch_grad( + original_input=original_input, + batch_index=batch_index, + grad=grad, + id=batch_id, + container="", + shared_name="", + name="")) if __name__ == "__main__": test.main()