Skip to content

Commit

Permalink
[C2] Add string equality operator
Browse files Browse the repository at this point in the history
Summary: This diff adds a string equality checking operator.

Test Plan: Unit tests

Differential Revision: D24042344

fbshipit-source-id: c8997c6130e3438f2ae95dae69f76978e2e95527
  • Loading branch information
Pawel Garbacki authored and facebook-github-bot committed Oct 5, 2020
1 parent 162717e commit cf48872
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
26 changes: 26 additions & 0 deletions caffe2/operators/string_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ struct EndsWith {
std::string suffix_;
};

struct Equals {
explicit Equals(OperatorBase& op)
: text_(op.GetSingleArgument<std::string>("text", "")) {}
bool operator()(const std::string& str) {
return str == text_;
}

private:
std::string text_;
};

struct Prefix {
explicit Prefix(OperatorBase& op)
: length_(op.GetSingleArgument<int>("length", 3)) {}
Expand Down Expand Up @@ -108,6 +119,9 @@ REGISTER_CPU_OPERATOR(
REGISTER_CPU_OPERATOR(
StringEndsWith,
StringElementwiseOp<EndsWith, FixedType<bool>>);
REGISTER_CPU_OPERATOR(
StringEquals,
StringElementwiseOp<Equals, FixedType<bool>>);
REGISTER_CPU_OPERATOR(StringJoin, StringJoinOp<CPUContext>);

OPERATOR_SCHEMA(StringPrefix)
Expand Down Expand Up @@ -164,6 +178,17 @@ Returns tensor of boolean of the same dimension of input.
.Input(0, "strings", "Tensor of std::string.")
.Output(0, "bools", "Tensor of bools of same shape as input.");

OPERATOR_SCHEMA(StringEquals)
.NumInputs(1)
.NumOutputs(1)
.SetDoc(R"DOC(
Performs equality check on each string in the input tensor.
Returns tensor of booleans of the same dimension as input.
)DOC")
.Arg("text", "The text to check input strings equality against.")
.Input(0, "strings", "Tensor of std::string.")
.Output(0, "bools", "Tensor of bools of same shape as input.");

OPERATOR_SCHEMA(StringJoin)
.NumInputs(1)
.NumOutputs(1)
Expand All @@ -187,6 +212,7 @@ SHOULD_NOT_DO_GRADIENT(StringPrefix);
SHOULD_NOT_DO_GRADIENT(StringSuffix);
SHOULD_NOT_DO_GRADIENT(StringStartsWith);
SHOULD_NOT_DO_GRADIENT(StringEndsWith);
SHOULD_NOT_DO_GRADIENT(StringEquals);
SHOULD_NOT_DO_GRADIENT(StringJoin);
}
} // namespace caffe2
27 changes: 27 additions & 0 deletions caffe2/python/operator_test/string_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,33 @@ def string_ends_with_ref(strings):
[strings],
string_ends_with_ref)

@given(strings=st.text(alphabet=['a', 'b']))
@settings(deadline=1000)
def test_string_equals(self, strings):
text = ""
if strings:
text = strings[0]

strings = np.array(
[str(a) for a in strings], dtype=np.object
)

def string_equals_ref(strings):
return (
np.array([a == text for a in strings], dtype=bool),
)

op = core.CreateOperator(
'StringEquals',
['strings'],
['bools'],
text=text)
self.assertReferenceChecks(
hu.cpu_do,
op,
[strings],
string_equals_ref)

if __name__ == "__main__":
import unittest
unittest.main()

0 comments on commit cf48872

Please sign in to comment.