Skip to content

Latest commit

 

History

History
237 lines (181 loc) · 9.57 KB

README.md

File metadata and controls

237 lines (181 loc) · 9.57 KB

MacBertMaskedLM For Correction

本项目是 MacBERT 改变网络结构的中文文本纠错模型,可支持 BERT 类模型为 backbone。

"MacBERT shares the same pre-training tasks as BERT with several modifications." —— (Cui et al., Findings of the EMNLP 2020)

MacBert4csc 模型网络结构:

  • 本项目是 MacBERT 改变网络结构的中文文本纠错模型,可支持 BERT 类模型为 backbone。
  • 在通常 BERT 模型上进行了魔改,追加了一个全连接层作为错误检测即 detection, 与 SoftMaskedBERT 模型不同点在于,本项目中的 MacBERT 中,只是利用 detection 层和 correction 层的 loss 加权得到最终的 loss。不像 SoftmaskedBERT 中需要利用 detection 层的置信概率来作为 correction 的输入权重。

macbert_network

MacBERT 简介

MacBERT 全称为 MLM as correction BERT,其中 MLM 指的是 masked language model。
MacBERT 的模型网络结构上可以选择任意 BERT 类模型,其主要特征在于预训练时不同的 MLM task 设计:

  • 使用全词屏蔽 (wwm, whole-word masking) 以及 N-gram 屏蔽策略来选择 candidate tokens 进行屏蔽;
  • BERT 类模型通常使用 [MASK] 来屏蔽原词,而 MacBERT 使用第三方的同义词工具来为目标词生成近义词用于屏蔽原词,特别地,当原词没有近义词时,使用随机 n-gram 来屏蔽原词;
  • 和 BERT 类模型相似地,对于每个训练样本,输入中 80% 的词被替换成近义词(原为[MASK])、10%的词替换为随机词,10%的词不变。

MLM as Correction Mask strategies:
macbert_strategies

使用说明

快速加载

pycorrector调用

example: correct_demo.py

from pycorrector.macbert.macbert_corrector import MacBertCorrector

nlp = MacBertCorrector("shibing624/macbert4csc-base-chinese").macbert_correct

i = nlp('今天新情很好')
print(i)

transformers调用

当然,你也可使用官方的transformers库进行调用。

1.先pip安装transformers库:

pip install transformers>=4.1.1

2.使用以下示例执行:

import operator
import torch
from transformers import BertTokenizerFast, BertForMaskedLM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = BertTokenizerFast.from_pretrained("shibing624/macbert4csc-base-chinese")
model = BertForMaskedLM.from_pretrained("shibing624/macbert4csc-base-chinese")
model.to(device)

texts = ["今天新情很好", "你找到你最喜欢的工作,我也很高心。"]
with torch.no_grad():
    outputs = model(**tokenizer(texts, padding=True, return_tensors='pt').to(device))

def get_errors(corrected_text, origin_text):
    sub_details = []
    for i, ori_char in enumerate(origin_text):
        if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']:
            # add unk word
            corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
            continue
        if i >= len(corrected_text):
            continue
        if ori_char != corrected_text[i]:
            if ori_char.lower() == corrected_text[i]:
                # pass english upper char
                corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
                continue
            sub_details.append((ori_char, corrected_text[i], i, i + 1))
    sub_details = sorted(sub_details, key=operator.itemgetter(2))
    return corrected_text, sub_details

result = []
for ids, text in zip(outputs.logits, texts):
    _text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '')
    corrected_text = _text[:len(text)]
    corrected_text, details = get_errors(corrected_text, text)
    print(text, ' => ', corrected_text, details)
    result.append((corrected_text, details))
print(result)

output:

今天新情很好  =>  今天心情很好 [('', '', 2, 3)]
你找到你最喜欢的工作,我也很高心。  =>  你找到你最喜欢的工作,我也很高兴。 [('', '', 15, 16)]

模型文件组成:

macbert4csc-base-chinese
    ├── config.json
    ├── added_tokens.json
    ├── pytorch_model.bin
    ├── special_tokens_map.json
    ├── tokenizer_config.json
    └── vocab.txt

Evaluate

提供评估脚本pycorrector/utils/eval.py,该脚本有两个功能:

  • 构建评估样本集:评估集pycorrector/data/eval_corpus.json, 包括字粒度错误100条、词粒度错误100条、语法错误100条,正确句子200条。用户可以修改条数生成其他评估样本分布。
  • 计算纠错准召率:采用保守计算方式,简单把纠错之后与正确句子完成匹配的视为正确,否则为错。

执行该评估脚本后,

shibing624/macbert4csc-base-chinese 模型在 corpus500 纠错效果评估如下:

  • Sentence Level: acc:0.6560, precision:0.7797, recall:0.5919, f1:0.6730

规则方法(加入自定义混淆集)在corpus500纠错效果评估如下:

  • Sentence Level: acc:0.6400, recall:0.5067

shibing624/macbert4csc-base-chinese 在 SIGHAN2015 测试集纠错效果评估如下:

  • Char Level: precision:0.9372, recall:0.8640, f1:0.8991
  • Sentence Level: precision:0.8264, recall:0.7366, f1:0.7789

由于训练使用的数据使用了 SIGHAN2015 的训练集(复现paper),在 SIGHAN2015 的测试集上达到SOTA水平。

评估case

  • run python tests/macbert_corrector_test.py result 在 SIGHAN2015 的测试集上达到了SOTA水平。

训练

安装依赖

pip install transformers>=4.1.1 pytorch-lightning==1.4.9 torch>=1.7.0 yacs

训练数据集

toy数据集(约1千条)

cd macbert
python preprocess.py

得到toy数据集文件:

macbert/output
|-- dev.json
|-- test.json
`-- train.json

SIGHAN+Wang271K中文纠错数据集

数据集 语料 下载链接 压缩包大小
SIGHAN+Wang271K中文纠错数据集 SIGHAN+Wang271K(27万条) 百度网盘(密码01b9)
shibing624/CSC
106M
原始SIGHAN数据集 SIGHAN13 14 15 官方csc.html 339K
原始Wang271K数据集 Wang271K Automatic-Corpus-Generation dimmywang提供 93M

SIGHAN+Wang271K中文纠错数据集,数据格式:

[
    {
        "id": "B2-4029-3",
        "original_text": "晚间会听到嗓音,白天的时候大家都不会太在意,但是在睡觉的时候这嗓音成为大家的恶梦。",
        "wrong_ids": [
            5,
            31
        ],
        "correct_text": "晚间会听到噪音,白天的时候大家都不会太在意,但是在睡觉的时候这噪音成为大家的恶梦。"
    }
]

下载SIGHAN+Wang271K中文纠错数据集,下载后新建output文件夹并放里面,文件位置同上。

自有数据集

把自己数据集标注好,保存为跟训练样本集一样的json格式,然后加载模型继续训练即可。

  1. 已有大量业务相关错误样本,主要标注错误位置(wrong_ids)和纠错后的句子(correct_text)
  2. 没有现成的错误样本,可以手动写脚本生成错误样本(original_text),根据音似、形似等特征把正确句子的指定位置(wrong_ids)字符改为错字,附上 第三方同音字生成脚本同音词替换

训练 MacBert4CSC 模型

python train.py

注意:MacBert4CSC模型只能处理对齐文本的纠错问题,不能处理多字、少字的错误,所以训练集original_text需要和correct_text长度一样。 否则会报错:“ValueError: Expected input batch_size (*A) to match target batch_size (*B).”

预测

  • 方法一:直接加载保存的ckpt文件:
python infer.py
  • 方法二:加载pytorch_model.bin文件: 把output/macbert4csc文件夹下以下模型文件复制到~/.pycorrector/datasets/macbert_models/chinese_finetuned_correction目录下, 就可以像上面快速加载使用pycorrector或者transformers调用。
output
└── macbert4csc
    ├── config.json
    ├── pytorch_model.bin
    ├── special_tokens_map.json
    ├── tokenizer_config.json
    └── vocab.txt

demo示例macbert_corrector.py:

python3 macbert_corrector.py

训练 SoftMaskedBert4CSC 模型

python train.py --config_file train_softmaskedbert4csc.yml

Reference