# PicoGPT: GPT-2 in 60 Lines of NumPy

This notebook demonstrates how to implement GPT-2 from scratch using only NumPy.

**References:**
- Original blog post: [GPT in 60 Lines of NumPy](https://jaykmody.com/blog/gpt-from-scratch/)
- Chinese translation: [60行NumPy手搓GPT](https://zhuanlan.zhihu.com/p/640935459)
- GitHub: [picoGPT](https://github.com/jaymody/picoGPT)

---

# PicoGPT：用60行NumPy实现GPT-2

本notebook演示如何仅使用NumPy从零实现GPT-2。

**参考资料:**
- 原文: [GPT in 60 Lines of NumPy](https://jaykmody.com/blog/gpt-from-scratch/)
- 中文翻译: [60行NumPy手搓GPT](https://zhuanlan.zhihu.com/p/640935459)

## 1. Install Dependencies / 安装依赖

First, let's install the required packages.

首先，安装必要的依赖包。

In [None]:
!pip install numpy regex requests tqdm tensorflow

## 2. Import Libraries / 导入库

In [None]:
import json
import os
import re
from functools import lru_cache

import numpy as np
import regex
import requests
import tensorflow as tf
from tqdm import tqdm

## 3. BPE Tokenizer / BPE分词器

GPT-2 uses Byte Pair Encoding (BPE) for tokenization. This code is from OpenAI's official implementation.

GPT-2使用字节对编码（BPE）进行分词。这段代码来自OpenAI的官方实现。

In [None]:
@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    返回utf-8字节和对应的unicode字符串列表。
    """
    bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8 + n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))


def get_pairs(word):
    """Return set of symbol pairs in a word. / 返回单词中的符号对集合"""
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


class Encoder:
    """BPE Encoder for GPT-2 / GPT-2的BPE编码器"""
    
    def __init__(self, encoder, bpe_merges, errors="replace"):
        self.encoder = encoder
        self.decoder = {v: k for k, v in self.encoder.items()}
        self.errors = errors
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        self.cache = {}
        self.pat = regex.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token)
        pairs = get_pairs(word)

        if not pairs:
            return token

        while True:
            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
                    new_word.append(first + second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = " ".join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        """Encode text to token ids / 将文本编码为token id"""
        bpe_tokens = []
        for token in regex.findall(self.pat, text):
            token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
        return bpe_tokens

    def decode(self, tokens):
        """Decode token ids to text / 将token id解码为文本"""
        text = "".join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
        return text


def get_encoder(model_name, models_dir):
    """Load encoder from files / 从文件加载编码器"""
    with open(os.path.join(models_dir, model_name, "encoder.json"), "r") as f:
        encoder = json.load(f)
    with open(os.path.join(models_dir, model_name, "vocab.bpe"), "r", encoding="utf-8") as f:
        bpe_data = f.read()
    bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
    return Encoder(encoder=encoder, bpe_merges=bpe_merges)

## 4. Model Loading Utilities / 模型加载工具

Functions to download and load pre-trained GPT-2 weights from OpenAI.

下载和加载OpenAI预训练GPT-2权重的函数。

In [None]:
def download_gpt2_files(model_size, model_dir):
    """Download GPT-2 model files from OpenAI / 从OpenAI下载GPT-2模型文件"""
    assert model_size in ["124M", "355M", "774M", "1558M"]
    for filename in [
        "checkpoint",
        "encoder.json",
        "hparams.json",
        "model.ckpt.data-00000-of-00001",
        "model.ckpt.index",
        "model.ckpt.meta",
        "vocab.bpe",
    ]:
        url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
        r = requests.get(f"{url}/{model_size}/{filename}", stream=True)
        r.raise_for_status()

        with open(os.path.join(model_dir, filename), "wb") as f:
            file_size = int(r.headers["content-length"])
            chunk_size = 1000
            with tqdm(
                ncols=100,
                desc="Fetching " + filename,
                total=file_size,
                unit_scale=True,
                unit="b",
            ) as pbar:
                for chunk in r.iter_content(chunk_size=chunk_size):
                    f.write(chunk)
                    pbar.update(chunk_size)


def load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams):
    """Load GPT-2 parameters from TensorFlow checkpoint / 从TensorFlow检查点加载GPT-2参数"""
    def set_in_nested_dict(d, keys, val):
        if not keys:
            return val
        if keys[0] not in d:
            d[keys[0]] = {}
        d[keys[0]] = set_in_nested_dict(d[keys[0]], keys[1:], val)
        return d

    params = {"blocks": [{} for _ in range(hparams["n_layer"])]}
    for name, _ in tf.train.list_variables(tf_ckpt_path):
        array = np.squeeze(tf.train.load_variable(tf_ckpt_path, name))
        name = name[len("model/") :]
        if name.startswith("h"):
            m = re.match(r"h([0-9]+)/(.*)", name)
            n = int(m[1])
            sub_name = m[2]
            set_in_nested_dict(params["blocks"][n], sub_name.split("/"), array)
        else:
            set_in_nested_dict(params, name.split("/"), array)

    return params


def load_encoder_hparams_and_params(model_size, models_dir):
    """Load encoder, hyperparameters and model parameters / 加载编码器、超参数和模型参数"""
    assert model_size in ["124M", "355M", "774M", "1558M"]

    model_dir = os.path.join(models_dir, model_size)
    tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
    if not tf_ckpt_path:  # download files if necessary
        os.makedirs(model_dir, exist_ok=True)
        download_gpt2_files(model_size, model_dir)
        tf_ckpt_path = tf.train.latest_checkpoint(model_dir)

    encoder = get_encoder(model_size, models_dir)
    hparams = json.load(open(os.path.join(model_dir, "hparams.json")))
    params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams)

    return encoder, hparams, params

## 5. GPT-2 Model Implementation / GPT-2模型实现

Now for the exciting part - the actual GPT-2 implementation in pure NumPy!

现在是激动人心的部分——用纯NumPy实现GPT-2！

### 5.1 Basic Building Blocks / 基础构建块

In [None]:
def gelu(x):
    """
    GELU activation function (Gaussian Error Linear Unit)
    GELU激活函数（高斯误差线性单元）
    
    This is the activation function used in GPT-2, which is smoother than ReLU.
    这是GPT-2使用的激活函数，比ReLU更平滑。
    """
    return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))


def softmax(x):
    """
    Softmax function with numerical stability.
    带数值稳定性的Softmax函数。
    
    Subtracting max(x) prevents overflow when computing exp(x).
    减去max(x)可以防止计算exp(x)时溢出。
    """
    exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return exp_x / np.sum(exp_x, axis=-1, keepdims=True)


def layer_norm(x, g, b, eps=1e-5):
    """
    Layer Normalization
    层归一化
    
    Normalizes the input to have mean=0 and variance=1, then scales and shifts.
    将输入归一化为均值=0和方差=1，然后进行缩放和偏移。
    """
    mean = np.mean(x, axis=-1, keepdims=True)
    variance = np.var(x, axis=-1, keepdims=True)
    x = (x - mean) / np.sqrt(variance + eps)  # normalize
    return g * x + b  # scale and offset with gamma/beta params


def linear(x, w, b):
    """
    Linear transformation: y = xW + b
    线性变换: y = xW + b
    
    Shape: [m, in] @ [in, out] + [out] -> [m, out]
    """
    return x @ w + b

### 5.2 Feed-Forward Network / 前馈网络

In [None]:
def ffn(x, c_fc, c_proj):
    """
    Position-wise Feed-Forward Network
    位置前馈网络
    
    Shape: [n_seq, n_embd] -> [n_seq, n_embd]
    
    Two linear transformations with GELU activation in between.
    两个线性变换，中间是GELU激活。
    
    1. Project up: n_embd -> 4*n_embd
    2. GELU activation
    3. Project back down: 4*n_embd -> n_embd
    """
    # project up
    a = gelu(linear(x, **c_fc))  # [n_seq, n_embd] -> [n_seq, 4*n_embd]

    # project back down
    x = linear(a, **c_proj)  # [n_seq, 4*n_embd] -> [n_seq, n_embd]

    return x

### 5.3 Attention Mechanism / 注意力机制

The core of the Transformer architecture.

Transformer架构的核心。

In [None]:
def attention(q, k, v, mask):
    """
    Scaled Dot-Product Attention
    缩放点积注意力
    
    Shape: [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k] -> [n_q, d_v]
    
    Attention(Q,K,V) = softmax(QK^T / sqrt(d_k)) * V
    """
    return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v


def mha(x, c_attn, c_proj, n_head):
    """
    Multi-Head Attention
    多头注意力
    
    Shape: [n_seq, n_embd] -> [n_seq, n_embd]
    
    1. Project input to Q, K, V
    2. Split into multiple heads
    3. Apply attention on each head with causal mask
    4. Concatenate heads
    5. Project output
    """
    # qkv projection
    x = linear(x, **c_attn)  # [n_seq, n_embd] -> [n_seq, 3*n_embd]

    # split into qkv
    qkv = np.split(x, 3, axis=-1)  # [n_seq, 3*n_embd] -> [3, n_seq, n_embd]

    # split into heads
    qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv))  # [3, n_head, n_seq, n_embd/n_head]

    # causal mask to hide future inputs from being attended to
    # 因果掩码，防止模型看到未来的token
    causal_mask = (1 - np.tri(x.shape[0], dtype=x.dtype)) * -1e10  # [n_seq, n_seq]

    # perform attention over each head
    out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)]

    # merge heads
    x = np.hstack(out_heads)  # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd]

    # out projection
    x = linear(x, **c_proj)  # [n_seq, n_embd] -> [n_seq, n_embd]

    return x

### 5.4 Transformer Block / Transformer块

In [None]:
def transformer_block(x, mlp, attn, ln_1, ln_2, n_head):
    """
    Single Transformer Block
    单个Transformer块
    
    Shape: [n_seq, n_embd] -> [n_seq, n_embd]
    
    Each block consists of:
    每个块包含：
    1. Layer Norm + Multi-Head Attention + Residual Connection
       层归一化 + 多头注意力 + 残差连接
    2. Layer Norm + Feed-Forward Network + Residual Connection
       层归一化 + 前馈网络 + 残差连接
    """
    # multi-head causal self attention
    x = x + mha(layer_norm(x, **ln_1), **attn, n_head=n_head)

    # position-wise feed forward network
    x = x + ffn(layer_norm(x, **ln_2), **mlp)

    return x

### 5.5 GPT-2 Forward Pass / GPT-2前向传播

In [None]:
def gpt2(inputs, wte, wpe, blocks, ln_f, n_head):
    """
    GPT-2 Model Forward Pass
    GPT-2模型前向传播
    
    Shape: [n_seq] -> [n_seq, n_vocab]
    
    1. Token embeddings + Positional embeddings
       词嵌入 + 位置嵌入
    2. Pass through N transformer blocks
       通过N个Transformer块
    3. Final layer norm
       最终层归一化
    4. Project to vocabulary
       投影到词表
    """
    # token + positional embeddings
    x = wte[inputs] + wpe[range(len(inputs))]  # [n_seq] -> [n_seq, n_embd]

    # forward pass through n_layer transformer blocks
    for block in blocks:
        x = transformer_block(x, **block, n_head=n_head)

    # projection to vocab (using tied weights with wte)
    x = layer_norm(x, **ln_f)  # [n_seq, n_embd] -> [n_seq, n_embd]
    return x @ wte.T  # [n_seq, n_embd] -> [n_seq, n_vocab]

### 5.6 Text Generation / 文本生成

In [None]:
def generate(inputs, params, n_head, n_tokens_to_generate):
    """
    Auto-regressive text generation using greedy sampling.
    使用贪婪采样的自回归文本生成。
    
    For each new token:
    对于每个新token：
    1. Run forward pass to get logits
       运行前向传播获取logits
    2. Take argmax of last position's logits (greedy sampling)
       取最后位置logits的argmax（贪婪采样）
    3. Append new token to input
       将新token添加到输入
    4. Repeat
       重复
    """
    for _ in tqdm(range(n_tokens_to_generate), "generating"):
        logits = gpt2(inputs, **params, n_head=n_head)  # model forward pass
        next_id = np.argmax(logits[-1])  # greedy sampling
        inputs.append(int(next_id))  # append prediction to input

    return inputs[len(inputs) - n_tokens_to_generate :]  # only return generated ids

## 6. Main Function / 主函数

Put it all together!

整合所有部分！

In [None]:
def main(prompt, n_tokens_to_generate=40, model_size="124M", models_dir="models"):
    """
    Main function to generate text.
    生成文本的主函数。
    
    Args:
        prompt: Input text prompt / 输入文本提示
        n_tokens_to_generate: Number of tokens to generate / 要生成的token数量
        model_size: Model size (124M, 355M, 774M, 1558M) / 模型大小
        models_dir: Directory to store/load models / 存储/加载模型的目录
    """
    # load encoder, hparams, and params from the released open-ai gpt-2 files
    # 从OpenAI发布的GPT-2文件加载编码器、超参数和参数
    encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir)

    # encode the input string using the BPE tokenizer
    # 使用BPE分词器编码输入字符串
    input_ids = encoder.encode(prompt)

    # make sure we are not surpassing the max sequence length of our model
    # 确保我们没有超过模型的最大序列长度
    assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]

    # generate output ids
    # 生成输出id
    output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)

    # decode the ids back into a string
    # 将id解码回字符串
    output_text = encoder.decode(output_ids)

    return output_text

## 7. Run Demo / 运行演示

Let's test our GPT-2 implementation! The first run will automatically download the pre-trained model weights (~500MB for 124M).

让我们测试GPT-2实现！首次运行会自动下载预训练模型权重（124M约500MB）。

> **About the weights / 关于模型权重:**
>
> This notebook implements the GPT-2 **inference (forward pass)** from scratch using NumPy — that's the ~60 lines of code above. However, the model also needs **pre-trained weights** (the billions of numbers learned during training) to produce meaningful output.
>
> 本notebook用NumPy从零实现了GPT-2的**推理（前向传播）**——就是上面约60行代码。但模型还需要**预训练权重**（训练过程中学到的数十亿个数字）才能产生有意义的输出。
>
> The weights are **openly released by OpenAI** under the MIT license. They are downloaded from OpenAI's public storage:
> `https://openaipublic.blob.core.windows.net/gpt-2/models/`
>
> 这些权重由**OpenAI以MIT许可证开源发布**，从OpenAI的公开存储下载：
> `https://openaipublic.blob.core.windows.net/gpt-2/models/`
>
> We use the **124M** (smallest) version here. Available sizes:
>
> 这里使用的是**124M**（最小）版本。可用的模型大小：
>
> | Model | Layers | Hidden Size | Heads | Params | Download Size |
> |-------|--------|-------------|-------|--------|---------------|
> | 124M  | 12     | 768         | 12    | ~124M  | ~500MB        |
> | 355M  | 24     | 1024        | 16    | ~355M  | ~1.5GB        |
> | 774M  | 36     | 1280        | 20    | ~774M  | ~3GB          |
> | 1558M | 48     | 1600        | 25    | ~1.5B  | ~6GB          |
>
> **In short: the code is ours, the weights are OpenAI's.**
>
> **简而言之：代码是我们从零实现的，权重是OpenAI开源的。**

In [None]:
# Example 1: Alan Turing quote
# 示例1：艾伦·图灵名言
prompt = "Alan Turing theorized that computers would one day become"
print(f"Prompt: {prompt}")
print(f"\nGenerating...\n")

output = main(prompt, n_tokens_to_generate=40)
print(f"\n{'='*50}")
print(f"Generated text: {output}")
print(f"{'='*50}")
print(f"\nFull output: {prompt}{output}")

In [None]:
# Example 2: Try your own prompt!
# 示例2：尝试你自己的提示！

# Change this to whatever you want
# 改成你想要的任何内容
your_prompt = "The meaning of life is"

print(f"Prompt: {your_prompt}")
print(f"\nGenerating...\n")

output = main(your_prompt, n_tokens_to_generate=50)
print(f"\n{'='*50}")
print(f"Full output: {your_prompt}{output}")

## 9. Model Architecture Summary / 模型架构总结

```
GPT-2 Architecture (124M model):
GPT-2架构（124M模型）:

├── Token Embedding (wte): [50257, 768]     # 词嵌入
├── Position Embedding (wpe): [1024, 768]   # 位置嵌入
├── 12 Transformer Blocks                   # 12个Transformer块
│   ├── Layer Norm 1                        # 层归一化1
│   ├── Multi-Head Attention (12 heads)     # 多头注意力（12头）
│   │   ├── Q, K, V Projection              # Q、K、V投影
│   │   ├── Scaled Dot-Product Attention    # 缩放点积注意力
│   │   └── Output Projection               # 输出投影
│   ├── Residual Connection                 # 残差连接
│   ├── Layer Norm 2                        # 层归一化2
│   ├── Feed-Forward Network                # 前馈网络
│   │   ├── Linear (768 -> 3072)            # 线性变换
│   │   ├── GELU                            # GELU激活
│   │   └── Linear (3072 -> 768)            # 线性变换
│   └── Residual Connection                 # 残差连接
├── Final Layer Norm                        # 最终层归一化
└── Output (tied with wte.T)                # 输出（与wte.T共享权重）

Total Parameters / 总参数量:
- 124M: 12 layers, 768 hidden, 12 heads
- 355M: 24 layers, 1024 hidden, 16 heads
- 774M: 36 layers, 1280 hidden, 20 heads
- 1558M: 48 layers, 1600 hidden, 25 heads
```

## 10. Key Takeaways / 关键要点

1. **The entire GPT-2 forward pass is just ~40 lines of NumPy!**
   
   整个GPT-2前向传播只需约40行NumPy代码！

2. **Core components / 核心组件:**
   - `gelu`: Activation function / 激活函数
   - `softmax`: Probability distribution / 概率分布
   - `layer_norm`: Normalization / 归一化
   - `linear`: Matrix multiplication / 矩阵乘法
   - `attention`: The key innovation / 关键创新
   - `mha`: Multi-head version / 多头版本
   - `ffn`: Position-wise processing / 逐位置处理
   - `transformer_block`: Combines attention + FFN / 组合注意力+FFN
   - `gpt2`: Stack of blocks / 块的堆叠

3. **One-line GPU/TPU acceleration with JAX!**
   
   使用JAX只需一行代码即可切换到GPU/TPU加速！

4. **Limitations of this implementation / 此实现的局限性:**
   - Inference only (no training) / 仅推理（无训练）
   - No batching / 无批处理
   - Greedy sampling only / 仅贪婪采样
   - Slow (pure NumPy) / 速度慢（纯NumPy），但可通过JAX加速

5. **The magic is in the pre-trained weights!**
   
   魔法在于预训练的权重！

In [None]:
!pip install jax jaxlib

In [None]:
import jax
import jax.numpy as jnp

# Check which device JAX is using
# 检查JAX正在使用的设备
print(f"JAX backend: {jax.default_backend()}")
print(f"JAX devices: {jax.devices()}")

In [None]:
# Redefine all model functions using jax.numpy instead of numpy
# 用jax.numpy重新定义所有模型函数（代码完全一样，只是np现在指向jax.numpy）

# --- The only change: np = jax.numpy ---
np = jnp

def gelu_jax(x):
    return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))

def softmax_jax(x):
    exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

def layer_norm_jax(x, g, b, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    variance = np.var(x, axis=-1, keepdims=True)
    x = (x - mean) / np.sqrt(variance + eps)
    return g * x + b

def linear_jax(x, w, b):
    return x @ w + b

def ffn_jax(x, c_fc, c_proj):
    a = gelu_jax(linear_jax(x, **c_fc))
    x = linear_jax(a, **c_proj)
    return x

def attention_jax(q, k, v, mask):
    return softmax_jax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v

def mha_jax(x, c_attn, c_proj, n_head):
    x = linear_jax(x, **c_attn)
    qkv = np.split(x, 3, axis=-1)
    qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv))
    causal_mask = (1 - np.tri(x.shape[0], dtype=x.dtype)) * -1e10
    out_heads = [attention_jax(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)]
    x = np.hstack(out_heads)
    x = linear_jax(x, **c_proj)
    return x

def transformer_block_jax(x, mlp, attn, ln_1, ln_2, n_head):
    x = x + mha_jax(layer_norm_jax(x, **ln_1), **attn, n_head=n_head)
    x = x + ffn_jax(layer_norm_jax(x, **ln_2), **mlp)
    return x

def gpt2_jax(inputs, wte, wpe, blocks, ln_f, n_head):
    x = wte[inputs] + wpe[list(range(len(inputs)))]
    for block in blocks:
        x = transformer_block_jax(x, **block, n_head=n_head)
    x = layer_norm_jax(x, **ln_f)
    return x @ wte.T

def generate_jax(inputs, params, n_head, n_tokens_to_generate):
    for _ in tqdm(range(n_tokens_to_generate), "generating (JAX)"):
        logits = gpt2_jax(inputs, **params, n_head=n_head)
        next_id = int(np.argmax(logits[-1]))
        inputs.append(next_id)
    return inputs[len(inputs) - n_tokens_to_generate :]

# Restore np to numpy for other cells
# 恢复np为numpy，避免影响其他cell
import numpy
np = numpy

print("JAX model functions defined!")

In [None]:
# Compare NumPy (CPU) vs JAX (GPU/TPU) speed
# 对比NumPy（CPU）与JAX（GPU/TPU）的速度
import time

prompt = "Alan Turing theorized that computers would one day become"
n_tokens = 20

# Load model (reuse if already loaded)
# 加载模型（如果已经加载过则复用）
encoder, hparams, params = load_encoder_hparams_and_params("124M", "models")
input_ids = encoder.encode(prompt)

# --- NumPy (CPU) ---
print("=" * 50)
print("NumPy (CPU):")
input_ids_np = list(input_ids)  # copy
start = time.time()
output_ids_np = generate(input_ids_np, params, hparams["n_head"], n_tokens)
time_np = time.time() - start
print(f"Output: {encoder.decode(output_ids_np)}")
print(f"Time: {time_np:.2f}s")

# --- JAX (GPU/TPU) ---
print("\n" + "=" * 50)
print(f"JAX ({jax.default_backend().upper()}):")

# Convert params to JAX arrays for GPU acceleration
# 将参数转换为JAX数组以启用GPU加速
def to_jax(d):
    if isinstance(d, dict):
        return {k: to_jax(v) for k, v in d.items()}
    elif isinstance(d, list):
        return [to_jax(v) for v in d]
    else:
        return jnp.array(d)

params_jax = to_jax(params)

input_ids_jax = list(input_ids)  # copy
start = time.time()
output_ids_jax = generate_jax(input_ids_jax, params_jax, hparams["n_head"], n_tokens)
time_jax = time.time() - start
print(f"Output: {encoder.decode(output_ids_jax)}")
print(f"Time: {time_jax:.2f}s")

# Summary
# 总结
print("\n" + "=" * 50)
print(f"Speedup / 加速比: {time_np / time_jax:.1f}x")
print(f"(Note: JAX's first run includes JIT compilation overhead)")
print(f"(注意：JAX首次运行包含JIT编译开销)")