In [1]:
from IPython.core.display import display, HTML, Image
display(HTML("<style>.container { width:95% !important; }</style>"))
%config IPCompleter.use_jedi=False

In [2]:
from transformers import AutoTokenizer, TFAutoModelForMaskedLM
import tensorflow as tf
import re
from itertools import chain
from string import punctuation
import os
import numpy as np
from itertools import chain
from collections import Counter

In [3]:
tf.get_logger().setLevel('ERROR')
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
model = TFAutoModelForMaskedLM.from_pretrained("roberta-base")
model.trainable = False

Downloading:   0%|          | 0.00/481 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/627M [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFRobertaForMaskedLM.

All the layers of TFRobertaForMaskedLM were initialized from the model checkpoint at roberta-base.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFRobertaForMaskedLM for predictions without further training.


In [29]:
def filter_logits(logits, indices_to_filter=None, fill_with=-np.inf): # different behavior for [-1] on gpu and cpu TODO add ability to prefer certain logits
    if indices_to_filter is None:
        return logits
    else:
        indices_to_filter = tf.expand_dims(indices_to_filter, axis=1)
        filters = tf.fill([len(indices_to_filter)], fill_with)
        filtered_logits = tf.tensor_scatter_nd_update(logits, indices_to_filter, filters)    
        return filtered_logits
    

def pairwise(iterable):
    "s -> (s0, s1), (s2, s3), (s4, s5), ..."
    a = iter(iterable)
    return zip(a, a)

def sample_from_logits(logits, num_samples=1):
    sample = tf.random.categorical([logits], num_samples, dtype=tf.int32)
    sample = tf.reshape(sample, shape=(tf.shape(sample)[1],))
    return sample

def temperature(logits, temp):
    return logits / temp

def top_k_filter(logits, top_k):
    indices_to_filter = tf.math.top_k(-logits, k=len(logits)-top_k).indices
    print(indices_to_filter)
    return filter_logits(logits, indices_to_filter=indices_to_filter)

def top_p_filter(logits, top_p):
    logits = tf.squeeze(logits)
    indices_sorted = tf.argsort(logits, direction="DESCENDING")
    logits_sorted = tf.gather(logits, indices_sorted)
    probs_sorted = tf.nn.softmax(logits_sorted, axis=-1)
    cutoff_index = tf.argmax(tf.cumsum(probs_sorted)>top_p)
    indices_to_filter = indices_sorted[cutoff_index+1:]
    return filter_logits(logits, indices_to_filter=indices_to_filter)

def filter_masked_logits(logits, top_k=0, top_p=1.0, temp=1.0, custom=None):
    logits = filter_logits(logits, indices_to_filter=custom)
    if temp!=1.0:
        logits = temperature(logits, temp=temp)
    if top_k>0:
        logits = top_k_filter(logits, top_k=top_k)
    if top_p<1.0:
        logits = top_p_filter(logits, top_p=top_p)
    return logits

def tokenize_sentence(sentence, tokenizer):
        tokenized_tensor = tokenizer(sentence, return_tensors='tf')['input_ids']
        tokens = tf.squeeze(tokenized_tensor)
        return tokens

def frozen_indices(sentence, tokenizer, separator="&"):
    indices = [(np.array(match.span()) - i) for i,match in enumerate(re.finditer(separator, sentence))]
    indices = [np.concatenate(thing)[[0,-1]] for thing in list(pairwise(indices))]
    indices = [(thing - np.array((0,1))) for thing in indices]
    mapping = tokenizer(sentence.replace("&", ""), return_offsets_mapping=True, return_attention_mask=False, return_token_type_ids=False)['offset_mapping'][1:-1]
    tokenized = tokenizer.tokenize(sentence.replace("&", ""))
    r = len(tokenized)+1
    mapping = dict(zip(mapping, list(range(1,r))))
    all_ix = {}
    for span in indices:
        a,b = span
        try:
            i,j = {k:v for k,v in mapping.items() if k[0]==a or k[1]==b}.values()
            ix_span = list(range(i, j+1))
        except:
            ix_span = {k:v for k,v in mapping.items() if k[0]==a or k[1]==b}.values()
        all_ix.update({i:tokenized[i-1] for i in ix_span})
    return set(all_ix.keys()), sentence.replace("&", "")
    
def energy_norm(tokens, tokenizer, model):
    ix = list(range(1,len(tokens)-1))
    mask_token = tf.constant([tokenizer.mask_token_id])
    log_probs = tf.constant([0.])
    for i in ix:
        indices = tf.constant([[i]])
        updates = mask_token
        masked = tf.tensor_scatter_nd_update(tokens, indices, updates)   
        masked_token_logits = tf.squeeze(model(tf.expand_dims(masked, axis=0))[0])[i]       
        log_prob = -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tokens[i], logits=masked_token_logits)
        log_prob = tf.expand_dims(log_prob, axis=0)
        log_probs = tf.concat([log_probs, log_prob], axis=0)
    return tf.reduce_sum(log_probs) / len(log_probs)

def energy_raw(tokens, tokenizer, model):
    ix = list(range(1,len(tokens)-1))
    mask_token = tf.constant([tokenizer.mask_token_id])
    logit_values = tf.constant([0.])
    for i in ix:
        indices = tf.constant([[i]])
        updates = mask_token
        masked = tf.tensor_scatter_nd_update(tokens, indices, updates)
        masked_token_logits = tf.squeeze(model(tf.expand_dims(masked, axis=0))[0])[i]       
        logit_value = masked_token_logits[tokens[i]]
        logit_value = tf.expand_dims(logit_value, axis=0)
        logit_values = tf.concat([logit_values, logit_value], axis=0)
    return tf.reduce_sum(logit_values) / len(logit_values)

def metropolis_hastings_one_iteration(tokens, mask_token, tokenizer, model, frozen=set(), 
                                      randomized=True, mode='norm', top_k=0, top_p=1.0, temp=1.0, custom=None):
    ix = list(range(1,len(tokens)-1))
    ix = list(set(ix)-frozen)
#     print(dict(zip(ix, tokens.numpy())))
    if randomized:
        np.random.shuffle(ix)
    for i in ix:
        masked = tf.tensor_scatter_nd_update(tokens, tf.constant([[i]]), mask_token)
        masked_token_logits = tf.squeeze(model(tf.expand_dims(masked, axis=0))[0])[i]
        masked_token_logits = filter_masked_logits(masked_token_logits, top_k=top_k, top_p=top_p, temp=temp, custom=custom) # custom = custom[i]??
        candidate_token = sample_from_logits(masked_token_logits, num_samples=1)
        candidate_tokens = tf.tensor_scatter_nd_update(tokens, tf.constant([[i]]), candidate_token)
        if mode=='norm':
            energy = energy_norm
        elif mode=='raw':
            energy = energy_raw
        E_old = energy(tokens, tokenizer, model)
        E_new = energy(candidate_tokens, tokenizer, model)
        q_old = tf.nn.softmax(masked_token_logits)[tokens[i]]
        q_new = tf.nn.softmax(masked_token_logits)[tf.squeeze(candidate_token)]
        accept = tf.minimum(1., (E_new/E_old)*(q_old/q_new))
        u = tf.random.uniform(shape=())
        if tokens[i] == mask_token:
            accept=1.
        tokens =  candidate_tokens if u <= accept else tokens
    return tokens

def generate(sentence, tokenizer, model, num_iterations=10, randomized=True, mode='norm', 
             top_k=0, top_p=1.0, temp=1.0, custom=None, print_generated=True):
    frozen, sentence = frozen_indices(sentence, tokenizer)
    new_tokens = tokenize_sentence(sentence, tokenizer)
    #CREATE FILTERING DICT HERE
    token_lists = [new_tokens]
    mask_token = tf.constant([tokenizer.mask_token_id])
    for i in range(num_iterations):
        new_tokens = metropolis_hastings_one_iteration(new_tokens, mask_token, tokenizer, model, frozen=frozen, randomized=randomized, mode=mode, top_k=top_k, top_p=top_p, temp=temp, custom=custom)
        token_lists.append(new_tokens)
        if print_generated:
            print(tokenizer.decode(new_tokens))
    return token_lists

def lipogram_filter(tokenizer, letter="e"):
    vocabulary = {tokenizer.convert_tokens_to_string(k):v for k,v in tokenizer.vocab.items()}
    return [vocabulary[subword] for subword in vocabulary if (letter in subword or letter.capitalize() in subword)]

In [19]:
vocabulary = {tokenizer.convert_tokens_to_string(k):v for k,v in tokenizer.vocab.items()}

In [None]:
[subword for subword in vocabulary if ("sex" in subword or "Sex" in subword)]

In [30]:
%%time
sentence="<mask><mask><mask>."
tokens = generate(sentence, tokenizer, model, randomized=False, mode='norm', num_iterations=5, top_p=0.8, temp=1.0, print_generated=True)
tokenizer.batch_decode(tf.stack(tokens))

{1: 0, 2: 50264, 3: 50264, 4: 50264}
<s>7.1.</s>
{1: 0, 2: 406, 3: 4, 4: 134}
<s>7. 8.</s>
{1: 0, 2: 406, 3: 4, 4: 290}
<s>7. Introduction.</s>
{1: 0, 2: 406, 3: 4, 4: 24474}
<s>7. California legislature</s>
{1: 0, 2: 406, 3: 4, 4: 886}
<s>C. Iowa legislature</s>
CPU times: user 20.7 s, sys: 48.7 ms, total: 20.8 s
Wall time: 20.9 s


['<s><mask><mask><mask>.</s>',
 '<s>7.1.</s>',
 '<s>7. 8.</s>',
 '<s>7. Introduction.</s>',
 '<s>7. California legislature</s>',
 '<s>C. Iowa legislature</s>']

In [18]:
%%time
sentence="&In this prison of flesh& <mask><mask><mask><mask><mask><mask><mask><mask>"
tokens = generate(sentence, tokenizer, model, randomized=False, mode='raw', num_iterations=5, top_p=0.8, temp=1.05, print_generated=True)
tokenizer.batch_decode(tf.stack(tokens))

<s>In this prison of flesh and blood, they will be liberated.</s>
<s>In this prison of flesh and blood, they must be alive.</s>
<s>In this prison of flesh and blood, they might stay alive.</s>
<s>In this prison of flesh and blood, they might stay alive.</s>
<s>In this prison of flesh and blood, they might remain apart.</s>
CPU times: user 2min 16s, sys: 404 ms, total: 2min 16s
Wall time: 2min 16s


['<s>In this prison of flesh<mask><mask><mask><mask><mask><mask><mask><mask></s>',
 '<s>In this prison of flesh and blood, they will be liberated.</s>',
 '<s>In this prison of flesh and blood, they must be alive.</s>',
 '<s>In this prison of flesh and blood, they might stay alive.</s>',
 '<s>In this prison of flesh and blood, they might stay alive.</s>',
 '<s>In this prison of flesh and blood, they might remain apart.</s>']

In [None]:
%%time
sentence="&In this prison of flesh& <mask><mask><mask><mask><mask><mask><mask><mask>"
tokens = generate(sentence, tokenizer, model, num_iterations=20, top_p=0.8, temp=1.05, print_generated=True)
tokenizer.batch_decode(tf.stack(tokens))

In [10]:
%%time
sentence="&I traveled far and& <mask><mask><mask><mask><mask><mask><mask><mask>"
tokens = generate(sentence, tokenizer, model, num_iterations=20, top_p=0.8, temp=1.05, print_generated=True)
tokenizer.batch_decode(tf.stack(tokens))

<s>I traveled far and wide while I was learning to fly.</s>
<s>I traveled far and wide while I was learning to read.</s>
<s>I traveled far and wide while I was able to spare.</s>
<s>I traveled far and wide while I was able to rest.</s>
<s>I traveled far and wide while I was able to survive.</s>
<s>I traveled far and wide while I was able to drive.</s>
<s>I traveled far and wide while I was able to teach.</s>
<s>I traveled far and wide while I was teaching to teach.</s>
<s>I traveled far and wide while I was able to teach.</s>
<s>I traveled far and wide while I was preparing to teach.</s>
<s>I traveled far and wide as I was preparing to teach.</s>
<s>I traveled far and wide as I was preparing to teach.</s>
<s>I traveled far and wide as I was preparing to teach.</s>
<s>I traveled far and wide as I was preparing to teach.</s>
<s>I traveled far and wide while I was preparing to teach.</s>
<s>I traveled far and wide when I was preparing to travel.</s>
<s>I traveled far and fast while I was 

['<s>I traveled far and<mask><mask><mask><mask><mask><mask><mask><mask></s>',
 '<s>I traveled far and wide while I was learning to fly.</s>',
 '<s>I traveled far and wide while I was learning to read.</s>',
 '<s>I traveled far and wide while I was able to spare.</s>',
 '<s>I traveled far and wide while I was able to rest.</s>',
 '<s>I traveled far and wide while I was able to survive.</s>',
 '<s>I traveled far and wide while I was able to drive.</s>',
 '<s>I traveled far and wide while I was able to teach.</s>',
 '<s>I traveled far and wide while I was teaching to teach.</s>',
 '<s>I traveled far and wide while I was able to teach.</s>',
 '<s>I traveled far and wide while I was preparing to teach.</s>',
 '<s>I traveled far and wide as I was preparing to teach.</s>',
 '<s>I traveled far and wide as I was preparing to teach.</s>',
 '<s>I traveled far and wide as I was preparing to teach.</s>',
 '<s>I traveled far and wide as I was preparing to teach.</s>',
 '<s>I traveled far and wide wh

In [12]:
%%time
sentence="&I traveled far and wide while I was learning to fly,\nBut& <mask><mask><mask><mask><mask><mask><mask><mask><mask>& die.&"
tokens = generate(sentence, tokenizer, model, num_iterations=20, top_p=0.8, temp=1.05, print_generated=True)
tokenizer.batch_decode(tf.stack(tokens))

<s>I traveled far and wide while I was learning to fly,
But for some reason, they decided to let me die.</s>
<s>I traveled far and wide while I was learning to fly,
But for some reason my destiny refused to let me die.</s>
<s>I traveled far and wide while I was learning to fly,
But for some reason my destiny decided to let me die.</s>
<s>I traveled far and wide while I was learning to fly,
But for some reason, God decided to let me die.</s>
<s>I traveled far and wide while I was learning to fly,
But for some reason, they decided to let me die.</s>
<s>I traveled far and wide while I was learning to fly,
But for some reason, fate decided to let me die.</s>
<s>I traveled far and wide while I was learning to fly,
But for some reason, fate chose to let me die.</s>
<s>I traveled far and wide while I was learning to fly,
But for some reason, he chose to let me die.</s>
<s>I traveled far and wide while I was learning to fly,
But for whatever reason, people chose to let me fly.</s>
<s>I travele

['<s>I traveled far and wide while I was learning to fly,\nBut<mask><mask><mask><mask><mask><mask><mask><mask><mask> die.</s>',
 '<s>I traveled far and wide while I was learning to fly,\nBut for some reason, they decided to let me die.</s>',
 '<s>I traveled far and wide while I was learning to fly,\nBut for some reason my destiny refused to let me die.</s>',
 '<s>I traveled far and wide while I was learning to fly,\nBut for some reason my destiny decided to let me die.</s>',
 '<s>I traveled far and wide while I was learning to fly,\nBut for some reason, God decided to let me die.</s>',
 '<s>I traveled far and wide while I was learning to fly,\nBut for some reason, they decided to let me die.</s>',
 '<s>I traveled far and wide while I was learning to fly,\nBut for some reason, fate decided to let me die.</s>',
 '<s>I traveled far and wide while I was learning to fly,\nBut for some reason, fate chose to let me die.</s>',
 '<s>I traveled far and wide while I was learning to fly,\nBut for 

In [20]:
# Add filter for question words

In [19]:
%%time
sentence="&I traveled far and wide while I was learning to fly,\nBut how& <mask><mask><mask><mask><mask><mask><mask><mask><mask>"
tokens = generate(sentence, tokenizer, model, num_iterations=20, top_p=0.85, temp=0.95, print_generated=True)
tokenizer.batch_decode(tf.stack(tokens))

<s>I traveled far and wide while I was learning to fly,
But how can be, and I said I flew.</s>
<s>I traveled far and wide while I was learning to fly,
But how could blame you and I said I flew.</s>
<s>I traveled far and wide while I was learning to fly,
But how to blame you and I said I flew.</s>
<s>I traveled far and wide while I was learning to fly,
But how to blame yourself and have said you flew.</s>
<s>I traveled far and wide while I was learning to fly,
But how to blame others and have said you flew.</s>
<s>I traveled far and wide while I was learning to fly,
But how to blame others and have said you flew."</s>
<s>I traveled far and wide while I was learning to fly,
But how to blame others and have said you flew?"</s>
<s>I traveled far and wide while I was learning to fly,
But how to blame others and they said you flew?</s>
<s>I traveled far and wide while I was learning to fly,
But how to blame others and they said you flew?"</s>
<s>I traveled far and wide while I was learning t

['<s>I traveled far and wide while I was learning to fly,\nBut how<mask><mask><mask><mask><mask><mask><mask><mask><mask></s>',
 '<s>I traveled far and wide while I was learning to fly,\nBut how can be, and I said I flew.</s>',
 '<s>I traveled far and wide while I was learning to fly,\nBut how could blame you and I said I flew.</s>',
 '<s>I traveled far and wide while I was learning to fly,\nBut how to blame you and I said I flew.</s>',
 '<s>I traveled far and wide while I was learning to fly,\nBut how to blame yourself and have said you flew.</s>',
 '<s>I traveled far and wide while I was learning to fly,\nBut how to blame others and have said you flew.</s>',
 '<s>I traveled far and wide while I was learning to fly,\nBut how to blame others and have said you flew."</s>',
 '<s>I traveled far and wide while I was learning to fly,\nBut how to blame others and have said you flew?"</s>',
 '<s>I traveled far and wide while I was learning to fly,\nBut how to blame others and they said you fle

In [12]:
%%time
sentence="&This pain& is <mask>&.&"
tokens = generate(sentence, tokenizer, model, num_iterations=20, top_p=0.1, temp=1.05, print_generated=True)
tokenizer.batch_decode(tf.stack(tokens))

<s>This pain is excruciating.</s>
<s>This pain is excruciating.</s>
<s>This pain is excruciating.</s>
<s>This pain is unbearable.</s>
<s>This pain is excruciating.</s>
<s>This pain is unbearable.</s>
<s>This pain is unbearable.</s>
<s>This pain is unbearable.</s>
<s>This pain is unbearable.</s>
<s>This pain is excruciating.</s>
<s>This pain is unbearable.</s>
<s>This pain is excruciating.</s>
<s>This pain is excruciating.</s>
<s>This pain is excruciating.</s>
<s>This pain is excruciating.</s>
<s>This pain is excruciating.</s>
<s>This pain is excruciating.</s>
<s>This pain is excruciating.</s>
<s>This pain is excruciating.</s>
<s>This pain is unbearable.</s>
CPU times: user 37.9 s, sys: 98.3 ms, total: 38 s
Wall time: 37.6 s


['<s>This pain is<mask>.</s>',
 '<s>This pain is excruciating.</s>',
 '<s>This pain is excruciating.</s>',
 '<s>This pain is excruciating.</s>',
 '<s>This pain is unbearable.</s>',
 '<s>This pain is excruciating.</s>',
 '<s>This pain is unbearable.</s>',
 '<s>This pain is unbearable.</s>',
 '<s>This pain is unbearable.</s>',
 '<s>This pain is unbearable.</s>',
 '<s>This pain is excruciating.</s>',
 '<s>This pain is unbearable.</s>',
 '<s>This pain is excruciating.</s>',
 '<s>This pain is excruciating.</s>',
 '<s>This pain is excruciating.</s>',
 '<s>This pain is excruciating.</s>',
 '<s>This pain is excruciating.</s>',
 '<s>This pain is excruciating.</s>',
 '<s>This pain is excruciating.</s>',
 '<s>This pain is excruciating.</s>',
 '<s>This pain is unbearable.</s>']

In [13]:
%%time
sentence="&This pain& is <mask>&.&"
tokens = generate(sentence, tokenizer, model, num_iterations=20, top_p=0.1, temp=1.05, print_generated=True, custom=lipogram_filter(tokenizer))
tokenizer.batch_decode(tf.stack(tokens))

<s>This pain is normal.</s>
<s>This pain is normal.</s>
<s>This pain is normal.</s>
<s>This pain is normal.</s>
<s>This pain is normal.</s>
<s>This pain is familiar.</s>
<s>This pain is familiar.</s>
<s>This pain is familiar.</s>
<s>This pain is normal.</s>
<s>This pain is normal.</s>
<s>This pain is familiar.</s>
<s>This pain is familiar.</s>
<s>This pain is normal.</s>
<s>This pain is normal.</s>
<s>This pain is familiar.</s>
<s>This pain is familiar.</s>
<s>This pain is normal.</s>
<s>This pain is normal.</s>
<s>This pain is familiar.</s>
<s>This pain is familiar.</s>
CPU times: user 39.5 s, sys: 177 ms, total: 39.7 s
Wall time: 39.3 s


['<s>This pain is<mask>.</s>',
 '<s>This pain is normal.</s>',
 '<s>This pain is normal.</s>',
 '<s>This pain is normal.</s>',
 '<s>This pain is normal.</s>',
 '<s>This pain is normal.</s>',
 '<s>This pain is familiar.</s>',
 '<s>This pain is familiar.</s>',
 '<s>This pain is familiar.</s>',
 '<s>This pain is normal.</s>',
 '<s>This pain is normal.</s>',
 '<s>This pain is familiar.</s>',
 '<s>This pain is familiar.</s>',
 '<s>This pain is normal.</s>',
 '<s>This pain is normal.</s>',
 '<s>This pain is familiar.</s>',
 '<s>This pain is familiar.</s>',
 '<s>This pain is normal.</s>',
 '<s>This pain is normal.</s>',
 '<s>This pain is familiar.</s>',
 '<s>This pain is familiar.</s>']

In [15]:
%%time
sentence="&This pain& is <mask>&.&"
tokens = generate(sentence, tokenizer, model, num_iterations=20, top_p=0.1, temp=1.05, print_generated=True, custom=lipogram_filter(tokenizer, letter="a"))
tokenizer.batch_decode(tf.stack(tokens))

<s>This pain is chronic.</s>
<s>This pain is intense.</s>
<s>This pain is intense.</s>
<s>This pain is chronic.</s>
<s>This pain is chronic.</s>
<s>This pain is chronic.</s>
<s>This pain is intense.</s>
<s>This pain is intense.</s>
<s>This pain is intense.</s>
<s>This pain is intense.</s>
<s>This pain is chronic.</s>
<s>This pain is intense.</s>
<s>This pain is intense.</s>
<s>This pain is chronic.</s>
<s>This pain is intense.</s>
<s>This pain is intense.</s>
<s>This pain is chronic.</s>
<s>This pain is intense.</s>
<s>This pain is chronic.</s>
<s>This pain is intense.</s>
CPU times: user 39.8 s, sys: 220 ms, total: 40 s
Wall time: 39.7 s


['<s>This pain is<mask>.</s>',
 '<s>This pain is chronic.</s>',
 '<s>This pain is intense.</s>',
 '<s>This pain is intense.</s>',
 '<s>This pain is chronic.</s>',
 '<s>This pain is chronic.</s>',
 '<s>This pain is chronic.</s>',
 '<s>This pain is intense.</s>',
 '<s>This pain is intense.</s>',
 '<s>This pain is intense.</s>',
 '<s>This pain is intense.</s>',
 '<s>This pain is chronic.</s>',
 '<s>This pain is intense.</s>',
 '<s>This pain is intense.</s>',
 '<s>This pain is chronic.</s>',
 '<s>This pain is intense.</s>',
 '<s>This pain is intense.</s>',
 '<s>This pain is chronic.</s>',
 '<s>This pain is intense.</s>',
 '<s>This pain is chronic.</s>',
 '<s>This pain is intense.</s>']

In [24]:
%%time
sentence="&Beyond this& <mask><mask><mask><mask>&.&" #PUNCTUATION MATTERS.
tokens = generate(sentence, tokenizer, model, num_iterations=10, top_p=0.85, print_generated=True)
tokenizer.batch_decode(tf.stack(tokens))

<s>Beyond this, life goes on.</s>
<s>Beyond this, life goes on.</s>
<s>Beyond this, it goes downhill.</s>
<s>Beyond this, everything went downhill.</s>
<s>Beyond this, everything went downhill.</s>
<s>Beyond this, everything went downhill.</s>
<s>Beyond this, everything went downhill.</s>
<s>Beyond this, everything went downhill.</s>
<s>Beyond this, everything went downhill.</s>
<s>Beyond this, everything went downhill.</s>
CPU times: user 52.5 s, sys: 138 ms, total: 52.7 s
Wall time: 52.1 s


['<s>Beyond this<mask><mask><mask><mask>.</s>',
 '<s>Beyond this, life goes on.</s>',
 '<s>Beyond this, life goes on.</s>',
 '<s>Beyond this, it goes downhill.</s>',
 '<s>Beyond this, everything went downhill.</s>',
 '<s>Beyond this, everything went downhill.</s>',
 '<s>Beyond this, everything went downhill.</s>',
 '<s>Beyond this, everything went downhill.</s>',
 '<s>Beyond this, everything went downhill.</s>',
 '<s>Beyond this, everything went downhill.</s>',
 '<s>Beyond this, everything went downhill.</s>']

In [177]:
%%time
sentence="&Beyond&<mask><mask><mask><mask><mask>&.&"
tokens = generate(sentence, tokenizer, model, num_iterations=5, top_p=0.75, temp=1.05, print_generated=True, custom=lipogram_filter(tokenizer, letter="a"))
tokenizer.batch_decode(tf.stack(tokens))

<s>Beyond those lines, the unforeseen.</s>
<s>Beyond those lines, no book.</s>
<s>Beyond these two is no book.</s>
<s>Beyond these two is the book.</s>
<s>Beyond these two is the book.</s>
CPU times: user 3min, sys: 17.2 s, total: 3min 17s
Wall time: 1min 25s


['<s>Beyond<mask><mask><mask><mask><mask>.</s>',
 '<s>Beyond those lines, the unforeseen.</s>',
 '<s>Beyond those lines, no book.</s>',
 '<s>Beyond these two is no book.</s>',
 '<s>Beyond these two is the book.</s>',
 '<s>Beyond these two is the book.</s>']

In [178]:
%%time
sentence="&Beyond those lines, the unforeseen.&<mask><mask><mask><mask><mask><mask>&.&"
tokens = generate(sentence, tokenizer, model, num_iterations=5, top_p=0.75, temp=1.05, print_generated=True, custom=lipogram_filter(tokenizer, letter="a"))
tokenizer.batch_decode(tf.stack(tokens))

<s>Beyond those lines, the unforeseen. The worst will never be predicted.</s>
<s>Beyond those lines, the unforeseen. The worst will not be predicted.</s>
<s>Beyond those lines, the unforeseen. The worst will not be predicted.</s>
<s>Beyond those lines, the unforeseen. The unexpected will not be possible.</s>
<s>Beyond those lines, the unforeseen. The unexpected will never be possible.</s>
CPU times: user 8min 27s, sys: 49.6 s, total: 9min 16s
Wall time: 3min 53s


['<s>Beyond those lines, the unforeseen.<mask><mask><mask><mask><mask><mask>.</s>',
 '<s>Beyond those lines, the unforeseen. The worst will never be predicted.</s>',
 '<s>Beyond those lines, the unforeseen. The worst will not be predicted.</s>',
 '<s>Beyond those lines, the unforeseen. The worst will not be predicted.</s>',
 '<s>Beyond those lines, the unforeseen. The unexpected will not be possible.</s>',
 '<s>Beyond those lines, the unforeseen. The unexpected will never be possible.</s>']

In [None]:
%%time
tokens = generate(sentence, tokenizer, model, num_iterations=30, top_p=0.7, temp=1.05, print_generated=True)
tokenizer.batch_decode(tf.stack(tokens))

In [155]:
%%time
tokens = tokenize_sentence("<mask> <mask> <mask> <mask>", tokenizer)
mask_token = tf.constant([tokenizer.mask_token_id])
a = metropolis_hastings_one_iteration(tokens, mask_token, tokenizer, model, top_p=0.8, temp=1.1, custom=lipogram_filter(tokenizer))
print(a)
tokenizer.decode(a)

tf.Tensor([   0  500 1827  298  111    2], shape=(6,), dtype=int32)
CPU times: user 16.5 s, sys: 1.67 s, total: 18.2 s
Wall time: 8.24 s


'<s>Roush -</s>'

In [66]:
%%time
tokens = tokenize_sentence("This is a test sentence", tokenizer)
mask_token = tf.constant([tokenizer.mask_token_id])
a, logits = metropolis_hastings_one_iteration(tokens, mask_token, tokenizer, model, top_p=0.8, temp=1.9)
print(a)
tokenizer.decode(a)

tf.Tensor([-0.6050985 -4.2053432  6.859257  ... -1.996237  -3.6504614  2.9491684], shape=(50265,), dtype=float32)
HERE
tf.Tensor([-inf -inf -inf ... -inf -inf -inf], shape=(50265,), dtype=float32)
tf.Tensor([3703], shape=(1,), dtype=int32)
tf.Tensor([   0  713 3703   10 1296 3645    2], shape=(7,), dtype=int32)
tf.Tensor([-3.4178288 -4.643718  11.854994  ... -4.0670986 -5.0556707  1.6806355], shape=(50265,), dtype=float32)
HERE
tf.Tensor([     -inf      -inf 6.2394705 ...      -inf      -inf      -inf], shape=(50265,), dtype=float32)
tf.Tensor([1151], shape=(1,), dtype=int32)
tf.Tensor([   0  713 3703   10 1296 1151    2], shape=(7,), dtype=int32)
tf.Tensor([-1.7849343 -3.9210784  0.8064407 ... -3.577879  -4.052895   1.3995967], shape=(50265,), dtype=float32)
HERE
tf.Tensor([      -inf       -inf       -inf ...       -inf       -inf 0.73662984], shape=(50265,), dtype=float32)
tf.Tensor([20178], shape=(1,), dtype=int32)
tf.Tensor([    0   713  3703 20178  1296  1151     2], shape=(7,), 

'<s>Creating ISGameMem moment</s>'

In [79]:
len(lipogram_filter(tokenizer))

24233

In [72]:
%%time
tokens = tokenize_sentence("This is a test sentence", tokenizer)
print(tokens)
mask_token = tf.constant([tokenizer.mask_token_id])
a = metropolis_hastings_one_iteration(tokens, mask_token, tokenizer, model, top_p=0.8, temp=1.9, custom=lipogram_filter(tokenizer))
print(a)
tokenizer.decode(a)

tf.Tensor([   0  713   16   10 1296 3645    2], shape=(7,), dtype=int32)
tf.Tensor([-1.5155997 -4.703458  12.018816  ... -3.687202  -4.2801313  2.5221024], shape=(50265,), dtype=float32)
HERE
HERE
tf.Tensor([     -inf      -inf 6.3256927 ...      -inf      -inf      -inf], shape=(50265,), dtype=float32)
tf.Tensor([5674], shape=(1,), dtype=int32)
tf.Tensor([   0  713   16   10 1296 5674    2], shape=(7,), dtype=int32)
tf.Tensor(
[-1.8672018  -3.6670442   7.969133   ... -5.02354    -3.6320744
  0.48664474], shape=(50265,), dtype=float32)
HERE
HERE
tf.Tensor([     -inf      -inf 4.1942806 ...      -inf      -inf      -inf], shape=(50265,), dtype=float32)
tf.Tensor([46571], shape=(1,), dtype=int32)
tf.Tensor([    0 46571    16    10  1296  3645     2], shape=(7,), dtype=int32)
tf.Tensor(
[-3.4993885  -4.169795    2.837184   ... -4.8185396  -6.3750577
  0.60346913], shape=(50265,), dtype=float32)
HERE
HERE
tf.Tensor([     -inf      -inf 1.4932548 ...      -inf      -inf      -inf], shape=(5

'<s>Location 8??? test sentence</s>'

In [68]:
a

(<tf.Tensor: shape=(7,), dtype=int32, numpy=array([    0, 20839,    16,     5, 14692,  3645,     2], dtype=int32)>,
 <tf.Tensor: shape=(50265,), dtype=float32, numpy=array([-inf, -inf, -inf, ..., -inf, -inf, -inf], dtype=float32)>)

In [29]:
tokenizer.convert_ids_to_tokens(4625)

'Ġrepresented'

In [None]:
tokenizer("This is a test").token_to_chars(1)

In [23]:
a

<tf.Tensor: shape=(7,), dtype=int32, numpy=array([    0, 47632,    25,    10,  1050,  3645,     2], dtype=int32)>

In [36]:
logits = tf.squeeze(model(tf.expand_dims(tokens, axis=0))[0])[2]

In [151]:
tokenizer.decode(list(sample_from_logits(filter_logits(logits, indices_to_filter=lipogram_filter(tokenizer)), num_samples=10).numpy()))

' Pak Immortal simulation180 card MatchAction ultra visibilityThis'

In [135]:
tokenizer.convert_ids_to_tokens(sample_from_logits(filter_logits(logits, indices_to_filter=lipogram_filter(tokenizer)), num_samples=1).numpy())

['ĠHERO']

In [136]:
'E' in 'ĠHERO'

True

In [101]:
0 in lipogram_filter(tokenizer)

False

In [None]:
t