Skip to content

Commit ba424dd

Browse files
Enhance validation of ngram op and handle case of 0 tokens.
PiperOrigin-RevId: 369940178 Change-Id: Ia82f42c09d14efe76e7dc013505b832a42282f0b
1 parent 1cdd4da commit ba424dd

File tree

2 files changed

+75
-11
lines changed

2 files changed

+75
-11
lines changed

Diff for: tensorflow/core/kernels/string_ngrams_op.cc

+41-11
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,28 @@ class StringNGramsOp : public tensorflow::OpKernel {
6161
OP_REQUIRES_OK(context, context->input("data_splits", &splits));
6262
const auto& splits_vec = splits->flat<SPLITS_TYPE>();
6363

64-
// Validate that the splits are valid indices into data
64+
// Validate that the splits are valid indices into data, only if there are
65+
// splits specified.
6566
const int input_data_size = data->flat<tstring>().size();
6667
const int splits_vec_size = splits_vec.size();
67-
for (int i = 0; i < splits_vec_size; ++i) {
68-
bool valid_splits = splits_vec(i) >= 0;
69-
valid_splits = valid_splits && (splits_vec(i) <= input_data_size);
70-
OP_REQUIRES(
71-
context, valid_splits,
72-
errors::InvalidArgument("Invalid split value ", splits_vec(i),
73-
", must be in [0,", input_data_size, "]"));
68+
if (splits_vec_size > 0) {
69+
int prev_split = splits_vec(0);
70+
OP_REQUIRES(context, prev_split == 0,
71+
errors::InvalidArgument("First split value must be 0, got ",
72+
prev_split));
73+
for (int i = 1; i < splits_vec_size; ++i) {
74+
bool valid_splits = splits_vec(i) >= prev_split;
75+
valid_splits = valid_splits && (splits_vec(i) <= input_data_size);
76+
OP_REQUIRES(context, valid_splits,
77+
errors::InvalidArgument(
78+
"Invalid split value ", splits_vec(i), ", must be in [",
79+
prev_split, ", ", input_data_size, "]"));
80+
prev_split = splits_vec(i);
81+
}
82+
OP_REQUIRES(context, prev_split == input_data_size,
83+
errors::InvalidArgument(
84+
"Last split value must be data size. Expected ",
85+
input_data_size, ", got ", prev_split));
7486
}
7587

7688
int num_batch_items = splits_vec.size() - 1;
@@ -174,13 +186,31 @@ class StringNGramsOp : public tensorflow::OpKernel {
174186
ngram->append(left_pad_);
175187
ngram->append(separator_);
176188
}
189+
// Only output first num_tokens - 1 pairs of data and separator
177190
for (int n = 0; n < num_tokens - 1; ++n) {
178191
ngram->append(data[data_start_index + n]);
179192
ngram->append(separator_);
180193
}
181-
ngram->append(data[data_start_index + num_tokens - 1]);
182-
for (int n = 0; n < right_padding; ++n) {
183-
ngram->append(separator_);
194+
// Handle case when there are no tokens or no right padding as these can
195+
// result in consecutive separators.
196+
if (num_tokens > 0) {
197+
// If we have tokens, then output last and then pair each separator with
198+
// the right padding that follows, to ensure ngram ends either with the
199+
// token or with the right pad.
200+
ngram->append(data[data_start_index + num_tokens - 1]);
201+
for (int n = 0; n < right_padding; ++n) {
202+
ngram->append(separator_);
203+
ngram->append(right_pad_);
204+
}
205+
} else {
206+
// If we don't have tokens, then the last item inserted into the ngram
207+
// has been the separator from the left padding loop above. Hence,
208+
// output right pad and separator and make sure to finish with a
209+
// padding, not a separator.
210+
for (int n = 0; n < right_padding - 1; ++n) {
211+
ngram->append(right_pad_);
212+
ngram->append(separator_);
213+
}
184214
ngram->append(right_pad_);
185215
}
186216

Diff for: tensorflow/core/kernels/string_ngrams_op_test.cc

+34
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,40 @@ TEST_F(NgramKernelTest, TestEmptyInput) {
542542
assert_int64_equal(expected_splits, *GetOutput(1));
543543
}
544544

545+
TEST_F(NgramKernelTest, TestNoTokens) {
546+
MakeOp("|", {3}, "L", "R", -1, false);
547+
// Batch items are:
548+
// 0:
549+
// 1: "a"
550+
AddInputFromArray<tstring>(TensorShape({1}), {"a"});
551+
AddInputFromArray<int64>(TensorShape({3}), {0, 0, 1});
552+
TF_ASSERT_OK(RunOpKernel());
553+
554+
std::vector<tstring> expected_values(
555+
{"L|L|R", "L|R|R", // no input in first split
556+
"L|L|a", "L|a|R", "a|R|R"}); // second split
557+
std::vector<int64> expected_splits({0, 2, 5});
558+
559+
assert_string_equal(expected_values, *GetOutput(0));
560+
assert_int64_equal(expected_splits, *GetOutput(1));
561+
}
562+
563+
TEST_F(NgramKernelTest, TestNoTokensNoPad) {
564+
MakeOp("|", {3}, "", "", 0, false);
565+
// Batch items are:
566+
// 0:
567+
// 1: "a"
568+
AddInputFromArray<tstring>(TensorShape({1}), {"a"});
569+
AddInputFromArray<int64>(TensorShape({3}), {0, 0, 1});
570+
TF_ASSERT_OK(RunOpKernel());
571+
572+
std::vector<tstring> expected_values({});
573+
std::vector<int64> expected_splits({0, 0, 0});
574+
575+
assert_string_equal(expected_values, *GetOutput(0));
576+
assert_int64_equal(expected_splits, *GetOutput(1));
577+
}
578+
545579
TEST_F(NgramKernelTest, ShapeFn) {
546580
ShapeInferenceTestOp op("StringNGrams");
547581
INFER_OK(op, "?;?", "[?];[?]");

0 commit comments

Comments
 (0)