# Reward Modeling to MLM Task

## Introduction

Reward models are typically trained to output one scalar value for each state-action pair. In language modeling, they are typically decoder-only models with a classification head that predicts the reward for a given sequence. In this notebook, we will convert typical training data (that comes in the form of chosen and rejected pairs) and convert it into a finegrained, MLM task. Instead of chosen having the label value of "1" and rejected having the label value of "0", we will intersperse intermediate rewards throughout the sequence.

## Method

1. (Naive) We will first sprinkle the same reward value throughout the sequence. 
2. Try to evaluate each sentence via some automatic metric and use those as the individual rewards. (Likely OmegaPRM.)

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from typing import Optional, cast, Dict, Any

import torch

from omegaconf import DictConfig
from omegaconf import OmegaConf as om

import transformers
from transformers import AutoModel, AutoConfig, AutoTokenizer
import datasets

from src.flex_bert import *
from src.evals.data import *

  from .autonotebook import tqdm as notebook_tqdm
  @custom_fwd
  @custom_bwd


In [2]:
rbench_dataset = datasets.load_dataset("sarahpann/reward_bench_processed")

In [23]:
original_dataset = datasets.load_dataset("sarahpann/processed_skywork")

In [9]:
num_c_chars = 0

for example in rbench_dataset['train']['chosen']:
    split = example.split(" ")
    num_c_chars += len(split)

num_inc_chars = 0

for example in rbench_dataset['train']['rejected']:
    split = example.split(" ")
    num_inc_chars += len(split)

print(num_c_chars / len(rbench_dataset['train']['chosen']))
print(num_inc_chars / len(rbench_dataset['train']['rejected']))

213.62073004099162
169.15850087839158


In [4]:
num_c_chars = 0

for example in original_dataset["train"]["chosen"]:
    split = example.split(" ")
    num_c_chars += len(split)

num_inc_chars = 0

for example in original_dataset["train"]["rejected"]:
    split = example.split(" ")
    num_inc_chars += len(split)

print(num_c_chars / len(original_dataset["train"]["chosen"]))
print(num_inc_chars / len(original_dataset["train"]["rejected"]))


407.8271777707245
439.8684392763367


The average length of a chosen sequence is 407 and 439 for rejected ones. This makes for 4 and 4 intermediate rewards respectively.

In [10]:
def sprinkle_the_same_label(example, freq=100):
    """
    Insert [CLS] then tokenize.

    The number of labels should scale up based on 100 words. If the text < 100 
    words, then the number of labels should be 1.
    """
    chosen = example["chosen"].split(" ")
    rejected = example["rejected"].split(" ")

    if len(chosen) < freq:
        chosen = chosen + ["[CLS]"]
    else:
        for i in range(freq, len(chosen), freq):
            chosen = chosen[:i] + ["[CLS]"] + chosen[i:]

    if len(rejected) < freq:
        rejected = rejected + ["[CLS]"]
    else:
        for i in range(freq, len(rejected), freq):
            rejected = rejected[:i] + ["[CLS]"] + rejected[i:]

    return {"chosen_labeled": " ".join(chosen), 
            "rejected_labeled": " ".join(rejected),
            "num_chosen_labels": len(chosen) // freq,
            "num_rejected_labels": len(rejected) // freq}


In [13]:
rbench_dataset

DatasetDict({
    train: Dataset({
        features: ['chosen', 'rejected', 'og_dataset', 'chosen_labeled', 'rejected_labeled', 'num_chosen_labels', 'num_rejected_labels'],
        num_rows: 5123
    })
})

In [14]:
rbench_dataset['train'] = rbench_dataset['train'].map(sprinkle_the_same_label)

rbench_dataset.push_to_hub("sarahpann/reward_bench_processed_labeled")

Map: 100%|██████████| 5123/5123 [00:00<00:00, 6902.48 examples/s]
Creating parquet from Arrow format: 100%|██████████| 6/6 [00:00<00:00, 46.07ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  1.18it/s]


CommitInfo(commit_url='https://huggingface.co/datasets/sarahpann/reward_bench_processed_labeled/commit/08c8500db2e4faed516a0bc2a7a48ac9cc7a6579', commit_message='Upload dataset', commit_description='', oid='08c8500db2e4faed516a0bc2a7a48ac9cc7a6579', pr_url=None, pr_revision=None, pr_num=None)

In [15]:
rbench_dataset

DatasetDict({
    train: Dataset({
        features: ['chosen', 'rejected', 'og_dataset', 'chosen_labeled', 'rejected_labeled', 'num_chosen_labels', 'num_rejected_labels'],
        num_rows: 5123
    })
})

In [None]:
original_dataset['train'] = original_dataset['train'].map(sprinkle_the_same_label)
original_dataset['test'] = original_dataset['test'].map(sprinkle_the_same_label)

In [26]:
original_dataset.push_to_hub("sarahpann/processed_skywork_labeled")

Creating parquet from Arrow format: 100%|██████████| 35/35 [00:02<00:00, 16.14ba/s]
Creating parquet from Arrow format: 100%|██████████| 35/35 [00:02<00:00, 16.30ba/s]
Uploading the dataset shards: 100%|██████████| 2/2 [00:14<00:00,  7.35s/it]
Creating parquet from Arrow format: 100%|██████████| 8/8 [00:00<00:00, 14.84ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:01<00:00,  1.84s/it]


CommitInfo(commit_url='https://huggingface.co/datasets/sarahpann/processed_skywork_labeled/commit/55f23c17a24fa57d366a8539443f9d981ec0a373', commit_message='Upload dataset', commit_description='', oid='55f23c17a24fa57d366a8539443f9d981ec0a373', pr_url=None, pr_revision=None, pr_num=None)

In [27]:
original_dataset

DatasetDict({
    train: Dataset({
        features: ['chosen', 'rejected', 'chosen_labeled', 'rejected_labeled', 'num_chosen_labels', 'num_rejected_labels'],
        num_rows: 69314
    })
    test: Dataset({
        features: ['chosen', 'rejected', 'chosen_labeled', 'rejected_labeled', 'num_chosen_labels', 'num_rejected_labels'],
        num_rows: 7702
    })
})

In [6]:
tokenizer = transformers.AutoTokenizer.from_pretrained("bclavie/olmo_bert_template")

In [18]:
def tokenize_and_process_ds(examples, tokenizer):
    tokenized_chosen = tokenizer(examples["chosen_labeled"])
    tokenized_rejected = tokenizer(examples["rejected_labeled"])

    chosen_labels = [[-100] * len(example) for example in tokenized_chosen["input_ids"]]
    rejected_labels = [[-100] * len(example) for example in tokenized_rejected["input_ids"]]

    cls_token = 50281

                       
    for i, example in enumerate(tokenized_chosen["input_ids"]):
        for j in range(len(example)):
            if example[j] == cls_token:
                chosen_labels[i][j] = 1
    
    for i, example in enumerate(tokenized_rejected["input_ids"]):
        for j in range(len(example)):
            if example[j] == cls_token:
                rejected_labels[i][j] = 0

    return {"input_ids": tokenized_chosen["input_ids"] + tokenized_rejected["input_ids"], 
            "attention_mask": tokenized_chosen["attention_mask"] + tokenized_rejected["attention_mask"],
            "labels": chosen_labels + rejected_labels}

In [21]:
rm_columns = ["chosen", "rejected", "chosen_labeled", "rejected_labeled", "num_chosen_labels", "num_rejected_labels"]


tokenized_train_ds = original_dataset['train'].map(lambda x: tokenize_and_process_ds(x, tokenizer), batched=True, remove_columns=rm_columns)
tokenized_test_ds = original_dataset['test'].map(lambda x: tokenize_and_process_ds(x, tokenizer), batched=True, remove_columns=rm_columns)

Map: 100%|██████████| 69314/69314 [03:46<00:00, 305.64 examples/s]
Map: 100%|██████████| 7702/7702 [00:25<00:00, 306.22 examples/s]


Just to verify that function quickly!

In [19]:
rm_columns = ["chosen", "rejected", "chosen_labeled", "rejected_labeled", "num_chosen_labels", "num_rejected_labels"]


mini_tokenized_train_ds = original_dataset['train'].select(range(50))
mini_tokenized_test_ds = original_dataset['test'].select(range(50))

mini_tokenized_train = mini_tokenized_train_ds.map(lambda x: tokenize_and_process_ds(x, tokenizer), batched=True, remove_columns=rm_columns)
mini_tokenized_test = mini_tokenized_test_ds.map(lambda x: tokenize_and_process_ds(x, tokenizer), batched=True, remove_columns=rm_columns)

Map: 100%|██████████| 50/50 [00:00<00:00, 206.53 examples/s]
Map: 100%|██████████| 50/50 [00:00<00:00, 208.23 examples/s]


In [18]:
datasets.Dataset.save_to_disk(tokenized_train_ds, "/home/public/span/MATH_DPO/modern_bert_test/bert24/data/train")
datasets.Dataset.save_to_disk(tokenized_test_ds, "/home/public/span/MATH_DPO/modern_bert_test/bert24/data/val")

Saving the dataset (2/2 shards): 100%|██████████| 138628/138628 [00:00<00:00, 273630.82 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 15404/15404 [00:00<00:00, 277124.40 examples/s]


In [37]:
tokenized_train_ds = datasets.load_from_disk("/home/public/span/MATH_DPO/modern_bert_test/bert24/data/train")
tokenized_test_ds = datasets.load_from_disk("/home/public/span/MATH_DPO/modern_bert_test/bert24/data/val")

In [6]:
def consume_prefix_in_state_dict_if_present(
    state_dict, prefix
):
    r"""Strip the prefix in state_dict in place, if any.

    ..note::
        Given a `state_dict` from a DP/DDP model, a local model can load it by applying
        `consume_prefix_in_state_dict_if_present(state_dict, "module.")` before calling
        :meth:`torch.nn.Module.load_state_dict`.

    Args:
        state_dict (OrderedDict): a state-dict to be loaded to the model.
        prefix (str): prefix.
    """
    keys = sorted(state_dict.keys())
    for key in keys:
        if key.startswith(prefix):
            newkey = key[len(prefix) :]
            state_dict[newkey] = state_dict.pop(key)

    # also strip the prefix in metadata if any.
    if "_metadata" in state_dict:
        metadata = state_dict["_metadata"]
        for key in list(metadata.keys()):
            # for the metadata dict, the key can be:
            # '': for the DDP module, which we want to remove.
            # 'module': for the actual model.
            # 'module.xx.xx': for the rest.

            if len(key) == 0:
                continue
            newkey = key[len(prefix) :]
            metadata[newkey] = metadata.pop(key)

In [None]:
original_state_dict = ""
new_state_dict = ""

state_dict = torch.load(original_state_dict)['state']['model']
consume_prefix_in_state_dict_if_present(state_dict, "model.")
torch.save(state_dict, new_state_dict)

In [None]:
with open("/home/public/span/MATH_DPO/modern_bert_test/bert24/yamls/test/sequence_classification_og.yaml") as f:
    yaml_config = om.load(f)

cfg = cast(DictConfig, yaml_config)

In [None]:
model = create_flex_bert_classification(
    num_labels=cfg.model.num_labels,
    pretrained_checkpoint=cfg.model.pretrained_checkpoint,
    model_config=cfg.model.model_config,
    tokenizer_name=cfg.tokenizer_name,
    token_classification=True,
)

In [91]:
model = transformers.AutoModelForTokenClassification.from_pretrained("google-bert/bert-base-uncased", num_labels=2)
model.to("cpu")

Some weights of BertForTokenClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12

In [75]:
tokenized_train_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
tokenized_test_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

In [78]:
mini_tokenized_train.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
mini_tokenized_test.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

In [87]:
tokenizer.decode([101])

'[CLS]'

In [92]:
example = mini_tokenized_train[8]
print(example['input_ids'])
num_cls = sum([1 for i in example['input_ids'] if i == 101])
print(num_cls)
input_ids = torch.tensor(example["input_ids"]).unsqueeze(0)[: , :512]
attention_mask = torch.tensor(example["attention_mask"]).unsqueeze(0)[:, :512]

model(input_ids)

tensor([  101,  1026,  1064,  4088,  1035,  1997,  1035,  3793,  1064,  1028,
         1026,  1064,  2707,  1035, 20346,  1035,  8909,  1064,  1028,  2291,
         1026,  1064,  2203,  1035, 20346,  1035,  8909,  1064,  1028,  6276,
         3716,  3058,  1024,  2285, 16798,  2509,  2651,  3058,  1024,  2656,
        21650, 16798,  2549,  1026,  1064,  1041,  4140,  1035,  8909,  1064,
         1028,  1026,  1064,  2707,  1035, 20346,  1035,  8909,  1064,  1028,
         5310,  1026,  1064,  2203,  1035, 20346,  1035,  8909,  1064,  1028,
         2071,  2017,  4339,  2019,  5385,  1011,  2773, 10061,  2006,  2129,
         8494,  2013, 29530,  2047,  2259,  2003,  1037,  5957,  2083,  2029,
         2000,  2228,  2055, 19483,  8474,  1029,  1026,  1064,  1041,  4140,
         1035,  8909,  1064,  1028,  1026,  1064,  2707,  1035, 20346,  1035,
         8909,  1064,  1028,  3353,  1026,  1064,  2203,  1035, 20346,  1035,
         8909,  1064,  1028,  5121,   999,  2182,  2003,  2019, 

  input_ids = torch.tensor(example["input_ids"]).unsqueeze(0)[: , :512]
  attention_mask = torch.tensor(example["attention_mask"]).unsqueeze(0)[:, :512]


TokenClassifierOutput(loss=None, logits=tensor([[[-0.0625,  0.1349],
         [-0.0397,  0.3458],
         [-0.3241,  0.6699],
         ...,
         [-0.1785, -0.1765],
         [ 0.1751, -0.3413],
         [-0.3007, -0.3201]]], grad_fn=<ViewBackward0>), hidden_states=None, attentions=None)

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from typing import Optional, cast, Dict, Any

import torch

from omegaconf import DictConfig
from omegaconf import OmegaConf as om

import transformers
from transformers import AutoModel, AutoConfig, AutoTokenizer
import datasets

from src.flex_bert import *
from src.evals.data import *

with open("/home/public/span/MATH_DPO/modern_bert_test/bert24/yamls/test/sequence_classification_og.yaml") as f:
    yaml_config = om.load(f)

cfg = cast(DictConfig, yaml_config)

model = create_flex_bert_classification(
    num_labels=cfg.model.num_labels,
    pretrained_checkpoint=cfg.model.pretrained_checkpoint,
    model_config=cfg.model.model_config,
    tokenizer_name=cfg.tokenizer_name,
    token_classification=True,
)