# Experimental implementation of Informal-GPT

This notebook will finetune GPT-2 medium model to talk in informal way.

Will use

- [s-nlp/roberta-base-formality-ranker](https://huggingface.co/s-nlp/roberta-base-formality-ranker) as a reward model
- [gpt2-medium](https://huggingface.co/gpt2-medium) as a autoregressive text generator
- 60% [binhgiangnguyendanh/reddit_casual_conversation_for_alpaca_lora](https://huggingface.co/datasets/binhgiangnguyendanh/reddit_casual_conversation_for_alpaca_lora) and 40% [wikitext](https://huggingface.co/datasets/wikitext) as a text source

## Preparation & Configuration

In [1]:
import datetime
import os
import re
from typing import Any, Dict, List, Optional, Union

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import pandas as pd

# tqdm.pandas()

# Hugging Face transformers
from transformers import pipeline, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from datasets import load_dataset, Dataset, concatenate_datasets

# TRL library
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler

# IDs of models to use
REWARD_MODEL_ID = "s-nlp/roberta-base-formality-ranker"
TEXT_GENERATION_MODEL_ID = "gpt2-medium"

PPO Configurations

In [2]:
GPT_BATCH_SIZE = 330
BERT_BATCH_SIZE = 16

# os.environ["WANDB_PROJECT"] = "informal-gpt"
ppo_config = PPOConfig(
    batch_size=GPT_BATCH_SIZE,
    model_name=TEXT_GENERATION_MODEL_ID,
    learning_rate=1.41e-5,  # NOTE: This parameter is taken from OpenAI's paper: "Fine-Tuning Language Models from Human Preferences"
    log_with="wandb",
)
datetime_str = str(datetime.datetime.now()).replace(" ", "_").replace(":", "-")
checkpoint_dir_name = f"gpt2-medium-and-bert-based-formality-ranker-{datetime_str}"

bert_kwargs = {
    "return_all_scores": True,  # make BERT return scores for all classes, not only the most likely one
    "function_to_apply": "none",
    "batch_size": BERT_BATCH_SIZE,
}

fatal: No names found, cannot describe anything.


connect to wandb

In [3]:
import wandb
wandb.init(
    project="informal-gpt",
    name=checkpoint_dir_name,
)

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mt0d4[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Prepare **wikitext** dataset

define function for dataset preparation

In [4]:
def build_dataset(
    dataset_size: int,
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    input_min_text_length: int = 6,
    input_max_text_length: int = 10,
) -> DataLoader:
    
    ds_reddit = load_dataset("binhgiangnguyendanh/reddit_casual_conversation_for_alpaca_lora", split="train")
    collected_texts = []
    for text in ds_reddit["input"] + ds_reddit["output"]:
        if not collected_texts:
            collected_texts.append(text)
        else:
            previous_text = collected_texts[-1]
            if text[:20] != previous_text[:20]:
                collected_texts.append(text)
    
    # delete mentions like "@user" and resulting empty strings
    split_pattern = re.compile(r'@[^\s]+')
    sanitized_userwise_texts = []
    for whole_text in tqdm(collected_texts):
        userwise_texts = re.split(split_pattern, whole_text)
        userwise_texts = [re.sub(split_pattern, "", t) for t in userwise_texts]
        userwise_texts = [t for t in userwise_texts if not (t.isspace() or t == "")]
        userwise_texts = [t for t in userwise_texts if len(t) > 60]
        sanitized_userwise_texts += userwise_texts
    ds_reddit = Dataset.from_dict({"text": sanitized_userwise_texts})
    ds_reddit = ds_reddit.shuffle()

    ds_wiki = load_dataset("wikitext", "wikitext-103-v1", split="train")
    ds_wiki = ds_wiki.filter(
        function=lambda item: all([ch not in item["text"] for ch in ["@", "(", ")", "<", ">", "/", ";", ":", "-", "="]]),  # only use sentences whose length is more than 150 characters
        batched=False,
        num_proc=8,
    )
    ds_wiki = ds_wiki.shuffle()
    sanitized_wiki_texts = []
    for text in ds_wiki["text"]:
        if not (text.isspace() or text == ""):
            sanitized_wiki_texts.append(text)
    ds_wiki = Dataset.from_dict({"text": sanitized_wiki_texts[:int(len(sanitized_userwise_texts)*0.66)]})

    def tokenize(item):
        """
            item["input_ids"] will be a sequence of first `get_input_size()` tokens converted from the input text.
            item["query"] will be a human-readable raw text decoded from item["input_ids"]
        """
        get_input_size = LengthSampler(
            min_value=input_min_text_length,
            max_value=input_max_text_length,
        )
        item["input_ids"] = tokenizer.encode(text=item["text"], truncation=True, max_length=1024)[:get_input_size()]
        item["query"] = tokenizer.decode(token_ids=item["input_ids"])
        return item
    
    ds_reddit = ds_reddit.map(
        function=tokenize,
        batched=False,
        num_proc=8,
    )
    ds_wiki = ds_wiki.map(
        function=tokenize,
        batched=False,
        num_proc=8,
    )
    
    ds = concatenate_datasets([ds_reddit, ds_wiki])
    ds = ds.shuffle()
    
    ds.set_format(type="torch")
    return ds

instantiate tokenizer and do preprocessing

In [5]:
# reference: imdb dataset's train split contains 24895 rows
DATASET_SIZE = 30000

tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path=ppo_config.model_name,
)
tokenizer.pad_token = tokenizer.eos_token  # use <EOS> token for padding

dataset = build_dataset(
    dataset_size=DATASET_SIZE,
    tokenizer=tokenizer,
)


def dataset_collator(data: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
    # Data is like list of dictionaries.
    # We format them as a batch
    return dict((key, [d[key] for d in data]) for key in data[0])

100%|██████████| 11091/11091 [00:00<00:00, 144726.11it/s]


Map (num_proc=8):   0%|          | 0/9883 [00:00<?, ? examples/s]

Map (num_proc=8):   0%|          | 0/12000 [00:00<?, ? examples/s]

In [6]:
from pprint import pprint
print(dataset)
pprint(dataset[:5])

Dataset({
    features: ['text', 'input_ids', 'query'],
    num_rows: 21883
})
{'input_ids': [tensor([30655,  2142, 26100,  7907, 25370,   254]),
               tensor([  383, 48665,   282, 32536,   286,   262,  7439, 41416]),
               tensor([ 9870,   502,    11, 38074, 15378,    11]),
               tensor([ 818, 1029, 1524,  314, 3214,  287,  351]),
               tensor([  383, 19955, 13399,  1061,   360,  3565,   705,  2878,   870])],
 'query': [' Yugoslav Partisans captured Š',
           ' The Patriarchal Cathedral of the Holy Ascension',
           ' Trust me, Aunt Lisa,',
           'In high school I fell in with',
           " The medieval texts follow Dares'structuring"],
 'text': [' Yugoslav Partisans captured Šibenik and Zadar by 3 November 1944 , '
          'but the war in the Adriatic continued until April 1945 . Allied '
          'destroyers never engaged large Kriegsmarine vessels in the Adriatic '
          'after November 1944 . Dwindling German naval assets 

## Load pre-trained text generation models

In [7]:
model = AutoModelForCausalLMWithValueHead.from_pretrained(
    pretrained_model_name_or_path=ppo_config.model_name,
)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    pretrained_model_name_or_path=ppo_config.model_name,
)

## Initialize PPOTrainer

In [8]:
ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=model,
    ref_model=ref_model,
    tokenizer=tokenizer,
    dataset=dataset,
    data_collator=dataset_collator,
    # TODO: use learning rate scheduler
)

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01667026193366231, max=1.0)…

In [9]:
# for batch in ppo_trainer.dataloader:
#     print(batch["input_ids"][0])
#     print(ppo_trainer.generate(batch["input_ids"][0]))
#     break

## Load BERT-based classification model to use as Reward Model

In [10]:
device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
    if torch.cuda.is_available():
        device = 0
    else:
        device = "cpu"  # although this is not feasible
formality_pipe = pipeline(
    task="text-classification",
    model=REWARD_MODEL_ID,
    device=device,
)

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Xformers is not installed correctly. If you want to use memorry_efficient_attention to accelerate training use the following command to install Xformers
pip install xformers.


## Optimize model

output directory setting

In [None]:
checkpoint_dirpath = os.path.join("checkpoints", checkpoint_dir_name)
os.makedirs(checkpoint_dirpath, exist_ok=False)

define generation settings

In [11]:
generation_kwargs = {
    "min_length": -1,
    "top_k": 50,  # NOTE: Here's room for optimization
    "top_p": .95,  # NOTE: Here's room for optimization
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id
}

### Define training loop

In [None]:
output_min_length = 4
output_max_length = 16
output_length_sampler = LengthSampler(
    min_value=output_min_length,
    max_value=output_max_length,
)

EPOCH = 2

batch_count = len(ppo_trainer.dataloader)
save_interval = batch_count // 5
for epoch in range(EPOCH):
    for batch_pos, batch in tqdm(enumerate(ppo_trainer.dataloader)):
        queries_tensor = batch["input_ids"]
        
        # retrieve response from text generator (faster implementation than tutorial)
        # generation_kwargs["max_new_tokens"] = output_max_length
        # completions = ppo_trainer.generate(queries_tensor, **generation_kwargs)
        # responses_tensor = []
        # for i in range(completions.shape[0]):
        #     generation_len = output_length_sampler()
        #     responses_tensor.append(completions[i, :generation_len])
        
        completions_tensor = []
        for query in queries_tensor:
            generation_len = output_length_sampler()
            generation_kwargs["max_new_tokens"] = generation_len
            completion = ppo_trainer.generate(query, **generation_kwargs)
            completions_tensor.append(completion.squeeze()[-generation_len:])
        batch["response"] = [tokenizer.decode(c.squeeze()) for c in completions_tensor]
        
        # compute informality score
        completed_texts = [q + c for q, c in zip(batch["query"], batch["response"])]
        pipe_outputs = formality_pipe(completed_texts, **bert_kwargs)
        # output[0] is a score for "informal"
        rewards = [torch.tensor(output[0]["score"]) for output in pipe_outputs]

        # execute PPO to tune parameter
        stats = ppo_trainer.step(
            queries=queries_tensor,
            responses=completions_tensor,
            scores=rewards,
        )
        ppo_trainer.log_stats(
            stats=stats,
            batch=batch,
            rewards=rewards,
        )

        if batch_pos % save_interval == 0:
            ppo_trainer.accelerator.save_model(
                model=model,
                save_directory=os.path.join(
                    checkpoint_dirpath, f"checkpoint-at-epoch-{epoch}-batch-{batch_pos}"
                )
            )

save the model at the end of training

In [None]:
ppo_trainer.accelerator.save_model(
    model=model,
    save_directory=os.path.join(
        checkpoint_dirpath, f"checkpoint-at-epoch-{epoch}-batch-{batch_pos}"
    )
)
tokenizer.save_pretrained(
    save_directory=os.path.join(checkpoint_dirpath, "checkpoint-last")
)

## try prediction

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from trl import AutoModelForCausalLMWithValueHead
model = AutoModelForCausalLMWithValueHead.from_pretrained(
    pretrained_model_name_or_path="checkpoints/gpt2-medium-and-bert-based-formality-ranker-2023-07-20_01-53-49.109389/checkpoint-last"
)
tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path="checkpoints/gpt2-medium-and-bert-based-formality-ranker-2023-07-20_01-53-49.109389/checkpoint-last"
)

gpt_generation_pipe = pipeline(
    task="text-generation",
    model=model,
    tokenizer=tokenizer,
    top_p=0.95,
    top_k=50,
    max_length=50,
)

In [None]:
gpt_generation_pipe("He is ")