# Bert2Bert for Cnn/daily

## Libraries and environment preparation

In [1]:
#Install essential packages
%%capture
!pip install datasets rouge-score nltk wandb
!pip install transformers==4.11.0
!apt install git-lfs

In [2]:
#Colab Environment Check for GPU and RAM
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

#GPU check
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Your runtime has 13.6 gigabytes of available RAM

Sun Feb  6 19:20:22 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   48C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-------------------------------------

In [3]:
# Make sure your version of Transformers is at 4.11.0
# to run the following code correctly:
import datasets
import transformers

In [4]:
# Import Wandb 
import os
import wandb
API_KEY = '39991c538626bee25c64d4f8a4c3403dd635537c'
os.environ["WANDB_API_KEY"] = API_KEY

## Loading the dataset

In [5]:
# import dataset and metrics with huggingface
raw_datasets = datasets.load_dataset('cnn_dailymail', '3.0.0')

Reusing dataset cnn_dailymail (/root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/3cb851bf7cf5826e45d49db2863f627cba583cbc32342df7349dfe6c38060234)


  0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 287113
    })
    validation: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 13368
    })
    test: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 11490
    })
})

## Preprocessing the data

In [7]:
from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
tokenizer.bos_token = tokenizer.cls_token
tokenizer.eos_token = tokenizer.sep_token

  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "


**Following steps indentity the length of the article**

We can see that on average an article contains 848 tokens with *ca.* 3/4 of the articles being longer than the model's `max_length` 512. The summary is on average 57 tokens long. Over 30% of our 10000-sample summaries are longer than 64 tokens, but none are longer than 128 tokens.

`bert-base-cased` is limited to 512 tokens, which means we would have to cut possibly important information from the article. Because most of the important information is often found at the beginning of articles and because we want to be computationally efficient, we decide to stick to `bert-base-cased` with a `max_length` of 512 in this notebook. This choice is not optimal but has shown to yield [good results](https://arxiv.org/abs/1907.12461) on CNN/Dailymail. 

In [8]:
encoder_max_length=512
decoder_max_length=128

def process_data_to_model_inputs(batch):
  # tokenize the inputs and labels
  inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=encoder_max_length)
  outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=decoder_max_length)

  batch["input_ids"] = inputs.input_ids
  batch["attention_mask"] = inputs.attention_mask
  batch["decoder_input_ids"] = outputs.input_ids
  batch["decoder_attention_mask"] = outputs.attention_mask
  batch["labels"] = outputs.input_ids.copy()

  # because BERT automatically shifts the labels, the labels correspond exactly to `decoder_input_ids`. 
  # We have to make sure that the PAD token is ignored
  batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]

  return batch

In [9]:
train_data = raw_datasets["train"].select(range(10000))
train_data = train_data.map(
    process_data_to_model_inputs, 
    batched=True, 
    remove_columns=["article", "highlights", "id"]
)
train_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

Loading cached processed dataset at /root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/3cb851bf7cf5826e45d49db2863f627cba583cbc32342df7349dfe6c38060234/cache-33ae1bb3f62403a4.arrow


In [10]:
# bert2bert validation step is high computation cost. We choose about 1/5
val_data = raw_datasets['validation'].select(range(1000))
val_data = val_data.map(
    process_data_to_model_inputs, 
    batched=True, 
    remove_columns=["article", "highlights", "id"]
)
val_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

  0%|          | 0/1 [00:00<?, ?ba/s]

## Fine-tuning the model

In [11]:
# Import tokenizer from model checkpoint and print detail
from transformers import EncoderDecoderModel
#bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "bert-base-cased", tie_encoder_decoder=True)
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")

  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-

Setting the special tokens.
`bert-base-cased` does not have a `decoder_start_token_id` or `eos_token_id`, so we will use its `cls_token_id` and `sep_token_id` respectively. 
Also, we should define a `pad_token_id` on the config and make sure the correct `vocab_size` is set.

In [12]:
# set special tokens
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
bert2bert.config.eos_token_id = tokenizer.sep_token_id
bert2bert.config.pad_token_id = tokenizer.pad_token_id

# sensible parameters for beam search
bert2bert.config.vocab_size = bert2bert.config.decoder.vocab_size
bert2bert.config.max_length = 142
bert2bert.config.min_length = 56
bert2bert.config.no_repeat_ngram_size = 3
bert2bert.config.early_stopping = True
bert2bert.config.length_penalty = 2.0
bert2bert.config.num_beams = 4

Define `Seq2SeqTrainer` to compute the metrics from the predictions, and also do a bit of pre-processing to decode the predictions into texts:

In [13]:
# keep track with wandb
wandb.init(project="bert2bert")

[34m[1mwandb[0m: Currently logged in as: [33mshusunny[0m (use `wandb login --relogin` to force relogin)


In [14]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

In [15]:
import nltk
import numpy as np
nltk.download('punkt')

from datasets import load_metric
metric = load_metric("rouge")

def compute_metrics(eval_pred):
    predictions = eval_pred.predictions
    labels = eval_pred.label_ids
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    return {k: round(v, 4) for k, v in result.items()}

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [16]:
# set training arguments - these params are not really tuned, feel free to change
batch_size = 8
epochs = 16
training_args = Seq2SeqTrainingArguments(
    output_dir="./",
    evaluation_strategy="steps",
    load_best_model_at_end="eval_loss",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    predict_with_generate=True,
    logging_steps=1000,  # set to 1000 for full training
    save_steps=2000,  # set to 500 for full training
    eval_steps=2000,  # set to 8000 for full training
    warmup_steps=500,  # set to 2000 for full training
    overwrite_output_dir=True,
    num_train_epochs=epochs,
    save_total_limit=3,
    fp16=True, 
    report_to="wandb",
)

In [17]:
# Pass into the trainer
trainer = Seq2SeqTrainer(
    model=bert2bert,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=val_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

Using amp fp16 backend


In [18]:
trainer.train()

***** Running training *****
  Num examples = 10000
  Num Epochs = 16
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 20000
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Step,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
2000,4.0143,4.088926,19.074,3.77,13.7252,13.7645
4000,3.1725,4.071744,20.9492,4.4045,14.9659,15.0016
6000,2.3879,4.111645,22.4568,4.8246,15.7005,15.7463
8000,1.8621,4.267831,23.1577,5.1581,15.9649,16.01
10000,1.4263,4.397253,23.2661,5.1622,16.1844,16.2134
12000,1.0173,4.578029,23.6312,5.0765,16.1414,16.177
14000,0.774,4.738756,23.6892,5.088,16.1619,16.2026
16000,0.5602,4.845972,23.6302,5.05,16.1245,16.1589
18000,0.4544,4.921787,23.8735,4.9613,16.2767,16.3138
20000,0.3856,4.954407,23.7362,5.1214,16.0812,16.1138


***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8
Saving model checkpoint to ./checkpoint-2000
  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
Configuration saved in ./checkpoint-2000/config.json
Model weights saved in ./checkpoint-2000/pytorch_model.bin
tokenizer config file saved in ./checkpoint-2000/tokenizer_config.json
Special tokens file saved in ./checkpoint-2000/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8
Saving model checkpoint to ./checkpoint-4000
  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
Configuration saved in ./checkpoint-4000/config.json
Model weights saved in ./checkpoint-4000/pytorch_model.bin
tokenizer config file saved in ./checkpoint-4000/tokenizer_config.json
Special tokens file saved in ./checkpoint-4000/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 1000

TrainOutput(global_step=20000, training_loss=1.7576053955078126, metrics={'train_runtime': 15739.6597, 'train_samples_per_second': 10.165, 'train_steps_per_second': 1.271, 'total_flos': 1.22690820096e+17, 'train_loss': 1.7576053955078126, 'epoch': 16.0})

In [19]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
eval/loss,▁▁▁▃▄▅▆▇██
eval/rouge1,▁▄▆▇▇█████
eval/rouge2,▁▄▆████▇▇█
eval/rougeL,▁▄▆▇█████▇
eval/rougeLsum,▁▄▆▇█████▇
eval/runtime,▂▁██▄▆▆▆▄▆
eval/samples_per_second,▇█▂▁▅▃▃▃▅▃
eval/steps_per_second,▇█▂▁▅▄▄▃▅▃
train/epoch,▁▁▁▂▂▂▂▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▇▇▇▇▇████
train/global_step,▁▁▁▂▂▂▂▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▇▇▇▇▇████

0,1
eval/loss,4.95441
eval/rouge1,23.7362
eval/rouge2,5.1214
eval/rougeL,16.0812
eval/rougeLsum,16.1138
eval/runtime,409.4674
eval/samples_per_second,2.442
eval/steps_per_second,0.305
train/epoch,16.0
train/global_step,20000.0


## Results and Evaluation

In [None]:
!ls -lh

total 2.6G
-rw-r--r-- 1 root root 2.6G Feb  6 01:13 bert2bert-cnn.zip
drwxr-xr-x 2 root root 4.0K Feb  5 23:25 checkpoint-20000
drwxr-xr-x 2 root root 4.0K Feb  6 00:12 checkpoint-22500
drwxr-xr-x 2 root root 4.0K Feb  6 01:00 checkpoint-25000
drwxr-xr-x 1 root root 4.0K Feb  1 14:32 sample_data
drwxr-xr-x 3 root root 4.0K Feb  5 17:02 wandb


In [None]:
!zip -r bert2bert-cnn.zip checkpoint-25000/

  adding: checkpoint-25000/ (stored 0%)
  adding: checkpoint-25000/training_args.bin (deflated 48%)
  adding: checkpoint-25000/vocab.txt (deflated 53%)
  adding: checkpoint-25000/pytorch_model.bin (deflated 7%)
  adding: checkpoint-25000/tokenizer.json (deflated 59%)
  adding: checkpoint-25000/tokenizer_config.json (deflated 39%)
  adding: checkpoint-25000/rng_state.pth (deflated 27%)
  adding: checkpoint-25000/special_tokens_map.json (deflated 53%)
  adding: checkpoint-25000/optimizer.pt (deflated 9%)
  adding: checkpoint-25000/scaler.pt (deflated 55%)
  adding: checkpoint-25000/scheduler.pt (deflated 49%)
  adding: checkpoint-25000/config.json (deflated 80%)
  adding: checkpoint-25000/trainer_state.json (deflated 80%)


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!cp bert2bert-cnn.zip '/content/drive/My Drive/weights/'

In [None]:

# only use 16 training examples for notebook - DELETE LINE FOR FULL TRAINING
test_data = test_data.select(range(16))

results = test_data.map(generate_summary, batched=True, batch_size=batch_size)

  0%|          | 0/2 [00:00<?, ?ba/s]