Skip to content

Commit

Permalink
fix #494
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed May 17, 2024
1 parent d07ec8a commit 5cdf3cc
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 14 deletions.
9 changes: 7 additions & 2 deletions examples/macbert/model_correction_pipeline_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,13 @@

model1 = MacBertCorrector()
# add confusion corrector for post process
confusion_dict = {"喝小明同学": "喝小茗同学", "老人让坐": "老人让座", "平净": "平静", "分知": "分支",
"天氨门": "天安门"}
confusion_dict = {
"喝小明同学": "喝小茗同学",
"老人让坐": "老人让座",
"平净": "平静",
"分知": "分支",
"天氨门": "天安门"
}
model2 = ConfusionCorrector(custom_confusion_path_or_dict=confusion_dict)
for line in error_sentences:
r1 = model1.correct(line)
Expand Down
25 changes: 14 additions & 11 deletions pycorrector/confusion_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
功能:1)补充纠错对,提升召回率;2)对误杀加白,提升准确率
"""
import os
import re
from typing import List

from ahocorasick import Automaton
from loguru import logger


Expand All @@ -21,7 +21,12 @@ def __init__(self, custom_confusion_path_or_dict=''):
self.custom_confusion = self.load_custom_confusion_dict(custom_confusion_path_or_dict)
else:
raise ValueError('custom_confusion_path_or_dict must be dict or str.')
logger.debug('Loaded confusion size: %d' % len(self.custom_confusion))

self.automaton = Automaton()
for idx, (err, truth) in enumerate(self.custom_confusion.items()):
self.automaton.add_word(err, (idx, err, truth))
self.automaton.make_automaton()
logger.debug('Loaded confusion size: %d, make automaton done' % len(self.custom_confusion))

@staticmethod
def load_custom_confusion_dict(path):
Expand Down Expand Up @@ -53,21 +58,19 @@ def correct(self, sentence: str):
:param sentence: str, 待纠错的文本
:return: dict, {'source': 'src', 'target': 'trg', 'errors': [(error_word, correct_word, position), ...]}
"""
corrected_sentence = sentence
corrected_sentence = list(sentence)
details = []
# 自定义混淆集加入疑似错误词典
for err, truth in self.custom_confusion.items():
for i in re.finditer(err, sentence):
start, end = i.span()
corrected_sentence = corrected_sentence[:start] + truth + corrected_sentence[end:]
details.append((err, truth, start))
return {'source': sentence, 'target': corrected_sentence, 'errors': details}
for end_index, (idx, err, truth) in self.automaton.iter(sentence):
start_index = end_index - len(err) + 1
corrected_sentence[start_index:end_index + 1] = list(truth)
details.append((err, truth, start_index))

return {'source': sentence, 'target': ''.join(corrected_sentence), 'errors': details}

def correct_batch(self, sentences: List[str]):
"""
批量句子纠错
:param sentences: 句子文本列表
:param kwargs: 其他参数
:return: list of {'source': 'src', 'target': 'trg', 'errors': [(error_word, correct_word, position), ...]}
"""
return [self.correct(s) for s in sentences]
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pypinyin
numpy
six
loguru
pyahocorasick
scikit-learn>=0.19.1
torch>=1.3.1
pytorch-lightning>=1.1.2
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ datasets
numpy
pandas
six
loguru
loguru
pyahocorasick

0 comments on commit 5cdf3cc

Please sign in to comment.