T5-base code

In [None]:
# --------------------------------------------
# 1. Install Required Libraries
# --------------------------------------------
!pip install -q transformers datasets evaluate rouge-score accelerate nltk

# --------------------------------------------
# 2. Import Libraries
# --------------------------------------------
import torch
from datasets import load_dataset
from transformers import (
    T5Tokenizer,
    T5ForConditionalGeneration,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq
)
import evaluate
import nltk
nltk.download("punkt")
from tqdm.auto import tqdm

# --------------------------------------------
# 3. Setup Device
# --------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# --------------------------------------------
# 4. Load Dataset
# --------------------------------------------
dataset = load_dataset("cnn_dailymail", "3.0.0")
dataset["train"] = dataset["train"].select(range(2000))
dataset["validation"] = dataset["validation"].select(range(500))
dataset["test"] = dataset["test"].select(range(200))
print("Loaded and trimmed dataset.")

# --------------------------------------------
# 5. Tokenizer and Preprocessing
# --------------------------------------------
tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=False)  # Disable legacy mode

max_input_len = 512
max_target_len = 128

def preprocess(example):
    inputs = ["summarize: " + article for article in example["article"]]
    model_inputs = tokenizer(inputs, max_length=max_input_len, truncation=True, padding="max_length")
    labels = tokenizer(
        text_target=example["highlights"],  # New recommended way
        max_length=max_target_len,
        truncation=True,
        padding="max_length"
    )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_dataset = dataset.map(preprocess, batched=True, remove_columns=dataset["train"].column_names)

# --------------------------------------------
# 6. Load Model
# --------------------------------------------
model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)

# --------------------------------------------
# 7. Evaluation Metrics
# --------------------------------------------
rouge = evaluate.load("rouge")
meteor = evaluate.load("meteor")

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    decoded_labels = [[label] for label in decoded_labels]

    rouge_result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    meteor_result = meteor.compute(predictions=decoded_preds, references=decoded_labels)

    return {
        "rouge1": round(rouge_result["rouge1"], 4),
        "rouge2": round(rouge_result["rouge2"], 4),
        "rougeL": round(rouge_result["rougeL"], 4),
        "meteor": round(meteor_result["meteor"], 4)
    }

# --------------------------------------------
# 8. Training Args (UPDATED)
# --------------------------------------------
args = Seq2SeqTrainingArguments(
    output_dir="./t5_cnn_baseline",
    eval_strategy="epoch",  # Changed from evaluation_strategy
    learning_rate=3e-4,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=5,
    save_total_limit=2,
    predict_with_generate=True,
    fp16=torch.cuda.is_available(),
    logging_steps=50,
    logging_dir="./logs",
    report_to="none"  # Disable wandb reporting
)

# --------------------------------------------
# 9. Trainer
# --------------------------------------------
trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
    compute_metrics=compute_metrics
)

# --------------------------------------------
# 10. Train & Save
# --------------------------------------------
print("\nTraining model...")
trainer.train()

print("\nSaving model...")
trainer.save_model("t5_cnn_baseline_final")
tokenizer.save_pretrained("t5_cnn_baseline_final")

# --------------------------------------------
# 11. Final Evaluation
# --------------------------------------------
print("\nEvaluating on test set...")
test_results = trainer.evaluate(eval_dataset=tokenized_dataset["test"], metric_key_prefix="test")
print("\nTest Results:")
for key, value in test_results.items():
    if isinstance(value, float):
        print(f"{key}: {value:.4f}")

# --------------------------------------------
# 12. Generate Sample Summaries
# --------------------------------------------
def generate_summary(text):
    inputs = tokenizer("summarize: " + text, return_tensors="pt", truncation=True, max_length=max_input_len).to(device)
    outputs = model.generate(**inputs, max_length=max_target_len, num_beams=4, early_stopping=True)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

print("\nSample Summaries:")
for i in range(3):
    sample = dataset["test"][i]
    print(f"\nArticle {i+1} (excerpt):\n{sample['article'][:200]}...")
    print(f"\nReference Summary:\n{sample['highlights']}")
    print(f"\nGenerated Summary:\n{generate_summary(sample['article'])}")
    print("\n" + "-"*80)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Using device: cuda
Loaded and trimmed dataset.


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
  trainer = Seq2SeqTrainer(



Training model...


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Meteor
1,0.8991,0.75919,0.2438,0.0943,0.198,0.1473
2,0.6552,0.80167,0.2411,0.0896,0.1951,0.1452
3,0.5018,0.834631,0.2507,0.0937,0.2036,0.1517
4,0.4144,0.868797,0.2442,0.0933,0.1992,0.1494
5,0.4152,0.875853,0.2443,0.0934,0.1995,0.1492



Saving model...

Evaluating on test set...



Test Results:
test_loss: 0.9011
test_rouge1: 0.2468
test_rouge2: 0.0890
test_rougeL: 0.1998
test_meteor: 0.1495
test_runtime: 41.1904
test_samples_per_second: 4.8560
test_steps_per_second: 1.2140
epoch: 5.0000

Sample Summaries:

Article 1 (excerpt):
(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territor...

Reference Summary:
Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June .
Israel and the United States opposed the move, which could open the door to war crimes investigations against Israelis .

Generated Summary:
Palestinian Authority becomes 123rd member of International Criminal Court . Court gives jurisdiction over alleged crimes committed in Palestinian territories . Palestinians may be subject to counter-charges as well .

---------------------------------------------------