In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
from datasets import load_dataset
import torch

eval_dataset = load_dataset("jeggers/competition_math", "original", split="test")
eval_loader = torch.utils.data.DataLoader(
    eval_dataset,
    batch_size=32,
    shuffle=True,
)
examples = next(iter(eval_loader))
print(examples)

{'problem': ['What is the area of trapezoid $OBCD$ below? [asy]\nsize(200);\ndefaultpen(linewidth(0.8));\nxaxis("$x$",-4,10);\nyaxis("$y$",-3,5);\ndot(Label("$O$",align=SW),(0,0));\ndot(Label("$D(2,3)$",align=NW),(2,3));\ndot(Label("$C(4,3)$",align=NE),(4,3));\ndot(Label("$B(8,0)$",align=S),(8,0));\ndraw((0,0)--(2,3)--(4,3)--(8,0));\n[/asy]', 'Find all values of $k$ for which the system\n\\begin{align*}\nx + ky - z &= 0, \\\\\nkx - y - z &= 0, \\\\\nx + y - kz &= 0\n\\end{align*}has a non-trivial solution.  (In other words, find all values of $k$ for which the system has a solution other than $(x,y,z) = (0,0,0).$)', 'If the roots of the quadratic equation $\\frac12x^2+99x+c=0$ are $x=-99+\\sqrt{8001}$ and $x=-99-\\sqrt{8001}$, then what is the value of $c$?', 'Let A = 1, B = 2, C = 3, ..., Z = 26. The product value of a word is equal to the product of the values of its letters. For example, CAB has a product value of 3 $\\times$ 1 $\\times$ 2 = 6. What common English word has a product

In [None]:
from dataclasses import dataclass
from typing import Literal
import torch
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"


@dataclass
class Config:
    n_grpo_steps: int = 200
    learning_rate: float = 1e-5
    advantage_eps: float = 1e-6
    rollout_batch_size: int = 256
    group_size: int = 8
    sampling_temperature: float = 1.0
    sampling_min_tokens: int = 4  # As in Expiter, disallow empty string responses
    sampling_max_tokens: int = 1024
    epochs_per_rollout_batch: int = 1  # On-policy
    train_batch_size: int = 256  # On-policy
    gradient_accumulation_steps: int = 128  # microbatch size is 2, will fit on H100
    gpu_memory_utilization: float = 0.09
    loss_type: Literal[
        "no_baseline",
        "reinforce_with_baseline",
        "grpo_clip",
    ] = "reinforce_with_baseline"
    use_std_normalization: bool = True

    model_id: str = "Qwen/Qwen2.5-Math-1.5B"
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    train_data: str = "results/math_1.5B_train.jsonl"
    eval_data: str = "jeggers/competition_math"
    prompt_template: str = ""

    def __post_init__(self):
        self.mini_batch_size = (
            self.rollout_batch_size // self.gradient_accumulation_steps
        )
        self.n_prompts_per_rollout_batch = self.rollout_batch_size // self.group_size


cfg = Config()
with open("prompts/r1_zero.prompt") as f:
    cfg.prompt_template = f.read()

# For test only

cfg.n_grpo_steps = 1

In [None]:
from expert_iteration import init_policy_model, init_vllm
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
from vllm import LLM, SamplingParams


policy = init_policy_model(cfg.model_id, cfg.device)
eval_model = init_vllm(
    cfg.model_id,
    cfg.device,
    seed=42,
    gpu_memory_utilization=cfg.gpu_memory_utilization,
)
tokenizer = AutoTokenizer.from_pretrained(cfg.model_id)
optimizer = torch.optim.AdamW(
    policy.parameters(),
    lr=cfg.learning_rate,
    weight_decay=0.0,
    betas=(0.9, 0.95),
)

INFO 11-30 09:16:18 __init__.py:190] Automatically detected platform cuda.


`torch_dtype` is deprecated! Use `dtype` instead!


INFO 11-30 09:16:31 config.py:542] This model supports multiple tasks: {'embed', 'score', 'generate', 'classify', 'reward'}. Defaulting to 'generate'.
INFO 11-30 09:16:31 llm_engine.py:234] Initializing a V0 LLM engine (v0.7.2) with config: model='Qwen/Qwen2.5-Math-1.5B', speculative_config=None, tokenizer='Qwen/Qwen2.5-Math-1.5B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=Qwen/Qwen2.5-Math-1.5B, num_scheduler_st

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 11-30 09:16:35 model_runner.py:1115] Loading model weights took 2.8797 GB
INFO 11-30 09:16:36 worker.py:267] Memory profiling takes 0.55 seconds
INFO 11-30 09:16:36 worker.py:267] the current vLLM instance can use total_gpu_memory (79.25GiB) x gpu_memory_utilization (0.09) = 7.13GiB
INFO 11-30 09:16:36 worker.py:267] model weights take 2.88GiB; non_torch_memory takes 0.09GiB; PyTorch activation peak memory takes 1.40GiB; the rest of the memory reserved for KV Cache is 2.76GiB.
INFO 11-30 09:16:36 executor_base.py:110] # CUDA blocks: 6469, # CPU blocks: 9362
INFO 11-30 09:16:36 executor_base.py:115] Maximum concurrency for 4096 tokens per request: 25.27x
INFO 11-30 09:16:39 model_runner.py:1434] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utili

Capturing CUDA graph shapes: 100%|██████████| 35/35 [00:15<00:00,  2.30it/s]

INFO 11-30 09:16:54 model_runner.py:1562] Graph capturing finished in 15 secs, took 0.21 GiB
INFO 11-30 09:16:54 llm_engine.py:431] init engine (profile, create kv cache, warmup model) took 19.01 seconds





### Algorithm 3 Group Relative Policy Optimization (GRPO)

**Input:** initial policy model $\pi_{\theta_{\text{init}}}$; reward function $R$; task questions $\mathcal{D}$

1.  policy model $\pi_\theta \leftarrow \pi_{\theta_{\text{init}}}$
2.  **for** step = 1, ..., `n_grpo_steps` **do**
3.  $\quad$ Sample a batch of questions $\mathcal{D}_b$ from $\mathcal{D}$
4.  $\quad$ Set the old policy model $\pi_{\theta_{\text{old}}} \leftarrow \pi_\theta$/
5.  $\quad$ Sample $G$ outputs $\{o^{(i)}\}_{i=1}^G \sim \pi_{\theta_{\text{old}}}(\cdot \mid q)$ for each question $q \in \mathcal{D}_b$
6.  $\quad$ Compute rewards $\{r^{(i)}\}_{i=1}^G$ for each sampled output $o^{(i)}$ by running reward function $R(q, o^{(i)})$
7.  $\quad$ Compute $A^{(i)}$ with group normalization (Eq. 28)
8.  $\quad$ **for** train step = 1, ..., `n_train_steps_per_rollout_batch` **do**
9.  $\quad\quad$ Update the policy model $\pi_\theta$ by maximizing the GRPO-Clip objective (to be discussed, Eq. 29)
10. $\quad$ **end for**
11. **end for**

**Output** $\pi_\theta$

In [None]:
from cs336_alignment.grpo import compute_log_probs
from cs336_alignment.sft import tokenize_prompt_and_output
from cs336_alignment.grpo import (
    compute_group_normalized_rewards,
    grpo_microbatch_train_step,
)
from cs336_alignment.drgrpo_grader import r1_zero_reward_fn
from datasets import load_dataset
from cs336_alignment.expert_iteration import load_policy_into_vllm_instance

rollout_dataset = load_dataset(cfg.eval_data, "original", split="train")
rollout_loader = torch.utils.data.DataLoader(
    rollout_dataset,
    batch_size=cfg.n_prompts_per_rollout_batch,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)
rollout_sampling_params: SamplingParams = SamplingParams(
    temperature=cfg.sampling_temperature,
    top_p=1.0,
    max_tokens=cfg.sampling_max_tokens,
    min_tokens=cfg.sampling_min_tokens,
    stop=["</answer>"],
    include_stop_str_in_output=True,
    n=cfg.group_size,
)
data_iter = iter(rollout_loader)

for grpo_step in range(cfg.n_grpo_steps):
    load_policy_into_vllm_instance(policy, eval_model)
    examples = next(data_iter)
    prompts = [
        cfg.prompt_template.format(question=problem) for problem in examples["problem"]
    ]
    repeated_ground_truths = [
        extracted_solution
        for extracted_solution in examples["extracted_solution"]
        for _ in range(cfg.group_size)
    ]
    flat_prompt_strs = [prompt for prompt in prompts for _ in range(cfg.group_size)]
    responses = eval_model.generate(prompts, rollout_sampling_params)
    flat_response_sts = [
        output.text for response in responses for output in response.outputs
    ]
    tokenized_dict = tokenize_prompt_and_output(
        flat_prompt_strs, flat_response_sts, tokenizer
    )
    prompts = tokenized_dict["input_ids"].to(cfg.device)
    responses = tokenized_dict["labels"].to(cfg.device)
    response_mask = tokenized_dict["response_mask"].to(cfg.device)
    with torch.no_grad():
        old_log_probs = compute_log_probs(prompts, responses, policy, True)
        advantages, raw_rewards, _ = compute_group_normalized_rewards(
            reward_fn=r1_zero_reward_fn,
            rollout_responses=flat_response_sts,
            repeated_ground_truths=repeated_ground_truths,
            group_size=cfg.group_size,
            advantage_eps=1e-8,
            normalize_by_std=False,
        )
        advantages = advantages.unsqueeze(-1).to(cfg.device)
        raw_rewards = raw_rewards.unsqueeze(-1).to(cfg.device)

    for epoch in range(cfg.epochs_per_rollout_batch):
        for mb_step, mini_batch_start_idx in enumerate(
            range(0, cfg.rollout_batch_size, cfg.mini_batch_size)
        ):
            mb_idx = range(cfg.rollout_batch_size)[
                mini_batch_start_idx : mini_batch_start_idx + cfg.mini_batch_size
            ]
            mb_prompts = prompts[mb_idx]
            mb_responses = responses[mb_idx]
            mb_response_mask = response_mask[mb_idx]
            mb_raw_rewards = raw_rewards[mb_idx]
            mb_advantages = advantages[mb_idx]
            mb_old_log_probs = old_log_probs[mb_idx]
            policy_log_probs = compute_log_probs(mb_prompts, mb_responses, policy)

            loss, _ = grpo_microbatch_train_step(
                policy_log_probs=policy_log_probs,
                response_mask=mb_response_mask,
                gradient_accumulation_steps=cfg.gradient_accumulation_steps,
                loss_type=cfg.loss_type,
                raw_rewards=mb_raw_rewards,
                advantages=mb_advantages,
                old_log_probs=mb_old_log_probs,
                cliprange=0.2,
            )

            if (mb_step + 1) % cfg.gradient_accumulation_steps == 0:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    policy.parameters(), max_norm=1
                )
                optimizer.step()
                optimizer.zero_grad()

Processed prompts:  12%|█▎        | 32/256 [00:11<01:23,  2.70it/s, est. speed input: 483.72 toks/s, output: 8029.39 toks/s]


range(0, 2)
torch.Size([2, 1])
range(2, 4)
torch.Size([2, 1])
range(4, 6)
torch.Size([2, 1])
range(6, 8)
torch.Size([2, 1])
range(8, 10)
torch.Size([2, 1])
range(10, 12)
torch.Size([2, 1])
range(12, 14)
torch.Size([2, 1])
range(14, 16)
torch.Size([2, 1])
range(16, 18)
torch.Size([2, 1])
range(18, 20)
torch.Size([2, 1])
range(20, 22)
torch.Size([2, 1])
range(22, 24)
torch.Size([2, 1])
range(24, 26)
torch.Size([2, 1])
range(26, 28)
torch.Size([2, 1])
range(28, 30)
torch.Size([2, 1])
range(30, 32)
torch.Size([2, 1])
range(32, 34)
torch.Size([2, 1])
range(34, 36)
torch.Size([2, 1])
range(36, 38)
torch.Size([2, 1])
range(38, 40)
torch.Size([2, 1])
range(40, 42)
torch.Size([2, 1])
range(42, 44)
torch.Size([2, 1])
range(44, 46)
torch.Size([2, 1])
range(46, 48)
torch.Size([2, 1])
range(48, 50)
torch.Size([2, 1])
range(50, 52)
torch.Size([2, 1])
range(52, 54)
torch.Size([2, 1])
range(54, 56)
torch.Size([2, 1])
range(56, 58)
torch.Size([2, 1])
range(58, 60)
torch.Size([2, 1])
range(60, 62)
torch

In [None]:
for x in [
    prompts,
    responses,
    response_mask,
    policy_log_probs,
    advantages,
    raw_rewards,
    loss,
]:
    print(x.shape)
print(len(repeated_ground_truths))

torch.Size([6, 1305])
torch.Size([6, 1305])
torch.Size([6, 1305])
torch.Size([6, 1305])
torch.Size([6, 1])
torch.Size([6, 1])
torch.Size([6, 1305])
6
