##Text summarization with BERT

Code adopted from: https://huggingface.co/blog/warm-starting-encoder-decoder

In [1]:
is_colab = True

##Import packages

Install packages

In [None]:
if is_colab:
    !pip install datasets
    !pip install rouge_score
    !pip install transformers==4.5.0

Import packages

In [5]:
import os

import datasets
from datasets import load_dataset, load_metric
import numpy as np
import pandas as pd
import torch
from transformers import (BertTokenizerFast, EncoderDecoderModel,
                          Seq2SeqTrainer, Seq2SeqTrainingArguments)

if is_colab:
    from google.colab import drive

##Load data

In [None]:
if is_colab:
    # Load data from Google Drive
    drive.mount('/content/drive')
    data_folder = '/content/drive/MyDrive/deep_learning/project/data/cxr'

    train_path = os.path.join(data_folder, "train.csv")
    val_path = os.path.join(data_folder, "validation.csv")
    test_path = os.path.join(data_folder, "test.csv")

else:
    # Load data from local directory
    data_folder = "/home/labuser/Documents/deep_learning/project/data/cxr"

    train_path = os.path.join(data_folder, "train.csv")
    val_path = os.path.join(data_folder, "validation.csv")
    test_path = os.path.join(data_folder, "test.csv")

train_data = load_dataset('csv', data_files=train_path, split='train')
val_data = load_dataset('csv', data_files=val_path, split='train')
test_data = load_dataset('csv', data_files=test_path, split='train')

##Tokenize data

In [12]:
def map_to_length(x):
    """
    Generates summary statistics for data.

    Parameters
    ----------
    x : Dataset
        Dataframe with text and summary columns.
    
    Returns
    -------
    x : Dataset
        Dataframe with additional columns showing data statistics.
    """

    x["text_len"] = len(tokenizer(x["text"]).input_ids)
    x["text_longer_128"] = int(x["text_len"] > 128)
    x["text_longer_256"] = int(x["text_len"] > 256)
    x["summary_len"] = len(tokenizer(x["summary"]).input_ids)
    x["summary_longer_64"] = int(x["summary_len"] > 64)
    x["summary_longer_128"] = int(x["summary_len"] > 128)
    return x


def compute_and_print_stats(x, sample_size=10000):
    """
    Print summary statistics for data.

    Parameters
    ----------
    x : Dataset
        Dataset returned from `map_to_length`.
    sample_size : int
        Number of samples to be included in calculation.
    """

    if len(x["summary_len"]) == sample_size:
        print("Text Mean: {:.3f}, %-Text > 128: {:.3f}, %-Text > 256: {:.3f}\n"\
              "Summary Mean: {:.3f}, %-Summary > 64: {:.3f}, %-Summary > 128: {:.3f}".format(
                sum(x["text_len"]) / sample_size,
                sum(x["text_longer_128"]) / sample_size * 100,
                sum(x["text_longer_256"]) / sample_size * 100,
                sum(x["summary_len"]) / sample_size,
                sum(x["summary_longer_64"]) / sample_size * 100,
                sum(x["summary_longer_128"]) / sample_size * 100,
        )
    )


def tokenize_batch(batch, enc_max_len=256, dec_max_len=128):
    """
    Tokenize a batch of data.

    Parameters
    ----------
    batch : Dataset
        Batch of dataset.
    enc_max_len : int
        Maximum input length for encoder.
    dec_max_len : int
        Maximum input length for decoder.
    
    Returns
    -------
    batch : Dataset
        Dataframe with tokenized data.
    """

    # tokenize inputs and outputs
    x = tokenizer(batch["text"], padding="max_length", truncation=True,
                  max_length=enc_max_len)
    y = tokenizer(batch["summary"], padding="max_length", truncation=True,
                  max_length=dec_max_len)
    
    # include info in dict
    batch["input_ids"] = x.input_ids
    batch["attention_mask"] = x.attention_mask
    batch["decoder_input_ids"] = y.input_ids
    batch["decoder_attention_mask"] = y.attention_mask
    batch["labels"] = y.input_ids.copy()

    # ignore PAD token
    batch["labels"] = [[-100 if token == tokenizer.pad_token_id
                        else token for token in labels]
                       for labels in batch["labels"]]
    return batch


def tokenize_data(data):
    """
    Tokenize data in batches.

    Parameters
    ----------
    data : Dataset
        Dataframe with text and summary columns.
    
    Returns##
    -------
    data : Dataset
        Dataframe with tokenized data.
    """

    data = data.map(
        tokenize_batch,
        batched=True,
        batch_size=16,
        remove_columns=["study_id", "subject_id", "text", "summary"]
    )
    data.set_format(
        type="torch",
        columns=["input_ids", "attention_mask",
                 "decoder_input_ids", "decoder_attention_mask",
                 "labels"]
    )
    return data

Initialize tokenizer

In [None]:
tokenizer_name = "dmis-lab/biobert-base-cased-v1.1"
#tokenizer_name = "emilyalsentzer/Bio_ClinicalBERT"

tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name)

Compute data statistics

In [None]:
data_stats = train_data.select(range(10000)).map(map_to_length, num_proc=4)
output = data_stats.map(compute_and_print_stats, batched=True, batch_size=-1)

Tokenize data

In [None]:
train_data_tokenized = tokenize_data(train_data)
val_data_tokenized = tokenize_data(val_data)

##Initialize model

Select encoder and decoder, and whether to share parameters

In [16]:
enc_name = "dmis-lab/biobert-base-cased-v1.1"
dec_name = "dmis-lab/biobert-base-cased-v1.1"
tie_encoder_decoder=False
model_name = "biobert2biobert"

#enc_name = "dmis-lab/biobert-base-cased-v1.1"
#dec_name = "dmis-lab/biobert-base-cased-v1.1"
#tie_encoder_decoder=True
#model_name = "biobertshare"

#enc_name = "emilyalsentzer/Bio_ClinicalBERT"
#dec_name = "emilyalsentzer/Bio_ClinicalBERT"
#tie_encoder_decoder=False
#model_name = "clinicalbert2clinicalbert"

#enc_name = "emilyalsentzer/Bio_ClinicalBERT"
#dec_name = "emilyalsentzer/Bio_ClinicalBERT"
#tie_encoder_decoder=True
#model_name = "clinicalbertshare"

Initialize model

In [None]:
model = EncoderDecoderModel.from_encoder_decoder_pretrained(enc_name, dec_name,
                                                            tie_encoder_decoder=tie_encoder_decoder)

model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.eos_token_id = tokenizer.sep_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = model.config.encoder.vocab_size

model.config.max_length = 142
model.config.min_length = 56
model.config.no_repeat_ngram_size = 3
model.config.early_stopping = True
model.config.length_penalty = 2.0
model.config.num_beams = 4

##Train model

In [21]:
def compute_metrics(pred):
    """
    Compute ROUGE score for predicted summary.

    Parameters
    ----------
    pred : 
        Predicted tokenized summary.

    Returns
    -------
    dict
        ROUGE-2 precision, recall and fmeasure.
    """

    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(predictions=pred_str, references=label_str,
                                 rouge_types=["rouge2"])["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }


training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    fp16=False, 
    output_dir="./",
    logging_steps=1000,
    save_steps=500,
    eval_steps=7500,
    warmup_steps=2000,
    save_total_limit=3,
)

rouge = load_metric("rouge")

trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_data_tokenized,
    eval_dataset=val_data_tokenized,
)

In [None]:
trainer.train()