## Full Fine-Tuning of FLAN-T5

### Overview
Previously, we explored zero-shot, one-shot, and few-shot prompting techniques in [Zero_Multishot_Inference.ipynb](./Zero_Multishot_Inference.ipynb).  
These approaches demonstrated how well FLAN-T5 can generate responses with pretrained knowledge and in-context learning.

Now, we move to full fine-tuning, where we train FLAN-T5 on a specific dataset to improve performance for a targeted task.  
Fine-tuning allows the model to:
- Learn task-specific patterns beyond its pretraining.
- Generalize better within the domain.
- Perform more consistently without needing complex prompts.

This notebook covers:
- Preparing the dataset  
- Setting up the fine-tuning pipeline  
- Training FLAN-T5 from its pretrained state  
- Evaluating and saving the fine-tuned model  

Let's begin.


In [3]:

!pip install --upgrade pip
!pip install transformers
!pip install datasets --quiet
!pip install torchdata
!pip install torch
!pip install streamlit
!pip install openai
!pip install langchain
!pip install unstructured
!pip install sentence-transformers
!pip install chromadb
!pip install evaluate==0.4.0
!pip install rouge_score==0.1.2
!pip install loralib==0.1.1
!pip install peft==0.3.0



In [4]:
import warnings
warnings.filterwarnings('ignore')

In [5]:
import torch
import evaluate
import time
import pandas as pd
import numpy as np
from datasets import load_dataset
from transformers import (AutoModelForSeq2SeqLM, AutoModelForCausalLM, 
                          AutoTokenizer, GenerationConfig, TrainingArguments, Trainer)
from transformers import AutoTokenizer
from transformers import GenerationConfig


# Model Fine Tuning

In [6]:
DEVICE="mps"
torch_device = torch.device(DEVICE)

## Load Dataset and LLM

In [34]:
hugging_face_dataset_name = "knkarthick/dialogsum"

In [35]:
dataset = load_dataset(hugging_face_dataset_name)

In [36]:
model_name='google/flan-t5-base'
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

## Checking Trainable Parameters

This function calculates the total and trainable parameters in the model and their percentage, helping compare full fine-tuning with methods(eg:PEFT).


In [37]:
def number_of_trainable_model_parameters(model):
        trainable_model_params = 0
        all_model_params = 0
        for _, param in model.named_parameters():
            all_model_params += param.numel()
            if param.requires_grad:
                trainable_model_params += param.numel()
        result = f"trainable model parameters: {trainable_model_params}\n"
        result += f"all model parameters: {all_model_params}\n"
        result += f"Percentage of model params: {(trainable_model_params/all_model_params)*100}"
        return result

In [38]:
print(number_of_trainable_model_parameters(original_model))

trainable model parameters: 247577856
all model parameters: 247577856
Percentage of model params: 100.0


## Trainable Parameters

In full fine-tuning, we train **all parameters** of the model, which is **247M** here, making the task **resource-intensive** compared to PEFT methods.



## Test the Model with Zero Shot Inferencing

In [39]:
index = 200

dialogue = dataset['test'][index]['dialogue']
summary = dataset['test'][index]['summary']

prompt = f"""
Summarize the following conversation

{dialogue}

Summary:
"""

inputs = tokenizer(prompt, return_tensors='pt')
output = tokenizer.decode(
    original_model.generate(
        inputs["input_ids"],
        max_new_tokens=200,
    )[0],
    skip_special_tokens=True
)
dash_line = "-".join("" for x in range(100))
print(dash_line)
print(f"Input Prompt:\n{prompt}")
print(dash_line)
print(f"Baseline Human Summary:\n{summary}\n")
print(dash_line)
print(f"Model Generation - Zero Shot: \n{output}")


---------------------------------------------------------------------------------------------------
Input Prompt:

Summarize the following conversation

#Person1#: Have you considered upgrading your system?
#Person2#: Yes, but I'm not sure what exactly I would need.
#Person1#: You could consider adding a painting program to your software. It would allow you to make up your own flyers and banners for advertising.
#Person2#: That would be a definite bonus.
#Person1#: You might also want to upgrade your hardware because it is pretty outdated now.
#Person2#: How can we do that?
#Person1#: You'd probably need a faster processor, to begin with. And you also need a more powerful hard disc, more memory and a faster modem. Do you have a CD-ROM drive?
#Person2#: No.
#Person1#: Then you might want to add a CD-ROM drive too, because most new software programs are coming out on Cds.
#Person2#: That sounds great. Thanks.

Summary:

--------------------------------------------------------------------

## Perform Full Fine-Tunning

### Preprocess the Dialog-Summary dataset

Convert the dialog-summary (prompt-response) pairs into explicit instructions for the LLM. Prepend an instruction to the start of the dialog with 'Summarize the following conversation' and the start of the summary with 'Summary as follows'

In [40]:
def tokenize_function(example):
    start_prompt = 'Summarize the following conversation.\n\n'
    end_prompt = '\n\nSummary: '
    prompt = [start_prompt + dialogue + end_prompt for dialogue in example['dialogue']]
    example['input_ids'] = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
    example['labels'] = tokenizer(example["summary"], padding="max_length", truncation=True, return_tensors="pt").input_ids

    return example

## Handling Train, Validation, and Test Sets

The dataset is processed using the `datasets` library, which automatically manages splits:

- `dataset["train"]`: The training set used for model learning.
- `dataset["validation"]`: The validation set for tuning hyperparameters and preventing overfitting.
- `dataset["test"]`: The test set for evaluating final model performance.

The `map()` function applies tokenization to all splits while `remove_columns()` ensures only relevant tokenized inputs remain.


In [41]:
# The dataset actually contains 3 diff splits: train, validation, test
# The tokenize_function code is handling all data accross all splits in batches

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['id', 'topic', 'dialogue', 'summary',])

Map: 100%|██████████| 12460/12460 [00:07<00:00, 1675.51 examples/s]
Map: 100%|██████████| 500/500 [00:00<00:00, 1713.19 examples/s]
Map: 100%|██████████| 1500/1500 [00:00<00:00, 1795.14 examples/s]


## Subsampling the Dataset

To reduce computational load and speed up processing, we select a smaller subset of the dataset instead of using the full data.


In [42]:
tokenized_datasets = tokenized_datasets.filter(lambda example, index: index % 100 == 0, with_indices=True)

Filter: 100%|██████████| 12460/12460 [00:07<00:00, 1566.54 examples/s]
Filter: 100%|██████████| 500/500 [00:00<00:00, 1348.14 examples/s]
Filter: 100%|██████████| 1500/1500 [00:00<00:00, 1512.20 examples/s]


In [43]:
print(f"Shapes of the datasets:")
print(f"Training: {tokenized_datasets['train'].shape}")
print(f"Validation: {tokenized_datasets['validation'].shape}")
print(f"Test: {tokenized_datasets['test'].shape}")

print(tokenized_datasets)

Shapes of the datasets:
Training: (125, 2)
Validation: (5, 2)
Test: (15, 2)
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 125
    })
    validation: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 5
    })
    test: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 15
    })
})


### Fine-Tune the model with the Preprocessed Dataset

Now utilize the built-in Hugging Face Trainer class.

## Training Arguments and Trainer Setup

The `TrainingArguments` define key hyperparameters for training:

- **`output_dir`** → Directory to save model checkpoints.  
- **`learning_rate`** → Step size for model updates (1e-5).  
- **`num_train_epochs`** → Number of full passes over the dataset (1 epoch).  
- **`weight_decay`** → Regularization to prevent overfitting (0.01).  
- **`logging_steps`** → Logs training progress after every step (1).  
- **`max_steps`** → Limits training to 1 step for quick testing.

The `Trainer` manages training and evaluation:
- **`model`** → The FLAN-T5 model.  
- **`train_dataset`** → Training dataset.  
- **`eval_dataset`** → Validation dataset for evaluation.  


In [44]:
output_dir = f"./dialogue-summary-training-{str(int(time.time()))}"

training_args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=1e-5,
    num_train_epochs=1,
    weight_decay=0.01,
    logging_steps=1,
    max_steps=1
)

trainer = Trainer(
    model=original_model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation']
)

In [None]:
trainer.train()  

In [46]:
instruct_model = AutoModelForSeq2SeqLM.from_pretrained(output_dir).to(torch_device)
original_model = original_model.to(torch_device)

## Evaluate the Model Qualitatively

In [47]:
index = 200
dialogue = dataset['test'][index]['dialogue']
human_baseline_summary = dataset['test'][index]['summary']

prompt = f"""
Summarize the following conversation

{dialogue}

Summary:
"""

input_ids = tokenizer(prompt, return_tensors='pt').input_ids

original_outputs = original_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
original_text_output = tokenizer.decode(original_outputs[0], skip_special_tokens=True)

instruct_outputs = instruct_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
instruct_text_output = tokenizer.decode(instruct_outputs[0], skip_special_tokens=True)

dash_line = "-".join("" for x in range(100))
print(dash_line)
print(f"Input Prompt:\n{prompt}")
print(dash_line)
print(f"Baseline Human Summary:\n{human_baseline_summary}\n")
print(dash_line)
print(f"Original Model Generation - Zero Shot: \n{original_text_output}")
print(dash_line)
print(f"Instruct Model Generation - Fine Tune: \n{instruct_text_output}")

---------------------------------------------------------------------------------------------------
Input Prompt:

Summarize the following conversation

#Person1#: Have you considered upgrading your system?
#Person2#: Yes, but I'm not sure what exactly I would need.
#Person1#: You could consider adding a painting program to your software. It would allow you to make up your own flyers and banners for advertising.
#Person2#: That would be a definite bonus.
#Person1#: You might also want to upgrade your hardware because it is pretty outdated now.
#Person2#: How can we do that?
#Person1#: You'd probably need a faster processor, to begin with. And you also need a more powerful hard disc, more memory and a faster modem. Do you have a CD-ROM drive?
#Person2#: No.
#Person1#: Then you might want to add a CD-ROM drive too, because most new software programs are coming out on Cds.
#Person2#: That sounds great. Thanks.

Summary:

--------------------------------------------------------------------

## Evaluate the Model Quantitatively (with ROUGE Metric)

In [48]:
rouge = evaluate.load('rouge')

Downloading builder script: 100%|██████████| 6.27k/6.27k [00:00<00:00, 6.27MB/s]


In [49]:
dialogue = dataset['test'][0:10]['dialogue']
human_baseline_summaries = dataset['test'][0:10]['summary']

original_model_summaries = []
instruct_model_summaries = []
for _, dialogue in enumerate(dialogue):
    prompt = f"""
Summarize the following conversation

{dialogue}

Summary:
    """
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids
    
    original_outputs = original_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200))
    original_text_output = tokenizer.decode(original_outputs[0], skip_special_tokens=True)
    original_model_summaries.append(original_text_output)

    instruct_outputs = instruct_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200))
    instruct_text_output = tokenizer.decode(instruct_outputs[0], skip_special_tokens=True)
    instruct_model_summaries.append(instruct_text_output)

zipped_summaries = list(zip(human_baseline_summaries, original_model_summaries, instruct_model_summaries))

df = pd.DataFrame(zipped_summaries, columns=['human', 'original', 'instruct'])

In [50]:
df

Unnamed: 0,human,original,instruct
0,Ms. Dawson helps #Person1# to write a memo to ...,Employees are required to use instant messagin...,#Person1# asks Ms. Dawson to take a dictation ...
1,In order to prevent employees from wasting tim...,This memo will be sent to all employees by thi...,#Person1# asks Ms. Dawson to take a dictation ...
2,Ms. Dawson takes a dictation for #Person1# abo...,Employees are required to use the Office of In...,#Person1# asks Ms. Dawson to take a dictation ...
3,#Person2# arrives late because of traffic jam....,People are talking about the traffic in this c...,#Person2# got stuck in traffic again. #Person1...
4,#Person2# decides to follow #Person1#'s sugges...,#Person1: I'm finally here!,#Person2# got stuck in traffic again. #Person1...
5,#Person2# complains to #Person1# about the tra...,#Person1: I'm sorry to hear that you're stuck ...,#Person2# got stuck in traffic again. #Person1...
6,#Person1# tells Kate that Masha and Hero get d...,Masha and Hero are divorced.,#Person1# tells Kate Masha and Hero are gettin...
7,#Person1# tells Kate that Masha and Hero are g...,Masha and Hero are divorced.,#Person1# tells Kate Masha and Hero are gettin...
8,#Person1# and Kate talk about the divorce betw...,#Person1: Masha and Hero are getting a divorce.,#Person1# tells Kate Masha and Hero are gettin...
9,#Person1# and Brian are at the birthday party ...,"#Person1#: Brian, thank you for coming to the ...",Brian's birthday is coming. Brian dances with ...


In [51]:
original_model_results = rouge.compute(
    predictions=original_model_summaries,
    references=human_baseline_summaries,
    use_aggregator=True,
    use_stemmer=True
)

In [52]:
instruct_model_results = rouge.compute(
    predictions=instruct_model_summaries,
    references=human_baseline_summaries[0:len(instruct_model_summaries)],
    use_aggregator=True,
    use_stemmer=True
)

print(f"Original Model: \n{original_model_results}")
print(f"Instruct Model: \n{instruct_model_results}")


Original Model: 
{'rouge1': 0.261052062988671, 'rouge2': 0.08531489481944488, 'rougeL': 0.224821552384684, 'rougeLsum': 0.22788611265447228}
Instruct Model: 
{'rouge1': 0.38857220563277894, 'rouge2': 0.13135692283806472, 'rougeL': 0.28167162470172985, 'rougeLsum': 0.28344342480768214}


## Performance Improvement and Training Time

The **fine-tuned model** shows a **significant improvement** in ROUGE scores over the zero-shot model, especially in **ROUGE-1 (+12.7%)** and **ROUGE-2 (+4.6%)**, indicating better summarization quality.

### Training Time:
- **Device:** Mac M2 (`mps`)
- **Epochs:** 1
- **Total Time:** ~24 hours

## Next Steps for Improvement

To further enhance the model's efficiency and performance, consider the following improvements:

1. **Optimize Training**  
   - Use **LoRA/PEFT** to fine-tune fewer parameters efficiently.  
   - Reduce **sequence length** and **batch size** to lower memory usage.  

2. **Use Better Hardware**  
   - Train on a **GPU cloud service (AWS, Google Colab, Kaggle)** for faster processing.
      
3. **Deploy and Test**  
   - Convert the model to **ONNX/TorchScript** for optimized inference.  
