# RLLM SDK Quick Start: Make Any Agent Trainable With Almost No Adaptation

This tutorial shows how to make **any existing agent code** trainable with minimal changes.

**The key insight:** Just replace your OpenAI client with the SDK client, and everything is automatically tracked for training!

*Note: This tutorial focuses on showing you the mechanics. We'll explain how training works at the end.*

## Step 1: Start the Proxy

Start the proxy for testing. During training, the Trainer manages this automatically. The proxy logs all LLM calls to a SQLite database.

In [None]:
from pathlib import Path
from rllm.sdk.proxy.proxy_manager import ProxyManager

# Setup
DB_PATH = "/tmp/rllm_demo.db"
MODEL = "gpt-4o-mini"

openai_api_key = "sk-xxx"  # Fill your openai api key

# Clean up
Path(DB_PATH).unlink(missing_ok=True)

# Start proxy
proxy_manager = ProxyManager(proxy_port=4000, admin_token="my-shared-secret")
config = {
    "model_list": [
        {
            "model_name": MODEL,
            "litellm_params": {
                "model": MODEL,
                "api_key": openai_api_key,
            },
        }
    ]
}
proxy_manager.start_proxy_subprocess(config=config, db_path=DB_PATH, project="demo")
proxy_url = proxy_manager.get_proxy_url(include_v1=True)

print(f"âœ“ Proxy started at {proxy_url}")
print(f"âœ“ Database: {DB_PATH}")

In [None]:
!python rllm/examples/solver_judge/prepare_countdown_data.py

In [None]:
from rllm.data.dataset import DatasetRegistry

train_dataset = DatasetRegistry.load_dataset("countdown", "train")
test_dataset = DatasetRegistry.load_dataset("countdown", "test")

train_dataset[0]

**The Countdown Task:**  
Given a set of numbers and a target, find an arithmetic expression using those numbers to reach the target. Each number can be used at most once. For example: numbers `[30, 32, 76]` and target `78` â†’ solution could be `76 + 32 - 30 = 78`.

## Step 2: Your Original Agent Code

A typical agent using the standard OpenAI client. This agent follows a Solver-Judge workflow: generate multiple solution attempts, then select the best one.

In [None]:
from openai import AsyncOpenAI
import re

judge_prompt = """You are an expert verifier. Given a countdown problem and multiple solution attempts, select a correct solution.
Problem:
{problem}
Solutions to evaluate:
{solutions}
A correct solution must satisfy the following criteria:
1. The solution uses only the given numbers.
2. Each number is used exactly once.
3. Only basic arithmetic operations (+, -, *, /) are used.
4. The calculation results in the target number.
5. The final answer is clearly marked within <answer>...</answer> tags.
Output the index of your selected solution within <answer>...</answer> tags, e.g., <answer>1</answer> for the first solution, <answer>2</answer> for the second solution, etc. If multiple solutions are correct, output the index of the first correct solution."""


class CountdownAgent:
    """A simple math solving agent - ORIGINAL VERSION"""

    def __init__(self, api_key: str, model: str = "gpt-4o-mini"):
        # Standard OpenAI client
        self.client = AsyncOpenAI(api_key=api_key)
        self.model = model

    async def solve(self, problem: str) -> str:
        """Solve a math problem."""
        response = await self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": f"{problem}. Output the final answer within <answer>...</answer>"}],
            max_tokens=1000,
        )
        return self.parse_solver_answer(response.choices[0].message.content)

    async def judge(self, problem, solutions: list[str]) -> str:
        """Judge multiple solutions to a problem."""
        formatted_solutions = "\n".join([f"Solution {i + 1}:\n{sol}\n" for i, sol in enumerate(solutions)])
        prompt = judge_prompt.format(problem=problem, solutions=formatted_solutions)

        response = await self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            max_tokens=1000,
        )
        return response.choices[0].message.content

    def parse_solver_answer(self, solution):
        # Find all <answer> tags and return the last one
        answer_matches = re.findall(r"<answer>(.*?)</answer>", solution, re.IGNORECASE | re.DOTALL)
        if answer_matches:
            return "<answer>" + answer_matches[-1].strip() + "</answer>"
        return "No solution found"

    def parse_selected_solution(self, judgment, solutions):
        # Find all <answer> tags and use the last one
        answer_matches = re.findall(r"<answer>(.*?)</answer>", judgment, re.IGNORECASE | re.DOTALL)
        if answer_matches:
            answer_text = answer_matches[-1].strip()
            try:
                solution_index = int(answer_text)
                return solutions[solution_index - 1]
            except (ValueError, IndexError):
                return ""
        return ""

    async def run(self, problem: str, n_solutions: int = 2) -> str:
        """Generate multiple solutions and judge them."""
        solutions = []
        for i in range(n_solutions):
            sol = await self.solve(problem)
            solutions.append(sol)

        judgment = await self.judge(problem, solutions)
        selected_solution = self.parse_selected_solution(judgment, solutions)
        return selected_solution


# Use it
agent = CountdownAgent(api_key=openai_api_key, model=MODEL)
result = await agent.run(train_dataset[0]["question"])

print(result)

## Step 3: Make It Trainable (2 Simple Changes!)

**Change 1:** Import the SDK client instead of OpenAI client  
**Change 2:** Point to the proxy URL

**What's `session()`?** A lightweight primitive that tracks all LLM calls within its scope and injects metadata into each call. Everything inside `with session()` is automatically grouped and retrievable via `sess._uid`.

In [None]:
from rllm.sdk import get_chat_client_async, session  # Change 1: Import SDK client
import re


class TrainableAgent:
    """A simple math solving agent - TRAINABLE VERSION"""

    def __init__(self, api_key: str, model: str = "gpt-4o-mini"):
        # Replace standard OpenAI client with SDK client
        # self.client = AsyncOpenAI(api_key=api_key)
        self.client = get_chat_client_async(api_key=api_key, base_url=proxy_url)
        self.model = model

    async def solve(self, problem: str) -> str:
        """Solve a math problem."""
        response = await self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": f"{problem}. Output the final answer within <answer>...</answer>"}],
            max_tokens=1000,
        )
        return self.parse_solver_answer(response.choices[0].message.content)

    async def judge(self, problem, solutions: list[str]) -> str:
        """Judge multiple solutions to a problem."""
        formatted_solutions = "\n".join([f"Solution {i + 1}:\n{sol}\n" for i, sol in enumerate(solutions)])
        prompt = judge_prompt.format(problem=problem, solutions=formatted_solutions)

        response = await self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            max_tokens=1000,
        )
        return self.parse_solver_answer(response.choices[0].message.content)

    def parse_solver_answer(self, solution):
        # Find all <answer> tags and return the last one
        answer_matches = re.findall(r"<answer>(.*?)</answer>", solution, re.IGNORECASE | re.DOTALL)
        if answer_matches:
            return "<answer>" + answer_matches[-1].strip() + "</answer>"
        return "No solution found"

    def parse_selected_solution(self, judgment, solutions):
        # Find all <answer> tags and use the last one
        answer_matches = re.findall(r"<answer>(.*?)</answer>", judgment, re.IGNORECASE | re.DOTALL)
        if answer_matches:
            answer_text = answer_matches[-1].strip()
            try:
                solution_index = int(answer_text)
                return solutions[solution_index - 1]
            except (ValueError, IndexError):
                return ""
        return ""

    async def run(self, problem: str, n_solutions: int = 2) -> str:
        """Generate multiple solutions and judge them."""
        solutions = []
        for _ in range(n_solutions):
            sol = await self.solve(problem)
            solutions.append(sol)

        judgment = await self.judge(problem, solutions)
        selected_solution = self.parse_selected_solution(judgment, solutions)
        return selected_solution


# # Use it
agent = TrainableAgent(api_key=openai_api_key, model=MODEL)
with session() as sess:
    result = await agent.run(train_dataset[0]["question"])

print(result)

### Why This Makes It Trainable: Automatic LLM Call Tracking

The `session()` primitive enables training by capturing every LLM interaction. You can access all traces directly via `sess.llm_calls`:

Each trace contains:
- `input`: Prompt messages sent to the model
- `output`: Model's response
- `tokens`: Exact token IDs (ensures correctness, bypasses retokenization issues)

In [None]:
# Access traces directly from the session
traces = sess.llm_calls

print(f"âœ… Retrieved {len(traces)} trace(s)\n")

# Inspect the first trace
trace = traces[0]

print("=" * 70)
print("TRACE DETAILS")
print("=" * 70)
print(f"Model: {trace.model}")
print(f"\nInput Messages:")
for msg in trace.input["messages"]:
    print(f"  [{msg['role']}]: {msg['content']}")
print(f"\nOutput:")
print(f"  {trace.output['choices'][0]['message']['content']}")
print("=" * 70)

print("\nðŸ’¡ This trace contains everything you need for training!")

## Step 4: Add Rewards and Train

Define a reward function that scores agent outputs, then pass it to the trainer:

In [None]:
train_dataset[0]

In [None]:
def extract_solution(solution_str):
    # Look for answer pattern in the entire string
    answer_pattern = r"<answer>(.*?)</answer>"
    match = re.finditer(answer_pattern, solution_str)
    matches = list(match)
    if matches:
        final_answer = matches[-1].group(1).strip()
    else:
        final_answer = None
    return final_answer


def validate_equation(equation_str, available_numbers):
    """Validate that equation only uses available numbers and each number once."""
    try:
        # Extract all numbers from the equation
        numbers_in_eq = [int(n) for n in re.findall(r"\d+", equation_str)]

        # Check if all numbers in equation are available
        available_numbers = sorted(available_numbers)
        numbers_in_eq = sorted(numbers_in_eq)

        # Each number should be used exactly once
        return numbers_in_eq == available_numbers
    except Exception:
        return False


def evaluate_equation(equation_str):
    """Safely evaluate the arithmetic equation using eval() with precautions."""
    try:
        # Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
        allowed_pattern = r"^[\d+\-*/().\s]+$"
        if not re.match(allowed_pattern, equation_str):
            raise ValueError("Invalid characters in equation.")

        # Evaluate the equation with restricted globals and locals
        result = eval(equation_str, {"__builtins__": None}, {})
        return result
    except Exception:
        return None


def reward_fn(solution_str, numbers, target):
    """The scoring function for countdown task.

    Args:
        solution_str: the solution text
        numbers: list of numbers
        target: target number

    Returns:
        float: 1.0 if correct, 0.0 if incorrectet
    """
    equation = extract_solution(solution_str=solution_str)

    if equation is None:
        return 0.0

    # Validate equation uses correct numbers
    if not validate_equation(equation, numbers):
        return 0.0

    # Evaluate equation
    try:
        result = evaluate_equation(equation)

        if result is None:
            return 0.0

        if abs(result - target) < 1e-5:  # Account for floating point precision
            return 1.0
        else:
            return 0.0
    except Exception:
        return 0.0


async def rollout_v1(question: str, ground_truth: str, nums: list, target: float, model="Qwen/Qwen3-4B-Instruct-2507", **kwargs) -> float:
    # we need to provide an rollout function that return a reward
    agent = TrainableAgent(api_key=openai_api_key, model=model)
    # agent = TrainableAgent(api_key=openai_api_key, model="gpt-4o-mini")
    response = await agent.run(question)
    print(response)
    reward = reward_fn(response, nums, target)
    return reward


await rollout_v1(**train_dataset[0], model="gpt-4o-mini")

In [None]:
# Training
from rllm.trainer import AgentTrainer
from hydra import initialize_config_dir, compose
import os

with initialize_config_dir(config_dir="/workspace/rllm/examples/sdk", version_base=None):
    config = compose(config_name="tutorial_config")

trainer = AgentTrainer(
    agent_run_func=rollout_v1,  # or use rollout_v2 for step-level rewards
    config=config,
    train_dataset=train_dataset,
    val_dataset=test_dataset,
)

In [None]:
trainer.train()

## Bonus: Using @trajectory Decorator for Step-Level Control

The `@trajectory` decorator is **equivalent to `with session()`** - both track LLM calls using contextvar.

**Key difference:** Both provide `.steps` access for fine-grained control:
- `with session() as sess:` â†’ `sess.steps` 
- `@trajectory(name="...")` â†’ returns `TrajectoryView` with `.steps`

**When to use:**
- `with session()`: Simple episode tracking
- `@trajectory`: Multi-step agents where you want explicit step-level rewards (e.g., reward solver differently from judge)

In [None]:
from rllm.sdk import trajectory
from rllm.sdk.protocol import TrajectoryView


class TrainableAgentV2:
    """A simple math solving agent - TRAINABLE VERSION V2"""

    def __init__(self, api_key: str, model: str = "gpt-4o-mini"):
        # Replace standard OpenAI client with SDK client
        # self.client = AsyncOpenAI(api_key=api_key)
        self.client = get_chat_client_async(api_key=api_key, base_url=proxy_url)
        self.model = model

    @trajectory(name="solver")
    async def solve(self, problem: str) -> str:
        """Solve a math problem."""
        response = await self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": f"{problem}. Output the final answer within <answer>...</answer>"}],
            max_tokens=1000,
        )
        return response.choices[0].message.content

    @trajectory(name="judge")
    async def judge(self, problem, solutions: list[str]) -> str:
        """Judge multiple solutions to a problem."""
        formatted_solutions = "\n".join([f"Solution {i + 1}:\n{sol}\n" for i, sol in enumerate(solutions)])
        prompt = judge_prompt.format(problem=problem, solutions=formatted_solutions)

        response = await self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            max_tokens=1000,
        )
        return self.parse_solver_answer(response.choices[0].message.content)

    def parse_solver_answer(self, solution):
        # Find all <answer> tags and return the last one
        answer_matches = re.findall(r"<answer>(.*?)</answer>", solution, re.IGNORECASE | re.DOTALL)
        if answer_matches:
            return "<answer>" + answer_matches[-1].strip() + "</answer>"
        return "No solution found"

    def parse_selected_solution(self, judgment, solutions):
        # Find all <answer> tags and use the last one
        answer_matches = re.findall(r"<answer>(.*?)</answer>", judgment, re.IGNORECASE | re.DOTALL)
        if answer_matches:
            answer_text = answer_matches[-1].strip()
            try:
                solution_index = int(answer_text)
                return solutions[solution_index - 1]
            except (ValueError, IndexError):
                return ""
        return ""

    async def run(self, problem: str, n_solutions: int = 2, ground_truth: str = None) -> str:
        """Generate multiple solutions and judge them."""
        solutions = []
        for _ in range(n_solutions):
            sol = await self.solve(problem)
            solutions.append(sol)

        judgment = await self.judge(problem, solutions)
        selected_solution = self.parse_selected_solution(judgment.result, solutions)

        # assign reward for each step in trajectory
        for sol in solutions:
            sol.reward = reward_fn(sol.result, ground_truth)
            sol.steps[0].reward = sol.reward

        judgment.reward = reward_fn(selected_solution, ground_truth)
        judgment.steps[0].reward = judgment.reward

        return solutions + [judgment]


# Use it
async def rollout_v2(question: str, ground_truth: str, **kwargs) -> list[TrajectoryView]:
    agent = TrainableAgentV2(None, model="Qwen/Qwen3-4B-Instruct-2507")
    trajs = await agent.run(question, ground_truth=ground_truth)
    return trajs

## How Does Training Work?

Here's what happens under the hood:

1. **Trace Collection:** The proxy captures all LLM calls (inputs, outputs, tokens, latency)
2. **Reward Assignment:** You define what's good (correct answer = 1.0, wrong = 0.0)
3. **Training Loop:** The trainer feeds traces + rewards to the model
4. **Learning:** The model adjusts weights to maximize rewards
5. **Improvement:** Over time, the model learns successful behaviors

This is reinforcement learning: try different approaches, get feedback, learn what works.

## Design Details (For The Curious)

**Why a proxy?**  
Transparent LLM call interception without modifying agent code. Works with any OpenAI-compatible API.

**How does session tracking work?**  
Uses Python's **contextvar** for automatic context propagation. `with session()` or `@trajectory` creates a context that automatically groups all LLM calls inside it. Thread-safe, zero manual tracking.

**Session vs Trajectory:**  
Both use contextvar under the hood:
- `with session()`: Returns session object with `._uid` for retrieval
- `@trajectory`: Returns `TrajectoryView` with `.steps` for fine-grained control

**Why SQLite storage?**  
Offline training with no live service dependencies. Query and analyze traces anytime.

In [None]:
# Cleanup
proxy_manager.shutdown_proxy()
print("âœ“ Proxy shutdown complete")