# Fine Tuning Gemma2 9B base model on Medical dataset

## Installation
Install KerasNLP with the Gemma 2 model.

In [1]:
!pip install -q -U keras-nlp tensorflow-text
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
!pip install -q -U tensorflow-cpu

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tf-keras 2.16.0 requires tensorflow<2.17,>=2.16, but you have tensorflow 2.17.0 which is incompatible.[0m[31m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


Install huggingface datasets library

In [2]:
!pip install -q -U datasets

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


Login to huggingface account

In [3]:
from huggingface_hub import login
login()

  from .autonotebook import tqdm as notebook_tqdm


The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


## Set up Keras JAX backend

In [5]:
import jax

jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [6]:
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate all TPU memory to minimize memory fragmentation and allocation overhead.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

In [7]:
import keras
import keras_nlp

## Load model

To load the model with the weights and tensors distributed across TPUs, first create a new `DeviceMesh`. `DeviceMesh` represents a collection of hardware devices configured for distributed computation and was introduced in Keras 3 as part of the unified distribution API.

The distribution API enables data and model parallelism, allowing for efficient scaling of deep learning models on multiple accelerators and hosts. It leverages the underlying framework (e.g. JAX) to distribute the program and tensors according to the sharding directives through a procedure called single program, multiple data (SPMD) expansion. Check out more details in the new [Keras 3 distribution API guide](https://keras.io/guides/distribution/).

In [8]:
# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices(),
)

LayoutMap from the distribution API specifies how the weights and tensors should be sharded or replicated, using the string keys, for example, token_embedding/embeddings below, which are treated like regex to match tensor paths. Matched tensors are sharded with model dimensions (8 TPUs); others will be fully replicated.

In [9]:
model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in attention layers
layout_map["decoder_block.*attention.*(query|key|value)/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*attention_output/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*/kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear/kernel"] = (model_dim, None)

**ModelParallel** allows you to shard model weights or activation tensors across all devcies on the DeviceMesh. In this case, some of the Gemma 7B model weights are sharded across 8 TPU chips according the layout_map defined above. Now load the model in the distributed way.

In [10]:
model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map,
    batch_dim_name="batch",
)

keras.distribution.set_distribution(model_parallel)


Now load the Gemma 2 27B model in the distributed way.

In [11]:
gemma2_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_9b_en")
gemma2_lm.summary()

Attaching 'model.safetensors' from model 'keras/gemma2/keras/gemma2_9b_en/2' to your Kaggle notebook...
Attaching 'model.safetensors.index.json' from model 'keras/gemma2/keras/gemma2_9b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma2/keras/gemma2_9b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma2/keras/gemma2_9b_en/2' to your Kaggle notebook...
Attaching 'task.json' from model 'keras/gemma2/keras/gemma2_9b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma2/keras/gemma2_9b_en/2' to your Kaggle notebook...
Attaching 'model.safetensors' from model 'keras/gemma2/keras/gemma2_9b_en/2' to your Kaggle notebook...
Attaching 'model.safetensors.index.json' from model 'keras/gemma2/keras/gemma2_9b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma2/keras/gemma2_9b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma2/keras/gemma2_9b_e

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


## Generate text before fine-tuning

Now the Gemma 2 9B model is ready to be used for text generation.

In [12]:
print(gemma2_lm.generate("Hello doctor,Can I pull out hard teeth out no pain because I got a hole in my teeth? Whenever I use to eat something, it got stuck in it and caused severe pain. Kindly give me some advice about pulling it out without any pain.", max_length=512))

Hello doctor,Can I pull out hard teeth out no pain because I got a hole in my teeth? Whenever I use to eat something, it got stuck in it and caused severe pain. Kindly give me some advice about pulling it out without any pain. I am 20 years old. I am a student. I am not able to eat anything. I am not able to sleep. I am not able to concentrate on my studies. I am not able to do anything. I am not able to do anything. I am not able to do anything. I am not able to do anything. I am not able to do anything. I am not able to do anything. I am not able to do anything. I am not able to do anything. I am not able to do anything. I am not able to do anything. I am not able to do anything. I am not able to do anything. I am not able to do anything. I am not able to do anything. I am not able to do anything. I am not able to do anything. I am not able to do anything. I am not able to do anything. I am not able to do anything. I am not able to do anything. I am not able to do anything. I am not 

In [13]:
print(gemma2_lm.generate("Hello doctor, The top of my belly button, that is, the skin of my belly button is hard and firm, and the rest of my belly button is not. I feel as if there is something under the surface. I changed my diet, so I lost some weight. I was 168 lbs, and now I am 135 lbs. If I place two fingers over my belly button and do wide circles, I feel like there is something hard underneath. Is this a hernia or where my umbilical cord was? I do not have pain or symptoms, but I just feel it firm and hard. What is this?", max_length=512))

Hello doctor, The top of my belly button, that is, the skin of my belly button is hard and firm, and the rest of my belly button is not. I feel as if there is something under the surface. I changed my diet, so I lost some weight. I was 168 lbs, and now I am 135 lbs. If I place two fingers over my belly button and do wide circles, I feel like there is something hard underneath. Is this a hernia or where my umbilical cord was? I do not have pain or symptoms, but I just feel it firm and hard. What is this?


## Load Dataset
Download medical chatbot dataset from huggingface

In [14]:
import pandas as pd

df = pd.read_parquet("hf://datasets/ruslanmv/ai-medical-chatbot/dialogues.parquet")

In [15]:
df.head(3)

Unnamed: 0,Description,Patient,Doctor
0,Q. What does abutment of the nerve root mean?,"Hi doctor,I am just wondering what is abutting...",Hi. I have gone through your query with dilige...
1,Q. What should I do to reduce my weight gained...,"Hi doctor, I am a 22-year-old female who was d...",Hi. You have really done well with the hypothy...
2,Q. I have started to get lots of acne on my fa...,Hi doctor! I used to have clear skin but since...,Hi there Acne has multifactorial etiology. Onl...


In [16]:
med_records = df.sample(n=1000, random_state=2).to_dict('records')

In [17]:
med_records[0]

{'Description': 'Is eloctrophoresis which shows a discrete band consistent with plasma cell dyscrasia anything to worry ?',
 'Patient': 'i recently had a blood test because i was having a lot of pain which comes and goes i have a crush fracture of t9 blood tests are generally ok, except for the eloctrophoresis which shows a discrete band consistent with plasma cell dyscrasia such as myeloma or mgus',
 'Doctor': 'your electrophoresis reports says that it is consistent with plasma cell dyscariasis. however not specified that it is myeloma or mgus. you need further investigation to confirm that what the disease you have. it depends on your monoclonal Ig levels and your symptoms. go for Monoclonal Ig levels,  bone marrow study, x-ray skull also needed. also scan you have any lytic lesion aor not. what is creatinine level and albumin level is also important. but one thing is sure you need further investigation for plasma cell dyscariasis. go for that and take treatment accordingly.'}

Format medical dataset for fine-tuning

In [18]:
import json
data = []

template = """Instruction:
{Patient}

Response:
{Doctor}"""
for record in med_records:
    data.append(json.dumps(template.format(**record)))

In [19]:
data[0]

'"Instruction:\\ni recently had a blood test because i was having a lot of pain which comes and goes i have a crush fracture of t9 blood tests are generally ok, except for the eloctrophoresis which shows a discrete band consistent with plasma cell dyscrasia such as myeloma or mgus\\n\\nResponse:\\nyour electrophoresis reports says that it is consistent with plasma cell dyscariasis. however not specified that it is myeloma or mgus. you need further investigation to confirm that what the disease you have. it depends on your monoclonal Ig levels and your symptoms. go for Monoclonal Ig levels,  bone marrow study, x-ray skull also needed. also scan you have any lytic lesion aor not. what is creatinine level and albumin level is also important. but one thing is sure you need further investigation for plasma cell dyscariasis. go for that and take treatment accordingly."'

## LoRA Fine-tuning

To get better responses from the model, you can fine-tune the model with Low Rank Adaptation (LoRA) using the Databricks Dolly 15k dataset.

The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.

A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.

This tutorial uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance.

In [20]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma2_lm.backbone.enable_lora(rank=4)
gemma2_lm.summary()

Note that enabling LoRA reduces the number of trainable parameters significantly (from 9 billion to 14 million).

In [21]:
# Limit the input sequence length to 256 (to control memory usage).
gemma2_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma2_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

Start fine-tuning job

In [22]:
gemma2_lm.fit(data, epochs=10, batch_size=4)

Epoch 1/10
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m137s[0m 289ms/step - loss: 7.0965 - sparse_categorical_accuracy: 0.4179
Epoch 2/10
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 260ms/step - loss: 2.0822 - sparse_categorical_accuracy: 0.4465
Epoch 3/10
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 260ms/step - loss: 1.9817 - sparse_categorical_accuracy: 0.4642
Epoch 4/10
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 260ms/step - loss: 1.9552 - sparse_categorical_accuracy: 0.4683
Epoch 5/10
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 259ms/step - loss: 1.9378 - sparse_categorical_accuracy: 0.4704
Epoch 6/10
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 261ms/step - loss: 1.9212 - sparse_categorical_accuracy: 0.4735
Epoch 7/10
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 260ms/step - loss: 1.9030 - sparse_categorical_accuracy: 0.4766
Epoch

<keras.src.callbacks.history.History at 0x78c07bfc3190>

## Generate text after fine-tuning
After fine-tuning, responses follow the instruction provided in the prompt.

In [23]:
prompt = template.format(
    Patient="i recently had a blood test because i was having a lot of pain which comes and goes i have a crush fracture of t9 blood tests are generally ok, except for the eloctrophoresis which shows a discrete band consistent with plasma cell dyscrasia such as myeloma or mgus. Is it a thing to worry ?",
    Doctor="",
)
print(gemma2_lm.generate(prompt, max_length=512))

Instruction:
i recently had a blood test because i was having a lot of pain which comes and goes i have a crush fracture of t9 blood tests are generally ok, except for the eloctrophoresis which shows a discrete band consistent with plasma cell dyscrasia such as myeloma or mgus. Is it a thing to worry ?

Response:
Hello, Thank you for posting your query. I have gone through your query and understand your concern. I would like to inform you that the discrete band in electrophoresis is suggestive of monoclonal gammopathy. This is a condition in which there is an increase in the level of a particular type of protein in the blood. This is usually seen in multiple myeloma. So, you need to get a bone marrow biopsy done to rule out multiple myeloma. Hope I have answered your query. I will be happy to help you further. Wish you good health. Thanks.


In [43]:
prompt = template.format(
    Patient="""my daughter 9 yrs old went to community swimming pool for 3 days there .she developed some skin infection around her right eye and on her chin and lip. both areas have got inflamed and she is complaining of pain and itching. now my question is whether her condition can be treated by only topical application or she needs oral medication too.
    """,
    Doctor="",
)
print(gemma2_lm.generate(prompt, max_length=512))

Instruction:
my daughter 9 yrs old went to community swimming pool for 3 days there .she developed some skin infection around her right eye and on her chin and lip. both areas have got inflamed and she is complaining of pain and itching. now my question is whether her condition can be treated by only topical application or she needs oral medication too.
    

Response:
Hi, Thanks for posting your query. I have gone through your query and understand your concern. I would suggest you to apply topical antibiotic ointment like mupirocin ointment 2% on the affected area. You can also apply topical steroid ointment like mometasone furoate 0.1% on the affected area. You can also apply topical antihistamine ointment like


## Save finetuned model to Kaggle/HuggingFace

In [None]:
# Save the finetuned model as a KerasNLP preset.
gemma.save_to_preset("./gemma2-medical-base-7b")

# Upload the preset as a new model variant on Kaggle
kaggle_uri = "kaggle://my_kaggle_username/gemma-medical/keras/gemma2-medical-base-7b"
keras_nlp.upload_preset(kaggle_uri, "./gemma2-medical-base-7b")

In [None]:
# Then save it as a KerasNLP preset.
gemma2_lm.save_to_preset('./gemma2-medical-base-7b)

# Upload the preset to Hugging Face Hub
hf_uri = "hf://my_hf_username/gemma2-medical-base-7b"
keras_nlp.upload_preset(hf_uri, './gemma2-medical-base-7b)