<a href="https://colab.research.google.com/github/weedge/doraemon-nb/blob/main/nano_rl_r1_zero_grpo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

这个项目，名为 **nanoAhaMoment: 单文件“LLM 强化学习”库**，致力于提供一个极致精简、高效且完全透明的大型语言模型强化学习训练实现。

---

### 项目亮点：

* **单 GPU 运行**：无需昂贵的硬件，在单个 GPU 上即可高效训练。（3B使用的A100-80GB, 这里使用A100-40GB/L4来训练1.5B/0.5B模型）
* **零外部 RL 库依赖**：告别 `TRL` 或 `Verl` 等复杂库，所有代码都是手写，确保你对每一个细节都了如指掌。
* **极致效率**：在简化代码的同时，依然保持训练的高效率。
* **支持 3B 基础模型**：适用于参数量适中的基础大型语言模型。（colab中使用0.5B作为测试训练调试）
* **R1-Zero 训练的全参数微调实现**：直接实现 R1-Zero 的训练范式，且支持模型全参数微调。

---

### 设计理念：

**nanoAhaMoment** 的灵感来源于 [TinyZero](https://github.com/Jiayi-Pan/TinyZero) 和 [Mini-R1](https://www.philschmid.de/mini-deepseek-r1) 这类项目，但它更注重于：

* **更简单**：代码结构直观，易于理解。
* **更清晰**：消除不必要的抽象，让每一行代码都清晰可见。
* **更快**：在代码精简的同时，不牺牲运行效率。

这个项目的核心目标是为大型语言模型的强化学习训练提供一个**透明、可控且易于修改**的基础，让你可以真正理解并掌控整个训练过程。

以下是对您提供英文内容的中文翻译：

R1-Zero 可以说是 DeepSeek R1 论文中更有趣的贡献。其核心思想是：取一个刚刚预训练好的大型语言模型（直接从无监督预训练的“烤箱”中取出），并使用强化学习继续对其进行训练，*无需*任何人类反馈或监督。结果呢？模型开始展现出涌现行为，例如自我反思、验证、回溯，这些行为是研究人员至少从 O1 开始就试图通过手工技巧和归纳偏置注入到大型语言模型中的。

在本 notebook 中，我们将**从头开始**构建一个 R1-Zero 风格的训练循环。目标是为 RL 风格的大型语言模型训练创建一个清晰、可修改的基础；一个让您对每个运动部件以及它们如何协同工作一目了然的基础。非常适合进行尝试、扩展或修改。

---

### 为什么是另一个 R1-Zero 实现？

已经有很棒的实现，例如 [TinyZero](https://github.com/Jiayi-Pan/TinyZero) 和 [Mini-R1](https://www.philschmid.de/mini-deepseek-r1)。但它们依赖于成熟的 RL 库（如 `trl` 或 `verl`）来处理训练。

这些库的存在是有充分理由的；大型语言模型的高效 RL 训练处于可扩展训练和快速推理的十字路口。要实现这一点需要大量的工程。但这S也意味着内部结构通常被抽象化，难以阅读，甚至更难调整。

这个 notebook 则不同：**没有抽象，没有隐藏**。您将看到一切，从上到下。一个轻量级、可读的代码库，同时仍遵循最佳实践并在单个 GPU 上高效运行。

### 这个 notebook 到底是什么？

我们将使用 RL 训练一个基础大型语言模型来解决一个推理密集的算法任务。设置如下：

- **模型**：Qwen2.5 3B-Base、Qwen2.5 1.5B-Base、Qwen2.5 0.5B-Base，1.5B和0.5B模型主要用于测试
- **数据集**：Countdown-Tasks-3to4
- **算法**：GRPO（策略梯度的一种变体）

是的，这个任务有点像玩具——但它抓住了 R1-Zero 的精髓：自我反思、验证、回溯，甚至语言切换等涌现行为。这种设置非常适合快速原型设计和实验。

### 这个 notebook 适合谁？

- 任何对大型语言模型 RL 训练感兴趣的人
- 研究人员，尤其是学术界探索语言模型推理的研究人员

### 在开始之前我应该了解什么？

- 熟悉 HuggingFace Transformers 库
- 具有微调大型语言模型的经验
- 熟悉策略梯度方法（有帮助但非必需）

## R1-Zero Recipe

这个项目的核心目标是训练一个基础大型语言模型（LLM），使其能够进行**推理**，并自主地**重新评估**和**改进**其输出，而这一切都无需人工监督。我们将在此 notebook 中实现 DeepSeek R1 论文中提出的一种出奇简单的训练方法。

---

## 训练方法

以下是该方法的概要步骤：

1.  **初始化**：首先，准备一个基础 LLM 和一个数据集。该数据集只包含问题提示（prompts）及其**最终答案**，不包含任何中间推理步骤。
2.  **迭代训练**：对于从 $i = 0$ 到 `NUM_ITERATIONS` 的每个迭代周期：
    * **采样提示**：从数据集中随机抽取一批 $N$ 个提示，记作 $\{x_i\}_{i=1}^N$。
    * **生成响应**：对于每个提示 $x_i$，模型会生成 $G$ 个不同的响应：
        $$y_1, y_2, \cdots, y_G \sim \pi_\theta(y|x)$$
        这 $G$ 个响应在 GRPO 算法中被称为一个“组”（group）。
    * **计算奖励与优势**：为每个生成的响应计算一个奖励 $R_i$，并对这些奖励进行归一化，以计算每个组内的 **GRPO 优势**。
    * **构建训练样本**：创建包含 $N \times G$ 个“回合”（episodes）的列表。每个回合是一个 $(x_i, y_i)$ 对，并附带其对应的优势值。
    * **估计策略梯度**：利用这些回合数据来估计**策略梯度** $\vec{g}_{pg}$。
    * **更新模型参数**：根据估计出的策略梯度更新模型参数：
        $$\theta \leftarrow \theta + \eta \vec{g}_{pg}$$

---

## 代码结构概览

您将看到的代码结构严格遵循上述训练方法，主要由三个核心组件构成：

1.  **回合生成（Episode Generation）**：
    * 负责在每个强化学习迭代中生成 $(x, y)$ 对及其对应的优势值。

2.  **奖励计算（Reward Calculation）**：
    * 用于计算每个生成响应的奖励。

3.  **策略梯度估计（Policy Gradient Estimation）**：
    * 利用生成的回合数据来估计策略梯度并执行模型更新。

最终，这三个组件将协同工作，形成一个简单的循环，逐步训练模型，使其通过强化学习发展出强大的推理能力。

## Checkpoint Playground

在 `notebooks/checkpoint_playground.ipynb` 文件中，您可以加载我们已经用这个 notebook 训练好的模型，并以交互方式测试模型的推理能力。这个 notebook 允许您输入自定义提示（prompts）并观察模型的响应。

# install

安装完，需要重启会话

In [None]:
!pip install -q vllm==0.7.3 deepspeed==0.16.4 datasets==3.3.2 accelerate==1.4.0


In [None]:
!pip install -q flash-attn --no-build-isolation


In [None]:
!pip list | grep -E "torch|transformers|datasets|deepspeed|vllm|wandb|numba|flash|accelerate|numpy"

# run

In [None]:
import os
from pathlib import Path

# Set the environment variables for HuggingFace
# This is done to ensure that the cache directory for HuggingFace is set to a specific location,
# preventing the storage from being overwhelmed with model files and other data.
SCRATCH =  "/content/scratch"
os.environ["HF_HOME"] = f"{SCRATCH}/hf_home"

In [None]:
import json
import socket
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import wandb
from datasets import Dataset
from deepspeed import DeepSpeedEngine
from transformers import AutoTokenizer, PreTrainedModel
from vllm import LLM, SamplingParams

DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant. You first think about the reasoning process in the mind and then provide the user with the answer."
DEFAULT_PROMPT_TEMPLATE = "Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final equation and answer in <answer> </answer> tags, for example <answer>(1 + 2) / (3 * 5)</answer>."


def create_prompt(
    numbers: List[int],
    target: int,
    tokenizer: AutoTokenizer,
    system_message: str = DEFAULT_SYSTEM_MESSAGE,
    prompt_template: str = DEFAULT_PROMPT_TEMPLATE,
) -> str:
    prefix = [
        {"role": "system", "content": system_message},
        {
            "role": "user",
            "content": prompt_template.format(numbers=numbers, target=target),
        },
        {"role": "assistant", "content": "Let me solve this step by step.\n<think>"},
    ]
    return tokenizer.apply_chat_template(prefix, tokenize=False, continue_final_message=True)


def prepare_model_inputs(
    query_token_ids: List[List[int]],
    response_token_ids: List[List[int]],
    advantages: List[List[float]],
    device: torch.device,
) -> Dict[str, torch.Tensor]:
    """
    Prepare padded model inputs with attention masks, labels, and advantages.
    Args:
        query_token_ids: List of query token ids
        response_token_ids: List of response token ids
        advantages: List of lists of advantage values, matching response_token_ids structure
        device: Device to move the tensors to
    Returns:
        Dict with input_ids, attention_mask, labels, and advantages

    Example:
        >>> query_token_ids = [[1, 2, 3], [4, 5]]
        >>> response_token_ids = [[6, 7], [8]]
        >>> advantages = [[0.5, 0.8], [0.3]]
        >>> outputs = prepare_model_inputs(query_token_ids, response_token_ids, advantages, "cuda")
        >>> outputs
        {
            'input_ids': tensor([
                [1, 2, 3, 6, 7],
                [4, 5, 8, 0, 0]
            ]),
            'attention_mask': tensor([
                [1, 1, 1, 1, 1],
                [1, 1, 1, 0, 0]
            ]),
            'labels': tensor([
                [-100, -100, -100, 6, 7],
                [-100, -100, 8, -100, -100]
            ]),
            'advantages': tensor([
                [0.0, 0.0, 0.0, 0.5, 0.5],
                [0.0, 0.0, 0.0, 0.9, 0.0]
            ])
        }
    """
    max_seq_len = max(len(q) + len(r) for q, r in zip(query_token_ids, response_token_ids))
    inputs = {"input_ids": [], "attention_mask": [], "labels": [], "advantages": []}

    pad_token_id = 0  # Doesn't matter, will be masked
    ignore_index = -100

    for query, response, advantage in zip(query_token_ids, response_token_ids, advantages):
        combined_ids = query + response
        seq_len = len(combined_ids)

        # Create padded sequences
        input_ids = combined_ids + [pad_token_id] * (max_seq_len - seq_len)
        attention_mask = [1] * seq_len + [0] * (max_seq_len - seq_len)
        labels = [ignore_index] * len(query) + response + [ignore_index] * (max_seq_len - seq_len)
        advantages_seq = [0.0] * len(query) + advantage + [0.0] * (max_seq_len - seq_len)

        assert len(input_ids) == max_seq_len
        assert len(attention_mask) == max_seq_len
        assert len(labels) == max_seq_len
        assert len(advantages_seq) == max_seq_len

        inputs["input_ids"].append(input_ids)
        inputs["attention_mask"].append(attention_mask)
        inputs["labels"].append(labels)
        inputs["advantages"].append(advantages_seq)

    # Convert to tensors
    return {
        k: torch.tensor(v, dtype=torch.long if k != "advantages" else torch.float, device=device)
        for k, v in inputs.items()
    }


def compute_token_log_probs(
    model: Union[DeepSpeedEngine, PreTrainedModel],
    inputs: Dict[str, torch.Tensor],
    temperature: float,
) -> torch.Tensor:
    """
    Compute log probabilities for each token in the sequence, masked for valid labels only.

    This function:
    1. Runs the model forward pass
    2. Applies temperature scaling to logits
    3. Shifts the sequences for causal language modeling
    4. Computes log probabilities for the actual tokens that appeared in the sequence
    5. Masks the log probabilities to only include valid labels (non -100 positions)

    Args:
        model: The language model (either DeepSpeed-wrapped or regular HuggingFace model)
        inputs: Dictionary containing:
            - input_ids: Tensor of token ids [batch_size, seq_len]
            - attention_mask: Tensor of attention mask [batch_size, seq_len]
            - labels: Tensor of target labels [batch_size, seq_len] with -100 for ignored positions
        temperature: Temperature for scaling the logits before softmax

    Returns:
        torch.Tensor: Log probabilities tensor of shape [batch_size, seq_len-1], where:
            - Each value is the log probability of the actual token that appeared
            - Values are masked to 0.0 for positions where labels were -100
            - The sequence length is reduced by 1 due to the causal shift

    Example:
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
        >>> inputs = {
        ...     "input_ids": torch.tensor([[1, 2, 3]]),
        ...     "attention_mask": torch.tensor([[1, 1, 1]]),
        ...     "labels": torch.tensor([[-100, 2, 3]])
        ... }
        >>> log_probs = compute_token_log_probs(model, inputs, temperature=1.0)
        >>> log_probs.shape
        torch.Size([1, 2])  # batch_size=1, seq_len-1=2
        >>> # First position is 0 (masked), second position has actual log prob
    """
    outputs = model(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        return_dict=True,
        use_cache=False,
    )

    logits = outputs.logits.float() / temperature  # Shape: [batch_size, seq_len, vocab_size]
    shift_logits = logits[..., :-1, :].contiguous()  # Shape: [batch_size, seq_len-1, vocab_size]
    shift_labels = inputs["labels"][..., 1:].contiguous()  # Shape: [batch_size, seq_len-1]

    # Create mask for valid labels
    label_mask = (shift_labels != -100).float()  # Shape: [batch_size, seq_len-1]
    shift_labels[shift_labels == -100] = 0  # Shape: [batch_size, seq_len-1]

    # Calculate log probabilities
    log_probs = torch.log_softmax(shift_logits, dim=-1)  # Shape: [batch_size, seq_len-1, vocab_size]
    log_probs = torch.gather(log_probs, dim=2, index=shift_labels.unsqueeze(2))  # Shape: [batch_size, seq_len-1, 1]
    log_probs = log_probs.squeeze(2)  # Shape: [batch_size, seq_len-1]
    log_probs = log_probs * label_mask  # Shape: [batch_size, seq_len-1]

    return log_probs


def find_free_port():
    """Find a free port on localhost."""
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("", 0))
        s.listen(1)
        port = s.getsockname()[1]
    return port


def evaluate_on_test_set(
    inference_engine: LLM,
    test_dataset: Dataset,
    tokenizer: AutoTokenizer,
    eos_token: str,
    eval_sampling_params: SamplingParams,
    reward_func: Callable[[str, Dict[str, Any]], Tuple[float, Dict[str, float]]],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Evaluate the model on a test dataset by generating responses and computing rewards.

    Args:
        inference_engine: The sglang Engine instance used for text generation
        test_dataset: Dataset containing test samples
        tokenizer: Tokenizer for decoding generated token IDs
        eos_token: End of sequence token string
        eval_sampling_params: Dictionary of parameters for controlling the generation process
        reward_func: Function that computes rewards for generated responses. Takes a response
            string and sample dict as input, returns a tuple of (overall_reward, reward_components)

    Returns:
        Dictionary containing evaluation statistics:
            - response_lengths: List of token counts for each generated response
            - rewards: List of overall reward values for each response
            - non_stop_rate: List of booleans indicating if generation ended for non-stop reason
            - reward_metrics/*: Lists of individual reward component values, prefixed with
              "reward_metrics/"
        episodes: Dictionary containing:
            - all_query_token_ids: List of query token IDs for each episode
            - all_response_token_ids: List of response token IDs for each episode

    Example:
        >>> episodes, episodes_stats = evaluate_on_test_set(
        ...     inference_engine=engine,
        ...     test_dataset=dataset,
        ...     tokenizer=tokenizer,
        ...     eos_token="</s>",
        ...     eval_sampling_params={"temperature": 0.7, "max_tokens": 100},
        ...     reward_func=compute_rewards
        ... )
        >>> print(f"Average reward: {episodes_stats['rewards']:.3f}")
    """
    generations = inference_engine.generate(
        prompt_token_ids=test_dataset["input_ids"], sampling_params=eval_sampling_params
    )

    metrics = {
        "response_lengths": [],
        "rewards": [],
        "non_stop_rate": [],
    }

    all_query_token_ids = []
    all_responses_token_ids = []

    for i, sample in enumerate(test_dataset):
        query_token_ids = sample["input_ids"]
        response_token_ids = generations[i].outputs[0].token_ids
        finish_reason = generations[i].outputs[0].finish_reason

        response = tokenizer.decode(response_token_ids, skip_special_tokens=False)
        reward, reward_components = reward_func(response, sample)

        all_query_token_ids.append(query_token_ids)
        all_responses_token_ids.append(response_token_ids)

        metrics["rewards"].append(reward)
        metrics["non_stop_rate"].append(finish_reason != "stop")
        metrics["response_lengths"].append(len(response_token_ids))
        for k, v in reward_components.items():
            metrics.setdefault(f"reward_metrics/{k}", []).append(v)

    episodes = {
        "all_query_token_ids": all_query_token_ids,
        "all_response_token_ids": all_responses_token_ids,
    }

    return episodes, metrics


def dump_episodes(
    episodes: Dict[str, Any],
    episodes_stats: Dict[str, Any],
    exp_dir: Path,
    tokenizer: AutoTokenizer,
    iteration: int,
    is_eval: bool = False,
) -> wandb.Table:
    query_token_ids = episodes["all_query_token_ids"]
    response_token_ids = episodes["all_response_token_ids"]
    rewards = episodes_stats["rewards"]
    response_lengths = episodes_stats["response_lengths"]

    query_texts = tokenizer.batch_decode(
        query_token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
    )
    response_texts = tokenizer.batch_decode(
        response_token_ids,
        skip_special_tokens=False,
        clean_up_tokenization_spaces=False,
    )

    if not is_eval:
        print(f"########## Example 1 (Reward: {rewards[0]}, Response Length: {response_lengths[0]})")
        print(f"#### Query:\n`{query_texts[0]}`")
        print(f"#### Response:\n`{response_texts[0]}`\n\n")

        print(f"########## Example 2 (Reward: {rewards[1]}, Response Length: {response_lengths[1]})")
        print(f"#### Query:\n`{query_texts[1]}`")
        print(f"#### Response:\n`{response_texts[1]}`\n\n")

    if is_eval:
        episodes_dir = exp_dir / "eval_episodes"
    else:
        episodes_dir = exp_dir / "episodes"
    episodes_dir.mkdir(parents=True, exist_ok=True)

    with open(episodes_dir / f"eps_{iteration:06d}.json", "w") as f:
        json.dump(
            [
                {
                    "query": query_texts[i],
                    "response": response_texts[i],
                    "reward": rewards[i],
                }
                for i in range(len(query_texts))
            ],
            f,
        )

    # Create wandb table
    table = wandb.Table(columns=["query", "response", "reward", "response_length"])
    for i in range(len(query_texts)):
        table.add_data(query_texts[i], response_texts[i], rewards[i], response_lengths[i])

    return table


def find_last_checkpoint(exp_dir: Path) -> Tuple[Optional[Path], Optional[int]]:
    checkpoint_dir = exp_dir / "checkpoints"
    checkpoints = list(checkpoint_dir.glob("ckpt_*"))
    # Filter out directories that don't have a deepspeed subdirectory
    checkpoints = [ckpt for ckpt in checkpoints if (ckpt / "deepspeed").exists()]
    if not checkpoints:
        return None, None
    ckpt_path = max(checkpoints, key=lambda x: int(x.stem.split("_")[-1]))
    ckpt_iter = int(ckpt_path.stem.split("_")[-1])
    return ckpt_path, ckpt_iter


def load_model_into_vllm(model: Union[DeepSpeedEngine, PreTrainedModel], llm: LLM) -> None:
    """
    Load weights from a HuggingFace model (either wrapped in DeepSpeed or not) into a vLLM inference engine.

    This function transfers the weights from a training model to a vLLM inference engine,
    allowing for efficient inference using the updated model weights.

    Args:
        model (Union[DeepSpeedEngine, PreTrainedModel]): The source model to copy weights from.
            Can be either a DeepSpeed-wrapped model or a regular HuggingFace PreTrainedModel.
        vllm (LLM): The target vLLM inference engine to load the weights into.
            Must be already initialized and ready to accept new weights.

    Returns:
        None
    """
    state_dict = model.module.state_dict() if isinstance(model, DeepSpeedEngine) else model.state_dict()
    llm.llm_engine.model_executor.driver_worker.model_runner.model.load_weights(state_dict.items())

### Import the required libraries

In [None]:
import gc
import re
import time
from typing import Any, Dict, List, Tuple, Union

import deepspeed
import numpy as np
import torch
from datasets import load_dataset
from deepspeed import DeepSpeedEngine
from tqdm import trange
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
from vllm import LLM, SamplingParams

import wandb

# Needed to stop DeepSpeed from complaining
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(find_free_port())
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"

## Hyperparameters

我们要开始定义训练的**超参数**了。这些参数大部分都参考了 [Mini-R1](https://www.philschmid.de/mini-deepseek-r1) 的实现。

In [None]:
# Model configuration
#MODEL_NAME = "Qwen/Qwen2.5-3B"
#MODEL_NAME = "Qwen/Qwen2.5-1.5B"

MODEL_NAME = "Qwen/Qwen2.5-0.5B"

MODEL_CHAT_NAME = MODEL_NAME + "-Instruct"

# Dataset configuration
DATASET_NAME = "Jiayi-Pan/Countdown-Tasks-3to4"

# Total number of training iterations
NUM_ITERATIONS = 1000
# Number of episodes to collect per iteration for training
EPISODES_PER_ITERATION = 64
# Number of responses to generate for each input prompt (i.e. group size in GRPO)
GENERATIONS_PER_SAMPLE = 4
# Controls how much the policy can deviate from the reference model
KL_COEFFICIENT = 0.001

# Training hyperparameters
# Batch size for each GPU device during training
PER_DEVICE_BATCH_SIZE = 4
# Learning rate for model updates
LEARNING_RATE = 1e-6

# Sampling parameters
# Maximum number of tokens to generate in each response
MAX_RESPONSE_TOKENS = 1024
# Controls randomness in generation (higher = more random)
TEMPERATURE = 1.0
# Nucleus sampling parameter (1.0 = disabled)
TOP_P = 1.0
# Top-k sampling parameter (-1 = disabled)
TOP_K = -1  # no top k

# DeepSpeed configuration
# DeepSpeed config for the policy model
deepspeed_config = {
    "bf16": {"enabled": True},
    "zero_optimization": {"stage": 2, "overlap_comm": False},
    "train_batch_size": EPISODES_PER_ITERATION,
    "train_micro_batch_size_per_gpu": PER_DEVICE_BATCH_SIZE,
    "gradient_accumulation_steps": EPISODES_PER_ITERATION // PER_DEVICE_BATCH_SIZE,
    "gradient_clipping": 1.0,
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": LEARNING_RATE,
            "betas": (0.9, 0.999),
            "eps": 1e-8,
            "weight_decay": 0.0,
            "torch_adam": True,
        },
    },
}
# DeepSpeed config for the reference model
ref_deepspeed_config = {
    "bf16": {"enabled": True},
    # Note that we don't train the reference model
    # These are just for compatibility with DeepSpeed.
    "train_batch_size": EPISODES_PER_ITERATION,
    "train_micro_batch_size_per_gpu": PER_DEVICE_BATCH_SIZE,
    "gradient_accumulation_steps": EPISODES_PER_ITERATION // PER_DEVICE_BATCH_SIZE,
}

RUN_NAME = "r1-zero"
EXP_DIR = f"{SCRATCH}/deepseek_r1z_hackathon/{RUN_NAME}"
os.makedirs(EXP_DIR, exist_ok=True)
print(f"Logs and Checkpoints will be saved to: {EXP_DIR}")
print(Path(EXP_DIR))

## Generating the training prompts

我们将在训练中使用 [Countdown-Tasks-3to4](https://huggingface.co/datasets/Jiayi-Pan/Countdown-Tasks-3to4) 数据集。这个数据集提供了问题陈述及其最终答案（但不包含推理步骤）。

### 倒计时任务

倒计时游戏是一个数字谜题，玩家必须使用一组随机选择的数字和基本的算术运算：加法、减法、乘法和除法，来达到一个目标数字。每个数字必须且只能使用一次。

示例：

```yaml
目标：622
可用数字：[25, 3, 6, 100]

# 数据集中不提供此内容
解决方案：(100 × 6) + (25 − 3) = 622
```

这项任务非常适合训练大型语言模型练习推理、搜索和自我验证能力。

由于我们使用的是模型的**基础版本**，它只通过原始互联网数据进行了预训练，因此对系统提示或聊天格式没有先验理解。然而，我们仍然会采用**聊天格式**，以确保最终的模型能与期望这种格式的下游工具和框架兼容。

In [None]:
SYSTEM_MESSAGE = (
    "You are a helpful assistant. You first think about the reasoning process in the mind "
    "and then provide the user with the answer."
)
PROMPT_TEMPLATE = (
    "Using the numbers {numbers}, create an equation that equals {target}. "
    "You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. "
    "Show your work in <think> </think> tags. And return the final equation and answer in "
    "<answer> </answer> tags, for example <answer>(1 + 2) / (3 * 5)</answer>."
)

有了系统消息和提示模板，我们就可以生成训练提示了。

In [None]:
# Load and process dataset
def preprocess_example(example: Dict[str, Any]):
    numbers: List[int] = example["nums"]
    target: int = example["target"]

    prefix = [
        {"role": "system", "content": SYSTEM_MESSAGE},
        {"role": "user", "content": PROMPT_TEMPLATE.format(numbers=numbers, target=target)},
        {"role": "assistant", "content": "Let me solve this step by step.\n<think>"},
    ]
    input_ids = tokenizer.apply_chat_template(
        prefix, tokenize=True, continue_final_message=True
    )
    prompt = tokenizer.decode(
        input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
    )
    return {"prompt": prompt, "input_ids": input_ids}

# Note that the base model and "instruct" model have different eos token.
# Here we make sure to use the correct one.
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHAT_NAME)
EOS_TOKEN_ID = AutoTokenizer.from_pretrained(MODEL_NAME).eos_token_id
EOS_TOKEN = tokenizer.convert_ids_to_tokens(EOS_TOKEN_ID)

dataset = load_dataset(DATASET_NAME, split="train")
dataset = dataset.map(preprocess_example, num_proc=6)

# Split dataset
train_test_split = dataset.train_test_split(test_size=500, seed=42)
train_dataset = train_test_split["train"]
test_dataset = train_test_split["test"]

len(train_dataset), len(test_dataset)

In [None]:
print("Target: ", train_dataset[0]["target"])
print("Available Numbers: ", train_dataset[0]["nums"])

In [None]:
print(train_dataset[0]["prompt"])

在每个提示前添加了 `<assistant>` 标签和短语**“让我一步一步地解决这个问题。”**这有助于引导模型进入**回答模式**。如果没有这个引导，基础模型可能只会继续提示，而不是尝试解决任务，因为它本身没有理解指令的能力。

此外，我们对每个提示进行分词，并将结果存储为 `input_ids`，这将在稍后的训练中使用。

In [None]:
print(train_dataset[0]["input_ids"])

## Reward Function


DeepSeek R1 论文引入了**基于规则的奖励**来评估模型生成的解决方案是否正确。我们将采用类似的方法，定义两个自定义奖励函数：

---

## 奖励函数详解

1.  **格式奖励（Format Reward）**：
    * 检查输出是否遵循所需的格式：
        `<think> [思考过程] </think><answer> [答案] </answer>`
    * 强制执行这种格式主要是为了方便答案提取。虽然它并非答案正确性本身所必需，但在训练期间能大大简化解析过程。

2.  **等式奖励（Equation Reward）**：
    * 从 `<answer>` 标签中提取等式。
    * 验证该等式计算结果是否与目标结果匹配。
    * 确保所有提供的可用数字在等式中恰好使用一次。

---

## 最终奖励计算

分配给一个回合/轨迹（即提示 + 响应）的最终奖励是这两个组件的简单总和。值得注意的是，奖励只在输出的**最后一个 token** 处计算。从强化学习的角度来看，这意味着所有中间动作都获得零奖励。此外，我们这里也没有应用任何折扣（即 $\gamma = 1$）。

In [None]:
def format_reward_func(completion: str) -> float:
    """
    Format: <think>...</think>\n</answer>...</answer>

    Also checks that the content within <answer>...</answer> conforms to a
    specified pattern (only digits, + - * / ( ) . and whitespace).

    Args:
        completion (str): Generated output

    Returns:
        float: Reward score
    """
    # Define the allowed pattern (only numbers, +, -, *, /, (, ), ., and whitespace)
    allowed_pattern = r"^[\d+\-*/().\s]+$"

    try:
        # add synthetic <think> as its already part of the prompt and prefilled
        # for the assistant to more easily match the regex
        completion = "<think>" + completion

        # Strip EOS token if present
        if completion.endswith(EOS_TOKEN):
            completion = completion[:-len(EOS_TOKEN)]

        # Check if the format is correct
        # Pattern means:
        # 1) <think>...contents not including other <think> tags...</think>
        # 2) \n
        # 3) <answer>...anything...</answer>
        regex = r"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n<answer>([\s\S]*?)<\/answer>$"
        match = re.search(regex, completion, re.DOTALL)

        if match is None or len(match.groups()) != 2:
            # Format is incorrect
            return 0.0
        else:
            # Extract the content inside <answer>...</answer>
            answer_content = match.group(2).strip()

            # Check if answer content matches the allowed pattern
            if not re.match(allowed_pattern, answer_content):
                # If it doesn't match, reward is 0.5
                return 0.5
            else:
                # If both format and pattern are correct, reward is 1
                return 1.0
    except Exception:
        # Any error leads to 0 reward
        return 0.0


def equation_reward_func(completion: str, nums: List[int], target: int) -> float:
    """
    Evaluates completion based on mathematical correctness of the answer

    Args:
        completion (str): Generated output
        target (str): Expected answer
        nums (list): Available numbers to use in the equation

    Returns:
        float: Reward score
    """
    try:
        # Check if the format is correct
        match = re.search(r"<answer>(.*?)<\/answer>", completion)
        if match is None:
            return 0.0
        # Extract the "answer" part from the completion
        equation = match.group(1).strip()
        # Extract all numbers from the equation
        used_numbers = [int(n) for n in re.findall(r"\d+", equation)]

        # Check if all numbers are used exactly once
        if sorted(used_numbers) != sorted(nums):
            return 0.0
        # Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
        allowed_pattern = r"^[\d+\-*/().\s]+$"
        if not re.match(allowed_pattern, equation):
            return 0.0

        # Evaluate the equation with restricted globals and locals
        result = eval(equation, {"__builtins__": None}, {})
        # Check if the equation is correct and matches the ground truth
        if abs(float(result) - float(target)) < 1e-5:
            return 1.0
        else:
            return 0.0
    except Exception:
        # If evaluation fails, reward is 0
        return 0.0


def compute_reward(completion: str, sample: Dict[str, Any]) -> Tuple[float, Dict[str, float]]:
    nums = sample["nums"]
    target = sample["target"]

    format_reward = format_reward_func(completion)
    equation_reward = equation_reward_func(
        completion=completion, nums=nums, target=target
    )

    reward = format_reward + equation_reward

    metrics = {
        "format_reward": format_reward,
        "equation_reward": equation_reward,
    }

    return reward, metrics

In [None]:
# <think> is prefilled in the prompt. So, repeating it in the completion would be incorret.
format_reward_func("<think>I think the answer is </think>\n<answer>1+2</answer>")

In [None]:
format_reward_func("I think the answer is </think>\n<answer>1+2</answer>")

In [None]:
format_reward_func("<think>I think the<think>and even more</think> answer is </think>\n<answer>1+2</answer>")

In [None]:
equation_reward_func("I think the answer is </think>\n<answer>1+2+2</answer>", [1,2], 3)

## Episode Generation

情节生成（Episode generation）的目标是创建一个查询-响应对的集合，用于策略训练。从强化学习（RL）的角度来看，**查询（query）**充当初始状态，而**响应（response）**中生成的 token 则代表策略采取的行动。

`create_training_episodes` 函数接收一个提示（初始状态）列表以及我们使用模型生成的相应补全。在 GRPO 中，我们总是为每个提示生成多个响应——具体来说，`GENERATIONS_PER_SAMPLE` 会大于 1。这意味着，在情节生成之后，每次 RL 迭代我们都会得到 `batch_size × GENERATIONS_PER_SAMPLE` 个情节。

---

### 优势计算 (Advantage Computation)

除了生成情节，`create_training_episodes` 函数还负责计算每个响应 token 的**优势（advantage）**。

在 RL 术语中，一个 token 的优势代表了该 token 的行动与该特定状态（提示 + 前缀）下平均生成的 token 相比，好或坏的程度。理想情况下，我们会为每个 token 单独计算优势，以捕捉每一步对整体奖励的贡献。

然而，在 GRPO 中，没有按 token 计算的优势。相反，我们为每个响应计算一个单一的优势值。这个值反映了整个响应相对于为相同提示生成的其他响应的好坏程度。然后，我们将这个单一优势值均匀地分配给该响应中的所有 token。

GRPO 使用一个简单的公式来实现这一点：

1.  对于每个提示 $x$，以及其生成的一组响应 $y_1, y_2, \ldots, y_G \sim \pi(\cdot|x)$，计算它们的奖励 $R_1, R_2, \ldots, R_G$。
2.  计算该组的平均值和标准差：

    $$\mu = \text{mean}(R_1, R_2, \ldots, R_G)$$

    $$\sigma = \text{std}(R_1, R_2, \ldots, R_G)$$
    
3.  计算每个响应的**相对分数（relative score）**：

    $$R^*_i = \frac{R_i - \mu}{\sigma}$$
4.  将这个相对分数 $R^*_i$ 作为优势分配给第 $i$ 个响应的所有 token：

    $$A_t^{(i)} = R^*_i$$

这种**按组归一化（per-group normalization）**的方法鼓励优于平均水平的响应，并惩罚那些表现较差的响应。

---

### 优势的实际应用示例

考虑一个二元奖励场景，其中每个响应要么是正确的 (1)，要么是错误的 (0)：

```python
>>> rewards = np.array([1, 1, 0, 0, 0])
>>> (rewards - rewards.mean()) / (rewards.std())
array([ 1.22474487,  1.22474487, -0.81649658, -0.81649658, -0.81649658])
```

在这里，正确的响应获得了更高的优势分数，从而在未来的更新中得到推广。

如果只有一个响应是正确的：

```python
>>> rewards = np.array([1, 0, 0, 0, 0])
>>> (rewards - rewards.mean()) / (rewards.std())
array([ 2. , -0.5, -0.5, -0.5, -0.5])
```

这类似于提示中的问题太难，模型平均而言无法生成正确响应的情况。然而，如果其中一个响应是正确的，它将被赋予更高的优势分数，所有不正确的响应将被赋予负的相对分数。

如果所有响应都不正确：

```python
>>> rewards = np.array([0, 0, 0, 0, 0])
>>> (rewards - rewards.mean()) / (rewards.std() + 1e-6)
array([0., 0., 0., 0., 0.])
```

由于没有比平均更好的响应，模型收不到学习信号。

如果所有响应都正确：

```python
>>> rewards = np.array([1, 1, 1, 1, 1])
>>> (rewards - rewards.mean()) / (rewards.std() + 1e-6)
array([0., 0., 0., 0., 0.])
```

同样，没有提供学习信号，因为没有什么可以改进的。

在一个更复杂的情况下：

```python
>>> rewards = np.array([1, 1, 1, 1, 0])
>>> (rewards - rewards.mean()) / (rewards.std() + 1e-6)
array([0.5, 0.5, 0.5, 0.5, -2.])
```

这代表了对模型来说一个相对容易的问题。大多数响应都是正确的，但偶尔的错误响应会受到严厉的惩罚。

---

理解 GRPO 的优势计算如何鼓励模型学习更好的响应，即使是在不提供逐 token 奖励的情况下，这一点很重要。这能让模型在没有人类反馈的情况下，自主地进行自我改进。

In [None]:
def create_training_episodes(
    samples: List[Dict[str, Any]],
    all_generations: List[List[int]],
    all_finish_reasons: List[str],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Process model generations and calculate rewards for training episodes.

    This function processes generated responses and calculates rewards for training episodes by:
    1. Grouping generations by sample (GENERATIONS_PER_SAMPLE responses per input)
    2. Computing rewards and advantages for each response
    3. Processing response tokens

    Args:
        samples: List of input samples, each containing:
            - input_ids: List[int], tokenized input prompt
            - nums: List[int], numbers to use in equation
            - target: int, target value for equation
        all_generations: List of token ID sequences for each generated response
        all_finish_reasons: List of finish reasons for each generation ("stop" or other)

    Returns:
        Tuple containing:
        1. Dictionary with processed data for training:
            - all_query_token_ids: List[List[int]], input token IDs repeated for each generation
            - all_response_token_ids: List[List[int]], response token IDs with EOS tokens added
            - all_advantages: List[List[float]], advantage values repeated for each token
        2. Dictionary with generation statistics:
            - response_lengths: List[int], lengths of generated responses
            - rewards: List[float], raw reward values
            - non_stop_rate: List[bool], whether each generation ended naturally
            - reward_metrics/*: Various reward component metrics

    Example:
        >>> samples = [{"input_ids": [1,2,3], "nums": [1,2,3], "target": 6}]
        >>> generations = [[4,5, EOS_TOKEN_ID], [6,7], [8,9, EOS_TOKEN_ID]]  # 3 generations per sample
        >>> finish_reasons = ["stop", "length", "stop"]
        >>> episodes, stats = create_training_episodes(samples, generations, finish_reasons)
        >>> episodes
        {
            'all_query_token_ids': [[1,2,3], [1,2,3], [1,2,3]],
            'all_response_token_ids': [[4,5,EOS_TOKEN_ID], [6,7], [8,9,EOS_TOKEN_ID]],
            'all_advantages': [[0.5,0.5,0.5], [-1.0,-1.0], [0.5,0.5,0.5]]
        }
    """
    assert len(all_generations) == len(all_finish_reasons)
    assert len(all_generations) == len(samples) * GENERATIONS_PER_SAMPLE

    # Process responses and calculate rewards
    groups = [
        list(range(i, i + GENERATIONS_PER_SAMPLE))
        for i in range(0, len(all_generations), GENERATIONS_PER_SAMPLE)
    ]  # example: [[0, 1, 2], [3, 4, 5], [6, 7, 8]]

    all_query_token_ids, all_responses_token_ids, all_advantages = [], [], []

    stats = {
        "response_lengths": [],
        "rewards": [],
        "non_stop_rate": [],
    }

    for sample, group_indices in zip(samples, groups):
        finish_reasons = [all_finish_reasons[i] for i in group_indices]
        response_token_ids = [all_generations[i] for i in group_indices]
        responses = tokenizer.batch_decode(response_token_ids, skip_special_tokens=False)

        rewards_and_metrics = [compute_reward(resp, sample) for resp in responses]
        rewards, reward_metrics = zip(*rewards_and_metrics)

        rewards = np.array(rewards) # [group_size]
        response_advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-4)

        advantages = [
            [resp_adv] * len(resp)
            for resp_adv, resp in zip(response_advantages, response_token_ids)
        ]

        all_query_token_ids.extend([sample["input_ids"]] * GENERATIONS_PER_SAMPLE)
        all_responses_token_ids.extend(response_token_ids)
        all_advantages.extend(advantages)

        stats["rewards"].extend(rewards)
        stats["non_stop_rate"].extend([fr != "stop" for fr in finish_reasons])
        stats["response_lengths"].extend([len(ids) for ids in response_token_ids])
        for rm in reward_metrics:
            for k, v in rm.items():
                stats.setdefault(f"reward_metrics/{k}", []).append(v)

    episodes = {
        "all_query_token_ids": all_query_token_ids,
        "all_response_token_ids": all_responses_token_ids,
        "all_advantages": all_advantages,
    }

    return episodes, stats

In [None]:
case_0 = {
    "sample": {"input_ids": [1,2,3], "nums": [1,2,3], "target": 6},
    "generations": [[4,5, 22, 33], [6,7], [8,9, 11], [10,11]],
    "finish_reasons": ["stop", "length", "stop", "stop"]
}

case = case_0
episodes, stats = create_training_episodes([case["sample"]], case["generations"], case["finish_reasons"])
episodes,stats

In [None]:
case_1 = {
    "sample": {"input_ids": [33, 44], "nums": [11, 7, 8], "target": 26},
    "generations": [[1,2], [3,4], [5,6], [7,8]],
    "finish_reasons": ["stop", "stop", "length", "stop"]
}
case = case_1
episodes, stats = create_training_episodes([case["sample"]], case["generations"], case["finish_reasons"])
episodes,stats

In [None]:
case_2 = {
    "sample": {"input_ids": [9, 8, 7, 6, 5, 4], "nums": [1,2,3,4], "target": 10},
    "generations": [[9,10], [11,12], [13,14], [15,16]],
    "finish_reasons": ["length", "length", "stop", "stop"]
}
case = case_2
episodes, stats = create_training_episodes([case["sample"]], case["generations"], case["finish_reasons"])
episodes,stats

没错，正如您所见，在这个单一示例中，生成的**所有回合（episodes）的 `input_ids` 都是重复的**。

## Policy Gradient


现在我们有了一批带有相应优势值的情节，我们可以计算**策略梯度损失**来更新模型。

GRPO 使用与 PPO 相同的损失公式，但关键区别在于优势的计算方式。为了理解 `compute_pg_loss` 中的实现，我们首先回顾一下原始 PPO 目标：

$$
\mathcal{l}_{\text{PPO}} = \mathbb{E}\left[\min\left(
\frac{\pi_\theta(y_t \mid y_{<t}, x)}{\pi_{\theta_{\text{old}}}(y_t \mid y_{<t}, x)} A_t, \;
\text{clip}\left(
\frac{\pi_\theta(y_t \mid y_{<t}, x)}{\pi_{\theta_{\text{old}}}(y_t \mid y_{<t}, x)}, \;
1 - \epsilon, \; 1 + \epsilon
\right) A_t \right)\right]
$$其中：

- $\pi_{\theta}$ 是当前策略，
- $\pi_{\theta_{\text{old}}}$ 是来自上一次迭代的策略（我们从中采样情节的策略），
- $A_t$ 是优势（advantage）。

这个目标函数试图根据优势 $A\_t$ 增加或减少 token 的概率，但仅限于新旧策略概率之比在由裁剪阈值 $\\epsilon$ 控制的小范围内。这种裁剪机制可以防止训练过程中出现大的、不稳定的更新。

### 完全在线设置：简化目标函数

通常 PPO 中，可以使用同一批情节进行多次梯度更新。然而，在我们的例子中，我们每迭代只使用新鲜采样的情节进行**一次梯度更新**。这意味着：

- $\pi_{\theta} = \pi_{\theta_{\text{old}}}$
- 因此，
$$\frac{\pi_\theta(y_t \mid y_{<t}, x)}{\pi_{\theta_{\text{old}}}(y_t \mid y_{<t}, x)} = 1 $$
由于比率恰好为 1：

  - 裁剪函数变得不活跃。
  - $\min(\cdot,\cdot)$ 运算符只返回未裁剪项。

所以，目标函数**简化为**：

$$\mathcal{l}_{\text{PPO}} = \mathbb{E}\left[ \frac{\pi_\theta(y_t \mid y_{<t}, x)}{\pi_{\theta_{\text{old}}}(y_t \mid y_{<t}, x)} A_t \right]$$

对这个损失函数关于 $\theta$ 求梯度，我们得到：

$$\vec{g}_{\text{PPO}} = \nabla_\theta \mathcal{l}_{\text{PPO}} = 2 \underbrace{\mathbb{E}\left[ \nabla_\theta \log \pi_\theta(y_t \mid y_{<t}, x) \cdot A_t \right]}_{\text{带有优势的原始策略梯度}}$$

这是**标准策略梯度**公式，其中对数概率由优势加权。实际上，我们恢复了香草 REINFORCE 风格的学习。

> 注意：常数乘数（如 2）不影响梯度的方向，可以安全地忽略。
> - 在强化学习的背景下，“Vanilla Policy Gradient”（VPG）是指最简单的策略梯度算法，也称为 REINFORCE 。 这是一种基本方法，通过更新参数直接优化策略，以增加导致更高奖励的行为的概率并降低导致较低奖励的行为的概率
> - https://spinningup.openai.com/en/latest/algorithms/vpg.html

事实上，这种行为并非 GRPO 独有。在 PPO、TRPO 等所有方法中，收集新数据后的第一个梯度步骤总是会简化为相同的形式。只有在优化步骤之后，裁剪或信任区域约束才开始生效。

### KL 惩罚

最终损失还包含一个 **KL 惩罚**项，以确保新策略不会偏离参考策略太远：


$$ \mathcal{l} = \mathcal{l}_{\text{PPO}} - \beta \cdot \text{KL}(\pi_\theta \parallel \pi_{\theta_{\text{ref}}})$$

我们使用 [Schulman 的这篇博客文章](http://joschu.net/blog/kl-approx.html) 中的 **k3 估计器**来估计 KL 散度：

$$\text{KL}(\pi_\theta \parallel \pi_{\theta_{\text{ref}}}) = \mathbb{E}\left[\frac{\pi_{\theta_{\text{ref}}}(y_t \mid y_{<t}, x)}{\pi_\theta(y_t \mid y_{<t}, x)} - \log\left(\frac{\pi_{\theta_{\text{ref}}}(y_t \mid y_{<t}, x)}{\pi_\theta(y_t \mid y_{<t}, x)}\right) - 1\right]$$

这种正则化项柔和地约束了更新后的模型，使其保持与参考策略的接近。

### GRPO 与 PPO/VinePPO 的主要区别

**GRPO** 与 **PPO/VinePPO** 等方法之间的主要区别在于**优势的计算和应用方式**：

  - 在 **PPO/VinePPO** 中，每个 token/步骤的优势是单独计算的。这允许在序列中进行细粒度的信用分配（fine-grained credit assignment）。
  - 在 **GRPO** 中，为整个响应计算一个**单一的标量优势（scalar advantage）**，并**均匀地应用于该响应中的所有 token**。

这种区别如下所示：

#### GRPO 中的成功响应：

<img src="https://github.com/McGill-NLP/nano-aha-moment/blob/main/assets/grpo_successful.png?raw=true" alt="GRPO vs PPO/VinePPO: successful response" width="500">

#### GRPO 中的失败响应：
<img src="https://github.com/McGill-NLP/nano-aha-moment/blob/main/assets/grpo_unsuccessful.png?raw=true" alt="GRPO vs PPO/VinePPO: failed response" width="500">

在 GRPO 中，响应中的所有 token 都以相同的幅度进行更新。相比之下，PPO/VinePPO 以不同的优势值更新每个 token/步骤：

<img src="https://github.com/McGill-NLP/nano-aha-moment/blob/main/assets/ppo_and_vineppo.png?raw=true" alt="GRPO vs PPO/VinePPO: PPO and VinePPO" width="500">


In [None]:
def compute_pg_loss(
    policy_model: Union[DeepSpeedEngine, PreTrainedModel],
    reference_model: Union[DeepSpeedEngine, PreTrainedModel],
    batch: Dict[str, torch.Tensor],
    total_response_len: int,
) -> Tuple[torch.Tensor, Dict[str, float]]:
    """
    Compute the policy gradient loss with KL penalty between policy and reference models.

    This function:
    1. Computes log probabilities for both policy and reference models
    2. Calculates KL divergence penalty between the models
    3. Computes policy gradient loss using advantages
    4. Combines the losses with KL coefficient

    Args:
        policy_model: The model being trained
        reference_model: The reference model for KL penalty calculation
        batch: Dictionary containing:
            - input_ids: Tensor of shape [batch_size, seq_len]
            - attention_mask: Tensor of shape [batch_size, seq_len]
            - labels: Tensor of shape [batch_size, seq_len] with -100 for ignored positions
            - advantages: Tensor of shape [batch_size, seq_len]

    Returns:
        Tuple containing:
            - loss: Combined policy gradient and KL penalty loss (scalar tensor)
            - metrics: Dictionary with detailed loss components:
                - policy_loss: Pure policy gradient loss
                - kl_penalty: KL divergence penalty
                - entropy: Policy entropy
    """
    input_ids = batch["input_ids"]  # [batch_size, seq_len]
    attention_mask = batch["attention_mask"]  # [batch_size, seq_len]
    labels = batch["labels"]  # [batch_size, seq_len]
    advantages = batch["advantages"]  # [batch_size, seq_len]

    model_inputs = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }

    labels_mask = (labels[..., 1:] != -100).float()  # [batch_size, seq_len-1]

    with torch.no_grad():
        ref_logps = compute_token_log_probs(
            reference_model, model_inputs, TEMPERATURE
        )  # [batch_size, seq_len-1]

    logps = compute_token_log_probs(policy_model, model_inputs, TEMPERATURE)  # [batch_size, seq_len-1]

    kl_penalty = torch.exp(ref_logps - logps) - (ref_logps - logps) - 1  # [batch_size, seq_len-1]
    kl_penalty = kl_penalty * labels_mask  # [batch_size, seq_len-1]

    entropy = -logps.sum() / labels_mask.sum()  # scalar

    policy_loss = -logps * advantages[..., 1:]  # [batch_size, seq_len-1]
    policy_loss = policy_loss * labels_mask  # [batch_size, seq_len-1]

    loss = (policy_loss + KL_COEFFICIENT * kl_penalty).sum() / total_response_len  # scalar

    metrics = {
        "policy_loss": policy_loss.sum().item() / total_response_len,
        "kl_penalty": kl_penalty.sum().item() / total_response_len,
        "entropy": entropy.item() / total_response_len,
    }

    return loss, metrics

## Training

在开始强化学习（RL）循环之前，我们需要设置所有必要的组件：

* **策略模型（Policy Model）**：将使用策略梯度进行训练的主要模型。
* **参考模型（Reference Model）**：一个冻结的基础模型副本，用于 KL 正则化。
* **DeepSpeed**：两个模型都用 DeepSpeed 进行初始化。
* **vLLM 推理引擎（vLLM Inference Engine）**：用于在情节生成期间进行快速批处理推理。
* **WandB 日志记录（WandB Logging）**：我们初始化 WandB 来跟踪训练指标、超参数和检查点。

最后，如果检测到现有检查点，我们会自动从上次中断的地方恢复训练。

几点说明：
* 我们将参考模型移到 CPU，并且只在策略梯度计算期间将其重新移回 GPU。由于模型相对较小，这种在 GPU 和 CPU 之间的来回移动速度非常快。
* 尽管整个训练都在单个 GPU 上运行，但我们仍然使用 DeepSpeed Zero stage 2。这是因为 stage 2 附带了一些优化，可以避免内存碎片，从而充分利用 GPU 内存。
* 在我们的设置中，Flash Attention 是必需的，因为它将 Transformer 的内存需求从 $\mathcal{O}(n^2)$ 降低到 $\mathcal{O}(n)$，其中 $n$ 是序列长度。

In [None]:
import os
# https://docs.vllm.ai/en/latest/usage/troubleshooting.html#python-multiprocessing
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"


In [None]:
import os
from google.colab import userdata
os.environ["WANDB_API_KEY"]=userdata.get('WANDB_API_KEY')

In [None]:
# Initialize main and reference models
policy_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    device_map=0,
)
reference_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    device_map=0,
)
policy_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})


# Initialize DeepSpeed engines
policy_model, *_ = deepspeed.initialize(
    model=policy_model,
    config=deepspeed_config,
    model_parameters=policy_model.parameters(),
)
reference_model, *_ = deepspeed.initialize(
    model=reference_model,
    config=ref_deepspeed_config,
)

reference_model.module.cpu()

############################################
# Initialize vLLM (Inference) engine
############################################

inference_engine = LLM(
    model=MODEL_NAME,
    skip_tokenizer_init=False,
    gpu_memory_utilization=0.2,
    enable_prefix_caching=True,
    swap_space=1,
    scheduling_policy="fcfs",
    dtype=torch.bfloat16,
    max_model_len=2048,
    enable_sleep_mode=True,
)

# Wandb for logging
wandb.init(
    project="r1-aha-moment",
    name=RUN_NAME,
    config={
        "model_name": MODEL_NAME,
        "learning_rate": LEARNING_RATE,
        "num_iterations": NUM_ITERATIONS,
        "episodes_per_iteration": EPISODES_PER_ITERATION,
        "rollouts_per_episode": GENERATIONS_PER_SAMPLE,
        "kl_coefficient": KL_COEFFICIENT,
        "temperature": TEMPERATURE,
    },
)

# Load checkpoint if it exists
begin_iter = 0
ckpt_path, ckpt_iter = find_last_checkpoint(Path(EXP_DIR))
if ckpt_path is not None:
    print(f"Resuming from checkpoint {ckpt_path} at iteration {ckpt_iter}")
    out = policy_model.load_checkpoint(ckpt_path / "deepspeed")
    if out is None:
        raise RuntimeError(f"Failed to load checkpoint {ckpt_path}")
    begin_iter = ckpt_iter + 1
    load_model_into_vllm(policy_model, inference_engine)

### Training loop

环境搭建完毕，我们现在可以启动主要的训练循环了。循环的每次迭代都会执行以下步骤：

---

## 训练循环步骤

1.  **评估**（可选）：
    每隔几次迭代，模型会在测试集上进行评估，以监控训练进展。

2.  **情节生成**：
    从数据集中抽样一批提示，并使用推理引擎为每个提示生成多个响应。然后，我们会让推理引擎进入“休眠”状态。

3.  **奖励计算**：
    计算每个生成情节的奖励和优势。

4.  **策略梯度训练**：
    利用计算出的优势，我们计算策略梯度损失并更新模型参数。训练通过**梯度累积**来处理大批量数据。请注意，我们每次迭代只应用**一次梯度更新**。

5.  **推理引擎更新**：
    推理引擎被“唤醒”，并用最新的模型权重进行更新。

6.  **日志记录**：
    使用 WandB 记录训练和评估指标。

7.  **检查点**：
    每 50 次迭代，模型的权重和优化器状态会被保存下来。

此循环会持续运行，直到完成指定数量的迭代。

---

## vLLM 的“休眠”机制

在训练开始之前，我们会让 vLLM 进入“休眠”模式，以释放其 KV 缓存和模型权重，确保有足够的 GPU 内存可用于策略训练。训练步骤完成后，vLLM 会被“唤醒”，重新初始化其 KV 缓存，并准备好使用更新后的模型参数进行下一轮的采样。

In [None]:
for iteration in trange(NUM_ITERATIONS):
    print(f"Iteration {iteration}/{NUM_ITERATIONS}")

    metrics = {}

    #########################################################
    # Evaluation
    #########################################################

    eval_stats = None
    if iteration % 25 == 0:
        print("Evaluating on eval set...")
        eval_episodes, eval_stats = evaluate_on_test_set(
            inference_engine=inference_engine,
            test_dataset=test_dataset,
            tokenizer=tokenizer,
            eos_token=EOS_TOKEN,
            eval_sampling_params=SamplingParams(
                temperature=0.3,
                max_tokens=1024,
                n=1,
                detokenize=False,
                stop_token_ids=[EOS_TOKEN_ID],
            ),
            reward_func=lambda completion, sample: compute_reward(
                completion, sample
            ),
        )
        eval_episode_table = dump_episodes(
            episodes=eval_episodes,
            episodes_stats=eval_stats,
            exp_dir=Path(EXP_DIR),
            tokenizer=tokenizer,
            iteration=iteration,
            is_eval=True,
        )
        wandb.log({"eval/episodes": eval_episode_table, "iteration": iteration})


    #########################################################
    # Generate Episodes
    #########################################################

    # Sample training batch
    num_samples = EPISODES_PER_ITERATION // GENERATIONS_PER_SAMPLE
    indices = np.random.choice(
        len(train_dataset), size=num_samples, replace=False
    )
    samples = train_dataset.select(indices)

    # Sample responses
    outputs = inference_engine.generate(
        prompt_token_ids=samples["input_ids"],
        sampling_params=SamplingParams(
            n=GENERATIONS_PER_SAMPLE,
            temperature=TEMPERATURE,
            top_p=TOP_P,
            top_k=TOP_K,
            max_tokens=MAX_RESPONSE_TOKENS,
            detokenize=False,
            stop_token_ids=[EOS_TOKEN_ID],
        )
    )
    all_generations = [list(g.token_ids) for out in outputs for g in out.outputs]
    all_finish_reasons = [g.finish_reason for out in outputs for g in out.outputs]
    inference_engine.sleep(1)

    print(f"Generated {len(all_generations)} responses")
    gc.collect()
    torch.cuda.empty_cache()
    time.sleep(1)

    # Process responses and calculate rewards
    episodes, episodes_stats = create_training_episodes(
        samples,
        all_generations,
        all_finish_reasons,
    )
    for k, v in episodes_stats.items():
        metrics.setdefault(k, []).extend(v)

    episode_table = dump_episodes(
        episodes=episodes,
        episodes_stats=episodes_stats,
        exp_dir=Path(EXP_DIR),
        tokenizer=tokenizer,
        iteration=iteration,
    )

    #########################################################
    # Training
    #########################################################

    # Prepare training batch
    model_inputs = prepare_model_inputs(
        query_token_ids=episodes["all_query_token_ids"],
        response_token_ids=episodes["all_response_token_ids"],
        advantages=episodes["all_advantages"],
        device="cuda"
    )

    # Calculate losses and update model
    policy_model.train()
    reference_model.module.cuda()
    reference_model.eval()

    total_response_len = (model_inputs["labels"] != -100).sum().item()

    for i in trange(0, EPISODES_PER_ITERATION, PER_DEVICE_BATCH_SIZE, desc="Gradient Accumulation"):
        batch = {
            k: v[i : i + PER_DEVICE_BATCH_SIZE]
            for k, v in model_inputs.items()
        }

        # Compute policy gradient loss
        loss, loss_metrics = compute_pg_loss(
            policy_model=policy_model,
            reference_model=reference_model,
            batch=batch,
            total_response_len=total_response_len,
        )

        # Track metrics
        metrics.setdefault("loss", []).append(loss.item())
        grad_norm = policy_model.get_global_grad_norm()
        if grad_norm is not None:
            grad_norm = grad_norm.item()
        metrics.setdefault("grad_norm", []).append(grad_norm)
        for k, v in loss_metrics.items():
            metrics.setdefault(k, []).append(v.item() if isinstance(v, torch.Tensor) else v)

        # Backpropagation and optimization step
        policy_model.backward(loss, scale_wrt_gas=False)

        # Free memory
        del loss, loss_metrics
        if policy_model.is_gradient_accumulation_boundary():
            reference_model.module.cpu()

        policy_model.step()

    #########################################################
    # Update inference engine weights
    #########################################################

    gc.collect()
    torch.cuda.empty_cache()
    time.sleep(1)

    inference_engine.wake_up()
    load_model_into_vllm(policy_model, inference_engine)

    gc.collect()
    torch.cuda.empty_cache()
    time.sleep(1)


    #########################################################
    # Log metrics
    #########################################################

    train_metrics = {
        k: np.mean(v) for k, v in metrics.items() if None not in v
    }
    train_metrics["learning_rate"] = policy_model.get_lr()[0]
    logs = {
        "iteration": iteration,
        f"episodes/iter_{iteration:06d}": episode_table,
        **{f"train/{k}": v for k, v in train_metrics.items()},
    }
    if eval_stats is not None:
        eval_metrics = {k: np.mean(v) for k, v in eval_stats.items() if None not in v}
        logs.update({f"eval/{k}": v for k, v in eval_metrics.items()})
    wandb.log(logs)

    selected_keys = [
        "train/kl_penalty",
        "train/rewards",
        "train/reward_metrics/format_reward",
        "train/reward_metrics/equation_reward",
        "eval/rewards",
        "eval/reward_metrics/format_reward",
        "eval/reward_metrics/equation_reward",
    ]
    selected_metrics = {k: logs[k] for k in selected_keys if k in logs}
    print(f"KEY METRICS: {selected_metrics}")

    if iteration % 50 == 0 and iteration != 0:
        policy_model.module.save_pretrained(
            str(Path(EXP_DIR) / "checkpoints" / f"ckpt_{iteration:06d}" / "hf_model")
        )
        policy_model.save_checkpoint(
            str(Path(EXP_DIR) / "checkpoints" / f"ckpt_{iteration:06d}" / "deepspeed")
        )

## Citation

如果您在研究中使用此代码库，请引用我们，引用方式如下：

```bibtex
@misc{Kazemnejad2025:NanoAhaMoment,
  author       = {Amirhossein Kazemnejad and Milad Aghajohari and Alessandro Sordoni and Aaron Courville and Siva Reddy},
  title        = {Nano Aha! Moment: Lunch Break Reproduction of DeepSeek R1-Zero from Scratch},
  year         = {2025},
  howpublished = {\url{https://github.com/McGill-NLP/nano-aha-moment}},
  note         = {GitHub repository}
}
```