# Decaying prompt control in text generation

This notebook demonstrate with experiments how the control signal of instruction prompt decays as the distance increases between the prompt and the generated text. In the experiments, GPT-2 is first fine-tuned to generate continuations for stories according to the given sentiment in the control prompt. 

As the dataset for fine-tuning GPT-2 for story generation and for obtaining the story beginnings, we will be using the WritingPrompts dataset [1]. The dataset is originally collected from Reddit's [WritingPrompts forum](https://www.reddit.com/r/WritingPrompts/), where users can respond with stories to story prompts provided by other users. From this dataset, we will be using the stories written by the users to fine-tune GPT-2 to generate stories according to the sentiment given in the instruction. More specifically, in the fine-tuning and experiments stage we will utilize the following prompt: "Continue the story with <|sentiment|> sentiment: <|story|>", where <|sentiment|> is either "positive" or "negative" and <|story|> is a varying length start of a story from the WritingPrompts dataset.

### Experiment setup
Text Generation Model: [GPT-2 Medium](https://huggingface.co/openai-community/gpt2-medium) (355M parameters)<br>
Sentiment Classifier Model: [Twitter-roBERTa-base for Sentiment Analysis](https://huggingface.co/cardiffnlp/twitter-roberta-base-sentiment-latest) (125M parameters)<br>
Dataset: [WritingPrompts](https://github.com/facebookresearch/fairseq/blob/main/examples/stories/README.md)

## Experiments

We start by first installing the needed libraries. We will use [Hugging Face](https://huggingface.co/) for downloading the pre-trained models, GPT-2 and fine-tuned RoBERTa, and [PyTorch](https://pytorch.org/) for fine-tuning GPT-2 to generate text according to the instructed sentiment.

In [3]:
%pip install torch transformers

Note: you may need to restart the kernel to use updated packages.


In [1]:
import torch
from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GenerationConfig, RobertaTokenizerFast, RobertaForSequenceClassification
from transformers import get_linear_schedule_with_warmup, AdamW
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
import numpy as np

torch_device = "cuda" if torch.cuda.is_available() else "cpu"

### Data pre-processing
After importing the libraries, we will download the dataset and preprocess it for the fine-tuning and evaluation stages. To download the dataset, we follow the instructions given in the [README.md](https://github.com/facebookresearch/fairseq/blob/main/examples/stories/README.md) of the Hierarchical Neural Story Generation, which is the paper where the WritingPrompts dataset was presented.

In [None]:
!curl https://dl.fbaipublicfiles.com/fairseq/data/writingPrompts.tar.gz | tar xvzf -

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0x writingPrompts/
  0  363M    0  206k    0     0  1241k      0  0:04:59 --:--:--  0:04:59 1238k
  3  363M    3 12.3M    0     0  5842k      0  0:01:03  0:00:02  0:01:01 5840k
x writingPrompts/README
x writingPrompts/valid.wp_source
  8  363M    8 29.4M    0     0  7238k      0  0:00:51  0:00:04  0:00:47 7238k
 94  363M   94  344M    0     0  6893k      0  0:00:54  0:00:51  0:00:03 6761kk  0  0:00:50  0:00:37  0:00:13 4908k
100  363M  100  363M    0     0  6972k      0  0:00:53  0:00:53 --:--:-- 7689k



  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0x writingPrompts/
x writingPrompts/test.wp_source
  3  363M    3 12.3M    0     0  7841k      0  0:00:47  0:00:01  0:00:46 7837k
x writingPrompts/README
x writingPrompts/valid.wp_source
  8  363M    8 32.5M    0     0  9224k      0  0:00:40  0:00:03  0:00:37 9223k
 95  363M   95  348M    0     0  10.0M      0  0:00:36  0:00:34  0:00:02 10.7M25k      0  0:00:39  0:00:07  0:00:32 9937k 0     0  9948k      0  0:00:37  0:00:20  0:00:17 10.5M0.5M
100  363M  100  363M    0     0  10.0M      0  0:00:36  0:00:36 --:--:-- 10.8M



In the experiments, we will use the target training dataset to fine-tune the GPT-2 model for generating stories, the validation dataset for checking overfitting during the training, and the test dataset for giving the initial context in the control signal evaluation phase. As we are only interested in the story generation capabilities of the model following the sentiment given in the first sentence of the story, we will only use the stories of the dataset (target) without the initial context prompts (source) for the stories.

In [2]:
def load_file(filename: str, encoding="utf-8") -> list[str]:
    with open(filename, "r", encoding=encoding) as f:
        return f.readlines()

In [3]:
train_data = load_file("writingPrompts/train.wp_target")
valid_data = load_file("writingPrompts/valid.wp_target")
test_data = load_file("writingPrompts/test.wp_target")

print(f"Train data: {len(train_data)}")
print(f"Valid data: {len(valid_data)}")
print(f"Test data: {len(test_data)}")

Train data: 272600
Valid data: 15620
Test data: 15138


In  what follows, we use the [Hugging Face Tokenizer](https://huggingface.co/docs/transformers/en/main_classes/tokenizer) for tokenizing the train, validation, and test datasets. As we use GPT-2 (medium) as our language model, we will use the GPT2TokenizerFast. Underneath the class abstractions, OpenAI models utilize the [BPE (Byte pair encoding) algorithm](https://en.wikipedia.org/wiki/Byte_pair_encoding) in tokenizing the input texts for its models. For GPT-2, this encoding has a vocabulary size of 50,257 tokens. Before tokenizing, we have to also replace the special token of `<newline>` in the datasets to the standard `\n` token indicating a new line in the text. We will also add a random (to avoid overfitting) instruction from a pre-defined instruction list for the model to continue the story as during the test time we will include in this instruction the control prompt of the desired sentiment for the story. In addition, we will truncate all the stories including the instruction prompt to 512 tokens as the RoBERTa-based sentiment classifier can only classify sequences up to 512 tokens (different encoding algorithm but ballpark is the same), thus making it unnecessary to fine-tune the GPT-2 to generate longer stories. Shorter sequences than 512 are padded with the `<|pad|>` token. We will omit the usage of `<|endoftext|>` and `<|startoftext|>` tokens as during the training we won't concatenate multiple stories together, thus each batch consists of individual stories with their instruction prompts prepended.

We will utilize a custom PyTorch DataSet class for constructing the training and validation datasets, and the PyTorch DataLoader for enabling the batching (batch_size=16) and shuffling (shuffle=True) required by the mini-batch stochastic gradient descent. As the fine-tuning dataset is quite large (272,600 stories), we will tokenize the stories only after they are selected to the current mini-batch. Therefore, the tokenization is done in a custom collate function of the DataLoader to enable tokenizing the whole batch in one go.

In [4]:
class StoryGenerationDataset(Dataset):
    def __init__(self, data: list[str]):
        self.data = data
        self.instructions = [
           "Continue the story:",
           "Keep the narrative going:",
           "Resume the tale:",
           "Carry on with the story:",
           "Proceed with the plot:",
           "Continue the narrative:",
           "Move the story forward:",
           "Keep telling the story:",
           "Follow through with the story:"
        ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        story = self.data[idx]
        story = story.replace("<newline>", "\n").strip()
        instruction = np.random.choice(self.instructions)
        return f"{instruction} {story}"

In [5]:
def collate_batch(batch: list[str], tokenizer: GPT2TokenizerFast, max_length: int = 512):
    encodings = tokenizer(text=batch, max_length=max_length, truncation=True, padding="max_length", return_tensors="pt").to(torch_device)
    return encodings["input_ids"], encodings["attention_mask"]

In [27]:
gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", pad_token="<|endoftext|>")

batch_size = 16

train_dataset = StoryGenerationDataset(train_data)
train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=lambda x: collate_batch(x, gpt2_tokenizer, max_length=512)
)

valid_dataset = StoryGenerationDataset(valid_data)
valid_dataloader = DataLoader(
    dataset=valid_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=lambda x: collate_batch(x, gpt2_tokenizer, max_length=512)
)

### Fine-tuning

Now that we have our fine-tuning dataset and corresponding data loader ready, we can formulate the training loop for the fine-tuning process. For this we need to instantiate the GPT-2 Medium model object. We will use the Hugging Face model instantiation for this, openai-community/gpt2-medium. For the hyperparameters, we will use as the a starting learning rate `0.001` with a linearly decreasing schedule with warmup, warmup steps `100`, and epsilon `1e-8`. In total we will fine-tune for three epochs. At every 100 step we will generate a sample story text and also print out the validation loss, which is expected to decrease as the model adapts to the story domain. As the optimizer, we will use the AdamW-algorithm optimizer implemented by Hugging Face.

In [33]:
gpt2_model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-medium").to(torch_device)
gpt2_model.resize_token_embeddings(len(gpt2_tokenizer))

lr = 0.001
eps = 1e-8
num_epochs = 3
num_warmup = 100

optimizer = AdamW(gpt2_model.parameters(), lr=lr, eps=eps)
total_steps = len(train_dataloader) * num_epochs

scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup, num_training_steps=total_steps)

generation_config = GenerationConfig(
    max_new_tokens=40,
    do_sample=True,
    top_k=40,
    pad_token_id=gpt2_tokenizer.pad_token_id,
)

sample_input_prompt = "Continue the story: The"
sample_input = gpt2_tokenizer(sample_input_prompt, return_tensors="pt").to(torch_device)



In [34]:
%%time

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")

    total_train_loss = 0
    gpt2_model.train()

    for step, batch in enumerate(train_dataloader):
        input_ids, attention_mask = batch

        optimizer.zero_grad()

        outputs = gpt2_model(input_ids, attention_mask=attention_mask, labels=input_ids)
        loss = outputs.loss
        total_train_loss += loss.item()

        if step % 100 == 0:
            print(f"Step {step} of {total_steps} - Loss: {loss.item()}")

            gpt2_model.eval()

            sample_output = gpt2_model.generate(
                inputs=sample_input["input_ids"],
                attention_mask=sample_input["attention_mask"],
                generation_config=generation_config
            )
            sample_text = gpt2_tokenizer.decode(sample_output[0], skip_special_tokens=True)
            print(f"Sample output: {sample_text}")
            
            gpt2_model.train()
        
        loss.backward()
        optimizer.step()
        scheduler.step()
    
    print(f"Average training loss: {total_train_loss / len(train_dataloader)}")

    gpt2_model.eval()
    total_eval_loss = 0

    for batch in valid_dataloader:
        input_ids, attention_mask = batch

        with torch.no_grad():
            outputs = gpt2_model(input_ids, attention_mask=attention_mask, labels=input_ids)
            loss = outputs.loss
            total_eval_loss += loss.item()
    
    print(f"Average validation loss: {total_eval_loss / len(valid_dataloader)}")
        
print("Training completed successfully!")

Epoch 1/3


: 

### Controlling the text generation

In [None]:
tokenizer = RobertaTokenizerFast.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
model = RobertaForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest").to(torch_device)

## References
[1] A. Fan, M. Lewis, and Y. Dauphin. Hierarchical neural story generation. In ACL 2018 - 56th Annual Meeting of the Association for Computational Linguistics, Proceedings of the Conference (Long Papers), volume 1, page 889 – 898, 2018.