# Imports

In [None]:
import os

BASE_PATH = os.getcwd()
DATA_PATH = BASE_PATH + '/data/'
MODEL_PATH = BASE_PATH + '/models/'

if not os.path.exists(DATA_PATH):
    os.makedirs(DATA_PATH)
if not os.path.exists(MODEL_PATH):
    os.makedirs(MODEL_PATH)

In [None]:
import os
import time
import datetime

import pandas as pd
import seaborn as sns
import numpy as np
import random

import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler
from transformers import GPT2LMHeadModel,  GPT2Tokenizer, GPT2Config, GPT2LMHeadModel
from transformers import AdamW, get_linear_schedule_with_warmup
import json
from tqdm import tqdm
import pdb
seed_model(7)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('using device: {}'.format(device))

# Data Loading

In [None]:
import json

class GPT2Dataset(Dataset):
    """
    Custom dataset containing the GPT2 tokens from the subreddit
    """

    def __init__(self, input_file, tokenizer, gpt2_type="gpt2", max_length=768):
        self.tokenizer = tokenizer
        self.input_ids = []
        self.attn_masks = []

        with open(DATA_PATH+input_file) as f:
            data = json.load(f)

        for post in data:
            encodings_dict = tokenizer('<|startoftext|> '+ post + ' <|endoftext|>', 
                                       truncation=True, 
                                       max_length=max_length, 
                                       padding="max_length")

            self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
            self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
    
    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.attn_masks[idx] 


tokenizer = GPT2Tokenizer.from_pretrained('gpt2', 
                                          bos_token='<|startoftext|>', 
                                          eos_token='<|endoftext|>', 
                                          pad_token='<|pad|>')

shower_thoughts_dataset = GPT2Dataset('posts.json', tokenizer, max_length=512)

# Seeding

In [None]:
def seed_model(seed=7):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
seed_model()

# Data Split

In [None]:
train_data, validation_data, test_data =  torch.utils.data.random_split(shower_thoughts_dataset, [12846, 1000, 1000])

In [None]:
print('Training Samples: {} - Validation Samples: {} - Test Samples: {}'
      .format(len(train_data), len(validation_data), len(test_data)))

# GPT-2 Fine-Tuning

In [None]:
BATCH_SIZE = 2 # do not modify - can run into out-of-memory issues
EPOCHS = 5
LR = 5e-5
WARMUP_STEPS = 1e2
EPS = 1e-8

In [None]:
train_dataloader = DataLoader(train_data, sampler=RandomSampler(train_data), batch_size=BATCH_SIZE)
validation_dataloader = DataLoader(validation_data, sampler=RandomSampler(validation_data), batch_size=BATCH_SIZE)
test_dataloader = DataLoader(test_data, sampler=RandomSampler(test_data), batch_size=BATCH_SIZE)

In [None]:
LOAD_FROM_DISK = True

MODEL_NAME = "final_model_epoch_3"

# Loads GPT2 configuration. We also don't want to output hidden states (not seq2seq).
configuration = GPT2Config.from_pretrained('gpt2', output_hidden_states=False)


if not LOAD_FROM_DISK:
    # Instantiates a new pretrained GPT2 LM
    print('Loading base gpt model.')
    model = GPT2LMHeadModel.from_pretrained("gpt2", config=configuration)
else:
    # TODO: load GPT-2 LM model from existing pre-trained model
    print('Loading model from disk.')
    model = GPT2LMHeadModel.from_pretrained(MODEL_PATH + MODEL_NAME + '.pt')


# need to reisize token embeddings sinze we added new tokens to the embeddings
model.resize_token_embeddings(len(tokenizer))
model.to(device)

# We use huuggingface's AdamW optimizer, as opposed to pytorch's
optimizer = AdamW(model.parameters(), lr=LR, eps=EPS)

# Total number of training steps is [number of batches] x [number of epochs]. 
total_steps = len(train_dataloader) * EPOCHS

# Create the learning rate scheduler which changes the learning rate 
# as the training loop progresses
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = WARMUP_STEPS, 
                                            num_training_steps = total_steps)

# Sampling

In [None]:
def sample(model, n=10):
    
    model.eval()
    
    prompt = "<|startoftext|>" # initializes sampling with beginning-of-sentence token
    generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)

    sample_outputs = model.generate(bos_token_id=generated,
                                    do_sample=True,   
                                    top_k=50, 
                                    max_length = 100,
                                    top_p=0.95,
                                    num_return_sequences=n)

    for i, sample_output in enumerate(sample_outputs):
        print("{}: {}\n".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))
        
    model.train()

# Training

In [None]:

training_stats = []

model = model.to(device)

for epoch in range(EPOCHS):
    """ Training """

    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch + 1, EPOCHS))
    print('Training...')

    train_loss = 0

    model.train()

    for step, batch in enumerate(train_dataloader):

        in_features = batch[0].to(device)
        labels = batch[0].to(device)
        mask = batch[1].to(device)
        model.zero_grad()        

        outputs = model(in_features, labels=labels, attention_mask=mask, token_type_ids=None)

        loss = outputs[0]  

        batch_loss = loss.item()
        train_loss += batch_loss

        loss.backward()

        optimizer.step()

        scheduler.step()

        if step % 10 == 0 and step > 0:
            print('Average train loss: {}'
            .format(train_loss / (BATCH_SIZE * step)))

    # Calculate the average loss over all of the batches.
    train_loss = train_loss / len(train_dataloader)       
    train_perplexity = 2 ** train_loss

    print("")
    print("     Training Loss: {0:.2f} - Training Perplexity {1:.2f}".format(train_loss, train_perplexity))
        
    """ Validation """

    print("")
    print("Validating...")

    model.eval()

    eval_loss = 0

    for batch in validation_dataloader:
        
        input = batch[0].to(device)
        labels = batch[0].to(device)
        mask = batch[1].to(device)
        
        with torch.no_grad():        
            outputs = model(input, labels=labels, attention_mask=mask, token_type_ids=None)          
            loss = outputs[0]  
            
        batch_loss = loss.item()
        eval_loss += batch_loss        

    eval_loss = eval_loss / len(test_dataloader)
    eval_perplexity = (2 ** eval_loss)
    
    print("     Validation Loss: {0:.2f} - Validation Perplexity {1:.2f}".format(eval_loss, eval_perplexity))

    sample(model, 10)
    
    training_stats.append({'epoch': epoch + 1, 
                           'Training Loss': train_loss, 
                           'Training Perplexity': train_perplexity,
                           'Valid. Loss': eval_perplexity,
                           'Valid. Perplexity': eval_perplexity
                           })
    
    model.save_pretrained(MODEL_PATH + MODEL_NAME + '_epoch_{}.pt'.format(epoch))

print("")
print("Training complete")

## Save model to disk

In [None]:
model.save_pretrained(MODEL_PATH+MODEL_NAME)

## Save training stats to disk

In [None]:
import pickle
print("Saving training stats to disk...")

with open('training_stats.pickle', 'wb') as file:
    pickle.dump(training_stats, file, protocol=pickle.HIGHEST_PROTOCOL)

# Testing

In [None]:
""" Testing """
print("")
print("Testing...")

model.eval()

test_loss = 0

# Evaluate data for one epoch
for batch in test_dataloader:
    
    input = batch[0].to(device)
    labels = batch[0].to(device)
    mask = batch[1].to(device)
    
    with torch.no_grad():        
        outputs = model(input, labels=labels, attention_mask=mask, token_type_ids=None)          
        loss = outputs[0]  # recovers loss from GPT-2 output
        
    test_loss += loss.item()       

test_loss = test_loss / len(test_dataloader)
test_perplexity = 2 ** test_loss

print("     Test Loss: {0:.2f} - Test Perplexity {1:.2f}".format(test_loss, test_perplexity))

testing_stats = {'Test Loss': test_loss, 'Test Perplexity': test_perplexity}

## Save testing stats to disk

In [None]:
print("Saving testing stats to disk...")

with open('testing_stats_epoch_3.pickle', 'wb') as file:
    pickle.dump(testing_stats, file, protocol=pickle.HIGHEST_PROTOCOL)

# Plots

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import pickle
import seaborn as sns


sns.set_theme()

with open('training_stats.pickle', 'rb') as file:
    training_stats = pickle.load(file)
    
with open('testing_stats_epoch_3.pickle', 'rb') as file:
    testing_stats = pickle.load(file)
    
train_loss = [data['Training Loss'] for data in training_stats]
train_perplexity = [data['Training Perplexity'] for data in training_stats]

validation_loss = [data['Training Loss'] for data in training_stats]
validation_perplexity = [data['Valid. Perplexity'] for data in training_stats]

test_loss = testing_stats['Test Loss']
test_perplexity = testing_stats['Test Perplexity']

x = range(1, len(training_stats)+1)

ax = plt.figure().gca()
ax.xaxis.set_major_locator(MaxNLocator(integer=True))

plt.axhline(test_perplexity, label='test (epochs=3)', color='red')
plt.plot(x, train_perplexity, label='train')
plt.plot(x, validation_perplexity, label='validation')
plt.xlabel('Epoch')
plt.ylabel('Perplexity')
plt.rcParams['figure.dpi'] = 300
plt.legend()

# Demo

In [None]:
sample(model, 500)