Skip to content
Permalink
Browse files Browse the repository at this point in the history
Enhance validation of ngram op and handle case of 0 tokens.
PiperOrigin-RevId: 369940178
Change-Id: Ia82f42c09d14efe76e7dc013505b832a42282f0b
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Apr 22, 2021
1 parent 1cdd4da commit ba424dd
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 11 deletions.
52 changes: 41 additions & 11 deletions tensorflow/core/kernels/string_ngrams_op.cc
Expand Up @@ -61,16 +61,28 @@ class StringNGramsOp : public tensorflow::OpKernel {
OP_REQUIRES_OK(context, context->input("data_splits", &splits));
const auto& splits_vec = splits->flat<SPLITS_TYPE>();

// Validate that the splits are valid indices into data
// Validate that the splits are valid indices into data, only if there are
// splits specified.
const int input_data_size = data->flat<tstring>().size();
const int splits_vec_size = splits_vec.size();
for (int i = 0; i < splits_vec_size; ++i) {
bool valid_splits = splits_vec(i) >= 0;
valid_splits = valid_splits && (splits_vec(i) <= input_data_size);
OP_REQUIRES(
context, valid_splits,
errors::InvalidArgument("Invalid split value ", splits_vec(i),
", must be in [0,", input_data_size, "]"));
if (splits_vec_size > 0) {
int prev_split = splits_vec(0);
OP_REQUIRES(context, prev_split == 0,
errors::InvalidArgument("First split value must be 0, got ",
prev_split));
for (int i = 1; i < splits_vec_size; ++i) {
bool valid_splits = splits_vec(i) >= prev_split;
valid_splits = valid_splits && (splits_vec(i) <= input_data_size);
OP_REQUIRES(context, valid_splits,
errors::InvalidArgument(
"Invalid split value ", splits_vec(i), ", must be in [",
prev_split, ", ", input_data_size, "]"));
prev_split = splits_vec(i);
}
OP_REQUIRES(context, prev_split == input_data_size,
errors::InvalidArgument(
"Last split value must be data size. Expected ",
input_data_size, ", got ", prev_split));
}

int num_batch_items = splits_vec.size() - 1;
Expand Down Expand Up @@ -174,13 +186,31 @@ class StringNGramsOp : public tensorflow::OpKernel {
ngram->append(left_pad_);
ngram->append(separator_);
}
// Only output first num_tokens - 1 pairs of data and separator
for (int n = 0; n < num_tokens - 1; ++n) {
ngram->append(data[data_start_index + n]);
ngram->append(separator_);
}
ngram->append(data[data_start_index + num_tokens - 1]);
for (int n = 0; n < right_padding; ++n) {
ngram->append(separator_);
// Handle case when there are no tokens or no right padding as these can
// result in consecutive separators.
if (num_tokens > 0) {
// If we have tokens, then output last and then pair each separator with
// the right padding that follows, to ensure ngram ends either with the
// token or with the right pad.
ngram->append(data[data_start_index + num_tokens - 1]);
for (int n = 0; n < right_padding; ++n) {
ngram->append(separator_);
ngram->append(right_pad_);
}
} else {
// If we don't have tokens, then the last item inserted into the ngram
// has been the separator from the left padding loop above. Hence,
// output right pad and separator and make sure to finish with a
// padding, not a separator.
for (int n = 0; n < right_padding - 1; ++n) {
ngram->append(right_pad_);
ngram->append(separator_);
}
ngram->append(right_pad_);
}

Expand Down
34 changes: 34 additions & 0 deletions tensorflow/core/kernels/string_ngrams_op_test.cc
Expand Up @@ -542,6 +542,40 @@ TEST_F(NgramKernelTest, TestEmptyInput) {
assert_int64_equal(expected_splits, *GetOutput(1));
}

TEST_F(NgramKernelTest, TestNoTokens) {
MakeOp("|", {3}, "L", "R", -1, false);
// Batch items are:
// 0:
// 1: "a"
AddInputFromArray<tstring>(TensorShape({1}), {"a"});
AddInputFromArray<int64>(TensorShape({3}), {0, 0, 1});
TF_ASSERT_OK(RunOpKernel());

std::vector<tstring> expected_values(
{"L|L|R", "L|R|R", // no input in first split
"L|L|a", "L|a|R", "a|R|R"}); // second split
std::vector<int64> expected_splits({0, 2, 5});

assert_string_equal(expected_values, *GetOutput(0));
assert_int64_equal(expected_splits, *GetOutput(1));
}

TEST_F(NgramKernelTest, TestNoTokensNoPad) {
MakeOp("|", {3}, "", "", 0, false);
// Batch items are:
// 0:
// 1: "a"
AddInputFromArray<tstring>(TensorShape({1}), {"a"});
AddInputFromArray<int64>(TensorShape({3}), {0, 0, 1});
TF_ASSERT_OK(RunOpKernel());

std::vector<tstring> expected_values({});
std::vector<int64> expected_splits({0, 0, 0});

assert_string_equal(expected_values, *GetOutput(0));
assert_int64_equal(expected_splits, *GetOutput(1));
}

TEST_F(NgramKernelTest, ShapeFn) {
ShapeInferenceTestOp op("StringNGrams");
INFER_OK(op, "?;?", "[?];[?]");
Expand Down

0 comments on commit ba424dd

Please sign in to comment.