# Bert2Bert on Xsum

## Libraries and environment preparation

In [None]:
#Install essential packages
%%capture
!pip install datasets rouge-score nltk wandb
!pip install transformers==4.11.0
!apt install git-lfs

In [None]:
#Colab Environment Check for GPU and RAM
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

#GPU check
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Your runtime has 13.6 gigabytes of available RAM

Sun Feb 13 18:15:40 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   41C    P0    28W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-------------------------------------

In [None]:
# Make sure your version of Transformers is at 4.11.0
# to run the following code correctly:
import datasets
import transformers

In [None]:
# Import Wandb 
import os
import wandb
API_KEY = '39991c538626bee25c64d4f8a4c3403dd635537c'
os.environ["WANDB_API_KEY"] = API_KEY

## Loading the dataset

In [None]:
# import dataset and metrics with huggingface
raw_datasets = datasets.load_dataset('xsum')

Downloading:   0%|          | 0.00/2.05k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/954 [00:00<?, ?B/s]

Using custom data configuration default


Downloading and preparing dataset xsum/default (download: 245.38 MiB, generated: 507.60 MiB, post-processed: Unknown size, total: 752.98 MiB) to /root/.cache/huggingface/datasets/xsum/default/1.2.0/32c23220eadddb1149b16ed2e9430a05293768cfffbdfd151058697d4c11f934...


  0%|          | 0/2 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/255M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.00M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset xsum downloaded and prepared to /root/.cache/huggingface/datasets/xsum/default/1.2.0/32c23220eadddb1149b16ed2e9430a05293768cfffbdfd151058697d4c11f934. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})

## Preprocessing the data

In [None]:
from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
tokenizer.bos_token = tokenizer.cls_token
tokenizer.eos_token = tokenizer.sep_token

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "


In [None]:
encoder_max_length=512
decoder_max_length=64

def process_data_to_model_inputs(batch):
  # tokenize the inputs and labels
  inputs = tokenizer(batch["document"], padding="max_length", truncation=True, max_length=encoder_max_length)
  outputs = tokenizer(batch["summary"], padding="max_length", truncation=True, max_length=decoder_max_length)

  batch["input_ids"] = inputs.input_ids
  batch["attention_mask"] = inputs.attention_mask
  batch["decoder_input_ids"] = outputs.input_ids
  batch["decoder_attention_mask"] = outputs.attention_mask
  batch["labels"] = outputs.input_ids.copy()

  # because BERT automatically shifts the labels, the labels correspond exactly to `decoder_input_ids`. 
  # We have to make sure that the PAD token is ignored
  batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]

  return batch

In [None]:
train_data = raw_datasets["train"]
train_data = train_data.map(
    process_data_to_model_inputs, 
    batched=True, 
    remove_columns=["document", "summary", "id"]
)
train_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

  0%|          | 0/205 [00:00<?, ?ba/s]

In [None]:
# bert2bert validation step is high computation cost. We choose about 1/5
val_data = raw_datasets['validation'].select(range(2500))
val_data = val_data.map(
    process_data_to_model_inputs, 
    batched=True, 
    remove_columns=["document", "summary", "id"]
)
val_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

  0%|          | 0/3 [00:00<?, ?ba/s]

## Fine-tuning the model

In [None]:
# Import tokenizer from model checkpoint and print detail
from transformers import EncoderDecoderModel
#bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "bert-base-cased", tie_encoder_decoder=True)
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")

  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "


Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertLMHeadModel: ['cls.seq_relationship.weight', 'cls.seq_relatio

Setting the special tokens.
`bert-base-cased` does not have a `decoder_start_token_id` or `eos_token_id`, so we will use its `cls_token_id` and `sep_token_id` respectively. 
Also, we should define a `pad_token_id` on the config and make sure the correct `vocab_size` is set.

In [None]:
# set special tokens
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
bert2bert.config.eos_token_id = tokenizer.sep_token_id
bert2bert.config.pad_token_id = tokenizer.pad_token_id

# sensible parameters for beam search
bert2bert.config.vocab_size = bert2bert.config.decoder.vocab_size
bert2bert.config.max_length = 142
bert2bert.config.min_length = 56
bert2bert.config.no_repeat_ngram_size = 3
bert2bert.config.early_stopping = True
bert2bert.config.length_penalty = 2.0
bert2bert.config.num_beams = 4

Define `Seq2SeqTrainer` to compute the metrics from the predictions, and also do a bit of pre-processing to decode the predictions into texts:

In [None]:
# keep track with wandb
wandb.init(project="bert2bert")

[34m[1mwandb[0m: Currently logged in as: [33mshusunny[0m (use `wandb login --relogin` to force relogin)


In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

In [None]:
import nltk
import numpy as np
nltk.download('punkt')

from datasets import load_metric
metric = load_metric("rouge")

def compute_metrics(eval_pred):
    predictions = eval_pred.predictions
    labels = eval_pred.label_ids
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    return {k: round(v, 4) for k, v in result.items()}

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


Downloading:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

In [None]:
# set training arguments - these params are not really tuned, feel free to change
batch_size = 8
epochs = 1
training_args = Seq2SeqTrainingArguments(
    output_dir="./",
    evaluation_strategy="steps",
    load_best_model_at_end="eval_loss",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    predict_with_generate=True,
    logging_steps=1000,  # set to 1000 for full training
    save_steps=2500,  # set to 500 for full training
    eval_steps=2500,  # set to 8000 for full training
    warmup_steps=500,  # set to 2000 for full training
    overwrite_output_dir=True,
    num_train_epochs=epochs,
    save_total_limit=3,
    fp16=True, 
    report_to="wandb",
)

PyTorch: setting up devices


In [None]:
# Pass into the trainer
trainer = Seq2SeqTrainer(
    model=bert2bert,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=val_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

Using amp fp16 backend


In [None]:
trainer.train()

***** Running training *****
  Num examples = 204045
  Num Epochs = 1
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 25506
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Step,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
2500,3.9125,3.48907,20.2537,4.6131,15.2903,15.2999
5000,3.349,3.147529,23.194,6.2188,17.5282,17.5353
7500,3.1551,2.950679,24.1366,6.8518,18.2626,18.2668
10000,2.9674,2.820365,25.1193,7.4978,18.8538,18.8583
12500,2.8906,2.727035,26.2209,8.0839,19.5425,19.5549
15000,2.7885,2.652417,26.9716,8.5288,20.1322,20.1232
17500,2.7227,2.596484,27.3394,8.9009,20.4316,20.4377


***** Running Evaluation *****
  Num examples = 2500
  Batch size = 8
Saving model checkpoint to ./checkpoint-2500
  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
Configuration saved in ./checkpoint-2500/config.json
Model weights saved in ./checkpoint-2500/pytorch_model.bin
tokenizer config file saved in ./checkpoint-2500/tokenizer_config.json
Special tokens file saved in ./checkpoint-2500/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 2500
  Batch size = 8
Saving model checkpoint to ./checkpoint-5000
  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
Configuration saved in ./checkpoint-5000/config.json
Model weights saved in ./checkpoint-5000/pytorch_model.bin
tokenizer config file saved in ./checkpoint-5000/tokenizer_config.json
Special tokens file saved in ./checkpoint-5000/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 2500

Step,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
2500,3.9125,3.48907,20.2537,4.6131,15.2903,15.2999
5000,3.349,3.147529,23.194,6.2188,17.5282,17.5353
7500,3.1551,2.950679,24.1366,6.8518,18.2626,18.2668
10000,2.9674,2.820365,25.1193,7.4978,18.8538,18.8583
12500,2.8906,2.727035,26.2209,8.0839,19.5425,19.5549
15000,2.7885,2.652417,26.9716,8.5288,20.1322,20.1232
17500,2.7227,2.596484,27.3394,8.9009,20.4316,20.4377
20000,2.6652,2.539667,27.9718,9.331,20.8453,20.8556
22500,2.6296,2.501847,28.1762,9.5876,20.9497,20.9583
25000,2.6026,2.478509,28.3023,9.6417,21.0191,21.0233


***** Running Evaluation *****
  Num examples = 2500
  Batch size = 8
Saving model checkpoint to ./checkpoint-20000
  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
Configuration saved in ./checkpoint-20000/config.json
Model weights saved in ./checkpoint-20000/pytorch_model.bin
tokenizer config file saved in ./checkpoint-20000/tokenizer_config.json
Special tokens file saved in ./checkpoint-20000/special_tokens_map.json
Deleting older checkpoint [checkpoint-12500] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 2500
  Batch size = 8
Saving model checkpoint to ./checkpoint-22500
  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
Configuration saved in ./checkpoint-22500/config.json
Model weights saved in ./checkpoint-22500/pytorch_model.bin
tokenizer config file saved in ./checkpoint-22500/tokenizer_config.json
Special tokens file saved in ./checkpoin

TrainOutput(global_step=25506, training_loss=3.0276054837587236, metrics={'train_runtime': 27793.0182, 'train_samples_per_second': 7.342, 'train_steps_per_second': 0.918, 'total_flos': 1.408187721739968e+17, 'train_loss': 3.0276054837587236, 'epoch': 1.0})

In [25]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
eval/loss,█▆▄▃▃▂▂▁▁▁
eval/rouge1,▁▄▄▅▆▇▇███
eval/rouge2,▁▃▄▅▆▆▇███
eval/rougeL,▁▄▅▅▆▇▇███
eval/rougeLsum,▁▄▅▅▆▇▇███
eval/runtime,█▂█▄▃▃▃▂▁▇
eval/samples_per_second,▁▇▁▅▆▆▆▇█▂
eval/steps_per_second,▁▆▁▄▆▆▆▇█▂
train/epoch,▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇████
train/global_step,▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇████

0,1
eval/loss,2.47851
eval/rouge1,28.3023
eval/rouge2,9.6417
eval/rougeL,21.0191
eval/rougeLsum,21.0233
eval/runtime,957.867
eval/samples_per_second,2.61
eval/steps_per_second,0.327
train/epoch,1.0
train/global_step,25506.0


## Results and Evaluation

In [29]:
!ls -lh

total 2.6G
-rw-r--r-- 1 root root 2.6G Feb 14 02:19 bert2bert-xsum.zip
drwxr-xr-x 2 root root 4.0K Feb 14 00:37 checkpoint-20000
drwxr-xr-x 2 root root 4.0K Feb 14 01:22 checkpoint-22500
drwxr-xr-x 2 root root 4.0K Feb 14 02:08 checkpoint-25000
drwx------ 5 root root 4.0K Feb 14 02:20 drive
drwxr-xr-x 1 root root 4.0K Feb  1 14:32 sample_data
drwxr-xr-x 3 root root 4.0K Feb 13 18:26 wandb


In [27]:
!zip -r bert2bert-xsum.zip checkpoint-25000/

  adding: checkpoint-25000/ (stored 0%)
  adding: checkpoint-25000/special_tokens_map.json (deflated 53%)
  adding: checkpoint-25000/vocab.txt (deflated 53%)
  adding: checkpoint-25000/training_args.bin (deflated 48%)
  adding: checkpoint-25000/optimizer.pt (deflated 10%)
  adding: checkpoint-25000/config.json (deflated 80%)
  adding: checkpoint-25000/tokenizer.json (deflated 59%)
  adding: checkpoint-25000/tokenizer_config.json (deflated 39%)
  adding: checkpoint-25000/scheduler.pt (deflated 49%)
  adding: checkpoint-25000/trainer_state.json (deflated 78%)
  adding: checkpoint-25000/pytorch_model.bin (deflated 7%)
  adding: checkpoint-25000/rng_state.pth (deflated 27%)
  adding: checkpoint-25000/scaler.pt (deflated 55%)


In [28]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [30]:
!cp bert2bert-xsum.zip '/content/drive/My Drive/weights/'

In [None]:

# only use 16 training examples for notebook - DELETE LINE FOR FULL TRAINING
test_data = test_data.select(range(16))

results = test_data.map(generate_summary, batched=True, batch_size=batch_size)

  0%|          | 0/2 [00:00<?, ?ba/s]