# Prerequisites

## Uninstall default colab dependencies

Here, we are uninstalling default dependencies that cause version conflict with rLLM, VERL, and vLLM dependencies

In [None]:
!pip uninstall -y fastai albumentations albucore dopamine-rl bigframes \
  opencv-python opencv-python-headless spacy torchvision

In [None]:
%pip uninstall -y torch torchvision torchaudio numpy || true
%pip uninstall -y gcsfs fsspec
%pip uninstall -y opencv-python opencv-contrib-python opencv-python-headless thinc spacy
# vLLM’s Python deps (versions that play nicely here)

In [None]:
!pip uninstall -y gymnasium browsergym-core browsergym

## Installing
Now we are installing required dependencies to train our solver-judge workflow!
- It may prompt to restart the session. Make sure to do so before running the sunsequent cells.

In [None]:
!pip install --no-cache-dir "vllm==0.8.5.post1" "torch==2.6.0" "torchvision==0.21.0" "torchaudio==2.6.0" "tensordict==0.6.2" torchdata

In [None]:
!pip install "transformers[hf_xet]>=4.57.0" accelerate datasets peft hf-transfer \
    "numpy<2.0.0" "pyarrow>=15.0.0" pandas \
    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler \
    pytest py-spy pyext pre-commit ruff tensorboard

!pip install "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1"

In [None]:
!wget -q https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.4cxx11abiFALSE-cp312-cp312-linux_x86_64.whl
!pip install -q --no-cache-dir flash_attn-2.8.3+cu12torch2.4cxx11abiFALSE-cp312-cp312-linux_x86_64.whl

In [None]:
!wget -q https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl
!pip install -q --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl

In [None]:
!pip install opencv-python
!pip install opencv-fixer && \
    python -c "from opencv_fixer import AutoFix; AutoFix()"

In [None]:
%cd /content
!git clone --recurse-submodules https://github.com/rllm-org/rllm.git src
%cd /content/src
!git switch v0.2
!git submodule update --init --recursive

# Use the VERL that ships inside the repo
%pip install -q -e ./verl
# Install rLLM itself
%pip install -q -e .

# Train Solver and Judge Workflow

rLLM provides AgentWorkFlow engine to train different workflows using the reinforcement learning. You do not have to deal directly with AgentWorkFlow engine. We will just go over how to use AgentTrainer on your workflow logic.  

## Solver and Judge definition

Here, we'll define a custom workflow, which is SolverJudgeWorkFlow in this tutorial.

---

### Solver Class
`Solver` class generates n different solutions to the input problem (in parallel). It returns a list of n trajectories (without reward).

### Judge Class
`Judge` class selects the best solution from among the candidates generated by the solver. It returns a trajectory (without reward) containing the selected solution.

**Note:** Both classes query the model using the `RolloutEngine`.


In [None]:
import os

os.environ["VLLM_USE_V1"] = "1"

In [None]:
import asyncio
import re

from rllm.agents.agent import Episode, Step, Trajectory
from rllm.engine import ModelOutput, RolloutEngine
from rllm.rewards.reward_fn import RewardFunction
from rllm.workflows.workflow import Workflow


class Solver:
    def __init__(self, rollout_engine: RolloutEngine, **kwargs):
        self.rollout_engine = rollout_engine

    async def generate_solution(self, problem: str) -> Trajectory:
        messages = [{"role": "user", "content": f"{problem}. Output the final answer within <answer>...</answer>"}]
        output: ModelOutput = await self.rollout_engine.get_model_response(messages)
        return Trajectory(
            name="solver",
            steps=[
                Step(
                    chat_completions=messages + [{"role": "assistant", "content": output.content, "reasoning": output.reasoning}],
                    thought=output.reasoning,
                    action=self._parse_solver_response(output.content),
                    model_output=output,
                )
            ],
        )

    async def generate_solutions(self, problem: str, n_solutions: int = 2) -> list[Trajectory]:
        tasks = [asyncio.create_task(self.generate_solution(problem)) for _ in range(n_solutions)]
        return await asyncio.gather(*tasks)

    def _parse_solver_response(self, response: str) -> str:
        answer_match = re.search(r"<answer>(.*?)</answer>", response, re.IGNORECASE | re.DOTALL)
        if answer_match:
            return f"<answer>{answer_match.group(1).strip()}</answer>"
        else:
            return "No solution found"


class Judge:
    def __init__(self, rollout_engine: RolloutEngine, **kwargs):
        self.rollout_engine = rollout_engine

    async def judge_solutions(self, problem: str, solutions: list[str]) -> Trajectory:
        messages = [{"role": "user", "content": self._create_judge_prompt(problem, solutions)}]
        output: ModelOutput = await self.rollout_engine.get_model_response(messages)
        return Trajectory(
            name="judge",
            steps=[
                Step(
                    chat_completions=messages + [{"role": "assistant", "content": output.content, "reasoning": output.reasoning}],
                    thought=output.reasoning,
                    action=self._parse_judge_response(output.content, solutions),
                    model_output=output,
                )
            ],
        )

    def _parse_judge_response(self, response: str, solutions: list[str]) -> str:
        answer_match = re.search(r"<answer>(.*?)</answer>", response, re.IGNORECASE | re.DOTALL)
        if answer_match:
            answer_text = answer_match.group(1).strip()
            try:
                solution_index = int(answer_text)
                return solutions[solution_index - 1]
            except (ValueError, IndexError):
                return ""
        return ""

    def _create_judge_prompt(self, problem: str, solutions: list[str]) -> str:
        """Create a prompt for the judge to evaluate solutions."""
        prompt = f"""You are an expert verifier. Given a countdown problem and multiple solution attempts, select a correct solution.
Problem:
{problem}
Solutions to evaluate:
"""
        for i, solution in enumerate(solutions, 1):
            prompt += f"\nSolution {i}:\n{solution}\n"

        prompt += """
A correct solution must satisfy the following criteria:
1. The solution uses only the given numbers.
2. Each number is used exactly once.
3. Only basic arithmetic operations (+, -, *, /) are used.
4. The calculation results in the target number.
5. The final answer is clearly marked within <answer>...</answer> tags.
Output the index of your selected solution within <answer>...</answer> tags, e.g., <answer>1</answer> for the first solution, <answer>2</answer> for the second solution, etc. If multiple solutions are correct, output the index of the first correct solution."""
        return prompt


class SolverJudgeWorkflow(Workflow):
    def __init__(self, rollout_engine: RolloutEngine, n_solutions: int = 2, reward_function: RewardFunction = None, **kwargs):
        super().__init__(rollout_engine, **kwargs)
        self.n_solutions = n_solutions
        self.reward_function = reward_function
        self.solver = Solver(rollout_engine)
        self.judge = Judge(rollout_engine)

    async def run(self, task: dict, uid: str, **kwargs) -> Episode:
        self.reset(task, uid)
        problem = task["question"]

        # Step 1: Solver generates multiple solutions in parallel
        solver_trajectories = await self.solver.generate_solutions(problem, self.n_solutions)

        # Assign rewards to solver trajectories
        solutions = []
        for traj in solver_trajectories:
            solution = traj.steps[0].action
            solutions.append(solution)
            reward = self.reward_function(task, solution).reward
            traj.steps[0].reward = reward

        # Step 2: Judge selects the best solution
        judge_trajectory = await self.judge.judge_solutions(problem, solutions)
        selected_solution = judge_trajectory.steps[0].action

        # Evaluate the selected solution
        reward_result = self.reward_function(task, selected_solution)
        judge_trajectory.steps[0].reward = reward_result.reward
        is_correct = reward_result.is_correct

        # Compute metrics
        solver_acc = sum(traj.steps[0].reward for traj in solver_trajectories) / len(solver_trajectories)
        judge_acc = int(is_correct)

        # Step 3: Return episode with multiple trajectories
        return Episode(
            id=uid,
            task=task,
            trajectories=[*solver_trajectories, judge_trajectory],
            is_correct=is_correct,
            metrics={"solver_acc": solver_acc, "judge_acc": judge_acc},
        )

## Dataset Creation

We are getting the countdown task dataset from Huggingface.

In [None]:
import random

from datasets import load_dataset

from rllm.data.dataset import DatasetRegistry


def prepare_countdown_data():
    """
    Prepare the countdown task dataset from HuggingFace.
    Take 1024 examples as test set, remaining as training set.
    Also create stage 2 and stage 3 training sets with 50k examples each.
    """
    # Load the countdown dataset
    dataset = load_dataset("Jiayi-Pan/Countdown-Tasks-3to4", split="train")

    # Split dataset: 1024 examples for test, rest for training
    test_size = 1024
    total_size = len(dataset)

    # Create train/test split
    test_dataset = dataset.select(range(test_size))
    train_dataset = dataset.select(range(test_size, total_size))

    def preprocess_fn(example, idx):
        """
        Convert countdown task format to math problem format.
        Example: target=98, nums=[44, 19, 35] becomes a math word problem.
        """
        target = example["target"]
        nums = example["nums"]

        # Format as a math problem
        nums_str = ", ".join(map(str, nums))
        question = f"Using the numbers {nums_str}, find a way to reach the target number {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your step-by-step calculation and output the final answer within <answer>...</answer>, for example <answer> (1 + 2) / 3 </answer>."

        return {
            "question": question,
            "ground_truth": str(target),
            "data_source": "countdown",
            "target": target,
            "nums": nums,
        }

    # Apply preprocessing
    train_dataset = train_dataset.map(preprocess_fn, with_indices=True)
    test_dataset = test_dataset.map(preprocess_fn, with_indices=True)

    # Create stage 2 and stage 3 training datasets
    train_size = len(train_dataset)
    # stage_size = 50000
    stage_size = 5

    # Ensure we have enough data for both stages
    if train_size < 2 * stage_size:
        print(f"Warning: Training set has only {train_size} examples, but need {2 * stage_size} for both stages")
        stage_size = min(stage_size, train_size // 2)

    # Shuffle and select indices for stage 2 and stage 3
    all_indices = list(range(train_size))
    random.shuffle(all_indices)

    stage2_indices = all_indices[:stage_size]
    stage3_indices = all_indices[stage_size : 2 * stage_size]

    # Create stage datasets
    stage2_dataset = train_dataset.select(stage2_indices)
    stage3_dataset = train_dataset.select(stage3_indices)

    # Register datasets
    train_dataset = DatasetRegistry.register_dataset("countdown", train_dataset, "train")
    test_dataset = DatasetRegistry.register_dataset("countdown", test_dataset, "test")
    stage2_dataset = DatasetRegistry.register_dataset("countdown", stage2_dataset, "stage2_train")
    stage3_dataset = DatasetRegistry.register_dataset("countdown", stage3_dataset, "stage3_train")

    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Test dataset size: {len(test_dataset)}")
    print(f"Stage 2 train dataset size: {len(stage2_dataset)}")
    print(f"Stage 3 train dataset size: {len(stage3_dataset)}")

    return train_dataset, test_dataset, stage2_dataset, stage3_dataset


if __name__ == "__main__":
    train_dataset, test_dataset, stage2_dataset, stage3_dataset = prepare_countdown_data()
    print("Train dataset path:", train_dataset.get_data_path())
    print("Test dataset path:", test_dataset.get_data_path())
    print("Stage 2 train dataset path:", stage2_dataset.get_data_path())
    print("Stage 3 train dataset path:", stage3_dataset.get_data_path())

    # Print a sample
    print("\nSample train example:")
    print(train_dataset[0])
    print("\nSample stage 2 train example:")
    print(stage2_dataset[0])
    print("\nSample stage 3 train example:")
    print(stage3_dataset[0])

## Training configuration
In this section, we are configuring the trainer with information such as the model, batch size, Wandb API key to log, and the engine.
Here, we are using OmegaConf to load the base `agent_ppo_trainer` config and merges overrides written configs, including specific PPO settings.

For now, LoRA is disabled but it can be enabled by setting it to positive number.

In [None]:
import os

os.chdir("/content/src")
os.environ["WANDB_API_KEY"] = "YOUR WANDB API KEY!!!"


from rllm.data.dataset import DatasetRegistry
from omegaconf import OmegaConf
from rllm.trainer.agent_trainer import AgentTrainer
from rllm.rewards.countdown_reward import countdown_reward_fn
from hydra import compose, initialize_config_module
from hydra.core.global_hydra import GlobalHydra
import torch


# Detect available GPUs and CPUs
num_gpus = torch.cuda.device_count()
num_cpus = os.cpu_count() or 8
print(f"Detected {num_gpus} GPUs and {num_cpus} CPUs")

# Scale configuration based on available hardware
is_single_gpu = num_gpus == 1
batch_size = 1 if is_single_gpu else (64 if num_gpus >= 8 else 16)
n_parallel = 1 if is_single_gpu else (128 if num_gpus >= 8 else 16)


with initialize_config_module(version_base=None, config_module="rllm.trainer.config"):
    base_config = compose(config_name="agent_ppo_trainer")

overrides = OmegaConf.create(
    {
        "data": {
            "train_batch_size": batch_size,
            "max_prompt_length": 1024,
            "max_response_length": 1024,
            "dataloader_num_workers": 0,
        },
        "actor_rollout_ref": {
            "model": {
                "path": "Qwen/Qwen3-0.6B",
                "enable_gradient_checkpointing": True,
                "lora_rank": 0,  # Set to positive value to enable LoRA
                "lora_alpha": 2,
                "use_remove_padding": True,
            },
            "actor": {
                "optim": {"lr": 1e-6},
                "loss_agg_mode": "seq-mean-token-mean",
                "use_dynamic_bsz": True,
                "ppo_max_token_len_per_gpu": 32768,
                "ppo_mini_batch_size": batch_size,
                "use_kl_loss": False,
                "kl_loss_coef": 0.001,
                "kl_loss_type": "low_var_kl",
                "entropy_coeff": 0.0,
                "clip_ratio_low": 0.2,
                "clip_ratio_high": 0.28,
                "ulysses_sequence_parallel_size": 1,
                "fsdp_config": {
                    "param_offload": is_single_gpu,
                    "optimizer_offload": is_single_gpu,
                },
            },
            "rollout": {
                "name": "vllm",
                "mode": "async",
                "enforce_eager": False,
                "temperature": 0.6,
                "gpu_memory_utilization": 0.5,
                "tensor_model_parallel_size": 1,
                "n": 1,
                "val_kwargs": {
                    "n": 1,
                    "temperature": 0.6,
                    "top_p": 0.95,
                },
                "load_format": "auto",
            },
            "ref": {
                "fsdp_config": {
                    "param_offload": is_single_gpu,
                },
            },
            "hybrid_engine": True,
        },
        "algorithm": {
            "adv_estimator": "grpo",
        },
        "rllm": {
            "workflow": {
                "use_workflow": True,
                "n_parallel_tasks": n_parallel,
                "retry_limit": 1,
            },
            "stepwise_advantage": {
                "enable": True,
                "mode": "per_step",
            },
            "compact_filtering": {
                "enable": True,
                "mask_max_prompt_length_exceeded": True,
                "mask_max_response_length_exceeded": True,
                "mask_max_turns_exceeded": False,
                "mask_timeout": True,
            },
            "rejection_sample": {
                "enable": False,
                "multiplier": 1.0,
            },
        },
        "trainer": {
            "critic_warmup": 0,
            "project_name": "solver-judge-workflow",
            "experiment_name": "countdown-solver-judge",
            "total_epochs": 1,
            "n_gpus_per_node": num_gpus if num_gpus > 0 else 1,
            "nnodes": 1,
            "logger": ["console", "wandb"],  # add wandb if you have API_KEY
            "val_before_train": True,
            "test_freq": 5,
            "save_freq": 1000,
            "default_hdfs_dir": None,
        },
    }
)


train_config = OmegaConf.merge(base_config, overrides)


# Load datasets
train_dataset = DatasetRegistry.load_dataset("countdown", "train")
test_dataset = DatasetRegistry.load_dataset("countdown", "test")

# Create trainer
trainer = AgentTrainer(
    workflow_class=SolverJudgeWorkflow,
    workflow_args={
        "n_solutions": 2,
        "reward_function": countdown_reward_fn,
    },
    config=train_config,
    train_dataset=train_dataset,
    val_dataset=test_dataset,
)

print("Trainer ready!")

In [None]:
trainer.train()

In [None]:
!pip show vllm | grep Version

# Train Visualization

As we saved our training logs into Wandb, we can use the following code to plot the results. Make sure to replace wandb_run with actually run created from the training above.

In [None]:
# Install wandb if not already installed
!pip install -q wandb

import wandb
import matplotlib.pyplot as plt
import pandas as pd
from google.colab import auth

# Login to wandb (will prompt for API key if not logged in)
wandb.login()

# Initialize wandb API
api = wandb.Api()

# Fetch the specific run
wandb_run = "YOUR WANDB RUN"
run = api.run(wandb_run)

# Get run history (metrics over time)
history = run.history()

# Print available columns
print("Available metrics:")
print(history.columns.tolist())
print(f"\nTotal steps: {len(history)}")

# Create visualizations for all numeric columns
numeric_cols = history.select_dtypes(include=["float64", "int64"]).columns.tolist()
# Remove _step and _timestamp columns
numeric_cols = [col for col in numeric_cols if not col.startswith("_")]

if numeric_cols:
    # Calculate number of subplots needed
    n_metrics = len(numeric_cols)
    n_cols = 2
    n_rows = (n_metrics + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))
    axes = axes.flatten() if n_metrics > 1 else [axes]

    for idx, metric in enumerate(numeric_cols):
        ax = axes[idx]
        ax.plot(history[metric], linewidth=2)
        ax.set_xlabel("Step")
        ax.set_ylabel(metric)
        ax.set_title(f"{metric} over time")
        ax.grid(True, alpha=0.3)

    # Hide unused subplots
    for idx in range(n_metrics, len(axes)):
        axes[idx].axis("off")

    plt.tight_layout()
    plt.show()
else:
    print("No numeric metrics found to plot")

# Print summary statistics
print("\nRun Summary:")
for key, value in run.summary.items():
    print(f"{key}: {value}")

# Inference
We can also run solver-judge workflow with vLLM.

## vLLM inference

In [None]:
import time
import requests

# Configuration
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
PORT = 30000


def is_server_running():
    try:
        response = requests.get(f"http://localhost:{PORT}/v1/models", timeout=2)
        return response.status_code == 200
    except:
        return False


# Start or check server
if is_server_running():
    print(f"Serverrunning on port {PORT}")
else:
    print(f"Starting vLLM with {MODEL_NAME}...")

    # Start vLLM server in background
    !nohup python -m vllm.entrypoints.openai.api_server \
        --model {MODEL_NAME} \
        --port {PORT} \
        --max-model-len 4096 \
        > /dev/null 2>&1 &

    print("Server starting in background")

# Save config
SERVER_CONFIG = {"model_name": MODEL_NAME, "base_url": f"http://localhost:{PORT}/v1", "port": PORT}

print(f"\nServer URL: {SERVER_CONFIG['base_url']}")

# Misc

When it shows error that there is no enough GPU, run the following code to shutdown ray instances, then restart the trainer above.

In [None]:
import ray

ray.shutdown()  # reset if previously inited