<a href="https://colab.research.google.com/github/spenc34/wwdtm-limerick-solver/blob/master/gpt_2_limerick.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip3 install transformers
!pip3 install tqdm

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/37/ba/dda44bbf35b071441635708a3dd568a5ca6bf29f77389f7c7c6818ae9498/transformers-2.7.0-py3-none-any.whl (544kB)
[K     |████████████████████████████████| 552kB 3.4MB/s 
Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)
[K     |████████████████████████████████| 1.0MB 55.6MB/s 
[?25hCollecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/a6/b4/7a41d630547a4afd58143597d5a49e07bfd4c42914d8335b2a5657efc14b/sacremoses-0.0.38.tar.gz (860kB)
[K     |████████████████████████████████| 870kB 27.1MB/s 
Collecting tokenizers==0.5.2
[?25l  Downloading https://files.pythonhosted.org/packages/d1/3f/73c881ea4723e43c1e9acf317cf407fab3a278daab3a69c98dcac511c04f/tokenizers-0.5.2-cp36-cp36m-manylinux1_x86_64.whl (3.7MB)
[K     |████

In [0]:
import os
import csv
import json
import logging
import warnings

import torch
from torch.utils.data import Dataset, DataLoader

import numpy as np
from tqdm import tqdm
import transformers
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW

logging.getLogger().setLevel(logging.CRITICAL)
warnings.filterwarnings('ignore')

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [5]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
model = model.to(device)

HBox(children=(IntProgress(value=0, description='Downloading', max=1042301, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Downloading', max=456318, style=ProgressStyle(description_wid…




HBox(children=(IntProgress(value=0, description='Downloading', max=341, style=ProgressStyle(description_width=…




HBox(children=(IntProgress(value=0, description='Downloading', max=1520013706, style=ProgressStyle(description…




In [0]:
# Function to first select topN tokens from the probability list and then based on the selected N word distribution
# get random token ID
def choose_from_top(probs, n=5):
    ind = np.argpartition(probs, -n)[-n:]
    top_prob = probs[ind]
    top_prob = top_prob / np.sum(top_prob) # Normalize
    
    choice = np.random.choice(n, 1, p = top_prob)
    token_id = ind[choice][0]
    return int(token_id)

In [0]:
class LimerickDataset(Dataset):
  def __init__(self, train=True):
    super().__init__()
    self.train = train
    self.limericks = []
    self.EOT = "<|endoftext|>"

  def load(self, path):
    if path.endswith(".txt"):
      self._load_csv(path)
    if path.endswith(".json"):
      self._load_json(path)
    return self

  def _load_json(self, path):
    with open(path) as json_file:
      json_reader = json.load(json_file)
      for limerick in json_reader:
        limerick.append(self.EOT)
        self.limericks.append("\n".join(limerick))

  def _load_csv(self, path):
    with open(path) as csv_file:
      csv_reader = csv.reader(csv_file, delimiter="\n")
      limerick, skip_count = [], 0
      for row in csv_reader:
        if len(row) == 0:
          if self.train:
            limerick[-2] = limerick[-2] + limerick[-1]
            limerick.pop()
          limerick.append(self.EOT)
          self.limericks.append("\n".join(limerick))
          limerick, skip_count = [], 0
        elif skip_count < 2:
          skip_count += 1
        else:
          limerick.append(" ".join(row))

  def __len__(self):
    return len(self.limericks)

  def __getitem__(self, item):
    return self.limericks[item]

In [49]:
BATCH_SIZE = 16
EPOCHS = 10
LEARNING_RATE = 3e-4
WARMUP_STEPS = 5000
TRAINING_STEPS = 5000
MAX_SEQ_LEN = 400

train_dataset = LimerickDataset().load("limericks.json").load("limericks.txt")
print(len(train_dataset))
dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
model.train()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = transformers.get_linear_schedule_with_warmup(optimizer, WARMUP_STEPS, TRAINING_STEPS)
proc_seq_count = 0
sum_loss = 0.0
batch_count = 0

tmp_limerick_tens = None
work_limerick_tens = None
models_folder = "trained_models"
if not os.path.exists(models_folder):
  os.mkdir(models_folder)

for epoch in range(EPOCHS):
  print(f"Epoch {epoch + 1}")
  for idx, limerick in enumerate(dataloader):
    limerick_tens = torch.tensor(tokenizer.encode(limerick[0])).unsqueeze(0).to(device)
    #Skip sample from dataset if it is longer than MAX_SEQ_LEN
    if limerick_tens.size()[1] > MAX_SEQ_LEN:
      continue
        
    #The first limerick sequence in the sequence
    if not torch.is_tensor(tmp_limerick_tens):
      tmp_limerick_tens = limerick_tens
      continue
    else:
      #The next limerick does not fit in so we process the sequence and leave the last limerick 
      #as the start for next sequence 
      if tmp_limerick_tens.size()[1] + limerick_tens.size()[1] > MAX_SEQ_LEN:
        work_limerick_tens = tmp_limerick_tens
        tmp_limerick_tens = limerick_tens
      else:
        #Add the limerick to sequence, continue and try to add more
        tmp_limerick_tens = torch.cat([tmp_limerick_tens, limerick_tens[:,1:]], dim=1)
        continue
  
  outputs = model(work_limerick_tens, labels=work_limerick_tens)
  loss, logits = outputs[:2]                        
  loss.backward()  
  sum_loss += loss.detach().data
  proc_seq_count += 1
  if proc_seq_count == BATCH_SIZE:
    proc_seq_count = 0    
    batch_count += 1
    optimizer.step()
    scheduler.step() 
    optimizer.zero_grad()
    model.zero_grad()

if batch_count == 100:
    print(f"sum loss {sum_loss}")
    batch_count = 0
    sum_loss = 0.0

47307
Epoch 0
Epoch 1
Epoch 2
Epoch 3
Epoch 4
Epoch 5
Epoch 6
Epoch 7
Epoch 8
Epoch 9


In [54]:
count = 5
test_dataset = LimerickDataset(train=False).load("limericks.txt")
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)
print(len(test_dataset))
with torch.no_grad():
  for idx, limerick in enumerate(test_dataloader):
      if not count:
        break
      count -= 1
      prompt = "\n".join(limerick[0].split("\n")[:-2])
      cur_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)

      non_words = set(tokenizer.encode("\n,."))
      next_token_id = list(non_words)[0]
      while next_token_id in non_words:
        outputs = model(cur_ids, labels=cur_ids)
        loss, logits = outputs[:2]
        softmax_logits = torch.softmax(logits[0,-1], dim=0)
        next_token_id = choose_from_top(softmax_logits.to('cpu').numpy(), n=3)

      cur_ids = torch.cat([cur_ids, torch.ones((1,1)).long().to(device) * next_token_id], dim = 1) # Add the last word to the running sequence

      output = list(cur_ids.squeeze().to('cpu').numpy())
      print(limerick[0])
      print()
      print(tokenizer.decode(output))
      print("======================================")

50
With coffee and juice as our drinkies,
we spoon in this stuff with raised pinkies.
With snack cakes by Hostess,
our breakfast is toastless.
We're having a bowl full of
Twinkies
<|endoftext|>

With coffee and juice as our drinkies,
we spoon in this stuff with raised pinkies.
With snack cakes by Hostess,
our breakfast is toastless.
We're having a bowl full of cereal
From that night, I was lit as a skunk.
All my memories are trapped in a funk.
But at this museum
is where I can see 'em.
They honor those nights we were
Drunk
<|endoftext|>

From that night, I was lit as a skunk.
All my memories are trapped in a funk.
But at this museum
is where I can see 'em.
They honor those nights we were born
That duct tape banana, I hate it.
More than 100 grand is what I rate it.
Oh, sure it is art.
So I'm playing my part.
Because I was hungry, I
Ate it
<|endoftext|>

That duct tape banana, I hate it.
More than 100 grand is what I rate it.
Oh, sure it is art.
So I'm playing my part.
Because I was hung