## Overview

Gemma is a family of lightweight, state-of-the-art open models built from research and technology used to create Google Gemini models. Gemma can be further finetuned to suit specific needs. But Large Language Models, such as Gemma, can be very large in size and some of them may not fit on a sing accelerator for finetuning.

In the case of Gemma 2, the 9 billion model is too large to fit on a single Kaggle accelerator. This can make both fine-tuning and inference in Kaggle Notebooks difficult.

One option is quantizing a model. Quantization can allow efficient inference for larger models with less GPU or TPU memory.

Another option is model parallelism, which splits the model's parameters across a number of accelerators, and can be used easy for both inference and fine-tuning. This guide will show how to fine-tune a Gemma 2 model on 8 TPU V3 cores through the Kaggle TPU runtime. You can find out more about distributed training in this [Keras guide](https://keras.io/guides/distribution/).

## Using accelerators

Technically you can use either TPU or GPU for this tutorial.

### Notes on TPU environments

Google has 3 products that provide TPUs:
* [Colab](https://colab.sandbox.google.com/) provides TPU v2, which is not sufficient for this tutorial.
* [Kaggle](https://www.kaggle.com/) offers TPU v3 for free and they work for this tutorial.
* [Cloud TPU](https://cloud.google.com/tpu?hl=en) offers TPU v3 and newer generations. One way to set it up is:
  1. Create a new [TPU VM](https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#tpu-vms)
  2. Set up [SSH port forwarding](https://cloud.google.com/solutions/connecting-securely#port-forwarding-over-ssh) for your intended Jupyter server port
  3. Install Jupyter and start it on the TPU VM, then connect to Colab through "Connect to a local runtime"

### Notes on multi-GPU setup

Although this tutorial focuses on the TPU use case, you can easily adapt it for your own needs if you have a multi-GPU machine.

If you prefer to work through Colab, it's also possible to provision a multi-GPU VM for Colab directly through "Connect to a custom GCE VM" in the Colab Connect menu.


We will focus on using the **free TPU from Kaggle** here.

## Before you begin

### Gemma setup

To complete this tutorial, you first need to accept the Gemma Terms of Use. You can navigate to the [Keras Gemma 2 Page](https://www.kaggle.com/models/keras/gemma2) to do this. You will see a banner at the top of the page with a button to "Request Access" if you have not already done this for your Kaggle user.

## Installation

Install Keras and KerasNLP with the Gemma model.

In [1]:
!pip install -q -U keras-nlp tensorflow-text
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
!pip install -q -U tensorflow-cpu

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


### Set up Keras JAX backend

Import JAX and run a sanity check on TPU. Kaggle offers TPUv3-8 devices which have 8 TPU cores with 16GB of memory each.

In [2]:
import jax

jax.devices()

E0000 00:00:1725105590.948976    4406 common_lib.cc:798] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:479
E0831 11:59:50.983822766    4500 oauth2_credentials.cc:238]            oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {created_time:"2024-08-31T11:59:50.983779709+00:00", grpc_status:2}


[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [3]:
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate all TPU memory to minimize memory fragmentation and allocation overhead.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

## Load model

In [4]:
import keras
import keras_nlp

  from .autonotebook import tqdm as notebook_tqdm


To load the model with the weights and tensors distributed across TPUs, first create a new `DeviceMesh`. `DeviceMesh` represents a collection of hardware devices configured for distributed computation and was introduced in Keras 3 as part of the unified distribution API.

The distribution API enables data and model parallelism, allowing for efficient scaling of deep learning models on multiple accelerators and hosts. It leverages the underlying framework (e.g. JAX) to distribute the program and tensors according to the sharding directives through a procedure called single program, multiple data (SPMD) expansion. Check out more details in the new [Keras 3 distribution API guide](https://keras.io/guides/distribution/).

In [5]:
# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices(),
)

`LayoutMap` from the distribution API specifies how the weights and tensors should be sharded or replicated, using the string keys, for example, `token_embedding/embeddings` below, which are treated like regex to match tensor paths. Matched tensors are sharded with model dimensions (8 TPUs); others will be fully replicated.

In [6]:
model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in attention layers
layout_map["decoder_block.*attention.*(query|key|value)/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*attention_output/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*/kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear/kernel"] = (model_dim, None)

`ModelParallel` allows you to shard model weights or activation tensors across all devcies on the `DeviceMesh`. In this case, some of the Gemma 7B model weights are sharded across 8 TPU chips according the `layout_map` defined above. Now load the model in the distributed way.

In [7]:
model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map,
    batch_dim_name="batch",
)

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_9b_en")
gemma_lm.summary()

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


Let's verify that the model has been partitioned correctly. Let's take `decoder_block_1` as an example.

In [8]:
decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<48}  {str(variable.shape):<14}  {str(variable.value.sharding.spec)}')

<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale          (3584,)         PartitionSpec(None,)
decoder_block_1/pre_attention_norm/scale          (3584,)         PartitionSpec(None,)
decoder_block_1/attention/query/kernel            (16, 3584, 256)  PartitionSpec('model', None, None)
decoder_block_1/attention/key/kernel              (8, 3584, 256)  PartitionSpec('model', None, None)
decoder_block_1/attention/value/kernel            (8, 3584, 256)  PartitionSpec('model', None, None)
decoder_block_1/attention/attention_output/kernel  (16, 256, 3584)  PartitionSpec('model', None, None)
decoder_block_1/pre_ffw_norm/scale                (3584,)         PartitionSpec(None,)
decoder_block_1/post_ffw_norm/scale               (3584,)         PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                 (3584, 14336)   PartitionSpec(None, 'model')
decoder_block_1/ffw_gating_2/kernel               (3584, 14336)   Partition

## Inference before finetuning

Let's try asking a model a question.

In [9]:
print(gemma_lm.generate("How can I plan a trip to Europe?", max_length=512))

How can I plan a trip to Europe?

[User 0001]

I'm planning a trip to Europe for the first time. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm going to be in Europe for 2 weeks. I'm go

We are using the base Gemma 2 model, which means it has not been fine-tuned for any particular task. It has been trained to simply guess the next work on a vast amount of source documents.

Such a model is not yet a good fit for question answering. It will tend to continue predicting likely words, often continuing the question itself instead as if it was a random snippet of a random document on the web. It can easily get stuck in loops of high probability sequences.

To make it more useful, we can fine-tune on a question answering dataset. In this tutorial, we will use the Databricks Dolly dataset. This dataset contains 15,000 high-quality human-generated prompt / response pairs specifically designed for fine-tuning LLMs to follow instructions. Such fine-tuning is often called instruction fine-tuning, or IFT for short.

## Instruction fine-tuning

In [10]:
import json
data = []
with open('/kaggle/input/databricks-dolly-15k/databricks-dolly-15k.jsonl') as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Truncate our data to speed up training.
data = data[:1000]

Let's look at a single training example.

In [11]:
data[0]

'Instruction:\nWhich is a species of fish? Tope or Rope\n\nResponse:\nTope'

We will perform finetuning using [Low Rank Adaptation](https://arxiv.org/abs/2106.09685) (LoRA). LoRA is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the full weights of the model and inserting a smaller number of new trainable weights into the model. Basically LoRA reparameterizes the larger full weight matrices by 2 smaller low-rank matrices AxB to train and this technique makes training much faster and more memory-efficient.

In [12]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)

In [13]:
# Limit the input sequence length to 1024 to control memory usage.
gemma_lm.preprocessor.sequence_length = 1024
gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(learning_rate=5e-5),
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()

Note that enabling LoRA reduces the number of trainable parameters significantly, from 9 billion to only ~30 million.

Let's fine-tune our model!

In [14]:
gemma_lm.fit(data, epochs=5, batch_size=4)

Epoch 1/5
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m327s[0m 1s/step - loss: 0.2337 - sparse_categorical_accuracy: 0.5396
Epoch 2/5
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m264s[0m 1s/step - loss: 0.1903 - sparse_categorical_accuracy: 0.5847
Epoch 3/5
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m265s[0m 1s/step - loss: 0.1860 - sparse_categorical_accuracy: 0.5893
Epoch 4/5
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m265s[0m 1s/step - loss: 0.1823 - sparse_categorical_accuracy: 0.5962
Epoch 5/5
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m265s[0m 1s/step - loss: 0.1778 - sparse_categorical_accuracy: 0.6030


<keras.src.callbacks.history.History at 0x7821f864e110>

## Inference after finetuning

Now that we have fine-tuned our model, we can try prompting with a question again. This time, we will use the specific format we used to combine our prompts and responses from the Databricks Dolly dataset.

In [15]:
print(gemma_lm.generate("Instruction:\nHow can I plan a trip to Europe?\n\nResponse:\n", max_length=512))

Instruction:
How can I plan a trip to Europe?

Response:
Planning a trip to Europe can be a daunting task. There are so many countries to choose from, and each one has its own unique culture and attractions. Here are a few tips to help you plan your trip:

1. Decide which countries you want to visit. Europe is a large continent, and it's impossible to see everything in one trip. Choose a few countries that interest you, and focus on those.

2. Research each country's attractions. Once you've chosen your countries, start researching the different attractions each one has to offer. Make a list of the things you want to see and do, and prioritize them.

3. Plan your itinerary. Now that you know what you want to see and do, it's time to start planning your itinerary. Decide how many days you want to spend in each country, and start booking your flights and accommodations.

4. Pack for the weather. Europe has a wide range of climates, so it's important to pack for the weather. Do some resea

Much better! We could improve this model even more by fine-tuning with more data and tuning our learning rate and lora rank.

Alternately, the Gemma models come with pre-instruction tuned checkpoints that can be used for question answering and a chat like experience out of the box. See [Gemma 2 inference using KerasNLP](https://www.kaggle.com/code/nilaychauhan/gemma-2-inference-using-kerasnlp) as an example.

# Save LoRa Weight

In [17]:
gemma_lm.backbone.save_lora_weights("LoRa_Gemma2_9b_en.lora.h5")