# Gemma in PyTorch

This is a quick demo of running Gemma inference in PyTorch.
For more details, please check out the Github repo of the official PyTorch implementation [here](https://github.com/google/gemma_pytorch).


## Kaggle access

To login to Kaggle, you can either store your `kaggle.json` credentials file at
`~/.kaggle/kaggle.json` or run the following in a Colab environment. See the
[`kagglehub` package documentation](https://github.com/Kaggle/kagglehub#authenticate)
for more details.

In [1]:
import kagglehub

kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Kaggle credentials set.
Kaggle credentials successfully validated.


## Install dependencies

In [2]:
!pip install -q -U torch immutabledict sentencepiece

## Download model weights

In [3]:
# Choose variant and machine type
VARIANT = '2b-it' #@param ['2b', '2b-it', '7b', '7b-it', '7b-quant', '7b-it-quant']
MACHINE_TYPE = 'cpu' #@param ['cuda', 'cpu']

In [4]:
import os

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma/pyTorch/{VARIANT}')

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

Downloading from https://www.kaggle.com/api/v1/models/google/gemma/pyTorch/2b-it/2/download...
100%|██████████| 3.75G/3.75G [00:53<00:00, 75.9MB/s]
Extracting model files...


## Download the model implementation

In [5]:
# NOTE: The "installation" is just cloning the repo.
!git clone https://github.com/google/gemma_pytorch.git

fatal: destination path 'gemma_pytorch' already exists and is not an empty directory.


In [6]:
import sys

sys.path.append('gemma_pytorch')

In [7]:
from gemma_pytorch.gemma.config import get_config_for_7b, get_config_for_2b
from gemma_pytorch.gemma.model import GemmaForCausalLM

## Setup the model

In [8]:
import torch

# Set up model config.
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

## Run inference

Below are examples for generating in chat mode and generating with multiple
requests.

The instruction-tuned Gemma models were trained with a specific formatter that
annotates instruction tuning examples with extra information, both during
training and inference. The annotations (1) indicate roles in a conversation,
and (2) delineate turns in a conversation. Below we show a sample code snippet
for formatting the model prompt using the user and model chat templates in a
multi-turn conversation. The relevant tokens are:

- `user`: user turn
- `model`: model turn
- `<start_of_turn>`: beginning of dialogue turn
- `<end_of_turn>`: end of dialogue turn

Read about the Gemma formatting for instruction tuning and system instructions
[here](https://ai.google.dev/gemma/docs/formatting).

In [9]:
# Chat templates
USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='I have been feeling anxious lately. Can you suggest some coping strategies?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='Sure! Mindfulness exercises and deep breathing techniques can help.')
    + USER_CHAT_TEMPLATE.format(prompt='How about dealing with stress at work?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=100,
)


Chat prompt:
 <start_of_turn>user
I have been feeling anxious lately. Can you suggest some coping strategies?<end_of_turn>
<start_of_turn>model
Sure! Mindfulness exercises and deep breathing techniques can help.<end_of_turn>
<start_of_turn>user
How about dealing with stress at work?<end_of_turn>
<start_of_turn>model



'**Mindfulness Meditation:**\n- Focus on your breath, sensations of your body, and thoughts.\n- Inhale for 4 counts, hold your breath for 7 counts, and exhale for 8 counts.\n- Repeat this for 5-10 minutes.\n\n**Progressive Muscle Relaxation:**\n- Tense and release different muscle groups in your body, one by one.\n- Contract your muscles for a few seconds, then release them.\n- Continue this process from head to'

In [10]:
# Generate sample
model.generate(
    'What are some self-care practices that can help with depression?',
    device=device,
    output_len=60,
)

'\n\n**Answer:**\n\n**Self-care practices that can help with depression:**\n\n**Physical Health:**\n\n* **Exercise regularly:** Aim for at least 30 minutes of moderate-intensity exercise most days of the week.\n* **Eat a healthy diet:** Focus on consuming fruits, vegetables'