[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/elsanns/xai-nlp-notebooks/blob/master/fine_tune_bart_summarization_two_langs.ipynb)
[GitHub](https://github.com/elsanns/xai-nlp-notebooks/blob/master/fine_tune_bart_summarization_two_langs.ipynb)
# Fine-tuning BART for summarization in two languages

---

In a world of ever-growing amount of data, the task of automatically creating coherent and fluent summaries is gaining importance. Coming up with a shorter, concise version of a document, can help to derive value from large volumes of text.

This notebook contains an example of fine-tuning [Bart](https://huggingface.co/transformers/model_doc/bart.html) for generating summaries of article sections from the [WikiLingua](https://huggingface.co/datasets/wiki_lingua) dataset. WikiLingua is a multilingual set of articles. We will run the same code for two Bart checkpoints, including a non-English model from the [Hugging Face Model Hub](https://huggingface.co/models). We will be using:
- the **English** portion of WikiLingua with [sshleifer/distilbart-xsum-12-3](https://huggingface.co/sshleifer/distilbart-xsum-12-3) Bart checkpoint, and
- **French** articles from the WikiLingua with the [moussaKam/barthez-orangesum-abstract](https://huggingface.co/moussaKam/barthez-orangesum-abstract) model.

## Setup

---

In [3]:
! pip install transformers
! pip install datasets
! pip install sentencepiece
! pip install rouge_score
! pip install accelerate -U
# ! pip install wandb



In [4]:
import torch
import numpy as np
import datasets

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
)

from tabulate import tabulate
import nltk
from datetime import datetime

## Set language

---

Pick a language. In our example, there is a choice of English and French.

French

In [5]:
language = "french"

English

In [6]:
#language = "english"

## Model and tokenizer

---

Download model and tokenizer. Use default parameters or try custom values (see [HF Bart configuration](https://huggingface.co/transformers/_modules/transformers/configuration_bart.html) and [Fairseq Bart](https://github.com/pytorch/fairseq/tree/master/examples/bart)).

In [15]:
model_name = "sshleifer/distilbart-xsum-12-3"
if language == "french":
    model_name = "moussaKam/barthez-orangesum-abstract"

model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Set model parameters or use the default
# print(model.config)

# tokenization
encoder_max_length = 256  # demo
decoder_max_length = 64



## Data

---

### Download

For demonstration, we are only using a small portion of the data.

In [8]:
# Trained on validation set because of memory and time limitation

dataset = datasets.load_dataset("csv", data_files = "validation.csv")
dataset = dataset["train"]

train_data_txt, validation_data_txt = dataset.train_test_split(test_size=0.3).values()

In [9]:
dataset

Dataset({
    features: ['text', 'titles'],
    num_rows: 1500
})

### Prepare

**Format and split into train and validation sets**

**Preprocess and tokenize**

In [16]:
def batch_tokenize_preprocess(batch, tokenizer, max_source_length, max_target_length):
    source, target = batch["text"], batch["titles"]
    source_tokenized = tokenizer(
        source, padding="max_length", truncation=True, max_length=max_source_length
    )
    target_tokenized = tokenizer(
        target, padding="max_length", truncation=True, max_length=max_target_length
    )

    batch = {k: v for k, v in source_tokenized.items()}
    # Ignore padding in the loss
    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in l]
        for l in target_tokenized["input_ids"]
    ]
    return batch


train_data = train_data_txt.map(
    lambda batch: batch_tokenize_preprocess(
        batch, tokenizer, encoder_max_length, decoder_max_length
    ),
    batched=True,
    remove_columns=train_data_txt.column_names,
)

validation_data = validation_data_txt.map(
    lambda batch: batch_tokenize_preprocess(
        batch, tokenizer, encoder_max_length, decoder_max_length
    ),
    batched=True,
    remove_columns=validation_data_txt.column_names,
)

Map:   0%|          | 0/1050 [00:00<?, ? examples/s]

Map:   0%|          | 0/450 [00:00<?, ? examples/s]

## Training

---

### Metrics

In [17]:
# Borrowed from https://github.com/huggingface/transformers/blob/master/examples/seq2seq/run_summarization.py

nltk.download("punkt", quiet=True)

metric = datasets.load_metric("rouge")


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract a few results from ROUGE
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


### Training arguments

In [18]:
training_args = Seq2SeqTrainingArguments(
    output_dir="results",
    num_train_epochs=1,  # demo
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=4,  # demo
    per_device_eval_batch_size=4,
    learning_rate=3e-05,
    warmup_steps=500,
    weight_decay=0.1,
    #label_smoothing_factor=0.1,
    generation_max_length = 100,
    predict_with_generate=True,
    logging_dir="logs",
    logging_steps=50,
    save_total_limit=3,
    #report_to = None,
)

#data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    #data_collator=data_collator,
    train_dataset=train_data,
    eval_dataset=validation_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


### Train

Wandb integration

In [19]:
import os
os.environ["WANDB_DISABLED"] = "true"

Evaluate before fine-tuning

In [20]:
trainer.evaluate()



OverflowError: out of range integral type conversion attempted

Train the model

In [21]:
#%%wandb
# uncomment to display Wandb charts

trainer.train()

Step,Training Loss
50,2.3932
100,2.305
150,2.2024
200,2.1511
250,2.1778


TrainOutput(global_step=263, training_loss=2.247692797120533, metrics={'train_runtime': 52.3177, 'train_samples_per_second': 20.07, 'train_steps_per_second': 5.027, 'total_flos': 160060774809600.0, 'train_loss': 2.247692797120533, 'epoch': 1.0})

Evaluate after fine-tuning

In [None]:
trainer.evaluate()

## Evaluation

---

**Generate summaries from the fine-tuned model and compare them with those generated from the original, pre-trained one.**

In [24]:
def generate_summary(test_samples, model):
    inputs = tokenizer(
        test_samples["text"],
        padding="max_length",
        truncation=True,
        max_length=encoder_max_length,
        return_tensors="pt",
    )
    input_ids = inputs.input_ids.to(model.device)
    attention_mask = inputs.attention_mask.to(model.device)
    outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens = 120)
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs, output_str


model_before_tuning = AutoModelForSeq2SeqLM.from_pretrained(model_name)

test_samples = validation_data_txt.select(range(16))

summaries_before_tuning = generate_summary(test_samples, model_before_tuning)[1]
summaries_after_tuning = generate_summary(test_samples, model)[1]

In [25]:
print(
    tabulate(
        zip(
            range(len(summaries_after_tuning)),
            summaries_after_tuning,
            summaries_before_tuning,
        ),
        headers=["Id", "Summary after", "Summary before"],
    )
)
print("\nTarget summaries:\n")
print(
    tabulate(list(enumerate(test_samples["titles"])), headers=["Id", "Target summary"])
)
print("\nSource documents:\n")
print(tabulate(list(enumerate(test_samples["text"])), headers=["Id", "Document"]))

  Id  Summary after                                                                                                                                                                                        Summary before
----  ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
   0  Le chef de l'Etat a condamné les violences de manifestants à l'issue de l'intrusion de syndicalistes à l'Assemblée nationale.                                                                        Le chef de l'Etat a condamné avec la plus grande fermeté les violences à l'encontre de syndicalistes.
   1  Un violent impact de foudre s'est abattu mardi 4 juin sur le 

## Generate test

In [26]:
import pandas as pd

def make_submission(filepath, output_file):
    # format CSV
    test_df = pd.read_csv(filepath)
    test_df["titles"] = ""
    test_df[["text", "titles"]].to_csv("test_formatted.csv")

    dataset = datasets.load_dataset("csv", data_files = "test_formatted.csv")
    test_samples = dataset["train"]

    summaries_dict = {}

    for i, sample in enumerate(test_samples):
        summaries = generate_summary(test_samples[i], model)[1]
        summaries_dict[i] = {"id": sample["Unnamed: 0"], "text" : sample["text"], "titles" : summaries}
    print("Summaries generated")

    test = pd.from_dict(summaries, "index")
    test['titles'] = test['titles'].map(lambda x : x[0])
    test[["ID", "titles"]].to_csv(output_file, index = False)
    print("File saved")

In [27]:
make_submission("data/test_text.csv", "submission_0.csv")

Dataset({
    features: ['text', 'titles'],
    num_rows: 16
})