##### Copyright 2024 Google LLC.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.kaggle.com/windmaple/gemma-kaggle-tpu-only"><img src="https://www.kaggle.com/static/images/logos/kaggle-logo-transparent-300.png" height="32" width="70"/>Run in Kaggle</a>
  </td>
</table>

# Gemma insutruction tuning on Kaggle TPU using Chinese dataset

This notebook is an adapted from the official [Gemma distributed tuning tutorial](https://ai.google.dev/gemma/docs/distributed_tuning) and [Gemma Vertex AI tutorial](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb). It is meant to demonstrate how to instruction tune Gemma 2B (non-instruction tuned variant) on Kaggle TPU so that the finetuned model can better follow Chinese instructions.

(Note that the instruction-tuned variant of Gemma 2B does have some basic capability to follow Chinese instructions but the technique used here can be used to further enhance it.)

This notebook is also available directly on [Kaggle](https://www.kaggle.com/windmaple/gemma-kaggle-tpu-only).

## Overview

Gemma is a family of lightweight, state-of-the-art open models built from research and technology used to create Google Gemini models. Gemma can be further finetuned to suit specific needs. But Large Language Models, such as Gemma, can be very large in size and some of them may not fit on a sing accelerator for finetuning. In this case there are two general approaches for finetuning them:
1. Parameter Efficient Fine-Tuning (PEFT), which seeks to shrink the effective model size by sacrificing some fidelity. LoRA falls in this category and the [Finetune Gemma models in Keras using LoRA](https://ai.google.dev/gemma/docs/lora_tuning) tutorial demonstrates how to finetune the Gemma 2B model `gemma_2b_en` with LoRA using KerasNLP on a single GPU.
2. Full parameter finetuning with model parallelism. Model parallelism distributes a single model's weights across multiple devices and enables horizontal scaling. You can find out more about distributed training in this [Keras guide](https://keras.io/guides/distribution/).

This tutorial walks you through using Keras with a JAX backend to finetune the Gemma 7B model with LoRA and model-parallism distributed training on Google's Tensor Processing Unit (TPU). Note that LoRA can be turned off in this tutorial for a slower but more accurate full-parameter tuning.

## Using accelerators

Technically you can use either TPU or GPU for this tutorial.

### Notes on TPU environments

Google has 3 products that provide TPUs:
* [Colab](https://colab.sandbox.google.com/) provides TPU v2, which is not sufficient for this tutorial.
* [Kaggle](https://www.kaggle.com/) offers TPU v3 for free and they work for this tutorial.
* [Cloud TPU](https://cloud.google.com/tpu?hl=en) offers TPU v3 and newer generations. One way to set it up is:
  1. Create a new [TPU VM](https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#tpu-vms)
  2. Set up [SSH port forwarding](https://cloud.google.com/solutions/connecting-securely#port-forwarding-over-ssh) for your intended Jupyter server port
  3. Install Jupyter and start it on the TPU VM, then connect to Colab through "Connect to a local runtime"

### Notes on multi-GPU setup

Although this tutorial focuses on the TPU use case, you can easily adapt it for your own needs if you have a multi-GPU machine.

If you prefer to work through Colab, it's also possible to provision a multi-GPU VM for Colab directly through "Connect to a custom GCE VM" in the Colab Connect menu.


We will focus on using the **free TPU from Kaggle** here.

## Before you begin

### Gemma setup

To complete this tutorial, you will first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:

Gemma models are hosted by Kaggle. To use Gemma, request access on Kaggle:

- Sign in or register at [kaggle.com](https://www.kaggle.com)
- Open the [Gemma model card](https://www.kaggle.com/models/google/gemma) and select _"Request Access"_
- Complete the consent form and accept the terms and conditions


## Installation

Install Keras and KerasNLP with the Gemma model.

In [2]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q tensorflow-cpu
!pip install -q -U keras-nlp tensorflow-hub
!pip install -q -U keras>=3
!pip install -qU transformers
!pip install -U sentencepiece

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
keras-nlp 0.7.0 requires keras-core, which is not installed.[0m[31m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have

### Set up Keras JAX backend

Import JAX and run a sanity check on TPU. Kaggle offers TPUv3-8 devices which have 8 TPU cores with 16GB of memory each.

In [3]:
import jax

jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [4]:
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

## Load model

In [5]:
import keras
import keras_nlp

### Notes on mixed precision training on NVIDIA GPUs

When training on NVIDIA GPUs, mixed precision (`keras.mixed_precision.set_global_policy('mixed_bfloat16')`) can be used to speed up training with minimal effect on training quality. In most case, it is recommended to turn on mixed precision as it saves both memory and time. However, be aware that at small batch sizes, it can inflate memory usage by 1.5x (weights will be loaded twice, at half precision and full precision).

For inference, half-precision (`keras.config.set_floatx("bfloat16")`) will work and save memory while mixed-precision is not applicable.

In [6]:
# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

To load the model with the weights and tensors distributed across TPUs, first create a new `DeviceMesh`. `DeviceMesh` represents a collection of hardware devices configured for distributed computation and was introduced in Keras 3 as part of the unified distribution API.

The distribution API enables data and model parallelism, allowing for efficient scaling of deep learning models on multiple accelerators and hosts. It leverages the underlying framework (e.g. JAX) to distribute the program and tensors according to the sharding directives through a procedure called single program, multiple data (SPMD) expansion. Check out more details in the new [Keras 3 distribution API guide](https://keras.io/guides/distribution/).

In [7]:
# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

`LayoutMap` from the distribution API specifies how the weights and tensors should be sharded or replicated, using the string keys, for example, `token_embedding/embeddings` below, which are treated like regex to match tensor paths. Matched tensors are sharded with model dimensions (8 TPUs); others will be fully replicated.

In [8]:
model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (None, model_dim)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    None, model_dim, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    None, None, model_dim)
layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None)
layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim)

`ModelParallel` allows you to shard model weights or activation tensors across all devcies on the `DeviceMesh`. In this case, some of the Gemma 7B model weights are sharded across 8 TPU chips according the `layout_map` defined above. Now load the model in the distributed way.

In [9]:
model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")

Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


Now verify that the model has been partitioned correctly. Let's take `decoder_block_1` as an example.

In [10]:
decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')

<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (2048,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (8, 2048, 256)    PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (1, 2048, 256)    PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (1, 2048, 256)    PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (8, 256, 2048)    PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (2048,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (2048, 16384)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (2048, 16384)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (

## Load instruction dataset

In [11]:
!wget -O baike.jsonl https://huggingface.co/datasets/Hello-SimpleAI/HC3-Chinese/raw/main/baike.jsonl

--2024-02-25 06:20:25--  https://huggingface.co/datasets/Hello-SimpleAI/HC3-Chinese/raw/main/baike.jsonl
Resolving huggingface.co (huggingface.co)... 18.244.202.118, 18.244.202.60, 18.244.202.68, ...
Connecting to huggingface.co (huggingface.co)|18.244.202.118|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5005244 (4.8M) [text/plain]
Saving to: ‘baike.jsonl’


2024-02-25 06:20:25 (14.6 MB/s) - ‘baike.jsonl’ saved [5005244/5005244]



In [12]:
import re
import json
data = []
context = "你是一个知识丰富的人工智能助手，用户将用中文向你提问，你将根据你的知识用中文来如实回答问题。\n"
with open("baike.jsonl") as file:
    for line in file:
        features = json.loads(line)        
        template = context + "问题：\n{question}\n答案：\n{human_answers[0]}"
        data.append(template.format(**features))

# Manually construct a test case; 
# Already made sure the finetuning dataset contains nothing about zsh
test_prompt = context + "问题：\n我有一个信息科学相关的问题，请用中文回答，什么是 zsh\n答案：\n"
# 4616 in total in baike split
train_data = data[:4600]

## Inference before finetuning

In [13]:
gemma_lm.generate(test_prompt, max_length=200)

'你是一个知识丰富的人工智能助手，用户将用中文向你提问，你将根据你的知识用中文来如实回答问题。\n问题：\n我有一个信息科学相关的问题，请用中文回答，什么是 zsh\n答案：\nzsh 是一个命令行界面（CLI）的 shell，它支持许多命令行工具，包括 bash， fish， ksh， mksh， pdksh， tcsh， zsh， 和 yash。\nzsh 是一个命令行界面（CLI）的 shell，它支持许多命令行工具，包括 bash， fish， ksh， mksh， pdksh， tcsh， zsh， 和 yash。\nzsh 是一个命令行界面（CLI）的 shell，它支持许多命令行工具，包括 bash， fish， ksh， mksh， pdksh， tcsh， zsh， 和 yash。\nzsh 是一个命令行界面（CLI）的 shell，它支持'

The model starts to repeat itself after a few sentences, which is not good.

## Finetune

Perform finetuning using [Low Rank Adaptation](https://arxiv.org/abs/2106.09685) (LoRA). LoRA is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the full weights of the model and inserting a smaller number of new trainable weights into the model. Basically LoRA reparameterizes the larger full weight matrices by 2 smaller low-rank matrices AxB to train and this technique makes training much faster and more memory-efficient.

In [14]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)

In [15]:
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(train_data, epochs=5, batch_size=32)

Epoch 1/5


See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.


[1m143/144[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 638ms/step - loss: 2.9938 - sparse_categorical_accuracy: 0.4366

See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.


[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m138s[0m 800ms/step - loss: 2.9890 - sparse_categorical_accuracy: 0.4374
Epoch 2/5
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m91s[0m 629ms/step - loss: 1.9191 - sparse_categorical_accuracy: 0.6265
Epoch 3/5
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m91s[0m 630ms/step - loss: 1.8129 - sparse_categorical_accuracy: 0.6410
Epoch 4/5
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m91s[0m 629ms/step - loss: 1.7858 - sparse_categorical_accuracy: 0.6431
Epoch 5/5
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m91s[0m 629ms/step - loss: 1.7715 - sparse_categorical_accuracy: 0.6446


<keras.src.callbacks.history.History at 0x78d960646b90>

Note that enabling LoRA reduces the number of trainable parameters significantly, from 7 billion to only 11 million.

In total it took <10 mininutes.

## Inference after finetuning

In [16]:
gemma_lm.generate(test_prompt, max_length=200)

'你是一个知识丰富的人工智能助手，用户将用中文向你提问，你将根据你的知识用中文来如实回答问题。\n问题：\n我有一个信息科学相关的问题，请用中文回答，什么是 zsh\n答案：\nzsh（Z Shell）是一种UNIX/Linux操作系统中的一个内部命令。 \nzsh是一种高效率的交互式终端用户命令语言。它在命令行中提供了一个类似于Unix Bourne或Shell（BASH）的shell。 \nzsh的优点是：\n支持命令行历史记录（command history）'

Now it gives much better answer in Chinese than the pretrained variant.

## Convert to Hugging Face

Many ppl prefer to use Hugging Face than Keras for whatever reason. It's easy to make the conversion.

In [17]:
# Finetuned model
FINETUNED_MODEL_DIR = f"./finetuned_gemma"
FINETUNED_WEIGHTS_PATH = f"{FINETUNED_MODEL_DIR}/model.weights.h5"
FINETUNED_VOCAB_PATH = f"{FINETUNED_MODEL_DIR}/vocabulary.spm"

# Converted model
HUGGINGFACE_MODEL_DIR = f"./gemma_huggingface"

MODEL_NAME = "gemma_2b_en"

# Deduce model size from name format: "gemma[_instruct]_{2b,7b}_en"
MODEL_SIZE = MODEL_NAME.split("_")[-2]

In [18]:
# Make sure the directory exists
%mkdir -p $FINETUNED_MODEL_DIR

gemma_lm.save_weights(FINETUNED_WEIGHTS_PATH)

gemma_lm.preprocessor.tokenizer.save_assets(FINETUNED_MODEL_DIR)

In [19]:
!du -shc $FINETUNED_MODEL_DIR/*

9.4G	./finetuned_gemma/model.weights.h5
4.1M	./finetuned_gemma/vocabulary.spm
9.4G	total


In [20]:
# Download the conversion script from KerasNLP tools
!wget -nv -nc https://raw.githubusercontent.com/keras-team/keras-nlp/master/tools/gemma/export_gemma_to_hf.py

# Run the conversion script
# Note: it uses the PyTorch backend of Keras (hence the KERAS_BACKEND env variable)
!KERAS_BACKEND=torch python export_gemma_to_hf.py \
    --weights_file $FINETUNED_WEIGHTS_PATH \
    --size $MODEL_SIZE \
    --vocab_path $FINETUNED_VOCAB_PATH \
    --output_dir $HUGGINGFACE_MODEL_DIR

2024-02-25 06:31:03 URL:https://raw.githubusercontent.com/keras-team/keras-nlp/master/tools/gemma/export_gemma_to_hf.py [11761/11761] -> "export_gemma_to_hf.py" [1]

-> Loading Keras weights from file `./finetuned_gemma/model.weights.h5`...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
  trackable.load_own_variables(weights_store.get(inner_path))

-> Loading HuggingFace Gemma `2B` model...

✅ Model loading complete.

-> Converting weights from Ker

In [21]:
import transformers
model = transformers.GemmaForCausalLM.from_pretrained(
    HUGGINGFACE_MODEL_DIR,
    local_files_only=True,
    device_map="auto",  # Library "accelerate" to auto-select GPU
)
tokenizer = transformers.GemmaTokenizer.from_pretrained(
    HUGGINGFACE_MODEL_DIR,
    local_files_only=True,
)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  2.69it/s]


In [22]:
def test_transformers_model(
    model: transformers.GemmaForCausalLM,
    tokenizer: transformers.GemmaTokenizer,
) -> None:   
    inputs = tokenizer([test_prompt], return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, max_length=200)

    output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"{output}\n{'- '*40}")

# This run on CPU so it is a bit slow
test_transformers_model(model, tokenizer)

你是一个知识丰富的人工智能助手，用户将用中文向你提问，你将根据你的知识用中文来如实回答问题。
问题：
我有一个信息科学相关的问题，请用中文回答，什么是 zsh
答案：
zsh（Z Shell）是一个POSIX兼容的shell，它在BSD/OS和Linux系统上被广泛使用。 
zsh是Z shell的缩写，Z shell是Unix shell的一种，它继承了Bourne shell的特性，并增加了许多新的特性。 
zsh的特性包括： 
1.支持多级目录 
2.支持命令别名 
3.支持命令补全 
4.支持命令历史 
5.支持命令行参数
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 


Tis is very much similar to the KerasNLP output we had before, so I think our HF conversion worked.

# Final note

* Here we used Gemma 2B. Technically you can use Gemma 7B, but sadly Kaggle only offers 20G of hard drive disk space, so you can't easily store the converted HF file.
* TPU v3 is so much faster than the free T4 GPU on Google Colab. 