Comparison of T5, BART, and LLM Summarization Models on CNN/DailyMail Dataset using Rouge Score metrics

In [2]:
# Install required libraries
!pip install --upgrade --quiet torch transformers rouge-score datasets tiktoken langchain langchain-google-genai langchain-huggingface beautifulsoup4

import os
from transformers import BartForConditionalGeneration, BartTokenizer, T5ForConditionalGeneration, T5Tokenizer
from datasets import load_dataset
from rouge_score import rouge_scorer
import numpy as np
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate

# Set Google API key
os.environ["GOOGLE_API_KEY"] = "<Google API key>"  # Replace <Google API key> with your Google api key

# Load a small portion (first 5 examples) of the CNN/DailyMail dataset (v3.0.0) for testing.
dataset = load_dataset("cnn_dailymail", "3.0.0", split="test[:5]")

# Extract articles and reference summaries from the dataset.
texts = dataset['article']
references = dataset['highlights']

# Load BART and T5 models and tokenizers
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
t5_model = T5ForConditionalGeneration.from_pretrained("t5-base")
t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")

# Function to summarize using BART
def summarize_bart(text):
    inputs = bart_tokenizer([text], max_length=1024, return_tensors='pt', truncation=True, padding=True)
    summary_ids = bart_model.generate(inputs['input_ids'], max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
    summary = bart_tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return summary

# Function to summarize using T5
def summarize_t5(text):
    input_text = "summarize: " + text
    inputs = t5_tokenizer.encode(input_text, max_length=512, return_tensors='pt', truncation=True)
    summary_ids = t5_model.generate(inputs, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
    summary = t5_tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return summary

# Load LLM (Gemini model)
def load_llm(model="gemini-1.5-flash"):
    if model == "gemini-1.5-pro":
        llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro", temperature=0, max_tokens=None, timeout=None, max_retries=2)
    elif model == "gemini-1.5-flash":
        llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0, max_tokens=None, timeout=None, max_retries=2)
    else:
        raise ValueError("Invalid model name")
    return llm

# Get prompt template for LLM summarization
def get_prompt_template():
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", "Write a concise summary of the following in {num_words} words:\n\n"),
            ("human", "{context}")
        ]
    )
    return prompt

# Function to summarize using the LLM
def summarize_llm(text, num_words=50, model="gemini-1.5-flash"):
    llm = load_llm(model)
    prompt = get_prompt_template()
    chain = prompt | llm
    result = chain.invoke({"context": text, "num_words": num_words})
    return result.content

# Initialize ROUGE scorer
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

# Iterate over the dataset with enumeration to get article numbers
for i, example in enumerate(dataset, start=1):  # Start numbering from 1
    article = example['article']
    reference_summary = example['highlights']

    # Skip empty articles or highlights
    if not article.strip() or not reference_summary.strip():
        continue

    # Generate summaries
    bart_summary = summarize_bart(article)
    t5_summary = summarize_t5(article)
    llm_summary = summarize_llm(article, num_words=50)

    # Compute ROUGE scores
    bart_scores = scorer.score(reference_summary, bart_summary)
    t5_scores = scorer.score(reference_summary, t5_summary)
    llm_scores = scorer.score(reference_summary, llm_summary)

     # Print results with article number
    print(f"\n=== Article {i} ===")
    print("\nOriginal Article: \n", article[:500], "...")  # Truncated for readability
    print("\nReference Summary: \n", reference_summary)
    print("\nBART Summary: \n", bart_summary)
    print("\nT5 Summary: \n", t5_summary)
    print("\nLLM Summary: \n", llm_summary)

    print("\n--- ROUGE Scores ---")
    print("BART: ", {k: round(v.fmeasure, 4) for k, v in bart_scores.items()})
    print("T5:   ", {k: round(v.fmeasure, 4) for k, v in t5_scores.items()})
    print("LLM:  ", {k: round(v.fmeasure, 4) for k, v in llm_scores.items()})
    print("-" * 150)



=== Article 1 ===

Original Article: 
 (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, includin ...

Reference Summary: 
 Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June .
Israel and the United States opposed the move, which could open the door to war crimes investigations against Israelis .

BART Summary: 
 The Palestinian Authority becomes the 123rd member of the International Criminal Court. The move gives the court jurisdiction over alleged crimes in Palestinian territories. Israel a