### Dataset preparation

In [None]:
import re
import torch
import numpy as np
import random
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

text = (
    'Hello, how are you? I am Romeo.\n' # R
    'Hello, Romeo My name is Juliet. Nice to meet you.\n' # J
    'Nice meet you too. How are you today?\n' # R
    'Great. My baseball team won the competition.\n' # J
    'Oh Congratulations, Juliet\n' # R
    'Thank you Romeo\n' # J
    'Where are you going today?\n' # R
    'I am going shopping. What about you?\n' # J
    'I am going to visit my grandmother. she is not very well' # R
)

sentences = re.sub("[.,!?\\-]", '', text.lower()).split('\n')
word_list = list(set(" ".join(sentences).split(" ")))
word2idx = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}
for i, word in enumerate(word_list):
    word2idx[word] = i + 4
idx2word = {i: word for word, i in word2idx.items()}
vocab_size = len(word2idx)

token_list = list()
for sentence in sentences:
    arr = [word2idx[word] for word in sentence.split(" ")]
    token_list.append(arr)

token_list

[[28, 26, 30, 19, 20, 35, 5],
 [28, 5, 18, 33, 22, 27, 14, 17, 15, 19],
 [14, 15, 19, 4, 26, 30, 19, 29],
 [9, 18, 34, 23, 36, 7, 24],
 [31, 38, 27],
 [11, 19, 5],
 [21, 30, 19, 16, 29],
 [20, 35, 16, 8, 37, 25, 19],
 [20, 35, 16, 17, 12, 18, 10, 6, 22, 13, 39, 32]]

### Model parameters

In [2]:
maxlen = 30 # set the length of token number in a batch, padding with [PAD] token
batch_size = 6
max_pred = 5 # the max number for the prediction token in the masked token prediction task
n_layers = 6 # the encoder layer number
n_heads = 12 # the number of head in multihead attention
d_model = 768 # the token embedding, segement embedding, position embedding dimension
d_ff = d_model * 4 # the dimension of the FFN layer in the encoder layer
d_kq = d_v = 64 # dimension of K, Q, V
n_segments = 2 # the number of the sentence for encoder input

### Data preprocessing
* ##### The 15% of the token in the sentence nned to be replaced or masked in one sentence
* ##### Two random sentences need to be concatenate 

In [3]:
def make_data():
    batch = list()
    positive, negative = 0, 0 
    # postive: number of two continuous sentences
    # negative: number of two not continuous sentences
    # the ratio should be around 1:1
    while positive != batch_size / 2 or negative != batch_size / 2:
        tokens_a_idx, tokens_b_idx = random.randint(0, len(sentences)-1), random.randint(0, len(sentences)-1)
        tokens_a, tokens_b = token_list[tokens_a_idx], token_list[tokens_b_idx]
        input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']] + tokens_b + [word2idx['[SEP]']]
        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)

        # masking
        n_pred = min(max_pred, max(1, int(len(input_ids) * 0.15)))
        token_maked_pos = [i for i, token in enumerate(input_ids) if token != word2idx['[CLS]'] and token != word2idx['[SEP]']]
        random.shuffle(token_maked_pos)
        masked_tokens, masked_pos = list(), list()
        for pos in token_maked_pos[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])
            if random.random() < 0.8:
                input_ids[pos] = word2idx['[MASK]']
            elif random.random() > 0.9:
                index = random.randint(4, vocab_size-1)
                input_ids[pos] = index

        # zero-padding for sentence
        n_pad = maxlen - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)

        # zero padding for mask
        if max_pred > n_pred:
            n_pad = max_pred - n_pred
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)

        if tokens_a_idx + 1 == tokens_b_idx and positive < batch_size / 2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True])
            positive += 1
        elif tokens_a_idx + 1 != tokens_b_idx and negative < batch_size / 2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False])
            negative += 1

    return batch

len(make_data()), make_data()[1]

(6,
 [[1,
   21,
   30,
   19,
   3,
   3,
   2,
   9,
   18,
   34,
   23,
   36,
   7,
   24,
   2,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0],
  [0,
   0,
   0,
   0,
   0,
   0,
   0,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0],
  [29, 16, 0, 0, 0],
  [5, 4, 0, 0, 0],
  False])

In [4]:
class MyDataset(Dataset):
    def __init__(self, input_ids, segment_ids, masked_tokens, masked_pos, is_next):
          super().__init__()
          self.input_ids = input_ids
          self.segment_ids = segment_ids
          self.masked_tokens = masked_tokens
          self.masked_pos = masked_pos
          self.is_next = is_next

    def __getitem__(self, idx):
        return self.input_ids[idx], self.segment_ids[idx], self.masked_tokens[idx], self.masked_pos[idx], self.is_next[idx]

    def __len__(self):
         return len(self.input_ids)

In [5]:
batch = make_data()
input_ids, segment_ids, masked_tokens, masked_pos, is_next = zip(*batch)
input_ids, segment_ids, masked_tokens, masked_pos, is_next = \
    torch.LongTensor(input_ids),  torch.LongTensor(segment_ids), torch.LongTensor(masked_tokens),\
    torch.LongTensor(masked_pos), torch.LongTensor(is_next)
loader = DataLoader(MyDataset(input_ids, segment_ids, masked_tokens, masked_pos, is_next), batch_size=batch_size, shuffle=True)


### Model structure

In [6]:
def get_atten_pad_mask(seq_q):
    batch_size, seq_len = seq_q.size()
    pad_atten_mask = (seq_q == 0).unsqueeze(dim=-1) # (batch_size, seq_len, 1)
    return pad_atten_mask.expand(batch_size, seq_len, seq_len)

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, Q, K, V, atten_mask, d_kq):
        scores = Q @ K.transpose(-1, -2) / d_kq**0.5
        scores.masked_fill_(atten_mask, -1e9)
        atten = F.softmax(scores, dim=-1)
        context = atten @ V
        return context

class MutliHeadAttention(nn.Module):
    def __init__(self, d_model, d_kq, d_v, n_heads):
        super().__init__()
        self.wq = nn.Linear(d_model, d_kq * n_heads)
        self.wk = nn.Linear(d_model, d_kq * n_heads)
        self.wv = nn.Linear(d_model, d_v * n_heads)
        self.d_model = d_model
        self.d_kq = d_kq
        self.d_v = d_v
        self.n_heads = n_heads

    def forward(self, Q, K, V, atten_mask):
        # Q, K, V: (batch_size, seq_len, d_model)
        residual = Q
        batch_size, seq_len, d_model = Q.shape
        q_s = self.wq(Q).reshape(batch_size, -1, self.n_heads, self.d_kq).permute(0, 2, 1, 3)
        k_s = self.wk(K).reshape(batch_size, -1, self.n_heads, self.d_kq).permute(0, 2, 1, 3)
        v_s = self.wv(V).reshape(batch_size, -1, self.n_heads, self.d_v).permute(0, 2, 1, 3)
        # q_s, k_s, v_s: (batch_size, n_heads, seq_len, d_qkv)

        atten_mask = atten_mask.unsqueeze(dim=1).repeat(1, self.n_heads, 1, 1) # (batch_size, n_heads, seq_len, seq_len)

        context = ScaledDotProductAttention()(q_s, k_s, v_s, atten_mask, self.d_kq)
        context = context.transpose(1, 2).reshape(batch_size, seq_len, self.n_heads * self.d_v)
        output = nn.Linear(self.n_heads * self.d_v, d_model)(context)
        return nn.LayerNorm(d_model)(output + residual)
    
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.fc2(F.gelu(self.fc1(x)))

class EncoderLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc_atten = MutliHeadAttention(d_model, d_kq, d_v, n_heads)
        self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff)

    def forward(self, enc_input, enc_atten_mask):
        enc_output = self.enc_atten(enc_input, enc_input, enc_input, enc_atten_mask)
        enc_output = self.pos_ffn(enc_output)
        return enc_output
    
class Embedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(maxlen, d_model)
        self.seg_embed = nn.Embedding(n_segments, d_model)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, seg):
        seq_len = x.shape[1]
        pos = torch.arange(seq_len)
        pos = pos.unsqueeze(dim=0).expand(*x.shape)
        embedding = self.token_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)


class BERT(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.enc_atten = MutliHeadAttention(d_model, d_kq, d_v, n_heads)
        self.embedding = Embedding()
        self.layers = nn.ModuleList([
            EncoderLayer() for _ in range(n_layers)
        ])
        self.fc = nn.Sequential(
            nn.Linear(self.d_model, d_model),
            nn.Dropout(0.5),
            nn.Tanh()
        )
        self.classifier = nn.Linear(d_model, 2)
        self.linear = nn.Linear(d_model, d_model)
        self.gelu = nn.GELU()
        embed_weight = self.embedding.token_embed.weight
        self.fc2 = nn.Linear(d_model, vocab_size, bias=False)
        self.fc2.weight = embed_weight

    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.embedding(input_ids, segment_ids)
        enc_atten_mask = get_atten_pad_mask(input_ids)
        for layer in self.layers:
            output = layer(output, enc_atten_mask)
        # get [CLS] token
        h_pooled = self.fc(output[:, 0])
        # for is_next prediction
        logits_clsf = self.classifier(h_pooled)

        masked_pos = masked_pos[:, :, None].expand(-1, -1, d_model)
        h_masked = torch.gather(output, 1, masked_pos)
        h_masked = self.gelu(self.linear(h_masked))
        logits_lm = self.fc2(h_masked)

        return logits_clsf, logits_lm

model = BERT(d_model)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(model.parameters(), lr=0.001)    

In [7]:
for epoch in range(180):
    for input_ids, segment_ids, masked_tokens, masked_pos, isNext in loader:
      logits_clsf, logits_lm = model(input_ids, segment_ids, masked_pos)
      loss_lm = loss_fn(logits_lm.view(-1, vocab_size), masked_tokens.view(-1)) # for masked LM
      loss_lm = (loss_lm.float()).mean()
      loss_clsf = loss_fn(logits_clsf, isNext) # for sentence classification
      loss = loss_lm + loss_clsf
      if (epoch + 1) % 10 == 0:
          print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

Epoch: 0010 loss = 1.447332
Epoch: 0020 loss = 1.007936
Epoch: 0030 loss = 0.873854
Epoch: 0040 loss = 0.855126
Epoch: 0050 loss = 0.875040
Epoch: 0060 loss = 0.851540
Epoch: 0070 loss = 0.885756
Epoch: 0080 loss = 0.897899
Epoch: 0090 loss = 0.799609
Epoch: 0100 loss = 0.883441
Epoch: 0110 loss = 0.860890
Epoch: 0120 loss = 0.818617
Epoch: 0130 loss = 0.829771
Epoch: 0140 loss = 0.803633
Epoch: 0150 loss = 0.789039
Epoch: 0160 loss = 0.897564
Epoch: 0170 loss = 0.775942
Epoch: 0180 loss = 0.852835


In [8]:
# Predict mask tokens ans isNext
input_ids, segment_ids, masked_tokens, masked_pos, isNext = batch[0]
print(text)
print([idx2word[w] for w in input_ids if idx2word[w] != '[PAD]'])

logits_clsf, logits_lm = model(torch.LongTensor([input_ids]), \
                 torch.LongTensor([segment_ids]), torch.LongTensor([masked_pos]))
logits_lm = logits_lm.data.max(2)[1][0].data.numpy()
print('masked tokens list : ',[pos for pos in masked_tokens if pos != 0])
print('predict masked tokens list : ',[pos for pos in logits_lm if pos != 0])

logits_clsf = logits_clsf.data.max(1)[1].data.numpy()[0]
print('isNext : ', True if isNext else False)
print('predict isNext : ',True if logits_clsf else False)

Hello, how are you? I am Romeo.
Hello, Romeo My name is Juliet. Nice to meet you.
Nice meet you too. How are you today?
Great. My baseball team won the competition.
Oh Congratulations, Juliet
Thank you Romeo
Where are you going today?
I am going shopping. What about you?
I am going to visit my grandmother. she is not very well
['[CLS]', 'thank', 'you', 'romeo', '[SEP]', 'great', 'my', 'baseball', 'team', 'won', 'hello', 'competition', '[SEP]']
masked tokens list :  [7]
predict masked tokens list :  [7]
isNext :  False
predict isNext :  True
