## **Warm-starting BERT2BERT for CNN/Dailymail**

***Note***: This notebook only uses a few training, validation, and test data samples for demonstration purposes. To fine-tune an encoder-decoder model on the full training data, the user should change the training and data preprocessing parameters accordingly as highlighted by the comments.


### **Data Preprocessing**


In [1]:
%%capture
!pip install datasets
!pip install transformers
!pip install rouge-score

import datasets
import transformers

In [2]:
from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

raw_data = datasets.load_dataset("cnn_dailymail", "3.0.0")

In [3]:
encoder_max_length=512
decoder_max_length=128

def process_data_to_model_inputs(batch):
  # tokenize the inputs and labels
  inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=encoder_max_length)
  outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=decoder_max_length)

  batch["input_ids"] = inputs.input_ids
  batch["attention_mask"] = inputs.attention_mask
  batch["decoder_input_ids"] = outputs.input_ids
  batch["decoder_attention_mask"] = outputs.attention_mask
  batch["labels"] = outputs.input_ids.copy()

  # because BERT automatically shifts the labels, the labels correspond exactly to `decoder_input_ids`. 
  # We have to make sure that the PAD token is ignored
  batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]

  return batch

In [4]:
train_data = raw_data['train'].select(range(20000))
val_data = raw_data['validation'].select(range(2000))

In [5]:
batch_size = 8

train_data = train_data.map(
    process_data_to_model_inputs, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["article", "highlights", "id"]
)

train_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

  0%|          | 0/2500 [00:00<?, ?ba/s]

In [6]:
val_data = val_data.map(
    process_data_to_model_inputs, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["article", "highlights", "id"]
)
val_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

  0%|          | 0/250 [00:00<?, ?ba/s]

### **Warm-starting the Encoder-Decoder Model**

In [7]:
from transformers import EncoderDecoderModel

bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")

In [8]:
# set special tokens
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
bert2bert.config.eos_token_id = tokenizer.sep_token_id
bert2bert.config.pad_token_id = tokenizer.pad_token_id

# sensible parameters for beam search
bert2bert.config.vocab_size = bert2bert.config.decoder.vocab_size
bert2bert.config.max_length = 142
bert2bert.config.min_length = 56
bert2bert.config.no_repeat_ngram_size = 3
bert2bert.config.early_stopping = True
bert2bert.config.length_penalty = 2.0
bert2bert.config.num_beams = 4

### **Fine-Tuning Warm-Started Encoder-Decoder Models**

For the `EncoderDecoderModel` framework, we will use the `Seq2SeqTrainingArguments` and the `Seq2SeqTrainer`. Let's import them.

In [9]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

Also, we need to define a function to correctly compute the ROUGE score during validation. ROUGE is a much better metric to track during training than only language modeling loss.

In [15]:
import nltk
import numpy as np
nltk.download('punkt')

from datasets import load_metric
metric = datasets.load_metric("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Cool! Finally, we start training.

In [16]:
# set training arguments - these params are not really tuned, feel free to change

training_args = Seq2SeqTrainingArguments(
    output_dir="./",
    evaluation_strategy="steps",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    predict_with_generate=True,
    logging_steps=1000,  # set to 1000 for full training
    save_steps=500,  # set to 500 for full training
    eval_steps=1500,  # set to 8000 for full training
    warmup_steps=2000,  # set to 2000 for full training
    overwrite_output_dir=True,
    save_total_limit=3,
    fp16=True, 
)

In [17]:
# instantiate trainer
trainer = Seq2SeqTrainer(
    model=bert2bert,
    tokenizer=tokenizer,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_data,
    eval_dataset=val_data,
)

In [18]:
trainer.train()

Step,Training Loss,Validation Loss,Rouge1,Rougel,Gen Len,Runtime,Samples Per Second
1500,4.2322,4.492015,18.2249,12.9033,69.857,718.8185,2.782
3000,3.9789,4.213501,20.9799,14.4822,66.529,695.5165,2.876
4500,3.6941,4.026018,22.0524,14.8849,66.5205,699.6991,2.858


  next_indices = next_tokens // vocab_size
  next_indices = next_tokens // vocab_size
  next_indices = next_tokens // vocab_size


Step,Training Loss,Validation Loss,Rouge1,Rougel,Gen Len,Runtime,Samples Per Second
1500,4.2322,4.492015,18.2249,12.9033,69.857,718.8185,2.782
3000,3.9789,4.213501,20.9799,14.4822,66.529,695.5165,2.876
4500,3.6941,4.026018,22.0524,14.8849,66.5205,699.6991,2.858
6000,3.1138,3.95222,24.0855,16.3023,66.6005,697.3062,2.868
7500,3.0887,3.894571,24.6289,16.644,66.075,690.4254,2.897


  next_indices = next_tokens // vocab_size
  next_indices = next_tokens // vocab_size


TrainOutput(global_step=7500, training_loss=3.6303883463541666, metrics={'train_runtime': 9643.8687, 'train_samples_per_second': 0.778, 'total_flos': 56992524134400000, 'epoch': 3.0})

In [19]:
!ls

checkpoint-6500  checkpoint-7000  checkpoint-7500  runs  sample_data


### **Evaluation**

Awesome, we finished training our dummy model. Let's now evaluated the model on the test data. We make use of the dataset's handy `.map()` function to generate a summary of each sample of the test data.

In [20]:
from transformers import BertTokenizer, EncoderDecoderModel

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = EncoderDecoderModel.from_pretrained("./checkpoint-7500")
model.to("cuda")

test_data = raw_data['test']

# only use 16 training examples for notebook - DELETE LINE FOR FULL TRAINING
test_data = test_data.select(range(16))

batch_size = 16  # change to 64 for full evaluation

In [21]:
# map data correctly
def generate_summary(batch):
    # Tokenizer will automatically set [BOS] <text> [EOS]
    # cut off at BERT max length 512
    inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    input_ids = inputs.input_ids.to("cuda")
    attention_mask = inputs.attention_mask.to("cuda")

    outputs = model.generate(input_ids, attention_mask=attention_mask)

    # all special tokens including will be removed
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    batch["pred"] = output_str

    return batch

In [31]:
results = test_data.map(generate_summary, batched=True, batch_size=batch_size)

  next_indices = next_tokens // vocab_size


  0%|          | 0/1 [00:00<?, ?ba/s]

In [33]:
results

Dataset(features: {'article': Value(dtype='string', id=None), 'highlights': Value(dtype='string', id=None), 'id': Value(dtype='string', id=None), 'pred': Value(dtype='string', id=None)}, num_rows: 16)

In [23]:
pred_str = results["pred"]
label_str = results["highlights"]

In [27]:
len(pred_str)

16

In [34]:
for i in range(8):
    print("Original Text: %s" % results[i]['article'])
    print("\nActual Summary: %s" % label_str[i])
    print("\nPredicted Summary: %s" % pred_str[i])
    print("=====================================================================\n")

Original Text: (CNN)James Best, best known for his portrayal of bumbling sheriff Rosco P. Coltrane on TV's "The Dukes of Hazzard," died Monday after a brief illness. He was 88. Best died in hospice in Hickory, North Carolina, of complications from pneumonia, said Steve Latshaw, a longtime friend and Hollywood colleague. Although he'd been a busy actor for decades in theater and in Hollywood, Best didn't become famous until 1979, when "The Dukes of Hazzard's" cornpone charms began beaming into millions of American homes almost every Friday night. For seven seasons, Best's Rosco P. Coltrane chased the moonshine-running Duke boys back and forth across the back roads of fictitious Hazzard County, Georgia, although his "hot pursuit" usually ended with him crashing his patrol car. Although Rosco was slow-witted and corrupt, Best gave him a childlike enthusiasm that got laughs and made him endearing. His character became known for his distinctive "kew-kew-kew" chuckle and for goofy catchphras

In [26]:
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

print(rouge_output)

Score(precision=0.03913796113227547, recall=0.06732131280287557, fmeasure=0.04840349920181923)


The fully trained *BERT2BERT* model is uploaded to the ðŸ¤—model hub under [patrickvonplaten/bert2bert_cnn_daily_mail](https://huggingface.co/patrickvonplaten/bert2bert_cnn_daily_mail). 

The model achieves a ROUGE-2 score of **18.22**, which is even a little better than reported in the paper.

For some summarization examples, the reader is advised to use the online inference API of the model, [here](https://huggingface.co/patrickvonplaten/bert2bert_cnn_daily_mail).