In [27]:
import tensorflow as tf
from transformers import BertTokenizer, BertForMaskedLM

In [2]:
import torch

In [3]:
import pandas as pd
from data_processing.format_data import extract_ab_lines

In [4]:
file_path = '../raw_data/limerick_dataset_oedilf_v3.json'
df = pd.read_json(file_path)
df = df[df.is_limerick == True]
lines = extract_ab_lines(df)

In [5]:
a_lines = lines[0]
ex_line = a_lines[31]
ex_line

'Though indulgence has frequently showed'

In [6]:
tz = BertTokenizer.from_pretrained("bert-base-cased")

In [7]:
tz.tokenize(ex_line)

['Though', 'in', '##du', '##lge', '##nce', 'has', 'frequently', 'showed']

In [8]:
x = torch.rand(5, 3)
print(x)

tensor([[0.3343, 0.3567, 0.9100],
        [0.7196, 0.4538, 0.8579],
        [0.3113, 0.7047, 0.6230],
        [0.8664, 0.7172, 0.8465],
        [0.4166, 0.7034, 0.7823]])


In [12]:
# Encode the sentence
encoded = tz.encode_plus(
    text=ex_line,  # the sentence to be encoded
    add_special_tokens=True,  # Add [CLS] and [SEP]
    max_length = 64,  # maximum length of a sentence
    truncation=True,
    padding='longest',  # Add [PAD]s
    return_attention_mask = True,  # Generate the attention mask
    return_tensors = 'pt',  # ask the function to return PyTorch tensors
)

In [13]:
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']

In [14]:
input_ids

tensor([[  101,  3473,  1107,  7641, 21463,  3633,  1144,  3933,  2799,   102]])

In [15]:
attn_mask

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

In [28]:
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()

Downloading: 100%|██████████| 440M/440M [03:09<00:00, 2.33MB/s] 
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

In [29]:
text = ex_line + "[MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]"
text = "[CLS] %s [SEP]"%text
text

'[CLS] Though indulgence has frequently showed[MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [SEP]'

In [35]:
tz = BertTokenizer.from_pretrained('bert-base-uncased')

Downloading: 100%|██████████| 232k/232k [00:00<00:00, 376kB/s] 
Downloading: 100%|██████████| 28.0/28.0 [00:00<00:00, 10.8kB/s]


In [36]:
tokenized_text = tz.tokenize(text)
masked_index = tokenized_text.index("[MASK]")
indexed_tokens = tz.convert_tokens_to_ids(tokenized_text)
tokens_tensor = torch.tensor([indexed_tokens])

In [37]:
with torch.no_grad():
        outputs = model(tokens_tensor)
        predictions = outputs[0]

probs = torch.nn.functional.softmax(predictions[0, masked_index], dim=-1)
top_k = 5
top_k_weights, top_k_indices = torch.topk(probs, top_k, sorted=True)

for i, pred_idx in enumerate(top_k_indices):
    predicted_token = tz.convert_ids_to_tokens([pred_idx])[0]
    token_weight = top_k_weights[i]
    print("[MASK]: '%s'"%predicted_token, " | weights:", float(token_weight))

[MASK]: 'a'  | weights: 0.4407033324241638
[MASK]: 'in'  | weights: 0.07142706960439682
[MASK]: 'its'  | weights: 0.05816581845283508
[MASK]: 'with'  | weights: 0.04569272696971893
[MASK]: 'to'  | weights: 0.033848367631435394


In [38]:
text = ex_line + "[MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]"
text

'Though indulgence has frequently showed[MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]'

In [40]:
token_ids = tz.encode(text, return_tensors='pt')
token_ids

tensor([[  101,  2295, 27427,  5313, 17905,  2038,  4703,  3662,   103,   103,
           103,   103,   103,   103,   103,   102]])

In [42]:
token_ids_tk = tz.tokenize(text, return_tensors='pt')
token_ids_tk

Keyword arguments {'return_tensors': 'pt'} not recognized.


['though',
 'ind',
 '##ul',
 '##gence',
 'has',
 'frequently',
 'showed',
 '[MASK]',
 '[MASK]',
 '[MASK]',
 '[MASK]',
 '[MASK]',
 '[MASK]',
 '[MASK]']

In [43]:
masked_position = (token_ids.squeeze() == tz.mask_token_id).nonzero()
masked_pos = [mask.item() for mask in masked_position ]
masked_pos

[8, 9, 10, 11, 12, 13, 14]

In [46]:
with torch.no_grad():
    output = model(token_ids)
last_hidden_state = output[0].squeeze()
print ("sentence : ",text)

sentence :  Though indulgence has frequently showed[MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]


In [47]:
list_of_list =[]

for mask_index in masked_pos:
    mask_hidden_state = last_hidden_state[mask_index]
    idx = torch.topk(mask_hidden_state, k=100, dim=0)[1]
    words = [tz.decode(i.item()).strip() for i in idx]
    list_of_list.append(words)
    print (words)

['a', 'i n', 'i t s', 'w i t h', 't o', '"', 'a s', 'n o', 't h e', 'a n', ',', 'o f', 'v e r y', "'", 'f o r', 'h i s', 'a t', 's o m e', 's u c h', 't h e i r', 'h a s', 'i s', 'i t', 'i t s e l f', 'b y', 'a n d', 'i n t o', 'h e r', 'r a t h e r', 'm u c h', 'a n y', 'q u i t e', ':', 't h i s', 'm o r e', 'o n l y', '.', '-', 'o n', 'h i m', 't h r o u g h', 'n o t', 'o r', 'o n e', 't h a t', 'e v e n', 'b u t', 'l i t t l e', 's o m e t h i n g', '(', 'a l m o s t', ')', '...', 'a l s o', 'f r o m', 'w a s', 's o m e t i m e s', 'a b o u t', 'm a n y', 't h e m', 'o f t e n', 'f a r', 'g e n e r a l l y', '# # l y', 'h a d', ';', 'b o t h', 't h e r e', 'w i t h o u t', 'e x t r e m e l y', 'i n c r e a s i n g l y', '# # s', 'r e l a t i v e l y', 'y e t', 't o o', 's o', 's o m e w h a t', 'o t h e r', 'h i m s e l f', 'a n o t h e r', 'p o s s i b l e', 'a g a i n s t', 't w o', 's h o w n', 'o u t', 's h o w i n g', 'h a v e', 's l i g h t l y', 'h a v i n g', 's e v e r a l

In [48]:
best_guess = ""
for j in list_of_list:
    best_guess = best_guess+" "+j[0]

best_guess

' a a a n d a n d t o # # s .'