In [2]:
import sys
print(sys.executable) 

/usr/local/Cellar/jupyterlab/2.2.2/libexec/bin/python3.8


In [3]:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import transformers
from transformers import BertConfig,BertModel, BertTokenizer
from transformers.modeling_bert import BertEmbeddings, BertEncoder,BertPooler,BertPreTrainedModel

In [4]:
''' Proposed in the Paper of ACL 2020: Spelling Error Correction with Soft-Masked BERT(2020_ACL)'''

class Soft_Masked_BERT(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        # self.config contains all parameters of Correction_Network Bert network.
        self.config = config
        
        '''1) build layers of Detection_Network'''
        self.enc_bi_gru = torch.nn.GRU(input_size=768, hidden_size=256, dropout=0.2, bidirectional=True)
        self.detection_network_dense_out = torch.nn.Linear(512, 2)
        self.soft_masking_coef_mapping = torch.nn.Linear(512, 1)
        
        '''2) build 3 layers of Correction_Network in BertModel'''
        # embedding layer
        self.embeddings = BertEmbeddings(config)
        # 12-layer multi-head self attention
        self.encoder = BertEncoder(config)
        # pooling-layer BertPooler
        self.pooler = BertPooler(config)
        self.init_weights()
        
        self.mask_embeddings = self.embeddings.word_embeddings.weight[103]  # 此时,mask_embedding张量的形状为(768,)
        
        self.soft_masked_bert_dense_out = torch.nn.Linear(self.config.hidden_size, self.embeddings.word_embeddings.weight.shape[0])
        
        
    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
            See base class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)
             
            
    '''build Detection_Network'''
    def Detection_Network(self, input_embeddings: torch.Tensor, attention_mask: torch.Tensor):
        # input_embeddings:(seq_len, batch_size, embed_size)->(seq_len, batch_size, 768)
        # attention_mask:(batch_size, seq_len)
        # input embedding: every sentence's character's word embedding+position embedding+segment embeddings
        h_0 = torch.zeros(2, input_embeddings.shape[1], 256)
        bi_gru_final_hidden_layer = self.enc_bi_gru(input_embeddings, h_0)[0]
        bi_gru_final_hidden_layer = bi_gru_final_hidden_layer.permute(1,0,2)
        
        # (batch_size, seq_len, 2)
        detection_network_output = self.detection_network_dense_out(bi_gru_final_hidden_layer) 
        # (batch_size, seq_len, 1)
        soft_masking_coefs = torch.nn.functional.sigmoid(self.soft_masking_coef_mapping(bi_gru_final_hidden_layer) ) 
        # (batch_size, seq_len,1)
        attention_mask = attention_mask.unsqueeze(dim=2)
        
    
        attention_mask = (attention_mask != 0) 
        soft_masking_coefs[~attention_mask] = 0
        
        return detection_network_output, soft_masking_coefs
    
    
    '''build Soft Masking Connection'''
    # 在错误探查网络error detection network输出一个句子中每个位置的字符为错误拼写字符的概率之后，利用此概率作为[MASK] embeddings的权重，
    # 而1减去这个概率作为句子中每个字符character的input embeddings的权重，[MASK] embeddings乘以权重的结果再加上input embeddings乘以权重的结果后
    # 所得到的嵌入结果soft-masked embeddings即为之后的错误纠正网络error correction network的输入。
    def Soft_Masking_Connection(self,input_embeddings: torch.Tensor,
               mask_embeddings: torch.Tensor,
               soft_masking_coefs: torch.Tensor):
        
        # 此时输入Soft_Masking_Connection模块中:
        # input_embeddings张量形状为:(batch_size, seq_len, embed_size)->(batch_size, seq_len, 768);
        # mask_embeddings为只包含"遮罩特殊符[MASK]"的embedding嵌入的张量,其形状也为:(batch_size, seq_len, embed_size)->(batch_size, seq_len, 768)；
        # soft_masking_coefs张量可被称为：soft-masking系数张量, 其为计算soft-masked embeddings时和mask_embeddings相乘的系数p的张量,形状为(batch_size, seq_len, 1);
        # 输入模型起始处的嵌入张量input embedding由一句sentence中每个character的word embedding、position embedding、segment embeddings三者相加而成.
        
        # 得到soft-masking系数张量:soft_masking_coefs张量之后,利用soft_masking_coefs张量作为[MASK] embeddings的权重，
        # 而1减去这个概率作为句子中每个字符character的input embeddings的权重，[MASK] embeddings乘以权重的结果再加上
        # input embeddings乘以权重的结果后所得到的嵌入结果soft-masked embeddings即为之后的错误纠正网络error correction network的输入.
        # 此时soft_masked_embeddings形状也为(batch_size, seq_len, embed_size)->(batch_size, seq_len, 768),
        soft_masked_embeddings = soft_masking_coefs * mask_embeddings + (1 - soft_masking_coefs) * input_embeddings
        
        return soft_masked_embeddings
        
        
    
    '''forward函数.'''
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None,
                head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None,
                output_attentions=None,):
        
        # 利用张量的long()函数确保这些张量全为int型张量.
        input_ids = input_ids.long()
        attention_mask = attention_mask.long()
        token_type_ids = token_type_ids.long()
        position_ids = position_ids.long()

        
        '''以下部分为transformers库中BertModel类中的forward()部门的一小部分源码, 放在此处是为了和源BertModel类保持一致防止出错.'''
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        if attention_mask is None:
            attention_mask = torch.ones(input_shape, device=device)
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)

        # If a 2D ou 3D attention mask is provided for the cross-attention
        # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_extended_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
        '''以上部分为transformers库中BertModel类中的forward()部门的一小部分源码, 放在此处是为了和源BertModel类保持一致防止出错.'''
        
        
        
        
        # 输入模型起始处的嵌入张量input embedding由一句sentence中每个character的word embedding、segment embeddings、position embedding三者相加而成。 
        # 此时input_embeddings张量的形状为(batch_size, seq_len, embed_size)->(batch_size, seq_len, 768),
        # 应将input_embeddings张量的第一第二维度互换, 将其形状变为(seq_len, batch_size, embed_size)->(seq_len, batch_size, 768)才方便输入进
        # 后方的错误探查网络Detection_Network中的Bi-GRU网络中(双向GRU).
        input_embeddings = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds)
        # 形状变为(seq_len, batch_size, embed_size)->(seq_len, batch_size, 768).
        input_embeddings = input_embeddings.permute(1,0,2)
        
        
        # (1)错误探查网络Detection_Network中的双向GRU编码层的输出为(seq_len, batch_size, enc_hid_size * 2),
        # 将其交换维度变形为(batch_size, seq_len, enc_hid_size * 2),再将双向GRU编码层的变形后的输出输入self.detection_network_dense_out层中,
        # 映射为形状(batch_size, seq_len, 2)的张量detection_network_output, 这样方便后面进行判断句子序列中每一个字符是否为拼写错误字符的二分类任务的交叉熵损失值计算.
        # (2)此时soft_masking_coefs张量可被称为：soft-masking系数张量, 其形状为(batch_size, seq_len, 1).
        detection_network_output, soft_masking_coefs = self.Detection_Network(input_embeddings=input_embeddings, attention_mask=attention_mask)
        
        
        # 此时需再将input_embeddings张量的第一第二维度交换, 将其形状再变回(batch_size, seq_len, embed_size)->(batch_size, seq_len, 768),
        # 这样input_embeddings张量才方便输入进self.soft.masking_connection模块中计算soft_masked_embeddings.
        input_embeddings = input_embeddings.permute(1,0,2)
        
        # 遮罩特殊符[MASK]的张量self.mask_embedding的形状要变为和Bert模型嵌入层BertEmbeddings()的输出input_embeddings张量的形状一样,
        # 此时self.mask_embeddings张量的形状要为(batch_size, seq_len, embed_size)->(batch_size, seq_len, 768).
        self.mask_embeddings = self.mask_embeddings.unsqueeze(0).unsqueeze(0).repeat(1,input_embeddings.shape[1],1).repeat(input_embeddings.shape[0],1,1)
        
        
        # 在错误探查网络detection network输出一个句子中每个位置的字符为错误拼写字符的概率之后，利用此概率作为[MASK] embeddings的权重，
        # 而1减去这个概率作为句子中每个字符character的input embeddings的权重，[MASK] embeddings乘以权重的结果再加上input embeddings乘以权重的结果后
        # 所得到的嵌入结果soft-masked embeddings即为之后的拼写错误纠正网络correction network的输入。
        soft_masked_embeddings = self.Soft_Masking_Connection(input_embeddings=input_embeddings, mask_embeddings=self.mask_embeddings,
                                                             soft_masking_coefs=soft_masking_coefs)
        
        
        '''拼写错误纠正网络Correction_Network'''
        '''soft_masked_embeddings输入错误纠正网络correction network的Bert模型后的结果经过最后的输出层与Softmax层后，
        即为句子中每个位置的字符经过错误纠正网络correction network计算后预测的正确字符索引结果的概率。'''
        
        '''注意: 最新版本的transformers.modeling_bert中的BertEncoder()类中forward()方法所需传入的参数中不再有output_attentions这个参数.'''
        encoder_outputs = self.encoder(soft_masked_embeddings, 
                                       attention_mask=extended_attention_mask, 
                                       head_mask=head_mask, 
                                       encoder_hidden_states=encoder_hidden_states, 
                                       encoder_attention_mask=encoder_extended_attention_mask,)
        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output)
        
        # add hidden_states and attentions if they are here
        # outputs为一个包含四个元素的tuple：sequence_output, pooled_output, (hidden_states), (attentions)
        outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
        
        
        
        # outputs[0]代表Bert模型中最后一个隐藏层的输出(此时Bert模型中的隐藏层有12层,即num_hidden_layers参数为12),
        # 注意此处和循环神经网络的输出形状不同,循环网络隐藏层状态的输出为(seq_len, batch_size, bert_hidden_size)，
        # 此时outputs[0]的张量bert_output_final_hidden_layer的形状为(batch_size, seq_len, bert_hidden_size)—>(batch_size, seq_len, 768).
        bert_output_final_hidden_layer = outputs[0]
        
        # 注意!: 在soft_masked_embeddings输入拼写错误纠正网络correction network中的Bert模型后,其计算结果输入进最终的输出层与Softmax层之前，
        # 拼写错误纠正网络correction network的结果需通过残差连接residual connection与输入模型一开始的input embeddings相加，
        # 相加的结果才输入最终的输出层与Softmax层中做最终的正确字符预测。
        residual_connection_outputs = bert_output_final_hidden_layer + input_embeddings
    
        
        '''self.soft_masked_bert_dense_out即为拼写错误纠正网络correction network之后的输出层, 其会将经过残差连接模块residual connection之后
           的输出的维度由768投影到纠错词表的索引空间. (此处输出层self.soft_masked_bert_dense_out的输出final_outputs张量即可被视为Soft_Masked_BERT模型的最终输出).'''
        final_outputs = self.soft_masked_bert_dense_out(residual_connection_outputs)        
        
        
        # 此处输出层self.soft_masked_bert_dense_out的输出final_outputs张量即可被视为Soft_Masked_BERT模型的最终输出.
        return final_outputs
      

In [9]:
# test
config = BertConfig.from_pretrained("./bert_chinese_model/bert_config.json")
tokenizer = BertTokenizer.from_pretrained('./bert_chinese_model/vocab.txt')
soft_masked_bert = Soft_Masked_BERT.from_pretrained("./bert_chinese_model/pytorch_model.bin", config=config)

text = '上海的填空美'
token = tokenizer.tokenize(text)
ids = tokenizer.convert_tokens_to_ids(token)
input_ids = torch.Tensor([ids]).long()
# input_ids = torch.Tensor([[101,768,867,117,102,0]]).long()
attention_mask = torch.Tensor([[1,1,1,1,1,0]]).long()
token_type_ids = torch.Tensor([[0,0,0,0,0,0]]).long()
position_ids = torch.Tensor([[0,1,2,3,4,5]]).long()

output = soft_masked_bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids)
print(output.shape)
output

Calling BertTokenizer.from_pretrained() with the path to a single file or url is deprecated
Some weights of Soft_Masked_BERT were not initialized from the model checkpoint at ./bert_chinese_model/pytorch_model.bin and are newly initialized: ['enc_bi_gru.weight_ih_l0', 'enc_bi_gru.weight_hh_l0', 'enc_bi_gru.bias_ih_l0', 'enc_bi_gru.bias_hh_l0', 'enc_bi_gru.weight_ih_l0_reverse', 'enc_bi_gru.weight_hh_l0_reverse', 'enc_bi_gru.bias_ih_l0_reverse', 'enc_bi_gru.bias_hh_l0_reverse', 'detection_network_dense_out.weight', 'detection_network_dense_out.bias', 'soft_masking_coef_mapping.weight', 'soft_masking_coef_mapping.bias', 'soft_masked_bert_dense_out.weight', 'soft_masked_bert_dense_out.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


torch.Size([1, 6, 21128])


tensor([[[ 0.9036,  0.2973, -0.3927,  ..., -0.0348,  0.7091, -0.8938],
         [ 0.2234, -0.5968, -0.2733,  ..., -0.9220,  0.5116, -0.6441],
         [ 0.8974, -0.2807,  0.1874,  ..., -0.1327,  0.5778, -1.0925],
         [ 2.0729, -1.0279, -0.9150,  ..., -1.0137,  0.4178,  0.4004],
         [ 1.4600, -0.5652,  0.6940,  ...,  0.5275,  0.7452, -0.6747],
         [ 0.7080, -0.9226, -0.4265,  ...,  0.8992, -0.6006,  0.1721]]],
       grad_fn=<AddBackward0>)

In [27]:
words = []
for i in output[0]:
    ids = torch.argmax(i)
    print(ids)
    tokens = tokenizer.convert_ids_to_tokens([ids])
    string = tokenizer.convert_tokens_to_string(tokens)
    words.append(string)
text = "".join(words)

tensor(11286)
tensor(1228)
tensor(8451)
tensor(20571)
tensor(14735)
tensor(9943)


In [26]:
text

'##more労love##頁##嘿##hy'