diff --git a/test/test_jit.py b/test/test_jit.py index b12baf823fc9f..21c7999431972 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -12079,14 +12079,29 @@ def fn(x): self.checkScript(fn, ("abcde",)) - def test_str_cmp(self): - def test(a, b): + def test_str_ops(self): + def test_str_is(s): + # type: (str) -> Tuple[bool, bool, bool, bool, bool, bool] + return s.isupper(), s.islower(), s.isdigit(), s.isspace(), \ + s.isalnum(), s.isalpha() + + def test_str_to(s): + # type: (str) -> Tuple[str, str] + return s.upper(), s.lower() + + inputs = ["", "12a", "!B", "12", "a", "B", "aB", "$12", "B12", "AB ", + " \t", " \n", "\na", "abc"] + + for input in inputs: + self.checkScript(test_str_is, (input,)) + self.checkScript(test_str_to, (input,)) + + def test_str_cmp(a, b): # type: (str, str) -> Tuple[bool, bool, bool, bool, bool, bool] return a != b, a == b, a < b, a > b, a <= b, a >= b - self.checkScript(test, ("1", "2")) - self.checkScript(test, ("2", "1")) - self.checkScript(test, ("1", "1")) + for i in range(len(inputs) - 1): + self.checkScript(test_str_cmp, (inputs[i], inputs[i + 1])) def test_ord(self): def fn(x): diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 51cdeb6feca45..479e402bc6001 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -1913,6 +1914,70 @@ RegisterOperators reg2({ Operator( "aten::slice(str string, int start, int end=9223372036854775807, int step=1) -> str", stringSlice), + +// python string is methods return false if empty +#define DEFINE_STRING_IS_OP(op_name, char_op) \ + Operator(#op_name "(str self) -> bool", [](Stack& stack) { \ + auto string = pop(stack).toStringRef(); \ + push( \ + stack, \ + string.size() != 0 && \ + std::all_of(string.begin(), string.end(), [](char c) { \ + return char_op(c); \ + })); \ + return 0; \ + }) + + // upper and lower require there to be at least one alpha character, + // and ignore all other characters + Operator( + "aten::isupper(str self) -> bool", + [](Stack& stack) { + auto string = pop(stack).toStringRef(); + bool found_alpha = false; + bool is_upper = true; + for (size_t i = 0; i < string.size() && is_upper; ++i) { + char c = string[i]; + found_alpha |= std::isalpha(c); + is_upper &= (!std::isalpha(c) || std::isupper(c)); + } + push(stack, found_alpha && is_upper); + return 0; + }), + Operator( + "aten::islower(str self) -> bool", + [](Stack& stack) { + auto string = pop(stack).toStringRef(); + bool found_alpha = false; + bool is_lower = true; + for (size_t i = 0; i < string.size() && is_lower; ++i) { + char c = string[i]; + found_alpha |= std::isalpha(c); + is_lower &= (!std::isalpha(c) || std::islower(c)); + } + push(stack, found_alpha && is_lower); + return 0; + }), + + DEFINE_STRING_IS_OP(aten::isdigit, std::isdigit), + DEFINE_STRING_IS_OP(aten::isspace, std::isspace), + DEFINE_STRING_IS_OP(aten::isalnum, std::isalnum), + DEFINE_STRING_IS_OP(aten::isalpha, std::isalpha), + +#define DEFINE_STRING_CHAR_MAP_OP(op_name, char_op) \ + Operator(#op_name "(str self) -> str", [](Stack& stack) { \ + auto string = pop(stack).toStringRef(); \ + std::stringstream ss; \ + for (char c : string) { \ + ss << static_cast(char_op(c)); \ + } \ + push(stack, ss.str()); \ + return 0; \ + }) + + DEFINE_STRING_CHAR_MAP_OP(aten::upper, std::toupper), + DEFINE_STRING_CHAR_MAP_OP(aten::lower, std::tolower), + Operator( "prim::StringIndex(str string, int index) -> str", [](Stack& stack) {