<a href="https://colab.research.google.com/github/bradenwatkins/wwdtm-limerick-solver/blob/master/gpt_2_limerick.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip3 install transformers
!pip3 install tqdm



In [2]:
import os
import csv
import logging
import warnings

import torch
from torch.utils.data import Dataset, DataLoader

import numpy as np
from tqdm import tqdm
import transformers
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW

logging.getLogger().setLevel(logging.CRITICAL)
warnings.filterwarnings('ignore')

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [0]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
model = model.to(device)

In [0]:
# Function to first select topN tokens from the probability list and then based on the selected N word distribution
# get random token ID
def choose_from_top(probs, n=5):
    ind = np.argpartition(probs, -n)[-n:]
    top_prob = probs[ind]
    top_prob = top_prob / np.sum(top_prob) # Normalize
    
    choice = np.random.choice(n, 1, p = top_prob)
    token_id = ind[choice][0]
    return int(token_id)

In [0]:
class LimerickDataset(Dataset):
  def __init__(self, path="limericks.txt"):
    super().__init__()

    self.limericks = []
    self.EOT = "<|endoftext|>"

    with open(path) as csv_file:
      csv_reader = csv.reader(csv_file, delimiter="\n")
      limerick, skip_count = [], 0
      for row in csv_reader:
        if len(row) == 0:
          limerick.append(self.EOT)
          self.limericks.append("\n".join(limerick))
          limerick, skip_count = [], 0
        elif skip_count < 2:
          skip_count += 1
        else:
          limerick.append(" ".join(row))

  def __len__(self):
    return len(self.limericks)

  def __getitem__(self, item):
    return self.limericks[item]

In [6]:
BATCH_SIZE = 16
EPOCHS = 5
LEARNING_RATE = 3e-5
WARMUP_STEPS = 5000
TRAINING_STEPS = 5000
MAX_SEQ_LEN = 400

dataset = LimerickDataset()
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
model.train()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = transformers.get_linear_schedule_with_warmup(optimizer, WARMUP_STEPS, TRAINING_STEPS)
proc_seq_count = 0
sum_loss = 0.0
batch_count = 0

tmp_limerick_tens = None
work_limerick_tens = None
models_folder = "trained_models"
if not os.path.exists(models_folder):
  os.mkdir(models_folder)

for epoch in range(EPOCHS):
  print(f"Epoch {epoch}")
  for idx, limerick in enumerate(dataloader):
    limerick_tens = torch.tensor(tokenizer.encode(limerick[0])).unsqueeze(0).to(device)
    #Skip sample from dataset if it is longer than MAX_SEQ_LEN
    if limerick_tens.size()[1] > MAX_SEQ_LEN:
      continue
        
    #The first limerick sequence in the sequence
    if not torch.is_tensor(tmp_limerick_tens):
      tmp_limerick_tens = limerick_tens
      continue
    else:
      #The next limerick does not fit in so we process the sequence and leave the last limerick 
      #as the start for next sequence 
      if tmp_limerick_tens.size()[1] + limerick_tens.size()[1] > MAX_SEQ_LEN:
        work_limerick_tens = tmp_limerick_tens
        tmp_limerick_tens = limerick_tens
      else:
        #Add the limerick to sequence, continue and try to add more
        tmp_limerick_tens = torch.cat([tmp_limerick_tens, limerick_tens[:,1:]], dim=1)
        continue
  
  outputs = model(work_limerick_tens, labels=work_limerick_tens)
  loss, logits = outputs[:2]                        
  loss.backward()  
  sum_loss += loss.detach().data
  proc_seq_count += 1
  if proc_seq_count == BATCH_SIZE:
    proc_seq_count = 0    
    batch_count += 1
    optimizer.step()
    scheduler.step() 
    optimizer.zero_grad()
    model.zero_grad()

if batch_count == 100:
    print(f"sum loss {sum_loss}")
    batch_count = 0
    sum_loss = 0.0

Epoch 0
Epoch 1
Epoch 2
Epoch 3
Epoch 4


In [21]:
count = 0
with torch.no_grad():
  for idx, limerick in enumerate(dataloader):
      count += 1
      if count == 5:
        break
      prompt = "\n".join(limerick[0].split("\n")[:-2])
      cur_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)

      non_words = set(tokenizer.encode("\n,."))
      next_token_id = list(non_words)[0]
      while next_token_id in non_words:
        outputs = model(cur_ids, labels=cur_ids)
        loss, logits = outputs[:2]
        softmax_logits = torch.softmax(logits[0,-1], dim=0)
        next_token_id = choose_from_top(softmax_logits.to('cpu').numpy(), n=3)

      cur_ids = torch.cat([cur_ids, torch.ones((1,1)).long().to(device) * next_token_id], dim = 1) # Add the last word to the running sequence

      output = list(cur_ids.squeeze().to('cpu').numpy())
      print(limerick[0])
      print()
      print(tokenizer.decode(output))
      print("======================================")

Rod Stewart's not drinking Champagne yet.
He's rocked out and worked up a drained sweat.
He's back in his room
with toot, toot and zoom, zoom
because he tours with his big model
Train set
<|endoftext|>

Rod Stewart's not drinking Champagne yet.
He's rocked out and worked up a drained sweat.
He's back in his room
with toot, toot and zoom, zoom
because he tours with his big model.
When signs of decay don't appear,
the doctor will never be near.
Three-sixty-five days
he is keeping away
because our apples stay crisp a whole
Year
<|endoftext|>

When signs of decay don't appear,
the doctor will never be near.
Three-sixty-five days
he is keeping away
because our apples stay crisp a whole year
Playing video games is my plan.
There's no time for a plate, pot or pan.
My Christmas meal prop
is an easy pop-top.
There's three courses all packed in one
Can
<|endoftext|>

Playing video games is my plan.
There's no time for a plate, pot or pan.
My Christmas meal prop
is an easy pop-top.
There's three 