# Minimalist Reproduction of R1-Zero

## Prerequisites

In [15]:
import os
from pathlib import Path

SCRATCH = Path.home() / "scratch"

os.environ["HF_HOME"] = str(SCRATCH / "hf_home")
os.environ["CUDA_HOME"] = "/cvmfs/ai.mila.quebec/apps/arch/common/cuda/12.5.1" # Hardcoded for now
# os.environ["HF_TOKEN"] = "..." # Optional. Only needed for Llama models

# Needed to stop DeepSpeed from complaining
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "23643"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"

In [None]:
import argparse
import random
import re
import socket
import time
from typing import List

import deepspeed
import numpy as np
import sglang
import torch
import wandb
from datasets import load_dataset
from tqdm import trange
from transformers import AutoModelForCausalLM, AutoTokenizer

from utils import compute_token_log_probs

## Hyperparameters

In [3]:
# RL parameters
NUM_ITERATIONS = 1000
EPISODES_PER_ITERATION = 64
GENERATIONS_PER_SAMPLE = 4
KL_COEFFICIENT = 0.001

# Training hyperparameters
PER_DEVICE_BATCH_SIZE = 4
LEARNING_RATE = 1e-6

# Sampling parameters
MAX_RESPONSE_TOKENS = 1024
TEMPERATURE = 1.0
TOP_P = 1.0
TOP_K = -1 # no top k

# DeepSpeed configuration
deepspeed_config = {
    "bf16": {"enabled": True},
    "zero_optimization": {"stage": 2, "overlap_comm": False},
    "train_batch_size": EPISODES_PER_ITERATION,
    "train_micro_batch_size_per_gpu": PER_DEVICE_BATCH_SIZE,
    "gradient_accumulation_steps": EPISODES_PER_ITERATION // PER_DEVICE_BATCH_SIZE,
    "gradient_clipping": 1.0,
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": LEARNING_RATE,
            "betas": (0.9, 0.999),
            "eps": 1e-8,
            "weight_decay": 0.0,
            "torch_adam": True,
        },
    },
}
ref_deepspeed_config = {
    "bf16": {"enabled": True},
    "train_batch_size": EPISODES_PER_ITERATION,
    "train_micro_batch_size_per_gpu": PER_DEVICE_BATCH_SIZE,
    "gradient_accumulation_steps": EPISODES_PER_ITERATION // PER_DEVICE_BATCH_SIZE,
}

RUN_NAME = "r1-zero"

## Prompt and Dataset

In [4]:
MODEL_NAME = "Qwen/Qwen2.5-3B"
MODEL_CHAT_NAME = MODEL_NAME + "-Instruct"

In [5]:
SYSTEM_MESSAGE = "You are a helpful assistant. You first thinks about the reasoning process in the mind and then provides the user with the answer."
PROMPT_TEMPLATE = "Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final equation and answer in <answer> </answer> tags, for example <answer>(1 + 2) / (3 * 5)</answer>."

In [6]:
# Dataset configuration
from typing import Any, Dict


TEST_SPLIT_SIZE = 500

# Load and process dataset
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHAT_NAME)
EOS_TOKEN_ID = AutoTokenizer.from_pretrained(MODEL_NAME).eos_token_id
EOS_TOKEN = tokenizer.convert_ids_to_tokens(EOS_TOKEN_ID)


def preprocess_example(example: Dict[str, Any]):
    numbers: List[int] = example["nums"]
    target: int = example["target"]

    prefix = [
        {"role": "system", "content": SYSTEM_MESSAGE},
        {"role": "user", "content": PROMPT_TEMPLATE.format(numbers=numbers, target=target)},
        {"role": "assistant", "content": "Let me solve this step by step.\n<think>"},
    ]
    input_ids = tokenizer.apply_chat_template(
        prefix, tokenize=True, continue_final_message=True
    )
    prompt = tokenizer.decode(
        input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
    )
    return {"prompt": prompt, "input_ids": input_ids}

dataset = load_dataset("Jiayi-Pan/Countdown-Tasks-3to4", split="train")
dataset = dataset.map(preprocess_example, num_proc=6)
dataset

# Split dataset
train_test_split = dataset.train_test_split(test_size=TEST_SPLIT_SIZE)
train_dataset = train_test_split["train"]
test_dataset = train_test_split["test"]

.... display some examples...

## Reward Function


In [7]:
from typing import Tuple


def format_reward_func(completion: str) -> float:
    """
    Format: <think>...</think><answer>...</answer>

    Also checks that the content within <answer>...</answer> conforms to a
    specified pattern (only digits, + - * / ( ) . and whitespace).

    Args:
        completion (str): Generated output
        EOS_TOKEN (str): End of sequence token

    Returns:
        float: Reward score
    """
    # Define the allowed pattern (only numbers, +, -, *, /, (, ), ., and whitespace)
    allowed_pattern = r"^[\d+\-*/().\s]+$"

    try:
        # Synthetically prepend <think> (if your pipeline relies on that to ease matching)
        completion = "<think>" + completion

        # Strip EOS token if present
        if completion.endswith(EOS_TOKEN):
            completion = completion[: -len(EOS_TOKEN)]

        # Check if the format is correct
        # Pattern means:
        # 1) <think>...contents not including other <think> tags...</think>
        # 2) \n
        # 3) <answer>...anything...</answer>
        regex = r"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n<answer>([\s\S]*?)<\/answer>$"
        match = re.search(regex, completion, re.DOTALL)

        if match is None or len(match.groups()) != 2:
            # Format is incorrect
            return 0.0
        else:
            # Extract the content inside <answer>...</answer>
            answer_content = match.group(2).strip()

            # Check if answer content matches the allowed pattern
            if not re.match(allowed_pattern, answer_content):
                # If it doesn't match, reward is 0.5
                return 0.5
            else:
                # If both format and pattern are correct, reward is 1
                return 1.0
    except Exception:
        # Any error leads to 0 reward
        return 0.0


def equation_reward_func(completion: str, nums: List[int], target: int) -> float:
    """
    Evaluates completion based on mathematical correctness of the answer

    Args:
        completion (str): Generated output
        target (str): Expected answer
        nums (list): Available numbers to use in the equation

    Returns:
        float: Reward score
    """
    try:
        # add synthetic <think> as its already part of the prompt and prefilled for the assistant to more easily match the regex
        completion = "<think>" + completion
        # Check if the format is correct
        match = re.search(r"<answer>(.*?)<\/answer>", completion)
        if match is None:
            return 0.0
        # Extract the "answer" part from the completion
        equation = match.group(1).strip()
        # Extract all numbers from the equation
        used_numbers = [int(n) for n in re.findall(r"\d+", equation)]

        # Check if all numbers are used exactly once
        if sorted(used_numbers) != sorted(nums):
            return 0.0
        # Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
        allowed_pattern = r"^[\d+\-*/().\s]+$"
        if not re.match(allowed_pattern, equation):
            return 0.0

        # Evaluate the equation with restricted globals and locals
        result = eval(equation, {"__builtins__": None}, {})
        # Check if the equation is correct and matches the ground truth
        if abs(float(result) - float(target)) < 1e-5:
            return 1.0
        else:
            return 0.0
    except Exception:
        # If evaluation fails, reward is 0
        return 0.0
    

def compute_reward(completion: str, sample: Dict[str, Any]) -> Tuple[float, Dict[str, float]]:
    nums = sample["nums"]
    target = sample["target"]

    format_reward = format_reward_func(completion)
    equation_reward = equation_reward_func(
        completion=completion, nums=nums, target=target
    )

    reward = format_reward + equation_reward

    stats = {
        "format_reward": format_reward,
        "equation_reward": equation_reward,
    }   

    return reward, stats


In [8]:
def unit_test_compute_reward():     
    test_cases = [
        ("hello</think>\n<answer>1+2+3+4</answer>", {"nums": [1, 2, 3, 4], "target": 10}),
        ("<think>hello</think>\n<answer>1+2+3+4</answer>", {"nums": [1, 2, 3, 4], "target": 10}),
        ("hello</think>\n<answer>1+2+3+3</answer>", {"nums": [1, 2, 3], "target": 9}),
        ("hello</think>\n<answer>5+6+4+3-1</answer>", {"nums": [1, 3, 4, 6], "target": 17}),
        ("hello</think>\n<answer>(3-1)*(9+7)</answer>", {"nums": [1, 3, 9, 7], "target": 32}),
    ]
    
    for completion, sample in test_cases:
        out_ground = compute_reward(completion, sample)
        out = compute_reward(completion, sample)
        assert out_ground == out

unit_test_compute_reward()

## Episode Generation

In [9]:
def create_training_episodes(
    samples: List[Dict[str, Any]],
    all_generations: List[List[int]],
    all_finish_reasons: List[str],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Process model generations and calculate rewards for training episodes.

    This function processes generated responses and calculates rewards for training episodes by:
    1. Grouping generations by sample (GENERATIONS_PER_SAMPLE responses per input)
    2. Computing rewards and advantages for each response
    3. Processing response tokens (adding EOS tokens where needed)

    Args:
        samples: List of input samples, each containing:
            - input_ids: List[int], tokenized input prompt
            - nums: List[int], numbers to use in equation
            - target: int, target value for equation
        all_generations: List of token ID sequences for each generated response
        all_finish_reasons: List of finish reasons for each generation ("stop" or other)

    Returns:
        Tuple containing:
        1. Dictionary with processed data for training:
            - all_query_token_ids: List[List[int]], input token IDs repeated for each generation
            - all_response_token_ids: List[List[int]], response token IDs with EOS tokens added
            - all_advantages: List[List[float]], advantage values repeated for each token
        2. Dictionary with generation statistics:
            - response_lengths: List[int], lengths of generated responses
            - rewards: List[float], raw reward values
            - non_stop_rate: List[bool], whether each generation ended naturally
            - reward_metrics/*: Various reward component metrics

    Example:
        >>> samples = [{"input_ids": [1,2,3], "nums": [1,2,3], "target": 6}]
        >>> generations = [[4,5], [6,7], [8,9]]  # 3 generations per sample
        >>> finish_reasons = ["stop", "length", "stop"]
        >>> episodes, stats = create_training_episodes(samples, generations, finish_reasons)
        >>> episodes
        {
            'all_query_token_ids': [[1,2,3], [1,2,3], [1,2,3]],
            'all_response_token_ids': [[4,5,EOS], [6,7], [8,9,EOS]],
            'all_advantages': [[0.5,0.5,0.5], [-1.0,-1.0], [0.5,0.5,0.5]]
        }
    """
    assert len(all_generations) == len(all_finish_reasons)
    assert len(all_generations) == len(samples) * GENERATIONS_PER_SAMPLE

    # Process responses and calculate rewards
    groups = [
        list(range(i, i + GENERATIONS_PER_SAMPLE))
        for i in range(0, len(all_generations), GENERATIONS_PER_SAMPLE)
    ]  # example: [[0, 1, 2], [3, 4, 5], [6, 7, 8]]

    all_query_token_ids, all_responses_token_ids, all_advantages = [], [], []

    stats = {
        "response_lengths": [],
        "rewards": [],
        "non_stop_rate": [],
    }

    for sample, group_indices in zip(samples, groups):
        response_token_ids = [all_generations[i] for i in group_indices]
        finish_reasons = [all_finish_reasons[i] for i in group_indices]
        responses = tokenizer.batch_decode(
            response_token_ids, skip_special_tokens=False
        )
        rewards_and_metrics = [
            compute_reward(response, sample) for response in responses
        ]
        rewards, reward_metrics = zip(*rewards_and_metrics)

        rewards = np.array(rewards)
        advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-4)

        per_token_advantages = [
            [adv] * len(resp) for adv, resp in zip(advantages, response_token_ids)
        ]

        response_token_ids = [
            (r + [EOS_TOKEN_ID]) if fr == "stop" else r
            for r, fr in zip(response_token_ids, finish_reasons)
        ]

        all_query_token_ids.extend([sample["input_ids"]] * GENERATIONS_PER_SAMPLE)
        all_responses_token_ids.extend(response_token_ids)
        all_advantages.extend(per_token_advantages)

        stats["non_stop_rate"].extend([fr != "stop" for fr in finish_reasons])
        stats["response_lengths"].extend([len(ids) for ids in response_token_ids])
        for rm in reward_metrics:
            for k, v in rm.items():
                stats.setdefault(f"reward_metrics/{k}", []).extend([v])

    episodes = {
        "all_query_token_ids": all_query_token_ids,
        "all_response_token_ids": all_responses_token_ids,
        "all_advantages": all_advantages,
    }

    return episodes, stats

In [39]:
def unit_test_create_training_episodes():
    test_cases = [
        {"sample": {"input_ids": [1,2,3], "nums": [1,2,3], "target": 6},
         "generations": [[4,5], [6,7], [8,9], [10,11]],
         "finish_reasons": ["stop", "length", "stop", "stop"]},
        
        {"sample": {"input_ids": [33,44], "nums": [11, 7, 8], "target": 26},
         "generations": [[1,2], [3,4], [5,6], [7,8]],
         "finish_reasons": ["stop", "stop", "length", "stop"]},
        
        {"sample": {"input_ids": [9, 8, 7, 6, 5, 4], "nums": [1,2,3,4], "target": 10},
         "generations": [[9,10], [11,12], [13,14], [15,16]],
         "finish_reasons": ["length", "length", "stop", "stop"]}
    ]
    
    for case in test_cases:
        sample = case["sample"]
        generations = case["generations"]
        finish_reasons = case["finish_reasons"]
        
        episodes_ground, stats_ground = create_training_episodes([sample], generations, finish_reasons)
        episodes, stats = create_training_episodes([sample], generations, finish_reasons)
        
        assert episodes_ground == episodes, f"Mismatch in episodes for sample: {sample}"
        assert stats_ground == stats, f"Mismatch in stats for sample: {sample}"
        
unit_test_create_training_episodes()


## Policy Gradient

In [12]:
from typing import Union

from deepspeed import DeepSpeedEngine
from transformers import PreTrainedModel

def compute_pg_loss(
    policy_model: Union[DeepSpeedEngine, PreTrainedModel],
    reference_model: Union[DeepSpeedEngine, PreTrainedModel],
    batch: Dict[str, torch.Tensor],
    total_response_len: int,
) -> Tuple[torch.Tensor, Dict[str, float]]:
    """
    Compute the policy gradient loss with KL penalty between policy and reference models.

    This function:
    1. Computes log probabilities for both policy and reference models
    2. Calculates KL divergence penalty between the models
    3. Computes policy gradient loss using advantages
    4. Combines the losses with KL coefficient

    Args:
        policy_model: The model being trained
        reference_model: The reference model for KL penalty calculation
        batch: Dictionary containing:
            - input_ids: Tensor of shape [batch_size, seq_len]
            - attention_mask: Tensor of shape [batch_size, seq_len]
            - labels: Tensor of shape [batch_size, seq_len] with -100 for ignored positions
            - advantages: Tensor of shape [batch_size, seq_len]

    Returns:
        Tuple containing:
            - loss: Combined policy gradient and KL penalty loss (scalar tensor)
            - metrics: Dictionary with detailed loss components:
                - policy_loss: Pure policy gradient loss
                - kl_penalty: KL divergence penalty
                - entropy: Policy entropy
    """
    input_ids = batch["input_ids"]  # [batch_size, seq_len]
    attention_mask = batch["attention_mask"]  # [batch_size, seq_len]
    labels = batch["labels"]  # [batch_size, seq_len]
    advantages = batch["advantages"]  # [batch_size, seq_len]

    model_inputs = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }

    labels_mask = (labels[..., 1:] != -100).float()  # [batch_size, seq_len-1]

    with torch.no_grad():
        ref_logps = compute_token_log_probs(
            reference_model, model_inputs, TEMPERATURE
        )  # [batch_size, seq_len-1]

    logps = compute_token_log_probs(policy_model, model_inputs, TEMPERATURE)  # [batch_size, seq_len-1]

    kl_penalty = torch.exp(ref_logps - logps) - (ref_logps - logps) - 1  # [batch_size, seq_len-1]
    kl_penalty = kl_penalty * labels_mask  # [batch_size, seq_len-1]

    entropy = -logps.sum() / labels_mask.sum()  # scalar

    policy_loss = -logps * advantages[..., 1:]  # [batch_size, seq_len-1]
    policy_loss = policy_loss * labels_mask  # [batch_size, seq_len-1]

    loss = (policy_loss + KL_COEFFICIENT * kl_penalty).sum() / total_response_len  # scalar

    metrics = {
        "policy_loss": policy_loss.sum().item(),
        "kl_penalty": kl_penalty.sum().item(),
        "entropy": entropy.item(),
    }

    return loss, metrics

## Training

In [None]:
# Initialize main and reference models
policy_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    device_map=0,
)
reference_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    device_map=0,
)
policy_model.gradient_checkpointing_enable()


# Initialize DeepSpeed engines
policy_model, *_ = deepspeed.initialize(
    model=policy_model,
    config=deepspeed_config,
    model_parameters=policy_model.parameters(),
)
reference_model, *_ = deepspeed.initialize(
    model=reference_model,
    config=ref_deepspeed_config,
)

reference_model.module.cpu()

# Initialize SGLang (Inference) engine
inference_engine = sglang.Engine(
    model_path=MODEL_NAME,
    enable_memory_saver=True,
    skip_tokenizer_init=True,
    mem_fraction_static=0.20,
    schedule_policy="fcfs",
    schedule_conservativeness=0.001,
    max_running_requests=10000,
)

In [23]:
def unit_test_compute_pg_loss():
    reference_model.module.cuda()
    test_case = {
        "labels": torch.tensor([[0, 0, 0, 1, 1, 1, 0, 0]], device="cuda"),
        "input_ids": torch.tensor([[1, 2, 3, 4, 5, 6, 10, 10]], device="cuda"),
        "advantages": torch.tensor([[0.1, -0.5, +0.25, 0.3, 0.1, 0.1, 0.0, 0.0]], device="cuda"),
        "attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1, 0, 0]], device="cuda"),
    }
    
    loss_ground, metrics_ground = compute_pg_loss(policy_model, reference_model, test_case, total_response_len=8)
    loss_ground = loss_ground.item()
    
    loss, metrics = compute_pg_loss(policy_model, reference_model, test_case, total_response_len=8)
    loss = loss.item()

    assert loss_ground == loss, f"Mismatch in loss: {loss_ground} != {loss}"
    for k, v in metrics_ground.items():
        assert metrics_ground[k] == metrics[k], f"Mismatch in {k}: {metrics_ground[k]} != {metrics[k]}"
    
unit_test_compute_pg_loss()

In [None]:
reference_model.module.cpu()

In [None]:
# Wandb for logging
wandb.init(
    project="r1-aha-moment",
    name=RUN_NAME,
    config={
        "model_name": MODEL_NAME,
        "learning_rate": LEARNING_RATE,
        "num_iterations": NUM_ITERATIONS,
        "episodes_per_iteration": EPISODES_PER_ITERATION,
        "rollouts_per_episode": GENERATIONS_PER_SAMPLE,
        "kl_coefficient": KL_COEFFICIENT,
        "temperature": TEMPERATURE,
    },
)

### Training loop

In [None]:
for iteration in trange(NUM_ITERATIONS):
    print(f"Iteration {iteration}/{NUM_ITERATIONS}")

    # Sample training batch
    num_samples = EPISODES_PER_ITERATION // GENERATIONS_PER_SAMPLE
    batch_indices = np.random.choice(
        len(train_dataset), size=num_samples, replace=False
    )
    batch_samples = train_dataset.select(batch_indices)

    # Update model weights in SGLang engine
    torch.cuda.empty_cache()
    time.sleep(0)

    inference_engine.resume_memory_occupation()
    success, error = inference_engine.update_weights_from_tensor(
        list(policy_model.module.named_parameters())
    )
    if not success:
        raise RuntimeError(f"Weight update failed: {error}")

    eval_stats = None
    if iteration % 25 == 0:
        eval_stats = evaluate_on_test_set(
            sglang_engine=sglang_engine,
            test_dataset=test_dataset,
            tokenizer=tokenizer,
            EOS_TOKEN=EOS_TOKEN,
            SAMPLING_PARAMS=SAMPLING_PARAMS,
        )
        time.sleep(2)  # so sglang scheduler cools down

    # Generate responses
    outputs = inference_engine.generate(
        input_ids=batch_samples["input_ids"], sampling_params=SAMPLING_PARAMS
    )
    print(f"Generated {len(generations)} responses")
    sglang_engine.release_memory_occupation()
    time.sleep(
        1
    )  # WARNING: hacky, to make sure the memory is released before training

    generation_time = time.time() - generation_start_time
    training_start_time = time.time()

    # Process responses and calculate rewards
    training_episode_data = training_episode_generator(
        generations,
        batch_samples,
        tokenizer,
        EOS_TOKEN,
        EOS_TOKEN_ID,
        ROLLOUTS_PER_EPISODE,
    )

    # Prepare training batch
    model_inputs = prepare_model_inputs(
        training_episode_data["all_queries"], training_episode_data["all_responses"]
    )
    advantages_tensor = torch.tensor(
        training_episode_data["all_advantages"], device="cuda"
    )

    # Calculate losses and update model
    policy_model.train()
    reference_model.module.cuda()
    reference_model.eval()

    total_response_len = (model_inputs["labels"] != -100).sum()

    # Track metrics
    total_policy_loss = 0
    total_kl_penalty = 0
    total_entropy = 0
    grad_norm = 0

    for i in range(0, EPISODES_PER_ITERATION, PER_DEVICE_BATCH_SIZE):
        print(f"Processing batch {i}/{EPISODES_PER_ITERATION}")
        batch_model_inputs = {
            "input_ids": model_inputs["input_ids"][i : i + PER_DEVICE_BATCH_SIZE],
            "attention_mask": model_inputs["attention_mask"][
                i : i + PER_DEVICE_BATCH_SIZE
            ],
            "labels": model_inputs["labels"][i : i + PER_DEVICE_BATCH_SIZE],
        }
        batch_advantages = advantages_tensor[i : i + PER_DEVICE_BATCH_SIZE]

        # Compute policy gradient loss
        total_loss, policy_loss, kl_penalty, entropy = compute_pg_loss(
            policy_model=policy_model,
            reference_model=reference_model,
            batch_model_inputs=batch_model_inputs,
            batch_advantages=batch_advantages,
            total_response_len=total_response_len,
            kl_coefficient=KL_COEFFICIENT,
            temperature=TEMPERATURE,
        )

        # Track metrics
        total_policy_loss += policy_loss.sum().item() / total_response_len
        total_kl_penalty += kl_penalty.sum().item() / total_response_len
        print(f"total_kl_penalty: {total_kl_penalty}")
        total_entropy += entropy.item()
        grad_norm = policy_model.get_global_grad_norm()

        # Backpropagation and optimization step
        policy_model.backward(total_loss, scale_wrt_gas=False)
        # del policy_loss, kl_penalty, entropy, total_loss # free memory, avoid Cuda OOM
        if policy_model.is_gradient_accumulation_boundary():
            reference_model.module.cpu()

        policy_model.step()

    print("Finished training")

    training_time = time.time() - training_start_time
    total_iteration_time = time.time() - iteration_start_time

    # Log metrics to wandb
    stats = {
        "iteration": iteration,
        # Generation quality metrics
        "train/generation/non_stop_rate": training_episode_data["stats"][
            "non_stop_rate"
        ],
        "train/generation/mean_response_length": training_episode_data["stats"][
            "mean_response_length"
        ],
        "train/generation/max_response_length": training_episode_data["stats"][
            "max_response_length"
        ],
        # Overall reward metrics
        "train/reward_mean": training_episode_data["stats"]["reward_mean"],
        "train/reward_std": training_episode_data["stats"]["reward_std"],
        # Format reward metrics
        "train/format_reward/mean": training_episode_data["stats"][
            "format_reward_mean"
        ],
        "train/format_reward/std": training_episode_data["stats"][
            "format_reward_std"
        ],
        # Equation reward metrics
        "train/equation_reward/mean": training_episode_data["stats"][
            "equation_reward_mean"
        ],
        "train/equation_reward/std": training_episode_data["stats"][
            "equation_reward_std"
        ],
        # Training metrics
        "train/policy_loss": total_policy_loss
        / (EPISODES_PER_ITERATION // PER_DEVICE_BATCH_SIZE),
        "train/kl_penalty": total_kl_penalty
        / (EPISODES_PER_ITERATION // PER_DEVICE_BATCH_SIZE),
        "train/entropy": total_entropy
        / (EPISODES_PER_ITERATION // PER_DEVICE_BATCH_SIZE),
        "train/grad_norm": grad_norm,
        "train/learning_rate": policy_model.get_lr()[0],
        # Timing metrics
        "train/generation_time": generation_time,
        "train/training_time": training_time,
        "train/total_iteration_time": total_iteration_time,
    }

    if eval_stats is not None:
        stats.update(eval_stats)

    wandb.log(stats)

    selected_keys = [
        "train/reward_mean",
        "train/format_reward/mean",
        "train/equation_reward/mean",
    ]
    if iteration % 25 == 0:
        selected_keys.extend(
            [
                "test/reward_mean",
                "test/format_reward/mean",
                "test/equation_reward/mean",
            ]
        )
    selected_stats = {k: stats[k] for k in selected_keys}
    print(f"key stats: {selected_stats}")

    if iteration % 1001 == 0:
        save_dir = (
            f"/network/scratch/a/aghajohm/aha_models/r1_aha_moment_{iteration}"
        )
        policy_model.module.save_pretrained(save_dir)