# Alligned Instruction Finetuned T5 Model with human values [Lower Toxicity] using RLHF + PEFT

1. Using hate speech reward model by Facebook
2. Using PPO to finetune and reduce toxicity

In [3]:
# #install libs

# %pip install --upgrade pip
# %pip install --disable-pip-version-check \
#     torch==1.13.1 \
#     torchdata==0.5.1 --quiet

# %pip install \
#     transformers==4.27.2 \
#     datasets==2.11.0 \
#     evaluate==0.4.0 \
#     rouge_score==0.1.2 \
#     loralib==0.1.1 \
#     peft==0.3.0 --quiet

# %pip install trl==0.4.4 --quiet  # RLHF lib provided by HF

Note: you may need to restart the kernel to use updated packages.


In [1]:
# import libs
from datasets import load_dataset
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, AutoModelForSequenceClassification
from peft import PeftModel, PeftConfig, LoraConfig, TaskType

from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler

import torch
import time
import evaluate
import pandas as pd
import numpy as np

TIME=str(int(time.time()))

from tqdm import tqdm
tqdm.pandas()

  from .autonotebook import tqdm as notebook_tqdm


In [89]:
torch.manual_seed(42)

<torch._C.Generator at 0x7faa29d665f0>

In [2]:
#add wandb logging
# %pip install wandb
import wandb
wandb.init(project="T5-RLHF-peft-finetuning")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mshubsoni[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
model_name = "google/flan-t5-base"
huggingface_dataset_name = "knkarthick/dialogsum"

dataset_original = load_dataset(huggingface_dataset_name)
dataset_original


Downloading readme: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.65k/4.65k [00:00<00:00, 8.01MB/s]


Downloading and preparing dataset csv/knkarthick--dialogsum to /home/azureuser/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-cd36827d3490488d/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...


Downloading data files:   0%|                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    | 0/3 [00:00<?, ?it/s]

Downloading data: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11.3M/11.3M [00:00<00:00, 22.5MB/s]


Dataset csv downloaded and prepared to /home/azureuser/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-cd36827d3490488d/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1. Subsequent calls will reuse this data.


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 199.55it/s]


DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 12460
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 1500
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 500
    })
})

In [4]:
# utility functions

def print_number_of_trainable_parameters(model):
    trainable_params =0
    all_model_params =0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_params+=param.numel()
    return f"trainable model parameters: {trainable_params}\nall model parameters: {all_model_params} \n. percentage trainable parameters: {trainable_params*(100)/all_model_params:.2f}%"

Load and Build dataset

In [6]:
def build_dataset(model_name, dataset_name, input_min_length, input_max_length):
    
    dataset = load_dataset(dataset_name, split="train")
    

    #use only dialogues of fixed length
    dataset = dataset.filter(lambda x: len(x["dialogue"]) > input_min_length and len(x["dialogue"]) <= input_max_length, batched= False)

    tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")

    def tokenize(sample):

        prompt = f"""
Summarize the following conversation.

{sample["dialogue"]}

Summary:
"""
        sample["input_ids"] = tokenizer.encode(prompt)
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample
    
    dataset = dataset.map(tokenize, batched=False)
    dataset.set_format(type="torch")

    dataset_splits = dataset.train_test_split(test_size=0.2, shuffle=False, seed =42)

    return dataset_splits

dataset = build_dataset(model_name=model_name, dataset_name=huggingface_dataset_name, input_max_length=1000, input_min_length=200)
dataset

Found cached dataset csv (/home/azureuser/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-cd36827d3490488d/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)
Filter:   0%|                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 8017
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 2005
    })
})

Load the PEFT model trained earlier

In [7]:
peft_model_path = './peft-dialogue-summary-checkpoint-1695920559'

lora_config = LoraConfig(
    r=8,
    lora_alpha=128,
    target_modules=["q","v"],
    lora_dropout=0.05,
    bias = "none",
    task_type= TaskType.SEQ_2_SEQ_LM,)

original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype = torch.bfloat16)

peft_model = PeftModel.from_pretrained(
        original_model,
        peft_model_path,
        lora_config = lora_config,
        torch_dtype = torch.bfloat16,
        device_map ="auto",
        is_trainable=True
)


print_number_of_trainable_parameters(peft_model)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


'trainable model parameters: 884736\nall model parameters: 248462592 \n. percentage trainable parameters: 0.36%'

In [8]:
# prepare model for PPO training
# value head addition convert a output token to scaler values needed for PPO training

ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,
                                                            torch_dtype= torch.bfloat16,
                                                            is_trainable=True)

print_number_of_trainable_parameters(ppo_model)
# print(ppo_model.v_head)                                    

'trainable model parameters: 885505\nall model parameters: 248463361 \n. percentage trainable parameters: 0.36%'

In [9]:
# create a reference model berore PPO training to fix reward hacking happening during RLHF
ref_model = create_reference_model(ppo_model)
print_number_of_trainable_parameters(ref_model) # reference model is not trainable

'trainable model parameters: 0\nall model parameters: 248463361 \n. percentage trainable parameters: 0.00%'

Prepare Reward Model

Using Facebook's roberta based hate speech model for reward model

In [87]:
toxicity_model_name= "DaNLP/da-electra-hatespeech-detection"
# toxicity_model_name = "nicholasKluge/ToxicityModel"
toxicity_tokenizer = AutoTokenizer.from_pretrained(toxicity_model_name, device_map ="cuda")
toxicity_model = AutoModelForSequenceClassification.from_pretrained(toxicity_model_name)
print(toxicity_model.config.id2label)

{0: 'not offensive', 1: 'offensive'}


In [68]:
# how to use reward model
toxicity_model = toxicity_model.to("cuda")
non_toxic_text = "love you"
print(non_toxic_text)
toxicity_input_ids = toxicity_tokenizer(non_toxic_text, return_tensors="pt").input_ids.to("cuda")
logits = toxicity_model(input_ids = toxicity_input_ids).logits
prob = logits.softmax(dim=-1).tolist()[0]
print(f"probs  [not offensive, offensive]: {prob}")


love you
probs  [not offensive, offensive]: [0.9983037710189819, 0.0016961833462119102]


In [88]:
#use a HF inference pipeline for reward getting

device = 0 if torch.cuda.is_available() else "cpu"

reward_pipe = pipeline("sentiment-analysis",
                        model = toxicity_model,
                        tokenizer=toxicity_tokenizer,
                        device =0)

reward_logits_kwargs = {
    "tok_k": None, # return all scores
    "function_to_apply": "none",
    "batch_size":16
}

reward_probs_kwargs = {
    "tok_k": None, # return all scores
    "function_to_apply": "softmax",
    "batch_size":16
}

print(f"reward pipe output for {non_toxic_text}, {reward_pipe(non_toxic_text,reward_probs_kwargs)}")

Ignoring args : ({'tok_k': None, 'function_to_apply': 'softmax', 'batch_size': 16},)


reward pipe output for love you, [{'label': 'not offensive', 'score': 0.9983037710189819}]


use evaluate lib for toxicity measurement for the model, it directly evaluates the model on toxicity


In [71]:
toxicity_model.config.label2id

{'not offensive': 0, 'offensive': 1}

In [82]:

non_toxic_text = "love you"
toxicity_evaluator = evaluate.load("toxicity", 'DaNLP/da-electra-hatespeech-detection', module_type="measurement")
toxicity_score = toxicity_evaluator.compute(predictions=[non_toxic_text], toxic_label = 'offensive')
toxicity_score

{'toxicity': [0.0016961980145424604]}

In [85]:
# deffine a function to calculate the mean toxicity score for a model

def evaluate_toxicity(model,
                    toxicity_evaluator,
                    dataset,
                    tokenizer,
                    num_samples):
    
    max_new_tokens = 200

    # dataset = dataset.filter(lambda x: x["label"] == 1)

    toxicities = []
    input_text = []

    for i, sample in tqdm(enumerate(dataset)):
        input_text = sample["query"]

        if i> num_samples:
            break

        input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")

        generation_config = GenerationConfig(max_new_tokens = max_new_tokens,
                                             top_k=0.0,
                                             top_p=1.0,
                                             do_sample=True)

        response_token_ids = model.generate(input_ids=input_ids,
                                            generation_config=generation_config)

        generated_text = tokenizer.decode(response_token_ids[0], skip_special_tokens=True)

        toxicity_score = toxicity_evaluator.compute(predictions=[(input_text + " "+ generated_text)], toxic_label = 'offensive')                                                       

        toxicities.extend(toxicity_score["toxicity"])

    
    mean = np.mean(toxicities)
    std = np.std(toxicities)

    return mean, std


In [99]:
tokenizer = AutoTokenizer.from_pretrained(model_name,device_map = "auto")
ds = load_dataset("OxAISH-AL-LLM/wiki_toxic", split="test")
mean_toxicity_before_ppo, std_toxicity_before_ppo = evaluate_toxicity(model = ref_model,
                                                                      toxicity_evaluator = toxicity_evaluator,
                                                                      tokenizer=toxicity_tokenizer,
                                                                      dataset =dataset["test"],
                                                                      num_samples =10)

print(f"mean, std toxicity before ppo {mean_toxicity_before_ppo, std_toxicity_before_ppo}")                                                                      

Found cached dataset wiki_toxic (/home/azureuser/.cache/huggingface/datasets/OxAISH-AL-LLM___wiki_toxic/default/1.0.0/09a67129f85f67f22107b0190f7c32050ef0dce44afeedc6e3e0ab7ab3bd709c)
11it [01:19,  7.26s/it]

mean, std toxicity before ppo (0.18388001654635777, 0.26625992794405157)





Perform RLHF  to detoxify the summaries

In [100]:
# initializer the PPO trainer

learning_rate = 5.41e-5
max_ppo_epochs = 1
mini_batch_size = 4
batch_size = 16

config = PPOConfig(
    model_name = model_name,
    learning_rate = learning_rate,
    ppo_epochs= max_ppo_epochs,
    mini_batch_size= mini_batch_size,
    batch_size = batch_size
)


def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

test_data = [{"key1":"value1", "key2":"value2"},{"key1":"value3", "key2":"value4"}]
collator(test_data)

{'key1': ['value1', 'value3'], 'key2': ['value2', 'value4']}

In [92]:
ppo_trainer = PPOTrainer(config=config,
                         model=ppo_model,
                         ref_model=ref_model,
                            tokenizer=tokenizer,
                            dataset=dataset["train"],
                            data_collator=collator)

Obserbe the RL metrics

1. objective/kl
2. ppo/returns/mean
3. ppo/policy/advnatages_mean

In [102]:
output_min_length = 100
output_max_length = 400

output_length_sampler = LengthSampler(output_min_length,
                                      output_max_length)

generation_kwargs = {
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
}

reward_kwargs = {
    "top_k": None, # return all scores
    "function_to_apply": "softmax",
    "batch_size":16
}

max_ppo_steps = 30

for steps, batch in tqdm(enumerate(ppo_trainer.dataloader)):

    if steps > max_ppo_steps:
        break

    prompt_tensors = batch["input_ids"]
    summary_tensors = []

    for prompt_tensor in prompt_tensors:
        max_new_tokens = output_length_sampler()
        generation_kwargs["max_new_tokens"] = max_new_tokens
        summary = ppo_trainer.generate(prompt_tensor,
                                                    **generation_kwargs)

        summary_tensors.append(summary.squeeze()[-max_new_tokens:])

    batch["response"] = [tokenizer.decode(summary_tensor, skip_special_tokens=True) for summary_tensor in summary_tensors]

    query_response_pairs = [ q+r for q,r in zip(batch["query"], batch["response"])]

    rewards = reward_pipe(query_response_pairs, **reward_kwargs)
    reward_tensors = [torch.tensor(reward[0]["score"]) for reward in rewards]
    stats = ppo_trainer.step(prompt_tensors, summary_tensors, reward_tensors)
    ppo_trainer.log_stats(stats, batch, reward_tensors)

    print(f'objective/kl: {stats["objective/kl"]}')
    print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}')
    print(f'ppo/policy/advantages_mean: {stats["ppo/policy/advantages_mean"]}')
    print('-'.join('' for x in range(79)))





1it [00:18, 18.02s/it]

objective/kl: 15.268651962280273
ppo/returns/mean: -0.6054368019104004
ppo/policy/advantages_mean: -3.824568217680735e-09
------------------------------------------------------------------------------


2it [00:31, 15.40s/it]

objective/kl: 14.318596839904785
ppo/returns/mean: -0.5726776123046875
ppo/policy/advantages_mean: -1.096756996332715e-08
------------------------------------------------------------------------------


3it [00:49, 16.57s/it]

objective/kl: 18.28998565673828
ppo/returns/mean: -0.7864238023757935
ppo/policy/advantages_mean: -2.9557245539990618e-09
------------------------------------------------------------------------------


4it [01:05, 16.49s/it]

objective/kl: 19.34527015686035
ppo/returns/mean: -0.9329463839530945
ppo/policy/advantages_mean: -2.053383951761134e-08
------------------------------------------------------------------------------


5it [01:20, 15.82s/it]

objective/kl: 12.339771270751953
ppo/returns/mean: -0.4130007028579712
ppo/policy/advantages_mean: -1.1378542552620274e-08
------------------------------------------------------------------------------


6it [01:35, 15.68s/it]

objective/kl: 12.884422302246094
ppo/returns/mean: -0.44335299730300903
ppo/policy/advantages_mean: 1.3281545641063985e-08
------------------------------------------------------------------------------


7it [01:51, 15.47s/it]

objective/kl: 16.919139862060547
ppo/returns/mean: -0.9352494478225708
ppo/policy/advantages_mean: 2.514263464092892e-09
------------------------------------------------------------------------------


8it [02:06, 15.55s/it]

objective/kl: 15.680459022521973
ppo/returns/mean: -0.7123470306396484
ppo/policy/advantages_mean: -2.0349098406313715e-08
------------------------------------------------------------------------------


9it [02:21, 15.21s/it]

objective/kl: 12.17045783996582
ppo/returns/mean: -0.4999503791332245
ppo/policy/advantages_mean: -1.4291950733991143e-08
------------------------------------------------------------------------------


10it [02:34, 14.48s/it]

objective/kl: 11.800167083740234
ppo/returns/mean: -0.4948999285697937
ppo/policy/advantages_mean: 2.8503697180326526e-09
------------------------------------------------------------------------------


11it [02:46, 13.80s/it]

objective/kl: 11.169166564941406
ppo/returns/mean: -0.5076763033866882
ppo/policy/advantages_mean: 4.745400516981135e-09
------------------------------------------------------------------------------


12it [03:00, 13.85s/it]

objective/kl: 15.693375587463379
ppo/returns/mean: -0.7878344058990479
ppo/policy/advantages_mean: 7.784768385477037e-09
------------------------------------------------------------------------------


13it [03:13, 13.65s/it]

objective/kl: 13.772385597229004
ppo/returns/mean: -0.5478227138519287
ppo/policy/advantages_mean: 8.02979105429813e-09
------------------------------------------------------------------------------


14it [03:25, 13.04s/it]

objective/kl: 7.439396858215332
ppo/returns/mean: -0.10541556775569916
ppo/policy/advantages_mean: 1.1171917613239657e-08
------------------------------------------------------------------------------


15it [03:37, 12.96s/it]

objective/kl: 14.839390754699707
ppo/returns/mean: -0.6161832809448242
ppo/policy/advantages_mean: -2.9621382680034003e-08
------------------------------------------------------------------------------


16it [03:50, 12.91s/it]

objective/kl: 10.867485046386719
ppo/returns/mean: -0.4072226881980896
ppo/policy/advantages_mean: 5.745025788428393e-09
------------------------------------------------------------------------------


17it [04:04, 13.08s/it]

objective/kl: 12.01366901397705
ppo/returns/mean: -0.39700138568878174
ppo/policy/advantages_mean: -2.939279486469104e-09
------------------------------------------------------------------------------


18it [04:16, 12.82s/it]

objective/kl: 13.10476016998291
ppo/returns/mean: -0.593254566192627
ppo/policy/advantages_mean: -1.5807341213758264e-08
------------------------------------------------------------------------------


19it [04:28, 12.72s/it]

objective/kl: 12.604625701904297
ppo/returns/mean: -0.40113598108291626
ppo/policy/advantages_mean: 1.677533312260948e-08
------------------------------------------------------------------------------


20it [04:40, 12.56s/it]

objective/kl: 11.573077201843262
ppo/returns/mean: -0.3285127878189087
ppo/policy/advantages_mean: 1.2797562121136252e-08
------------------------------------------------------------------------------


21it [04:53, 12.41s/it]

objective/kl: 9.786053657531738
ppo/returns/mean: -0.26008498668670654
ppo/policy/advantages_mean: 2.240621910232221e-08
------------------------------------------------------------------------------


22it [05:04, 12.25s/it]

objective/kl: 13.233003616333008
ppo/returns/mean: -0.43858084082603455
ppo/policy/advantages_mean: 5.508438150059192e-10
------------------------------------------------------------------------------


23it [05:17, 12.42s/it]

objective/kl: 10.77065372467041
ppo/returns/mean: -0.2846233546733856
ppo/policy/advantages_mean: -2.1688741469461092e-08
------------------------------------------------------------------------------


24it [05:28, 11.92s/it]

objective/kl: 10.439411163330078
ppo/returns/mean: -0.369087278842926
ppo/policy/advantages_mean: 8.999259115682889e-10
------------------------------------------------------------------------------


25it [05:40, 11.95s/it]

objective/kl: 9.204517364501953
ppo/returns/mean: -0.24764466285705566
ppo/policy/advantages_mean: -8.128289152864454e-10
------------------------------------------------------------------------------


26it [05:52, 12.00s/it]

objective/kl: 10.753318786621094
ppo/returns/mean: -0.3263453543186188
ppo/policy/advantages_mean: 2.3738701671049967e-08
------------------------------------------------------------------------------


27it [06:03, 11.54s/it]

objective/kl: 10.338382720947266
ppo/returns/mean: -0.28389811515808105
ppo/policy/advantages_mean: -2.7425032911310154e-08
------------------------------------------------------------------------------


28it [06:14, 11.58s/it]

objective/kl: 11.250727653503418
ppo/returns/mean: -0.3541765511035919
ppo/policy/advantages_mean: -3.6888108123633856e-09
------------------------------------------------------------------------------


29it [06:25, 11.33s/it]

objective/kl: 9.893604278564453
ppo/returns/mean: -0.2711153030395508
ppo/policy/advantages_mean: 8.724278188765311e-09
------------------------------------------------------------------------------


30it [06:36, 11.29s/it]

objective/kl: 11.727471351623535
ppo/returns/mean: -0.36738479137420654
ppo/policy/advantages_mean: -6.823546616629983e-09
------------------------------------------------------------------------------


31it [06:49, 13.21s/it]

objective/kl: 10.808866500854492
ppo/returns/mean: -0.3813740611076355
ppo/policy/advantages_mean: 1.4115292934491208e-08
------------------------------------------------------------------------------





Compare the model qualitatively



In [109]:

# evaluate the toxicity after ppo

mean_toxicity_after_ppo, std_toxicity_after_ppo = evaluate_toxicity(model = ppo_model,
                                                                        toxicity_evaluator = toxicity_evaluator,
                                                                        tokenizer=toxicity_tokenizer,
                                                                        dataset =dataset["test"],
                                                                        num_samples =5)


print(f"mean, std toxicity after ppo {mean_toxicity_after_ppo, std_toxicity_after_ppo}")
print("percentage improvement in toxicity", ((mean_toxicity_before_ppo - mean_toxicity_after_ppo)*100/mean_toxicity_before_ppo), "%")

6it [00:43,  7.26s/it]

mean, std toxicity after ppo (0.1282393615692854, 0.09117653166753624)
percentage improvement in toxicity 30.25921795207414 %





In [114]:
# save the ppo model
ppo_model_path = f"./ppo-dialogue-summary-checkpoint-RLHF-{TIME}"
tokenizer.save_pretrained(ppo_model_path)
ppo_model.save_pretrained(ppo_model_path)