diff --git a/tensorflow/core/kernels/stage_op.cc b/tensorflow/core/kernels/stage_op.cc index 55c9db22ddf527..f7bb42f9c52b7d 100644 --- a/tensorflow/core/kernels/stage_op.cc +++ b/tensorflow/core/kernels/stage_op.cc @@ -258,6 +258,8 @@ class StagePeekOp : public OpKernel { core::ScopedUnref scope(buf); Buffer::Tuple tuple; + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->input(0).shape()), + errors::InvalidArgument("index must be scalar")); std::size_t index = ctx->input(0).scalar()(); OP_REQUIRES_OK(ctx, buf->Peek(index, &tuple)); diff --git a/tensorflow/python/kernel_tests/data_structures/stage_op_test.py b/tensorflow/python/kernel_tests/data_structures/stage_op_test.py index c720155f3b6c90..d12624f1065928 100644 --- a/tensorflow/python/kernel_tests/data_structures/stage_op_test.py +++ b/tensorflow/python/kernel_tests/data_structures/stage_op_test.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -134,6 +135,16 @@ def testPeek(self): for i in range(10): self.assertTrue(sess.run(peek, feed_dict={p: i}) == [i]) + def testPeekBadIndex(self): + stager = data_flow_ops.StagingArea([ + dtypes.int32, + ], shapes=[[10]]) + stager.put([array_ops.zeros([10], dtype=dtypes.int32)]) + + with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), + 'must be scalar'): + self.evaluate(stager.peek([])) + @test_util.run_deprecated_v1 def testSizeAndClear(self): with ops.Graph().as_default() as G: diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 0076e54833de6c..42fd28c8cc1e60 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -1737,7 +1737,7 @@ def _check_put_dtypes(self, vals, indices=None): # Sanity check number of values if not len(vals) <= len(self._dtypes): - raise ValueError(f"Unexpected number of inputs {len(vals)} vs" + raise ValueError(f"Unexpected number of inputs {len(vals)} vs " f"{len(self._dtypes)}") tensors = []