##### Copyright 2024 Google LLC.




## Overview

Gemma is a family of lightweight, state-of-the-art open models built from research and technology used to create Google Gemini models. Gemma can be further fine-tuned to suit specific needs. Large Language Models, such as Gemma 2 9B, can be very large in size and some of them may not fit on a single accelerator for finetuning. In this case there are two techniques that helps us fine-tune the model:
1. Parameter Efficient Fine-Tuning (PEFT), which seeks to shrink the effective model size by sacrificing some fidelity. LoRA falls in this category and the [Fine-tune Gemma models in Keras using LoRA](https://ai.google.dev/gemma/docs/lora_tuning) tutorial demonstrates how to fine-tune the Gemma 9B model `gemma_9b_en` with LoRA using KerasNLP on a single GPU.
2. Full parameter fine-tuning with model parallelism. Model parallelism distributes a single model's weights across multiple devices and enables horizontal scaling. You can find out more about distributed training in this [Keras guide](https://keras.io/guides/distribution/).

This tutorial walks you through using Keras with a JAX backend to use LoRA fine-tuning and Model Parallelism to fine-tune Gemma 2 9B on Google's Tensor Processing Unit (TPU).

## Using accelerators

You can use TPUs for this tutorial.
[Cloud TPU](https://cloud.google.com/tpu?hl=en) offers TPU v3 and newer generations. One way to set it up is:
  1. Create a new [TPU VM](https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#tpu-vms)
  2. Set up [SSH port forwarding](https://cloud.google.com/solutions/connecting-securely#port-forwarding-over-ssh) for your intended Jupyter server port
  3. Install Jupyter and start it on the TPU VM, then connect to Colab through "Connect to a local runtime". See: https://research.google.com/colaboratory/local-runtimes.html

[Here](https://docs.google.com/document/d/1sJYqi5qYjNMoLOLGELgqemNx2mrDyQEuTsQjzAb0uCA/edit?usp=sharing) is the a guide to create a TPU VM for Colab.

## Installation

Install KerasNLP with the Gemma 2 model.

In [None]:
!pip install  -U keras-nlp
!pip install keras==3.3.3
!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install wandb

In [1]:
import wandb
wandb.login()

# Initialize a new W&B run
wandb.init(
    project="Enhancing-Sinhala_NLP",  # Set your project name here
    config={
        "learning_rate": 5e-5,
        "epochs": 4,  # Adjust this to your needs
        "batch_size": 4,  # Adjust this to your needs
        "weight_decay": 0.01,
        "sequence_length": 256,
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mthe-ai-team97[0m. Use [1m`wandb login --relogin`[0m to force relogin


### Set up Keras JAX backend

In [27]:
import jax

jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [28]:
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 100% of TPU memory to minimize memory fragmentation and allocation overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

import keras
import keras_nlp

# Load Dataset

In [7]:
dataset_name = "0xAIT/sinhala-flan" # dataset to fine-tune on
base_model_name = "google/gemma-2-9b" # model that we're fine-tuning

In [8]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(base_model_name,token = '')
tokenizer.padding_side = "right"

In [1]:
from datasets import load_from_disk

dataset = load_from_disk('/mnt/persistent_disk/combined_dataset')

Loading dataset from disk:   0%|          | 0/115 [00:00<?, ?it/s]

In [17]:
data = []

prompt_template = """### Instruction:\n {} \n\n### Response:\n{}"""

In [4]:
import pandas as pd
from tqdm import tqdm
tqdm.pandas()
df=dataset.to_pandas()

100%|██████████| 10550057/10550057 [08:38<00:00, 20347.84it/s]


KeyboardInterrupt: 

In [5]:
data = df.progress_apply(lambda row: prompt_template.format(row['Translated Input'], row['Translated Target']), axis=1)

100%|██████████| 10550057/10550057 [07:15<00:00, 24236.84it/s]


In [11]:
data.head()

0    ### Instruction:\n මිනිසා තම සෙල්ලම් බඩු දරුවන...
1    ### Instruction:\n [ප්‍රශ්නය] "දුඹුරු බල්ලෙක් ...
2    ### Instruction:\n Jax: පහත සඳහන් වාක්‍යවලින් ...
3    ### Instruction:\n ශිෂ්‍යයා ඇසුවේය: "දම් පැහැත...
4    ### Instruction:\n **ප්‍ර**\nපහත වාක්‍යය සත්‍ය...
dtype: object

## Load model

To load the model with the weights and tensors distributed across TPUs, first create a new `DeviceMesh`. `DeviceMesh` represents a collection of hardware devices configured for distributed computation and was introduced in Keras 3 as part of the unified distribution API.

The distribution API enables data and model parallelism, allowing for efficient scaling of deep learning models on multiple accelerators and hosts. It leverages the underlying framework (e.g. JAX) to distribute the program and tensors according to the sharding directives through a procedure called single program, multiple data (SPMD) expansion. Check out more details in the new [Keras 3 distribution API guide](https://keras.io/guides/distribution/).

In [7]:
# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

`LayoutMap` from the distribution API specifies how the weights and tensors should be sharded or replicated.

In [8]:
layout_map = keras_nlp.models.GemmaBackbone.get_layout_map(device_mesh)

`ModelParallel` allows you to shard model weights or activation tensors across all devcies on the `DeviceMesh`. In this case, some of the Gemma 2 27B model weights are sharded across 8 TPU cores according to the `layout_map` defined above.

In [9]:
model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)

Now load the Gemma 2 27B model in the distributed way.

In [10]:
gemma2_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_9b_en")
gemma2_lm.summary()

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


## Generate text before fine-tuning

Now the Gemma 2 9B model is ready to be used for text generation.

In [22]:
prompt = prompt_template.format(
"What should I do on a trip to Europe?",
""
)
print(gemma2_lm.generate(prompt, max_length=128))

### Instruction:
 What should I do on a trip to Europe? 

### Response:
 I would recommend visiting the Eiffel Tower in Paris, France. It is a beautiful and iconic landmark that is a must-see for any traveler. You can also take a boat ride on the Seine River and enjoy the stunning views of the city. If you're interested in art, you can visit the Louvre Museum and see some of the world's most famous paintings. For a more active experience, you can take a bike tour of the city and explore the different neighborhoods. Don't forget to try some of the delicious French


In [None]:
prompt = template.format(
    "It's my friend's birthday, and they enjoy hiking and nature. I have a budget of $50. Recommend a thoughtful gift they might like.",
""
)
print(gemma2_lm.generate(prompt, max_length=128))

Instruction:
It's my friend's birthday, and they enjoy hiking and nature. I have a budget of $50. Recommend a thoughtful gift they might like.

Response:
I think a hiking backpack would be a great gift for your friend. It's a practical and useful item that they can use on their hikes. You can find a variety of backpacks in different sizes and styles to suit their needs.

Instruction:
I'm looking for a gift for my friend who loves to cook. They have everything they need in the kitchen, so I'm looking for something unique and special. I


In [None]:
prompt = template.format(
    "Explain the process of photosynthesis in a way that a child could understand.",
    "",
)
print(gemma2_lm.generate(prompt, max_length=128))

Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Photosynthesis is a process that plants use to make their own food. Plants need sunlight, water, and carbon dioxide to make food. The process starts when sunlight hits the leaves of the plant. The sunlight is absorbed by the plant's cells and is used to make energy. The energy is then used to break down water and carbon dioxide into glucose and oxygen. The glucose is used by the plant to make food, and the oxygen is released into the air.


## LoRA Fine-tuning

To get better responses from the model, you can fine-tune the model with Low Rank Adaptation (LoRA) using the Databricks Dolly 15k dataset.

The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.

A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.

This tutorial uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance.

In [11]:
gemma2_lm.backbone.enable_lora(rank=4)
gemma2_lm.summary()

Note that enabling LoRA reduces the number of trainable parameters significantly (from 9 billion to 14 million).

In [12]:
# Limit the input sequence length to 256 (to control memory usage).
gemma2_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma2_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)


In [13]:
from tensorflow.keras.callbacks import ModelCheckpoint
from wandb.integration.keras import WandbMetricsLogger

checkpoint_callback = ModelCheckpoint(
    filepath="/mnt/persistent_disk/model_checkpoints/epoch_{epoch:02d}_loss_{loss:.2f}.keras",  # File path where to save the model
    save_weights_only=False,  # Set to True to save only model weights
    monitor='loss',  # Metric to monitor
    mode='min',  # Mode 'min' because you want to save the model when val_loss is minimized
    save_best_only=True,  # Save the model only when the monitored metric improves
    verbose=1,  # Verbosity mode,
    save_freq=50000
)

callbacks = [
    WandbMetricsLogger(),  # This callback will handle logging to W&B
    checkpoint_callback,  # ModelCheckpoint callback
]

In [21]:
gemma2_lm.fit(data, epochs=1, batch_size=4,callbacks=callbacks)

[1m 49999/250003[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m14:12:24[0m 256ms/step - loss: 0.8689 - sparse_categorical_accuracy: 0.7467
Epoch 1: loss improved from inf to 0.85359, saving model to /mnt/persistent_disk/model_checkpoints_new_1/epoch_01_loss_0.85.keras
[1m 99999/250003[0m [32m━━━━━━━[0m[37m━━━━━━━━━━━━━[0m [1m10:46:36[0m 259ms/step - loss: 0.8441 - sparse_categorical_accuracy: 0.7492
Epoch 1: loss improved from 0.85359 to 0.76823, saving model to /mnt/persistent_disk/model_checkpoints_new_1/epoch_01_loss_0.77.keras
[1m149999/250003[0m [32m━━━━━━━━━━━[0m[37m━━━━━━━━━[0m [1m7:11:40[0m 259ms/step - loss: 0.8107 - sparse_categorical_accuracy: 0.7511
Epoch 1: loss improved from 0.76823 to 0.74345, saving model to /mnt/persistent_disk/model_checkpoints_new_1/epoch_01_loss_0.74.keras
[1m199999/250003[0m [32m━━━━━━━━━━━━━━━[0m[37m━━━━━[0m [1m3:35:53[0m 259ms/step - loss: 0.7945 - sparse_categorical_accuracy: 0.7537
Epoch 1: loss did not improve from 0

[34m[1mwandb[0m: [32m[41mERROR[0m Unable to log learning rate.


[1m250003/250003[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65141s[0m 260ms/step - loss: 0.7839 - sparse_categorical_accuracy: 0.7565


<keras.src.callbacks.history.History at 0x7f9210271e70>

In [22]:
gemma2_lm.save('sinhala-gemma-2-9b.keras')
wandb.finish()

VBox(children=(Label(value='0.010 MB of 0.010 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch/epoch,▁
epoch/loss,▁
epoch/sparse_categorical_accuracy,▁

0,1
epoch/epoch,0.0
epoch/loss,0.7375
epoch/sparse_categorical_accuracy,0.76991


## Generate text after fine-tuning
After fine-tuning, responses follow the instruction provided in the prompt.

In [54]:
prompt = prompt_template.format(
"n = 10 සඳහා S_n මූලද්‍රව්‍ය සඳහා හැකි උපරිම අනුපිළිවෙල සොයන්න.",
""
)
print(gemma2_lm.generate(prompt, max_length=128))

### Instruction:
 n = 10 සඳහා S_n මූලද්‍රව්‍ය සඳහා හැකි උපරිම අනුපිළිවෙල සොයන්න. 

### Response:
10


Note that for demonstration purposes, this tutorial fine-tunes the model on a small subset of the dataset for just 10 epochs and with a low LoRA rank value. To get better responses from the fine-tuned model, you can experiment with:

1. Increasing the size of the fine-tuning dataset
2. Training for more steps (epochs)
3. Setting a higher LoRA rank
4. Modifying the hyperparameter values such as `learning_rate` and `weight_decay`.