In [1]:
import torch as t
import torch.nn as nn
import transformers

In [2]:
# https://huggingface.co/bert-base-cased

tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased')
model = transformers.BertForMaskedLM.from_pretrained("bert-base-cased")

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
def ascii_art_probs(text, k=5):

    inputs = tokenizer(text, return_tensors="pt")
    mask_indices, = t.where(inputs["input_ids"][0] == tokenizer.mask_token_id)

    outputs = model(**inputs)
    logits = t.nn.functional.softmax(outputs.logits, dim=2)

    top_k_masks = t.topk(logits, k, dim = 2)[1][0][mask_indices]

    candidate_words = [tokenizer.decode(candidate_list).split() for candidate_list in top_k_masks]
    candidate_percents = [logits[:,mask_index,top_k_masks[i]][0] for i, mask_index in enumerate(mask_indices)]
    logits = logits.argmax(dim=2)
    tokenizer.decode(logits[0])

    s = text.replace('[MASK]', '___') + '\n\n'
    for i, (words, percents) in enumerate(zip(candidate_words, candidate_percents)):
        candidates = ['%d%%\t%s' % (round(float(percent*100)), word)  for word, percent in zip(words, percents)]
        s += '\n'.join(candidates) + '\n\n'
    print(s)

text = "The firetruck was painted a bright [MASK]."
ascii_art_probs(text)

The firetruck was painted a bright ___.

48%	red
15%	yellow
10%	blue
8%	pink
6%	orange


