In [1]:
import transformers as trf
import torch as pt
import numpy as np
from collections import Counter, defaultdict
from tqdm import tqdm

Intel(R) Data Analytics Acceleration Library (Intel(R) DAAL) solvers for sklearn enabled: https://intelpython.github.io/daal4py/sklearn.html


In [2]:
lmbert = trf.BertForMaskedLM.from_pretrained('bert-base-uncased')

In [3]:
tokenizer = trf.BertTokenizer.from_pretrained('bert-base-uncased')

In [4]:
def collect(sentence, model=lmbert, tokenizer=tokenizer):
    tks = tokenizer.encode(sentence)
    mask_pos = tks.index(tokenizer.mask_token_id)
    bert_input = pt.tensor(tks).unsqueeze(0)
    all_tk = tokenizer.convert_ids_to_tokens(list(range(tokenizer.vocab_size)))
    predictions = defaultdict(lambda: [])
    for i, tk in tqdm(enumerate(all_tk)):
        bert_input[0, mask_pos] = i
        bert_mask_output = model(bert_input)[0].squeeze()[mask_pos]
        predicted = bert_mask_output.argmax()
        if i != predicted:
            pred_tk = all_tk[predicted]
    #         print(f"{i} :: {tk} :: {pred_tk}")
            predictions[pred_tk].append(tk)
    return predictions

In [5]:
def do_present(sentence):
    predictions = collect(sentence)
    print(predictions)
    prediction_counters = list(map(lambda t: (t[0], len(t[1])), predictions.items()))
    print("##############################################")
    print(sorted(prediction_counters, key=lambda t: -t[1]))

In [6]:
do_present('Francesco Bartolomeo Conti was born in [MASK].')

30522it [17:08, 29.67it/s]

##############################################
[('rome', 2519), ('bologna', 685), ('genoa', 586), ('ca', 391), ('verona', 363), ('como', 320), ('florence', 273), ('c', 186), ('venice', 152), ('town', 95), ('padua', 81), ('villa', 75), ('ferrara', 69), ('st', 56), ('the', 51), ('turin', 50), ('est', 49), ('alexandria', 43), ('pisa', 42), ('it', 38), ('po', 37), ('milan', 36), ('san', 36), ('here', 35), ('[UNK]', 31), ('val', 30), ('.', 29), ('this', 27), ('palazzo', 26), ('e', 25), ('italy', 19), ('a', 19), ('##e', 18), ('port', 18), ('##o', 18), ('lima', 17), ('##gno', 15), ('campo', 15), (',', 14), ('parma', 14), ('harbour', 14), ('##eto', 14), ('m', 13), ('mt', 13), ('there', 13), ('s', 13), ('nearby', 13), ('##gli', 12), ('##a', 11), ('siena', 11), ('marche', 10), ('split', 9), ('piazza', 9), ('edo', 8), ('roma', 8), ('ad', 8), ('milano', 8), ('м', 7), ('ss', 7), ('bern', 7), ('bray', 7), ('sts', 7), ('n', 6), ('village', 6), ('##oda', 6), ('production', 5), ('point', 5), ('##ggio',




In [7]:
do_present("If you want to get [MASK] then you should drink wine or beer.")

30522it [15:17, 33.27it/s]

defaultdict(<function collect.<locals>.<lambda> at 0x7fa0c4479ea0>, {'drunk': ['[PAD]', '[unused0]', '[unused1]', '[unused2]', '[unused3]', '[unused4]', '[unused5]', '[unused6]', '[unused7]', '[unused8]', '[unused9]', '[unused10]', '[unused11]', '[unused12]', '[unused13]', '[unused14]', '[unused15]', '[unused16]', '[unused17]', '[unused18]', '[unused19]', '[unused20]', '[unused21]', '[unused22]', '[unused23]', '[unused24]', '[unused25]', '[unused26]', '[unused27]', '[unused28]', '[unused29]', '[unused30]', '[unused31]', '[unused32]', '[unused33]', '[unused34]', '[unused35]', '[unused36]', '[unused37]', '[unused38]', '[unused39]', '[unused40]', '[unused41]', '[unused42]', '[unused43]', '[unused44]', '[unused45]', '[unused46]', '[unused47]', '[unused48]', '[unused49]', '[unused50]', '[unused51]', '[unused52]', '[unused53]', '[unused54]', '[unused55]', '[unused56]', '[unused57]', '[unused58]', '[unused59]', '[unused60]', '[unused61]', '[unused62]', '[unused63]', '[unused64]', '[unused65]'




In [8]:
do_present("Static electricity can give you a [MASK] shock when you touch metal objects during dry weather.")

30522it [15:52, 32.04it/s]

##############################################





In [9]:
do_present('By avoiding crowds during peak flu season, you reduce your chances of [MASK].')

30522it [16:11, 31.43it/s]

##############################################





In [10]:
do_present('You would write a story to [MASK], teach or entertain.')

30522it [15:26, 32.96it/s]

##############################################
[('read', 5016), ('entertain', 4261), ('yourself', 1289), ('be', 1002), ('write', 769), ('learn', 534), ('you', 533), ('teach', 270), ('play', 268), ('tell', 181), ('them', 164), ('educate', 145), ('do', 136), ('inspire', 136), ('create', 76), ('think', 72), ('see', 63), ('it', 60), ('watch', 57), ('hear', 54), ('draw', 52), ('win', 52), ('eat', 47), ('grow', 45), ('sing', 38), ('sell', 37), ('speak', 35), ('make', 34), ('share', 28), ('understand', 28), ('kill', 27), ('fly', 27), ('me', 26), ('judge', 22), ('paint', 22), ('challenge', 22), ('connect', 21), ('find', 19), ('talk', 19), ('build', 19), ('help', 19), ('feed', 19), ('defend', 18), ('come', 17), ('impress', 17), ('students', 17), ('fight', 17), ('remember', 15), ('get', 14), ('listen', 14), ('trade', 14), ('##t', 13), ('end', 13), ('lie', 12), ('live', 12), ('protect', 12), ('cook', 12), ('travel', 12), ('sleep', 11), ('dance', 11), ('say', 11), ('hunt', 11), ('comfort', 11), ('




In [11]:
do_present('[MASK] is a fabric made from the hair of sheep.')

30522it [14:37, 34.79it/s]

defaultdict(<function collect.<locals>.<lambda> at 0x7fa0984ecd08>, {'it': ['[PAD]', '[unused0]', '[unused1]', '[unused2]', '[unused3]', '[unused4]', '[unused5]', '[unused6]', '[unused7]', '[unused8]', '[unused9]', '[unused10]', '[unused11]', '[unused12]', '[unused13]', '[unused14]', '[unused15]', '[unused16]', '[unused17]', '[unused18]', '[unused19]', '[unused20]', '[unused21]', '[unused22]', '[unused23]', '[unused24]', '[unused25]', '[unused26]', '[unused27]', '[unused28]', '[unused29]', '[unused30]', '[unused31]', '[unused32]', '[unused33]', '[unused34]', '[unused35]', '[unused36]', '[unused37]', '[unused38]', '[unused39]', '[unused40]', '[unused41]', '[unused42]', '[unused43]', '[unused44]', '[unused45]', '[unused46]', '[unused47]', '[unused48]', '[unused49]', '[unused50]', '[unused51]', '[unused52]', '[unused53]', '[unused54]', '[unused55]', '[unused56]', '[unused57]', '[unused58]', '[unused59]', '[unused60]', '[unused61]', '[unused62]', '[unused63]', '[unused64]', '[unused65]', '




In [12]:
do_present('The effect of bringing home some fish is a new pet or a tasty [MASK].')

30522it [16:11, 31.42it/s]

##############################################
[('fish', 6560), ('meal', 5938), ('dish', 5168), ('food', 2071), ('one', 313), ('taste', 227), ('pet', 188), ('product', 132), ('catch', 116), ('drink', 99), ('toy', 86), (',', 76), ('sauce', 71), ('treat', 70), ('flavor', 59), ('aquarium', 53), ('snack', 43), ('diet', 26), ('fry', 21), ('person', 20), ('game', 18), ('combination', 16), ('flavour', 14), ('.', 13), ('item', 13), ('is', 13), ('place', 11), ('egg', 11), ('lure', 11), ('flower', 10), ('film', 10), ('dance', 10), ('player', 10), ('shape', 10), ('or', 10), ('thing', 9), ('challenge', 9), ('look', 9), ('feed', 9), ('response', 8), ('piece', 8), ('book', 7), ('name', 7), ('character', 7), ('sound', 7), ('plant', 7), ('change', 7), ('skin', 7), ('symbol', 6), ('dinner', 6), ('shell', 6), ('size', 6), ('home', 6), ('trophy', 6), ('image', 6), ('effect', 5), ('play', 5), ('business', 5), ('form', 5), ('meat', 5), ('lesson', 5), ('environment', 5), ('feeling', 5), ('test', 5), ('addit


