diff --git a/examples/macbert/model_correction_pipeline_demo.py b/examples/macbert/model_correction_pipeline_demo.py index ab981a27..95cc911b 100644 --- a/examples/macbert/model_correction_pipeline_demo.py +++ b/examples/macbert/model_correction_pipeline_demo.py @@ -34,8 +34,13 @@ model1 = MacBertCorrector() # add confusion corrector for post process - confusion_dict = {"喝小明同学": "喝小茗同学", "老人让坐": "老人让座", "平净": "平静", "分知": "分支", - "天氨门": "天安门"} + confusion_dict = { + "喝小明同学": "喝小茗同学", + "老人让坐": "老人让座", + "平净": "平静", + "分知": "分支", + "天氨门": "天安门" + } model2 = ConfusionCorrector(custom_confusion_path_or_dict=confusion_dict) for line in error_sentences: r1 = model1.correct(line) diff --git a/pycorrector/confusion_corrector.py b/pycorrector/confusion_corrector.py index 7c1e99e4..5cc5c865 100644 --- a/pycorrector/confusion_corrector.py +++ b/pycorrector/confusion_corrector.py @@ -5,9 +5,9 @@ 功能:1)补充纠错对,提升召回率;2)对误杀加白,提升准确率 """ import os -import re from typing import List +from ahocorasick import Automaton from loguru import logger @@ -21,7 +21,12 @@ def __init__(self, custom_confusion_path_or_dict=''): self.custom_confusion = self.load_custom_confusion_dict(custom_confusion_path_or_dict) else: raise ValueError('custom_confusion_path_or_dict must be dict or str.') - logger.debug('Loaded confusion size: %d' % len(self.custom_confusion)) + + self.automaton = Automaton() + for idx, (err, truth) in enumerate(self.custom_confusion.items()): + self.automaton.add_word(err, (idx, err, truth)) + self.automaton.make_automaton() + logger.debug('Loaded confusion size: %d, make automaton done' % len(self.custom_confusion)) @staticmethod def load_custom_confusion_dict(path): @@ -53,21 +58,19 @@ def correct(self, sentence: str): :param sentence: str, 待纠错的文本 :return: dict, {'source': 'src', 'target': 'trg', 'errors': [(error_word, correct_word, position), ...]} """ - corrected_sentence = sentence + corrected_sentence = list(sentence) details = [] - # 自定义混淆集加入疑似错误词典 - for err, truth in self.custom_confusion.items(): - for i in re.finditer(err, sentence): - start, end = i.span() - corrected_sentence = corrected_sentence[:start] + truth + corrected_sentence[end:] - details.append((err, truth, start)) - return {'source': sentence, 'target': corrected_sentence, 'errors': details} + for end_index, (idx, err, truth) in self.automaton.iter(sentence): + start_index = end_index - len(err) + 1 + corrected_sentence[start_index:end_index + 1] = list(truth) + details.append((err, truth, start_index)) + + return {'source': sentence, 'target': ''.join(corrected_sentence), 'errors': details} def correct_batch(self, sentences: List[str]): """ 批量句子纠错 :param sentences: 句子文本列表 - :param kwargs: 其他参数 :return: list of {'source': 'src', 'target': 'trg', 'errors': [(error_word, correct_word, position), ...]} """ return [self.correct(s) for s in sentences] diff --git a/requirements-dev.txt b/requirements-dev.txt index 0bb077ea..5e39ce3b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,6 +3,7 @@ pypinyin numpy six loguru +pyahocorasick scikit-learn>=0.19.1 torch>=1.3.1 pytorch-lightning>=1.1.2 diff --git a/requirements.txt b/requirements.txt index 3863c23f..bd85ea1d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ datasets numpy pandas six -loguru \ No newline at end of file +loguru +pyahocorasick \ No newline at end of file