Skip to content
Permalink
Browse files Browse the repository at this point in the history
Add a check for pad width to be a positive value.
PiperOrigin-RevId: 413275853
Change-Id: I261a8db9dabf5ce48a806a9e58129080c9fac619
  • Loading branch information
ishark authored and tensorflower-gardener committed Dec 1, 2021
1 parent 8384494 commit f68fdab
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
10 changes: 10 additions & 0 deletions tensorflow/core/kernels/string_ngrams_op.cc
Expand Up @@ -152,6 +152,16 @@ class StringNGramsOp : public tensorflow::OpKernel {
// We don't have to worry about dynamic padding sizes here: if padding
// was dynamic, every sequence would have had sufficient padding to
// generate at least one ngram.

// If reached here, pad_width should be > 0, pad_width_ = -1,
// which indicates max(ngram_widths) - 1 cannot be used here since
// ngram_width is not known.
OP_REQUIRES(
context, pad_width_ >= 0,
errors::InvalidArgument("Pad width should be >= 0 when "
"preserve_short_sequences is True and "
"ngram_widths are not provided, got ",
pad_width_));
int ngram_width = data_length + 2 * pad_width_;
auto output_start = &ngrams_data[output_start_idx];
int num_ngrams = 1;
Expand Down
25 changes: 22 additions & 3 deletions tensorflow/python/ops/raw_ops_test.py
Expand Up @@ -28,7 +28,6 @@


@test_util.run_all_in_graph_and_eager_modes
@test_util.disable_tfrt
class RawOpsTest(test.TestCase, parameterized.TestCase):

def testSimple(self):
Expand Down Expand Up @@ -63,8 +62,9 @@ def testDefaults(self):
@parameterized.parameters([[0, 8]], [[-1, 6]])
def testStringNGramsBadDataSplits(self, splits):
data = ["aa", "bb", "cc", "dd", "ee", "ff"]
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Invalid split value"):
with self.assertRaisesRegex(
errors.InvalidArgumentError,
r"Invalid split value|First split value must be 0"):
self.evaluate(
gen_string_ops.string_n_grams(
data=data,
Expand All @@ -76,6 +76,25 @@ def testStringNGramsBadDataSplits(self, splits):
pad_width=0,
preserve_short_sequences=False))

def testStringSplit(self):
data = ["123456"]
data_splits = [0, 1]
separator = "a" * 15
ngram_widths = []
pad_width = -5
left_pad = right_pad = ""
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Pad width should be >= 0"):
self.evaluate(gen_string_ops.string_n_grams(
data=data,
data_splits=data_splits,
separator=separator,
ngram_widths=ngram_widths,
left_pad=left_pad,
right_pad=right_pad,
pad_width=pad_width,
preserve_short_sequences=True))

def testGetSessionHandle(self):
if context.executing_eagerly():
with self.assertRaisesRegex(
Expand Down

0 comments on commit f68fdab

Please sign in to comment.