# Introduction
In this notebook we explore T5 vs. FlanT5 on the MuP dataset.

### Dataset

In [1]:
# load MuP dataset from huggingface
from datasets import load_dataset

dataset_name = "allenai/mup" # allenai/mup-full

dataset = load_dataset(dataset_name, split="validation")
dataset[:1]

  from .autonotebook import tqdm as notebook_tqdm
Found cached dataset csv (/home/tannaaman/.cache/huggingface/datasets/allenai___csv/allenai--mup-c30ba3347ec8183d/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


{'paper_name': ['Practical Locally Private Federated Learning with Communication Efficiency'],
 'text': ["1 INTRODUCTION . 1.1 BACKGROUND . Federated learning ( FL ) Kairouz et al . ( 2019 ) ; Konečnỳ et al . ( 2016 ) is a rapidly evolving application of distributed optimization to large-scale learning or estimation scenarios where multiple entities . called clients , collaborate in solving a machine learning problem , under the coordination of a central server . Each client ’ s raw data is stored locally and not exchanged or transferred . To achieve the learning objective , the server collects minimal information from the clients for immediate aggregation . FL is particularly suitable for mobile and edge device applications since the ( sensitive ) individual data never directly leave the device and has seen deployments in industries ( ? Hard et al. , 2019 ; Leroy et al. , 2019 ) . While FL offers significant practical privacy improvements over centralizing all the training data , it

### Summarization prompts
We take the prompts from the templates [here](https://github.com/google-research/FLAN/blob/main/flan/v2/templates.py: 
* XSum
* Gigaword
* CNN daily mail

Note that we only choose those that do summarization, not expansion (given summary, write article).

In [2]:
all_prompts = [
    ("Summarize:\n\n{text}\n\nSummary:", "{summary}"),
        ("Summarize this article:\n\n{text}\n\nSummary:", "{summary}"),
        ("Summarize this article in one sentence.\n\n{text}\n\nSummary:",
         "{summary}"),
        ("{text}\nWhat is a summary of this text?", "{summary}"),
        ("{text}\nWhat was that article about?", "{summary}"),
        ("{text}\n\nThis article was about:", "{summary}"),
        ("Article:{text}\n\nA summary of the above article is?", "{summary}"),
        ("Article:{text}\n\nSummarize the main points of that article.",
         "{summary}"),
        ("Write an article based on this summary:\n\n{summary}\n\nArticle:",
         "{text}"),
        ("Write an article based on this \"{summary}\"\n\nArticle:", "{text}"),
         ("Write a short summary for this text: {text}\n\nSummary:",
         "{summary}"),
        ("Briefly summarize this sentence: {text}\n\nSummary:", "{summary}"),
        ("Generate a short summary this sentence:\n{text}\n\nSummary:",
         "{summary}"),
        ("What is a shorter version of this:\n\n{text}\n\nSummary:",
         "{summary}"),
        ("{text}\n\nWrite a brief summary in a sentence or less.", "{summary}"),
        ("{text}\n\nWhat is a very short summary of the above text?",
         "{summary}"),
        ("{text}\nSummarize the aforementioned text in a single phrase.",
         "{summary}"),
        ("{text}\nCan you generate a short summary of the above paragraph?",
         "{summary}"),
        ("Write a text based on this summary: {summary}\n\nText:", "{text}"),
        ("Write a text based on \"{summary}\"\n\nText:", "{text}"),
        ("Write highlights for this article:\n\n{text}\n\nHighlights:",
         "{highlights}"),
        ("Write some highlights for the following "
         "article:\n\n{text}\n\nHighlights:", "{highlights}"),
        ("{text}\n\nWrite highlights for this article.", "{highlights}"),
        ("{text}\n\nWhat are highlight points for this article?",
         "{highlights}"),
        ("{text}\nSummarize the highlights of this article.", "{highlights}"),
        ("{text}\nWhat are the important parts of this article?",
         "{highlights}"),
        ("{text}\nHere is a summary of the highlights for this article:",
         "{highlights}"),
        ("Write an article using the following "
         "points:\n\n{highlights}\n\nArticle:", "{text}"),
        ("Use the following highlights to write an "
         "article:\n\n{highlights}\n\nArticle:", "{text}"),
        ("{highlights}\n\nWrite an article based on these highlights.",
         "{text}"),
        
]

summarization_prompts = [p for p in all_prompts if "{text}" in p[0].lower()]

from pprint import pprint
pprint(summarization_prompts)


[('Summarize:\n\n{text}\n\nSummary:', '{summary}'),
 ('Summarize this article:\n\n{text}\n\nSummary:', '{summary}'),
 ('Summarize this article in one sentence.\n\n{text}\n\nSummary:', '{summary}'),
 ('{text}\nWhat is a summary of this text?', '{summary}'),
 ('{text}\nWhat was that article about?', '{summary}'),
 ('{text}\n\nThis article was about:', '{summary}'),
 ('Article:{text}\n\nA summary of the above article is?', '{summary}'),
 ('Article:{text}\n\nSummarize the main points of that article.', '{summary}'),
 ('Write a short summary for this text: {text}\n\nSummary:', '{summary}'),
 ('Briefly summarize this sentence: {text}\n\nSummary:', '{summary}'),
 ('Generate a short summary this sentence:\n{text}\n\nSummary:', '{summary}'),
 ('What is a shorter version of this:\n\n{text}\n\nSummary:', '{summary}'),
 ('{text}\n\nWrite a brief summary in a sentence or less.', '{summary}'),
 ('{text}\n\nWhat is a very short summary of the above text?', '{summary}'),
 ('{text}\nSummarize the afore

### Evaluation 

In [3]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from typing import List
import torch

def generate_summaries(model: AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer, prompt: str, documents_to_summarize: List[str], max_length=150, num_beams=4, no_repeat_ngram_size=2, early_stopping=True):
    """
    Takes in a prompt and list of documents to summarize and returns a list of summaries. Keyword for replacement in the prompt is {text}.
    """
    summaries = []
    batch_size = 4

    # 0. model to device
    model.to("cuda")
    with torch.no_grad():

        for i in range(0, len(documents_to_summarize), batch_size):
            batch = documents_to_summarize[i:i+batch_size]

            # 1. include prompt for every document
            prompt_batch = [prompt.format(text=text) for text in batch]

            # 2. tokenize prompt and move to device
            inputs = tokenizer(prompt_batch, return_tensors="pt", max_length=max_length, truncation=True, padding="max_length")
            for k, v in inputs.items():
                inputs[k] = v.to("cuda")

            # 3. generate summary
            outputs = model.generate(**inputs, max_length=max_length, num_beams=num_beams, no_repeat_ngram_size=no_repeat_ngram_size, early_stopping=early_stopping).to("cpu")
            summaries.extend(tokenizer.batch_decode(outputs, skip_special_tokens=True))
        return summaries

In [4]:
from evaluation_utils import evaluate_rouge_score
model_names = [
    "google/flan-t5-large",
    "t5-small",
    "t5-base",
    "t5-large",
    "google/flan-t5-base",
    "google/flan-t5-small", 
]
sample_size = 16
dataset = dataset.shuffle(seed=42).select(range(sample_size))

out = dict() # model -> prompt -> rouge scores
for model_name in model_names: 
    if model_name not in out: 
        out[model_name] = dict()
    prompt_count = 0
    for prompt, output_format in summarization_prompts:
        
        print(f"generating for {model_name} with prompt {prompt_count} and output format {output_format}...")
    
        # 1. load model and tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

        # 2. generate summaries
        summaries = generate_summaries(model, tokenizer, prompt, documents_to_summarize=dataset["text"])

        # 3. do evaluation
        rouge_scores = evaluate_rouge_score(summaries, dataset["summary"])

        # 4. save results
        out[model_name][prompt] = rouge_scores
        prompt_count += 1

# save to pickle
import pickle
with open("t5-vs-flant5.pkl", "wb") as f:
    pickle.dump(out, f)


Loading cached shuffled indices for dataset at /home/tannaaman/.cache/huggingface/datasets/allenai___csv/allenai--mup-c30ba3347ec8183d/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-e5c346cff315c88f.arrow


generating for google/flan-t5-large with prompt 0 and output format {summary}...
generating for google/flan-t5-large with prompt 1 and output format {summary}...
generating for google/flan-t5-large with prompt 2 and output format {summary}...
generating for google/flan-t5-large with prompt 3 and output format {summary}...
generating for google/flan-t5-large with prompt 4 and output format {summary}...
generating for google/flan-t5-large with prompt 5 and output format {summary}...
generating for google/flan-t5-large with prompt 6 and output format {summary}...
generating for google/flan-t5-large with prompt 7 and output format {summary}...
generating for google/flan-t5-large with prompt 8 and output format {summary}...
generating for google/flan-t5-large with prompt 9 and output format {summary}...
generating for google/flan-t5-large with prompt 10 and output format {summary}...
generating for google/flan-t5-large with prompt 11 and output format {summary}...
generating for google/flan

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


generating for t5-small with prompt 1 and output format {summary}...
generating for t5-small with prompt 2 and output format {summary}...
generating for t5-small with prompt 3 and output format {summary}...
generating for t5-small with prompt 4 and output format {summary}...
generating for t5-small with prompt 5 and output format {summary}...
generating for t5-small with prompt 6 and output format {summary}...
generating for t5-small with prompt 7 and output format {summary}...
generating for t5-small with prompt 8 and output format {summary}...
generating for t5-small with prompt 9 and output format {summary}...
generating for t5-small with prompt 10 and output format {summary}...
generating for t5-small with prompt 11 and output format {summary}...
generating for t5-small with prompt 12 and output format {summary}...
generating for t5-small with prompt 13 and output format {summary}...
generating for t5-small with prompt 14 and output format {summary}...
generating for t5-small with 

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


generating for t5-base with prompt 1 and output format {summary}...
generating for t5-base with prompt 2 and output format {summary}...
generating for t5-base with prompt 3 and output format {summary}...
generating for t5-base with prompt 4 and output format {summary}...
generating for t5-base with prompt 5 and output format {summary}...
generating for t5-base with prompt 6 and output format {summary}...
generating for t5-base with prompt 7 and output format {summary}...
generating for t5-base with prompt 8 and output format {summary}...
generating for t5-base with prompt 9 and output format {summary}...
generating for t5-base with prompt 10 and output format {summary}...
generating for t5-base with prompt 11 and output format {summary}...
generating for t5-base with prompt 12 and output format {summary}...
generating for t5-base with prompt 13 and output format {summary}...
generating for t5-base with prompt 14 and output format {summary}...
generating for t5-base with prompt 15 and o

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-large automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


generating for t5-large with prompt 1 and output format {summary}...
generating for t5-large with prompt 2 and output format {summary}...
generating for t5-large with prompt 3 and output format {summary}...
generating for t5-large with prompt 4 and output format {summary}...
generating for t5-large with prompt 5 and output format {summary}...
generating for t5-large with prompt 6 and output format {summary}...
generating for t5-large with prompt 7 and output format {summary}...
generating for t5-large with prompt 8 and output format {summary}...
generating for t5-large with prompt 9 and output format {summary}...
generating for t5-large with prompt 10 and output format {summary}...
generating for t5-large with prompt 11 and output format {summary}...
generating for t5-large with prompt 12 and output format {summary}...
generating for t5-large with prompt 13 and output format {summary}...
generating for t5-large with prompt 14 and output format {summary}...
generating for t5-large with 

Downloading (…)okenizer_config.json: 100%|██████████| 2.54k/2.54k [00:00<00:00, 217kB/s]
Downloading spiece.model: 100%|██████████| 792k/792k [00:00<00:00, 37.9MB/s]
Downloading (…)/main/tokenizer.json: 100%|██████████| 2.42M/2.42M [00:01<00:00, 1.65MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 2.20k/2.20k [00:00<00:00, 504kB/s]
Downloading (…)lve/main/config.json: 100%|██████████| 1.40k/1.40k [00:00<00:00, 141kB/s]
Downloading pytorch_model.bin: 100%|██████████| 990M/990M [00:11<00:00, 89.5MB/s] 
Downloading (…)neration_config.json: 100%|██████████| 147/147 [00:00<00:00, 11.5kB/s]


generating for google/flan-t5-base with prompt 1 and output format {summary}...
generating for google/flan-t5-base with prompt 2 and output format {summary}...
generating for google/flan-t5-base with prompt 3 and output format {summary}...
generating for google/flan-t5-base with prompt 4 and output format {summary}...
generating for google/flan-t5-base with prompt 5 and output format {summary}...
generating for google/flan-t5-base with prompt 6 and output format {summary}...
generating for google/flan-t5-base with prompt 7 and output format {summary}...
generating for google/flan-t5-base with prompt 8 and output format {summary}...
generating for google/flan-t5-base with prompt 9 and output format {summary}...
generating for google/flan-t5-base with prompt 10 and output format {summary}...
generating for google/flan-t5-base with prompt 11 and output format {summary}...
generating for google/flan-t5-base with prompt 12 and output format {summary}...
generating for google/flan-t5-base wi

Downloading (…)okenizer_config.json: 100%|██████████| 2.54k/2.54k [00:00<00:00, 574kB/s]
Downloading spiece.model: 100%|██████████| 792k/792k [00:00<00:00, 36.1MB/s]
Downloading (…)/main/tokenizer.json: 100%|██████████| 2.42M/2.42M [00:00<00:00, 3.22MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 2.20k/2.20k [00:00<00:00, 361kB/s]
Downloading (…)lve/main/config.json: 100%|██████████| 1.40k/1.40k [00:00<00:00, 238kB/s]
Downloading pytorch_model.bin: 100%|██████████| 308M/308M [00:02<00:00, 110MB/s]  
Downloading (…)neration_config.json: 100%|██████████| 147/147 [00:00<00:00, 25.4kB/s]


generating for google/flan-t5-small with prompt 1 and output format {summary}...
generating for google/flan-t5-small with prompt 2 and output format {summary}...
generating for google/flan-t5-small with prompt 3 and output format {summary}...
generating for google/flan-t5-small with prompt 4 and output format {summary}...
generating for google/flan-t5-small with prompt 5 and output format {summary}...
generating for google/flan-t5-small with prompt 6 and output format {summary}...
generating for google/flan-t5-small with prompt 7 and output format {summary}...
generating for google/flan-t5-small with prompt 8 and output format {summary}...
generating for google/flan-t5-small with prompt 9 and output format {summary}...
generating for google/flan-t5-small with prompt 10 and output format {summary}...
generating for google/flan-t5-small with prompt 11 and output format {summary}...
generating for google/flan-t5-small with prompt 12 and output format {summary}...
generating for google/fla

In [12]:
# average by model, across prompts
average_by_model = dict()
for model_name in model_names:
    results = out[model_name]
    average_by_model[model_name] = sum(v for  r in results.values() for k, v in r.items() if "_fmeasure" in k) / len(results.values())

# average by prompt, across models
average_by_prompt = dict()
for prompt, filler in summarization_prompts:
    results = [out[model_name][prompt] for model_name in model_names]
    average_by_prompt[prompt] = sum(v for r in results for k, v in r.items()) / len(results)

# results
print("Average by model")
pprint(average_by_model)
print("\nAverage by prompt")
pprint(average_by_prompt)

# save to pickle
with open("t5-vs-flant5-average-by-model-large-scale.pkl", "wb") as f:
    pickle.dump(average_by_model, f)
with open("t5-vs-flant5-average-by-prompt.pkl-scale", "wb") as f:
    pickle.dump(average_by_prompt, f)


Average by model
{'google/flan-t5-base': tensor(0.1789),
 'google/flan-t5-large': tensor(0.1709),
 'google/flan-t5-small': tensor(0.1798),
 't5-base': tensor(0.3292),
 't5-large': tensor(0.4032),
 't5-small': tensor(0.2919)}

Average by prompt
{'Article:{text}\n\nA summary of the above article is?': tensor(0.9923),
 'Article:{text}\n\nSummarize the main points of that article.': tensor(0.9923),
 'Briefly summarize this sentence: {text}\n\nSummary:': tensor(1.3868),
 'Generate a short summary this sentence:\n{text}\n\nSummary:': tensor(0.9381),
 'Summarize this article in one sentence.\n\n{text}\n\nSummary:': tensor(1.3837),
 'Summarize this article:\n\n{text}\n\nSummary:': tensor(1.2982),
 'Summarize:\n\n{text}\n\nSummary:': tensor(1.4133),
 'What is a shorter version of this:\n\n{text}\n\nSummary:': tensor(1.4437),
 'Write a short summary for this text: {text}\n\nSummary:': tensor(1.1277),
 'Write highlights for this article:\n\n{text}\n\nHighlights:': tensor(1.4129),
 'Write some hig

In [7]:
import pickle
with open("t5-vs-flant5.pkl", "wb") as f:
    pickle.dump(out, f)