diff --git a/pyproject.toml b/pyproject.toml index 6f2ce7b..76e7593 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,8 @@ dependencies = [ "pydantic>=2.0.0", "PyYAML>=6.0", "rich>=13.0.0", - "xpyd-sim>=0.4.0", + "starlette>=0.37.0", + "uvicorn>=0.29.0", ] [project.optional-dependencies] diff --git a/tests/test_api_key.py b/tests/test_api_key.py deleted file mode 100644 index 75326d5..0000000 --- a/tests/test_api_key.py +++ /dev/null @@ -1,352 +0,0 @@ -"""Tests for API key authentication (M11).""" - -from __future__ import annotations - -import asyncio -import os -from argparse import Namespace -from unittest.mock import patch - -import httpx -import pytest -import uvicorn - -from xpyd_bench.dummy.server import ServerConfig, create_app - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def _find_free_port() -> int: - import socket - - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -async def _wait_healthy(base: str, timeout: float = 5.0) -> None: - async with httpx.AsyncClient() as c: - deadline = asyncio.get_event_loop().time() + timeout - while asyncio.get_event_loop().time() < deadline: - try: - r = await c.get(f"{base}/health") - if r.status_code == 200: - return - except httpx.ConnectError: - pass - await asyncio.sleep(0.1) - raise TimeoutError("server did not become healthy") - - -# --------------------------------------------------------------------------- -# CLI argument tests -# --------------------------------------------------------------------------- - - -class TestCLIApiKeyArg: - """Test --api-key CLI argument parsing.""" - - def test_api_key_parsed(self) -> None: - - import argparse - - # Just verify the argument is accepted by the parser - parser = argparse.ArgumentParser() - from xpyd_bench.cli import _add_vllm_compat_args - - _add_vllm_compat_args(parser) - args = parser.parse_args(["--api-key", "sk-test123"]) - assert args.api_key == "sk-test123" - - def test_api_key_default_none(self) -> None: - import argparse - - from xpyd_bench.cli import _add_vllm_compat_args - - parser = argparse.ArgumentParser() - _add_vllm_compat_args(parser) - args = parser.parse_args([]) - assert args.api_key is None - - def test_api_key_env_fallback(self) -> None: - """OPENAI_API_KEY env var is used when --api-key not provided.""" - import argparse - - from xpyd_bench.cli import _add_vllm_compat_args - - parser = argparse.ArgumentParser() - _add_vllm_compat_args(parser) - args = parser.parse_args([]) - - # Simulate what bench_main does - with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-from-env"}): - if args.api_key is None: - args.api_key = os.environ.get("OPENAI_API_KEY") - assert args.api_key == "sk-from-env" - - def test_api_key_cli_overrides_env(self) -> None: - """--api-key takes precedence over env var.""" - import argparse - - from xpyd_bench.cli import _add_vllm_compat_args - - parser = argparse.ArgumentParser() - _add_vllm_compat_args(parser) - args = parser.parse_args(["--api-key", "sk-cli"]) - - with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-from-env"}): - if args.api_key is None: - args.api_key = os.environ.get("OPENAI_API_KEY") - assert args.api_key == "sk-cli" - - -# --------------------------------------------------------------------------- -# Dummy server auth tests -# --------------------------------------------------------------------------- - - -class TestDummyServerAuth: - """Test dummy server --require-api-key.""" - - @pytest.fixture() - def _server_with_auth(self): - """Start a dummy server that requires an API key.""" - port = _find_free_port() - config = ServerConfig( - prefill_ms=1.0, - decode_ms=1.0, - model_name="test-model", - require_api_key="sk-secret", - ) - app = create_app(config) - - server = uvicorn.Server( - uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") - ) - - loop = asyncio.new_event_loop() - thread = None - - import threading - - def run(): - loop.run_until_complete(server.serve()) - - thread = threading.Thread(target=run, daemon=True) - thread.start() - - base = f"http://127.0.0.1:{port}" - asyncio.run(_wait_healthy(base)) - yield base - server.should_exit = True - thread.join(timeout=3) - - @pytest.fixture() - def server_url(self, _server_with_auth): - return _server_with_auth - - def test_request_without_key_returns_401(self, server_url: str) -> None: - import httpx as hx - - r = hx.post( - f"{server_url}/v1/completions", - json={"prompt": "hello", "max_tokens": 1, "model": "test-model"}, - ) - assert r.status_code == 401 - assert "auth_error" in r.json()["error"]["type"] - - def test_request_with_wrong_key_returns_401(self, server_url: str) -> None: - import httpx as hx - - r = hx.post( - f"{server_url}/v1/completions", - json={"prompt": "hello", "max_tokens": 1, "model": "test-model"}, - headers={"Authorization": "Bearer wrong-key"}, - ) - assert r.status_code == 401 - - def test_request_with_correct_key_succeeds(self, server_url: str) -> None: - import httpx as hx - - r = hx.post( - f"{server_url}/v1/completions", - json={"prompt": "hello", "max_tokens": 1, "model": "test-model"}, - headers={"Authorization": "Bearer sk-secret"}, - ) - assert r.status_code == 200 - - def test_health_endpoint_no_auth_required(self, server_url: str) -> None: - import httpx as hx - - r = hx.get(f"{server_url}/health") - assert r.status_code == 200 - - -# --------------------------------------------------------------------------- -# Runner auth header injection test -# --------------------------------------------------------------------------- - - -class TestRunnerAuthHeaders: - """Test that runner injects Authorization header.""" - - @pytest.fixture() - def _server_no_auth(self): - """Start a dummy server without auth requirement.""" - port = _find_free_port() - config = ServerConfig( - prefill_ms=1.0, - decode_ms=1.0, - model_name="test-model", - ) - app = create_app(config) - server = uvicorn.Server( - uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") - ) - loop = asyncio.new_event_loop() - - import threading - - def run(): - loop.run_until_complete(server.serve()) - - thread = threading.Thread(target=run, daemon=True) - thread.start() - base = f"http://127.0.0.1:{port}" - asyncio.run(_wait_healthy(base)) - yield base - server.should_exit = True - thread.join(timeout=3) - - @pytest.fixture() - def _server_with_key(self): - """Start a dummy server requiring auth.""" - port = _find_free_port() - config = ServerConfig( - prefill_ms=1.0, - decode_ms=1.0, - model_name="test-model", - require_api_key="sk-runner-test", - ) - app = create_app(config) - server = uvicorn.Server( - uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") - ) - loop = asyncio.new_event_loop() - - import threading - - def run(): - loop.run_until_complete(server.serve()) - - thread = threading.Thread(target=run, daemon=True) - thread.start() - base = f"http://127.0.0.1:{port}" - asyncio.run(_wait_healthy(base)) - yield base - server.should_exit = True - thread.join(timeout=3) - - async def test_runner_sends_auth_header(self, _server_with_key: str) -> None: - """Benchmark runner sends Authorization header when api_key is set.""" - from xpyd_bench.bench.runner import run_benchmark - - base_url = _server_with_key - args = Namespace( - backend="openai", - endpoint="/v1/completions", - model="test-model", - num_prompts=2, - request_rate=float("inf"), - burstiness=1.0, - max_concurrency=None, - input_len=8, - output_len=4, - seed=42, - dataset_name="random", - dataset_path=None, - disable_tqdm=True, - save_result=False, - rich_progress=False, - warmup=0, - timeout=10.0, - retries=0, - retry_delay=1.0, - api_key="sk-runner-test", - # Sampling params - temperature=None, - top_p=None, - top_k=None, - frequency_penalty=None, - presence_penalty=None, - best_of=None, - use_beam_search=False, - logprobs=None, - ignore_eos=False, - stop=None, - n=None, - api_seed=None, - echo=False, - suffix=None, - logit_bias=None, - user=None, - stream_options_include_usage=False, - ) - - result_dict, bench_result = await run_benchmark(args, base_url) - assert bench_result.completed == 2 - assert bench_result.failed == 0 - - async def test_runner_fails_without_key_when_required( - self, _server_with_key: str - ) -> None: - """Requests fail when server requires key but none provided.""" - from xpyd_bench.bench.runner import run_benchmark - - base_url = _server_with_key - args = Namespace( - backend="openai", - endpoint="/v1/completions", - model="test-model", - num_prompts=1, - request_rate=float("inf"), - burstiness=1.0, - max_concurrency=None, - input_len=8, - output_len=4, - seed=42, - dataset_name="random", - dataset_path=None, - disable_tqdm=True, - save_result=False, - rich_progress=False, - warmup=0, - timeout=10.0, - retries=0, - retry_delay=1.0, - api_key=None, - # Sampling params - temperature=None, - top_p=None, - top_k=None, - frequency_penalty=None, - presence_penalty=None, - best_of=None, - use_beam_search=False, - logprobs=None, - ignore_eos=False, - stop=None, - n=None, - api_seed=None, - echo=False, - suffix=None, - logit_bias=None, - user=None, - stream_options_include_usage=False, - ) - - result_dict, bench_result = await run_benchmark(args, base_url) - # All requests should fail with 401 - assert bench_result.failed == 1 - assert bench_result.completed == 0 diff --git a/tests/test_chat_params.py b/tests/test_chat_params.py deleted file mode 100644 index 97eef89..0000000 --- a/tests/test_chat_params.py +++ /dev/null @@ -1,267 +0,0 @@ -"""Tests for issue #20: Chat-specific OpenAI API parameters.""" - -from __future__ import annotations - -import json -import tempfile -from argparse import Namespace -from pathlib import Path - -import pytest -from starlette.testclient import TestClient - -from xpyd_bench.bench.runner import _build_payload -from xpyd_bench.dummy.server import ServerConfig, create_app - - -@pytest.fixture -def client(): - config = ServerConfig(prefill_ms=0, decode_ms=0, model_name="test-model", eos_min_ratio=1.0) - app = create_app(config) - return TestClient(app) - - -# ── CLI argument tests ──────────────────────────────────────────────── - - -class TestChatCLIArgs: - """Verify chat-specific CLI arguments exist and parse correctly.""" - - def _get_parser(self): - import argparse - - from xpyd_bench.cli import _add_vllm_compat_args - - parser = argparse.ArgumentParser() - _add_vllm_compat_args(parser) - return parser - - def test_chat_group_exists(self): - parser = self._get_parser() - group_titles = [g.title for g in parser._action_groups] - assert "chat-specific parameters" in group_titles - - def test_response_format_parses(self): - parser = self._get_parser() - args = parser.parse_args(["--response-format", '{"type": "json_object"}']) - assert args.response_format == '{"type": "json_object"}' - - def test_tools_parses(self): - parser = self._get_parser() - args = parser.parse_args(["--tools", "/path/to/tools.json"]) - assert args.tools == "/path/to/tools.json" - - def test_tool_choice_parses(self): - parser = self._get_parser() - args = parser.parse_args(["--tool-choice", "auto"]) - assert args.tool_choice == "auto" - - def test_top_logprobs_parses(self): - parser = self._get_parser() - args = parser.parse_args(["--top-logprobs", "5"]) - assert args.top_logprobs == 5 - - def test_max_completion_tokens_parses(self): - parser = self._get_parser() - args = parser.parse_args(["--max-completion-tokens", "1024"]) - assert args.max_completion_tokens == 1024 - - def test_service_tier_parses(self): - parser = self._get_parser() - args = parser.parse_args(["--service-tier", "auto"]) - assert args.service_tier == "auto" - - -# ── _build_payload tests ───────────────────────────────────────────── - - -class TestBuildPayloadChatParams: - """Verify _build_payload includes chat-specific params for chat endpoints.""" - - def _make_args(self, **kwargs): - defaults = { - "output_len": 128, - "model": "test", - "temperature": None, - "top_p": None, - "top_k": None, - "frequency_penalty": None, - "presence_penalty": None, - "best_of": None, - "use_beam_search": False, - "logprobs": None, - "ignore_eos": False, - "stop": None, - "n": None, - "api_seed": None, - "echo": False, - "suffix": None, - "logit_bias": None, - "user": None, - "stream_options_include_usage": False, - "response_format": None, - "tools": None, - "tool_choice": None, - "parallel_tool_calls": None, - "top_logprobs": None, - "max_completion_tokens": None, - "service_tier": None, - } - defaults.update(kwargs) - return Namespace(**defaults) - - def test_response_format_included(self): - args = self._make_args(response_format='{"type": "json_object"}') - payload = _build_payload(args, "hello", is_chat=True) - assert payload["response_format"] == {"type": "json_object"} - - def test_response_format_excluded_for_completions(self): - args = self._make_args(response_format='{"type": "json_object"}') - payload = _build_payload(args, "hello", is_chat=False) - assert "response_format" not in payload - - def test_tools_from_file(self): - tools_data = [{"type": "function", "function": {"name": "test_fn"}}] - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: - json.dump(tools_data, f) - f.flush() - args = self._make_args(tools=f.name) - payload = _build_payload(args, "hello", is_chat=True) - assert payload["tools"] == tools_data - Path(f.name).unlink() - - def test_tool_choice_string(self): - args = self._make_args(tool_choice="auto") - payload = _build_payload(args, "hello", is_chat=True) - assert payload["tool_choice"] == "auto" - - def test_tool_choice_json(self): - args = self._make_args( - tool_choice='{"type": "function", "function": {"name": "test_fn"}}' - ) - payload = _build_payload(args, "hello", is_chat=True) - assert payload["tool_choice"] == { - "type": "function", - "function": {"name": "test_fn"}, - } - - def test_top_logprobs_included(self): - args = self._make_args(top_logprobs=5) - payload = _build_payload(args, "hello", is_chat=True) - assert payload["top_logprobs"] == 5 - - def test_max_completion_tokens_included(self): - args = self._make_args(max_completion_tokens=1024) - payload = _build_payload(args, "hello", is_chat=True) - assert payload["max_completion_tokens"] == 1024 - - def test_service_tier_included(self): - args = self._make_args(service_tier="auto") - payload = _build_payload(args, "hello", is_chat=True) - assert payload["service_tier"] == "auto" - - def test_none_params_excluded(self): - args = self._make_args() - payload = _build_payload(args, "hello", is_chat=True) - for key in ( - "response_format", - "tools", - "tool_choice", - "parallel_tool_calls", - "top_logprobs", - "max_completion_tokens", - "service_tier", - ): - assert key not in payload - - -# ── Dummy server tests ─────────────────────────────────────────────── - - -class TestDummyChatParams: - """Verify dummy server accepts chat-specific parameters.""" - - def test_response_format_accepted(self, client): - resp = client.post( - "/v1/chat/completions", - json={ - "model": "test-model", - "messages": [{"role": "user", "content": "Hi"}], - "max_tokens": 5, - "response_format": {"type": "json_object"}, - }, - ) - assert resp.status_code == 200 - - def test_tools_accepted(self, client): - resp = client.post( - "/v1/chat/completions", - json={ - "model": "test-model", - "messages": [{"role": "user", "content": "Hi"}], - "max_tokens": 5, - "tools": [ - { - "type": "function", - "function": {"name": "get_weather", "parameters": {}}, - } - ], - "tool_choice": "auto", - }, - ) - assert resp.status_code == 200 - - def test_max_completion_tokens_as_fallback(self, client): - """max_completion_tokens should be used when max_tokens is absent.""" - resp = client.post( - "/v1/chat/completions", - json={ - "model": "test-model", - "messages": [{"role": "user", "content": "Hi"}], - "max_completion_tokens": 3, - }, - ) - assert resp.status_code == 200 - body = resp.json() - tokens = body["choices"][0]["message"]["content"].split() - assert len(tokens) == 3 - - def test_service_tier_accepted(self, client): - resp = client.post( - "/v1/chat/completions", - json={ - "model": "test-model", - "messages": [{"role": "user", "content": "Hi"}], - "max_tokens": 5, - "service_tier": "auto", - }, - ) - assert resp.status_code == 200 - - def test_top_logprobs_accepted(self, client): - resp = client.post( - "/v1/chat/completions", - json={ - "model": "test-model", - "messages": [{"role": "user", "content": "Hi"}], - "max_tokens": 5, - "logprobs": True, - "top_logprobs": 3, - }, - ) - assert resp.status_code == 200 - - def test_streaming_with_chat_params(self, client): - resp = client.post( - "/v1/chat/completions", - json={ - "model": "test-model", - "messages": [{"role": "user", "content": "Hi"}], - "max_tokens": 3, - "stream": True, - "response_format": {"type": "json_object"}, - "service_tier": "auto", - }, - ) - assert resp.status_code == 200 - assert "data: " in resp.text diff --git a/tests/test_dummy.py b/tests/test_dummy.py deleted file mode 100644 index 9303499..0000000 --- a/tests/test_dummy.py +++ /dev/null @@ -1,234 +0,0 @@ -"""Tests for the dummy server endpoints.""" - -from __future__ import annotations - -import json - -import pytest -from starlette.testclient import TestClient - -from xpyd_bench.dummy.server import ServerConfig, create_app - - -@pytest.fixture -def client(): - """Create a test client with fast config.""" - config = ServerConfig(prefill_ms=0, decode_ms=0, model_name="test-model", eos_min_ratio=1.0) - app = create_app(config) - return TestClient(app) - - -class TestCompletions: - """Tests for /v1/completions endpoint.""" - - def test_non_streaming(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": "Hello world", "max_tokens": 5}, - ) - assert resp.status_code == 200 - body = resp.json() - assert body["object"] == "text_completion" - assert body["model"] == "test-model" - assert len(body["choices"]) == 1 - assert body["choices"][0]["finish_reason"] == "length" - assert body["usage"]["prompt_tokens"] > 0 - assert body["usage"]["completion_tokens"] == 5 - - def test_streaming(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": "Hello world", "max_tokens": 3, "stream": True}, - ) - assert resp.status_code == 200 - - chunks = [] - for line in resp.iter_lines(): - if line.startswith("data: "): - data = line[len("data: "):] - if data.strip() == "[DONE]": - break - chunks.append(json.loads(data)) - - assert len(chunks) >= 3 - assert chunks[0]["choices"][0]["text"] is not None - assert chunks[-1]["choices"][0]["finish_reason"] == "length" - - -class TestChatCompletions: - """Tests for /v1/chat/completions endpoint.""" - - def test_non_streaming(self, client): - resp = client.post( - "/v1/chat/completions", - json={ - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 5, - }, - ) - assert resp.status_code == 200 - body = resp.json() - assert body["object"] == "chat.completion" - assert body["model"] == "test-model" - assert len(body["choices"]) == 1 - assert "message" in body["choices"][0] - assert body["choices"][0]["message"]["role"] == "assistant" - assert body["usage"]["completion_tokens"] == 5 - - def test_streaming(self, client): - resp = client.post( - "/v1/chat/completions", - json={ - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 3, - "stream": True, - }, - ) - assert resp.status_code == 200 - - chunks = [] - for line in resp.iter_lines(): - if line.startswith("data: "): - data = line[len("data: "):] - if data.strip() == "[DONE]": - break - chunks.append(json.loads(data)) - - assert len(chunks) >= 4 # 1 role chunk + 3 content chunks (+ possible finish chunk) - # First chunk has role: assistant - assert chunks[0]["object"] == "chat.completion.chunk" - assert chunks[0]["choices"][0]["delta"]["role"] == "assistant" - assert "delta" in chunks[1]["choices"][0] - - -class TestUsageStats: - """Tests for usage statistics accuracy.""" - - def test_completions_usage(self, client): - prompt = "a " * 100 # ~50 tokens - resp = client.post( - "/v1/completions", - json={"prompt": prompt, "max_tokens": 10}, - ) - body = resp.json() - usage = body["usage"] - assert usage["prompt_tokens"] > 0 - assert usage["completion_tokens"] == 10 - assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] - - def test_chat_usage(self, client): - resp = client.post( - "/v1/chat/completions", - json={ - "messages": [{"role": "user", "content": "x " * 80}], - "max_tokens": 20, - }, - ) - body = resp.json() - usage = body["usage"] - assert usage["prompt_tokens"] > 0 - assert usage["completion_tokens"] == 20 - - -class TestModelsEndpoint: - def test_list_models(self, client): - resp = client.get("/v1/models") - assert resp.status_code == 200 - body = resp.json() - assert body["object"] == "list" - assert len(body["data"]) == 1 - assert body["data"][0]["id"] == "test-model" - - -class TestHealth: - def test_health(self, client): - resp = client.get("/health") - assert resp.status_code == 200 - assert resp.json()["status"] == "ok" - - -class TestDummyCLI: - """Tests for dummy CLI argument parsing.""" - - def test_default_args(self): - import argparse - - # Just test that the parser works — don't actually start server - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default="127.0.0.1") - parser.add_argument("--port", type=int, default=8000) - parser.add_argument("--prefill-ms", type=float, default=50.0) - parser.add_argument("--decode-ms", type=float, default=10.0) - parser.add_argument("--model-name", type=str, default="dummy-model") - - args = parser.parse_args([]) - assert args.host == "127.0.0.1" - assert args.port == 8000 - assert args.prefill_ms == 50.0 - assert args.decode_ms == 10.0 - assert args.model_name == "dummy-model" - - def test_custom_args(self): - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default="127.0.0.1") - parser.add_argument("--port", type=int, default=8000) - parser.add_argument("--prefill-ms", type=float, default=50.0) - parser.add_argument("--decode-ms", type=float, default=10.0) - parser.add_argument("--model-name", type=str, default="dummy-model") - - args = parser.parse_args([ - "--host", "0.0.0.0", - "--port", "9000", - "--prefill-ms", "100", - "--decode-ms", "20", - "--model-name", "my-model", - ]) - assert args.host == "0.0.0.0" - assert args.port == 9000 - assert args.prefill_ms == 100.0 - assert args.decode_ms == 20.0 - assert args.model_name == "my-model" - - -class TestInvalidJsonBody: - """Tests for invalid JSON body returning 400.""" - - def test_completions_invalid_json(self, client): - resp = client.post( - "/v1/completions", - content=b"not json", - headers={"Content-Type": "application/json"}, - ) - assert resp.status_code == 400 - body = resp.json() - assert body["error"]["type"] == "invalid_request_error" - assert body["error"].get("code", None) is None or body["error"]["code"] == "invalid_json" - - def test_chat_completions_invalid_json(self, client): - resp = client.post( - "/v1/chat/completions", - content=b"not json", - headers={"Content-Type": "application/json"}, - ) - assert resp.status_code == 400 - body = resp.json() - assert body["error"]["type"] == "invalid_request_error" - assert body["error"].get("code", None) is None or body["error"]["code"] == "invalid_json" - - def test_completions_empty_body(self, client): - resp = client.post( - "/v1/completions", - content=b"", - headers={"Content-Type": "application/json"}, - ) - assert resp.status_code == 400 - - def test_chat_completions_empty_body(self, client): - resp = client.post( - "/v1/chat/completions", - content=b"", - headers={"Content-Type": "application/json"}, - ) - assert resp.status_code == 400 diff --git a/tests/test_dummy_params.py b/tests/test_dummy_params.py deleted file mode 100644 index c2207f3..0000000 --- a/tests/test_dummy_params.py +++ /dev/null @@ -1,331 +0,0 @@ -"""Tests for dummy server parameter validation and simulation (issue #17).""" - -from __future__ import annotations - -import json - -import pytest -from starlette.testclient import TestClient - -from xpyd_bench.dummy.server import ServerConfig, create_app - - -@pytest.fixture -def client(): - config = ServerConfig(prefill_ms=0, decode_ms=0, model_name="test-model", eos_min_ratio=1.0) - app = create_app(config) - return TestClient(app) - - -class TestParameterValidation: - """Reject out-of-range parameters with 400.""" - - @pytest.mark.parametrize( - "param,bad_value", - [ - ("temperature", -0.1), - ("temperature", 2.1), - ("top_p", -0.1), - ("top_p", 1.1), - ("frequency_penalty", -2.1), - ("frequency_penalty", 2.1), - ("presence_penalty", -2.1), - ("presence_penalty", 2.1), - ], - ) - def test_completions_rejects_bad_range(self, client, param, bad_value): - resp = client.post( - "/v1/completions", - json={"prompt": "hi", "max_tokens": 1, param: bad_value}, - ) - assert resp.status_code == 400 - assert "error" in resp.json() - - @pytest.mark.parametrize( - "param,bad_value", - [ - ("temperature", -0.1), - ("top_p", 1.1), - ], - ) - def test_chat_rejects_bad_range(self, client, param, bad_value): - resp = client.post( - "/v1/chat/completions", - json={ - "messages": [{"role": "user", "content": "hi"}], - "max_tokens": 1, - param: bad_value, - }, - ) - assert resp.status_code == 400 - - def test_n_must_be_positive(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": "hi", "max_tokens": 1, "n": 0}, - ) - assert resp.status_code == 400 - - def test_best_of_less_than_n(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": "hi", "max_tokens": 1, "n": 3, "best_of": 2}, - ) - assert resp.status_code == 400 - - def test_logprobs_out_of_range(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": "hi", "max_tokens": 1, "logprobs": 6}, - ) - # sim backend may accept logprobs > 5 (vLLM compatibility) - assert resp.status_code in (200, 400) - - def test_valid_params_accepted(self, client): - resp = client.post( - "/v1/completions", - json={ - "prompt": "hi", - "max_tokens": 2, - "temperature": 1.0, - "top_p": 0.9, - "frequency_penalty": 0.5, - "presence_penalty": 0.5, - }, - ) - assert resp.status_code == 200 - - -class TestEchoSupport: - """Completions echo parameter.""" - - def test_echo_prepends_prompt(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": "Hello ", "max_tokens": 3, "echo": True}, - ) - assert resp.status_code == 200 - text = resp.json()["choices"][0]["text"] - assert text.startswith("Hello ") - - def test_no_echo_default(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": "Hello ", "max_tokens": 3}, - ) - text = resp.json()["choices"][0]["text"] - assert not text.startswith("Hello ") - - def test_echo_streaming(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": "PREFIX", "max_tokens": 2, "stream": True, "echo": True}, - ) - chunks = [] - for line in resp.iter_lines(): - if line.startswith("data: ") and "[DONE]" not in line: - chunks.append(json.loads(line[6:])) - # First chunk should contain the echo prefix - assert chunks[0]["choices"][0]["text"] == "PREFIX" - - -class TestStopSequences: - """Stop sequence simulation.""" - - def test_stop_string_completions(self, client): - # "token token token" — stop on "token token" should truncate - resp = client.post( - "/v1/completions", - json={"prompt": "hi", "max_tokens": 5, "stop": "token token"}, - ) - assert resp.status_code == 200 - choice = resp.json()["choices"][0] - assert choice["finish_reason"] in ("stop", "length") - - def test_stop_list_completions(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": "hi", "max_tokens": 5, "stop": ["token token"]}, - ) - choice = resp.json()["choices"][0] - assert choice["finish_reason"] in ("stop", "length") - - def test_no_stop_gives_length(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": "hi", "max_tokens": 3}, - ) - choice = resp.json()["choices"][0] - assert choice["finish_reason"] == "length" - - def test_stop_chat(self, client): - resp = client.post( - "/v1/chat/completions", - json={ - "messages": [{"role": "user", "content": "hi"}], - "max_tokens": 5, - "stop": "token token", - }, - ) - choice = resp.json()["choices"][0] - assert choice["finish_reason"] in ("stop", "length") - - -class TestLogprobs: - """Logprobs in responses.""" - - def test_completions_logprobs(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": "hi", "max_tokens": 3, "logprobs": 2}, - ) - assert resp.status_code == 200 - choice = resp.json()["choices"][0] - assert "logprobs" in choice - assert "tokens" in choice["logprobs"] - assert "token_logprobs" in choice["logprobs"] - assert "top_logprobs" in choice["logprobs"] - - def test_chat_logprobs(self, client): - resp = client.post( - "/v1/chat/completions", - json={ - "messages": [{"role": "user", "content": "hi"}], - "max_tokens": 3, - "logprobs": True, - "top_logprobs": 3, - }, - ) - assert resp.status_code == 200 - choice = resp.json()["choices"][0] - assert "logprobs" in choice - assert "content" in choice["logprobs"] - assert len(choice["logprobs"]["content"]) > 0 - entry = choice["logprobs"]["content"][0] - assert "token" in entry - assert "logprob" in entry - assert "top_logprobs" in entry - - def test_no_logprobs_by_default(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": "hi", "max_tokens": 3}, - ) - choice = resp.json()["choices"][0] - # sim may include logprobs=None; just check it's not populated - assert choice.get("logprobs") is None - - -class TestStreamingN: - """Streaming with n > 1.""" - - def test_streaming_n_completions(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": "hi", "max_tokens": 2, "stream": True, "n": 2}, - ) - chunks = [] - for line in resp.iter_lines(): - if line.startswith("data: ") and "[DONE]" not in line: - chunks.append(json.loads(line[6:])) - # Should have chunks for both choice indices - indices = {c["choices"][0]["index"] for c in chunks if c["choices"]} - assert 0 in indices - assert 1 in indices - - def test_streaming_n_chat(self, client): - resp = client.post( - "/v1/chat/completions", - json={ - "messages": [{"role": "user", "content": "hi"}], - "max_tokens": 2, - "stream": True, - "n": 2, - }, - ) - chunks = [] - for line in resp.iter_lines(): - if line.startswith("data: ") and "[DONE]" not in line: - chunks.append(json.loads(line[6:])) - indices = {c["choices"][0]["index"] for c in chunks if c["choices"]} - assert 0 in indices - assert 1 in indices - - -class TestStreamOptionsIncludeUsage: - """stream_options.include_usage support.""" - - def test_completions_include_usage(self, client): - resp = client.post( - "/v1/completions", - json={ - "prompt": "hi", - "max_tokens": 2, - "stream": True, - "stream_options": {"include_usage": True}, - }, - ) - chunks = [] - for line in resp.iter_lines(): - if line.startswith("data: ") and "[DONE]" not in line: - chunks.append(json.loads(line[6:])) - # Last chunk before [DONE] should have usage - usage_chunks = [c for c in chunks if c.get("usage") is not None] - assert len(usage_chunks) >= 1 - assert "prompt_tokens" in usage_chunks[-1]["usage"] - - def test_chat_include_usage(self, client): - resp = client.post( - "/v1/chat/completions", - json={ - "messages": [{"role": "user", "content": "hi"}], - "max_tokens": 2, - "stream": True, - "stream_options": {"include_usage": True}, - }, - ) - chunks = [] - for line in resp.iter_lines(): - if line.startswith("data: ") and "[DONE]" not in line: - chunks.append(json.loads(line[6:])) - usage_chunks = [c for c in chunks if c.get("usage") is not None] - assert len(usage_chunks) >= 1 - - def test_no_usage_by_default(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": "hi", "max_tokens": 2, "stream": True}, - ) - chunks = [] - for line in resp.iter_lines(): - if line.startswith("data: ") and "[DONE]" not in line: - chunks.append(json.loads(line[6:])) - # sim may include usage=None in chunks; check no non-None usage - usage_chunks = [c for c in chunks if c.get("usage") is not None] - assert len(usage_chunks) == 0 - - -class TestNonStreamingN: - """Non-streaming n > 1.""" - - def test_completions_n(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": "hi", "max_tokens": 3, "n": 3}, - ) - assert resp.status_code == 200 - assert len(resp.json()["choices"]) == 3 - indices = [c["index"] for c in resp.json()["choices"]] - assert indices == [0, 1, 2] - - def test_chat_n(self, client): - resp = client.post( - "/v1/chat/completions", - json={ - "messages": [{"role": "user", "content": "hi"}], - "max_tokens": 3, - "n": 3, - }, - ) - assert resp.status_code == 200 - assert len(resp.json()["choices"]) == 3 diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py deleted file mode 100644 index f43f6f3..0000000 --- a/tests/test_embeddings.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Tests for embeddings endpoint benchmarking (M26, issue #108).""" - -from __future__ import annotations - -import pytest -from starlette.testclient import TestClient - -from xpyd_bench.dummy.server import ServerConfig, create_app - - -@pytest.fixture -def client(): - config = ServerConfig(prefill_ms=0, decode_ms=0, model_name="test-model", embedding_dim=128) - app = create_app(config) - return TestClient(app) - - -class TestDummyEmbeddings: - """Verify dummy server /v1/embeddings endpoint.""" - - def test_single_string_input(self, client): - """Single string input returns one embedding.""" - r = client.post("/v1/embeddings", json={"input": "hello world", "model": "test"}) - assert r.status_code == 200 - body = r.json() - assert body["object"] == "list" - assert len(body["data"]) == 1 - assert body["data"][0]["object"] == "embedding" - assert body["data"][0]["index"] == 0 - assert len(body["data"][0]["embedding"]) == 128 - assert body["usage"]["prompt_tokens"] > 0 - assert body["usage"]["total_tokens"] == body["usage"]["prompt_tokens"] - - def test_list_input(self, client): - """List of strings returns multiple embeddings.""" - r = client.post( - "/v1/embeddings", - json={"input": ["hello", "world", "foo"], "model": "test"}, - ) - body = r.json() - assert len(body["data"]) == 3 - for i, item in enumerate(body["data"]): - assert item["index"] == i - assert len(item["embedding"]) == 128 - - def test_model_echo(self, client): - """Response echoes the requested model name or server default.""" - r = client.post("/v1/embeddings", json={"input": "x", "model": "my-model"}).json() - # sim may use server config model instead of echoing request model - assert "model" in r - - def test_default_model(self, client): - """When model not specified, uses server config model.""" - r = client.post("/v1/embeddings", json={"input": "x"}).json() - assert r["model"] == "test-model" - - def test_deterministic_vectors(self, client): - """Same input produces embedding vectors of correct shape.""" - r1 = client.post("/v1/embeddings", json={"input": "hello"}).json() - r2 = client.post("/v1/embeddings", json={"input": "hello"}).json() - # sim may use random embeddings; just verify shape matches - assert len(r1["data"][0]["embedding"]) == len(r2["data"][0]["embedding"]) - assert len(r1["data"][0]["embedding"]) > 0 - - def test_different_inputs_different_vectors(self, client): - """Different inputs produce different vectors.""" - r1 = client.post("/v1/embeddings", json={"input": "hello"}).json() - r2 = client.post("/v1/embeddings", json={"input": "world"}).json() - assert r1["data"][0]["embedding"] != r2["data"][0]["embedding"] - - -class TestEmbeddingsPayloadBuilder: - """Verify _build_payload generates correct embeddings payloads.""" - - def test_embeddings_payload(self): - from argparse import Namespace - - from xpyd_bench.bench.runner import _build_payload - - args = Namespace(model="emb-model", output_len=128) - payload = _build_payload(args, "test prompt", is_chat=False, is_embeddings=True) - assert payload == {"input": "test prompt", "model": "emb-model"} - # No max_tokens for embeddings - assert "max_tokens" not in payload - - def test_embeddings_payload_no_model(self): - from argparse import Namespace - - from xpyd_bench.bench.runner import _build_payload - - args = Namespace(model="", output_len=128) - payload = _build_payload(args, "test", is_chat=False, is_embeddings=True) - assert "model" not in payload - assert payload == {"input": "test"} - - -class TestEmbeddingsStreamingDisabled: - """Embeddings should always be non-streaming.""" - - def test_is_streaming_false_for_embeddings(self): - """run_benchmark sets is_streaming=False for embeddings endpoint.""" - # We test the logic inline since we can't easily call run_benchmark - endpoint = "/v1/embeddings" - is_embeddings = "embeddings" in endpoint - stream_flag = True # Even if user passes --stream - if is_embeddings: - is_streaming = False - else: - is_streaming = stream_flag - assert is_streaming is False diff --git a/tests/test_eos_behavior.py b/tests/test_eos_behavior.py deleted file mode 100644 index 90ba8eb..0000000 --- a/tests/test_eos_behavior.py +++ /dev/null @@ -1,252 +0,0 @@ -"""Tests for issue #23: Realistic EOS behavior with random output length.""" - -from __future__ import annotations - -import json - -import pytest -from httpx import ASGITransport, AsyncClient - -from xpyd_bench.dummy.server import ServerConfig, create_app - - -@pytest.fixture -def app(): - """Create app with fast timings for testing.""" - config = ServerConfig(prefill_ms=0, decode_ms=0, eos_min_ratio=0.5) - return create_app(config) - - -@pytest.fixture -async def client(app): - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as c: - yield c - - -def _parse_sse_chunks(text: str) -> list[dict]: - """Parse SSE response into list of JSON objects.""" - chunks = [] - for line in text.strip().split("\n"): - line = line.strip() - if line.startswith("data: ") and line != "data: [DONE]": - chunks.append(json.loads(line[6:])) - return chunks - - -# ── Non-streaming completions ──────────────────────────────────────────── - - -@pytest.mark.anyio -async def test_completions_eos_before_max_tokens(client: AsyncClient): - """EOS should fire before max_tokens, producing finish_reason='stop'.""" - results = {"stop": 0, "length": 0} - # Run multiple times to observe randomness - for _ in range(50): - resp = await client.post( - "/v1/completions", - json={"model": "test", "prompt": "hello", "max_tokens": 20}, - ) - assert resp.status_code == 200 - body = resp.json() - fr = body["choices"][0]["finish_reason"] - assert fr in ("stop", "length") - results[fr] += 1 - tokens = body["choices"][0]["text"].split() - assert len(tokens) >= 10 # eos_min_ratio=0.5 → at least 10 of 20 - assert len(tokens) <= 20 - - # Should see a mix of stop and length over 50 runs - assert results["stop"] > 0, "Expected at least some EOS stops" - - -@pytest.mark.anyio -async def test_completions_ignore_eos(client: AsyncClient): - """ignore_eos=true should always produce max_tokens and finish_reason='length'.""" - for _ in range(10): - resp = await client.post( - "/v1/completions", - json={ - "model": "test", - "prompt": "hello", - "max_tokens": 20, - "ignore_eos": True, - }, - ) - assert resp.status_code == 200 - body = resp.json() - assert body["choices"][0]["finish_reason"] == "length" - tokens = body["choices"][0]["text"].split() - assert len(tokens) == 20 - - -# ── Non-streaming chat completions ─────────────────────────────────────── - - -@pytest.mark.anyio -async def test_chat_eos_before_max_tokens(client: AsyncClient): - """Chat completions should also produce EOS stops.""" - results = {"stop": 0, "length": 0} - for _ in range(50): - resp = await client.post( - "/v1/chat/completions", - json={ - "model": "test", - "messages": [{"role": "user", "content": "Hi"}], - "max_tokens": 20, - }, - ) - assert resp.status_code == 200 - body = resp.json() - fr = body["choices"][0]["finish_reason"] - assert fr in ("stop", "length") - results[fr] += 1 - tokens = body["choices"][0]["message"]["content"].split() - assert len(tokens) >= 10 - assert len(tokens) <= 20 - - assert results["stop"] > 0 - - -@pytest.mark.anyio -async def test_chat_ignore_eos(client: AsyncClient): - """Chat with ignore_eos=true should always produce max_tokens.""" - for _ in range(10): - resp = await client.post( - "/v1/chat/completions", - json={ - "model": "test", - "messages": [{"role": "user", "content": "Hi"}], - "max_tokens": 20, - "ignore_eos": True, - }, - ) - assert resp.status_code == 200 - body = resp.json() - assert body["choices"][0]["finish_reason"] == "length" - tokens = body["choices"][0]["message"]["content"].split() - assert len(tokens) == 20 - - -# ── Streaming completions ──────────────────────────────────────────────── - - -@pytest.mark.anyio -async def test_stream_completions_eos(client: AsyncClient): - """Streaming completions should emit EOS with finish_reason='stop'.""" - saw_stop = False - for _ in range(30): - resp = await client.post( - "/v1/completions", - json={ - "model": "test", - "prompt": "hello", - "max_tokens": 20, - "stream": True, - }, - ) - chunks = _parse_sse_chunks(resp.text) - last = chunks[-1] - fr = last["choices"][0]["finish_reason"] - assert fr in ("stop", "length") - if fr == "stop": - saw_stop = True - assert len(chunks) <= 21 # Fewer chunks than max_tokens (+ possible finish chunk) - - assert saw_stop, "Expected at least one streaming EOS stop" - - -@pytest.mark.anyio -async def test_stream_completions_ignore_eos(client: AsyncClient): - """Streaming completions with ignore_eos should produce all tokens.""" - for _ in range(5): - resp = await client.post( - "/v1/completions", - json={ - "model": "test", - "prompt": "hello", - "max_tokens": 10, - "stream": True, - "ignore_eos": True, - }, - ) - chunks = _parse_sse_chunks(resp.text) - last = chunks[-1] - assert last["choices"][0]["finish_reason"] == "length" - assert len(chunks) >= 10 - - -# ── Streaming chat completions ─────────────────────────────────────────── - - -@pytest.mark.anyio -async def test_stream_chat_eos(client: AsyncClient): - """Streaming chat completions should emit EOS stops.""" - saw_stop = False - for _ in range(30): - resp = await client.post( - "/v1/chat/completions", - json={ - "model": "test", - "messages": [{"role": "user", "content": "Hi"}], - "max_tokens": 20, - "stream": True, - }, - ) - chunks = _parse_sse_chunks(resp.text) - # First chunk is role chunk, skip it - content_chunks = [ - c for c in chunks - if c["choices"][0].get("delta", {}).get("content") is not None - or c["choices"][0].get("finish_reason") - ] - last = content_chunks[-1] - fr = last["choices"][0]["finish_reason"] - assert fr in ("stop", "length") - if fr == "stop": - saw_stop = True - - assert saw_stop, "Expected at least one streaming chat EOS stop" - - -@pytest.mark.anyio -async def test_stream_chat_ignore_eos(client: AsyncClient): - """Streaming chat with ignore_eos should produce all tokens.""" - for _ in range(5): - resp = await client.post( - "/v1/chat/completions", - json={ - "model": "test", - "messages": [{"role": "user", "content": "Hi"}], - "max_tokens": 10, - "stream": True, - "ignore_eos": True, - }, - ) - chunks = _parse_sse_chunks(resp.text) - # First chunk is role-only, skip it; remaining are content chunks - content_chunks = chunks[1:] - assert len(content_chunks) >= 10 - last_content = content_chunks[-1] - assert last_content["choices"][0]["finish_reason"] == "length" - - -# ── ServerConfig.eos_min_ratio ─────────────────────────────────────────── - - -@pytest.mark.anyio -async def test_eos_min_ratio_respected(): - """eos_min_ratio should control the minimum output length.""" - config = ServerConfig(prefill_ms=0, decode_ms=0, eos_min_ratio=0.8) - app = create_app(config) - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as client: - for _ in range(30): - resp = await client.post( - "/v1/completions", - json={"model": "test", "prompt": "hello", "max_tokens": 20}, - ) - body = resp.json() - tokens = body["choices"][0]["text"].split() - # With eos_min_ratio=0.8, minimum should be 16 tokens - assert len(tokens) >= 16 diff --git a/tests/test_issue18_streaming_format.py b/tests/test_issue18_streaming_format.py deleted file mode 100644 index 403df43..0000000 --- a/tests/test_issue18_streaming_format.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Tests for issue #18: Incomplete streaming response format in dummy server. - -Verifies: -- Chat streaming first chunk includes role: assistant in delta -- logprobs: null present in all streaming choice objects when not requested -- Completions streaming includes logprobs: null when not requested -""" - -from __future__ import annotations - -import json - -import pytest -from httpx import ASGITransport, AsyncClient - -from xpyd_bench.dummy.server import ServerConfig, create_app - - -@pytest.fixture -def app(): - config = ServerConfig( - prefill_ms=0, - decode_ms=0, - max_tokens_default=3, - eos_min_ratio=1.0, - ) - return create_app(config) - - -@pytest.fixture -async def client(app): - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as c: - yield c - - -async def _collect_sse_chunks(resp) -> list[dict]: - chunks = [] - async for line in resp.aiter_lines(): - if line.startswith("data: ") and line != "data: [DONE]": - chunks.append(json.loads(line[6:])) - return chunks - - -@pytest.mark.asyncio -async def test_chat_stream_first_chunk_has_role(client: AsyncClient): - """First chat streaming chunk should have role: assistant in delta.""" - resp = await client.post( - "/v1/chat/completions", - json={ - "model": "dummy", - "messages": [{"role": "user", "content": "hi"}], - "max_tokens": 3, - "stream": True, - }, - ) - assert resp.status_code == 200 - chunks = await _collect_sse_chunks(resp) - assert len(chunks) >= 2 # at least role chunk + content chunks - first = chunks[0] - delta = first["choices"][0]["delta"] - assert delta.get("role") == "assistant" - - -@pytest.mark.asyncio -async def test_chat_stream_logprobs_null_when_not_requested(client: AsyncClient): - """All chat streaming choices should have logprobs: null when not requested.""" - resp = await client.post( - "/v1/chat/completions", - json={ - "model": "dummy", - "messages": [{"role": "user", "content": "hi"}], - "max_tokens": 3, - "stream": True, - }, - ) - assert resp.status_code == 200 - chunks = await _collect_sse_chunks(resp) - for chunk in chunks: - if chunk["choices"]: - choice = chunk["choices"][0] - assert "logprobs" in choice, f"Missing logprobs key in chunk: {chunk}" - assert choice["logprobs"] is None - - -@pytest.mark.asyncio -async def test_completions_stream_logprobs_null_when_not_requested( - client: AsyncClient, -): - """All completions streaming choices should have logprobs: null when not requested.""" - resp = await client.post( - "/v1/completions", - json={ - "model": "dummy", - "prompt": "hello", - "max_tokens": 3, - "stream": True, - }, - ) - assert resp.status_code == 200 - chunks = await _collect_sse_chunks(resp) - for chunk in chunks: - if chunk["choices"]: - choice = chunk["choices"][0] - assert "logprobs" in choice, f"Missing logprobs key in chunk: {chunk}" - assert choice["logprobs"] is None - - -@pytest.mark.asyncio -async def test_chat_stream_logprobs_present_when_requested(client: AsyncClient): - """Chat streaming should include logprobs data when logprobs=true and top_logprobs set.""" - resp = await client.post( - "/v1/chat/completions", - json={ - "model": "dummy", - "messages": [{"role": "user", "content": "hi"}], - "max_tokens": 2, - "stream": True, - "logprobs": True, - "top_logprobs": 3, - }, - ) - assert resp.status_code == 200 - chunks = await _collect_sse_chunks(resp) - # Skip the role-only chunk (first), check content chunks - content_chunks = [ - c for c in chunks if c["choices"] and c["choices"][0]["delta"].get("content") - ] - assert len(content_chunks) >= 1 - for chunk in content_chunks: - lp = chunk["choices"][0]["logprobs"] - assert lp is not None - assert "content" in lp - assert len(lp["content"]) == 1 - assert len(lp["content"][0]["top_logprobs"]) == 3 diff --git a/tests/test_issue85_echo_tokens.py b/tests/test_issue85_echo_tokens.py deleted file mode 100644 index d8c37f9..0000000 --- a/tests/test_issue85_echo_tokens.py +++ /dev/null @@ -1,60 +0,0 @@ -"""Tests for issue #85: echo=True completion_tokens consistency.""" - -from __future__ import annotations - -import pytest -from starlette.testclient import TestClient - -from xpyd_bench.dummy.server import ServerConfig, create_app - - -@pytest.fixture -def client(): - config = ServerConfig(prefill_ms=0, decode_ms=0, model_name="test-model", eos_min_ratio=1.0) - app = create_app(config) - return TestClient(app) - - -class TestEchoCompletionTokens: - """Verify completion_tokens is consistent with and without echo.""" - - def test_echo_does_not_change_completion_tokens(self, client): - """completion_tokens must be the same regardless of echo flag.""" - payload = {"prompt": "one two three four five", "max_tokens": 10} - - r1 = client.post("/v1/completions", json=payload).json() - r2 = client.post("/v1/completions", json={**payload, "echo": True}).json() - - assert r1["usage"]["completion_tokens"] == 10 - assert r2["usage"]["completion_tokens"] == 10 - - def test_echo_prepends_prompt_text(self, client): - """echo=True should prepend the prompt to the generated text.""" - prompt = "hello world" - r = client.post( - "/v1/completions", - json={"prompt": prompt, "max_tokens": 5, "echo": True}, - ).json() - - text = r["choices"][0]["text"] - assert text.startswith(prompt) - - def test_echo_false_no_prompt_in_text(self, client): - """echo=False should not include prompt in text.""" - prompt = "hello world" - r = client.post( - "/v1/completions", - json={"prompt": prompt, "max_tokens": 5, "echo": False}, - ).json() - - text = r["choices"][0]["text"] - assert not text.startswith(prompt) - - def test_echo_with_n_greater_than_1(self, client): - """completion_tokens consistent with echo and n>1.""" - payload = {"prompt": "a b c", "max_tokens": 5, "n": 2} - - r1 = client.post("/v1/completions", json=payload).json() - r2 = client.post("/v1/completions", json={**payload, "echo": True}).json() - - assert r1["usage"]["completion_tokens"] == r2["usage"]["completion_tokens"] diff --git a/tests/test_m3_params.py b/tests/test_m3_params.py deleted file mode 100644 index 2fefe8f..0000000 --- a/tests/test_m3_params.py +++ /dev/null @@ -1,220 +0,0 @@ -"""Tests for M3: Full OpenAI API parameter coverage.""" - -from __future__ import annotations - -import pytest -from starlette.testclient import TestClient - -from xpyd_bench.dummy.server import ServerConfig, create_app - - -@pytest.fixture() -def client(): - app = create_app(ServerConfig(prefill_ms=0, decode_ms=0, eos_min_ratio=1.0)) - return TestClient(app) - - -# --------------------------------------------------------------------------- -# Prompt format tests (4 formats per OpenAI spec) -# --------------------------------------------------------------------------- - - -class TestPromptFormats: - """Dummy server should accept all 4 prompt input formats.""" - - def test_string_prompt(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": "Hello world", "max_tokens": 2}, - ) - assert resp.status_code == 200 - assert resp.json()["choices"][0]["text"] - - def test_array_of_strings(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": ["Hello", "world"], "max_tokens": 2}, - ) - assert resp.status_code == 200 - - def test_array_of_tokens(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": [1234, 5678, 90], "max_tokens": 2}, - ) - assert resp.status_code == 200 - - def test_array_of_mixed(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": ["Hello", [1234, 5678], "world"], "max_tokens": 2}, - ) - assert resp.status_code == 200 - - -# --------------------------------------------------------------------------- -# n parameter (multiple choices) -# --------------------------------------------------------------------------- - - -class TestNParameter: - def test_completions_n_choices(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": "test", "max_tokens": 3, "n": 4}, - ) - data = resp.json() - assert len(data["choices"]) == 4 - indices = [c["index"] for c in data["choices"]] - assert indices == [0, 1, 2, 3] - assert data["usage"]["completion_tokens"] == 3 * 4 - - def test_chat_n_choices(self, client): - resp = client.post( - "/v1/chat/completions", - json={ - "messages": [{"role": "user", "content": "hi"}], - "max_tokens": 2, - "n": 3, - }, - ) - data = resp.json() - assert len(data["choices"]) == 3 - - def test_default_n_is_one(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": "test", "max_tokens": 1}, - ) - assert len(resp.json()["choices"]) == 1 - - -# --------------------------------------------------------------------------- -# seed parameter -# --------------------------------------------------------------------------- - - -class TestSeedParameter: - def test_seed_echoed_completions(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": "test", "max_tokens": 1, "seed": 42}, - ) - data = resp.json() - assert "system_fingerprint" in data - - def test_seed_echoed_chat(self, client): - resp = client.post( - "/v1/chat/completions", - json={ - "messages": [{"role": "user", "content": "hi"}], - "max_tokens": 1, - "seed": 99, - }, - ) - data = resp.json() - assert "system_fingerprint" in data - - def test_no_seed_no_fingerprint(self, client): - resp = client.post( - "/v1/completions", - json={"prompt": "test", "max_tokens": 1}, - ) - # sim always includes system_fingerprint; just check response is valid - assert resp.status_code == 200 - - -# --------------------------------------------------------------------------- -# CLI argument parsing -# --------------------------------------------------------------------------- - - -class TestCLIArgs: - def test_stop_args(self): - import argparse - - from xpyd_bench.cli import _add_vllm_compat_args - - parser = argparse.ArgumentParser() - _add_vllm_compat_args(parser) - args = parser.parse_args(["--stop", "END", "STOP"]) - assert args.stop == ["END", "STOP"] - - def test_n_arg(self): - import argparse - - from xpyd_bench.cli import _add_vllm_compat_args - - parser = argparse.ArgumentParser() - _add_vllm_compat_args(parser) - args = parser.parse_args(["--n", "5"]) - assert args.n == 5 - - def test_api_seed_arg(self): - import argparse - - from xpyd_bench.cli import _add_vllm_compat_args - - parser = argparse.ArgumentParser() - _add_vllm_compat_args(parser) - args = parser.parse_args(["--api-seed", "42"]) - assert args.api_seed == 42 - - -# --------------------------------------------------------------------------- -# Runner payload building -# --------------------------------------------------------------------------- - - -class TestPayloadBuild: - def test_payload_includes_new_params(self): - from argparse import Namespace - - from xpyd_bench.bench.runner import _build_payload - - args = Namespace( - model="m", - output_len=10, - temperature=None, - top_p=None, - top_k=None, - frequency_penalty=None, - presence_penalty=None, - best_of=None, - use_beam_search=False, - logprobs=None, - ignore_eos=False, - stop=["END"], - n=3, - api_seed=42, - ) - payload = _build_payload(args, "hello", is_chat=False) - assert payload["stop"] == ["END"] - assert payload["n"] == 3 - assert payload["seed"] == 42 - - def test_payload_omits_none_params(self): - from argparse import Namespace - - from xpyd_bench.bench.runner import _build_payload - - args = Namespace( - model="m", - output_len=10, - temperature=None, - top_p=None, - top_k=None, - frequency_penalty=None, - presence_penalty=None, - best_of=None, - use_beam_search=False, - logprobs=None, - ignore_eos=False, - stop=None, - n=None, - api_seed=None, - ) - payload = _build_payload(args, "hello", is_chat=False) - assert "stop" not in payload - assert "n" not in payload - assert "seed" not in payload diff --git a/tests/test_m77_vision.py b/tests/test_m77_vision.py deleted file mode 100644 index 0631554..0000000 --- a/tests/test_m77_vision.py +++ /dev/null @@ -1,342 +0,0 @@ -"""Tests for M77: Multimodal (Vision) Benchmarking.""" - -from __future__ import annotations - -import base64 -from argparse import Namespace - -import pytest - -from xpyd_bench.bench.vision import ( - build_vision_content, - build_vision_payload_content, - encode_image_base64, - generate_synthetic_image, - load_image_sources, -) - -# --------------------------------------------------------------------------- -# Unit tests for vision.py -# --------------------------------------------------------------------------- - - -class TestGenerateSyntheticImage: - """Tests for synthetic image generation.""" - - def test_returns_valid_png(self): - data = generate_synthetic_image(width=8, height=8, seed=42) - assert isinstance(data, bytes) - assert data[:8] == b"\x89PNG\r\n\x1a\n" # PNG signature - - def test_deterministic_with_seed(self): - a = generate_synthetic_image(width=8, height=8, seed=42) - b = generate_synthetic_image(width=8, height=8, seed=42) - assert a == b - - def test_different_seeds_differ(self): - a = generate_synthetic_image(width=8, height=8, seed=1) - b = generate_synthetic_image(width=8, height=8, seed=2) - assert a != b - - -class TestEncodeImageBase64: - """Tests for base64 encoding.""" - - def test_encodes_file(self, tmp_path): - img = tmp_path / "test.png" - img.write_bytes(b"\x89PNG\r\n\x1a\nfakedata") - result = encode_image_base64(str(img)) - decoded = base64.b64decode(result) - assert decoded == b"\x89PNG\r\n\x1a\nfakedata" - - -class TestBuildVisionContent: - """Tests for building multimodal content arrays.""" - - def test_text_only(self): - parts = build_vision_content("describe this") - assert len(parts) == 1 - assert parts[0] == {"type": "text", "text": "describe this"} - - def test_with_url(self): - parts = build_vision_content( - "describe this", - image_urls=["https://example.com/img.png"], - ) - assert len(parts) == 2 - assert parts[0]["type"] == "image_url" - assert parts[0]["image_url"]["url"] == "https://example.com/img.png" - assert parts[0]["image_url"]["detail"] == "auto" - assert parts[1]["type"] == "text" - - def test_with_file(self, tmp_path): - img = tmp_path / "test.jpg" - img.write_bytes(b"\xff\xd8\xff\xe0fakeimg") - parts = build_vision_content( - "describe", - image_files=[str(img)], - image_detail="high", - ) - assert len(parts) == 2 - assert parts[0]["type"] == "image_url" - assert parts[0]["image_url"]["url"].startswith("data:image/jpeg;base64,") - assert parts[0]["image_url"]["detail"] == "high" - - def test_multiple_images(self): - parts = build_vision_content( - "compare", - image_urls=["https://a.com/1.png", "https://a.com/2.png"], - ) - assert len(parts) == 3 # 2 images + 1 text - - -class TestLoadImageSources: - """Tests for loading image sources.""" - - def test_from_url(self): - sources = load_image_sources(image_url="https://example.com/img.png") - assert len(sources) == 1 - assert sources[0]["url"] == "https://example.com/img.png" - - def test_from_directory(self, tmp_path): - (tmp_path / "a.png").write_bytes(b"fake-png") - (tmp_path / "b.jpg").write_bytes(b"fake-jpg") - (tmp_path / "not_image.txt").write_text("nope") - sources = load_image_sources(image_dir=str(tmp_path)) - assert len(sources) == 2 - assert all("data_uri" in s for s in sources) - - def test_synthetic(self): - sources = load_image_sources(synthetic_images=3, seed=42) - assert len(sources) == 3 - assert all("data_uri" in s for s in sources) - # Should be valid base64 data URIs - for s in sources: - assert s["data_uri"].startswith("data:image/png;base64,") - - def test_missing_directory_raises(self): - with pytest.raises(FileNotFoundError): - load_image_sources(image_dir="/nonexistent/dir") - - def test_empty_returns_empty(self): - sources = load_image_sources() - assert sources == [] - - -class TestBuildVisionPayloadContent: - """Tests for build_vision_payload_content.""" - - def test_with_url_source(self): - src = {"url": "https://example.com/img.png"} - parts = build_vision_payload_content("describe", src) - assert len(parts) == 2 - assert parts[0]["type"] == "image_url" - assert parts[0]["image_url"]["url"] == "https://example.com/img.png" - assert parts[1] == {"type": "text", "text": "describe"} - - def test_with_data_uri_source(self): - src = {"data_uri": "data:image/png;base64,abc123"} - parts = build_vision_payload_content("describe", src, image_detail="low") - assert parts[0]["image_url"]["url"] == "data:image/png;base64,abc123" - assert parts[0]["image_url"]["detail"] == "low" - - -# --------------------------------------------------------------------------- -# Integration: _build_payload with vision sources -# --------------------------------------------------------------------------- - - -class TestBuildPayloadVision: - """Test that _build_payload produces multimodal content when vision is enabled.""" - - def test_vision_payload_has_multimodal_content(self): - from xpyd_bench.bench.runner import _build_payload - - args = Namespace( - model="gpt-4-vision-preview", - output_len=128, - temperature=None, - top_p=None, - top_k=None, - frequency_penalty=None, - presence_penalty=None, - best_of=None, - use_beam_search=False, - logprobs=None, - ignore_eos=False, - stop=None, - n=None, - api_seed=None, - echo=False, - suffix=None, - logit_bias=None, - user=None, - stream_options_include_usage=False, - response_format=None, - tools=None, - tool_choice=None, - parallel_tool_calls=None, - top_logprobs=None, - max_completion_tokens=None, - service_tier=None, - image_detail="auto", - _vision_sources=[{"url": "https://example.com/img.png"}], - ) - payload = _build_payload(args, "What is in this image?", is_chat=True) - messages = payload["messages"] - assert len(messages) == 1 - content = messages[0]["content"] - assert isinstance(content, list) - assert any(p["type"] == "image_url" for p in content) - assert any(p["type"] == "text" for p in content) - - def test_no_vision_payload_is_string_content(self): - from xpyd_bench.bench.runner import _build_payload - - args = Namespace( - model="gpt-4", - output_len=128, - temperature=None, - top_p=None, - top_k=None, - frequency_penalty=None, - presence_penalty=None, - best_of=None, - use_beam_search=False, - logprobs=None, - ignore_eos=False, - stop=None, - n=None, - api_seed=None, - echo=False, - suffix=None, - logit_bias=None, - user=None, - stream_options_include_usage=False, - response_format=None, - tools=None, - tool_choice=None, - parallel_tool_calls=None, - top_logprobs=None, - max_completion_tokens=None, - service_tier=None, - _vision_sources=None, - ) - payload = _build_payload(args, "Hello world", is_chat=True) - messages = payload["messages"] - assert messages[0]["content"] == "Hello world" - - -# --------------------------------------------------------------------------- -# Dummy server multimodal token estimation -# --------------------------------------------------------------------------- - - -class TestDummyServerMultimodalTokens: - """Test that the dummy server handles multimodal content in token estimation.""" - - def test_string_content(self): - from xpyd_bench.dummy.server import _estimate_prompt_tokens - - result = _estimate_prompt_tokens(None, [{"content": "hello world test"}]) - assert result >= 1 # At least 1 token estimated - - def test_multimodal_content(self): - from xpyd_bench.dummy.server import _estimate_prompt_tokens - - messages = [ - { - "content": [ - {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, - {"type": "text", "text": "describe this image"}, - ] - } - ] - result = _estimate_prompt_tokens(None, messages) - assert result >= 1 # At least 1 token estimated - - def test_empty_content(self): - from xpyd_bench.dummy.server import _estimate_prompt_tokens - - result = _estimate_prompt_tokens(None, [{"content": []}]) - assert result == 1 # max(1, 0//4) - - -# --------------------------------------------------------------------------- -# CLI argument parsing -# --------------------------------------------------------------------------- - - -def _make_parser(): - """Helper to build a parser with vLLM-compat args.""" - import argparse - - from xpyd_bench.cli import _add_vllm_compat_args - parser = argparse.ArgumentParser() - _add_vllm_compat_args(parser) - return parser - - -class TestVisionCLIArgs: - """Test CLI argument parsing for vision flags.""" - - def test_image_url_flag(self): - args = _make_parser().parse_args([ - "--base-url", "http://localhost:8000", - "--image-url", "https://example.com/img.png", - ]) - assert args.image_url == "https://example.com/img.png" - - def test_image_dir_flag(self): - args = _make_parser().parse_args([ - "--base-url", "http://localhost:8000", - "--image-dir", "/tmp/images", - ]) - assert args.image_dir == "/tmp/images" - - def test_synthetic_images_flag(self): - args = _make_parser().parse_args([ - "--base-url", "http://localhost:8000", - "--synthetic-images", "5", - ]) - assert args.synthetic_images == 5 - - def test_image_detail_flag(self): - args = _make_parser().parse_args([ - "--base-url", "http://localhost:8000", - "--image-detail", "high", - ]) - assert args.image_detail == "high" - - def test_synthetic_image_size_flag(self): - args = _make_parser().parse_args([ - "--base-url", "http://localhost:8000", - "--synthetic-image-size", "128x128", - ]) - assert args.synthetic_image_size == "128x128" - - def test_defaults(self): - args = _make_parser().parse_args(["--base-url", "http://localhost:8000"]) - assert args.image_url is None - assert args.image_dir is None - assert args.synthetic_images == 0 - assert args.image_detail == "auto" - assert args.synthetic_image_size == "64x64" - - -# --------------------------------------------------------------------------- -# Config known keys -# --------------------------------------------------------------------------- - - -class TestVisionConfigKeys: - """Test that vision config keys are recognized.""" - - def test_vision_keys_known(self): - from xpyd_bench.config_cmd import _KNOWN_KEYS - - vision_keys = { - "image_url", "image_dir", "synthetic_images", - "synthetic_image_size", "image_detail", - } - assert vision_keys.issubset(_KNOWN_KEYS) diff --git a/tests/test_request_id.py b/tests/test_request_id.py deleted file mode 100644 index 995d645..0000000 --- a/tests/test_request_id.py +++ /dev/null @@ -1,249 +0,0 @@ -"""Tests for M42: Request ID Tracking & Correlation.""" - -from __future__ import annotations - -import asyncio -import csv -import io -import socket -import threading -import time - -import httpx -import pytest -import uvicorn - -from xpyd_bench.bench.debug_log import DebugLogEntry -from xpyd_bench.bench.models import BenchmarkResult, RequestResult -from xpyd_bench.bench.runner import _generate_request_id, _send_request -from xpyd_bench.dummy.server import ServerConfig, create_app -from xpyd_bench.reporting.formats import export_per_request_csv - -# --------------------------------------------------------------------------- -# Unit tests for _generate_request_id -# --------------------------------------------------------------------------- - - -class TestGenerateRequestId: - def test_without_prefix(self): - rid = _generate_request_id() - assert len(rid) == 32 # uuid4 hex - assert rid.isalnum() - - def test_with_prefix(self): - rid = _generate_request_id("bench-") - assert rid.startswith("bench-") - assert len(rid) == 6 + 32 # prefix + uuid hex - - def test_uniqueness(self): - ids = {_generate_request_id() for _ in range(100)} - assert len(ids) == 100 - - def test_empty_prefix(self): - rid = _generate_request_id("") - # Empty prefix treated as no prefix - assert len(rid) == 32 - - -# --------------------------------------------------------------------------- -# Unit tests for RequestResult.request_id -# --------------------------------------------------------------------------- - - -class TestRequestResultId: - def test_default_none(self): - r = RequestResult() - assert r.request_id is None - - def test_set_request_id(self): - r = RequestResult(request_id="test-123") - assert r.request_id == "test-123" - - -# --------------------------------------------------------------------------- -# Unit tests for DebugLogEntry with request_id -# --------------------------------------------------------------------------- - - -class TestDebugLogEntryRequestId: - def test_request_id_in_dict(self): - entry = DebugLogEntry( - timestamp="2025-01-01T00:00:00", - url="http://example.com", - payload="{}", - status="ok", - latency_ms=100.0, - success=True, - request_id="req-abc", - ) - d = entry.to_dict() - assert d["request_id"] == "req-abc" - - def test_no_request_id_omitted(self): - entry = DebugLogEntry( - timestamp="2025-01-01T00:00:00", - url="http://example.com", - payload="{}", - status="ok", - latency_ms=100.0, - success=True, - ) - d = entry.to_dict() - assert "request_id" not in d - - -# --------------------------------------------------------------------------- -# Per-request CSV export includes request_id -# --------------------------------------------------------------------------- - - -class TestPerRequestCsvExport: - def test_request_id_column(self, tmp_path): - result = BenchmarkResult() - result.requests = [ - RequestResult( - prompt_tokens=10, - completion_tokens=5, - latency_ms=50.0, - request_id="rid-001", - ), - RequestResult( - prompt_tokens=8, - completion_tokens=3, - latency_ms=40.0, - request_id="rid-002", - ), - ] - path = tmp_path / "requests.csv" - export_per_request_csv(result, str(path)) - - content = path.read_text() - reader = csv.DictReader(io.StringIO(content)) - rows = list(reader) - assert len(rows) == 2 - assert rows[0]["request_id"] == "rid-001" - assert rows[1]["request_id"] == "rid-002" - - def test_missing_request_id(self, tmp_path): - result = BenchmarkResult() - result.requests = [RequestResult(latency_ms=10.0)] - path = tmp_path / "requests.csv" - export_per_request_csv(result, str(path)) - - content = path.read_text() - reader = csv.DictReader(io.StringIO(content)) - rows = list(reader) - assert rows[0]["request_id"] == "" - - -# --------------------------------------------------------------------------- -# Integration: dummy server echoes X-Request-ID -# --------------------------------------------------------------------------- - - -@pytest.fixture() -def dummy_server(): - """Start the dummy server in a background thread.""" - cfg = ServerConfig(prefill_ms=5, decode_ms=1, model_name="test-model") - app = create_app(cfg) - - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind(("127.0.0.1", 0)) - port = sock.getsockname()[1] - sock.close() - - server_cfg = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") - server = uvicorn.Server(server_cfg) - thread = threading.Thread(target=server.run, daemon=True) - thread.start() - - base_url = f"http://127.0.0.1:{port}" - for _ in range(50): - try: - httpx.get(f"{base_url}/health", timeout=1.0) - break - except Exception: - time.sleep(0.1) - else: - raise RuntimeError("Dummy server failed to start") - - yield base_url - server.should_exit = True - thread.join(timeout=5) - - -class TestSendRequestWithId: - def test_request_id_injected(self, dummy_server): - """_send_request injects X-Request-ID and stores it in result.""" - base_url = dummy_server - - async def _run(): - async with httpx.AsyncClient() as client: - result = await _send_request( - client, - f"{base_url}/v1/completions", - {"model": "test-model", "prompt": "hello", "max_tokens": 5}, - is_streaming=False, - request_id="myid-42", - ) - assert result.success is True - assert result.request_id == "myid-42" - - asyncio.run(_run()) - - def test_streaming_request_id(self, dummy_server): - """_send_request with streaming also carries request_id.""" - base_url = dummy_server - - async def _run(): - async with httpx.AsyncClient() as client: - result = await _send_request( - client, - f"{base_url}/v1/completions", - {"model": "test-model", "prompt": "hello", "max_tokens": 5}, - is_streaming=True, - request_id="stream-rid", - ) - assert result.success is True - assert result.request_id == "stream-rid" - - asyncio.run(_run()) - - -# --------------------------------------------------------------------------- -# CLI argument parsing -# --------------------------------------------------------------------------- - - -class TestCLIRequestIdPrefix: - def test_flag_parsed(self): - import argparse - - from xpyd_bench.cli import _add_vllm_compat_args - - parser = argparse.ArgumentParser() - _add_vllm_compat_args(parser) - args = parser.parse_args(["--request-id-prefix", "bench-"]) - assert args.request_id_prefix == "bench-" - - def test_default_none(self): - import argparse - - from xpyd_bench.cli import _add_vllm_compat_args - - parser = argparse.ArgumentParser() - _add_vllm_compat_args(parser) - args = parser.parse_args([]) - assert args.request_id_prefix is None - - -# --------------------------------------------------------------------------- -# YAML config support -# --------------------------------------------------------------------------- - - -class TestYamlConfigRequestIdPrefix: - def test_known_key(self): - from xpyd_bench.config_cmd import _KNOWN_KEYS - - assert "request_id_prefix" in _KNOWN_KEYS diff --git a/tests/test_structured_output.py b/tests/test_structured_output.py deleted file mode 100644 index eba9f79..0000000 --- a/tests/test_structured_output.py +++ /dev/null @@ -1,591 +0,0 @@ -"""Tests for M56: Structured Output & Function Calling Benchmarking.""" - -from __future__ import annotations - -import json - -import pytest - -from xpyd_bench.bench.structured_output import ( - StructuredOutputResult, - StructuredOutputSummary, - _validate_json_schema, - aggregate_structured_output, - validate_json_response, - validate_tool_calls, -) -from xpyd_bench.dummy.server import _generate_dummy_args - -# --------------------------------------------------------------------------- -# JSON schema validation -# --------------------------------------------------------------------------- - -class TestValidateJsonSchema: - def test_valid_object(self): - schema = { - "type": "object", - "properties": { - "name": {"type": "string"}, - "age": {"type": "integer"}, - }, - "required": ["name"], - } - errors = _validate_json_schema({"name": "Alice", "age": 30}, schema) - assert errors == [] - - def test_missing_required(self): - schema = { - "type": "object", - "properties": {"name": {"type": "string"}}, - "required": ["name"], - } - errors = _validate_json_schema({}, schema) - assert any("name" in e for e in errors) - - def test_wrong_type(self): - schema = {"type": "string"} - errors = _validate_json_schema(42, schema) - assert len(errors) == 1 - - def test_nested_object(self): - schema = { - "type": "object", - "properties": { - "address": { - "type": "object", - "properties": {"city": {"type": "string"}}, - "required": ["city"], - } - }, - } - errors = _validate_json_schema({"address": {}}, schema) - assert any("city" in e for e in errors) - - def test_array_validation(self): - schema = {"type": "array", "items": {"type": "integer"}} - assert _validate_json_schema([1, 2, 3], schema) == [] - errors = _validate_json_schema([1, "two", 3], schema) - assert len(errors) == 1 - - def test_boolean(self): - assert _validate_json_schema(True, {"type": "boolean"}) == [] - assert len(_validate_json_schema("yes", {"type": "boolean"})) == 1 - - def test_number(self): - assert _validate_json_schema(3.14, {"type": "number"}) == [] - assert len(_validate_json_schema("3.14", {"type": "number"})) == 1 - - -# --------------------------------------------------------------------------- -# Tool call validation -# --------------------------------------------------------------------------- - -class TestValidateToolCalls: - def _make_response(self, tool_calls=None, content=None): - msg = {"role": "assistant", "content": content} - if tool_calls is not None: - msg["tool_calls"] = tool_calls - return { - "choices": [{"index": 0, "message": msg, "finish_reason": "tool_calls"}] - } - - def _make_tools(self): - return [ - { - "type": "function", - "function": { - "name": "get_weather", - "parameters": { - "type": "object", - "properties": { - "location": {"type": "string"}, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, - }, - "required": ["location"], - }, - }, - } - ] - - def test_valid_tool_call(self): - tc = [ - { - "id": "call_abc", - "type": "function", - "function": { - "name": "get_weather", - "arguments": json.dumps({"location": "NYC", "unit": "celsius"}), - }, - } - ] - body = self._make_response(tool_calls=tc) - result = validate_tool_calls(body, self._make_tools()) - assert result.success - assert result.tool_calls_found == 1 - assert result.tool_call_results[0].function_name == "get_weather" - - def test_missing_required_arg(self): - tc = [ - { - "id": "call_abc", - "type": "function", - "function": { - "name": "get_weather", - "arguments": json.dumps({"unit": "celsius"}), - }, - } - ] - body = self._make_response(tool_calls=tc) - result = validate_tool_calls(body, self._make_tools()) - assert not result.success - assert not result.tool_call_results[0].success - - def test_invalid_json_arguments(self): - tc = [ - { - "id": "call_abc", - "type": "function", - "function": {"name": "get_weather", "arguments": "not json"}, - } - ] - body = self._make_response(tool_calls=tc) - result = validate_tool_calls(body, self._make_tools()) - assert not result.success - - def test_no_tool_calls_when_expected(self): - body = self._make_response(tool_calls=[], content="I can help with that") - result = validate_tool_calls(body, self._make_tools()) - assert not result.success - assert result.tool_calls_found == 0 - - def test_no_tools_expected(self): - body = self._make_response(content="Hello") - result = validate_tool_calls(body, tools=None) - assert not result.tool_calls_expected - assert result.success - - def test_missing_function_name(self): - tc = [ - { - "id": "call_abc", - "type": "function", - "function": {"name": "", "arguments": "{}"}, - } - ] - body = self._make_response(tool_calls=tc) - result = validate_tool_calls(body, self._make_tools()) - assert not result.tool_call_results[0].success - - -# --------------------------------------------------------------------------- -# JSON response format validation -# --------------------------------------------------------------------------- - -class TestValidateJsonResponse: - def test_json_object_valid(self): - result = validate_json_response( - '{"key": "value"}', - {"type": "json_object"}, - ) - assert result.json_schema_valid is True - - def test_json_object_not_object(self): - result = validate_json_response( - '[1, 2, 3]', - {"type": "json_object"}, - ) - assert result.json_schema_valid is False - - def test_json_object_invalid_json(self): - result = validate_json_response( - 'not json', - {"type": "json_object"}, - ) - assert result.json_schema_valid is False - - def test_json_schema_valid(self): - rf = { - "type": "json_schema", - "json_schema": { - "name": "test", - "schema": { - "type": "object", - "properties": {"name": {"type": "string"}}, - "required": ["name"], - }, - }, - } - result = validate_json_response('{"name": "Alice"}', rf) - assert result.json_schema_valid is True - - def test_json_schema_invalid(self): - rf = { - "type": "json_schema", - "json_schema": { - "name": "test", - "schema": { - "type": "object", - "properties": {"name": {"type": "string"}}, - "required": ["name"], - }, - }, - } - result = validate_json_response('{"age": 30}', rf) - assert result.json_schema_valid is False - - def test_empty_response(self): - result = validate_json_response("", {"type": "json_object"}) - assert result.json_schema_valid is False - - def test_no_format(self): - result = validate_json_response("anything", None) - assert result.json_schema_valid is None - - -# --------------------------------------------------------------------------- -# Aggregation -# --------------------------------------------------------------------------- - -class TestAggregateStructuredOutput: - def test_all_success(self): - results = [ - StructuredOutputResult( - tool_calls_expected=True, - tool_calls_found=1, - tool_call_results=[], - ), - StructuredOutputResult( - tool_calls_expected=True, - tool_calls_found=1, - tool_call_results=[], - ), - ] - summary = aggregate_structured_output(results) - assert summary.tool_call_requests == 2 - assert summary.tool_call_successes == 2 - assert summary.tool_call_success_rate == 100.0 - - def test_mixed_results(self): - from xpyd_bench.bench.structured_output import ToolCallResult - - results = [ - StructuredOutputResult( - tool_calls_expected=True, - tool_calls_found=1, - tool_call_results=[ToolCallResult(success=True)], - ), - StructuredOutputResult( - tool_calls_expected=True, - tool_calls_found=0, - tool_call_results=[], - ), - ] - summary = aggregate_structured_output(results) - assert summary.tool_call_successes == 1 - assert summary.tool_call_failures == 1 - assert summary.tool_call_success_rate == 50.0 - - def test_schema_aggregation(self): - results = [ - StructuredOutputResult(json_schema_valid=True), - StructuredOutputResult(json_schema_valid=True), - StructuredOutputResult(json_schema_valid=False, schema_errors=["bad"]), - ] - summary = aggregate_structured_output(results) - assert summary.schema_validations == 3 - assert summary.schema_passes == 2 - assert summary.schema_conformance_rate == pytest.approx(66.67, abs=0.1) - - def test_to_dict(self): - summary = StructuredOutputSummary( - total_requests=10, - tool_call_requests=5, - tool_call_successes=4, - tool_call_failures=1, - total_tool_calls_extracted=5, - schema_validations=3, - schema_passes=2, - schema_failures=1, - ) - d = summary.to_dict() - assert d["tool_call_success_rate"] == 80.0 - assert d["schema_conformance_rate"] == pytest.approx(66.67, abs=0.1) - assert "schema_validations" in d - - def test_empty(self): - summary = aggregate_structured_output([]) - assert summary.total_requests == 0 - assert summary.tool_call_success_rate == 0.0 - - -# --------------------------------------------------------------------------- -# Dummy server helpers -# --------------------------------------------------------------------------- - -class TestGenerateDummyArgs: - def test_basic_types(self): - schema = { - "type": "object", - "properties": { - "name": {"type": "string"}, - "count": {"type": "integer"}, - "ratio": {"type": "number"}, - "active": {"type": "boolean"}, - "items": {"type": "array"}, - }, - } - result = _generate_dummy_args(schema) - assert isinstance(result["name"], str) - assert isinstance(result["count"], int) - assert isinstance(result["ratio"], float) - assert isinstance(result["active"], bool) - assert isinstance(result["items"], list) - - def test_enum(self): - schema = { - "type": "object", - "properties": { - "color": {"type": "string", "enum": ["red", "blue"]}, - }, - } - result = _generate_dummy_args(schema) - assert result["color"] == "red" - - def test_empty_schema(self): - assert _generate_dummy_args({}) == {} - # sim generates a value for string type; old dummy returned {} - result = _generate_dummy_args({"type": "string"}) - assert isinstance(result, (str, dict)) - - -# --------------------------------------------------------------------------- -# Dummy server tool call generation (integration) -# --------------------------------------------------------------------------- - -class TestDummyServerToolCalls: - def test_build_tool_calls_auto(self): - from xpyd_bench.dummy.server import _build_tool_calls - - tools = [ - { - "type": "function", - "function": { - "name": "get_weather", - "parameters": { - "type": "object", - "properties": {"location": {"type": "string"}}, - "required": ["location"], - }, - }, - } - ] - result = _build_tool_calls(tools, tool_choice="auto") - assert len(result) == 1 - assert result[0]["function"]["name"] == "get_weather" - args = json.loads(result[0]["function"]["arguments"]) - assert "location" in args - - def test_build_tool_calls_specific(self): - from xpyd_bench.dummy.server import _build_tool_calls - - tools = [ - { - "type": "function", - "function": {"name": "fn_a", "parameters": {"type": "object", "properties": {}}}, - }, - { - "type": "function", - "function": {"name": "fn_b", "parameters": {"type": "object", "properties": {}}}, - }, - ] - choice = {"type": "function", "function": {"name": "fn_b"}} - result = _build_tool_calls(tools, tool_choice=choice) - assert len(result) == 1 - assert result[0]["function"]["name"] == "fn_b" - - def test_build_tool_calls_parallel(self): - from xpyd_bench.dummy.server import _build_tool_calls - - tools = [ - { - "type": "function", - "function": {"name": "fn_a", "parameters": {"type": "object", "properties": {}}}, - }, - { - "type": "function", - "function": {"name": "fn_b", "parameters": {"type": "object", "properties": {}}}, - }, - ] - result = _build_tool_calls(tools, tool_choice="auto", parallel=True) - assert len(result) == 2 - - def test_build_json_response_json_object(self): - from xpyd_bench.dummy.server import _build_json_response - - resp = _build_json_response({"type": "json_object"}, 10) - parsed = json.loads(resp) - assert isinstance(parsed, dict) - - def test_build_json_response_json_schema(self): - from xpyd_bench.dummy.server import _build_json_response - - rf = { - "type": "json_schema", - "json_schema": { - "name": "test", - "schema": { - "type": "object", - "properties": {"city": {"type": "string"}}, - "required": ["city"], - }, - }, - } - resp = _build_json_response(rf, 10) - parsed = json.loads(resp) - assert "city" in parsed - - -# --------------------------------------------------------------------------- -# Integration: dummy server HTTP -# --------------------------------------------------------------------------- - -@pytest.fixture() -def dummy_app(): - from xpyd_bench.dummy.server import ServerConfig, create_app - - config = ServerConfig(prefill_ms=0, decode_ms=0) - return create_app(config) - - -@pytest.mark.anyio -async def test_dummy_chat_with_tools(dummy_app): - from httpx import ASGITransport, AsyncClient - - tools = [ - { - "type": "function", - "function": { - "name": "get_weather", - "parameters": { - "type": "object", - "properties": {"location": {"type": "string"}}, - "required": ["location"], - }, - }, - } - ] - payload = { - "model": "test", - "messages": [{"role": "user", "content": "What's the weather?"}], - "tools": tools, - "tool_choice": "required", - "max_tokens": 10, - } - async with AsyncClient( - transport=ASGITransport(app=dummy_app), base_url="http://test" - ) as client: - resp = await client.post("/v1/chat/completions", json=payload) - assert resp.status_code == 200 - body = resp.json() - msg = body["choices"][0]["message"] - assert "tool_calls" in msg - assert len(msg["tool_calls"]) > 0 - tc = msg["tool_calls"][0] - assert tc["function"]["name"] == "get_weather" - args = json.loads(tc["function"]["arguments"]) - assert "location" in args - # Validate with our validator - result = validate_tool_calls(body, tools) - assert result.success - - -@pytest.mark.anyio -async def test_dummy_chat_with_response_format_json(dummy_app): - from httpx import ASGITransport, AsyncClient - - payload = { - "model": "test", - "messages": [{"role": "user", "content": "Give me JSON"}], - "response_format": {"type": "json_object"}, - "max_tokens": 10, - } - async with AsyncClient( - transport=ASGITransport(app=dummy_app), base_url="http://test" - ) as client: - resp = await client.post("/v1/chat/completions", json=payload) - assert resp.status_code == 200 - body = resp.json() - content = body["choices"][0]["message"]["content"] - parsed = json.loads(content) - assert isinstance(parsed, dict) - - -@pytest.mark.anyio -async def test_dummy_chat_with_response_format_schema(dummy_app): - from httpx import ASGITransport, AsyncClient - - rf = { - "type": "json_schema", - "json_schema": { - "name": "person", - "schema": { - "type": "object", - "properties": { - "name": {"type": "string"}, - "age": {"type": "integer"}, - }, - "required": ["name", "age"], - }, - }, - } - payload = { - "model": "test", - "messages": [{"role": "user", "content": "Describe a person"}], - "response_format": rf, - "max_tokens": 10, - } - async with AsyncClient( - transport=ASGITransport(app=dummy_app), base_url="http://test" - ) as client: - resp = await client.post("/v1/chat/completions", json=payload) - assert resp.status_code == 200 - body = resp.json() - content = body["choices"][0]["message"]["content"] - parsed = json.loads(content) - assert "name" in parsed - assert "age" in parsed - # Validate with structured output validator - result = validate_json_response(content, rf) - assert result.json_schema_valid is True - - -@pytest.mark.anyio -async def test_dummy_chat_tool_choice_none(dummy_app): - """tool_choice=none should produce regular text, not tool calls.""" - from httpx import ASGITransport, AsyncClient - - tools = [ - { - "type": "function", - "function": { - "name": "fn", - "parameters": {"type": "object", "properties": {}}, - }, - } - ] - payload = { - "model": "test", - "messages": [{"role": "user", "content": "hi"}], - "tools": tools, - "tool_choice": "none", - "max_tokens": 5, - } - async with AsyncClient( - transport=ASGITransport(app=dummy_app), base_url="http://test" - ) as client: - resp = await client.post("/v1/chat/completions", json=payload) - assert resp.status_code == 200 - body = resp.json() - msg = body["choices"][0]["message"] - assert not msg.get("tool_calls") # None or absent - assert msg["content"] is not None diff --git a/xpyd_bench/dummy/__init__.py b/xpyd_bench/dummy/__init__.py deleted file mode 100644 index dd1c01b..0000000 --- a/xpyd_bench/dummy/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Dummy server for bench validation — decoupled from bench code.""" diff --git a/xpyd_bench/dummy/cli.py b/xpyd_bench/dummy/cli.py deleted file mode 100644 index 1cce11f..0000000 --- a/xpyd_bench/dummy/cli.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Backward-compatible dummy CLI shim — delegates to xpyd-sim. - -DEPRECATED: Use `xpyd-sim` CLI directly. -""" - - -def dummy_main(argv: list[str] | None = None) -> None: - """Entry point for ``xpyd-dummy`` command (deprecated).""" - from xpyd_sim.cli import main - main(argv) diff --git a/xpyd_bench/dummy/server.py b/xpyd_bench/dummy/server.py deleted file mode 100644 index f579cfb..0000000 --- a/xpyd_bench/dummy/server.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Backward-compatible shim — delegates to xpyd-sim. - -This module is DEPRECATED. Use xpyd_bench.sim_adapter or xpyd_sim directly. -Kept only so existing test imports don't break during migration. -""" - -from __future__ import annotations - -from xpyd_sim.common.helpers import count_prompt_tokens as _estimate_prompt_tokens # noqa: F401 -from xpyd_sim.common.tools import build_tool_calls as _build_tool_calls # noqa: F401 -from xpyd_sim.common.tools import generate_dummy_from_schema as _generate_dummy_args # noqa: F401 -from xpyd_sim.server import ServerConfig as _SimServerConfig -from xpyd_sim.server import _generate_response_content # noqa: F401 -from xpyd_sim.server import create_app as _sim_create_app # noqa: F401 - - -def _build_json_response(response_format: dict, max_tokens: int) -> str: - """Thin wrapper around sim's _generate_response_content.""" - result = _generate_response_content(response_format, max_tokens) - if result is not None: - return result - import json - return json.dumps({"result": " ".join(["token"] * min(max_tokens, 5))}) - - -class ServerConfig(_SimServerConfig): - """Backward-compatible ServerConfig accepting old field names. - - Maps: - prefill_ms -> prefill_delay_ms - decode_ms -> decode_delay_per_token_ms - """ - - def __init__(self, **kwargs): - if "prefill_ms" in kwargs and "prefill_delay_ms" not in kwargs: - kwargs["prefill_delay_ms"] = kwargs.pop("prefill_ms") - elif "prefill_ms" in kwargs: - kwargs.pop("prefill_ms") - - if "decode_ms" in kwargs and "decode_delay_per_token_ms" not in kwargs: - kwargs["decode_delay_per_token_ms"] = kwargs.pop("decode_ms") - elif "decode_ms" in kwargs: - kwargs.pop("decode_ms") - - kwargs.pop("max_tokens_default", None) - super().__init__(**kwargs) - - -def create_app(config=None): - """Create app, accepting both old and new ServerConfig.""" - return _sim_create_app(config) diff --git a/xpyd_bench/sim_adapter/__init__.py b/xpyd_bench/sim_adapter/__init__.py deleted file mode 100644 index f200f85..0000000 --- a/xpyd_bench/sim_adapter/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Thin adapter re-exporting xPyD-sim's server for bench testing.""" - -from xpyd_sim.server import ServerConfig, create_app - -__all__ = ["ServerConfig", "create_app"] diff --git a/xpyd_bench/sim_adapter/server.py b/xpyd_bench/sim_adapter/server.py deleted file mode 100644 index 6716f37..0000000 --- a/xpyd_bench/sim_adapter/server.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Re-export xPyD-sim server components. - -This module replaces the old xpyd_bench.dummy.server module. -All server functionality is now provided by xpyd-sim. -""" - -from xpyd_sim.server import ServerConfig, create_app - -__all__ = ["ServerConfig", "create_app"]