In [1]:
from pickle import load
import numpy as np
from torch import nn
import torch
from torch.utils.data import DataLoader
from torch.nn.functional import cross_entropy,softmax, relu

import utils
from GPT import GPT
import os
import pickle

In [2]:
MASK_RATE=0.15

In [3]:
class BERT(GPT):
    def __init__(self,model_dim,max_len,num_layer,num_head,n_vocab,lr,
                 max_seg=3,drop_rate=0.2,padding_idx=0) -> None:
        # 在 super() 函数中调用父类的 __init__ 方法，传递相应的参数。这样可以初始化继承自父类的属性和执行父类中的初始化逻辑
        super(BERT, self).__init__(model_dim,max_len,num_layer,num_head,n_vocab,lr,max_seg,drop_rate,padding_idx)

    def step(self, seqs, segs, seqs_, loss_mask, nsp_labels):
        """
        执行一次前向传播、计算损失、反向传播和参数更新的过程
        :param seqs: 输入序列
        :param segs: 分段序列
        :param seqs_: 目标序列
        :param loss_mask: 损失掩码
        :param nsp_labels: 下一个句预测标签
        :return: 总损失loss和MLM预测结果
        """
        device=next(self.parameters()).device
        self.opt.zero_grad()
        mlm_logits,nsp_logits=self(seqs,segs,training=True) # 调用forward函数执行前向传播
        mlm_loss=cross_entropy(
            # 计算mlm损失，使用torch.masked_select根据损失掩码loss_mask选择有效位置的预测结果和目标序列进行交叉熵计算
            torch.masked_select(mlm_logits,loss_mask).reshape(-1,mlm_logits.shape[2]),
            torch.masked_select(seqs_,loss_mask.squeeze(2))
        )
        # 计算nsp损失
        nsp_loss=cross_entropy(nsp_logits,nsp_labels.reshape(-1))
        loss=mlm_loss+0.2*nsp_loss
        loss.backward()
        self.opt.step()
        return loss.cpu().data.numpy(),mlm_logits

    def mask(self, seqs):
        # 生成序列的掩码矩阵，以标识序列中的填充位置
        # 通过比较序列中的元素与填充索引 self.padding_idx 是否相等，生成一个布尔类型的掩码矩阵
        # 返回的掩码矩阵形状为 [batch_size, 1, 1, seq_len]，其中 batch_size 是序列的批大小，seq_len 是序列的长度。生成的掩码矩阵会在后续的注意力计算中使用，用于屏蔽填充位置的影响。
        mask = torch.eq(seqs,self.padding_idx)
        return mask[:, None, None, :]

# torch.masked_select()

`torch.masked_select` 是一个用于按照指定掩码从张量中选择元素的函数。它的作用是根据给定的掩码，在输入张量中选择满足条件的元素。

具体来说，对于第一行代码 `torch.masked_select(mlm_logits, loss_mask).reshape(-1, mlm_logits.shape[2])`：

- `mlm_logits` 是 MLM 模型的预测结果，它是一个形状为 `[batch_size, seq_len, n_vocab]` 的张量，表示每个位置对应每个词的预测概率。
- `loss_mask` 是损失掩码，它是一个形状为 `[batch_size, 1, seq_len, seq_len]` 的布尔类型张量，用于标识填充位置。
- `torch.masked_select(mlm_logits, loss_mask)` 使用 `loss_mask` 对 `mlm_logits` 进行掩码选择，即只选择掩码为 True 的位置对应的元素。
- `reshape(-1, mlm_logits.shape[2])` 将选择的元素进行形状重塑，将第二维的大小调整为 `mlm_logits.shape[2]`，也就是词汇表大小，以适应后续计算交叉熵损失的需要。

对于第二行代码 `torch.masked_select(seqs_, loss_mask.squeeze(2))`：

- `seqs_` 是目标序列，形状为 `[batch_size, seq_len]`，表示每个位置的真实词索引。
- `loss_mask.squeeze(2)` 对 `loss_mask` 进行压缩操作，将第三维的大小为 1 的维度压缩去除，得到形状为 `[batch_size, 1, seq_len]` 的张量。
- `torch.masked_select(seqs_, loss_mask.squeeze(2))` 使用 `loss_mask.squeeze(2)` 对 `seqs_` 进行掩码选择，即只选择掩码为 True 的位置对应的目标词索引。

这两行代码的目的是根据掩码选择模型的预测结果和目标序列，以便后续计算 MLM 的交叉熵损失。

In [6]:
def _get_loss_mask(len_arange, seq, pad_id):
    """生成用于掩码语言建模任务的掩码数组,生成的掩码数组可以用于在训练过程中计算 MLM 的损失函数"""
    rand_id = np.random.choice(len_arange, size=max(2, int(MASK_RATE * len(len_arange))), replace=False)    #从序列长度范围 len_arange 中随机选择一些索引作为要进行掩码的位置
    loss_mask = np.full_like(seq, pad_id, dtype=np.bool)    #创建一个与输入序列 seq 形状相同的数组 loss_mask，并将其填充为布尔类型的 pad_id
    loss_mask[rand_id] = True   #将选择的掩码位置在 loss_mask 中标记为 True，表示这些位置要进行掩码
    return loss_mask[None, :], rand_id  #返回形状为 [1, seq_len] 的掩码数组 loss_mask，以及选择的掩码位置索引 rand_id

def do_mask(seq, len_arange, pad_id, mask_id):
    """直接用特定的mask_id来mask掩码的位置"""
    loss_mask, rand_id = _get_loss_mask(len_arange, seq, pad_id)
    seq[rand_id] = mask_id
    return loss_mask

def do_replace(seq, len_arange, pad_id, word_ids):
    """在输入序列 seq 中随机替换一些位置的单词或标记"""
    loss_mask, rand_id = _get_loss_mask(len_arange, seq, pad_id)    #生成掩码数组 loss_mask，用于指示要进行替换的位置
    seq[rand_id] = torch.from_numpy(np.random.choice(word_ids, size=len(rand_id))).type(torch.IntTensor)    #从给定的 word_ids 中随机选择与掩码位置数目相同的单词或标记，并将其替换到输入序列中对应的位置上
    return loss_mask

def do_nothing(seq, len_arange, pad_id):
    loss_mask, _ = _get_loss_mask(len_arange, seq, pad_id)
    return loss_mask

In [7]:
def random_mask_or_replace(data,arange,dataset):
    """在训练过程中对输入数据进行随机的掩码、保持不变或替换操作，增加数据的多样性和模型的泛化能力"""
    seqs, segs,xlen,nsp_labels = data
    seqs_ = seqs.data.clone()
    p = np.random.random()
    if p < 0.7:
        # mask
        loss_mask = np.concatenate([
            do_mask(
                seqs[i],
                np.concatenate((arange[:xlen[i,0]],arange[xlen[i,0]+1:xlen[i].sum()+1])),
                dataset.pad_id,
                dataset.mask_id
                )
                for i in range(len(seqs))], axis=0)
    elif p < 0.85:
        # do nothing
        loss_mask = np.concatenate([
            do_nothing(
                seqs[i],
                np.concatenate((arange[:xlen[i,0]],arange[xlen[i,0]+1:xlen[i].sum()+1])),
                dataset.pad_id
                )
                for i in range(len(seqs))],  axis=0)
    else:
        # replace
        loss_mask = np.concatenate([
            do_replace(
                seqs[i],
                np.concatenate((arange[:xlen[i,0]],arange[xlen[i,0]+1:xlen[i].sum()+1])),
                dataset.pad_id,
                dataset.word_ids
                )
                for i in range(len(seqs))],  axis=0)
    loss_mask = torch.from_numpy(loss_mask).unsqueeze(2)
    return seqs, segs, seqs_, loss_mask, xlen, nsp_labels

In [9]:
def export_attention(model,device,data,name="bert"):
    model.load_state_dict(torch.load("./visual/models/bert/model.pth",map_location=device))
    seqs, segs,xlen,nsp_labels = data[:32]
    seqs, segs,xlen,nsp_labels = torch.from_numpy(seqs),torch.from_numpy(segs),torch.from_numpy(xlen),torch.from_numpy(nsp_labels)
    seqs, segs,nsp_labels = seqs.type(torch.LongTensor).to(device), segs.type(torch.LongTensor).to(device),nsp_labels.to(device)
    model(seqs,segs,False)
    seqs = seqs.cpu().data.numpy()
    data = {"src": [[data.i2v[i] for i in seqs[j]] for j in range(len(seqs))], "attentions": model.attentions}
    path = "./visual/tmp/%s_attention_matrix.pkl" % name
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "wb") as f:
        pickle.dump(data, f)

In [8]:
def train():
    MODEL_DIM = 256
    N_LAYER = 4
    LEARNING_RATE = 1e-4
    dataset = utils.MRPCData("./MRPC",2000)
    print("num word: ",dataset.num_word)
    model = BERT(
        model_dim=MODEL_DIM, max_len=dataset.max_len, num_layer=N_LAYER, num_head=4, n_vocab=dataset.num_word,
        lr=LEARNING_RATE, max_seg=dataset.num_seg, drop_rate=0.2, padding_idx=dataset.pad_id
    )
    if torch.cuda.is_available():
        print("GPU train avaliable")
        device =torch.device("cuda")
        model = model.cuda()
    else:
        device = torch.device("cpu")
        model = model.cpu()

    loader = DataLoader(dataset,batch_size=32,shuffle=True)
    arange = np.arange(0,dataset.max_len)
    for epoch in range(500):
        for batch_idx, batch in enumerate(loader):
            seqs, segs, seqs_, loss_mask, xlen, nsp_labels = random_mask_or_replace(batch,arange,dataset)
            seqs, segs, seqs_, nsp_labels, loss_mask = seqs.type(torch.LongTensor).to(device), segs.type(torch.LongTensor).to(device),seqs_.type(torch.LongTensor).to(device),nsp_labels.to(device),loss_mask.to(device)
            loss, pred = model.step(seqs, segs, seqs_, loss_mask, nsp_labels)
            if batch_idx % 100 == 0:
                pred = pred[0].cpu().data.numpy().argmax(axis=1)
                print(
                "\n\nEpoch: ",epoch,
                "|batch: ", batch_idx,
                "| loss: %.3f" % loss,
                "\n| tgt: ", " ".join([dataset.i2v[i] for i in seqs[0].cpu().data.numpy()[:xlen[0].sum()+1]]),
                "\n| prd: ", " ".join([dataset.i2v[i] for i in pred[:xlen[0].sum()+1]]),
                "\n| tgt word: ", [dataset.i2v[i] for i in (seqs_[0]*loss_mask[0].view(-1)).cpu().data.numpy() if i != dataset.v2i["<PAD>"]],
                "\n| prd word: ", [dataset.i2v[i] for i in pred*(loss_mask[0].view(-1).cpu().data.numpy()) if i != dataset.v2i["<PAD>"]],
                )
    os.makedirs("./visual/models/bert",exist_ok=True)
    torch.save(model.state_dict(),"./visual/models/bert/model.pth")
    export_attention(model,device,dataset)

In [10]:
if __name__ == "__main__":
    train()

num word:  12880
GPU train avaliable


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  loss_mask = np.full_like(seq, pad_id, dtype=np.bool)    #创建一个与输入序列 seq 形状相同的数组 loss_mask，并将其填充为布尔类型的 pad_id




Epoch:  0 |batch:  0 | loss: 9.824 
| tgt:  <GO> rumsfeld , who has been feuding for two years with army leadership , passed over nine active-duty four-star generals . <SEP> rumsfeld has been feuding for a long time with army leadership , and he passed over nine active-duty four-star generals 
| prd:  solutia tarantella prepares byers dementia vaudevillians permit relationship consumed cunha walker step utilities walker dementia blocks loan related specific dropping lesser detection surged folders realized abetted dreams status giadone nevada milledge initiative entree realized bishops within hoped realized columbus scoffed issues unofficial 
| tgt word:  ['years', 'passed', '<SEP>', 'over', 'active-duty', 'generals'] 
| prd word:  ['cunha', 'dementia', 'detection', 'realized', 'scoffed', 'unofficial']


Epoch:  1 |batch:  0 | loss: 7.764 
| tgt:  <GO> shares of <MASK> interactive rose $ <NUM> , or <NUM> percent , to $ <MASK> on <MASK> in nasdaq <MASK> market composite trading and ha

KeyboardInterrupt: 