In [6]:
import torch
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
import re

t = GPT2TokenizerFast.from_pretrained("gpt2")
m = GPT2LMHeadModel.from_pretrained("gpt2")


In [7]:
starting_sentence = "You are in my"
encoded_text = t(starting_sentence, return_tensors="pt")

#1. step to get the logits of the next token
with torch.inference_mode():
  outputs = m(**encoded_text)

next_token_logits = outputs.logits[0, -1, :]

topk_next_tokens= torch.topk(next_token_logits, 5)

#putting it together
print(f'{starting_sentence}: Next word: {[t.decode(idx)[1:] for idx, prob in zip(topk_next_tokens.indices, topk_next_tokens.values)]}')

You are in my: Next word: ['office', 'company', 'life', 'home', 'house']


In [8]:
#load test data from LSTM testing set
import pickle
with open("data/test_nextWord.pkl", "rb") as fp:  
    actual_words = pickle.load(fp)
with open("data/test_sentences.pkl", "rb") as fp:  
    previous_text = pickle.load(fp)




In [9]:
# import re
def predict_next_five(starting_sentence):
    if starting_sentence in [""," "]:
        starting_sentence = "  "
    encoded_text = t(starting_sentence[:-1], return_tensors="pt")

    #1. step to get the logits of the next token
    with torch.inference_mode():
        outputs = m(**encoded_text)

    next_token_logits = outputs.logits[0, -1, :]
    p = re.compile("[^a-z0-9']")
    # often returns punctuation for this version we only want words so take top 20 just in case and then filter down to have only words using regex
    topk_next_tokens= torch.topk(next_token_logits, 20, sorted = True)
    top_five_words = [p.sub('',t.decode(idx).lower()) for idx, prob in zip(topk_next_tokens.indices, topk_next_tokens.values) if len(p.sub('', t.decode(idx))) >0][:5]
    # print(f'{starting_sentence}: Next word: {top_five_words}')
    return top_five_words
predicted_words = []
for line in previous_text:
    predicted_words.append(predict_next_five(line))
# print(predicted_words)

In [10]:
j=0
for i in range (len(actual_words)):
    # print(f'{previous_text[i]}: {predicted_words[i]}: actual: {actual_words[i]}')
    if actual_words[i] in predicted_words[i]: 
        j+=1
        
print(j/len(actual_words))


0.4375


In [15]:
print(predict_next_five(" "))
print( predict_next_five("where") )
print( predict_next_five("I will be back ") )



['the', 'of', '1', '2', 'is']
['a', 'as', 'ing', 'i', 'r']
['in', 'soon', 'to', 'with', 'on']


## Test final package


In [1]:
from next_word_pred import predict_gpt2 
predictor = predict_gpt2.GPT2_next_word_pred() 
    
print(predictor.predict("I will be back "))
print(predictor.predict(""))
print(predictor.predict("where"))

['in', 'soon', 'to', 'with', 'on']
['the', 'of', '1', '2', 'is']
['the', 'and', 'is', 'of', 'a']
