In [1]:
!pip3 install transformers
!pip3 install tqdm
import nltk
nltk.download('cmudict')
nltk.download('wordnet')
nltk.download('averaged_perceptron_tagger')
from nltk.corpus import cmudict, wordnet
from nltk.tokenize import word_tokenize, wordpunct_tokenize, sent_tokenize
!pip install pronouncing
import pronouncing

[nltk_data] Downloading package cmudict to /root/nltk_data...
[nltk_data]   Package cmudict is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


In [0]:
import os
import csv
import json
import logging
import warnings

import torch
from torch.utils.data import Dataset, DataLoader

import numpy as np
from tqdm import tqdm
import transformers
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW

logging.getLogger().setLevel(logging.CRITICAL)
warnings.filterwarnings('ignore')

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [0]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
model = model.to(device)

In [0]:
def intersection(lst1, lst2): 
    return list(set(lst1) & set(lst2)) 

# Function to first select topN tokens from the probability list and then based on the selected N word distribution
# get random token ID
def choose_from_top(probs, rhyme_list, n=5):
    arr = []
    ind = np.argpartition(probs, -n)[-n:]
    for i in ind:
      arr.append(tokenizer.decode(torch.ones((1,1)).long().to(device) * i))
    # print(arr)
    # print(rhyme_list)
    inter = intersection(arr, rhyme_list)
    try:
      return inter[-1]
    except:
      try:
        return rhyme_list[-1]
      except:
        top_prob = probs[ind]
        top_prob = top_prob / np.sum(top_prob) # Normalize
        
        choice = np.random.choice(n, 1, p = top_prob)
        token_id = ind[choice][0]
        return tokenizer.decode(torch.ones((1,1)).long().to(device) * int(token_id))

In [0]:
class LimerickDataset(Dataset):
  def __init__(self, train=True):
    super().__init__()
    self.train = train
    self.limericks = []
    self.EOT = "<|endoftext|>"

  def load(self, path):
    if path.endswith(".txt"):
      self._load_csv(path)
    if path.endswith(".json"):
      self._load_json(path)
    return self

  def _load_json(self, path):
    with open(path) as json_file:
      json_reader = json.load(json_file)
      for limerick in json_reader:
        limerick.append(self.EOT)
        self.limericks.append("\n".join(limerick))

  def _load_csv(self, path):
    with open(path) as csv_file:
      csv_reader = csv.reader(csv_file, delimiter="\n")
      limerick, skip_count = [], 0
      for row in csv_reader:
        if len(row) == 0:
          if self.train:
            limerick[-2] = limerick[-2] + limerick[-1]
            limerick.pop()
          limerick.append(self.EOT)
          self.limericks.append("\n".join(limerick))
          limerick, skip_count = [], 0
        elif skip_count < 2:
          skip_count += 1
        else:
          limerick.append(" ".join(row))

  def __len__(self):
    return len(self.limericks)

  def __getitem__(self, item):
    return self.limericks[item]

In [6]:
BATCH_SIZE = 16
EPOCHS = 10
LEARNING_RATE = 3e-4
WARMUP_STEPS = 5000
TRAINING_STEPS = 5000
MAX_SEQ_LEN = 400

train_dataset = LimerickDataset().load("limericks.json").load("limericks.txt")
print(len(train_dataset))
dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
model.train()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = transformers.get_linear_schedule_with_warmup(optimizer, WARMUP_STEPS, TRAINING_STEPS)
proc_seq_count = 0
sum_loss = 0.0
batch_count = 0

tmp_limerick_tens = None
work_limerick_tens = None
models_folder = "trained_models"
if not os.path.exists(models_folder):
  os.mkdir(models_folder)

for epoch in range(EPOCHS):
  print(f"Epoch {epoch + 1}")
  for idx, limerick in enumerate(dataloader):
    limerick_tens = torch.tensor(tokenizer.encode(limerick[0])).unsqueeze(0).to(device)
    #Skip sample from dataset if it is longer than MAX_SEQ_LEN
    if limerick_tens.size()[1] > MAX_SEQ_LEN:
      continue
        
    #The first limerick sequence in the sequence
    if not torch.is_tensor(tmp_limerick_tens):
      tmp_limerick_tens = limerick_tens
      continue
    else:
      #The next limerick does not fit in so we process the sequence and leave the last limerick 
      #as the start for next sequence 
      if tmp_limerick_tens.size()[1] + limerick_tens.size()[1] > MAX_SEQ_LEN:
        work_limerick_tens = tmp_limerick_tens
        tmp_limerick_tens = limerick_tens
      else:
        #Add the limerick to sequence, continue and try to add more
        tmp_limerick_tens = torch.cat([tmp_limerick_tens, limerick_tens[:,1:]], dim=1)
        continue
  
  outputs = model(work_limerick_tens, labels=work_limerick_tens)
  loss, logits = outputs[:2]                        
  loss.backward()  
  sum_loss += loss.detach().data
  proc_seq_count += 1
  if proc_seq_count == BATCH_SIZE:
    proc_seq_count = 0    
    batch_count += 1
    optimizer.step()
    scheduler.step() 
    optimizer.zero_grad()
    model.zero_grad()

if batch_count == 100:
    print(f"sum loss {sum_loss}")
    batch_count = 0
    sum_loss = 0.0

47307
Epoch 1
Epoch 2
Epoch 3
Epoch 4
Epoch 5
Epoch 6
Epoch 7
Epoch 8
Epoch 9
Epoch 10


In [0]:
def cmp(p, wt):
  pt = set([s.pos() for s in wordnet.synsets(wt)])
  # print(wt, pt)
  return len(pt & p) > 0 or len(p) == 0 or len(pt) == 0
 
def rhymes(s):
  try:
    (w, l, p, pos) = s[0]
    try:
      if pos[0] == 'N':
        pos = {'n'}
      elif pos[0] == 'V':
        pos = {'v'}
      elif pos[0:2] == 'RB' or pos == 'WRB':
        pos = {'r'}
      elif pos[0] == 'J':
        pos = {'a'}
      else:
        pos = set([s.pos() for s in wordnet.synsets(w)])
      # print(w, pos)
      filtered = [' '+wt for (wt, pt) in cmudict.entries() \
                  if ((l == len(pt) \
                  and p[-2:] == pt[-2:]) \
                  or (l == len(pt)-1 \
                  and p == pt[-2:]) \
                  or (l == len(pt)+1 \
                  and pt == p[-2:]) \
                  or (l == 2 \
                  and len(pt) == 2 \
                  and p[-1:] == pt[-1:])) \
                  and (nltk.distance.edit_distance(w, wt) > 2 \
                  or not w[0:2] == wt[0:2]) \
                  and cmp(pos, wt) \
                  and len(nltk.corpus.wordnet.synsets(wt)) > 0]
      return filtered
    except:
      return [w]
  except:
    return []

In [85]:
count = 5
test_dataset = LimerickDataset(train=False).load("limericks.txt")
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)
print(len(test_dataset))
with torch.no_grad():
  for idx, limerick in enumerate(test_dataloader):
      if not count:
        break
      count -= 1
      prompt = "\n".join(limerick[0].split("\n")[:-2])
      sample = limerick[0].split("\n")[:-2]
      tokens = [wordpunct_tokenize(s) for s in sample]
      tagged = [nltk.pos_tag(t) for t in tokens]
      punct = set(['.', ',', '!', ':', ';', '-'])
      filtered = [ [w for w in sentence if w[0] not in punct ] for sentence in tagged]
      last = [ sentence[len(sentence) - 1] for sentence in filtered][0:2]
      # print(last)
      syllables = [[(word, len(pron), pron, w[1]) for (word, pron) in cmudict.entries() if word == w[0]] for w in last]
      rhyme_list = [rhymes(s) for s in syllables]
      rhyme_list = np.concatenate(rhyme_list)

      cur_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)

      non_words = set(tokenizer.encode("\n,."))
      next_token_id = list(non_words)[0]
      # while next_token_id in non_words:
      outputs = model(cur_ids, labels=cur_ids)
      loss, logits = outputs[:2]
      softmax_logits = torch.softmax(logits[0,-1], dim=0)
      next_token = choose_from_top(softmax_logits.to('cpu').numpy(), rhyme_list, n=50000)

      cur_ids = prompt + next_token # Add the last word to the running sequence

      # output = list(cur_ids.squeeze().to('cpu').numpy())
      print(limerick[0])
      print()
      print(cur_ids)
      print("======================================")

50
He's not a good RPG shooter,
nor much of a coin and ring looter.
My kid is quite lame
at video games,
so we've hired the poor guy a
Tutor
<|endoftext|>

He's not a good RPG shooter,
nor much of a coin and ring looter.
My kid is quite lame
at video games,
so we've hired the poor guy a writer
When the world is in steady decline,
quit pretending that everything's fine.
You'll be easing your pain
if you moan and complain.
It's much better for you if you
Whine
<|endoftext|>

When the world is in steady decline,
quit pretending that everything's fine.
You'll be easing your pain
if you moan and complain.
It's much better for you if you combine
They can see how I twitch when I'm dreaming.
I sure hope that I don't wake up screaming.
At a dollar a night,
the hotel is priced right
'cause my room's on a feed that's live
Streaming
<|endoftext|>

They can see how I twitch when I'm dreaming.
I sure hope that I don't wake up screaming.
At a dollar a night,
the hotel is priced right
'cause my room's