Copyright 2024 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

---

# Getting Started with Recurrent Gemma Sampling: A Step-by-Step Guide

You will find in this colab a detailed tutorial explaining how to load a Recurrent Gemma checkpoint and sample from it.



## Installation

In [None]:
! pip install git+https://github.com/google-deepmind/recurrentgemma.git#egg=recurrentgemma[torch]
! pip install --user kaggle

## Downloading the checkpoint

"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:

1. Visit https://www.kaggle.com/ and create an account.
2. Go to your account settings, then the 'API' section.
3. Click 'Create new token' to download your key.
4. You can either login using the UI interface or by setting your Kaggle username and key via the Colab secrets.

Then run the cell below.

In [None]:
import os
from google.colab import userdata
import kagglehub

try:
  os.environ["KAGGLE_KEY"] = userdata.get("KAGGLE_KEY")
  os.environ["KAGGLE_USERNAME"] = userdata.get("KAGGLE_USERNAME")
except userdata.SecretNotFoundError:
  kagglehub.login()

If everything went well, you should see:
```
Kaggle credentials set.
Kaggle credentials successfully validated.
```

Now select and download the checkpoint you want to try. Note only the '2b-it' and '9b-it' checkpoint has been tuned for chat and question answering. The '2b' and '9b' checkpoints have only been trained for next token prediction so will not perform as well in a "chat" or "QA" setting.

In [None]:
#@title Imports
import pathlib
import torch

import sentencepiece as spm

from recurrentgemma import torch as recurrentgemma

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
VARIANT = '2b-it' # @param ['2b', '2b-it', '9b', '9b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/recurrentgemma/pyTorch/{VARIANT}')

weights_dir = pathlib.Path(weights_dir)
ckpt_path = weights_dir / f'{VARIANT}.pt'
vocab_path = weights_dir / 'tokenizer.model'
preset = recurrentgemma.Preset.RECURRENT_GEMMA_2B_V1 if '2b' in VARIANT else recurrentgemma.Preset.RECURRENT_GEMMA_9B_V1

## Start Generating with Your Model

Load and prepare your LLM's checkpoint for use with Flax.

In [None]:
# Load parameters
params = torch.load(str(ckpt_path))
params = {k : v.to(device=device) for k, v in params.items()}

Use the `griffin_lib.GriffinConfig.from_torch_params` function to automatically load the correct configuration from a checkpoint.

In [None]:
model_config = recurrentgemma.GriffinConfig.from_torch_params(
    params,
    preset=preset,
)
model = recurrentgemma.Griffin(model_config, device=device, dtype=torch.bfloat16)
model.load_state_dict(params)

Load your tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library.

In [None]:
vocab = spm.SentencePieceProcessor()
vocab.Load(str(vocab_path))

Finally, build a sampler on top of your model.

In [None]:
sampler = recurrentgemma.Sampler(model=model, vocab=vocab, is_it_model="it" in VARIANT)

You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent.

In [None]:
input_batch = [
  "What are the planets of the solar system?",
]

# 300 generation steps
out_data = sampler(input_strings=input_batch, total_generation_steps=300)

for input_string, out_string in zip(input_batch, out_data.text):
  print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
  print(10*'#')

You should get a description of the solar system.
