# Text summarization with T5 on XSum

We are going to fine-tune the [T5 model, implemented by HuggingFace](https://huggingface.co/t5-small), for text summarization on the [Extreme Summarization (XSum)](https://huggingface.co/datasets/xsum) dataset.
The data is composed by news articles and the corresponding summaries.

We will be using the following model sizes available from HuggingFace

| Variant                                     |   Parameters    |
|:-------------------------------------------:|----------------:|
| [T5-small](https://huggingface.co/t5-small) |    60,506,624   | 
| [T5-large](https://huggingface.co/t5-large) |   737,668,096   | 
| [T5-3b](https://huggingface.co/t5-3b)       | 2,851,598,336   | 


More info:
* This notebooks is based on the script [run_summarization_no_trainer.py](https://github.com/huggingface/transformers/blob/v4.12.5/examples/pytorch/summarization/run_summarization_no_trainer.py) from HuggingFace
* [T5 on HuggingFace docs](https://huggingface.co/transformers/model_doc/t5.html)

In [None]:
import os
import datasets
import numpy as np
import torch
from datasets import load_dataset, load_metric
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq
from torch.utils.data import DataLoader

In [None]:
from datasets.utils import disable_progress_bar
from datasets import disable_caching


disable_progress_bar()
disable_caching()

In [None]:
hf_model = 't5-large'
t5_cache = os.path.join(os.getcwd(), 'cache')

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    hf_model,
    use_fast=True,
    cache_dir=os.path.join(t5_cache, f'{hf_model}_tokenizer')
)

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(
    hf_model,
    cache_dir=os.path.join(t5_cache, f'{hf_model}_model')
)

In [None]:
parameters = filter(lambda p: p.requires_grad, model.parameters())
num_params = sum([np.prod(p.size()) for p in parameters])
print(f'{num_params:,} parameters\n')

parameters = filter(lambda p: p.requires_grad, model.parameters())

In [None]:
hf_dataset = load_dataset('xsum')

In [None]:
def preprocess_function(examples):    
    inputs = examples['document']
    targets = examples['summary']
    inputs = [f'summarize: {inp}' for inp in inputs]

    model_inputs = tokenizer(inputs, max_length=1024,
                             padding=False, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=128,
                           padding=False, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
processed_datasets = hf_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=hf_dataset["train"].column_names,
    num_proc=12
)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer)

train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation"]

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    collate_fn=data_collator,
    batch_size=4
)

eval_dataloader = DataLoader(
    eval_dataset,
    shuffle=True,
    collate_fn=data_collator,
    batch_size=1
)

In [None]:
no_decay = ["bias", "LayerNorm.weight"]

optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters()
                   if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
    {
        "params": [p for n, p in model.named_parameters()
                   if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]

In [None]:
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=5e-5)

In [None]:
device = 0
model.to(device)

model.train()
model.training

In [None]:
for step, batch in enumerate(train_dataloader):  
    optimizer.zero_grad()
    outputs = model(**batch.to(device))
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    
    # stop after 10 steps for the demo:
    if step > 100:
        break

## Evaluation

In [None]:
# only to print with style
from rich import print as pprint
from rich.console import Console

In [None]:
model.eval()

gen_kwargs = {
    "max_length": 128,
    "num_beams": None,
}
for step, batch in enumerate(eval_dataloader):
    if step > 10:
        break

    with torch.no_grad():
        generated_tokens = model.generate(
            input_ids=batch["input_ids"].to(device),
            attention_mask=batch["attention_mask"].to(device),
        )

        labels = batch["labels"]
        generated_tokens = generated_tokens.cpu().numpy()
        
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        if isinstance(generated_tokens, tuple):
            generated_tokens = generated_tokens[0]

        decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        pprint(':page_facing_up:', tokenizer.batch_decode(batch["input_ids"])[0])
        pprint(':robot_face:', decoded_preds[0])
        pprint(':white_check_mark:', decoded_labels[0])
        Console().rule(style='black')