# CNNDM TEXT SUMMARIZATION WITH T5-BASE MODEL

**Author**: Pendo Abbo

### Overview

This tutorial demonstrates how to use a pre-trained T5 Model for text summarization on the CNN-DailyMail dataset. We will demonstrate how to use the torchtext library to:

1. build a text pre-processing pipeline for a T5 model
2. read in the CNN-DM dataset and pre-process the text
3. instantiate a pre-trained t5 model with base configuration, and perform text summarization on input text

### Common Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


### Data Transformation

The T5 model does not work with raw text. Instead, it requires the text to be transformed into numerical form in order to perform training and inference. The following transformations are required for the T5 model:

1. Tokenize text
2. Convert tokens into (integer) IDs
3. Truncate the sequences to a specified maximum length
4. Add end-of-sequence (EOS) and padding token IDs

T5 uses a sentencepiece model for text tokenization. Below, we use a pre-trained sentencepiece model to build the text pre-processing pipeline using torchtext's `T5Transform`. Note that the transform supports both batched and non-batched text input (i.e. one can either pass a single sentence or a list of sentences), however the T5 model expects the input to be batched.

In [2]:
from torchtext.prototype.models import T5Transform

padding_idx = 0
eos_idx = 1
max_seq_len = 512
t5_sp_model_path = r"https://download.pytorch.org/models/text/t5_tokenizer_base.model"


transform = T5Transform(
        sp_model_path=t5_sp_model_path,
        max_seq_len=max_seq_len,
        eos_idx=eos_idx,
        padding_idx=padding_idx,
)


Alternatively, we can also use the transform shipped with the pre-trained models that does all of the above out-of-the-box

In [3]:
from torchtext.prototype.models import T5_BASE_GENERATION

transform = T5_BASE_GENERATION.transform()


### Dataset

torchtext provides several standard NLP datasets. For a complete list, refer to the documentation at https://pytorch.org/text/stable/datasets.html. These datasets are built using composable torchdata datapipes and hence support standard flow-control and mapping/transformation using user defined functions and transforms. Below, we demonstrate how to pre-process the CNNDM dataset to include the prefix necessary for the model to indentify the task it is performing.

The CNNDM dataset has a train, validation, and test split. Below we demo on the test split.

In [4]:
from functools import partial
from torch.utils.data import DataLoader
from torchtext.datasets.cnndm import CNNDM

batch_size = 5
test_datapipe = CNNDM(split="test")
task = 'summarize'

def apply_prefix(task, x):
    return f'{task}: ' + x[0], x[1]

test_datapipe = test_datapipe.map(partial(apply_prefix, task))
test_datapipe = test_datapipe.batch(batch_size)
test_datapipe = test_datapipe.rows2columnar(["article", "abstract"])
test_dataloader = DataLoader(test_datapipe, batch_size=None)


Alternately we can also use batched API (i.e apply the prefix on the whole batch)

In [5]:
def batch_prefix(task, x):
    return {
        "article": [f'{task}: ' + y for y in x["article"]],
        "abstract": x["abstract"]
    }

batch_size = 5
test_datapipe = CNNDM(split="test")
task = 'summarize'

test_datapipe = test_datapipe.batch(batch_size).rows2columnar(["article", "abstract"])
test_datapipe = test_datapipe.map(partial(batch_prefix, task))
test_dataloader = DataLoader(test_datapipe, batch_size=None)


### Model Preparation

torchtext provides SOTA pre-trained models that can be used directly for NLP tasks or fine-tuned on downstream tasks. Below we use the pre-trained T5 model with standard base architecture to perform text summarization. For additional details on available pre-trained models, please refer to documentation at https://pytorch.org/text/main/models.html

In [None]:
t5_base = T5_BASE_GENERATION
transform = t5_base.transform()
model = t5_base.get_model()
model.to(DEVICE)


### Sequence Generator

We can define a sequence generator to produce an output sequence based on the input sequence provided. This calls on the model's encoder and decoder, and iteratively expands the decoded sequences until the end-of-sequence token is generated for all sequences in the batch. The `generate` method shown below uses a greedy search (i.e. expands the sequence based on the most probable next word).

In [7]:
from torch import Tensor
from torchtext.prototype.models import T5Model

def generate(
        encoder_tokens: Tensor,
        eos_idx: int,
        model: T5Model,
    ) -> Tensor:
        
        # pass tokens through encoder
        encoder_padding_mask = encoder_tokens.eq(model.padding_idx)
        encoder_embeddings = model.dropout1(model.token_embeddings(encoder_tokens))
        encoder_output = model.encoder(encoder_embeddings, tgt_key_padding_mask=encoder_padding_mask)[0]

        encoder_output = model.norm1(encoder_output)
        encoder_output = model.dropout2(encoder_output)
        
        # initialize decoder input sequence; T5 uses padding index as starter index to decoder sequence
        decoder_tokens = torch.ones((encoder_tokens.size(0), 1), dtype=torch.long) * model.padding_idx
        
        # mask to keep track of sequences for which the decoder has not produced an end-of-sequence token yet
        incomplete_sentences = torch.ones((encoder_tokens.size(0), 1), dtype=torch.long)

        # iteratively generate output sequence until all sequences in the batch have generated the end-of-sequence token
        for step in range(model.config.max_seq_len):
            
            # causal mask and padding mask for decoder sequence
            tgt_len = decoder_tokens.shape[1]
            decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1).bool()
            decoder_padding_mask = decoder_tokens.eq(model.padding_idx)
            
            # T5 implemention uses padding idx to start sequence. Want to ignore this when masking
            decoder_padding_mask[:, 0] = False
            
            # pass decoder sequence through decoder
            decoder_embeddings = model.dropout3(model.token_embeddings(decoder_tokens))
            decoder_output = model.decoder(
                decoder_embeddings,
                memory=encoder_output,
                tgt_mask=decoder_mask,
                tgt_key_padding_mask=decoder_padding_mask,
                memory_key_padding_mask=encoder_padding_mask,
            )[0]

            decoder_output = model.norm2(decoder_output)
            decoder_output = model.dropout4(decoder_output)
            decoder_output = decoder_output * (model.config.embedding_dim ** -0.5)
            decoder_output = model.lm_head(decoder_output)
            
            # greedy search for next token to add to sequence
            probs = F.log_softmax(decoder_output[:,-1], dim=-1)
            _, next_token = torch.topk(decoder_output[:,-1], 1)
            
            # ignore next tokens for sentences that are already complete
            next_token *= incomplete_sentences
            
            # update incomplete_sentences to remove those that were just ended
            incomplete_sentences = incomplete_sentences - (next_token == eos_idx).long()
            
            # update decoder sequences to include new tokens
            decoder_tokens = torch.cat((decoder_tokens, next_token), 1)
            
            # early stop if all sentences have been ended
            if (incomplete_sentences == 0).all():
                break

        return decoder_tokens


### Generate Summaries

Finally we put all of the components together the generate summaries on the first batch of articles in the CNNDM test set.

In [8]:
batch = next(iter(test_dataloader))
input_text = batch["article"]
model_input = transform(input_text)
model_output = generate(
    model=model,
    encoder_tokens=model_input,
    eos_idx=eos_idx
)
output_text = transform.decode(model_output.tolist())
target = batch["abstract"]


In [9]:
for i in range(batch_size):
    
    print(f"Example {i+1}:\n")
    print(f"prediction: {output_text[i]}\n")
    print(f"target: {target[i]}\n\n")


Example 1:

prediction: the Palestinians officially become the 123rd member of the international criminal court . the move gives the court jurisdiction over alleged crimes committed in the occupied Palestinian territory . the ICC opened a preliminary examination into the situation in the occupied territories .

target: Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June . Israel and the United States opposed the move, which could open the door to war crimes investigations against Israelis .


Example 2:

prediction: a stray pooch in Washington state has used up at least three of her own after being hit by a car . the dog staggers to a nearby farm, dirt-covered and emaciated, where she is found . she suffered a dislocated jaw, leg injuries and a caved-in sinus cavity .

target: Theia, a bully breed mix, was apparently hit by a car, whacked with a hammer and buried in a field . "She's a true miracle dog and she deserves a good li