In [20]:
import torch
from transformers import BertTokenizer, BertForMaskedLM

# Initialize the BERT model and tokenizer
model = BertForMaskedLM.from_pretrained('bert-large-uncased')
model.eval()
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')

# Tokenize the input text and mask the first word
tokens = ["[CLS]", "I", "flew", "to", "[MASK]", "which", "is", "the", "capital", "of", "germany", "[SEP]"]
mask_idx = 5
print(tokens)
input_ids = torch.tensor(tokenizer.encode(tokens)).unsqueeze(0)

# Get the top 10 most likely words
with torch.no_grad():
    print(input_ids)
    out = model(input_ids)
    print(out['logits'].size())
    logits = out['logits'][0][mask_idx]
    probs = torch.nn.functional.softmax(logits, dim=-1)
    top_probs, top_indices = torch.topk(probs, k=20)
    top_probs = top_probs.tolist()
    top_words = [tokenizer.convert_ids_to_tokens([i])[0] for i in top_indices]

# Print the results
for word, prob in zip(top_words, top_probs):
    print(f"{word}: {prob:.4f}")


Downloading:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

['[CLS]', 'I', 'flew', 'to', '[MASK]', 'which', 'is', 'the', 'capital', 'of', 'germany', '[SEP]']
tensor([[ 101,  101,  100, 5520, 2000,  103, 2029, 2003, 1996, 3007, 1997, 2762,
          102,  102]])
torch.Size([1, 14, 30522])
berlin: 0.6312
frankfurt: 0.0643
munich: 0.0583
stuttgart: 0.0492
bonn: 0.0465
cologne: 0.0305
dusseldorf: 0.0132
hamburg: 0.0111
leipzig: 0.0064
nuremberg: 0.0063
prague: 0.0057
hanover: 0.0048
dresden: 0.0046
halle: 0.0045
vienna: 0.0041
aachen: 0.0041
potsdam: 0.0038
wrocław: 0.0031
heidelberg: 0.0026
warsaw: 0.0023
