In [1]:
import json
import multiprocessing
import os
import torch
from torch import nn
import sys
sys.path.append('..')
from d2l_helpers import *

In [2]:
devices = [get_device()]
bert, vocab = load_pretrained_model(
    'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,
    num_layers=2, dropout=0.1, max_len=512, devices=devices)

In [4]:
# 如果出现显存不足错误，请减少“batch_size”。在原始的BERT模型中，max_len=512
batch_size, max_len, num_workers = 512, 128, 4
data_dir = '../data/snli_1.0'
train_set = SNLIBERTDataset(read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(read_snli(data_dir, False), max_len, vocab)
train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
                                   num_workers=num_workers, multiprocessing_context="forkserver")
test_iter = torch.utils.data.DataLoader(test_set, batch_size,
                                  num_workers=num_workers, multiprocessing_context="forkserver")

read 549367 examples
read 9824 examples


In [4]:
class BERTClassifier(nn.Module):
    def __init__(self, bert):
        super(BERTClassifier, self).__init__()
        self.encoder = bert.encoder
        self.hidden = bert.hidden
        self.output = nn.Linear(256, 3)

    def forward(self, inputs):
        tokens_X, segments_X, valid_lens_x = inputs
        encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
        return self.output(self.hidden(encoded_X[:, 0, :]))

In [5]:
net = BERTClassifier(bert)

In [6]:
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,
                         segments_X, valid_lens_x,
                         pred_positions_X, mlm_weights_X,
                         mlm_Y, nsp_y):
    # 前向传播
    _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,
                                  valid_lens_x.reshape(-1),
                                  pred_positions_X)
    # 计算遮蔽语言模型损失
    mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\
    mlm_weights_X.reshape(-1, 1)
    mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)
    # 计算下一句子预测任务的损失
    nsp_l = loss(nsp_Y_hat, nsp_y)
    l = mlm_l + nsp_l
    return mlm_l, nsp_l, l

In [7]:
def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
    print(f'training on {devices}')
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    trainer = torch.optim.Adam(net.parameters(), lr=0.01)
    step, timer = 0, Timer()
    # 遮蔽语言模型损失的和，下一句预测任务损失的和，句子对的数量，计数
    metric = Accumulator(4)
    num_steps_reached = False
    while step < num_steps and not num_steps_reached:
        for tokens_X, segments_X, valid_lens_x, pred_positions_X,\
            mlm_weights_X, mlm_Y, nsp_y in train_iter:
            tokens_X = tokens_X.to(devices[0])
            segments_X = segments_X.to(devices[0])
            valid_lens_x = valid_lens_x.to(devices[0])
            pred_positions_X = pred_positions_X.to(devices[0])
            mlm_weights_X = mlm_weights_X.to(devices[0])
            mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])
            trainer.zero_grad()
            timer.start()
            mlm_l, nsp_l, l = _get_batch_loss_bert(
                net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,
                pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)
            l.backward()
            trainer.step()
            metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)
            timer.stop()
            print(f'iter {step}, MLM loss {metric[0] / metric[3]:.3f}, '
                  f'NSP loss {metric[1] / metric[3]:.3f}')
            step += 1
            if step == num_steps:
                num_steps_reached = True
                break

    print(f'MLM loss {metric[0] / metric[3]:.3f}, '
          f'NSP loss {metric[1] / metric[3]:.3f}')
    print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '
          f'{str(devices)}')

In [8]:
lr, num_epochs = 1e-4, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction='none')
train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)

training on [device(type='mps')]
epoch: 0, loss 0.952, train acc 0.532
epoch: 0, loss 0.884, train acc 0.583
epoch: 0, loss 0.848, train acc 0.609
epoch: 0, loss 0.822, train acc 0.626
epoch: 0, loss 0.802, train acc 0.639
epoch: 0, loss 0.802, train acc 0.639
epoch: 1, loss 0.687, train acc 0.711
epoch: 1, loss 0.684, train acc 0.712
epoch: 1, loss 0.676, train acc 0.716
epoch: 1, loss 0.672, train acc 0.718
epoch: 1, loss 0.667, train acc 0.720
epoch: 1, loss 0.667, train acc 0.720
epoch: 2, loss 0.611, train acc 0.749
epoch: 2, loss 0.611, train acc 0.749
epoch: 2, loss 0.610, train acc 0.749
epoch: 2, loss 0.607, train acc 0.750
epoch: 2, loss 0.605, train acc 0.751
epoch: 2, loss 0.605, train acc 0.751
epoch: 3, loss 0.558, train acc 0.773
epoch: 3, loss 0.561, train acc 0.772
epoch: 3, loss 0.563, train acc 0.772
epoch: 3, loss 0.563, train acc 0.772
epoch: 3, loss 0.562, train acc 0.772
epoch: 3, loss 0.562, train acc 0.772
epoch: 4, loss 0.524, train acc 0.790
epoch: 4, loss 0.