Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ jobs:
- name: Prepare Graph
run: |
mkdir graph
cp tn/zh_tn_normalizer.far graph
cp itn/zh_itn_normalizer.far graph
cp tn/*.fst graph
cp itn/*.fst graph

- name: Upload Graph
uses: actions/upload-artifact@v3
Expand Down
23 changes: 2 additions & 21 deletions itn/chinese/inverse_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from tn.processor import Processor
from itn.chinese.rules.cardinal import Cardinal
from itn.chinese.rules.char import Char
Expand All @@ -26,36 +24,19 @@
from itn.chinese.rules.time import Time
from itn.chinese.rules.preprocessor import PreProcessor

from pynini import Far
from pynini.lib.pynutil import add_weight, delete
from importlib_resources import files


class InverseNormalizer(Processor):

def __init__(self, cache_dir=None, overwrite_cache=False,
def __init__(self, cache_dir='itn', overwrite_cache=False,
enable_standalone_number=True,
enable_0_to_9=True):
super().__init__(name='inverse_normalizer', ordertype='itn')
self.cache_dir = cache_dir
self.overwrite_cache = overwrite_cache
self.convert_number = enable_standalone_number
self.enable_0_to_9 = enable_0_to_9

far_file = files('itn').joinpath('zh_itn_normalizer.far')
if self.cache_dir:
os.makedirs(self.cache_dir, exist_ok=True)
far_file = os.path.join(self.cache_dir, 'zh_itn_normalizer.far')

if far_file and os.path.exists(far_file) and not overwrite_cache:
self.tagger = Far(far_file)['tagger']
self.verbalizer = Far(far_file)['verbalizer']
else:
self.build_tagger()
self.build_verbalizer()

if self.cache_dir and self.overwrite_cache:
self.export(far_file)
self.build_fst('zh_itn', cache_dir, overwrite_cache)

def build_tagger(self):
tagger = (add_weight(Date().tagger, 1.02)
Expand Down
9 changes: 3 additions & 6 deletions runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)

project(text_processing VERSION 0.1)
project(wetextprocessing VERSION 0.1)
set(CMAKE_CXX_STANDARD 14)

set(CMAKE_VERBOSE_MAKEFILE OFF)
option(BUILD_TESTING "whether to build unit test" ON)
option(FST_HAVE_BIN "whether to build fst binaries" OFF)
option(BUILD_TESTING "whether to build unit test" OFF)

include(FetchContent)
include(ExternalProject)
set(FETCHCONTENT_QUIET OFF)
get_filename_component(fc_base "fc_base" REALPATH BASE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(FETCHCONTENT_BASE_DIR ${fc_base})
Expand All @@ -21,9 +19,8 @@ if(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin")
endif()

include(openfst)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${PROJECT_SOURCE_DIR})
add_subdirectory(utils)
add_dependencies(utils openfst)
add_subdirectory(processor)
add_subdirectory(bin)

Expand Down
14 changes: 7 additions & 7 deletions runtime/bin/processor_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,26 @@
#include <string>

#include "processor/processor.h"
#include "processor/token_parser.h"
#include "utils/flags.h"

DEFINE_string(text, "", "input string");
DEFINE_string(file, "", "input file");
DEFINE_string(far, "", "FST archives");
DEFINE_string(tagger, "", "tagger fst path");
DEFINE_string(verbalizer, "", "verbalizer fst path");

int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);

if (FLAGS_far.empty()) {
LOG(FATAL) << "Please provide the FST archives.";
if (FLAGS_tagger.empty() || FLAGS_verbalizer.empty()) {
LOG(FATAL) << "Please provide the tagger and verbalizer fst files.";
}
wenet::Processor processor(FLAGS_far);
wenet::Processor processor(FLAGS_tagger, FLAGS_verbalizer);

if (!FLAGS_text.empty()) {
std::string tagged_text = processor.tag(FLAGS_text);
std::cout << tagged_text << std::endl;
std::string normalized_text = processor.normalize(FLAGS_text);
std::string normalized_text = processor.verbalize(tagged_text);
std::cout << normalized_text << std::endl;
}

Expand All @@ -46,7 +46,7 @@ int main(int argc, char* argv[]) {
while (getline(file, line)) {
std::string tagged_text = processor.tag(line);
std::cout << tagged_text << std::endl;
std::string normalized_text = processor.normalize(line);
std::string normalized_text = processor.verbalize(tagged_text);
std::cout << normalized_text << std::endl;
}
}
Expand Down
42 changes: 25 additions & 17 deletions runtime/cmake/openfst.cmake
Original file line number Diff line number Diff line change
@@ -1,28 +1,36 @@
include(gflags)
# We can't build glog with gflags, unless gflags is pre-installed.
# If build glog with pre-installed gflags, there will be conflict.
set(WITH_GFLAGS OFF CACHE BOOL "whether build glog with gflags" FORCE)
include(glog)

set(CONFIG_FLAGS "")
if(NOT FST_HAVE_BIN)
set(CONFIG_FLAGS "--disable-bin")
set(HAVE_BIN OFF CACHE BOOL "Build the fst binaries" FORCE)
set(HAVE_SCRIPT OFF CACHE BOOL "Build the fstscript" FORCE)
set(HAVE_COMPACT OFF CACHE BOOL "Build compact" FORCE)
set(HAVE_CONST OFF CACHE BOOL "Build const" FORCE)
set(HAVE_GRM OFF CACHE BOOL "Build grm" FORCE)
set(HAVE_FAR OFF CACHE BOOL "Build far" FORCE)
set(HAVE_PDT OFF CACHE BOOL "Build pdt" FORCE)
set(HAVE_MPDT OFF CACHE BOOL "Build mpdt" FORCE)
set(HAVE_LINEAR OFF CACHE BOOL "Build linear" FORCE)
set(HAVE_LOOKAHEAD OFF CACHE BOOL "Build lookahead" FORCE)
set(HAVE_NGRAM OFF CACHE BOOL "Build ngram" FORCE)
set(HAVE_SPECIAL OFF CACHE BOOL "Build special" FORCE)

if(MSVC)
add_compile_options(/W0 /wd4244 /wd4267)
endif()

# "OpenFST port for Windows" builds openfst with cmake for multiple platforms.
# Openfst is compiled with glog/gflags to avoid log and flag conflicts with log and flags in wenet/libtorch.
# To build openfst with gflags and glog, we comment out some vars of {flags, log}.h and flags.cc.
set(openfst_SOURCE_DIR ${fc_base}/openfst-src CACHE PATH "OpenFST source directory")
set(openfst_PREFIX_DIR ${fc_base}/openfst-subbuild/openfst-populate-prefix CACHE PATH "OpenFST prefix directory")
ExternalProject_Add(openfst
URL https://github.com/mjansche/openfst/archive/1.7.2.zip
URL_HASH MD5=96656fee440ee2d71006a4900ef9ac00
PREFIX ${openfst_PREFIX_DIR}
SOURCE_DIR ${openfst_SOURCE_DIR}
CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure ${CONFIG_FLAGS} --enable-far --prefix=${openfst_PREFIX_DIR}
"CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}"
"LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}"
"LIBS=-lgflags_nothreads -lglog -lpthread"
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR}
BUILD_COMMAND make -j$(nproc)
FetchContent_Declare(openfst
URL https://github.com/kkm000/openfst/archive/refs/tags/win/1.7.2.1.tar.gz
URL_HASH SHA256=e04e1dabcecf3a687ace699ccb43a8a27da385777a56e69da6e103344cc66bca
PATCH_COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR}
)
add_dependencies(openfst gflags glog)
link_directories(${openfst_PREFIX_DIR}/lib)
FetchContent_MakeAvailable(openfst)
add_dependencies(fst gflags glog)
target_link_libraries(fst PUBLIC gflags_nothreads_static glog)
include_directories(${openfst_SOURCE_DIR}/src/include)
5 changes: 4 additions & 1 deletion runtime/patch/openfst/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ install(DIRECTORY include/ DESTINATION include/
FILES_MATCHING PATTERN "*.h")

add_subdirectory(lib)
add_subdirectory(script)

if(HAVE_SCRIPT)
add_subdirectory(script)
endif(HAVE_SCRIPT)

if(HAVE_BIN)
add_subdirectory(bin)
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ if(HAVE_BIN)
sigma-fst.cc
)

set_target_properties(fstspecial-bin PROPERTIE
set_target_properties(fstspecial-bin PROPERTIES
FOLDER special/bin
OUTPUT_NAME fstspecial
)
Expand Down
54 changes: 0 additions & 54 deletions runtime/patch/openfst/src/test/CMakeLists.txt

This file was deleted.

2 changes: 1 addition & 1 deletion runtime/processor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ add_library(processor STATIC
processor.cc
token_parser.cc
)
target_link_libraries(processor PUBLIC fstfar utils)
target_link_libraries(processor PUBLIC utils)
50 changes: 16 additions & 34 deletions runtime/processor/processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,47 +14,30 @@

#include "processor/processor.h"

#include "fst/fstlib.h"

#include "utils/utils.h"

using fst::FarReader;
using fst::StdVectorFst;
using fst::StringTokenType;

namespace wenet {

Processor::Processor(const std::string& far_path) {
FarReader<StdArc>* reader = FarReader<StdArc>::Open(far_path);
CHECK_NOTNULL(reader);

CHECK_GT(reader->Find("tagger"), 0) << "Tagger is missing.";
tagger_ = reader->GetFst()->Copy();

CHECK_GT(reader->Find("verbalizer"), 0) << "Verbalizer is missing.";
verbalizer_ = reader->GetFst()->Copy();

delete reader;

Processor::Processor(const std::string& tagger_path,
const std::string& verbalizer_path) {
tagger_.reset(StdVectorFst::Read(tagger_path));
verbalizer_.reset(StdVectorFst::Read(verbalizer_path));
compiler_ = std::make_shared<StringCompiler<StdArc>>(StringTokenType::BYTE);

if (far_path.find("_tn_") != far_path.npos) {
if (tagger_path.find("_tn_") != tagger_path.npos) {
parse_type_ = ParseType::kTN;
} else if (far_path.find("_itn_") != far_path.npos) {
} else if (tagger_path.find("_itn_") != tagger_path.npos) {
parse_type_ = ParseType::kITN;
} else {
LOG(FATAL) << "Invalid far prefix, prefix should contain"
LOG(FATAL) << "Invalid fst prefix, prefix should contain"
<< " either \"_tn_\" or \"_itn_\".";
}
}

Processor::~Processor() {
delete tagger_;
delete verbalizer_;
}

std::string Processor::compose(const std::string& input,
const Fst<StdArc>* fst) {
const StdVectorFst* fst) {
StdVectorFst input_fst;
compiler_->operator()(input, &input_fst);

Expand All @@ -64,21 +47,20 @@ std::string Processor::compose(const std::string& input,
}

std::string Processor::tag(const std::string& input) {
return compose(input, tagger_);
return compose(input, tagger_.get());
}

std::string Processor::verbalize(const std::string& input) {
return compose(input, verbalizer_);
}

std::string Processor::normalize(const std::string& input) {
std::string output = tag(input);
if (output.empty()) {
if (input.empty()) {
return "";
}
TokenParser parser(parse_type_);
output = parser.reorder(output);
return verbalize(output);
std::string output = parser.reorder(input);
return compose(output, verbalizer_.get());
}

std::string Processor::normalize(const std::string& input) {
return verbalize(tag(input));
}

} // namespace wenet
13 changes: 6 additions & 7 deletions runtime/processor/processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,29 @@
#ifndef PROCESSOR_PROCESSOR_H_
#define PROCESSOR_PROCESSOR_H_

#include "fst/extensions/far/farlib.h"
#include "fst/fstlib.h"

#include "processor/token_parser.h"

using fst::Fst;
using fst::StdArc;
using fst::StdVectorFst;
using fst::StringCompiler;

namespace wenet {

class Processor {
public:
Processor(const std::string& far_path);
~Processor();
Processor(const std::string& tagger_path, const std::string& verbalizer_path);
std::string tag(const std::string& input);
std::string verbalize(const std::string& input);
std::string normalize(const std::string& input);

private:
std::string compose(const std::string& input, const Fst<StdArc>* fst);
std::string compose(const std::string& input, const StdVectorFst* fst);

ParseType parse_type_;
Fst<StdArc>* tagger_ = nullptr;
Fst<StdArc>* verbalizer_ = nullptr;
std::shared_ptr<StdVectorFst> tagger_ = nullptr;
std::shared_ptr<StdVectorFst> verbalizer_ = nullptr;
std::shared_ptr<StringCompiler<StdArc>> compiler_ = nullptr;
};

Expand Down
1 change: 0 additions & 1 deletion runtime/processor/token_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ std::string TokenParser::reorder(const std::string& input) {
for (auto& token : tokens) {
output += token.string(orders) + " ";
}
tokens.clear();
Comment thread
xingchensong marked this conversation as resolved.
return trim(output);
}

Expand Down
Loading