# Training LED for summarization

This document provided by Senacor Technologies demonstrates how the model used in the Proof of Concept (PoC) solution was trained and provides instructions on how to reproduce and adapt the approach.

## Table Of Contents

- [Requirements](#Requirements)
- [Overview](#Overview)
- [Step 1: Language Model Pretraining](#Step-1:-Language-Model-Pretraining)
  - [Training Data for Pretraining](#Training-Data-for-Pretraining)
  - [Text Infilling Data Collator](#Text-Infilling-Data-Collator)
  - [Masked Language Model Training](#Masked-Language-Model-Training)
- [Step 2: Summarization Finetuning](#Step-2:-Summarization-Finetuning)
  - [Training Data for Finetuning](#Training-Data-for-Finetuning)
  - [Summarization Training](#Summarization-Training)
- [Generating a Summary](#Generating-a-Summary)


## Requirements

This notebook assumes an environment with Python 3.8.10 and the following third-party dependencies:

In [None]:
!pip install torch==1.9.0 transformers==4.6.1 datasets==1.8.0 rouge_score

## Overview

As described in the provided documentation, the summarization model trained for and utilized in the PoC solution is a _Longformer-Encoder-Decoder_ (`LED`) as introduced in [_Longformer: The Long-Document Transformer_](https://arxiv.org/pdf/2004.05150.pdf) by Beltagy et al. of [Allen AI](https://allenai.org/).

We use the implementation of `LED` provided as part of [`🤗 Transformers`](https://huggingface.co/transformers/index.html). Of the three checkpoints published togther with aforementioned paper, we found through experimentation that `led-large-16384-arxiv` is the most promising starting point. This checkpoint is pretrained on [the arXiv dataset](https://huggingface.co/datasets/scientific_papers) for long-document summarization and, as reported in the paper, exhibits state-of-the-art performance on this dataset.

To adapt this model to the PoC objective, we found that the following two-step approach yields the best results:

First, to adapt the model's language to the ECB's domain, we perform an additional short language model pretraining on texts published by the ECB. While we find that long language model pretraining negatively impacts the summarization performance, a relatively short masked language model training greatly improves the quality of the text generated by the final model.

Second, we use a relatively small dataset of manually created summaries of selected ECB publications to finetune the model directly on the summarization objective.

In the following, these two steps are described in more detail, including the fully functional code to recreate the training steps that we performed to obtain the model version used in the PoC solution.

## Step 1: Language Model Pretraining

As the first step, we perform an additional short language model pretraining on the `led-large-16384-arxiv` checkpoint provided by _Allen AI_ to adapt it to the ECB's domain.


### Training Data for Pretraining
Self-supervised language model training requires a body of representative text. In the following, this notebook assumes that such a dataset is stored locally as a [`🤗 Dataset`](https://huggingface.co/docs/datasets/). It is assumed that this dataset contains both a `train` and an `eval` split.

For the PoC, we obtained the training data by crawling [the ECB's website](https://www.ecb.europa.eu/home/html/index.en.html) for publicly available text documents, sanitizing the texts and pre-processing it using the tokenizer provided by _Allen AI_ with the `led-large-16384-arxiv` checkpoint.

### Text Infilling Data Collator

Since `LED` is based on `BART`, we would like to use the text infilling task for language model training, which has been shown to be a major contribution to `BART`'s outstanding performance on summarization tasks in the pretraining experiments reported in [the original `BART` paper](https://arxiv.org/pdf/1910.13461.pdf) by Lewis et al. of _Facebook AI_.

There is currently no text infilling data collator in [`🤗 Transformers`](https://huggingface.co/transformers/index.html), so we provide our own. The implementation borrows ideas from `fairseq`'s more complex
    [DenoisingDataset](https://github.com/pytorch/fairseq/blob/1bba712622b8ae4efb3eb793a8a40da386fe11d0/fairseq/data/denoising_dataset.py). It is likely that a further refined version of this collator will become part of `🤗 Transformers` in an upcoming release ([PR 12370](https://github.com/huggingface/transformers/pull/12370)).

In [None]:
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union

import torch
from transformers.data.data_collator import _collate_batch
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase


@dataclass
class DataCollatorForTextInfilling:
    tokenizer: PreTrainedTokenizerBase
    mlm_probability: float = 0.15
    poisson_lambda: float = 3.0
    pad_to_multiple_of: Optional[int] = None

    def __post_init__(self):
        if self.tokenizer.mask_token is None:
            raise ValueError

    def __call__(self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
                 ) -> Dict[str, torch.Tensor]:
        # Handle dict or lists with proper padding and conversion to tensor.
        if isinstance(examples[0], (dict, BatchEncoding)):
            batch = self.tokenizer.pad(examples,
                                       return_tensors="pt",
                                       pad_to_multiple_of=self.pad_to_multiple_of)
        else:
            batch = {"input_ids": _collate_batch(examples,
                                                 self.tokenizer,
                                                 pad_to_multiple_of=self.pad_to_multiple_of)}

        # If special token mask has been preprocessed, pop it from the dict.
        special_tokens_mask = batch.pop("special_tokens_mask", None)

        batch["input_ids"], batch["labels"] = self.mask_tokens(
            batch["input_ids"], special_tokens_mask=special_tokens_mask
        )

        return batch

    def mask_tokens(self,
                    inputs: torch.Tensor,
                    special_tokens_mask: Optional[torch.Tensor] = None
                    ) -> Tuple[torch.Tensor, torch.Tensor]:
        labels = inputs.clone()

        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
                for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()


        # determine how many tokens we need to mask in total
        is_token = ~(inputs == self.tokenizer.pad_token_id) & ~special_tokens_mask
        num_to_mask = int(math.ceil(is_token.float().sum() * self.mlm_probability))

        if num_to_mask == 0:
            return inputs, labels

        # generate a sufficient number of span lengths
        poisson_distribution = torch.distributions.Poisson(rate=self.poisson_lambda)
        lengths = poisson_distribution.sample(sample_shape=(num_to_mask,))
        while torch.cumsum(lengths, 0)[-1] < num_to_mask:
            lengths = torch.cat([lengths, poisson_distribution.sample(sample_shape=(num_to_mask,))])

        # remove all spans of length 0
        # Note that BART inserts additional mask tokens where length == 0,
        # which we do not implement for now as it adds additional complexity
        lengths = lengths[lengths > 0]

        # trim to about num_to_mask tokens
        idx = torch.argmin(torch.abs(torch.cumsum(lengths, 0) - num_to_mask)) + 1
        lengths = lengths[:idx + 1]

        # select span start indices
        token_indices = is_token.nonzero(as_tuple=False)
        span_starts = torch.randperm(token_indices.shape[0])[:lengths.shape[0]]

        # prepare mask
        masked_indices = token_indices[span_starts]
        mask = torch.full_like(inputs, fill_value=False)

        # mask span start indices
        for mi in masked_indices:
            mask[tuple(mi)] = True
        lengths -= 1

        # fill up spans
        max_index = inputs.shape[1] - 1
        remaining = (lengths > 0) & (masked_indices[:, 1] < max_index)
        while torch.any(remaining):
            masked_indices[remaining, 1] += 1
            for mi in masked_indices:
                mask[tuple(mi)] = True
            lengths -= 1
            remaining = (lengths > 0) & (masked_indices[:, 1] < max_index)

        # place the mask tokens
        mask[special_tokens_mask] = False
        inputs[mask.bool()] = self.tokenizer.mask_token_id
        labels[~mask.bool()] = -100

        # remove mask tokens that are not starts of spans
        to_remove = mask.bool() & mask.bool().roll(1, 1)
        new_inputs = torch.full_like(inputs, fill_value=self.tokenizer.pad_token_id)
        for i, example in enumerate(torch.split(inputs, split_size_or_sections=1, dim=0)):
            new_example = example[0][~to_remove[i]]
            new_inputs[i, 0:new_example.shape[0]] = new_example

        return new_inputs, labels

### Masked Language Model Training

The masked language model training is based on the training utilities provided in [`
🤗 Transformers`](https://huggingface.co/transformers/index.html).

The following is a fully functional but greatly simplified version of the training script used to generate the model version used in the proof of concept, where training was performed on a GPU cluster.

In [None]:
from datasets import load_from_disk
from transformers import LEDTokenizer, LEDForConditionalGeneration
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

def mlm(batch_size: int, num_train_epochs: float, warmup_steps: int, learning_rate: float,
        pretrained_model_name: str, dataset_path: str, 
        new_model_name: str):

    # load training/validation dataset
    ds = load_from_disk(dataset_path)

    # initialize model and tokenizer
    pretrained_model = LEDForConditionalGeneration.from_pretrained(pretrained_model_name,
                                                                   gradient_checkpointing=True)

    pretrained_model.led.num_beams = 1  # ensure that we use greedy search in generation
    
    tokenizer = LEDTokenizer.from_pretrained(pretrained_model_name)

    # initialize the data collator
    dc = DataCollatorForTextInfilling(tokenizer=tokenizer,
                                      mlm_probability=.15,
                                      pad_to_multiple_of=8)
    
    # initialize the trainer
    training_args = Seq2SeqTrainingArguments(
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=4,
        num_train_epochs=num_train_epochs,
        warmup_steps=warmup_steps,
        learning_rate=learning_rate,
        output_dir="./mlm_outputs",
        overwrite_output_dir=True
    )
    
    trainer = Seq2SeqTrainer(
        model=pretrained_model,
        tokenizer=tokenizer,
        train_dataset=ds["train"],
        eval_dataset=ds["eval"],
        data_collator=dc,
        args=training_args
    )
   
    # perform training
    train_result = trainer.train()
    trainer.save_metrics("train", train_result.metrics)

    # store the model
    trainer.save_model(f"./{new_model_name}")

    # evaluate the model
    eval_metrics = trainer.evaluate()
    trainer.save_metrics("eval", eval_metrics)

We use perform masked language model training on half of the about 1600 samples in our pretraining dataset.

In [None]:
mlm(batch_size=8, num_train_epochs=0.5, warmup_steps=80, learning_rate=5e-7,
    pretrained_model_name="allenai/led-large-16384-arxiv",
    dataset_path="./ecb-publications",
    new_model_name="led-ecb-lm-arxiv")

## Step 2: Summarization Finetuning

As the second step, we perform summarization finetuning on the pretrained model.

### Training Data for Finetuning

Summarization finetuning requires a selection of summaries for representative texts. In the following, this notebook assumes that such a dataset is stored locally as a [`🤗 Dataset`](https://huggingface.co/docs/datasets/). It is assumed that this dataset contains both a `train` and an `eval` split.

For the PoC, we created the training data by hand-crafting summaries of different lengths and with different topical focus for selected publicly available ECB publications. The texts and summaries were then pre-processed for `LED` summarization finetuning using the tokenizer provided by _Allen AI_ with the `led-large-16384-arxiv` checkpoint.


### Summarization Training

The summarization training is based on the training utilities provided in [`🤗 Transformers`](https://huggingface.co/transformers/index.html).

The following is a fully functional but greatly simplified version of the training script used to generate the model version used in the proof of concept, where training was performed on a GPU cluster.

In [None]:
import torch

from datasets import load_from_disk, load_metric
from transformers import LEDForConditionalGeneration, LEDTokenizer
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments


def _create_compute_metrics(tokenizer: LEDTokenizer):
    rouge = load_metric("rouge")

    def compute_metrics(pred):
        labels_ids = pred.label_ids
        pred_ids = pred.predictions

        pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
        labels_ids[labels_ids == -100] = tokenizer.pad_token_id
        label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

        rouge_output = rouge.compute(
            predictions=pred_str, references=label_str, rouge_types=["rouge2"]
        )["rouge2"].mid

        return {
            "rouge2_precision": round(rouge_output.precision, 4),
            "rouge2_recall": round(rouge_output.recall, 4),
            "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
        }

    return compute_metrics


def summarization(batch_size: int, num_train_epochs: float, warmup_steps: int, learning_rate: float,
                  pretrained_model_path: str, dataset_path: str, new_model_name: str):
    
    # load training/validation dataset
    ds = load_from_disk(dataset_path)

    # initialize model and tokenizer
    pretrained_model = LEDForConditionalGeneration.from_pretrained(pretrained_model_path,
                                                                   gradient_checkpointing=True,
                                                                   use_cache=False)
    
    tokenizer = LEDTokenizer.from_pretrained(pretrained_model_path)

    # set generation hyperparameters for training
    pretrained_model.config.num_beams = 2
    pretrained_model.config.max_length = 512
    pretrained_model.config.min_length = 100
    pretrained_model.config.length_penalty = 2.0
    pretrained_model.config.early_stopping = True
    pretrained_model.config.no_repeat_ngram_size = 3

    # initialize the trainer
    training_args = Seq2SeqTrainingArguments(
        predict_with_generate=True,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=4,
        num_train_epochs=num_train_epochs,
        warmup_steps=warmup_steps,
        learning_rate=learning_rate,
        output_dir="./summarization_outputs",
        overwrite_output_dir=True
    )

    trainer = Seq2SeqTrainer(
        model=pretrained_model,
        tokenizer=tokenizer,
        args=training_args,
        compute_metrics=_create_compute_metrics(tokenizer),
        train_dataset=ds["train"],
        eval_dataset=ds["eval"]
    )

    # perform training
    train_result = trainer.train()
    trainer.save_metrics("train", train_result.metrics)

    # store the model
    trainer.save_model(f"./{new_model_name}")

    # evaluate the model
    eval_metrics = trainer.evaluate()
    trainer.save_metrics("eval", eval_metrics)

We perform finetuning for 20 epochs on the about 80 documents in our summarization dataset:

In [None]:
summarization(batch_size=2, num_train_epochs=20, warmup_steps=16, learning_rate=5e-5,
              pretrained_model_path="./led-ecb-lm-arxiv", dataset_path="./ecb-summaries",
              new_model_name="led-ecb-lm-arxiv-sum")

Note that in the context of the PoC we do not distinguish between short or long summaries in training. Finetuning the model with different generation settings for short and long summaries, or hinting to the model which length of summary is desired.

## Generating a Summary

Now that the model is trained and stored, we can use it to generate summaries of long texts.

In [None]:
from transformers import LEDForConditionalGeneration, LEDTokenizer

trained_model = LEDForConditionalGeneration.from_pretrained("./led-ecb-lm-arxiv-sum")
tokenizer = LEDTokenizer.from_pretrained("./led-ecb-lm-arxiv-sum")

beam_search_args = {
    "num_beams": 2,
    "length_penalty": 1.0,
    "no_repeat_ngram_size": 3,
    "early_stopping": True
}

def summarize(text: str, length: str) -> str:
    if length == "short":
        min_length, max_length = 80, 200
    elif length == "long":
        min_length, max_length = 350, 600
    else:
        raise ValueError(f"Length has to be either 'long' or 'short', not {length}.")
        
    inputs = tokenizer.encode(text, return_tensors="pt")

    # Global attention on the first token (cf. Beltagy et al. 2020)
    global_attention_mask = torch.zeros_like(inputs)
    global_attention_mask[:, 0] = 1

    summary_ids = model.generate(inputs, global_attention_mask=global_attention_mask,
                                 min_length=min_length, max_length=max_length, **beam_search_args)

    return tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)

In [None]:
with open("sample.txt", "rt") as f:
    text = f.read()
    
print(summarize(text, "short"))