In [1]:
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm import tqdm
import pandas as pd
tqdm.pandas()
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler
from transformers import BitsAndBytesConfig
import wandb
from dataset import build_dataset, collator

In [8]:
device = 'cuda' if torch.cuda.is_available() else "cpu" 

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("lvwerra/gpt2-imdb")
tokenizer.pad_token = tokenizer.eos_token

pretrained_model = AutoModelForCausalLM.from_pretrained('gpt2-rlhf',  device_map="auto")
model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model)
model.to(device)

# Load a reference model with a value head
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('reference_model')
ref_model.to(device)

generation_kwargs = {
    "min_length":-1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id
}

sent_kwargs = {
        "return_all_scores": True, 
        "function_to_apply": "none", 
        "batch_size": 16 
}


config = PPOConfig(
    model_name="lvwerra/gpt2-imdb", 
    learning_rate=1.5e-5
)

dataset = build_dataset(config)


sentiment_pipe = pipeline("sentiment-analysis", model="distilbert-imdb", device=device)

ppo_trainer = PPOTrainer(config, model, ref_model=ref_model, 
                         tokenizer=tokenizer, dataset=dataset, 
                         data_collator=collator)



Some weights of the model checkpoint at reference_model were not used when initializing GPT2LMHeadModel: ['v_head.summary.bias', 'v_head.summary.weight']
- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [6]:
def generate_responses(model, ppo_trainer, query_tensors, sample_size, device, generation_kwargs):
    response_tensors_ref, response_tensors = [], []
    for i in range(sample_size):
        gen_len = output_length_sampler()
        output = model.generate(torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), **generation_kwargs).squeeze()[-gen_len:]
        response_tensors_ref.append(output)
        output = ppo_trainer.generate(torch.tensor(query_tensors[i]).to(device), **generation_kwargs).squeeze()[-gen_len:]
        response_tensors.append(output)
    return response_tensors_ref, response_tensors

def decode_responses(tokenizer, response_tensors):
    """Decode the list of response tensors into human-readable text."""
    return [tokenizer.decode(tensor) for tensor in response_tensors]

def reward_scores(sentiment_pipe, queries, responses, sent_kwargs):
    """Calculate sentiment scores for query-response pairs."""
    texts = [q + r for q, r in zip(queries, responses)]
    return [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]


In [18]:
model.eval()
sample_size = 40
output_length_sampler = LengthSampler(4, 16)
dataset.set_format("pandas")
df_batch = dataset[:].sample(sample_size)


result_data = {
    'query': df_batch['query'].tolist(),
    'input_ids': df_batch['input_ids'].tolist()
}


# Generate responses before and after training
response_tensors_ref, response_tensors = generate_responses(ref_model, ppo_trainer, result_data['input_ids'], sample_size, device, generation_kwargs)

# Decode responses
result_data['response (before)'] = decode_responses(tokenizer, response_tensors_ref)
result_data['response (after)'] = decode_responses(tokenizer, response_tensors)

# reward_scores
result_data['rewards (before)'] = reward_scores(sentiment_pipe, result_data['query'], result_data['response (before)'], sent_kwargs)
result_data['rewards (after)'] = reward_scores(sentiment_pipe, result_data['query'], result_data['response (after)'], sent_kwargs)

# Store results in a DataFrame
df_results = pd.DataFrame(result_data)
df_results



Unnamed: 0,query,input_ids,response (before),response (after),rewards (before),rewards (after)
0,Not sure,"[3673, 1654]","the ""best"" sequel, such as the later ""Vampire...","a lot of humour, it is well done and certainl...",0.653392,3.017669
1,This is a totally,"[1212, 318, 257, 6635]",of Hopper which certainly enhance film photog...,that will keep you going...much sweeter than ...,2.806727,2.500447
2,Many of the,"[7085, 286, 262]","in this film are filmed in Pittsburgh, so I c...",and the characters are so funny and tricky to...,-0.207658,2.389933
3,SPOILERS This is,"[4303, 49713, 4877, 770, 318]",an Isla Nui movie,made. These tales of the,0.112259,0.318706
4,I rented,"[40, 26399]","receiving a refund or return claim, and I'm",good! And the soundtrack is fabulous. And the,-0.785264,2.395598
5,Some of those guys,"[4366, 286, 883, 3730]",to my mom! I always thought that Mc,happy story. This film really closes its doors,0.745585,0.517703
6,The only explanation I can,"[464, 691, 7468, 314, 460]","popularity of word 5,",is very enjoyable! It,-0.661681,2.27806
7,Not having read,"[3673, 1719, 1100]",", I can tell you",IDAY HALLOW,-0.50591,-0.639166
8,If you,"[1532, 345]",helpbooks like 'Easy Rider,"delightful, balanced and still stunning",0.093648,3.052445
9,Let's,"[5756, 338]",br />The only thing I,every adventure crime cinema has been,-0.928606,1.602268


In [19]:
print(f"The average reward before RLHF is {round(df_results['rewards (before)'].mean(), 3)}")

The average reward before RLHF is 0.502


In [20]:
print(f"The average reward after RLHF is {round(df_results['rewards (after)'].mean(), 3)}")

The average reward after RLHF is 2.118


In [21]:
df_results.to_csv('ppo_result.csv')