<a href="https://colab.research.google.com/github/sdaigo/playground-transformers/blob/main/text_summarization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers[sentencepiece]==4.19.0 datasets --quiet

[K     |████████████████████████████████| 4.2 MB 8.3 MB/s 
[K     |████████████████████████████████| 342 kB 63.1 MB/s 
[K     |████████████████████████████████| 596 kB 63.3 MB/s 
[K     |████████████████████████████████| 6.6 MB 55.0 MB/s 
[K     |████████████████████████████████| 84 kB 2.8 MB/s 
[K     |████████████████████████████████| 1.2 MB 55.0 MB/s 
[K     |████████████████████████████████| 1.1 MB 52.4 MB/s 
[K     |████████████████████████████████| 212 kB 55.9 MB/s 
[K     |████████████████████████████████| 136 kB 59.5 MB/s 
[K     |████████████████████████████████| 127 kB 58.1 MB/s 
[K     |████████████████████████████████| 94 kB 3.1 MB/s 
[K     |████████████████████████████████| 144 kB 57.3 MB/s 
[K     |████████████████████████████████| 271 kB 59.3 MB/s 
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires foli

In [2]:
from datasets import load_dataset


dataset = load_dataset("cnn_dailymail", version="3.0.0")

Downloading builder script:   0%|          | 0.00/3.23k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

Using custom data configuration default


Downloading and preparing dataset cnn_dailymail/default to /root/.cache/huggingface/datasets/cnn_dailymail/default/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de...


Downloading data files:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/159M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/376M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/12.3M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/661k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/572k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Dataset cnn_dailymail downloaded and prepared to /root/.cache/huggingface/datasets/cnn_dailymail/default/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
print(f"Features: {dataset['train'].column_names}")

Features: ['article', 'highlights', 'id']


* `article`: news article
* `highlights`: summaries
* `id`: identifier

In [4]:
sample = dataset["train"][1]

print(f"""
Article (excerpt of 500 characters, total length: {len(sample["article"])}):
""")

print(sample["article"][:500])

print(f"\nSummary (length: {len(sample['highlights'])}):")
print(sample["highlights"])


Article (excerpt of 500 characters, total length: 4051):

Editor's note: In our Behind the Scenes series, CNN correspondents share their experiences in covering news and analyze the stories behind the events. Here, Soledad O'Brien takes users inside a jail where many of the inmates are mentally ill. An inmate housed on the "forgotten floor," where many mentally ill inmates are housed in Miami before trial. MIAMI, Florida (CNN) -- The ninth floor of the Miami-Dade pretrial detention facility is dubbed the "forgotten floor." Here, inmates with the most s

Summary (length: 281):
Mentally ill inmates in Miami are housed on the "forgotten floor"
Judge Steven Leifman says most are there as a result of "avoidable felonies"
While CNN tours facility, patient shouts: "I am the son of the president"
Leifman says the system is unjust and he's fighting for change .


# Text Summarization Pipelines

In [5]:
sample_text = dataset["train"][1]["article"][:2000]

summaries = {}

In [6]:
import nltk
from nltk.tokenize import sent_tokenize

In [7]:
nltk.download("punkt")

string = "The U.S. are country. The U.N. is a organization."
sent_tokenize(string)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


['The U.S. are country.', 'The U.N. is a organization.']

In [8]:
def three_sentence_summary(text):
  return "\n".join(sent_tokenize(text)[:3])

In [9]:
summaries["baseline"] = three_sentence_summary(sample_text)

In [10]:
summaries

{'baseline': 'Editor\'s note: In our Behind the Scenes series, CNN correspondents share their experiences in covering news and analyze the stories behind the events.\nHere, Soledad O\'Brien takes users inside a jail where many of the inmates are mentally ill. An inmate housed on the "forgotten floor," where many mentally ill inmates are housed in Miami before trial.\nMIAMI, Florida (CNN) -- The ninth floor of the Miami-Dade pretrial detention facility is dubbed the "forgotten floor."'}

## GPT-2

In [11]:
from transformers import pipeline, set_seed

set_seed(42)
pipe = pipeline("text-generation", model="gpt2-xl")
gpt2_query = sample_text + "\nTL;DR:\n" # keyword to generate summarized text
pipe_out = pipe(gpt2_query, max_length=512, clean_up_tokenization_spaces=True)
summaries["gpt2"] = "\n".join(
    sent_tokenize(pipe_out[0]["generated_text"][len(gpt2_query) :])
)

Downloading:   0%|          | 0.00/689 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/5.99G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [12]:
summaries

{'baseline': 'Editor\'s note: In our Behind the Scenes series, CNN correspondents share their experiences in covering news and analyze the stories behind the events.\nHere, Soledad O\'Brien takes users inside a jail where many of the inmates are mentally ill. An inmate housed on the "forgotten floor," where many mentally ill inmates are housed in Miami before trial.\nMIAMI, Florida (CNN) -- The ninth floor of the Miami-Dade pretrial detention facility is dubbed the "forgotten floor."',
 'gpt2': '- Miami-Dade jails have more mentally ill inmates than other counties - About one-third of the mentally ill in Miami-Dade county are housed here - All the mentally ill prisoners have no shoes, no beds, and no mattresses\xa0 \n"You don\'t want to be the judge because you\'re dealing with a lot of people you don\'t know," Leif'}

## T5

In [13]:
pipe = pipeline("summarization", model="t5-large")
pipe_out = pipe(sample_text)

summaries["t5"] = "\n".join(sent_tokenize(pipe_out[0]["summary_text"]))

Downloading:   0%|          | 0.00/1.17k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.75G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.32M [00:00<?, ?B/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-large automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


## BART

In [14]:
pipe = pipeline("summarization", model="facebook/bart-large-cnn")
pipe_out = pipe(sample_text)

summaries["bart"] = "\n".join(sent_tokenize(pipe_out[0]["summary_text"]))

Downloading:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.51G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

## PEGASUS

In [15]:
!pip install transformers[sentencepiece] --quiet

In [16]:
pipe = pipeline("summarization", model="google/pegasus-cnn_dailymail")
pipe_out = pipe(sample_text)

summaries["pegasus"] = pipe_out[0]["summary_text"].replace("<n>", "\n")

Downloading:   0%|          | 0.00/1.09k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.12G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/88.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.82M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/65.0 [00:00<?, ?B/s]

# Comparing summaries

In [17]:
print(dataset["train"][1]["highlights"])

Mentally ill inmates in Miami are housed on the "forgotten floor"
Judge Steven Leifman says most are there as a result of "avoidable felonies"
While CNN tours facility, patient shouts: "I am the son of the president"
Leifman says the system is unjust and he's fighting for change .


In [18]:
print(dataset["train"][1]["article"])

Editor's note: In our Behind the Scenes series, CNN correspondents share their experiences in covering news and analyze the stories behind the events. Here, Soledad O'Brien takes users inside a jail where many of the inmates are mentally ill. An inmate housed on the "forgotten floor," where many mentally ill inmates are housed in Miami before trial. MIAMI, Florida (CNN) -- The ninth floor of the Miami-Dade pretrial detention facility is dubbed the "forgotten floor." Here, inmates with the most severe mental illnesses are incarcerated until they're ready to appear in court. Most often, they face drug charges or charges of assaulting an officer --charges that Judge Steven Leifman says are usually "avoidable felonies." He says the arrests often result from confrontations with police. Mentally ill people often won't do what they're told when police arrive on the scene -- confrontation seems to exacerbate their illness and they become more paranoid, delusional, and less likely to follow dir

In [19]:
for model_name in summaries:
  print(f"{model_name.upper()}:")
  print(summaries[model_name])
  print("")

BASELINE:
Editor's note: In our Behind the Scenes series, CNN correspondents share their experiences in covering news and analyze the stories behind the events.
Here, Soledad O'Brien takes users inside a jail where many of the inmates are mentally ill. An inmate housed on the "forgotten floor," where many mentally ill inmates are housed in Miami before trial.
MIAMI, Florida (CNN) -- The ninth floor of the Miami-Dade pretrial detention facility is dubbed the "forgotten floor."

GPT2:
- Miami-Dade jails have more mentally ill inmates than other counties - About one-third of the mentally ill in Miami-Dade county are housed here - All the mentally ill prisoners have no shoes, no beds, and no mattresses  
"You don't want to be the judge because you're dealing with a lot of people you don't know," Leif

T5:
mentally ill inmates are housed on the ninth floor of a florida jail .
most face drug charges or charges of assaulting an officer .
judge says arrests often result from confrontations wit

## Measuring the quality of generated text

In [22]:
from datasets import list_metrics

list_metrics()

['accuracy',
 'bertscore',
 'bleu',
 'bleurt',
 'cer',
 'chrf',
 'code_eval',
 'comet',
 'competition_math',
 'coval',
 'cuad',
 'exact_match',
 'f1',
 'frugalscore',
 'glue',
 'google_bleu',
 'indic_glue',
 'mae',
 'mahalanobis',
 'matthews_correlation',
 'mauve',
 'mean_iou',
 'meteor',
 'mse',
 'pearsonr',
 'perplexity',
 'precision',
 'recall',
 'roc_auc',
 'rouge',
 'sacrebleu',
 'sari',
 'seqeval',
 'spearmanr',
 'squad',
 'squad_v2',
 'super_glue',
 'ter',
 'wer',
 'wiki_split',
 'xnli',
 'xtreme_s']

In [24]:
!pip install rouge_score

Collecting rouge_score
  Downloading rouge_score-0.0.4-py2.py3-none-any.whl (22 kB)
Installing collected packages: rouge-score
Successfully installed rouge-score-0.0.4


In [25]:
from datasets import load_metric

rouge_metric = load_metric("rouge")

reference = dataset["train"][1]["highlights"]
records = []

rouge_names = ["rouge1", "rouge2", "rougeL", "rougeLsum"]

for model_name in summaries:
  rouge_metric.add(prediction=summaries[model_name], reference=reference)
  score = rouge_metric.compute()
  rouge_dict = dict((name, score[name].mid.fmeasure) for name in rouge_names)
  records.append(rouge_dict)

In [26]:
import pandas as pd

pd.DataFrame.from_records(records, index=summaries.keys())

Unnamed: 0,rouge1,rouge2,rougeL,rougeLsum
baseline,0.365079,0.145161,0.206349,0.285714
gpt2,0.275229,0.093458,0.201835,0.275229
t5,0.382979,0.130435,0.255319,0.382979
bart,0.475248,0.222222,0.316832,0.415842
pegasus,0.326531,0.208333,0.285714,0.326531
