In [1]:
pip install -U transformers datasets evaluate nltk accelerate trl wandb peft rouge_score

Collecting wandb
  Obtaining dependency information for wandb from https://files.pythonhosted.org/packages/45/e0/7ba3b78a74413b7467300cb7a5d486b9871ee464a7cade98ea869d3ca3df/wandb-0.19.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
  Downloading wandb-0.19.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting pydantic<3,>=2.6 (from wandb)
  Obtaining dependency information for pydantic<3,>=2.6 from https://files.pythonhosted.org/packages/62/51/72c18c55cf2f46ff4f91ebcc8f75aa30f7305f3d726be3f4ebffb4ae972b/pydantic-2.10.3-py3-none-any.whl.metadata
  Downloading pydantic-2.10.3-py3-none-any.whl.metadata (172 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m172.0/172.0 kB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
Collecting annotated-types>=0.6.0 (from pydantic<3,>=2.6->wandb)
  Obtaining dependency information for annotated-types>=0.6.0 from https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e

In [6]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 
from peft import PeftModel, PeftConfig, LoraConfig, TaskType, get_peft_model

model_name = "google/flan-t5-base"

model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer.truncate_side = "left"
model.config.max_length = 256

In [7]:
from datasets import load_dataset
import numpy as np

dataset = load_dataset("Jise/hh-rlhf-helpful-base")

def preprocess(examples):
    inputs = ["".join([m["role"].capitalize() + m["content"] + "\n\n" for m in x]) + "Assistant:" for x in examples["prompt"]]
    model_inputs = tokenizer(inputs, max_length=511, truncation=True)
    labels = [x[0]["content"] for x in examples["chosen"]]
    labels = tokenizer(text_target=labels, max_length=256, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

dataset = dataset.map(preprocess, batched=True)
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['chosen', 'rejected', 'prompt', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 43835
    })
    test: Dataset({
        features: ['chosen', 'rejected', 'prompt', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 2354
    })
})


In [8]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
from trl import DPOTrainer, DPOConfig
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
import requests
import pickle
from transformers.optimization import Adafactor, AdafactorSchedule
import wandb
import nltk
import evaluate

nltk.download("punkt", quiet=True)
nltk.download("punkt_tab", quiet=True)
metric = evaluate.load("rouge")

run_name = "Flan-T5_DPO_LoRA_HH-RLHF"

with open("TOKENS.pkl", "rb") as f:
    TOKENS = pickle.load(f)

WANDB_TOKEN = TOKENS["WANDB_TOKEN"]
HF_TOKEN = TOKENS["HF_TOKEN"]

wandb.login(key=WANDB_TOKEN)

def compute_metrics(eval_preds):
   preds, labels = eval_preds

   # decode preds and labels
   labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
   decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
   decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

   # rougeLSum expects newline after each sentence
   decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
   decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

   result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
  
   return result

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

sft_training_args = Seq2SeqTrainingArguments(
    output_dir="./flan-t5-sft-lora",
    eval_strategy="steps",
    eval_steps=1000,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    learning_rate=1e-4,
    num_train_epochs=3,
    logging_dir="./logs",
    logging_steps=100,
    weight_decay=0.01,
    bf16=True,
    predict_with_generate=True,
    push_to_hub=True,
    report_to="wandb",
    run_name=run_name,
    hub_token=HF_TOKEN,
    hub_model_id="Jise/flan-t5-hh-dpo-lora",
    save_safetensors=False,
)

sft_trainer = Seq2SeqTrainer(
    model=model,
    args=sft_training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/idies/.netrc
  sft_trainer = Seq2SeqTrainer(
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
sft_trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Step,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1000,2.457,2.269583,0.199893,0.063281,0.163376,0.177872
2000,2.4187,2.2532,0.194891,0.062697,0.160896,0.174012
3000,2.3289,2.243706,0.208327,0.06526,0.169937,0.185981
4000,2.3473,2.236511,0.205002,0.06521,0.167471,0.18372
5000,2.309,2.230092,0.204255,0.06369,0.166663,0.182348
6000,2.2701,2.229096,0.208357,0.066447,0.16999,0.185868
7000,2.287,2.226198,0.206204,0.065627,0.16839,0.184685
8000,2.2587,2.224727,0.20884,0.066966,0.169965,0.186928


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_clas

In [2]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 
from peft import PeftModel, PeftConfig, LoraConfig, TaskType, get_peft_model
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
from trl import DPOTrainer, DPOConfig
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
import requests
import pickle
from transformers.optimization import Adafactor, AdafactorSchedule
import wandb
import nltk
import evaluate
from datasets import load_dataset
import numpy as np

checkpoint = "./flan-t5-sft-lora/checkpoint-8220"

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

dataset = load_dataset("Jise/hh-rlhf-helpful-base")

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

if tokenizer.chat_template is None:
    tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

nltk.download("punkt", quiet=True)
nltk.download("punkt_tab", quiet=True)
metric = evaluate.load("rouge")

run_name = "Flan-T5_DPO_LoRA_HH-RLHF"

with open("TOKENS.pkl", "rb") as f:
    TOKENS = pickle.load(f)

WANDB_TOKEN = TOKENS["WANDB_TOKEN"]
HF_TOKEN = TOKENS["HF_TOKEN"]

wandb.login(key=WANDB_TOKEN)

lora_config = LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)

peft_model = get_peft_model(model, lora_config)

dpo_training_args = DPOConfig(
    output_dir="./flan-t5-dpo-lora",
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=6,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    learning_rate=1e-5,
    num_train_epochs=3,
    logging_dir="./logs",
    logging_steps=100,
    weight_decay=0.01,
    max_prompt_length=511,
    max_length=256,
    save_total_limit=1,
    truncation_mode='keep_end',
    bf16=True,
    push_to_hub=True,
    report_to="wandb",
    run_name=run_name,
    hub_token=HF_TOKEN,
    hub_model_id="Jise/flan-t5-hh-dpo-lora",
    save_safetensors=False,
)


dpo_trainer = DPOTrainer(
    model=peft_model,
    args=dpo_training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=tokenizer,
)

2024-12-05 04:06:07.555081: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  return self.fget.__get__(instance, owner)()
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mjiseshen[0m ([33mjise[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/idies/.netrc
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [3]:
dpo_trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,Rewards/chosen,Rewards/rejected,Rewards/accuracies,Rewards/margins,Logps/chosen,Logps/rejected,Logits/chosen,Logits/rejected
100,0.7074,0.69252,-0.001593,-0.002884,0.540713,0.001291,-148.483109,-118.399284,-17.018328,-17.126493
200,0.7108,0.690911,-0.005331,-0.010041,0.593723,0.00471,-148.520462,-118.470856,-17.053373,-17.161543
300,0.7014,0.688665,-0.012032,-0.021987,0.600085,0.009955,-148.587494,-118.590332,-17.110165,-17.21841
400,0.6979,0.686078,-0.022928,-0.039478,0.609839,0.01655,-148.696442,-118.765236,-17.188845,-17.298573
500,0.6885,0.6831,-0.039983,-0.064101,0.612383,0.024118,-148.866989,-119.011452,-17.290449,-17.401979
600,0.68,0.680043,-0.059577,-0.092961,0.60687,0.033384,-149.062943,-119.300056,-17.406548,-17.521002
700,0.6817,0.677302,-0.082866,-0.125386,0.613232,0.04252,-149.295837,-119.624313,-17.529215,-17.647436
800,0.6667,0.674662,-0.107003,-0.158802,0.614928,0.051798,-149.537201,-119.958473,-17.641436,-17.763874
900,0.6547,0.673123,-0.138239,-0.197792,0.610687,0.059553,-149.849548,-120.348358,-17.771891,-17.898695
1000,0.6502,0.670639,-0.173015,-0.242611,0.61408,0.069596,-150.197327,-120.79657,-17.897585,-18.029146


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


TrainOutput(global_step=8217, training_loss=0.5805910762040825, metrics={'train_runtime': 43558.6142, 'train_samples_per_second': 3.019, 'train_steps_per_second': 0.189, 'total_flos': 0.0, 'train_loss': 0.5805910762040825, 'epoch': 2.9991787571858746})

In [4]:
results = dpo_trainer.evaluate()
print(results)

trainer.save_model("./flan-t5-dpo-trained")
tokenizer.save_pretrained("./flan-t5-dpo-trained")

{'eval_loss': 0.6549020409584045, 'eval_runtime': 233.4999, 'eval_samples_per_second': 10.081, 'eval_steps_per_second': 1.683, 'eval_rewards/chosen': -1.223189353942871, 'eval_rewards/rejected': -1.6155133247375488, 'eval_rewards/accuracies': 0.6458864212036133, 'eval_rewards/margins': 0.39232388138771057, 'eval_logps/chosen': -160.6990509033203, 'eval_logps/rejected': -134.52557373046875, 'eval_logits/chosen': -20.92389488220215, 'eval_logits/rejected': -21.557830810546875, 'epoch': 2.9991787571858746}


NameError: name 'trainer' is not defined