# Setup

In [18]:
!pip install gradio
!pip install transformers==3.0.2



In [22]:
from transformers import BertTokenizer, BertForMaskedLM
import torch
import string 

# Bert Next Word Prediction

Source: https://blog.jovian.ai/next-word-prediction-using-bert-388bf48f38f

# Load the model

Loading bert tokenizer which is used to tokenize the text. And loads the model, a pretrained model case insensitive model.

In [2]:
def load_model():
  bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
  bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased').eval()
  return bert_tokenizer,bert_model


# Masking

Masking is needed at the end to predict the next word. In this case we need it at the end. Bert needs this to preprocess correctly

In [20]:
def get_prediction_eos(input_text):
  try:
    input_text += ' <mask>'
    res = get_all_predictions(input_text, top_clean=int(top_k))
    return res
  except Exception as error:
    print(error)
    pass


# Encode/Decode

Need to encode the input text with the bert tokenizer we initialized earlier. add_special_tokens = True is given to use out of vocabulary words which bert uses. Tokenizer will then return input_ids. You can then get mask index, mask_idx, which is where the mask is added

In [4]:
def encode(tokenizer, text_sentence, add_special_tokens=True):
  text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
    # if <mask> is the last token, append a "." so that models dont predict punctuation.
  if tokenizer.mask_token == text_sentence.split()[-1]:
    text_sentence += ' .'

    input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
    mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
  return input_ids, mask_idx

The input_ids here is returned and passed to bert model which will predict new word in encoded format. Next thing to do is decode:

In [5]:
def decode(tokenizer, pred_idx, top_clean):
  ignore_tokens = string.punctuation + '[PAD]'
  tokens = []
  for w in pred_idx:
    token = ''.join(tokenizer.decode(w).split())
    if token not in ignore_tokens:
      tokens.append(token.replace('##', ''))
  return '\n'.join(tokens[:top_clean])

# Complete Model Predictor

In [10]:
def get_all_predictions(text_sentence, top_clean=5):
    # ========================= BERT =================================
  input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
  with torch.no_grad():
    predict = bert_model(input_ids)[0]
  bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean)
  return {'bert': bert}

In [7]:
# Setting up tokenizer and bert model
bert_tokenizer, bert_model  = load_model() 

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/433 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMaskedLM were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['cls.predictions.decoder.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [25]:
top_k = 5
res = get_prediction_eos("I want a")
answer = []
print(res['bert'].split("\n"))
for i in res['bert'].split("\n"):
  answer.append(i)
answer_as_string = "    ".join(answer)

Top_k 5
['life', 'job', 'baby', 'drink', 'shower']


# Interface

In [28]:
import gradio as gr

def greet(input_text):
  top_k = 5
  res = get_prediction_eos(input_text)
  answer = []
  print(res['bert'].split("\n"))
  for i in res['bert'].split("\n"):
    answer.append(i)
  answer_as_string = "    ".join(answer)
  return answer_as_string
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
iface.launch()

Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`
This share link will expire in 72 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted
Running on External URL: https://46059.gradio.app


(<Flask 'gradio.networking'>,
 'http://127.0.0.1:7864/',
 'https://46059.gradio.app')