<table style="width:100%">
<tr>
<td style="vertical-align:middle; text-align:left;">
<font size="2">
Supplementary code for the <a href="http://mng.bz/orYv">Build a Large Language Model From Scratch</a> book by <a href="https://sebastianraschka.com">Sebastian Raschka</a><br>
<br>Code repository: <a href="https://github.com/rasbt/LLMs-from-scratch">https://github.com/rasbt/LLMs-from-scratch</a>
</font>
</td>
<td style="vertical-align:middle; text-align:left;">
<a href="http://mng.bz/orYv"><img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp" width="100px"></a>
</td>
</tr>
</table>

# Gemma 3 270M From Scratch (A Standalone Notebook)

- This notebook is purposefully minimal and focuses on the code to re-implement Gemma 3 270M in pure PyTorch without relying on other external LLM libraries
- For more information, see the official [Gemma 3 270M model card](https://huggingface.co/google/gemma-3-270m)

- Below is a side-by-side comparison with Qwen3 0.6B as a reference model; if you are interested in the Qwen3 0.6B standalone notebook, you can find it [here](../11_qwen3)
<br>

<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gemma3/gemma3-vs-qwen3.webp?1">
  
  
- About the code:
  - all code is my own code, mapping the Gemma 3 architecture onto the model code implemented in my [Build A Large Language Model (From Scratch)](http://mng.bz/orYv) book; the code is released under a permissive open-source Apache 2.0 license (see [LICENSE.txt](https://github.com/rasbt/LLMs-from-scratch/blob/main/LICENSE.txt))

In [28]:
!pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt

Collecting ipywidgets>=8.1.2 (from -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt (line 3))
  Downloading ipywidgets-8.1.7-py3-none-any.whl.metadata (2.4 kB)
Collecting comm>=0.1.3 (from ipywidgets>=8.1.2->-r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt (line 3))
  Downloading comm-0.2.3-py3-none-any.whl.metadata (3.7 kB)
Collecting widgetsnbextension~=4.0.14 (from ipywidgets>=8.1.2->-r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt (line 3))
  Downloading widgetsnbextension-4.0.14-py3-none-any.whl.metadata (1.6 kB)
Collecting jedi>=0.16 (from ipython>=6.1.0->ipywidgets>=8.1.2->-r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt (line 3))
  Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Do

In [2]:
from importlib.metadata import version

# 定义需要检查版本的包列表
pkgs = [
    "huggingface_hub",  # 用于下载预训练权重
    "tokenizers",       # 用于实现分词器
    "torch",            # 用于实现模型
]
# 遍历包列表并打印每个包的版本
for p in pkgs:
    print(f"{p} version: {version(p)}")

huggingface_hub version: 0.34.4
tokenizers version: 0.21.4
torch version: 2.6.0+cu124


- This notebook supports both the base model and the instructmodel; which model to use can be controlled via the following flag:

In [1]:
# 定义一个布尔标志，用于控制是否使用指令模型
USE_INSTRUCT_MODEL = True

&nbsp;
# 1. Architecture code

In [2]:
import torch
import torch.nn as nn

# 定义前馈网络类
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        # 定义第一个线性层，输入维度 emb_dim，输出维度 hidden_dim
        self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
        # 定义第二个线性层，输入维度 emb_dim，输出维度 hidden_dim
        self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
        # 定义第三个线性层，输入维度 hidden_dim，输出维度 emb_dim
        self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)

    # 定义前向传播
    def forward(self, x):
        # 应用第一个线性层
        x_fc1 = self.fc1(x)
        # 应用第二个线性层
        x_fc2 = self.fc2(x)
        # 应用 GELU 激活函数并与 x_fc2 相乘
        x = nn.functional.gelu(x_fc1, approximate="tanh") * x_fc2
        # 应用第三个线性层
        return self.fc3(x)

In [3]:
# 定义 RMSNorm 类
class RMSNorm(nn.Module):
    def __init__(self, emb_dim, eps=1e-6, bias=False):
        super().__init__()
        self.eps = eps # 定义 epsilon，用于防止除零
        # Gemma3 存储零中心权重并在前向传播中使用 (1 + weight)
        self.scale = nn.Parameter(torch.zeros(emb_dim)) # 定义可学习的缩放参数
        self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None # 定义可选的可学习的偏移参数

    # 定义前向传播
    def forward(self, x):
        # 匹配 HF Gemma3：在 float32 中计算范数，然后按 (1 + w) 缩放
        input_dtype = x.dtype # 保存输入数据的原始数据类型
        x_f = x.float() # 将输入数据转换为 float32
        var = x_f.pow(2).mean(dim=-1, keepdim=True) # 计算输入数据的方差
        x_norm = x_f * torch.rsqrt(var + self.eps) # 应用 RMS 归一化
        out = x_norm * (1.0 + self.scale.float()) # 应用缩放

        # 如果存在偏移参数，则应用偏移
        if self.shift is not None:
            out = out + self.shift.float()

        return out.to(input_dtype) # 将输出数据转换回原始数据类型

In [5]:
# 计算 RoPE (Rotary Positional Embedding) 参数
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
    assert head_dim % 2 == 0, "Embedding dimension must be even" # 确保 embedding 维度是偶数

    # 计算逆频率
    # 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))
    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))

    # 生成位置索引
    positions = torch.arange(context_length, dtype=dtype)

    # 计算角度
    # positions[:, None] 的形状是 (context_length, 1)
    # inv_freq[None, :] 的形状是 (1, head_dim // 2)
    # 相乘得到形状为 (context_length, head_dim // 2) 的角度
    angles = positions[:, None] * inv_freq[None, :]

    # 扩展角度以匹配 head_dim
    # 将角度自身连接起来，得到形状为 (context_length, head_dim)
    angles = torch.cat([angles, angles], dim=1)

    # 预计算 sine 和 cosine
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin # 返回 cosine 和 sine

# 应用 RoPE
def apply_rope(x, cos, sin):
    # x: (batch_size, num_heads, seq_len, head_dim)
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "Head dimension must be even" # 确保 head 维度是偶数

    # 将 x 分割成前半部分和后半部分
    x1 = x[..., : head_dim // 2]  # 前半部分
    x2 = x[..., head_dim // 2 :]  # 后半部分

    # 调整 sin 和 cos 的形状以进行广播
    # unsqueeze(0) 两次，添加批次维度和头维度，形状变为 (1, 1, seq_len, head_dim)
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

    # 应用旋转变换
    # 将 x2 的符号取反并与 x1 连接，得到旋转后的向量
    rotated = torch.cat((-x2, x1), dim=-1)
    # 应用 RoPE 公式: x * cos + rotated * sin
    x_rotated = (x * cos) + (rotated * sin)

    # 应用 cos 和 sin 旋转后可以使用较低精度
    return x_rotated.to(dtype=x.dtype)

In [6]:
# 定义分组查询注意力机制 (Grouped Query Attention)
class GroupedQueryAttention(nn.Module):
    def __init__(
        self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False,
        query_pre_attn_scalar=None, dtype=None,
    ):
        super().__init__()
        # 确保 num_heads 能被 num_kv_groups 整除
        assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"

        self.num_heads = num_heads # 查询头的数量
        self.num_kv_groups = num_kv_groups # KV 组的数量
        self.group_size = num_heads // num_kv_groups # 每个 KV 组对应的查询头数量

        # 如果 head_dim 没有设置，则根据 d_in 和 num_heads 计算
        if head_dim is None:
            assert d_in % num_heads == 0, "`d_in` must be divisible by `num_heads` if `head_dim` is not set"
            head_dim = d_in // num_heads

        self.head_dim = head_dim # 每个头的维度
        self.d_out = num_heads * head_dim # 输出维度

        # 定义查询、键、值的线性投影层
        self.W_query = nn.Linear(d_in, self.d_out, bias=False, dtype=dtype) # 查询投影
        self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype) # 键投影
        self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype) # 值投影

        # 定义输出投影层
        self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)

        # 如果需要进行 QK 归一化，则定义 Q 归一化和 K 归一化层
        if qk_norm:
            self.q_norm = RMSNorm(head_dim, eps=1e-6)
            self.k_norm = RMSNorm(head_dim, eps=1e-6)
        else:
            self.q_norm = self.k_norm = None

        # 计算缩放因子
        if query_pre_attn_scalar is not None:
            self.scaling = (query_pre_attn_scalar) ** -0.5
        else:
            self.scaling = (head_dim) ** -0.5


    # 定义前向传播
    def forward(self, x, mask, cos, sin):
        b, num_tokens, _ = x.shape # 获取批次大小和序列长度

        # 应用投影层
        queries = self.W_query(x)  # (b, num_tokens, num_heads * head_dim)
        keys = self.W_key(x)       # (b, num_tokens, num_kv_groups * head_dim)
        values = self.W_value(x)   # (b, num_tokens, num_kv_groups * head_dim)

        # 重塑形状以进行多头注意力计算
        # 将 num_heads * head_dim 拆分成 num_heads 和 head_dim
        # 转置维度 1 和 2，使头的维度在前
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        # 将 num_kv_groups * head_dim 拆分成 num_kv_groups 和 head_dim
        # 转置维度 1 和 2，使 KV 组的维度在前
        keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
        values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)

        # 可选的归一化
        if self.q_norm:
            queries = self.q_norm(queries)
        if self.k_norm:
            keys = self.k_norm(keys)

        # 应用 RoPE
        queries = apply_rope(queries, cos, sin)
        keys = apply_rope(keys, cos, sin)

        # 扩展 K 和 V 以匹配查询头的数量
        # 使用 repeat_interleave 将每个 KV 组重复 group_size 次
        keys = keys.repeat_interleave(self.group_size, dim=1)
        values = values.repeat_interleave(self.group_size, dim=1)

        # 缩放查询
        queries = queries * self.scaling

        # 计算注意力分数
        # 查询与键的转置相乘
        attn_scores = queries @ keys.transpose(2, 3)
        # 应用掩码，将需要屏蔽的位置设置为负无穷
        attn_scores = attn_scores.masked_fill(mask, -torch.inf)
        # 应用 softmax 计算注意力权重
        attn_weights = torch.softmax(attn_scores, dim=-1)

        # 计算上下文向量
        # 注意力权重与值相乘
        context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)
        # 应用输出投影层
        return self.out_proj(context)

In [7]:
# 定义 Transformer 块
class TransformerBlock(nn.Module):

    def __init__(self, cfg: dict, attn_type: str):
        super().__init__()
        self.attn_type = attn_type # 定义注意力类型 (滑动窗口注意力或全局注意力)

        # 定义分组查询注意力层
        self.att = GroupedQueryAttention(
            d_in=cfg["emb_dim"],
            num_heads=cfg["n_heads"],
            num_kv_groups=cfg["n_kv_groups"],
            head_dim=cfg["head_dim"],
            qk_norm=cfg["qk_norm"],
            query_pre_attn_scalar=cfg["query_pre_attn_scalar"],
            dtype=cfg["dtype"],
        )
        # 定义前馈网络层
        self.ff = FeedForward(cfg)
        # 定义各种 LayerNorm 层
        self.input_layernorm = RMSNorm(cfg["emb_dim"], eps=1e-6) # 输入 LayerNorm
        self.post_attention_layernorm = RMSNorm(cfg["emb_dim"], eps=1e-6) # 注意力后 LayerNorm
        self.pre_feedforward_layernorm = RMSNorm(cfg["emb_dim"], eps=1e-6) # 前馈网络前 LayerNorm
        self.post_feedforward_layernorm = RMSNorm(cfg["emb_dim"], eps=1e-6) # 前馈网络后 LayerNorm

    # 定义前向传播
    def forward(
        self,
        x,
        mask_global, # 全局注意力掩码
        mask_local, # 滑动窗口注意力掩码
        cos_global, # 全局 RoPE cos
        sin_global, # 全局 RoPE sin
        cos_local, # 滑动窗口 RoPE cos
        sin_local, # 滑动窗口 RoPE sin
    ):
        # 注意力块的 shortcut 连接
        shortcut = x
        # 对输入应用 LayerNorm
        x = self.input_layernorm(x)

        # 根据注意力类型选择对应的掩码和 RoPE 参数
        if self.attn_type == "sliding_attention":
            attn_mask = mask_local
            cos = cos_local
            sin = sin_local
        else:
            attn_mask = mask_global
            cos = cos_global
            sin = sin_global

        # 应用注意力机制
        x_attn = self.att(x, attn_mask, cos, sin)
        # 对注意力输出应用 LayerNorm
        x_attn = self.post_attention_layernorm(x_attn)
        # 将注意力输出加到 shortcut 连接上
        x = shortcut + x_attn

        # 前馈块的 shortcut 连接
        shortcut = x
        # 对输入应用前馈网络前的 LayerNorm
        x_ffn = self.pre_feedforward_layernorm(x)
        # 应用前馈网络
        x_ffn = self.ff(x_ffn)
        # 对前馈网络输出应用后 LayerNorm
        x_ffn = self.post_feedforward_layernorm(x_ffn)
        # 将前馈网络输出加到 shortcut 连接上
        x = shortcut + x_ffn
        return x # 返回 Transformer 块的输出

In [9]:
# 定义 Gemma3 模型
class Gemma3Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        # 确保 layer_types 配置存在且长度等于层数
        assert cfg["layer_types"] is not None and len(cfg["layer_types"]) == cfg["n_layers"]

        # 主要模型参数
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"]) # token embedding 层

        # Transformer 块列表，根据 layer_types 配置构建不同类型的注意力块
        self.blocks = nn.ModuleList([
            TransformerBlock(cfg, attn_type)for attn_type in cfg["layer_types"]
        ])

        self.final_norm = RMSNorm(cfg["emb_dim"], eps=1e-6) # 最终 LayerNorm
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"]) # 输出头 (用于预测下一个 token)
        self.cfg = cfg # 保存配置字典

        # 可重复使用的工具
        # 计算局部 RoPE 参数 (用于滑动窗口注意力)
        cos_local, sin_local = compute_rope_params(
            head_dim=cfg["head_dim"],
            theta_base=cfg["rope_local_base"],
            context_length=cfg["context_length"],
            dtype=torch.float32,
        )
        # 计算全局 RoPE 参数 (用于全局注意力)
        cos_global, sin_global = compute_rope_params(
            head_dim=cfg["head_dim"],
            theta_base=cfg["rope_base"],
            context_length=cfg["context_length"],
            dtype=torch.float32,
        )
        # 将 RoPE 参数注册为 buffer，它们不是模型参数，但需要保存状态
        self.register_buffer("cos_local", cos_local, persistent=False)
        self.register_buffer("sin_local", sin_local, persistent=False)
        self.register_buffer("cos_global", cos_global, persistent=False)
        self.register_buffer("sin_global", sin_global, persistent=False)

    # 创建注意力掩码
    def _create_masks(self, seq_len, device):
        # 创建一个全 1 的布尔张量
        ones = torch.ones((seq_len, seq_len), dtype=torch.bool, device=device)

        # mask_global (掩盖未来：j > i)
        # 使用 triu 创建上三角矩阵，对角线偏移 1
        mask_global = torch.triu(ones, diagonal=1)

        # far_past (掩盖过远的过去：i - j >= sliding_window)
        # 使用 triu 创建上三角矩阵，对角线偏移 sliding_window，然后转置
        far_past = torch.triu(ones, diagonal=self.cfg["sliding_window"]).T

        # Local (sliding_window) = 未来 OR 过远的过去 (掩码)
        # 将 mask_global 和 far_past 进行逻辑或运算
        mask_local = mask_global | far_past
        return mask_global, mask_local # 返回全局掩码和局部掩码

    # 定义前向传播
    def forward(self, input_ids):
        # 获取批次大小和序列长度
        b, seq_len = input_ids.shape
        # 应用 token embedding，并进行缩放 (Gemma 使用特殊的缩放因子)
        x = self.tok_emb(input_ids) * (self.cfg["emb_dim"] ** 0.5)
        # 创建注意力掩码
        mask_global, mask_local = self._create_masks(seq_len, x.device)

        # 遍历每个 Transformer 块并应用前向传播
        for block in self.blocks:
            x = block(
                x,
                mask_global=mask_global,
                mask_local=mask_local,
                cos_global=self.cos_global,
                sin_global=self.sin_global,
                cos_local=self.cos_local,
                sin_local=self.sin_local,
            )

        # 应用最终 LayerNorm
        x = self.final_norm(x)
        # 应用输出头预测 logits，并转换到配置的数据类型
        logits = self.out_head(x.to(self.cfg["dtype"]))
        return logits # 返回 logits

&nbsp;
# 2. Initialize model

In [10]:
# Gemma3 270M 模型的配置字典
GEMMA3_CONFIG_270M = {
    "vocab_size": 262_144, # 词汇表大小
    "context_length": 32_768, # 上下文长度
    "emb_dim": 640, # 嵌入维度
    "n_heads": 4, # 注意力头的数量
    "n_layers": 18, # Transformer 层的数量
    "hidden_dim": 2048, # 前馈网络的隐藏层维度
    "head_dim": 256, # 每个注意力头的维度
    "qk_norm": True, # 是否进行 QK 归一化
    "n_kv_groups": 1, # KV 组的数量 (1 表示多头注意力 MHA)
    "rope_local_base": 10_000.0, # 局部 RoPE 的 theta_base
    "rope_base": 1_000_000.0, # 全局 RoPE 的 theta_base
    "sliding_window": 512, # 滑动窗口大小
      "layer_types": [ # 每层的注意力类型
        "sliding_attention", # 滑动窗口注意力
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention", # 全局注意力
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention"
    ],
    "dtype": torch.bfloat16, # 模型使用的数据类型
    "query_pre_attn_scalar": 256, # 查询在注意力计算前的缩放因子
}

In [11]:
torch.manual_seed(123) # 设置随机种子以确保可复现性
model = Gemma3Model(GEMMA3_CONFIG_270M) # 使用配置初始化 Gemma3 模型

In [12]:
model # 打印模型结构

Gemma3Model(
  (tok_emb): Embedding(262144, 640)
  (blocks): ModuleList(
    (0-17): 18 x TransformerBlock(
      (att): GroupedQueryAttention(
        (W_query): Linear(in_features=640, out_features=1024, bias=False)
        (W_key): Linear(in_features=640, out_features=256, bias=False)
        (W_value): Linear(in_features=640, out_features=256, bias=False)
        (out_proj): Linear(in_features=1024, out_features=640, bias=False)
        (q_norm): RMSNorm()
        (k_norm): RMSNorm()
      )
      (ff): FeedForward(
        (fc1): Linear(in_features=640, out_features=2048, bias=False)
        (fc2): Linear(in_features=640, out_features=2048, bias=False)
        (fc3): Linear(in_features=2048, out_features=640, bias=False)
      )
      (input_layernorm): RMSNorm()
      (post_attention_layernorm): RMSNorm()
      (pre_feedforward_layernorm): RMSNorm()
      (post_feedforward_layernorm): RMSNorm()
    )
  )
  (final_norm): RMSNorm()
  (out_head): Linear(in_features=640, out_features

- A quick check that the forward pass works before continuing:

In [13]:
# 对输入张量进行一次前向传播，检查模型是否正常工作
# torch.tensor([1, 2, 3]) 创建一个张量
# .unsqueeze(0) 在第 0 维增加一个批次维度
model(torch.tensor([1, 2, 3]).unsqueeze(0))

tensor([[[ 0.7500,  0.1055,  0.4844,  ...,  0.9414,  0.3984, -0.2354],
         [-0.3418, -0.0542,  0.8945,  ..., -0.2383,  0.4590,  0.8242],
         [-0.2676, -0.3301,  0.4141,  ...,  0.8672, -0.9688,  0.9844]]],
       dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)

In [14]:
# 计算模型参数总数
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params:,}") # 打印参数总数

# 考虑权重共享，计算唯一的参数总数
# 总参数减去 token embedding 层的参数数量（因为输出头与 token embedding 共享权重）
total_params_normalized = total_params - model.tok_emb.weight.numel()
print(f"\nTotal number of unique parameters: {total_params_normalized:,}") # 打印唯一的参数总数

Total number of parameters: 435,870,336

Total number of unique parameters: 268,098,176


In [14]:
# 计算模型内存大小
def model_memory_size(model, input_dtype=torch.float32):
    total_params = 0 # 参数总数
    total_grads = 0 # 梯度总数
    # 遍历模型参数
    for param in model.parameters():
        # 计算每个参数的元素总数
        param_size = param.numel()
        total_params += param_size
        # 检查参数是否需要计算梯度
        if param.requires_grad:
            total_grads += param_size

    # 计算 buffer 的大小 (非参数但需要内存)
    total_buffers = sum(buf.numel() for buf in model.buffers())

    # 计算总内存大小 (字节) = (参数数量 + 梯度数量 + buffer 数量) * 元素大小
    # 假设参数和梯度使用与输入 dtype 相同的数据类型存储
    element_size = torch.tensor(0, dtype=input_dtype).element_size()
    total_memory_bytes = (total_params + total_grads + total_buffers) * element_size

    # 将字节转换为千兆字节 (GB)
    total_memory_gb = total_memory_bytes / (1024**3)

    return total_memory_gb # 返回内存大小 (GB)

# 打印不同数据类型下的模型内存大小
print(f"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB")
print(f"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB")

float32 (PyTorch default): 3.37 GB
bfloat16: 1.69 GB


In [15]:
# 检查是否有可用的 CUDA 设备，否则检查 MPS 设备，最后使用 CPU
if torch.cuda.is_available():
    device = torch.device("cuda") # 使用 CUDA
elif torch.backends.mps.is_available():
    device = torch.device("mps") # 使用 MPS (Mac)
else:
    device = torch.device("cpu") # 使用 CPU

model.to(device); # 将模型移动到选定的设备上

&nbsp;
# 4. Load pretrained weights

In [16]:
# 将预训练权重加载到 Gemma 模型中
def load_weights_into_gemma(Gemma3Model, param_config, params):

    # 辅助函数：将源张量赋值给目标张量，并进行形状检查
    def assign(left, right, tensor_name="unknown"):
        if left.shape != right.shape:
            raise ValueError(
                f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}"
            )
        # 克隆并分离张量，然后转换为 nn.Parameter
        return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))

    # 加载 Embedding 权重
    if "model.embed_tokens.weight" in params:
        model.tok_emb.weight = assign(
            model.tok_emb.weight,
            params["model.embed_tokens.weight"],
            "model.embed_tokens.weight",
        )

    # 遍历 Transformer 层并加载权重
    for l in range(param_config["n_layers"]):
        block = model.blocks[l] # 获取当前 Transformer 块
        att = block.att # 获取当前注意力层
        # 注意力投影权重
        att.W_query.weight = assign(
            att.W_query.weight,
            params[f"model.layers.{l}.self_attn.q_proj.weight"],
            f"model.layers.{l}.self_attn.q_proj.weight",
        )
        att.W_key.weight = assign(
            att.W_key.weight,
            params[f"model.layers.{l}.self_attn.k_proj.weight"],
            f"model.layers.{l}.self_attn.k_proj.weight",
        )
        att.W_value.weight = assign(
            att.W_value.weight,
            params[f"model.layers.{l}.self_attn.v_proj.weight"],
            f"model.layers.{l}.self_attn.v_proj.weight",
        )
        att.out_proj.weight = assign(
            att.out_proj.weight,
            params[f"model.layers.{l}.self_attn.o_proj.weight"],
            f"model.layers.{l}.self_attn.o_proj.weight",
        )
        # QK 归一化权重
        att.q_norm.scale = assign(
            att.q_norm.scale,
            params[f"model.layers.{l}.self_attn.q_norm.weight"],
            f"model.layers.{l}.self_attn.q_norm.weight",
        )
        att.k_norm.scale = assign(
            att.k_norm.scale,
            params[f"model.layers.{l}.self_attn.k_norm.weight"],
            f"model.layers.{l}.self_attn.k_norm.weight",
        )
        # 前馈网络权重
        block.ff.fc1.weight = assign(
            block.ff.fc1.weight,
            params[f"model.layers.{l}.mlp.gate_proj.weight"],
            f"model.layers.{l}.mlp.gate_proj.weight",
        )
        block.ff.fc2.weight = assign(
            block.ff.fc2.weight,
            params[f"model.layers.{l}.mlp.up_proj.weight"],
            f"model.layers.{l}.mlp.up_proj.weight",
        )
        block.ff.fc3.weight = assign(
            block.ff.fc3.weight,
            params[f"model.layers.{l}.mlp.down_proj.weight"],
            f"model.layers.{l}.mlp.down_proj.weight",
        )
        # LayerNorm 权重
        block.input_layernorm.scale = assign(
            block.input_layernorm.scale,
            params[f"model.layers.{l}.input_layernorm.weight"],
            f"model.layers.{l}.input_layernorm.weight",
        )
        block.post_attention_layernorm.scale = assign(
            block.post_attention_layernorm.scale,
            params[f"model.layers.{l}.post_attention_layernorm.weight"],
            f"model.layers.{l}.post_attention_layernorm.weight",
        )
        # 前馈网络前和后 LayerNorm 权重
        pre_key = f"model.layers.{l}.pre_feedforward_layernorm.weight"
        post_key = f"model.layers.{l}.post_feedforward_layernorm.weight"
        if pre_key in params:
            block.pre_feedforward_layernorm.scale = assign(
                block.pre_feedforward_layernorm.scale,
                params[pre_key],
                pre_key,
            )
        if post_key in params:
            block.post_feedforward_layernorm.scale = assign(
                block.post_feedforward_layernorm.scale,
                params[post_key],
                post_key,
            )

    # 加载最终 LayerNorm 权重
    if "model.norm.weight" in params:
        model.final_norm.scale = assign(
            model.final_norm.scale,
            params["model.norm.weight"],
            "model.norm.weight",
        )
    # 加载输出头权重
    if "lm_head.weight" in params:
        model.out_head.weight = assign(
            model.out_head.weight,
            params["lm_head.weight"],
            "lm_head.weight",
        )
    elif "model.embed_tokens.weight" in params:
        # 权重共享：重用 embedding 权重作为输出头权重
        model.out_head.weight = assign(
            model.out_head.weight,
            params["model.embed_tokens.weight"],
            "model.embed_tokens.weight",
        )

- Please note that Google requires that you accept the Gemma 3 licensing terms before you can download the files; to do this, you have to create a Hugging Face Hub account and visit the [google/gemma-3-270m]https://huggingface.co/google/gemma-3-270m) repository to accept the terms
- Next, you will need to create an access token; to generate an access token with READ permissions, click on the profile picture in the upper right and click on "Settings"


<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/settings.webp?1" width="300px">

- Then, create and copy the access token so you can copy & paste it into the next code cell

<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/access-token.webp?1" width="600px">

In [29]:
# 如果是第一次运行 notebook，取消注释并运行以下代码进行 huggingface_hub 登录

from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [30]:
from google.colab import output
output.enable_custom_widget_manager()

Support for third party widgets will remain active for the duration of the session. To disable support:

In [31]:
from google.colab import output
output.disable_custom_widget_manager()

In [32]:
import json
import os
from pathlib import Path
from safetensors.torch import load_file # 从 safetensors 文件加载权重
from huggingface_hub import hf_hub_download, snapshot_download # 从 Hugging Face Hub 下载文件
from google.colab import userdata # 导入用于访问 Colab Secrets 的模块
from huggingface_hub import login # 导入 login 函数

# 从 Colab Secrets 获取 Hugging Face 令牌并登录
try:
    hf_token = userdata.get('HF_TOKEN')
    if hf_token:
        login(token=hf_token)
        print("Successfully logged in to Hugging Face Hub.")
    else:
        print("Hugging Face token not found in Colab Secrets. Please add it.")
except Exception as e:
    print(f"An error occurred during Hugging Face login: {e}")


CHOOSE_MODEL = "270m" # 选择模型大小

# 根据 USE_INSTRUCT_MODEL 标志确定 Hugging Face Hub 仓库 ID
if USE_INSTRUCT_MODEL:
    repo_id = f"google/gemma-3-{CHOOSE_MODEL}-it" # 指令模型仓库 ID
else:
    repo_id = f"google/gemma-3-{CHOOSE_MODEL}" # 基础模型仓库 ID


local_dir = Path(repo_id).parts[-1] # 定义本地存储目录

# 根据模型大小选择下载方式
if CHOOSE_MODEL == "270m":
    # 对于 270M 模型，直接下载 model.safetensors 文件
    weights_file = hf_hub_download(
        repo_id=repo_id, # 仓库 ID
        filename="model.safetensors", # 文件名
        local_dir=local_dir, # 本地存储目录
    )
    weights_dict = load_file(weights_file) # 从 safetensors 文件加载权重到字典
else:
    # 对于其他模型大小 (如果存在分片)，下载整个仓库快照
    repo_dir = snapshot_download(repo_id=repo_id, local_dir=local_dir)
    index_path = os.path.join(repo_dir, "model.safetensors.index.json") # 权重索引文件路径
    with open(index_path, "r") as f:
        index = json.load(f) # 加载权重索引

    weights_dict = {} # 初始化权重字典
    # 遍历索引中的文件名并加载权重
    for filename in set(index["weight_map"].values()):
        shard_path = os.path.join(repo_dir, filename) # 分片文件路径
        shard = load_file(shard_path) # 加载分片权重
        weights_dict.update(shard) # 更新权重字典

# 将加载的权重加载到模型中
load_weights_into_gemma(model, GEMMA3_CONFIG_270M, weights_dict)
model.to(device) # 将模型移动到设备
del weights_dict # 删除权重字典以释放内存

Successfully logged in to Hugging Face Hub.


model.safetensors:   0%|          | 0.00/536M [00:00<?, ?B/s]

&nbsp;
# 4. Load tokenizer

In [33]:
from tokenizers import Tokenizer # 导入 Tokenizer 类

# 定义 Gemma 分词器类
class GemmaTokenizer:
    def __init__(self, tokenizer_file_path: str):
        tok_file = Path(tokenizer_file_path) # 创建 Path 对象
        self._tok = Tokenizer.from_file(str(tok_file)) # 从文件加载分词器
        # 尝试识别 EOS 和 padding token
        eos_token = "<end_of_turn>" # 定义 EOS token 字符串
        self.pad_token_id = eos_token # 设置 padding token ID
        self.eos_token_id = eos_token # 设置 EOS token ID

    # 将文本编码为 token ID 列表
    def encode(self, text: str) -> list[int]:
        return self._tok.encode(text).ids

    # 将 token ID 列表解码为文本
    def decode(self, ids: list[int]) -> str:
        return self._tok.decode(ids, skip_special_tokens=False) # 不跳过特殊 token

# 应用聊天模板，格式化用户输入
def apply_chat_template(user_text):
    return f"<start_of_turn>user\n{user_text}<end_of_turn>\n<start_of_turn>model\n"

In [34]:
tokenizer_file_path = os.path.join(local_dir, "tokenizer.json") # 构建 tokenizer 文件路径
# 如果文件不存在，则尝试从 Hugging Face Hub 下载
if not os.path.exists(tokenizer_file_path):
    try:
        tokenizer_file_path = hf_hub_download(repo_id=repo_id, filename="tokenizer.json", local_dir=local_dir)
    except Exception as e:
        print(f"Warning: failed to download tokenizer.json: {e}")
        tokenizer_file_path = "tokenizer.json" # 如果下载失败，使用默认文件名

tokenizer = GemmaTokenizer(tokenizer_file_path=tokenizer_file_path) # 初始化 Gemma 分词器

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

In [40]:
prompt = "深入浅出地用简体中文解释一下统计学中的三门问题" # 定义原始 prompt
# 应用聊天模板格式化 prompt
prompt = apply_chat_template("深入浅出地用简体中文解释一下统计学中的三门问题")


input_token_ids = tokenizer.encode(prompt) # 将 prompt 编码为 token ID 列表
text = tokenizer.decode(input_token_ids) # 将 token ID 列表解码回文本
text # 打印解码后的文本 (包含特殊 token)

'<bos><start_of_turn>user\n深入浅出地用简体中文解释一下统计学中的三门问题<end_of_turn>\n<start_of_turn>model\n'

&nbsp;
# 5. Generate text

In [41]:
# 定义基本的文本生成器 (流式输出)
def generate_text_basic_stream(model, token_ids, max_new_tokens, eos_token_id=None):

    model.eval() # 将模型设置为评估模式
    with torch.no_grad(): # 禁用梯度计算
        # 循环生成 max_new_tokens 个 token
        for _ in range(max_new_tokens):
            # 前向传播，获取最后一个 token 位置的 logits
            out = model(token_ids)[:, -1]
            # 找到 logits 中概率最大的 token 作为下一个 token
            next_token = torch.argmax(out, dim=-1, keepdim=True)

            # 如果生成了 EOS token，则停止生成
            if (eos_token_id is not None
                   and torch.all(next_token == eos_token_id)):
               break

            yield next_token # 生成下一个 token

            # 将新生成的 token 添加到 token_ids 中
            token_ids = torch.cat([token_ids, next_token], dim=1)

In [42]:
# 将输入 token ID 列表转换为张量，并增加批次维度，移动到设备上
input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)

# 使用流式生成器生成文本并打印
for token in generate_text_basic_stream(
    model=model, # 模型
    token_ids=input_token_ids_tensor, # 输入 token ID 张量
    max_new_tokens=500, # 最大生成 token 数量
    eos_token_id=tokenizer.encode("<end_of_turn>")[-1] # EOS token ID
):
    token_id = token.squeeze(0).tolist() # 从张量中提取 token ID 并转换为列表
    # 解码 token ID 并打印，不换行，并刷新输出缓冲区
    print(
        tokenizer.decode(token_id),
        end="",
        flush=True
    )

好的，下面是深入浅出地用简体中文解释一下统计学中的三门问题：

**1. 检验假设 (Hypothesis Testing)**

*   **定义:** 检验一个假设是否成立，即一个假设的概率大于或等于另一个假设的概率。
*   **原理:** 统计学中，假设检验的核心是检验一个假设的**概率**。 假设检验的目的是确定一个假设的**概率**，即它是否真的存在。
*   **步骤:**
    1.  **确定假设:** 确定一个假设的概率。
    2.  **确定假设的概率:** 确定假设的概率。
    3.  **计算假设的概率:** 计算假设的概率。
    4.  **检验假设:** 检查假设的概率是否大于或等于另一个假设的概率。
*   **例子:**
    *   假设： 假设一个新车价格是 1000 元。
    *   假设的概率： 1/1000
    *   假设的概率： 100%
    *   检验假设： 100% > 1/1000
    *   结果： 假设的概率大于另一个假设的概率。

**2. 统计量 (Statistical Tests)**

*   **定义:** 统计量是统计学中用来检验一个假设的**统计量**。 统计量是衡量一个假设的**统计特征**的指标。
*   **原理:** 统计量是统计学中用来检验一个假设的**统计量**。 统计量是衡量一个假设的**统计量**。
*   **步骤:**
    1.  **确定统计量:** 确定一个统计量。
    2.  **计算统计量:** 计算一个统计量。
    3.  **检验统计量:** 检查统计量是否满足一个**假设**。
*   **例子:**
    *   假设： 假设一个新公司的员工工资是 500 元。
    *   假设的统计量： 500
    *   假设的概率： 50%
    *   假设的概率： 50%
    *   检验统计量： 50% > 50%
    *   结果： 假设的

&nbsp;
# What's next?

- Check out the [README.md](./README.md), to use this model via the `llms_from_scratch` package
- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)

<a href="http://mng.bz/orYv"><img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp" width="100px"></a>