The dependencies for baseline

In [None]:
# Install all necessary libraries
!pip install --upgrade \
    transformers datasets evaluate rouge_score nltk sacrebleu tqdm

Collecting datasets
  Downloading datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting sacrebleu
  Downloading sacrebleu-2.5.1-py3-none-any.whl.metadata (51 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metada

Baseline training pipeline. Bart-large-cnn model with all evaluation metrics, rouge1, rouge2, rougeL, meteor and sacrebleu metrics.

In [None]:
import warnings, logging, torch, numpy as np
from tqdm.auto import tqdm
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments, Seq2SeqTrainer,
    logging as hf_logging
)
import evaluate
import nltk

# ─── Silence logs & download NLTK data ────────────────────────────────────────
warnings.filterwarnings("ignore")
hf_logging.set_verbosity_error()
logging.getLogger("datasets").setLevel(logging.ERROR)
nltk.download("wordnet", quiet=True)
nltk.download("punkt", quiet=True)

# ─── CONFIG ───────────────────────────────────────────────────────────────────
MODEL_NAME        = "facebook/bart-large-cnn"
OUTPUT_DIR        = "./baseline_bart_cnn"
TRAIN_SIZE        = 100    # 100 for quick debug
VAL_SIZE          = 10
TEST_SIZE         = 10
BATCH_SIZE        = 32       # adjust for different runtimes
NUM_EPOCHS        = 4       # bump to 3–5 for full baseline
MAX_INPUT_LENGTH  = 512
MAX_OUTPUT_LENGTH = 128

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# ─── 1. LOAD & SAMPLE ─────────────────────────────────────────────────────────
ds = load_dataset("cnn_dailymail", "3.0.0")
def sample(split, n):
    d = ds[split]
    return d.shuffle(seed=42).select(range(n)) if n else d

train_ds = sample("train", TRAIN_SIZE)
val_ds   = sample("validation", VAL_SIZE)
test_ds  = sample("test", TEST_SIZE)
print(f"Sizes → train: {len(train_ds)}, val: {len(val_ds)}, test: {len(test_ds)}")

# ─── 2. TOKENIZER & PREPROCESS ────────────────────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
def preprocess(batch):
    inp = tokenizer(batch["article"],
                    max_length=MAX_INPUT_LENGTH,
                    truncation=True,
                    padding="max_length")
    lbl = tokenizer(batch["highlights"],
                    max_length=MAX_OUTPUT_LENGTH,
                    truncation=True,
                    padding="max_length").input_ids
    inp["labels"] = lbl
    return inp

train_tok = train_ds.map(preprocess, batched=True, remove_columns=["article","highlights"])
val_tok   = val_ds.map(preprocess, batched=True, remove_columns=["article","highlights"])
test_tok  = test_ds.map(preprocess, batched=True, remove_columns=["article","highlights"])

# ─── 3. MODEL & TRAINER SETUP ─────────────────────────────────────────────────
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
collator = DataCollatorForSeq2Seq(tokenizer, model=model)

training_args = Seq2SeqTrainingArguments(
    output_dir             = OUTPUT_DIR,
    num_train_epochs       = NUM_EPOCHS,
    per_device_train_batch_size = BATCH_SIZE,
    per_device_eval_batch_size  = BATCH_SIZE,
    logging_steps          = 50,
    eval_strategy    = "epoch",
    save_strategy          = "epoch",
    save_total_limit       = 2,
    predict_with_generate  = True,
    load_best_model_at_end = True,
    metric_for_best_model  = "rougeL",
    greater_is_better      = True,
    report_to              = "none",
)
# ─── 4. LOAD METRICS ──────────────────────────────────────────────────────────
rouge     = evaluate.load("rouge")
meteor    = evaluate.load("meteor")
sacrebleu = evaluate.load("sacrebleu")

def compute_metrics(preds_labels):
    preds, labels = preds_labels
    if isinstance(preds, tuple):
        preds = preds[0]
    # Decode
    decoded_preds  = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Compute each metric
    r = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    m = meteor.compute(predictions=decoded_preds, references=decoded_labels)
    sb = sacrebleu.compute(predictions=decoded_preds, references=decoded_labels)

    # Extract the floats directly
    return {
        "rouge1":    r["rouge1"],
        "rouge2":    r["rouge2"],
        "rougeL":    r["rougeL"],
        "meteor":    m["meteor"],
        "sacrebleu": sb["score"],
    }

trainer = Seq2SeqTrainer(
    model            = model,
    args             = training_args,
    train_dataset    = train_tok,
    eval_dataset     = val_tok,
    tokenizer        = tokenizer,
    data_collator    = collator,
    compute_metrics  = compute_metrics,
)


# ─── 5. SUPERVISED TRAINING ───────────────────────────────────────────────────
print(" Training baseline...")
trainer.train()

# ─── 6. EVALUATE ON TEST ─────────────────────────────────────────────────────
print("\n Test metrics:")
test_metrics = trainer.evaluate(test_tok)
for k,v in test_metrics.items():
    if k.startswith("eval_"):
        print(f"{k}: {v:.4f}")

# ─── 7. SAMPLE OUTPUTS (FULL ARTICLE vs SUMMARY vs HIGHLIGHTS) ───────────────
print("\n Full Article | Generated Summary | Reference Highlights (3 samples)")
for i in range(min(3, len(test_ds))):
    sample = test_ds[i]
    article = sample["article"]
    reference = sample["highlights"]

    inputs = tokenizer(article, return_tensors="pt",
                       truncation=True, max_length=MAX_INPUT_LENGTH).to(device)
    with torch.no_grad():
        out = model.generate(**inputs,
                             max_length=MAX_OUTPUT_LENGTH,
                             num_beams=4,
                             early_stopping=True)
    generated = tokenizer.decode(out[0], skip_special_tokens=True)

    print(f"\n--- SAMPLE {i+1} ---")
    print("\nFULL ARTICLE:\n", article)
    print("\nGENERATED SUMMARY:\n", generated)
    print("\nREFERENCE HIGHLIGHTS:\n", reference)
    print("\n" + "="*80)


Device: cpu


README.md:   0%|          | 0.00/15.6k [00:00<?, ?B/s]

train-00000-of-00003.parquet:   0%|          | 0.00/257M [00:00<?, ?B/s]

train-00001-of-00003.parquet:   0%|          | 0.00/257M [00:00<?, ?B/s]

train-00002-of-00003.parquet:   0%|          | 0.00/259M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/34.7M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/30.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

Sizes → train: 100, val: 10, test: 10


config.json:   0%|          | 0.00/1.58k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/7.02k [00:00<?, ?B/s]

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...


Downloading builder script:   0%|          | 0.00/8.15k [00:00<?, ?B/s]

 Training baseline...
{'eval_loss': 4.164383888244629, 'eval_rouge1': 0.4346190151557174, 'eval_rouge2': 0.1914825565138844, 'eval_rougeL': 0.29184170881184057, 'eval_meteor': 0.3852253480667711, 'eval_sacrebleu': 15.510409633704695, 'eval_runtime': 73.5098, 'eval_samples_per_second': 0.136, 'eval_steps_per_second': 0.014, 'epoch': 1.0}
{'eval_loss': 2.2790589332580566, 'eval_rouge1': 0.42064773040579495, 'eval_rouge2': 0.17819284667808355, 'eval_rougeL': 0.2821752698810519, 'eval_meteor': 0.3806397977534133, 'eval_sacrebleu': 13.26528689509695, 'eval_runtime': 56.7855, 'eval_samples_per_second': 0.176, 'eval_steps_per_second': 0.018, 'epoch': 2.0}
{'eval_loss': 1.7856388092041016, 'eval_rouge1': 0.4419297995079012, 'eval_rouge2': 0.18768261435697298, 'eval_rougeL': 0.2922554082170653, 'eval_meteor': 0.39644982163974235, 'eval_sacrebleu': 14.16258338876916, 'eval_runtime': 50.3796, 'eval_samples_per_second': 0.198, 'eval_steps_per_second': 0.02, 'epoch': 3.0}
{'eval_loss': 1.6572692394

summary generation settings, for different variety of generated samples to choose from

In [None]:
from itertools import product

# define a few generation settings
gen_configs = [
  {"num_beams": 4,  "length_penalty": 1.0, "max_length": 128},
  {"num_beams": 4,  "length_penalty": 2.0, "max_length": 128},
  {"do_sample": True, "top_k": 50,  "max_length": 128},
  {"do_sample": True, "top_p": 0.9, "max_length": 128},
]

def generate_candidates(article: str):
    inputs = tokenizer(article, return_tensors="pt",
                       truncation=True, max_length=512).to(device)
    candidates = []
    for cfg in gen_configs:
        out = model.generate(**inputs, **cfg)
        cand = tokenizer.decode(out[0], skip_special_tokens=True)
        candidates.append(cand)
    return candidates


We are exporting 100 articles and forming 4 candidates for each of them.

In [None]:
import os, torch, pandas as pd
from tqdm.auto import tqdm

# ─── 0) Max out CPU threads ───────────────────────────────────────────────────
n_threads = os.cpu_count() or 1
torch.set_num_threads(n_threads)
DEVICE = torch.device("cpu")
print(f"Using {n_threads} CPU threads on {DEVICE}")

# ─── 1) Sample size & batch settings ──────────────────────────────────────────
NUM_SAMPLES = 100      # total articles to export
BATCH_GEN   = 8        # articles per generate() batch
MAX_IN_LEN  = 512
MAX_OUT_LEN = 128

gen_configs = [
    {"num_beams": 4, "length_penalty": 1.0, "max_length": MAX_OUT_LEN},
    {"num_beams": 4, "length_penalty": 2.0, "max_length": MAX_OUT_LEN},
    {"do_sample": True, "top_k": 50,  "max_length": MAX_OUT_LEN},
    {"do_sample": True, "top_p": 0.9, "max_length": MAX_OUT_LEN},
]

# ─── 2) Grab first N articles ─────────────────────────────────────────────────
split    = "validation"    # choose "train"/"validation"/"test"
subset   = ds[split].select(range(NUM_SAMPLES))
articles = [ex["article"] for ex in subset]

# ─── 3) Batched generation per config ─────────────────────────────────────────
all_cands = {i: [] for i in range(len(gen_configs))}

for cfg_idx, cfg in enumerate(gen_configs):
    print(f"\n Generating config #{cfg_idx+1}: {cfg}")
    for start in tqdm(range(0, NUM_SAMPLES, BATCH_GEN), desc="Batches"):
        batch = articles[start : start + BATCH_GEN]
        tokens = tokenizer(
            batch,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=MAX_IN_LEN
        ).to(DEVICE)
        outs = model.generate(**tokens, **cfg)
        texts = tokenizer.batch_decode(outs, skip_special_tokens=True)
        all_cands[cfg_idx].extend(texts)

    assert len(all_cands[cfg_idx]) == NUM_SAMPLES, "Wrong count"

# ─── 4) Build DataFrame & Save CSV ────────────────────────────────────────────
rows = []
for idx in range(NUM_SAMPLES):
    row = {"article_id": idx, "article": articles[idx]}
    for cfg_idx in range(len(gen_configs)):
        row[f"candidate_{cfg_idx+1}"] = all_cands[cfg_idx][idx]
    rows.append(row)

df = pd.DataFrame(rows)
csv_path = "summarization_candidates_100_batched.csv"
df.to_csv(csv_path, index=False)
print(f"\nWrote {csv_path} with columns: {df.columns.tolist()}")


Using 8 CPU threads on cpu

▶ Generating config #1: {'num_beams': 4, 'length_penalty': 1.0, 'max_length': 128}


Batches:   0%|          | 0/13 [00:00<?, ?it/s]


▶ Generating config #2: {'num_beams': 4, 'length_penalty': 2.0, 'max_length': 128}


Batches:   0%|          | 0/13 [00:00<?, ?it/s]


▶ Generating config #3: {'do_sample': True, 'top_k': 50, 'max_length': 128}


Batches:   0%|          | 0/13 [00:00<?, ?it/s]


▶ Generating config #4: {'do_sample': True, 'top_p': 0.9, 'max_length': 128}


Batches:   0%|          | 0/13 [00:00<?, ?it/s]


Wrote summarization_candidates_100_batched.csv with columns: ['article_id', 'article', 'candidate_1', 'candidate_2', 'candidate_3', 'candidate_4']


 Here we are generating combinations,choosing 2 candidates from 4 which gives us 6 combinations and thus 600 combinations.

In [None]:
import pandas as pd
from itertools import combinations

# 1) Load your 4-candidate CSV
df = pd.read_csv("summarization_candidates_100_batched.csv")

# 2) Build pairwise rows
pairs = []
for _, row in df.iterrows():
    art_id = row["article_id"]
    article = row["article"]
    # pull the four candidates into a list
    cands = [row[f"candidate_{i}"] for i in range(1,5)]
    # all i<j combinations
    for i,j in combinations(range(4), 2):
        pairs.append({
            "article_id": art_id,
            "article":    article,
            "summary_A":  cands[i],
            "summary_B":  cands[j],
            "preferred":  ""      # leave blank for annotator to fill "A" or "B"
        })

pair_df = pd.DataFrame(pairs)
out_path = "preference_pairs.csv"
pair_df.to_csv(out_path, index=False)
print(f"→ Wrote {out_path} with {len(pair_df)} pairwise rows (4-way → pairs)")


→ Wrote preference_pairs.csv with 600 pairwise rows (4-way → pairs)


Here i am filtering out the similar pairs of candidates and ended up with exactly 365 pairs which i specifically annotates after inspecting

In [None]:
import pandas as pd
import difflib
import textwrap

# Load your filtered (dissimilar) pairs
pairs = pd.read_csv("dissimilar_pairs.csv")

# Function to pretty-print multi-paragraph text
def print_paragraphs(text, width=80):
    for para in text.split("\n"):
        para = para.strip()
        if not para:
            print()  # preserve blank lines
        else:
            print(textwrap.fill(para, width=width))
    print()  # extra newline at end

# Annotation loop
annotated = []
for idx, row in pairs.iterrows():
    print("\n" + "="*80)
    print(f"PAIR {idx+1}/{len(pairs)}    (sim={row.sim:.2f})\n")

    print("FULL ARTICLE:\n")
    print_paragraphs(row.article, width=100)

    print("---- SUMMARY A ----\n")
    print_paragraphs(row.summary_A, width=100)

    print("---- SUMMARY B ----\n")
    print_paragraphs(row.summary_B, width=100)

    choice = None
    while choice not in ("A","B"):
        choice = input("Which do you prefer? (A/B): ").strip().upper()

    annotated.append({
        "article_id": row.article_id,
        "preferred":  choice,
        "sim":        row.sim
    })

# Save results
out_df = pd.DataFrame(annotated)
out_df.to_csv("annotated_dissimilar_paragraphs.csv", index=False)
print("\n\BABABABAABABA Saved to annotated_dissimilar_paragraphs.csv")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
borders; one that has not yet decided whether it is a Western society or a Middle Eastern one; one
that cannot decide whether it wants to be religious or secular, Jewish or bi-national. All of these
critical issues -- none of them decided on -- have been cast aside, ignored, covered up or denied by
a country that has busied itself with the important business of recycled bottles at the prime
minister's residence. There is a big elephant in the room, but Israel is turning its back to it.
There is a big elephant in the room, but Israel believes that if nobody talks about it, the elephant
does not exist. This elephant is absent from the Israeli discourse on a day-to-day basis, and it is
absent during elections -- a time when public discourse should be only be focused on what really
matters. The elephant in the Israeli room is the unending occupation of Palestinian territories, and
nobody is talking about it. Most of the parti

Here we are training a distlbert-uncased (i dont know why), with the files below with self-explainatory names.

In [None]:
import pandas as pd
from datasets import Dataset
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    Trainer, TrainingArguments
)
import numpy as np
import evaluate

# ─── 1) Load original pair text + your annotations ────────────────────────────
pairs_df = pd.read_csv("dissimilar_pairs.csv")
ann_df   = pd.read_csv("annotated_dissimilar_paragraphs.csv")

# Quick sanity check
print("pairs_df cols:", pairs_df.columns.tolist())
print("ann_df   cols:",   ann_df.columns.tolist())

# ─── 2) Merge so each annotated row has article, summary_A/B, and preferred ───
# We'll join on article_id AND sim to be safe
merged = ann_df.merge(
    pairs_df,
    on=["article_id", "sim"],
    how="left",
    suffixes=("_ann","")
)
assert merged["article"].notna().all(), "some rows failed to merge!"

# Keep only the columns we need
merged = merged[[
    "article_id",
    "article",
    "summary_A",
    "summary_B",
    "preferred"
]]
print("After merge, examples:", len(merged))
print(merged.head())

# ─── 3) Build a HuggingFace Dataset with (text1, text2, label) ───────────────
records = []
for _, row in merged.iterrows():
    art   = row["article"].replace("\n"," ")
    A     = row["summary_A"].strip()
    B     = row["summary_B"].strip()
    label = 0 if row["preferred"] == "A" else 1
    records.append({
        "text1": f"{art} </s> {A}",
        "text2": B,
        "label": label
    })

hf_ds = Dataset.from_pandas(pd.DataFrame(records))
split = hf_ds.train_test_split(test_size=0.2, seed=42)
train_ds, val_ds = split["train"], split["test"]
print("Train/Val sizes:", len(train_ds), len(val_ds))

# ─── 4) Tokenize ──────────────────────────────────────────────────────────────
MODEL_NAME = "distilbert-base-uncased"
tokenizer  = AutoTokenizer.from_pretrained(MODEL_NAME)

def preprocess(batch):
    return tokenizer(
        batch["text1"], batch["text2"],
        truncation=True, padding="max_length", max_length=256
    )

train_ds = train_ds.map(preprocess, batched=True, remove_columns=["text1","text2"])
val_ds   = val_ds.map(preprocess,   batched=True, remove_columns=["text1","text2"])

# ─── 5) Setup & train the reward model ─────────────────────────────────────────
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=2
)

accuracy = evaluate.load("accuracy")
def compute_metrics(preds_labels):
    preds, labels = preds_labels
    preds = np.argmax(preds, axis=1)
    return accuracy.compute(predictions=preds, references=labels)

training_args = TrainingArguments(
    output_dir              = "./reward_model",
    eval_strategy     = "epoch",
    save_strategy           = "epoch",
    learning_rate           = 2e-5,
    per_device_train_batch_size = ,
    per_device_eval_batch_size  = 8,
    num_train_epochs        = 3,
    load_best_model_at_end  = True,
    metric_for_best_model   = "accuracy",
    logging_steps           = 10,
    report_to               = "none",
)

trainer = Trainer(
    model           = model,
    args            = training_args,
    train_dataset   = train_ds,
    eval_dataset    = val_ds,
    tokenizer       = tokenizer,
    compute_metrics = compute_metrics,
)

print(" Fine-tuning reward model…")
trainer.train()

print("\n Validation metrics:", trainer.evaluate())
trainer.save_model("./reward_model_best")
print(" Reward model saved to ./reward_model_best")


pairs_df cols: ['article_id', 'article', 'summary_A', 'summary_B', 'preferred', 'sim']
ann_df   cols: ['article_id', 'preferred', 'sim']
After merge, examples: 537
   article_id                                            article  \
0           0  (CNN)Share, and your gift will be multiplied. ...   
1           1  (CNN)On the 6th of April 1996, San Jose Clash ...   
2           1  (CNN)On the 6th of April 1996, San Jose Clash ...   
3           1  (CNN)On the 6th of April 1996, San Jose Clash ...   
4           1  (CNN)On the 6th of April 1996, San Jose Clash ...   

                                           summary_A  \
0  Zully Broussard selflessly gave one of her kid...   
1  The first Major League Soccer match took place...   
2  The first Major League Soccer match took place...   
3  The first Major League Soccer match took place...   
4  The first Major League Soccer match took place...   

                                           summary_B  preferred  
0  Zully Broussard selfl

Map:   0%|          | 0/429 [00:00<?, ? examples/s]

Map:   0%|          | 0/108 [00:00<?, ? examples/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


 Fine-tuning reward model…


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0071,0.004299,1.0
2,0.0026,0.002068,1.0
3,0.0021,0.001737,1.0



 Validation metrics: {'eval_loss': 0.004298726562410593, 'eval_accuracy': 1.0, 'eval_runtime': 11.0037, 'eval_samples_per_second': 9.815, 'eval_steps_per_second': 1.272, 'epoch': 3.0}
 Reward model saved to ./reward_model_best


  We are comparing the reward model we trained and the bart-large-mnli model.  code compares two models for predicting human preferences between two summaries (A and B) of the same article. I didn tklnow what other metric could be helpful.

In [None]:
import torch
import pandas as pd
import numpy as np
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification
)
from sklearn.metrics import accuracy_score
from scipy.stats import pearsonr

# ─── Device ───────────────────────────────────────────────────────────────────
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using", DEVICE)

# ─── 1) Reload & merge your annotations + text ────────────────────────────────
pairs_df = pd.read_csv("dissimilar_pairs.csv")                    # full text + sim
ann_df   = pd.read_csv("annotated_dissimilar_paragraphs.csv")     # human A/B

# Clean up stray whitespace
pairs_df.columns = pairs_df.columns.str.strip()
ann_df.columns   = ann_df.columns.str.strip()

merged = ann_df.merge(pairs_df, on=["article_id","sim"], how="left",
                      suffixes=("_ann","_orig"))
merged = merged.drop(columns=["preferred_orig"])  # drop the blank col
print("Examples:", len(merged))

# Prepare human labels
labels = merged["preferred_ann"].map({"A":0,"B":1}).tolist()

# Prepare text tuples
texts = [
    (
      merged.loc[i,"article"].replace("\n"," "),
      merged.loc[i,"summary_A"].strip(),
      merged.loc[i,"summary_B"].strip()
    )
    for i in range(len(merged))
]

# ─── 2) Load your fine-tuned reward model ────────────────────────────────────
our_tok = AutoTokenizer.from_pretrained("distilbert-base-uncased")
our_mod = AutoModelForSequenceClassification.from_pretrained(
    "./reward_model_best"
).to(DEVICE).eval()

# ─── 3) Load HF pretrained NLI model (BART-MNLI) ─────────────────────────────
nli_tok = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
nli_mod = AutoModelForSequenceClassification.from_pretrained(
    "facebook/bart-large-mnli"
).to(DEVICE).eval()

# ─── 4) Scoring functions ────────────────────────────────────────────────────
import torch.nn.functional as F

def our_score(article, A, B):
    # 1) Tokenize without pulling in any leftover labels
    enc = our_tok(
        article,
        A,
        return_tensors="pt",
        truncation=True,
        padding="longest",
        max_length=256
    )
    input_ids     = enc["input_ids"].to(DEVICE)
    attention_mask= enc["attention_mask"].to(DEVICE)

    # 2) Forward only the tensors we want
    with torch.no_grad():
        logits = our_mod(
            input_ids=input_ids,
            attention_mask=attention_mask
        ).logits.squeeze()

    # 3) Return difference between the "A-preferred" vs "B-preferred" logits
    return (logits[0] - logits[1]).item()


def nli_score(article, A, B):
    # Same for the NLI model, only pass input_ids & attention_mask
    encA = nli_tok(
        article, A,
        return_tensors="pt",
        truncation=True,
        padding="longest",
        max_length=256
    )
    encB = nli_tok(
        article, B,
        return_tensors="pt",
        truncation=True,
        padding="longest",
        max_length=256
    )

    with torch.no_grad():
        logitsA = nli_mod(
            input_ids=encA["input_ids"].to(DEVICE),
            attention_mask=encA["attention_mask"].to(DEVICE)
        ).logits.squeeze()
        logitsB = nli_mod(
            input_ids=encB["input_ids"].to(DEVICE),
            attention_mask=encB["attention_mask"].to(DEVICE)
        ).logits.squeeze()

    # MNLI classes are [contradiction, neutral, entailment]
    probA_entail = torch.softmax(logitsA, dim=-1)[2].item()
    probB_entail = torch.softmax(logitsB, dim=-1)[2].item()
    return probA_entail - probB_entail


def nli_entail_prob(premise, hypothesis):
    inp = nli_tok(premise, hypothesis, return_tensors="pt",
                  truncation=True, padding=True).to(DEVICE)
    with torch.no_grad():
        logits = nli_mod(**inp).logits.squeeze()
    probs = F.softmax(logits, dim=-1)  # [contradiction, neutral, entailment]
    return probs[2].item()

# ─── 5) Compute scores & predictions ──────────────────────────────────────────
our_scores, nli_scores = [], []
for art, A, B in texts:
    our_scores.append( our_score(art, A, B) )
    nli_scores.append(nli_score(art, A, B))

our_preds = [0 if s>0 else 1 for s in our_scores]
nli_preds = [0 if s>0 else 1 for s in nli_scores]

# ─── 6) Compare to human labels ──────────────────────────────────────────────
print("Our reward-model accuracy:", accuracy_score(labels, our_preds))
print("NLI-model accuracy:       ", accuracy_score(labels, nli_preds))

corr, pval = pearsonr(our_scores, nli_scores)
print(f"Correlation of raw scores: r={corr:.3f}, p={pval:.2e}")


Using cpu
Examples: 537
Our reward-model accuracy: 0.6405959031657356
NLI-model accuracy:        0.45251396648044695
Correlation of raw scores: r=0.008, p=8.61e-01


From here it is just the PPO inegration. The installation snippets are all to figure out what dependencies are right.

In [None]:
# 1) Uninstall any mismatched packages
!pip uninstall -y trl transformers accelerate

# 2) Install compatible versions
!pip install --upgrade \
    transformers==4.33.2 \
    accelerate==0.25.0

# 3) Install the latest TRL from Hugging Face
!pip install git+https://github.com/huggingface/trl.git@main


[0mFound existing installation: transformers 4.51.3
Uninstalling transformers-4.51.3:
  Successfully uninstalled transformers-4.51.3
Found existing installation: accelerate 1.6.0
Uninstalling accelerate-1.6.0:
  Successfully uninstalled accelerate-1.6.0
Collecting transformers==4.33.2
  Downloading transformers-4.33.2-py3-none-any.whl.metadata (119 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.9/119.9 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate==0.25.0
  Downloading accelerate-0.25.0-py3-none-any.whl.metadata (18 kB)
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers==4.33.2)
  Downloading tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.10.0->accelerate==0.25.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torc

Collecting git+https://github.com/huggingface/trl.git@main
  Cloning https://github.com/huggingface/trl.git (to revision main) to /tmp/pip-req-build-4dao245f
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/trl.git /tmp/pip-req-build-4dao245f
  Resolved https://github.com/huggingface/trl.git to commit 89d44caece2cd7d085bb66d49be55bce7ca2c1ca
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate>=0.34.0 (from trl==0.18.0.dev0)
  Downloading accelerate-1.6.0-py3-none-any.whl.metadata (19 kB)
Collecting transformers>=4.46.0 (from trl==0.18.0.dev0)
  Downloading transformers-4.51.3-py3-none-any.whl.metadata (38 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers>=4.46.0->trl==0.18.0.dev0)
  Downloading tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Downloading accele

In [None]:
import transformers, accelerate, trl
print("transformers:", transformers.__version__)
print("accelerate:  ", accelerate.__version__)
print("trl:         ", trl.__version__)

# Try importing the PPO classes
from trl import PPOConfig, PPOTrainer
print(" PPOConfig & PPOTrainer available")


transformers: 4.51.3
accelerate:   1.6.0
trl:          0.18.0.dev0
 PPOConfig & PPOTrainer available


In [None]:
# install the bleeding‑edge trl
!pip install git+https://github.com/huggingface/trl.git@main

# run PPO, saving into a local folder named "ppo_output" (no "./" prefix)
!python -m trl.scripts.ppo \
  --model_name_or_path     facebook/bart-large-cnn \
  --sft_model_path         facebook/bart-large-cnn \
  --reward_model_path      ./reward_model_best \
  --dataset_name           trl-internal-testing/descriptiveness-sentiment-trl-style \
  --dataset_split          train \
  --learning_rate          1e-5 \
  --batch_size             4 \
  --num_ppo_epochs         1 \
  --min_length             50 \
  --max_length             128 \
  --top_k                  4 \
  --top_p                  1.0 \
  --output_dir             ppo_output


Collecting git+https://github.com/huggingface/trl.git@main
  Cloning https://github.com/huggingface/trl.git (to revision main) to /tmp/pip-req-build-tc3nyx1i
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/trl.git /tmp/pip-req-build-tc3nyx1i
  Resolved https://github.com/huggingface/trl.git to commit cc044e35b285be7dc062764b3364e1e684db4c7c
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
/usr/bin/python3: No module named trl.scripts.ppo


In [None]:
import inspect
from trl import PPOConfig, PPOTrainer

print("PPOConfig signature:\n", inspect.signature(PPOConfig), "\n")
print("PPOTrainer signature:\n", inspect.signature(PPOTrainer))


PPOConfig signature:

PPOTrainer signature:
 (args: trl.trainer.ppo_config.PPOConfig, processing_class: Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType], model: torch.nn.modules.module.Module, ref_model: Optional[torch.nn.modules.module.Module], reward_model: torch.nn.modules.module.Module, train_dataset: datasets.arrow_dataset.Dataset, value_model: Optional[torch.nn.modules.module.Module] = None, data_collator: Optional[transformers.data.data_collator.DataCollatorWithPadding] = None, eval_dataset: Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None, optimizers: tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), callbacks: Optional[list[transformers.trainer_callback.TrainerCallback]] = None, peft_config: Optional

In [None]:
# 1) Pin to the TRL / transformers versions we tested
!pip install --quiet \
    transformers==4.26.1 \
    accelerate==0.15.0 \
    trl==0.4.6

# 2) Restart your runtime! (Colab: Runtime → Restart runtime)


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/191.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m191.5/191.5 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
!pip install git+https://github.com/huggingface/trl.git@main transformers accelerate

Collecting git+https://github.com/huggingface/trl.git@main
  Cloning https://github.com/huggingface/trl.git (to revision main) to /tmp/pip-req-build-ej2odcz0
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/trl.git /tmp/pip-req-build-ej2odcz0
  Resolved https://github.com/huggingface/trl.git to commit cc044e35b285be7dc062764b3364e1e684db4c7c
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate
  Using cached accelerate-1.6.0-py3-none-any.whl.metadata (19 kB)
Collecting transformers
  Using cached transformers-4.51.3-py3-none-any.whl.metadata (38 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers)
  Using cached tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Using cached transformers-4.51.3-py3-none-any.whl (10.4 MB)
Using cached accelerate-1.6.0-py3-none-any.wh

In [None]:
# 0) pip-install & restart:
# !pip install git+https://github.com/huggingface/trl.git@main transformers accelerate

import torch
import torch.nn.functional as F
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForSequenceClassification
)
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead
from datasets import Dataset

DEVICE    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LM_NAME   = "gpt2-medium"
NLI_NAME  = "facebook/bart-large-mnli"

# 1) Tokenizer & plain LM
tokenizer           = AutoTokenizer.from_pretrained(LM_NAME)
tokenizer.pad_token = tokenizer.eos_token
policy              = AutoModelForCausalLM.from_pretrained(LM_NAME).to(DEVICE)
ref_policy          = AutoModelForCausalLM.from_pretrained(LM_NAME).to(DEVICE)
ref_policy.eval()

# 2) LM-with-value-head and patch
value_policy = (
    AutoModelForCausalLMWithValueHead.from_pretrained(LM_NAME)
    .to(DEVICE)
    .eval()
)
value_policy.base_model_prefix = "transformer"
# ← **this** fixes the AttributeError

# 3) MNLI reward model
nli_tok = AutoTokenizer.from_pretrained(NLI_NAME)
nli_mod = AutoModelForSequenceClassification.from_pretrained(NLI_NAME).to(DEVICE).eval()
def nli_reward(p, c):
    enc   = nli_tok(p, c, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
    with torch.no_grad():
        logits = nli_mod(**enc).logits.squeeze()
    return F.softmax(logits, dim=-1)[2].item()

# 4) Tiny demo prompts
prompts = ["The Earth’s climate is changing because", "In 1929 the stock market"]
ds      = Dataset.from_list([{"prompt": p} for p in prompts])

# 5) PPO config
ppo_config = PPOConfig(
    learning_rate   = 1.4e-5,
    batch_size      = 1,
    num_ppo_epochs  = 1,
    kl_coef         = 0.1,
    cliprange       = 0.2,
    total_episodes  = len(prompts),
)

# 6) Trainer factory
def make_trainer():
    return PPOTrainer(
        ppo_config,
        tokenizer,
        policy,
        ref_policy,
        nli_mod,
        ds,
        value_model=value_policy,
    )

# 7) PPO loop
def run_ppo():
    trainer = make_trainer()
    for i, ex in enumerate(ds):
        prompt = ex["prompt"]
        inp    = tokenizer([prompt], return_tensors="pt", padding=True).to(DEVICE)
        out_ids = trainer.generate(
            inp["input_ids"],
            max_length=50,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            top_k=50,
        )
        comp = tokenizer.decode(out_ids[0, inp["input_ids"].shape[-1]:], skip_special_tokens=True)
        r    = nli_reward(prompt, comp)
        print(f"\n—Episode {i+1}—\nPrompt: {prompt}\nComp:   {comp}\nReward: {r:.3f}")
        trainer.step(inp["input_ids"], out_ids, torch.tensor([r]).to(DEVICE))
    print("\n✓ PPO done.")

if __name__ == "__main__":
    run_ppo()


ModuleNotFoundError: No module named 'trl'

In [None]:
# Clean install with verified versions
!pip uninstall -y torch torchvision torchaudio
!pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
!pip install transformers==4.35.0 datasets==2.14.6 accelerate==0.24.1
!pip install trl==0.7.10 peft==0.6.0
!pip install rouge-score nltk evaluate

Found existing installation: torch 2.0.1
Uninstalling torch-2.0.1:
  Successfully uninstalled torch-2.0.1
Found existing installation: torchvision 0.21.0+cu124
Uninstalling torchvision-0.21.0+cu124:
  Successfully uninstalled torchvision-0.21.0+cu124
Found existing installation: torchaudio 2.6.0+cu124
Uninstalling torchaudio-2.6.0+cu124:
  Successfully uninstalled torchaudio-2.6.0+cu124
Looking in indexes: https://download.pytorch.org/whl/cu118
Collecting torch==2.1.0
  Downloading https://download.pytorch.org/whl/cu118/torch-2.1.0%2Bcu118-cp311-cp311-linux_x86_64.whl (2325.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 GB[0m [31m743.8 kB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.16.0
  Downloading https://download.pytorch.org/whl/cu118/torchvision-0.16.0%2Bcu118-cp311-cp311-linux_x86_64.whl (6.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.2/6.2 MB[0m [31m88.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting to

Collecting transformers==4.35.0
  Downloading transformers-4.35.0-py3-none-any.whl.metadata (123 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/123.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m122.9/123.1 kB[0m [31m4.3 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m123.1/123.1 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets==2.14.6
  Downloading datasets-2.14.6-py3-none-any.whl.metadata (19 kB)
Collecting accelerate==0.24.1
  Downloading accelerate-0.24.1-py3-none-any.whl.metadata (18 kB)
Collecting tokenizers<0.15,>=0.14 (from transformers==4.35.0)
  Downloading tokenizers-0.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting fsspec<=2023.10.0,>=2023.1.0 (from fsspec[http]<=2023.10.0,>=2023.1.0->datasets==2.14.6)
  Downloading fsspec-2023.10.0-py3-none-any.whl.metadata (6.8 kB)
Co

Collecting trl==0.7.10
  Downloading trl-0.7.10-py3-none-any.whl.metadata (10 kB)
Collecting peft==0.6.0
  Downloading peft-0.6.0-py3-none-any.whl.metadata (23 kB)
Collecting tyro>=0.5.11 (from trl==0.7.10)
  Downloading tyro-0.9.20-py3-none-any.whl.metadata (10 kB)
Collecting shtab>=1.5.6 (from tyro>=0.5.11->trl==0.7.10)
  Downloading shtab-1.7.2-py3-none-any.whl.metadata (7.4 kB)
Downloading trl-0.7.10-py3-none-any.whl (150 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m150.9/150.9 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading peft-0.6.0-py3-none-any.whl (134 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.9/134.9 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tyro-0.9.20-py3-none-any.whl (125 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.3/125.3 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading shtab-1.7.2-py3-none-any.whl (14 kB)
Installing collected packages: shtab, ty



In [None]:
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    pipeline,
    DataCollatorForSeq2Seq
)
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from datasets import load_dataset

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model names
BASE_MODEL = "facebook/bart-large-cnn"
REWARD_MODEL = "facebook/bart-large-mnli"  # We'll use this as reward model

# Initialize tokenizer and base model
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(BASE_MODEL).to(device)
ref_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)

# Initialize reward model
reward_tokenizer = AutoTokenizer.from_pretrained(REWARD_MODEL)
reward_model = AutoModelForSequenceClassification.from_pretrained(REWARD_MODEL).to(device)

RuntimeError: Failed to import diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion because of the following error (look up to see its traceback):
Failed to import diffusers.loaders.single_file because of the following error (look up to see its traceback):
No module named 'torch.sparse._triton_ops'

In [None]:
def prepare_dataset(split="train[:1%]", max_length=512):
    dataset = load_dataset("cnn_dailymail", "3.0.0", split=split)

    def tokenize_function(examples):
        model_inputs = tokenizer(
            examples["article"],
            max_length=max_length,
            truncation=True,
            padding="max_length"
        )
        return model_inputs

    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=["article", "highlights", "id"]
    )
    return tokenized_dataset

train_dataset = prepare_dataset()

In [None]:
def compute_rewards(summaries, original_texts):
    rewards = []
    for summary, text in zip(summaries, original_texts):
        # Use NLI model to compute entailment probability
        inputs = reward_tokenizer(
            text,
            summary,
            return_tensors="pt",
            truncation=True,
            max_length=512
        ).to(device)

        with torch.no_grad():
            logits = reward_model(**inputs).logits
            prob = torch.softmax(logits, dim=1)[0][2].item()  # entailment probability

        rewards.append(prob)

    return torch.tensor(rewards, device=device)

In [None]:
ppo_config = PPOConfig(
    batch_size=8,
    learning_rate=1.41e-5,
    log_with="wandb",  # Optional: for logging
    ppo_epochs=3,
    mini_batch_size=4,
    init_kl_coef=0.2,
    adap_kl_ctrl=True
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
def train_ppo():
    ppo_trainer = PPOTrainer(
        model=model,
        ref_model=ref_model,
        tokenizer=tokenizer,
        dataset=train_dataset,
        data_collator=data_collator,
        config=ppo_config
    )

    generation_kwargs = {
        "min_length": -1,
        "top_k": 0.0,
        "top_p": 1.0,
        "do_sample": True,
        "pad_token_id": tokenizer.eos_token_id,
        "max_new_tokens": 128
    }

    for epoch in range(3):  # Adjust number of epochs
        for batch in ppo_trainer.dataloader:
            articles = batch["input_ids"]

            # Generate summaries
            summary_ids = ppo_trainer.generate(
                articles,
                return_prompt=False,
                **generation_kwargs
            )
            summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)

            # Compute rewards
            original_texts = tokenizer.batch_decode(articles, skip_special_tokens=True)
            rewards = compute_rewards(summaries, original_texts)

            # PPO step
            stats = ppo_trainer.step(summary_ids, articles, rewards)
            ppo_trainer.log_stats(stats, batch, rewards)

        print(f"Epoch {epoch+1} completed")

if __name__ == "__main__":
    train_ppo()
    model.save_pretrained("ppo_summarizer")