diff --git a/inverse_normalize.py b/inverse_normalize.py index 2ab3e740..d314750c 100644 --- a/inverse_normalize.py +++ b/inverse_normalize.py @@ -31,7 +31,7 @@ def main(): parser.add_argument('--text', help='input string') parser.add_argument('--file', help='input file path') parser.add_argument('--overwrite_cache', action='store_true', - help='rebuild *.far') + help='rebuild *.fst') parser.add_argument('--enable_standalone_number', type=str, default='True', help='enable standalone number') diff --git a/normalize.py b/normalize.py index 8de82a4a..d17a7ddf 100644 --- a/normalize.py +++ b/normalize.py @@ -23,7 +23,7 @@ def main(): parser.add_argument('--text', help='input string') parser.add_argument('--file', help='input file path') parser.add_argument('--overwrite_cache', action='store_true', - help='rebuild *.far') + help='rebuild *.fst') args = parser.parse_args() normalizer = Normalizer(cache_dir='tn', diff --git a/runtime/CMakeLists.txt b/runtime/CMakeLists.txt index 7c7d8cbc..7fadde18 100644 --- a/runtime/CMakeLists.txt +++ b/runtime/CMakeLists.txt @@ -29,19 +29,9 @@ endif() include(openfst) include_directories(${PROJECT_SOURCE_DIR}) -add_library(processor STATIC - processor/processor.cc - processor/token_parser.cc - utils/utf8_string.cc -) -if(MSVC) - target_link_libraries(processor PUBLIC fst) -else() - target_link_libraries(processor PUBLIC dl fst) -endif() - -add_executable(processor_main bin/processor_main.cc) -target_link_libraries(processor_main PUBLIC processor) +add_subdirectory(utils) +add_subdirectory(processor) +add_subdirectory(bin) if(BUILD_TESTING) include(gtest) diff --git a/runtime/bin/CMakeLists.txt b/runtime/bin/CMakeLists.txt new file mode 100644 index 00000000..89d2b875 --- /dev/null +++ b/runtime/bin/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(processor_main processor_main.cc) +target_link_libraries(processor_main PUBLIC processor) diff --git a/runtime/bin/processor_main.cc b/runtime/bin/processor_main.cc index bcf48eaf..2b7cbbf5 100644 --- a/runtime/bin/processor_main.cc +++ b/runtime/bin/processor_main.cc @@ -34,9 +34,9 @@ int main(int argc, char* argv[]) { wetext::Processor processor(FLAGS_tagger, FLAGS_verbalizer); if (!FLAGS_text.empty()) { - std::string tagged_text = processor.tag(FLAGS_text); + std::string tagged_text = processor.Tag(FLAGS_text); std::cout << tagged_text << std::endl; - std::string normalized_text = processor.verbalize(tagged_text); + std::string normalized_text = processor.Verbalize(tagged_text); std::cout << normalized_text << std::endl; } @@ -44,9 +44,9 @@ int main(int argc, char* argv[]) { std::ifstream file(FLAGS_file); std::string line; while (getline(file, line)) { - std::string tagged_text = processor.tag(line); + std::string tagged_text = processor.Tag(line); std::cout << tagged_text << std::endl; - std::string normalized_text = processor.verbalize(tagged_text); + std::string normalized_text = processor.Verbalize(tagged_text); std::cout << normalized_text << std::endl; } } diff --git a/runtime/processor/CMakeLists.txt b/runtime/processor/CMakeLists.txt new file mode 100644 index 00000000..0cab5056 --- /dev/null +++ b/runtime/processor/CMakeLists.txt @@ -0,0 +1,9 @@ +add_library(processor STATIC + processor.cc + token_parser.cc +) +if(MSVC) + target_link_libraries(processor PUBLIC fst utils) +else() + target_link_libraries(processor PUBLIC dl fst utils) +endif() diff --git a/runtime/processor/processor.cc b/runtime/processor/processor.cc index fb666857..e3b138ca 100644 --- a/runtime/processor/processor.cc +++ b/runtime/processor/processor.cc @@ -34,7 +34,7 @@ Processor::Processor(const std::string& tagger_path, } } -std::string Processor::shortest_path(const StdVectorFst& lattice) { +std::string Processor::ShortestPath(const StdVectorFst& lattice) { StdVectorFst shortest_path; fst::ShortestPath(lattice, &shortest_path, 1, true); @@ -43,31 +43,34 @@ std::string Processor::shortest_path(const StdVectorFst& lattice) { return output; } -std::string Processor::compose(const std::string& input, +std::string Processor::Compose(const std::string& input, const StdVectorFst* fst) { StdVectorFst input_fst; compiler_->operator()(input, &input_fst); StdVectorFst lattice; fst::Compose(input_fst, *fst, &lattice); - return shortest_path(lattice); + return ShortestPath(lattice); } -std::string Processor::tag(const std::string& input) { - return compose(input, tagger_.get()); +std::string Processor::Tag(const std::string& input) { + return Compose(input, tagger_.get()); } -std::string Processor::verbalize(const std::string& input) { +std::string Processor::Verbalize(const std::string& input) { if (input.empty()) { return ""; } TokenParser parser(parse_type_); - std::string output = parser.reorder(input); - return compose(output, verbalizer_.get()); + std::string output = parser.Reorder(input); + + output = Compose(output, verbalizer_.get()); + output.erase(std::remove(output.begin(), output.end(), '\0'), output.end()); + return output; } -std::string Processor::normalize(const std::string& input) { - return verbalize(tag(input)); +std::string Processor::Normalize(const std::string& input) { + return Verbalize(Tag(input)); } } // namespace wetext diff --git a/runtime/processor/processor.h b/runtime/processor/processor.h index 16ba0a59..7412b51d 100644 --- a/runtime/processor/processor.h +++ b/runtime/processor/processor.h @@ -28,13 +28,13 @@ namespace wetext { class Processor { public: Processor(const std::string& tagger_path, const std::string& verbalizer_path); - std::string tag(const std::string& input); - std::string verbalize(const std::string& input); - std::string normalize(const std::string& input); + std::string Tag(const std::string& input); + std::string Verbalize(const std::string& input); + std::string Normalize(const std::string& input); private: - std::string shortest_path(const StdVectorFst& lattice); - std::string compose(const std::string& input, const StdVectorFst* fst); + std::string ShortestPath(const StdVectorFst& lattice); + std::string Compose(const std::string& input, const StdVectorFst* fst); ParseType parse_type_; std::shared_ptr tagger_ = nullptr; diff --git a/runtime/processor/token_parser.cc b/runtime/processor/token_parser.cc index 21a20ccc..fc02eff2 100644 --- a/runtime/processor/token_parser.cc +++ b/runtime/processor/token_parser.cc @@ -15,7 +15,7 @@ #include "processor/token_parser.h" #include "utils/log.h" -#include "utils/utf8_string.h" +#include "utils/string.h" namespace wetext { const std::string EOS = ""; @@ -41,113 +41,113 @@ const std::unordered_map> ITN_ORDERS = { TokenParser::TokenParser(ParseType type) { if (type == ParseType::kTN) { - orders = TN_ORDERS; + orders_ = TN_ORDERS; } else { - orders = ITN_ORDERS; + orders_ = ITN_ORDERS; } } -void TokenParser::load(const std::string& input) { - string2chars(input, &text); - CHECK_GT(text.size(), 0); - index = 0; - ch = text[0]; +void TokenParser::Load(const std::string& input) { + SplitUTF8StringToChars(input, &text_); + CHECK_GT(text_.size(), 0); + index_ = 0; + ch_ = text_[0]; } -bool TokenParser::read() { - if (index < text.size() - 1) { - index += 1; - ch = text[index]; +bool TokenParser::Read() { + if (index_ < text_.size() - 1) { + index_ += 1; + ch_ = text_[index_]; return true; } - ch = EOS; + ch_ = EOS; return false; } -bool TokenParser::parse_ws() { - bool not_eos = ch != EOS; - while (not_eos && ch == " ") { - not_eos = read(); +bool TokenParser::ParseWs() { + bool not_eos = ch_ != EOS; + while (not_eos && ch_ == " ") { + not_eos = Read(); } return not_eos; } -bool TokenParser::parse_char(const std::string& exp) { - if (ch == exp) { - read(); +bool TokenParser::ParseChar(const std::string& exp) { + if (ch_ == exp) { + Read(); return true; } return false; } -bool TokenParser::parse_chars(const std::string& exp) { +bool TokenParser::ParseChars(const std::string& exp) { bool ok = false; std::vector chars; - string2chars(exp, &chars); + SplitUTF8StringToChars(exp, &chars); for (const auto& x : chars) { - ok |= parse_char(x); + ok |= ParseChar(x); } return ok; } -std::string TokenParser::parse_key() { - CHECK_NE(ch, EOS); - CHECK_EQ(UTF8_WHITESPACE.count(ch), 0); +std::string TokenParser::ParseKey() { + CHECK_NE(ch_, EOS); + CHECK_EQ(UTF8_WHITESPACE.count(ch_), 0); std::string key = ""; - while (ASCII_LETTERS.count(ch) > 0) { - key += ch; - read(); + while (ASCII_LETTERS.count(ch_) > 0) { + key += ch_; + Read(); } return key; } -std::string TokenParser::parse_value() { - CHECK_NE(ch, EOS); +std::string TokenParser::ParseValue() { + CHECK_NE(ch_, EOS); bool escape = false; std::string value = ""; - while (ch != "\"") { - value += ch; - escape = ch == "\\" && !escape; - read(); + while (ch_ != "\"") { + value += ch_; + escape = ch_ == "\\" && !escape; + Read(); if (escape) { - value += ch; - read(); + value += ch_; + Read(); } } return value; } -void TokenParser::parse(const std::string& input) { - load(input); - while (parse_ws()) { - std::string name = parse_key(); - parse_chars(" { "); +void TokenParser::Parse(const std::string& input) { + Load(input); + while (ParseWs()) { + std::string name = ParseKey(); + ParseChars(" { "); Token token(name); - while (parse_ws()) { - if (ch == "}") { - parse_char("}"); + while (ParseWs()) { + if (ch_ == "}") { + ParseChar("}"); break; } - std::string key = parse_key(); - parse_chars(": \""); - std::string value = parse_value(); - parse_char("\""); - token.append(key, value); + std::string key = ParseKey(); + ParseChars(": \""); + std::string value = ParseValue(); + ParseChar("\""); + token.Append(key, value); } - tokens.emplace_back(token); + tokens_.emplace_back(token); } } -std::string TokenParser::reorder(const std::string& input) { - parse(input); +std::string TokenParser::Reorder(const std::string& input) { + Parse(input); std::string output = ""; - for (auto& token : tokens) { - output += token.string(orders) + " "; + for (auto& token : tokens_) { + output += token.String(orders_) + " "; } - return trim(output); + return Trim(output); } } // namespace wetext diff --git a/runtime/processor/token_parser.h b/runtime/processor/token_parser.h index c3ba1b85..c099504e 100644 --- a/runtime/processor/token_parser.h +++ b/runtime/processor/token_parser.h @@ -37,12 +37,12 @@ struct Token { Token(const std::string& name) : name(name) {} - void append(const std::string& key, const std::string& value) { + void Append(const std::string& key, const std::string& value) { order.emplace_back(key); members[key] = value; } - std::string string( + std::string String( const std::unordered_map>& orders) { std::string output = name + " {"; if (orders.count(name) > 0) { @@ -67,25 +67,25 @@ enum ParseType { class TokenParser { public: TokenParser(ParseType type); - std::string reorder(const std::string& input); + std::string Reorder(const std::string& input); private: - void load(const std::string& input); - bool read(); - bool parse_ws(); - bool parse_char(const std::string& exp); - bool parse_chars(const std::string& exp); - std::string parse_key(); - std::string parse_value(); - void parse(const std::string& input); + void Load(const std::string& input); + bool Read(); + bool ParseWs(); + bool ParseChar(const std::string& exp); + bool ParseChars(const std::string& exp); + std::string ParseKey(); + std::string ParseValue(); + void Parse(const std::string& input); - int index; - std::string ch; - std::vector text; - std::vector tokens; - std::unordered_map> orders; + int index_; + std::string ch_; + std::vector text_; + std::vector tokens_; + std::unordered_map> orders_; }; -} // wetext +} // namespace wetext #endif // PROCESSOR_TOKEN_PARSER_H_ diff --git a/runtime/test/CMakeLists.txt b/runtime/test/CMakeLists.txt index f2eb5653..8146bfa3 100644 --- a/runtime/test/CMakeLists.txt +++ b/runtime/test/CMakeLists.txt @@ -2,9 +2,9 @@ enable_testing() link_libraries(gtest_main gmock) include(GoogleTest) -add_executable(utf8_string_test utf8_string_test.cc) -target_link_libraries(utf8_string_test PUBLIC utils) -gtest_discover_tests(utf8_string_test) +add_executable(string_test string_test.cc) +target_link_libraries(string_test PUBLIC utils) +gtest_discover_tests(string_test) if(NOT MSVC) # token_parser_test uses the macro to access the private members diff --git a/runtime/test/processor_test.cc b/runtime/test/processor_test.cc index e13fe282..030d39e0 100644 --- a/runtime/test/processor_test.cc +++ b/runtime/test/processor_test.cc @@ -18,9 +18,9 @@ #include "gmock/gmock.h" #include "processor/processor.h" -#include "utils/utf8_string.h" +#include "utils/string.h" -std::vector> parse_test_case( +std::vector> ParseTestCase( const std::string& file_path) { const std::string delimiter = "=>"; std::ifstream file(file_path); @@ -30,14 +30,14 @@ std::vector> parse_test_case( while (getline(file, line)) { CHECK_NE(line.find(delimiter), string::npos); std::vector arr; - wetext::split_string(line, delimiter, &arr); + wetext::Split(line, delimiter, &arr); CHECK_GT(arr.size(), 0); CHECK_LE(arr.size(), 2); - std::string written = wetext::trim(arr[0]); + std::string written = wetext::Trim(arr[0]); std::string spoken = ""; if (arr.size() == 2) { - spoken = wetext::trim(arr[1]); + spoken = wetext::Trim(arr[1]); } test_cases.emplace_back(std::make_pair(written, spoken)); } @@ -63,10 +63,10 @@ class ProcessorTest }; TEST_P(ProcessorTest, NormalizeTest) { - EXPECT_EQ(processor->normalize(written), spoken); + EXPECT_EQ(processor->Normalize(written), spoken); } std::vector> test_cases = - parse_test_case("../tn/chinese/test/data/normalizer.txt"); + ParseTestCase("../tn/chinese/test/data/normalizer.txt"); INSTANTIATE_TEST_SUITE_P(NormalizeTest, ProcessorTest, testing::ValuesIn(test_cases)); diff --git a/runtime/test/utf8_string_test.cc b/runtime/test/string_test.cc similarity index 61% rename from runtime/test/utf8_string_test.cc rename to runtime/test/string_test.cc index e7518dbf..6ed75312 100644 --- a/runtime/test/utf8_string_test.cc +++ b/runtime/test/string_test.cc @@ -14,32 +14,32 @@ #include "gmock/gmock.h" -#include "utils/utf8_string.h" +#include "utils/string.h" class StringTest : public testing::Test {}; TEST(StringTest, StringLengthTest) { - EXPECT_EQ(wetext::string_length("A"), 1); - EXPECT_EQ(wetext::string_length("À"), 1); - EXPECT_EQ(wetext::string_length("啊"), 1); - EXPECT_EQ(wetext::string_length("✐"), 1); - EXPECT_EQ(wetext::string_length("你好"), 2); - EXPECT_EQ(wetext::string_length("world"), 5); + EXPECT_EQ(wetext::UTF8StringLength("A"), 1); + EXPECT_EQ(wetext::UTF8StringLength("À"), 1); + EXPECT_EQ(wetext::UTF8StringLength("啊"), 1); + EXPECT_EQ(wetext::UTF8StringLength("✐"), 1); + EXPECT_EQ(wetext::UTF8StringLength("你好"), 2); + EXPECT_EQ(wetext::UTF8StringLength("world"), 5); } -TEST(StringTest, String2CharsTest) { +TEST(StringTest, SplitUTF8StringToCharsTest) { std::vector chars; - wetext::string2chars("你好world", &chars); + wetext::SplitUTF8StringToChars("你好world", &chars); ASSERT_THAT(chars, testing::ElementsAre("你", "好", "w", "o", "r", "l", "d")); } TEST(StringTest, TrimTest) { - ASSERT_EQ(wetext::trim("\thello "), "hello"); - ASSERT_EQ(wetext::trim(" hello\t"), "hello"); + ASSERT_EQ(wetext::Trim("\thello "), "hello"); + ASSERT_EQ(wetext::Trim(" hello\t"), "hello"); } -TEST(StringTest, SplitStringTest) { +TEST(StringTest, SplitTest) { std::vector output; - wetext::split_string("written => spoken", " => ", &output); + wetext::Split("written => spoken", " => ", &output); ASSERT_THAT(output, testing::ElementsAre("written", "spoken")); } diff --git a/runtime/test/token_parser_test.cc b/runtime/test/token_parser_test.cc index ad96b85e..9170bfb1 100644 --- a/runtime/test/token_parser_test.cc +++ b/runtime/test/token_parser_test.cc @@ -33,46 +33,46 @@ class TokenParserTest : public testing::Test { }; TEST_F(TokenParserTest, ReadTest) { - parser->load(" "); - ASSERT_FALSE(parser->read()); - ASSERT_EQ(parser->ch, wetext::EOS); + parser->Load(" "); + ASSERT_FALSE(parser->Read()); + ASSERT_EQ(parser->ch_, wetext::EOS); } TEST_F(TokenParserTest, ParseWSTest) { - parser->load(" "); - ASSERT_FALSE(parser->parse_ws()); - ASSERT_EQ(parser->ch, wetext::EOS); + parser->Load(" "); + ASSERT_FALSE(parser->ParseWs()); + ASSERT_EQ(parser->ch_, wetext::EOS); - parser->load(" "); - ASSERT_FALSE(parser->parse_ws()); - ASSERT_EQ(parser->ch, wetext::EOS); + parser->Load(" "); + ASSERT_FALSE(parser->ParseWs()); + ASSERT_EQ(parser->ch_, wetext::EOS); - parser->load(" test"); - ASSERT_TRUE(parser->parse_ws()); - ASSERT_EQ(parser->ch, "t"); + parser->Load(" test"); + ASSERT_TRUE(parser->ParseWs()); + ASSERT_EQ(parser->ch_, "t"); } TEST_F(TokenParserTest, ParseCharsTest) { - parser->load("hello world"); - ASSERT_TRUE(parser->parse_chars("hello")); - ASSERT_EQ(parser->ch, " "); + parser->Load("hello world"); + ASSERT_TRUE(parser->ParseChars("hello")); + ASSERT_EQ(parser->ch_, " "); - parser->load("world"); - ASSERT_FALSE(parser->parse_chars("hello")); - ASSERT_EQ(parser->ch, "w"); + parser->Load("world"); + ASSERT_FALSE(parser->ParseChars("hello")); + ASSERT_EQ(parser->ch_, "w"); } TEST_F(TokenParserTest, ParseKeyTest) { - parser->load("key"); - ASSERT_EQ(parser->parse_key(), "key"); + parser->Load("key"); + ASSERT_EQ(parser->ParseKey(), "key"); - parser->load("key "); - ASSERT_EQ(parser->parse_key(), "key"); + parser->Load("key "); + ASSERT_EQ(parser->ParseKey(), "key"); } TEST_F(TokenParserTest, ParseValueTest) { - parser->load("value\""); - ASSERT_EQ(parser->parse_value(), "value"); + parser->Load("value\""); + ASSERT_EQ(parser->ParseValue(), "value"); } TEST_F(TokenParserTest, ParseTest) { @@ -82,8 +82,8 @@ TEST_F(TokenParserTest, ParseTest) { std::string input = "time { minute: \"零二分\" hour: \"两点\" } char { value: \"走\" }"; - parser->parse(input); - std::vector tokens = parser->tokens; + parser->Parse(input); + std::vector tokens = parser->tokens_; ASSERT_EQ(tokens.size(), 2); ASSERT_EQ(tokens[0].name, "time"); ASSERT_EQ(tokens[1].name, "char"); @@ -99,5 +99,5 @@ TEST_F(TokenParserTest, ReorderTest) { "time { minute: \"零二分\" hour: \"两点\" } char { value: \"走\" }"; std::string expected = "time { hour: \"两点\" minute: \"零二分\" } char { value: \"走\" }"; - ASSERT_EQ(parser->reorder(input), expected); + ASSERT_EQ(parser->Reorder(input), expected); } diff --git a/runtime/utils/CMakeLists.txt b/runtime/utils/CMakeLists.txt new file mode 100644 index 00000000..b33a5f2f --- /dev/null +++ b/runtime/utils/CMakeLists.txt @@ -0,0 +1,3 @@ +add_library(utils STATIC string.cc) + +target_link_libraries(utils PUBLIC glog) diff --git a/runtime/utils/utf8_string.cc b/runtime/utils/string.cc similarity index 81% rename from runtime/utils/utf8_string.cc rename to runtime/utils/string.cc index b4bfb990..36d82c07 100644 --- a/runtime/utils/utf8_string.cc +++ b/runtime/utils/string.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "utils/utf8_string.h" +#include "utils/string.h" #include "utils/log.h" namespace wetext { const char* WHITESPACE = " \n\r\t\f\v"; -int char_length(char ch) { +int UTF8CharLength(char ch) { int num_bytes = 1; CHECK_LE((ch & 0xF8), 0xF0); if ((ch & 0x80) == 0x00) { @@ -43,39 +43,40 @@ int char_length(char ch) { return num_bytes; } -int string_length(const std::string& str) { +int UTF8StringLength(const std::string& str) { int len = 0; int num_bytes = 1; for (size_t i = 0; i < str.length(); i += num_bytes) { - num_bytes = char_length(str[i]); + num_bytes = UTF8CharLength(str[i]); ++len; } return len; } -void string2chars(const std::string& str, std::vector* chars) { +void SplitUTF8StringToChars(const std::string& str, + std::vector* chars) { chars->clear(); int num_bytes = 1; for (size_t i = 0; i < str.length(); i += num_bytes) { - num_bytes = char_length(str[i]); + num_bytes = UTF8CharLength(str[i]); chars->push_back(str.substr(i, num_bytes)); } } -std::string ltrim(const std::string& str) { +std::string Ltrim(const std::string& str) { size_t start = str.find_first_not_of(WHITESPACE); return (start == std::string::npos) ? "" : str.substr(start); } -std::string rtrim(const std::string& str) { +std::string Rtrim(const std::string& str) { size_t end = str.find_last_not_of(WHITESPACE); return end == std::string::npos ? "" : str.substr(0, end + 1); } -std::string trim(const std::string& str) { return rtrim(ltrim(str)); } +std::string Trim(const std::string& str) { return Rtrim(Ltrim(str)); } -void split_string(const std::string& str, const std::string& delim, - std::vector* output) { +void Split(const std::string& str, const std::string& delim, + std::vector* output) { std::string s = str; size_t pos = 0; while ((pos = s.find(delim)) != std::string::npos) { diff --git a/runtime/utils/utf8_string.h b/runtime/utils/string.h similarity index 66% rename from runtime/utils/utf8_string.h rename to runtime/utils/string.h index 0d216c84..15a21cd9 100644 --- a/runtime/utils/utf8_string.h +++ b/runtime/utils/string.h @@ -21,20 +21,21 @@ namespace wetext { extern const char* WHITESPACE; -int char_length(char ch); +int UTF8CharLength(char ch); -int string_length(const std::string& str); +int UTF8StringLength(const std::string& str); -void string2chars(const std::string& str, std::vector* chars); +void SplitUTF8StringToChars(const std::string& str, + std::vector* chars); -std::string ltrim(const std::string& str); +std::string Ltrim(const std::string& str); -std::string rtrim(const std::string& str); +std::string Rtrim(const std::string& str); -std::string trim(const std::string& str); +std::string Trim(const std::string& str); -void split_string(const std::string& str, const std::string& delim, - std::vector* output); +void Split(const std::string& str, const std::string& delim, + std::vector* output); } // namespace wetext