Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#17 from LiuChiachi/improve-data-augme…
Browse files Browse the repository at this point in the history
…nt-for-distill-lstm

Improve data augmentation for distilling Bi-LSTM
  • Loading branch information
guoshengCS committed Feb 26, 2021
2 parents cb6e992 + f0d7e16 commit c8ebb53
Show file tree
Hide file tree
Showing 10 changed files with 261 additions and 197 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,24 @@

在模型蒸馏中,较大的模型(在本例中是BERT)通常被称为教师模型,较小的模型(在本例中是Bi-LSTM)通常被成为学生模型。知识的蒸馏通常是通过模型学习蒸馏相关的损失函数实现,在本实验中,损失函数是均方误差损失函数,传入函数的两个参数分别是学生模型的输出和教师模型的输出。

[论文](https://arxiv.org/abs/1903.12136)的模型蒸馏阶段,作者为了能让教师模型表达出更多的知识供学生模型学习,对训练数据进行了数据增强。作者使用了三种数据增强方式,分别是:1.Masking,即以一定的概率将原数据中的word token替换成`[MASK]`;2. POS—guided word replacement,即以一定的概率将原数据中的词用与其有相同POS tag的词替换;3. n-gram sampling,即以一定的概率,从每条数据中采样n-gram,其中n的范围可通过人工设置。通过数据增强,可以产生更多无标签的训练数据,在训练过程中,学生模型可借助教师模型的“暗知识”,在更大的数据集上进行训练,产生更好的蒸馏效果。需要指出的是,实验只使用了第1和第3种数据增强方式。
[论文](https://arxiv.org/abs/1903.12136)的模型蒸馏阶段,作者为了能让教师模型表达出更多的知识供学生模型学习,对训练数据进行了数据增强。作者使用了三种数据增强方式,分别是:

1. Masking,即以一定的概率将原数据中的word token替换成`[MASK]`

2. POS—guided word replacement,即以一定的概率将原数据中的词用与其有相同POS tag的词替换;

3. n-gram sampling,即以一定的概率,从每条数据中采样n-gram,其中n的范围可通过人工设置。通过数据增强,可以产生更多无标签的训练数据,在训练过程中,学生模型可借助教师模型的“暗知识”,在更大的数据集上进行训练,产生更好的蒸馏效果。需要指出的是,实验只使用了第1和第3种数据增强方式。
在英文数据集任务上,本文使用了Google News语料[预训练的Word Embedding](https://code.google.com/archive/p/word2vec/)初始化小模型的Embedding层。

本实验分为三个训练过程:在特定任务上对BERT的fine-tuning、在特定任务上对基于Bi-LSTM的小模型的训练(用于评价蒸馏效果)、将BERT模型的知识蒸馏到基于Bi-LSTM的小模型上。

## 环境要求
运行本目录下的范例模型需要安装PaddlePaddle 2.0及以上版本。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](https://www.paddlepaddle.org.cn/#quick-start)中的说明更新 PaddlePaddle 安装版本。
另外,本项目还依赖paddlenlp,可以使用下面的命令进行安装:

另外,本文下载并在对英文数据集的训练中使用了Google News语料[预训练的Word Embedding](https://code.google.com/archive/p/word2vec/)初始化小模型的Embedding层,并使用gensim包对该Word Embedding文件进行读取。因此,运行本实验还需要安装`gensim`及下载预训练的Word Embedding。

```shell
pip install paddlenlp==2.0.0rc
```

## 数据、预训练模型介绍及获取

Expand Down Expand Up @@ -61,11 +70,11 @@ python -u ./run_bert_finetune.py \
--num_train_epochs 3 \
--logging_steps 10 \
--save_steps 10 \
--output_dir ../distill/ditill_lstm/model/$TASK_NAME/ \
--output_dir ../model_compression/distill_lstm/pretrained_modelss/$TASK_NAME/ \
--n_gpu 1 \

```
训练完成之后,可将训练效果最好的模型保存在本项目下的`models/$TASK_NAME/`下。模型目录下有`model_config.json`, `model_state.pdparams`, `tokenizer_config.json``vocab.txt`这几个文件。
训练完成之后,可将训练效果最好的模型保存在本项目下的`pretrained_models/$TASK_NAME/`下。模型目录下有`model_config.json`, `model_state.pdparams`, `tokenizer_config.json``vocab.txt`这几个文件。


### 训练小模型
Expand All @@ -83,8 +92,9 @@ CUDA_VISIBLE_DEVICES=0 python small.py \
--optimizer adam \
--lr 3e-4 \
--dropout_prob 0.2 \
--use_pretrained_emb False \
--vocab_path senta_word_dict_subset.txt
--vocab_path senta_word_dict_subset.txt \
--output_dir small_models/senta/

```

```shell
Expand All @@ -95,7 +105,9 @@ CUDA_VISIBLE_DEVICES=0 python small.py \
--batch_size 64 \
--lr 1.0 \
--dropout_prob 0.4 \
--use_pretrained_emb True
--output_dir small_models/SST-2 \
--embedding_name w2v.google_news.target.word-word.dim300.en

```

```shell
Expand All @@ -106,7 +118,9 @@ CUDA_VISIBLE_DEVICES=0 python small.py \
--batch_size 256 \
--lr 2.0 \
--dropout_prob 0.4 \
--use_pretrained_emb True
--output_dir small_models/QQP \
--embedding_name w2v.google_news.target.word-word.dim300.en

```

### 蒸馏模型
Expand All @@ -121,9 +135,10 @@ CUDA_VISIBLE_DEVICES=0 python bert_distill.py \
--dropout_prob 0.1 \
--batch_size 64 \
--model_name bert-wwm-ext-chinese \
--use_pretrained_emb False \
--teacher_path model/senta/best_bert_wwm_ext_model_880/model_state.pdparams \
--vocab_path senta_word_dict_subset.txt
--teacher_path pretrained_models/senta/best_bert_wwm_ext_model_880/model_state.pdparams \
--vocab_path senta_word_dict_subset.txt \
--output_dir distilled_models/senta

```

```shell
Expand All @@ -136,8 +151,10 @@ CUDA_VISIBLE_DEVICES=0 python bert_distill.py \
--dropout_prob 0.2 \
--batch_size 128 \
--model_name bert-base-uncased \
--use_pretrained_emb True \
--teacher_path model/SST-2/best_model_610/model_state.pdparams
--embedding_name w2v.google_news.target.word-word.dim300.en \
--output_dir distilled_models/SST-2 \
--teacher_path pretrained_models/SST-2/best_model_610/model_state.pdparams

```

```shell
Expand All @@ -149,24 +166,25 @@ CUDA_VISIBLE_DEVICES=0 python bert_distill.py \
--dropout_prob 0.2 \
--batch_size 256 \
--model_name bert-base-uncased \
--use_pretrained_emb True \
--embedding_name w2v.google_news.target.word-word.dim300.en \
--n_iter 10 \
--teacher_path model/QQP/best_model_17000/model_state.pdparams
--output_dir distilled_models/QQP \
--teacher_path pretrained_models/QQP/best_model_17000/model_state.pdparams

```

各参数的具体说明请参阅 `args.py` ,注意在训练不同任务时,需要调整对应的超参数。


## 蒸馏实验结果
本蒸馏实验基于GLUE的SST-2、QQP、中文情感分类ChnSentiCorp数据集。实验效果均使用每个数据集的验证集(dev)进行评价,评价指标是准确率(acc),其中QQP中包含f1值。利用基于BERT的教师模型去蒸馏基于Bi-LSTM的学生模型,对比Bi-LSTM小模型单独训练,在SST-2、QQP、senta(中文情感分类)任务上分别有3.2%、1.8%、1.4%的提升。

| Model | SST-2(dev acc) | QQP(dev acc/f1) | ChnSentiCorp(dev acc) | ChnSentiCorp(dev acc) |
| -------------- | ----------------- | -------------------------- | --------------------- | --------------------- |
| teacher model | bert-base-uncased | bert-base-uncased | bert-base-chinese | bert-wwm-ext-chinese |
| Teacher | 0.930046 | 0.905813(acc)/0.873472(f1) | 0.951667 | 0.955000 |
| Student | 0.853211 | 0.856171(acc)/0.806057(f1) | 0.920833 | 0.920800 |
| Distilled | 0.885321 | 0.874375(acc)/0.829581(f1) | 0.930000 | 0.935000 |

本蒸馏实验基于GLUE的SST-2、QQP、中文情感分类ChnSentiCorp数据集。实验效果均使用每个数据集的验证集(dev)进行评价,评价指标是准确率(acc),其中QQP中包含f1值。利用基于BERT的教师模型去蒸馏基于Bi-LSTM的学生模型,对比Bi-LSTM小模型单独训练,在SST-2、QQP、senta(中文情感分类)任务上分别有3.3%、1.9%、1.4%的提升。

| Model | SST-2(dev acc) | QQP(dev acc/f1) | ChnSentiCorp(dev acc) | ChnSentiCorp(dev acc) |
| ----------------- | ----------------- | -------------------------- | --------------------- | --------------------- |
| Teacher model | bert-base-uncased | bert-base-uncased | bert-base-chinese | bert-wwm-ext-chinese |
| BERT-base | 0.930046 | 0.905813(acc)/0.873472(f1) | 0.951667 | 0.955000 |
| Bi-LSTM | 0.854358 | 0.856616(acc)/0.799682(f1) | 0.920000 | 0.920000 |
| Distilled Bi-LSTM | 0.887615 | 0.875216(acc)/0.831254(f1) | 0.932500 | 0.934167 |

## 参考文献

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,6 @@ def parse_args():
parser.add_argument(
"--num_layers", type=int, default=1, help="Layers number of LSTM.")

parser.add_argument(
'--use_pretrained_emb',
type=eval,
default=False,
help='Whether to use pre-trained embedding tensor.')

parser.add_argument(
"--emb_dim", type=int, default=300, help="Embedding dim.")

Expand Down Expand Up @@ -84,6 +78,12 @@ def parse_args():
default=10,
help="The frequency to print evaluation logs.")

parser.add_argument(
"--save_steps",
type=int,
default=100,
help="The frequency to print evaluation logs.")

parser.add_argument(
"--padding_idx",
type=int,
Expand All @@ -106,6 +106,24 @@ def parse_args():
default='/root/.paddlenlp/models/bert-base-uncased/bert-base-uncased-vocab.txt',
help="Student model's vocab path.")

parser.add_argument(
"--output_dir",
type=str,
default='models',
help="Directory to save models .")

parser.add_argument(
"--whole_word_mask",
action="store_true",
help="If True, use whole word masking method in data augmentation in distilling."
)

parser.add_argument(
"--embedding_name",
type=str,
default=None,
help="The name of pretrained word embedding.")

parser.add_argument(
"--vocab_size",
type=int,
Expand All @@ -118,5 +136,12 @@ def parse_args():
default=0.0,
help="Weight balance between cross entropy loss and mean square loss.")

parser.add_argument(
"--seed",
type=int,
default=2021,
help="Random seed for model parameter initialization, data augmentation and so on."
)

args = parser.parse_args()
return args
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time

import paddle
import paddle.nn as nn
from paddle.metric import Metric, Accuracy, Precision, Recall
from paddle.metric import Accuracy

from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer
from paddlenlp.transformers.tokenizer_utils import whitespace_tokenize
from paddlenlp.transformers import BertForSequenceClassification
from paddlenlp.metrics import AccuracyAndF1
from paddlenlp.datasets import GlueSST2, GlueQQP, ChnSentiCorp

from args import parse_args
from small import BiLSTM
from data import create_distill_loader, load_embedding
from data import create_distill_loader

TASK_CLASSES = {
"sst-2": (GlueSST2, Accuracy),
Expand All @@ -36,7 +36,6 @@

class TeacherModel(object):
def __init__(self, model_name, param_path):
self.tokenizer = BertTokenizer.from_pretrained(model_name)
self.model = BertForSequenceClassification.from_pretrained(model_name)
self.model.set_state_dict(paddle.load(param_path))
self.model.eval()
Expand Down Expand Up @@ -78,14 +77,14 @@ def do_train(agrs):
vocab_path=args.vocab_path,
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
n_iter=args.n_iter)

emb_tensor = load_embedding(
args.vocab_path) if args.use_pretrained_emb else None
n_iter=args.n_iter,
whole_word_mask=args.whole_word_mask,
seed=args.seed)

model = BiLSTM(args.emb_dim, args.hidden_size, args.vocab_size,
args.output_dim, args.padding_idx, args.num_layers,
args.dropout_prob, args.init_scale, emb_tensor)
args.output_dim, args.vocab_path, args.padding_idx,
args.num_layers, args.dropout_prob, args.init_scale,
args.embedding_name)

if args.optimizer == 'adadelta':
optimizer = paddle.optimizer.Adadelta(
Expand Down Expand Up @@ -143,12 +142,21 @@ def do_train(agrs):
acc = evaluate(args.task_name, model, metric, dev_data_loader)
print("eval done total : %s s" % (time.time() - tic_eval))
tic_train = time.time()

if i % args.save_steps == 0:
paddle.save(
model.state_dict(),
os.path.join(args.output_dir,
"step_" + str(global_step) + ".pdparams"))
paddle.save(optimizer.state_dict(),
os.path.join(args.output_dir,
"step_" + str(global_step) + ".pdopt"))

global_step += 1


if __name__ == '__main__':
paddle.seed(2021)
args = parse_args()
print(args)

paddle.seed(args.seed)
do_train(args)
Loading

0 comments on commit c8ebb53

Please sign in to comment.