Skip to content

Commit

Permalink
Merge pull request #53899 from tensorflow/cherrypick-f68fdab93fb7f4dd…
Browse files Browse the repository at this point in the history
…b4eb438c8fe052753c9413e8-on-r2.6

Add a check for pad width to be a positive value.
  • Loading branch information
mihaimaruseac committed Jan 24, 2022
2 parents 5b0eac3 + c61b184 commit 4c0d509
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
10 changes: 10 additions & 0 deletions tensorflow/core/kernels/string_ngrams_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,16 @@ class StringNGramsOp : public tensorflow::OpKernel {
// We don't have to worry about dynamic padding sizes here: if padding
// was dynamic, every sequence would have had sufficient padding to
// generate at least one ngram.

// If reached here, pad_width should be > 0, pad_width_ = -1,
// which indicates max(ngram_widths) - 1 cannot be used here since
// ngram_width is not known.
OP_REQUIRES(
context, pad_width_ >= 0,
errors::InvalidArgument("Pad width should be >= 0 when "
"preserve_short_sequences is True and "
"ngram_widths are not provided, got ",
pad_width_));
int ngram_width = data_length + 2 * pad_width_;
auto output_start = &ngrams_data[output_start_idx];
int num_ngrams = 1;
Expand Down
25 changes: 22 additions & 3 deletions tensorflow/python/ops/raw_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@


@test_util.run_all_in_graph_and_eager_modes
@test_util.disable_tfrt
class RawOpsTest(test.TestCase, parameterized.TestCase):

def testSimple(self):
Expand Down Expand Up @@ -67,8 +66,9 @@ def testDefaults(self):
@parameterized.parameters([[0, 8]], [[-1, 6]])
def testStringNGramsBadDataSplits(self, splits):
data = ["aa", "bb", "cc", "dd", "ee", "ff"]
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Invalid split value"):
with self.assertRaisesRegex(
errors.InvalidArgumentError,
r"Invalid split value|First split value must be 0"):
self.evaluate(
gen_string_ops.string_n_grams(
data=data,
Expand All @@ -80,6 +80,25 @@ def testStringNGramsBadDataSplits(self, splits):
pad_width=0,
preserve_short_sequences=False))

def testStringSplit(self):
data = ["123456"]
data_splits = [0, 1]
separator = "a" * 15
ngram_widths = []
pad_width = -5
left_pad = right_pad = ""
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Pad width should be >= 0"):
self.evaluate(gen_string_ops.string_n_grams(
data=data,
data_splits=data_splits,
separator=separator,
ngram_widths=ngram_widths,
left_pad=left_pad,
right_pad=right_pad,
pad_width=pad_width,
preserve_short_sequences=True))

def testGetSessionHandle(self):
if context.executing_eagerly():
with self.assertRaisesRegex(
Expand Down

0 comments on commit 4c0d509

Please sign in to comment.