From 5a4390069d483c7ea2eeecf52a4da70f9ed5cde9 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 3 Dec 2024 09:08:36 +0000 Subject: [PATCH 1/2] run_sync repeat runs, fix #122 --- pydantic_ai_slim/pydantic_ai/agent.py | 5 ++-- tests/conftest.py | 9 +++++++ tests/models/test_model_function.py | 20 +++++++-------- tests/models/test_model_test.py | 12 ++++----- tests/test_agent.py | 37 ++++++++++++++++----------- tests/test_deps.py | 4 +-- tests/test_examples.py | 1 + tests/test_logfire.py | 2 +- tests/test_retrievers.py | 14 +++++----- 9 files changed, 61 insertions(+), 43 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 791ee6bb38..13cb4c4959 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -206,7 +206,7 @@ def run_sync( ) -> result.RunResult[ResultData]: """Run the agent with a user prompt synchronously. - This is a convenience method that wraps `self.run` with `asyncio.run()`. + This is a convenience method that wraps `self.run` with `loop.run_until_complete()`. Args: user_prompt: User input to start/continue the conversation. @@ -217,7 +217,8 @@ def run_sync( Returns: The result of the run. """ - return asyncio.run(self.run(user_prompt, message_history=message_history, model=model, deps=deps)) + loop = asyncio.get_event_loop() + return loop.run_until_complete(self.run(user_prompt, message_history=message_history, model=model, deps=deps)) @asynccontextmanager async def run_stream( diff --git a/tests/conftest.py b/tests/conftest.py index f12cf67bc9..4a219ce00d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import asyncio import importlib.util import os import re @@ -162,3 +163,11 @@ def check_import() -> bool: pass else: import_success = True + + +@pytest.fixture +def set_event_loop() -> Iterator[None]: + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + yield + new_loop.close() diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index e9e5761b88..4fdb422e49 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -38,7 +38,7 @@ async def return_last(messages: list[Message], _: AgentInfo) -> ModelAnyResponse return ModelTextResponse(' '.join(f'{k}={v!r}' for k, v in response.items())) -def test_simple(): +def test_simple(set_event_loop: None): agent = Agent(FunctionModel(return_last)) result = agent.run_sync('Hello') assert result.data == snapshot("content='Hello' role='user' message_count=1") @@ -129,7 +129,7 @@ async def get_weather(_: RunContext[None], lat: int, lng: int): return 'Sunny' -def test_weather(): +def test_weather(set_event_loop: None): result = weather_agent.run_sync('London') assert result.data == 'Raining in London' assert result.all_messages() == snapshot( @@ -206,7 +206,7 @@ def get_var_args(ctx: RunContext[int], *args: int): return json.dumps({'args': args}) -def test_var_args(): +def test_var_args(set_event_loop: None): result = var_args_agent.run_sync('{"function": "get_var_args", "arguments": {"args": [1, 2, 3]}}', deps=123) response_data = json.loads(result.data) # Can't parse ISO timestamps with trailing 'Z' in older versions of python: @@ -231,7 +231,7 @@ async def call_tool(messages: list[Message], info: AgentInfo) -> ModelAnyRespons return ModelTextResponse('final response') -def test_deps_none(): +def test_deps_none(set_event_loop: None): agent = Agent(FunctionModel(call_tool)) @agent.tool @@ -251,7 +251,7 @@ async def get_none(ctx: RunContext[None]): assert called -def test_deps_init(): +def test_deps_init(set_event_loop: None): def get_check_foobar(ctx: RunContext[tuple[str, str]]) -> str: nonlocal called @@ -266,7 +266,7 @@ def get_check_foobar(ctx: RunContext[tuple[str, str]]) -> str: assert called -def test_model_arg(): +def test_model_arg(set_event_loop: None): agent = Agent() result = agent.run_sync('Hello', model=FunctionModel(return_last)) assert result.data == snapshot("content='Hello' role='user' message_count=1") @@ -308,7 +308,7 @@ def spam() -> str: return 'foobar' -def test_register_all(): +def test_register_all(set_event_loop: None): async def f(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: return ModelTextResponse( f'messages={len(messages)} allow_text_result={info.allow_text_result} tools={len(info.function_tools)}' @@ -318,7 +318,7 @@ async def f(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: assert result.data == snapshot('messages=2 allow_text_result=True tools=5') -def test_call_all(): +def test_call_all(set_event_loop: None): result = agent_all.run_sync('Hello', model=TestModel()) assert result.data == snapshot('{"foo":"1","bar":"2","baz":"3","qux":"4","quz":"a"}') assert result.all_messages() == snapshot( @@ -347,7 +347,7 @@ def test_call_all(): ) -def test_retry_str(): +def test_retry_str(set_event_loop: None): call_count = 0 async def try_again(messages: list[Message], _: AgentInfo) -> ModelAnyResponse: @@ -369,7 +369,7 @@ async def validate_result(r: str) -> str: assert result.data == snapshot('2') -def test_retry_result_type(): +def test_retry_result_type(set_event_loop: None): call_count = 0 async def try_again(messages: list[Message], _: AgentInfo) -> ModelAnyResponse: diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index ae47d223b6..38d7933308 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -24,7 +24,7 @@ from ..conftest import IsNow -def test_call_one(): +def test_call_one(set_event_loop: None): agent = Agent() calls: list[str] = [] @@ -43,7 +43,7 @@ async def ret_b(x: str) -> str: # pragma: no cover assert calls == ['a'] -def test_custom_result_text(): +def test_custom_result_text(set_event_loop: None): agent = Agent() result = agent.run_sync('x', model=TestModel(custom_result_text='custom')) assert result.data == snapshot('custom') @@ -52,13 +52,13 @@ def test_custom_result_text(): agent.run_sync('x', model=TestModel(custom_result_text='custom')) -def test_custom_result_args(): +def test_custom_result_args(set_event_loop: None): agent = Agent(result_type=tuple[str, str]) result = agent.run_sync('x', model=TestModel(custom_result_args=['a', 'b'])) assert result.data == ('a', 'b') -def test_custom_result_args_model(): +def test_custom_result_args_model(set_event_loop: None): class Foo(BaseModel): foo: str bar: int @@ -68,13 +68,13 @@ class Foo(BaseModel): assert result.data == Foo(foo='a', bar=1) -def test_result_type(): +def test_result_type(set_event_loop: None): agent = Agent(result_type=tuple[str, str]) result = agent.run_sync('x', model=TestModel()) assert result.data == ('a', 'a') -def test_tool_retry(): +def test_tool_retry(set_event_loop: None): agent = Agent() call_count = 0 diff --git a/tests/test_agent.py b/tests/test_agent.py index e416bd4fb0..6d5f242025 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -27,7 +27,7 @@ from .conftest import IsNow, TestEnv -def test_result_tuple(): +def test_result_tuple(set_event_loop: None): def return_tuple(_: list[Message], info: AgentInfo) -> ModelAnyResponse: assert info.result_tools is not None args_json = '{"response": ["foo", "bar"]}' @@ -44,7 +44,7 @@ class Foo(BaseModel): b: str -def test_result_pydantic_model(): +def test_result_pydantic_model(set_event_loop: None): def return_model(_: list[Message], info: AgentInfo) -> ModelAnyResponse: assert info.result_tools is not None args_json = '{"a": 1, "b": "foo"}' @@ -57,7 +57,7 @@ def return_model(_: list[Message], info: AgentInfo) -> ModelAnyResponse: assert result.data.model_dump() == {'a': 1, 'b': 'foo'} -def test_result_pydantic_model_retry(): +def test_result_pydantic_model_retry(set_event_loop: None): def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: assert info.result_tools is not None if len(messages) == 1: @@ -99,7 +99,7 @@ def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: assert result.all_messages_json().startswith(b'[{"content":"Hello"') -def test_result_validator(): +def test_result_validator(set_event_loop: None): def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: assert info.result_tools is not None if len(messages) == 1: @@ -135,7 +135,7 @@ def validate_result(ctx: RunContext[None], r: Foo) -> Foo: ) -def test_plain_response(): +def test_plain_response(set_event_loop: None): call_index = 0 def return_tuple(_: list[Message], info: AgentInfo) -> ModelAnyResponse: @@ -170,7 +170,7 @@ def return_tuple(_: list[Message], info: AgentInfo) -> ModelAnyResponse: ) -def test_response_tuple(): +def test_response_tuple(set_event_loop: None): m = TestModel() agent = Agent(m, result_type=tuple[str, str]) @@ -215,7 +215,7 @@ def test_response_tuple(): [lambda: Union[str, Foo], lambda: Union[Foo, str], lambda: str | Foo, lambda: Foo | str], ids=['Union[str, Foo]', 'Union[Foo, str]', 'str | Foo', 'Foo | str'], ) -def test_response_union_allow_str(input_union_callable: Callable[[], Any]): +def test_response_union_allow_str(set_event_loop: None, input_union_callable: Callable[[], Any]): try: union = input_union_callable() except TypeError: @@ -277,7 +277,7 @@ def validate_result(ctx: RunContext[None], r: Any) -> Any: ), ], ) -def test_response_multiple_return_tools(create_module: Callable[[str], Any], union_code: str): +def test_response_multiple_return_tools(set_event_loop: None, create_module: Callable[[str], Any], union_code: str): module_code = f''' from pydantic import BaseModel from typing import Union @@ -356,7 +356,7 @@ def validate_result(ctx: RunContext[None], r: Any) -> Any: assert got_tool_call_name == snapshot('final_result_Bar') -def test_run_with_history_new(): +def test_run_with_history_new(set_event_loop: None): m = TestModel() agent = Agent(m, system_prompt='Foobar') @@ -440,7 +440,7 @@ async def ret_a(x: str) -> str: ) -def test_empty_tool_calls(): +def test_empty_tool_calls(set_event_loop: None): def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse: return ModelStructuredResponse(calls=[]) @@ -450,7 +450,7 @@ def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse: agent.run_sync('Hello') -def test_unknown_tool(): +def test_unknown_tool(set_event_loop: None): def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse: return ModelStructuredResponse(calls=[ToolCall.from_json('foobar', '{}')]) @@ -472,7 +472,7 @@ def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse: ) -def test_unknown_tool_fix(): +def test_unknown_tool_fix(set_event_loop: None): def empty(m: list[Message], _info: AgentInfo) -> ModelAnyResponse: if len(m) > 1: return ModelTextResponse(content='success') @@ -495,7 +495,7 @@ def empty(m: list[Message], _info: AgentInfo) -> ModelAnyResponse: ) -def test_model_requests_blocked(env: TestEnv): +def test_model_requests_blocked(env: TestEnv, set_event_loop: None): env.set('GEMINI_API_KEY', 'foobar') agent = Agent('gemini-1.5-flash', result_type=tuple[str, str], defer_model_check=True) @@ -503,7 +503,7 @@ def test_model_requests_blocked(env: TestEnv): agent.run_sync('Hello') -def test_override_model(env: TestEnv): +def test_override_model(env: TestEnv, set_event_loop: None): env.set('GEMINI_API_KEY', 'foobar') agent = Agent('gemini-1.5-flash', result_type=tuple[int, str], defer_model_check=True) @@ -512,9 +512,16 @@ def test_override_model(env: TestEnv): assert result.data == snapshot((0, 'a')) -def test_override_model_no_model(): +def test_override_model_no_model(set_event_loop: None): agent = Agent() with pytest.raises(UserError, match=r'`model` must be set either.+Even when `override\(model=...\)` is customiz'): with agent.override(model='test'): agent.run_sync('Hello') + + +def test_run_sync_multiple(set_event_loop: None): + agent = Agent('test') + for _ in range(3): + result = agent.run_sync('Hello') + assert result.data == 'success (no tool calls)' diff --git a/tests/test_deps.py b/tests/test_deps.py index 2a62d4b9ea..7e5778d0f1 100644 --- a/tests/test_deps.py +++ b/tests/test_deps.py @@ -18,12 +18,12 @@ async def example_tool(ctx: RunContext[MyDeps]) -> str: return f'{ctx.deps}' -def test_deps_used(): +def test_deps_used(set_event_loop: None): result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2)) assert result.data == '{"example_tool":"MyDeps(foo=1, bar=2)"}' -def test_deps_override(): +def test_deps_override(set_event_loop: None): with agent.override(deps=MyDeps(foo=3, bar=4)): result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2)) assert result.data == '{"example_tool":"MyDeps(foo=3, bar=4)"}' diff --git a/tests/test_examples.py b/tests/test_examples.py index 9a7d3b170d..b7856f1c63 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -61,6 +61,7 @@ def test_docs_examples( client_with_handler: ClientWithHandler, env: TestEnv, tmp_path: Path, + set_event_loop: None, ): mocker.patch('pydantic_ai.agent.models.infer_model', side_effect=mock_infer_model) mocker.patch('pydantic_ai._utils.group_by_temporal', side_effect=mock_group_by_temporal) diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 1102ee4827..193cc558dc 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -59,7 +59,7 @@ def get_summary() -> LogfireSummary: @pytest.mark.skipif(not logfire_installed, reason='logfire not installed') -def test_logfire(get_logfire_summary: Callable[[], LogfireSummary]) -> None: +def test_logfire(get_logfire_summary: Callable[[], LogfireSummary], set_event_loop: None) -> None: agent = Agent(model=TestModel()) @agent.tool_plain diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py index c411c1cce2..32191d27ec 100644 --- a/tests/test_retrievers.py +++ b/tests/test_retrievers.py @@ -73,7 +73,7 @@ async def get_json_schema(_messages: list[Message], info: AgentInfo) -> ModelAny return ModelTextResponse(json.dumps(r.json_schema)) -def test_docstring_google(): +def test_docstring_google(set_event_loop: None): agent = Agent(FunctionModel(get_json_schema)) agent.tool_plain(google_style_docstring) @@ -104,7 +104,7 @@ def sphinx_style_docstring(foo: int, /) -> str: # pragma: no cover return str(foo) -def test_docstring_sphinx(): +def test_docstring_sphinx(set_event_loop: None): agent = Agent(FunctionModel(get_json_schema)) agent.tool_plain(sphinx_style_docstring) @@ -136,7 +136,7 @@ def numpy_style_docstring(*, foo: int, bar: str) -> str: # pragma: no cover return f'{foo} {bar}' -def test_docstring_numpy(): +def test_docstring_numpy(set_event_loop: None): agent = Agent(FunctionModel(get_json_schema)) agent.tool_plain(numpy_style_docstring) @@ -161,7 +161,7 @@ def unknown_docstring(**kwargs: int) -> str: # pragma: no cover return str(kwargs) -def test_docstring_unknown(): +def test_docstring_unknown(set_event_loop: None): agent = Agent(FunctionModel(get_json_schema)) agent.tool_plain(unknown_docstring) @@ -190,7 +190,7 @@ async def google_style_docstring_no_body( return f'{foo} {bar}' -def test_docstring_google_no_body(): +def test_docstring_google_no_body(set_event_loop: None): agent = Agent(FunctionModel(get_json_schema)) agent.tool_plain(google_style_docstring_no_body) @@ -209,7 +209,7 @@ def test_docstring_google_no_body(): ) -def test_takes_just_model(): +def test_takes_just_model(set_event_loop: None): agent = Agent() class Foo(BaseModel): @@ -235,7 +235,7 @@ def takes_just_model(model: Foo) -> str: assert result.data == snapshot('{"takes_just_model":"0 a"}') -def test_takes_model_and_int(): +def test_takes_model_and_int(set_event_loop: None): agent = Agent() class Foo(BaseModel): From c0159031840003aba109dac765c3816a109b712d Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 3 Dec 2024 09:42:52 +0000 Subject: [PATCH 2/2] fix test so it would fail with asyncio.run --- tests/test_agent.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/test_agent.py b/tests/test_agent.py index 6d5f242025..d4fc1ccf46 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -20,6 +20,7 @@ ToolReturn, UserPrompt, ) +from pydantic_ai.models import cached_async_http_client from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel from pydantic_ai.result import Cost, RunResult @@ -522,6 +523,15 @@ def test_override_model_no_model(set_event_loop: None): def test_run_sync_multiple(set_event_loop: None): agent = Agent('test') - for _ in range(3): + + @agent.tool_plain + async def make_request() -> str: + # raised a `RuntimeError: Event loop is closed` on repeat runs when we used `asyncio.run()` + client = cached_async_http_client() + # use this as I suspect it's about the fastest globally available endpoint + response = await client.get('https://cloudflare.com/cdn-cgi/trace') + return str(response.status_code) + + for _ in range(2): result = agent.run_sync('Hello') - assert result.data == 'success (no tool calls)' + assert result.data == '{"make_request":"200"}'