Skip to content

Commit

Permalink
Add BigBird Model (PaddlePaddle#6)
Browse files Browse the repository at this point in the history
* merge pretrain and classifier scripts

* Fix bigbird classifier model
* remove parse_args
* use BigbirdTokenizer instead of spm
* use paddlenlp.data.Imdb instead of ImdbDataset
* remove BertConfig function
* use Adam instead of AdamW

* use BigBirdForTokenClassification instead of Classifier

* * remove useless args
* add save model

* add multi process training

* add help message of --model_name_or_path

* add pretraining

* finish training

* update BigBirdPretrainingCriterion impl

* add the tokenizer encode function for the bigbird

* temp

* temp

* fix the pretrain download parameter

* fix the bs 32 for pre-train

* update the learning_rate for the train

* remove the unuse code for the bigbird

* fix the dropout and format the code for bigbird

* remove the unuse code in run_classifier.py

* merge

* fix bigbird sparse grad bug

* add the readme for the bigbird

* remove chinese comment

* add the example data for the bigbird pretrain

* fix the load model for the bigbird

* upgrade Linear3D

* remove useless expand;change to zeros_like

* update the seed for the finetune model

* change the lr in readme

* fix the tokenizer index bug for the bigbird

* fix the tokenizer for the bigbird

* fix readme

* fix readme and comment

* change model_name_or_path default value

* fix doc description;remove useless code

Co-authored-by: fangzeyang <fangzeyang@baidu.com>
Co-authored-by: wawltor <fangzeyang0904@hotmail.com>
  • Loading branch information
3 people committed Mar 5, 2021
1 parent 80c1f77 commit 16f14c6
Show file tree
Hide file tree
Showing 10 changed files with 2,626 additions and 0 deletions.
103 changes: 103 additions & 0 deletions examples/language_model/bigbird/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Big Bird

## 模型介绍
[Big Bird](https://arxiv.org/abs/2007.14062)(Transformers for Longer Sequences) 是Google的研究人员提出的针对长序列预训练模型,使用了稀疏注意力机制,将计算复杂度、空间复杂度降到线性复杂度,大大提升了长序列任务的预测能力。
本项目是 Big Bird 的 PaddlePaddle 实现, 包含模型训练,模型验证等内容。以下是本例的简要目录结构及说明:

```text
.
├── args.py # 预训练任务的配置
├── run_classifier.py # IMDB数据集的分类任务
├── run_pretrain.py # 预训练任务脚本
├── README.md # 文档
└── data/ # 示例数据
```
## 快速开始

### 安装说明

* PaddlePaddle安装

本项目依赖于 PaddlePaddle 2.0.1 及以上版本或适当的develop版本,请参考 [安装指南](https://www.paddlepaddle.org.cn/install/quick) 进行安装。

* PaddleNLP 安装

```shell
pip install paddlenlp\>=2.0.0rc5
```

* 下载代码

克隆代码库到本地


### 数据准备
根据论文中的信息,目前 Big Bird 的预训练数据是主要是由 Books,CC-News,Stories, Wikipedia 4种预训练数据来构造,用户可以根据自己的需要来下载和清洗相应的数据。提供一份示例数据在 data 目录。


### 预训练任务

下面是预训练任务的具体的执行方式

```shell
unset CUDA_VISIBLE_DEVICES
python -m paddle.distributed.launch --gpus "0" run_pretrain.py --model_name_or_path bigbird-base-uncased \
--input_dir "./data" \
--output_dir "output" \
--batch_size 4 \
--weight_decay 0.01 \
--learning_rate 1e-5 \
--max_steps 100000 \
--save_steps 10000 \
--logging_steps 1 \
--max_encoder_length 512 \
--max_pred_length 75
```

其中参数释义如下:

- `model_name_or_path` 要训练的模型或者之前训练的checkpoint。
- `input_dir` 指定输入文件,可以使用目录,指定目录时将包括目录中的所有文件。
- `output_dir` 指定输出文件。
- `batch_size` 训练的batch大小
- `weight_decay` AdamW权重衰减参数
- `learning_rate` 训练的学习率
- `max_steps` 最大训练步数
- `save_steps` 保存模型间隔
- `logging_steps` 打印日志的步数
- `max_encoder_length` MLM任务的最大的token数目
- `max_pred_length` MLM任务最大的需要预测token的数目


### 验证任务
通过预训练任务训练完成之后,可以预训练的模型参数,在 Big Bird 的验证任务中通过IMDB数据集来进行最终模型效果的验证,[IMDB数据集](http://ai.stanford.edu/~amaas/data/sentiment/) ,IMDB数据集是关于电影用户评论情感分析的数据集,主要是包含了50000条偏向明显的评论,其中25000条作为训练集,25000作为测试集。label为pos(positive)和neg(negative),是一个序列文本分类任务,具体的执行脚本如下。


```shell
export CUDA_VISIBLE_DEVICES=0
python run_classifier.py --model_name_or_path bigbird-base-uncased-finetune \
--output_dir "output" \
--batch_size 2 \
--learning_rate 1e-5 \
--max_steps 10000 \
--save_steps 1000 \
--max_encoder_length 3072
```

其中参数释义如下:

- `model_name_or_path` 要训练的模型或者之前训练的checkpoint。
- `output_dir` 指定输出文件。
- `batch_size` 训练的batch大小。
- `learning_rate` 训练的学习率。
- `max_steps` 最大训练步数。
- `save_steps` 保存模型间隔。
- `logging_steps` 打印日志的步数。
- `max_encoder_length` MLM任务的最大的token数目。


基于`bigbird-base-uncased-finetune`在IMDB评测任务上Fine-tuning后,在验证集上有如下结果:

| Task | Metric | Result |
|:-----:|:----------------------------:|:-----------------:|
| IMDB | Accuracy | 0.9449 |
130 changes: 130 additions & 0 deletions examples/language_model/bigbird/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_type",
default="bigbird",
type=str,
help="Model type selected in training model.")

parser.add_argument(
"--model_name_or_path",
default="bigbird-base-uncased",
type=str,
help="Path to pre-trained model or shortcut model name for training model."
)

parser.add_argument(
"--input_dir",
default=None,
type=str,
required=True,
help="The input directory where the data will be read from.")

parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written."
)

parser.add_argument(
"--batch_size",
default=8,
type=int,
help="Batch size per GPU/CPU for training.")

parser.add_argument(
"--learning_rate",
default=5e-5,
type=float,
help="The initial learning rate for AdamW.")

parser.add_argument(
"--warmup_steps",
default=10000,
type=int,
help="Linear warmup over warmup_steps.")

parser.add_argument(
"--weight_decay",
default=0.01,
type=float,
help=" Weight decay rate if we apply in the optimizer of Adamw.")

parser.add_argument(
"--adam_epsilon",
default=1e-6,
type=float,
help="Epsilon for AdamW optimizer.")

parser.add_argument(
"--max_steps",
default=100000,
type=int,
help="If > 0: set total number of training steps to perform.")

parser.add_argument(
"--logging_steps",
type=int,
default=1,
help="Log every X updates steps.")

parser.add_argument(
"--save_steps",
type=int,
default=500,
help="Save checkpoint every X updates steps.")

parser.add_argument(
"--seed", type=int, default=42, help="Random seed for initialization.")

parser.add_argument(
"--device",
type=str,
default="gpu",
help="Select cpu, gpu, xpu devices to train model.")

parser.add_argument(
"--epochs",
type=int,
default=10,
help="Number of epoches for training.")

parser.add_argument(
"--max_encoder_length",
type=int,
default=512,
help="The maximum total input sequence length after SentencePiece tokenization."
)

parser.add_argument(
"--max_pred_length",
default=75,
type=int,
help="The maximum total of masked tokens in input sequence.")

parser.add_argument(
"--use_nsp",
default=False,
type=bool,
help="Whether or not add the nsp loss to the total loss.")

args = parser.parse_args()
return args
Loading

0 comments on commit 16f14c6

Please sign in to comment.