Skip to content

Latest commit

 

History

History
200 lines (155 loc) · 8.77 KB

README.md

File metadata and controls

200 lines (155 loc) · 8.77 KB

MacBertMaskedLM For Chinese Spelling Correction

本项目是 MacBERT 改变网络结构的中文文本纠错模型,可支持 BERT 类模型为 backbone。

"MacBERT shares the same pre-training tasks as BERT with several modifications." —— (Cui et al., Findings of the EMNLP 2020)

MacBert4csc 模型网络结构:

  • 本项目是 MacBERT 改变网络结构的中文文本纠错模型,可支持 BERT 类模型为 backbone。
  • 在通常 BERT 模型上进行了魔改,追加了一个全连接层作为错误检测即 detection, 与 SoftMaskedBERT 模型不同点在于,本项目中的 MacBERT 中,只是利用 detection 层和 correction 层的 loss 加权得到最终的 loss。不像 SoftmaskedBERT 中需要利用 detection 层的置信概率来作为 correction 的输入权重。

macbert_network

MacBERT 简介

MacBERT 全称为 MLM as correction BERT,其中 MLM 指的是 masked language model。
MacBERT 的模型网络结构上可以选择任意 BERT 类模型,其主要特征在于预训练时不同的 MLM task 设计:

  • 使用全词屏蔽 (wwm, whole-word masking) 以及 N-gram 屏蔽策略来选择 candidate tokens 进行屏蔽;
  • BERT 类模型通常使用 [MASK] 来屏蔽原词,而 MacBERT 使用第三方的同义词工具来为目标词生成近义词用于屏蔽原词,特别地,当原词没有近义词时,使用随机 n-gram 来屏蔽原词;
  • 和 BERT 类模型相似地,对于每个训练样本,输入中 80% 的词被替换成近义词(原为[MASK])、10%的词替换为随机词,10%的词不变。

MLM as Correction Mask strategies:
macbert_strategies

使用说明

快速加载

pycorrector调用预测

example: examples/macbert/demo.py

from pycorrector import MacBertCorrector
m = MacBertCorrector("shibing624/macbert4csc-base-chinese")
print(m.correct_batch(['今天新情很好', '你找到你最喜欢的工作,我也很高心。']))

output:

[{'source': '今天新情很好', 'target': '今天心情很好', 'errors': [('', '', 2)]},
{'source': '你找到你最喜欢的工作,我也很高心。', 'target': '你找到你最喜欢的工作,我也很高兴。', 'errors': [('', '', 15)]}]

transformers调用预测

当然,你也可使用官方的transformers库进行调用。

import torch
from transformers import BertTokenizerFast, BertForMaskedLM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = BertTokenizerFast.from_pretrained("shibing624/macbert4csc-base-chinese")
model = BertForMaskedLM.from_pretrained("shibing624/macbert4csc-base-chinese")
model.to(device)

texts = ["今天新情很好", "你找到你最喜欢的工作,我也很高心。"]
with torch.no_grad():
    outputs = model(**tokenizer(texts, padding=True, return_tensors='pt').to(device))

result = []
for ids, text in zip(outputs.logits, texts):
    _text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '')
    corrected_text = _text[:len(text)]
    print(text, ' => ', corrected_text)
    result.append(corrected_text)
print(result)

output:

今天新情很好  =>  今天心情很好 
你找到你最喜欢的工作,我也很高心。  =>  你找到你最喜欢的工作,我也很高兴。 

模型文件组成:

macbert4csc-base-chinese
    ├── config.json
    ├── added_tokens.json
    ├── pytorch_model.bin
    ├── special_tokens_map.json
    ├── tokenizer_config.json
    └── vocab.txt

Evaluate

提供评估脚本pycorrector/utils/evaluate_utils.py

  • 评估样本集:SIGHAN2015测试集。
  • 计算纠错准召率:采用保守计算方式,简单把纠错之后与正确句子完成匹配的视为正确,否则为错。

执行该评估脚本后,shibing624/macbert4csc-base-chinese 在 SIGHAN2015 测试集纠错效果评估如下:

  • Char Level: precision:0.9372, recall:0.8640, f1:0.8991
  • Sentence Level: precision:0.8264, recall:0.7366, f1:0.7789

由于训练使用的数据使用了 SIGHAN2015 的训练集(复现paper),在 SIGHAN2015 的测试集上达到SOTA水平。

训练

安装依赖

pip install transformers>=4.1.1 pytorch-lightning==1.4.9 torch>=1.7.0 yacs torchmetrics==0.6.0

训练数据集

toy数据集

sighan 2015中文拼写纠错数据(2k条):examples/data/sighan_2015/train.json

data format:

[
    {
        "id": "A2-0003-1",
        "original_text": "但是我不能去参加,因为我有一点事情阿!",
        "wrong_ids": [
            17
        ],
        "correct_text": "但是我不能去参加,因为我有一点事情啊!"
    },
]

SIGHAN+Wang271K中文纠错数据集

数据集 语料 下载链接 压缩包大小
SIGHAN+Wang271K中文纠错数据集 SIGHAN+Wang271K(27万条) 百度网盘(密码01b9)
shibing624/CSC
106M
原始SIGHAN数据集 SIGHAN13 14 15 官方csc.html 339K
原始Wang271K数据集 Wang271K Automatic-Corpus-Generation dimmywang提供 93M

SIGHAN+Wang271K中文纠错数据集,数据格式:

[
    {
        "id": "B2-4029-3",
        "original_text": "晚间会听到嗓音,白天的时候大家都不会太在意,但是在睡觉的时候这嗓音成为大家的恶梦。",
        "wrong_ids": [
            5,
            31
        ],
        "correct_text": "晚间会听到噪音,白天的时候大家都不会太在意,但是在睡觉的时候这噪音成为大家的恶梦。"
    }
]

下载SIGHAN+Wang271K中文纠错数据集,下载后新建output文件夹并放里面,文件位置同上。

自有数据集

把自己数据集标注好,保存为跟训练样本集一样的json格式,然后加载模型继续训练即可。

  1. 已有大量业务相关错误样本,主要标注错误位置(wrong_ids)和纠错后的句子(correct_text)
  2. 没有现成的错误样本,可以手动写脚本生成错误样本(original_text),根据音似、形似等特征把正确句子的指定位置(wrong_ids)字符改为错字,附上 第三方同音字生成脚本同音词替换

训练 MacBert4CSC 模型

python train.py --config_file train_macbert4csc.yml

注意:MacBert4CSC模型只能处理对齐长度的文本纠错问题,不能处理多字、少字的错误,所以训练集original_text需要和correct_text长度一样。 否则会报错:“ValueError: Expected input batch_size (*A) to match target batch_size (*B).”

预测

  • 方法一:直接加载保存的ckpt文件:
python predict_ckpt.py
  • 方法二:加载pytorch_model.bin文件: 把outputs_macbert4csc文件夹路径赋值给MacBertCorrector类的model_name_or_path, 就可以像上面快速加载使用pycorrector或者transformers调用。
outputs-macbert4csc
    ├── config.json
    ├── pytorch_model.bin
    ├── special_tokens_map.json
    ├── tokenizer_config.json
    └── vocab.txt

示例predict.py:

python predict.py

训练 SoftMaskedBert4CSC 模型

python train.py --config_file train_softmaskedbert4csc.yml

Release model

基于SIGHAN+Wang271K中文纠错数据集训练的macbert4csc模型,已经release到HuggingFace models: https://huggingface.co/shibing624/macbert4csc-base-chinese

Reference