In [None]:
!pip install bert-score

In [23]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"

import argparse
import json
import gc
import datasets
import transformers
import torch
import evaluate
from tqdm import tqdm
import json
import numpy as np
from trl import PPOConfig, PPOTrainer, AutoModelForSeq2SeqLMWithValueHead
from peft import LoraConfig, TaskType, get_peft_model, PeftModel

"""
Train fine tuned T5 model with Proximal Policy Optimization (PPO) algorithm.
"""
parser = argparse.ArgumentParser()
#parser.add_argument("--model_name", type=str, default="t5-small")
#parser.add_argument("--highlight", type=bool, default=True)
#args = parser.parse_args()


bertscore = evaluate.load("bertscore")
average_question_length = 10.0

HIGHLIGHT = True
TOKEN_QUESTION = "<question>"
TOKEN_END_QUESTION = "<question>"
TOKEN_CONTEXT = "<context>"
TOKEN_END_CONTEXT = "<context>"
TOKEN_ANSWER = "<answer>"
TOKEN_END_ANSWER = "<answer>"
HIGHLIGHT_ANSWER = "<hl>"
SPLIT_SEED = 42
NPROC = 32

model_name = "t5-small"
HIGHLIGHT = True
if HIGHLIGHT:
    model_name = f"{model_name}-hl"

In [24]:
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q", "v"],
)

In [26]:
# from here we can try this:
# https://ankushmulkar.medium.com/finetune-flan-t5-using-reinforcement-learning-ppo-and-peft-for-producing-non-toxic-summaries-fed2695ee4f4

# Alternatively, we can use the following:

model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(f"./models/t5-base-hl/", device_map="auto", load_in_8bit=True)

peft_model = PeftModel.from_pretrained(model,"./models/t5-small-hl/",
                                       lora_config=peft_config, 
                                       device_map="auto",                                       
                                       is_trainable=True)

print(f'PEFT model parameters to be updated:\n{peft_model.print_trainable_parameters()}\n')



AttributeError: 'AutoModelForSeq2SeqLMWithValueHead' object has no attribute '_prepare_encoder_decoder_kwargs_for_generation'

In [None]:
# Second method failing: Use PeftModel.from_pretrained and then AutoModelForSeq2SeqLMWithValueHead.from_pretrained
#peft_model = PeftModel.from_pretrained(f"./models/{model_name}", device_map="auto", peft_config=peft_config, is_trainable=True, model_id="./models/t5-small-hl/")



# First method failing : 

model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(f"./models/{model_name}", device_map="auto", load_in_8bit=True)

peft_model = get_peft_model(model, peft_config)

peft_model.print_trainable_parameters()

model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model, device_map="auto", load_in_8bit=True)

tokenizer = transformers.AutoTokenizer.from_pretrained(f"./models/{model_name}", model_max_length=512)

torch.cuda.empty_cache()



In [None]:
# Import model given mode_name
#model = transformers.T5ForConditionalGeneration.from_pretrained(f"./models/{model_name}")

model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(f"./models/{model_name}", device_map="auto", load_in_8bit=True, peft_config=peft_config)

# Get peft model
peft_model = get_peft_model(model, peft_config)

peft_model.print_trainable_parameters()

model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model, device_map="auto", load_in_8bit=True)


tokenizer = transformers.AutoTokenizer.from_pretrained(f"./models/{model_name}", model_max_length=512)

torch.cuda.empty_cache()



In [None]:
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q", "v"],
)

model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(
    f"./models/{model_name}/", device_map="auto", load_in_8bit=True, peft_config=peft_config
)

tokenizer = transformers.AutoTokenizer.from_pretrained(f"./models/{model_name}", model_max_length=512)
torch.cuda.empty_cache()

In [None]:
def get_inputs_target(e):
    answer_start = e["answers"]["answer_start"][0]
    # add highlight token to context
    ans_len = len(e["answers"]["text"][0])

    if HIGHLIGHT:
        e["context"] = (
            e["context"][:answer_start]
            + " "
            + HIGHLIGHT_ANSWER
            + " "
            + e["context"][answer_start : answer_start + ans_len]
            + " "
            + HIGHLIGHT_ANSWER
            + " "
            + e["context"][answer_start + ans_len :]
        )

    return {
        # answer + context
        "inputs": f'generate question: {TOKEN_ANSWER} {e["answers"]["text"][0]} {TOKEN_END_ANSWER} {TOKEN_CONTEXT} {e["context"]} {TOKEN_END_CONTEXT}',
        # question
        "target": f'{TOKEN_QUESTION} {e["question"]} {TOKEN_END_QUESTION}',
    }


def preprocess_squad_dataset(dataset_name="squad", split="train"):
    dataset = datasets.load_dataset(dataset_name, split=split)
    # Add question, answer and context tokens to dataset in a new column named text
    dataset = dataset.map(
        lambda e: {
            # answer + context
            # changed to 'query' for PPO
            "query": f'generate question: {TOKEN_ANSWER} {e["answers"]["text"][0]} {TOKEN_END_ANSWER} {TOKEN_CONTEXT} {e["context"]} {TOKEN_END_CONTEXT}',
            # question
            "target": f'{TOKEN_QUESTION} {e["question"]} {TOKEN_END_QUESTION}',
        },
        num_proc=NPROC,
    )

    return dataset

# Need to have training dataset aligned with PPO input format

train_dataset = preprocess_squad_dataset(dataset_name="squad", split="train") 

In [None]:
def tokenize(sample):
    sample["input_ids"] = tokenizer.encode(sample["query"], return_tensors="pt", truncation=True, padding="max_length", max_length=512).squeeze()
    return sample

# Tokenize the dataset
train_dataset = train_dataset.map(tokenize, num_proc=NPROC)

In [None]:
config = PPOConfig(
    learning_rate=1e-5,
    log_with='tensorboard',
    project_kwargs={'logging_dir': f'./logs/{model_name}'},
    batch_size=1,
)

ppo_trainer = PPOTrainer(
    model=model,
    config=config,
    dataset=train_dataset,
    tokenizer=tokenizer,
)

In [None]:
# https://huggingface.co/docs/trl/main/en/how_to_train#how-to-generate-text-for-training
generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
}

In [None]:
def reward_model(prediction, example):
    """
    this function will return a reward function for PPO
    """

    context = example["context"]
    question = example["question"]
    answer = example["answers"]["text"][0]

    reward = bertscore.compute(
        predictions=[prediction],
        references=[question],
        lang="en",
        model_type="bert-base-uncased",
    )["f1"].item()

    repetition_penalty = -1.0 if answer.lower() in prediction.lower() else 1.0
    question_word_penalty = -0.5 if question.split()[0].lower() != prediction.split()[0].lower() else 0.5
    ans_in_question_penalty = -1.0 if answer.lower() in question.lower() else 1.0
    question_length_penalty = -0.5 if len(question.split()) > average_question_length else 0.5

    reward += repetition_penalty + question_word_penalty + ans_in_question_penalty + question_length_penalty
    return reward

In [None]:
for batch in tqdm(ppo_trainer.dataloader):
    query_tensors = batch["input_ids"]
    response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
    batch["prediction"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
    print(batch["prediction"])
    pipe_outputs = reward_model(*zip(batch["query"], batch["prediction"]))
    rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]

    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)