Skip to content

Commit

Permalink
Check for element_shape in TensorListScatter and TensorListScatterV2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 462460346
  • Loading branch information
tensorflower-gardener authored and tensorflow-jenkins committed Aug 19, 2022
1 parent 98fbc78 commit 8e2837d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tensorflow/core/kernels/list_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,11 @@ class TensorListScatter : public OpKernel {
OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
Tensor indices = c->input(1);
PartialTensorShape element_shape;
OP_REQUIRES(
c, !TensorShapeUtils::IsMatrixOrHigher(c->input(2).shape()),
errors::InvalidArgument(
"TensorListScatter: element_shape must be at most rank 1 but has ",
"the shape of ", c->input(2).shape().DebugString()));
OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(2), &element_shape));
// TensorListScatterV2 passes the num_elements input, TensorListScatter does
// not.
Expand Down
24 changes: 24 additions & 0 deletions tensorflow/python/kernel_tests/data_structures/list_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,30 @@ def testScatterOutputListSizeWithNumElementsSpecified(self):
# TensorListScatter should return a list with size num_elements.
self.assertAllEqual(list_ops.tensor_list_length(l), 5)

def testScatterFailsWhenElementShapeIsNotVector(self):
c0 = constant_op.constant([1.0, 2.0])
# In Eager mode, InvalidArgumentError is generated by the Compute function.
# In graph mode, ValueError is generated by the shape function.
with self.assertRaisesRegex(
(errors.InvalidArgumentError, ValueError),
"must be at most rank 1"):
l = gen_list_ops.tensor_list_scatter(
# Wrong element_shape. Should be at most rank 1.
c0, [1, 3], element_shape=[[1]])
self.evaluate(l)

def testScatterV2FailsWhenElementShapeIsNotVector(self):
c0 = constant_op.constant([1.0, 2.0])
# In Eager mode, InvalidArgumentError is generated by the Compute function.
# In graph mode, ValueError is generated by the shape function.
with self.assertRaisesRegex(
(errors.InvalidArgumentError, ValueError),
"must be at most rank 1"):
l = gen_list_ops.tensor_list_scatter_v2(
# Wrong element_shape. Should be at most rank 1.
c0, [1, 3], element_shape=[[1]], num_elements=2)
self.evaluate(l)

def testScatterFailsWhenIndexLargerThanNumElements(self):
c0 = constant_op.constant([1.0, 2.0])
with self.assertRaisesRegex(
Expand Down

0 comments on commit 8e2837d

Please sign in to comment.