Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tf.regex_match for regex match support #19160

Merged
merged 8 commits into from
May 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 30 additions & 0 deletions tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
op {
graph_op_name: "RegexFullMatch"
in_arg {
name: "input"
description: <<END
A string tensor of the text to be processed.
END
}
in_arg {
name: "pattern"
description: <<END
A 1-D string tensor of the regular expression to match the input.
END
}
out_arg {
name: "output"
description: <<END
A bool tensor with the same shape as `input`.
END
}
summary: "Check if the input matches the regex pattern."
description: <<END
The input is a string tensor of any shape. The pattern is a scalar
string tensor which is applied to every element of the input tensor.
The boolean values (True or False) of the output tensor indicate
if the input matches the regex pattern provided.

The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)
END
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
op {
graph_op_name: "RegexFullMatch"
visibility: HIDDEN
}
7 changes: 7 additions & 0 deletions tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4249,6 +4249,7 @@ cc_library(
":as_string_op",
":base64_ops",
":reduce_join_op",
":regex_full_match_op",
":regex_replace_op",
":string_join_op",
":string_split_op",
Expand Down Expand Up @@ -4285,6 +4286,12 @@ tf_kernel_library(
deps = STRING_DEPS,
)

tf_kernel_library(
name = "regex_full_match_op",
prefix = "regex_full_match_op",
deps = STRING_DEPS + ["@com_googlesource_code_re2//:re2"],
)

tf_kernel_library(
name = "regex_replace_op",
prefix = "regex_replace_op",
Expand Down
59 changes: 59 additions & 0 deletions tensorflow/core/kernels/regex_full_match_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <string>

#include "re2/re2.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"

namespace tensorflow {

class RegexFullMatchOp : public OpKernel {
public:
explicit RegexFullMatchOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}

void Compute(OpKernelContext* ctx) override {
const Tensor* input_tensor;
OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
const auto& input_flat = input_tensor->flat<string>();

const Tensor* pattern_tensor;
OP_REQUIRES_OK(ctx, ctx->input("pattern", &pattern_tensor));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(pattern_tensor->shape()),
errors::InvalidArgument("Pattern must be scalar, but received ",
pattern_tensor->shape().DebugString()));
const string pattern = pattern_tensor->flat<string>()(0);
const RE2 match(pattern);
OP_REQUIRES(ctx, match.ok(),
errors::InvalidArgument("Invalid pattern: ", pattern,
", error: ", match.error()));

Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
&output_tensor));
auto output_flat = output_tensor->flat<bool>();
for (size_t i = 0; i < input_flat.size(); ++i) {
output_flat(i) = RE2::FullMatch(input_flat(i), match);
}
}
};

REGISTER_KERNEL_BUILDER(Name("RegexFullMatch").Device(DEVICE_CPU),
RegexFullMatchOp);

} // namespace tensorflow
11 changes: 11 additions & 0 deletions tensorflow/core/ops/string_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ REGISTER_OP("RegexReplace")
return Status::OK();
});

REGISTER_OP("RegexFullMatch")
.Input("input: string")
.Input("pattern: string")
.Output("output: bool")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
c->set_output(0, c->input(0));
return Status::OK();
});

REGISTER_OP("StringToHashBucketFast")
.Input("input: string")
.Output("output: int64")
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/python/kernel_tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,18 @@ tf_py_test(
],
)

tf_py_test(
name = "regex_full_match_op_test",
size = "small",
srcs = ["regex_full_match_op_test.py"],
additional_deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:string_ops",
],
)

tf_py_test(
name = "save_restore_ops_test",
size = "small",
Expand Down
54 changes: 54 additions & 0 deletions tensorflow/python/kernel_tests/regex_full_match_op_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for RegexFullMatch op from string_ops."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test


class RegexFullMatchOpTest(test.TestCase):

def testRegexFullMatch(self):
values = ["abaaba", "abcdabcde"]
with self.test_session():
input_vector = constant_op.constant(values, dtypes.string)
matched = string_ops.regex_full_match(input_vector, "a.*a").eval()
self.assertAllEqual([True, False], matched)

def testEmptyMatch(self):
values = ["abc", "1"]
with self.test_session():
input_vector = constant_op.constant(values, dtypes.string)
matched = string_ops.regex_full_match(input_vector, "").eval()
self.assertAllEqual([False, False], matched)

def testInvalidPattern(self):
values = ["abc", "1"]
with self.test_session():
input_vector = constant_op.constant(values, dtypes.string)
invalid_pattern = "A["
matched = string_ops.regex_full_match(input_vector, invalid_pattern)
with self.assertRaisesOpError("Invalid pattern"):
matched.eval()


if __name__ == "__main__":
test.main()
2 changes: 2 additions & 0 deletions tensorflow/python/ops/string_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import

# Expose regex_full_match in strings namespace
tf_export("strings.regex_full_match")(regex_full_match)

@tf_export("string_split")
def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=invalid-name
Expand Down
1 change: 1 addition & 0 deletions tensorflow/tools/api/generator/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ genrule(
"api/profiler/__init__.py",
"api/python_io/__init__.py",
"api/resource_loader/__init__.py",
"api/strings/__init__.py",
"api/saved_model/__init__.py",
"api/saved_model/builder/__init__.py",
"api/saved_model/constants/__init__.py",
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/tools/api/golden/tensorflow.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,10 @@ tf_module {
name: "string"
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
}
member {
name: "strings"
mtype: "<type \'module\'>"
}
member {
name: "summary"
mtype: "<type \'module\'>"
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/tools/api/golden/tensorflow.strings.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
path: "tensorflow.strings"
tf_module {
member_method {
name: "regex_full_match"
argspec: "args=[\'input\', \'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}