Skip to content

Commit

Permalink
Fix tf.raw_ops.StagePeek vulnerability with invalid index.
Browse files Browse the repository at this point in the history
Check that input is actually a scalar before treating it as such.

PiperOrigin-RevId: 445524908
  • Loading branch information
poulsbo authored and tensorflower-gardener committed Apr 29, 2022
1 parent 98c0222 commit cebe3c4
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
2 changes: 2 additions & 0 deletions tensorflow/core/kernels/stage_op.cc
Expand Up @@ -258,6 +258,8 @@ class StagePeekOp : public OpKernel {
core::ScopedUnref scope(buf);
Buffer::Tuple tuple;

OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->input(0).shape()),
errors::InvalidArgument("index must be scalar"));
std::size_t index = ctx->input(0).scalar<int>()();

OP_REQUIRES_OK(ctx, buf->Peek(index, &tuple));
Expand Down
11 changes: 11 additions & 0 deletions tensorflow/python/kernel_tests/data_structures/stage_op_test.py
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
Expand Down Expand Up @@ -134,6 +135,16 @@ def testPeek(self):
for i in range(10):
self.assertTrue(sess.run(peek, feed_dict={p: i}) == [i])

def testPeekBadIndex(self):
stager = data_flow_ops.StagingArea([
dtypes.int32,
], shapes=[[10]])
stager.put([array_ops.zeros([10], dtype=dtypes.int32)])

with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
'must be scalar'):
self.evaluate(stager.peek([]))

@test_util.run_deprecated_v1
def testSizeAndClear(self):
with ops.Graph().as_default() as G:
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/ops/data_flow_ops.py
Expand Up @@ -1737,7 +1737,7 @@ def _check_put_dtypes(self, vals, indices=None):

# Sanity check number of values
if not len(vals) <= len(self._dtypes):
raise ValueError(f"Unexpected number of inputs {len(vals)} vs"
raise ValueError(f"Unexpected number of inputs {len(vals)} vs "
f"{len(self._dtypes)}")

tensors = []
Expand Down

0 comments on commit cebe3c4

Please sign in to comment.