# Fine tuning DistilBert model with QLoRA and LoRA

## Import libraries

In [None]:
import torch
import matplotlib.pyplot as plt
import warnings
def warn(*args, **kwargs):
    pass
warnings.warn = warn
warnings.filterwarnings('ignore')

import json
import numpy as np

from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, TaskType, replace_lora_weights_loftq, prepare_model_for_kbit_training

In [None]:
# Move the model to the appropriate device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Define help functions

In [None]:
# help functions
def save_to_json(data, file_path):
    """
    Save a dictionary to a JSON file.

    Args:
        data (dict): The dictionary to save.
        file_path (str): The path to the JSON file.
    """
    with open(file_path, 'w') as json_file:
        json.dump(data, json_file, indent=4)
    print(f"Data successfully saved to {file_path}")
    
    
def load_from_json(file_path):
    """
    Load data from a JSON file.

    Args:
        file_path (str): The path to the JSON file.

    Returns:
        dict: The data loaded from the JSON file.
    """
    with open(file_path, 'r') as json_file:
        data = json.load(json_file)
    return data   

## Load IMDB dataset

In [None]:
# Load IMDB dataset
imdb = load_dataset("imdb")

In [None]:
train_labels = imdb['train']['label']
unique_labels = set(train_labels)
print("\nUnique labels in the dataset (class information):")
print(unique_labels)

In [None]:
class_names = {0: "negative", 1: "positive"}

In [None]:
train_dataset = imdb["train"].shuffle(seed=42)
test_dataset = imdb["test"].shuffle(seed=42)
print(f"No. of training sample: {len(train_dataset)}")
print(f"No. of testing sample: {len(test_dataset)}")

## Define Tokenizer and preprocess text

In [None]:
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
def preprocess_function(examples):
    return tokenizer(examples["text"], padding=True, truncation=True, max_length=512)

tokenized_train = train_dataset.map(preprocess_function, batched=True)
tokenized_test = test_dataset.map(preprocess_function, batched=True)

## Define metrics calculation function

In [None]:
def compute_metrics(eval_pred):
   load_accuracy = load_metric("accuracy", trust_remote_code=True)

  
   logits, labels = eval_pred
   predictions = np.argmax(logits, axis=-1)
   accuracy = load_accuracy.compute(predictions=predictions, references=labels)["accuracy"]

   return {"accuracy": accuracy}

In [None]:
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = dict((v,k) for k,v in id2label.items())

## Load DistilBert-base-uncased model and configure the model for QLoRA fine tuning

In [None]:
# QLoRA model quantization and configuration
# Configure BitsAndBytes
config_bnb = BitsAndBytesConfig(
    load_in_4bit=True, # quantize the model to 4-bits when you load it
    bnb_4bit_quant_type="nf4", # use a special 4-bit data type for weights initialized from a normal distribution
    bnb_4bit_use_double_quant=True, # nested quantization scheme to quantize the already quantized weights
    bnb_4bit_compute_dtype=torch.bfloat16, # use bfloat16 for faster computation
    llm_int8_skip_modules=["classifier", "pre_classifier"] #  Don't convert the "classifier" and "pre_classifier" layers to 8-bit
)
# Load a quantized version of a pretrained model

model_qlora = AutoModelForSequenceClassification.from_pretrained("distilbert/distilbert-base-uncased",
                                                                 id2label=id2label,
                                                                 label2id=label2id,
                                                                 num_labels=2,
                                                                 quantization_config=config_bnb,
                                                                 #low_cpu_mem_usage=True,
                                                                 #device_map='cuda:0'
                                                                )
model_qlora = prepare_model_for_kbit_training(model_qlora)
qlora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,  # Specify the task type as sequence classification
    r=8,  # Rank of the low-rank matrices
    lora_alpha=16,  # Scaling factor
    lora_dropout=0.1,  # Dropout rate  
    target_modules=['q_lin','k_lin','v_lin'] # which modules
)

peft_model_qlora = get_peft_model(model_qlora, qlora_config)
replace_lora_weights_loftq(peft_model_qlora)
peft_model_qlora.print_trainable_parameters()

## Load DistilBert-base-uncased model and configure the model for LoRA fine tuning

In [None]:
# LoRA model configuration
model_lora = AutoModelForSequenceClassification.from_pretrained("distilbert/distilbert-base-uncased",
                                                                 id2label=id2label,
                                                                 label2id=label2id,
                                                                 num_labels=2,
                                                                 #quantization_config=config_bnb,
                                                                 #low_cpu_mem_usage=True
                                                                )
lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,  # Specify the task type as sequence classification
    r=8,  # Rank of the low-rank matrices
    lora_alpha=16,  # Scaling factor
    lora_dropout=0.1,  # Dropout rate  
    target_modules=['q_lin','k_lin','v_lin'] # which modules
)

peft_model_lora = get_peft_model(model_lora, lora_config)
peft_model_lora.print_trainable_parameters()

## Configure QLoRA training arguments and train Quantized model

In [None]:
# Configure and train QLora
training_args_qlora = TrainingArguments(
    output_dir="./results_qlora",
    num_train_epochs=10,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    learning_rate=2e-5,
    evaluation_strategy="epoch",
    weight_decay=0.01,
    label_names=["labels"],
)
trainer_qlora = Trainer(
    model=peft_model_qlora,
    args=training_args_qlora,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics)

trainer_qlora.train()

## Configure LoRA training arguments and train Quantized model

In [None]:
# Configure and train Lora
training_args_lora = TrainingArguments(
    output_dir="./results_lora",
    num_train_epochs=10,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    learning_rate=2e-5,
    evaluation_strategy="epoch",
    weight_decay=0.01
)
trainer_lora = Trainer(
    model=peft_model_lora,
    args=training_args_lora,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics)

trainer_lora.train()

## Saving trained models and compare metrics of fine-tuned models with QLoRA and LoRA 

In [None]:
# Save QLoRA weight and evaluate
trainer_qlora.save_model("./qlora_final_model")
log_history_qlora = trainer_qlora.state.log_history
get_metric_qlora = lambda metric, log_history_qlora: [log[metric] for log in log_history_qlora if metric in log]
eval_accuracy_qlora=get_metric_qlora('eval_accuracy',log_history_qlora)
eval_loss_qlora=get_metric_qlora('eval_loss',log_history_qlora)

In [None]:
# Save LoRA weight and evaluate
trainer_lora.save_model("./lora_final_model")
log_history_lora = trainer_lora.state.log_history
get_metric_lora = lambda metric, log_history_lora: [log[metric] for log in log_history_lora if metric in log]
eval_accuracy_lora=get_metric_lora('eval_accuracy',log_history_lora)
eval_loss_lora=get_metric_lora('eval_loss',log_history_lora)

In [None]:
# Compare QLoRA with LoRA
plt.plot(eval_accuracy_qlora,label='accuracy_qlora', marker="o", color='b')
plt.plot(eval_loss_qlora,label='loss_qlora', marker="^", color='b')
plt.plot(eval_accuracy_lora,label='accuracy_lora', marker="o", color='g')
plt.plot(eval_loss_lora,label='loss_lora', marker="^", color='g')
plt.xlabel("epoch")
plt.legend()
plt.title("Accuracy and Loss with QLoRA and LoRA")