<a href='https://colab.research.google.com/github/prane-eth/AI_projects/blob/main/projects/LLM_fine-tuning.ipynb' target='_parent'><img src='https://colab.research.google.com/assets/colab-badge.svg' alt='Open In Colab'/></a>

### Project: Fine-tuning a language model

In [1]:
try:
    __import__('unsloth')
except ImportError:
	# %%capture
	%pip install pandas groq python-dotenv datasets
	%pip install 'unsloth @ git+https://github.com/unslothai/unsloth.git'
	%pip install --no-deps 'xformers<0.0.26' trl tyro peft accelerate bitsandbytes
	%pip install torch==2.2.2

In [2]:
import os
import re
import sys
from datasets import Dataset
from groq import Groq
from io import StringIO
import pandas as pd
import torch
from transformers import TrainingArguments, set_seed
from trl import SFTTrainer
from unsloth import FastLanguageModel
from common_functions import display_md

random_state = 42
set_seed(random_state)

datasets_folder = 'datasets'
if not os.path.exists(datasets_folder):
	os.makedirs(datasets_folder)

topic = 'customer_support'
data_filename = os.path.join(datasets_folder, f'{topic}_bot_finetune_data.csv')
model_checkpoint_path = os.path.join(datasets_folder, f'{topic}_saved_model')
rlaif_data_filepath = os.path.join(datasets_folder, f'{topic}_bot_rlaif_data.csv')

groq_api_key = os.getenv('GROQ_API_KEY')

if not groq_api_key and 'google.colab' in sys.modules:
	from google.colab import userdata
	groq_api_key = userdata.get('GROQ_API_KEY')

if not groq_api_key:
	raise ValueError('GROQ_API_KEY is not set in the environment variables')

### Generate synthetic data for fine-tuning
**Data generation using an LLM**: Uses a Large model like Llama-3 (70B) to generate data to use for fine-tuning a small model like Phi 3 (3.8B).

In [3]:
client = Groq(api_key=groq_api_key)

def ask_larger_llm(prompt, model='llama3-70b-8192', return_quoted=True):
	chat_completion = client.chat.completions.create(
		messages=[{ 'role': 'user', 'content': prompt }],
		model=model,
	)
	response = chat_completion.choices[0].message.content
	if not response:
		raise SystemExit('No response from the API.')

	if not return_quoted:
		return response

	# if response doesnt end with ``` then add it
	if not response.endswith('```'):
		response += '```'

	# get the data from the response - csv text between triple quotes ``` ```
	match = re.search(r'```(.*?)```', response, re.DOTALL)
	if match:
		quoted_text = match.group(1)
		quoted_text = quoted_text.strip()

		# sometimes, quotes or special characters are used to start and end the text. remove them
		# if quoted_text[0] == quoted_text[-1]:
		# 	quoted_text = quoted_text[1:-1]
		# remove only if first line doesnt end with same character
		first_line_end_character = quoted_text.split('\n')[0][-1] if '\n' in quoted_text else None
		if quoted_text[0] == quoted_text[-1] and quoted_text[0] != first_line_end_character:
			quoted_text = quoted_text[1:-1]

		return quoted_text
	else:
		print(response)
		raise SystemExit('No data found in the response.')


# if file exists, read it
if os.path.exists(data_filename):
	with open(data_filename, 'r') as file:
		csv_text = file.read()
else:
	num_lines = 100
	prompt = f'Generate high-quality data for fine-tuning in csv for {topic} chatbot' \
			f' for an ecommerce platform in at least {num_lines} lines of data. ' \
			'Include the csv file text in triple quotes ```. ' \
			'response should include no other text. fields: instruction, output.'
	csv_text = ask_larger_llm(prompt)
	with open(data_filename, 'w') as file:
		file.write(csv_text)


training_data = pd.read_csv(data_filename)
print(f'Data size: {len(training_data)}')

training_data.head()

Data size: 56


Unnamed: 0,instruction,output
0,What is the status of my order?,Your order is currently being processed. Pleas...
1,I want to return my item,Please contact our customer service team to in...
2,I forgot my password,No worries! Click on the 'Forgot Password' lin...
3,I want to cancel my order,We're sorry to hear that. Please contact our c...
4,Where is my order?,Tracking information will be sent to you via e...


### Prepare the model for fine-tuning

In [4]:
max_seq_length = 2048
model = None
tokenizer = None
restored_finetuned_model = False

if os.path.exists(model_checkpoint_path):
    try:
        model, tokenizer = FastLanguageModel.from_pretrained(
            model_checkpoint_path, trust_remote_code=True,
            dtype=None, load_in_4bit = True, device_map='cuda',
        )
        restored_finetuned_model = True
        print('Model loaded successfully.')
    except Exception as e:
        print('Error loading the model. Will train a new model.')
        print(e)
else:  # if not restored_finetuned_model:
    model, tokenizer = FastLanguageModel.from_pretrained(
		model_name = 'unsloth/Phi-3-mini-4k-instruct',
		max_seq_length = max_seq_length,
		dtype = None,  # None for auto-detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
		load_in_4bit = True,  # 4-bit quantization to reduce memory usage
	)

    model = FastLanguageModel.get_peft_model(
		model,
		r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
		target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj',
						'gate_proj', 'up_proj', 'down_proj',],
		lora_alpha = 16,
		lora_dropout = 0,  # Supports any, but = 0 is optimized
		bias = 'none',  # Supports any, but = 'none' is optimized
		# 'unsloth' uses 30% less VRAM, fits 2x larger batch sizes!
		use_gradient_checkpointing = 'unsloth', # True or 'unsloth' for very long context
		random_state = random_state,
		use_rslora = False,
		loftq_config = None,
	)

==((====))==  Unsloth: Fast Mistral patching release 2024.5
   \\   /|    GPU: NVIDIA GeForce RTX 3050 Laptop GPU. Max memory: 3.804 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.2+cu121. CUDA = 8.6. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. Xformers = 0.0.25.post1. FA = False.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Unsloth: datasets/customer_support_saved_model has no tokenizer.model file.
Just informing you about this - this is not a critical error.


Unsloth 2024.5 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


Model loaded successfully.


### Prepare the dataset for fine-tuning

In [5]:
prompt = '''You are a customer support chatbot.
Below is an instruction that describes a task that provides further context.
Write a response that appropriately completes the request.

### Instruction:
{}

### Response:
{}'''

def create_dataset(training_data):
	instructions = training_data['instruction']
	outputs = training_data['output']
	texts = []
	for instruction, output in zip(instructions, outputs):
		# without EOS_TOKEN, generation will go on forever
		text = prompt.format(instruction, output) + tokenizer.eos_token
		texts.append(text)
	dataset = Dataset.from_dict({ 'text': texts })
	return dataset

### Train the model

In [6]:
trainer = None

def train_model(training_data, force_train=False):
	global trainer, model, tokenizer, restored_finetuned_model

	if not restored_finetuned_model:  # if restoration failed
		if not os.path.exists(model_checkpoint_path):
			print('Model not found. Training from scratch.')
			force_train = True

	if force_train:
		train_dataset = create_dataset(training_data)
		trainer = SFTTrainer(
			model = model,
			tokenizer = tokenizer,
			train_dataset = train_dataset,
			dataset_text_field = 'text',
			max_seq_length = max_seq_length,
			dataset_num_proc = 2,
			packing = False, # Can make training 5x faster for short sequences.
			args = TrainingArguments(
				per_device_train_batch_size = 2,
				gradient_accumulation_steps = 4,
				warmup_steps = 5,
				max_steps = 60,
				learning_rate = 2e-4,
				fp16 = not torch.cuda.is_bf16_supported(),
				bf16 = torch.cuda.is_bf16_supported(),
				logging_steps = 1,
				optim = 'adamw_8bit',
				weight_decay = 0.01,
				lr_scheduler_type = 'linear',
				seed = random_state,
				output_dir = model_checkpoint_path,
			),
		)

		trainer.train()
		model.save_pretrained(model_checkpoint_path)
		tokenizer.save_pretrained(model_checkpoint_path)
		# trainer.save_model(model_checkpoint_path)

train_model(training_data, force_train=False)

### Test the model

In [7]:
FastLanguageModel.for_inference(model) # Enable native 2x faster inference

def ask_query(query, display=False):
	inputs = tokenizer([
		# query
		prompt.format(
			query,
			'', # output - leave this blank for generation!
		)
	], return_tensors = 'pt').to('cuda')

	# # Streaming outputs
	# text_streamer = TextStreamer(tokenizer)
	# _ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128)

	outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True)
	output = ''.join(tokenizer.batch_decode(outputs))

	# find 'Response: ' and get text after that
	if 'Response:' in output:
		output = output[output.find('Response:') + len('Response:') + 1:]  # also remove extra space or \n

	# remove '<|endoftext|>' from end
	if output.endswith('<|endoftext|>'):
		output = output[:-len('<|endoftext|>')]

	output = output.strip()
	if display:
		display_md(output)
	else:
		return output

ask_query('What are the payment options?', display=True)

We accept all major credit cards and PayPal.

In [8]:
ask_query('May I know the return policy?', display=True)

Our return policy allows for returns within 30 days of purchase with a valid receipt. Please see our full return policy for more details.

In [9]:
ask_query('Who are you?', display=True)

We are the customer service team at XYZ Company. How can we assist you?

### RLAIF: Reinforcement Learning from AI (LLM) Feedback

In [17]:
if os.path.exists(rlaif_data_filepath):
	rlaif_data = pd.read_csv(rlaif_data_filepath)
	if 'improved_output' in rlaif_data.columns:  # drop improved_output column if it exists
		rlaif_data.drop(columns=['improved_output'], inplace=True)
else:
	# provide expected response and current response to AI, ask to improve the response, fine-tune the model again
	rlaif_data = training_data.copy()
	# rename output to expected_output
	rlaif_data.rename(columns={ 'output': 'expected_output' }, inplace = True)
	rlaif_data['current_output'] = None

	# generate response with ask_query function
	# no parallel processing due to CPU heat concerns
	for row_num, row in rlaif_data.iterrows():
		if row_num % 5 == 0:
			print(f'Processing row {row_num+1}/{len(rlaif_data)}')
		current_output = rlaif_data.at[row_num, 'current_output']
		if not current_output:
			response = ask_query(row['instruction'])
			rlaif_data.at[row_num, 'current_output'] = response

	rlaif_data.to_csv(rlaif_data_filepath, index=False)

rlaif_data.head()

Unnamed: 0,instruction,expected_output,current_output
0,What is the status of my order?,Your order is currently being processed. Pleas...,Your order is currently being processed. You w...
1,I want to return my item,Please contact our customer service team to in...,We're happy to help! Please contact our custom...
2,I forgot my password,No worries! Click on the 'Forgot Password' lin...,"No worries! Click on the ""Forgot Password?"" li..."
3,Where is my order?,Tracking information will be sent to you via e...,Tracking information is available once your or...
4,I need a refund,We apologize for any inconvenience. Please con...,We're sorry to hear that. Please contact our c...


In [18]:
forced = True
if forced:
    rlaif_data['improved_output'] = None

# if column is not loaded from file, or empty
if 'improved_output' not in rlaif_data or rlaif_data['improved_output'].notna().sum() == 0:
	rlaif_data['improved_output'] = None
	improvement_prompt = '''
		I am fine-tuning a customer-support chatbot. 
		I provided the instruction, current_output, expected_output (provided by you in the past). 
		Include csv text in the response in triple quotes ```.
		return only these headers: instruction, improved_output. 
	'''

	# pass 15 rows at a time to the AI to improve the response
	for row_num in range(0, len(rlaif_data), 15):
		print(f'Processing rows {row_num} to {row_num+15} of {len(rlaif_data)} rows')
		chunk = rlaif_data.iloc[row_num:row_num+15]
		csv_text = chunk.to_csv(index=False)
		prompt = f'{improvement_prompt}\n```{csv_text}```'
		response_csv = ask_larger_llm(prompt)
		try:
			response_data = pd.read_csv(StringIO(response_csv))
		except:
			print('Failed to parse csv data from the response.')
			print(response_csv)
			break

		# for each row's instruction value in response_data, update the corresponding row in rlaif_data
		for index, row in response_data.iterrows():
			if 'instruction' not in row:
				continue
			instruction = row['instruction']
			improved_output = row['improved_output']
			if 'no improvement' in improved_output.lower():
				improved_output = None
			if improved_output is not None: # and current_output != improved_output:
				instruction_row = rlaif_data[rlaif_data['instruction'] == instruction]
				current_output = instruction_row['current_output'].values[0]
				if current_output != improved_output:
					rlaif_data.loc[rlaif_data['instruction'] == instruction, 'improved_output'] = improved_output

	rlaif_data.dropna(subset=['improved_output'], inplace=True)
	rlaif_data.to_csv(rlaif_data_filepath, index=False)

rlaif_data.head()

Processing rows 0 to 15 of 12 rows


Unnamed: 0,instruction,expected_output,current_output,improved_output
0,What is the status of my order?,Your order is currently being processed. Pleas...,Your order is currently being processed. You w...,Your order is currently being processed. Pleas...
1,I want to return my item,Please contact our customer service team to in...,We're happy to help! Please contact our custom...,Please contact our customer service team to in...
2,I forgot my password,No worries! Click on the 'Forgot Password' lin...,"No worries! Click on the ""Forgot Password?"" li...",No worries! Click on the 'Forgot Password' lin...
3,Where is my order?,Tracking information will be sent to you via e...,Tracking information is available once your or...,Tracking information will be sent to you via e...
4,I need a refund,We apologize for any inconvenience. Please con...,We're sorry to hear that. Please contact our c...,We apologize for any inconvenience. Please con...
