In [1]:
from IPython.core.display import display, HTML, Image
display(HTML("<style>.container { width:95% !important; }</style>"))
%config IPCompleter.use_jedi=False

In [2]:
from transformers import AutoTokenizer, TFAutoModelForMaskedLM
import tensorflow as tf
import re
from itertools import chain
from string import punctuation
import os
import numpy as np
from itertools import chain
from collections import Counter

In [3]:
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd

In [4]:
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
model = TFAutoModelForMaskedLM.from_pretrained("roberta-base")
model.trainable = False

Downloading:   0%|          | 0.00/481 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/657M [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFRobertaForMaskedLM.

All the layers of TFRobertaForMaskedLM were initialized from the model checkpoint at roberta-base.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFRobertaForMaskedLM for predictions without further training.


In [1]:
def prepare_sampler_input(sentence, pre='', post='', pre_mask=0, post_mask=5, mask_token='[MASK]',cap=False, low=False):
    if cap:
        sentence = sentence.capitalize()
    if low:
        sentence = sentence.lower()
    return f"{pre_mask*(mask_token+' ')}&{pre}{sentence}{post}&{post_mask*(' '+mask_token)}."

def tokenize_sentence(sentence):
        tokenized_tensor = tokenizer(sentence, return_tensors='tf')['input_ids']
        tokens = tf.squeeze(tokenized_tensor)
        return tokens

def pairwise(iterable):
    "s -> (s0, s1), (s2, s3), (s4, s5), ..."
    a = iter(iterable)
    return zip(a, a)
    
def sample_from_logits(logits, num_samples=10):
    sample = tf.random.categorical([logits], num_samples, dtype=tf.int32)
    sample = tf.reshape(sample, shape=(tf.shape(sample)[1],))
    return sample


def temperature(logits, temp=1.0):
    return logits / temp

def filter_logits(logits, indices_to_filter, fill_with=-np.inf):
    indices_to_filter = tf.expand_dims(indices_to_filter, axis=1)
    filters = tf.fill([len(indices_to_filter)],fill_with)
    filtered_logits = tf.tensor_scatter_nd_update(logits, indices_to_filter, filters)    
    return filtered_logits

def top_k_filter(logits, top_k=100):
    indices_to_filter = tf.math.top_k(-logits, k=len(logits)-top_k).indices #TODO CHECK IF THIS SUBSTRACTION IS CORRECT
    return filter_logits(logits, indices_to_filter=indices_to_filter)

#     indices_to_filter = tf.expand_dims(indices_to_filter, axis=1)
#     filters = tf.fill([len(indices_to_filter)],-np.inf)
#     filtered_logits = tf.tensor_scatter_nd_update(logits, indices_to_filter, filters)
#     return filtered_logits

def top_p_filter(logits, top_p=0.8):
    logits = tf.squeeze(logits)
    indices_sorted = tf.argsort(logits, direction="DESCENDING")
    logits_sorted = tf.gather(logits, indices_sorted)
    probs_sorted = tf.nn.softmax(logits_sorted, axis=-1)
    cutoff_index = tf.argmax(tf.cumsum(probs_sorted)>top_p)
    indices_to_filter = indices_sorted[cutoff_index+1:]
    return filter_logits(logits, indices_to_filter=indices_to_filter)
#     indices_to_filter = tf.expand_dims(indices_to_filter, axis=1)
#     filters = tf.fill([len(indices_to_filter)],-np.inf)
#     filtered_logits = tf.tensor_scatter_nd_update(logits, indices_to_filter, filters)    
#     return filtered_logits

def frozen_indices(sentence, separator="&"):
    indices = [(np.array(match.span()) - i) for i,match in enumerate(re.finditer(separator, sentence))]
    indices = [np.concatenate(thing)[[0,-1]] for thing in list(pairwise(indices))]
    indices = [(thing - np.array((0,1))) for thing in indices]
    mapping = tokenizer(sentence.replace("&", ""), return_offsets_mapping=True, return_attention_mask=False, return_token_type_ids=False)['offset_mapping'][1:-1]
    tokenized = tokenizer.tokenize(sentence.replace("&", ""))
    r = len(tokenized)+1
    mapping = dict(zip(mapping, list(range(1,r))))
    all_ix = {}
    for span in indices:
        a,b = span
        try:
            i,j = {k:v for k,v in mapping.items() if k[0]==a or k[1]==b}.values()
            ix_span = list(range(i, j+1))
        except:
            ix_span = {k:v for k,v in mapping.items() if k[0]==a or k[1]==b}.values()
        all_ix.update({i:tokenized[i-1] for i in ix_span})
    return set(all_ix.keys()), sentence.replace("&", "")

def gibbs_one_iteration(tokens, frozen={}, randomized=True, top_k=0, top_p=1.0, temp=1.0):
    ix = list(range(1,len(tokens)-1))
    ix = list(set(ix)-frozen) # ONLY THE TOKENS ix WILL BE SAMPLED. THE REST ARE FROZEN.
    if randomized:
        np.random.shuffle(ix)
    mask_token = tf.constant([tokenizer.mask_token_id])
    for i in ix:
        indices = tf.constant([[i]])
        updates = mask_token
        masked = tf.tensor_scatter_nd_update(tokens, indices, updates) #FEED THIS INTO BERT        
        masked_token_logits = tf.squeeze(model(tf.expand_dims(masked, axis=0))[0])[i] #LOGITS FOR MASKED TOKEN
        
        if temp!=1.0:
            masked_token_logits = temperature(masked_token_logits, temp=temp) #TEMPERATURE ADJUSTMENT. MORE THAN 1 MAKES THE MODEL LESS CONFIDENT AND MORE LIKELY TO SAMPLE LESS LIKELY TOKENS 
        if top_k>0:
            masked_token_logits = top_k_filter(masked_token_logits, top_k=top_k)
        if top_p<1.0:
            masked_token_logits = top_p_filter(masked_token_logits, top_p=top_p)
        
        replacement = sample_from_logits(masked_token_logits, num_samples=1)        
        tokens = tf.tensor_scatter_nd_update(tokens, indices, replacement)
    
    new_sentence = tokenizer.decode(tokens)
    return new_sentence, tokens


def gibbs_sampler(sentence, num_iterations=10, top_k=0, top_p=1.0, temp=1.0, print_generations=True):
    frozen, sentence = frozen_indices(sentence)
    new_tokens = tokenize_sentence(sentence)
    sentences = []
    for i in range(num_iterations):
        new_sentence, new_tokens = gibbs_one_iteration(tokens=new_tokens, frozen=frozen, top_k=top_k, top_p=top_p, temp=temp)
        if print_generations:
            print(new_sentence)
            print("-----------------")
        sentences.append(new_sentence)            
    return sentences

def sentence_log_potential(sentence):
    tokenized_tensor = tokenizer(sentence, return_tensors='tf')['input_ids']
    tokens = tf.squeeze(tokenized_tensor)
    ix = list(range(1,len(tokens)-1))
    mask_token = tf.constant([tokenizer.mask_token_id])
    log_probs = tf.constant([0.])
    for i in ix:
        indices = tf.constant([[i]])
        updates = mask_token
        masked = tf.tensor_scatter_nd_update(tokens, indices, updates) #FEED THIS INTO BERT  
        masked_token_logits = tf.squeeze(model(tf.expand_dims(masked, axis=0))[0])[i] #LOGITS FOR MASKED TOKEN        
        log_prob = -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tokens[i], logits=masked_token_logits)
        log_prob = tf.expand_dims(log_prob, axis=0)
        log_probs = tf.concat([log_probs, log_prob], axis=0)
    return (tf.reduce_sum(log_probs) / len(log_probs))

NameError: name 'np' is not defined

In [59]:
tokenizer.vocab

{'403': 34030,
 '301': 28167,
 'Ġ1937': 31556,
 'Ġlatent': 42715,
 'Ġdisbel': 45668,
 'ĠRuss': 18184,
 'Ġhumanity': 9187,
 'Ġsyndrome': 14115,
 'Ġcenterpiece': 28478,
 'Ġ408': 31060,
 'ĠMode': 22158,
 '727': 39125,
 'ĠDrill': 37411,
 'Launch': 47269,
 'Ġspaced': 42926,
 'Ġbarrage': 23080,
 'ukemia': 44336,
 'Ġflesh': 18940,
 'ĠHa': 4936,
 'ĠWin': 5711,
 'ĠTill': 15628,
 'ĠTools': 19514,
 'ĠVil': 22153,
 '316': 35092,
 'Ġreading': 2600,
 'ĠNost': 39233,
 'ĠMess': 15212,
 '044': 40847,
 'Ġbanker': 15573,
 'ized': 1538,
 'ĠIbn': 43609,
 'Ġmating': 40297,
 'Ġany': 143,
 'Ġtherapy': 5804,
 'Ġphysically': 7217,
 'charge': 15040,
 'Ġrespectable': 25031,
 'Ġcred': 18994,
 'ING': 1862,
 'ggie': 30210,
 'Indust': 36926,
 'Ġtriangular': 42265,
 'ĠDaredevil': 42796,
 'bell': 11312,
 'Ġ{}': 49153,
 'ï¸ı': 12605,
 'opt': 19693,
 'Ġbasil': 32394,
 'Ġdesert': 10348,
 'ĠLebanon': 8398,
 'pri': 13718,
 'Ġmotorcycles': 21027,
 'TM': 14386,
 'ĠGillespie': 25418,
 'redit': 35979,
 'Ġbreeds': 29441,
 'Ġfear

In [75]:
tokenizer.vocab['Ġtimeout']

25386

In [76]:
tokenizer.vocab['timeout']

49109

In [77]:
tokenizer.vocab['Timeout']

49405

In [67]:
tokenizer('Timeout')

{'input_ids': [0, 49405, 2], 'attention_mask': [1, 1, 1]}

In [60]:
'Ġsyndrome' in tokenizer.vocab

True

In [18]:
%%timeit
sentence_log_potential("This is a test sentence")

424 ms ± 13.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [54]:
float(sentence_log_potential("This is a test sentence"))

-2.9686391353607178

In [25]:
%%timeit
sentence_log_potential("This is a test sentence")

412 ms ± 2.61 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [39]:
log_p = sentence_log_potential("Siko biko")

In [42]:
tf.reduce_sum(log_p)/len(log_p)

<tf.Tensor: shape=(), dtype=float32, numpy=-6.3395705>

In [8]:
tokens = [3,2]
masked_token_logits = [1,1,3,5.]

In [10]:
categorical = tfd.Categorical(logits=masked_token_logits)
log_prob = categorical.log_prob(tokens[1])

In [11]:
log_prob

<tf.Tensor: shape=(), dtype=float32, numpy=-2.158683>

In [12]:
tf.nn.softmax(masked_token_logits)

<tf.Tensor: shape=(4,), dtype=float32, numpy=array([0.01562813, 0.01562813, 0.11547709, 0.85326666], dtype=float32)>

In [13]:
-tf.nn.sparse_softmax_cross_entropy_with_logits(labels=2, logits=masked_token_logits)

<tf.Tensor: shape=(), dtype=float32, numpy=-2.158683>

In [15]:
np.exp(log_prob)

0.11547709

In [None]:
-tf.nn.sparse_softmax_cross_entropy_with_logits

In [15]:
logits=np.array([1,2,3,1,4])
logits

array([1, 2, 3, 1, 4])

In [19]:
tf.squeeze(logits)

<tf.Tensor: shape=(5,), dtype=int64, numpy=array([1, 2, 3, 1, 4])>

In [22]:
len(i)

3

In [17]:
i = tf.math.top_k(-logits, k=len(logits)-2).indices
i

<tf.Tensor: shape=(3,), dtype=int32, numpy=array([0, 3, 1], dtype=int32)>

In [23]:
len(tf.expand_dims(i, axis=1))

3

In [14]:
sentence="&I cannot& <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask>&.&"

In [17]:
%%time
gibbs_sampler(sentence, top_k=100, temp=0.9, print_generations=False, num_iterations=5)

CPU times: user 21.3 s, sys: 1.95 s, total: 23.3 s
Wall time: 10.3 s


['<s>I cannot take you on a deeper dive just now.</s>',
 '<s>I cannot take it for any deeper interrogation just yet.</s>',
 '<s>I cannot understand him beyond a basic level, yet.</s>',
 '<s>I cannot understand them in a rational sense, either.</s>',
 '<s>I cannot judge him in any objective way, however.</s>']

In [15]:
%%time
gibbs_sampler(sentence, top_k=5, temp=0.9, print_generations=False, num_iterations=5)

CPU times: user 21.3 s, sys: 1.93 s, total: 23.2 s
Wall time: 10.1 s


['<s>I cannot see the future, because I am lost.</s>',
 '<s>I cannot see the future, and I am afraid.</s>',
 '<s>I cannot see the moon, and I am afraid.</s>',
 '<s>I cannot see the stars, and I am alone.</s>',
 '<s>I cannot see the sun, and I am afraid.</s>']

In [28]:
%%time
gibbs_sampler(sentence, top_k=5, temp=0.9, print_generations=False, num_iterations=5)

CPU times: user 21.3 s, sys: 2.26 s, total: 23.6 s
Wall time: 10.2 s


['<s>I cannot leave you. You have just passed away.</s>',
 '<s>I cannot understand you. You have just moved away.</s>',
 '<s>I cannot believe this. I have just found out.</s>',
 '<s>I cannot believe it, I have just found out.</s>',
 '<s>I cannot believe it! I have finally found freedom.</s>']

In [33]:
%%time
gibbs_sampler(sentence, top_k=5, temp=0.9, print_generations=False, num_iterations=5)

50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
50260
CPU times: user 23.5 s, sys: 2.67 s, total: 26.2 s
Wall time: 12.4 s


['<s>I cannot help you. I have to kill you.</s>',
 '<s>I cannot leave you. We want to help you.</s>',
 '<s>I cannot help you. People need to help you.</s>',
 '<s>I cannot help you. I want to help you.</s>',
 '<s>I cannot leave you. I have to find you.</s>']

In [30]:
top_k_filter(logits, top_k=1)

4
4


<tf.Tensor: shape=(5,), dtype=float32, numpy=array([-inf, -inf, -inf, -inf,   4.], dtype=float32)>

In [39]:
%%time
gibbs_sampler(sentence, top_p=0.1, temp=0.9, print_generations=False, num_iterations=25)

CPU times: user 17.6 s, sys: 90.7 ms, total: 17.7 s
Wall time: 17.8 s


['<s>I cannot afford to die. I have to live.</s>',
 '<s>I cannot afford to die. I have to live.</s>',
 '<s>I cannot afford to die. I have to live.</s>',
 '<s>I cannot afford to die. I have to live.</s>',
 '<s>I cannot afford to die. I have to live.</s>',
 '<s>I cannot afford to die. I have to live.</s>',
 '<s>I cannot afford to die. I have to live.</s>',
 '<s>I cannot afford to die. I have to live.</s>',
 '<s>I cannot afford to die. I have to live.</s>',
 '<s>I cannot afford to die. I have to live.</s>',
 '<s>I cannot afford to die. I have to live.</s>',
 '<s>I cannot afford to die. I have to live.</s>',
 '<s>I cannot afford to die. I have to live.</s>',
 '<s>I cannot afford to die. I have to live.</s>',
 '<s>I cannot afford to die. I have to live.</s>',
 '<s>I cannot afford to die. I have to live.</s>',
 '<s>I cannot afford to die. I have to live.</s>',
 '<s>I cannot afford to die. I have to live.</s>',
 '<s>I cannot afford to die. I have to live.</s>',
 '<s>I cannot afford to die. I 

In [41]:
%%time
gibbs_sampler(sentence, top_p=0.5, temp=0.9, print_generations=False, num_iterations=25)

CPU times: user 17.3 s, sys: 108 ms, total: 17.4 s
Wall time: 17.5 s


['<s>I cannot stand to be the way I am now.</s>',
 '<s>I cannot continue to live the way I do now.</s>',
 '<s>I cannot continue to live the way I do now.</s>',
 '<s>I cannot continue to live the way I do now.</s>',
 '<s>I cannot continue to live the way I do now.</s>',
 '<s>I cannot continue to live the way I do now.</s>',
 '<s>I cannot continue to live the way I do now.</s>',
 '<s>I cannot continue to live the way I do now.</s>',
 '<s>I cannot continue to live the way I do now.</s>',
 '<s>I cannot continue to live the way I do now.</s>',
 '<s>I cannot continue to live the way I do now.</s>',
 '<s>I cannot continue to live the way I do now.</s>',
 '<s>I cannot continue to live the way I do now.</s>',
 '<s>I cannot continue to live the way I do now.</s>',
 '<s>I cannot continue to live the way I do now.</s>',
 '<s>I cannot continue to live the way I do now.</s>',
 '<s>I cannot continue to live the way I do now.</s>',
 '<s>I cannot continue to live the way I do now.</s>',
 '<s>I cannot c

In [100]:
tokenizer.vocab_size

50265

In [48]:
[thing for thing in tokenizer.vocab if not 'e' in thing]

['Why',
 '310',
 'ĠAsian',
 'agascar',
 'Ġfly',
 'ĠHIS',
 'ĠÃ¾',
 'riott',
 'OTA',
 'ĠTL',
 'uristic',
 '604',
 'ĠMiliband',
 'Ġ39',
 'Ġwatching',
 'Ġpathways',
 'ĠSac',
 'KC',
 'Cruz',
 'Ġcolonization',
 'aryl',
 'ĠRuin',
 'zos',
 'ĠNorton',
 'onica',
 'Ġcouncils',
 'Ġpromoting',
 'Ġstir',
 'Ġmamm',
 'ĠLoyal',
 'Ġpony',
 'If',
 '";',
 'icating',
 '117',
 'Ġdysfunctional',
 'achi',
 'icc',
 'ĠTrib',
 'Fast',
 'ĠArts',
 'Ġpav',
 'Ġloan',
 'ĠHick',
 'ĠOval',
 'Entity',
 'HO',
 'Ġma',
 'copy',
 'ĠSpock',
 'Ġpulmonary',
 'Ġbiologically',
 'Ġstupidity',
 'ĠDEN',
 'Ġtort',
 'ĠSAM',
 'xc',
 'FW',
 '")',
 'ãĥĩãĤ£',
 'Ġpopping',
 'Downloadha',
 'ĠHaku',
 'ĠMort',
 'ĠIndians',
 'ĠWalking',
 'Ġagon',
 '652',
 'ãĢĲ',
 'ĠSTUD',
 'ĠIgn',
 'ĠMissions',
 'Ġcollisions',
 'Ġgigg',
 'Ġdub',
 'Ġafloat',
 'Ġmull',
 '969',
 'ĠÎ±',
 'åī',
 '707',
 'tor',
 'ASED',
 'anch',
 'Ġsatisfy',
 'Ġallocations',
 'ito',
 'ĠPittsburgh',
 'ĠXVI',
 'Ġshouting',
 'ilar',
 'Ġlangu',
 'riot',
 'Ġpatriarchal',
 'ĠXY',
 'IDS',

'ãĥĪ'

In [79]:
tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(49280))

'ト'

In [81]:
tokenizer.convert_tokens_to_string('ĠMiliband')

' Miliband'