##### Copyright 2024 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.kaggle.com/howecnchen/gemma-kaggle-tpu-only"><img src="https://www.kaggle.com/static/images/logos/kaggle-logo-transparent-300.png" height="32" width="70"/>Run in Kaggle</a>
  </td>
</table>

**原文在[此](https://www.kaggle.com/windmaple/gemma-kaggle-tpu-only).**

# Gemma insutruction tuning on Kaggle TPU using Chinese dataset

This notebook is an adapted from the official [Gemma distributed tuning tutorial](https://ai.google.dev/gemma/docs/distributed_tuning) and [Gemma Vertex AI tutorial](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb). It is meant to demonstrate how to instruction tune Gemma 2B (non-instruction tuned variant) on Kaggle TPU so that the finetuned model can better follow Chinese instructions.

This notebook is also available directly on [Kaggle](https://www.kaggle.com/windmaple/gemma-kaggle-tpu-only).

## Overview

Gemma is a family of lightweight, state-of-the-art open models built from research and technology used to create Google Gemini models. Gemma can be further finetuned to suit specific needs. But Large Language Models, such as Gemma, can be very large in size and some of them may not fit on a sing accelerator for finetuning. In this case there are two general approaches for finetuning them:
1. Parameter Efficient Fine-Tuning (PEFT), which seeks to shrink the effective model size by sacrificing some fidelity. LoRA falls in this category and the [Finetune Gemma models in Keras using LoRA](https://ai.google.dev/gemma/docs/lora_tuning) tutorial demonstrates how to finetune the Gemma 2B model `gemma_2b_en` with LoRA using KerasNLP on a single GPU.
2. Full parameter finetuning with model parallelism. Model parallelism distributes a single model's weights across multiple devices and enables horizontal scaling. You can find out more about distributed training in this [Keras guide](https://keras.io/guides/distribution/).

This tutorial walks you through using Keras with a JAX backend to finetune the Gemma 7B model with LoRA and model-parallism distributed training on Google's Tensor Processing Unit (TPU). Note that LoRA can be turned off in this tutorial for a slower but more accurate full-parameter tuning.

## Using accelerators

Technically you can use either TPU or GPU for this tutorial.

### Notes on TPU environments

Google has 3 products that provide TPUs:
* [Colab](https://colab.sandbox.google.com/) provides TPU v2, which is not sufficient for this tutorial.
* [Kaggle](https://www.kaggle.com/) offers TPU v3 for free and they work for this tutorial.
* [Cloud TPU](https://cloud.google.com/tpu?hl=en) offers TPU v3 and newer generations. One way to set it up is:
  1. Create a new [TPU VM](https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#tpu-vms)
  2. Set up [SSH port forwarding](https://cloud.google.com/solutions/connecting-securely#port-forwarding-over-ssh) for your intended Jupyter server port
  3. Install Jupyter and start it on the TPU VM, then connect to Colab through "Connect to a local runtime"

### Notes on multi-GPU setup

Although this tutorial focuses on the TPU use case, you can easily adapt it for your own needs if you have a multi-GPU machine.

If you prefer to work through Colab, it's also possible to provision a multi-GPU VM for Colab directly through "Connect to a custom GCE VM" in the Colab Connect menu.


We will focus on using the **free TPU from Kaggle** here.

## Before you begin

### Gemma setup

To complete this tutorial, you will first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:

Gemma models are hosted by Kaggle. To use Gemma, request access on Kaggle:

- Sign in or register at [kaggle.com](https://www.kaggle.com)
- Open the [Gemma model card](https://www.kaggle.com/models/google/gemma) and select _"Request Access"_
- Complete the consent form and accept the terms and conditions


## Installation

Install Keras and KerasNLP with the Gemma model.

In [3]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q tensorflow-cpu
!pip install -q -U keras-nlp tensorflow-hub
!pip install -q -U keras>=3
!pip install -qU transformers
!pip install -U sentencepiece

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
keras-nlp 0.8.1 requires keras-core, which is not installed.[0m[31m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have

### Set up Keras JAX backend

Import JAX and run a sanity check on TPU. Kaggle offers TPUv3-8 devices which have 8 TPU cores with 16GB of memory each.

In [4]:
import jax
# 下一个格子里面会解释为什么我们需要 JAX
jax.devices()

E0307 02:44:18.585609393     257 oauth2_credentials.cc:238]            oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {grpc_status:2, created_time:"2024-03-07T02:44:18.585591655+00:00"}


[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [5]:
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
# Keras 3 分发 API 目前仅针对 JAX 后端实现
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
# 预分配 90% 的 TPU 内存，以最大限度地减少内存碎片和分配开销
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

## Load model

In [6]:
import keras
import keras_nlp

### Notes on mixed precision training on NVIDIA GPUs

When training on NVIDIA GPUs, mixed precision (`keras.mixed_precision.set_global_policy('mixed_bfloat16')`) can be used to speed up training with minimal effect on training quality. In most case, it is recommended to turn on mixed precision as it saves both memory and time. However, be aware that at small batch sizes, it can inflate memory usage by 1.5x (weights will be loaded twice, at half precision and full precision).

For inference, half-precision (`keras.config.set_floatx("bfloat16")`) will work and save memory while mixed-precision is not applicable.

In [None]:
# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

To load the model with the weights and tensors distributed across TPUs, first create a new `DeviceMesh`. `DeviceMesh` represents a collection of hardware devices configured for distributed computation and was introduced in Keras 3 as part of the unified distribution API.

The distribution API enables data and model parallelism, allowing for efficient scaling of deep learning models on multiple accelerators and hosts. It leverages the underlying framework (e.g. JAX) to distribute the program and tensors according to the sharding directives through a procedure called single program, multiple data (SPMD) expansion. Check out more details in the new [Keras 3 distribution API guide](https://keras.io/guides/distribution/).

In [7]:
# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
# 创建形状为 (1, 8) 的设备网格，以便权重分布在所有 8 个 TPU 上。
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

`LayoutMap` from the distribution API specifies how the weights and tensors should be sharded or replicated, using the string keys, for example, `token_embedding/embeddings` below, which are treated like regex to match tensor paths. Matched tensors are sharded with model dimensions (8 TPUs); others will be fully replicated.

In [8]:
model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (None, model_dim)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    None, model_dim, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    None, None, model_dim)
layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None)
layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim)

`ModelParallel` allows you to shard model weights or activation tensors across all devcies on the `DeviceMesh`. In this case, some of the Gemma 7B model weights are sharded across 8 TPU chips according the `layout_map` defined above. Now load the model in the distributed way.

In [9]:
model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)

# 从 `keras_nlp.models` 模块导入 `GemmaCausalLM` 类。
# https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/gemma_causal_lm.py
# 通过加载 "gemma_2b_en" 预设来创建一个 `gemma_lm` 对象。
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")

Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


Now verify that the model has been partitioned correctly. Let's take `decoder_block_1` as an example.

In [10]:
# 'decoder_block_1'代表解码器中的第一个解码器层。
# 该层负责将编码器输出的中间表示转换为第一个输出token。
# 为了区分不同的解码器层，通常会使用 'decoder_block_' 加上数字来命名每个解码器层。
# 这种命名方式简洁明了，容易理解。具有通用性，可以应用于不同的 Transformer 模型。
decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
# 遍历 decoder_block_1 中的每个权重变量
# * 使用 variable.path 获取变量的路径，该路径指示变量在模型中的位置。
# * 使用 str(variable.shape) 获取变量的形状，该形状指示变量包含多少个元素。
# * 使用 str(variable.value.sharding.spec) 获取变量的分片规范，
#   该规范指示变量的数据如何在训练期间分布在多个设备（例如 GPU）上。
#
# decoder_block_1.weights 输出包含模型中的所有权重变量，
# 包括 layout_map 中包含的权重变量以及未包含的权重变量。
# layout_map 可以用于将 decoder_block_1.weights 输出中的部分权重变量重新布局到不同的设备上。
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')

<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (2048,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (8, 2048, 256)    PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (1, 2048, 256)    PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (1, 2048, 256)    PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (8, 256, 2048)    PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (2048,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (2048, 16384)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (2048, 16384)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (

## Load instruction dataset

In [11]:
!wget -O baike.jsonl https://huggingface.co/datasets/Hello-SimpleAI/HC3-Chinese/raw/main/baike.jsonl

--2024-03-07 02:45:21--  https://huggingface.co/datasets/Hello-SimpleAI/HC3-Chinese/raw/main/baike.jsonl
Resolving huggingface.co (huggingface.co)... 65.8.243.90, 65.8.243.46, 65.8.243.92, ...
Connecting to huggingface.co (huggingface.co)|65.8.243.90|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5005244 (4.8M) [text/plain]
Saving to: ‘baike.jsonl’


2024-03-07 02:45:24 (9.56 MB/s) - ‘baike.jsonl’ saved [5005244/5005244]



In [15]:
import re
import json
data = []
context = "你是一个知识丰富的人工智能助手，用户将用中文向你提问，你将根据你的知识用中文来如实回答问题。\n"
with open("baike.jsonl") as file:
    for line in file:
        features = json.loads(line)        
        template = context + "问题：\n{question}\n答案：\n{human_answers[0]}"
        data.append(template.format(**features))

# Manually construct a test case; 
# Already made sure the finetuning dataset contains nothing about zsh
test_prompt = context + "问题：\n我有一个信息科学相关的问题，请用中文回答，什么是 zsh\n答案：\n"
# 4616 in total in baike split
train_data = data[:4600]

## Inference before finetuning

In [13]:
# 调用 `gemma_lm` 对象上的 `generate` 方法。
# 它提供起始短语 "test_prompt" 并将 `max_length` 参数设置为 200。
# 300 指示模型应该生成的最大单词数。
# 模型将尝试造句，同时根据其训练数据保持事实一致性。
gemma_lm.generate(test_prompt, max_length=200)

'你是一个知识丰富的人工智能助手，用户将用中文向你提问，你将根据你的知识用中文来如实回答问题。\n问题：\n我有一个信息科学相关的问题，请用中文回答，什么是 zsh\n答案：\nzsh 是一个命令行界面（CLI）的 shell，它支持许多命令行工具，包括 bash， fish， ksh， mksh， pdksh， tcsh， zsh， 和 yash。\nzsh 是一个命令行界面（CLI）的 shell，它支持许多命令行工具，包括 bash， fish， ksh， mksh， pdksh， tcsh， zsh， 和 yash。\nzsh 是一个命令行界面（CLI）的 shell，它支持许多命令行工具，包括 bash， fish， ksh， mksh， pdksh， tcsh， zsh， 和 yash。\nzsh 是一个命令行界面（CLI）的 shell，它支持'

The model starts to repeat itself after a few sentences, which is not good.

In [14]:
#记录一下模型参数，以便微调之后进行比较
model_json = gemma_lm.to_json()
with open("before-model.json", "w") as f:
    f.write(model_json)

## Finetune

Perform finetuning using [Low Rank Adaptation](https://arxiv.org/abs/2106.09685) (LoRA). LoRA is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the full weights of the model and inserting a smaller number of new trainable weights into the model. Basically LoRA reparameterizes the larger full weight matrices by 2 smaller low-rank matrices AxB to train and this technique makes training much faster and more memory-efficient.

In [17]:
# Enable LoRA for the model and set the LoRA rank to 4.
# 该代码片段启用 Gemma 语言模型骨干网的低秩优化 (LORA)，秩为 4。
# gemma_lm.backbone 指的是 Gemma 语言模型的骨干网，它是模型中负责处理和编码输入文本的主要部分。
# enable_lora(rank=4) 是激活骨干网 LORA 的函数调用。
#  * rank=4 参数指定用于 LORA 分解的秩。它定义了用于逼近原始权重矩阵的低秩因子的维数。
#    通常，较低的秩值会导致更大的压缩，但也可能导致模型精度下降。
gemma_lm.backbone.enable_lora(rank=4)

**LORA 简介**：

LORA 是一种用于压缩深度学习模型参数的技术，特别是权重矩阵。它通过将权重矩阵分解为两个低秩矩阵的乘积来实现，从而显着减少存储和训练模型所需的 parameters 数量。这对于以下方面是有益的：

* 减少内存占用: LORA 可以显着减少存储模型所需的内存，使其更容易部署在内存资源有限的设备上。
* 更快地训练: 通过减少参数数量，LORA 可能会导致更快的训练时间。

**代码的影响**：

通过启用秩为 4 的 LORA，该代码旨在压缩 Gemma 模型骨干网中的权重矩阵，这可能会导致：

* 内存使用量减少。
* 更快的训练时间（尽管这取决于各种因素）。

**重要提示**：

启用 LORA 也可能会引入权衡，可能会影响模型的准确性。因此，在应用 LORA 之后评估模型的性能至关重要，以确保其满足所需的准确性要求。

**LORA秩的选取**：

没有严格的规则，通常需要根据具体任务和模型进行调整。以下是一些影响秩选取的因素：

**1. 任务复杂度**:

* 对于简单任务，秩通常可以设置较低，例如 1 或 2。
* 对于复杂任务，可能需要更高的秩来获得更好的精度，例如 4 或 8。

**2. 模型大小**:

* 对于较小的模型，秩通常可以设置较低。
* 对于较大的模型，可能需要更高的秩来捕获更多信息。

**3. 精度要求**:

* 如果对精度要求不高，可以设置较低的秩。
* 如果需要更高的精度，则需要设置更高的秩。

**4. 计算资源**:

* 较高的秩需要更多的计算资源进行训练和推理。
* 需要根据实际情况权衡精度和计算资源之间的关系。

**经验建议**:

可以从较低的秩开始，然后根据需要逐渐增加。
可以使用交叉验证来找到最佳的秩。
可以参考其他类似任务的经验。

**以下是一些具体的建议**：

对于大多数自然语言处理任务，秩 4 或 8 是一个不错的起点。
对于图像分类任务，秩 16 或 32 可能是必要的。
对于语音识别任务，秩 64 或更高可能需要。
需要注意的是，LORA 秩的选择是一个经验性过程，需要根据具体情况进行调整。

**参考资料**：

LORA: Low-Rank Optimization for Large Language Models: https://arxiv.org/abs/2104.08991
LORA: A Low-Rank Adapter for Efficient and Accurate Large Language Model Training: https://arxiv.org/abs/2202.05954

**总结**：

LORA秩的选取没有固定的规则，需要根据任务复杂度、模型大小、精度要求和计算资源等因素进行调整。可以从较低的秩开始，然后根据需要逐渐增加，也可以参考其他类似任务的经验。

In [18]:
#接下来，使用通用的训练设置来调整 gemma_lm 模型以适应任务。
#
#AdamW 是基于 Transformer 的语言模型中常用的优化器。
#为了考虑内存，调整了序列长度。
#使用加权指标来处理潜在不平衡的训练数据集。

# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
# 将模型可以处理的文本序列的最大长度限制为 128 个标记。 
# 这通常是为了提高内存效率，特别是在使用像 Gemma 这样的大型语言模型时。
gemma_lm.preprocessor.sequence_length = 128

# Use AdamW (a common optimizer for transformer models).
# keras.optimizers.AdamW设置了用于训练模型的优化器 AdamW。
# * AdamW 是 Adam 优化器的变体，包括权重衰减。
# * 权重衰减是一种正则化技术，有助于防止过拟合。
# * 偏差和层归一化参数 (scale) 被排除在权重衰减之外，因为它们通常需要不同的学习率。
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
# 告诉 AdamW 优化器将名为 “bias” 和 “scale” 的参数从权重衰减中排除。
# 这些参数通常与 批量归一化层 相关联，而批量归一化层经常用于像 Gemma 这样的 Transformer 模型。将它们从权重衰减中排除有助于：
# * 保持它们的原始学习率: 权重衰减会影响可训练参数的学习率。排除它们允许根据需要进行独立的学习率调整。
# * 保持它们的有效性: 将权重衰减应用于 “bias” 和 “scale” 可能会对批量归一化层的性能产生负面影响。
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

# 编译模型，并指定了：
# * 要使用的损失函数：SparseCategoricalCrossentropy，适用于每个样本都属于一个类的多类分类任务。
# * 优化器：配置的 AdamW 优化器。
# * 训练过程中要跟踪的指标：SparseCategoricalAccuracy，根据多类预测来衡量准确性。
gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
# 使用提供的 data 开始模型训练过程。
# * epochs=5 该参数设置训练的 epoch 数。
#   一个 epoch 是对整个训练数据集的完整遍历。在这里，模型将训练 5 个 epoch。
# * batch_size=32 该参数定义训练期间使用的 batch 大小。一个 batch 是用于一次训练更新的训练数据子集。
#   在这里，batch 大小为 32 表示模型将在每次处理 32 个训练示例后进行更新。
gemma_lm.fit(train_data, epochs=5, batch_size=32)

Epoch 1/5
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m140s[0m 834ms/step - loss: 2.9920 - sparse_categorical_accuracy: 0.4369
Epoch 2/5
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m90s[0m 623ms/step - loss: 1.9053 - sparse_categorical_accuracy: 0.6308
Epoch 3/5
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m89s[0m 620ms/step - loss: 1.8146 - sparse_categorical_accuracy: 0.6418
Epoch 4/5
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m89s[0m 620ms/step - loss: 1.7925 - sparse_categorical_accuracy: 0.6428
Epoch 5/5
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m89s[0m 620ms/step - loss: 1.7777 - sparse_categorical_accuracy: 0.6442


<keras.src.callbacks.history.History at 0x7cf9b4667190>

In total it took <10 mininutes.

## Inference after finetuning

In [19]:
gemma_lm.generate(test_prompt, max_length=200)

'你是一个知识丰富的人工智能助手，用户将用中文向你提问，你将根据你的知识用中文来如实回答问题。\n问题：\n我有一个信息科学相关的问题，请用中文回答，什么是 zsh\n答案：\nzsh（Zsh Shell）是一个命令行用户界面，在Linux和Unix操作系统上运行的shell程序，是Z shell（zsh）的缩写。 \n它是一个基于POSIX规范的命令行用户界面，支持命令行参数和文件参数传递，支持标准命令的自动补齐。'

Now it gives much better answer in Chinese than the pretrained variant.

In [26]:
#记录一下新的模型参数，以便进行比较
model_json = gemma_lm.to_json()
with open("after-model.json", "w") as f:
    f.write(model_json)

In [50]:
!pip install json_tools

Collecting json_tools
  Downloading json_tools-0.4.1.tar.gz (7.2 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting colorama
  Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)
Building wheels for collected packages: json_tools
  Building wheel for json_tools (setup.py) ... [?25ldone
[?25h  Created wheel for json_tools: filename=json_tools-0.4.1-py3-none-any.whl size=10287 sha256=1c4b0807f847b3988d9afd705e6b5b4178e394c550ca8fdf7ed93d708662e72e
  Stored in directory: /root/.cache/pip/wheels/9a/38/dd/3e5cdf6112b06d0965456b456a2e8905e1076850910a39b987
Successfully built json_tools
Installing collected packages: colorama, json_tools
Successfully installed colorama-0.4.6 json_tools-0.4.1
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [52]:
# 计算参数变化
with open("before-model.json", "r") as f:
    before_model_dict = json.load(f)
with open("after-model.json", "r") as f:
    after_model_dict = json.load(f)

import json
import json_tools
diff_value = json_tools.diff(before_model_dict, after_model_dict)
# 输出微调前后的模型差异。
# 目前看来，所有这些变化都是以牺牲准确度为前提提升训练速度。
# 估计是为了演示而进行的调整。如果真的是为了生产环境的微调，这些参数大部分还是要改回去。
print(json.dumps(diff_value, indent=2))

[
    {
        "replace": "/config/preprocessor/config/sequence_length",
        "value": 128,
        "prev": 8192
    },
    {
        "replace": "/compile_config/optimizer/class_name",
        "value": "AdamW",
        "prev": "Adam"
    },
    {
        "replace": "/compile_config/optimizer/config/name",
        "value": "adamw",
        "prev": "adam"
    },
    {
        "replace": "/compile_config/optimizer/config/learning_rate",
        "value": 4.999999873689376e-05,
        "prev": 1.9999999494757503e-05
    },
    {
        "replace": "/compile_config/optimizer/config/weight_decay",
        "value": 0.01,
        "details": "type",
        "prev": null
    },
    {
        "replace": "/compile_config/metrics",
        "value": null,
        "details": "type",
        "prev": [
            {
                "module": "keras.metrics",
                "class_name": "SparseCategoricalAccuracy",
                "config": {
                    "name": "sparse_categorical_accuracy

## Convert to Hugging Face

Many ppl prefer to use Hugging Face than Keras for whatever reason. It's easy to make the conversion.

In [None]:
# Finetuned model
FINETUNED_MODEL_DIR = f"./finetuned_gemma"
FINETUNED_WEIGHTS_PATH = f"{FINETUNED_MODEL_DIR}/model.weights.h5"
FINETUNED_VOCAB_PATH = f"{FINETUNED_MODEL_DIR}/vocabulary.spm"

# Converted model
HUGGINGFACE_MODEL_DIR = f"./gemma_huggingface"

MODEL_NAME = "gemma_2b_en"

# Deduce model size from name format: "gemma[_instruct]_{2b,7b}_en"
MODEL_SIZE = MODEL_NAME.split("_")[-2]

In [None]:
# Make sure the directory exists
%mkdir -p $FINETUNED_MODEL_DIR

gemma_lm.save_weights(FINETUNED_WEIGHTS_PATH)

gemma_lm.preprocessor.tokenizer.save_assets(FINETUNED_MODEL_DIR)

In [None]:
!du -shc $FINETUNED_MODEL_DIR/*

In [None]:
# Download the conversion script from KerasNLP tools
!wget -nv -nc https://raw.githubusercontent.com/keras-team/keras-nlp/master/tools/gemma/export_gemma_to_hf.py

# Run the conversion script
# Note: it uses the PyTorch backend of Keras (hence the KERAS_BACKEND env variable)
!KERAS_BACKEND=torch python export_gemma_to_hf.py \
    --weights_file $FINETUNED_WEIGHTS_PATH \
    --size $MODEL_SIZE \
    --vocab_path $FINETUNED_VOCAB_PATH \
    --output_dir $HUGGINGFACE_MODEL_DIR

In [None]:
import transformers
model = transformers.GemmaForCausalLM.from_pretrained(
    HUGGINGFACE_MODEL_DIR,
    local_files_only=True,
    device_map="auto",  # Library "accelerate" to auto-select GPU
)
tokenizer = transformers.GemmaTokenizer.from_pretrained(
    HUGGINGFACE_MODEL_DIR,
    local_files_only=True,
)

In [None]:
def test_transformers_model(
    model: transformers.GemmaForCausalLM,
    tokenizer: transformers.GemmaTokenizer,
) -> None:   
    inputs = tokenizer([test_prompt], return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, max_length=200)

    output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"{output}\n{'- '*40}")

# This run on CPU so it is a bit slow
test_transformers_model(model, tokenizer)

Tis is very much similar to the KerasNLP output we had before, so I think our HF conversion worked.

# Final note

* Here we used Gemma 2B. Technically you can use Gemma 7B, but sadly Kaggle only offers 20G of hard drive disk space, so you can't easily store the converted HF file.
* The instruction-tuned variant of Gemma 2B that we didn't use does have some basic capability to follow Chinese instructions but the technique used here can be used to further enhance it.
* TPU v3 is so much faster than the free T4 GPU on Google Colab. 
