# Fine Tune Flan-T5 using LoRA
In this notebook, we will be fine-tuning Flan-T5 model by Google on Text summarization using LoRA technique.

## 1. Installations

In [1]:
!pip install --upgrade pip
!pip install torch \
    torchdata \
    transformers \
    datasets \
    evaluate \
    rouge_score \
    loralib \
    peft

Collecting pip
  Downloading pip-24.0-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m15.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.1.2
    Uninstalling pip-23.1.2:
      Successfully uninstalled pip-23.1.2
Successfully installed pip-24.0
Collecting datasets
  Downloading datasets-2.17.1-py3-none-any.whl.metadata (20 kB)
Collecting evaluate
  Downloading evaluate-0.4.1-py3-none-any.whl.metadata (9.4 kB)
Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting loralib
  Downloading loralib-0.1.2-py3-none-any.whl.metadata (15 kB)
Collecting peft
  Downloading peft-0.8.2-py3-none-any.whl.metadata (25 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting multiprocess (from datasets)
  Downloa

## 2. Imports

In [20]:
import time
import torch

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, TaskType, PeftModel, PeftConfig, prepare_model_for_kbit_training

## 3. Dataset
The dataset we will using is  Dialogsum Hugging Face dataset. It contains 10,000+ dialogues with the corresponding manually labeled summaries and topics.

In [3]:
dataset = load_dataset("knkarthick/dialogsum")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/4.65k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/11.3M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/442k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.35M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [4]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 12460
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 500
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 1500
    })
})


check a single sample

In [5]:
dataset["train"][0]

{'id': 'train_0',
 'dialogue': "#Person1#: Hi, Mr. Smith. I'm Doctor Hawkins. Why are you here today?\n#Person2#: I found it would be a good idea to get a check-up.\n#Person1#: Yes, well, you haven't had one for 5 years. You should have one every year.\n#Person2#: I know. I figure as long as there is nothing wrong, why go see the doctor?\n#Person1#: Well, the best way to avoid serious illnesses is to find out about them early. So try to come at least once a year for your own good.\n#Person2#: Ok.\n#Person1#: Let me see here. Your eyes and ears look fine. Take a deep breath, please. Do you smoke, Mr. Smith?\n#Person2#: Yes.\n#Person1#: Smoking is the leading cause of lung cancer and heart disease, you know. You really should quit.\n#Person2#: I've tried hundreds of times, but I just can't seem to kick the habit.\n#Person1#: Well, we have classes and some medications that might help. I'll give you more information before you leave.\n#Person2#: Ok, thanks doctor.",
 'summary': "Mr. Smith'

Load Flan-T5 tokenizer

In [6]:
model_name='google/flan-t5-base'

tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

## 4. Preprocess Data

In [7]:
def tokenize_function(example):
    start_prompt = 'Summarize the following conversation.\n\n'
    end_prompt = '\n\nSummary: '
    prompt = [start_prompt + dialogue + end_prompt for dialogue in example["dialogue"]]
    example['input_ids'] = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
    example['labels'] = tokenizer(example["summary"], padding="max_length", truncation=True, return_tensors="pt").input_ids

    return example

In [8]:
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['id', 'topic', 'dialogue', 'summary',])

Map:   0%|          | 0/12460 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/1500 [00:00<?, ? examples/s]

In [9]:
print(tokenized_datasets)

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 12460
    })
    validation: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 500
    })
    test: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 1500
    })
})


## 5. Model Building

In [10]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)

config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Setup PEFT/LoRA model for Fine Tuning

In [32]:
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
    r=32, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5
)
peft_model = get_peft_model(model, lora_config)

In [22]:
print(peft_model)

PeftModelForSeq2SeqLM(
  (base_model): LoraModel(
    (model): T5ForConditionalGeneration(
      (shared): Embedding(32128, 768)
      (encoder): T5Stack(
        (embed_tokens): Embedding(32128, 768)
        (block): ModuleList(
          (0): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): lora.Linear(
                    (base_layer): Linear(in_features=768, out_features=768, bias=False)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.05, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=768, out_features=32, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default): Linear(in_features=32, out_features=768, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
            

In [23]:
peft_model.print_trainable_parameters()

trainable params: 3,538,944 || all params: 251,116,800 || trainable%: 1.4092820552029972


## 6. Training PEFT Adapter

In [41]:
output_dir = f'./peft-dialogue-summary-training-{str(int(time.time()))}'

training_args = TrainingArguments(
    output_dir=output_dir,
    auto_find_batch_size=True,
    learning_rate=2e-5,
    num_train_epochs=1,
    logging_steps=1,

)

trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"]
)

In [42]:
trainer.train()

Step,Training Loss
1,8.8103
2,8.5886
3,8.8697
4,8.4443
5,8.0577
6,8.6409
7,9.2052
8,8.5139
9,8.6767
10,7.6818




TrainOutput(global_step=1558, training_loss=2.4135568084168955, metrics={'train_runtime': 4336.9584, 'train_samples_per_second': 2.873, 'train_steps_per_second': 0.359, 'total_flos': 8667537195663360.0, 'train_loss': 2.4135568084168955, 'epoch': 1.0})

In [46]:
trainer.evaluate()

{'eval_loss': 0.4536179304122925,
 'eval_runtime': 50.3408,
 'eval_samples_per_second': 9.932,
 'eval_steps_per_second': 1.251,
 'epoch': 1.0}

## 7. Saving model and tokenizer

In [43]:
peft_model_path="./peft-dialogue-summary-checkpoint"

trainer.model.save_pretrained(peft_model_path)
tokenizer.save_pretrained(peft_model_path)

('./peft-dialogue-summary-checkpoint/tokenizer_config.json',
 './peft-dialogue-summary-checkpoint/special_tokens_map.json',
 './peft-dialogue-summary-checkpoint/spiece.model',
 './peft-dialogue-summary-checkpoint/added_tokens.json',
 './peft-dialogue-summary-checkpoint/tokenizer.json')

## 8. Loading the Adapter on top of FLAN-T5

In [44]:
peft_model_base = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)

peft_model = PeftModel.from_pretrained(peft_model_base,
                                       './peft-dialogue-summary-checkpoint/',
                                       torch_dtype=torch.bfloat16,
                                       is_trainable=False)

`Note`: We set `is_trainable=False` because the plan is only to perform inference with this PEFT model. If you were preparing the model for further training, you would set `is_trainable=True`

## 8. Testing the Model

In [51]:
from transformers import GenerationConfig
import pprint

index = 45
dialogue = dataset['test'][index]['dialogue']
baseline_human_summary = dataset['test'][index]['summary']

prompt = f"""
Summarize the following conversation.

{dialogue}

Summary: """

input_ids = tokenizer(prompt, return_tensors="pt").input_ids

peft_model_outputs = peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)

pprint.pprint(prompt)
print(f'PEFT MODEL: {peft_model_text_output}')

('\n'
 'Summarize the following conversation.\n'
 '\n'
 '#Person1#: Would you like to go to the party tonight?\n'
 '#Person2#: Whose party?\n'
 "#Person1#: Ruojia's. Don't you know that? Ruojia has got married.\n"
 "#Person2#: What! Is she really? I can't believe it!\n"
 '#Person1#: Yes. Yesterday.\n'
 "#Person2#: Good gracious. That's incredible! I feel so happy for her!\n"
 '#Person1#: Yes, me too.\n'
 '#Person2#: But how do you know that?\n'
 '#Person1#: I saw the news from her twitter. And she sent an email about it.\n'
 "#Person2#: What? I didn't receive it!\n"
 '#Person1#: Maybe you should check your email.\n'
 '#Person2#: Oh yes, I find it. Tonight at her home. Will you bring '
 'something?\n'
 '#Person1#: Yes, a pair of wineglasses and a card to wish her happy '
 'marriage.\n'
 '#Person2#: I will buy a tea set.\n'
 '\n'
 'Summary: ')
PEFT MODEL: You're going to Ruojia's party tonight.


## 9. Conclusion
LoRA can bring down training costs by a significant margin without compromising on the output quality, we can even bring down these costs further by using Quantization techinques along with LoRA (QLoRA).As we can see training is so easy and fast with LoRA. But the reason we fine tuining the model for only one epoch is that the dataset it self is very large.