diff --git a/benchmark/tau2/.gitignore b/benchmark/tau2/.gitignore new file mode 100644 index 000000000..2577e5885 --- /dev/null +++ b/benchmark/tau2/.gitignore @@ -0,0 +1,5 @@ +result/ +.env.tau2 +.external/ +.venv-tau2/ +__pycache__/ diff --git a/benchmark/tau2/README.md b/benchmark/tau2/README.md new file mode 100644 index 000000000..7ebdb0807 --- /dev/null +++ b/benchmark/tau2/README.md @@ -0,0 +1,146 @@ +# TAU-2 Benchmark + +This directory contains a small OpenViking-style entry point for TAU-2 memory +evaluation. The first version is intentionally narrow: + +- fresh OpenViking Memory V2 experience-only baseline; +- Memory V2 pre-write recall treatment. + +Trajectory / procedure-view prompts, category rerank, and other harness-only +diagnostics are intentionally left out of this first PR. + +## Layout + +```text +benchmark/tau2/ +├── config/ +│ ├── baseline.yaml +│ ├── official.yaml +│ └── prewrite.yaml +├── scripts/ +│ ├── run_eval.py +│ ├── setup_tau2_repo.sh +│ └── tau2_common.py +└── run_full_eval.sh +``` + +Generated artifacts are written to `benchmark/tau2/result//`. + +## Quick Start + +This benchmark delegates task simulation and scoring to an external TAU-2 +checkout. Point the runner at that checkout and CLI explicitly when they are not +on the default path: + +```bash +export TAU2_REPO=/path/to/tau2-bench +export TAU2_CLI=/path/to/tau2 +``` + +For a local one-command setup, clone and install TAU-2 into ignored benchmark +directories: + +```bash +benchmark/tau2/scripts/setup_tau2_repo.sh +source benchmark/tau2/.env.tau2 +``` + +Plan the default benchmark without running TAU-2: + +```bash +python benchmark/tau2/scripts/run_eval.py --config benchmark/tau2/config/baseline.yaml --plan-only +``` + +Add `--preflight` or `--strict-preflight` when you want the runner to write a +small environment/config check next to the run plan. + +After setup, verify the local TAU-2 link and write a one-cell run plan: + +```bash +benchmark/tau2/run_full_eval.sh \ + --config benchmark/tau2/config/baseline.yaml \ + --strict-preflight \ + --domain retail \ + --strategy-id memory_v2_experience_only \ + --task-id 5 \ + --repeat-count 1 +``` + +Plan a one-cell Memory V2 pre-write smoke: + +```bash +benchmark/tau2/run_full_eval.sh \ + --config benchmark/tau2/config/baseline.yaml \ + --domain retail \ + --strategy-id memory_v2_prewrite \ + --num-tasks 1 \ + --repeat-count 1 +``` + +Run the Memory V2 8-trial matrix (`retail + airline` x 2 strategies x 8 repeats): + +```bash +benchmark/tau2/run_full_eval.sh \ + --config benchmark/tau2/config/baseline.yaml \ + --execute +``` + +For a small E2E smoke, keep both the eval and train slices tiny: + +```bash +benchmark/tau2/run_full_eval.sh \ + --config benchmark/tau2/config/baseline.yaml \ + --domain retail \ + --strategy-id memory_v2_experience_only \ + --num-tasks 1 \ + --train-num-tasks 1 \ + --repeat-count 1 \ + --execute +``` + +When using Doubao through an OpenAI-compatible endpoint, set `OPENAI_API_KEY` +and `OPENAI_API_BASE` for LiteLLM before running upstream TAU-2. + +Start the OpenViking service before executing memory cells, and verify it with +`ov status`. For evidence runs, use a clean OpenViking workspace/config and set +`OPENVIKING_URL` explicitly so local custom memory templates do not pollute the +Memory V2 baseline. + +## Memory Adapter + +`memory_v2_experience_only` and `memory_v2_prewrite` cells run through a small +TAU-2 agent adapter in this directory: + +- train by writing TAU-2 training conversations into OpenViking sessions; +- evaluate by retrieving OpenViking experience memory at the first user turn; +- for pre-write recall, retrieve again before write-like tool calls and + regenerate that step with the matched memories; +- emit artifact metadata to identify the OpenViking account, agent, + corpus, retrieval mode, and simulator policy used by each cell. + +## User Simulator Policy + +The runner default is the official TAU-2 user simulator if +`eval.user_simulator_policy` is omitted. The bundled OpenViking memory benchmark +config sets `confirmation_aware`, because a memory benchmark should not treat +user confirmation as task completion before the backend write has happened. + +`confirmation_aware` applies a small idempotent prompt patch to the configured +TAU-2 checkout before planning or running. The patch appends only the behavioral +confirmation boundary to the TAU-2 user simulator guidelines; metadata such as +the upstream PR link is kept in run artifacts, not in the simulator prompt. +Reference: [sierra-research/tau2-bench#297](https://github.com/sierra-research/tau2-bench/pull/297). + +Use `config/official.yaml` with a clean TAU-2 checkout when you need an +official-user-simulator parity run. If the checkout was already patched, the +artifact records that boundary instead of labeling the run pure official. + +## Evidence Boundary + +Only completed `retail + airline` runs with the same config, same seeds/repeats, +and non-empty artifacts should be read as benchmark evidence. Partial runs, +single-task probes, or missing OpenViking corpus identity are diagnostics. +Executed runs write per-cell JSON under `cell_results/` and a strategy/domain +aggregate under `scoreboard.json`. Memory training artifacts are shared by +domain and strategy under `memory_corpora/`, so repeated eval cells reuse the +same fresh corpus instead of rewriting it. diff --git a/benchmark/tau2/config/baseline.yaml b/benchmark/tau2/config/baseline.yaml new file mode 100644 index 000000000..4c4a5060e --- /dev/null +++ b/benchmark/tau2/config/baseline.yaml @@ -0,0 +1,53 @@ +benchmark: + name: tau2_openviking_baseline + domains: + - retail + - airline + train_split_name: train + eval_split_name: test + repeat_count: 8 + task_max_concurrency: 10 + max_steps: 200 + seed: 300 + agent: llm_agent + user: user_simulator + reasoning_effort: high + +paths: + tau2_repo: ${TAU2_REPO:-data/external_benchmarks/tau2-bench} + tau2_cli: ${TAU2_CLI:-tau2} + output_dir: benchmark/tau2/result + +eval: + # The runner default is official if this field is omitted. The OpenViking + # memory benchmark config opts into a confirmation-aware TAU-2 user simulator + # prompt; run_eval.py applies that small prompt patch idempotently when needed. + user_simulator_policy: confirmation_aware + +model: + agent_llm: ${TAU2_AGENT_LLM:-openai/doubao-seed-2-0-pro-260215} + user_llm: ${TAU2_USER_LLM:-openai/doubao-seed-2-0-pro-260215} + temperature: 0.0 + +openviking: + url: ${OPENVIKING_URL:-http://localhost:1933} + account: ${OPENVIKING_ACCOUNT:-default} + agent_id: ${OPENVIKING_AGENT_ID:-tau2-openviking-agent} + retrieval_top_k: 4 + replay_write_policy: read_only + +strategies: + - id: memory_v2_experience_only + label: OpenViking Memory V2 experience-only + memory_backend: openviking + train_required: true + corpus_id: memory_v2_experience_only + train_memory_mode: experience_only + retrieval_mode: first_user + - id: memory_v2_prewrite + label: OpenViking Memory V2 pre-write recall + memory_backend: openviking + train_required: true + corpus_id: memory_v2_experience_only + train_memory_mode: experience_only + retrieval_mode: first_user_prewrite diff --git a/benchmark/tau2/config/official.yaml b/benchmark/tau2/config/official.yaml new file mode 100644 index 000000000..d10bee872 --- /dev/null +++ b/benchmark/tau2/config/official.yaml @@ -0,0 +1,7 @@ +extends: baseline.yaml + +benchmark: + name: tau2_openviking_official_user_simulator + +eval: + user_simulator_policy: official diff --git a/benchmark/tau2/config/prewrite.yaml b/benchmark/tau2/config/prewrite.yaml new file mode 100644 index 000000000..834963b41 --- /dev/null +++ b/benchmark/tau2/config/prewrite.yaml @@ -0,0 +1,13 @@ +extends: baseline.yaml + +benchmark: + name: tau2_openviking_prewrite + +strategies: + - id: memory_v2_prewrite + label: OpenViking Memory V2 pre-write recall + memory_backend: openviking + train_required: true + corpus_id: memory_v2_experience_only + train_memory_mode: experience_only + retrieval_mode: first_user_prewrite diff --git a/benchmark/tau2/run_full_eval.sh b/benchmark/tau2/run_full_eval.sh new file mode 100755 index 000000000..ca69a7a32 --- /dev/null +++ b/benchmark/tau2/run_full_eval.sh @@ -0,0 +1,72 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" +PYTHON_BIN="${PYTHON_BIN:-python3}" +CONFIG="$SCRIPT_DIR/config/baseline.yaml" +EXECUTE=false +PREFLIGHT=false +STRICT_PREFLIGHT=false +RUN_ID="" +RUN_EVAL_EXTRA=() + +while [[ $# -gt 0 ]]; do + case "$1" in + --config) + CONFIG="$2" + shift 2 + ;; + --run-id) + RUN_ID="$2" + shift 2 + ;; + --execute) + EXECUTE=true + shift + ;; + --preflight) + PREFLIGHT=true + shift + ;; + --strict-preflight) + STRICT_PREFLIGHT=true + shift + ;; + --domain|--repeat-count|--strategy-id|--task-id|--num-tasks|--train-num-tasks) + RUN_EVAL_EXTRA+=("$1" "$2") + shift 2 + ;; + --help|-h) + cat <<'EOF' +Usage: + benchmark/tau2/run_full_eval.sh [--config PATH] [--run-id ID] [--execute] [--preflight] + +Without --execute the script only writes run_plan artifacts. +EOF + exit 0 + ;; + *) + echo "Unknown argument: $1" >&2 + exit 1 + ;; + esac +done + +RUN_ARGS=() +if [[ -n "$RUN_ID" ]]; then + RUN_ARGS+=(--run-id "$RUN_ID") +fi + +cd "$REPO_ROOT" +if [[ "$STRICT_PREFLIGHT" == true ]]; then + RUN_EVAL_EXTRA+=(--strict-preflight) +elif [[ "$PREFLIGHT" == true ]]; then + RUN_EVAL_EXTRA+=(--preflight) +fi + +if [[ "$EXECUTE" == true ]]; then + "$PYTHON_BIN" "$SCRIPT_DIR/scripts/run_eval.py" --config "$CONFIG" "${RUN_ARGS[@]}" "${RUN_EVAL_EXTRA[@]}" --execute +else + "$PYTHON_BIN" "$SCRIPT_DIR/scripts/run_eval.py" --config "$CONFIG" "${RUN_ARGS[@]}" "${RUN_EVAL_EXTRA[@]}" --plan-only +fi diff --git a/benchmark/tau2/scripts/run_eval.py b/benchmark/tau2/scripts/run_eval.py new file mode 100755 index 000000000..5458ba61a --- /dev/null +++ b/benchmark/tau2/scripts/run_eval.py @@ -0,0 +1,531 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import importlib.util +import json +import subprocess +import sys +from pathlib import Path +from typing import Any + +from tau2_common import ( + domains, + load_config, + output_dir, + normalize_litellm_env, + run_id, + simulator_policy_report, + split_file, + strategy_ids, + tau2_cli, + tau2_context, + tau2_repo, + user_simulator_policy, + write_json, +) + + +def _reward(sim: dict[str, Any]) -> float: + info = sim.get("reward_info") or {} + value = info.get("reward", sim.get("reward", 0.0)) + try: + return float(value) + except (TypeError, ValueError): + return 0.0 + + +def _db_match(sim: dict[str, Any]) -> bool | None: + info = sim.get("reward_info") or {} + db = info.get("db_check") or {} + if isinstance(db, dict): + if "score" in db: + return bool(db["score"]) + if "db_match" in db: + return bool(db["db_match"]) + return sim.get("db_match") + + +def _metrics_from_tau2_results(results_path: Path) -> dict[str, Any]: + data = json.loads(results_path.read_text(encoding="utf-8")) + sims = data.get("simulations") or [] + rewards = [_reward(sim) for sim in sims] + db_values = [_db_match(sim) for sim in sims] + db_known = [value for value in db_values if value is not None] + return { + "simulation_count": len(sims), + "avg_reward": sum(rewards) / len(rewards) if rewards else 0.0, + "db_match_rate": (sum(1 for value in db_known if value) / len(db_known)) if db_known else None, + } + + +def _tau2_command( + config: dict[str, Any], + *, + domain: str, + strategy: dict[str, Any], + configured_run_id: str, + run_label: str, + task_ids: list[str] | None, + num_tasks: int | None, + train_num_tasks: int | None, + seed: int, +) -> list[str] | None: + benchmark = config["benchmark"] + model = config["model"] + + reasoning_effort = benchmark.get("reasoning_effort") + agent_llm_args = '{"temperature":0.0}' + user_llm_args = '{"temperature":0.0}' + if reasoning_effort: + agent_llm_args = f'{{"temperature":0.0,"reasoning_effort":"{reasoning_effort}"}}' + user_llm_args = f'{{"temperature":0.0,"reasoning_effort":"{reasoning_effort}"}}' + + if ( + strategy.get("memory_backend") == "openviking" + and strategy.get("train_memory_mode") == "experience_only" + ): + openviking = config["openviking"] + corpus_id = str(strategy.get("corpus_id") or strategy["id"]) + account = f"{openviking['account']}-{configured_run_id}-{domain}-{corpus_id}" + agent_id = f"{openviking['agent_id']}-{domain}-{corpus_id}" + user = f"tau2-{domain}-{corpus_id}" + search_uri = f"viking://agent/{agent_id}/memories/experiences" + command = [ + sys.executable, + str(Path(__file__).with_name("run_memory_v2_eval.py")), + "--tau2-repo", + str(tau2_repo(config)), + "--run-dir", + str(output_dir(config, configured_run_id) / "memory_cells" / run_label), + "--corpus-dir", + str( + output_dir(config, configured_run_id) + / "memory_corpora" + / f"{domain}_{corpus_id}" + ), + "--run-label", + run_label, + "--strategy-id", + strategy["id"], + "--domain", + domain, + "--train-split-name", + str(benchmark.get("train_split_name", "train")), + "--eval-split-name", + str(benchmark.get("eval_split_name", "test")), + "--max-steps", + str(benchmark.get("max_steps", 200)), + "--max-concurrency", + str(benchmark.get("task_max_concurrency", 10)), + "--agent-llm", + str(model["agent_llm"]), + "--user-llm", + str(model["user_llm"]), + "--agent-llm-args", + agent_llm_args, + "--user-llm-args", + user_llm_args, + "--openviking-url", + str(openviking["url"]), + "--openviking-account", + account, + "--openviking-user", + user, + "--openviking-agent-id", + agent_id, + "--search-uri", + search_uri, + "--retrieval-top-k", + str(openviking.get("retrieval_top_k", 4)), + "--retrieval-mode", + str(strategy.get("retrieval_mode", "first_user")), + "--seed", + str(seed), + ] + if task_ids: + for task_id in task_ids: + command.extend(["--task-id", task_id]) + elif num_tasks is not None: + command.extend(["--num-tasks", str(num_tasks)]) + train_num_tasks = train_num_tasks if train_num_tasks is not None else strategy.get("train_num_tasks") + if train_num_tasks is not None: + command.extend(["--train-num-tasks", str(train_num_tasks)]) + return command + + if strategy.get("memory_backend") != "none": + return None + + command = [ + tau2_cli(config), + "run", + "--domain", + domain, + "--agent", + str(benchmark.get("agent", "llm_agent")), + "--user", + str(benchmark.get("user", "user_simulator")), + "--task-split-name", + str(benchmark.get("eval_split_name", "test")), + "--num-trials", + "1", + "--max-steps", + str(benchmark.get("max_steps", 200)), + "--max-concurrency", + str(benchmark.get("task_max_concurrency", 10)), + "--agent-llm", + str(model["agent_llm"]), + "--user-llm", + str(model["user_llm"]), + "--save-to", + run_label, + "--seed", + str(seed), + ] + + command.extend(["--agent-llm-args", agent_llm_args]) + command.extend(["--user-llm-args", user_llm_args]) + + if task_ids: + command.append("--task-ids") + command.extend(task_ids) + elif num_tasks is not None: + command.extend(["--num-tasks", str(num_tasks)]) + + return command + + +def _build_plan( + config: dict[str, Any], + configured_run_id: str, + *, + selected_domains: set[str] | None, + selected_strategy_ids: set[str] | None, + task_ids: list[str] | None, + num_tasks: int | None, + train_num_tasks: int | None, + repeat_count_override: int | None, +) -> dict[str, Any]: + repeat_count = repeat_count_override or int(config["benchmark"].get("repeat_count", 8)) + base_seed = int(config["benchmark"].get("seed", 300)) + policy_report = simulator_policy_report(config) + strategies = config.get("strategies") or [] + if selected_strategy_ids: + unknown = selected_strategy_ids - set(strategy_ids(config)) + if unknown: + raise ValueError(f"unknown strategy ids: {sorted(unknown)}") + strategies = [strategy for strategy in strategies if strategy["id"] in selected_strategy_ids] + cells = [] + plan_domains = domains(config) + if selected_domains: + unknown_domains = selected_domains - set(plan_domains) + if unknown_domains: + raise ValueError(f"unknown domains: {sorted(unknown_domains)}") + plan_domains = [domain for domain in plan_domains if domain in selected_domains] + for domain in plan_domains: + split_path = split_file(config, domain) + for strategy in strategies: + for repeat_index in range(repeat_count): + seed = base_seed + repeat_index + run_label = f"{configured_run_id}_{domain}_{strategy['id']}_r{repeat_index + 1}" + command = _tau2_command( + config, + domain=domain, + strategy=strategy, + configured_run_id=configured_run_id, + run_label=run_label, + task_ids=task_ids, + num_tasks=num_tasks, + train_num_tasks=train_num_tasks, + seed=seed, + ) + non_executable_reason = None + if command is None: + non_executable_reason = ( + "This OpenViking memory strategy is planned but not wired to " + "the TAU-2 adapter in this PR." + ) + cells.append( + { + "domain": domain, + "strategy_id": strategy["id"], + "strategy_label": strategy.get("label", strategy["id"]), + "repeat_index": repeat_index + 1, + "seed": seed, + "run_label": run_label, + "train_required": bool(strategy.get("train_required")), + "memory_backend": strategy.get("memory_backend"), + "corpus_id": strategy.get("corpus_id", strategy["id"]), + "retrieval_mode": strategy.get("retrieval_mode"), + "adapter_status": strategy.get("adapter_status", "ready"), + "executable": command is not None, + "user_simulator_policy": user_simulator_policy(config), + "user_simulator_policy_supported": policy_report["supported"], + "split_file": str(split_path), + "command": command, + "non_executable_reason": non_executable_reason, + } + ) + executable_cell_count = sum(1 for cell in cells if cell["executable"]) + return { + "schema_version": "openviking.tau2.run_plan.v0", + "run_id": configured_run_id, + "status": "planned", + "strategy_ids": strategy_ids(config), + "domains": plan_domains, + "tau2": tau2_context(config), + "simulator_policy": policy_report, + "cell_count": len(cells), + "executable_cell_count": executable_cell_count, + "pending_cell_count": len(cells) - executable_cell_count, + "cells": cells, + } + + +def _cell_artifacts(cell: dict[str, Any], repo: Path, out: Path) -> dict[str, str]: + if cell.get("memory_backend") == "openviking": + run_dir = out / "memory_cells" / cell["run_label"] + corpus_id = str(cell.get("corpus_id") or cell["strategy_id"]) + corpus_dir = out / "memory_corpora" / f"{cell['domain']}_{corpus_id}" + return { + "summary": str(run_dir / f"{cell['run_label']}.summary.json"), + "results": str(run_dir / f"{cell['run_label']}.json"), + "retrieval_trace": str(run_dir / f"{cell['run_label']}.retrieval_trace.jsonl"), + "corpus_manifest": str(corpus_dir / "corpus_manifest.json"), + } + return { + "results": str(repo / "data" / "simulations" / f"{cell['run_label']}.json") + } + + +def _cell_metrics(cell: dict[str, Any], artifacts: dict[str, str]) -> dict[str, Any] | None: + if cell.get("memory_backend") == "openviking": + summary_path = Path(artifacts["summary"]) + if not summary_path.is_file(): + return None + summary = json.loads(summary_path.read_text(encoding="utf-8")) + return summary.get("metrics") + + results_path = Path(artifacts["results"]) + if not results_path.is_file(): + return None + return _metrics_from_tau2_results(results_path) + + +def _summarize(rows: list[dict[str, Any]]) -> dict[str, Any]: + def weighted(rows_for_group: list[dict[str, Any]]) -> dict[str, Any]: + metric_rows = [row for row in rows_for_group if row.get("metrics")] + sim_count = sum(int(row["metrics"].get("simulation_count") or 0) for row in metric_rows) + reward_sum = sum( + float(row["metrics"].get("avg_reward") or 0.0) + * int(row["metrics"].get("simulation_count") or 0) + for row in metric_rows + ) + db_weighted_rows = [ + row + for row in metric_rows + if row["metrics"].get("db_match_rate") is not None + and int(row["metrics"].get("simulation_count") or 0) > 0 + ] + db_weight = sum(int(row["metrics"].get("simulation_count") or 0) for row in db_weighted_rows) + db_sum = sum( + float(row["metrics"]["db_match_rate"]) + * int(row["metrics"].get("simulation_count") or 0) + for row in db_weighted_rows + ) + return { + "cell_count": len(rows_for_group), + "completed_cell_count": len(metric_rows), + "simulation_count": sim_count, + "avg_reward": reward_sum / sim_count if sim_count else None, + "db_match_rate": db_sum / db_weight if db_weight else None, + } + + by_strategy: dict[str, dict[str, Any]] = {} + for row in rows: + strategy_id = row["strategy_id"] + strategy_summary = by_strategy.setdefault( + strategy_id, + { + "strategy_id": strategy_id, + "domains": {}, + "task_weighted_total": {}, + }, + ) + strategy_summary["domains"].setdefault(row["domain"], []).append(row) + + for strategy_summary in by_strategy.values(): + all_rows = [] + for domain, domain_rows in list(strategy_summary["domains"].items()): + strategy_summary["domains"][domain] = weighted(domain_rows) + all_rows.extend(domain_rows) + strategy_summary["task_weighted_total"] = weighted(all_rows) + + return { + "schema_version": "openviking.tau2.scoreboard.v0", + "strategies": by_strategy, + } + + +def _execute_cells(plan: dict[str, Any], repo: Path, out: Path) -> list[dict[str, Any]]: + policy_report = plan.get("simulator_policy") or {} + if not policy_report.get("supported", False): + raise RuntimeError( + "configured user simulator policy is not supported by this TAU-2 checkout: " + f"{policy_report}" + ) + rows = [] + for cell in plan["cells"]: + if not cell.get("executable"): + raise RuntimeError( + f"cell is not executable yet: {cell['run_label']} " + f"(strategy_id={cell['strategy_id']}, adapter_status={cell.get('adapter_status')})" + ) + print(f"[tau2] running {cell['run_label']}") + completed = subprocess.run( + cell["command"], + cwd=repo, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + row = { + "run_label": cell["run_label"], + "domain": cell["domain"], + "strategy_id": cell["strategy_id"], + "returncode": completed.returncode, + "stdout_tail": completed.stdout[-4000:], + "stderr_tail": completed.stderr[-4000:], + } + row["artifacts"] = _cell_artifacts(cell, repo, out) + row["metrics"] = _cell_metrics(cell, row["artifacts"]) + rows.append(row) + write_json(out / "cell_results" / f"{cell['run_label']}.json", row) + if completed.returncode != 0: + raise RuntimeError(f"cell failed: {cell['run_label']} returncode={completed.returncode}") + return rows + + +def _preflight(config: dict[str, Any], out: Path, *, strict: bool) -> int: + errors: list[str] = [] + llm_env = normalize_litellm_env() + tau2_info = tau2_context(config) + policy_report = simulator_policy_report(config) + if strict and not tau2_info["tau2_repo_exists"]: + errors.append(f"missing TAU-2 repo: {tau2_info['tau2_repo']}") + if strict and not tau2_info["tau2_cli_resolved"]: + errors.append(f"missing TAU-2 CLI: {tau2_info['tau2_cli']}") + if strict and not llm_env["has_api_key"]: + errors.append("missing LLM API key: set OPENAI_API_KEY or ARK_API_KEY") + if strict and not llm_env["has_base_url"]: + errors.append("missing OpenAI-compatible base URL: set OPENAI_API_BASE, OPENAI_BASE_URL, or ARK_BASE_URL") + if strict and not policy_report["supported"]: + errors.append( + "configured confirmation-aware user simulator policy requires a TAU-2 " + f"checkout with the prompt fix: {policy_report['prompt_files']}" + ) + split_rows = [] + for domain in domains(config): + path = split_file(config, domain) + exists = path.is_file() + split_rows.append({"domain": domain, "path": str(path), "exists": exists}) + if strict and not exists: + errors.append(f"missing split file for {domain}: {path}") + + import_rows = [] + for module in ("openviking", "openviking_cli", "tau2"): + ok = importlib.util.find_spec(module) is not None + import_rows.append({"module": module, "ok": ok}) + if strict and not ok: + errors.append(f"missing Python module: {module}") + + report = { + "status": "failed" if errors else "ok", + "strict": strict, + "tau2": tau2_info, + "llm_env": llm_env, + "simulator_policy": policy_report, + "domains": domains(config), + "strategies": strategy_ids(config), + "imports": import_rows, + "split_files": split_rows, + "errors": errors, + } + write_json(out / "preflight.json", report) + if errors: + for error in errors: + print(f"[preflight][ERROR] {error}", file=sys.stderr) + return 1 + print(f"[preflight][OK] wrote {out / 'preflight.json'}") + return 0 + + +def main() -> int: + parser = argparse.ArgumentParser(description="Plan or run TAU-2 benchmark cells.") + parser.add_argument("--config", type=Path, default=Path(__file__).parents[1] / "config" / "baseline.yaml") + parser.add_argument("--run-id", default=run_id()) + parser.add_argument("--domain", action="append", help="Run only this configured domain; may be repeated.") + parser.add_argument("--repeat-count", type=int, help="Override benchmark.repeat_count for smoke runs.") + parser.add_argument("--strategy-id", action="append", help="Run only this strategy id; may be repeated.") + parser.add_argument("--task-id", action="append", help="Run only this TAU-2 task id; may be repeated.") + parser.add_argument("--num-tasks", type=int, help="Run the first N tasks from the selected split.") + parser.add_argument("--train-num-tasks", type=int, help="Train OpenViking memory on the first N train tasks.") + parser.add_argument( + "--preflight", + action="store_true", + help="Write a lightweight environment/config preflight report.", + ) + parser.add_argument( + "--strict-preflight", + action="store_true", + help="Fail if optional runtime imports or split files are missing.", + ) + parser.add_argument("--plan-only", action="store_true", help="Only write run_plan.json.") + parser.add_argument("--execute", action="store_true", help="Execute planned cells.") + args = parser.parse_args() + normalize_litellm_env() + + if args.plan_only and args.execute: + raise SystemExit("--plan-only and --execute are mutually exclusive") + + config = load_config(args.config) + out = output_dir(config, args.run_id) + out.mkdir(parents=True, exist_ok=True) + if args.preflight or args.strict_preflight: + preflight_status = _preflight(config, out, strict=args.strict_preflight) + if args.strict_preflight and preflight_status != 0: + return preflight_status + + plan = _build_plan( + config, + args.run_id, + selected_domains=set(args.domain) if args.domain else None, + selected_strategy_ids=set(args.strategy_id) if args.strategy_id else None, + task_ids=args.task_id, + num_tasks=args.num_tasks, + train_num_tasks=args.train_num_tasks, + repeat_count_override=args.repeat_count, + ) + write_json(out / "run_plan.json", plan) + write_json(out / "resolved_config.json", config) + print(f"[tau2] wrote {out / 'run_plan.json'}") + + if args.execute: + try: + rows = _execute_cells(plan, tau2_repo(config), out) + plan["status"] = "succeeded" + plan["executed_cell_count"] = len(rows) + write_json(out / "run_plan.json", plan) + write_json(out / "scoreboard.json", _summarize(rows)) + except Exception as exc: + plan["status"] = "failed" + plan["error"] = str(exc) + write_json(out / "run_plan.json", plan) + print(f"[tau2][ERROR] {exc}", file=sys.stderr) + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/benchmark/tau2/scripts/run_memory_v2_eval.py b/benchmark/tau2/scripts/run_memory_v2_eval.py new file mode 100644 index 000000000..de5ef5441 --- /dev/null +++ b/benchmark/tau2/scripts/run_memory_v2_eval.py @@ -0,0 +1,645 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import json +import shutil +import sys +import time +from pathlib import Path +from typing import Any + +from tau2_common import normalize_litellm_env + + +AGENT_NAME = "openviking_memory_agent" +REPO_ROOT = Path(__file__).resolve().parents[3] +WRITE_TOOL_PREFIXES = ( + "toggle_", + "enable_", + "disable_", + "set_", + "reset_", + "update_", + "modify_", + "cancel_", + "book_", + "exchange_", + "return_", + "grant_", + "reboot_", +) + + +def _json(text: str) -> dict[str, Any]: + return json.loads(text) if text else {} + + +def _write_json(path: Path, payload: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, ensure_ascii=False, indent=2, sort_keys=True) + "\n") + + +def _add_tau2_to_path(tau2_repo: Path) -> None: + src = tau2_repo / "src" + sys.path.insert(0, str(REPO_ROOT)) + sys.path.insert(0, str(src if src.is_dir() else tau2_repo)) + + +def _save_to_arg(path: Path) -> str: + # Some TAU-2 versions append ".json"; newer versions treat save_to as a + # run directory and write results.json under it. + return str(path.with_suffix("") if path.suffix == ".json" else path) + + +def _compat_results_path(path: Path) -> Path: + run_dir = path.with_suffix("") if path.suffix == ".json" else path + return run_dir / "results.json" + + +def _reward(sim: dict[str, Any]) -> float: + info = sim.get("reward_info") or {} + value = info.get("reward", sim.get("reward", 0.0)) + try: + return float(value) + except (TypeError, ValueError): + return 0.0 + + +def _db_match(sim: dict[str, Any]) -> bool | None: + info = sim.get("reward_info") or {} + db = info.get("db_check") or {} + if isinstance(db, dict): + if "score" in db: + return bool(db["score"]) + if "db_match" in db: + return bool(db["db_match"]) + return sim.get("db_match") + + +def _metrics(results_path: Path) -> dict[str, Any]: + data = json.loads(results_path.read_text()) + sims = data.get("simulations") or [] + rewards = [_reward(sim) for sim in sims] + db_values = [_db_match(sim) for sim in sims] + db_known = [value for value in db_values if value is not None] + return { + "simulation_count": len(sims), + "avg_reward": sum(rewards) / len(rewards) if rewards else 0.0, + "db_match_rate": (sum(1 for value in db_known if value) / len(db_known)) if db_known else None, + } + + +def _tool_call_name(tool_call: Any) -> str: + if isinstance(tool_call, dict): + return str(tool_call.get("name") or tool_call.get("function", {}).get("name") or "") + return str(getattr(tool_call, "name", "") or "") + + +def _tool_call_arguments(tool_call: Any) -> Any: + if isinstance(tool_call, dict): + return tool_call.get("arguments") or tool_call.get("function", {}).get("arguments") or {} + return getattr(tool_call, "arguments", {}) or {} + + +def _is_write_tool_call(tool_call: Any) -> bool: + name = _tool_call_name(tool_call) + return bool(name) and name.startswith(WRITE_TOOL_PREFIXES) + + +def _tool_call_query(tool_calls: list[Any], state_messages: list[Any]) -> str: + rendered = [] + for call in tool_calls: + rendered.append( + f"{_tool_call_name(call) or 'unknown_tool'}(" + f"{json.dumps(_tool_call_arguments(call), ensure_ascii=False, sort_keys=True)}" + ")" + ) + recent_user = [ + str(getattr(message, "content", "") or "") + for message in state_messages[-8:] + if str(getattr(message, "role", "")) == "user" and str(getattr(message, "content", "") or "").strip() + ] + recent_observations = [ + str(getattr(message, "content", "") or "")[:600] + for message in state_messages[-12:] + if str(getattr(message, "role", "")) == "tool" and str(getattr(message, "content", "") or "").strip() + ] + parts = [ + "Before executing write-like tool call(s): " + "; ".join(rendered), + "Recent user context: " + " | ".join(recent_user[-3:]), + ] + if recent_observations: + parts.append("Recent tool observations: " + " | ".join(recent_observations[-4:])) + return "\n".join(parts) + + +def _message_text(message: dict[str, Any]) -> tuple[str, str]: + role = str(message.get("role") or "assistant") + if role == "user": + return "user", str(message.get("content") or "") + if role == "tool": + return "assistant", f"Tool result: {message.get('content') or ''}" + calls = message.get("tool_calls") or [] + if calls: + rendered = [] + for call in calls: + name = call.get("name") or call.get("function", {}).get("name") or "unknown_tool" + arguments = call.get("arguments") or call.get("function", {}).get("arguments") or {} + rendered.append(f"{name}({json.dumps(arguments, ensure_ascii=False, sort_keys=True)})") + return "assistant", "Assistant tool call: " + "; ".join(rendered) + return "assistant", str(message.get("content") or "") + + +def _run_tau2( + *, + tau2_repo: Path, + domain: str, + split: str, + task_ids: list[str] | None, + num_tasks: int | None, + trials: int, + max_steps: int, + max_concurrency: int, + agent: str, + user: str, + agent_llm: str, + user_llm: str, + agent_llm_args: dict[str, Any], + user_llm_args: dict[str, Any], + seed: int, + save_to: Path, +): + _add_tau2_to_path(tau2_repo) + from tau2.data_model.simulation import RunConfig, TextRunConfig + from tau2.run import run_domain + + compat_results = _compat_results_path(save_to) + if save_to.exists(): + save_to.unlink() + if compat_results.parent.is_dir(): + shutil.rmtree(compat_results.parent) + config_cls = TextRunConfig if getattr(RunConfig, "__origin__", None) is not None else RunConfig + result = run_domain( + config_cls( + domain=domain, + task_split_name=split, + task_ids=task_ids, + num_tasks=num_tasks, + agent=agent, + llm_agent=agent_llm, + llm_args_agent=agent_llm_args, + user=user, + llm_user=user_llm, + llm_args_user=user_llm_args, + num_trials=trials, + max_steps=max_steps, + save_to=_save_to_arg(save_to), + max_concurrency=max_concurrency, + seed=seed, + log_level="INFO", + ) + ) + if not save_to.exists() and compat_results.exists(): + shutil.copyfile(compat_results, save_to) + return result + + +def _client(args: argparse.Namespace): + import openviking as ov + + client = ov.SyncHTTPClient( + url=args.openviking_url, + api_key="", + user=args.openviking_user, + agent_id=args.openviking_agent_id, + account=args.openviking_account, + timeout=args.openviking_timeout, + extra_headers={}, + ) + client.initialize() + return client + + +def _wait_task(client: Any, task_id: str | None, timeout: int) -> dict[str, Any]: + if not task_id: + return {"status": "no_task"} + deadline = time.time() + timeout + last = None + while time.time() < deadline: + last = client.get_task(task_id) + status = (last or {}).get("status") + if status == "completed": + return last or {"status": status} + if status in {"failed", "cancelled"}: + raise RuntimeError(f"OpenViking task {task_id} {status}: {last}") + time.sleep(2) + raise TimeoutError(f"OpenViking task {task_id} did not finish within {timeout}s: {last}") + + +def _read_memory_text(client: Any, match: Any) -> tuple[str, str | None]: + try: + return client.read(getattr(match, "uri", "")), None + except Exception as exc: + fallback = getattr(match, "abstract", "") or getattr(match, "overview", "") or "" + return fallback, f"{type(exc).__name__}: {exc}" + + +def _probe_corpus(args: argparse.Namespace, client: Any) -> dict[str, Any]: + result = client.search( + query=f"{args.domain} customer service order reservation booking cancellation exchange return update", + target_uri=args.search_uri, + limit=args.retrieval_top_k, + ) + memories = list(getattr(result, "memories", []) or []) + reads = [] + for match in memories[: args.retrieval_top_k]: + uri = getattr(match, "uri", "") + text, read_error = _read_memory_text(client, match) + row = { + "uri": uri, + "score": getattr(match, "score", None), + "text_chars": len(text), + "non_empty": bool(str(text).strip()), + } + if read_error: + row["read_error"] = read_error + reads.append(row) + return { + "query": f"{args.domain} customer service order reservation booking cancellation exchange return update", + "match_count": len(memories), + "read_non_empty_count": sum(1 for row in reads if row["non_empty"]), + "matches": reads, + } + + +def _train(args: argparse.Namespace, train_results: Path, corpus_manifest: Path) -> dict[str, Any]: + if corpus_manifest.is_file() and not args.force_train: + return json.loads(corpus_manifest.read_text()) + + _run_tau2( + tau2_repo=args.tau2_repo, + domain=args.domain, + split=args.train_split_name, + task_ids=args.train_task_ids, + num_tasks=args.train_num_tasks, + trials=1, + max_steps=args.max_steps, + max_concurrency=args.max_concurrency, + agent=args.base_agent, + user=args.user, + agent_llm=args.agent_llm, + user_llm=args.user_llm, + agent_llm_args=args.agent_llm_args, + user_llm_args=args.user_llm_args, + seed=args.seed, + save_to=train_results, + ) + + data = json.loads(train_results.read_text()) + client = _client(args) + committed = [] + try: + for sim in data.get("simulations") or []: + session_id = f"tau2-{args.domain}-train-{sim.get('task_id')}-trial-{sim.get('trial', 0)}" + created = client.create_session(session_id=session_id) + sid = created.get("session_id", session_id) + for msg in sim.get("messages") or []: + role, text = _message_text(msg) + if not text.strip(): + continue + client.add_message( + sid, + role=role, + parts=[{"type": "text", "text": text}], + created_at=msg.get("timestamp"), + ) + result = client.commit_session(sid, telemetry=True) + task = _wait_task(client, result.get("task_id"), args.openviking_wait_timeout) + committed.append( + { + "session_id": sid, + "task_id": sim.get("task_id"), + "commit_status": result.get("status"), + "openviking_task_id": result.get("task_id"), + "openviking_task_status": task.get("status"), + } + ) + finally: + client.close() + + client = _client(args) + try: + corpus_probe = _probe_corpus(args, client) + finally: + client.close() + + manifest = { + "domain": args.domain, + "train_results": str(train_results), + "openviking": { + "url": args.openviking_url, + "account": args.openviking_account, + "user": args.openviking_user, + "agent_id": args.openviking_agent_id, + "search_uri": args.search_uri, + }, + "committed_sessions": committed, + "committed_session_count": len(committed), + "corpus_probe": corpus_probe, + } + _write_json(corpus_manifest, manifest) + return manifest + + +def _register_memory_agent(args: argparse.Namespace, trace_path: Path) -> None: + _add_tau2_to_path(args.tau2_repo) + + from tau2.agent.llm_agent import LLMAgent, LLMAgentState + from tau2.data_model.message import AssistantMessage, MultiToolMessage, SystemMessage + from tau2.registry import registry + from tau2.utils.llm_utils import generate + + class OpenVikingMemoryAgent(LLMAgent): + def get_init_state(self, message_history=None): + state = super().get_init_state(message_history) + if args.retrieval_mode in {"first_user", "first_user_prewrite"}: + state.system_messages.append( + SystemMessage(role="system", content="") + ) + return state + + def _retrieve(self, query: str) -> tuple[str, list[dict[str, Any]]]: + client = _client(args) + rows: list[dict[str, Any]] = [] + try: + result = client.search(query=query, target_uri=args.search_uri, limit=args.retrieval_top_k) + memories = list(getattr(result, "memories", []) or []) + blocks = [] + for index, match in enumerate(memories[: args.retrieval_top_k], 1): + uri = getattr(match, "uri", "") + text, read_error = _read_memory_text(client, match) + row = { + "uri": uri, + "score": getattr(match, "score", None), + "level": getattr(match, "level", None), + "text_chars": len(text), + } + if read_error: + row["read_error"] = read_error + rows.append(row) + if text.strip(): + blocks.append(f"Memory {index} ({uri}):\n{text.strip()}") + return "\n\n".join(blocks), rows + finally: + client.close() + + def _trace(self, event: dict[str, Any]) -> None: + with trace_path.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(event, ensure_ascii=False, sort_keys=True) + "\n") + + @staticmethod + def _trace_injection_fields(block: str, matches: list[dict[str, Any]]) -> dict[str, Any]: + injected_count = sum(1 for row in matches if int(row.get("text_chars") or 0) > 0) + return { + "injected": bool(block.strip()), + "injected_count": injected_count if block.strip() else 0, + "retrieval_action_taken": "retrieve_and_inject" if block.strip() else "retrieve_no_injection", + } + + def _generate(self, messages): + def _is_empty_assistant(response) -> bool: + content = str(getattr(response, "content", "") or "") + tool_calls = getattr(response, "tool_calls", None) or [] + return not content.strip() and not tool_calls + + try: + response = generate( + model=self.llm, + tools=self.tools, + messages=messages, + **self.llm_args, + ) + if not _is_empty_assistant(response): + return response + except json.JSONDecodeError: + retry_messages = messages + [ + SystemMessage( + role="system", + content=( + "Retry the last assistant step once. If you call a tool, " + "the tool arguments must be syntactically valid JSON." + ), + ) + ] + else: + retry_messages = messages + [ + SystemMessage( + role="system", + content=( + "Retry the last assistant step once. Return either a useful " + "natural language response or a valid tool call; do not return " + "an empty assistant message." + ), + ) + ] + try: + response = generate( + model=self.llm, + tools=self.tools, + messages=retry_messages, + **self.llm_args, + ) + if not _is_empty_assistant(response): + return response + return AssistantMessage( + role="assistant", + content="I need to continue with the available task information.", + raw_data={"openviking_memory_agent_error": "empty_assistant_message"}, + ) + except json.JSONDecodeError as exc: + return AssistantMessage( + role="assistant", + content="I need to continue with the available task information.", + raw_data={ + "openviking_memory_agent_error": "invalid_tool_call_json", + "error": str(exc), + }, + ) + + def generate_next_message(self, message, state: LLMAgentState): + if isinstance(message, MultiToolMessage): + state.messages.extend(message.tool_messages) + else: + state.messages.append(message) + marker_index = next( + ( + i + for i, item in enumerate(state.system_messages) + if isinstance(item, SystemMessage) and item.content == "" + ), + None, + ) + role = getattr(message, "role", "") + role_value = getattr(role, "value", role) + if marker_index is not None and str(role_value) == "user": + query = str(getattr(message, "content", "") or "") + block, matches = self._retrieve(query) + prompt = ( + "No OpenViking memory matched this user request." + if not block + else "Use these OpenViking experience memories only when they match the current task:\n\n" + + block + ) + state.system_messages[marker_index] = SystemMessage(role="system", content=prompt) + self._trace( + { + "decision_node": "first_user", + "query": query, + "match_count": len(matches), + "matches": matches, + **self._trace_injection_fields(block, matches), + } + ) + + assistant_message = self._generate(state.system_messages + state.messages) + if args.retrieval_mode in {"prewrite", "first_user_prewrite"}: + tool_calls = list(getattr(assistant_message, "tool_calls", None) or []) + write_calls = [call for call in tool_calls if _is_write_tool_call(call)] + if write_calls: + query = _tool_call_query(write_calls, state.messages) + block, matches = self._retrieve(query) + self._trace( + { + "decision_node": "before_write_tool_call", + "query": query, + "match_count": len(matches), + "matches": matches, + **self._trace_injection_fields(block, matches), + "tool_calls": [ + { + "name": _tool_call_name(call), + "arguments": _tool_call_arguments(call), + } + for call in write_calls + ], + } + ) + if block: + prompt = ( + "Before executing the pending write-like tool call, use these " + "OpenViking experience memories only when they match the current task:\n\n" + + block + ) + assistant_message = self._generate( + state.system_messages + + state.messages + + [SystemMessage(role="system", content=prompt)] + ) + state.messages.append(assistant_message) + return assistant_message, state + + if AGENT_NAME not in registry.get_agents(): + def create_openviking_memory_agent(tools, domain_policy, **kwargs): + return OpenVikingMemoryAgent( + tools=tools, + domain_policy=domain_policy, + llm=kwargs.get("llm"), + llm_args=kwargs.get("llm_args"), + ) + + if hasattr(registry, "register_agent"): + registry.register_agent(OpenVikingMemoryAgent, AGENT_NAME) + else: + registry.register_agent_factory(create_openviking_memory_agent, AGENT_NAME) + + +def main() -> int: + parser = argparse.ArgumentParser(description="Run TAU-2 with OpenViking Memory V2.") + parser.add_argument("--tau2-repo", type=Path, required=True) + parser.add_argument("--run-dir", type=Path, required=True) + parser.add_argument("--corpus-dir", type=Path) + parser.add_argument("--run-label", required=True) + parser.add_argument("--strategy-id", default="memory_v2_experience_only") + parser.add_argument("--domain", required=True) + parser.add_argument("--train-split-name", default="train") + parser.add_argument("--eval-split-name", default="test") + parser.add_argument("--task-id", dest="task_ids", action="append") + parser.add_argument("--num-tasks", type=int) + parser.add_argument("--train-task-id", dest="train_task_ids", action="append") + parser.add_argument("--train-num-tasks", type=int) + parser.add_argument("--max-steps", type=int, default=200) + parser.add_argument("--max-concurrency", type=int, default=10) + parser.add_argument("--seed", type=int, default=300) + parser.add_argument("--base-agent", default="llm_agent") + parser.add_argument("--user", default="user_simulator") + parser.add_argument("--agent-llm", required=True) + parser.add_argument("--user-llm", required=True) + parser.add_argument("--agent-llm-args", type=_json, default={}) + parser.add_argument("--user-llm-args", type=_json, default={}) + parser.add_argument("--openviking-url", required=True) + parser.add_argument("--openviking-account", required=True) + parser.add_argument("--openviking-user", required=True) + parser.add_argument("--openviking-agent-id", required=True) + parser.add_argument("--openviking-timeout", type=float, default=600.0) + parser.add_argument("--openviking-wait-timeout", type=int, default=600) + parser.add_argument("--search-uri", required=True) + parser.add_argument("--retrieval-top-k", type=int, default=4) + parser.add_argument( + "--retrieval-mode", + choices=["first_user", "prewrite", "first_user_prewrite"], + default="first_user", + ) + parser.add_argument("--force-train", action="store_true") + args = parser.parse_args() + normalize_litellm_env() + + args.tau2_repo = args.tau2_repo.resolve() + args.run_dir.mkdir(parents=True, exist_ok=True) + corpus_dir = args.corpus_dir or args.run_dir + corpus_dir.mkdir(parents=True, exist_ok=True) + train_results = corpus_dir / "train_results.json" + corpus_manifest = corpus_dir / "corpus_manifest.json" + eval_results = args.run_dir / f"{args.run_label}.json" + trace_path = args.run_dir / f"{args.run_label}.retrieval_trace.jsonl" + summary_path = args.run_dir / f"{args.run_label}.summary.json" + + corpus = _train(args, train_results, corpus_manifest) + trace_path.touch() + _register_memory_agent(args, trace_path) + _run_tau2( + tau2_repo=args.tau2_repo, + domain=args.domain, + split=args.eval_split_name, + task_ids=args.task_ids, + num_tasks=args.num_tasks, + trials=1, + max_steps=args.max_steps, + max_concurrency=args.max_concurrency, + agent=AGENT_NAME, + user=args.user, + agent_llm=args.agent_llm, + user_llm=args.user_llm, + agent_llm_args=args.agent_llm_args, + user_llm_args=args.user_llm_args, + seed=args.seed, + save_to=eval_results, + ) + summary = { + "run_label": args.run_label, + "domain": args.domain, + "strategy_id": args.strategy_id, + "retrieval_mode": args.retrieval_mode, + "seed": args.seed, + "corpus": corpus, + "eval_results": str(eval_results), + "retrieval_trace": str(trace_path), + "metrics": _metrics(eval_results), + } + _write_json(summary_path, summary) + print(json.dumps(summary, ensure_ascii=False, sort_keys=True)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/benchmark/tau2/scripts/setup_tau2_repo.sh b/benchmark/tau2/scripts/setup_tau2_repo.sh new file mode 100755 index 000000000..3cee2655a --- /dev/null +++ b/benchmark/tau2/scripts/setup_tau2_repo.sh @@ -0,0 +1,82 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +TAU2_DIR="$(cd "$SCRIPT_DIR/.." && pwd)" +DEFAULT_REPO_DIR="$TAU2_DIR/.external/tau2-bench" +DEFAULT_VENV_DIR="$TAU2_DIR/.venv-tau2" + +REPO_URL="${TAU2_REPO_URL:-https://github.com/sierra-research/tau2-bench.git}" +REPO_DIR="${TAU2_REPO:-$DEFAULT_REPO_DIR}" +VENV_DIR="${TAU2_VENV:-$DEFAULT_VENV_DIR}" +REF="${TAU2_REF:-}" +INSTALL=true + +while [[ $# -gt 0 ]]; do + case "$1" in + --repo-url) + REPO_URL="$2" + shift 2 + ;; + --repo-dir) + REPO_DIR="$2" + shift 2 + ;; + --venv) + VENV_DIR="$2" + shift 2 + ;; + --ref) + REF="$2" + shift 2 + ;; + --no-install) + INSTALL=false + shift + ;; + --help|-h) + cat <<'EOF' +Usage: + benchmark/tau2/scripts/setup_tau2_repo.sh [--repo-url URL] [--repo-dir DIR] [--venv DIR] [--ref REF] [--no-install] + +Clones TAU-2 into a local ignored directory and optionally installs it into a +local virtualenv. The script writes benchmark/tau2/.env.tau2 with TAU2_REPO and +TAU2_CLI for the benchmark runner. +EOF + exit 0 + ;; + *) + echo "Unknown argument: $1" >&2 + exit 1 + ;; + esac +done + +mkdir -p "$(dirname "$REPO_DIR")" +if [[ ! -d "$REPO_DIR/.git" ]]; then + git clone "$REPO_URL" "$REPO_DIR" +else + git -C "$REPO_DIR" fetch --all --prune +fi + +if [[ -n "$REF" ]]; then + git -C "$REPO_DIR" checkout "$REF" +fi + +TAU2_CLI="tau2" +if [[ "$INSTALL" == true ]]; then + python3 -m venv "$VENV_DIR" + "$VENV_DIR/bin/python" -m pip install --upgrade pip + "$VENV_DIR/bin/python" -m pip install -e "$REPO_DIR" + TAU2_CLI="$VENV_DIR/bin/tau2" +fi + +cat > "$TAU2_DIR/.env.tau2" < str: + return datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + + +def normalize_litellm_env() -> dict[str, Any]: + aliases = [] + if not os.environ.get("OPENAI_API_KEY") and os.environ.get("ARK_API_KEY"): + os.environ["OPENAI_API_KEY"] = os.environ["ARK_API_KEY"] + aliases.append("OPENAI_API_KEY<-ARK_API_KEY") + ark_base = os.environ.get("ARK_BASE_URL") + openai_base = os.environ.get("OPENAI_API_BASE") or os.environ.get("OPENAI_BASE_URL") + if not openai_base and ark_base: + os.environ["OPENAI_API_BASE"] = ark_base + os.environ["OPENAI_BASE_URL"] = ark_base + aliases.append("OPENAI_API_BASE<-ARK_BASE_URL") + elif os.environ.get("OPENAI_API_BASE") and not os.environ.get("OPENAI_BASE_URL"): + os.environ["OPENAI_BASE_URL"] = os.environ["OPENAI_API_BASE"] + aliases.append("OPENAI_BASE_URL<-OPENAI_API_BASE") + elif os.environ.get("OPENAI_BASE_URL") and not os.environ.get("OPENAI_API_BASE"): + os.environ["OPENAI_API_BASE"] = os.environ["OPENAI_BASE_URL"] + aliases.append("OPENAI_API_BASE<-OPENAI_BASE_URL") + return { + "aliases": aliases, + "has_api_key": bool(os.environ.get("OPENAI_API_KEY") or os.environ.get("ARK_API_KEY")), + "has_base_url": bool( + os.environ.get("OPENAI_API_BASE") + or os.environ.get("OPENAI_BASE_URL") + or os.environ.get("ARK_BASE_URL") + ), + } + + +def render_env(value: Any) -> Any: + if isinstance(value, str): + def replace(match: re.Match[str]) -> str: + name = match.group(1) + default = match.group(2) or "" + return os.environ.get(name, default) + + return _ENV_PATTERN.sub(replace, value) + if isinstance(value, list): + return [render_env(item) for item in value] + if isinstance(value, dict): + return {key: render_env(item) for key, item in value.items()} + return value + + +def deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: + merged = dict(base) + for key, value in override.items(): + if ( + key in merged + and isinstance(merged[key], dict) + and isinstance(value, dict) + ): + merged[key] = deep_merge(merged[key], value) + else: + merged[key] = value + return merged + + +def load_config(path: Path) -> dict[str, Any]: + path = path.expanduser().resolve() + with path.open("r", encoding="utf-8") as handle: + raw = yaml.safe_load(handle) or {} + if not isinstance(raw, dict): + raise ValueError(f"Config must be a mapping: {path}") + + parent_name = raw.pop("extends", None) + if parent_name: + parent_path = (path.parent / str(parent_name)).resolve() + parent = load_config(parent_path) + raw = deep_merge(parent, raw) + return render_env(raw) + + +def resolve_path(path_value: str | Path, *, base: Path | None = None) -> Path: + path = Path(path_value).expanduser() + if path.is_absolute(): + return path + return ((base or REPO_ROOT) / path).resolve() + + +def output_dir(config: dict[str, Any], configured_run_id: str) -> Path: + raw = config.get("paths", {}).get("output_dir", TAU2_DIR / "result") + return resolve_path(raw) / configured_run_id + + +def write_json(path: Path, payload: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(payload, ensure_ascii=False, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + + +def strategy_ids(config: dict[str, Any]) -> list[str]: + strategies = config.get("strategies") or [] + if not isinstance(strategies, list): + raise ValueError("strategies must be a list") + ids = [] + for item in strategies: + if not isinstance(item, dict) or not item.get("id"): + raise ValueError("each strategy must be a mapping with id") + ids.append(str(item["id"])) + if len(ids) != len(set(ids)): + raise ValueError(f"duplicate strategy ids: {ids}") + return ids + + +def domains(config: dict[str, Any]) -> list[str]: + values = config.get("benchmark", {}).get("domains") or [] + if not isinstance(values, list) or not values: + raise ValueError("benchmark.domains must be a non-empty list") + return [str(item) for item in values] + + +def tau2_repo(config: dict[str, Any]) -> Path: + raw = config.get("paths", {}).get("tau2_repo") + if not raw: + raise ValueError("paths.tau2_repo is required") + return resolve_path(raw) + + +def tau2_cli(config: dict[str, Any]) -> str: + return str(config.get("paths", {}).get("tau2_cli") or "tau2") + + +def _git_commit(path: Path) -> str | None: + if not path.exists(): + return None + completed = subprocess.run( + ["git", "-C", str(path), "rev-parse", "HEAD"], + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + check=False, + ) + if completed.returncode != 0: + return None + return completed.stdout.strip() or None + + +def tau2_context(config: dict[str, Any]) -> dict[str, Any]: + repo = tau2_repo(config) + cli = tau2_cli(config) + return { + "tau2_repo": str(repo), + "tau2_repo_exists": repo.exists(), + "tau2_commit": _git_commit(repo), + "tau2_cli": cli, + "tau2_cli_resolved": shutil.which(cli), + } + + +def _prompt_paths(repo: Path) -> list[Path]: + return [ + repo / "data" / "tau2" / "user_simulator" / "simulation_guidelines.md", + repo / "data" / "tau2" / "user_simulator" / "simulation_guidelines_tools.md", + ] + + +def _has_confirmation_aware_prompt(prompt_text: str) -> bool: + normalized = " ".join(prompt_text.split()) + return ( + "reply with the requested confirmation" in normalized + and "do not emit `###STOP###` in the same turn" in normalized + ) + + +def _ensure_confirmation_aware_prompt(repo: Path) -> bool: + patched = False + for path in _prompt_paths(repo): + if not path.is_file(): + continue + text = path.read_text(encoding="utf-8") + if _has_confirmation_aware_prompt(text): + continue + backup = path.with_suffix(path.suffix + ".openviking.bak") + if not backup.exists(): + backup.write_text(text, encoding="utf-8") + path.write_text(text.rstrip() + CONFIRMATION_AWARE_APPENDIX + "\n", encoding="utf-8") + patched = True + return patched + + +def user_simulator_policy(config: dict[str, Any]) -> str: + policy = config.get("eval", {}).get("user_simulator_policy", "official") + policy = str(policy) + if policy not in {"official", "confirmation_aware"}: + raise ValueError( + "eval.user_simulator_policy must be 'official' or 'confirmation_aware'" + ) + return policy + + +def simulator_policy_report(config: dict[str, Any]) -> dict[str, Any]: + policy = user_simulator_policy(config) + repo = tau2_repo(config) + patch_applied = policy == "confirmation_aware" and _ensure_confirmation_aware_prompt(repo) + patch_mode = "direct_prompt_append" if patch_applied else "none" + if policy == "confirmation_aware": + if not patch_applied: + patch_mode = "upstream_or_existing_prompt" + + prompt_paths = _prompt_paths(repo) + prompt_text = "\n".join( + path.read_text(encoding="utf-8") for path in prompt_paths if path.is_file() + ) + confirmation_aware_prompt = _has_confirmation_aware_prompt(prompt_text) + supported = policy == "official" or confirmation_aware_prompt + claim_boundary = "confirmation_aware_user_simulator_prompt" + if policy == "official": + claim_boundary = ( + "official_policy_with_confirmation_aware_checkout" + if confirmation_aware_prompt + else "official_tau2_user_simulator" + ) + return { + "user_simulator_policy": policy, + "supported": supported, + "confirmation_aware_prompt_detected": confirmation_aware_prompt, + "confirmation_aware_upstream_pr": CONFIRMATION_AWARE_UPSTREAM_PR, + "patch_applied": patch_applied, + "patch_mode": patch_mode, + "prompt_files": [str(path) for path in prompt_paths], + "backup_files": [ + str(path.with_suffix(path.suffix + ".openviking.bak")) + for path in prompt_paths + if path.with_suffix(path.suffix + ".openviking.bak").exists() + ], + "claim_boundary": claim_boundary, + } + + +def split_file(config: dict[str, Any], domain: str) -> Path: + return tau2_repo(config) / "data" / "tau2" / "domains" / domain / "split_tasks.json"