Skip to content
Permalink
Browse files Browse the repository at this point in the history
Check for element_shape in TensorListFromTensor
PiperOrigin-RevId: 462468167
  • Loading branch information
tensorflower-gardener committed Jul 21, 2022
1 parent 2e80c0b commit 3db59a0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tensorflow/core/kernels/list_kernels.h
Expand Up @@ -769,6 +769,11 @@ class TensorListFromTensor : public OpKernel {
attr.set_on_host(true);
OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
PartialTensorShape element_shape;
OP_REQUIRES(
c, !TensorShapeUtils::IsMatrixOrHigher(c->input(1).shape()),
errors::InvalidArgument(
"TensorListFromTensor: element_shape must be at most rank 1 but ",
"has the shape of ", c->input(1).shape().DebugString()));
OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(1), &element_shape));
TensorList output_list;
const Tensor& t = c->input(0);
Expand Down
11 changes: 11 additions & 0 deletions tensorflow/python/kernel_tests/data_structures/list_ops_test.py
Expand Up @@ -584,6 +584,17 @@ def testTensorListFromTensor(self):
self.assertAllEqual(e, 1.0)
self.assertAllEqual(list_ops.tensor_list_length(l), 0)

def testTensorListFromTensorFailsWhenElementShapeIsNotVector(self):
t = constant_op.constant([1.0, 2.0])
# In Eager mode, InvalidArgumentError is generated by the Compute function.
# In graph mode, ValueError is generated by the shape function.
with self.assertRaisesRegex(
(errors.InvalidArgumentError, ValueError),
"must be at most rank 1"):
# Wrong element_shape. Should be at most rank 1.
l = list_ops.tensor_list_from_tensor(t, element_shape=[[1]])
self.evaluate(l)

@test_util.run_gpu_only
def testFromTensorGPU(self):
with context.device("gpu:0"):
Expand Down

0 comments on commit 3db59a0

Please sign in to comment.