##### Copyright 2024 Google LLC.

In [None]:
# @title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Fine-tuning Gemma for Function Calling

Welcome to this step-by-step guide on fine-tuning the [Gemma](https://huggingface.co/google/gemma-2b) for Function Calling.


[**Gemma**](https://ai.google.dev/gemma) is a family of lightweight, state-of-the-art open models from Google, built from the same research and technology used to create the Gemini models. They are text-to-text, decoder-only large language models, available in English, with open weights, pre-trained variants, and instruction-tuned variants. Gemma models are well-suited for a variety of text generation tasks, including question answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as a laptop, desktop, or your own cloud infrastructure, democratizing access to state of the art AI models and helping foster innovation for everyone.

**Function calling finetuning** is a crucial step in enhancing the performance of LLMs with function calling capabilities. It involves training the model on a dataset of prompts and corresponding function calls, enabling it to accurately identify the appropriate function for a given task. By fine-tuning the model, it learns to better understand the nuances of natural language, recognize the intent behind prompts, and select the most suitable functions.

This notebook uses [Torch XLA](https://github.com/pytorch/xla) and Hugging Face's [**Transformer Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) framework for Function calling finetuning.

[**Torch XLA**](https://pytorch.org/xla/) enables you to leverage the computational power of TPUs (Tensor Processing Units) for efficient training of deep learning models. By interfacing PyTorch with the [XLA (Accelerated Linear Algebra)](https://openxla.org/xla) compiler, Torch XLA translates PyTorch operations into XLA operations that can be executed on TPUs. This means you can write your models in PyTorch as usual, and Torch XLA handles the underlying computations to run them efficiently on TPUs.

**Transformer Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) is a framework developed by Hugging Face to fine-tune and align both transformer language and diffusion models using methods such as Supervised Fine-Tuning (SFT), Reward Modeling (RM), Proximal Policy Optimization (PPO), Direct Preference Optimization (DPO), and others.

To know more about how to use Torch XLA and TRL to finetune Gemma, check the **Finetune with Torch XLA** notebook from [Gemma Cookbook](https://github.com/google-gemini/gemma-cookbook/blob/main/Gemma/Finetune_with_Torch_XLA.ipynb).

<table align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Finetune_with_Function_Calling.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
</table>
<br><br>

[![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)]("https://www.kaggle.com/notebooks/welcome?src=https://github.com/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Finetune_with_Function_Calling.ipynb")

## Setup

### Selecting the Runtime Environment

To start, you can choose either **Google Colab** or **Kaggle** as your platform. Select one, and proceed from there.

- #### **Google Colab** <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/d/d0/Google_Colaboratory_SVG_Logo.svg/1200px-Google_Colaboratory_SVG_Logo.svg.png" alt="Google Colab" width="30"/>

  1. Click **Open in Colab**.
  2. In the menu, go to **Runtime** > **Change runtime type**.
  3. Under **Hardware accelerator**, select **TPU**.
  4. Ensure that the **TPU type** is set to **TPU v2-8**.

- #### **Kaggle** <img src="https://upload.wikimedia.org/wikipedia/commons/7/7c/Kaggle_logo.png" alt="Kaggle" width="40"/>

  1. Click **Open in Kaggle**.
  2. Click on **Settings** in the right sidebar.
  3. Under **Accelerator**, select **TPUs**.
    - Note: Kaggle currently provides **TPU v3-8**.
  4. Save the settings, and the notebook will restart with TPU support.


### Gemma using Hugging Face

Before diving into the tutorial, let's set up Gemma:

1. **Create a Hugging Face Account**: If you don't have one, you can sign up for a free account [here](https://huggingface.com/join).
2. **Access the Gemma Model**: Visit the [Gemma model page](https://huggingface.com/collections/google/gemma-2-release-667d6600fd5220e7b967f315) and accept the usage conditions.
3. **Generate a Hugging Face Token**: Go to your Hugging Face [settings page](https://huggingface.com/settings/tokens) and generate a new access token (preferably with `write` permissions). You'll need this token later in the tutorial.

**Once you've completed these steps, you're ready to move on to the next section where you'll set up environment variables in your Colab environment.**

### Configure Your Credentials

To access private models and datasets, you need to log in to the Hugging Face (HF) ecosystem.

- #### **Google Colab** <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/d/d0/Google_Colaboratory_SVG_Logo.svg/1200px-Google_Colaboratory_SVG_Logo.svg.png" alt="Google Colab" width="30"/>
  If you're using Colab, you can securely store your Hugging Face token (`HF_TOKEN`) using the Colab Secrets manager:
  1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src="https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg" alt="The Secrets tab is found on the left panel." width=50%>
  2. **Add Hugging Face Token**:
    - Create a new secret with the **name** `HF_TOKEN`.
    - Copy/paste your token key into the **Value** input box of `HF_TOKEN`.
    - **Toggle** the button on the left to allow notebook access to the secret

- #### **Kaggle** <img src="https://upload.wikimedia.org/wikipedia/commons/7/7c/Kaggle_logo.png" alt="Kaggle" width="40"/>
  To securely use your Hugging Face token (`HF_TOKEN`) in this notebook, you'll need to add it as a secret in your Kaggle environment:  
  1. Open your Kaggle notebook and locate the **Addons** menu at the top in your notebook interface.
  2. Click on **Secrets** to manage your environment secrets.  
  <img src="https://i.imgur.com/vxrtJuM.png" alt="The Secrets option is found at the top." width=50%>
  3. **Add Hugging Face Token**:
      - Click on the **Add secret** button.
      - In the **Label** field, enter `HF_TOKEN`.  
      - In the **Value** field, paste your Hugging Face token.
      - Click **Save** to add the secret.

This code retrieves your secrets and sets them as environment variables, which you will use later in the tutorial.

In [None]:
import os
import sys

if 'google.colab' in sys.modules:
    # Running on Colab
    from google.colab import userdata
    os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')
elif os.path.exists('/kaggle/working'):
    # Running on Kaggle
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    os.environ['HF_TOKEN'] = user_secrets.get_secret("HF_TOKEN")
else:
    # Not running on Colab or Kaggle
    raise EnvironmentError('This notebook is designed to run on Google Colab or Kaggle.')

### Install dependencies

Next, you'll set up the environment by installing all the necessary Python packages for fine-tuning the Gemma model on a TPU VM using Torch XLA.


In [None]:
# Uninstalling any existing TensorFlow installations and then install the CPU-only version to avoid conflicts while using the TPU.
!pip uninstall -y tensorflow tf-keras
!pip install tensorflow==2.18.0 tf-keras==2.18.0

!pip uninstall tensorflow -y
!pip install tensorflow-cpu==2.18.0 -q

# Install the appropriate Hugging Face libraries to ensure compatibility with the Gemma model and PEFT.
!pip install transformers==4.46.1 -U -q
!pip install datasets==3.1.0 -U -q
!pip install trl==0.12.0 peft==0.13.2 -U -q
!pip install accelerate==0.34.0 -U -q

# Install PyTorch and Torch XLA with versions compatible with the TPU runtime, ensuring efficient TPU utilization.
!pip install -qq torch~=2.5.0 --index-url https://download.pytorch.org/whl/cpu
!pip install -qq torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html

# Install the `tpu-info` package to display TPU-related information
!pip install tpu-info

Found existing installation: tensorflow 2.15.0
Uninstalling tensorflow-2.15.0:
  Successfully uninstalled tensorflow-2.15.0
Found existing installation: tf_keras 2.15.1
Uninstalling tf_keras-2.15.1:
  Successfully uninstalled tf_keras-2.15.1
Collecting tensorflow==2.18.0
  Downloading tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting tf-keras==2.18.0
  Downloading tf_keras-2.18.0-py3-none-any.whl.metadata (1.6 kB)
Collecting tensorboard<2.19,>=2.18 (from tensorflow==2.18.0)
  Downloading tensorboard-2.18.0-py3-none-any.whl.metadata (1.6 kB)
Collecting keras>=3.5.0 (from tensorflow==2.18.0)
  Downloading keras-3.7.0-py3-none-any.whl.metadata (5.8 kB)
Collecting ml-dtypes<0.5.0,>=0.4.0 (from tensorflow==2.18.0)
  Downloading ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting namex (from keras>=3.5.0->tensorflow==2.18.0)
  Downloading namex-0.0.8-py3-none-any.whl.metadata (246 bytes

**Note**: Ensure that your PyTorch and Torch XLA versions are compatible with the TPU you're using.

### Verify TPU Setup

You run `!tpu-info` to verify the TPU has been properly initialized.

In [None]:
!tpu-info

[3mTPU Chips                                     [0m
┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━┓
┃[1m [0m[1mChip       [0m[1m [0m┃[1m [0m[1mType       [0m[1m [0m┃[1m [0m[1mDevices[0m[1m [0m┃[1m [0m[1mPID [0m[1m [0m┃
┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━┩
│ /dev/accel0 │ TPU v2 chip │ 2       │ None │
│ /dev/accel1 │ TPU v2 chip │ 2       │ None │
│ /dev/accel2 │ TPU v2 chip │ 2       │ None │
│ /dev/accel3 │ TPU v2 chip │ 2       │ None │
└─────────────┴─────────────┴─────────┴──────┘
Libtpu metrics unavailable. Is there a framework using the TPU? See https://github.com/google/cloud-accelerator-diagnostics/tree/main/tpu_info for more information


If everything is set up correctly, you should see the TPU details printed out.

## Finetuning Gemma 2 for Function Calling

### Initializing Gemma 2 model

You will initialize the `AutoModelForCausalLM` from the `transformers` library by loading a pre-trained Gemma 2 model from HuggingFace. You will also initialize the tokenizer for the selected model(`google/gemma-2-2b-it`) using the `AutoTokenizer` from the `transformers` library.

In [None]:
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)

# Define model names
model_name = "google/gemma-2-2b-it"
new_model = "gemma-func-ft"

# Load the Gemma pre-trained model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16
)

# You must disable the cache to prevent issues during training
model.config.use_cache = False

# Load the Gemma tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# You adjust the tokenizer's padding side to ensure compatibility during TPU
# training.
tokenizer.padding_side = "right" # Fix overflow issue with bf16/fp16 training

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

Enable Single Program Multiple Data (SPMD) mode,
which allows for parallel execution across multiple TPU cores.


In [None]:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr

xr.use_spmd()

### Load a dataset

For this guide, you'll use an existing dataset from Hugging Face. You can replace it with your dataset if you prefer.

The dataset chosen for this guide is [**lilacai/glaive-function-calling-v2-sharegpt**](https://huggingface.co/datasets/lilacai/glaive-function-calling-v2-sharegpt), which is a ShareGPT version of the original **glaive-function-calling-v2** dataset by glaiveai. The glaive-function-calling-v2 dataset is a collection of over 113,000 prompts and corresponding function calls that can fine-tune language models to identify the appropriate function for a given task accurately.

**Credits:** **https://huggingface.co/lilacai**

In [None]:
from datasets import Dataset, load_dataset

# Only the first 15% of the `train` split is used for training. A smaller
# subsection of the dataset is selected to avoid out-of-memory crashes.
dataset = load_dataset("lilacai/glaive-function-calling-v2-sharegpt", split="train[:15%]")

README.md:   0%|          | 0.00/2.51k [00:00<?, ?B/s]

(…)-00000-of-00002-6f3344faa23e9b0a.parquet:   0%|          | 0.00/98.0M [00:00<?, ?B/s]

(…)-00001-of-00002-41f063cddf49c933.parquet:   0%|          | 0.00/98.5M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/112960 [00:00<?, ? examples/s]

Let's look at a few samples to understand the data.

In [None]:
dataset[10]['conversations']

[{'from': 'system',
  'value': 'You are a helpful assistant with access to the following functions. Use them if required -\n{\n    "name": "calculate_discount",\n    "description": "Calculate the discount amount based on original price and discount percentage",\n    "parameters": {\n        "type": "object",\n        "properties": {\n            "original_price": {\n                "type": "number",\n                "description": "The original price of the item"\n            },\n            "discount_percentage": {\n                "type": "number",\n                "description": "The percentage discount"\n            }\n        },\n        "required": [\n            "original_price",\n            "discount_percentage"\n        ]\n    }\n}\n'},
 {'from': 'human',
  'value': "Hi, I saw a dress that I liked in a store. It was originally priced at $200 but it's on a 20% discount. Can you help me calculate how much I will save?"},
 {'from': 'gpt',
  'value': '<functioncall> {"name": "cal

### Create a custom chat template

Hugging Face supports chat templates that can be used to define the structure and format for converting conversations into a single tokenizable string, which is the input format expected by the language model. Check the [chat templates documentation](https://huggingface.co/docs/transformers/main/en/chat_templating) to know more about templates and how to create a custom new one.

Since Gemma doesn't support system instructions, you will provide system input as user input. To read more about the format expected by Gemma, check out the [Gemma formatting doc](https://ai.google.dev/gemma/docs/formatting).

In [None]:
# Reference: https://github.com/unslothai/unsloth/blob/main/unsloth/chat_templates.py#L383

chat_template = \
    "{{ bos_token }}"\
    "{% if messages[0]['from'] == 'system' %}"\
        "{{'<start_of_turn>user\n' + messages[0]['value'] | trim + ' ' + messages[1]['value'] | trim + '<end_of_turn>\n'}}"\
        "{% set messages = messages[2:] %}"\
    "{% endif %}"\
    "{% for message in messages %}"\
        "{% if message['from'] == 'human' %}"\
            "{{'<start_of_turn>user\n' + message['value'] | trim + '<end_of_turn>\n'}}"\
        "{% elif message['from'] == 'gpt' %}"\
            "{{'<start_of_turn>model\n' + message['value'] | trim + '<end_of_turn>\n' }}"\
        "{% endif %}"\
    "{% endfor %}"\
    "{% if add_generation_prompt %}"\
        "{{ '<start_of_turn>model\n' }}"\
    "{% endif %}"

tokenizer.chat_template = chat_template

### Define the formatting function

The formatting function applies the template created above to each row in the dataset and converts it into a format suited for training.

In [None]:
def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(convo, tokenize = False,
                      add_generation_prompt = False) for convo in convos]
    return { "text" : texts, }

dataset = dataset.map(formatting_prompts_func, batched = True,)

Map:   0%|          | 0/16944 [00:00<?, ? examples/s]

### Clean up the dataset.

Remove unnecessary tokens from the dataset.

In [None]:
import pandas as pd

df_train = pd.DataFrame(dataset)
df_train["text"] = df_train["text"].apply(
    lambda x: x.replace("<|endoftext|>", ""))

pd.set_option('display.max_colwidth', None)
print(df_train.head(1))

                                                                                                                                                                                                                                                                                                                                                                 chat  \
0  USER: Hi, I have a list of numbers and I need to find the median. The numbers are 5, 2, 9, 1, 7, 4, 6, 3, 8.\n\n\nASSISTANT: <functioncall> {"name": "calculate_median", "arguments": '{"numbers": [5, 2, 9, 1, 7, 4, 6, 3, 8]}'} <|endoftext|>\n\n\nFUNCTION RESPONSE: {"median": 5}\n\n\nASSISTANT: The median of your list of numbers is 5. <|endoftext|>\n\n\n   

                                                                                                                                                                                                                                                                                     

Convert the dataset back to Hugging Face's `Dataset` format.

In [None]:
dataset = Dataset.from_pandas(df_train[['text']])

dataset

Dataset({
    features: ['text'],
    num_rows: 16944
})

### LoRA configuration

LoRA(Low-Rank Adaptation) introduces small, trainable matrices into the model's architecture, specifically targeting the attention layers of Transformer models. Instead of updating the full weight matrices, LoRA adds rank-decomposed matrices, making adaptation more efficient.

Here, you set the following parameters:
- `r` to 16, which controls the rank of the adaptation matrices.
- `lora_alpha` to 16 for scaling.
- `lora_dropout` to 0 since it is optimized.

In [None]:
from peft import LoraConfig, PeftModel

# Load LoRA configuration
peft_config = LoraConfig(
    lora_alpha=16,       # Alpha parameter for LoRA scaling
    lora_dropout=0,    # Dropout probability for LoRA layers
    r=16,                # LoRA attention dimension
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",]
)

The **Fully Sharded Data Parallel (FSDP)** configuration is set up in `fsdp_config`, enabling [**full model sharding**](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy) and [**gradient checkpointing**](https://huggingface.co/docs/transformers/v4.19.4/en/performance#gradient-checkpointing) for memory efficiency on TPUs, and specifying that gradient checkpointing should be enabled with `xla_fsdp_grad_ckpt`.

In [None]:
# Set up the FSDP config. To enable FSDP via SPMD, set xla_fsdp_v2 to True.
fsdp_config = {
    "fsdp_transformer_layer_cls_to_wrap": [
        "Gemma2DecoderLayer"
    ],
    "xla": True,
    "xla_fsdp_v2": True,
    "xla_fsdp_grad_ckpt": True
}

### Set training configuration

Set up the training arguments that define how the model will be trained.

Here, you'll define the following parameters:

- For training:
  - `output directory`
  - `max steps`
  - `batch sizes`

- To optimize the training process:
  - `learning rate`
  - `optimizer`
  - `learning rate scheduler`

**Note:** `max_steps` is set as 100 steps to speed things up, but you can set `num_train_epochs=1` for a full run.

In [None]:
from trl import SFTTrainer, SFTConfig

# Set training parameters
training_arguments = SFTConfig(
    # ---Output settings--
    # Output directory where model predictions and checkpoints will be stored
    output_dir="./results",
    overwrite_output_dir=True,
    save_strategy="no",
    # ---Training settings---
    # Number of training epochs
    #num_train_epochs=1,
    # Number of training steps (overrides num_train_epochs)
    max_steps=100,
    # This is the global train batch size for SPMD
    # Batch size per GPU core for training
    per_device_train_batch_size=32,
    # Number of update steps to accumulate the gradients for
    gradient_accumulation_steps=1,
    # Optimizer to use
    optim="adafactor",
    # Required for SPMD
    dataloader_drop_last=True,
    fsdp="full_shard",
    fsdp_config=fsdp_config,
    # Initial learning rate (adafactor optimizer)
    learning_rate=0.0002,
    # Enable bfloat16 precision
    bf16=True,
    # Maximum gradient normal (gradient clipping)
    max_grad_norm=0.3,
    # Ratio of steps for a linear warmup (from 0 to learning rate)
    warmup_ratio=0.03,
    # Learning rate schedule (constant a bit better than cosine)
    lr_scheduler_type="linear",
    # Maximum sequence length to use
    max_seq_length=1024,
    dataset_text_field="text",
    dataset_kwargs={
        "add_special_tokens": False,
        "append_concat_token": False,
    },
    # Pack multiple short examples in the same input sequence
    # to increase efficiency
    packing=True,
    # ---Logging---
    # Log every X update step
    logging_steps=1,
    report_to="none",
    seed=42
)



### Train the model

[Huggingface's TRL](https://huggingface.co/docs/trl/index) offers a user-friendly API for building SFT models and training them on your dataset with just a few lines of code. Here you will use Huggingface TRL's `SFTTrainer` class to train the model. This class inherits from the `Trainer` class available in the Transformers library but is specifically optimized for supervised fine-tuning (instruction tuning). Read more about SFFTrainer from the [official TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer).

In [None]:
# Set supervised fine-tuning parameters
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    args=training_arguments
)

Generating train split: 0 examples [00:00, ? examples/s]

max_steps is given, it will override any value given in num_train_epochs


Now, let's start the fine-tuning process by calling `trainer.train()`, which uses `SFTTrainer` to handle the training loop, including data loading, forward and backward passes, and optimizer steps, all configured according to the settings you've provided.

In [None]:
trainer.train()

  torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
  torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):


Step,Training Loss
1,2.0938
2,2.125
3,2.0781
4,1.7656
5,1.5469
6,1.2188
7,1.1641
8,1.0391
9,1.125
10,0.9531


  xldata.append(torch.load(xbio))


TrainOutput(global_step=100, training_loss=0.708515625, metrics={'train_runtime': 849.5865, 'train_samples_per_second': 3.767, 'train_steps_per_second': 0.118, 'total_flos': 5.18083433201664e+16, 'train_loss': 0.708515625, 'epoch': 0.3861003861003861})

After training is complete, you save the fine-tuned model by moving it to the CPU with `trainer.model.to('cpu')` to ensure compatibility and then calling `save_pretrained(new_model)` to save the model weights and configuration files to the directory specified by `new_model` (**gemma-func-ft**). This allows you to reload and use the fine-tuned model later for inference or further training.

In [None]:
# Remove the model weights directory if it exists
!rm -rf gemma-func-ft

# Save the LoRA adapter
trainer.model.to('cpu').save_pretrained(new_model)

## Prompt using the newly fine-tuned model


Now that you've finally fine-tuned your custom Gemma model, let's reload the LoRA adapter weights to finally prompt using it and also verify if it's really working as intended.

To do this, use the following steps to correctly reload the adapter weights:

- Use `AutoModelForCausalLM.from_pretrained` to first load the **base Gemma model**, while setting `low_cpu_mem_usage=True` to optimize memory consumption (since you're using a TPU) and `torch_dtype=torch.bfloat16` for consistency with the fine-tuned model.

- Load the **fine-tuned LoRA adapter** that you've previously saved into the base model using `PeftModel.from_pretrained`, where `new_model` is the directory containing your fine-tuned weights.

- The `model.merge_and_unload` function **merges** the **LoRA adapter weights** with the **base model weights** and unloads the adapter, resulting in a standalone model ready for inference.

In [None]:
# Reload the fine-tuned Gemma model
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.bfloat16
)
model = PeftModel.from_pretrained(base_model, new_model)
model = model.merge_and_unload()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Reload the tokenizer to ensure it matches the model configuration, adjusting the padding side as before.

In [None]:
# Reload tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"

Now, test the fine-tuned model with a sample prompt by first using the tokenizer to generate the input ids, and then relying on the reloaded fine-tuned model to generate a response using `model.generate()`.

In [None]:
input_text = """\
<start_of_turn>user
You are a helpful assistant with access to the following functions. Use them if required -
{
    "name": "calculate_median",
    "description": "Calculate the median of a list of numbers",
    "parameters": {
        "type": "object",
        "properties": {
             "numbers": {
                 "type": "array",
                 "items": {
                     "type": "number"
                 },
                 "description": "The list of numbers"
             }
        }
        "required": [
            "numbers"
        ]
    }
}
To use these functions respond with:
<functioncall> {"name": "function_name", "arguments": {"arg_1": "value_1", "arg_1": "value_1", ...}} </functioncall>

Then finally respond with:
Answer:

<end_of_turn>
<start_of_turn>user
USER: Hi, I have a list of numbers and I need to find the median. The numbers are [5, 2, 9, 1, 7, 4, 6, 3, 8]
<end_of_turn>
<start_of_turn>model
<functioncall>
"""

In [None]:
input_ids = tokenizer(input_text, return_tensors="pt").to("cpu")
outputs = model.generate(**input_ids, max_new_tokens = 512)

Finally, you decode the output tokens back into human-readable text with `tokenizer.decode` and print the result, allowing you to see how the fine-tuned model responds to the prompt.

In [None]:
print(tokenizer.decode(outputs[0]))

<bos><start_of_turn>user
You are a helpful assistant with access to the following functions. Use them if required -
{
    "name": "calculate_median",
    "description": "Calculate the median of a list of numbers",
    "parameters": {
        "type": "object",
        "properties": { 
             "numbers": {
                 "type": "array",
                 "items": {
                     "type": "number"              
                 },
                 "description": "The list of numbers"
             }      
        }       
        "required": [
            "numbers"       
        ]    
    }
}
To use these functions respond with:
<functioncall> {"name": "function_name", "arguments": {"arg_1": "value_1", "arg_1": "value_1", ...}} </functioncall>

Then finally respond with:
Answer:

<end_of_turn>
<start_of_turn>user
USER: Hi, I have a list of numbers and I need to find the median. The numbers are [5, 2, 9, 1, 7, 4, 6, 3, 8]
<end_of_turn>
<start_of_turn>model
<functioncall>
{"nam

Congratulations! You've successfully fine-tuned Gemma for Function Calling using Torch XLA and PEFT with LoRA on TPUs. With that, you've covered the entire process, from setting up the environment to training and testing the model.

## What's next?
Your next steps could include the following:

- **Experiment with Different Datasets**: Try fine-tuning on other function calling datasets in [Hugging Face](https://huggingface.co/docs/datasets/en/index) or your own data.

- **Tune Hyperparameters**: Adjust training parameters (e.g., learning rate, batch size, epochs, LoRA settings) to optimize performance and
improve training efficiency.

- **Try different templates**: Try different chat templates and try to improve the performance.

By exploring these activities, you'll deepen your understanding and further enhance your fine-tuned Gemma model. Happy experimenting!