<a href="https://www.kaggle.com/code/yahayamkayode/fine-tuning-gemma2b-model-using-lora-and-keras?scriptVersionId=208046475" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

<center><h1>Fine-tuning Gemma 2 Model Using LoRA and Keras with Custom Datatset</h1></center>


# Introduction

> In this project, I developed PolicyLens-India, an advanced conversational AI chatbot capable of answering questions based on Indian parliamentary debates and policies. The model is fine-tuned using Gemma2_2b, leveraging a custom Q&A-style dataset created from Indian parliamentary debate documents spanning 2024.

> The dataset includes 500 QA pairs derived from comprehensive debates covering legislative processes, policy discussions, and key national issues. The chatbot's performance was rigorously evaluated on 10 QA pairs generated from 2024 parliamentary debates, achieving over 90% accuracy.


#### The following resources were acknowledged for the successful implementation of this project

> 1. Fine-tune Gemma models in Keras using LoRA, Kaggle Code, https://www.kaggle.com/code/nilaychauhan/fine-tune-gemma-models-in-keras-using-lora (Version 1) 
> 2. Fine-tune Gemma using LoRA and Keras, https://www.kaggle.com/code/gpreda/fine-tune-gemma-using-lora-and-keras




# What is Gemma 2?

> Gemma is a collection of lightweight, advanced open models developed by Google, leveraging the same research and technology behind the Gemini models. These models are text-to-text, decoder-only large language models available in English, with open weights provided for both pre-trained and instruction-tuned versions. Gemma models excel in a range of text generation tasks, such as question answering, summarization, and reasoning. Their compact size allows for deployment in resource-constrained environments like laptops, desktops, or personal cloud infrastructure, making state-of-the-art AI models more accessible and encouraging innovation for all. 

> Gemma 2 represent the 2nd generation of Gemma models. These models were trained on a dataset of text data that includes a wide variety of sources. The **27B** model was trained with **13 trillion** tokens, the **9B** model was trained with **8 trillion tokens**, and **2B** model was trained with **2 trillion** tokens. Here is a summary of their key components: 

> To learn more about Gemma 2, follow this link: [Gemma 2 Model Card](https://www.kaggle.com/models/google/gemma-2).


# What is LoRA?  

> **LoRA** stands for **Low-Rank Adaptation**. It is a method used to fine-tune large language models (LLMs) by freezing the weights of the LLM and injecting trainable rank-decomposition matrices. The number of trainable parameters during fine-tunning will decrease therefore considerably. According to **LoRA** paper, this number decreases **10,000 times**, and the computational resources size decreases 3 times. 

# How we proceed?

> For fine-tunning with LoRA, we will follow the steps:

> 1. Install prerequisites
> 2. Load and process the maize data for fine-tuning
> 3. Initialize the code for Gemma causal language model (Gemma Causal LM)
> 4. Perform fine-tuning
> 5. Test the fine-tunned model with questions from the data used for fine-tuning and with aditional questions

# Prerequisites


## Install packages

We start by installing `keras-nlp` and `keras` packages.

In [1]:
# Install dependencies
!pip install -q -U wurlitzer
!pip install keras-core
!pip install -q -U keras-nlp
!pip install -q -U keras==3.5.0  # Use Keras 3.x to work with JAX
!pip install -q -U kagglehub --upgrade
!pip install jax jaxlib
!pip install keras-nlp




In [None]:
!pip install keras-nlp

## Import packages

Now we can import the packages we just installed. We will also install `os`, so that we can set the environment variables needed for keras backend. We will use `jax` as `KERAS_BACKEND`.

Because we want to publish the Model from the Notebook, we also include `kagglehub` and import secrets from `Kaggle App`.

In [2]:
import os
os.environ["KERAS_BACKEND"] = "jax" # you can also use tensorflow or torch
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00" # avoid memory fragmentation on JAX backend.
os.environ["JAX_PLATFORMS"] = ""
import keras
import keras_nlp
import kagglehub


from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("KAGGLE_KEY")
secret_value_1 = user_secrets.get_secret("KAGGLE_USERNAME")

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
tqdm.pandas() # progress bar for pandas

import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, Markdown

## Configurations


We use a `Config` class to group the information needed to control the fine-tuning process:
* random seed 
* dataset path
* preset - name of pretrained Gemma 2
* sequence length - this is the maximum size of input sequence for training
* batch size - size of the input batch in training, x 2 as two GPUs
* lora rank - rank for LoRA, higher means more trainable parameters 
* learning rate used in the train
* epochs - number of epochs for train

In [3]:
class Config:
    seed = 42

    dataset_path = "/kaggle/input/indian-parliamentary-debates-data-2024"  # Use your dataset's Kaggle path
    preset = "gemma2_2b_en" # name of pretrained Gemma 2
    sequence_length = 512 # max size of input sequence for training
    batch_size = 1 # size of the input batch in training
    lora_rank = 4 # rank for LoRA, higher means more trainable parameters
    learning_rate=8e-5 # learning rate used in train
    epochs = 12 # number of epochs to train


In [4]:
keras.utils.set_random_seed(Config.seed)

# Load the data


We load the data we will use for fine-tunining.

In [5]:
df = pd.read_csv(f"{Config.dataset_path}/PolicyData_2024_mini.csv")
df.sample(8)

Unnamed: 0,Question,Context,Answer
460,What did the Speaker emphasize about discussio...,The Speaker highlighted the need for discussio...,The Speaker emphasized that discussions should...
73,How does the government plan to support domest...,The government plans to develop infrastructure...,"To support domestic tourism growth, the govern..."
231,Who conveyed congratulations on behalf of the ...,Adv. Francis George from the Kerala Congress P...,Adv. Francis George congratulated Shri Om Birl...
175,Who is the member from Arani and in which lang...,Shri Tharaniventhan M.S. from Arani took the a...,Shri Tharaniventhan M.S. represents Arani and ...
237,Who expressed concerns about passing Bills wit...,Adv. Francis George from Kerala Congress expre...,Adv. Francis George voiced concerns about Bill...
425,What was Shri Rahul Gandhi request related to...,Shri Rahul Gandhi wanted to address student is...,He requested a discussion focused on student i...
155,Who represents Thane and took the oath in Mara...,Shri Naresh Ganpat Mhaske from Thane took the ...,Shri Naresh Ganpat Mhaske represents Thane and...
55,How has the National Education Policy 2020 cha...,The National Education Policy 2020 introduces ...,The National Education Policy 2020 has brought...


For easiness, we will create the following template for QA: 

In [6]:
template = "\n\nQuestion:\n{Question}\n\nAnswer:\n{Answer}"

df["prompt"] = df.apply(lambda row: template.format(Question=row.Question,
                                                    Answer=row.Answer), axis=1)
data = df.prompt.tolist()

## Template utility function

In [7]:
def colorize_text(text):
    for word, color in zip(["Question", "Answer"], ["red", "green"]):
        text = text.replace(f"\n\n{word}:", f"\n\n**<font color='{color}'>{word}:</font>**")
    return text

# Specialized class to query Gemma


We define a specialized class to query Gemma. But first, we need to initialize an object of GemmaCausalLM class.

## Initialize the code for Gemma Causal LM

In [9]:
import keras_nlp
print(keras_nlp.__version__)

0.17.0


In [10]:
gemma_causal_lm = keras_nlp.models.GemmaCausalLM.from_preset(Config.preset)

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [11]:
gemma_causal_lm.summary()

## Define the specialized class

Here we define the special class `GemmaQA`. 
in the `__init__` we pass the `GemmaCausalLM` object created before.
The `query` member function uses `GemmaCausalLM` member function `generate` to generate the answer, based on a prompt that includes the category and the question.

In [12]:
class GemmaQA:
    def __init__(self, max_length=512):
        self.max_length = max_length
        self.prompt = template
        self.gemma_causal_lm = gemma_causal_lm
        
    def query(self, question):
        response = self.gemma_causal_lm.generate(
            self.prompt.format(
                Question=question,
                Answer=""), 
            max_length=self.max_length)
        display(Markdown(colorize_text(response)))

## Gemma preprocessor


This preprocessing layer will take in batches of strings, and return outputs in a ```(x, y, sample_weight)``` format, where the y label is the next token id in the x sequence.

From the code below, we can see that, after the preprocessor, the data shape is ```(num_samples, sequence_length)```.

In [13]:
x, y, sample_weight = gemma_causal_lm.preprocessor(data[0:2])

# Perform fine-tuning with LoRA

## Enable LoRA for the model

LoRA rank is setting the number of trainable parameters. A larger rank will result in a larger number of parameters to train.

In [14]:
# Enable LoRA for the model and set the LoRA rank to the lora_rank as set in Config (4).
gemma_causal_lm.backbone.enable_lora(rank=Config.lora_rank)
gemma_causal_lm.summary()

## Run the training sequence

We set the `sequence_length` for the `GemmaCausalLM` (from configuration, will be 512).
We compile the model, with the loss, optimizer and metric.
For the metric, it is used `SparseCategoricalAccuracy`. This metric calculates how often predictions match integer labels.

In [15]:
#set sequence length cf. config (512)
gemma_causal_lm.preprocessor.sequence_length = Config.sequence_length 

# Compile the model with loss, optimizer, and metric
gemma_causal_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(learning_rate=Config.learning_rate),
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train model
gemma_causal_lm.fit(data, epochs=Config.epochs, batch_size=Config.batch_size)

Epoch 1/12
[1m499/499[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m459s[0m 855ms/step - loss: 0.1931 - sparse_categorical_accuracy: 0.6039
Epoch 2/12
[1m499/499[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m427s[0m 819ms/step - loss: 0.1353 - sparse_categorical_accuracy: 0.6834
Epoch 3/12
[1m499/499[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m409s[0m 820ms/step - loss: 0.1244 - sparse_categorical_accuracy: 0.7014
Epoch 4/12
[1m499/499[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m409s[0m 820ms/step - loss: 0.1145 - sparse_categorical_accuracy: 0.7190
Epoch 5/12
[1m499/499[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m409s[0m 820ms/step - loss: 0.1037 - sparse_categorical_accuracy: 0.7405
Epoch 6/12
[1m499/499[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m409s[0m 820ms/step - loss: 0.0925 - sparse_categorical_accuracy: 0.7658
Epoch 7/12
[1m499/499[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m409s[0m 820ms/step - loss: 0.0805 - sparse_categorical_accuracy: 0.7891

<keras.src.callbacks.history.History at 0x7f90fc538550>

# Test the fine-tuned model

We instantiate an object of class GemmaQA. Because `gemma_causal_lm` was fine-tuned using LoRA, `gemma_qa` defined here will use the fine-tuned model.

In [16]:
gemma_qa = GemmaQA()

For start, we will testing the model with some of the data from the training set itself.

In [17]:
row = df.iloc[0]
gemma_qa.query(row.Question)



**<font color='red'>Question:</font>**
What is the goal of the 'Sabka Saath, Sabka Vikas' philosophy?

**<font color='green'>Answer:</font>**
'Sabka Saath, Sabka Vikas' is a philosophy that emphasizes inclusive growth, intending to reach every social and geographical section in India, and foster equal opportunities and prosperity for all segments of society.

In [18]:
row = df.iloc[3]
gemma_qa.query(row.Question)



**<font color='red'>Question:</font>**
What is the PM-KISAN SAMMAN Yojana, and who benefits from it?

**<font color='green'>Answer:</font>**
PM-KISAN SAMMAN Yojana is a program that provides direct financial assistance to 11.8 crore farmers, focusing on supporting small and marginal farmers.

In [19]:
row = df.iloc[105]
gemma_qa.query(row.Question)



**<font color='red'>Question:</font>**
Who took the oath in Manipuri on June 25, 2024?

**<font color='green'>Answer:</font>**
Shri Angomcha Bimol Akoijam took the oath in Manipuri.

## Test the model with unseen question(s)

In [38]:
question = "How is the government promoting inclusive education for the disabled?"
gemma_qa.query(question)



**<font color='red'>Question:</font>**
How is the government promoting inclusive education for the disabled?

**<font color='green'>Answer:</font>**
Inclusive education promotes equal access for disabled students, with focus on modern facilities and resources.

In [47]:
question = "How has India promoted renewable energy?"
gemma_qa.query(question)



**<font color='red'>Question:</font>**
How has India promoted renewable energy?

**<font color='green'>Answer:</font>**
India promotes renewable energy, aiming for 5 gigawatts of wind energy and 10 gigawatts of solar energy by 2027.

In [48]:
question = "What disaster relief funds are available to Tamil Nadu?"
gemma_qa.query(question)



**<font color='red'>Question:</font>**
What disaster relief funds are available to Tamil Nadu?

**<font color='green'>Answer:</font>**
The government has released Rs. 1,111 crore from the NDRF for Tamil Nadu.

In [49]:
question = "What is the current disaster relief fund status for Tamil Nadu?"
gemma_qa.query(question)



**<font color='red'>Question:</font>**
What is the current disaster relief fund status for Tamil Nadu?

**<font color='green'>Answer:</font>**
The government has released partial aid for Tamil Nadu's relief and restoration efforts.

In [44]:
question = "Who congratulated Shri Om Birla on behalf of the INDIA Alliance?"
gemma_qa.query(question)



**<font color='red'>Question:</font>**
Who congratulated Shri Om Birla on behalf of the INDIA Alliance?

**<font color='green'>Answer:</font>**
Shri Rahul Gandhi congratulated Shri Om Birla on behalf of the INDIA Alliance, highlighting the importance of a non-partisan approach.

In [36]:
question = "What is the focus of India  National Solar Mission?"
gemma_qa.query(question)



**<font color='red'>Question:</font>**
What is the focus of India  National Solar Mission?

**<font color='green'>Answer:</font>**
India  National Solar Mission aims to expand solar energy adoption across sectors.

In [35]:
question = "How is the Ministry of Health addressing the shortage of healthcare professionals?"
gemma_qa.query(question)



**<font color='red'>Question:</font>**
How is the Ministry of Health addressing the shortage of healthcare professionals?

**<font color='green'>Answer:</font>**
The Ministry is addressing the shortage by increasing medical and para-medical education seats.

# Save the model

In [30]:
preset_dir = ".\gemma2_2b_en_policylens_model"
gemma_causal_lm.save_to_preset(preset_dir)

# Conclusions



> - Fine-tuning of a **Gemma 2** model has been demonstated using LoRA.   
> -  A class was alos created to run queries to the **Gemma 2** model and tested it with some examples from the existing training data but also with some new, unseen questions.   
> - The models was as a Keras model.
> - The model was evaluated using Perplexity,recorded Perplexity value of 2.502. 
> - Finnally, the model was published as a Kaggle Model on Kaggle Models platform.