In [None]:
# @title Baseline Model:T5 base - Set-Up
# !pip install transformers datasets rouge-score

import torch
import numpy as np
from datasets import load_dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from torch.utils.data import Dataset
from rouge_score import rouge_scorer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
# @title Loading Datatset, Tokenizer and Model
dataset = load_dataset("cnn_dailymail", "3.0.0")
print(f"Dataset loaded - Train: {len(dataset['train'])}, Validation: {len(dataset['validation'])}, Test: {len(dataset['test'])}")

# Load tokenizer and model
tokenizer = T5Tokenizer.from_pretrained("t5-base")
model = T5ForConditionalGeneration.from_pretrained("t5-base")
model = model.to(device)

Dataset loaded - Train: 287113, Validation: 13368, Test: 11490


In [None]:
# @title SummarizationDataset Class
class SummarizationDataset(Dataset):
    def __init__(self, data, tokenizer, max_input_len=512, max_target_len=128):
        self.inputs = ["summarize: " + x for x in data["article"]]
        self.targets = data["highlights"]
        self.tokenizer = tokenizer
        self.max_input_len = max_input_len
        self.max_target_len = max_target_len

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_enc = self.tokenizer(
            self.inputs[idx], padding="max_length", truncation=True,
            max_length=self.max_input_len, return_tensors="pt"
        )
        target_enc = self.tokenizer(
            self.targets[idx], padding="max_length", truncation=True,
            max_length=self.max_target_len, return_tensors="pt"
        )
        input_ids = input_enc["input_ids"].squeeze()
        attention_mask = input_enc["attention_mask"].squeeze()
        labels = target_enc["input_ids"].squeeze()
        labels[labels == tokenizer.pad_token_id] = -100  # Ignore pad tokens in loss

        return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

In [None]:
# @title Create Training and Evaluation Datatsets
train_dataset = SummarizationDataset(dataset["train"].select(range(2000)), tokenizer)
val_dataset = SummarizationDataset(dataset["validation"].select(range(500)), tokenizer)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

Training samples: 2000
Validation samples: 500


In [None]:
# @title Evaluation metrics
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Rouge scores
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

    rouge1 = rouge2 = rougeL = 0.0
    for pred, label in zip(decoded_preds, decoded_labels):
        scores = scorer.score(label, pred)
        rouge1 += scores['rouge1'].fmeasure
        rouge2 += scores['rouge2'].fmeasure
        rougeL += scores['rougeL'].fmeasure

    # Average scores
    count = len(decoded_preds)
    return {
        'rouge1': rouge1 / count,
        'rouge2': rouge2 / count,
        'rougeL': rougeL / count
    }

In [None]:
# @title Configure Training
training_args = TrainingArguments(
    output_dir="./t5-news-summarizer",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=1,
    learning_rate=3e-4,
    # Removed: evaluation_strategy="epoch",
    # Removed: save_strategy="epoch",
    save_steps=500,       # Save every 500 steps instead
    eval_steps=500,       # Evaluate every 500 steps
    save_total_limit=2,   # Keep only the 2 most recent checkpoints
    weight_decay=0.01,
    logging_steps=50,
    logging_dir="./logs",
    report_to="none"
)

In [None]:
# @title Train Model
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

print("Starting training...")
trainer.train()
print("Training complete!")

# Save fine-tuned model
model.save_pretrained("./t5-news-summarizer")
tokenizer.save_pretrained("./t5-news-summarizer")
print("Model and tokenizer saved!")

Starting training...


Step,Training Loss
50,1.7472
100,1.6895
150,1.6612
200,1.6708
250,1.6396


Training complete!
Model and tokenizer saved!


In [None]:
# @title Summary Generation Function
def generate_summary(text, model, tokenizer, max_input_length=512, max_target_length=128):
    # Get the device that the model is on
    device = next(model.parameters()).device

    # Make sure input_ids are on the same device as the model
    input_ids = tokenizer("summarize: " + text, return_tensors="pt",
                         truncation=True, max_length=max_input_length).input_ids.to(device)

    # Generate summary
    output_ids = model.generate(input_ids, max_length=max_target_length,
                              num_beams=4, early_stopping=True)

    return tokenizer.decode(output_ids[0], skip_special_tokens=True)


In [None]:
# @title Samples generation
print("Generating sample summary...")
example_article = dataset["test"][0]["article"]
print("\nOriginal Article:\n", example_article[:500], "...\n")
generated_summary = generate_summary(example_article, model, tokenizer)
print("\nGenerated Summary:\n", generated_summary)
print("\nReference Summary:\n", dataset["test"][0]["highlights"])

Generating sample summary...

Original Article:
 (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, includin ...


Generated Summary:
 Palestinian Authority officially becomes 123rd member of the International Criminal Court . The move gives the court jurisdiction over alleged crimes committed in Palestinian territories . Israel and the United States, neither of which is an ICC member, oppose the move .

Reference Summary:
 Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June .
Israel and the Uni

In [None]:
# @title Comprehensive Evaluation Function
def evaluate_summaries(dataset, model, tokenizer, num_samples=100):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    rouge_scores = {'rouge1': 0, 'rouge2': 0, 'rougeL': 0}

    print(f"Evaluating on {min(num_samples, len(dataset))} samples...")

    for i in range(min(num_samples, len(dataset))):
        if i % 10 == 0:
            print(f"Processing sample {i}...")

        article = dataset[i]["article"]
        reference = dataset[i]["highlights"]
        generated = generate_summary(article, model, tokenizer)

        scores = scorer.score(reference, generated)
        for key in rouge_scores:
            rouge_scores[key] += scores[key].fmeasure

    # Average scores
    for key in rouge_scores:
        rouge_scores[key] /= min(num_samples, len(dataset))

    print(f"Evaluation complete!")
    return rouge_scores

In [None]:
# @title Evaluation
test_dataset = dataset["test"].select(range(100))
test_scores = evaluate_summaries(test_dataset, model, tokenizer)
print(f"Test set ROUGE scores: {test_scores}")

Evaluating on 100 samples...
Processing sample 0...
Processing sample 10...
Processing sample 20...
Processing sample 30...
Processing sample 40...
Processing sample 50...
Processing sample 60...
Processing sample 70...
Processing sample 80...
Processing sample 90...
Evaluation complete!
Test set ROUGE scores: {'rouge1': 0.3301861812648006, 'rouge2': 0.13454597845211605, 'rougeL': 0.252344467616384}


In [None]:
# @title Bart-Base pipeline
# --------------------------------------------
!pip install -q transformers==4.51.3 datasets evaluate rouge-score nltk accelerate peft trl

In [None]:
# @title pipeline-1
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
from datasets import load_dataset, DatasetDict
import evaluate
import numpy as np

# 1. CONFIG
class SFTConfig:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.max_samples = {"train": 25, "validation": 5, "test": 5}
        self.batch_size = 2 if self.device == "cpu" else 8
        self.max_input_length = 512
        self.model_name = "facebook/bart-base"
        self.epochs = 1
        self.save_steps = 50
        self.max_output_length = 128  # Default value

# 2. DATA MANAGER (FIXED)
class SFDataManager:
    def __init__(self, config):
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        self.rouge = evaluate.load("rouge")

    def load_data(self):
        ds = load_dataset("cnn_dailymail", "3.0.0")
        data = DatasetDict({
            k: ds[k].shuffle().select(range(self.config.max_samples[k]))
            for k in ["train", "validation", "test"]
        })

        # SAFE length calculation
        if len(data["train"]["highlights"]) > 0:
            lens = [len(self.tokenizer.tokenize(x)) for x in data["train"]["highlights"][:10]]
            avg_len = sum(lens)/max(1, len(lens))
            self.config.max_output_length = min(int(avg_len * 2), 128)
        return data

    def tokenize(self, dataset):
        def process(examples):
            model_inputs = self.tokenizer(
                examples["article"],
                max_length=self.config.max_input_length,
                truncation=True,
                padding="max_length"
            )
            labels = self.tokenizer(
                text_target=examples["highlights"],
                max_length=self.config.max_output_length,
                truncation=True,
                padding="max_length"
            )
            model_inputs["labels"] = labels["input_ids"]
            return model_inputs

        return dataset.map(process, batched=True, batch_size=4)

    def compute_metrics(self, eval_preds):
        preds, labels = eval_preds
        decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
        labels = np.where(labels != -100, labels, self.tokenizer.pad_token_id)
        decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
        return self.rouge.compute(
            predictions=decoded_preds,
            references=decoded_labels,
            use_stemmer=True
        )

# 3. TRAINER (FIXED)
class SFTrainer:
    def __init__(self, config):
        self.config = config
        self.data_mgr = SFDataManager(config)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(
            config.model_name,
            torch_dtype=torch.float32  # Safer for CPU
        ).to(config.device)
        self.data_collator = DataCollatorForSeq2Seq(
            self.data_mgr.tokenizer,
            model=self.model,
            padding=True
        )

    def train(self, tokenized_data):
        args = Seq2SeqTrainingArguments(
            output_dir="./saved_model",
            per_device_train_batch_size=self.config.batch_size,
            per_device_eval_batch_size=2,
            predict_with_generate=True,
            eval_strategy="steps",
            eval_steps=25,
            logging_steps=10,
            save_strategy="steps",
            save_steps=self.config.save_steps,
            num_train_epochs=self.config.epochs,
            report_to="none",
            remove_unused_columns=True,
            generation_max_length=self.config.max_output_length,
            generation_num_beams=2
        )

        trainer = Seq2SeqTrainer(
            model=self.model,
            args=args,
            train_dataset=tokenized_data["train"],
            eval_dataset=tokenized_data["validation"],
            tokenizer=self.data_mgr.tokenizer,
            compute_metrics=self.data_mgr.compute_metrics,
            data_collator=self.data_collator
        )

        trainer.train()
        return trainer

# 4. EXECUTION
def main():
    print(f"Using device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

    print("Loading data...")
    config = SFTConfig()
    data_mgr = SFDataManager(config)
    raw_data = data_mgr.load_data()

    print(f"Using max_output_length = {config.max_output_length}")

    print("Tokenizing...")
    tokenized_data = data_mgr.tokenize(raw_data)

    print("Fine-tuning...")
    trainer = SFTrainer(config)
    trainer.train(tokenized_data)

    print("\nSample generation:")
    test_sample = raw_data["test"][0]["article"][:config.max_input_length]
    inputs = data_mgr.tokenizer(
        test_sample,
        max_length=config.max_input_length,
        truncation=True,
        padding="max_length",
        return_tensors="pt"
    ).to(config.device)

    outputs = trainer.model.generate(
        **inputs,
        max_length=config.max_output_length,
        num_beams=2,
        early_stopping=True
    )

    print("\nInput:", test_sample[:200] + "...")
    print("\nGenerated:", data_mgr.tokenizer.decode(outputs[0], skip_special_tokens=True))
    print("\nReference:", raw_data["test"][0]["highlights"])

if __name__ == "__main__":
    main()

Using device: cpu
Loading data...
Using max_output_length = 128
Tokenizing...


Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Fine-tuning...


  trainer = Seq2SeqTrainer(


Step,Training Loss,Validation Loss





Sample generation:

Input: This is the dramatic moment a golden retriever had to be rescued from the water after being swept half a mile out to sea while playing on the shoreline. Ten-year-old Martha had been with her owner on ...

Generated: This is the dramatic moment a golden retriever was swept out to sea after being swept into the sea by a strong currentNew Brighton RNLI crew spot golden retrievate Martha in the water after she was swept to the water in a strong position. Ten-year-old Martha had been with her owner on the beach in Leasowe in Merseyside and was paddling in the strong current. And with the strength of the outgoing tide, the dog was rapidly swept out of the water. But when she was caught out by the current, she was rescued and her owner launched an unsuccessful rescue attempt. New Brighton RN

Reference: Martha, aged 10, had been paddling in the water when she was swept out .
Her owner launched their own rescue but was unable to grab the pet .
RNLI eventually found 

In [2]:
# @title imports for code 2
# in a notebook cell, run:
!pip install -U datasets fsspec huggingface-hub
!pip install --upgrade evaluate datasets==2.13.4
!pip install --upgrade datasets
!pip install git+https://github.com/huggingface/evaluate.git


Collecting fsspec
  Using cached fsspec-2025.3.2-py3-none-any.whl.metadata (11 kB)
Collecting evaluate
  Using cached evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
[31mERROR: Could not find a version that satisfies the requirement datasets==2.13.4 (from versions: 0.0.9, 1.0.0, 1.0.1, 1.0.2, 1.1.0, 1.1.1, 1.1.2, 1.1.3, 1.2.0, 1.2.1, 1.3.0, 1.4.0, 1.4.1, 1.5.0, 1.6.0, 1.6.1, 1.6.2, 1.7.0, 1.8.0, 1.9.0, 1.10.0, 1.10.1, 1.10.2, 1.11.0, 1.12.0, 1.12.1, 1.13.0, 1.13.1, 1.13.2, 1.13.3, 1.14.0, 1.15.0, 1.15.1, 1.16.0, 1.16.1, 1.17.0, 1.18.0, 1.18.1, 1.18.2, 1.18.3, 1.18.4, 2.0.0, 2.1.0, 2.2.0, 2.2.1, 2.2.2, 2.3.0, 2.3.1, 2.3.2, 2.4.0, 2.5.0, 2.5.1, 2.5.2, 2.6.0, 2.6.1, 2.6.2, 2.7.0, 2.7.1, 2.8.0, 2.9.0, 2.10.0, 2.10.1, 2.11.0, 2.12.0, 2.13.0, 2.13.1, 2.13.2, 2.14.0, 2.14.1, 2.14.2, 2.14.3, 2.14.4, 2.14.5, 2.14.6, 2.14.7, 2.15.0, 2.16.0, 2.16.1, 2.17.0, 2.17.1, 2.18.0, 2.19.0, 2.19.1, 2.19.2, 2.20.0, 2.21.0, 3.0.0, 3.0.1, 3.0.2, 3.1.0, 3.2.0, 3.3.0, 3.3.1, 3.3.2, 3.4.0, 3.4.1, 3.5.0, 3.5.1

In [1]:
# pip install transformers datasets evaluate tqdm nltk

import warnings
warnings.filterwarnings("ignore")

import logging
from transformers import logging as hf_logging
import datasets, evaluate

# ─── Silence library logs ────────────────────────────────────────────────────
hf_logging.set_verbosity_error()
datasets.logging.set_verbosity_error()
evaluate.logging.set_verbosity_error()
logging.getLogger("nltk").setLevel(logging.ERROR)
logging.getLogger("transformers.generation_utils").setLevel(logging.ERROR)

import torch
import numpy as np
from datasets import load_dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)

# ─── CONFIG (customize here) ────────────────────────────────────────────────
MODEL_NAME         = "facebook/bart-large-cnn"
OUTPUT_DIR         = "./bart_cnn_sum"
BATCH_SIZE         = 8
NUM_EPOCHS         = 3
MAX_INPUT_LENGTH   = 512
MAX_OUTPUT_LENGTH  = 142

# Set to an int to debug on a subset, or to None to use the full split:
TRAIN_SIZE = 100
VAL_SIZE   = 10
TEST_SIZE  = 10
NUM_EPOCHS = 3
BATCH_SIZE = 16

# Logging / checkpoint intervals (in steps)
LOG_STEPS        = 50
EVAL_STEPS       = 200
SAVE_STEPS       = 200
SAVE_TOTAL_LIMIT = 3

# ─── DEVICE ────────────────────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ─── METRICS ───────────────────────────────────────────────────────────────
rouge  = evaluate.load("rouge")
meteor = evaluate.load("meteor")

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds  = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    r = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    m = meteor.compute(predictions=decoded_preds, references=decoded_labels)
    return {
        "rouge1": round(r["rouge1"],4),
        "rouge2": round(r["rouge2"],4),
        "rougeL": round(r["rougeL"],4),
        "meteor": round(m["meteor"],4),
    }

# ─── LOAD & SAMPLE ─────────────────────────────────────────────────────────
print("Loading CNN/DailyMail 3.0.0…")
raw = load_dataset("cnn_dailymail", "3.0.0")

def sample_split(split, n):
    ds = raw[split]
    return ds.shuffle(seed=42).select(range(n)) if n is not None else ds

data = DatasetDict({
    "train":      sample_split("train",      TRAIN_SIZE),
    "validation": sample_split("validation", VAL_SIZE),
    "test":       sample_split("test",       TEST_SIZE),
})
print(f"  → train={len(data['train'])}, val={len(data['validation'])}, test={len(data['test'])}")

# ─── TOKENIZER & PREPROCESS ─────────────────────────────────────────────────
print("Initializing tokenizer & tokenizing…")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def preprocess(batch):
    inputs = tokenizer(
        batch["article"],
        max_length=MAX_INPUT_LENGTH,
        truncation=True,
        padding="max_length"
    )
    labels = tokenizer(
        batch["highlights"],
        max_length=MAX_OUTPUT_LENGTH,
        truncation=True,
        padding="max_length"
    ).input_ids
    inputs["labels"] = labels
    return inputs

# parallelize over 4 cores
tokenized = data.map(
    preprocess,
    batched=True,
    num_proc=4,
    remove_columns=["article","highlights"]
)

# ─── MODEL & COLLATOR ───────────────────────────────────────────────────────
print(f"Loading model to {device}…")
model     = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
collator  = DataCollatorForSeq2Seq(tokenizer, model=model)

# ─── TRAINING ARGS ──────────────────────────────────────────────────────────
training_args = Seq2SeqTrainingArguments(
    output_dir               = OUTPUT_DIR,
    num_train_epochs         = NUM_EPOCHS,
    per_device_train_batch_size = BATCH_SIZE,
    per_device_eval_batch_size  = BATCH_SIZE,

    logging_strategy         = "steps",
    logging_steps            = LOG_STEPS,
    logging_first_step       = True,

    evaluation_strategy      = "steps",
    eval_steps               = EVAL_STEPS,

    save_strategy            = "steps",
    save_steps               = SAVE_STEPS,
    save_total_limit         = SAVE_TOTAL_LIMIT,

    disable_tqdm             = True,
    predict_with_generate    = True,
    fp16                     = (device.type=="cuda"),
    load_best_model_at_end   = True,
    metric_for_best_model    = "rougeL",
    greater_is_better        = True,
    report_to                = "none",
)

# ─── TRAINER ────────────────────────────────────────────────────────────────
trainer = Seq2SeqTrainer(
    model           = model,
    args            = training_args,
    train_dataset   = tokenized["train"],
    eval_dataset    = tokenized["validation"],
    tokenizer       = tokenizer,
    data_collator   = collator,
    compute_metrics = compute_metrics,
)

# ─── TRAIN & RESUME ─────────────────────────────────────────────────────────
print("▶ Fine-tuning started…")
trainer.train()
# If you restart the script later, pick up from the last checkpoint with:
#   trainer.train(resume_from_checkpoint=True)

# ─── FINAL EVALUATION ───────────────────────────────────────────────────────
print("\n== Validation Metrics ==")
val_metrics = trainer.evaluate(tokenized["validation"])
print({k: val_metrics[k] for k in sorted(val_metrics) if k.startswith("eval_")})

print("\n== Test Metrics ==")
test_metrics = trainer.evaluate(tokenized["test"])
print({k: test_metrics[k] for k in sorted(test_metrics) if k.startswith("eval_")})


Using device: cpu


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


Loading CNN/DailyMail 3.0.0…
  → train=100, val=10, test=10
Initializing tokenizer & tokenizing…


Map (num_proc=4):   0%|          | 0/100 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Loading model to cpu…


The following columns in the training set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: id. If id are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 100
  Num Epochs = 3
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 21
  Number of trainable parameters = 406290432
You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


▶ Fine-tuning started…
{'loss': 7.2474, 'learning_rate': 4.761904761904762e-05, 'epoch': 0.14}




Training completed. Do not forget to share your model on huggingface.co/models =)


The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: id. If id are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 16
Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "length_penalty": 2.0,
  "max_length": 142,
  "min_length": 56,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}



{'train_runtime': 1197.7744, 'train_samples_per_second': 0.25, 'train_steps_per_second': 0.018, 'train_loss': 2.7065482366652716, 'epoch': 3.0}

== Validation Metrics ==


The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: id. If id are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 16
Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "length_penalty": 2.0,
  "max_length": 142,
  "min_length": 56,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}



{'eval_loss': 1.368355631828308, 'eval_rouge1': 0.4014, 'eval_rouge2': 0.1555, 'eval_rougeL': 0.2687, 'eval_meteor': 0.3281, 'eval_runtime': 49.5882, 'eval_samples_per_second': 0.202, 'eval_steps_per_second': 0.02, 'epoch': 3.0}
{'eval_loss': 1.368355631828308, 'eval_meteor': 0.3281, 'eval_rouge1': 0.4014, 'eval_rouge2': 0.1555, 'eval_rougeL': 0.2687, 'eval_runtime': 49.5882, 'eval_samples_per_second': 0.202, 'eval_steps_per_second': 0.02}

== Test Metrics ==
{'eval_loss': 1.203233003616333, 'eval_rouge1': 0.3793, 'eval_rouge2': 0.1493, 'eval_rougeL': 0.2591, 'eval_meteor': 0.3471, 'eval_runtime': 49.9964, 'eval_samples_per_second': 0.2, 'eval_steps_per_second': 0.02, 'epoch': 3.0}
{'eval_loss': 1.203233003616333, 'eval_meteor': 0.3471, 'eval_rouge1': 0.3793, 'eval_rouge2': 0.1493, 'eval_rougeL': 0.2591, 'eval_runtime': 49.9964, 'eval_samples_per_second': 0.2, 'eval_steps_per_second': 0.02}


In [5]:
# pip install transformers datasets evaluate tqdm nltk

import warnings
warnings.filterwarnings("ignore")

import logging
from transformers import logging as hf_logging
import datasets, evaluate

# ─── Silence library logs ────────────────────────────────────────────────────
hf_logging.set_verbosity_error()
datasets.logging.set_verbosity_error()
evaluate.logging.set_verbosity_error()
logging.getLogger("nltk").setLevel(logging.ERROR)
logging.getLogger("transformers.generation_utils").setLevel(logging.ERROR)

import torch
import numpy as np
from datasets import load_dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)

# ─── CONFIG (customize here) ────────────────────────────────────────────────
MODEL_NAME         = "facebook/bart-large-cnn"
OUTPUT_DIR         = "./bart_cnn_sum"
BATCH_SIZE         = 16         # per device
NUM_EPOCHS         = 3
MAX_INPUT_LENGTH   = 512
MAX_OUTPUT_LENGTH  = 64
GEN_MAX_LENGTH     = MAX_OUTPUT_LENGTH * 2

# Set to an int to debug on a subset, or None to use the full splits:
TRAIN_SIZE = 200
VAL_SIZE   = 15
TEST_SIZE  = 15

# Logging / checkpoint intervals (in steps)
LOG_STEPS        = 50
EVAL_STEPS       = 200
SAVE_STEPS       = 200
SAVE_TOTAL_LIMIT = 3

# ─── DEVICE ────────────────────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ─── METRICS ───────────────────────────────────────────────────────────────
rouge  = evaluate.load("rouge")
meteor = evaluate.load("meteor")

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    dec_preds  = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels     = np.where(labels != -100, labels, tokenizer.pad_token_id)
    dec_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    r = rouge.compute(predictions=dec_preds, references=dec_labels, use_stemmer=True)
    m = meteor.compute(predictions=dec_preds, references=dec_labels)
    return {
        "rouge1": round(r["rouge1"],4),
        "rouge2": round(r["rouge2"],4),
        "rougeL": round(r["rougeL"],4),
        "meteor": round(m["meteor"],4),
    }

# ─── LOAD & SAMPLE ─────────────────────────────────────────────────────────
print("Loading CNN/DailyMail 3.0.0…")
raw = load_dataset("cnn_dailymail", "3.0.0")

def sample_split(split, size):
    ds = raw[split]
    return ds.shuffle(seed=42).select(range(size)) if size is not None else ds

data = DatasetDict({
    "train":      sample_split("train",      TRAIN_SIZE),
    "validation": sample_split("validation", VAL_SIZE),
    "test":       sample_split("test",       TEST_SIZE),
})
print(f"  → train={len(data['train'])}, val={len(data['validation'])}, test={len(data['test'])}")

# ─── TOKENIZER & PREPROCESS ─────────────────────────────────────────────────
print("Tokenizing…")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def preprocess(batch):
    inputs = tokenizer(
        batch["article"],
        max_length=MAX_INPUT_LENGTH,
        truncation=True,
        padding="max_length"
    )
    labels = tokenizer(
        batch["highlights"],
        max_length=MAX_OUTPUT_LENGTH,
        truncation=True,
        padding="max_length"
    ).input_ids
    inputs["labels"] = labels
    return inputs

tokenized = data.map(
    preprocess,
    batched=True,
    num_proc=3,
    remove_columns=["article","highlights"]
)

# ─── MODEL & COLLATOR ───────────────────────────────────────────────────────
print(f"Loading model to {device}…")
model    = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# ─── TRAINING ARGS ──────────────────────────────────────────────────────────
training_args = Seq2SeqTrainingArguments(
    output_dir                 = OUTPUT_DIR,
    num_train_epochs           = NUM_EPOCHS,
    per_device_train_batch_size= BATCH_SIZE,
    per_device_eval_batch_size = BATCH_SIZE,

    logging_strategy           = "steps",
    logging_steps              = LOG_STEPS,
    logging_first_step         = True,

    evaluation_strategy        = "steps",
    eval_steps                 = EVAL_STEPS,

    save_strategy              = "steps",
    save_steps                 = SAVE_STEPS,
    save_total_limit           = SAVE_TOTAL_LIMIT,

    disable_tqdm               = True,
    predict_with_generate      = True,
    fp16                       = (device.type=="cuda"),
    load_best_model_at_end     = True,
    metric_for_best_model      = "rougeL",
    greater_is_better          = True,
    report_to                  = "none",
)

# ─── TRAINER ────────────────────────────────────────────────────────────────
trainer = Seq2SeqTrainer(
    model           = model,
    args            = training_args,
    train_dataset   = tokenized["train"],
    eval_dataset    = tokenized["validation"],
    tokenizer       = tokenizer,
    data_collator   = collator,
    compute_metrics = compute_metrics,
)

# ─── TRAIN & RESUME ─────────────────────────────────────────────────────────
print(" Fine-tuning started…")
trainer.train()  # resume later with resume_from_checkpoint=True

# ─── FINAL EVALUATION ───────────────────────────────────────────────────────
print("\n== Validation Metrics ==")
val_m = trainer.evaluate(tokenized["validation"])
print({k: val_m[k] for k in sorted(val_m) if k.startswith("eval_")})

print("\n== Test Metrics ==")
test_m = trainer.evaluate(tokenized["test"])
print({k: test_m[k] for k in sorted(test_m) if k.startswith("eval_")})

# ─── EXAMPLE GENERATIONS ────────────────────────────────────────────────────
print("\n== Example generations on test set ==")
for idx in range(min(5, len(data["test"]))):
    sample = data["test"][idx]
    print(f"\n--- SAMPLE {idx+1} ---")
    print("\nARTICLE:")
    print(sample["article"])
    inputs = tokenizer(
        sample["article"],
        return_tensors="pt",
        truncation=True,
        padding="longest",
        max_length=MAX_INPUT_LENGTH
    ).to(device)
    outputs = model.generate(
        **inputs,
        max_length=GEN_MAX_LENGTH,
        num_beams=4,
        early_stopping=True
    )
    print("\nGENERATED SUMMARY:")
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))
    print("\nREFERENCE HIGHLIGHTS:")
    print(sample["highlights"])


Using device: cpu


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


Loading CNN/DailyMail 3.0.0…
  → train=200, val=15, test=15
Tokenizing…


Map (num_proc=3):   0%|          | 0/200 [00:00<?, ? examples/s]

Map (num_proc=3):   0%|          | 0/15 [00:00<?, ? examples/s]

Map (num_proc=3):   0%|          | 0/15 [00:00<?, ? examples/s]

Loading model to cpu…


The following columns in the training set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: id. If id are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 200
  Num Epochs = 3
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 39
  Number of trainable parameters = 406290432
You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


 Fine-tuning started…
{'loss': 3.2265, 'learning_rate': 4.871794871794872e-05, 'epoch': 0.08}




Training completed. Do not forget to share your model on huggingface.co/models =)


The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: id. If id are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 15
  Batch size = 16
Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "length_penalty": 2.0,
  "max_length": 142,
  "min_length": 56,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}



{'train_runtime': 2099.4724, 'train_samples_per_second': 0.286, 'train_steps_per_second': 0.019, 'train_loss': 1.4119522754962628, 'epoch': 3.0}

== Validation Metrics ==


The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: id. If id are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 15
  Batch size = 16
Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "length_penalty": 2.0,
  "max_length": 142,
  "min_length": 56,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}



{'eval_loss': 1.8956702947616577, 'eval_rouge1': 0.4363, 'eval_rouge2': 0.2098, 'eval_rougeL': 0.3005, 'eval_meteor': 0.403, 'eval_runtime': 75.0714, 'eval_samples_per_second': 0.2, 'eval_steps_per_second': 0.013, 'epoch': 3.0}
{'eval_loss': 1.8956702947616577, 'eval_meteor': 0.403, 'eval_rouge1': 0.4363, 'eval_rouge2': 0.2098, 'eval_rougeL': 0.3005, 'eval_runtime': 75.0714, 'eval_samples_per_second': 0.2, 'eval_steps_per_second': 0.013}

== Test Metrics ==


Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "length_penalty": 2.0,
  "max_length": 142,
  "min_length": 56,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}



{'eval_loss': 1.6750235557556152, 'eval_rouge1': 0.4049, 'eval_rouge2': 0.1778, 'eval_rougeL': 0.2901, 'eval_meteor': 0.3941, 'eval_runtime': 79.9912, 'eval_samples_per_second': 0.188, 'eval_steps_per_second': 0.013, 'epoch': 3.0}
{'eval_loss': 1.6750235557556152, 'eval_meteor': 0.3941, 'eval_rouge1': 0.4049, 'eval_rouge2': 0.1778, 'eval_rougeL': 0.2901, 'eval_runtime': 79.9912, 'eval_samples_per_second': 0.188, 'eval_steps_per_second': 0.013}

== Example generations on test set ==

--- SAMPLE 1 ---

ARTICLE:
(CNN) I see signs of a revolution everywhere. I see it in the op-ed pages of the newspapers, and on the state ballots in nearly half the country. I see it in politicians who once preferred to play it safe with this explosive issue but are now willing to stake their political futures on it. I see the revolution in the eyes of sterling scientists, previously reluctant to dip a toe into this heavily stigmatized world, who are diving in head first. I see it in the new surgeon general 

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "length_penalty": 2.0,
  "max_length": 142,
  "min_length": 56,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}




GENERATED SUMMARY:
CNN's John Sutter says he sees signs of a revolution in attitudes toward medical marijuana.
For the first time, 53% of Americans favor its legalization, with 77% supporting it for medical purposes.
Sutter: Support for legalization has risen 11 points in the past few years alone.
He says the issue is burning white hot among young people, but also shows up among parents.

REFERENCE HIGHLIGHTS:
CNN's Dr. Sanjay Gupta says we should legalize medical marijuana now .
He says he knows how easy it is do nothing "because I did nothing for too long"

--- SAMPLE 2 ---

ARTICLE:
He looks barely teenage. But this child has amassed thousands of Twitter followers with his pictorial updates of 'gang life'. The baby-faced boy from Memphis, Tennessee, poses with guns, cash, and bags of what looks like marijuana. Scroll down for video . Baby-faced: This little boy has amassed more than 3,000 followers on Twitter with pictures like these . In many pictures he is smoking suspicious subs

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "length_penalty": 2.0,
  "max_length": 142,
  "min_length": 56,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}




GENERATED SUMMARY:
The baby-faced boy from Memphis, Tennessee, regularly posts pictures of 'gang life' on Twitter with captions like: 'I need a bad b****', 'f*** da police', and 'gang sh** n****'
In one video he laughs and points a gun at the camera in an apparent attempt to look menacing - as adults laugh in the background.
As a minor, DailyMail.com will not identify the boy, who has more than 3,000 followers.
He has prompted a wave of critics calling his stunts'sad' and'stressing out' his mother.

REFERENCE HIGHLIGHTS:
Child has amassed thousands of Twitter followers with 'gang life' photos .
In one video he points gun at camera as adults look on unfazed .
His tweets have prompted backlash with calls for intervention .

--- SAMPLE 3 ---

ARTICLE:
New Jersey Governor Chris Christie wasn't looking too presidential Tuesday night when he got into a heated debate with a veteran teacher at a town hall meeting. And now the state's largest teacher's union is calling him out for his 'bullyin

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "length_penalty": 2.0,
  "max_length": 142,
  "min_length": 56,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}




GENERATED SUMMARY:
New Jersey Gov Chris Christie got into a heated debate with a veteran English teacher at a town hall meeting Tuesday night.
Kathy Mooney questioned Christie's motivations behind a $225million legal settlement with oil company ExxonMobil.
The state's largest teacher's union is calling him out for his 'bullying' behavior.
Steve Wollmer of the NJ Education Association said: 'He's always taken a very nasty and disrespectful tone with teachers'

REFERENCE HIGHLIGHTS:
The presidential hopeful held a town hall meeting in Kenilworth on Tuesday .
During the meeting, high school English teacher Kathy Mooney got up to ask the governor a question about pensions .
She asked why he didn't seek a higher legal settlement in a case with ExxonMobil that would have contributed to the state's pension system .
Christie responded by repeatedly asking how much Mooney knew about the deal instead of answering her question .

--- SAMPLE 4 ---

ARTICLE:
YouTube star Cassey Ho has hit back at 

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "length_penalty": 2.0,
  "max_length": 142,
  "min_length": 56,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}




GENERATED SUMMARY:
The 28-year-old fitness blogger has released a new video in response to the barrage of negative comments she has received from viewers of her fitness-focused YouTube videos.
In the video, she explores whether conforming to society's standards is the key to happiness, while highlighting some of the cruel comments she's received online.
She says the trash-talking has brought her to tears on more than one occasion, and has made her feel self-conscious about her appearance.

REFERENCE HIGHLIGHTS:
Cassey Ho boasts over two million subscribers on her YouTube channel Blogilates .
The 28-year-old receives hundreds of comments a day telling her that she needs to lose weight .

--- SAMPLE 5 ---

ARTICLE:
British taekwondo fighter Aaron Cook has confirmed he plans to compete for Moldova at the 2016 Olympics in Rio. Dorset-born Cook, 24, who was overlooked for the Great Britain taekwondo squad at London 2012, applied for citizenship after receiving funding from Moldovan billion

In [6]:
# pip install transformers datasets evaluate tqdm nltk

import warnings, logging, contextlib, io
warnings.filterwarnings("ignore")

from transformers import logging as hf_logging
import datasets, evaluate

# ─── Silence library logs ───────────────────────────────────────────────────
hf_logging.set_verbosity_error()
datasets.logging.set_verbosity_error()
evaluate.logging.set_verbosity_error()
logging.getLogger("nltk").setLevel(logging.ERROR)

import torch, numpy as np
from datasets import load_dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)

# ─── CONFIG ────────────────────────────────────────────────────────────────
MODEL_NAME         = "facebook/bart-large-cnn"
OUTPUT_DIR         = "./bart_cnn_sum"
BATCH_SIZE         = 8
NUM_EPOCHS         = 3
MAX_INPUT_LENGTH   = 512
MAX_OUTPUT_LENGTH  = 142
GEN_MAX_LENGTH     = MAX_OUTPUT_LENGTH * 2

TRAIN_SIZE = 20    # or None for full
VAL_SIZE   = 3
TEST_SIZE  = 3

LOG_STEPS        = 50
EVAL_STEPS       = 200
SAVE_STEPS       = 200
SAVE_TOTAL_LIMIT = 3

# ─── DEVICE ────────────────────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ─── METRICS ───────────────────────────────────────────────────────────────
rouge  = evaluate.load("rouge")
meteor = evaluate.load("meteor")
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple): preds = preds[0]
    dec_preds  = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels     = np.where(labels != -100, labels, tokenizer.pad_token_id)
    dec_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    r = rouge.compute(predictions=dec_preds, references=dec_labels, use_stemmer=True)
    m = meteor.compute(predictions=dec_preds, references=dec_labels)
    return {
        "rouge1": round(r["rouge1"],4),
        "rouge2": round(r["rouge2"],4),
        "rougeL": round(r["rougeL"],4),
        "meteor": round(m["meteor"],4),
    }

# ─── LOAD & SAMPLE ─────────────────────────────────────────────────────────
print("Loading CNN/DailyMail 3.0.0…")
raw = load_dataset("cnn_dailymail", "3.0.0")
def sample_split(name, n):
    ds = raw[name]
    return ds.shuffle(seed=42).select(range(n)) if n is not None else ds

data = DatasetDict({
    "train":      sample_split("train",      TRAIN_SIZE),
    "validation": sample_split("validation", VAL_SIZE),
    "test":       sample_split("test",       TEST_SIZE),
})
print(f"  → train={len(data['train'])}, val={len(data['validation'])}, test={len(data['test'])}")

# ─── PREPROCESS ─────────────────────────────────────────────────────────────
print("Tokenizing…")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
def preprocess(batch):
    inp = tokenizer(batch["article"],
        max_length=MAX_INPUT_LENGTH, truncation=True, padding="max_length")
    lbl = tokenizer(batch["highlights"],
        max_length=MAX_OUTPUT_LENGTH, truncation=True, padding="max_length").input_ids
    inp["labels"] = lbl
    return inp

tokenized = data.map(
    preprocess,
    batched=True,
    num_proc=4,
    remove_columns=["article","highlights"]
)

# ─── MODEL & TRAINER ───────────────────────────────────────────────────────
print(f"Loading model onto {device}…")
model    = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
collator = DataCollatorForSeq2Seq(tokenizer, model=model)

training_args = Seq2SeqTrainingArguments(
    output_dir                 = OUTPUT_DIR,
    num_train_epochs           = NUM_EPOCHS,
    per_device_train_batch_size= BATCH_SIZE,
    per_device_eval_batch_size = BATCH_SIZE,

    logging_strategy   ="steps",
    logging_steps      = LOG_STEPS,
    logging_first_step = True,

    evaluation_strategy="steps",
    eval_steps         = EVAL_STEPS,

    save_strategy      ="steps",
    save_steps         = SAVE_STEPS,
    save_total_limit   = SAVE_TOTAL_LIMIT,

    disable_tqdm       = True,
    predict_with_generate=True,
    fp16               = (device.type=="cuda"),
    load_best_model_at_end= True,
    metric_for_best_model ="rougeL",
    greater_is_better     = True,
    report_to             ="none",
)

trainer = Seq2SeqTrainer(
    model           = model,
    args            = training_args,
    train_dataset   = tokenized["train"],
    eval_dataset    = tokenized["validation"],
    tokenizer       = tokenizer,
    data_collator   = collator,
    compute_metrics = compute_metrics,
)

# ─── TRAIN & EVAL ──────────────────────────────────────────────────────────
print("▶ Fine-tuning started…")
trainer.train()

print("\n== Validation Metrics ==")
val_m = trainer.evaluate(tokenized["validation"])
print({k: val_m[k] for k in sorted(val_m) if k.startswith("eval_")})

print("\n== Test Metrics ==")
test_m = trainer.evaluate(tokenized["test"])
print({k: test_m[k] for k in sorted(test_m) if k.startswith("eval_")})

# ─── EXAMPLE GENERATIONS ───────────────────────────────────────────────────
print("\n== Example generations ==")
for i in range(min(5, len(data["test"]))):
    sample = data["test"][i]
    print(f"\n--- SAMPLE {i+1} ---")
    print("\nARTICLE:")
    print(sample["article"])
    inputs = tokenizer(
        sample["article"],
        max_length=MAX_INPUT_LENGTH,
        truncation=True,
        padding="longest",
        return_tensors="pt"
    ).to(device)

    # suppress the GenerationConfig print
    with contextlib.redirect_stdout(io.StringIO()):
        outputs = model.generate(
            **inputs,
            max_length=GEN_MAX_LENGTH,
            num_beams=4,
            early_stopping=True
        )

    print("\nGENERATED SUMMARY:")
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))
    print("\nREFERENCE HIGHLIGHTS:")
    print(sample["highlights"])


Using device: cpu


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


Loading CNN/DailyMail 3.0.0…
  → train=20, val=3, test=3
Tokenizing…


Map (num_proc=4):   0%|          | 0/20 [00:00<?, ? examples/s]

Map (num_proc=3):   0%|          | 0/3 [00:00<?, ? examples/s]

Map (num_proc=3):   0%|          | 0/3 [00:00<?, ? examples/s]

Loading model onto cpu…


The following columns in the training set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: id. If id are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 20
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 9
  Number of trainable parameters = 406290432
You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


▶ Fine-tuning started…
{'loss': 7.7924, 'learning_rate': 4.4444444444444447e-05, 'epoch': 0.33}




Training completed. Do not forget to share your model on huggingface.co/models =)


The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: id. If id are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 3
  Batch size = 8
Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "length_penalty": 2.0,
  "max_length": 142,
  "min_length": 56,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}



{'train_runtime': 234.4663, 'train_samples_per_second': 0.256, 'train_steps_per_second': 0.038, 'train_loss': 4.841492811838786, 'epoch': 3.0}

== Validation Metrics ==


The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: id. If id are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 3
  Batch size = 8
Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "length_penalty": 2.0,
  "max_length": 142,
  "min_length": 56,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}



{'eval_loss': 2.9673423767089844, 'eval_rouge1': 0.4388, 'eval_rouge2': 0.2191, 'eval_rougeL': 0.3598, 'eval_meteor': 0.4504, 'eval_runtime': 19.5329, 'eval_samples_per_second': 0.154, 'eval_steps_per_second': 0.051, 'epoch': 3.0}
{'eval_loss': 2.9673423767089844, 'eval_meteor': 0.4504, 'eval_rouge1': 0.4388, 'eval_rouge2': 0.2191, 'eval_rougeL': 0.3598, 'eval_runtime': 19.5329, 'eval_samples_per_second': 0.154, 'eval_steps_per_second': 0.051}

== Test Metrics ==


Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "length_penalty": 2.0,
  "max_length": 142,
  "min_length": 56,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}



{'eval_loss': 3.1277267932891846, 'eval_rouge1': 0.2882, 'eval_rouge2': 0.0607, 'eval_rougeL': 0.1626, 'eval_meteor': 0.259, 'eval_runtime': 23.8858, 'eval_samples_per_second': 0.126, 'eval_steps_per_second': 0.042, 'epoch': 3.0}
{'eval_loss': 3.1277267932891846, 'eval_meteor': 0.259, 'eval_rouge1': 0.2882, 'eval_rouge2': 0.0607, 'eval_rougeL': 0.1626, 'eval_runtime': 23.8858, 'eval_samples_per_second': 0.126, 'eval_steps_per_second': 0.042}

== Example generations ==

--- SAMPLE 1 ---

ARTICLE:
(CNN) I see signs of a revolution everywhere. I see it in the op-ed pages of the newspapers, and on the state ballots in nearly half the country. I see it in politicians who once preferred to play it safe with this explosive issue but are now willing to stake their political futures on it. I see the revolution in the eyes of sterling scientists, previously reluctant to dip a toe into this heavily stigmatized world, who are diving in head first. I see it in the new surgeon general who cites data

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "length_penalty": 2.0,
  "max_length": 142,
  "min_length": 56,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}




GENERATED SUMMARY:
CNN's John Sutter sees signs of a medical marijuana revolution everywhere..."I see a revolution that is burning white hot among young people... among the parents and grandparents in my kids' school.. "I see it in the faces of good parents uprooting their lives to get medicine for their children." Support for legalization has risen 11 points in the past few years alone.. in 1969, the first time Pew asked the question about legalization, only 12% of the nation was in favor..

REFERENCE HIGHLIGHTS:
CNN's Dr. Sanjay Gupta says we should legalize medical marijuana now .
He says he knows how easy it is do nothing "because I did nothing for too long"

--- SAMPLE 2 ---

ARTICLE:
He looks barely teenage. But this child has amassed thousands of Twitter followers with his pictorial updates of 'gang life'. The baby-faced boy from Memphis, Tennessee, poses with guns, cash, and bags of what looks like marijuana. Scroll down for video . Baby-faced: This little boy has amassed more

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "length_penalty": 2.0,
  "max_length": 142,
  "min_length": 56,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}




GENERATED SUMMARY:
Baby-faced boy from Memphis, Tennessee, posts pictures of 'gang life' with guns, cash and marijuana. In many pictures he is smoking suspicious substances, with captions such as 'High Life'. Critics have called his stunts'sad' and 'disappointing'. The boy hit back at critics claiming he has had a tough year and the backlash is'stressing out' his mother.

REFERENCE HIGHLIGHTS:
Child has amassed thousands of Twitter followers with 'gang life' photos .
In one video he points gun at camera as adults look on unfazed .
His tweets have prompted backlash with calls for intervention .

--- SAMPLE 3 ---

ARTICLE:
New Jersey Governor Chris Christie wasn't looking too presidential Tuesday night when he got into a heated debate with a veteran teacher at a town hall meeting. And now the state's largest teacher's union is calling him out for his 'bullying' behavior. 'He's always taken a very nasty and disrespectful tone with teachers and other individuals who dare to question him a

In [None]:
# @title PPO
# pip install transformers datasets evaluate tqdm nltk trl accelerate

import warnings
warnings.filterwarnings("ignore")

# Silence excessive logs
import logging
from transformers import logging as hf_logging
import datasets, evaluate
hf_logging.set_verbosity_error()
datasets.logging.set_verbosity_error()
evaluate.logging.set_verbosity_error()
logging.getLogger("nltk").setLevel(logging.ERROR)

import torch
from datasets import load_dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    AutoModelForSequenceClassification
)
from trl import Seq2SeqPPOConfig, Seq2SeqPPOTrainer

# ========== CONFIGURATION ===========
MODEL_NAME        = "facebook/bart-large-cnn"
OUTPUT_DIR        = "./bart_cnn_sum"
BATCH_SIZE        = 8
NUM_EPOCHS        = 3
MAX_INPUT_LENGTH  = 512
MAX_OUTPUT_LENGTH = 142
GEN_MAX_LENGTH    = MAX_OUTPUT_LENGTH * 2

TRAIN_SIZE = None   # e.g. 1000 or None
VAL_SIZE   = None
TEST_SIZE  = None

LOG_STEPS        = 50
EVAL_STEPS       = 200
SAVE_STEPS       = 200
SAVE_TOTAL_LIMIT = 3

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# ========== SUPERVISED FINE-TUNING ==========
# Load and tokenize CNN/DailyMail dataset
raw = load_dataset("cnn_dailymail", "3.0.0")
def sample_split(split, n):
    ds = raw[split]
    return ds.shuffle(seed=42).select(range(n)) if n is not None else ds

data = DatasetDict({
    "train":      sample_split("train", TRAIN_SIZE),
    "validation": sample_split("validation", VAL_SIZE),
    "test":       sample_split("test", TEST_SIZE),
})
print(f"Train={len(data['train'])}, Val={len(data['validation'])}, Test={len(data['test'])}")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
def preprocess(batch):
    inputs = tokenizer(batch["article"], max_length=MAX_INPUT_LENGTH, truncation=True, padding="max_length")
    labels = tokenizer(batch["highlights"], max_length=MAX_OUTPUT_LENGTH, truncation=True, padding="max_length").input_ids
    inputs["labels"] = labels
    return inputs

tokenized = data.map(preprocess, batched=True, num_proc=4, remove_columns=["article","highlights"])

# Supervised training
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)

train_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    logging_steps=LOG_STEPS,
    evaluation_strategy="steps",
    eval_steps=EVAL_STEPS,
    save_strategy="steps",
    save_steps=SAVE_STEPS,
    save_total_limit=SAVE_TOTAL_LIMIT,
    predict_with_generate=True,
    fp16=(DEVICE.type=="cuda"),
    load_best_model_at_end=True,
    metric_for_best_model="rougeL",
    greater_is_better=True,
    report_to="none",
)

evaluator = evaluate.load("rouge")
def compute_sft_metrics(eval_preds):
    preds, labels = eval_preds
    preds = preds[0] if isinstance(preds, tuple) else preds
    dec_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = torch.where(torch.tensor(labels) != -100, torch.tensor(labels), tokenizer.pad_token_id).tolist()
    dec_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    return evaluator.compute(predictions=dec_preds, references=dec_labels, use_stemmer=True)

trainer = Seq2SeqTrainer(
    model=model,
    args=train_args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
    compute_metrics=compute_sft_metrics,
)

print("▶ Starting supervised fine-tuning...")
trainer.train()

# ========== RLHF WITH PPO ==========
# Load reward model trained on human preferences
reward_model = AutoModelForSequenceClassification.from_pretrained("openai/reward-model").to(DEVICE)
reward_model.eval()

# Convert policy to seq2seq PPO trainer
ppo_config = Seq2SeqPPOConfig(
    model_name=MODEL_NAME,
    learning_rate=1e-5,
    batch_size=BATCH_SIZE,
    ppo_epochs=1,
)

ppo_trainer = Seq2SeqPPOTrainer(
    ppo_config,
    policy_model=model,
    ref_model=None,  # optional reference model for KL control
    tokenizer=tokenizer,
    dataset=data["train"],
    data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
)

# PPO loop over a small subset
for sample in data["train"].shuffle(seed=0).select(range(100)):
    # Prepare inputs
    query = tokenizer(sample["article"], return_tensors="pt", truncation=True, padding=True).to(DEVICE)
    # Generate response
    response_tokens = ppo_trainer.generate(query["input_ids"], max_length=GEN_MAX_LENGTH, num_beams=4)

    # Compute reward
    # Reward model expects concatenated input + response
    concat = tokenizer(
        sample["article"],
        tokenizer.decode(response_tokens[0], skip_special_tokens=True),
        return_tensors="pt",
        truncation=True,
        padding=True
    ).to(DEVICE)
    reward = reward_model(**concat).logits.squeeze().detach()

    # Run PPO step
    ppo_trainer.step(query["input_ids"], response_tokens, reward)

print("▶ PPO training completed.")

# ========== EVALUATION ==========
print("== Final evaluation on test set ==")
metrics = ppo_trainer.evaluate(tokenized["test"])
print(metrics)

# Generate examples
for i in range(5):
    art = data["test"][i]["article"]
    print(f"\n--- Example {i+1} ---")
    print("ARTICLE:", art)
    inp = tokenizer(art, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
    with torch.no_grad():
        out = model.generate(**inp, max_length=GEN_MAX_LENGTH, num_beams=4)
    summary = tokenizer.decode(out[0], skip_special_tokens=True)
    print("SUMMARY:", summary)
    print("HIGHLIGHTS:", data["test"][i]["highlights"])


In [2]:
# Uninstall any mismatched versions
!pip uninstall -y transformers huggingface_hub trl accelerate

# Freshly install all essentials in one go
!pip install --upgrade \
    datasets \
    transformers \
    accelerate \
    huggingface_hub \
    evaluate \
    trl



[0mFound existing installation: huggingface-hub 0.30.2
Uninstalling huggingface-hub-0.30.2:
  Successfully uninstalled huggingface-hub-0.30.2
[0mFound existing installation: accelerate 1.6.0
Uninstalling accelerate-1.6.0:
  Successfully uninstalled accelerate-1.6.0
Collecting datasets
  Using cached datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting transformers
  Downloading transformers-4.51.3-py3-none-any.whl.metadata (38 kB)
Collecting accelerate
  Using cached accelerate-1.6.0-py3-none-any.whl.metadata (19 kB)
Collecting huggingface_hub
  Using cached huggingface_hub-0.30.2-py3-none-any.whl.metadata (13 kB)
Collecting trl
  Using cached trl-0.17.0-py3-none-any.whl.metadata (12 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers)
  Downloading tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Using cached datasets-3.5.1-py3-none-any.whl (491 kB)
Downloading transformers-4.51.3-py3-none-any.whl (10.4 MB)
[2K   [90m━━━━━━━━

In [3]:
import datasets, transformers, accelerate, huggingface_hub, evaluate, trl
print("OK:", datasets.__version__, transformers.__version__, accelerate.__version__, huggingface_hub.__version__, evaluate.__version__, trl.__version__)


RuntimeError: operator torchvision::nms does not exist

In [1]:
# Install compatible versions (upgrade to latest if needed)
# !pip install --upgrade transformers huggingface_hub trl accelerate

import warnings
warnings.filterwarnings("ignore")

# Silence library logs
import logging
from transformers import logging as hf_logging
import datasets, evaluate
hf_logging.set_verbosity_error()
datasets.logging.set_verbosity_error()
evaluate.logging.set_verbosity_error()
logging.getLogger("nltk").setLevel(logging.ERROR)

import torch
import numpy as np
from datasets import load_dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    AutoModelForSequenceClassification
)
from trl import Seq2SeqPPOConfig, Seq2SeqPPOTrainer

# ========== CONFIGURATION ===========
MODEL_NAME         = "facebook/bart-large-cnn"
OUTPUT_DIR         = "./bart_cnn_sum"
BATCH_SIZE         = 16
NUM_EPOCHS         = 3
MAX_INPUT_LENGTH   = 512
MAX_OUTPUT_LENGTH  = 64

# Dynamic generation lengths will be set after supervised training

TRAIN_SIZE = 20  # e.g. 1000 or None for full split
VAL_SIZE   = 3
TEST_SIZE  = 3

LOG_STEPS        = 50
EVAL_STEPS       = 200
SAVE_STEPS       = 200
SAVE_TOTAL_LIMIT = 3

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# ========== LOAD & SAMPLE DATA ===========
raw = load_dataset("cnn_dailymail", "3.0.0")
def sample_split(split, n):
    ds = raw[split]
    return ds.shuffle(seed=42).select(range(n)) if n is not None else ds

data = DatasetDict({
    "train":      sample_split("train",    TRAIN_SIZE),
    "validation": sample_split("validation", VAL_SIZE),
    "test":       sample_split("test",      TEST_SIZE),
})
print(f"Dataset sizes — train: {len(data['train'])}, val: {len(data['validation'])}, test: {len(data['test'])}")

# ========== TOKENIZER & PREPROCESS ===========
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
def preprocess(batch):
    inp = tokenizer(batch["article"], max_length=MAX_INPUT_LENGTH, truncation=True, padding="max_length")
    lbl = tokenizer(batch["highlights"], max_length=MAX_OUTPUT_LENGTH, truncation=True, padding="max_length").input_ids
    inp["labels"] = lbl
    return inp
tokenized = data.map(preprocess, batched=True, num_proc=4, remove_columns=["article","highlights"])

# ========== SUPERVISED FINE-TUNING ===========
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
train_args = Seq2SeqTrainingArguments(
    output_dir                 = OUTPUT_DIR,
    num_train_epochs           = NUM_EPOCHS,
    per_device_train_batch_size= BATCH_SIZE,
    per_device_eval_batch_size = BATCH_SIZE,
    logging_strategy           = "steps",
    logging_steps              = LOG_STEPS,
    evaluation_strategy        = "steps",
    eval_steps                 = EVAL_STEPS,
    save_strategy              = "steps",
    save_steps                 = SAVE_STEPS,
    save_total_limit           = SAVE_TOTAL_LIMIT,
    predict_with_generate      = True,
    fp16                       = (DEVICE.type == "cuda"),
    load_best_model_at_end     = True,
    metric_for_best_model      = "rougeL",
    greater_is_better          = True,
    report_to                  = "none",
)
evaluator = evaluate.load("rouge")

def compute_sft_metrics(preds_labels):
    preds, labels = preds_labels
    if isinstance(preds, tuple): preds = preds[0]
    decoded_preds  = tokenizer.batch_decode(preds,  skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    return evaluator.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

trainer = Seq2SeqTrainer(
    model           = model,
    args            = train_args,
    train_dataset   = tokenized["train"],
    eval_dataset    = tokenized["validation"],
    tokenizer       = tokenizer,
    data_collator   = DataCollatorForSeq2Seq(tokenizer, model=model),
    compute_metrics = compute_sft_metrics,
)
print("▶ Starting supervised fine-tuning…")
trainer.train()

# ========== DYNAMIC GENERATION LENGTH SETUP ===========
highlight_lens = [len(tokenizer.tokenize(h)) for h in raw["train"]["highlights"]]
avg_highlight_len = int(sum(highlight_lens)/len(highlight_lens))
min_length = avg_highlight_len
max_length = avg_highlight_len * 2
print(f"Avg highlight length = {avg_highlight_len}; generation min={min_length}, max={max_length}")

# ========== RLHF + PPO ===========
# Load reward model
reward_model = AutoModelForSequenceClassification.from_pretrained("openai/reward-model").to(DEVICE)
reward_model.eval()

# PPO configuration
ppo_config = Seq2SeqPPOConfig(
    model_name  = MODEL_NAME,
    learning_rate=1e-5,
    batch_size   = BATCH_SIZE,
    ppo_epochs   = 1,
)
ppo_trainer = Seq2SeqPPOTrainer(
    ppo_config,
    policy_model = model,
    ref_model    = None,
    tokenizer    = tokenizer,
    dataset      = data["train"],
    data_collator= DataCollatorForSeq2Seq(tokenizer, model=model),
)

# PPO loop on subset
for sample in data["train"].shuffle(seed=0).select(range(100)):
    query = tokenizer(sample["article"], return_tensors="pt", truncation=True, padding=True).to(DEVICE)
    response = ppo_trainer.generate(
        query["input_ids"],
        min_length=min_length,
        max_length=max_length,
        num_beams=5,
        length_penalty=1.1,
        early_stopping=True
    )
    concat_inputs = tokenizer(
        sample["article"],
        tokenizer.decode(response[0], skip_special_tokens=True),
        return_tensors="pt", truncation=True, padding=True
    ).to(DEVICE)
    reward = reward_model(**concat_inputs).logits.squeeze().detach()
    ppo_trainer.step(query["input_ids"], response, reward)
print(" PPO training completed.")

# ========== EVALUATION & EXAMPLE GENERATION ===========
print("== Validation / Test Metrics ==")
print(ppo_trainer.evaluate(tokenized["validation"]))
print(ppo_trainer.evaluate(tokenized["test"]))

print("\n== Example Generations ==")
for i, sample in enumerate(data["test"][:5]):
    print(f"--- Sample {i+1} ---")
    print("ARTICLE:", sample["article"])
    inp = tokenizer(sample["article"], return_tensors="pt", truncation=True, padding=True).to(DEVICE)
    with torch.no_grad():
        out = model.generate(
            **inp,
            min_length=min_length,
            max_length=max_length,
            num_beams=5,
            length_penalty=1.1,
            early_stopping=True
        )
    print("SUMMARY:", tokenizer.decode(out[0], skip_special_tokens=True))
    print("HIGHLIGHTS:", sample["highlights"])

ModuleNotFoundError: No module named 'transformers'