In [6]:
import torch
from transformers import pipeline

pipeline = pipeline(
    task="fill-mask",
    model="google-bert/bert-base-uncased",
    dtype=torch.float16,
    device=0
)

Some weights of the model checkpoint at google-bert/bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Device set to use mps:0


In [7]:
pipeline("ITMO is the university in [MASK].")

[{'score': 0.0728759765625,
  'token': 9701,
  'token_str': 'ghana',
  'sequence': 'itmo is the university in ghana.'},
 {'score': 0.035247802734375,
  'token': 16274,
  'token_str': 'mozambique',
  'sequence': 'itmo is the university in mozambique.'},
 {'score': 0.0341796875,
  'token': 11959,
  'token_str': 'tanzania',
  'sequence': 'itmo is the university in tanzania.'},
 {'score': 0.0311126708984375,
  'token': 16878,
  'token_str': 'macau',
  'sequence': 'itmo is the university in macau.'},
 {'score': 0.0213775634765625,
  'token': 10031,
  'token_str': 'uganda',
  'sequence': 'itmo is the university in uganda.'}]

In [104]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch.nn.functional as F
import torch

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")

inputs = tokenizer("This movie is [MASK]!", return_tensors="pt")
outputs = model(**inputs)

outputs.logits.shape

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([1, 7, 30522])

In [107]:
mask_token_id = tokenizer.mask_token_id
mask_position = (inputs['input_ids'] == mask_token_id).nonzero(as_tuple=True)[1][0]

probs = torch.softmax(outputs.logits[0][mask_position], dim=-1)

torch.topk(probs, top_k).indices

tensor([12476,  2307,  6429,  9788,  4121, 10392,  4689,  9951,  2204,  8235])

In [108]:
top_k = 10
top_indices = torch.topk(probs, top_k).indices

print("Top predictions:")
for i, idx in enumerate(top_indices):
    token = tokenizer.decode(idx)
    prob = probs[idx].item()
    print(f"{i+1}. {token:15} (index: {idx:5}) - {prob:.4f}")

Top predictions:
1. awesome         (index: 12476) - 0.1405
2. great           (index:  2307) - 0.0885
3. amazing         (index:  6429) - 0.0874
4. incredible      (index:  9788) - 0.0371
5. huge            (index:  4121) - 0.0292
6. fantastic       (index: 10392) - 0.0274
7. crazy           (index:  4689) - 0.0258
8. ridiculous      (index:  9951) - 0.0174
9. good            (index:  2204) - 0.0161
10. brilliant       (index:  8235) - 0.0158


In [109]:
print("Input IDs:", inputs['input_ids'])
print("Token IDs:", inputs['input_ids'][0].tolist())
print("Tokens:", tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))
print("Decoded:", tokenizer.decode(inputs['input_ids'][0]))

Input IDs: tensor([[ 101, 2023, 3185, 2003,  103,  999,  102]])
Token IDs: [101, 2023, 3185, 2003, 103, 999, 102]
Tokens: ['[CLS]', 'this', 'movie', 'is', '[MASK]', '!', '[SEP]']
Decoded: [CLS] this movie is [MASK]! [SEP]
