##### Copyright 2024 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## **[goo.gle/ai-kaggle-keras-gemma](goo.gle/ai-kaggle-keras-gemma)**

## Introduction

This tutorial demonstrates how to fine-tune Gemma on a Kaggle dataset and share your model with the community. We'll be using a [Medical Q&A Dataset](https://www.kaggle.com/datasets/jpmiller/layoutlm/data) from Kaggle and fine-tuning Gemma to answer questions about complex conditions.

**Please note that this tutorial is purely for educational purposes and should not be used for medical consultation.**

## Setup

### Get access to Gemma

To complete this tutorial, you will first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:

* Get access to Gemma on [kaggle.com](https://kaggle.com).
* Select a Colab runtime with sufficient resources to run
  the Gemma 2B model.
* Generate and configure a Kaggle username and API key.

After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.

### Select the runtime

To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU or an A100 GPU (recommended, if available):

1. In the upper-right of the Colab window, select &#9662; (**Additional connection options**).
2. Select **Change runtime type**.
3. Under **Hardware accelerator**, select **T4 GPU** or **A100 GPU**.

### Configure your API key

To use Gemma, you must provide your Kaggle username and a Kaggle API key.

To generate a Kaggle API key, go to the **Account** tab of your Kaggle user profile and select **Create New Token**. This will trigger the download of a `kaggle.json` file containing your API credentials.

In Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`.

### Set environment variables

Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`.

In [None]:
import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["GITHUB_TOKEN"] = userdata.get('GITHUB_TOKEN')
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

### Install dependencies

Install Keras, KerasNLP, and other dependencies.

In [None]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U tf-keras
!pip install -q -U keras-nlp==0.10.0
!pip install -q -U kagglehub>=0.2.4
!pip install -q -U keras>=3

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m513.7/513.7 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m950.8/950.8 kB[0m [31m28.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m78.0 MB/s[0m eta [36m0:00:00[0m
[?25h

### Select a backend

Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch.

For this tutorial, configure the backend for JAX.

In [None]:
os.environ["KERAS_BACKEND"] = "jax"
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

### Import packages

Import Keras, KerasNLP, and the `csv` package.

In [None]:
import keras_nlp
import keras
import csv

print("KerasNLP version: ", keras_nlp.__version__)
print("Keras version: ", keras.__version__)

KerasNLP version:  0.10.0
Keras version:  3.5.0


## Load Model

Let's download the 2B variant of Gemma from Kaggle. You can see the model page [here](https://www.kaggle.com/models/keras/gemma/keras/gemma_2b_en).

In [None]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")

Downloading from https://www.kaggle.com/api/v1/models/keras/gemma/keras/gemma_2b_en/2/download/metadata.json...


100%|██████████| 143/143 [00:00<00:00, 118kB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gemma/keras/gemma_2b_en/2/download/config.json...


100%|██████████| 555/555 [00:00<00:00, 937kB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gemma/keras/gemma_2b_en/2/download/model.weights.h5...


100%|██████████| 4.67G/4.67G [04:33<00:00, 18.3MB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gemma/keras/gemma_2b_en/2/download/tokenizer.json...


100%|██████████| 401/401 [00:00<00:00, 132kB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gemma/keras/gemma_2b_en/2/download/assets/tokenizer/vocabulary.spm...


100%|██████████| 4.04M/4.04M [00:01<00:00, 2.62MB/s]


In [None]:
gemma_lm.summary()

## Load Dataset

Let's download a [Medical Question Answering Dataset](https://www.kaggle.com/datasets/jpmiller/layoutlm/data) from Kaggle for this fine-tune example.

In [None]:
!kaggle datasets download -d jpmiller/layoutlm -f medquad.csv

Dataset URL: https://www.kaggle.com/datasets/jpmiller/layoutlm
License(s): CC-BY-SA-4.0
Downloading medquad.csv.zip to /content
 61% 3.00M/4.95M [00:00<00:00, 5.72MB/s]
100% 4.95M/4.95M [00:00<00:00, 7.26MB/s]


In [None]:
!unzip medquad.csv.zip

Archive:  medquad.csv.zip
  inflating: medquad.csv             


After unzipping the `medquad.csv` file, we should format our data from the `csv` into question and answer examples.

This will be the dataset our model will be fine-tuned on.

In [None]:
data = []

# The CSV file contains two columns 'question' and 'answer'
with open("medquad.csv", mode='r', encoding='utf-8') as file:
    reader = csv.DictReader(file)
    for row in reader:
        # Use a template to format the questions and answers in the CSV into
        # questions and answers in the data.
        template = "Question:\n{question}\n\nAnswer:\n{answer}"
        data.append(template.format(**row))

Let's take a look at an example to make sure the data has been formatted correctly with the Question-Answer template:

In [None]:
print(data[3])

Question:
What are the treatments for Glaucoma ?

Answer:
Although open-angle glaucoma cannot be cured, it can usually be controlled. While treatments may save remaining vision, they do not improve sight already lost from glaucoma. The most common treatments for glaucoma are medication and surgery. Medications  Medications for glaucoma may be either in the form of eye drops or pills. Some drugs reduce pressure by slowing the flow of fluid into the eye. Others help to improve fluid drainage. (Watch the video to learn more about coping with glaucoma. To enlarge the video, click the brackets in the lower right-hand corner. To reduce the video, press the Escape (Esc) button on your keyboard.) For most people with glaucoma, regular use of medications will control the increased fluid pressure. But, these drugs may stop working over time. Or, they may cause side effects. If a problem occurs, the eye care professional may select other drugs, change the dose, or suggest other ways to deal with 

### Inference before fine tuning

The original Gemma model has a lot of general knowledge, but fine-tuning can help improve domain-specific knowledge.

To test the pre-trained model on more specific medical knowledge, let's pick a more complex disease: **Chronic Eosinophilic Leukemia**.

Let's prompt Gemma by asking about about treatments for that disease, making sure to format our prompt using the Question-Answer template we previously defined.

In [None]:
prompt = template.format(
    question="What are the treatments for Chronic Eosinophilic Leukemia?",
    answer="",
)
print(gemma_lm.generate(prompt, max_length=128))

Question:
What are the treatments for Chronic Eosinophilic Leukemia?

Answer:
Chronic Eosinophilic Leukemia (CEL) is a rare type of blood cancer that affects the bone marrow. It is a type of leukemia that is caused by an abnormal increase in the number of white blood cells called eosinophils. Eosinophils are a type of white blood cell that are involved in the immune system. They are responsible for fighting infections and allergies.

The exact cause of Chronic Eosinophilic Leukemia is not known. However, it is thought to be caused by a combination of genetic and environmental factors. Some risk factors for


As you can see, the resulting answer from Gemma simply defines the disease, breaking down the definition of leukemia and eosinophils. However, it isn't able to answer the question on treatments!

This is where fine-tuning on our medical dataset can help.

## LoRA Fine-tuning

To get better responses from the model, fine-tune the model with Low Rank Adaptation (LoRA) using our Medical Question-Answer dataset.

The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.

A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.

This tutorial uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance.

In [None]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Note that enabling LoRA reduces the number of trainable parameters significantly.

In [None]:
# Fine-tune on the Medical QA dataset.

# Limit the input sequence length to 128 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)

[1m 1493/16412[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m1:53:14[0m 455ms/step - loss: 1.5942 - sparse_categorical_accuracy: 0.5979

### Inference after fine tuning
After fine tuning the model, let's try the prompt again to ask for treatments to the disease.

In [None]:
prompt = template.format(
    question="What are the treatments for Chronic Eosinophilic Leukemia?",
    answer="",
)
print(gemma_lm.generate(prompt, max_length=128))

Question:
What are the treatments for Chronic Eosinophilic Leukemia?

Answer:
The treatment for chronic eosinophilic leukemia depends on the type of leukemia and the severity of the disease.
                
Treatment may include
                
- chemotherapy  - radiation therapy  - bone marrow transplant  - targeted therapy  - immunotherapy  - supportive care
                
Chemotherapy
                
Chemotherapy is the use of drugs to kill cancer cells. Chemotherapy for chronic eosinophilic leukemia may include
                
- anthracyclines  - vincristine  - cyclophosphamide  - prednisone  - rituximab


The response is much more helpful than before fine-tuning, readily listing potential treatment options for Chronic Eosinophilic Leukemia.

## Upload your model to Kaggle

Create a preset directory for your model files.

Then, save the model to that preset directory.

In [None]:
preset = "./medical_gemma"
# Save the model to the preset directory.
gemma_lm.save_to_preset(preset)

Create a Kaggle URI for your model.
It should follow the following format:

`kaggle://{KAGGLE USERNAME}/{MODEL NAME}/keras/{VARIATION NAME}`

In [None]:
kaggle_username = userdata.get('KAGGLE_USERNAME')
model_name = "gemma"
variation_name = "medical_gemma"

uri = f"kaggle://{kaggle_username}/{model_name}/keras/{variation_name}"
uri

'kaggle://nkovela/gemma/keras/medical_gemma'

Then, upload the preset to Kaggle!

If this is your first upload of this model, a Kaggle model page will be created associated with your profile.

You can view all your models on your [Work Page](https://www.kaggle.com/work/models).

In [None]:
# Upload preset to Kaggle
keras_nlp.upload_preset(uri, preset)

Starting upload for file task.json
Uploading: 100%|██████████| 1.91k/1.91k [00:00<00:00, 2.27kB/s]
Upload successful: task.json (2KB)
Starting upload for file tokenizer.json
Uploading: 100%|██████████| 315/315 [00:00<00:00, 374B/s]
Upload successful: tokenizer.json (315B)
Starting upload for file preprocessor.json
Uploading: 100%|██████████| 831/831 [00:00<00:00, 990B/s]
Upload successful: preprocessor.json (831B)
Starting upload for file config.json
Uploading: 100%|██████████| 501/501 [00:00<00:00, 582B/s]
Upload successful: config.json (501B)
Starting upload for file metadata.json
Uploading: 100%|██████████| 143/143 [00:00<00:00, 176B/s]
Upload successful: metadata.json (143B)
Starting upload for file model.weights.h5
Uploading: 100%|██████████| 10.0G/10.0G [06:47<00:00, 24.6MB/s]
Upload successful: model.weights.h5 (9GB)
Starting upload for file vocabulary.spm
Uploading: 100%|██████████| 4.24M/4.24M [00:02<00:00, 1.75MB/s]
Upload successful: vocabulary.spm (4MB)
Your model instance 

Now view the model page using the URL in the output of the previous cell.

Verify that your new model instance is successfully uploaded.
Note this can take several minutes if this is your first upload of this model type.

**That's it!** You've now learned how to fine-tune Gemma using Kaggle and Keras and share your model with the community.