In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")

poem = [
    "One must have a mind of winter",
    "To regard the frost and the boughs",
    "Of the pine-trees crusted with snow;",
    "And have been cold a long time",
    "To behold the junipers shagged with ice,",
    "The spruces rough in the distant glitter",
    "Of the January sun; and not to think",
    "Of any misery in the sound of the wind,",
    "In the sound of a few leaves,",
    "Which is the sound of the land",
    "Full of the same wind",
    "That is blowing in the same bare place",
    "For the listener, who listens in the snow,",
    "And, nothing himself, beholds",
    "Nothing that is not there and the nothing that is."
]

In [None]:
# P+7 ALGORITHM

def replace_last_word(line):
  # list of words in a line
  words = line.split()

  # Remove last word & tokenize line
  # Convert string to tokenIDs + return as PyTorch tensors
  prompt = " ".join(words[:-1])
  inputs = tokenizer(prompt, return_tensors="pt")

  # Get model predictions by processing input tokens
  with torch.no_grad():
    outputs = model(**inputs)

  # Raw predictions for last token in input (-1) (logits = raw scores)
  logits = outputs.logits[0, -1]

  # Convert logits (raw scores) to probabilities & sort
  probs = torch.softmax(logits, dim=0)
  sorted_probs, sorted_indices = torch.sort(probs, descending=True)

  # x-th highest prob
  p7_token_id = sorted_indices[6].item()
  p7_word = tokenizer.decode([p7_token_id]).strip()

  # Construct new line
  new_line = prompt + " " + p7_word
  return new_line

p7_poem = [replace_last_word(line) for line in poem]

for line in p7_poem:
  print(line)

In [None]:
# P+X ALGORITHM

def replace_last_word(line, x):
  # list of words in a line
  words = line.split()

  # Remove last word & tokenize line
  # Convert string to tokenIDs + return as PyTorch tensors
  prompt = " ".join(words[:-1])
  inputs = tokenizer(prompt, return_tensors="pt")

  # Get model predictions by processing input tokens
  with torch.no_grad():
    outputs = model(**inputs)

  # Raw predictions for last token in input (-1)
  logits = outputs.logits[0, -1]

  # Convert logits (raw scores) to probabilities & sort
  probs = torch.softmax(logits, dim=0)
  sorted_probs, sorted_indices = torch.sort(probs, descending=True)

  # 7th highest prob
  p7_token_id = sorted_indices[x].item()
  p7_word = tokenizer.decode([p7_token_id]).strip()

  # Construct new line
  new_line = prompt + " " + p7_word
  return new_line

p7_poem = [replace_last_word(line, -1) for line in poem]

for line in p7_poem:
  print(line)