## MLM and NSP summary

#### 1. Masked Language Modeling (MLM)
MLM is a technique where a portion of the input tokens in a sentence are masked at random, and the model is trained to predict the original tokens based on the context provided by the remaining tokens. This helps the model to understand the relationship between words in a sentence from both directions (left and right).</br>

* Why important?</br>
MLM enables BERT to learn bidirectional representations by considering both the left and right context in all layers.</br>

- Masking Tokens: During training, 15% of the tokens in each sequence are selected for masking. Out of these, 80% are replaced with the [MASK] token, 10% with a random token, and 10% remain unchanged.</br>
</br>
- Prediction Task: The model then attempts to predict the original tokens that were masked out, based on the context provided by the other tokens in the sequence.</br>
</br>
#### 2. Next Sentence Prediction (NSP)
NSP is a binary classification task that helps BERT understand the relationship between two sentences. Specifically, it predicts whether a given sentence B follows sentence A in the original text.</br>

* Why important?</br> 
NSP helps BERT to capture information about the relationships between sentences, which is crucial for tasks like question answering and natural language inference</br>

- Sentence Pair Preparation: During training, the model receives pairs of sentences. In 50% of the cases, sentence B is the actual next sentence that follows sentence A (labeled as IsNext). In the other 50% of the cases, sentence B is a random sentence from the corpus (labeled as NotNext).</br>
- Prediction Task: The model is trained to classify whether sentence B follows sentence A or not.</br>

In [2]:
from transformers import BertTokenizer, BertForPreTraining

In [3]:
import torch

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForPreTraining.from_pretrained('bert-base-uncased')

Downloading vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [58]:
# tokenize a paragraph of text from the Wikipedia page on the American Civil War. 
text = ("After Abraham Lincoln won the November 1860 presidential [MASK] on an "
        "anti-slavery platform, an initial seven slave states declared their "
        "secession from the country to form the Confederacy. War broke out in "
        "April 1861 when secessionist forces [MASK] Fort Sumter in South "
        "Carolina, just over a month after Lincoln's inauguration.")

In [35]:
inputs = tokenizer(text, return_tensors='pt')

In [36]:
outputs = model(**inputs)

In [37]:
outputs.keys()

odict_keys(['prediction_logits', 'seq_relationship_logits'])

In [38]:
outputs.prediction_logits

tensor([[[ -7.6192,  -7.5433,  -7.6124,  ...,  -6.7155,  -6.7375,  -4.6122],
         [-12.5489, -12.3772, -12.6500,  ..., -11.8644, -11.4446,  -9.1151],
         [ -6.2346,  -6.3590,  -5.9091,  ...,  -6.1258,  -6.2720,  -5.0268],
         ...,
         [ -2.2497,  -2.1352,  -2.1812,  ...,  -1.7201,  -1.2728,  -7.8301],
         [-14.2654, -14.3100, -14.2294,  ..., -11.4669, -11.7212, -10.3129],
         [-11.5071, -12.0389, -11.6046,  ..., -11.2875,  -9.1655,  -9.1732]]],
       grad_fn=<ViewBackward0>)

In [60]:
# There are 62 tokens (60 + [CLS] and [SEP]), we can see this reflected 
# in the prediction_logits.shape:
outputs.prediction_logits.shape

torch.Size([1, 62, 30522])

In [40]:
inputs.input_ids.shape

torch.Size([1, 62])

In [41]:
outputs.seq_relationship_logits

tensor([[ 2.8257, -1.6897]], grad_fn=<AddmmBackward0>)

## MLM

Convert our prediction_logits into token predictions.</br>
To do this, we'll need to get a mapping between index values and words from the model</br> vocab, which we can extract from the tokenizer.</br>

In [42]:
token2idx = tokenizer.get_vocab()

In [43]:
token2idx

{'[PAD]': 0,
 '[unused0]': 1,
 '[unused1]': 2,
 '[unused2]': 3,
 '[unused3]': 4,
 '[unused4]': 5,
 '[unused5]': 6,
 '[unused6]': 7,
 '[unused7]': 8,
 '[unused8]': 9,
 '[unused9]': 10,
 '[unused10]': 11,
 '[unused11]': 12,
 '[unused12]': 13,
 '[unused13]': 14,
 '[unused14]': 15,
 '[unused15]': 16,
 '[unused16]': 17,
 '[unused17]': 18,
 '[unused18]': 19,
 '[unused19]': 20,
 '[unused20]': 21,
 '[unused21]': 22,
 '[unused22]': 23,
 '[unused23]': 24,
 '[unused24]': 25,
 '[unused25]': 26,
 '[unused26]': 27,
 '[unused27]': 28,
 '[unused28]': 29,
 '[unused29]': 30,
 '[unused30]': 31,
 '[unused31]': 32,
 '[unused32]': 33,
 '[unused33]': 34,
 '[unused34]': 35,
 '[unused35]': 36,
 '[unused36]': 37,
 '[unused37]': 38,
 '[unused38]': 39,
 '[unused39]': 40,
 '[unused40]': 41,
 '[unused41]': 42,
 '[unused42]': 43,
 '[unused43]': 44,
 '[unused44]': 45,
 '[unused45]': 46,
 '[unused46]': 47,
 '[unused47]': 48,
 '[unused48]': 49,
 '[unused49]': 50,
 '[unused50]': 51,
 '[unused51]': 52,
 '[unused52]': 53,

In [44]:
token2idx['hello']

7592

In [45]:
idx2token = {value:key for key, value in token2idx.items()}

In [46]:
idx2token[7592]

'hello'

In [47]:
outputs.prediction_logits[0][2].shape

torch.Size([30522])

In [48]:
len(token2idx)

30522

Now all we need to do is take the softmax to get a probability distribution across the 30522 tokens, and extract the most probable using an argmax function:

In [49]:
softmax = torch.nn.functional.softmax(outputs.prediction_logits[0][2], dim=-1)

In [50]:
argmax = torch.argmax(softmax)

In [51]:
idx2token[argmax.item()]

'abraham'

In [52]:
softmax = torch.nn.functional.softmax(outputs.prediction_logits[0], dim=0)
argmax = torch.argmax(softmax, dim=1)

In [53]:
argmax

tensor([28191,  2348,  8181, 16628,  2180,  3882,  2281,  7313,  4883, 27419,
         2006,  2010,  3424,  1011,  8864,  4132,  1010,  2019,  3988,  2698,
         8914,  2163, 13520,  2037,  4336,  2013,  1996,  2406,  2000,  2433,
        28775, 18179, 16363,  2162,  3631,  2041,  1999,  2258,  6863,  2043,
        18232,  2923,  2749,  4548,  3481,  7680,  5017,  2005,  2148,  3792,
        24901,  2074,  2058,  1037,  3204,  2077,  3946,  1005,  1055, 17331,
         1025, 25656])

In [54]:
argmax.shape

torch.Size([62])

In [55]:
for idx in argmax:
    print(idx2token[idx.item()], end=' ')

##ecin although abraham lincolnshire won 1948 november 1860 presidential primaries on his anti - slavery platform , an initial seven tributary states declare their independence from the country to form ##ici confederacy ##yre war broke out in april 1861 when ##oya ##ist forces occupied fort sum ##mer for south carolina ##trip just over a month before grant ' s inauguration ; ##tson 

In [56]:
text

"After Abraham Lincoln won the November 1860 presidential [MASK] on an anti-slavery platform, an initial seven slave states declared their secession from the country to form the Confederacy. War broke out in April 1861 when secessionist forces [MASK] Fort Sumter in South Carolina, just over a month after Lincoln's inauguration."

We can see here that the predicted word for 'election' is 'primaries', which can is a reasonably close word match - although certainly not perfect or correct. For 'attacked' we see 'occupied' as the predicted word, again, not correct but pretty close.

## NSP (Next Sentence Prediction) 

In [62]:
# tokenize a paragraph of text from the Wikipedia page on the American Civil War. 
text = ("After Abraham Lincoln won the November 1860 presidential [MASK] on an "
        "anti-slavery platform, an initial seven slave states declared their "
        "secession from the country to form the Confederacy.")
text2 = ("War broke out in April 1861 when secessionist forces [MASK] Fort "
         "Sumter in South Carolina, just over a month after Lincoln's " 
         "inauguration.")

In [64]:
inputs = tokenizer(text, text2, return_tensors='pt')

In [65]:
inputs

{'input_ids': tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883,   103,
          2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,  2698,
          6658,  2163,  4161,  2037, 22965,  2013,  1996,  2406,  2000,  2433,
          1996, 18179,  1012,  2162,  3631,  2041,  1999,  2258,  6863,  2043,
         22965,  2923,  2749,   103,  3481,  7680,  3334,  1999,  2148,  3792,
          1010,  2074,  2058,  1037,  3204,  2044,  5367,  1005,  1055, 17331,
          1012,   102,  2162,  3631,  2041,  1999,  2258,  6863,  2043, 22965,
          2923,  2749,   103,  3481,  7680,  3334,  1999,  2148,  3792,  1010,
          2074,  2058,  1037,  3204,  2044,  5367,  1005,  1055, 17331,  1012,
           102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [66]:
outputs = model(**inputs)

In [67]:
outputs.seq_relationship_logits

tensor([[ 6.4874, -6.4341]], grad_fn=<AddmmBackward0>)

In [69]:
argmax = torch.argmax(outputs.seq_relationship_logits)

In [70]:
argmax

tensor(0)

Index 0 represents BERTs IsNext class, meaning that sentence B is the next sentence after A. Index 1 represents the NotNext class, meaning sentence B is not the next sentence after B. We can write this as:



In [72]:
'NotNext' if argmax.item() else 'IsNext'

'IsNext'