# BERT for grammar / spell check

In [1]:
# imports

from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification
import torch
from torch.nn import Softmax

In [2]:
BERT_MODEL = 'bert-base-cased'  # using a cased tokenizer because case may matter in grammar / spelling

# load up a tokenizer and BERT with MLM head
bert_tokenizer = BertTokenizer.from_pretrained(BERT_MODEL)
model = BertForMaskedLM.from_pretrained(BERT_MODEL)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
model

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

In [4]:
# note the decoder's output size is the size of the tokenizer's vocab. It is crucial to use a matching tokenizer
model.cls

BertOnlyMLMHead(
  (predictions): BertLMPredictionHead(
    (transform): BertPredictionHeadTransform(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (transform_act_fn): GELUActivation()
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    )
    (decoder): Linear(in_features=768, out_features=28996, bias=True)
  )
)

In [5]:
bert_tokenizer.vocab_size  # Looks good!

28996

In [6]:
def top_predictions(phrase, top_n=1):
    # add a pad token before and after the phrase. 
    #  I find this helps as BERT often will neglect the first and last token otherwise
    phrase = f'{bert_tokenizer.pad_token} {phrase} {bert_tokenizer.pad_token}'
    
    input_ids = bert_tokenizer.encode(phrase, return_tensors="pt")  # get the input_ids from the tokenizer
    
    outputs = model(input_ids)  # run the input ids against BERT
    
    # Get the nth most confident predicted tokens from the MLM head
    prediction_scores = outputs.logits
    predicted_tokens = prediction_scores.argsort()[:,:,-top_n].reshape(-1,)
    
    # Get the probability for each token
    token_probas = Softmax(dim=2)(prediction_scores.sort().values)[:,:,-top_n].reshape(-1, )
    
    for proba, token in zip(token_probas, predicted_tokens):
        print(f'Token: {bert_tokenizer.decode([token])} ({token})  Probability: {proba:.4f}')
        
    return predicted_tokens
        

In [7]:
top_predictions('Last time I went here, me bill was too high.', 1)

Token: . (119)  Probability: 0.0636
Token: " (107)  Probability: 0.9721
Token: Last (4254)  Probability: 0.8593
Token: time (1159)  Probability: 0.9999
Token: I (146)  Probability: 0.9995
Token: went (1355)  Probability: 0.4761
Token: here (1303)  Probability: 0.9999
Token: , (117)  Probability: 1.0000
Token: my (1139)  Probability: 0.9564
Token: bill (4550)  Probability: 0.9953
Token: was (1108)  Probability: 0.9999
Token: too (1315)  Probability: 1.0000
Token: high (1344)  Probability: 0.9989
Token: . (119)  Probability: 1.0000
Token: " (107)  Probability: 0.9807
Token: . (119)  Probability: 1.0000


tensor([ 119,  107, 4254, 1159,  146, 1355, 1303,  117, 1139, 4550, 1108, 1315,
        1344,  119,  107,  119])

In [8]:
top_predictions('My wonderful teacher is so great!', 1)

Token: . (119)  Probability: 0.0563
Token: " (107)  Probability: 0.9262
Token: My (1422)  Probability: 0.9989
Token: wonderful (7310)  Probability: 0.9551
Token: teacher (3218)  Probability: 0.9954
Token: is (1110)  Probability: 0.9981
Token: so (1177)  Probability: 0.9991
Token: great (1632)  Probability: 0.9953
Token: ! (106)  Probability: 1.0000
Token: " (107)  Probability: 0.9189
Token: . (119)  Probability: 0.9683


tensor([ 119,  107, 1422, 7310, 3218, 1110, 1177, 1632,  106,  107,  119])

In [9]:
top_predictions('My wonderful teacher is so great!', 2)  # 2nd choice  for wonderful is brilliant

Token: , (117)  Probability: 0.0202
Token: ' (112)  Probability: 0.0596
Token: my (1139)  Probability: 0.0006
Token: brilliant (8431)  Probability: 0.0154
Token: instructor (10332)  Probability: 0.0009
Token: was (1108)  Probability: 0.0014
Token: very (1304)  Probability: 0.0004
Token: wonderful (7310)  Probability: 0.0027
Token: . (119)  Probability: 0.0000
Token: ' (112)  Probability: 0.0763
Token: ! (106)  Probability: 0.0311


tensor([  117,   112,  1139,  8431, 10332,  1108,  1304,  7310,   119,   112,
          106])

In [10]:
top_predictions('My wonderful teacher is so great!', 3)  # 3rd choice  for wonderful is great

Token: the (1103)  Probability: 0.0174
Token: . (119)  Probability: 0.0066
Token: The (1109)  Probability: 0.0002
Token: great (1632)  Probability: 0.0093
Token: Teacher (14208)  Probability: 0.0007
Token: isn (2762)  Probability: 0.0001
Token: such (1216)  Probability: 0.0002
Token: brilliant (8431)  Probability: 0.0005
Token: ? (136)  Probability: 0.0000
Token: ! (106)  Probability: 0.0039
Token: ? (136)  Probability: 0.0004


tensor([ 1103,   119,  1109,  1632, 14208,  2762,  1216,  8431,   136,   106,
          136])

In [11]:
# Lookahead prediction

def look_ahead(phrase):
    # add a mask token at the end
    phrase = f'{phrase} {bert_tokenizer.mask_token} {bert_tokenizer.pad_token}'
    
    input_ids = bert_tokenizer.encode(phrase, return_tensors="pt")  # get the input_ids from the tokenizer
    
    outputs = model(input_ids)  # run the input ids against BERT
    
    # Get the nth most confident predicted tokens from the MLM head
    prediction_scores = outputs.logits
    
    for i in range(1, 4):
        print(f'Top Score {i}')
        predicted_tokens = prediction_scores.argsort()[:,:,-i].reshape(-1,)

        # Get the probability for each token
        token_probas = Softmax(dim=2)(prediction_scores.sort().values)[:,:,-i].reshape(-1, )

        for proba, token in list(zip(token_probas, predicted_tokens))[input_ids.shape[1] - 3:]:
            print(f'Token: {bert_tokenizer.decode([token])} ({token})  Probability: {proba:.4f}')
        print()
    return predicted_tokens


In [12]:
look_ahead('Can we split the')

Top Score 1
Token: time (1159)  Probability: 0.0528
Token: ? (136)  Probability: 0.9924
Token: . (119)  Probability: 0.9999

Top Score 2
Token: money (1948)  Probability: 0.0303
Token: . (119)  Probability: 0.0056
Token: ? (136)  Probability: 0.0001

Top Score 3
Token: numbers (2849)  Probability: 0.0271
Token: ! (106)  Probability: 0.0015
Token: ! (106)  Probability: 0.0000



tensor([1103,  117, 1284, 2866, 1412, 2849,  106,  106])

In [13]:
look_ahead('Where are we')

Top Score 1
Token: going (1280)  Probability: 0.8487
Token: ? (136)  Probability: 0.9920
Token: . (119)  Probability: 0.9986

Top Score 2
Token: now (1208)  Probability: 0.0605
Token: . (119)  Probability: 0.0046
Token: ? (136)  Probability: 0.0013

Top Score 3
Token: headed (2917)  Probability: 0.0298
Token: ! (106)  Probability: 0.0032
Token: ; (132)  Probability: 0.0000



tensor([ 107,  117, 1231, 1128, 2917,  106,  132])

In [14]:
look_ahead('This class is kind of')

Top Score 1
Token: unique (3527)  Probability: 0.0218
Token: . (119)  Probability: 0.9514
Token: . (119)  Probability: 0.9967

Top Score 2
Token: fun (4106)  Probability: 0.0216
Token: ; (132)  Probability: 0.0225
Token: ? (136)  Probability: 0.0010

Top Score 3
Token: special (1957)  Probability: 0.0206
Token: ! (106)  Probability: 0.0186
Token: ! (106)  Probability: 0.0008



tensor([ 107,  117, 4370, 1108, 1472, 4106, 1957,  106,  106])