Skip to content

基于PyTorch框架实现的BERT模型及其相关下游微调任务(Bert PyTorch Implementation)

Notifications You must be signed in to change notification settings

XuRui314/BertWithPretrained

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

BertWithPretrained

本项目是一个基于PyTorch从零实现的BERT模型及相关下游任务示例的代码仓库,同时也包含了BERT模型以及每个下有任务原理的详细讲解。

BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding

在学习使用本项目之前需要清楚Transformer的相关原理,更多关于Transformer内容的介绍可以参考文章 This post is all you need(上卷)——层层剥开Transformer ,近4万余字、50张图、3个实战示例(翻译分类对联生成 ),带你一网打尽Transformer!

经过几个月磨磨蹭蹭地梳理,掌柜总算是把整个BERT模型的基本原理、代码实现以及原论文中所提到的4个微调任务场景都给详细地介绍了一遍。完整PDF点击此处 This post is all you need(下卷)——步步走进BERT模型 获取!

模型详细解析

工程结构

  • bert_base_chinese目录中是BERT base中文预训练模型以及配置文件

    模型下载地址:https://huggingface.co/bert-base-chinese/tree/main

  • bert_base_uncased_english目录中是BERT base英文预训练模型以及配置文件

    模型下载地址:https://huggingface.co/bert-base-uncased/tree/main

    注意:config.json中需要添加"pooler_type": "first_token_transform"这个参数

  • data目录中是各个下游任务所使用到的数据集

    • SingleSentenceClassification是今日头条的15分类中文数据集;
    • PairSentenceClassification是MNLI(The Multi-Genre Natural Language Inference Corpus, 多类型自然语言推理数据库)数据集;
    • MultipeChoice是SWAG问题选择数据集
    • SQuAD是斯坦福大学开源的问答数据集1.1版本
    • WikiText是维基百科英文语料用于模型预训练
    • SongCi是宋词语料用于中文模型预训练
  • model目录中是各个模块的实现

    • BasicBert中是基础的BERT模型实现模块
      • MyTransformer.py是自注意力机制实现部分;
      • BertEmbedding.py是Input Embedding实现部分;
      • BertConfig.py用于导入开源的config.json配置文件;
      • Bert.py是BERT模型的实现部分;
    • DownstreamTasks目录是下游任务各个模块的实现
      • BertForSentenceClassification.py是单标签句子分类的实现部分;
      • BertForMultipleChoice.py是问题选择模型的实现部分;
      • BertForQuestionAnswering.py是问题回答(text span)模型的实现部分;
      • BertForNSPAndMLM.py是BERT模型预训练的两个任务实现部分;
  • Task目录中是各个具体下游任务的训练和推理实现

    • TaskForSingleSentenceClassification.py是单标签单文本分类任务的训练和推理实现,可用于普通的文本分类任务;
    • TaskForPairSentence.py是文本对分类任务的训练和推理实现,可用于蕴含任务(例如MNLI数据集);
    • TaskForMultipleChoice.py是问答选择任务的训练和推理实现,可用于问答选择任务(例如SWAG数据集);
    • TaskForSQuADQuestionAnswering.py是问题回答任务的训练和推理实现,可用于问题问答任务(例如SQuAD数据集);
    • TaskForPretraining.py是BERT模型中MLM和NSP两个预训练任务的实现部分,可用于BERT模型预训练;
  • test目录中是各个模块的测试案例

  • utils是各个工具类的实现

    • data_helpers.py是各个下游任务的数据预处理及数据集构建模块;
    • log_helper.py是日志打印模块;
    • creat_pretraining_data.py是用于构造BERT预训练任务的数据集;

环境

Python版本为3.6,其它相关包的版本如下:

torch==1.5.0
torchtext==0.6.0
torchvision==0.6.0
transformers==4.5.1
numpy==1.19.5
pandas==1.1.5
scikit-learn==0.24.0
tqdm==4.61.0

使用方式

Step 1. 下载数据

下载完成各个数据集以及相应的BERT预训练模型(如果为空),并放入对应的目录中。具体可以查看每个数据(data)目录下的README.md文件。

Step 2. 运行模型

进入Tasks目录,运行相关模型.

2.1 中文文本分类任务

python TaskForSingleSentenceClassification.py

运行结果:

-- INFO: Epoch: 0, Batch[0/4186], Train loss :2.862, Train acc: 0.125
-- INFO: Epoch: 0, Batch[10/4186], Train loss :2.084, Train acc: 0.562
-- INFO: Epoch: 0, Batch[20/4186], Train loss :1.136, Train acc: 0.812        
-- INFO: Epoch: 0, Batch[30/4186], Train loss :1.000, Train acc: 0.734
...
-- INFO: Epoch: 0, Batch[4180/4186], Train loss :0.418, Train acc: 0.875
-- INFO: Epoch: 0, Train loss: 0.481, Epoch time = 1123.244s
...
-- INFO: Epoch: 9, Batch[4180/4186], Train loss :0.102, Train acc: 0.984
-- INFO: Epoch: 9, Train loss: 0.100, Epoch time = 1130.071s
-- INFO: Accurcay on val 0.884
-- INFO: Accurcay on val 0.888

2.2 英文文本蕴含任务

python TaskForPairSentenceClassification.py

运行结果:

-- INFO: Epoch: 0, Batch[0/17181], Train loss :1.082, Train acc: 0.438
-- INFO: Epoch: 0, Batch[10/17181], Train loss :1.104, Train acc: 0.438
-- INFO: Epoch: 0, Batch[20/17181], Train loss :1.129, Train acc: 0.250     
-- INFO: Epoch: 0, Batch[30/17181], Train loss :1.063, Train acc: 0.375
...
-- INFO: Epoch: 0, Batch[17180/17181], Train loss :0.367, Train acc: 0.909
-- INFO: Epoch: 0, Train loss: 0.589, Epoch time = 2610.604s
...
-- INFO: Epoch: 9, Batch[0/17181], Train loss :0.064, Train acc: 1.000
-- INFO: Epoch: 9, Train loss: 0.142, Epoch time = 2542.781s
-- INFO: Accurcay on val 0.827
-- INFO: Accurcay on val 0.830

2.3 SWAG多项选择任务

python TaskForMultipleChoice.py

运行结果:

[2021-11-11 21:32:50] - INFO: Epoch: 0, Batch[0/4597], Train loss :1.433, Train acc: 0.250
[2021-11-11 21:32:58] - INFO: Epoch: 0, Batch[10/4597], Train loss :1.277, Train acc: 0.438
[2021-11-11 21:33:01] - INFO: Epoch: 0, Batch[20/4597], Train loss :1.249, Train acc: 0.438
        ......
[2021-11-11 21:58:34] - INFO: Epoch: 0, Batch[4590/4597], Train loss :0.489, Train acc: 0.875
[2021-11-11 21:58:36] - INFO: Epoch: 0, Batch loss :0.786, Epoch time = 1546.173s
[2021-11-11 21:28:55] - INFO: Epoch: 0, Batch[0/4597], Train loss :1.433, Train acc: 0.250
[2021-11-11 21:30:52] - INFO: He is throwing darts at a wall. A woman, squats alongside flies side to side with his gun.  ## False
[2021-11-11 21:30:52] - INFO: He is throwing darts at a wall. A woman, throws a dart at a dartboard.   ## False
[2021-11-11 21:30:52] - INFO: He is throwing darts at a wall. A woman, collapses and falls to the floor.   ## False
[2021-11-11 21:30:52] - INFO: He is throwing darts at a wall. A woman, is standing next to him.    ## True
[2021-11-11 21:30:52] - INFO: Accuracy on val 0.794

2.4 SQuAD问题回答任务

python TaskForSQuADQuestionAnswering.py

运行结果:

[2022-01-02 14:42:17]缓存文件 ~/BertWithPretrained/data/SQuAD/dev-v1_128_384_64.pt 不存在重新处理并缓存!
[2022-01-02 14:42:17] - DEBUG: <<<<<<<<  进入新的example  >>>>>>>>>
[2022-01-02 14:42:17] - DEBUG: ## 正在预处理数据 utils.data_helpers is_training = False
[2022-01-02 14:42:17] - DEBUG: ## 问题 id: 56be5333acb8001400a5030d
[2022-01-02 14:42:17] - DEBUG: ## 原始问题 text: Which performers joined the headliner during the Super Bowl 50 halftime show?
[2022-01-02 14:42:17] - DEBUG: ## 原始描述 text: CBS broadcast Super Bowl 50 in the U.S., and charged an average of $5 million for a  ....
[2022-01-02 14:42:17]- DEBUG: ## 上下文长度为:87, 剩余长度 rest_len 为 : 367
[2022-01-02 14:42:17] - DEBUG: ## input_tokens: ['[CLS]', 'which', 'performers', 'joined', 'the', 'headline', '##r', 'during', 'the', ...]
[2022-01-02 14:42:17] - DEBUG: ## input_ids:[101, 2029, 9567, 2587, 1996, 17653, 2099, 2076, 1996, 3565, 4605, 2753, 22589, 2265, 1029, 102, 6568, ....]
[2022-01-02 14:42:17] - DEBUG: ## segment ids:[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]
[2022-01-02 14:42:17] - DEBUG: ## orig_map:{16: 0, 17: 1, 18: 2, 19: 3, 20: 4, 21: 5, 22: 6, 23: 7, 24: 7, 25: 7, 26: 7, 27: 7, 28: 8, 29: 9, 30: 10,....}
[2022-01-02 14:42:17] - DEBUG: ======================
....
[2022-01-02 15:13:50] - INFO: Epoch:0, Batch[810/7387] Train loss: 0.998, Train acc: 0.708
[2022-01-02 15:13:55] - INFO: Epoch:0, Batch[820/7387] Train loss: 1.130, Train acc: 0.708
[2022-01-02 15:13:59] - INFO: Epoch:0, Batch[830/7387] Train loss: 1.960, Train acc: 0.375
[2022-01-02 15:14:04] - INFO: Epoch:0, Batch[840/7387] Train loss: 1.933, Train acc: 0.542
......
[2022-01-02 15:15:27] - INFO:  ### Quesiotn: [CLS] when was the first university in switzerland founded..
[2022-01-02 15:15:27] - INFO:    ## Predicted answer: 1460
[2022-01-02 15:15:27] - INFO:    ## True answer: 1460
[2022-01-02 15:15:27] - INFO:    ## True answer idx: (tensor(46, tensor(47))
[2022-01-02 15:15:27] - INFO:  ### Quesiotn: [CLS] how many wards in plymouth elect two councillors?
[2022-01-02 15:15:27] - INFO:    ## Predicted answer: 17 of which elect three .....
[2022-01-02 15:15:27] - INFO:    ## True answer: three
[2022-01-02 15:15:27] - INFO:    ## True answer idx: (tensor(25, tensor(25))

运行结束后,data/SQuAD目录中会生成一个名为best_result.json的预测文件,此时只需要切换到该目录下,并运行以下代码即可得到在dev-v1.1.json的测试结果:

python evaluate-v1.1.py dev-v1.1.json best_result.json

"exact_match" : 80.879848628193, "f1": 88.338575234135

2.5 NSP与MLM任务训练及推理

if __name__ == '__main__':
    config = ModelConfig()
    train(config)
    sentences_1 = ["I no longer love her, true, but perhaps I love her.",
                   "Love is so short and oblivion so long."]

    sentences_2 = ["我住长江头,君住长江尾。",
                   "日日思君不见君,共饮长江水。",
                   "此水几时休,此恨何时已。",
                   "只愿君心似我心,定不负相思意。"]
    inference(config, sentences_2, masked=False, language='zh')

上述代码运行结束后将会看到类似如下所示的输出结果:

- INFO: ## 成功载入已有模型进行推理……
- INFO:  ### 原始:我住长江头,君住长江尾。
- INFO:   ## 掩盖:我住长江头,[MASK]住长[MASK]尾。
- INFO:   ## 预测:我住长江头,君住长河尾。  
- INFO: ====================
- INFO:  ### 原始:日日思君不见君,共饮长江水。
- INFO:   ## 掩盖:日日思君不[MASK]君,共[MASK]长江水。
- INFO:   ## 预测:日日思君不见君,共饮长江水。
#   ......

About

基于PyTorch框架实现的BERT模型及其相关下游微调任务(Bert PyTorch Implementation)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%