In [82]:
import numpy as np
import pandas as pd
from IPython.display import display

import torch
import torch.nn.functional as F

from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertForMaskedLM

In [233]:
model_id = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_id, do_lower_case=False)
model = BertForMaskedLM.from_pretrained(model_id)
model.eval()
pass

In [239]:
# the tokenizer can represent any word by spelling it, if necessary
all('##'+c in tokenizer.vocab for c in 'abcdefghijklmnopqrstuvwxyz')

True

In [235]:
tokenizer.tokenize('xkcd')

['x', '##k', '##cd']

In [236]:
# but it's not all that smart
tokenizer.tokenize('foodservice'), tokenizer.tokenize('food service')

(['foods', '##er', '##vic', '##e'], ['food', 'service'])

In [242]:
# and vocab is missing many proper nouns
# (cased model will spell this from parts, uncased will return '[UNK]' because of the capital 'A')
tokenizer.tokenize('Alcatraz')

['[UNK]']

In [243]:
np.random.choice(list(tokenizer.vocab.keys()), size=200)

array(['firearms', 'irritated', 'highest', 'adoptive', 'comedians',
       'coached', 'seeking', 'shin', 'recruits', 'fighters', 'ultrasound',
       '##arium', '##ul', 'pose', 'shut', 'nj', 'loops', 'figure',
       '##moor', 'orton', '##yev', 'replace', 'grade', 'ears',
       '[unused337]', '##ry', '##oja', 'soluble', '##omorphic',
       'sectional', 'territories', 'consultant', 'deported', 'little',
       'hungary', 'immunity', 'hines', 'majority', 'gradual', '##ph',
       'koppen', 'munich', 'diseases', 'violence', 'descendant',
       'custody', '##oi', 'freight', '##ン', '▪', 'usable', '##rwood',
       'questioning', 'stabbed', 'ports', 'fear', 'scanned', 'dialogues',
       'nail', 'artwork', '##力', 'ramsey', '##oy', '##ister', 'pneumonia',
       'roadside', 'mrna', 'piedmont', 'bullying', 'madden', 'journeys',
       'confident', 'superintendent', 'darker', 'rake', 'scroll', '##そ',
       'conservatoire', '##phi', 'obscene', 'skopje', 'liberty',
       'specialization', 's

In [244]:
# text = tokenizer.tokenize('this is a simple test sentence.')
# text = tokenizer.tokenize('certain methods handle self-referential sentences poorly.')
# text = tokenizer.tokenize('who was jim henson?')
# text = tokenizer.tokenize('Is there life on Mars?')
# text = tokenizer.tokenize('The Occupation of Alcatraz was an occupation of Alcatraz Island by 89 American Indians and supporters, led by Richard Oakes, LaNada Means, and others.')
text = tokenizer.tokenize('beam search uses breadth-first search to build its search tree.')
# response = tokenizer.tokenize('jim henson was a puppeteer')
response = ['[MASK]']*5 + ['!']
# response = list(np.random.choice([t for t in tokenizer.vocab if not t.startswith('[unused')], size=5)) + ['.']
# response = tokenizer.tokenize('so is this one.')
# response = tokenizer.tokenize('however, some may perform better than others.')
# response = tokenizer.tokenize('They chose the name Indians of All Tribes and John Trudell was the spokesperson.')
# response = tokenizer.tokenize('At each level of the tree, it generates all successors of the states at the current level, sorting them in increasing order of heuristic cost')
# response = ['[MASK]']*(len(response)-1) + ['.']
# response[0] = '[MASK]'
# response[3] = '[MASK]'
# response[5] = '[MASK]'
response[3] = 'computer'

In [245]:
# tokens = torch.tensor(tokenizer.convert_tokens_to_ids(text+response))
# segments = torch.tensor([0]*len(text) + [1]*len(response))
tokens = torch.tensor(tokenizer.convert_tokens_to_ids(['[CLS]']+text+['[SEP]']+response))
segments = torch.tensor([0]*(len(text)+2) + [1]*len(response))

In [246]:
token_prior = torch.tensor([1.0]*len(tokenizer.vocab))
# token_prior[1012] *= 0
# token_prior[[1000, 1006, 1007, 1010]] *= 0.1

In [247]:
display(' '.join(tokenizer.convert_ids_to_tokens(np.array(tokens))))

'[CLS] beam search uses breadth - first search to build its search tree . [SEP] [MASK] [MASK] [MASK] computer [MASK] !'

In [248]:
def display_text(t):
    display(' '.join(tokenizer.convert_ids_to_tokens(np.array(t))).replace(' ##', ''))

In [249]:
# deterministically fix the most confident [MASK] token one at a time (resembles beam search with size 1?)
pred_tokens = tokens.clone()
idxs_to_predict = [i for i,t in enumerate(tokens) if t==103]
while len(idxs_to_predict):    
    pred = (
        model(pred_tokens.unsqueeze(dim=0), segments.unsqueeze(dim=0))[0]
        .index_select(0, torch.LongTensor(idxs_to_predict))
        )
    pred = F.softmax(pred, -1)*token_prior
    
    p, i = torch.max(pred, dim=-1)
    most_confident_idx = torch.argmax(p).item()
    token_id = i[most_confident_idx].item()
    pred_tokens[idxs_to_predict[most_confident_idx]] = token_id
    
    display_text(pred_tokens[len(text)+2:])
    del idxs_to_predict[most_confident_idx]

'[MASK] [MASK] [MASK] computer ! !'

'[MASK] : [MASK] computer ! !'

'example : [MASK] computer ! !'

'example : hello computer ! !'

In [250]:
# iteratively sample from naive joint distribution 
# does not resemble data distribution but generates pretty concrete poetry
temp=1
pred_tokens = tokens.clone()
for _ in range(32):
    pred = model(pred_tokens.unsqueeze(dim=0), segments.unsqueeze(dim=0))[0, len(text)+2:]
    pred = F.softmax(pred, -1).pow(1/temp)
    pred = pred*token_prior
    pred_tokens[len(text)+2:] = pred.multinomial(1).flatten()
    display_text(pred_tokens[len(text)+2:])

'otherwise : from computer ! !'

'otherwise : from computer ! !'

'otherwise : from computer ! !'

'otherwise : from computer ! !'

'otherwise : from ! ! !'

'otherwise : from ! ! !'

'otherwise : from ! ! !'

'otherwise : from ! ! !'

'otherwise : from ! ! !'

'example : = ! ! !'

'example : = ! ! !'

'example : = ! ! !'

'example : = ! ! !'

'example : = ! ! !'

'example : = ! ! !'

'example : = ! ! !'

'example : = ! ! !'

'example : = + ! !'

'example : = + ! !'

'example : = + ! !'

'example : = + ! !'

'example : = + ! !'

'example : = + ! !'

'example : = + ! !'

'example : = + ! !'

'example : = + + !'

'example : = + + !'

'example : = + + !'

'example : = + + !'

'example : = + + +'

'example : = + + +'

'example : = + + +'

In [251]:
# iteratively sample tokens from softmax, left to right
temp = 1
pred_tokens = tokens.clone()
for i in range(len(text)+2, len(tokens)):
    pred = model(pred_tokens.unsqueeze(dim=0), segments.unsqueeze(dim=0))[0, i]
    pred = F.softmax(pred, -1)
    pred = pred.pow(1/temp)
    pred_tokens[i] = pred.multinomial(1).item()
    display_text(pred_tokens[len(text)+2:])

'mark [MASK] [MASK] computer [MASK] !'

'mark goodnight [MASK] computer [MASK] !'

'mark goodnight from computer [MASK] !'

'mark goodnight from computer [MASK] !'

'mark goodnight from computer help !'

'mark goodnight from computer help !'

In [252]:
# # iteratively sample from second sentence, one random position at a time
temp = 1
pred_tokens = tokens.clone()
idxs_to_predict = [i for i,t in enumerate(tokens) if t==103]
while 103 in pred_tokens:
    pred = (
        model(pred_tokens.unsqueeze(dim=0), segments.unsqueeze(dim=0))[0]
        .index_select(0, torch.LongTensor(idxs_to_predict))
    )
    i = np.random.choice(len(idxs_to_predict))
    pred_tokens[idxs_to_predict[i]] = F.softmax(pred[i], -1).pow(1/temp).multinomial(1).item()
    display_text(pred_tokens[len(text)+2:])
    del idxs_to_predict[i]

'they [MASK] [MASK] computer [MASK] !'

'they watch [MASK] computer [MASK] !'

'they watch their computer [MASK] !'

'they watch their computer ! !'

In [253]:
# iteratively sample one position at a time, but choose i to have maximum entropy
# usually terminates pretty fast but can occasionally get stuck,
# if there's a lower entropy dist. over a chosen token than any mask tokens.
temp = 1
pred_tokens = tokens.clone()
k = len(text)+2
it=0
while 103 in pred_tokens and it < 64:# or it < 2*len(response) :
    it+=1
    pred = model(pred_tokens.unsqueeze(dim=0), segments.unsqueeze(dim=0))[0, k:]
    pred = F.softmax(pred, dim=-1)
    ent = -(pred.log() * pred).sum(dim=-1)
#     display(ent)
    i = ent.argmax()
    pred_tokens[i+k] = pred[i].pow(1/temp).multinomial(1).item()
    display_text(pred_tokens[len(text)+2:])

'stars [MASK] [MASK] computer [MASK] !'

'stars [MASK] [MASK] arcade [MASK] !'

'stars [MASK] of arcade [MASK] !'

'stars appropriated of arcade [MASK] !'

'stars reminiscent of arcade [MASK] !'

'stars reminiscent of arcade fire !'

In [254]:
# iteratively sample one mask position at a time, but choose i to have minimum entropy
temp = 1
pred_tokens = tokens.clone()
idxs_to_predict = [i for i,t in enumerate(tokens) if t==103]
while 103 in pred_tokens:
    pred = (
        model(pred_tokens.unsqueeze(dim=0), segments.unsqueeze(dim=0))[0]
        .index_select(0, torch.LongTensor(idxs_to_predict))
    )
    pred = F.softmax(pred, dim=-1)
    ent = -(pred.log() * pred).sum(dim=-1)
#     display(ent)
    i = ent.argmin()
    
    pred_tokens[idxs_to_predict[i]] = pred[i].pow(1/temp).multinomial(1).item()
    display_text(pred_tokens[len(text)+2:])
    del idxs_to_predict[i]

'[MASK] [MASK] - computer [MASK] !'

'[MASK] facebook - computer [MASK] !'

'[MASK] facebook - computer search !'

'facebook facebook - computer search !'

In [255]:
# iteratively fix one [MASK] token at a time, treating the flattened logits as one big softmax
temp = 1
pred_tokens = tokens.clone()
idxs_to_predict = [i for i,t in enumerate(tokens) if t==103]
# idxs_to_predict = list(range(2+len(text), len(tokens)))
while True:
    display_text(pred_tokens[len(text)+2:])
    
    if not len(idxs_to_predict): break
    
    pred = (
        model(pred_tokens.unsqueeze(dim=0), segments.unsqueeze(dim=0))[0]
        .index_select(0, torch.LongTensor(idxs_to_predict))
        )
    pred = (pred.exp()*token_prior).view(-1)
    if temp!=1:
        pred = pred/pred.sum()
        pred = pred.pow(1/temp)
    
    ij = pred.multinomial(1).item()
    i = ij//len(tokenizer.vocab)
    j = ij%len(tokenizer.vocab)
    pred_tokens[idxs_to_predict[i]] = j
    del idxs_to_predict[i]

'[MASK] [MASK] [MASK] computer [MASK] !'

'[MASK] [MASK] means computer [MASK] !'

'[MASK] h means computer [MASK] !'

'[MASK] h means computer search !'

'field h means computer search !'

In [78]:
#idea: linguistic erosion stalactites. iteratively delete, replace, shorten tokens
# for max entropy position, zero prob of current token then sample
# delete min entropy position
# zero prob of all tokens longer than current token in position, then sample