<a href="https://colab.research.google.com/github/xinh3ng/ds-research/blob/colab/gpt2_fine_tune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Links:**

*   https://towardsdatascience.com/teaching-gpt-2-a-sense-of-humor-fine-tuning-large-transformer-models-on-a-single-gpu-in-pytorch-59e8cec40912
*   https://gist.github.com/mf1024/3df214d2f17f3dcc56450ddf0d5a4cd7#file-fine-tuning-gpt2-medium-in-pytorch-ipynb




# Generating text with a pre-trained GPT2 in PyTorch

This notebook was created as a part of a blog post - [Fine-tuning large Transformer models on a single GPU in PyTorch - Teaching GPT-2 a sense of humor](https://mf1024.github.io/2019/11/12/Fun-With-GPT-2/).

In this notebook, I will use a pre-trained medium-sized GPT2 model from the [huggingface](https://github.com/huggingface/transformers) to generate some text.

The easiest way to use huggingface transformer libraries is to install their pip package *transformers*.

In [1]:
!pip install transformers



In [2]:
from google.colab import drive
import logging
import numpy as np
import pandas as pd
import sys
import torch

logging.getLogger().setLevel(logging.CRITICAL)

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"

pd.set_option("precision", 4)

print("Python version is %s" % sys.version)
print("Device is: %s" % device)

drive.mount("/content/gdrive")

Python version is 3.6.9 (default, Jul 17 2020, 12:50:27) 
[GCC 8.4.0]
Device is: cuda
Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


### Models and classes

I use the [GPT2LMHeadModel](https://github.com/huggingface/transformers/blob/master/transformers/modeling_gpt2.py#L491) module for the language model, which is [GPT2Model](https://github.com/huggingface/transformers/blob/master/transformers/modeling_gpt2.py#L326), with an additional linear layer that uses input embedding layer weights to do the inverse operation of the embedding layer - to create logits vector for the dictionary from outputs of the GPT2.

[GPT2Tokenizer](https://github.com/huggingface/transformers/blob/master/transformers/tokenization_gpt2.py#L106) is a byte-code pair encoder that will transform input text input into input tokens that the huggingface transformers were trained on. 

In [3]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel


tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
model = model.to(device)
print("model has %s Bytes" % sys.getsizeof(model))

model has 56 Bytes


In [4]:
def choose_from_top(probs: list, n: int = 5):
    """Select topN tokens from the probability list. Then based on the selected N word distribution get random token ID"""
    ind = np.argpartition(probs, -n)[-n:]
    top_prob = probs[ind]
    top_prob = top_prob / np.sum(top_prob)  # Normalize
    choice = np.random.choice(n, 1, p=top_prob)
    token_id = ind[choice][0]
    return int(token_id)

### Text generation

At each prediction step, GPT2 model needs to know all of the previous sequence elements to predict the next one. Below is a function that will tokenize the starting input text, and then in a loop, one new token is predicted at each step and is added to the sequence, which will be fed into the model in the next step. In the end, the token list is decoded back into a text. 

In [5]:
def generate_some_text(input_str, text_len=250):
    cur_ids = torch.tensor(tokenizer.encode(input_str)).unsqueeze(0).long().to(device)
    model.eval()

    with torch.no_grad():
        for i in range(text_len):
            outputs = model(cur_ids, labels=cur_ids)
            loss, logits = outputs[:2]

            # Take the first(only one) batch and the last predicted embedding
            softmax_logits = torch.softmax(logits[0, -1], dim=0)

            # Randomly(from the given probability distribution) choose the next word from the top n words
            next_token_id = choose_from_top(softmax_logits.to("cpu").numpy(), n=10)
            cur_ids = torch.cat(
                [cur_ids, torch.ones((1, 1)).long().to(device) * next_token_id], dim=1
            )  # Add the last word

        output_list = list(cur_ids.squeeze().to("cpu").numpy())
        output_text = tokenizer.decode(output_list)
        print(output_text)

    return

## Generating the text

I will give thre different sentence beginnings to the GPT2 and let it generate the rest:


***1. The Matrix is everywhere. It is all around us. Even now, in this very room. You can see it when you look out your window or when you turn on your television. You can feel it when you go to work… when you go to church… when you pay your taxes. It is the world that has been pulled over your eyes to blind you from the truth…***

***2. Artificial general intelligence is…***

***3. The Godfather: “I’m going to make him an offer he can’t refuse.”…***

In [6]:
generate_some_text(
    "The Matrix is everywhere. It is all around us. Even now, in this very room. You can see it when you look out your window or when you turn on your television. You can feel it when you go to work... when you go to church... when you pay your taxes. It is the world that has been pulled over your eyes to blind you from the truth. "
)

The Matrix is everywhere. It is all around us. Even now, in this very room. You can see it when you look out your window or when you turn on your television. You can feel it when you go to work... when you go to church... when you pay your taxes. It is the world that has been pulled over your eyes to blind you from the truth.  You must learn to understand it."
 "I'm not saying you should be scared to die. But I'm saying that you should learn to understand it."   The Matrix is all around us. In our heads, we see everything. The Matrix is everywhere. It is all around us. Even now, in this very room. You can see it when you look out your window or when you turn on your television. You can feel it when you go to work... when you go to church... when your wages are paid. It is the world that has been pulled over your eyes to blind you from the truth.  You must learn to understand it." "I'm not saying you should be scared to die. But I'm saying that you should learn to understand it."
"It is

In [7]:
generate_some_text(" Artificial general intelligence is ")

 Artificial general intelligence is vernacular as it is. It's a great tool for the development of new ideas that might otherwise only be explored by people who know nothing about science or computer science. In short, it's great for developing new ideas that will help us better understand the world around us.

There is a great deal of work that needs to be done to make AI the most useful tool we can be. There is a lot more that needs to be done, but at the moment there are two main things we do:

We need to be able to use artificial intelligence to solve problems and solve problems that don't exist,

We need to be able to use artificial intelligence to help solve problems that are not actually human in nature, and

We need to be able to use artificial intelligence to solve problems that do not exist and do not have any human involvement whatsoever.

So let me give an example. Let's say we want to solve a problem. We have two main goals. One is to understand the physics of the problem. 

In [8]:
generate_some_text(" The Godfather: \"I'm going to make him an offer he can't refuse.\" ")

 The Godfather: "I'm going to make him an offer he can't refuse."  And it was just such a deal.  He offered to pay his $5,000 to get him to take a photo of himself on his Facebook page and post it.  He also offered a $10,000 reward for those who could get him to write down the name of his mother and get her to take the picture.  So I was pretty excited at that point.  I think the reason why I was so excited was because I thought there could be a way around it because I think the way people were going to react when I posted was that they would say "you know what, you're really a man!" It's like I thought "I'm just being nice and I'm not being mean!" But that's not how the world works.  It's not like the world works for you. I've had people say, "You know, I want to be a guy and that's why I'm doing that." But I think it was just like it was the only answer that I'd ever get from people that said, "Well I want your money for that."  And it's just like you've got to figure things out.  Yo

In [9]:
"""
Jokes data set
"""
import csv
import os
import json

from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader


class JokesDataset(Dataset):
    def __init__(self, jokes_dataset_path: str):
        super().__init__()
        short_jokes_path = os.path.join(jokes_dataset_path, "shortjokes.csv")
        self.joke_list = []
        self.end_of_text_token = "<|endoftext|>"

        with open(short_jokes_path) as csv_file:
            csv_reader = csv.reader(csv_file, delimiter=",")
            x = 0
            for row in csv_reader:
                joke_str = f"JOKE:{row[1]}{self.end_of_text_token}"
                self.joke_list.append(joke_str)

    def __len__(self):
        return len(self.joke_list)

    def __getitem__(self, item):
        return self.joke_list[item]


jokes_dataset_path = "/content/gdrive/My Drive/xheng/data/jokes_data/"  # flower dataset's path

dataset = JokesDataset(jokes_dataset_path=jokes_dataset_path)
joke_loader = DataLoader(dataset, batch_size=1, shuffle=True)

In [None]:
from transformers import AdamW, get_linear_schedule_with_warmup


BATCH_SIZE = 16
EPOCHS = 5
LEARNING_RATE = 3e-5
WARMUP_STEPS = 5000
MAX_SEQ_LEN = 400

# Train the model and save the model weights after each epoch and then generate jokes with each version of the weight
# to see which performs the best.

model = model.to(device)
model.train()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=-1)

proc_seq_count = 0
sum_loss = 0.0
batch_count = 0
tmp_jokes_tens = None

models_folder = jokes_dataset_path + "trained_models"
if not os.path.exists(models_folder):
    os.mkdir(models_folder)

for epoch in range(EPOCHS):
    print(f"EPOCH: {epoch} started")
    for idx, joke in enumerate(joke_loader):
        # print(f"Starting with idx: {idx}, joke: {joke}")

        # Fit as many joke sequences into MAX_SEQ_LEN sequence as possible
        joke_tens = torch.tensor(tokenizer.encode(joke[0])).unsqueeze(0).to(device)

        # Skip sample from dataset if it is longer than MAX_SEQ_LEN
        if joke_tens.size()[1] > MAX_SEQ_LEN:
            continue

        # The first joke sequence in the sequence
        if not torch.is_tensor(tmp_jokes_tens):
            tmp_jokes_tens = joke_tens
            continue
        else:
            # The next joke does not fit in so we process the sequence and leave the last joke as the start for next sequence
            if tmp_jokes_tens.size()[1] + joke_tens.size()[1] > MAX_SEQ_LEN:
                work_jokes_tens = tmp_jokes_tens
                tmp_jokes_tens = joke_tens
            else:
                # Add the joke to sequence, continue and try to add more
                tmp_jokes_tens = torch.cat([tmp_jokes_tens, joke_tens[:, 1:]], dim=1)
                continue

        # Sequence ready, process it trough the model
        outputs = model(work_jokes_tens, labels=work_jokes_tens)
        loss, logits = outputs[:2]
        loss.backward()
        sum_loss = sum_loss + loss.detach().data

        proc_seq_count = proc_seq_count + 1
        if proc_seq_count == BATCH_SIZE:
            proc_seq_count = 0
            batch_count += 1
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            model.zero_grad()

        if batch_count == 10:
            print(f"batch_count = {batch_count}, sum_loss = {sum_loss}")
            batch_count, sum_loss = 0, 0.0

    print("Storing the model after each epoch to compare the performance of them")
    torch.save(model.state_dict(), os.path.join(models_folder, f"gpt2_small_joker_{epoch}.pt"))

EPOCH: 0 started
batch_count = 10, sum_loss = 726.8229370117188
batch_count = 10, sum_loss = 728.3819580078125
batch_count = 10, sum_loss = 719.9417724609375
batch_count = 10, sum_loss = 720.7799682617188
batch_count = 10, sum_loss = 721.210693359375
batch_count = 10, sum_loss = 717.4074096679688
batch_count = 10, sum_loss = 710.0521240234375
batch_count = 10, sum_loss = 711.2081298828125
batch_count = 10, sum_loss = 705.300537109375
batch_count = 10, sum_loss = 698.4765014648438
batch_count = 10, sum_loss = 700.3212890625
batch_count = 10, sum_loss = 691.4007568359375
batch_count = 10, sum_loss = 684.748291015625
batch_count = 10, sum_loss = 682.7960205078125
batch_count = 10, sum_loss = 680.3382568359375
batch_count = 10, sum_loss = 670.200927734375
batch_count = 10, sum_loss = 670.744140625
batch_count = 10, sum_loss = 662.0491943359375
batch_count = 10, sum_loss = 650.756103515625
batch_count = 10, sum_loss = 642.2711791992188
batch_count = 10, sum_loss = 630.345947265625
batch_cou

In [None]:
"""
Generating the jokes
"""
MODEL_EPOCH = 4
model_path = os.path.join(models_folder, f"gpt2_small_joker_{MODEL_EPOCH}.pt")
model.load_state_dict(torch.load(model_path))

jokes_output_file_path = jokes_dataset_path + f"generated_{MODEL_EPOCH}.jokes"

model.eval()
if os.path.exists(jokes_output_file_path):
    os.remove(jokes_output_file_path)

joke_num = 0
with torch.no_grad():
    for joke_idx in range(1000):
        joke_finished = False
        cur_ids = torch.tensor(tokenizer.encode("JOKE:")).unsqueeze(0).to(device)

        for i in range(100):
            outputs = model(cur_ids, labels=cur_ids)
            loss, logits = outputs[:2]
            softmax_logits = torch.softmax(
                logits[0, -1], dim=0
            )  # Take the first(from only one in this case) batch and the last predicted embedding
            if i < 3:
                n = 20
            else:
                n = 3
            next_token_id = choose_from_top(
                softmax_logits.to("cpu").numpy(), n=n
            )  # Randomly(from the topN probability distribution) select the next word
            cur_ids = torch.cat(
                [cur_ids, torch.ones((1, 1)).long().to(device) * next_token_id], dim=1
            )  # Add the last word to the running sequence

            if next_token_id in tokenizer.encode("<|endoftext|>"):
                joke_finished = True
                break

        if joke_finished:
            joke_num = joke_num + 1
            output_list = list(cur_ids.squeeze().to("cpu").numpy())
            output_text = tokenizer.decode(output_list)
            with open(jokes_output_file_path, "a") as f:
                f.write(f"{output_text} \n\n")