In [1]:
import torch
from torch import nn
from d2l import torch as d2l

  warn(f"Failed to load image Python extension: {e}")


Input Representation

In [2]:
def get_tokens_and_segments(tokens_a, tokens_b=None):
    """Get tokens of the BERT input sequence and their segment IDs."""
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    # 0 and 1 are marking segment A and B, respectively
    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments += [1] * (len(tokens_b) + 1)
    return tokens, segments

The model

In [3]:
#@save
class BERTEncoder(nn.Module):
    """BERT encoder."""
    def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
                 num_blks, dropout, max_len=1000, **kwargs):
        super(BERTEncoder, self).__init__(**kwargs)
        self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
        self.segment_embedding = nn.Embedding(2, num_hiddens)
        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module(f"{i}", d2l.TransformerEncoderBlock(
                num_hiddens, ffn_num_hiddens, num_heads, dropout, True))
        # In BERT, positional embeddings are learnable, thus we create a
        # parameter of positional embeddings that are long enough
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len,
                                                      num_hiddens))

    def forward(self, tokens, segments, valid_lens):
        # Shape of `X` remains unchanged in the following code snippet:
        # (batch size, max sequence length, `num_hiddens`)
        X = self.token_embedding(tokens) + self.segment_embedding(segments)
        X = X + self.pos_embedding[:, :X.shape[1], :]
        for blk in self.blks:
            X = blk(X, valid_lens)
        return X

Hyperparameter

In [4]:
vocab_size, num_hiddens, ffn_num_hiddens, num_heads = 10000, 768, 1024, 4
ffn_num_input, num_blks, dropout = 768, 2, 0.2
encoder = BERTEncoder(vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
                      num_blks, dropout)



In [45]:
# test
tokens = torch.randint(0, vocab_size, (2, 8))
segments = torch.tensor([[0, 
                          0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
encoded_X = encoder(tokens, segments, None)
encoded_X.shape

torch.Size([2, 8, 768])

# Pretraining Task

MLM

a special “< mask >” token for 80% of the time (e.g., “this movie is great” becomes “this movie is < mask >”);

a random token for 10% of the time (e.g., “this movie is great” becomes “this movie is drink”);

the unchanged label token for 10% of the time (e.g., “this movie is great” becomes “this movie is great”).

In [33]:
# In forward inference, it takes two inputs:
# the encoded result of BERTEncoder and the token positions for prediction. 
# The output is the prediction results at these positions.

class MaskLM(nn.Module):
    """The masked language model task of BERT."""
    def __init__(self, vocab_size, num_hiddens, **kwargs):
        super(MaskLM, self).__init__(**kwargs)
        self.mlp = nn.Sequential(nn.LazyLinear(num_hiddens),
                                 nn.ReLU(),
                                 nn.LayerNorm(num_hiddens),
                                 nn.LazyLinear(vocab_size))

    def forward(self, X, pred_positions):
        num_pred_positions = pred_positions.shape[1]
        pred_positions = pred_positions.reshape(-1)
        batch_size = X.shape[0]
        print(f'num_pred_positions: {num_pred_positions}\npred_positions: {pred_positions}\nbatch_size: {batch_size}\n')
        batch_idx = torch.arange(0, batch_size)
        print(f'batch_idx: {batch_idx}\n')
        # Suppose that `batch_size` = 2, `num_pred_positions` = 3, then
        # `batch_idx` is `torch.tensor([0, 0, 0, 1, 1, 1])`
        batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)
        print(f'new batch_idx: {batch_idx.shape}\n{batch_idx}\n')
        masked_X = X[batch_idx, pred_positions]
        print(f'masked_X: {masked_X.shape}\n{masked_X}\n')
        masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))
        print(f'new masked_X: {masked_X.shape}\n{masked_X}')
        mlm_Y_hat = self.mlp(masked_X)
        return mlm_Y_hat

In [35]:
encoded_X.shape # #seq=2, seq_length=8, hidden_dim=768

torch.Size([2, 8, 768])

In [38]:
mlm = MaskLM(vocab_size, num_hiddens)
mlm_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])    # we want these posision's tokens context representation, pos<8 because seq_length=8
# encoded_X: 2 seqs, mlm_positions: 
# mlm_positions: the 3 indices to predict in either BERT input sequence of encoded_X
# mlm: returns prediction results mlm_Y_hat at all the masked positions mlm_positions of encoded_X
mlm_Y_hat = mlm(encoded_X, mlm_positions)
mlm_Y_hat.shape

num_pred_positions: 3
pred_positions: tensor([1, 5, 2, 6, 1, 5])
batch_size: 2

batch_idx: tensor([0, 1])

new batch_idx: torch.Size([6])
tensor([0, 0, 0, 1, 1, 1])

masked_X: torch.Size([6, 768])
tensor([[-0.3498,  1.1218, -0.6139,  ..., -0.5616, -0.6248, -1.2501],
        [-0.5589, -0.2799, -2.3278,  ...,  0.9386, -0.5574,  0.8928],
        [ 1.2462,  0.7204, -1.7339,  ...,  0.3420, -1.7081,  0.1696],
        [ 1.1380, -0.8836, -2.1321,  ..., -0.3173, -0.7220,  0.7464],
        [ 0.2241,  0.8395, -0.3604,  ..., -0.7947,  0.5012, -0.7594],
        [ 0.6268,  0.3732, -0.7302,  ...,  1.1053,  0.2988,  1.7736]],
       grad_fn=<IndexBackward0>)

new masked_X: torch.Size([2, 3, 768])
tensor([[[-0.3498,  1.1218, -0.6139,  ..., -0.5616, -0.6248, -1.2501],
         [-0.5589, -0.2799, -2.3278,  ...,  0.9386, -0.5574,  0.8928],
         [ 1.2462,  0.7204, -1.7339,  ...,  0.3420, -1.7081,  0.1696]],

        [[ 1.1380, -0.8836, -2.1321,  ..., -0.3173, -0.7220,  0.7464],
         [ 0.2241,  0.83

torch.Size([2, 3, 10000])

In [39]:
mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]]) # the token should be these
loss = nn.CrossEntropyLoss(reduction='none')
mlm_l = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))
mlm_l.shape

torch.Size([6])

Next Sentence Prediction

In [40]:
# a classification task, input the <cls> token, output_size = 2
class NextSentencePred(nn.Module):
    """The next sentence prediction task of BERT."""
    def __init__(self, **kwargs):
        super(NextSentencePred, self).__init__(**kwargs)
        self.output = nn.LazyLinear(2)

    def forward(self, X):
        # `X` shape: (batch size, `num_hiddens`)
        return self.output(X)

In [46]:
encoded_X.shape

torch.Size([2, 8, 768])

In [53]:
# PyTorch by default will not flatten the tensor as seen in mxnet where, if
# flatten=True, all but the first axis of input data are collapsed together
encoded_flatten_X = torch.flatten(encoded_X, start_dim=1)
print(encoded_flatten_X.shape)
# input_shape for NSP: (batch size, `num_hiddens`)
nsp = NextSentencePred()
nsp_Y_hat = nsp(encoded_X)
nsp_Y_hat.shape
# here is the result without flatten

torch.Size([2, 6144])


torch.Size([2, 8, 2])

In [54]:
# here is the seq(sentence) level task, so we need flatten
nsp = NextSentencePred()
nsp_Y_hat = nsp(encoded_flatten_X)
nsp_Y_hat.shape

torch.Size([2, 2])

In [56]:
nsp_y = torch.tensor([0, 1])
nsp_l = loss(nsp_Y_hat, nsp_y)
nsp_l.shape, nsp_l

(torch.Size([2]), tensor([1.0465, 0.7118], grad_fn=<NllLossBackward0>))

# Putting it all together

In [57]:
#@save
class BERTModel(nn.Module):
    """The BERT model."""
    def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens,
                 num_heads, num_blks, dropout, max_len=1000):
        super(BERTModel, self).__init__()
        self.encoder = BERTEncoder(vocab_size, num_hiddens, ffn_num_hiddens,
                                   num_heads, num_blks, dropout,
                                   max_len=max_len)
        self.hidden = nn.Sequential(nn.LazyLinear(num_hiddens),
                                    nn.Tanh())
        self.mlm = MaskLM(vocab_size, num_hiddens)
        self.nsp = NextSentencePred()

    def forward(self, tokens, segments, valid_lens=None, pred_positions=None):
        encoded_X = self.encoder(tokens, segments, valid_lens)
        if pred_positions is not None:
            mlm_Y_hat = self.mlm(encoded_X, pred_positions)
        else:
            mlm_Y_hat = None
        # The hidden layer of the MLP classifier for next sentence prediction.
        # 0 is the index of the '<cls>' token
        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
        return encoded_X, mlm_Y_hat, nsp_Y_hat