<a href="https://colab.research.google.com/github/sochachai/dialogue_topic_summary_llm_lora_finetune/blob/main/llm_code_finetune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Install dependencies

In [1]:
!pip install datasets
!pip install transformers
!pip install evaluate
!pip install accelerate -U
!pip install transformers[torch]
!pip install peft

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (

#### Import libraries

In [2]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer, GenerationConfig
import evaluate
import pandas as pd
import numpy as np

#### Load data

In [3]:
dataset = load_dataset("HuggingFaceTB/smoltalk","everyday-conversations")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/9.25k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/946k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/52.6k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2260 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/119 [00:00<?, ? examples/s]

In [4]:
dataset

DatasetDict({
    train: Dataset({
        features: ['full_topic', 'messages'],
        num_rows: 2260
    })
    test: Dataset({
        features: ['full_topic', 'messages'],
        num_rows: 119
    })
})

In [5]:
dataset['train'][2]['full_topic']

'Shopping/Window shopping/Window shopping etiquette'

In [6]:
dataset['train'][2]['messages']

[{'content': 'Hi', 'role': 'user'},
 {'content': 'Hello! How can I help you today?', 'role': 'assistant'},
 {'content': "I'm going to the mall to do some window shopping. What's the point of window shopping if I'm not going to buy anything?",
  'role': 'user'},
 {'content': 'Window shopping can be a fun way to browse and get inspiration for future purchases, or to simply enjoy looking at products and displays without feeling pressured to buy.',
  'role': 'assistant'},
 {'content': 'That makes sense. What are some basic rules I should follow while window shopping to be polite to the store staff?',
  'role': 'user'},
 {'content': 'Some basic rules include not touching or handling merchandise excessively, not blocking store entrances or aisles, and being respectful of staff and other customers.',
  'role': 'assistant'},
 {'content': "Alright, I'll keep those in mind. Is it okay to take pictures of store displays or products?",
  'role': 'user'},
 {'content': "Yes, it's usually okay to tak

#### Load model

In [7]:
model_name = "google/flan-t5-base"
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

In [8]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

print(print_number_of_trainable_model_parameters(base_model))

trainable model parameters: 247577856
all model parameters: 247577856
percentage of trainable model parameters: 100.00%


In [9]:
dataset['train'][0]['messages'] = "".join([str(item['role'] + ':' + item['content']+'\n') for item in dataset['train'][0]['messages']])

In [10]:
dataset['train'][0]['full_topic']

'Travel/Vacation destinations/Beach resorts'

In [11]:
"".join([str(item['role'] + ':' + item['content']+'\n') for item in dataset['train'][0]['messages']])

"user:Hi there\nassistant:Hello! How can I help you today?\nuser:I'm looking for a beach resort for my next vacation. Can you recommend some popular ones?\nassistant:Some popular beach resorts include Maui in Hawaii, the Maldives, and the Bahamas. They're known for their beautiful beaches and crystal-clear waters.\nuser:That sounds great. Are there any resorts in the Caribbean that are good for families?\nassistant:Yes, the Turks and Caicos Islands and Barbados are excellent choices for family-friendly resorts in the Caribbean. They offer a range of activities and amenities suitable for all ages.\nuser:Okay, I'll look into those. Thanks for the recommendations!\nassistant:You're welcome. I hope you find the perfect resort for your vacation.\n"

In [12]:
import datasets
import pandas as pd
from datasets import Dataset, DatasetDict
train_df = pd.DataFrame({
     "dialogue" : ["".join([str(item['role'] + ':' + item['content']+'\n') for item in row['messages']]) for row in dataset['train']],
     "topic" : [row['full_topic'] for row in dataset['train']]
})

test_df = pd.DataFrame({
     "dialogue" : ["".join([str(item['role'] + ':' + item['content']+'\n') for item in row['messages']]) for row in dataset['test']],
     "topic" : [row['full_topic'] for row in dataset['test']]
})

train_dataset = Dataset.from_dict(train_df)
test_dataset = Dataset.from_dict(test_df)
my_dataset_dict = datasets.DatasetDict({"train":train_dataset, "test":test_dataset})

In [13]:
my_dataset_dict['test'][100]

{'dialogue': "user:Hi\nassistant:Hello! How can I help you today?\nuser:I'm looking for grocery delivery services in my area. Do you know of any?\nassistant:Yes, there are several options available. Some popular ones include Instacart, Shipt, and Peapod. Would you like me to help you find one near you?\nuser:That sounds great, thank you. How do I get started with Instacart?\nassistant:You can download the Instacart app or visit their website to sign up and enter your zip code. They'll show you available stores and delivery times in your area.\n",
 'topic': 'Shopping/Grocery shopping/Grocery delivery services'}

In [14]:
def tokenize_function(example):
    start_prompt = 'What is the topic of the following conversation?\n\n'
    end_prompt = '\n\nTopic: '
    prompt = [start_prompt + dialogue + end_prompt for dialogue in example['dialogue']]
    example['input_ids'] = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
    example['labels'] = tokenizer(example["topic"], padding="max_length", truncation=True, return_tensors="pt").input_ids
    return example

tokenized_datasets = my_dataset_dict.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['dialogue', 'topic'])
#tokenized_datasets = tokenized_datasets.filter(lambda example, index: index % 100 == 0, with_indices=True)

Map:   0%|          | 0/2260 [00:00<?, ? examples/s]

Map:   0%|          | 0/119 [00:00<?, ? examples/s]

In [15]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 2260
    })
    test: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 119
    })
})

In [16]:
from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
)

peft_model_train = get_peft_model(base_model, lora_config)
print(print_number_of_trainable_model_parameters(peft_model_train))

trainable model parameters: 884736
all model parameters: 248462592
percentage of trainable model parameters: 0.36%


In [17]:
output_dir = "./peft-dialogue-topic-training"

peft_training_args = TrainingArguments(
    output_dir=output_dir,
    auto_find_batch_size=True,
    learning_rate=1e-3,
    num_train_epochs=5,
)

In [18]:
peft_trainer = Trainer(
    model=peft_model_train,
    args=peft_training_args,
    train_dataset=tokenized_datasets["train"],
)

peft_trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss
500,1.2675
1000,0.0436
1500,0.0356
2000,0.032
2500,0.0286


TrainOutput(global_step=2825, training_loss=0.2521726925605166, metrics={'train_runtime': 2802.5182, 'train_samples_per_second': 4.032, 'train_steps_per_second': 1.008, 'total_flos': 7768470454272000.0, 'train_loss': 0.2521726925605166, 'epoch': 5.0})

In [19]:
peft_model_path = "./peft-topic-analysis"
peft_trainer.model.save_pretrained(peft_model_path)
tokenizer.save_pretrained(peft_model_path)

('./peft-topic-analysis/tokenizer_config.json',
 './peft-topic-analysis/special_tokens_map.json',
 './peft-topic-analysis/spiece.model',
 './peft-topic-analysis/added_tokens.json',
 './peft-topic-analysis/tokenizer.json')

In [20]:
start_prompt = 'What is the topic of the following conversation?\n\n'
end_prompt = '\n\nTopic: '
prompt = [start_prompt + my_dataset_dict['test'][100]['dialogue']+ end_prompt]
input_ids = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids

In [21]:
my_dataset_dict['test'][100]

{'dialogue': "user:Hi\nassistant:Hello! How can I help you today?\nuser:I'm looking for grocery delivery services in my area. Do you know of any?\nassistant:Yes, there are several options available. Some popular ones include Instacart, Shipt, and Peapod. Would you like me to help you find one near you?\nuser:That sounds great, thank you. How do I get started with Instacart?\nassistant:You can download the Instacart app or visit their website to sign up and enter your zip code. They'll show you available stores and delivery times in your area.\n",
 'topic': 'Shopping/Grocery shopping/Grocery delivery services'}

In [22]:
from peft import PeftModel

peft_model_base = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
peft_model = PeftModel.from_pretrained(peft_model_base, peft_model_path, is_trainable=False)

peft_model_outputs = peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=500, num_beams=1))
peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)


In [23]:
peft_model_text_output

'Food delivery services'

In [24]:
start_prompt = 'What is the topic of the following conversation?\n\n'
end_prompt = '\n\nTopic: '
prompt = [start_prompt + my_dataset_dict['test'][10]['dialogue']+ end_prompt]
input_ids = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
print(my_dataset_dict['test'][10])
peft_model_outputs = peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=500, num_beams=1))
peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)
print(peft_model_text_output)

{'dialogue': "user:Hi\nassistant:Hello! How can I help you today?\nuser:I'm interested in painting as a hobby. What kind of paint should I use for a beginner?\nassistant:As a beginner, I would recommend using acrylic paint. It's easy to clean up and dries quickly.\nuser:That sounds great. How can I preserve my artwork so it lasts a long time?\nassistant:To preserve your artwork, make sure to varnish it once it's completely dry. You can also frame it behind glass to protect it from dust and sunlight.\nuser:That's really helpful, thank you.\n", 'topic': 'Hobbies/Painting/Art preservation'}
Paint for beginners


In [25]:
index = 20
start_prompt = 'What is the topic of the following conversation?\n\n'
end_prompt = '\n\nTopic: '
prompt = [start_prompt + my_dataset_dict['test'][index]['dialogue']+ end_prompt]
input_ids = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
print(my_dataset_dict['test'][index])
peft_model_outputs = peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=500, num_beams=1))
peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)
print(peft_model_text_output)

{'dialogue': "user:Hi\nassistant:Hello! How can I help you today?\nuser:I'm looking for a new frying pan. What are some good materials?\nassistant:Frying pans come in various materials, but popular ones include stainless steel, non-stick, and cast iron. Each has its own benefits.\nuser:What's the difference between a saucepan and a Dutch oven?\nassistant:A saucepan is typically smaller and used for heating sauces or cooking small meals, while a Dutch oven is larger and often used for slow-cooking, braising, or roasting.\nuser:I think I'll get a non-stick pan. Do I need special utensils for it?\nassistant:Yes, it's best to use silicone, wooden, or plastic utensils with non-stick pans to avoid scratching the surface. Metal utensils can damage the non-stick coating.\n", 'topic': 'Cooking/Kitchen tools/Cooking utensils'}
Food & Cooking


In [26]:
index = 30
start_prompt = 'What is the topic of the following conversation?\n\n'
end_prompt = '\n\nTopic: '
prompt = [start_prompt + my_dataset_dict['test'][index]['dialogue']+ end_prompt]
input_ids = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
print(my_dataset_dict['test'][index])
peft_model_outputs = peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=500, num_beams=1))
peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)
print(peft_model_text_output)

Weather:Have a heat wave


In [27]:
index = 111
start_prompt = 'What is the topic of the following conversation?\n\n'
end_prompt = '\n\nTopic: '
prompt = [start_prompt + my_dataset_dict['test'][index]['dialogue']+ end_prompt]
input_ids = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
print(my_dataset_dict['test'][index])
peft_model_outputs = peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=500, num_beams=1))
peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)
print(peft_model_text_output)

{'dialogue': "user:Hello\nassistant:Hello! How can I help you today?\nuser:I'm looking for a new hobby. Do you have any suggestions?\nassistant:Yes, I do. Have you considered cooking? It's a fun and rewarding hobby that can be very creative.\nuser:That sounds interesting. What are some basic food safety tips I should know if I start cooking?\nassistant:Great question. Always wash your hands before handling food, make sure to store food at the right temperature, and cook food to the recommended internal temperature to avoid foodborne illness.\nuser:That's really helpful, thank you.\n", 'topic': 'Hobbies/Cooking/Food safety'}
Cooking
