In [None]:
!pip install datasets==2.14.6 gcsfs==2023.6.0 fsspec==2023.6.0 --quiet
!pip install evaluate bert-score --quiet
!pip install transformers==4.28.1 trl==0.4.7 peft==0.2.0 --quiet
!pip install numpy==1.26.4 --quiet


# Implement PPO loop for fine-tuning T5 using the trained reward model

In [None]:
# PPO Fine-Tuning of T5 using Trained Reward Model with KL Divergence
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load tokenizer and PPO-wrapped T5 model
model_name = "t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(model_name).to(device)

# Define and load Reward Model
class RewardModel(nn.Module):
    def __init__(self, base_model_name="distilbert-base-uncased"):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(base_model_name)
        self.dropout = nn.Dropout(0.1)
        self.regressor = nn.Linear(self.encoder.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        return self.regressor(self.dropout(cls_output)).squeeze(-1)

state_dict = torch.load("best_reward_model.pt", map_location=device)
state_dict["regressor.weight"] = state_dict.pop("head.weight")
state_dict["regressor.bias"] = state_dict.pop("head.bias")

reward_model = RewardModel("distilbert-base-uncased")
reward_model.load_state_dict(state_dict)
reward_model.eval().to(device)

# Reward tokenizer
reward_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

# PPO configuration
ppo_config = PPOConfig(
    learning_rate=5e-6,
    batch_size=4,
    mini_batch_size=2,
    optimize_cuda_cache=True
)

# PPO Trainer
ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=model,
    tokenizer=tokenizer,
)

# Load a small subset of CNN/DailyMail
dataset = load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]")

# PPO fine-tuning loop
for step in range(100):
    batch = dataset.shuffle(seed=step).select(range(ppo_config.batch_size))
    queries = ["summarize: " + article for article in batch["article"]]

    responses = []
    for q in queries:
        inputs = tokenizer(q, return_tensors="pt", padding=True, truncation=True).to(device)
        output_ids = model.generate(**inputs, max_length=128)
        responses.append(tokenizer.decode(output_ids[0], skip_special_tokens=True))

    # Compute rewards
    rewards = []
    for resp in responses:
        inputs = reward_tokenizer(resp, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
        with torch.no_grad():
            score = reward_model(inputs["input_ids"], inputs["attention_mask"])
        rewards.append(score.item())
    rewards = [torch.tensor(r, dtype=torch.float32).to(device) for r in rewards]

    # Tokenize and convert to list of tensors
    tokenized_queries = tokenizer(queries, padding=True, truncation=True, return_tensors="pt").input_ids.to(device)
    tokenized_responses = tokenizer(responses, padding=True, truncation=True, return_tensors="pt").input_ids.to(device)

    tokenized_queries = list(tokenized_queries)
    tokenized_responses = list(tokenized_responses)

    try:
        ppo_trainer.step(tokenized_queries, tokenized_responses, rewards)
        print(f"PPO Step {step+1}/100 completed")
    except Exception as e:
        print(f"Skipping step {step+1} due to error: {e}")
        continue

# Save final model
model.save_pretrained("t5-ppo-finetuned")
tokenizer.save_pretrained("t5-ppo-finetuned")


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


PPO Step 1/100 completed




PPO Step 2/100 completed




PPO Step 3/100 completed




PPO Step 4/100 completed




PPO Step 5/100 completed




PPO Step 6/100 completed




PPO Step 7/100 completed




PPO Step 8/100 completed




PPO Step 9/100 completed




PPO Step 10/100 completed




PPO Step 11/100 completed




PPO Step 12/100 completed




PPO Step 13/100 completed




PPO Step 14/100 completed




PPO Step 15/100 completed




PPO Step 16/100 completed




PPO Step 17/100 completed




PPO Step 18/100 completed




PPO Step 19/100 completed




PPO Step 20/100 completed




PPO Step 21/100 completed




PPO Step 22/100 completed




PPO Step 23/100 completed




PPO Step 24/100 completed




PPO Step 25/100 completed




PPO Step 26/100 completed




PPO Step 27/100 completed




PPO Step 28/100 completed




PPO Step 29/100 completed




PPO Step 30/100 completed




PPO Step 31/100 completed




PPO Step 32/100 completed




PPO Step 33/100 completed




PPO Step 34/100 completed




PPO Step 35/100 completed




PPO Step 36/100 completed




PPO Step 37/100 completed




PPO Step 38/100 completed




PPO Step 39/100 completed




PPO Step 40/100 completed




PPO Step 41/100 completed




PPO Step 42/100 completed




PPO Step 43/100 completed




PPO Step 44/100 completed




PPO Step 45/100 completed




PPO Step 46/100 completed




PPO Step 47/100 completed




PPO Step 48/100 completed




PPO Step 49/100 completed




PPO Step 50/100 completed




PPO Step 51/100 completed




PPO Step 52/100 completed




PPO Step 53/100 completed




PPO Step 54/100 completed




PPO Step 55/100 completed




PPO Step 56/100 completed




PPO Step 57/100 completed




PPO Step 58/100 completed




PPO Step 59/100 completed




PPO Step 60/100 completed




PPO Step 61/100 completed




PPO Step 62/100 completed




PPO Step 63/100 completed




PPO Step 64/100 completed




PPO Step 65/100 completed




PPO Step 66/100 completed




PPO Step 67/100 completed




PPO Step 68/100 completed




PPO Step 69/100 completed




PPO Step 70/100 completed




PPO Step 71/100 completed




PPO Step 72/100 completed




PPO Step 73/100 completed




PPO Step 74/100 completed




PPO Step 75/100 completed




PPO Step 76/100 completed




PPO Step 77/100 completed




PPO Step 78/100 completed




PPO Step 79/100 completed




PPO Step 80/100 completed




PPO Step 81/100 completed




PPO Step 82/100 completed




PPO Step 83/100 completed




PPO Step 84/100 completed




PPO Step 85/100 completed




PPO Step 86/100 completed




PPO Step 87/100 completed




PPO Step 88/100 completed




PPO Step 89/100 completed




PPO Step 90/100 completed




PPO Step 91/100 completed




PPO Step 92/100 completed




PPO Step 93/100 completed




PPO Step 94/100 completed




PPO Step 95/100 completed




PPO Step 96/100 completed




PPO Step 97/100 completed




PPO Step 98/100 completed




PPO Step 99/100 completed




PPO Step 100/100 completed


('t5-ppo-finetuned/tokenizer_config.json',
 't5-ppo-finetuned/special_tokens_map.json',
 't5-ppo-finetuned/spiece.model',
 't5-ppo-finetuned/added_tokens.json',
 't5-ppo-finetuned/tokenizer.json')

# Evaluate with metrics (ROUGE, METEOR, BERTScore)

In [None]:
import evaluate
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

rouge = evaluate.load("rouge")
meteor = evaluate.load("meteor")
bertscore = evaluate.load("bertscore")

# Load fine-tuned model
model = AutoModelForSeq2SeqLM.from_pretrained("t5-ppo-finetuned").cuda()
tokenizer = AutoTokenizer.from_pretrained("t5-ppo-finetuned")

# Load validation split
eval_data = load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]")

predictions = []
references = []

for example in eval_data:
    prompt = "summarize: " + example["article"]
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to("cuda")
    summary_ids = model.generate(**inputs, max_length=128)
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

    predictions.append(summary)
    references.append(example["highlights"])

rouge_result = rouge.compute(predictions=predictions, references=references)
meteor_result = meteor.compute(predictions=predictions, references=references)
bertscore_result = bertscore.compute(predictions=predictions, references=references, lang="en")

print("ROUGE:", rouge_result)
print("METEOR:", meteor_result)
print("BERTScore:", {
    "precision": sum(bertscore_result["precision"]) / len(bertscore_result["precision"]),
    "recall": sum(bertscore_result["recall"]) / len(bertscore_result["recall"]),
    "f1": sum(bertscore_result["f1"]) / len(bertscore_result["f1"]),
})