<link rel="stylesheet" href="/site-assets/css/gemma.css">
<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,GRAD@20..48,100..700,0..1,-50..200" />

##### Copyright 2024 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Inference with Gemma using JAX and Flax

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/jax_inference"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />View on ai.google.dev</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/jax_inference.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/google/generative-ai-docs/main/site/en/gemma/docs/jax_inference.ipynb"><img src="https://ai.google.dev/images/cloud-icon.svg" width="40" />Open in Vertex AI</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/jax_inference.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

## Overview

Gemma is a family of lightweight, state-of-the-art open large language models, based on the Google DeepMind Gemini research and technology. This tutorial demonstrates how to perform basic sampling/inference with the Gemma 2B Instruct model using [Google DeepMind's `gemma` library](https://github.com/google-deepmind/gemma) that was written with [JAX](https://jax.readthedocs.io) (a high-performance numerical computing library), [Flax](https://flax.readthedocs.io) (the JAX-based neural network library), [Orbax](https://orbax.readthedocs.io/) (a JAX-based library for training utilities like checkpointing), and [SentencePiece](https://github.com/google/sentencepiece) (a tokenizer/detokenizer library). Although Flax is not used directly in this notebook, Flax was used to create Gemma.

This notebook can run on Google Colab with free T4 GPU (go to **Edit** > **Notebook settings** > Under **Hardware accelerator** select **T4 GPU**).

## Setup

### 1. Set up Kaggle access for Gemma

To complete this tutorial, you first need to follow the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup), which show you how to do the following:

* Get access to Gemma on [kaggle.com](https://www.kaggle.com/models/google/gemma/).
* Select a Colab runtime with sufficient resources to run the Gemma model.
* Generate and configure a Kaggle username and API key.

After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.

### 2. Set environment variables

Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`. When prompted with the "Grant access?" messages, agree to provide secret access.

In [None]:
import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

### 3. Install the `gemma` library

This notebook focuses on using a free Colab GPU. To enable hardware acceleration, click on **Edit** > **Notebook settings** > Select **T4 GPU** > **Save**.

Next, you need to install the Google DeepMind `gemma` library from [`github.com/google-deepmind/gemma`](https://github.com/google-deepmind/gemma). If you get an error about "pip's dependency resolver", you can usually ignore it.

**Note:** By installing `gemma`, you will also install [`flax`](https://flax.readthedocs.io), core [`jax`](https://jax.readthedocs.io), [`optax`](https://optax.readthedocs.io/en/latest/) (the JAX-based gradient processing and optimization library), [`orbax`](https://orbax.readthedocs.io/), and [`sentencepiece`](https://github.com/google/sentencepiece).

In [None]:
!pip install -q git+https://github.com/google-deepmind/gemma.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for gemma (pyproject.toml) ... [?25l[?25hdone


## Load and prepare the Gemma model

1. Load the Gemma model with [`kagglehub.model_download`](https://github.com/Kaggle/kagglehub/blob/bddefc718182282882b72f814d407d89e5d178c4/src/kagglehub/models.py#L12), which takes three arguments:

- `handle`: The model handle from Kaggle
- `path`: (Optional string) The local path
- `force_download`: (Optional boolean) Forces to re-download the model

**Note:** Be mindful that the `gemma-2b-it` model is around 3.7Gb in size.

In [None]:
GEMMA_VARIANT = 'gemma2-2b-it' # @param ['gemma2-2b', 'gemma2-2b-it'] {type:"string"}

In [None]:
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma-2/flax/{GEMMA_VARIANT}')



Downloading 11 files:   0%|          | 0/11 [00:00<?, ?it/s]

Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/flax/gemma2-2b-it/1/download/gemma2-2b-it/_METADATA...
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/flax/gemma2-2b-it/1/download/gemma2-2b-it/_CHECKPOINT_METADATA...




100%|██████████| 92.0/92.0 [00:00<00:00, 107kB/s]

100%|██████████| 55.3k/55.3k [00:00<00:00, 48.8MB/s]

Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/flax/gemma2-2b-it/1/download/gemma2-2b-it/descriptor/descriptor.pbtxt...




100%|██████████| 45.0/45.0 [00:00<00:00, 165kB/s]

Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/flax/gemma2-2b-it/1/download/gemma2-2b-it/ocdbt.process_0/manifest.ocdbt...





Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/flax/gemma2-2b-it/1/download/gemma2-2b-it/ocdbt.process_0/d/834bb4bf1e3854eb09f6208c95c071b2...




  0%|          | 0.00/1.70G [00:00<?, ?B/s][A[A
100%|██████████| 180/180 [00:00<00:00, 431kB/s]

Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/flax/gemma2-2b-it/1/download/gemma2-2b-it/ocdbt.process_0/d/bf69258061ae5f35eb7a5669fe6877d4...




  0%|          | 0.00/2.12G [00:00<?, ?B/s][A

Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/flax/gemma2-2b-it/1/download/gemma2-2b-it/manifest.ocdbt...





100%|██████████| 118/118 [00:00<00:00, 215kB/s]

Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/flax/gemma2-2b-it/1/download/gemma2-2b-it/d/b5a4695f4be0a2f41ec1e25616ebd7e7...






100%|██████████| 2.66k/2.66k [00:00<00:00, 3.49MB/s]




Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/flax/gemma2-2b-it/1/download/gemma2-2b-it/ocdbt.process_0/d/fc20151969d7ca91ea9d8275bda0e219...


  0%|          | 1.00M/1.70G [00:00<11:36, 2.62MB/s][A[A
  0%|          | 1.00M/2.12G [00:00<14:01, 2.70MB/s][A


100%|██████████| 2.64k/2.64k [00:00<00:00, 6.47MB/s]

Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/flax/gemma2-2b-it/1/download/tokenizer.model...






  0%|          | 0.00/4.04M [00:00<?, ?B/s][A[A[A

  0%|          | 3.00M/1.70G [00:00<04:09, 7.29MB/s][A[A
  0%|          | 3.00M/2.12G [00:00<05:05, 7.41MB/s][A

  1%|          | 9.00M/1.70G [00:00<01:25, 21.1MB/s][A[A
  0%|          | 9.00M/2.12G [00:00<01:46, 21.2MB/s][A

  1%|          | 15.0M/1.70G [00:00<00:56, 31.9MB/s][A[A

Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/flax/gemma2-2b-it/1/download/gemma2-2b-it/checkpoint...






100%|██████████| 22.5k/22.5k [00:00<00:00, 29.4MB/s]

  1%|          | 14.0M/2.12G [00:00<01:35, 23.7MB/s][A

  1%|          | 21.0M/1.70G [00:00<00:49, 36.7MB/s][A[A


 25%|██▍       | 1.00M/4.04M [00:00<00:01, 2.69MB/s][A[A[A


 74%|███████▍  | 3.00M/4.04M [00:00<00:00, 7.46MB/s][A[A[A
100%|██████████| 4.04M/4.04M [00:00<00:00, 7.95MB/s]


  2%|▏         | 27.0M/1.70G [00:01<00:48, 37.1MB/s][A[A

  2%|▏         | 32.0M/1.70G [00:01<00:44, 40.6MB/s][A[A
  1%|▏         | 28.0M/2.12G [00:01<01:11, 31.2MB/s][A

  2%|▏         | 37.0M/1.70G [00:01<00:46, 38.5MB/s][A[A
  2%|▏         | 35.0M/2.12G [00:01<01:05, 34.2MB/s][A

  2%|▏         | 43.0M/1.70G [00:01<00:43, 41.1MB/s][A[A

  3%|▎         | 50.0M/1.70G [00:01<00:38, 46.3MB/s][A[A
  2%|▏         | 41.0M/2.12G [00:01<01:06, 33.4MB/s][A

  3%|▎         | 56.0M/1.70G [00:01<00:37, 47.3MB/s][A[A

  4%|▎         | 62.0M/1.70G [00:01<00:36, 48.5MB/s][A[A

  4%|▍         | 67.0M/1.70G [00:01<00:35, 49.3MB/s][A

In [None]:
print('GEMMA_PATH:', GEMMA_PATH)

GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma-2/flax/gemma2-2b-it/1


**Note:** The path from the output above is where the model weights and tokenizer are saved locally, you will need them for later.

2. Check the location of the model weights and the tokenizer, then set the path variables. The tokenizer directory will be in the main directory where you downloaded the model, while the model weights will be in a sub-directory. For example:

- The `tokenizer.model` file will be in `/LOCAL/PATH/TO/gemma/flax/2b-it/2`).
- The model checkpoint will be in `/LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it`).

In [None]:
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)

CKPT_PATH: /root/.cache/kagglehub/models/google/gemma-2/flax/gemma2-2b-it/1/gemma2-2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma-2/flax/gemma2-2b-it/1/tokenizer.model


## Perform sampling/inference

1. Load and format the Gemma model checkpoint with the [`gemma.params.load_and_format_params`](https://github.com/google-deepmind/gemma/blob/c6bd156c246530e1620a7c62de98542a377e3934/gemma/params.py#L27) method:

In [None]:
from gemma import params as params_lib

params = params_lib.load_and_format_params(CKPT_PATH)

2. Load the Gemma tokenizer, constructed using [`sentencepiece.SentencePieceProcessor`](https://github.com/google/sentencepiece/blob/4d6a1f41069c4636c51a5590f7578a0dbed83450/python/src/sentencepiece/__init__.py#L423):

In [None]:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)

True

3. To automatically load the correct configuration from the Gemma model checkpoint, use [`gemma.transformer.TransformerConfig`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L65). The `cache_size` argument is the number of time steps in the Gemma `Transformer` cache. Afterwards, instantiate the Gemma model as `transformer` with [`gemma.transformer.Transformer`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L136) (which inherits from [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html)).

**Note:** The vocabulary size is smaller than the number of input embeddings because of unused tokens in the current Gemma release.

In [None]:
from gemma import transformer as transformer_lib

transformer_config = transformer_lib.TransformerConfig.from_params(
    params=params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(transformer_config)

3. Create a `sampler` with [`gemma.sampler.Sampler`](https://github.com/google-deepmind/gemma/blob/c6bd156c246530e1620a7c62de98542a377e3934/gemma/sampler.py#L88) on top of the Gemma model checkpoint/weights and the tokenizer:

In [None]:
from gemma import sampler as sampler_lib

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)

4. Write a prompt in `input_batch` and perform inference. You can tweak `total_generation_steps` (the number of steps performed when generating a response — this example uses `100` to preserve host memory).

**Note:** If you run out of memory, click on **Runtime** > **Disconnect and delete runtime**, and then **Runtime** > **Run all**.

In [None]:
prompt = [
    "Given this FEN: r6k/pp2r2p/4Rp1Q/3p4/8/1N1P2R1/PqP2bPP/7K b - - 0 24, what is the correct next move? Output the piece's origin square and its destination square. Do not print any explanation.",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=128,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(out_string.removesuffix("<end_of_turn>").strip())

e4


5. (Optional) Run this cell to free up memory if you have completed the notebook and want to try another prompt. Afterwards, you can instantiate the `sampler` again in step 3 and customize and run the prompt in step 4.

In [None]:
del sampler

## Learn more

- You can learn more about the Google DeepMind [`gemma`  library on GitHub](https://github.com/google-deepmind/gemma), which contains docstrings of modules you used in this tutorial, such as [`gemma.params`](https://github.com/google-deepmind/gemma/blob/main/gemma/params.py),
[`gemma.transformer`](https://github.com/google-deepmind/gemma/blob/main/gemma/transformer.py), and
[`gemma.sampler`](https://github.com/google-deepmind/gemma/blob/main/gemma/sampler.py).
- The following libraries have their own documentation sites: [core JAX](https://jax.readthedocs.io), [Flax](https://flax.readthedocs.io), and [Orbax](https://orbax.readthedocs.io/).
- For `sentencepiece` tokenizer/detokenizer documentation, check out [Google's `sentencepiece` GitHub repo](https://github.com/google/sentencepiece).
- For `kagglehub` documentation, check out `README.md` on [Kaggle's `kagglehub` GitHub repo](https://github.com/Kaggle/kagglehub).
- Learn how to [use Gemma models with Google Cloud Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).

In [None]:
import os
import enum
import re
import string

import chex
import jax
import jax.numpy as jnp
import optax

import tensorflow as tf
import tensorflow_datasets as tfds

from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

In [None]:
!wget https://raw.githubusercontent.com/sam0109/llm_chess/refs/heads/main/lichess_puzzle_db.csv

--2024-12-25 05:49:20--  https://raw.githubusercontent.com/sam0109/llm_chess/refs/heads/main/lichess_puzzle_db.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1842940 (1.8M) [text/plain]
Saving to: ‘lichess_puzzle_db.csv.1’


2024-12-25 05:49:20 (41.7 MB/s) - ‘lichess_puzzle_db.csv.1’ saved [1842940/1842940]



In [None]:
import pandas as pd

puzzles = pd.read_csv('lichess_puzzle_db.csv.1')
puzzles.head()

Unnamed: 0,PuzzleId,FEN,Moves,Rating,RatingDeviation,Popularity,NbPlays,Themes,GameUrl,OpeningTags
0,00008,r6k/pp2r2p/4Rp1Q/3p4/8/1N1P2R1/PqP2bPP/7K b - ...,f2g3 e6e7 b2b1 b3c1 b1c1 h6c1,1838,75,95,7002,crushing hangingPiece long middlegame,https://lichess.org/787zsVup/black#48,
1,0000D,5rk1/1p3ppp/pq3b2/8/8/1P1Q1N2/P4PPP/3R2K1 w - ...,d3d6 f8d8 d6d8 f6d8,1455,74,96,29264,advantage endgame short,https://lichess.org/F8M8OS71#53,
2,0008Q,8/4R3/1p2P3/p4r2/P6p/1P3Pk1/4K3/8 w - - 1 64,e7f7 f5e5 e2f1 e5e6,1314,75,90,658,advantage endgame rookEndgame short,https://lichess.org/MQSyb3KW#127,
3,0009B,r2qr1k1/b1p2ppp/pp4n1/P1P1p3/4P1n1/B2P2Pb/3NBP...,b6c5 e2g4 h3g4 d1g4,1099,74,87,571,advantage middlegame short,https://lichess.org/4MWQCxQ6/black#32,Kings_Pawn_Game Kings_Pawn_Game_Leonardis_Vari...
4,000VW,r4r2/1p3pkp/p5p1/3R1N1Q/3P4/8/P1q2P2/3R2K1 b -...,g6f5 d5c5 c2e4 h5g5 g7h8 g5f6,2830,88,92,112,crushing endgame long,https://lichess.org/e9AY2m5j/black#50,
