In [None]:
!pip install -q datasets peft trl bitsandbytes accelerate wandb sentencepiece ml_dtypes
!pip install -q typing-extensions --upgrade

In [None]:
!pip install -q git+https://github.com/huggingface/transformers

In [None]:
import torch
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import (
    Adafactor,
    AutoTokenizer,
    LlamaTokenizer,
    HfArgumentParser,
    pipeline
)
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
from trl.core import LengthSampler
from huggingface_hub import notebook_login

In [None]:
notebook_login()

In [None]:
# DEFAULT_PAD_TOKEN = "[PAD]"
# DEFAULT_EOS_TOKEN = "</s>"
# DEFAULT_BOS_TOKEN = "</s>"
# DEFAULT_UNK_TOKEN = "</s>"

tqdm.pandas()

**Models and datasets**

In [None]:
dataset_name = "berkeley-nest/Nectar"
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
reward_model_name = "Nexusflow/Starling-RM-34B"

In [None]:
def preprocess_function(examples):
    new_examples = {
        "query": [],
        "input_ids": [],
    }
    for question in examples["prompt"]:
        query = "Question: " + question + "\n\nAnswer: "
        tokenized_question = tokenizer(query, truncation=True)
        new_examples["query"].append(query)
        new_examples["input_ids"].append(tokenized_question["input_ids"])

    return new_examples

In [None]:
train_dataset = load_dataset(dataset_name, split="train")
original_columns = train_dataset.column_names
train_dataset

In [None]:
from datasets import Dataset
train_dataset.set_format('pandas')
train_dataset = train_dataset[:1101]
train_dataset = Dataset.from_pandas(train_dataset)
train_dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
if getattr(tokenizer, "pad_token", None) is None:
    tokenizer.pad_token = tokenizer.eos_token

In [None]:
tokenizer.eos_token_id

In [None]:
ds = train_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=original_columns,
    )
ds = ds.filter(lambda x: len(x["input_ids"]) < 512, batched=False)
ds.set_format(type="torch")

In [None]:
ds

**Lora Config**

In [None]:
lora_config = LoraConfig(
    r=32,
    lora_alpha=64,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

In [None]:
model = AutoModelForCausalLMWithValueHead.from_pretrained(
    model_name,
    load_in_4bit=True,
    device_map={"": 0},
    peft_config=lora_config,
)

In [None]:
config = PPOConfig(
    model_name=model_name,
    learning_rate=1.41e-5,
    log_with='wandb',
    batch_size = 1,
    mini_batch_size=1,
    gradient_accumulation_steps=1,
    ppo_epochs=1
    # steps=1080
)

In [None]:
# rw_kwargs = {
#     "return_all_scores": True,
#     "function_to_apply": "none",
#     "batch_size": 16,
#     "truncation": True
# }

In [None]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

In [None]:
ppo_trainer = PPOTrainer(
    config,
    model,
    ref_model=None,
    tokenizer=tokenizer,
    dataset=ds,
    data_collator=collator,
)

In [None]:
import os
import torch
from torch import nn
from transformers import AutoTokenizer, LlamaPreTrainedModel,LlamaModel
import math

## Define the reward model function class

class LlamaForSequenceClassification(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.transformer = LlamaModel(config)
        self.v_head = nn.Linear(config.hidden_size, 1, bias=False)
        self.PAD_ID = 0
        # Initialize weights and apply final processing
        self.post_init()

    def get_device(self):
        return self.transformer.device

    def forward(
          self,
          input_ids=None,
          past_key_values=None,
          attention_mask=None,
          position_ids=None,
      ):
          transformer_outputs = self.transformer(
              input_ids,
              attention_mask=attention_mask,
              position_ids=position_ids,
              output_hidden_states=True,
          )
          hidden_states = transformer_outputs.hidden_states[-1]
          scores = []
          rewards = self.v_head(hidden_states).squeeze(-1)
          bs = int(input_ids.shape[0])
          for i in range(bs):
              c_inds = (input_ids[i] == self.PAD_ID).nonzero()
              c_ind = c_inds[0].item() if len(c_inds) > 0 else input_ids.shape[1]
              scores.append(rewards[i, c_ind - 1])
          scores = torch.stack(scores)
          return {"scores": scores}

## Load the model and tokenizer

reward_model = LlamaForSequenceClassification.from_pretrained("Nexusflow/Starling-RM-34B", load_in_4bit=True)
reward_tokenizer = AutoTokenizer.from_pretrained("01-ai/Yi-34B-Chat")
reward_tokenizer.truncation_side = "left"

reward_model.eval().requires_grad_(False)

## Define the reward function


In [None]:
reward_model

In [None]:
reward_device = "cuda"
reward_batch_size = 1
def get_reward(samples):
    """samples: List[str]"""
    input_ids = []
    attention_masks = []
    encodings_dict = reward_tokenizer(
        samples,
        truncation=True,
        max_length=2048,
        padding="max_length",
        return_tensors="pt",
    ).to(reward_device)
    input_ids = encodings_dict["input_ids"]
    attention_masks = encodings_dict["attention_mask"]
    mbs = reward_batch_size
    out = []
    for i in range(math.ceil(len(samples) / mbs)):
        rewards = reward_model(input_ids=input_ids[i * mbs : (i + 1) * mbs], attention_mask=attention_masks[i * mbs : (i + 1) * mbs])
        out.extend(rewards["scores"])
    return torch.hstack(out)

In [None]:
test_sample = ["<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>"]
reward_for_test_sample = get_reward(test_sample)
print(reward_for_test_sample)
torch.mean(torch.tensor(reward_for_test_sample[0].item())).cpu().numpy()

In [None]:
generation_kwargs = {
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.pad_token_id,
    "eos_token_id": tokenizer.eos_token_id,
}
output_min_length = 32
output_max_length = 256
output_length_sampler = LengthSampler(output_min_length, output_max_length)

In [None]:
# #monitering login
# wandb.login(key="cb6a8e776ebf15749aef8317fc520c1bc4580ec0")
# run = wandb.init(project='JSL-MedMNX-7B', job_type="training", anonymous="allow")

In [None]:
save_freq = 500
output_dir = "./llama-3-ppo"
reward_baseline = 0.0

In [None]:
# for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
#     question_tensors = batch["input_ids"]

#     response_tensors = ppo_trainer.generate(
#         question_tensors,
#         return_prompt=False,
#         length_sampler=output_length_sampler,
#         **generation_kwargs,
#     )
#     batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)

#     # Compute sentiment score
#     texts = [q + r for q, r in zip(batch["query"], batch["response"])]
#     reward_outputs = reward_model(texts, **rw_kwargs)
#     rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in reward_outputs]

#     # Run PPO step
#     stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
#     ppo_trainer.log_stats(stats, batch, rewards)

#     if script_args.save_freq and epoch and epoch % script_args.save_freq == 0:
#         ppo_trainer.save_pretrained(script_args.output_dir + f"step_{epoch}")

In [None]:
epochs = 1
for epoch in tqdm(range(epochs), "epoch: "):
    for batch in tqdm(ppo_trainer.dataloader):
        question_tensors = batch["input_ids"]

        response_tensors = ppo_trainer.generate(
            question_tensors,
            return_prompt=False,
            length_sampler=output_length_sampler,
            **generation_kwargs,
        )

        batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)

        # Compute sentiment score
        texts = [q + r for q, r in zip(batch["query"], batch["response"])]
        reward_outputs = get_reward(texts)
        rewards = [reward_outputs.to(torch.float32) for output in reward_outputs]
        # Run PPO step
        stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
        ppo_trainer.log_stats(stats, batch, rewards)

        if save_freq and epoch and epoch % save_freq == 0:
            ppo_trainer.save_pretrained(output_dir + f"step_{epoch}")