Reference:
1. [How to Code BERT Using PyTorch](https://neptune.ai/blog/how-to-code-bert-using-pytorch-tutorial)

In [1]:
maxlen = 30  
batch_size = 6  
max_pred = 5  
n_layers = 6  
n_heads = 12  
d_model = 768  
d_ff = 768*4  
d_k = d_v = 64  
n_segments = 2  
num_epochs = 40

In [2]:
import re
import numpy as np

text = (
    'Hello, how are you? I am Romeo.\n' 
    'Hello, Romeo My name is Juliet. Nice to meet you.\n'
    'Nice to meet you too. How are you today?\n'
    'Great. My baseball team won the competition.\n'
    'Oh Congratulations, Juliet\n'
    'Thank you Romeo'
)
sentences = re.sub("[.,!?-]", '', text.lower()).split('\n')  # filter '.', ',', '?', '!'
word_list = list(set(" ".join(sentences).split()))
word_dict = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}
for i, w in enumerate(word_list):
   word_dict[w] = i + 4
   number_dict = {i: w for i, w in enumerate(word_dict)}
vocab_size = len(word_dict)
number_dict, vocab_size

({0: '[PAD]',
  1: '[CLS]',
  2: '[SEP]',
  3: '[MASK]',
  4: 'thank',
  5: 'great',
  6: 'oh',
  7: 'i',
  8: 'how',
  9: 'today',
  10: 'are',
  11: 'name',
  12: 'competition',
  13: 'you',
  14: 'nice',
  15: 'baseball',
  16: 'won',
  17: 'to',
  18: 'congratulations',
  19: 'too',
  20: 'romeo',
  21: 'my',
  22: 'hello',
  23: 'juliet',
  24: 'is',
  25: 'team',
  26: 'am',
  27: 'the',
  28: 'meet'},
 29)

In [4]:
token_list = list()
for sentence in sentences:
   arr = [word_dict[s] for s in sentence.split()]
   token_list.append(arr)
   # print(sentence)
   # print(arr)

token_list

[22, 8, 10, 13, 7, 26, 20]
[22, 20, 21, 11, 24, 23, 14, 17, 28, 13]
[14, 17, 28, 13, 19, 8, 10, 13, 9]
[5, 21, 15, 25, 16, 27, 12]
[6, 18, 23]
[4, 13, 20]


[[22, 8, 10, 13, 7, 26, 20],
 [22, 20, 21, 11, 24, 23, 14, 17, 28, 13],
 [14, 17, 28, 13, 19, 8, 10, 13, 9],
 [5, 21, 15, 25, 16, 27, 12],
 [6, 18, 23],
 [4, 13, 20]]

In [7]:
from random import randrange, shuffle, randint, random


def make_batch():
   batch = []
   positive = negative = 0
   while positive != batch_size/2 or negative != batch_size/2:
       tokens_a_index, tokens_b_index= randrange(len(sentences)), randrange(len(sentences))

       tokens_a, tokens_b= token_list[tokens_a_index], token_list[tokens_b_index]

       input_ids = [word_dict['[CLS]']] + tokens_a + [word_dict['[SEP]']] + tokens_b + [word_dict['[SEP]']]
       segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)

       # MASK LM
       n_pred =  min(max_pred, max(1, int(round(len(input_ids) * 0.15)))) # 15 % of tokens in one sentence
       cand_maked_pos = [i for i, token in enumerate(input_ids)
                         if token != word_dict['[CLS]'] and token != word_dict['[SEP]']]
       shuffle(cand_maked_pos)
       masked_tokens, masked_pos = [], []
       for pos in cand_maked_pos[:n_pred]:
           masked_pos.append(pos)
           masked_tokens.append(input_ids[pos])
           if random() < 0.8:  # 80%
               input_ids[pos] = word_dict['[MASK]'] # make mask
           elif random() < 0.5:  # 10%
               index = randint(0, vocab_size - 1) # random index in vocabulary
               input_ids[pos] = word_dict[number_dict[index]] # replace

       # Zero Paddings
       n_pad = maxlen - len(input_ids)
       input_ids.extend([0] * n_pad)
       segment_ids.extend([0] * n_pad)

       # Zero Padding (100% - 15%) tokens
       if max_pred > n_pred:
           n_pad = max_pred - n_pred
           masked_tokens.extend([0] * n_pad)
           masked_pos.extend([0] * n_pad)

       if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:
           batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext
           positive += 1
       elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
           batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext
           negative += 1
   return batch

In [8]:
import torch
import torch.utils.data as Data

batch = make_batch()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)
input_ids, segment_ids, masked_tokens, masked_pos, isNext = (
    torch.LongTensor(input_ids),
    torch.LongTensor(segment_ids),
    torch.LongTensor(masked_tokens),
    torch.LongTensor(masked_pos),
    torch.LongTensor(isNext)
)


class MyDataSet(Data.Dataset):
    def __init__(self, input_ids, segment_ids, masked_tokens, masked_pos, isNext):
        self.input_ids = input_ids
        self.segment_ids = segment_ids
        self.masked_tokens = masked_tokens
        self.masked_pos = masked_pos
        self.isNext = isNext

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.segment_ids[idx], self.masked_tokens[idx], self.masked_pos[idx], self.isNext[idx]
    
loader = Data.DataLoader(MyDataSet(input_ids, segment_ids, masked_tokens, masked_pos, isNext), batch_size, True)

In [5]:
def get_attn_pad_mask(seq_q, seq_k):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)   # batch_size x 1 x len_k(=len_q), one is masking
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # batch_size x len_q x len_k


In [6]:
import torch

# 定义查询序列和键序列
seq_q = torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]])  # batch_size x len_q
seq_k = torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]])  # batch_size x len_k

# 定义 get_attn_pad_mask 函数
def get_attn_pad_mask(seq_q, seq_k):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)   # batch_size x 1 x len_k(=len_q), one is masking
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # batch_size x len_q x len_k

# 生成填充掩码
pad_mask = get_attn_pad_mask(seq_q, seq_k)

print("查询序列 (seq_q):")
print(seq_q)
print("\n键序列 (seq_k):")
print(seq_k)
print("\n填充掩码 (pad_mask):")
print(pad_mask)

查询序列 (seq_q):
tensor([[1, 2, 3, 0],
        [4, 5, 0, 0]])

键序列 (seq_k):
tensor([[1, 2, 3, 0],
        [4, 5, 0, 0]])

填充掩码 (pad_mask):
tensor([[[False, False, False,  True],
         [False, False, False,  True],
         [False, False, False,  True],
         [False, False, False,  True]],

        [[False, False,  True,  True],
         [False, False,  True,  True],
         [False, False,  True,  True],
         [False, False,  True,  True]]])


In [None]:
import torch.nn as nn


def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / np.sqrt(2.0)))
                      

class Embedding(nn.Module):
    def __init__(self):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(maxlen, d_model)
        self.seg_embed = nn.Embedding(n_segments, d_model)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, input_ids, segment_ids):
        seq_len = input_ids.size(1)
        pos = torch.arange(seq_len, dtype=torch.long)
        pos = pos.unsqueeze(0).expand_as(input_ids)
        embedding = self.tok_embed(input_ids) + self.pos_embed(pos) + self.seg_embed(segment_ids)
        return self.norm(embedding)
    

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)
    def forward(self, Q, K, V, attn_mask):
        # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
        residual, batch_size = Q, Q.size(0)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)
        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)
        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)
        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)

        # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)
        output = nn.Linear(n_heads * d_v, d_model)(context)
        return nn.LayerNorm(d_model)(output + residual), attn
    

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, inputs):
        return self.fc2(gelu(self.fc1(inputs)))
    

class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]
        return enc_outputs, attn
    

class BERT(nn.Module):
    def __init__(self):
        super(BERT, self).__init__()
        self.embedding = Embedding()
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Dropout(0.5),
            nn.Tanh(),
        )
        self.classifier = nn.Linear(d_model, 2)
        self.linear = nn.Linear(d_model, d_model)
        self.activ2 = gelu
        embed_weight = self.embedding.tok_embed.weight
        self.fc2 = nn.Linear(d_model, vocab_size, bias=False)
        self.fc2.weight = embed_weight

    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)
        h_pooled = self.fc(output[:, 0])
        logits_clsf = self.classifier(h_pooled)

        masked_pos = masked_pos[:, :, None].expand(-1, -1, d_model)  # [batch_size, max_pred, d_model]
        h_masked = torch.gather(output, 1, masked_pos)
        h_masked = self.activ2(self.linear(h_masked))
        logits_lm = self.fc2(h_masked)
        
        return logits_lm, logits_clsf
    
model = BERT()
criteria = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [None]:
for epoch in range(num_epochs):
    for input_ids, segment_ids, masked_tokens, masked_pos, isNext in loader:
        logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)
        loss_lm = criteria(logits_lm.transpose(1, 2), masked_tokens) # for masked LM
        loss_lm = (loss_lm.float()).mean()
        loss_clsf = criteria(logits_clsf, isNext) # for sentence classification
        loss = loss_lm + loss_clsf
        if (epoch + 1) % 10 == 0:
            print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# save model
# torch.save(model.state_dict(), 'bert.model')

Epoch: 0010 cost = 1.043114
Epoch: 0020 cost = 0.914230
Epoch: 0030 cost = 0.926474
Epoch: 0040 cost = 0.865482


In [14]:
input_ids, segment_ids, masked_tokens, masked_pos, isNext = batch[1]
print(text)
print([number_dict[w] for w in input_ids if number_dict[w] != '[PAD]'])

logits_lm, logits_clsf = model(torch.LongTensor([input_ids]), torch.LongTensor([segment_ids]), torch.LongTensor([masked_pos]))
logits_lm = logits_lm.data.max(2)[1][0].data.numpy()
print('masked tokens list : ', [pos for pos in masked_tokens if pos != 0])
print('predict masked tokens list : ', [pos for pos in logits_lm if pos != 0])

logits_clsf = logits_clsf.data.max(1)[1].data.numpy()[0]
print('isNext : ', True if isNext else False)
print('predict isNext : ', True if logits_clsf else False)

Hello, how are you? I am Romeo.
Hello, Romeo My name is Juliet. Nice to meet you.
Nice to meet you too. How are you today?
Great. My baseball team won the competition.
Oh Congratulations, Juliet
Thank you Romeo
['[CLS]', 'great', '[MASK]', 'baseball', 'team', 'won', '[MASK]', 'competition', '[SEP]', 'hello', 'how', 'are', '[MASK]', 'i', 'am', 'romeo', '[SEP]']
masked tokens list :  [14, 4, 17]
predict masked tokens list :  [14, 4, 7]
isNext :  False
predict isNext :  False
