<a href="https://www.kaggle.com/code/yahayamkayode/fine-tuning-gemma2b-model-using-lora-and-keras?scriptVersionId=208046475" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

<center><h1>Fine-tuning Gemma 2 Model Using LoRA and Keras with Custom Datatset</h1></center>

<center><img src="https://res.infoq.com/news/2024/02/google-gemma-open-model/en/headerimage/generatedHeaderImage-1708977571481.jpg" width="400"></center>


# Introduction

> 1. How to fine-tune Gemma model using LoRA with a customise dataset - Maize dataset 🌽
> 2. Creation of a specialised class to query about Maize production
> 3. Some results of querying about best practice for Maize production

#### The following resources were acknowledged for the successful implementation of this project

> 1. Gemma 2 Model Card, Kaggle Models,https://www.kaggle.com/models/google/gemma-2/
> 2. Kaggle QA with Gemma - KerasNLP Starter, Kaggle Code, https://www.kaggle.com/code/awsaf49/kaggle-qa-with-gemma-kerasnlp-starter (Version 11)  
> 3. Fine-tune Gemma models in Keras using LoRA, Kaggle Code, https://www.kaggle.com/code/nilaychauhan/fine-tune-gemma-models-in-keras-using-lora (Version 1) 
> 4. Unlock the Power of Gemma 2: Prompt it like a Pro, https://www.kaggle.com/code/gpreda/unlock-the-power-of-gemma-2-prompt-it-like-a-pro  
> 5. Fine-tune Gemma using LoRA and Keras, https://www.kaggle.com/code/gpreda/fine-tune-gemma-using-lora-and-keras



> **Let's go**🕺🕺🕺


# What is Gemma 2?

> Gemma is a collection of lightweight, advanced open models developed by Google, leveraging the same research and technology behind the Gemini models. These models are text-to-text, decoder-only large language models available in English, with open weights provided for both pre-trained and instruction-tuned versions. Gemma models excel in a range of text generation tasks, such as question answering, summarization, and reasoning. Their compact size allows for deployment in resource-constrained environments like laptops, desktops, or personal cloud infrastructure, making state-of-the-art AI models more accessible and encouraging innovation for all. 

> Gemma 2 represent the 2nd generation of Gemma models. These models were trained on a dataset of text data that includes a wide variety of sources. The **27B** model was trained with **13 trillion** tokens, the **9B** model was trained with **8 trillion tokens**, and **2B** model was trained with **2 trillion** tokens. Here is a summary of their key components: 

> To learn more about Gemma 2, follow this link: [Gemma 2 Model Card](https://www.kaggle.com/models/google/gemma-2).


# What is LoRA?  

> **LoRA** stands for **Low-Rank Adaptation**. It is a method used to fine-tune large language models (LLMs) by freezing the weights of the LLM and injecting trainable rank-decomposition matrices. The number of trainable parameters during fine-tunning will decrease therefore considerably. According to **LoRA** paper, this number decreases **10,000 times**, and the computational resources size decreases 3 times. 

# How we proceed?

> For fine-tunning with LoRA, we will follow the steps:

> 1. Install prerequisites
> 2. Load and process the maize data for fine-tuning
> 3. Initialize the code for Gemma causal language model (Gemma Causal LM)
> 4. Perform fine-tuning
> 5. Test the fine-tunned model with questions from the data used for fine-tuning and with aditional questions

# Prerequisites


## Install packages

We start by installing `keras-nlp` and `keras` packages.

In [None]:
# Install dependencies
!pip install -q -U wurlitzer
!pip install keras-core
!pip install -q -U keras-nlp
!pip install -q -U keras==3.5.0  # Use Keras 3.x to work with JAX
!pip install -q -U kagglehub --upgrade
!pip install jax jaxlib
!pip install keras-nlp


## Import packages

Now we can import the packages we just installed. We will also install `os`, so that we can set the environment variables needed for keras backend. We will use `jax` as `KERAS_BACKEND`.

Because we want to publish the Model from the Notebook, we also include `kagglehub` and import secrets from `Kaggle App`.

In [5]:
import os
os.environ["KERAS_BACKEND"] = "jax" # you can also use tensorflow or torch
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00" # avoid memory fragmentation on JAX backend.
os.environ["JAX_PLATFORMS"] = ""
import keras
import keras_nlp
import kagglehub


from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
os.environ["KAGGLE_USERNAME"] = user_secrets.get_secret("kaggle_username")
os.environ["KAGGLE_KEY"] = user_secrets.get_secret("kaggle_key")

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
tqdm.pandas() # progress bar for pandas

import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, Markdown

## Configurations


We use a `Config` class to group the information needed to control the fine-tuning process:
* random seed 
* dataset path
* preset - name of pretrained Gemma 2
* sequence length - this is the maximum size of input sequence for training
* batch size - size of the input batch in training, x 2 as two GPUs
* lora rank - rank for LoRA, higher means more trainable parameters 
* learning rate used in the train
* epochs - number of epochs for train

In [6]:
class Config:
    seed = 42

    dataset_path = "/kaggle/input/dataset-maize"  # Use your dataset's Kaggle path
    preset = "gemma2_2b_en" # name of pretrained Gemma 2
    sequence_length = 512 # max size of input sequence for training
    batch_size = 1 # size of the input batch in training
    lora_rank = 4 # rank for LoRA, higher means more trainable parameters
    learning_rate=8e-5 # learning rate used in train
    epochs = 12 # number of epochs to train


In [7]:
keras.utils.set_random_seed(Config.seed)

# Load the data


We load the data we will use for fine-tunining.

In [8]:
df = pd.read_csv(f"{Config.dataset_path}/maize-dataset.csv")
df.sample(8)

Unnamed: 0,Question,Answer,Intent,Entities,Category
73,What type of fertilizer should I use for maize...,"In Lagos, you should use nitrogen-based fertil...",Fertilizer Use,"Lagos, Forest, Maize",Fertilizer Application
18,What is the recommended planting spacing for m...,The recommended planting spacing for maize is ...,Planting Advice,"Maize, Spacing, 75cm, 25cm",Maize Cultivation
118,What are the advantages of early planting in m...,Early planting helps maize avoid late-season d...,Planting Benefits,"maize, early planting, drought, pest infestations",Agronomy
78,What type of fertilizer should I use for maize...,"In Osun, you should use nitrogen-based fertili...",Fertilizer Use,"Osun, Forest, Maize",Fertilizer Application
76,What is the ideal watering schedule for maize ...,"In the Forest zone, maize requires watering ev...",Watering Schedule,"Lagos, Forest, Maize",Water Management
31,What is the ideal watering schedule for maize ...,"In the Sudan Savanna zone, maize requires wate...",Watering Schedule,"Kano, Sudan Savanna, Maize",Water Management
64,How do I manage pests in Southern Guinea Savan...,"In the Southern Guinea Savanna zone, regular m...",Pest Control,"Benue, Southern Guinea Savanna, Maize",Pest Control
141,What are the signs of nitrogen deficiency in m...,Nitrogen deficiency in maize is indicated by y...,Deficiency Symptoms,"maize, nitrogen deficiency",Nutrient Management


For easiness, we will create the following template for QA: 

In [9]:
template = "\n\nQuestion:\n{Question}\n\nAnswer:\n{Answer}"

df["prompt"] = df.apply(lambda row: template.format(Question=row.Question,
                                                    Answer=row.Answer), axis=1)
data = df.prompt.tolist()

## Template utility function

In [10]:
def colorize_text(text):
    for word, color in zip(["Question", "Answer"], ["red", "green"]):
        text = text.replace(f"\n\n{word}:", f"\n\n**<font color='{color}'>{word}:</font>**")
    return text

# Specialized class to query Gemma


We define a specialized class to query Gemma. But first, we need to initialize an object of GemmaCausalLM class.

## Initialize the code for Gemma Causal LM

In [None]:
!pip install keras-nlp

In [12]:
import keras_nlp
print(keras_nlp.__version__)

0.17.0


In [13]:
gemma_causal_lm = keras_nlp.models.GemmaCausalLM.from_preset(Config.preset)

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [14]:
gemma_causal_lm.summary()

## Define the specialized class

Here we define the special class `GemmaQA`. 
in the `__init__` we pass the `GemmaCausalLM` object created before.
The `query` member function uses `GemmaCausalLM` member function `generate` to generate the answer, based on a prompt that includes the category and the question.

In [15]:
class GemmaQA:
    def __init__(self, max_length=512):
        self.max_length = max_length
        self.prompt = template
        self.gemma_causal_lm = gemma_causal_lm
        
    def query(self, question):
        response = self.gemma_causal_lm.generate(
            self.prompt.format(
                Question=question,
                Answer=""), 
            max_length=self.max_length)
        display(Markdown(colorize_text(response)))

## Gemma preprocessor


This preprocessing layer will take in batches of strings, and return outputs in a ```(x, y, sample_weight)``` format, where the y label is the next token id in the x sequence.

From the code below, we can see that, after the preprocessor, the data shape is ```(num_samples, sequence_length)```.

In [16]:
x, y, sample_weight = gemma_causal_lm.preprocessor(data[0:2])

# Perform fine-tuning with LoRA

## Enable LoRA for the model

LoRA rank is setting the number of trainable parameters. A larger rank will result in a larger number of parameters to train.

In [17]:
# Enable LoRA for the model and set the LoRA rank to the lora_rank as set in Config (4).
gemma_causal_lm.backbone.enable_lora(rank=Config.lora_rank)
gemma_causal_lm.summary()

## Run the training sequence

We set the `sequence_length` for the `GemmaCausalLM` (from configuration, will be 512).
We compile the model, with the loss, optimizer and metric.
For the metric, it is used `SparseCategoricalAccuracy`. This metric calculates how often predictions match integer labels.

In [18]:
#set sequence length cf. config (512)
gemma_causal_lm.preprocessor.sequence_length = Config.sequence_length 

# Compile the model with loss, optimizer, and metric
gemma_causal_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(learning_rate=Config.learning_rate),
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train model
gemma_causal_lm.fit(data, epochs=Config.epochs, batch_size=Config.batch_size)

Epoch 1/12
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m175s[0m 947ms/step - loss: 0.1787 - sparse_categorical_accuracy: 0.5858
Epoch 2/12
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m142s[0m 820ms/step - loss: 0.0803 - sparse_categorical_accuracy: 0.7630
Epoch 3/12
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m123s[0m 820ms/step - loss: 0.0621 - sparse_categorical_accuracy: 0.8112
Epoch 4/12
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m123s[0m 820ms/step - loss: 0.0557 - sparse_categorical_accuracy: 0.8267
Epoch 5/12
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m123s[0m 820ms/step - loss: 0.0496 - sparse_categorical_accuracy: 0.8390
Epoch 6/12
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m123s[0m 820ms/step - loss: 0.0442 - sparse_categorical_accuracy: 0.8557
Epoch 7/12
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m123s[0m 820ms/step - loss: 0.0393 - sparse_categorical_accuracy: 0.8669

<keras.src.callbacks.history.History at 0x79d1b40df6d0>

# Test the fine-tuned model

We instantiate an object of class GemmaQA. Because `gemma_causal_lm` was fine-tuned using LoRA, `gemma_qa` defined here will use the fine-tuned model.

In [19]:
gemma_qa = GemmaQA()

For start, we will testing the model with some of the data from the training set itself.

In [20]:
row = df.iloc[0]
gemma_qa.query(row.Question)



**<font color='red'>Question:</font>**
What is the best soil for maize?

**<font color='green'>Answer:</font>**
Maize grows best in well-drained loamy or sandy soil with high organic matter content.

In [21]:
row = df.iloc[3]
gemma_qa.query(row.Question)



**<font color='red'>Question:</font>**
How can I control weeds in my maize farm?

**<font color='green'>Answer:</font>**
To control weeds in maize, use herbicides like Atrazine or Halosulfuron at 2-leaf or 2-3 leaf stage. Regular weeding is also important.

In [22]:
row = df.iloc[105]
gemma_qa.query(row.Question)



**<font color='red'>Question:</font>**
What is the ideal pH for maize cultivation?

**<font color='green'>Answer:</font>**
The ideal pH for maize cultivation is between 5.5 and 7.0. Adjust soil acidity with lime if necessary.

## Test the model with unseen question(s)

In [23]:
question = "What is the best time to plant Maize?"
gemma_qa.query(question)



**<font color='red'>Question:</font>**
What is the best time to plant Maize?

**<font color='green'>Answer:</font>**
The best time to plant Maize is in the early rainy season, about 2-3 weeks before the first rainfall. This allows the crop to take advantage of the moisture provided by the early rains.

In [24]:
question = "How many seed of maize should I plant per hole?"
gemma_qa.query(question)



**<font color='red'>Question:</font>**
How many seed of maize should I plant per hole?

**<font color='green'>Answer:</font>**
Plant about 2-3 seeds per hole, spacing them 50-60 cm apart for good plant spacing.

In [25]:
question = "What is the ideal spacing for maize planting?"
gemma_qa.query(question)



**<font color='red'>Question:</font>**
What is the ideal spacing for maize planting?

**<font color='green'>Answer:</font>**
For optimal spacing, plant maize at 75 cm between rows and 25 cm between plants.

In [26]:
question = "What is the right time to harvest maize?"
gemma_qa.query(question)



**<font color='red'>Question:</font>**
What is the right time to harvest maize?

**<font color='green'>Answer:</font>**
Maize can be harvested when the silks turn brown, the husks turn yellow, or when the grains are hard and dry.

In [27]:
question = "What is the best way to store maize after harvest?"
gemma_qa.query(question)



**<font color='red'>Question:</font>**
What is the best way to store maize after harvest?

**<font color='green'>Answer:</font>**
After harvesting maize, it is important to dry the grains quickly to prevent mold. Spread the dried maize in a well-ventilated area or use a maize dryer to ensure proper storage. Store the maize in airtight containers or in a cool, dry place, such as a ventilated granary or a dry room, to protect it from pests and rodents.

In [29]:
question = "How do I manage Fall Army Worm on maize crop?"
gemma_qa.query(question)



**<font color='red'>Question:</font>**
How do I manage Fall Army Worm on maize crop?

**<font color='green'>Answer:</font>**
Fall Army Worm can be managed through early detection, using pheromone traps, and applying insecticides like Spinosad or

<h2>Related Posts</h2>

* How to Manage Fall Army Worm in Maize
* How to Use Pheromone Traps for Fall Army Worm in Maize
* What are the Recommended Insecticides for Fall Army Worm in Maize?

In [30]:
question = "How do I get rid-off yellow maize leave?"
gemma_qa.query(question)



**<font color='red'>Question:</font>**
How do I get rid-off yellow maize leave?

**<font color='green'>Answer:</font>**
Yellow maize leaves can be a sign of nutrient deficiency, particularly iron or magnesium. To address this, apply a foliar spray of iron sulfate or use a magnesium-based fertilizer.

In [31]:
question = "How can I identify bad maize seed not to plant?"
gemma_qa.query(question)



**<font color='red'>Question:</font>**
How can I identify bad maize seed not to plant?

**<font color='green'>Answer:</font>**
Bad maize seeds can be identified by wilting leaves, reduced vigor, or by physical defects such as mold or insect damage.

In [32]:
question = "What is the best maize variety in Abuja for planting"
gemma_qa.query(question)



**<font color='red'>Question:</font>**
What is the best maize variety in Abuja for planting

**<font color='green'>Answer:</font>**
In Abuja, SAMMAZ 14, SAMMAZ 41, and SAMMAZ 10 are the recommended maize varieties for planting.

In [33]:
question = "What is the best maize variety in Kaduna state?"
gemma_qa.query(question)



**<font color='red'>Question:</font>**
What is the best maize variety in Kaduna state?

**<font color='green'>Answer:</font>**
In Kaduna, SAMMAZ 14, SAMMAZ 21, and SAMMAZ 23 are the recommended maize varieties.

# Save the model

In [35]:
preset_dir = ".\gemma2_2b_en_maize_model"
gemma_causal_lm.save_to_preset(preset_dir)

# Publish the model on Kaggle as a Kaggle Model

We are publishing now the saved model as a Kaggle Model.

In [37]:
kaggle_username = os.environ["KAGGLE_USERNAME"]
print(kaggle_username)

yahayamkayode


In [38]:
kaggle_username = os.environ["KAGGLE_USERNAME"]

kaggle_uri = f"kaggle://{kaggle_username}/maize-dataset/keras/gemma2_2b_en_maize_model"

# Proceed with the upload
keras_nlp.upload_preset(kaggle_uri, preset_dir)

Uploading Model https://www.kaggle.com/models/yahayamkayode/maize-dataset/keras/gemma2_2b_en_maize_model ...
Model 'maize-dataset' does not exist or access is forbidden for user 'yahayamkayode'. Creating or handling Model...
Model 'maize-dataset' Created.
Starting upload for file .\gemma2_2b_en_maize_model/task.json


Uploading: 100%|██████████| 2.98k/2.98k [00:00<00:00, 14.5kB/s]

Upload successful: .\gemma2_2b_en_maize_model/task.json (3KB)
Starting upload for file .\gemma2_2b_en_maize_model/preprocessor.json



Uploading: 100%|██████████| 1.41k/1.41k [00:00<00:00, 8.40kB/s]

Upload successful: .\gemma2_2b_en_maize_model/preprocessor.json (1KB)
Starting upload for file .\gemma2_2b_en_maize_model/model.weights.h5



Uploading: 100%|██████████| 10.5G/10.5G [01:59<00:00, 87.5MB/s]

Upload successful: .\gemma2_2b_en_maize_model/model.weights.h5 (10GB)
Starting upload for file .\gemma2_2b_en_maize_model/config.json



Uploading: 100%|██████████| 782/782 [00:00<00:00, 2.66kB/s]

Upload successful: .\gemma2_2b_en_maize_model/config.json (782B)
Starting upload for file .\gemma2_2b_en_maize_model/metadata.json



Uploading: 100%|██████████| 143/143 [00:00<00:00, 812B/s]

Upload successful: .\gemma2_2b_en_maize_model/metadata.json (143B)
Starting upload for file .\gemma2_2b_en_maize_model/tokenizer.json



Uploading: 100%|██████████| 591/591 [00:00<00:00, 3.39kB/s]

Upload successful: .\gemma2_2b_en_maize_model/tokenizer.json (591B)
Starting upload for file .\gemma2_2b_en_maize_model/assets/tokenizer/vocabulary.spm



Uploading: 100%|██████████| 4.24M/4.24M [00:00<00:00, 12.7MB/s]

Upload successful: .\gemma2_2b_en_maize_model/assets/tokenizer/vocabulary.spm (4MB)





Your model instance has been created.
Files are being processed...
See at: https://www.kaggle.com/models/yahayamkayode/maize-dataset/keras/gemma2_2b_en_maize_model


## Upload the model to Hugging Face

In [None]:
# Upload the preset to Hugging Face Hub
hf_uri = "hf://Justsp3cial/maize_model"
keras_nlp.upload_preset(hf_uri, '.\\gemma2_2b_en_maize_model')

# Conclusions



> - Fine-tuning of a **Gemma 2** model has been demonstated using LoRA.   
> -  A class was alos created to run queries to the **Gemma 2** model and tested it with some examples from the existing training data but also with some new, unseen questions.   
> - The models was as a Keras model.
> - The model was evaluated using Perplexity,recorded Perplexity value of 2.601. 
> - Finnally, the model was published as a Kaggle Model on Kaggle Models platform.