
## RLHF

* Ref: OpenAI, CarperAI, HF
* GPT-J with pre-trained reward model


In [2]:

import os

import torch
from torch import nn
from datasets import load_dataset
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer, Trainer, TrainingArguments
from transformers import AutoModelForCausalLM

import random
import numpy as np




## Utility Functions


In [3]:

def create_comparison_dataset(path="CarperAI/openai_summarize_comparisons", split="train"):

    dataset = load_dataset(path, split=split)
    if split == "test":
        dataset = dataset.select(range(5000))
 
    pairs = []
    for sample in tqdm(dataset):
        pair = {}
        prompt = sample["prompt"]
        chosen_summary = sample["chosen"]
        rejected_summary = sample["rejected"]
        if chosen_summary == rejected_summary:
            continue
        if len(chosen_summary.split()) < 5 or len(rejected_summary.split()) < 5:
            continue
        pair["chosen"] = prompt + "\n" + chosen_summary
        pair["rejected"] = prompt + "\n" + rejected_summary
        pairs.append(pair)
    return pairs


In [4]:

def compute_metrics(eval_preds):
    chosen_end_scores = eval_preds.predictions[0]  # chosen scores
    rejected_end_scores = eval_preds.predictions[1]  # rejected scores

    result = {}
    acc = sum(chosen_end_scores > rejected_end_scores) / len(rejected_end_scores)
    result["accuracy"] = acc

    return result


In [5]:

def set_seed(seed_val=42):
    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    ## torch.cuda.manual_seed_all(seed_val) for GPU



## Classes


In [6]:



class GPTRewardModel(nn.Module):
    def __init__(self, model_path):
        super().__init__()
        model = AutoModelForCausalLM.from_pretrained(model_path)
        self.config = model.config
        # `gpt-neo(x)` models use `hidden_size` attribute names instead of `n_embd``
        self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd
        self.transformer = model.transformer
        self.v_head = nn.Linear(self.config.n_embd, 1, bias=False)
        self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.PAD_ID = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0]

    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        mc_token_ids=None,
        labels=None,
        return_dict=False,
        output_attentions=False,
        output_hidden_states=False,
    ):
        loss = None
        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )

        hidden_states = transformer_outputs[0]

        rewards = self.v_head(hidden_states).squeeze(-1)
        chosen_end_scores = []
        rejected_end_scores = []

        # Split the inputs and rewards into two parts, chosen and rejected
        assert len(input_ids.shape) == 2
        bs = input_ids.shape[0] // 2
        chosen = input_ids[:bs]
        rejected = input_ids[bs:]
        chosen_rewards = rewards[:bs]
        rejected_rewards = rewards[bs:]

        loss = 0
        inference = False
        for i in range(bs):
            if torch.all(torch.eq(chosen[i], rejected[i])).item():
                c_inds = (chosen[i] == self.PAD_ID).nonzero()
                c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen.shape[1]
                chosen_end_scores.append(chosen_rewards[i, c_ind - 1])
                inference = True
                continue

            # Check if there is any padding otherwise take length of sequence
            c_inds = (chosen[i] == self.PAD_ID).nonzero()
            c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen.shape[1]
            r_inds = (rejected[i] == self.PAD_ID).nonzero()
            r_ind = r_inds[0].item() if len(r_inds) > 0 else rejected.shape[1]
            end_ind = max(c_ind, r_ind)

            # Retrieve first index where trajectories diverge
            divergence_ind = (chosen[i] != rejected[i]).nonzero()[0]
            assert divergence_ind > 0

            # Index into the correct rewards
            c_truncated_reward = chosen_rewards[i][divergence_ind:end_ind]
            r_truncated_reward = rejected_rewards[i][divergence_ind:end_ind]

            # Append the last rewards to the list of end scores
            chosen_end_scores.append(c_truncated_reward[-1])
            rejected_end_scores.append(r_truncated_reward[-1])

            # Compute loss based on truncated rewards (ignore padding)
            loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean()
        loss = loss / bs

        if not inference:
            chosen_end_scores = torch.stack(chosen_end_scores)
            rejected_end_scores = torch.stack(rejected_end_scores)

        if inference:
            chosen_end_scores = torch.stack(chosen_end_scores)
            return {"chosen_end_scores": chosen_end_scores}

        return {
            "loss": loss,
            "chosen_end_scores": chosen_end_scores,
            "rejected_end_scores": rejected_end_scores,
        }


In [7]:

class DataCollatorReward:
    def __call__(self, data):
        batch = {}
        batch["input_ids"] = torch.cat([f[0] for f in data] + [f[2] for f in data])
        batch["attention_mask"] = torch.cat([f[1] for f in data] + [f[3] for f in data])
        batch["labels"] = torch.tensor([0] * len(data) + [1] * len(data))
        return batch
    


In [8]:



class PairwiseDataset(Dataset):
    def __init__(self, pairs, tokenizer, max_length):
        self.chosen_input_ids = []
        self.chosen_attn_masks = []
        self.rejected_input_ids = []
        self.rejected_attn_masks = []
        for pair in tqdm(pairs):
            ## for pair in pairs:
            chosen, rejected = pair["chosen"], pair["rejected"]
            chosen_encodings_dict = tokenizer(
                "<|startoftext|>" + chosen + "<|endoftext|>",
                truncation=True,
                max_length=max_length,
                padding="max_length",
                return_tensors="pt",
            )
            rejected_encodings_dict = tokenizer(
                "<|startoftext|>" + rejected + "<|endoftext|>",
                truncation=True,
                max_length=max_length,
                padding="max_length",
                return_tensors="pt",
            )
            if not torch.all(torch.eq(chosen_encodings_dict["input_ids"], rejected_encodings_dict["input_ids"])).item():
                self.chosen_input_ids.append(chosen_encodings_dict["input_ids"])
                self.chosen_attn_masks.append(chosen_encodings_dict["attention_mask"])
                self.rejected_input_ids.append(rejected_encodings_dict["input_ids"])
                self.rejected_attn_masks.append(rejected_encodings_dict["attention_mask"])

    def __len__(self):
        return len(self.chosen_input_ids)

    def __getitem__(self, idx):
        return (
            self.chosen_input_ids[idx],
            self.chosen_attn_masks[idx],
            self.rejected_input_ids[idx],
            self.rejected_attn_masks[idx],
        )





## Tokenizers


In [9]:

tokenizer           = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer.pad_token = tokenizer.eos_token


In [10]:

if not os.path.exists("reward_model_checkpoint"):
    os.mkdir("reward_model_checkpoint")



## Training Arguments


In [11]:


training_args = TrainingArguments(
    output_dir                    ="reward_model_checkpoint/",
    num_train_epochs              = 1,
    logging_steps                 = 10,
    gradient_accumulation_steps   = 4,
    save_strategy                 = "steps",
    evaluation_strategy           = "steps",
    per_device_train_batch_size   = 1,
    per_device_eval_batch_size    = 1,
    eval_accumulation_steps       = 1,
    eval_steps                    = 500,
    save_steps                    = 500,
    warmup_steps                  = 100,
    logging_dir                   = "./logs",
    ## fp16                       = True,
    ## bf16                       = False,
    learning_rate                 = 1e-5,
    ## deepspeed                  = "ds_config_gpt_j.json",
    ## save_total_limit           = 1,
    no_cuda                       = True
)



## Reward Model


In [12]:

# Initialize the reward model from the (supervised) fine-tuned GPT-J

model = GPTRewardModel("CarperAI/openai_summarize_tldr_sft")


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]


## Freeze the first 70% of the hidden layers of the reward model 


In [14]:

layers       = model.transformer.h
num_layers   = len( layers )

num_layers 


28

In [15]:

num_unfrozen = int(0.3 * num_layers)

for layer in layers[:-num_unfrozen]:
    layer.requires_grad_(False)



## Create the comparisons datasets


In [16]:

    
data_path   = "CarperAI/openai_summarize_comparisons"

train_pairs = create_comparison_dataset(data_path, "train" )
val_pairs   = create_comparison_dataset(data_path, "test"  )



Found cached dataset parquet (/Users/user/.cache/huggingface/datasets/CarperAI___parquet/CarperAI--openai_summarize_comparisons-be6a3808a629348d/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)
100%|██████████████████████████████████████████████████████████| 92534/92534 [00:03<00:00, 28733.43it/s]
Found cached dataset parquet (/Users/user/.cache/huggingface/datasets/CarperAI___parquet/CarperAI--openai_summarize_comparisons-be6a3808a629348d/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)
100%|████████████████████████████████████████████████████████████| 5000/5000 [00:00<00:00, 29029.94it/s]



## Make pairwise datasets for training


In [17]:

max_length    = 550

train_dataset = PairwiseDataset(train_pairs, tokenizer, max_length=max_length)
val_dataset   = PairwiseDataset(val_pairs,   tokenizer, max_length=max_length)


100%|████████████████████████████████████████████████████████████| 92534/92534 [02:47<00:00, 553.72it/s]
100%|██████████████████████████████████████████████████████████████| 5000/5000 [00:08<00:00, 565.27it/s]



## Create the collator to gather batches of pairwise comparisons


In [18]:

   
data_collator = DataCollatorReward()



## Train with pure PyTorch and HF trainer


In [19]:

trainer_rc = Trainer(
    model              = model,
    args               = training_args,
    train_dataset      = train_dataset,
    compute_metrics    = compute_metrics,
    eval_dataset       = val_dataset,
    data_collator      = data_collator,
)


In [None]:

trainer_rc.train()



## Inference


In [None]:

tokenizer           = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer.pad_token = tokenizer.eos_token
PAD_ID              = tokenizer(tokenizer.pad_token)["input_ids"][0]


In [None]:

model = GPTRewardModel("CarperAI/openai_summarize_tldr_sft")
model.load_state_dict( torch.load("rm_checkpoint/pytorch_model.bin") )


In [None]:

max_length = 550

val_pairs   = create_comparison_dataset("CarperAI/openai_summarize_comparisons", "test")

dev_dataset = PairwiseDataset(val_pairs, tokenizer, max_length=max_length)



In [None]:

dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=6, collate_fn=DataCollatorReward())


In [None]:

## model.cuda()

model.eval()
model.half()


In [None]:

correct     = 0
chosen_list = []
reject_list = []


In [None]:


with torch.no_grad():
    for step, batch in tqdm(enumerate(dev_dataloader), total=len(dev_dataloader)):
        for x in batch:
            batch[x] = batch[x]           ## .cuda()
        outputs = model(**batch)
        correct += sum(outputs["chosen_end_scores"] > outputs["rejected_end_scores"])
        chosen_list.append(outputs["chosen_end_scores"].cpu())
        reject_list.append(outputs["rejected_end_scores"].cpu())



In [None]:

print("Total accuracy: ", correct / len(dev_dataset))
