## BERT代码实现

基于huggingface/Transformers库(3.1.0版本)中的pytorch版本BERT实现，我来实现自己的BERT模型

### 架构如下图：

BertSelfAttention和BertSelfOutput构成了BertAttention，也就是Attention层的结构，再与全连接层BertIntermediate和BertOutput构成了一层

![BERT架构图](./imgs/BERT.jpg)

### 具体代码如下：

In [1]:
import pdb
import math
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from transformers import (BertConfig,
                          BertTokenizer,
                          set_seed,
                         )

In [2]:
# 保证每次跑的结果一致性
set_seed(2020)

#### 这一部分是为了标准化输出

In [3]:
@dataclass
class BaseModelOutput:
    """
    模型输出的基类，可能有hidden_states和attentions
    last_hidden_state: 模型最后一层的输出，(batch_size, seq_length, hidden_size) eg. (64, 128, 768)
    hidden_states: 元组，(num_hidden_layer+1)个，也就是13个，包含embedding的输出和其他所有层的输出
    attentions: num_hidden_layer * (batch_size, num_heads, seq_length, seq_length) eg. 12 * (64, 12, 128, 128)
    """
    
    last_hidden_state: torch.FloatTensor
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None    

In [4]:
@dataclass
class BaseModelOutputWithPooling:
    """
    BERT模型输出的基类
    last_hidden_state: 模型最后一层的输出，(batch_size, seq_length, hidden_size) eg. (64, 128, 768)
    pooler_output: 模型最后一层的cls输出，乘以(hidden_size, hidden_size)后的结果，(batch_size, hidden_size) eg. (64, 768)
    hidden_states: 元组，(num_hidden_layer+1)个，也就是13个，包含embedding的输出和其他所有层的输出
    attentions: num_hidden_layer * (batch_size, num_heads, seq_length, seq_length) eg. 12 * (64, 12, 128, 128)
    """
    
    last_hidden_state: torch.FloatTensor
    pooler_output: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

In [5]:
@dataclass
class BertForPreTrainingOutput:
    """
    BertForPreTrainingModel的输出格式
    
    loss: 损失，一个值。包含MLM任务和NSP任务的交叉熵损失之和。
    prediction_logits: MLM任务的logits，在softmax之前的结果。(batch_size, seq_length, hidden_size) eg. (64, 128, 768))
    seq_relationship_logits: NSP任务的logits。(batch_size, 2), eg. (64, 2)
    hidden_states: 元组，(num_hidden_layer+1)个，也就是13个，包含embedding的输出和其他所有层的输出
    attentions: num_hidden_layer * (batch_size, num_heads, seq_length, seq_length) eg. 12 * (64, 12, 128, 128)
    """
    
    loss: Optional[torch.FloatTensor] = None
    prediction_logits: torch.FloatTensor = None
    seq_relationship_logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

In [6]:
@dataclass
class MaskedLMOutput:
    """
    masked language model输出的基类
    """
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

In [7]:
@dataclass
class NextSentencePredictionOutput:
    """
    单NSP任务模型的输出
    """
    
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

In [8]:
@dataclass
class SequenceClassifierOutput:
    """
    句子分类模型的输出基类
    """
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

In [9]:
@dataclass
class TokenClassifierOutput:
    """
    token classification模型输出的基类
    """
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

#### BERT模型的构建

BERT模型构建，包含三层：embedding、encoder、pooler。其中，encoder包含多层layer。每个layer包含attention、全连接层。

In [10]:
# layer norm层
BertLayerNorm = nn.LayerNorm

In [11]:
# embedding层实现，三个输入相加，然后layernorm，再然后dropout
class BertEmbeddings(nn.Module):
    """ embeddings相关处理 """
    
    def __init__(self, config):
        super().__init__()
        # 首先是词，位置，type三个相加
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
        
        # LayerNorm & dropout
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
        # 位置编码 (1, seq_length)
        self.register_buffer('position_ids', torch.arange(config.max_position_embeddings).expand(1, -1))
        
    def forward(self, input_ids, position_ids=None, token_type_ids=None, inputs_embeds=None):
        """
        input_ids (batch, seq_length)
        position_ids (batch, seq_length) or None
        type_ids (batch,) or None
        """
        # 得到batch size和seq length
        # 用于处理postion_ids和type_ids为空时的默认值
        if inputs_embeds is not None:
            input_shape = inputs_embeds.shape[:-1] 
        else:
            input_shape = input_ids.shape 
        batch_size, seq_length = input_shape
        
        # 位置编码处理, 默认0-n
        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]
        # type编码处理，默认全是0，全是第一个
        # 生成全是0的tensor的方法
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
        
        # 词embedding优先使用输入的，其次是input_ids得到的
        if inputs_embeds is not None:
            word_embeds = inputs_embeds
        else:
            word_embeds = self.word_embeddings(input_ids)
        # 位置和type，使用embedding查找
        position_embeds = self.position_embeddings(position_ids)
        token_type_embeds = self.type_embeddings(token_type_ids)
        
        # 相加得到需要的结果，然后在经过layer norm层和dropout层
        embeddings = word_embeds + position_embeds + token_type_embeds
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

In [12]:
class BertSelfAttention(nn.Module):
    """
    self attention实现
    """
    
    def __init__(self, config):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, 'embedding_size'):
            raise ValueError(
                f'The hidden size {config.hidden_size} is not a multiple of the number of attention '
                f'heads (config.num_attention_heads)'
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # eg. query (768, 768)
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
    
    def transpose_for_scores(self, x):
        # x (batch_size, seq_length, hidden_states) eg. (64, 128, 768)
        # new_x_shape (batch_size, seq_length, num_attention_heads, attention_head_size) eg. (64, 128, 12, 64)
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        # 转换shape；调整顺序
        x = x.view(*new_x_shape)
        # 转换后： (batch_size, num_attention_heads, seq_length, attention_head_size) eg. (64, 12, 128, 64) 
        x = x.permute(0, 2, 1, 3)
        return x
    
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        output_attentions=False,
    ):
        # hidden_states (64, 128, 768)
        # query (768, 768)
        # mixed_query_layer (64, 128, 768)
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        # query_layer (batch_size, num_attention_heads, seq_length, attention_head_size) eg. (64, 12, 128, 64) 
        # key_layer & value_layer同理
        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # quey_layer
        # (batch_size, num_attention_heads, seq_length, attention_head_size) eg. (64, 12, 128, 64) 
        # key_layer.transpose(-1, -2):
        # (batch_size, num_attention_heads, attention_head_size, seq_length) eg. (64, 12, 64, 128) 
        # 乘积结果 (batch_size, num_attention_heads, seq_length, seq_length) eg. (64, 12, 128, 128)
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        # 除以attention_head_size的开根号 也就是除以12的开根号
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        # attention_mask: (batch_size, num_heads, from_seq_length, to_seq_length)
        # 但是设置的是哦户num_heads, from_seq_length设置为1，所以eg: (64, 1, 1, 128)
        if attention_mask is not None:
            # attention-scores: (batch_size, num_attention_heads, seq_length, seq_length) eg. (64, 12, 128, 128)
            attention_scores = attention_scores + attention_mask

        # 归一化attention_scores为概率率
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = self.dropout(attention_probs)

        # attention_probs: (batch_size, num_attention_heads, seq_length, seq_length) eg. (64, 12, 128, 128)
        # value_layer: (batch_size, num_attention_heads, seq_length, attention_head_size) eg. (64, 12, 128, 64) 
        # context_layer: (batch_size, num_attention_heads, seq_length, attention_head_size) eg. (64, 12, 128, 64)
        context_layer = torch.matmul(attention_probs, value_layer)

        # context_layer: (batch_size, num_attention_heads, seq_length, attention_head_size) eg. (64, 12, 128, 64)
        # context_layer 新: (batch_size, seq_length, num_attention_heads, attention_head_size) eg. (64, 128, 12, 64)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        # new_context_layer_shape: (batch_size, seq_length, all_head_size) eg. (64, 128, 768)
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, )
        context_layer = context_layer.view(*new_context_layer_shape)

        # context_layer: (batch_size, seq_length, all_head_size) eg. (64, 128, 768)
        # attention_probs: (batch_size, num_attention_heads, seq_length, seq_length) eg. (64, 12, 128, 128)
        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
        return outputs

In [13]:
class BertSelfOutput(nn.Module):
    """
    实现了attention之后的全连接和残差层
    与BertSelfAttention一起构成了BERT的Attention层
    """
    
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
    def forward(self, hidden_states, input_tensor):
        # hidden_states: (batch_size, seq_length, hidden_size) eg. (64, 128, 768)
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

In [14]:
class BertAttention(nn.Module):
    """
    Bert Attention层
    """
    
    def __init__(self, config):
        super().__init__()
        self.self = BertSelfAttention(config)
        self.output = BertSelfOutput(config)
        
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        output_attentions=False,
    ):
        # 首先是self attention层
        self_outputs = self.self(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
        )
        # 然后是全连接层和残差层
        attention_output = self.output(self_outputs[0], hidden_states)
        # 返回attention的输出，如果output_attentions为True的话，self_outptus第二位有attention权重
        # 也就是要么一个结果，要么两个结果
        outputs = (attention_output,) + self_outputs[1:]
        return outputs

In [15]:
class BertIntermediate(nn.Module):
    """
    中间层，attention上面的全连接层的一半
    """
    
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        if isinstance(config.hidden_act, str) and config.hidden_act == 'gelu':
            self.intermediate_act_fn = F.gelu
        else:
            self.intermediate_act_fn = config.hidden_act
            
    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states

In [16]:
class BertOutput(nn.Module):
    """
    中间层，attention上面的全连接的另一半
    """
    
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
    
    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

In [17]:
class BertLayer(nn.Module):
    """
    bert的encoder中的一层
    """
    
    def __init__(self, config):
        super().__init__()
        self.attention = BertAttention(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)
        
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        output_attentions=False,
    ):
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
        )
        # 第一个输出永远是attention output
        attention_output = self_attention_outputs[0]
        # self_attention_outputs 
        # 如果output_attentions为False，则只有第一个输出
        # 如果output_attentions为True，则有两个输出，第二个输出为attention的概率
        outputs = self_attention_outputs[1:]
        
        # 下面全连接层
        intermediate_output = self.intermediate(attention_output)
        # layer_output, 是hidden_states
        layer_output = self.output(intermediate_output, attention_output)
        
        # 打包输出
        # (hidden_states, attention_probs) or (layer_outputs)
        # hidden_states (batch_size, seq_length, hidden_size) eg (64, 128, 768)
        outputs = (layer_output,) + outputs
        return outputs

In [18]:
class BertEncoder(nn.Module):
    """
    bert encoder, bert模型分为三块，一个块是embedding，一块是encoder，一块是pooler
    encoder包含12层的layer，每个layer又包含上下两层
    """
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        # 保存多个layer的方法，nn.ModuleList
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
        
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=False,
    ):
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)
            
            # 如果output_attentions == True 返回(hidden_states, attention_probs)
            # 如果output_attentions == False, 返回(hidden_states)
            # hidden_states (batch_size, seq_length, hidden_size) eg (64, 128, 768)
            layer_outputs = layer_module(
                hidden_states,
                attention_mask=attention_mask,
                output_attentions=output_attentions,
            )
            hidden_states = layer_outputs[0]
            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)
        
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)
        
        # hidden_states: (batch_size, seq_length, hidden_size) eg (64, 128, 768) 最后一层的hidden_states
        # all_hidden_states: (num_hidden_layers+1) * hidden_states = 13 * (64, 128, 768)
        # all_attentions: (batch_size, num_attention_heads, seq_length, seq_length) eg. (64, 12, 128, 128)
        # 默认肯定返回hidden_states，另两个根据参数控制决定是否输出
        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
        )    

In [19]:
class BertPooler(nn.Module):
    """
    bert三个模块：embedding，encoder，pooler。这里是pooler模块， 将cls对应的hidden_states乘以(hidden_size, hidden_size)
    也就是乘以 eg. (768, 768)
    """
    
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()
        
    def forward(self, hidden_states):
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

In [20]:
class BertPreTrainedModel(nn.Module):
    """
    bert最基本的类，实现的功能是通用的功能，也就是初始化权重等通用性功能。
    """
    def __init__(self, config):
        super().__init__()
        
        self.config = config
    
    def init_weights(self):
        """
        初始化权重
        """
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        """
        初始化权重
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        # BertLayerNorm的初始化？有什么权重？
        elif isinstance(module, BertLayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

In [21]:
class BertModel(BertPreTrainedModel):
    """
    bert模型，包含embedding、encoder、pooler三层；其中encoder包含多层layer，每个layer又分为attention层和全连接层
    """
    
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)
        
        self.init_weights()
        
    def forward(
        self,
        # 四个输入相关的参数
        # inputs_embeds与input_ids是二选一的关系
        # input_ids + attention_mask + token_type_ids
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
        # 三个输出相关的参数
        # 是否输出attention；
        # 是否输出hidden states
        # 返回字典格式，或者dataclass格式
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        # 返回格式处理
        # eg: output_attentions False; 
        # eg: output_hidden_states: False;
        # eg: return_dict: False;
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        # 四个输入相关的标准化处理
        # input shape、device的处理；异常的警告
        # input_shape: (batch_size, seq_length) eg. (64, 128)
        if input_ids is not None and inputs_embeds is not None:
            # value error警告用法
            raise ValueError("You cnanot sepcify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError('You have to specify either input_ids or inputs_embeds')
        # device
        device = input_ids.device if input_ids is not None  else inputs_embeds.device
        # 处理attention_mask、token_type_ids
        # attention_mask: (batch_size, seq_length) eg. (64, 128)
        # token_type_ids: (batch_size, seq_length) eg. (64, 128)
        if attention_mask is None:
            attention_mask = torch.ones(input_shape, device=device)
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
        
        # 获取extended_attention_mask (batch_size, num_heads, from_seq_length, to_seq_length)
        # attention_mask: (batch_size, seq_length), input_shap: (batch_size, seq_length), devcie:
        # attention_mask查看维度为3； .dim()
        # 此时为 (batch_size, from_seq_length, to_seq_length)
        # extended_attention_mask eg: (64, 1, 1, 128)
        if attention_mask.dim() == 3:
            extended_attention_mask = attention_mask[:, None, :, :]
        elif attention_mask.dim() == 2:
            extended_attention_mask = attention_mask[:, None, None, :]
        else:
            raise ValueError(
                f'wrong shape for input_ids (shape {input_shape})'
                f'or attention_mask (shape {attention_mask.shape})'
            )
        # 1.0 表示未mask，0.0表示mask
        # 这里处理与之前有点不一样，是将未mak的attention置为0，mask的置为-10000
        # 然后与softmax之前的值相加，结果是一样
        extended_attention_mask = extended_attention_mask.to(device)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        
        # 处理完上面以后，开始真正进入bert了
        # bert可以分成三段：分别是embedding阶段、encoder阶段、pooler阶段
        # 下面这个是embedding阶段
        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
        )
        # 第二阶段 encoder阶段
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 第三阶段 pooler阶段
        sequence_output = encoder_outputs[0] if not return_dict else encoder_outputs.last_hidden_state
        pooler_output = self.pooler(sequence_output)
        
        # 打包返回
        if not return_dict:
            return (sequence_output, pooler_output) + encoder_outputs[1:]
        return BaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooler_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )

#### BERT的各种head

基于上面的BertModel，上面接入各种head，可以完成MLM、NSP的任务，可以完成分类、NER等任务。

这里仅仅是head，并不是完整的模型。

In [22]:
class BertOnlyMLMHead(nn.Module):
    """
    一层MLM的head，与bert基本模型配合使用
    """
    
    def __init__(self, config):
        super().__init__()
        # dense层和layernorm层
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        if isinstance(config.hidden_act, str) and config.hidden_act == 'gelu':
            self.transform_act_fn = F.gelu
        else:
            self.transform_act_fn = config.hidden_act
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 解码层
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        self.decoder.bias = self.bias
        
    def forward(self, hidden_states):
        # dense层和layernorm层
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        # 解码层
        hidden_states = self.decoder(hidden_states)
        return hidden_states    

In [23]:
class BertOnlyNSPHead(nn.Module):
    """
    一层NSP的head，与bert基本模型配合使用
    """
    
    def __init__(self, config):
        super().__init__()
        self.seq_relationship = nn.Linear(config.hidden_size, 2)
    
    def forward(self, pooled_output):
        seq_relationship_score = self.seq_relationship(pooled_output)
        return seq_relationship_score

In [24]:
class BertPreTrainingHeads(nn.Module):
    """
    一层预训练head，包含MLM与NSP任务
    """
    def __init__(self, config):
        super().__init__()
        self.predictions = BertOnlyMLMHead(config)
        self.seq_relationship = nn.Linear(config.hidden_size, 2)
    
    def forward(self, sequence_output, pooled_output):
        prediction_scores = self.predictions(sequence_output)
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score

#### 接入head的各种bert模型

MLM、NSP、预训练、分类、NER等任务的bert模型。

In [25]:
class BertForMaskedLM(BertPreTrainedModel):
    """
    MLM任务
    不含labels输入，会预测
    含labels输入，会返回
    """
    
    def __init__(self, config):
        super().__init__(config)
        
        self.bert = BertModel(config)
        self.cls = BertOnlyMLMHead(config)
        
        self.init_weights()
        
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        # last_hidden_state: 模型最后一层的输出，(batch_size, seq_length, hidden_size) eg. (64, 128, 768)
        if return_dict:
            sequence_output = outputs.last_hidden_state
        else:
            sequence_output = outputs[0]
        prediction_scores = self.cls(sequence_output)
        
        masked_lm_loss = None
        if labels is not None:
            # -100 index = padding token
            loss_fct = CrossEntropyLoss()
            # labels (batch_size, seq_length) eg. (64, 128)
            # labels.view(-1) (batch_size*seq_length) eg. (64*128) = (8192) 
            # prediction_scores (batch_size, seq_length, vocab_size) eg (64, 128, 28996)
            # prediction_scores.view(-1, 2) (batch_size*seq_lenght, vocab_size) eg: (8192, 28996)
            # 交叉熵，将这种格式输入即可。也就是左边是行数，每行有28996个选项的概率，右边是对应的行数，每个是对应的答案。
            # 可以忽视-100的label。
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
        
        if not return_dict:
            # 
            output = (prediction_scores,) + outputs[2:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
        
        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [26]:
class BertForNextSentencePrediction(BertPreTrainedModel):
    """
    NSP任务
    """
    def __init__(self, config):
        super().__init__(config)
        
        self.bert = BertModel(config)
        self.cls = BertOnlyNSPHead(config)
        
        self.init_weights()
    
    def forward(
        self,
        # 输入
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
        # 标签
        next_sentence_label=None,
        # 输出格式
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
#         pdb.set_trace()
        if return_dict:
            pooled_output = outputs.pooler_output
        else:
            pooled_output = outputs[1]
        seq_relationship_scores = self.cls(pooled_output)
        
        next_sentence_loss = None
        if next_sentence_label is not None:
            loss_fct = CrossEntropyLoss()
            next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), next_sentence_label.view(-1))
            
        if return_dict:
            return NextSentencePredictionOutput(
                loss=next_sentence_loss,
                logits=seq_relationship_scores,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )        
        else:
            output = (seq_relationship_scores,) + outputs[2:]
            return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output

In [27]:
class BertForPreTraining(BertPreTrainedModel):
    """
    用于预训练任务，包含MLM任务和NSP任务
    """
    
    def __init__(self, config):
        super().__init__(config)
        
        self.bert = BertModel(config)
        self.cls = BertPreTrainingHeads(config)
        
        self.init_weights()
    
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
        # 两个预训练任务的labels； 第一个是MLM的，第二个是NSP的
        labels=None,
        next_sentence_label=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        bert_outputs = self.bert(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        # bert_output: last_hidden_state; pooler_output; hidden_states; attentions
        # 四个结果中，前两个是必有；后两个根据参数决定是否有
        if return_dict:
            sequence_output, pooled_output = bert_outputs.last_hidden_state, bert_outputs.pooler_output
        else:
            sequence_output, pooled_output = bert_outputs[:2]

        # 得到两个预训练任务的结果
        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)

        total_loss = None
        if labels is not None and next_sentence_label is not None:
            loss_fct = CrossEntropyLoss()
            # labels (batch_size, seq_length) eg. (64, 128)
            # labels.view(-1) (batch_size*seq_length) eg. (64*128) = (8192) 
            # prediction_scores (batch_size, seq_length, vocab_size) eg (64, 128, 28996)
            # prediction_scores.view(-1, 2) (batch_size*seq_lenght, vocab_size) eg: (8192, 28996)
            # 交叉熵，将这种格式输入即可。也就是左边是行数，每行有28996个选项的概率，右边是对应的行数，每个是对应的答案。
            # 可以忽视-100的label。
            masked_lm_loss = loss_fct(prediction_scores.view(-1, config.vocab_size), labels.view(-1))
            # seq_relationship_score (batch_size, 2) eg. (64, 2)
            # next_sentence_label (batch_size), eg. (64)
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
            total_loss = masked_lm_loss + next_sentence_loss
        
        # loss: 1个tensor的值, MLM和NSP的损失之和。
        # prediction_logits: MLM任务得到的logits (batch_size, seq_length, vocab_size) eg. (64, 128, 28996)
        # seq_relationship_logits: NSP任务得到的logits (batch_size, 2) eg. (64, 2)
        # hidden_states: 元组，(num_hidden_layer+1)个，也就是13个，包含embedding的输出和其他所有层的输出
        # attentions: num_hidden_layer * (batch_size, num_heads, seq_length, seq_length) eg. 12 * (64, 12, 128, 128)
        if return_dict:
               return BertForPreTrainingOutput(
                loss=total_loss,
                prediction_logits=prediction_scores,
                seq_relationship_logits=seq_relationship_score,
                hidden_states=bert_outputs.hidden_states,
                attentions=bert_outputs.attentions,
            )         
        else:
            output = (prediction_scores, seq_relationship_score) + bert_outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

In [28]:
class BertForSequenceClassification(BertPreTrainedModel):
    """
    用于文本分类任务
    """
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        
        self.init_weights()
        
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
#         pdb.set_trace()
        
        if return_dict:
            pooled_output = outputs.pooler_output
        else:
            pooled_output = outputs[1]
        # pooler_output: 模型最后一层的cls输出，乘以(hidden_size, hidden_size)后的结果，(batch_size, hidden_size) eg. (64, 768)
        pooled_output = self.dropout(pooled_output)
        # logits: (batch_size, num_labels) eg. (64, 280)
        logits = self.classifier(pooled_output)
        
        loss = None
        if labels is not None:
            if self.num_labels == 1:
                # 做回归任务
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                # logits: (batch_size, num_labels) eg. (64, 280)
                # labels: (batch_size,) eg. (64)
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output
        
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [29]:
class BertForTokenClassification(BertPreTrainedModel):
    """
    用于序列标注任务（ner等）
    """
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        
        self.init_weights()
        
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        if return_dict:
            sequence_output = outputs.last_hidden_state
        else:
            sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        # logits: (batch_size, seq_length, num_labels) eg. (64, 128, 12)
        logits = self.classifier(sequence_output)
        
        loss = None
        # labels (batch_size, seq_length) eg. (64, 128)
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # 仅仅关注未mask的数据
            if attention_mask is not None:
                # attention_mask (batch_size, seq_length) eg (64, 128)
                # attention_mask.veiw(-1) (batch_size*seq_length) (64*128) = (8192)
                # 关注的是1，为True；mask的为0，为False
                active_loss = attention_mask.view(-1) == 1
                # logits: (batch_size, seq_length, num_labels) eg. (64, 128, 12)
                # logits.view(-1, self.num_labels) = (batch_size*seq_length, num_labels) = (8192, 12)
                active_logits = logits.view(-1, self.num_labels)
                # torch.where(condition, x, y), 如果condition某个位置为True，则采用x对应位置的值，否则采用y对应位置的值
                # 这里，就是关注的使用正确的label，不关注的使用-100，也就是交叉熵忽略的值
                # type_as返回同labels相同的类型，也就是FloatTensor类型。
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
            else:
                # 不同于分类，ner任务的标签个数肯定大于1，最常见的是B I O结构
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        
        if return_dict:
            return TokenClassifierOutput(
                loss=loss,
                logits=logits,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )
        else:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output
                

#### 引入transformers中的Config和Tokenizer

In [30]:
pretrained_model_name_or_path = r'E:\models\huggingface\bert-base-cased'
# 12个label
# pdb.set_trace()
config = BertConfig.from_pretrained(pretrained_model_name_or_path,
                                    num_labels=6,)
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path,
                                         config=config)

#### 测试我们的模型

In [31]:
input_ids = torch.tensor([[1, 5, 3, 1], [2, 1, 2, 1]])
position_ids = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.long)
token_type_ids = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0]], dtype=torch.long)
output_hidden_states = True,
output_attentions = True

##### 1. 基本模型测试

In [32]:
for return_dict in [True, False]:
    set_seed(2020)
    model_class = BertModel
    print(f'return_dict: {return_dict}, model_class: {model_class}')
    model = model_class(config)
    output = model(
        input_ids=input_ids,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    if return_dict:
        print(output.last_hidden_state)
    else:
        print(output[0])

return_dict: True, model_class: <class '__main__.BertModel'>
tensor([[[-0.6173,  1.3926, -1.1268,  ..., -1.1984,  0.8287,  0.4786],
         [ 0.1401,  0.0600, -1.2550,  ..., -0.8059,  1.0439,  0.3112],
         [-0.4751,  1.1698, -1.6823,  ..., -0.1699,  0.6303,  1.2419],
         [-0.5519,  0.6194, -1.7218,  ..., -0.9902,  0.5199,  0.5989]],

        [[-1.2104,  0.9546,  0.0531,  ...,  0.1068, -1.0703,  0.5158],
         [-0.1883, -0.1162, -1.1842,  ..., -0.6901,  0.0073, -0.6441],
         [-0.6517,  0.7814,  0.7595,  ...,  0.1150,  0.1251,  0.0587],
         [-1.1482,  0.7543, -0.4053,  ..., -1.6291, -0.4497,  0.2858]]],
       grad_fn=<NativeLayerNormBackward>)
return_dict: False, model_class: <class '__main__.BertModel'>
tensor([[[-0.6173,  1.3926, -1.1268,  ..., -1.1984,  0.8287,  0.4786],
         [ 0.1401,  0.0600, -1.2550,  ..., -0.8059,  1.0439,  0.3112],
         [-0.4751,  1.1698, -1.6823,  ..., -0.1699,  0.6303,  1.2419],
         [-0.5519,  0.6194, -1.7218,  ..., -0.9902

##### 2. 预训练MLM和NSP的合成模型测试

In [33]:
labels = torch.tensor([[1, 5, 3, 1], [2, 1, 2, 1]])
next_sentence_label = torch.tensor([0, 0])

for return_dict in [True, False]:
    set_seed(2020)
    model_class = BertForPreTraining
    print(f'return_dict: {return_dict}, model_class: {model_class}')
    model = model_class(config)
    output = model(
        input_ids=input_ids,
        position_ids=position_ids,
        token_type_ids=token_type_ids,
        labels=labels,
        next_sentence_label=next_sentence_label,
        output_hidden_states=output_hidden_states,
        output_attentions=output_attentions,
        return_dict=return_dict,
    )
    if return_dict:
        print(output.loss)
    else:
        print(output[0])

return_dict: True, model_class: <class '__main__.BertForPreTraining'>
tensor(11.1472, grad_fn=<AddBackward0>)
return_dict: False, model_class: <class '__main__.BertForPreTraining'>
tensor(11.1472, grad_fn=<AddBackward0>)


##### 3. MLM的模型测试

In [34]:
labels = torch.tensor([[1, 5, 3, 1], [2, 1, 2, 1]])

for return_dict in [True, False]:
    set_seed(2020)
    model_class = BertForMaskedLM
    print(f'return_dict: {return_dict}, model_class: {model_class}')
    model = model_class(config)
    output = model(
        input_ids=input_ids,
        position_ids=position_ids,
        token_type_ids=token_type_ids,
        labels=labels,
        output_hidden_states=output_hidden_states,
        output_attentions=output_attentions,
        return_dict=return_dict,
    )
    if return_dict:
        print(output.loss)
    else:
        print(output[0])

return_dict: True, model_class: <class '__main__.BertForMaskedLM'>
tensor(9.9528, grad_fn=<NllLossBackward>)
return_dict: False, model_class: <class '__main__.BertForMaskedLM'>
tensor(9.9528, grad_fn=<NllLossBackward>)


##### 4. NSP的模型测试

In [35]:
next_sentence_label = torch.tensor([0, 0])

for return_dict in [True, False]:
    set_seed(2020)
    model_class = BertForNextSentencePrediction
    print(f'return_dict: {return_dict}, model_class: {model_class}')
    model = model_class(config)
    output = model(
        input_ids=input_ids,
        position_ids=position_ids,
        token_type_ids=token_type_ids,
        next_sentence_label=next_sentence_label,
        output_hidden_states=output_hidden_states,
        output_attentions=output_attentions,
        return_dict=return_dict,
    )
    if return_dict:
        print(output.loss)
    else:
        print(output[0])

return_dict: True, model_class: <class '__main__.BertForNextSentencePrediction'>
tensor(0.7778, grad_fn=<NllLossBackward>)
return_dict: False, model_class: <class '__main__.BertForNextSentencePrediction'>
tensor(0.7778, grad_fn=<NllLossBackward>)


##### 5. 分类模型的测试

In [36]:
labels = torch.tensor([3, 1])
config.num_labels = 12

for return_dict in [True, False]:
    set_seed(2020)
    model_class = BertForSequenceClassification
    print(f'return_dict: {return_dict}, model_class: {model_class}')
    model = model_class(config)
    output = model(
        input_ids=input_ids,
        position_ids=position_ids,
        token_type_ids=token_type_ids,
        labels=labels,
        output_hidden_states=output_hidden_states,
        output_attentions=output_attentions,
        return_dict=return_dict,
    )
    if return_dict:
        print(output.loss)
    else:
        print(output[0])

return_dict: True, model_class: <class '__main__.BertForSequenceClassification'>
tensor(2.5378, grad_fn=<NllLossBackward>)
return_dict: False, model_class: <class '__main__.BertForSequenceClassification'>
tensor(2.5378, grad_fn=<NllLossBackward>)


##### 6. NER模型的测试

In [37]:
labels = torch.tensor([[1, 2, 0, 0], [0, 0, 1, 2]])
config.num_labels = 12

for return_dict in [True, False]:
    set_seed(2020)
    model_class = BertForTokenClassification
    print(f'return_dict: {return_dict}, model_class: {model_class}')
    model = model_class(config)
    output = model(
        input_ids=input_ids,
        position_ids=position_ids,
        token_type_ids=token_type_ids,
        labels=labels,
        output_hidden_states=output_hidden_states,
        output_attentions=output_attentions,
        return_dict=return_dict,
    )
    if return_dict:
        print(output.loss)
    else:
        print(output[0])

return_dict: True, model_class: <class '__main__.BertForTokenClassification'>
tensor(2.4908, grad_fn=<NllLossBackward>)
return_dict: False, model_class: <class '__main__.BertForTokenClassification'>
tensor(2.4908, grad_fn=<NllLossBackward>)
