    # **TinyBERT实现案例**

    BERT模型是NLP领域最著名的模型之一，它的出现带动了NLP领域预训练+微调方法的快速发展。BERT模型拥有优秀的自然语言理解能力，但模型参数庞大，训练时间耗费较长。TinyBERT是缩小版的BERT，对BERT的架构进行了简化。从推理角度看，TinyBERT比BERT-base（BERT模型基础版本）体积小了7.5倍、速度快了9.4倍，自然语言理解的性能表现更突出。

    **论文** :

    Transformer：https://proceedings.neurips.cc/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html

    bert：https://arxiv.org/abs/1810.04805 

    tinyBERT:https://arxiv.org/abs/1909.10351

# BERT介绍
TinyBERT的模型架构依然是BERT的架构，因此如果想掌握tinyBERT的结构，那么就必须知
道BERT的模型结构。反过来说，如果了解BERT的模型结构，那么自然会对tinyBERT有了清晰的认知。因此这里先对BERT的内部结构进行介绍。

  了解BERT模型之前，希望你对Transformer架构已经有了基本的了解。Transformer来源于自注意力机制。简单的介绍一下，Transformer的结构分为编码器和解码器，编码器可以把文字编码成高纬度的特征，而解码器可以用这些特征去完成相应任务，比如生成翻译。Transformer非常强大，NLP届的两大巨头模型都是由他的一部分改变而来，如下图。GPT取用的是Transformer的解码器部分，而BERT取的是Transformer的编码器，因此Bert的作用，大家可以理解为一个编码器,将输入的文字抽取为高维的特征。
   ![Transformer](ipyphoto/BERTandGPT.jpg)
                             
                                 
                                
  <div align="center"><i>图1：Transformer, BERT 和 GPT </i></div>

   在CV领域，在imagenet上训练一个编码器，迁移到其他任务，已经是惯用手段了。 而在BERT以前，NLP领域使用预训练技术的并不多。BERT是在大规模的无标注文本上进行预训练的，它为无标注的文本设计了两个自监督预训练任务，第一个是MLM（Masked Language Model）任务： 遮盖掉句子中一定比例的单词，使用剩余单词来预测被遮盖的词。 第二个是NSP（Next Sentence Prediction）任务：预测输入的两个句子是否是上下文关系。这样，无需标注即可在海量文本数据上进行预训练，得到一个效果十分良好的编码器，然后使用此编码器在其他文本任务上进行微调。BERT所建立的预训练加微调的模式在后续的NLP网络中得到了广泛的应用。
   
## BERT结构  
   BERT模型主要由三部分构成：嵌入层（embedding），BERT layer堆叠层， 和输出层。 一句文本，首先会被分词，然后进入BERT，经过嵌入层，得到嵌入向量。之后嵌入向量会进入BERT layer的自注意力模块中进行提取，最后得到深度特征，最后经过输出层输出。下面详细介绍这个步骤。
        
        
        
    
   <img src="ipyphoto/BERT.png" alt="Drawing" style="width: 400px;" align="mid"/>
                      

                               
   <div align="center"><i>图2： BERT结构 </i></div>
   
   
   
   **一，嵌入层（embedding layer）**：
   
   嵌入，有时候又称为向量化。就是把输入映射为高维的向量。比如输入一个词，一般只有一个编号，是一维的。在嵌入层会被转换成高维的向量。
   
   <img src="ipyphoto/trans_emb.jpg" alt="Drawing" style="width: 500px;" align="mid"/>
   
                                               
   <div align="center"><i>图3： Transformer嵌入层 </i></div>
   BERT源自Transformer，他们的嵌入层也是很相似的。上图是Transformer的嵌入层，我们可以看到，当一个分词变为输入时，首先要经过词嵌入（Embedding），变为长度为hidden_size的向量（图中为6，实际一般很长），字所在的位置也要进行嵌入。两个嵌入向量直接数值相加，我们就得到了嵌入层的输出。
   
   下图则是BERT的嵌入层，它的输出来源于三个嵌入向量相加。我们可以看到有两个区别。1：增加了句子嵌入（segment Embeddings），因为BERT的输入是两个句子，因此要在这里用句子嵌入值标识出是哪一句。2：出现了字符token，如E[cls],E[sep]。CLStoken一般用来统计全局的信息，最后可以用此token的特征进行下游分类任务。SEPtoken则表示句子的分割和中止。  其实还存在着第三个区别：Transformer的位置嵌入是固定的，也就是公式算出的值，而BERT的位置嵌入则是模型训练得到的。（其他嵌入也都是训练得到）。
   
   <img src="ipyphoto/BERT_emb.jpg" alt="Drawing" style="width: 700px;" align="mid"/>
   
                                             
           
  <div align="center"><i>图4： BERT嵌入层 </i></div>
   
   
   
   **二，多层BERT layer（BERT layers）**：
   
   这部分是BERT的核心部分，一般BERT系列的模型都会堆叠多个BERT层。首先需要知道的是，每个BERT layer ，都是不改变特征的维度的（如下图）。由于输入和输出维度一直相同，因此可以堆叠无数层。根据模型的大小，层数会有变化，BERT的base模型堆叠了12层，large模型堆叠了24层。而基础tinyBERT模型，只有4层。
   
   <img src="ipyphoto/BERT_layer.png" alt="Drawing" style="width: 700px;" align="mid"/>
   
<div align="center"><i>图5： 多层BERT layer </i></div>


   单个BERT layer是由多头自注意力层和MLP组成的。首先介绍多头自注意力层。自注意力层，有时候写作Transformer blocks，是进行特征交互，提取的关键部分。如下图所示。 图中的a1，a2与上图是对应的，指的就是一个token的特征。每一个token，都会分别经过三个不同的线性映射（也就是三个全连接），得到query,key和value(q,k,v)。然后对于每一个token，它的q会与其他token的k相乘得到一个权重，对应的 v 按照这些权重加起来，就得到了这个token的输出。
   
   <img src="ipyphoto/self_att.png" alt="Drawing" style="width: 700px;" align="mid"/>
   <div align="center"><i>图6： 自注意力机制 </i></div>
   
                                          
   一般在模型中使用的是多头注意力机制。通过上图，我们了解到了注意力机制是如何工作的，那么多头注意力机制其实非常简单，就是将一个长维度的向量，分发到多个head中，多个token的向量在对应的head内计算输出，最后合起来。 举例说明：如下图。 一个长为6的特征 被分到3个注意力头中。每个头中，仅需长度为6/3=2的特征计算自注意力，最后得到3个长度为2的特征，再拼在一起就得到了输出。维度和输入相同，每一个token都这样操作，因此能保证输入输出的维度不变。
      <img src="ipyphoto/multihead.png" alt="Drawing" style="width: 700px;" align="mid"/>
      <div align="center"><i> 图7： 多头自注意力 </i></div>
   
                                         
   向量经过多头注意力层编码后，会经过MLP层。mlp层是两次线性映射，首先通过一个全连接从长度$L$的向量变为$L*ratio$，然后再通过一个全连接从$L*ratio$ 到$L$。当然在BERT层中也加入了残差连接的结构。
   
   用一个实例的参数来回顾多头注意力层：对于一个BERT-base模型，他的特征长度L为768，注意力头的个数为12.这样每个头计算的特征长度就为64. 而ratio值为4.也就是先从768映射到到3072，再从3072到768.而他的输入token数量，最大为512。
   
   
   **三，输出层（pooler out）**：
   输出前面已经提到了，与输入是一样的。 如输入是$L_{token} *L_{emb}$ 那么输出依然是$L_{token} *L_{emb}$（$L_{token}$:token数量，$L_{emb}$：特征维度）。你可以认为输入和输出是一一对应的，也可以认为他们并不对应。BERT的工作其实到这里就结束了，我们要得到的就是一个编码器而已。 
   
   输出层又作 池化输出。也就是将$L_{token} *L_{emb}$ 的特征池化为$1*L_{emb}$。一般常用的是将第一个token的特征作为池化输出，当然也可以采取平均池化等方式得到。
   
   BERT的两个预训练任务，其实也表示了bert完成生成和分类两个下游任务的一般方式。 MLM任务： 可以取遮盖token对应的输出token的特征，通过一个分类，得到输出的词，这样可以做生成任务。SEP任务： 取第一个token的特征，也就是CLStoken 进行二分类。
   
   
   
   
这就是BERT的整体结构。

# tinyBERT
tinyBERT是缩小版的BERT，与BERT的基础模型BERT-base存在着结构上的差异。此外，与BERT系列其他模型，如BERT-small，BERT-large等，不同的是，tinyBERT并没有采取直接在大规模预训练数据上无监督训练，在下游数据集微调的方法。而是对BERT-base模型进行蒸馏学习，来获取优秀的性能。


## tinyBERT结构：
为了介绍tinyBERT的结构，我们先来读一个BERT的设置文档BERT config，一个config便可以决定一个BERT的结构。
     
     tinyBERT:
    {
      "hidden_size": 384,                      #决定token被编码的长度，即特征长度，$L_{emb}$ 

      "intermediate_size": 1536,                # MLP层第一次映射的长度，这里特征长度乘以4

      "max_position_embeddings": 512,                # 最大的输入长度。

      "model_type": "tiny_bert",                

      "num_attention_heads": 12,               # 注意力头个数

      "num_hidden_layers": 4,                 # 堆叠多少层

      "vocab_size": 30522                          # 训练词典个数，与训练语料有关

    }

    BERT-base:

    {
      "hidden_size": 768,           

      "intermediate_size": 3072,

      "max_position_embeddings": 512,

      "model_type": "bert",

      "num_attention_heads": 12,

      "num_hidden_layers": 12,

      "vocab_size": 30522

    }
BERT模型的结构主要由上面这些参数决定。其中上方是tinyBERT,下方是BERT-base。我们可以看到他们结构上的不同之处。参数的具体意思，可以参考上面的BERT结构。 
    在tinyBERT中，首先是词被编码的特征维度减少一半，变为384。对应的mlp层的映射层维度也减少一半（保持四倍）。 BERTlayer变为4层。注意力头个数不变，这样每个注意力头中的特征自然也会减少至一半。
    这样tinyBERT自然会参数减少很多。

## 蒸馏学习：

  参数量的降低，一般就会带来性能的降低。为了拥有良好的性能，tinyBERT并不像BERT家族那样在无标注上预训练，而是采取了蒸馏学习的方式进行训练。通过对BERT-base的蒸馏，得到了很好的性能。简单介绍一下蒸馏学习。 在一般的蒸馏学习中，有一个teacher模型和一个student模型。通过让student模型的输出去模仿teacher模型的输出，即让他们的输出靠的更近，来对student模型进行训练。
  
  与一般的蒸馏不同的是，tinyBERT是在模型的前向过程中进行多次蒸馏。而且在整个训练过程也进行多次蒸馏。
  
### 前向过程中的多次蒸馏：
    
   tinyBERT，对于n层的teacher bert，设计了一个mapping function ：n = g ( m )， 将student bert的第m层映射为原来的teacher的第n层，即让tinyBERT的第m层的输出去靠近teacher bert第n层的输出。 其实这个映射函数非常简单，就是$n = k*m$。k就是多少层当作tinyBERT的一层。当m=0时，对应的就是embedding layer。我们可以通过下图理解。图中仅为示例，tinyBERT每层的输出都去蒸馏学习Teacher net三层的输出，就是“一层顶三层”。
   
   <img src="ipyphoto/zhengliu.png" alt="Drawing" style="width: 700px;" align="mid"/>
     <div align="center"><i> 图8： 蒸馏对应层 </i></div>
   
   实际上的BERT-base有12层， 对于4层的tinyBERT，正好是三层对一层。 对于蒸馏学习，我们需要根据两个模型对应层的输出来计算loss，更新模型。从上图中，我们可以看到一共有四种loss，下面分别介绍。
   
   
-   **Embedding-layer distillation**

   <img src="ipyphoto/l_emb.png" alt="Drawing" style="width: 700px;" align="mid"/>
   
  这个是对embedding 矩阵的蒸馏loss，说是矩阵，其实是计算两个模型embedding输出的MSEloss。 而因为student的embedding层的特征维度和Teacher是不一样的，因此要乘上一个转换的映射矩阵$W_e$，此矩阵在模型中是一层全连接，在训练时学习。
   
   
   
   
-   **Attention based distillation and Hidden states based distillation**



  前文我们提到，BERT layer每一层包含两部分，一个是自注意力层，一个是MLP（也就是两层全连接）。attention指的是注意力层注意力分数矩阵，也就是对q和k乘算出来的那个值的蒸馏学习。 下面的hidden states层的蒸馏就是指的对MLP层的输出进行蒸馏学习。

 <img src="ipyphoto/l_hid.png" alt="Drawing" style="width: 700px;" align="mid"/>
   
   
   
   
-   **Prediction-layer distillation**


<img src="ipyphoto/l_pre.png" alt="Drawing" style="width: 700px;" align="mid"/>

  pred蒸馏，是最初的蒸馏方法。是对最终输出层输出结果的softmax进行蒸馏学习。T是蒸馏学习中的温度系数。对输出蒸馏时，采用的是带温度系数的交叉熵loss。
   
 
 
 ### 训练过程中的多次蒸馏：
 
 <img src="ipyphoto/train.png" alt="Drawing" style="width: 700px;" align="mid"/>
    <div align="center"><i> tinyBERT训练过程 </i></div>
     
 tinyBERT 并不是像其他蒸馏那样，直接根据成品的Teacher model在分类时蒸馏，而是去模型BERT的训练过程，在预训练和微调阶段都进行蒸馏。我们知道BERT模型使用时要经过预训练和下游任务微调。所以tinyBERT的蒸馏同样分为两步：General Distillation 与 Task-specific Distillation。前者是对BERT在大规模语料库进行预训练蒸馏学习，后者则是在特定的任务上进行蒸馏学习。值得注意的是，在预训练蒸馏阶段，使用的Teacher模型是仅仅经过预训练未微调的BERT，而在特定任务分类蒸馏训练阶段，使用的Teacher 模型是在特定任务上经过微调的BERT。  



## 总结：

这一部分介绍了tinyBERT结构和训练的过程。可以从参数设置部分看出tinyBERT是如何缩减参数量的，又可以从蒸馏学习部分看出tinyBERT是如何在参数量大大减少的同时保持优秀的性能。下面我们将通过具体的代码对tinyBERT的结构和训练过程进行更加详细的介绍。

# **训练数据准备**

BERT是在大规模无标注数据上进行预训练的，这里采用的是在wiki上下载文章进行预训练。 ms提供了处理这些数据文件的接口，但我们要把数据转为tfrecord格式或者是ms格式。 

https://mp.csdn.net/mp_blog/creation/editor/127061849


## 生成通用蒸馏阶段数据集，即wiki数据。

下载[wiki](http://t.zoukankan.com/dhName-p-11859318.html)数据集进行预训练，

使用[WikiExtractor](https://gitee.com/link?target=https%3A%2F%2Fgithub.com%2Fattardi%2Fwikiextractor)提取和整理数据集中的文本，使用步骤如下：

pip install wikiextractor
python -m wikiextractor.WikiExtractor [Wikipedia dump file] -o [output file path] -b 2G
下载[BERT](https://gitee.com/link?target=https%3A%2F%2Fgithub.com%2Fgoogle-research%2Fbert)代码仓，并下载模型文件[BERT-Base, Uncased](https://gitee.com/link?target=https%3A%2F%2Fstorage.googleapis.com%2Fbert_models%2F2018_10_18%2Funcased_L-12_H-768_A-12.zip)，其中包含了转化需要使用的vocab.txt, bert_config.json和预训练模型

使用create_pretraining_data.py文件，将下载得到的文件转化成tfrecord数据集，详细用法请参考bert的readme文件，其中input_file第2步会生成多个文本文件，请转化为bert0.tfrecord-bertx.tfrecord，如果出现AttributeError: module 'tokenization' has no attribute 'FullTokenizer’，请安装bert-tensorflow

将下载得到的tensorflow模型转化为mindspore模型，注意这个需要环境中同时存在tensorflow 和 mindspore。PATH为bert模型存放位置。

```python3
cd bert/ms2tf
python ms_and_tf_checkpoint_transfer_tools.py --tf_ckpt_path=PATH/model.ckpt　\
    --new_ckpt_path=PATH/ms_model_ckpt.ckpt　\
    --tarnsfer_option=tf2ms
```

## 生成下游任务蒸馏阶段数据集

下载数据集进行微调和评估，如GLUE，使用download_glue_data.py脚本下载SST2, MNLI, QNLI数据集等。后面的训练是以QNLI为例。

将数据集文件从JSON格式转换为TFRecord格式。使用通用蒸馏阶段的第三步BERT代码，在处理QNLI数据时，需要加入QNLI数据集的处理代码。参考readme使用代码仓中的run_classifier.py文件。  run_classifier.py代码中包含了训练，推理和预测的代码，对于转化tfrecord数据集来说，这部分代码是多余的，可以将这部分代码注释掉，只保留转化数据集的代码．其中task_name指定为QNLI，bert_config_file指定为通用蒸馏阶段下载得到的bert_config.json文件，max_seq_length为64. 更详细的下载转换方式见 [数据下载转换指南](https://mp.csdn.net/mp_blog/creation/editor/127061849)

~~~python3
...
class QnliProcessor(DataProcessor):
"""Processor for the QNLI data set (GLUE version)."""

def get_train_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

def get_dev_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "dev.tsv")),
                       "dev_matched")

def get_labels(self):
    """See base class."""
    return ["entailment", "not_entailment"]

def _create_examples(self, lines, set_type):
    """Creates examples for the training and dev sets."""
    examples = []
    for (i, line) in enumerate(lines):
        if i == 0:
            continue
        guid = "%s-%s" % (set_type, line[0])
        text_a = line[1]
        text_b = line[2]
        label = line[-1]
        examples.append(
            InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
    return examples
...
"qnli": QnliProcessor,
...
~~~

# 代码部分
    
## 导入环境
    

In [1]:

import mindspore.nn as nn
from mindspore import context
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.parameter import Parameter
from mindspore.communication.management import get_group_size
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore import Tensor
from mindspore.train.callback import Callback
from mindspore.train.serialization import save_checkpoint
from mindspore.ops import operations as P
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
import mindspore.common.dtype as mstype
from mindspore.common.initializer import TruncatedNormal, initializer
import mindspore.dataset as ds
from mindspore.dataset import transforms
from mindspore.common import set_seed
from mindspore.nn.optim import AdamWeightDecay
from mindspore.train.model import Model
from mindspore.train.callback import TimeMonitor
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell

In [2]:
from enum import Enum
import math
import copy
import re
import os
import numpy as np
import ast
from pprint import pformat
import yaml
import argparse
import datetime

# 模型代码部分

这一部分，会定义所有使用到的模型代码。看代码时，可以对照上面对bert模型的结构介绍来看。


## BERT config
BERT的config是模型构建的依据所在。规定了模型的整体结构。

In [3]:
class BertConfig:
    """
    Configuration for `BertModel`.

    Args:
        seq_length (int): Length of input sequence. Default: 128.
        vocab_size (int): The shape of each embedding vector. Default: 32000.
        hidden_size (int): Size of the bert encoder layers. Default: 768.
        num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder
                           cell. Default: 12.
        num_attention_heads (int): Number of attention heads in the BertTransformer
                             encoder cell. Default: 12.
        intermediate_size (int): Size of intermediate layer in the BertTransformer
                           encoder cell. Default: 3072.
        hidden_act (str): Activation function used in the BertTransformer encoder
                    cell. Default: "gelu".
        hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
        attention_probs_dropout_prob (float): The dropout probability for
                                      BertAttention. Default: 0.1.
        max_position_embeddings (int): Maximum length of sequences used in this
                                 model. Default: 512.
        type_vocab_size (int): Size of token type vocab. Default: 16.
        initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
        use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
        dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
        compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
    """
    def __init__(self,
                 seq_length=128,
                 vocab_size=32000,
                 hidden_size=768,
                 num_hidden_layers=12,
                 num_attention_heads=12,
                 intermediate_size=3072,
                 hidden_act="gelu",
                 hidden_dropout_prob=0.1,
                 attention_probs_dropout_prob=0.1,
                 max_position_embeddings=512,
                 type_vocab_size=16,
                 initializer_range=0.02,
                 use_relative_positions=False,
                 dtype=mstype.float32,
                 compute_type=mstype.float32):
        self.seq_length = seq_length
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.hidden_act = hidden_act
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.initializer_range = initializer_range
        self.use_relative_positions = use_relative_positions
        self.dtype = dtype
        self.compute_type = compute_type

## BERT模型。
bert的整体模型可以对照图片和代码。

<img src="ipyphoto/BERT.png" alt="Drawing" style="width: 400px;" align="mid"/>

对应代码中是，
- EmbeddingPostprocessor 
- BertTransformer
- pooler out

但是ms官方的代码中，并没有严格按照这个来行。 输入的embedding和 pooler 都放在主体部分进行，而 位置和句子的embedding放在EmbeddingPostprocessor中进行。bert layers 在BertTransformer中。

## 位置和句子embedding

In [4]:
class EmbeddingPostprocessor(nn.Cell):
    """
    Postprocessors apply positional and token type embeddings to word embeddings.

    Args:
        embedding_size (int): The size of each embedding vector.
        embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
                         each embedding vector.
        use_token_type (bool): Specifies whether to use token type embeddings. Default: False.
        token_type_vocab_size (int): Size of token type vocab. Default: 16.
       use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
        initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
        max_position_embeddings (int): Maximum length of sequences used in this
                                 model. Default: 512.
        dropout_prob (float): The dropout probability. Default: 0.1.
    """
    def __init__(self,
                 use_relative_positions,
                 embedding_size,
                 embedding_shape,
                 use_token_type=False,
                 token_type_vocab_size=16,
                 use_one_hot_embeddings=False,
                 initializer_range=0.02,
                 max_position_embeddings=512,
                 dropout_prob=0.1):
        super(EmbeddingPostprocessor, self).__init__()
        self.use_token_type = use_token_type
        self.token_type_vocab_size = token_type_vocab_size
        self.use_one_hot_embeddings = use_one_hot_embeddings
        self.max_position_embeddings = max_position_embeddings
        self.token_type_embedding = nn.Embedding(
            vocab_size=token_type_vocab_size,
            embedding_size=embedding_size,
            use_one_hot=use_one_hot_embeddings)
        self.shape_flat = (-1,)
        self.one_hot = P.OneHot()
        self.on_value = Tensor(1.0, mstype.float32)
        self.off_value = Tensor(0.1, mstype.float32)
        self.array_mul = P.MatMul()
        self.reshape = P.Reshape()
        self.shape = tuple(embedding_shape)
        self.dropout = nn.Dropout(1 - dropout_prob)
        self.gather = P.Gather()
        self.use_relative_positions = use_relative_positions
        self.slice = P.StridedSlice()
        _, seq, _ = self.shape
        self.full_position_embedding = nn.Embedding(
            vocab_size=max_position_embeddings,
            embedding_size=embedding_size,
            use_one_hot=False)
        self.layernorm = nn.LayerNorm((embedding_size,))
        self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32))
        self.add = P.Add()

    def construct(self, token_type_ids, word_embeddings):
        """Postprocessors apply positional and token type embeddings to word embeddings."""
        output = word_embeddings
        if self.use_token_type:
            token_type_embeddings = self.token_type_embedding(token_type_ids)
            output = self.add(output, token_type_embeddings)
        if not self.use_relative_positions:
            position_embeddings = self.full_position_embedding(self.position_ids)
            output = self.add(output, position_embeddings)
        output = self.layernorm(output)
        output = self.dropout(output)
        return output


下面是一些关于相对位置和格式转换的代码。 我们在这里不用相对位置编码。

In [5]:

class RelaPosMatrixGenerator(nn.Cell):
    """
    Generates matrix of relative positions between inputs.

    Args:
        length (int): Length of one dim for the matrix to be generated.
        max_relative_position (int): Max value of relative position.
    """
    def __init__(self, length, max_relative_position):
        super(RelaPosMatrixGenerator, self).__init__()
        self._length = length
        self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32)
        self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32)
        self.range_length = -length + 1
        self.tile = P.Tile()
        self.range_mat = P.Reshape()
        self.sub = P.Sub()
        self.expanddims = P.ExpandDims()
        self.cast = P.Cast()

    def construct(self):
        """position matrix generator"""
        range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32)
        range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1))
        tile_row_out = self.tile(range_vec_row_out, (self._length,))
        tile_col_out = self.tile(range_vec_col_out, (1, self._length))
        range_mat_out = self.range_mat(tile_row_out, (self._length, self._length))
        transpose_out = self.range_mat(tile_col_out, (self._length, self._length))
        distance_mat = self.sub(range_mat_out, transpose_out)
        distance_mat_clipped = C.clip_by_value(distance_mat,
                                               self._min_relative_position,
                                               self._max_relative_position)
        # Shift values to be >=0. Each integer still uniquely identifies a
        # relative position difference.
        final_mat = distance_mat_clipped + self._max_relative_position
        return final_mat


class RelaPosEmbeddingsGenerator(nn.Cell):
    """
    Generates tensor of size [length, length, depth].

    Args:
        length (int): Length of one dim for the matrix to be generated.
        depth (int): Size of each attention head.
        max_relative_position (int): Maxmum value of relative position.
        initializer_range (float): Initialization value of TruncatedNormal.
        use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
    """
    def __init__(self,
                 length,
                 depth,
                 max_relative_position,
                 initializer_range,
                 use_one_hot_embeddings=False):
        super(RelaPosEmbeddingsGenerator, self).__init__()
        self.depth = depth
        self.vocab_size = max_relative_position * 2 + 1
        self.use_one_hot_embeddings = use_one_hot_embeddings
        self.embeddings_table = Parameter(
            initializer(TruncatedNormal(initializer_range),
                        [self.vocab_size, self.depth]))
        self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
                                                                max_relative_position=max_relative_position)
        self.reshape = P.Reshape()
        self.one_hot = P.OneHot()
        self.on_value = Tensor(1.0, mstype.float32)
        self.off_value = Tensor(0.0, mstype.float32)
        self.shape = P.Shape()
        self.gather = P.Gather()  # index_select
        self.matmul = P.BatchMatMul()

    def construct(self):
        """position embedding generation"""
        relative_positions_matrix_out = self.relative_positions_matrix()
        # Generate embedding for each relative position of dimension depth.
        if self.use_one_hot_embeddings:
            flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
            one_hot_relative_positions_matrix = self.one_hot(
                flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value)
            embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
            my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
            embeddings = self.reshape(embeddings, my_shape)
        else:
            embeddings = self.gather(self.embeddings_table,
                                     relative_positions_matrix_out, 0)
        return embeddings


class SaturateCast(nn.Cell):
    """
    Performs a safe saturating cast. This operation applies proper clamping before casting to prevent
    the danger that the value will overflow or underflow.

    Args:
        src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32.
        dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32.
    """
    def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
        super(SaturateCast, self).__init__()
        np_type = mstype.dtype_to_nptype(dst_type)
        min_type = np.finfo(np_type).min
        max_type = np.finfo(np_type).max
        self.tensor_min_type = Tensor([min_type], dtype=src_type)
        self.tensor_max_type = Tensor([max_type], dtype=src_type)
        self.min_op = P.Minimum()
        self.max_op = P.Maximum()
        self.cast = P.Cast()
        self.dst_type = dst_type

    def construct(self, x):
        """saturate cast"""
        out = self.max_op(x, self.tensor_min_type)
        out = self.min_op(out, self.tensor_max_type)
        return self.cast(out, self.dst_type)

## BERT layer
bert layer（BertEncoderCell） 主要由  BertSelfAttention 与 BertOutput 即自注意力层和线性映射层组成。

而BertSelfAttention 主要包含 BertAttention 与 BertOutput 即自注意力计算和线性映射。


In [6]:
class BertOutput(nn.Cell):
    """
    Apply a linear computation to hidden status and a residual computation to input.

    Args:
        in_channels (int): Input channels.
        out_channels (int): Output channels.
        initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
        dropout_prob (float): The dropout probability. Default: 0.1.
        compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 initializer_range=0.02,
                 dropout_prob=0.1,
                 compute_type=mstype.float32):
        super(BertOutput, self).__init__()
        self.dense = nn.Dense(in_channels, out_channels,
                              weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
        self.dropout = nn.Dropout(1 - dropout_prob)
        self.add = P.Add()
        self.is_gpu = context.get_context('device_target') == "GPU"
        if self.is_gpu:
            self.layernorm = nn.LayerNorm((out_channels,)).to_float(mstype.float32)
            self.compute_type = compute_type
        else:
            self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)

        self.cast = P.Cast()

    def construct(self, hidden_status, input_tensor):
        """bert output"""
        output = self.dense(hidden_status)
        output = self.dropout(output)
        output = self.add(input_tensor, output)
        output = self.layernorm(output)
        if self.is_gpu:
            output = self.cast(output, self.compute_type)
        return output

    
class BertAttention(nn.Cell):
    """
    Apply multi-headed attention from "from_tensor" to "to_tensor".

    Args:
        from_tensor_width (int): Size of last dim of from_tensor.
        to_tensor_width (int): Size of last dim of to_tensor.
        from_seq_length (int): Length of from_tensor sequence.
        to_seq_length (int): Length of to_tensor sequence.
        num_attention_heads (int): Number of attention heads. Default: 1.
        size_per_head (int): Size of each attention head. Default: 512.
        query_act (str): Activation function for the query transform. Default: None.
        key_act (str): Activation function for the key transform. Default: None.
        value_act (str): Activation function for the value transform. Default: None.
        has_attention_mask (bool): Specifies whether to use attention mask. Default: False.
        attention_probs_dropout_prob (float): The dropout probability for
                                      BertAttention. Default: 0.0.
        use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
        initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
        do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d
                             tensor. Default: False.
        use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
        compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32.
    """
    def __init__(self,
                 from_tensor_width,
                 to_tensor_width,
                 from_seq_length,
                 to_seq_length,
                 num_attention_heads=1,
                 size_per_head=512,
                 query_act=None,
                 key_act=None,
                 value_act=None,
                 has_attention_mask=False,
                 attention_probs_dropout_prob=0.0,
                 use_one_hot_embeddings=False,
                 initializer_range=0.02,
                 do_return_2d_tensor=False,
                 use_relative_positions=False,
                 compute_type=mstype.float32):
        super(BertAttention, self).__init__()
        self.from_seq_length = from_seq_length
        self.to_seq_length = to_seq_length
        self.num_attention_heads = num_attention_heads
        self.size_per_head = size_per_head
        self.has_attention_mask = has_attention_mask
        self.use_relative_positions = use_relative_positions
        self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type)
        self.reshape = P.Reshape()
        self.shape_from_2d = (-1, from_tensor_width)
        self.shape_to_2d = (-1, to_tensor_width)
        weight = TruncatedNormal(initializer_range)
        units = num_attention_heads * size_per_head
        self.query_layer = nn.Dense(from_tensor_width,
                                    units,
                                    activation=query_act,
                                    weight_init=weight).to_float(compute_type)
        self.key_layer = nn.Dense(to_tensor_width,
                                  units,
                                  activation=key_act,
                                  weight_init=weight).to_float(compute_type)
        self.value_layer = nn.Dense(to_tensor_width,
                                    units,
                                    activation=value_act,
                                    weight_init=weight).to_float(compute_type)
        self.shape_from = (-1, from_seq_length, num_attention_heads, size_per_head)
        self.shape_to = (-1, to_seq_length, num_attention_heads, size_per_head)
        self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
        self.multiply = P.Mul()
        self.transpose = P.Transpose()
        self.trans_shape = (0, 2, 1, 3)
        self.trans_shape_relative = (2, 0, 1, 3)
        self.trans_shape_position = (1, 2, 0, 3)
        self.multiply_data = Tensor([-10000.0,], dtype=compute_type)
        self.matmul = P.BatchMatMul()
        self.softmax = nn.Softmax()
        self.dropout = nn.Dropout(1 - attention_probs_dropout_prob)
        if self.has_attention_mask:
            self.expand_dims = P.ExpandDims()
            self.sub = P.Sub()
            self.add = P.Add()
            self.cast = P.Cast()
            self.get_dtype = P.DType()
        if do_return_2d_tensor:
            self.shape_return = (-1, num_attention_heads * size_per_head)
        else:
            self.shape_return = (-1, from_seq_length, num_attention_heads * size_per_head)
        self.cast_compute_type = SaturateCast(dst_type=compute_type)
        if self.use_relative_positions:
            self._generate_relative_positions_embeddings = \
                RelaPosEmbeddingsGenerator(length=to_seq_length,
                                           depth=size_per_head,
                                           max_relative_position=16,
                                           initializer_range=initializer_range,
                                           use_one_hot_embeddings=use_one_hot_embeddings)

    def construct(self, from_tensor, to_tensor, attention_mask):
        """bert attention"""
        # reshape 2d/3d input tensors to 2d
        from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
        to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
        query_out = self.query_layer(from_tensor_2d)
        key_out = self.key_layer(to_tensor_2d)
        value_out = self.value_layer(to_tensor_2d)
        query_layer = self.reshape(query_out, self.shape_from)
        query_layer = self.transpose(query_layer, self.trans_shape)
        key_layer = self.reshape(key_out, self.shape_to)
        key_layer = self.transpose(key_layer, self.trans_shape)
        attention_scores = self.matmul_trans_b(query_layer, key_layer)
        # use_relative_position, supplementary logic
        if self.use_relative_positions:
            # relations_keys is [F|T, F|T, H]
            relations_keys = self._generate_relative_positions_embeddings()
            relations_keys = self.cast_compute_type(relations_keys)
            # query_layer_t is [F, B, N, H]
            query_layer_t = self.transpose(query_layer, self.trans_shape_relative)
            # query_layer_r is [F, B * N, H]
            query_layer_r = self.reshape(query_layer_t,
                                         (self.from_seq_length,
                                          -1,
                                          self.size_per_head))
            # key_position_scores is [F, B * N, F|T]
            key_position_scores = self.matmul_trans_b(query_layer_r,
                                                      relations_keys)
            # key_position_scores_r is [F, B, N, F|T]
            key_position_scores_r = self.reshape(key_position_scores,
                                                 (self.from_seq_length,
                                                  -1,
                                                  self.num_attention_heads,
                                                  self.from_seq_length))
            # key_position_scores_r_t is [B, N, F, F|T]
            key_position_scores_r_t = self.transpose(key_position_scores_r,
                                                     self.trans_shape_position)
            attention_scores = attention_scores + key_position_scores_r_t
        attention_scores = self.multiply(self.scores_mul, attention_scores)
        if self.has_attention_mask:
            attention_mask = self.expand_dims(attention_mask, 1)
            multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)),
                                    self.cast(attention_mask, self.get_dtype(attention_scores)))
            adder = self.multiply(multiply_out, self.multiply_data)
            attention_scores = self.add(adder, attention_scores)
        attention_probs = self.softmax(attention_scores)
        attention_probs = self.dropout(attention_probs)
        value_layer = self.reshape(value_out, self.shape_to)
        value_layer = self.transpose(value_layer, self.trans_shape)
        context_layer = self.matmul(attention_probs, value_layer)
        # use_relative_position, supplementary logic
        if self.use_relative_positions:
            # relations_values is [F|T, F|T, H]
            relations_values = self._generate_relative_positions_embeddings()
            relations_values = self.cast_compute_type(relations_values)
            # attention_probs_t is [F, B, N, T]
            attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative)
            # attention_probs_r is [F, B * N, T]
            attention_probs_r = self.reshape(
                attention_probs_t,
                (self.from_seq_length,
                 -1,
                 self.to_seq_length))
            # value_position_scores is [F, B * N, H]
            value_position_scores = self.matmul(attention_probs_r,
                                                relations_values)
            # value_position_scores_r is [F, B, N, H]
            value_position_scores_r = self.reshape(value_position_scores,
                                                   (self.from_seq_length,
                                                    -1,
                                                    self.num_attention_heads,
                                                    self.size_per_head))
            # value_position_scores_r_t is [B, N, F, H]
            value_position_scores_r_t = self.transpose(value_position_scores_r,
                                                       self.trans_shape_position)
            context_layer = context_layer + value_position_scores_r_t
        context_layer = self.transpose(context_layer, self.trans_shape)
        context_layer = self.reshape(context_layer, self.shape_return)
        return context_layer, attention_scores

class BertSelfAttention(nn.Cell):
    """
    Apply self-attention.

    Args:
        seq_length (int): Length of input sequence.
        hidden_size (int): Size of the bert encoder layers.
        num_attention_heads (int): Number of attention heads. Default: 12.
        attention_probs_dropout_prob (float): The dropout probability for
                                      BertAttention. Default: 0.1.
        use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False.
        initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
        hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
        use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
        compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32.
    """
    def __init__(self,
                 seq_length,
                 hidden_size,
                 num_attention_heads=12,
                 attention_probs_dropout_prob=0.1,
                 use_one_hot_embeddings=False,
                 initializer_range=0.02,
                 hidden_dropout_prob=0.1,
                 use_relative_positions=False,
                 compute_type=mstype.float32):
        super(BertSelfAttention, self).__init__()
        if hidden_size % num_attention_heads != 0:
            raise ValueError("The hidden size (%d) is not a multiple of the number "
                             "of attention heads (%d)" % (hidden_size, num_attention_heads))
        self.size_per_head = int(hidden_size / num_attention_heads)
        self.attention = BertAttention(
            from_tensor_width=hidden_size,
            to_tensor_width=hidden_size,
            from_seq_length=seq_length,
            to_seq_length=seq_length,
            num_attention_heads=num_attention_heads,
            size_per_head=self.size_per_head,
            attention_probs_dropout_prob=attention_probs_dropout_prob,
            use_one_hot_embeddings=use_one_hot_embeddings,
            initializer_range=initializer_range,
            use_relative_positions=use_relative_positions,
            has_attention_mask=True,
            do_return_2d_tensor=True,
            compute_type=compute_type)
        self.output = BertOutput(in_channels=hidden_size,
                                 out_channels=hidden_size,
                                 initializer_range=initializer_range,
                                 dropout_prob=hidden_dropout_prob,
                                 compute_type=compute_type)
        self.reshape = P.Reshape()
        self.shape = (-1, hidden_size)

    def construct(self, input_tensor, attention_mask):
        """bert self attention"""
        input_tensor = self.reshape(input_tensor, self.shape)
        attention_output, attention_scores = self.attention(input_tensor, input_tensor, attention_mask)
        output = self.output(attention_output, input_tensor)
        return output, attention_scores


class BertEncoderCell(nn.Cell):
    """
    Encoder cells used in BertTransformer.

    Args:
        hidden_size (int): Size of the bert encoder layers. Default: 768.
        seq_length (int): Length of input sequence. Default: 512.
        num_attention_heads (int): Number of attention heads. Default: 12.
        intermediate_size (int): Size of intermediate layer. Default: 3072.
        attention_probs_dropout_prob (float): The dropout probability for
                                      BertAttention. Default: 0.02.
        use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
        initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
        hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
        use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
        hidden_act (str): Activation function. Default: "gelu".
        compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32.
    """
    def __init__(self,
                 hidden_size=768,
                 seq_length=512,
                 num_attention_heads=12,
                 intermediate_size=3072,
                 attention_probs_dropout_prob=0.02,
                 use_one_hot_embeddings=False,
                 initializer_range=0.02,
                 hidden_dropout_prob=0.1,
                 use_relative_positions=False,
                 hidden_act="gelu",
                 compute_type=mstype.float32):
        super(BertEncoderCell, self).__init__()
        self.attention = BertSelfAttention(
            hidden_size=hidden_size,
            seq_length=seq_length,
            num_attention_heads=num_attention_heads,
            attention_probs_dropout_prob=attention_probs_dropout_prob,
            use_one_hot_embeddings=use_one_hot_embeddings,
            initializer_range=initializer_range,
            hidden_dropout_prob=hidden_dropout_prob,
            use_relative_positions=use_relative_positions,
            compute_type=compute_type)
        self.intermediate = nn.Dense(in_channels=hidden_size,
                                     out_channels=intermediate_size,
                                     activation=hidden_act,
                                     weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
        self.output = BertOutput(in_channels=intermediate_size,
                                 out_channels=hidden_size,
                                 initializer_range=initializer_range,
                                 dropout_prob=hidden_dropout_prob,
                                 compute_type=compute_type)
    def construct(self, hidden_states, attention_mask):
        """bert encoder cell"""
        # self-attention
        attention_output, attention_scores = self.attention(hidden_states, attention_mask)
        # feed construct
        intermediate_output = self.intermediate(attention_output)
        # add and normalize
        output = self.output(intermediate_output, attention_output)
        return output, attention_scores


## BERT layers 

BERT layer的堆叠。

In [7]:
class BertTransformer(nn.Cell):
    """
    Multi-layer bert transformer.

    Args:
        hidden_size (int): Size of the encoder layers.
        seq_length (int): Length of input sequence.
        num_hidden_layers (int): Number of hidden layers in encoder cells.
        num_attention_heads (int): Number of attention heads in encoder cells. Default: 12.
        intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072.
        attention_probs_dropout_prob (float): The dropout probability for
                                      BertAttention. Default: 0.1.
        use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
        initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
        hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
        use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
        hidden_act (str): Activation function used in the encoder cells. Default: "gelu".
        compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
        return_all_encoders (bool): Specifies whether to return all encoders. Default: False.
    """
    def __init__(self,
                 hidden_size,
                 seq_length,
                 num_hidden_layers,
                 num_attention_heads=12,
                 intermediate_size=3072,
                 attention_probs_dropout_prob=0.1,
                 use_one_hot_embeddings=False,
                 initializer_range=0.02,
                 hidden_dropout_prob=0.1,
                 use_relative_positions=False,
                 hidden_act="gelu",
                 compute_type=mstype.float32,
                 return_all_encoders=False):
        super(BertTransformer, self).__init__()
        self.return_all_encoders = return_all_encoders
        layers = []
        for _ in range(num_hidden_layers):
            layer = BertEncoderCell(hidden_size=hidden_size,
                                    seq_length=seq_length,
                                    num_attention_heads=num_attention_heads,
                                    intermediate_size=intermediate_size,
                                    attention_probs_dropout_prob=attention_probs_dropout_prob,
                                    use_one_hot_embeddings=use_one_hot_embeddings,
                                    initializer_range=initializer_range,
                                    hidden_dropout_prob=hidden_dropout_prob,
                                    use_relative_positions=use_relative_positions,
                                    hidden_act=hidden_act,
                                    compute_type=compute_type)
            layers.append(layer)
        self.layers = nn.CellList(layers)
        self.reshape = P.Reshape()
        self.shape = (-1, hidden_size)
        self.out_shape = (-1, seq_length, hidden_size)
    def construct(self, input_tensor, attention_mask):
        """bert transformer"""
        prev_output = self.reshape(input_tensor, self.shape)
        all_encoder_layers = ()
        all_encoder_atts = ()
        all_encoder_outputs = ()
        all_encoder_outputs += (prev_output,)
        for layer_module in self.layers:
            layer_output, encoder_att = layer_module(prev_output, attention_mask)
            prev_output = layer_output
            if self.return_all_encoders:
                all_encoder_outputs += (layer_output,)
                layer_output = self.reshape(layer_output, self.out_shape)
                all_encoder_layers += (layer_output,)
                all_encoder_atts += (encoder_att,)
        if not self.return_all_encoders:
            prev_output = self.reshape(prev_output, self.out_shape)
            all_encoder_layers += (prev_output,)
        return all_encoder_layers, all_encoder_outputs, all_encoder_atts


可以选择在模型内或者外 创建attension mask。

In [8]:
class CreateAttentionMaskFromInputMask(nn.Cell):
    """
    Create attention mask according to input mask.

    Args:
        config (Class): Configuration for BertModel.
    """
    def __init__(self, config):
        super(CreateAttentionMaskFromInputMask, self).__init__()
        self.input_mask = None
        self.cast = P.Cast()
        self.reshape = P.Reshape()
        self.shape = (-1, 1, config.seq_length)

    def construct(self, input_mask):
        attention_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
        return attention_mask


## BERT主模型.

上面规定了BERT的组件，这些组件将在BERT主模型中拼接起来形成BERT模型


In [9]:
class BertModel(nn.Cell):
    """
    Bidirectional Encoder Representations from Transformers.

    Args:
        config (Class): Configuration for BertModel.
        is_training (bool): True for training mode. False for eval mode.
        use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
    """
    def __init__(self,
                 config,
                 is_training,
                 use_one_hot_embeddings=False):
        super(BertModel, self).__init__()
        config = copy.deepcopy(config)
        if not is_training:
            config.hidden_dropout_prob = 0.0
            config.attention_probs_dropout_prob = 0.0
        self.seq_length = config.seq_length
        self.hidden_size = config.hidden_size
        self.num_hidden_layers = config.num_hidden_layers
        self.embedding_size = config.hidden_size
        self.token_type_ids = None
        self.last_idx = self.num_hidden_layers - 1
        output_embedding_shape = [-1, self.seq_length,
                                  self.embedding_size]
        self.bert_embedding_lookup = nn.Embedding(
            vocab_size=config.vocab_size,
            embedding_size=self.embedding_size,
            use_one_hot=use_one_hot_embeddings)
        self.bert_embedding_postprocessor = EmbeddingPostprocessor(
            use_relative_positions=config.use_relative_positions,
            embedding_size=self.embedding_size,
            embedding_shape=output_embedding_shape,
            use_token_type=True,
            token_type_vocab_size=config.type_vocab_size,
            use_one_hot_embeddings=use_one_hot_embeddings,
            initializer_range=0.02,
            max_position_embeddings=config.max_position_embeddings,
            dropout_prob=config.hidden_dropout_prob)
        self.bert_encoder = BertTransformer(
            hidden_size=self.hidden_size,
            seq_length=self.seq_length,
            num_attention_heads=config.num_attention_heads,
            num_hidden_layers=self.num_hidden_layers,
            intermediate_size=config.intermediate_size,
            attention_probs_dropout_prob=config.attention_probs_dropout_prob,
            use_one_hot_embeddings=use_one_hot_embeddings,
            initializer_range=config.initializer_range,
            hidden_dropout_prob=config.hidden_dropout_prob,
            use_relative_positions=config.use_relative_positions,
            hidden_act=config.hidden_act,
            compute_type=config.compute_type,
            return_all_encoders=True)
        self.cast = P.Cast()
        self.dtype = config.dtype
        self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
        self.slice = P.StridedSlice()
        self.squeeze_1 = P.Squeeze(axis=1)
        self.dense = nn.Dense(self.hidden_size, self.hidden_size,
                              activation="tanh",
                              weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
        self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)

    def construct(self, input_ids, token_type_ids, input_mask):
        """bert model"""
        # embedding
        embedding_tables = self.bert_embedding_lookup.embedding_table
        word_embeddings = self.bert_embedding_lookup(input_ids)
        embedding_output = self.bert_embedding_postprocessor(token_type_ids, word_embeddings)
        # attention mask [batch_size, seq_length, seq_length]
        attention_mask = self._create_attention_mask_from_input_mask(input_mask)
        # bert encoder
        encoder_output, encoder_layers, layer_atts = self.bert_encoder(self.cast_compute_type(embedding_output),
                                                                       attention_mask)
        sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
        # pooler
        batch_size = P.Shape()(input_ids)[0]
        sequence_slice = self.slice(sequence_output,
                                    (0, 0, 0),
                                    (batch_size, 1, self.hidden_size),
                                    (1, 1, 1))
        first_token = self.squeeze_1(sequence_slice)
        pooled_output = self.dense(first_token)
        pooled_output = self.cast(pooled_output, self.dtype)
        encoder_outputs = ()
        for output in encoder_layers:
            encoder_outputs += (self.cast(output, self.dtype),)
        attention_outputs = ()
        for output in layer_atts:
            attention_outputs += (self.cast(output, self.dtype),)
        return sequence_output, pooled_output, embedding_tables, encoder_outputs, attention_outputs


## BERT应用

BERT用于分类和NER两个任务的模型。

In [10]:
class BertModelCLS(nn.Cell):
    """
    This class is responsible for classification task evaluation,
    i.e. XNLI(num_labels=3), LCQMC(num_labels=2), Chnsenti(num_labels=2).
    The returned output represents the final logits as the results of log_softmax is proportional to that of softmax.
    """
    def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0,
                 use_one_hot_embeddings=False, phase_type="student"):
        super(BertModelCLS, self).__init__()
        self.bert = BertModel(config, is_training, use_one_hot_embeddings)
        self.cast = P.Cast()
        self.weight_init = TruncatedNormal(config.initializer_range)
        self.log_softmax = P.LogSoftmax(axis=-1)
        self.dtype = config.dtype
        self.num_labels = num_labels
        self.phase_type = phase_type
        self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
                                has_bias=True).to_float(config.compute_type)
        self.dropout = nn.ReLU()

    def construct(self, input_ids, token_type_id, input_mask):
        """classification bert model"""
        _, pooled_output, _, seq_output, att_output = self.bert(input_ids, token_type_id, input_mask)
        cls = self.cast(pooled_output, self.dtype)
        cls = self.dropout(cls)
        logits = self.dense_1(cls)
        logits = self.cast(logits, self.dtype)
        log_probs = self.log_softmax(logits)
        if self._phase == 'train' or self.phase_type == "teacher":
            return seq_output, att_output, logits, log_probs
        return log_probs

class BertModelNER(nn.Cell):
    """
    This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11).
    The returned output represents the final logits as the results of log_softmax is proportional to that of softmax.
    """
    def __init__(self, config, is_training, num_labels=11, dropout_prob=0.0,
                 use_one_hot_embeddings=False, phase_type="student"):
        super(BertModelNER, self).__init__()
        if not is_training:
            config.hidden_dropout_prob = 0.0
            config.hidden_probs_dropout_prob = 0.0
        self.bert = BertModel(config, is_training, use_one_hot_embeddings)
        self.cast = P.Cast()
        self.weight_init = TruncatedNormal(config.initializer_range)
        self.log_softmax = P.LogSoftmax(axis=-1)
        self.dtype = config.dtype
        self.num_labels = num_labels
        self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
                                has_bias=True).to_float(config.compute_type)
        self.dropout = nn.ReLU()
        self.reshape = P.Reshape()
        self.shape = (-1, config.hidden_size)
        self.origin_shape = (-1, config.seq_length, self.num_labels)
        self.phase_type = phase_type

    def construct(self, input_ids, input_mask, token_type_id):
        """Return the final logits as the results of log_softmax."""
        sequence_output, _, _, encoder_outputs, attention_outputs = \
            self.bert(input_ids, token_type_id, input_mask)
        seq = self.dropout(sequence_output)
        seq = self.reshape(seq, self.shape)
        logits = self.dense_1(seq)
        logits = self.cast(logits, self.dtype)
        return_value = self.log_softmax(logits)
        if self._phase == 'train' or self.phase_type == "teacher":
            return encoder_outputs, attention_outputs, logits, return_value
        return return_value


## tinyBERT 预训练蒸馏模型和下游蒸馏模型。

下面是tinyBERT蒸馏时用到的几个模型。 tinyBERT本身可以使用BERT的模型，导入tinyBERT的config即可。

注意在ms中，可以将梯度计算和梯度传递写进模型中，下面有几个模型就是进行这一项工作的。

先规定一些处理梯度的杂项。

In [11]:
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0

clip_grad = C.MultitypeFuncGraph("clip_grad")
@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
    """
    Clip gradients.

    Inputs:
        clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
        clip_value (float): Specifies how much to clip.
        grad (tuple[Tensor]): Gradients.

    Outputs:
        tuple[Tensor], clipped gradients.
    """
    if clip_type not in (0, 1):
        return grad
    dt = F.dtype(grad)
    if clip_type == 0:
        new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
                                   F.cast(F.tuple_to_array((clip_value,)), dt))
    else:
        new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
    return new_grad

grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()

@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
    return grad * reciprocal(scale)

class ClipGradients(nn.Cell):
    """
    Clip gradients.

    Args:
        grads (list): List of gradient tuples.
        clip_type (Tensor): The way to clip, 'value' or 'norm'.
        clip_value (Tensor): Specifies how much to clip.

    Returns:
        List, a list of clipped_grad tuples.
    """
    def __init__(self):
        super(ClipGradients, self).__init__()
        self.clip_by_norm = nn.ClipByNorm()
        self.cast = P.Cast()
        self.dtype = P.DType()

    def construct(self,
                  grads,
                  clip_type,
                  clip_value):
        """clip gradients"""
        if clip_type not in (0, 1):
            return grads
        new_grads = ()
        for grad in grads:
            dt = self.dtype(grad)
            if clip_type == 0:
                t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
                                    self.cast(F.tuple_to_array((clip_value,)), dt))
            else:
                t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
            new_grads = new_grads + (t,)
        return new_grads

class SoftCrossEntropy(nn.Cell):
    """SoftCrossEntropy loss"""
    def __init__(self):
        super(SoftCrossEntropy, self).__init__()
        self.log_softmax = P.LogSoftmax(axis=-1)
        self.softmax = P.Softmax(axis=-1)
        self.reduce_mean = P.ReduceMean()
        self.cast = P.Cast()

    def construct(self, predicts, targets):
        likelihood = self.log_softmax(predicts)
        target_prob = self.softmax(targets)
        loss = self.reduce_mean(-target_prob * likelihood)

        return self.cast(loss, mstype.float32)

##  预训练蒸馏模型
下面是预训练蒸馏时所用到的计算loss的模型。 注意层级之间的对应关系。

In [12]:
class BertNetworkWithLoss_gd(nn.Cell):
    """
    Provide bert pre-training loss through network.
    Args:
        config (BertConfig): The config of BertModel.
        is_training (bool): Specifies whether to use the training mode.
        use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
    Returns:
        Tensor, the loss of the network.
    """
    def __init__(self, teacher_config, teacher_ckpt, student_config, is_training, use_one_hot_embeddings=False,
                 is_att_fit=True, is_rep_fit=True):
        super(BertNetworkWithLoss_gd, self).__init__()
        # load teacher model
        self.teacher = BertModel(teacher_config, False, use_one_hot_embeddings)
        param_dict = load_checkpoint(teacher_ckpt)
        new_param_dict = {}
        for key, value in param_dict.items():
            new_key = re.sub('^bert.bert.', 'teacher.', key)
            new_param_dict[new_key] = value
        load_param_into_net(self.teacher, new_param_dict)
        # no_grad
        self.teacher.set_train(False)
        params = self.teacher.trainable_params()
        for param in params:
            param.requires_grad = False
        # student model
        self.bert = BertModel(student_config, is_training, use_one_hot_embeddings)
        self.cast = P.Cast()
        self.fit_dense = nn.Dense(student_config.hidden_size,
                                  teacher_config.hidden_size).to_float(teacher_config.compute_type)
        self.teacher_layers_num = teacher_config.num_hidden_layers
        self.student_layers_num = student_config.num_hidden_layers
        self.layers_per_block = int(self.teacher_layers_num / self.student_layers_num)
        self.is_att_fit = is_att_fit
        self.is_rep_fit = is_rep_fit
        self.loss_mse = nn.MSELoss()
        self.select = P.Select()
        self.zeroslike = P.ZerosLike()
        self.dtype = teacher_config.dtype

    def construct(self,
                  input_ids,
                  input_mask,
                  token_type_id):
        """general distill network with loss"""
        # teacher model
        _, _, _, teacher_seq_output, teacher_att_output = self.teacher(input_ids, token_type_id, input_mask)
        # student model
        _, _, _, student_seq_output, student_att_output = self.bert(input_ids, token_type_id, input_mask)
        total_loss = 0
        if self.is_att_fit:
            selected_teacher_att_output = ()
            selected_student_att_output = ()
            for i in range(self.student_layers_num):
                selected_teacher_att_output += (teacher_att_output[(i + 1) * self.layers_per_block - 1],)
                selected_student_att_output += (student_att_output[i],)
            att_loss = 0
            for i in range(self.student_layers_num):
                student_att = selected_student_att_output[i]
                teacher_att = selected_teacher_att_output[i]
                student_att = self.select(student_att <= self.cast(-100.0, mstype.float32), self.zeroslike(student_att),
                                          student_att)
                teacher_att = self.select(teacher_att <= self.cast(-100.0, mstype.float32), self.zeroslike(teacher_att),
                                          teacher_att)
                att_loss += self.loss_mse(student_att, teacher_att)
            total_loss += att_loss
        if self.is_rep_fit:
            selected_teacher_seq_output = ()
            selected_student_seq_output = ()
            for i in range(self.student_layers_num + 1):
                selected_teacher_seq_output += (teacher_seq_output[i * self.layers_per_block],)
                fit_dense_out = self.fit_dense(student_seq_output[i])
                fit_dense_out = self.cast(fit_dense_out, self.dtype)
                selected_student_seq_output += (fit_dense_out,)
            rep_loss = 0
            for i in range(self.student_layers_num + 1):
                teacher_rep = selected_teacher_seq_output[i]
                student_rep = selected_student_seq_output[i]
                rep_loss += self.loss_mse(student_rep, teacher_rep)
            total_loss += rep_loss
        return self.cast(total_loss, mstype.float32)

class BertTrainWithLossScaleCell(nn.Cell):
    """
    Encapsulation class of bert network training.

    Append an optimizer to the training network after that the construct
    function can be called to create the backward graph.

    Args:
        network (Cell): The training network. Note that loss function should have been added.
        optimizer (Optimizer): Optimizer for updating the weights.
        scale_update_cell (Cell): Cell to do the loss scale. Default: None.
    """
    def __init__(self, network, optimizer, scale_update_cell=None):
        super(BertTrainWithLossScaleCell, self).__init__(auto_prefix=False)
        self.network = network
        self.network.set_grad()
        self.weights = optimizer.parameters
        self.optimizer = optimizer
        self.grad = C.GradOperation(get_by_list=True,
                                    sens_param=True)
        self.reducer_flag = False
        self.allreduce = P.AllReduce()
        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
            self.reducer_flag = True
        self.grad_reducer = F.identity
        self.degree = 1
        if self.reducer_flag:
            self.degree = get_group_size()
            self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
        self.cast = P.Cast()
        self.alloc_status = P.NPUAllocFloatStatus()
        self.get_status = P.NPUGetFloatStatus()
        self.clear_status = P.NPUClearFloatStatus()
        self.reduce_sum = P.ReduceSum(keep_dims=False)
        self.base = Tensor(1, mstype.float32)
        self.less_equal = P.LessEqual()
        self.hyper_map = C.HyperMap()
        self.loss_scale = None
        self.loss_scaling_manager = scale_update_cell
        if scale_update_cell:
            self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))

    def construct(self,
                  input_ids,
                  input_mask,
                  token_type_id,
                  sens=None):
        """Defines the computation performed."""
        weights = self.weights
        loss = self.network(input_ids,
                            input_mask,
                            token_type_id)
        if sens is None:
            scaling_sens = self.loss_scale
        else:
            scaling_sens = sens
        # alloc status and clear should be right before gradoperation
        init = self.alloc_status()
        init = F.depend(init, loss)
        clear_status = self.clear_status(init)
        scaling_sens = F.depend(scaling_sens, clear_status)
        grads = self.grad(self.network, weights)(input_ids,
                                                 input_mask,
                                                 token_type_id,
                                                 self.cast(scaling_sens,
                                                           mstype.float32))
        # apply grad reducer on grads
        grads = self.grad_reducer(grads)
        grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
        grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
        init = F.depend(init, grads)
        get_status = self.get_status(init)
        init = F.depend(init, get_status)
        flag_sum = self.reduce_sum(init, (0,))
        if self.is_distributed:
            # sum overflow flag over devices
            flag_reduce = self.allreduce(flag_sum)
            cond = self.less_equal(self.base, flag_reduce)
        else:
            cond = self.less_equal(self.base, flag_sum)
        overflow = cond
        if sens is None:
            overflow = self.loss_scaling_manager(self.loss_scale, cond)
        if not overflow:
            self.optimizer(grads)
        return (loss, cond, scaling_sens)

class BertTrainCell(nn.Cell):
    """
    Encapsulation class of bert network training.

    Append an optimizer to the training network after that the construct
    function can be called to create the backward graph.

    Args:
        network (Cell): The training network. Note that loss function should have been added.
        optimizer (Optimizer): Optimizer for updating the weights.
        sens (Number): The adjust parameter. Default: 1.0.
    """
    def __init__(self, network, optimizer, sens=1.0):
        super(BertTrainCell, self).__init__(auto_prefix=False)
        self.network = network
        self.network.set_grad()
        self.weights = optimizer.parameters
        self.optimizer = optimizer
        self.sens = sens
        self.grad = C.GradOperation(get_by_list=True,
                                    sens_param=True)
        self.reducer_flag = False
        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
            self.reducer_flag = True
        self.grad_reducer = F.identity
        self.degree = 1
        if self.reducer_flag:
            mean = context.get_auto_parallel_context("gradients_mean")
            self.degree = get_group_size()
            self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, self.degree)
        self.cast = P.Cast()
        self.hyper_map = C.HyperMap()

    def construct(self,
                  input_ids,
                  input_mask,
                  token_type_id):
        """Defines the computation performed."""
        weights = self.weights
        loss = self.network(input_ids,
                            input_mask,
                            token_type_id)
        grads = self.grad(self.network, weights)(input_ids,
                                                 input_mask,
                                                 token_type_id,
                                                 self.cast(F.tuple_to_array((self.sens,)),
                                                           mstype.float32))
        # apply grad reducer on grads
        grads = self.grad_reducer(grads)
        grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
        self.optimizer(grads)
        return loss


## 下游蒸馏模型

下面是下游蒸馏时所用到的计算loss的模型。

BertEvaluationCell中规定了梯度的回传。

In [13]:
class BertNetworkWithLoss_td(nn.Cell):
    """
    Provide bert pre-training loss through network.
    Args:
        config (BertConfig): The config of BertModel.
        is_training (bool): Specifies whether to use the training mode.
        use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
    Returns:
        Tensor, the loss of the network.
    """
    def __init__(self, teacher_config, teacher_ckpt, student_config, student_ckpt,
                 is_training, task_type, num_labels, use_one_hot_embeddings=False,
                 is_predistill=True, is_att_fit=True, is_rep_fit=True,
                 temperature=1.0, dropout_prob=0.1):
        super(BertNetworkWithLoss_td, self).__init__()
        # load teacher model
        if task_type == "classification":
            self.teacher = BertModelCLS(teacher_config, False, num_labels, dropout_prob,
                                        use_one_hot_embeddings, "teacher")
            self.bert = BertModelCLS(student_config, is_training, num_labels, dropout_prob,
                                     use_one_hot_embeddings, "student")
        elif task_type == "ner":
            self.teacher = BertModelNER(teacher_config, False, num_labels, dropout_prob,
                                        use_one_hot_embeddings, "teacher")
            self.bert = BertModelNER(student_config, is_training, num_labels, dropout_prob,
                                     use_one_hot_embeddings, "student")
        else:
            raise ValueError(f"Not support task type: {task_type}")
        # param_dict = load_checkpoint(teacher_ckpt)
        # new_param_dict = {}
        # for key, value in param_dict.items():
        #     new_key = re.sub('^bert.', 'teacher.', key)
        #     new_param_dict[new_key] = value
        # load_param_into_net(self.teacher, new_param_dict)

        # no_grad
        self.teacher.set_train(False)
        # params = self.teacher.trainable_params()
        # for param in params:
        #     param.requires_grad = False
        # # load student model
        # param_dict = load_checkpoint(student_ckpt)
        # if is_predistill:
        #     new_param_dict = {}
        #     for key, value in param_dict.items():
        #         new_key = re.sub('tinybert_', 'bert_', 'bert.' + key)
        #         new_param_dict[new_key] = value
        #     load_param_into_net(self.bert, new_param_dict)
        # else:
        #     new_param_dict = {}
        #     for key, value in param_dict.items():
        #         new_key = re.sub('tinybert_', 'bert_', key)
        #         new_param_dict[new_key] = value
        #     load_param_into_net(self.bert, new_param_dict)
        self.cast = P.Cast()
        self.fit_dense = nn.Dense(student_config.hidden_size,
                                  teacher_config.hidden_size).to_float(teacher_config.compute_type)
        self.teacher_layers_num = teacher_config.num_hidden_layers
        self.student_layers_num = student_config.num_hidden_layers
        self.layers_per_block = int(self.teacher_layers_num / self.student_layers_num)
        self.is_predistill = is_predistill
        self.is_att_fit = is_att_fit
        self.is_rep_fit = is_rep_fit
        self.use_soft_cross_entropy = task_type in ["classification", "ner"]
        self.temperature = temperature
        self.loss_mse = nn.MSELoss()
        self.select = P.Select()
        self.zeroslike = P.ZerosLike()
        self.dtype = student_config.dtype
        self.num_labels = num_labels
        self.dtype = teacher_config.dtype
        self.soft_cross_entropy = SoftCrossEntropy()

    def construct(self,
                  input_ids,
                  input_mask,
                  token_type_id,
                  label_ids):
        """task distill network with loss"""
        # teacher model
        teacher_seq_output, teacher_att_output, teacher_logits, _ = self.teacher(input_ids, token_type_id, input_mask)
        # student model
        student_seq_output, student_att_output, student_logits, _ = self.bert(input_ids, token_type_id, input_mask)
        total_loss = 0
        if self.is_predistill:
            if self.is_att_fit:
                selected_teacher_att_output = ()
                selected_student_att_output = ()
                for i in range(self.student_layers_num):
                    selected_teacher_att_output += (teacher_att_output[(i + 1) * self.layers_per_block - 1],)
                    selected_student_att_output += (student_att_output[i],)
                att_loss = 0
                for i in range(self.student_layers_num):
                    student_att = selected_student_att_output[i]
                    teacher_att = selected_teacher_att_output[i]
                    student_att = self.select(student_att <= self.cast(-100.0, mstype.float32),
                                              self.zeroslike(student_att),
                                              student_att)
                    teacher_att = self.select(teacher_att <= self.cast(-100.0, mstype.float32),
                                              self.zeroslike(teacher_att),
                                              teacher_att)
                    att_loss += self.loss_mse(student_att, teacher_att)
                total_loss += att_loss
            if self.is_rep_fit:
                selected_teacher_seq_output = ()
                selected_student_seq_output = ()
                for i in range(self.student_layers_num + 1):
                    selected_teacher_seq_output += (teacher_seq_output[i * self.layers_per_block],)
                    fit_dense_out = self.fit_dense(student_seq_output[i])
                    fit_dense_out = self.cast(fit_dense_out, self.dtype)
                    selected_student_seq_output += (fit_dense_out,)
                rep_loss = 0
                for i in range(self.student_layers_num + 1):
                    teacher_rep = selected_teacher_seq_output[i]
                    student_rep = selected_student_seq_output[i]
                    rep_loss += self.loss_mse(student_rep, teacher_rep)
                total_loss += rep_loss
        else:
            if self.use_soft_cross_entropy:
                cls_loss = self.soft_cross_entropy(student_logits / self.temperature, teacher_logits / self.temperature)
            else:
                cls_loss = self.loss_mse(student_logits[len(student_logits) - 1], label_ids[len(label_ids) - 1])
            total_loss += cls_loss
        return self.cast(total_loss, mstype.float32)

class BertEvaluationCell(nn.Cell):
    """
    Especially defined for finetuning where only four inputs tensor are needed.
    """
    def __init__(self, network, optimizer, sens=1.0):
        super(BertEvaluationCell, self).__init__(auto_prefix=False)
        self.network = network
        self.network.set_grad()
        self.weights = optimizer.parameters
        self.optimizer = optimizer
        self.sens = sens
        self.grad = C.GradOperation(get_by_list=True,
                                    sens_param=True)
        self.reducer_flag = False
        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
            self.reducer_flag = True
        self.grad_reducer = F.identity
        self.degree = 1
        if self.reducer_flag:
            mean = context.get_auto_parallel_context("gradients_mean")
            self.degree = get_group_size()
            self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, self.degree)
        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
        self.cast = P.Cast()
        self.hyper_map = C.HyperMap()

    def construct(self,
                  input_ids,
                  input_mask,
                  token_type_id,
                  label_ids):
        """Defines the computation performed."""
        weights = self.weights
        loss = self.network(input_ids,
                            input_mask,
                            token_type_id,
                            label_ids)
        grads = self.grad(self.network, weights)(input_ids,
                                                 input_mask,
                                                 token_type_id,
                                                 label_ids,
                                                 self.cast(F.tuple_to_array((self.sens,)),
                                                           mstype.float32))
        # apply grad reducer on grads
        grads = self.grad_reducer(grads)
        grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
        self.optimizer(grads)
        return loss

class BertEvaluationWithLossScaleCell(nn.Cell):
    """
    Especially defined for finetuning where only four inputs tensor are needed.
    """
    def __init__(self, network, optimizer, scale_update_cell=None):
        super(BertEvaluationWithLossScaleCell, self).__init__(auto_prefix=False)
        self.network = network
        self.network.set_grad()
        self.weights = optimizer.parameters
        self.optimizer = optimizer
        self.grad = C.GradOperation(get_by_list=True,
                                    sens_param=True)
        self.reducer_flag = False
        self.allreduce = P.AllReduce()
        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
            self.reducer_flag = True
        self.grad_reducer = F.identity
        self.degree = 1
        if self.reducer_flag:
            self.degree = get_group_size()
            self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
        self.cast = P.Cast()
        self.alloc_status = P.NPUAllocFloatStatus()
        self.get_status = P.NPUGetFloatStatus()
        self.clear_status = P.NPUClearFloatStatus()
        self.reduce_sum = P.ReduceSum(keep_dims=False)
        self.base = Tensor(1, mstype.float32)
        self.less_equal = P.LessEqual()
        self.hyper_map = C.HyperMap()
        self.loss_scale = None
        self.loss_scaling_manager = scale_update_cell
        if scale_update_cell:
            self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))

    def construct(self,
                  input_ids,
                  input_mask,
                  token_type_id,
                  label_ids,
                  sens=None):
        """Defines the computation performed."""
        weights = self.weights
        loss = self.network(input_ids,
                            input_mask,
                            token_type_id,
                            label_ids)
        if sens is None:
            scaling_sens = self.loss_scale
        else:
            scaling_sens = sens
        # alloc status and clear should be right before gradoperation
        init = self.alloc_status()
        init = F.depend(init, loss)
        clear_status = self.clear_status(init)
        scaling_sens = F.depend(scaling_sens, clear_status)
        grads = self.grad(self.network, weights)(input_ids,
                                                 input_mask,
                                                 token_type_id,
                                                 label_ids,
                                                 self.cast(scaling_sens,
                                                           mstype.float32))
        # apply grad reducer on grads
        grads = self.grad_reducer(grads)
        grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
        grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
        init = F.depend(init, grads)
        get_status = self.get_status(init)
        init = F.depend(init, get_status)
        flag_sum = self.reduce_sum(init, (0,))
        if self.is_distributed:
            # sum overflow flag over devices
            flag_reduce = self.allreduce(flag_sum)
            cond = self.less_equal(self.base, flag_reduce)
        else:
            cond = self.less_equal(self.base, flag_sum)
        overflow = cond
        if sens is None:
            overflow = self.loss_scaling_manager(self.loss_scale, cond)
        if not overflow:
            self.optimizer(grads)
        return (loss, cond, scaling_sens)

# 数据集

  创建tinybert训练所需的dataset。

In [14]:
##########################data_############################
class DataType(Enum):
    """Enumerate supported dataset format"""
    TFRECORD = 1
    MINDRECORD = 2

def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0,
                            do_shuffle="true", data_dir=None, schema_dir=None,
                            data_type=DataType.TFRECORD):
    """create tinybert dataset"""
    files = os.listdir(data_dir)
    data_files = []
    for file_name in files:
        if "record" in file_name and "db" not in file_name:
            data_files.append(os.path.join(data_dir, file_name))
    if task == "td":
        columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
    else:
        columns_list = ["input_ids", "input_mask", "segment_ids"]

    shard_equal_rows = True
    shuffle = (do_shuffle == "true")
    if device_num == 1:
        shard_equal_rows = False
        shuffle = False

    if data_type == DataType.MINDRECORD:
        data_set = ds.MindDataset(data_files, columns_list=columns_list,
                                  shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank)
    else:
        data_set = ds.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None, columns_list=columns_list,
                                      shuffle=shuffle, num_shards=device_num, shard_id=rank,
                                      shard_equal_rows=shard_equal_rows)
    if device_num == 1 and shuffle is True:
        data_set = data_set.shuffle(10000)

    type_cast_op = transforms.TypeCast(mstype.int32)
    data_set = data_set.map(operations=type_cast_op, input_columns="segment_ids")
    data_set = data_set.map(operations=type_cast_op, input_columns="input_mask")
    data_set = data_set.map(operations=type_cast_op, input_columns="input_ids")
    if task == "td":
        data_set = data_set.map(operations=type_cast_op, input_columns="label_ids")
    # apply batch operations
    data_set = data_set.batch(batch_size, drop_remainder=True)

    return data_set

# 训练部分。

由于大部分关于训练的部分都已经写在模型之中，所以这一部分只简单的规定了学习率优化器，评价指标， 模型保存等内容。


In [15]:
class Accuracy():
    """Accuracy"""
    def __init__(self):
        self.acc_num = 0
        self.total_num = 0

    def update(self, logits, labels):
        labels = labels.asnumpy()
        labels = np.reshape(labels, -1)
        logits = logits.asnumpy()
        logit_id = np.argmax(logits, axis=-1)
        self.acc_num += np.sum(labels == logit_id)
        self.total_num += len(labels)

class BertLearningRate(LearningRateSchedule):
    """
    Warmup-decay learning rate for Bert network.
    """
    def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
        super(BertLearningRate, self).__init__()
        self.warmup_flag = False
        if warmup_steps > 0:
            self.warmup_flag = True
            self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
        self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
        self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))

        self.greater = P.Greater()
        self.one = Tensor(np.array([1.0]).astype(np.float32))
        self.cast = P.Cast()

    def construct(self, global_step):
        decay_lr = self.decay_lr(global_step)
        if self.warmup_flag:
            is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
            warmup_lr = self.warmup_lr(global_step)
            lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
        else:
            lr = decay_lr
        return lr


class LossCallBack(Callback):
    """
    Monitor the loss in training.
    If the loss in NAN or INF terminating training.
    Note:
        if per_print_times is 0 do not print loss.
    Args:
        per_print_times (int): Print loss every times. Default: 1.
    """
    def __init__(self, per_print_times=1):
        super(LossCallBack, self).__init__()
        if not isinstance(per_print_times, int) or per_print_times < 0:
            raise ValueError("print_step must be int and >= 0")
        self._per_print_times = per_print_times

    def step_end(self, run_context):
        """step end and print loss"""
        cb_params = run_context.original_args()
        print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num,
                                                           cb_params.cur_step_num,
                                                           str(cb_params.net_outputs)))

class ModelSaveCkpt(Callback):
    """
    Saves checkpoint.
    If the loss in NAN or INF terminating training.
    Args:
        network (Network): The train network for training.
        save_ckpt_num (int): The number to save checkpoint, default is 1000.
        max_ckpt_num (int): The max checkpoint number, default is 3.
    """
    def __init__(self, network, save_ckpt_step, max_ckpt_num, output_dir):
        super(ModelSaveCkpt, self).__init__()
        self.count = 0
        self.network = network
        self.save_ckpt_step = save_ckpt_step
        self.max_ckpt_num = max_ckpt_num
        self.output_dir = output_dir

    def step_end(self, run_context):
        """step end and save ckpt"""
        cb_params = run_context.original_args()
        if cb_params.cur_step_num % self.save_ckpt_step == 0:
            saved_ckpt_num = cb_params.cur_step_num / self.save_ckpt_step
            if saved_ckpt_num > self.max_ckpt_num:
                oldest_ckpt_index = saved_ckpt_num - self.max_ckpt_num
                # path = os.path.join(self.output_dir, "tiny_bert_{}_{}.ckpt".format(int(oldest_ckpt_index),
                #                                                                    self.save_ckpt_step))
                path = os.path.join(self.output_dir, "tiny_bert_wiki.ckpt".format(int(oldest_ckpt_index),
                                                                   self.save_ckpt_step))
                if os.path.exists(path):
                    os.remove(path)
            # save_checkpoint(self.network, os.path.join(self.output_dir,
            #                                            "tiny_bert_{}_{}.ckpt".format(int(saved_ckpt_num),
            #                                                                          self.save_ckpt_step)))
            save_checkpoint(self.network, os.path.join(self.output_dir,
                                                       "tiny_bert_wiki.ckpt".format(int(saved_ckpt_num),
                                                                                     self.save_ckpt_step)))

class EvalCallBack(Callback):
    """Evaluation callback"""
    def __init__(self, network, dataset):
        super(EvalCallBack, self).__init__()
        self.network = network
        self.global_acc = 0.0
        self.dataset = dataset

    def epoch_end(self, run_context):
        """step end and do evaluation"""
        callback = Accuracy()
        columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
        for data in self.dataset.create_dict_iterator(num_epochs=1):
            input_data = []
            for i in columns_list:
                input_data.append(data[i])
            input_ids, input_mask, token_type_id, label_ids = input_data
            self.network.set_train(False)
            logits = self.network(input_ids, token_type_id, input_mask)
            self.network.set_train(True)
            callback.update(logits, label_ids)
        acc = callback.acc_num / callback.total_num
        with open("./eval.log", "a+") as f:
            f.write("acc_num {}, total_num{}, accuracy{:.6f}".format(callback.acc_num, callback.total_num,
                                                                     callback.acc_num / callback.total_num))
            f.write('\n')

        if acc > self.global_acc:
            self.global_acc = acc
            print("The best acc is {}".format(acc))
            eval_model_ckpt_file = "eval_model.ckpt"
            if os.path.exists(eval_model_ckpt_file):
                os.remove(eval_model_ckpt_file)

# Config 设置

config 是代码运行的基础。这里通过手动设置和yaml文件读取两个方式共同创建config。

首先定义从文件读取config的函数。
- Config ： 定义config
- parse_yaml ： 读取yaml文件
- parse_cli_to_yaml： 把yaml文件加入argparse中
- extra_operations : 将总config分离成几个负责单独部分的config。

In [16]:
class Config:
    """
    Configuration namespace. Convert dictionary to members.
    """
    def __init__(self, cfg_dict):
        for k, v in cfg_dict.items():
            if isinstance(v, (list, tuple)):
                setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
            else:
                setattr(self, k, Config(v) if isinstance(v, dict) else v)

    def __str__(self):
        return pformat(self.__dict__)

    def __repr__(self):
        return self.__str__()


def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="pretrain_base_config.yaml"):
    """
    Parse command line arguments to the configuration according to the default yaml.

    Args:
        parser: Parent parser.
        cfg: Base configuration.
        helper: Helper description.
        cfg_path: Path to the default yaml config.
    """
    parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
                                     parents=[parser])
    helper = {} if helper is None else helper
    choices = {} if choices is None else choices
    for item in cfg:
        try:
            if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
                help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
                choice = choices[item] if item in choices else None
                if isinstance(cfg[item], bool):
                    parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
                                        help=help_description)
                else:
                    parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
                                        help=help_description)
        except:
            pass
    args = parser.parse_args(args=[])
    return args


def parse_yaml(yaml_path):
    """
    Parse the yaml config file.

    Args:
        yaml_path: Path to the yaml config.
    """
    with open(yaml_path, 'r') as fin:
        try:
            cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
            cfgs = [x for x in cfgs]
            if len(cfgs) == 1:
                cfg_helper = {}
                cfg = cfgs[0]
                cfg_choices = {}
            elif len(cfgs) == 2:
                cfg, cfg_helper = cfgs
                cfg_choices = {}
            elif len(cfgs) == 3:
                cfg, cfg_helper, cfg_choices = cfgs
            else:
                raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
            # print(cfg_helper)
        except:
            raise ValueError("Failed to parse yaml")
    return cfg, cfg_helper, cfg_choices


def merge(args, cfg):
    """
    Merge the base config from yaml file and command line arguments.

    Args:
        args: Command line arguments.
        cfg: Base configuration.
    """
    args_var = vars(args)
    for item in args_var:
        cfg[item] = args_var[item]
    return cfg

def extra_operations(cfg):
    """
    Do extra work on config

    Args:
        config: Object after instantiation of class 'Config'.
    """
    def create_filter_fun(keywords):
        return lambda x: not (True in [key in x.name.lower() for key in keywords])

    if cfg.description == 'general_distill':
        cfg.common_cfg.loss_scale_value = 2 ** 16
        cfg.common_cfg.AdamWeightDecay.decay_filter = create_filter_fun(cfg.common_cfg.AdamWeightDecay.decay_filter)
        cfg.bert_teacher_net_cfg.dtype = mstype.float32
        cfg.bert_teacher_net_cfg.compute_type = mstype.float16
        cfg.bert_student_net_cfg.dtype = mstype.float32
        cfg.bert_student_net_cfg.compute_type = mstype.float16
        cfg.bert_teacher_net_cfg = BertConfig(**cfg.bert_teacher_net_cfg.__dict__)
        cfg.bert_student_net_cfg = BertConfig(**cfg.bert_student_net_cfg.__dict__)
    elif cfg.description == 'task_distill':
        cfg.phase1_cfg.loss_scale_value = 2 ** 8
        cfg.phase1_cfg.optimizer_cfg.AdamWeightDecay.decay_filter = create_filter_fun(
            cfg.phase1_cfg.optimizer_cfg.AdamWeightDecay.decay_filter)
        cfg.phase2_cfg.loss_scale_value = 2 ** 16
        cfg.phase2_cfg.optimizer_cfg.AdamWeightDecay.decay_filter = create_filter_fun(
            cfg.phase2_cfg.optimizer_cfg.AdamWeightDecay.decay_filter)
        cfg.td_teacher_net_cfg.dtype = mstype.float32
        cfg.td_teacher_net_cfg.compute_type = mstype.float16
        cfg.td_student_net_cfg.dtype = mstype.float32
        cfg.td_student_net_cfg.compute_type = mstype.float16
        cfg.td_teacher_net_cfg = BertConfig(**cfg.td_teacher_net_cfg.__dict__)
        cfg.td_student_net_cfg = BertConfig(**cfg.td_student_net_cfg.__dict__)
    else:
        pass
    return cfg


# 预训练蒸馏。
定义好了模型，数据集，config和一些杂项，我们便可以开始在wiki数据集上的预训练蒸馏任务。


## 预训练蒸馏config
config 是代码运行的基础。这里通过手动设置和yaml文件读取两个方式共同创建config。 
其中 与模型和训练相关的大部分config都在 yaml文件中。

In [17]:
import argparse

parser_gen = argparse.ArgumentParser(description="default name", add_help=False)

parser_gen.add_argument("--distribute", default="False",
                    help="if distribute")

parser_gen.add_argument("--device_target", default="Ascend",
                    help="device_target")

parser_gen.add_argument("--epoch_size", default=3,
                    help="epoch_size")


parser_gen.add_argument("--save_ckpt_step", default=1,
                    help="save_ckpt_step")

parser_gen.add_argument("--max_ckpt_num", default=1,
                    help="max_ckpt_num")

parser_gen.add_argument("--save_ckpt_path", default="save/tinybert_wiki",
                    help="save_ckpt_path")

parser_gen.add_argument("--data_dir", default="data/wiki",
                    help="data_dir")

parser_gen.add_argument("--load_teacher_ckpt_path", default="bert/ms_model_ckpt.ckpt",
                    help="load_teacher_ckpt_path")

parser_gen.add_argument("--dataset_type", default="tfrecord",
                    help="dataset_type")

config_gen_path = 'config/gd_config.yaml'
default_gen, helper_gen, choices_gen = parse_yaml(config_gen_path)
args_gen = parse_cli_to_yaml(parser=parser_gen, cfg=default_gen, helper=helper_gen, choices=choices_gen, cfg_path=config_gen_path)

final_config_gen = merge(args_gen, default_gen)
config_obj_gen = Config(final_config_gen)

config = extra_operations(config_obj_gen)


common_cfg = config.common_cfg
bert_teacher_net_cfg = config.bert_teacher_net_cfg
bert_student_net_cfg = config.bert_student_net_cfg

args_opt = config
print(config)

{'bert_student_net_cfg': <__main__.BertConfig object at 0xffff26024d90>,
 'bert_teacher_net_cfg': <__main__.BertConfig object at 0xffff24e84490>,
 'checkpoint_url': '',
 'common_cfg': {'AdamWeightDecay': {'decay_filter': <function extra_operations.<locals>.create_filter_fun.<locals>.<lambda> at 0xffff24e16170>,
 'end_learning_rate': 1e-14,
 'eps': 1e-06,
 'learning_rate': 5e-05,
 'power': 1.0,
 'weight_decay': 0.0001},
 'batch_size': 32,
 'loss_scale_value': 65536,
 'scale_factor': 2,
 'scale_window': 1000},
 'data_dir': 'data/wiki',
 'data_path': '/cache/data',
 'data_sink_steps': 1,
 'data_url': '',
 'dataset_type': 'tfrecord',
 'description': 'general_distill',
 'device_id': 0,
 'device_num': 1,
 'device_target': 'Ascend',
 'distribute': 'False',
 'do_shuffle': 'true',
 'enable_data_sink': 'true',
 'enable_modelarts': False,
 'enable_profiling': False,
 'epoch_size': 3,
 'folder_name_under_zip_file': './',
 'load_path': '/cache/checkpoint_path',
 'load_teacher_ckpt_path': 'bert/ms_m

## 环境设置。

context 是用来存储训练时的环境变量的。 这里定义了存储文件夹和训练时的一些基础环境。具体含义见[ms文档](https://www.mindspore.cn/docs/zh-CN/r1.8/index.html)， 

在这里不启用分布式训练。在Ascend 环境中，也不使用混合精度训练。



In [18]:

set_seed(0)

context.set_context(mode=context.PYNATIVE_MODE, device_target=args_opt.device_target,
                    reserve_class_name_in_scope=False)


# save_ckpt_dir = os.path.join(args_opt.save_ckpt_path,
#                              datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
save_ckpt_dir = args_opt.save_ckpt_path

if not os.path.exists(save_ckpt_dir):
    os.makedirs(save_ckpt_dir)
    
if args_opt.distribute == "true":
    if args_opt.device_target == 'Ascend':
        D.init()
        device_num = args_opt.device_num
        rank = args_opt.device_id % device_num
    else:
        D.init()
        device_num = D.get_group_size()
        rank = D.get_rank()
    save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank)
    context.reset_auto_parallel_context()
    context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
                                      device_num=device_num)
else:
    rank = 0
    device_num = 1
    
enable_loss_scale = True 

if args_opt.device_target == "Ascend":
    context.set_context(device_id=args_opt.device_id)

## 读取数据

这里数据格式为TFRECORD， 利用上面定义的函数创建数据集

In [19]:
 
if args_opt.dataset_type == "tfrecord":
    dataset_type = DataType.TFRECORD
elif args_opt.dataset_type == "mindrecord":
    dataset_type = DataType.MINDRECORD
else:
    raise Exception("dataset format is not supported yet")
    
dataset = create_tinybert_dataset('gd', common_cfg.batch_size, device_num, rank,
                                  args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir,
                                  data_type=dataset_type)
dataset_size = dataset.get_dataset_size()
print('dataset size: ', dataset_size)
print("dataset repeatcount: ", dataset.get_repeat_count())


repeat_count = args_opt.epoch_size
time_monitor_steps = dataset_size

dataset size:  251
dataset repeatcount:  1


## 定义模型
  
我们使用刚才定义的用于general distill的模型

In [20]:
netwithloss = BertNetworkWithLoss_gd(teacher_config=bert_teacher_net_cfg,
                                     teacher_ckpt=args_opt.load_teacher_ckpt_path,
                                     student_config=bert_student_net_cfg,
                                     is_training=True, use_one_hot_embeddings=False)

## 定义学习率和优化器

学习率使用上面定义的BERT的学习率函数，优化器用adam。

netwithgrads 模型中已经定义好了梯度计算与回传，将学习率与优化器传入即可。


ModelSaveCkpt 定义了模型的保存方式。 注意因为这里只是预训练示例，因此训练数据非常小。 示例中只能把保存的step：<font color=black size=2 face=雅黑>**save_ckpt_step**</font> 调整的非常小。这里定义为1. 实际模型中可以考虑设置为100 或者 200等。这里为了简化， 将保存的模型统一命名为 <font color=black size=2 face=雅黑>**tiny_bert_wiki.ckpt**</font>. 方便后续使用。

In [21]:
lr_schedule = BertLearningRate(learning_rate=common_cfg.AdamWeightDecay.learning_rate,
                               end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate,
                               warmup_steps=int(dataset_size * args_opt.epoch_size / 10),
                               decay_steps=int(dataset_size * args_opt.epoch_size),
                               power=common_cfg.AdamWeightDecay.power)
params = netwithloss.trainable_params()
decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: not common_cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': common_cfg.AdamWeightDecay.weight_decay},
                {'params': other_params, 'weight_decay': 0.0},
                {'order_params': params}]

optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=common_cfg.AdamWeightDecay.eps)

callback = [TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
                                                                           args_opt.save_ckpt_step,
                                                                           args_opt.max_ckpt_num,
                                                                           save_ckpt_dir)]

In [22]:
if enable_loss_scale:
    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=common_cfg.loss_scale_value,
                                             scale_factor=common_cfg.scale_factor,
                                             scale_window=common_cfg.scale_window)
    netwithgrads = BertTrainWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    print('enable_loss_scale=True')
else:
    netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer)

enable_loss_scale=True


## 模型训练

mindspore中，[Model](https://mindspore.cn/docs/zh-CN/r1.7/api_python/mindspore/mindspore.Model.html#mindspore.Model)是MindSpore提供的高阶API，可以进行模型训练、评估和推理。 

具体见[Model详情页](https://www.mindspore.cn/tutorials/zh-CN/r1.7/advanced/train/model.html)

我们可以把上面定义的模型，训练设置等传入Model，Model即可自动训练。


In [23]:
model = Model(netwithgrads)
model.train(repeat_count, dataset, callbacks=callback,
            dataset_sink_mode=(args_opt.enable_data_sink == "true"),
            sink_size=args_opt.data_sink_steps)



epoch: 1, step: 1, outputs are (Tensor(shape=[], dtype=Float32, value= 52.5542), Tensor(shape=[], dtype=Bool, value= False), Parameter (name=loss_scale, shape=(), dtype=Float32, requires_grad=True))
Train epoch time: 48191.528 ms, per step time: 48191.528 ms




epoch: 2, step: 2, outputs are (Tensor(shape=[], dtype=Float32, value= 48.2302), Tensor(shape=[], dtype=Bool, value= False), Parameter (name=loss_scale, shape=(), dtype=Float32, requires_grad=True))
Train epoch time: 3890.811 ms, per step time: 3890.811 ms
epoch: 3, step: 3, outputs are (Tensor(shape=[], dtype=Float32, value= 49.1845), Tensor(shape=[], dtype=Bool, value= False), Parameter (name=loss_scale, shape=(), dtype=Float32, requires_grad=True))
Train epoch time: 858.037 ms, per step time: 858.037 ms


# 下游任务蒸馏。

下面我们开始以QNLI数据集为例，进行任务蒸馏。

QNLI是从另一个数据集The Stanford Question Answering Dataset(斯坦福问答数据集, SQuAD 1.0）转换而来的。SQuAD 1.0是有一个问题-段落对组成的问答数据集，其中段落来自维基百科，段落中的一个句子包含问题的答案。

QNLI目标是判断问题（question）和句子（sentence，维基百科段落中的一句）是否蕴含，蕴含和不蕴含，二分类。我们可以简单的当作分类任务即可。

## 任务蒸馏config
config 是代码运行的基础。这里通过手动设置和yaml文件读取两个方式共同创建config。 
其中 与模型和训练相关的大部分config都在 yaml文件中。

注意在代码中，任务蒸馏分成了两个阶段。 而两个阶段仅仅是训练超参不同，然后仅在第二阶段进行对验证集的预测。 所以config也要分开为两个。 这一点， 可以在自己训练时根据自己喜好，只进行一个阶段也是可以的。



In [24]:
###################task_config####################


parser_task = argparse.ArgumentParser(description="default name", add_help=False)

parser_task.add_argument("--do_train", default="True",
                    help="do_train")

parser_task.add_argument("--do_eval", default="True",
                    help="do_eval")

parser_task.add_argument("--device_target", default="Ascend",
                    help="device_target")

parser_task.add_argument("--device_id", default=0,
                    help="device_id")

parser_task.add_argument("--td_phase1_epoch_size", default=1,
                    help="td_phase1_epoch_size")

parser_task.add_argument("--td_phase2_epoch_size", default=3,
                    help="td_phase2_epoch_size")

parser_task.add_argument("--do_shuffle", default="true",
                    help="do_shuffle")

parser_task.add_argument("--max_ckpt_num", default=1,
                    help="max_ckpt_num")

parser_task.add_argument("--load_teacher_ckpt_path", default="bert/ms_model_ckpt.ckpt",
                    help="load_teacher_ckpt_path")

parser_task.add_argument("--load_gd_ckpt_path", default="save/tinybert_wiki/tiny_bert_wiki.ckpt",
                    help="load_gd_ckpt_path")

parser_task.add_argument("--load_td1_ckpt_path", default="",
                    help="load_td1_ckpt_path")

parser_task.add_argument("--train_data_dir", default="data/glue/qnli",
                    help="train_data_dir")

parser_task.add_argument("--eval_data_dir", default="data/glue/qnli",
                    help="eval_data_dir")

parser_task.add_argument("--dataset_type", default="tfrecord",
                    help="dataset_type")

parser_task.add_argument("--task_type", default="classification",
                    help="task_type")

parser_task.add_argument("--task_name", default="QNLI",
                    help="task_name")

parser_task.add_argument("--assessment_method", default="accuracy",
                    help="assessment_method")


config_task_path = 'config/td_config_qnli.yaml'
default_task, helper_task, choices_task = parse_yaml(config_task_path)
args_task = parse_cli_to_yaml(parser=parser_task, cfg=default_task, helper=helper_task, choices=choices_task, cfg_path=config_task_path)
final_config_task = merge(args_task, default_task)
config_obj_task = Config(final_config_task)
config_task = extra_operations(config_obj_task)
config = config_task

phase1_cfg = config.phase1_cfg
phase2_cfg = config.phase2_cfg
eval_cfg = config.eval_cfg
td_teacher_net_cfg = config.td_teacher_net_cfg
td_student_net_cfg = config.td_student_net_cfg


print(config)
args_opt = config

_cur_dir = os.getcwd()
td_phase1_save_ckpt_dir = os.path.join(_cur_dir, 'save/tinybert_td_phase1_save_ckpt')
td_phase2_save_ckpt_dir = os.path.join(_cur_dir, 'save/tinybert_td_phase2_save_ckpt')
if not os.path.exists(td_phase1_save_ckpt_dir):
    os.makedirs(td_phase1_save_ckpt_dir)
if not os.path.exists(td_phase2_save_ckpt_dir):
    os.makedirs(td_phase2_save_ckpt_dir)
    
enable_loss_scale = False
set_seed(123)
ds.config.set_seed(12345)
dataset_type = DataType.TFRECORD
cfg = phase1_cfg

{'assessment_method': 'accuracy',
 'checkpoint_url': '',
 'ckpt_file': '',
 'data_path': '/cache/data',
 'data_sink_steps': 1,
 'data_url': '',
 'dataset_type': 'tfrecord',
 'description': 'task_distill',
 'device_id': 0,
 'device_target': 'Ascend',
 'do_eval': 'True',
 'do_shuffle': 'true',
 'do_train': 'True',
 'enable_data_sink': 'true',
 'enable_modelarts': False,
 'enable_profiling': False,
 'eval_cfg': {'batch_size': 32},
 'eval_data_dir': 'data/glue/qnli',
 'file_format': 'MINDIR',
 'file_name': 'tinybert',
 'folder_name_under_zip_file': '',
 'load_gd_ckpt_path': 'save/tinybert_wiki/tiny_bert_wiki.ckpt',
 'load_path': '/cache/checkpoint_path',
 'load_td1_ckpt_path': '',
 'load_teacher_ckpt_path': 'bert/ms_model_ckpt.ckpt',
 'max_ckpt_num': 1,
 'modelarts_dataset_unzip_name': '',
 'num_labels': 2,
 'onnx_path': '',
 'output_path': '/cache/train',
 'phase1_cfg': {'batch_size': 32,
 'loss_scale_value': 256,
 'optimizer_cfg': {'AdamWeightDecay': {'decay_filter': <function extra_oper

## 第一阶段

### 环境 数据集 与模型设置

  与预训练蒸馏相似。

In [25]:
rank = 0
device_num = 1

dataset = create_tinybert_dataset('td', cfg.batch_size,
                                  device_num, rank, args_opt.do_shuffle,
                                  args_opt.train_data_dir, args_opt.schema_dir,
                                  data_type=dataset_type)

dataset_size = dataset.get_dataset_size()
print('td1 dataset size: ', dataset_size)
print('td1 dataset repeatcount: ', dataset.get_repeat_count())
args_opt.data_sink_steps = dataset_size
repeat_count = args_opt.td_phase1_epoch_size
time_monitor_steps = dataset_size

load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
load_student_checkpoint_path = args_opt.load_gd_ckpt_path
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
                                     student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
                                     is_training=True, task_type=args_opt.task_type,
                                     num_labels=args_opt.num_labels, is_predistill=True)

td1 dataset size:  3443
td1 dataset repeatcount:  1


### 设置学习率与优化器

 ModelSaveCkpt 规定了保存的位置。

In [26]:
optimizer_cfg = cfg.optimizer_cfg

lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
                               end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
                               warmup_steps=int(dataset_size / 10),
                               decay_steps=int(dataset_size * args_opt.td_phase1_epoch_size),
                               power=optimizer_cfg.AdamWeightDecay.power)





params = netwithloss.trainable_params()
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
                {'params': other_params, 'weight_decay': 0.0},
                {'order_params': params}]

optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)

callback = [TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
                                                                           dataset_size,
                                                                           args_opt.max_ckpt_num,
                                                                           td_phase1_save_ckpt_dir)]


### 载入梯度模型并训练。

In [27]:
netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer)

model = Model(netwithgrads)
model.train(repeat_count, dataset, callbacks=callback,
            dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
            sink_size=args_opt.data_sink_steps)



epoch: 1, step: 1, outputs are 5.313729
epoch: 1, step: 2, outputs are 5.312618
epoch: 1, step: 3, outputs are 5.2948303
epoch: 1, step: 4, outputs are 5.285376
epoch: 1, step: 5, outputs are 5.257374
epoch: 1, step: 6, outputs are 5.221371
epoch: 1, step: 7, outputs are 5.2034407
epoch: 1, step: 8, outputs are 5.147038
epoch: 1, step: 9, outputs are 5.101118
epoch: 1, step: 10, outputs are 5.045333
epoch: 1, step: 11, outputs are 4.9651184
epoch: 1, step: 12, outputs are 4.8991895
epoch: 1, step: 13, outputs are 4.870689
epoch: 1, step: 14, outputs are 4.816765
epoch: 1, step: 15, outputs are 4.786441
epoch: 1, step: 16, outputs are 4.7112
epoch: 1, step: 17, outputs are 4.6605344
epoch: 1, step: 18, outputs are 4.569983
epoch: 1, step: 19, outputs are 4.5126414
epoch: 1, step: 20, outputs are 4.4686213
epoch: 1, step: 21, outputs are 4.394112
epoch: 1, step: 22, outputs are 4.3465514
epoch: 1, step: 23, outputs are 4.2930713
epoch: 1, step: 24, outputs are 4.1946926
epoch: 1, step: 2

## 二阶段

 二阶段我们载入一阶段保存的模型， 调整超参，开始新的训练，其他与一阶段相似。
 
 
 ### 载入最新一阶段模型

In [28]:
lists = os.listdir(td_phase1_save_ckpt_dir)
if lists:
    lists.sort(key=lambda fn: os.path.getmtime(td_phase1_save_ckpt_dir + '/' + fn))
    name_ext = os.path.splitext(lists[-1])
    assert name_ext[-1] == ".ckpt", "Invalid file, checkpoint file should be .ckpt file"
    ckpt_file = os.path.join(td_phase1_save_ckpt_dir, lists[-1])
    if ckpt_file == '':
        raise ValueError("Student ckpt file should not be None")
    cfg = phase2_cfg
    
    
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
load_student_checkpoint_path = ckpt_file
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
                                     student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
                                     is_training=True, task_type=args_opt.task_type,
                                     num_labels=args_opt.num_labels, is_predistill=False)

### 重新读入数据集

在二阶段 我们需要对模型的结果进行验证。

In [29]:
rank = 0
device_num = 1
train_dataset = create_tinybert_dataset('td', cfg.batch_size,
                                        device_num, rank, args_opt.do_shuffle,
                                        args_opt.train_data_dir, args_opt.schema_dir,
                                        data_type=dataset_type)

dataset_size = train_dataset.get_dataset_size()
print('td2 train dataset size: ', dataset_size)
print('td2 train dataset repeatcount: ', train_dataset.get_repeat_count())

repeat_count = args_opt.td_phase2_epoch_size

time_monitor_steps = dataset_size

eval_dataset = create_tinybert_dataset('td', eval_cfg.batch_size,
                                       device_num, rank, args_opt.do_shuffle,
                                       args_opt.eval_data_dir, args_opt.schema_dir,
                                       data_type=dataset_type)
print('td2 eval dataset size: ', eval_dataset.get_dataset_size())

td2 train dataset size:  3443
td2 train dataset repeatcount:  1
td2 eval dataset size:  3443


### 二阶段学习率与优化器  

注意在callback中传入了测试集。
在EvalCallBack 中规定了模型的保存。

In [30]:
optimizer_cfg = cfg.optimizer_cfg

lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
                               end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
                               warmup_steps=int(dataset_size * args_opt.td_phase2_epoch_size / 10),
                               decay_steps=int(dataset_size * args_opt.td_phase2_epoch_size),
                               power=optimizer_cfg.AdamWeightDecay.power)

params = netwithloss.trainable_params()
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
                {'params': other_params, 'weight_decay': 0.0},
                {'order_params': params}]

optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)

callback = [TimeMonitor(time_monitor_steps), LossCallBack(),
            EvalCallBack(netwithloss.bert, eval_dataset)]

### 训练

注意 这里训练时， 如果没有屏蔽一阶段的训练，或者一阶段中途停止。 可能会出现报错:C++ Call Stack。 如果出现报错，可以先屏蔽一阶段训练的代码段。

In [None]:
netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer)

model = Model(netwithgrads)
model.train(repeat_count, train_dataset, callbacks=callback,
            dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
            sink_size=args_opt.data_sink_steps)



epoch: 1, step: 1, outputs are 0.34506848
epoch: 1, step: 2, outputs are 0.34500617
epoch: 1, step: 3, outputs are 0.34492475
epoch: 1, step: 4, outputs are 0.34461087
epoch: 1, step: 5, outputs are 0.3443854
epoch: 1, step: 6, outputs are 0.34384793
epoch: 1, step: 7, outputs are 0.3433299
epoch: 1, step: 8, outputs are 0.34268647
epoch: 1, step: 9, outputs are 0.34155688
epoch: 1, step: 10, outputs are 0.34076613
epoch: 1, step: 11, outputs are 0.33978412
epoch: 1, step: 12, outputs are 0.33741572
epoch: 1, step: 13, outputs are 0.33684337
epoch: 1, step: 14, outputs are 0.33472025
epoch: 1, step: 15, outputs are 0.33327097
epoch: 1, step: 16, outputs are 0.3312292
epoch: 1, step: 17, outputs are 0.32820725
epoch: 1, step: 18, outputs are 0.32503647
epoch: 1, step: 19, outputs are 0.3230511
epoch: 1, step: 20, outputs are 0.3192774
epoch: 1, step: 21, outputs are 0.31737298
epoch: 1, step: 22, outputs are 0.31302541
epoch: 1, step: 23, outputs are 0.31051105
epoch: 1, step: 24, outpu