In [11]:
!pip3 install transformers
!pip3 install tqdm
!pip install pronouncing

import nltk
nltk.download('cmudict')
nltk.download('wordnet')
nltk.download('averaged_perceptron_tagger')
from nltk.corpus import cmudict, wordnet
from nltk.tokenize import word_tokenize, wordpunct_tokenize, sent_tokenize
import pronouncing
import re

[nltk_data] Downloading package cmudict to /root/nltk_data...
[nltk_data]   Package cmudict is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


In [0]:
import os
import csv
import json
import logging
import warnings

import torch
from torch.utils.data import Dataset, DataLoader

import numpy as np
from tqdm import tqdm
import transformers
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW

logging.getLogger().setLevel(logging.CRITICAL)
warnings.filterwarnings('ignore')

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [0]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
model = model.to(device)

In [0]:
def cmp(wt):
  pt = set([s.pos() for s in wordnet.synsets(wt)])
  if 'n' in pt or 'v' in pt or 'r' in pt or 'a' in pt:
    return True
  else:
    return False

def rhymes_1(w):
    try:
      return pronouncing.rhymes(w)
    except:
      return []

def rhymes_2(s):
  try:
    (w, l, p) = s[0]
    try:
      filtered = [wt for (wt, pt) in cmudict.entries() \
                  if ((l == len(pt) \
                  and p[-2:] == pt[-2:]) \
                  or (l == len(pt)-1 \
                  and p == pt[-2:]) \
                  or (l == len(pt)+1 \
                  and pt == p[-2:]) \
                  or (l == 2 \
                  and len(pt) == 2 \
                  and p[-1:] == pt[-1:])) \
                  and (nltk.distance.edit_distance(w, wt) > 2 \
                  or not w[0:2] == wt[0:2]) \
                  and cmp(wt) \
                  and len(nltk.corpus.wordnet.synsets(wt)) > 0 \
                  and len(wt) > 2]
      return filtered
    except:
      return [w]
  except:
    return [] 

def passes_rhyme_check(rhymes_list, pred):
    i = int(pred)
    decoded_prediction = [tokenizer.decode(torch.ones((1,1)).long().to(device) * pred)][0].lower()
    decoded_prediction = decoded_prediction[1:] if decoded_prediction[0] == ' ' else decoded_prediction
    decoded_prediction = re.sub(r'[^\w\s]','',decoded_prediction)
    if decoded_prediction in rhymes_list:
      return True
    else:
      return False


# Function to first select topN tokens from the probability list and then based on the selected N word distribution
# get random token ID
def choose_from_top(rhymes_list_0, rhymes_list_1, probs, n=5):
    choice = -1
    ind = np.argsort(probs)[::-1]

    for j, pred in enumerate(ind):
      if passes_rhyme_check(rhymes_list_1, pred):
        return int(ind[j])

    for k, pred in enumerate(ind):
      if passes_rhyme_check(rhymes_list_0, pred):
        return int(ind[k])

    if choice == -1:
      # print(choice)
      return int(ind[0])

In [0]:
class LimerickDataset(Dataset):
  def __init__(self, train=True):
    super().__init__()
    self.train = train
    self.limericks = []
    self.EOT = "<|endoftext|>"

  def load(self, path):
    if path.endswith(".txt"):
      self._load_csv(path)
    if path.endswith(".json"):
      self._load_json(path)
    return self

  def _load_json(self, path):
    with open(path) as json_file:
      json_reader = json.load(json_file)
      for limerick in json_reader:
        limerick.append(self.EOT)
        self.limericks.append("\n".join(limerick))

  def _load_csv(self, path):
    with open(path) as csv_file:
      csv_reader = csv.reader(csv_file, delimiter="\n")
      limerick, skip_count = [], 0
      for row in csv_reader:
        if len(row) == 0:
          if self.train:
            limerick[-2] = limerick[-2] + limerick[-1]
            limerick.pop()
          limerick.append(self.EOT)
          self.limericks.append("\n".join(limerick))
          limerick, skip_count = [], 0
        elif skip_count < 2:
          skip_count += 1
        else:
          limerick.append(" ".join(row))

  def __len__(self):
    return len(self.limericks)

  def __getitem__(self, item):
    return self.limericks[item]

In [25]:
BATCH_SIZE = 16
EPOCHS = 10
LEARNING_RATE = 3e-4
WARMUP_STEPS = 5000
TRAINING_STEPS = 5000
MAX_SEQ_LEN = 400

train_dataset = LimerickDataset().load("limericks.json").load("limericks.txt")
print(len(train_dataset))
dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
model.train()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = transformers.get_linear_schedule_with_warmup(optimizer, WARMUP_STEPS, TRAINING_STEPS)
proc_seq_count = 0
sum_loss = 0.0
batch_count = 0

tmp_limerick_tens = None
work_limerick_tens = None
models_folder = "trained_models"
if not os.path.exists(models_folder):
  os.mkdir(models_folder)

for epoch in range(EPOCHS):
  print(f"Epoch {epoch + 1}")
  for idx, limerick in enumerate(dataloader):
    limerick_tens = torch.tensor(tokenizer.encode(limerick[0])).unsqueeze(0).to(device)
    #Skip sample from dataset if it is longer than MAX_SEQ_LEN
    if limerick_tens.size()[1] > MAX_SEQ_LEN:
      continue
        
    #The first limerick sequence in the sequence
    if not torch.is_tensor(tmp_limerick_tens):
      tmp_limerick_tens = limerick_tens
      continue
    else:
      #The next limerick does not fit in so we process the sequence and leave the last limerick 
      #as the start for next sequence 
      if tmp_limerick_tens.size()[1] + limerick_tens.size()[1] > MAX_SEQ_LEN:
        work_limerick_tens = tmp_limerick_tens
        tmp_limerick_tens = limerick_tens
      else:
        #Add the limerick to sequence, continue and try to add more
        tmp_limerick_tens = torch.cat([tmp_limerick_tens, limerick_tens[:,1:]], dim=1)
        continue
  
  outputs = model(work_limerick_tens, labels=work_limerick_tens)
  loss, logits = outputs[:2]                        
  loss.backward()  
  sum_loss += loss.detach().data
  proc_seq_count += 1
  if proc_seq_count == BATCH_SIZE:
    proc_seq_count = 0    
    batch_count += 1
    optimizer.step()
    scheduler.step() 
    optimizer.zero_grad()
    model.zero_grad()

if batch_count == 100:
    print(f"sum loss {sum_loss}")
    batch_count = 0
    sum_loss = 0.0

47307
Epoch 1
Epoch 2
Epoch 3
Epoch 4
Epoch 5
Epoch 6
Epoch 7
Epoch 8
Epoch 9
Epoch 10
Epoch 11
Epoch 12
Epoch 13
Epoch 14
Epoch 15
Epoch 16
Epoch 17
Epoch 18
Epoch 19
Epoch 20
Epoch 21
Epoch 22
Epoch 23
Epoch 24
Epoch 25
Epoch 26
Epoch 27
Epoch 28
Epoch 29
Epoch 30
Epoch 31
Epoch 32
Epoch 33
Epoch 34
Epoch 35
Epoch 36
Epoch 37
Epoch 38
Epoch 39
Epoch 40
Epoch 41
Epoch 42
Epoch 43
Epoch 44
Epoch 45
Epoch 46
Epoch 47
Epoch 48
Epoch 49
Epoch 50
Epoch 51
Epoch 52
Epoch 53
Epoch 54
Epoch 55
Epoch 56
Epoch 57
Epoch 58
Epoch 59
Epoch 60
Epoch 61
Epoch 62
Epoch 63
Epoch 64
Epoch 65
Epoch 66
Epoch 67
Epoch 68
Epoch 69
Epoch 70
Epoch 71
Epoch 72
Epoch 73
Epoch 74
Epoch 75
Epoch 76
Epoch 77
Epoch 78
Epoch 79
Epoch 80
Epoch 81
Epoch 82
Epoch 83
Epoch 84
Epoch 85
Epoch 86
Epoch 87
Epoch 88
Epoch 89
Epoch 90
Epoch 91
Epoch 92
Epoch 93
Epoch 94
Epoch 95
Epoch 96
Epoch 97
Epoch 98
Epoch 99
Epoch 100


In [46]:
count = 5
y_truth = []
y_hat = []
test_dataset = LimerickDataset(train=False).load("limericks.txt")
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)
print(len(test_dataset))
with torch.no_grad():
  for idx, limerick in enumerate(test_dataloader):
      if not count:
        break
      count -= 1
      prompt = "\n".join(limerick[0].split("\n")[:-2])

      sample = limerick[0].split("\n")[:-2]
      tokens = [wordpunct_tokenize(s) for s in sample]
      punct = set(['.', ',', '!', ':', ';', '-', ')', '(', '?', ').'])
      filtered = [ [w for w in sentence if w not in punct ] for sentence in tokens]
      last = [ sentence[len(sentence) - 1] for sentence in filtered][0:2]
      rhymes_list_0 = [rhymes_1(s) for s in last]

      syllables = [[(word, len(pron), pron) for (word, pron) in cmudict.entries() if word == w] for w in last]
      # print(syllables)
      rhymes_list_1 = [rhymes_2(s) for s in syllables]
      rhymes_list_1 = rhymes_list_1[0] + rhymes_list_1[1]

      cur_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)
      non_words = set(tokenizer.encode("\n,."))
      next_token_id = list(non_words)[0]
      while next_token_id in non_words:
        outputs = model(cur_ids, labels=cur_ids)
        loss, logits = outputs[:2]
        softmax_logits = torch.softmax(logits[0,-1], dim=0)
        next_token_id = choose_from_top(rhymes_list_0 ,rhymes_list_1, softmax_logits.to('cpu').numpy())

      cur_ids = torch.cat([cur_ids, torch.ones((1,1)).long().to(device) * next_token_id], dim = 1) # Add the last word to the running sequence

      output = list(cur_ids.squeeze().to('cpu').numpy())
      print(limerick[0])
      print()
      print(tokenizer.decode(output))
      print("======================================")
      y_truth.append(limerick[0].split("\n")[-2:-1][0].lower())
      decoded_prediction = [tokenizer.decode(torch.ones((1,1)).long().to(device) * next_token_id)][0].lower()
      decoded_prediction = decoded_prediction[1:] if decoded_prediction[0] == ' ' else decoded_prediction
      decoded_prediction = re.sub(r'[^\w\s]','',decoded_prediction)
      y_hat.append(decoded_prediction)

sum = 0
for n, _ in enumerate(y_truth):
  sum += y_hat[n] == y_truth[n]
acc = sum/len(y_truth)
print('\nAccuracy:', acc)

50
From that night, I was lit as a skunk.
All my memories are trapped in a funk.
But at this museum
is where I can see 'em.
They honor those nights we were
Drunk
<|endoftext|>

From that night, I was lit as a skunk.
All my memories are trapped in a funk.
But at this museum
is where I can see 'em.
They honor those nights we were drunk
Armored tanks aren't known to be plush.
But for this one, all soldiers will gush.
While bombs are incoming,
we use indoor plumbing.
Our tanks now have toilets that
Flush
<|endoftext|>

Armored tanks aren't known to be plush.
But for this one, all soldiers will gush.
While bombs are incoming,
we use indoor plumbing.
Our tanks now have toilets that flush
I feel like my mouth has been stung.
The loss of space there goes unsung.
To get better sleep,
my mouth needs a clean sweep.
I need to lose weight in my
Tongue
<|endoftext|>

I feel like my mouth has been stung.
The loss of space there goes unsung.
To get better sleep,
my mouth needs a clean sweep.
I need to