In [1]:
!pip3 install transformers datasets torch accelerate trl rouge_score



In [2]:
import torch
from torch import nn
from datasets import load_dataset
from transformers import (
    GPT2Config,
    GPT2Tokenizer,
    DataCollatorWithPadding,
    GPT2LMHeadModel
)
from trl import PPOConfig, PPOTrainer
from rouge_score import rouge_scorer

  from .autonotebook import tqdm as notebook_tqdm


# 1) Model w/ Value Head 

In [3]:
class GPT2WithValueHead(nn.Module):
    def __init__(self, base_model_name="gpt2"):
        super().__init__()
        self.config = GPT2Config.from_pretrained(base_model_name)
        from transformers import GPT2LMHeadModel
        self.transformer = GPT2LMHeadModel.from_pretrained(
            base_model_name, config=self.config
        )
        self.value_head = nn.Linear(self.config.hidden_size, 1, bias=False)
        self.value_head.weight.data.normal_(
            0.0, self.config.initializer_range
        )

    def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
        lm_out = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_hidden_states=True,
            **kwargs
        )
        hidden = lm_out.hidden_states[-1]                # (B, T, H)
        values = self.value_head(hidden).squeeze(-1)      # (B, T)
        return {
            "logits": lm_out.logits,
            "loss": lm_out.loss,
            "values": values,
        }



# 2) Data, Tokenizer, Model 

In [None]:
train_ds = load_dataset("cnn_dailymail", "3.0.0", split="train[:100000]")
val_ds  = load_dataset("cnn_dailymail", "3.0.0", split="validation[:10000]")

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

class MyCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.pad = DataCollatorWithPadding(tokenizer)
    def __call__(self, features):
        refs = [f.pop("refs") for f in features]
        batch = self.pad(features)
        batch["refs"] = refs
        return batch

data_collator = MyCollator(tokenizer)

# 3) Reward model

In [5]:
class RougeRewardModel(nn.Module):
    def __init__(self, tokenizer):
        super().__init__()
        self.tokenizer = tokenizer
        self.scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)

    def forward(self, prompts, response_toks, refs, **kwargs):
        hyps = self.tokenizer.batch_decode(response_toks, skip_special_tokens=True)
        scores = [
            self.scorer.score(r, h)["rougeL"].fmeasure
            for r, h in zip(refs, hyps)
        ]
        r = torch.tensor(scores, dtype=torch.float32, device=response_toks.device)
        # normalize
        return (r - r.mean()) / (r.std() + 1e-8)


# 4) PPO Setup

In [8]:
ppo_config = PPOConfig(
    learning_rate=1.4e-5,
    per_device_train_batch_size=4,
    num_ppo_epochs=4,
    cliprange=0.2,
    cliprange_value=0.2,
    vf_coef=0.1,
    kl_coef=0.1,
    gamma=1.0,
    lam=0.95,
    num_sample_generations=10,
    response_length=64,
    output_dir="gpt2-ppo-summarization",
    overwrite_output_dir=True,
    logging_dir="runs/ppo_summarization",
    logging_steps=100,
    save_steps=500,
)

# 5) Instantiate Models

In [9]:
model = GPT2WithValueHead("gpt2")
ref_model = GPT2WithValueHead("gpt2")
ref_model.load_state_dict(model.state_dict())
ref_model.eval()

model.generation_config = model.transformer.generation_config
ref_model.generation_config = ref_model.transformer.generation_config

reward_model = RougeRewardModel(tokenizer)
value_model = model.transformer

In [10]:
trainer = PPOTrainer(
    args=ppo_config,
    processing_class=tokenizer,
    model=model,
    ref_model=ref_model,
    reward_model=reward_model,
    train_dataset=train_ds,
    value_model=value_model,
    data_collator=data_collator,
    eval_dataset=val_ds,
)

In [11]:
trainer.train()

trainer.save_model("gpt2-ppo-summarization")
tokenizer.save_pretrained("gpt2-ppo-summarization")

===training policy===


ValueError: You should supply an encoding or a list of encodings to this method that includes input_ids, but you provided ['article', 'highlights', 'id']

In [None]:

model.eval()
all_scores = []
for i in range(0, len(val_ds), 4):
    batch = val_ds[i : i + 4]
    arts = batch["article"]
    refs = batch["highlights"]
    prompts = [f"Summarize:\n{a}\nTL;DR:" for a in arts]
    toks = trainer.generate(prompts, max_length=64)
    hyps = tokenizer.batch_decode(toks, skip_special_tokens=True)
    for ref, hyp in zip(refs, hyps):
        all_scores.append(
            rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
            .score(ref, hyp)["rougeL"].fmeasure
        )

avg = sum(all_scores) / len(all_scores)
print(f"\n>> Final Validation ROUGE-L: {avg:.4f}")
