From cebe3c45d76357d201c65bdbbf0dbe6e8a63bbdb Mon Sep 17 00:00:00 2001 From: Alan Liu Date: Fri, 29 Apr 2022 15:53:49 -0700 Subject: [PATCH] Fix tf.raw_ops.StagePeek vulnerability with invalid `index`. Check that input is actually a scalar before treating it as such. PiperOrigin-RevId: 445524908 --- tensorflow/core/kernels/stage_op.cc | 2 ++ .../kernel_tests/data_structures/stage_op_test.py | 11 +++++++++++ tensorflow/python/ops/data_flow_ops.py | 2 +- 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/stage_op.cc b/tensorflow/core/kernels/stage_op.cc index 55c9db22ddf527..f7bb42f9c52b7d 100644 --- a/tensorflow/core/kernels/stage_op.cc +++ b/tensorflow/core/kernels/stage_op.cc @@ -258,6 +258,8 @@ class StagePeekOp : public OpKernel { core::ScopedUnref scope(buf); Buffer::Tuple tuple; + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->input(0).shape()), + errors::InvalidArgument("index must be scalar")); std::size_t index = ctx->input(0).scalar()(); OP_REQUIRES_OK(ctx, buf->Peek(index, &tuple)); diff --git a/tensorflow/python/kernel_tests/data_structures/stage_op_test.py b/tensorflow/python/kernel_tests/data_structures/stage_op_test.py index c720155f3b6c90..d12624f1065928 100644 --- a/tensorflow/python/kernel_tests/data_structures/stage_op_test.py +++ b/tensorflow/python/kernel_tests/data_structures/stage_op_test.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -134,6 +135,16 @@ def testPeek(self): for i in range(10): self.assertTrue(sess.run(peek, feed_dict={p: i}) == [i]) + def testPeekBadIndex(self): + stager = data_flow_ops.StagingArea([ + dtypes.int32, + ], shapes=[[10]]) + stager.put([array_ops.zeros([10], dtype=dtypes.int32)]) + + with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), + 'must be scalar'): + self.evaluate(stager.peek([])) + @test_util.run_deprecated_v1 def testSizeAndClear(self): with ops.Graph().as_default() as G: diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 0076e54833de6c..42fd28c8cc1e60 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -1737,7 +1737,7 @@ def _check_put_dtypes(self, vals, indices=None): # Sanity check number of values if not len(vals) <= len(self._dtypes): - raise ValueError(f"Unexpected number of inputs {len(vals)} vs" + raise ValueError(f"Unexpected number of inputs {len(vals)} vs " f"{len(self._dtypes)}") tensors = []