In [1]:
import tensorflow as tf
from transformers import BertTokenizer, BertForMaskedLM
import torch
import pandas as pd

In [4]:
file_path = '../raw_data/limerick_dataset_oedilf_v3.json'
df = pd.read_json(file_path)
df = df[df.is_limerick == True]

In [None]:
lim = df['limerick'][8]
lim

In [None]:
lim_tk = "<LS> <KS> tomatoes <KE> Our tomatoes this year are abounding; <L0> They're lush, red and ripe — just astounding! <L1> We've run out of uses <L2> (We're making excuses) <L3> For produce that's all but dumbfounding. <L4> <LE>"
lim_tk

In [88]:
from data_processing.bert_preproc import zorro

masked_lim, unmasked = zorro(lim_tk, 0, 1, 2, 3, 4) 
print(masked_lim)
print(unmasked)
masked_lim = masked_lim[masked_lim.find('<KE>'):]
masked_lim = masked_lim.replace('<KE> ','')
masked_lim = masked_lim.replace('<L0> ','')
masked_lim = masked_lim.replace('<L1> ','')
masked_lim = masked_lim.replace('<L2> ','')
masked_lim = masked_lim.replace('<L3> ','')
masked_lim = masked_lim.replace(' <L4> <LE>','')
masked_lim

<LS> <KS> tomatoes <KE> Our tomatoes this year are [MASK]; <L0> They're lush, red and ripe — just [MASK]! <L1> We've run out of [MASK] <L2> (We're making [MASK]) <L3> For produce that's all but [MASK]. <L4> <LE>
{'masked': [(0, 'abounding'), (1, 'astounding'), (2, 'uses'), (3, 'excuses'), (4, 'dumbfounding')], 'unmasked': []}


"Our tomatoes this year are [MASK]; They're lush, red and ripe — just [MASK]! We've run out of [MASK] (We're making [MASK]) For produce that's all but [MASK]."

In [None]:
def rhyme_word_pred(masked_lim: str, top_k=50) -> tuple:
    """
    takes a limerick with masked words
    uses BERT to predicted possible replacement words in mask location
    """
    tz = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertForMaskedLM.from_pretrained('bert-base-uncased')
    model.eval()
    
    tokenized_text = tz.tokenize(masked_lim)
    indexed_tokens = tz.convert_tokens_to_ids(tokenized_text)
    tokens_tensor = torch.tensor([indexed_tokens])
    masked_pos = [i for i,d in enumerate(tokenized_text) if d=='[MASK]']
    
    with torch.no_grad():
        outputs = model(tokens_tensor)
        predictions = outputs[0]

    rhyme_list = ()
    for msk_ind in masked_pos:
        probs = torch.nn.functional.softmax(predictions[0, msk_ind], dim=-1)
        top_k_weights, top_k_indices = torch.topk(probs, top_k, sorted=True)
        
        temp_list = []
        for i, pred_idx in enumerate(top_k_indices):
            
            predicted_token = tz.convert_ids_to_tokens([pred_idx])[0]
            token_weight = top_k_weights[i]
            temp_tpl = (predicted_token, float(token_weight))
        
            temp_list.append(temp_tpl)

    rhyme_list = rhyme_list + (temp_list,)
    
    return rhyme_list


In [None]:
from smartbard.bert_model.bert_predict_masked_words import rhyme_word_pred

rhyme_list = rhyme_word_pred(masked_lim)
rhyme_list

In [90]:
tokenized_text = tz.tokenize(masked_lim)
indexed_tokens = tz.convert_tokens_to_ids(tokenized_text)
tokens_tensor = torch.tensor([indexed_tokens])

In [102]:
token_ids = tz.encode(masked_lim, return_tensors='pt')
# masked_index = tokenized_text.index("[MASK]")
masked_pos = [i for i,d in enumerate(tokenized_text) if d=='[MASK]']
# masked_position = (token_ids.squeeze() == tz.mask_token_id).nonzero()
# masked_pos = [mask.item() for mask in masked_position ]
# masked_pos

In [107]:
with torch.no_grad():
        outputs = model(tokens_tensor)
        predictions = outputs[0]

top_k = 10
words_and_proba = ()

# msk_ind = masked_pos[0]
# print(msk_ind)
# probs = torch.nn.functional.softmax(predictions[0, msk_ind], dim=-1)
# top_k_weights, top_k_indices = torch.topk(probs, top_k, sorted=True)
# # print(masked_index)

list_temp = []
for i, pred_idx in enumerate(top_k_indices):
    predicted_token = tz.convert_ids_to_tokens([pred_idx])[0]
    token_weight = top_k_weights[i]
    # print("[MASK]: '%s'"%predicted_token, " | weights:", float(token_weight))
    list_temp.append((predicted_token, float(token_weight)))

list_temp

for msk_ind in masked_pos:
    probs = torch.nn.functional.softmax(predictions[0, msk_ind], dim=-1)
    top_k_weights, top_k_indices = torch.topk(probs, top_k, sorted=True)
    
    temp_list = []
    for i, pred_idx in enumerate(top_k_indices):
        
        predicted_token = tz.convert_ids_to_tokens([pred_idx])[0]
        token_weight = top_k_weights[i]
        # print("[MASK]: '%s'"%predicted_token, " | weights:", float(token_weight))
        temp_tpl = (predicted_token, float(token_weight))
    
        temp_list.append(temp_tpl)

    words_and_proba = words_and_proba + (temp_list,)
words_and_proba


([('amazing', 0.03770241141319275),
  ('good', 0.030399274080991745),
  ('great', 0.024686096236109734),
  ('huge', 0.02320134826004505),
  ('beautiful', 0.022430414333939552),
  ('wonderful', 0.021738242357969284),
  ('incredible', 0.01915503665804863),
  ('fantastic', 0.018468396738171577),
  ('rare', 0.018459055572748184),
  ('perfect', 0.01667996123433113)],
 [('!', 0.12691663205623627),
  ('look', 0.04611518234014511),
  ('right', 0.040643367916345596),
  ('stop', 0.03317881003022194),
  ('wow', 0.0330524668097496),
  ('think', 0.02918442152440548),
  ('go', 0.022315895184874535),
  ('see', 0.01810493879020214),
  ('what', 0.017087459564208984),
  ('no', 0.015778031200170517)],
 [('time', 0.08695467561483383),
  ('options', 0.06752683222293854),
  ('money', 0.042667098343372345),
  ('space', 0.02524210326373577),
  ('produce', 0.019716963171958923),
  ('ingredients', 0.018899541348218918),
  ('supplies', 0.017176149412989616),
  ('food', 0.017155611887574196),
  ('orders', 0.01519

In [105]:
lim

"Our tomatoes this year are abounding;\nThey're lush, red and ripe — just astounding!\nWe've run out of uses\n(We're making excuses)\nFor produce that's all but dumbfounding."

In [None]:
for i, pred_idx in enumerate(top_k_indices):
    predicted_token = tz.convert_ids_to_tokens([pred_idx])[0]
    # print(predicted_token)
    token_weight = top_k_weights[i]
    print("[MASK]: '%s'"%predicted_token, " | weights:", float(token_weight))

In [9]:
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']

In [None]:
# tz = BertTokenizer.from_pretrained('bert-base-uncased')

In [15]:
# text = lim.replace("warm", "[MASK]")
# # text = "[CLS] %s [SEP]"%text
# text

"In a cozy, [MASK] alcove I knit.\nIt's my hole-in-the-wall, poorly lit.\nI drop many a stitch,\nYet my goal's to enrich\nMy dear husband, whom this thing may fit."

In [16]:
tz = BertTokenizer.from_pretrained('bert-base-uncased')

In [17]:
tokenized_text = tz.tokenize(text)
masked_index = tokenized_text.index("[MASK]")
indexed_tokens = tz.convert_tokens_to_ids(tokenized_text)
tokens_tensor = torch.tensor([indexed_tokens])

In [28]:
# with torch.no_grad():
#         outputs = model(tokens_tensor)
#         predictions = outputs[0]

# probs = torch.nn.functional.softmax(predictions[0, masked_index], dim=-1)
# top_k = 5
# top_k_weights, top_k_indices = torch.topk(probs, top_k, sorted=True)

# for i, pred_idx in enumerate(top_k_indices):
#     predicted_token = tz.convert_ids_to_tokens([pred_idx])[0]
#     token_weight = top_k_weights[i]
#     print("[MASK]: '%s'"%predicted_token, " | weights:", float(token_weight))

[MASK]: 'a'  | weights: 0.11124002188444138
[MASK]: 'an'  | weights: 0.05427626520395279
[MASK]: 'the'  | weights: 0.03255700320005417
[MASK]: 'my'  | weights: 0.03207283094525337
[MASK]: 'small'  | weights: 0.012609540484845638


In [23]:
token_ids = tz.encode(text, return_tensors='pt')
token_ids

tensor([[  101,  1999,  1037, 26931,  1010,   103,  2632,  3597,  3726,  1045,
         22404,  1012,  2009,  1005,  1055,  2026,  4920,  1011,  1999,  1011,
          1996,  1011,  2813,  1010,  9996,  5507,  1012,  1045,  4530,  2116,
          1037, 26035,  1010,  2664,  2026,  3125,  1005,  1055,  2000,  4372,
         13149,  2026,  6203,  3129,  1010,  3183,  2023,  2518,  2089,  4906,
          1012,   102]])

In [None]:
token_ids_tk = tz.tokenize(text, return_tensors='pt')
token_ids_tk

In [25]:
masked_position = (token_ids.squeeze() == tz.mask_token_id).nonzero()
masked_pos = [mask.item() for mask in masked_position ]
masked_pos

[5]

In [26]:
with torch.no_grad():
    output = model(token_ids)
last_hidden_state = output[0].squeeze()
print ("sentence : ",text)

sentence :  In a cozy, [MASK] alcove I knit.
It's my hole-in-the-wall, poorly lit.
I drop many a stitch,
Yet my goal's to enrich
My dear husband, whom this thing may fit.


In [27]:
list_of_list =[]

for mask_index in masked_pos:
    mask_hidden_state = last_hidden_state[mask_index]
    idx = torch.topk(mask_hidden_state, k=100, dim=0)[1]
    words = [tz.decode(i.item()).strip() for i in idx]
    list_of_list.append(words)
    print (words)

['c o z y', 's m a l l', 'w a r m', 'd a r k', 'c o m f o r t a b l e', 'l i t t l e', 'q u i e t', 'i n t i m a t e', 'n e a t', 't i n y', 'p r i v a t e', 'i n v i t i n g', 'l o n e l y', 's e c l u d e d', 'n a r r o w', 's h a d e d', 'c o o l', 'm o d e s t', 'd i m', 'p l e a s a n t', 'd u s t y', 'c r a m p e d', 't i d y', 'd a m p', 's h a d y', 's u n n y', 's o f t', 's e c r e t', 'l o v e l y', 'r u s t i c', 'e m p t y', 's m o k y', 'c l e a n', 'd i s c r e e t', 'r o m a n t i c', 'u n c o m f o r t a b l e', 'w e l c o m i n g', 'c o l d', 'n i c e', 'i s o l a t e d', 's i l e n t', 's p a c i o u s', 'w h i t e', 'c h e e r f u l', 'e l e g a n t', 'p l u s h', 'f e m i n i n e', 's h a d o w y', 'r o u n d', 'e n c l o s e d', 's h a d o w e d', 'o p e n', 's h a l l o w', 'i n f o r m a l', 'h u m b l e', 's a f e', 'b e a u t i f u l', 'l i t', 'c o m f o r t i n g', 'w o o d e n', 's i m p l e', 'h i d d e n', 'c i r c u l a r', 'c h i l l y', 'd r y', 'h u s

In [29]:
best_guess = ""
for j in list_of_list:
    best_guess = best_guess+" "+j[0]

best_guess

' c o z y'

In [None]:
# Encode the sentence
encoded = tz.encode_plus(
    text=lim,  # the sentence to be encoded
    add_special_tokens=True,  # Add [CLS] and [SEP]
    max_length = 64,  # maximum length of a sentence
    truncation=True,
    padding='longest',  # Add [PAD]s
    return_attention_mask = True,  # Generate the attention mask
    return_tensors = 'pt',  # ask the function to return PyTorch tensors
)