# Install all dependencies

In [None]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

In [None]:
!pip install --upgrade --no-cache-dir gdown
!git clone https://github.com/unitaryai/detoxify
!pip install transformers==4.16.2
!pip install bitsandbytes-cuda111
!git clone https://github.com/robgon-art/GRUEN
!pip install wmd
!pip install --upgrade --no-cache-dir gdown
!gdown --id 1S-l0L_YOzn5KhYHdB8iS37qKwuUhHP0G
!gdown --id 10LpkO5Vm_zOu723FVk6cCeRsv_qyYLdL
!unzip cola_model.zip
!pip install phonemizer
!sudo apt-get install festival
!python -m spacy download en_core_web_lg

# Env Variables

In [None]:
# Please change these paths appropriately before running the notebook
POSTPROCESSING_DIR = '/content/drive/MyDrive/true_poetry/'
GRUEN_DIR = "GRUEN"
LIMERICK_DATA_PATH = '/content/drive/MyDrive/true_poetry/limerick_dataset.csv'

# Import all necessary libs

In [None]:
import pandas as pd
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
import random
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm, trange
import torch.nn.functional as F
import csv
import json
import sys
import copy
import sys

# Add imports of source code synced grom git
# Gruen source code taken from -  https://github.com/robgon-art/GRUEN
sys.path.append(GRUEN_DIR)
import GRUEN.Main as gruen
# Postprocessing library synced from the source - https://github.com/summerstay/true_poetry
sys.path.append(POSTPROCESSING_DIR)
from limerick_generator import init_limerick_generator, generate_limerick

In [None]:
# Some of the libraries we use end up showing warnings from within
# The below lines of code supress these warnings and make the output of the training more readable
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
def warn(*args, **kwargs):
    pass
warnings.warn = warn

import warnings
warnings.filterwarnings('ignore')


# Data import

In [1]:
# define names of columns of dataset
IS_LIMERICK = 'is_limerick'
AUTHOR = 'author'
ID = 'id'
LIMERICK = 'limerick'

In [None]:
limericks = pd.read_csv(LIMERICK_DATA_PATH)
limericks = limericks[limericks[] == True]
df = limericks
df = df.drop(columns=[AUTHOR, ID, IS_LIMERICK])

In [None]:
class Limericks(Dataset):
    
    def __init__(self, control_code, gpt2_type="gpt2", max_length=1024):

        self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
        self.limericks = []
        self.limericks_text = []

        for row in df[LIMERICK]:
          self.limericks.append(torch.tensor(
                self.tokenizer.encode(f"{row[:max_length]}<|endoftext|>")
            ))
          row = row.replace('\r\n', ' ')
          self.limericks_text.append(row)
        self.limerick_count = len(self.limericks)
        
    def __len__(self):
        return self.limerick_count

    def __getitem__(self, item):
        return self.limericks[item], self.limericks_text[item]

In [None]:
dataset = Limericks(df[LIMERICK], gpt2_type="gpt2")
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# Training

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

In [None]:
# Hyperparameter definitions
num_epochs = 1
lr = 2e-5
batch_size = 1
max_seq_len = 400
optimizer = AdamW(model.parameters(), lr=lr)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=200, num_training_steps=-1)
eps = 1e-5

In [None]:
#loss function names
GRUEN_LOSS = "GRUEN_LOSS"
L2_LOSS = "L2_LOSS"

In [None]:
# Use this function to handle different batch sizes
# Currently training has batch_size = 1 due to computational issues
# TODO: Change later
def pack_tensor(new_tensor, packed_tensor, max_seq_len):
    if packed_tensor is None:
        return new_tensor
    else:
      raise NotImplementedError

In [None]:
def train(dataset, model, tokenizer, batch_size=batch_size, epochs=num_epochs, lr=lr, max_seq_len=max_seq_len, save_model_on_epoch=True, custom_loss_fn = None):

    device = torch.device("cuda")
    model = model.cuda()
    model.train()

    train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    loss=0
    accumulating_batch_count = 0
    input_tensor = None
    init_limerick_generator()
    
    for epoch in range(epochs):

        batch_bar = tqdm(total=len(train_dataloader), dynamic_ncols=True, leave=False, position=0, desc='Train') 
        print(f"Training epoch {epoch}")
        total_loss = 0

        for idx, (entry, text) in tqdm(enumerate(train_dataloader), position = 0, leave = True):
            input_tensor = pack_tensor(entry, input_tensor, 768)
            input_tensor = input_tensor.to(device)
            outputs = model(input_tensor, labels=input_tensor)

            if not custom_loss_fn:
              loss = outputs[0]
            elif custom_loss_fn == GRUEN_LOSS:
              greedy_op = torch.argmax(outputs[1][0, :, :], dim=1)
              ips = tokenizer.decode(greedy_op)
              gruen_score = gruen.get_gruen([ips])
              loss = outputs[0]  / (torch.tensor(1 - gruen_score[0]) + eps)
            elif  custom_loss_fn == L2_LOSS:
              greedy_op = torch.argmax(outputs[1][0, :, :], dim=1)
              ips = tokenizer.decode(greedy_op)
              loss = torch.norm( ((greedy_op.float() - input_tensor[0].float()) ** 2) , p=2)
          
            total_loss += float(loss)
            loss.backward()
          
            if (accumulating_batch_count % batch_size) == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                model.zero_grad()

            # tqdm lets you add some details so you can monitor training as you train.
            batch_bar.set_postfix(loss="{:.04f}".format(float(loss / (idx + 1)), lr="{:.04f}".format(float(optimizer.param_groups[0]['lr']))))
            accumulating_batch_count += 1
            input_tensor = None
            batch_bar.update() # Update tqdm bar

        batch_bar.close() # You need this to close the tqdm bar

        if save_model_on_epoch:
           torch.save({
              'epoch': epoch+1,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'loss': total_loss/len(train_dataloader),
              }, f'/content/drive/MyDrive/project_model.pt')

        print("Epoch {}/{}: Loss {:.04f}, Learning Rate {:.04f}".format(epoch + 1, epochs, total_loss/len(train_dataloader), lr))    
        
    return model

In [None]:
model = train(dataset, model, tokenizer)

# Example poetry generation

In [None]:
prompt = "The broadcasts and newspapers pull"
model_cpy = copy.deepcopy(model)
generate_limerick(prompt, model_cpy.to('cpu'))