In [0]:
!pip install transformers

# Next Word Prediction Demo

At first, all the pre-trained models I could find all use one hot encoding and an LSTM to predict the next characters in the time series where transfer learning could not be achieved since the USE population is very different than the one that trained those models. Next, I thought about finding the most adjacent vector after appending every word in the English vocabulary after the user prompt, but it is quite impractical. 

Finally, I chose to use Huggingface's pre-trained OpenAI GPT2 because it uses the same attention mechanism in USE as they are both transformers so the original goal of predicting the next word with greater-than-word context is accomplished. The tradeoff here would be the large model size, compared to the DAN version of USE, but given how short the user prompt is expected to be, even if the model evaluation is done on the server-end, latency of 100ms (rough CPU evaluation time shown bellow) + 200ms (estimated additional networking latency) = 300ms shouldn't be a big issue.

In [0]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import time
# Optional logging
# import logging
# logging.basicConfig(level=logging.INFO)

# Tokenizer and model initialization and weight loading
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.eval()

In [74]:
# User prompt
prompt = ['This demo can be',
          'How to approximate the area of the']

for i in prompt:
  start = time.time()
  # GPT2 encoder works well with GPT2 LM
  encoded_prompt = tokenizer.encode(i)

  # Disabling auto gradient because it's not necessary for evaluation
  with torch.no_grad():
    outputs = model(torch.tensor([encoded_prompt]))
    predictions = outputs[0]
  
  # Extract the highest scoring succeeding index
  predicted_index = torch.argmax(predictions[0,-1,:]).item()

  # The list is just for the printing format
  print(i, '+', [tokenizer.decode(predicted_index)])
  print('^ that took', time.time()-start, 'secs\n')

This demo can be + [' downloaded']
^ that took 0.1089472770690918 secs

How to approximate the area of the + [' earth']
^ that took 0.12757611274719238 secs

