## CS310 Natural Language Processing
## Assignment 5 (part 1): Pretraining BERT with Masked Language Modeling and Next Sentence Prediction on Toy Data

In [None]:
import math
import re
import random
from typing import List, Dict
from pprint import pprint
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

## Preprocessing
We start by assigning a raw text for training. In fact, the pretraining corpus of Bert including 33B words.

In [None]:
text = (
        'Hello, how are you? I am Romeo.\n'
        'Hello, Romeo! My name is Juliet. Nice to meet you.\n'
        'Nice meet you too. How are you today?\n'
        'Great. My baseball team won the competition.\n'
        'Oh Congratulations, Juliet\n'
        'Thank you Romeo'
    )

Now, in the following step, it is important to remember that `BERT` takes special tokens during training. Here is a table explaining the purpose of various tokens:

| Token      | Purpose |
| ----------- | ----------- |
| [CLS]      | The first token is always classification      |
| [SEP]  |   Separates two sentences      |
| [PAD]   |  Use to truncate the sentence with equal length.       |
| [MASK]   |    Use to create a mask by replacing the original word.     |

These tokens should be included in the word dictionary where each token and word in the vocabulary is assigned with an index number. 


In [None]:
sentences = re.sub("[.,!?\\-]", '', text.lower()).split('\n')  # filter '.', ',', '?', '!'
word_types = sorted(list(set(" ".join(sentences).split())))

# Add the special tokens to the vocabulary
word_to_id = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}
for i, w in enumerate(word_types):
    word_to_id[w] = i + 4
id_to_word = {i: w for i, w in enumerate(word_to_id)}
VOCAB_SIZE = len(word_to_id)

tokens_list = []
for sentence in sentences:
    tokens = [word_to_id[s] for s in sentence.split()]
    tokens_list.append(tokens)

# Test
pprint(tokens_list)

# You should expect to see the following results:
# [[10, 11, 5, 28, 12, 4, 20],
#  [10, 20, 16, 17, 13, 14, 18, 24, 15, 28],
#  [18, 15, 28, 26, 11, 5, 28, 25],
#  [9, 16, 6, 21, 27, 23, 7],
#  [19, 8, 14],
#  [22, 28, 20]]

Set hyperparameters:

In [None]:
MAX_LEN = 30 # maximum of width of a batch
MAX_PRED = 5  # maxium number predictions to make in a batch

batch_size = 6
n_layers = 6 # number of Encoder of Encoder Layer
n_heads = 12 # number of heads in Multi-Head Attention
d_model = 768 # Embedding Size
d_ff = 768 * 4  # 4*d_model, FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_segments = 2 # number of segments, ex) sentence A and sentence B

Now the next step will be to create masking. 

As mentioned in the original paper, BERT randomly assigns masks to 15% of the sequence. But keep in mind that you don’t assign masks to the special tokens. For that, we will use conditional statements.

We first replace 15% of the words with [MASK] tokens, and once that is done, we will add padding. 

Padding is to make sure that all the sentence pairs in the batch are of equal length. 

For example, if we take the sentence : “*The dog is barking at the tree. The cat is walking.*”

then with padding, it will look like this: 

“`[CLS]` The dog is barking at the tree. `[CLS]` The cat is walking. `[PAD]` `[PAD]` `[PAD]`” 


In [None]:
def make_batch(tokens_list: List[int], batch_size: int, word_to_id: Dict):
    batch = []
    positive = negative = 0
    
    while positive != batch_size/2 or negative != batch_size/2:
        sent_a_index, sent_b_index= random.randrange(len(tokens_list)), random.randrange(len(tokens_list))
        tokens_a, tokens_b= tokens_list[sent_a_index], tokens_list[sent_b_index]

        input_ids = [word_to_id['[CLS]']] + tokens_a + [word_to_id['[SEP]']] + tokens_b + [word_to_id['[SEP]']]
        segment_ids = [1] * (1 + len(tokens_a) + 1) + [2] * (len(tokens_b) + 1)

        # The following code is used for the Masked Language Modeling (MLM) task.
        n_pred =  min(MAX_PRED, max(1, int(round(len(input_ids) * 0.15)))) # Predict at most 15 % of tokens in one sentence
        masked_candidates_pos = [i for i, token in enumerate(input_ids)
                          if token != word_to_id['[CLS]'] and token != word_to_id['[SEP]']]
        random.shuffle(masked_candidates_pos)
        masked_tokens, masked_pos = [], []
        for pos in masked_candidates_pos[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])
            ### START YOUR CODE ###
            # Throw a dice to decide if you want to replace the token with [MASK], random word, or remain the same
            pass
            ### END YOUR CODE ###

        # Make zero paddings
        n_pad = MAX_LEN - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)

        # Zero padding (100% - 15%) of thetokens
        if MAX_PRED > n_pred:
            n_pad = MAX_PRED - n_pred
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)

        # The following code is used for the Next Sentence Prediction (NSP) task.
        ### START YOUR CODE ###
        # Decide if the is_next label is positive or negative, by comparing sent_a_index and sent_b_index
        # Don't forget to increment the positive/negative count
        pass
        ### END YOUR CODE ###

    return batch

Creating attention mask
BERT needs attention masks. And these should be in a proper format. The following code will help you create masks. 

It will convert the `[PAD]` to `1` and elsewhere `0`. 

In [None]:
def get_pad_attn_mask(seq_q, seq_k):
    batch_size, seq_len = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # batch_size x 1 x len_k(=seq_len), one is masking
    return pad_attn_mask.expand(batch_size, seq_len, len_k)  # batch_size x seq_len x len_k

In [None]:
# Test
random.seed(0)
batch = make_batch(tokens_list, batch_size, word_to_id)
input_ids, segment_ids, masked_tokens, masked_pos, is_next = map(torch.LongTensor, zip(*batch))

sample = 2

print('sampled text:')
print([id_to_word[w.item()] for w in input_ids[sample] if id_to_word[w.item()] != '[PAD]'])
print()
print('input_ids:', input_ids[sample])
print('segment_ids:', segment_ids[sample])
print('masked_tokens:', masked_tokens[sample])
print('masked_pos:', masked_pos[sample])
print('is_next:', is_next[sample].item())

# You should expect to see the following results:
# sampled text:
# ['[CLS]', '[MASK]', 'my', 'baseball', 'team', 'won', 'the', 'competition', '[SEP]', 'oh', 'congratulations', 'juliet', '[SEP]']

# input_ids: tensor([ 1,  3, 16,  6, 21, 27, 23,  7,  2, 19,  8, 14,  2,  0,  0,  0,  0,  0,
#          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])
# segment_ids: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#         0, 0, 0, 0, 0, 0])
# masked_tokens: tensor([ 9, 19,  0,  0,  0])
# masked_pos: tensor([1, 9, 0, 0, 0])
# is_next: 1

#### Embedding layer
The embedding is the first layer in BERT that takes the input and creates a lookup table. The parameters of the embedding layers are learnable, which means when the learning process is over the embeddings will cluster similar words together. 

The embedding layer also preserves different relationships between words such as: semantic, syntactic, linear, and since BERT is bidirectional it will also preserve contextual relationships as well. 

In the case of BERT, it creates three embeddings for 
- Token, 
- Segments,
- Position. 

In [None]:
class Embedding(nn.Module):
    def __init__(self):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(VOCAB_SIZE, d_model, padding_idx=0)  # token embedding
        self.pos_embed = nn.Embedding(MAX_LEN, d_model)  # position embedding
        self.seg_embed = nn.Embedding(n_segments + 1, d_model, padding_idx=0)  # segment(token type) embedding
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, seg):
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long)
        pos = pos.unsqueeze(0).expand_as(x)  # (seq_len,) -> (batch_size, seq_len)
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)

### Encoder
The encoder has two main components: 

- Multi-head Attention
- Position-wise feed-forward network. 

The work of the encoder is to find representations and patterns from the input and attention mask. 

#### Multi-head attention

This is the first of the main components of the encoder. 

The attention model takes as input a sequence `x` and a corresponding attention mask `attn`. 

Query, Key, and Value are computed from `x` and `atten`, based on which the contextual is computed.

In [None]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size x n_heads x seq_len x =seq_len]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)

    def forward(self, x, attn_mask):
        # x: [batch_size x seq_len x d_model]
        residual, batch_size = x, x.size(0)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q_s = self.W_Q(x).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # q_s: [batch_size x n_heads x seq_len x d_k]
        k_s = self.W_K(x).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # k_s: [batch_size x n_heads x seq_len x d_k]
        v_s = self.W_V(x).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # v_s: [batch_size x n_heads x seq_len x d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size x n_heads x seq_len x seq_len]

        # context: [batch_size x n_heads x seq_len x d_v], attn: [batch_size x n_heads x seq_len x seq_len]
        context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # context: [batch_size x seq_len x n_heads * d_v]
        output = nn.Linear(n_heads * d_v, d_model)(context)
        return nn.LayerNorm(d_model)(output + residual), attn # output: [batch_size x seq_len x d_model]


#### Position-Wise Feed Forward Network

The output from the multihead goes into the feed-forward network and that concludes the encoder part.

Let’s take a breath and revise what we’ve learned so far:

The input goes into the embedding and as well attention function. Both of which are fed into the encoder which has a multi-head function and a feed-forward network. 
The multi-head function itself has a function that operates the embeddings and attention mask using a dot product operation. 

In [None]:
# Activation function: GELU (Gaussian Error Linear Unit)
def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        # (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_ff) -> (batch_size, seq_len, d_model)
        return self.fc2(gelu(self.fc1(x)))

Now, we can use `Multi-head Attention` and `Position-Wise Feed Forward Network` to construct the Encoder. 

*Hint*: Call `self.enc_self_attn` and `self.pos_ffn` in the forward function.

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        ### START YOUR CODE ###
        enc_outputs, attn = None, None
        ### END YOUR CODE ###
        return enc_outputs, attn

In [None]:
# Test
torch.manual_seed(0)
random.seed(0)
batch = make_batch(tokens_list, batch_size, word_to_id)
input_ids, segment_ids, masked_tokens, masked_pos, is_next = map(torch.LongTensor, zip(*batch))

enc_layer = EncoderLayer()
enc_self_attn_mask = get_pad_attn_mask(input_ids, input_ids)
embedding = Embedding()
enc_inputs = embedding(input_ids, segment_ids)
enc_outputs, attn = enc_layer(enc_inputs, enc_self_attn_mask)

print('enc_outputs:', enc_outputs.size())
print('attn:', attn.size())

# You should expect to see the following results:
# enc_outputs: torch.Size([6, 30, 768])
# attn: torch.Size([6, 12, 30, 30])

### Assembling all the components
Let’s continue from where we left, i.e. the output from the encoder.

The encoder yields two outputs: 

- One which comes from the feed-forward layer and 
- the Attention mask. 

It’s key to remember that BERT does not explicitly use a decoder. Instead, it uses the output and the attention mask to get the desired result. 

Although the decoder section in the transformers is replaced with a shallow network which can be used for classification as shown in the code below.
Also, BERT outputs two results: one for the classifier and the other for masked.

In [None]:
class BERT(nn.Module):
    def __init__(self):
        super(BERT, self).__init__()
        self.embedding = Embedding()
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])

         # for NSP task
        self.fc1 = nn.Linear(d_model, d_model)
        self.activ1 = nn.Tanh()
        self.classifier = nn.Linear(d_model, 2)

        # for MLM task
        self.fc2 = nn.Linear(d_model, d_model)
        self.activ2 = gelu
        self.norm = nn.LayerNorm(d_model)
        embed_weight = self.embedding.tok_embed.weight
        n_vocab, n_dim = embed_weight.size()
        self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
        self.decoder.weight = embed_weight # decoder is shared with embedding layer
        self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))

    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_pad_attn_mask(input_ids, input_ids)
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)
            # output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model]

        # Use the representation of [CLS] to produce logits for NSP task
        ### START YOUR CODE ###
        logits_clsf = None
        ### END YOUR CODE ###

        # Gather the representations of masked tokens to produce logits for MLM task
        ### START YOUR CODE ###
        logits_lm = None
        ### END YOUR CODE ###

        return logits_lm, logits_clsf

In [None]:
# Test
torch.manual_seed(0)
random.seed(0)

batch = make_batch(tokens_list, batch_size, word_to_id)
input_ids, segment_ids, masked_tokens, masked_pos, is_next = map(torch.LongTensor, zip(*batch))

model = BERT()
logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)

print('logits_lm:', logits_lm.size())
print('logits_clsf:', logits_clsf.size())

# You should expect to see the following results:
# logits_lm: torch.Size([6, 5, 29])
# logits_clsf: torch.Size([6, 2])

Finally, we’ll start the training. 

In [None]:
random.seed(0)
torch.manual_seed(0)

model = BERT()
criterion = nn.CrossEntropyLoss() # You can also try two separate losses for each task
optimizer = optim.Adam(model.parameters(), lr=0.001)

batch = make_batch(tokens_list, batch_size, word_to_id)
input_ids, segment_ids, masked_tokens, masked_pos, is_next = map(torch.LongTensor, zip(*batch))

for epoch in range(500):
    optimizer.zero_grad()

    ### START YOUR CODE ###
    logits_lm, logits_clsf = None, None 
    # Hint: Check the shape of logits_lm and decide if post-processing is needed
    loss = None
    ### END YOUR CODE ###

    if (epoch + 1) % 10 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
    loss.backward()
    optimizer.step()

If you can observe a quick decrease in the loss, then you are on the right track. It should stablize to around 2.xx after a few hundred epochs.

In [None]:
# Test
sample = 1

# Predict mask tokens and is_next
input_ids, segment_ids, masked_tokens, masked_pos, is_next = map(torch.LongTensor, zip(batch[sample]))
# print(text)
print([id_to_word[w.item()] for w in input_ids[0] if id_to_word[w.item()] != '[PAD]'])

logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)
predicted_ids = logits_lm.argmax(dim=2).squeeze().data.numpy()
n_masked_tokens = torch.sum(masked_tokens[0]!=0).item()
print('masked tokens ground truth: ',[id_to_word[pos.item()] for pos in masked_tokens[0][:n_masked_tokens]])
print('predicted masked tokens: ',[id_to_word[pos] for pos in predicted_ids[:n_masked_tokens]])

predicted_is_next = logits_clsf.argmax(dim=1).data.numpy()[0]
print('is_next ground truth : ', True if is_next else False)
print('predicted is_next: ',True if predicted_is_next else False)

# There is “correct” answer to this task, best of luck with your training!