# Synth GEPA Demo - OOLONG RLM (MIT)

This notebook demonstrates running GEPA prompt optimization through a **Synth Local API** task app that uses an **RLM (Recursive Language Model)** from the `rlm` library on the OOLONG dataset.

We build a local task app that calls `rlm.RLM` for each rollout, expose it via a tunnel (or local URL), and then run a GEPA job with Synth.


In [None]:
# Step 0: Install dependencies (Colab only)
import sys

if "google.colab" in sys.modules:
    import os

    _INSTALLED_MARKER = "/content/.synth_deps_v2"

    if os.path.exists(_INSTALLED_MARKER):
        print("Dependencies ready.")
    else:
        print("Installing dependencies...")
        %pip install -q git+https://github.com/alexzhang13/rlm.git datasets synth-ai

        with open(_INSTALLED_MARKER, 'w') as f:
            f.write("ok")

        print("Done! Restarting runtime...")
        os.kill(os.getpid(), 9)
else:
    print("Not in Colab - assuming dependencies installed")

## Step 1: Setup

Configure imports and API keys in one place.


In [None]:
# Step 1: Setup - imports, config, and API keys
import os
import json
import textwrap
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional

from datasets import load_dataset

from synth_ai.sdk.api.train.prompt_learning import PromptLearningJob
from synth_ai.sdk.localapi import LocalAPIConfig, create_local_api
from synth_ai.sdk.localapi.auth import ensure_localapi_auth
from synth_ai.sdk.localapi.helpers import extract_api_key
from synth_ai.sdk.task import run_server_background
from synth_ai.sdk.task.contracts import RolloutMetrics, RolloutRequest, RolloutResponse, TaskInfo
from synth_ai.sdk.tunnels import TunneledLocalAPI, TunnelBackend, kill_port, wait_for_health_check

from rlm import RLM
from rlm.utils.prompts import RLM_SYSTEM_PROMPT, USER_PROMPT, USER_PROMPT_WITH_ROOT

# Work around rlm QueryMetadata typing bug under Python 3.11
from rlm.core import rlm as rlm_core
from rlm.core import types as rlm_types


class PatchedQueryMetadata:
    def __init__(self, prompt):
        if isinstance(prompt, str):
            self.context_lengths = [len(prompt)]
            self.context_type = "str"
        elif isinstance(prompt, dict):
            self.context_lengths = [len(chunk) for chunk in prompt.values()]
            self.context_type = "dict"
        elif isinstance(prompt, list):
            self.context_type = "list"
            if prompt and isinstance(prompt[0], dict):
                if "content" in prompt[0]:
                    self.context_lengths = [len(chunk["content"]) for chunk in prompt]
                else:
                    self.context_lengths = [len(chunk) for chunk in prompt]
            else:
                self.context_lengths = [len(chunk) for chunk in prompt]
        else:
            raise ValueError(f"Invalid prompt type: {type(prompt)}")

        self.context_total_length = sum(self.context_lengths)


rlm_types.QueryMetadata = PatchedQueryMetadata
rlm_core.QueryMetadata = PatchedQueryMetadata


def patched_build_rlm_system_prompt(system_prompt, query_metadata=None, **_kwargs):
    return [
        {"role": "system", "content": system_prompt},
        {"role": "assistant", "content": "{context_metadata}"},
    ]


from rlm.utils import prompts as rlm_prompts

rlm_prompts.build_rlm_system_prompt = patched_build_rlm_system_prompt
rlm_core.build_rlm_system_prompt = patched_build_rlm_system_prompt


from synth_ai.core.urls import BACKEND_URL_BASE
from synth_ai.sdk.auth import get_or_mint_synth_api_key

SYNTH_API_BASE = BACKEND_URL_BASE
SYNTH_API_KEY = get_or_mint_synth_api_key(backend_url=SYNTH_API_BASE)
ENVIRONMENT_API_KEY = ensure_localapi_auth(
    backend_base=SYNTH_API_BASE,
    synth_api_key=SYNTH_API_KEY,
)

LOCAL_API_PORT = int(os.getenv("LOCAL_API_PORT", "8115"))
USE_TUNNEL = os.getenv("USE_TUNNEL", "true").lower() in {"1", "true", "yes", "y"}

os.environ["SYNTH_API_KEY"] = SYNTH_API_KEY

print("Config loaded")

RLM_BASE_SYSTEM_PROMPT = (
    "You are a recursive language model. Use the REPL with the context variable to reason. "
    "Call llm_query or llm_query_batched as needed. When finished, answer with FINAL."
)

## Step 2: Dataset Loader (OOLONG)

We lazily load `oolongbench/oolong-real` (config: `dnd`).


In [None]:
# Step 2: OOLONG dataset wrapper
@dataclass
class OolongSample:
    index: int
    split: str
    query: str
    context: str
    answer: str


class OolongDataset:
    def __init__(self, hf_dataset: str = "oolongbench/oolong-real", hf_config: str = "dnd"):
        self.hf_dataset = hf_dataset
        self.hf_config = hf_config
        self._cache = {}

    def _load_split(self, split: str):
        if split not in self._cache:
            ds = load_dataset(self.hf_dataset, self.hf_config, split=split)
            self._cache[split] = ds
        return self._cache[split]

    def ensure_ready(self, splits: Iterable[str]) -> None:
        for split in splits:
            self._load_split(split)

    def size(self, split: str) -> int:
        return len(self._load_split(split))

    def sample(self, split: str, index: int) -> OolongSample:
        ds = self._load_split(split)
        idx = index % len(ds)
        row = ds[idx]
        query = row.get("query") or row.get("question") or ""
        context = row.get("context_window_text") or row.get("context") or row.get("text") or ""
        answer = row.get("answer") or ""
        return OolongSample(
            index=idx,
            split=split,
            query=str(query),
            context=str(context),
            answer=str(answer),
        )


oolong = OolongDataset()
oolong.ensure_ready(["validation", "test"])
print("Dataset ready:", oolong.size("validation"), oolong.size("test"))

## Step 3: Prompt Template Rendering

GEPA sends candidate prompts to the task app via `request.policy.config.prompt_template`.
We render those sections into messages, and use them to drive the RLM.


In [None]:
# Step 3: Prompt template helpers
def _normalize_prompt_template(policy_config: Dict[str, Any]) -> Dict[str, Any]:
    template = policy_config.get("prompt_template") or {}
    if not isinstance(template, dict):
        template = {}
    return template


def _get_prompt_sections(policy_config: Dict[str, Any]) -> List[Dict[str, Any]]:
    template = _normalize_prompt_template(policy_config)
    sections = (
        template.get("sections")
        or template.get("prompt_sections")
        or policy_config.get("prompt_sections")
        or []
    )
    if not isinstance(sections, list):
        return []
    return sorted(sections, key=lambda s: s.get("order", 0))


def render_prompt_sections(
    sections: List[Dict[str, Any]], placeholders: Dict[str, str]
) -> List[Dict[str, str]]:
    rendered: List[Dict[str, str]] = []
    for section in sections:
        role = section.get("role", "user")
        pattern = section.get("content") or section.get("pattern") or ""
        content = pattern.format(**placeholders)
        rendered.append({"role": role, "content": content})
    return rendered


def split_system_and_user(messages: List[Dict[str, str]]) -> tuple[str, str]:
    system_parts = [m["content"] for m in messages if m.get("role") == "system"]
    user_parts = [m["content"] for m in messages if m.get("role") != "system"]
    system_prompt = "\n\n".join(system_parts).strip()
    user_prompt = "\n\n".join(user_parts).strip()
    return system_prompt, user_prompt

## Step 4: Local API Factory (RLM Task App)

We implement a Local API that calls `rlm.RLM` for each rollout.
The task app reads the prompt template from `request.policy.config`.


In [None]:
# Step 4: Local API for OOLONG RLM
APP_ID = "oolong_rlm"
APP_NAME = "OOLONG RLM (Recursive Language Model) QA"


def normalize_answer(text: str) -> str:
    if text is None:
        return ""
    if not isinstance(text, str):
        text = str(text)
    return "".join(ch.lower() for ch in text.strip() if ch.isalnum() or ch.isspace()).strip()


def create_oolong_rlm_local_api():
    async def run_rollout(request: RolloutRequest, fastapi_request) -> RolloutResponse:
        policy_config = request.policy.config or {}
        split = env_config.get("split", "validation")

        sample = oolong.sample(split=split, index=seed)
        placeholders = {
            "query": sample.query,
            "context": sample.context,
            "context_metadata": "{context_metadata}",
        }

        sections = _get_prompt_sections(policy_config)
        if not sections:
            sections = [
                {"role": "system", "content": COMPOSED_SYSTEM_PROMPT, "order": 0},
                {"role": "assistant", "content": RLM_CONTEXT_METADATA_PATTERN, "order": 1},
                {"role": "user", "content": RLM_FIRST_USER_PROMPT, "order": 2},
                {"role": "user", "content": BASELINE_USER_PROMPT, "order": 3},
            ]
        rendered = render_prompt_sections(sections, placeholders)
        messages_for_validation = []
        for section in sections:
            role = section.get("role", "user")
            pattern = section.get("content") or section.get("pattern") or ""
            messages_for_validation.append({"role": role, "content": pattern})

        system_prompt, root_prompt = split_system_and_user(rendered)
        if system_prompt:
            custom_system_prompt = system_prompt
        else:
            custom_system_prompt = RLM_BASE_SYSTEM_PROMPT
        inference_url = (
            policy_config.get("inference_url")
            or policy_config.get("api_base")
            or policy_config.get("base_url")
        )
        if not inference_url:
            raise ValueError("Missing inference_url in policy config")

        api_key = policy_config.get("api_key") or SYNTH_API_KEY
        if not api_key:
            raise ValueError("Missing policy api_key for inference proxy")

        model_name = policy_config.get("model", "gpt-4o-mini")
        max_iterations = int(env_config.get("max_iterations", 2))
        max_depth = int(env_config.get("max_depth", 0))

        rlm = RLM(
            backend="openai",
            backend_kwargs={
                "model_name": model_name,
                "api_key": api_key,
                "base_url": inference_url,
            },
            environment="local",
            environment_kwargs={},
            custom_system_prompt=custom_system_prompt,
            max_iterations=max_iterations,
            max_depth=max_depth,
            verbose=False,
        )

        prompt_payload = rendered
        completion = rlm.completion(
            prompt_payload,
        )

        if isinstance(completion, str):
            predicted = completion
        else:
            predicted = completion.response or ""
        gold = sample.answer or ""

        reward = 1.0 if normalize_answer(predicted) == normalize_answer(gold) else 0.0

        return RolloutResponse(
            run_id=request.run_id,
            reward_info=RolloutMetrics(
                outcome_reward=reward,
                details={"messages": messages_for_validation, "predicted": predicted, "gold": gold},
            ),
            trace=None,
            trace_correlation_id=policy_config.get("trace_correlation_id"),
        )

    def provide_taskset_description():
        return {
            "splits": ["validation", "test"],
            "sizes": {"validation": oolong.size("validation"), "test": oolong.size("test")},
        }

    def provide_task_instances(seeds):
        for seed in seeds:
            sample = oolong.sample(split="validation", index=seed)
            yield TaskInfo(
                task={"id": APP_ID, "name": APP_NAME},
                dataset={"id": APP_ID, "split": sample.split, "index": sample.index},
                inference={"tool": "rlm_repl"},
                limits={"max_turns": 1},
                task_metadata={"query": sample.query},
            )

    return create_local_api(
        LocalAPIConfig(
            app_id=APP_ID,
            name=APP_NAME,
            description="OOLONG RLM local API for prompt optimization.",
            provide_taskset_description=provide_taskset_description,
            provide_task_instances=provide_task_instances,
            rollout=run_rollout,
            cors_origins=["*"],
        )
    )

## Step 5: Start the Local API

This spins up the task app and exposes it (optionally via a tunnel).


In [None]:
# Step 5: Start Local API
print("Starting local API...")
app = create_oolong_rlm_local_api()

kill_port(LOCAL_API_PORT)
run_server_background(app, LOCAL_API_PORT)

print(f"Waiting for local API on port {LOCAL_API_PORT}...")
await wait_for_health_check("localhost", LOCAL_API_PORT, ENVIRONMENT_API_KEY, timeout=60.0)
print("Local API ready!")

if USE_TUNNEL:
    print("Provisioning Cloudflare tunnel...")
    tunnel = await TunneledLocalAPI.create(
        local_port=LOCAL_API_PORT,
        backend=TunnelBackend.CloudflareManagedTunnel,
        api_key=SYNTH_API_KEY,
        backend_url=SYNTH_API_BASE,
        progress=True,
    )
    LOCAL_API_URL = tunnel.url
else:
    LOCAL_API_URL = f"http://localhost:{LOCAL_API_PORT}"

print("Local API URL:", LOCAL_API_URL)

## Step 6: Run GEPA Prompt Optimization

We configure GEPA to optimize the prompt sections passed to the local RLM task app.


In [None]:
# Step 6: GEPA optimization
BASELINE_SYSTEM_PROMPT = "Answer questions using the context."
BASELINE_USER_PROMPT = (
    "Query: {query}\n\nContext:\n{context}\n\nAnswer the query using the context."
)

RLM_CONTEXT_METADATA_PATTERN = "{context_metadata}"
RLM_FIRST_USER_PROMPT = (
    "You have not interacted with the REPL environment or seen your prompt / context yet. "
    "Your next action should be to look through and figure out how to answer the prompt, "
    "so don't just provide a final answer yet.\n\n" + USER_PROMPT
)

COMPOSED_SYSTEM_PROMPT = RLM_BASE_SYSTEM_PROMPT + " " + BASELINE_SYSTEM_PROMPT

config_body = {
    "prompt_learning": {
        "algorithm": "gepa",
        "task_app_url": LOCAL_API_URL,
        "env_name": "oolong",
        "initial_prompt": {
            "messages": [
                {"role": "system", "order": 0, "pattern": COMPOSED_SYSTEM_PROMPT},
                {"role": "assistant", "order": 1, "pattern": RLM_CONTEXT_METADATA_PATTERN},
                {"role": "user", "order": 2, "pattern": RLM_FIRST_USER_PROMPT},
                {"role": "user", "order": 3, "pattern": BASELINE_USER_PROMPT},
            ],
            "wildcards": {
                "query": "REQUIRED",
                "context": "REQUIRED",
                "context_metadata": "REQUIRED",
            },
        },
        "policy": {
            "model": "gpt-4o-mini",
            "inference_mode": "synth_hosted",
            "provider": "openai",
            "temperature": 0.0,
            "max_completion_tokens": 256,
        },
        "gepa": {
            "env_name": "oolong",
            "evaluation": {
                "seeds": list(range(13)),
                "validation_seeds": list(range(13, 15)),
            },
            "rollout": {"budget": 6, "max_concurrent": 3, "minibatch_size": 3},
            "mutation": {"rate": 0.3},
            "population": {"initial_size": 2, "num_generations": 1, "children_per_generation": 1},
            "archive": {"size": 10, "pareto_set_size": 10},
            "token": {"counting_model": "gpt-4"},
        },
        "env_config": {
            "split": "validation",
            "max_iterations": 2,
            "max_depth": 0,
        },
    },
}

job = PromptLearningJob.from_dict(
    config_dict=config_body,
)

job_id = job.submit()
print("GEPA job submitted:", job_id)
result = job.poll_until_complete(timeout=3600.0, interval=5.0, progress=True)
print("Final status:", result.status.value)
print("Best score:", result.best_score)

## Step 7: Next Steps

- Inspect the best prompt and rerun a manual rollout.
- Increase rollout budget and population size for stronger results.
