In [1]:
from transformers import BertTokenizer, BertForPreTraining

In [2]:
import torch

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForPreTraining.from_pretrained('bert-base-uncased')

In [4]:
text = ("After Abraham Lincoln won the November 1860 presidential [MASK] on an "
        "anti-slavery platform, an initial seven slave states declared their "
        "secession from the country to form the Confederacy. War broke out in "
        "April 1861 when secessionist forces [MASK] Fort Sumter in South "
        "Carolina, just over a month after Lincoln's inauguration.")

In [5]:
inputs = tokenizer(text, return_tensors='pt')
outputs = model(**inputs)

In [6]:
inputs.input_ids.shape

torch.Size([1, 62])

In [7]:
outputs.keys()

odict_keys(['prediction_logits', 'seq_relationship_logits'])

In [8]:
outputs.prediction_logits

tensor([[[ -7.6192,  -7.5433,  -7.6124,  ...,  -6.7155,  -6.7375,  -4.6122],
         [-12.5489, -12.3772, -12.6500,  ..., -11.8644, -11.4446,  -9.1151],
         [ -6.2346,  -6.3590,  -5.9091,  ...,  -6.1258,  -6.2720,  -5.0268],
         ...,
         [ -2.2497,  -2.1352,  -2.1812,  ...,  -1.7201,  -1.2728,  -7.8301],
         [-14.2654, -14.3100, -14.2294,  ..., -11.4669, -11.7212, -10.3129],
         [-11.5071, -12.0389, -11.6046,  ..., -11.2875,  -9.1655,  -9.1733]]],
       grad_fn=<ViewBackward0>)

In [9]:
outputs.prediction_logits.shape

torch.Size([1, 62, 30522])

In [10]:
outputs.seq_relationship_logits

tensor([[ 2.8257, -1.6897]], grad_fn=<AddmmBackward0>)

### MLM
Masked Lanaguage Modelling

In [11]:
token2idx = tokenizer.get_vocab()

In [12]:
token2idx

{'[PAD]': 0,
 '[unused0]': 1,
 '[unused1]': 2,
 '[unused2]': 3,
 '[unused3]': 4,
 '[unused4]': 5,
 '[unused5]': 6,
 '[unused6]': 7,
 '[unused7]': 8,
 '[unused8]': 9,
 '[unused9]': 10,
 '[unused10]': 11,
 '[unused11]': 12,
 '[unused12]': 13,
 '[unused13]': 14,
 '[unused14]': 15,
 '[unused15]': 16,
 '[unused16]': 17,
 '[unused17]': 18,
 '[unused18]': 19,
 '[unused19]': 20,
 '[unused20]': 21,
 '[unused21]': 22,
 '[unused22]': 23,
 '[unused23]': 24,
 '[unused24]': 25,
 '[unused25]': 26,
 '[unused26]': 27,
 '[unused27]': 28,
 '[unused28]': 29,
 '[unused29]': 30,
 '[unused30]': 31,
 '[unused31]': 32,
 '[unused32]': 33,
 '[unused33]': 34,
 '[unused34]': 35,
 '[unused35]': 36,
 '[unused36]': 37,
 '[unused37]': 38,
 '[unused38]': 39,
 '[unused39]': 40,
 '[unused40]': 41,
 '[unused41]': 42,
 '[unused42]': 43,
 '[unused43]': 44,
 '[unused44]': 45,
 '[unused45]': 46,
 '[unused46]': 47,
 '[unused47]': 48,
 '[unused48]': 49,
 '[unused49]': 50,
 '[unused50]': 51,
 '[unused51]': 52,
 '[unused52]': 53,

In [14]:
token2idx['[MASK]'], token2idx['hello']

(103, 7592)

In [15]:
idx2token = {v: k for k, v in token2idx.items()}

In [16]:
idx2token

{0: '[PAD]',
 1: '[unused0]',
 2: '[unused1]',
 3: '[unused2]',
 4: '[unused3]',
 5: '[unused4]',
 6: '[unused5]',
 7: '[unused6]',
 8: '[unused7]',
 9: '[unused8]',
 10: '[unused9]',
 11: '[unused10]',
 12: '[unused11]',
 13: '[unused12]',
 14: '[unused13]',
 15: '[unused14]',
 16: '[unused15]',
 17: '[unused16]',
 18: '[unused17]',
 19: '[unused18]',
 20: '[unused19]',
 21: '[unused20]',
 22: '[unused21]',
 23: '[unused22]',
 24: '[unused23]',
 25: '[unused24]',
 26: '[unused25]',
 27: '[unused26]',
 28: '[unused27]',
 29: '[unused28]',
 30: '[unused29]',
 31: '[unused30]',
 32: '[unused31]',
 33: '[unused32]',
 34: '[unused33]',
 35: '[unused34]',
 36: '[unused35]',
 37: '[unused36]',
 38: '[unused37]',
 39: '[unused38]',
 40: '[unused39]',
 41: '[unused40]',
 42: '[unused41]',
 43: '[unused42]',
 44: '[unused43]',
 45: '[unused44]',
 46: '[unused45]',
 47: '[unused46]',
 48: '[unused47]',
 49: '[unused48]',
 50: '[unused49]',
 51: '[unused50]',
 52: '[unused51]',
 53: '[unused52]',

In [18]:
idx2token[7592]

'hello'

In [17]:
# Get the predicted token ids that are masked (masked tokens are the ones with the lowest probability)
outputs.prediction_logits[0, 2].shape


torch.Size([30522])

In [19]:
len(idx2token)

30522

In [20]:
softmax = torch.nn.functional.softmax(outputs.prediction_logits[0, 2], dim=-1)
argmax = torch.argmax(softmax)

In [21]:
argmax

tensor(8181)

In [22]:
idx2token[argmax.item()]

'abraham'

In [25]:
softmax = torch.nn.functional.softmax(outputs.prediction_logits[0], dim=0)
argmax = torch.argmax(softmax, dim=1)

In [26]:
argmax

tensor([28191,  2348,  8181, 16628,  2180,  3882,  2281,  7313,  4883, 27419,
         2006,  2010,  3424,  1011,  8864,  4132,  1010,  2019,  3988,  2698,
         8914,  2163, 13520,  2037,  4336,  2013,  1996,  2406,  2000,  2433,
        28775, 18179, 16363,  2162,  3631,  2041,  1999,  2258,  6863,  2043,
        18232,  2923,  2749,  4548,  3481,  7680,  5017,  2005,  2148,  3792,
        24901,  2074,  2058,  1037,  3204,  2077,  3946,  1005,  1055, 17331,
         1025, 25656])

In [27]:
for i in argmax:
    print(idx2token[i.item()], end=' ')

##ecin although abraham lincolnshire won 1948 november 1860 presidential primaries on his anti - slavery platform , an initial seven tributary states declare their independence from the country to form ##ici confederacy ##yre war broke out in april 1861 when ##oya ##ist forces occupied fort sum ##mer for south carolina ##trip just over a month before grant ' s inauguration ; ##tson 

### Next Sentence Predicition