In [2]:
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM

In [3]:
model_name = 'bert-large-uncased-whole-word-masking'

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained(model_name)

# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained(model_name)
model.eval()
model = model.to('cuda')

In [4]:
from torch import nn

logsm = nn.LogSoftmax()
logsm = logsm.to('cuda')

In [5]:
def get_predictions(seq: str, masked_indexes=None):
    tokenized_text = tokenizer.tokenize(seq)
    if masked_indexes is None:
        if "[MASK]" not in tokenized_text:
            raise Exception("[MASK] token not in input")
        masked_indexes = indices = [i for i, x in enumerate(tokenized_text) if x == "[MASK]"]
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    tokens_tensor = torch.tensor([indexed_tokens])
    tokens_tensor = tokens_tensor.to('cuda')
    with torch.no_grad():
        outputs = model(tokens_tensor)
        predictions = outputs[0]
    return predictions[0, masked_indexes]

In [6]:
get_predictions("[CLS] I love eating fruits such as [MASK] [SEP]")

tensor([[-3.4119, -5.4709, -6.2785,  ..., -4.6543, -5.6838, -4.9877]],
       device='cuda:0')

In [7]:
def predictions_to_list(predictions):
    results = []
    for pred in predictions:
        predicted_indexes = list(torch.argsort(pred))[::-1]
        predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_indexes)
        predicted_probs = logsm(pred[predicted_indexes])
        results.append(list(zip(predicted_tokens, predicted_probs, pred[predicted_indexes])))
    return results

In [8]:
def print_predictions(seq: str, limit: int, masked_index=None):
    for pred in predictions_to_list(get_predictions(seq, masked_index)):
        for token, prob, p in pred[:limit]:
            print(f"token: {token}\t log probability: {prob}\t p: {p}")
        print("----")

In [9]:
sentence = "[CLS] apple is a [MASK] . [SEP]"
print_predictions(sentence, 10)

  


token: fruit	 log probability: -1.031285285949707	 p: 15.024468421936035
token: plant	 log probability: -1.5805988311767578	 p: 14.475154876708984
token: vegetable	 log probability: -2.7677001953125	 p: 13.288053512573242
token: flower	 log probability: -3.0948991775512695	 p: 12.960854530334473
token: tree	 log probability: -3.323002815246582	 p: 12.73275089263916
token: seed	 log probability: -3.6802635192871094	 p: 12.375490188598633
token: cereal	 log probability: -3.6965322494506836	 p: 12.359221458435059
token: fungus	 log probability: -3.8061017990112305	 p: 12.249651908874512
token: grain	 log probability: -3.8994436264038086	 p: 12.156310081481934
token: food	 log probability: -4.320439338684082	 p: 11.73531436920166
----


In [10]:
sentence = "[CLS] [MASK], such as saturday and friday [SEP]"
print_predictions(sentence, 10)

  


token: weekdays	 log probability: -2.153414726257324	 p: 12.686539649963379
token: sunday	 log probability: -2.1941633224487305	 p: 12.645791053771973
token: sundays	 log probability: -2.382669448852539	 p: 12.457284927368164
token: holidays	 log probability: -2.4000444412231445	 p: 12.439909934997559
token: monday	 log probability: -2.827853202819824	 p: 12.012101173400879
token: saturday	 log probability: -2.838412284851074	 p: 12.001542091369629
token: weekends	 log probability: -2.953669548034668	 p: 11.886284828186035
token: weekday	 log probability: -3.213428497314453	 p: 11.62652587890625
token: friday	 log probability: -3.275822639465332	 p: 11.564131736755371
token: fridays	 log probability: -3.424633026123047	 p: 11.415321350097656
----


In [11]:
sentence = "[CLS] [MASK] [MASK], such as elvis and the beatles [SEP]"
print_predictions(sentence, 10)

  


token: pop	 log probability: -1.2998237609863281	 p: 12.097662925720215
token: popular	 log probability: -2.0287046432495117	 p: 11.368782043457031
token: rock	 log probability: -2.095531463623047	 p: 11.301955223083496
token: musical	 log probability: -3.063312530517578	 p: 10.334174156188965
token: famous	 log probability: -3.370610237121582	 p: 10.026876449584961
token: music	 log probability: -3.593216896057129	 p: 9.804269790649414
token: celebrity	 log probability: -3.9175596237182617	 p: 9.479927062988281
token: notable	 log probability: -4.284005165100098	 p: 9.113481521606445
token: superstar	 log probability: -4.489748001098633	 p: 8.90773868560791
token: mainstream	 log probability: -4.56675910949707	 p: 8.830727577209473
----
token: artists	 log probability: -1.5210762023925781	 p: 12.927189826965332
token: musicians	 log probability: -1.8203630447387695	 p: 12.62790298461914
token: music	 log probability: -2.247682571411133	 p: 12.200583457946777
token: icons	 log probabil

In [12]:
sentence = '[CLS] "let it be", "strawberry fields forever" and "I wanna hold your hand" are songs by [MASK] [MASK] . [SEP]'
print_predictions(sentence, 10)

  


token: the	 log probability: -3.1562976837158203	 p: 9.283946990966797
token: michael	 log probability: -3.59041690826416	 p: 8.849827766418457
token: david	 log probability: -3.74137020111084	 p: 8.698874473571777
token: james	 log probability: -4.008335113525391	 p: 8.431909561157227
token: neil	 log probability: -4.092800140380859	 p: 8.347444534301758
token: billy	 log probability: -4.180458068847656	 p: 8.259786605834961
token: george	 log probability: -4.258735656738281	 p: 8.181509017944336
token: tom	 log probability: -4.492015361785889	 p: 7.9482293128967285
token: elton	 log probability: -4.507739067077637	 p: 7.9325056076049805
token: bryan	 log probability: -4.555271148681641	 p: 7.884973526000977
----
token: jackson	 log probability: -3.8180456161499023	 p: 8.245426177978516
token: adams	 log probability: -4.385303020477295	 p: 7.678168773651123
token: williams	 log probability: -4.399244785308838	 p: 7.66422700881958
token: springsteen	 log probability: -4.633245468139648

In [14]:
sentence = '[CLS] דמויות מהתנ"ך, כמו [MASK][MASK], היו דמויות מרכזיות בתקופת ממלכת יהודה. [SEP]'
print_predictions(sentence, 10)

  


token: "	 log probability: -1.9399681091308594	 p: 13.210062026977539
token: ##י	 log probability: -2.0456180572509766	 p: 13.104412078857422
token: ##ו	 log probability: -2.2604904174804688	 p: 12.88953971862793
token: ##ד	 log probability: -2.605360984802246	 p: 12.544669151306152
token: ##ל	 log probability: -2.8682384490966797	 p: 12.281791687011719
token: ##נ	 log probability: -3.020265579223633	 p: 12.129764556884766
token: ##מ	 log probability: -3.2315244674682617	 p: 11.918505668640137
token: ##פ	 log probability: -3.300172805786133	 p: 11.849857330322266
token: ##ת	 log probability: -3.436406135559082	 p: 11.713624000549316
token: ##ר	 log probability: -3.470670700073242	 p: 11.679359436035156
----
token: ##ה	 log probability: -1.3304805755615234	 p: 14.118085861206055
token: ##ת	 log probability: -1.6360149383544922	 p: 13.812551498413086
token: ##ו	 log probability: -2.020946502685547	 p: 13.427619934082031
token: ##ם	 log probability: -2.962373733520508	 p: 12.4861927032470