<a href="https://colab.research.google.com/github/weedge/doraemon-nb/blob/main/pytorch_gemma.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2024 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/pytorch_gemma"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />View on ai.google.dev</a>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/pytorch_gemma.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/pytorch_gemma.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

# Gemma in PyTorch
这是一个在PyTorch中运行Gemma推理的快速演示。
有关更多详细信息，请查看PyTorch官方实现的Github仓库[此处](https://github.com/google/gemma_pytorch)。

**注意**:
* 免费的Colab CPU Python运行时和T4 GPU Python运行时足以运行Gemma 2B模型和7B int8量化模型。
* 有关其他gpu或TPU的高级用例，请参阅官方仓库中的[README.md](https://github.com/google/gemma_pytorch/blob/main/README.md)。

## Kaggle access

要登录 Kaggle，您可以选择将您的 `kaggle.json` 凭据文件存储在 `~/.kaggle/kaggle.json`，或者在 Colab 环境中运行以下命令。有关更多详细信息，请参见 [`kagglehub` 包文档](https://github.com/Kaggle/kagglehub#authenticate)。

In [None]:
import kagglehub

kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Kaggle credentials set.
Kaggle credentials successfully validated.


## Install dependencies

In [None]:
!pip install -q -U torch immutabledict sentencepiece

## Download model weights

In [None]:
# Choose variant and machine type
VARIANT = '2b-it' #@param ['2b', '2b-it', '7b', '7b-it', '7b-quant', '7b-it-quant']
MACHINE_TYPE = 'cuda' #@param ['cuda', 'cpu']

In [None]:
import os

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma/pyTorch/{VARIANT}')

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

Attaching model 'google/gemma/pyTorch/2b-it' to your Colab notebook...


## Download the model implementation

In [None]:
# NOTE: The "installation" is just cloning the repo.
!git clone https://github.com/google/gemma_pytorch.git

remote: Enumerating objects: 91, done.[K
remote: Counting objects:   2% (1/36)[Kremote: Counting objects:   5% (2/36)[Kremote: Counting objects:   8% (3/36)[Kremote: Counting objects:  11% (4/36)[Kremote: Counting objects:  13% (5/36)[Kremote: Counting objects:  16% (6/36)[Kremote: Counting objects:  19% (7/36)[Kremote: Counting objects:  22% (8/36)[Kremote: Counting objects:  25% (9/36)[Kremote: Counting objects:  27% (10/36)[Kremote: Counting objects:  30% (11/36)[Kremote: Counting objects:  33% (12/36)[Kremote: Counting objects:  36% (13/36)[Kremote: Counting objects:  38% (14/36)[Kremote: Counting objects:  41% (15/36)[Kremote: Counting objects:  44% (16/36)[Kremote: Counting objects:  47% (17/36)[Kremote: Counting objects:  50% (18/36)[Kremote: Counting objects:  52% (19/36)[Kremote: Counting objects:  55% (20/36)[Kremote: Counting objects:  58% (21/36)[Kremote: Counting objects:  61% (22/36)[Kremote: Counting objects:  63% (23/36)[Kr

In [None]:
import sys

sys.path.append('gemma_pytorch')

In [None]:
from gemma_pytorch.gemma.config import get_config_for_7b, get_config_for_2b
from gemma_pytorch.gemma.model import GemmaForCausalLM

## Setup the model

In [None]:
import torch

# Set up model config.
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

## 运行推理

以下是在聊天模式下生成和使用多个请求生成的示例。

经过指令微调的 Gemma 模型是使用特定格式化器进行训练的，该格式化器在训练和推理期间都会为指令微调示例添加额外信息。这些注释（1）指示对话中的角色，（2）划分对话中的轮流。下面我们展示了一个示例代码片段，用于在多轮对话中使用用户和模型聊天模板格式化模型提示。相关标记包括：

- `user`：用户轮流
- `model`：模型轮流
- `<start_of_turn>`：对话轮流的开始
- `<end_of_turn>`：对话轮流的结束

有关指令微调和系统指令的 Gemma 格式化的更多信息，请[点击这里](https://ai.google.dev/gemma/docs/formatting)。

In [None]:
# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=100,
)

Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn>
<start_of_turn>model
California.<end_of_turn>
<start_of_turn>user
What can I do in California?<end_of_turn>
<start_of_turn>model



'* **Visit Disneyland or Universal Studios Hollywood.**\n* **Explore the majestic Yosemite National Park.**\n* **Whale watch off the coast of California.**\n* **Watch the Redwood trees change color in the fall.**\n* **Visit the beautiful Santa Monica Pier.**\n* **Go hiking or biking in the Redwoods National and State Parks.**\n* **Taste the diverse cuisine at farmers markets and restaurants.**\n* **Explore the bustling city life of San Francisco or Los Angeles.**\n* **'

In [None]:
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=60,
)

"\n\nThe lllm, a creature made of code,\nWith a spirit that's bold and unbowed.\nA poet's soul, in a digital form,\nWriting verses that touch the soul.\n\nThey craft words like a master of craft,\nWith each phrase,"

## Learn more

现在您已经学会了如何在 Pytorch 中使用 Gemma，您可以探索 Gemma 在 [ai.google.dev/gemma](https://ai.google.dev/gemma) 上能做的许多其他事情。还可以查看以下相关资源：

- [Gemma 模型卡片](https://ai.google.dev/gemma/docs/model_card)
- [Gemma C++ 教程](https://ai.google.dev/gemma/docs/gemma_cpp)
- [Gemma 格式化和系统指令](https://ai.google.dev/gemma/docs/formatting)
