# Trace rollouts with Weave to debug faster and reduce regressions

_Authored by: [Aakash Kumar Nain](...)_


This notebook extends the [TRL GRPO tutorial](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl). We add `wandb` for logging, `Weave` for tracing, and small hyperparameter tweaks to show RL sensitivity. <br>

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/rl_examples/blob/main/trl_examples/numinamath_grpo.ipynb)


### Why trace rollouts?
Compared to plain SFT, tuning models with RL is notoriously hard. On the surface, it may look very similar, especially if you are using packages like `TRL` for doing SFT and RL for LLMs, but it comes with many nuances. For example:

1. Is the reward trajectory a good enough signal to track the progress of a training run? How do you know if it is a normal run or a side effect of reward hacking?
2. What is the best way to define a grader? Is a static grader enough for grading your rollouts?
3. Is the model following the constraints? For example, say we want the model to always generate the reasoning tokens within `<think>...</think>` tags and the answer within `<answer>...</answer>` tags. As training progresses, how do we know if the model is following the right format? What if it follows one format correctly but ignores the other one completely? How do we debug that during the training run?
4. There is a fair chance that the model can generate structurally correct responses, yet the semantics of those responses are completely wrong. How far into the training run do we go before we decide to stop and start over, maybe with some changes?
5. If your trajectory is rewarding, does that mean your responses are improving? How do you compare the completions for the same samples at different training steps?

### Weave: The value proposition for RL
Though Weave can do a lot more, it provides an easy solution to the above questions. For example, tracing with Weave looks like shown below. Here you can see the raw response as well as the rendered responses. This gives you an idea of how well-formed the model completions are as training progresses. It makes it easy to identify what the model is missing or overemphasizing during the course of training. <br>

<img src="https://raw.githubusercontent.com/wandb/rl_examples//main/assets/weave_rollouts.png" width="70%"><br><br>



You can even compare rollouts, as shown below. A good example of comparing rollouts is that it is easy to detect whether the model is following the correct formats, is on the brink of mode collapse, or is on the path of reward hacking. <br>

<img src="https://raw.githubusercontent.com/wandb/rl_examples//main/assets/weave_compare_rollouts.png" width="70%">


Without further ado, let's take a look at the code

In [None]:
# !pip install  -U -q trl peft math_verify wandb weave

In [None]:
import os

# Do not forget to set your wandb API key
os.environ["WANDB_API_KEY"] = ""
if not os.environ.get("WANDB_API_KEY"):
    raise ValueError("WANDB_API_KEY not found!")

# optional: set the project name
os.environ["WANDB_PROJECT"] = "trl_grpo_numinamath_example"

import wandb
wandb.login()


import re
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import GenerationConfig
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from trl import GRPOConfig, GRPOTrainer,  WeaveCallback
from math_verify import LatexExtractionConfig, parse, verify

### 1. Dataset
For the demonstration purpose, we will sample a fraction of the training and test splits from the dataset and use them duringin this notebook.

In [None]:
dataset_id = 'AI-MO/NuminaMath-TIR'
train_dataset, eval_dataset = load_dataset(dataset_id, split=['train[:5%]', 'test[:25%]'])
print(f"Number of samples in the training dataset: {len(train_dataset)}")
print(f"Number of samples in the evaluation dataset: {len(eval_dataset)}")
print(f"Column names: {train_dataset.column_names}")
print("\n======== Sample data point ============")
sample = train_dataset.take(1)[0]
print(f"\nProblem:\n{sample["problem"]}")
print(f"\nSolution:\n{sample["solution"]}")

### 1.2 System Prompt
We need to give the model clear instructions for generating responses. For example, we must tell the model to add special tags for the thinking process and the final answer. Let’s lay out those details in the prompt. You can try modifying the prompt and provide better instructions.

In [None]:
SYSTEM_PROMPT = r"""You are an expert math solver. Solve the given question step by step. Strictly follow the format given below:

- Use <think>...</think> tags for thinking process. All the intermediate steps for solving the question goes between these tags.
- Use <answer>...</answer> tags for the answer. Answer should be strictly one word.
"""

### 1.3 Modify the dataset
We need to prepend the system prompt to every problem in the dataset, and convert the dataset into conversation format.

In [None]:
def make_conversation(example):
    return {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": example["problem"]},
        ],
    }

train_dataset = train_dataset.map(make_conversation)
eval_dataset = eval_dataset.map(make_conversation)

# Optional: Remove unsued columns from the training dataset
train_dataset = train_dataset.remove_columns(['messages', 'problem'])

# Check the samples again
print("Sample data point")
sample = train_dataset.take(1)[0]
print(f"\nProblem:\n{sample["prompt"][0]["content"]}")
print("-"*50)
print(f"\nSolution:\n{sample["solution"]}")

### 2. Model

The original notebook used `Qwen2-0.5B-Instruct`. We will use `Qwen2.5-3B-Instruct` instead, because models at or above 3B parameters are usually more capable and follow instructions better. If you are resource constrained, choose any other model from the list given as per your hardware specs.

- `Qwen/Qwen2.5-1.5B-Instruct`
- `Qwen/Qwen2.5-0.5B-Instruct`
- `Qwen/Qwen2-1.5B-Instruct`
- `Qwen/Qwen2-0.5B-Instruct`

In [None]:
model_id = "Qwen/Qwen2.5-3B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    dtype="auto",
    device_map="auto",
)

### 2.1 Configuring LoRA

Next, we will configure LoRA for model training, allowing us to efficiently train the model with a reduced number of parameters. We will increase the `rank` from `8`(as used in the original code) to `16`. We will not use dropout in this example. These are the hyperparameters that you can play with to get a better model.

In [None]:
lora_config = LoraConfig(task_type="CAUSAL_LM", r=16, lora_alpha=32)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

### 2.2 Reward Functions
We need to grade the model’s completions, so we need a grading system. We can use simple static checkers or a reward model. Choose based on the resources you have. For now, we will use static checkers to grade two aspects:

1. **Format correctness:** Does the model generate responses in the correct format? In our case, the model should place thinking tokens inside `<think>...</think>` and the final answer inside `<answer>...</answer>`.
2. **Numerical correctness:** Does the model produce the right answer for a given problem? In this example, we will use the `math_verify` package to check correctness.

Here we have used partial correctness reward in the format check. It is optional, and you can do a simple `0-1` credit assignment here.

In [None]:
# Regex patterns
PATTERN_FULL = re.compile(r"^<think>\n?.*?\n?</think>\n?.*?\n?<answer>\n?.*?\n?</answer>$", re.DOTALL)
PATTERN_THINK = re.compile(r"<think>\n?(.*?)\n?</think>", re.DOTALL)
PATTERN_ANSWER = re.compile(r"<answer>\n?(.*?)\n?</answer>", re.DOTALL)


def has_valid_format(text):
    """Check if text follows the full <think>...</think><answer>...</answer> structure."""
    return bool(PATTERN_FULL.match(text))

def extract_sections(text):
    """Extracts the content inside <think> and <answer> tags.
    Returns a tuple (think_text, answer_text), or (None, None) if not found.
    """
    think_match = PATTERN_THINK.search(text)
    answer_match = PATTERN_ANSWER.search(text)
    think_text = think_match.group(1).strip() if think_match else None
    answer_text = answer_match.group(1).strip() if answer_match else None
    return think_text, answer_text


def accuracy_reward(completions, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth."""
    solutions = kwargs['solution']
    completion_contents = [completion[0]["content"] for completion in completions]
    rewards = [0.0] * len(completion_contents)

    for i, (ground_truth, prediction) in enumerate(zip(solutions, completion_contents)):
        try:
            gold_parsed = parse(
                ground_truth,
                extraction_mode="first_match",
                extraction_config=[LatexExtractionConfig()],
                parsing_timeout=10, # keeping it 10 seconds for now
            )
            answer_parsed = parse(
                prediction,
                extraction_mode="first_match",
                extraction_config=[LatexExtractionConfig()],
                parsing_timeout=10,
            )
        except Exception as ex:
            print("Something went wrong when parsing predictions: ", ex)
            gold_parsed = []
            answer_parsed = []
            
        if len(gold_parsed) != 0:
            try:
                rewards[i] = float(verify(answer_parsed, gold_parsed))
            except Exception:
                rewards[i] =  0.0
        else:
            rewards[i] = 1.0
    return rewards



def format_reward(completions, **kwargs):
    """Reward function that gives partial credit:
      +0.15 if <think> and <answer> structure exists else -0.15
      +0.10 if <think> and <answer> tags contain text else -0.10
      +0.25 if <answer> section is not empty else -0.25
    """
    completion_contents = [completion[0]["content"] for completion in completions]
    rewards = [0.0] * len(completion_contents)

    for i, content in enumerate(completion_contents):
        score = 0.0

        if has_valid_format(content):
            score += 0.15
            think_text, answer_text = extract_sections(content)
            # Another way to constrain the length of the thinking
            # response. This is optional though, and you can modify
            # it as you see fir for your use case
            if 30 <= len(think_text.split()) <= 128:
                score += 0.10
            else:
                score += -0.10
            # We want the final answer to be strictly one-word.
            # Another option could have been to remove this constraint but
            # ask the model to put the final answer in \boxed{}.
            if len(answer_text.split()) == 1:
                score += 0.25
            else:
                score += -0.25
        else:
            score += -0.15
        rewards[i] = round(score, 2)
    return rewards

### 2.3 Configuring GRPO Training Parameters

Next, let us configure the training parameters for `GRPO`. The values here are slightly different from the original. It is recommended to try different hyperparameters and see how they affect the training run.

PS: For starters, focus on `completion_length`, `num_generations`. For example:
- Do you want your model to generate longer thoughts or keep them short?
- What is the average number of reasoning tokens in the training set?
- Should your model produce a chain of thought of similar length to the dataset or not?
- Should you penalize the model if it starts drifting a lot?


Also, hyperparameters such as `batch_size_per_device` and `gradient_accumulation_steps` can be set based on your hardware specs. These values are good enough to start, but there is plenty of room to tune them.

In [None]:
learning_rate = 3e-5
lr_scheduler_type = "linear"
weight_decay = 0.0
remove_unused_columns = False # to access the solution column in accuracy_reward
gradient_accumulation_steps = 4
num_train_epochs = 1
bf16 = True
per_device_train_batch_size = 64
gradient_checkpointing = True
torch_compile = False

num_generations = 4
max_prompt_length = 256
max_completion_length = 512
temperature = 1.0
top_p = 1.0
top_k = None # turn off topk

# Parameters related to reporting and saving
report_to = ["wandb"]
logging_steps = 5
eval_steps = 5
save_steps = 10
push_to_hub = False
save_strategy = "steps"
eval_strategy = "steps"

### 2.4 Training
Let us train the model now with the chosen hyperparameters.

In [None]:
# Configure training arguments using GRPOConfig
training_args = GRPOConfig(
    output_dir="qwen_numinamath_rl_3b",
    learning_rate=learning_rate,
    lr_scheduler_type=lr_scheduler_type,
    weight_decay=weight_decay,
    remove_unused_columns=remove_unused_columns,
    gradient_accumulation_steps=gradient_accumulation_steps,
    num_train_epochs=num_train_epochs,
    bf16=bf16,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_checkpointing=gradient_checkpointing,
    torch_compile=torch_compile,

    num_generations=num_generations,
    max_prompt_length=max_prompt_length,
    max_completion_length=max_completion_length,
    temperature=temperature,
    top_p=top_p,
    top_k=top_k,

    report_to=report_to,
    logging_steps=logging_steps,
    eval_steps=eval_steps,
    push_to_hub=push_to_hub,
    save_strategy=save_strategy,
    save_steps=save_steps,
    eval_strategy=eval_strategy,

    # Weights for the reward functions.
    # You are free to change this and see observe the model behavior
    reward_weights=[0.5, 0.5],
)

trainer = GRPOTrainer(
    model=model,
    reward_funcs=[format_reward, accuracy_reward],
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)


# To trace the rollouts during training, we will use `weave`. TRL
# already comes with a tight integration with weave. We can use the
# `WeaveCallback` by attaching it to the trainer, and trace the
# rollouts.

# Init the callback
weave_callback = WeaveCallback(
    trainer=trainer,
    # you can configure these generation params
    generation_config=GenerationConfig(temperature=temperature, max_new_tokens=max_completion_length),
)
# Add the callback to the trainer
trainer.add_callback(weave_callback)

# Train the model
trainer.train()

You can view the traces for the logged steps in the dashboard as show below:

<img src=https://raw.githubusercontent.com/wandb/rl_examples//main/assets/trl/numinamath/grpo/1.png width="70%">

<br><br><br>

You can then select any trace and view the logged prompts and completion as shown below:

<img src=https://raw.githubusercontent.com/wandb/rl_examples//main/assets/trl/numinamath/grpo/2.png width="70%">

