## Finetune T5-base

This example demonstrates "how to fine-tune [google/flan-t5-base](https://huggingface.co/google/flan-t5-base) for chat & dialogue. Basically follows this [walk-thru by Philipp Schmid](https://www.philschmid.de/fine-tune-flan-t5)

Papers: 

[Finetuned Language Models are Zero-Shot Learners](https://arxiv.org/abs/2109.01652) \
[Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf) \
[The Flan Collection: Designing Data and Methods for Effective Instruction Tuning](https://arxiv.org/abs/2301.13688)

Code: 

https://github.com/google-research/FLAN

Metrics:

[ROUGE (metric)](https://en.wikipedia.org/wiki/ROUGE_(metric)) \
[What is the ROUGE metric (video from HF course)](https://www.youtube.com/watch?v=TMshhnrEXlg) \
[An intro to ROUGE, and how to use it to evaluate summaries](https://www.freecodecamp.org/news/what-is-rouge-and-how-it-works-for-evaluation-of-summaries-e059fb8ac840/) \
[Two minutes NLP — Learn the ROUGE metric by examples](https://medium.com/nlplanet/two-minutes-nlp-learn-the-rouge-metric-by-examples-f179cc285499)

Other Resources:

[HF NLP Course - Part 7: Summarization](https://huggingface.co/learn/nlp-course/en/chapter7/5?fw=pt#summarization) \
[Training summarization & translation models with fastai & blurr: W&B Study Group](https://www.youtube.com/watch?v=Jsz4E2iNXUA)

Notes:

* "These models have been fine-tuned on more that 1000 additional tasks covering also more languages" \
* Improves upon T5 with instruction finetuning with more tasks and including chain-of-thought data \
* Dataset = [samsum](https://huggingface.co/datasets/samsum) ("16k messenger-like conversations with summaries") \

## Imports

In [None]:
import os

# NOTE: To limit HF's Trainer/Accelerate from using all the GPU's, you need to set this environment var BEFORE you import any
# related package!!!
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
from random import randrange

from datasets import concatenate_datasets, load_dataset
import evaluate
from huggingface_hub import HfFolder
import matplotlib.pyplot as plt
import nltk
import numpy as np
from nltk.tokenize import sent_tokenize
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments
import wandb

In [None]:
nltk.download("punkt")

## Config

In [None]:
ignore_tok_id = -100

dataset_id = "samsum"
model_id = "google/flan-t5-base"
hf_repo_id = f"{model_id.split('/')[1]}-{dataset_id}"

hf_tokenizer = AutoTokenizer.from_pretrained(model_id)
hf_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)

In [None]:
from huggingface_hub import notebook_login

notebook_login()

## Data

### Step 1: Load dataset

In [None]:
raw_training_ds = load_dataset(dataset_id)

raw_training_ds

In [None]:
print(f"Train dataset size: {len(raw_training_ds['train'])}")
print(f"Test dataset size: {len(raw_training_ds['test'])}")

In [None]:
sample = raw_training_ds['train'][randrange(len(raw_training_ds["train"]))]

print(f"dialogue: \n{sample['dialogue']}\n---------------")
print(f"summary: \n{sample['summary']}\n---------------")

In [None]:
from datasets import concatenate_datasets

# The maximum total input sequence length after tokenization.
# Sequences longer than this will be truncated, sequences shorter will be padded.
tokenized_inputs = concatenate_datasets([raw_training_ds["train"], raw_training_ds["test"]]).map(
    lambda x: hf_tokenizer(x["dialogue"], truncation=True), batched=True, remove_columns=["dialogue", "summary"]
)

source_lengths = [len(x) for x in tokenized_inputs["input_ids"]]
max_source_length = max(source_lengths)
print(f"Max source length: {max_source_length}")

# The maximum total sequence length for target text after tokenization.
# Sequences longer than this will be truncated, sequences shorter will be padded."
tokenized_targets = concatenate_datasets([raw_training_ds["train"], raw_training_ds["test"]]).map(
    lambda x: hf_tokenizer(x["summary"], truncation=True), batched=True, remove_columns=["dialogue", "summary"]
)

target_lengths = [len(x) for x in tokenized_targets["input_ids"]]
max_target_length = max(target_lengths)
print(f"Max target length: {max_target_length}")


In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10,3.5), sharey=True)

ax[0].hist(source_lengths, bins=20, color="C0", edgecolor="C0")
ax[0].set_title("Dialogue Token Length")
ax[0].set_xlabel("Length")
ax[0].set_ylabel("Count")

ax[1].hist(target_lengths, bins=20, color="C0", edgecolor="C0")
ax[1].set_title("Summary Length")
ax[1].set_xlabel("Length")

### Step 2: Tokenize

"In T5, every NLP task is formulated in terms of a prompt prefix like summarize: which conditions the model to adapt the generated text to the prompt."

![Image Alt Text](https://huggingface.co/datasets/huggingface-course/documentation-images/resolve/main/en/chapter7/t5.svg)

In [None]:
def preprocess_examples(sample, padding=False):
    # add prefix to the input for t5
    inputs = ["summarize: " + item for item in sample["dialogue"]]

    # tokenize inputs
    model_inputs = hf_tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)

    # Tokenize targets with the `text_target` keyword argument
    labels = hf_tokenizer(text_target=sample["summary"], max_length=max_target_length, padding=padding, truncation=True)

    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore padding in the loss.
    if padding == "max_length":
        labels["input_ids"] = [[(l if l != hf_tokenizer.pad_token_id else ignore_tok_id) for l in label] for label in labels["input_ids"]]

    model_inputs["labels"] = labels["input_ids"]
    
    return model_inputs


In [None]:
tok_training_ds = raw_training_ds.map(preprocess_examples, batched=True, remove_columns=["dialogue", "summary", "id"])

print(f"Keys of tokenized dataset: {list(tok_training_ds['train'].features)}")

## Train


In [None]:
wandb.init(project=f"llms_ft_t5_base_samsum")  # Replace 'project_name' with your project name in wandb

### Step 1: Metrics

`ROUGE` is "developed for applications like summarization where high recall is more important than just precision ... we check how many n-grams in the reference text also occur in the generated text"

Recall= # of overlapping words / # words in reference summary
Precision = # of overlapping words / # of words in generated summary
​



The score that is reported is generally the F1 for each rouge sub-metric (e.g., the harmonic mean of the precision and recall scores)

**Longest Common Substring (LCS) Score:** 

`ROUGE-L` = Calculates the score per sentence and averages it for the summaries \
`ROUGE-LSUM` = Calculates it directly over the whole summary

Note that LCS is normalized to account for reference summaries of different legnths

In [None]:
rouge_score = evaluate.load("rouge")

In [None]:
# Helper function to postprocess text
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(sent_tokenize(label)) for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
        
    decoded_preds = hf_tokenizer.batch_decode(preds, skip_special_tokens=True)
    
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != ignore_tok_id, labels, hf_tokenizer.pad_token_id)
    decoded_labels = hf_tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Ensure generated text is formatted correctly for rouge
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = rouge_score.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    result = {k: round(v * 100, 4) for k, v in result.items()}
    
    prediction_lens = [np.count_nonzero(pred != hf_tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    
    return result


### Step 2: DataCollator

During training we use "teacher forcing" on the decoder side so that the current and previous tokens are predicting the next token in the summary.  We do this by shifting the labels to the right by 1. This in conjunction with the masked self-attention mechanism ensures that we aren't seeing future tokens when we make a prediction at each time step.

In [None]:
# NOTE: If you use mixed precision all your tensors need to have dimensions that are multiple of 8 (thus we set `pad_to_multiple_of` = 8 just in case)
data_collator = DataCollatorForSeq2Seq(hf_tokenizer, model=hf_model, label_pad_token_id=ignore_tok_id, pad_to_multiple_of=8)


### Step 3: Trainer

In [None]:
# Define training args
training_args = Seq2SeqTrainingArguments(
    output_dir=f"./{hf_repo_id}",
    overwrite_output_dir=True,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    learning_rate=3e-5,
    # Overflows with fp16
    fp16=False,  
    # So we can evaluate generations as part of the training loop (uses `generate()` instead of model's forward pass to create preds)
    predict_with_generate=True,  
    # --- logging & evaluation strategies ---
    logging_dir=f"./{hf_repo_id}/logs",
    logging_strategy="steps",
    logging_steps=500,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="rougeLsum",
    # --- push to hub parameters ---
    report_to="wandb",
    push_to_hub=True,
    hub_strategy="every_save",
    hub_model_id=hf_repo_id,
    hub_token=HfFolder.get_token(),
)

# Create Trainer instance
trainer = Seq2SeqTrainer(
    model=hf_model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tok_training_ds["train"],
    eval_dataset=tok_training_ds["test"],
    compute_metrics=compute_metrics,
)


In [None]:
# How are we doing BEFORE training
trainer.evaluate()

In [None]:
trainer.train()

In [None]:
trainer.evaluate()

In [None]:
# Save our tokenizer and create model card
hf_tokenizer.save_pretrained(hf_repo_id)
trainer.create_model_card()

# Push the results to the hub
trainer.push_to_hub()

## Inference

In [None]:
from transformers import pipeline
from random import randrange

# load model and tokenizer from huggingface hub with pipeline
summarizer = pipeline("summarization", model="wgpubs/flan-t5-base-samsum", device=0)

# select a random test sample
sample = raw_training_ds['test'][randrange(len(raw_training_ds["test"]))]
print(f"dialogue: \n{sample['dialogue']}\n---------------")

# summarize dialogue
res = summarizer(sample["dialogue"])

print(f"flan-t5-base summary:\n{res[0]['summary_text']}")