Skip to content

Commit

Permalink
Merge pull request #49289 from geetachavan1/cherrypicks_6HKVX
Browse files Browse the repository at this point in the history
CherryPick:2.3:PR #46974: Fix crash of tf.strings.substr when pos and len have different shapes
  • Loading branch information
mihaimaruseac committed May 19, 2021
2 parents f331237 + 38e9840 commit 73aa994
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tensorflow/core/kernels/substr_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ class SubstrOp : public OpKernel {
const Tensor& len_tensor = context->input(2);
const TensorShape& input_shape = input_tensor.shape();
const TensorShape& pos_shape = pos_tensor.shape();
const TensorShape& len_shape = len_tensor.shape();
OP_REQUIRES(context, (pos_shape == len_shape),
errors::InvalidArgument(
"pos and len should have the same shape, got: ",
pos_shape.DebugString(), " vs. ", len_shape.DebugString()));

bool is_scalar = TensorShapeUtils::IsScalar(pos_shape);

Expand Down
9 changes: 9 additions & 0 deletions tensorflow/python/kernel_tests/substr_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,15 @@ def testInvalidUnit(self):
with self.assertRaises(ValueError):
string_ops.substr(b"test", 3, 1, unit="UTF8")

def testInvalidPos(self):
# Test case for GitHub issue 46900.
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
x = string_ops.substr(b"abc", len=1, pos=[1, -1])
self.evaluate(x)

with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
x = string_ops.substr(b"abc", len=1, pos=[1, 2])
self.evaluate(x)

if __name__ == "__main__":
test.main()

0 comments on commit 73aa994

Please sign in to comment.