Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def run_sync(
) -> result.RunResult[ResultData]:
"""Run the agent with a user prompt synchronously.

This is a convenience method that wraps `self.run` with `asyncio.run()`.
This is a convenience method that wraps `self.run` with `loop.run_until_complete()`.

Args:
user_prompt: User input to start/continue the conversation.
Expand All @@ -217,7 +217,8 @@ def run_sync(
Returns:
The result of the run.
"""
return asyncio.run(self.run(user_prompt, message_history=message_history, model=model, deps=deps))
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.run(user_prompt, message_history=message_history, model=model, deps=deps))

@asynccontextmanager
async def run_stream(
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations as _annotations

import asyncio
import importlib.util
import os
import re
Expand Down Expand Up @@ -162,3 +163,11 @@ def check_import() -> bool:
pass
else:
import_success = True


@pytest.fixture
def set_event_loop() -> Iterator[None]:
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
yield
new_loop.close()
20 changes: 10 additions & 10 deletions tests/models/test_model_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async def return_last(messages: list[Message], _: AgentInfo) -> ModelAnyResponse
return ModelTextResponse(' '.join(f'{k}={v!r}' for k, v in response.items()))


def test_simple():
def test_simple(set_event_loop: None):
agent = Agent(FunctionModel(return_last))
result = agent.run_sync('Hello')
assert result.data == snapshot("content='Hello' role='user' message_count=1")
Expand Down Expand Up @@ -129,7 +129,7 @@ async def get_weather(_: RunContext[None], lat: int, lng: int):
return 'Sunny'


def test_weather():
def test_weather(set_event_loop: None):
result = weather_agent.run_sync('London')
assert result.data == 'Raining in London'
assert result.all_messages() == snapshot(
Expand Down Expand Up @@ -206,7 +206,7 @@ def get_var_args(ctx: RunContext[int], *args: int):
return json.dumps({'args': args})


def test_var_args():
def test_var_args(set_event_loop: None):
result = var_args_agent.run_sync('{"function": "get_var_args", "arguments": {"args": [1, 2, 3]}}', deps=123)
response_data = json.loads(result.data)
# Can't parse ISO timestamps with trailing 'Z' in older versions of python:
Expand All @@ -231,7 +231,7 @@ async def call_tool(messages: list[Message], info: AgentInfo) -> ModelAnyRespons
return ModelTextResponse('final response')


def test_deps_none():
def test_deps_none(set_event_loop: None):
agent = Agent(FunctionModel(call_tool))

@agent.tool
Expand All @@ -251,7 +251,7 @@ async def get_none(ctx: RunContext[None]):
assert called


def test_deps_init():
def test_deps_init(set_event_loop: None):
def get_check_foobar(ctx: RunContext[tuple[str, str]]) -> str:
nonlocal called

Expand All @@ -266,7 +266,7 @@ def get_check_foobar(ctx: RunContext[tuple[str, str]]) -> str:
assert called


def test_model_arg():
def test_model_arg(set_event_loop: None):
agent = Agent()
result = agent.run_sync('Hello', model=FunctionModel(return_last))
assert result.data == snapshot("content='Hello' role='user' message_count=1")
Expand Down Expand Up @@ -308,7 +308,7 @@ def spam() -> str:
return 'foobar'


def test_register_all():
def test_register_all(set_event_loop: None):
async def f(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
return ModelTextResponse(
f'messages={len(messages)} allow_text_result={info.allow_text_result} tools={len(info.function_tools)}'
Expand All @@ -318,7 +318,7 @@ async def f(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
assert result.data == snapshot('messages=2 allow_text_result=True tools=5')


def test_call_all():
def test_call_all(set_event_loop: None):
result = agent_all.run_sync('Hello', model=TestModel())
assert result.data == snapshot('{"foo":"1","bar":"2","baz":"3","qux":"4","quz":"a"}')
assert result.all_messages() == snapshot(
Expand Down Expand Up @@ -347,7 +347,7 @@ def test_call_all():
)


def test_retry_str():
def test_retry_str(set_event_loop: None):
call_count = 0

async def try_again(messages: list[Message], _: AgentInfo) -> ModelAnyResponse:
Expand All @@ -369,7 +369,7 @@ async def validate_result(r: str) -> str:
assert result.data == snapshot('2')


def test_retry_result_type():
def test_retry_result_type(set_event_loop: None):
call_count = 0

async def try_again(messages: list[Message], _: AgentInfo) -> ModelAnyResponse:
Expand Down
12 changes: 6 additions & 6 deletions tests/models/test_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ..conftest import IsNow


def test_call_one():
def test_call_one(set_event_loop: None):
agent = Agent()
calls: list[str] = []

Expand All @@ -43,7 +43,7 @@ async def ret_b(x: str) -> str: # pragma: no cover
assert calls == ['a']


def test_custom_result_text():
def test_custom_result_text(set_event_loop: None):
agent = Agent()
result = agent.run_sync('x', model=TestModel(custom_result_text='custom'))
assert result.data == snapshot('custom')
Expand All @@ -52,13 +52,13 @@ def test_custom_result_text():
agent.run_sync('x', model=TestModel(custom_result_text='custom'))


def test_custom_result_args():
def test_custom_result_args(set_event_loop: None):
agent = Agent(result_type=tuple[str, str])
result = agent.run_sync('x', model=TestModel(custom_result_args=['a', 'b']))
assert result.data == ('a', 'b')


def test_custom_result_args_model():
def test_custom_result_args_model(set_event_loop: None):
class Foo(BaseModel):
foo: str
bar: int
Expand All @@ -68,13 +68,13 @@ class Foo(BaseModel):
assert result.data == Foo(foo='a', bar=1)


def test_result_type():
def test_result_type(set_event_loop: None):
agent = Agent(result_type=tuple[str, str])
result = agent.run_sync('x', model=TestModel())
assert result.data == ('a', 'a')


def test_tool_retry():
def test_tool_retry(set_event_loop: None):
agent = Agent()
call_count = 0

Expand Down
47 changes: 32 additions & 15 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
ToolReturn,
UserPrompt,
)
from pydantic_ai.models import cached_async_http_client
from pydantic_ai.models.function import AgentInfo, FunctionModel
from pydantic_ai.models.test import TestModel
from pydantic_ai.result import Cost, RunResult

from .conftest import IsNow, TestEnv


def test_result_tuple():
def test_result_tuple(set_event_loop: None):
def return_tuple(_: list[Message], info: AgentInfo) -> ModelAnyResponse:
assert info.result_tools is not None
args_json = '{"response": ["foo", "bar"]}'
Expand All @@ -44,7 +45,7 @@ class Foo(BaseModel):
b: str


def test_result_pydantic_model():
def test_result_pydantic_model(set_event_loop: None):
def return_model(_: list[Message], info: AgentInfo) -> ModelAnyResponse:
assert info.result_tools is not None
args_json = '{"a": 1, "b": "foo"}'
Expand All @@ -57,7 +58,7 @@ def return_model(_: list[Message], info: AgentInfo) -> ModelAnyResponse:
assert result.data.model_dump() == {'a': 1, 'b': 'foo'}


def test_result_pydantic_model_retry():
def test_result_pydantic_model_retry(set_event_loop: None):
def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
assert info.result_tools is not None
if len(messages) == 1:
Expand Down Expand Up @@ -99,7 +100,7 @@ def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
assert result.all_messages_json().startswith(b'[{"content":"Hello"')


def test_result_validator():
def test_result_validator(set_event_loop: None):
def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
assert info.result_tools is not None
if len(messages) == 1:
Expand Down Expand Up @@ -135,7 +136,7 @@ def validate_result(ctx: RunContext[None], r: Foo) -> Foo:
)


def test_plain_response():
def test_plain_response(set_event_loop: None):
call_index = 0

def return_tuple(_: list[Message], info: AgentInfo) -> ModelAnyResponse:
Expand Down Expand Up @@ -170,7 +171,7 @@ def return_tuple(_: list[Message], info: AgentInfo) -> ModelAnyResponse:
)


def test_response_tuple():
def test_response_tuple(set_event_loop: None):
m = TestModel()

agent = Agent(m, result_type=tuple[str, str])
Expand Down Expand Up @@ -215,7 +216,7 @@ def test_response_tuple():
[lambda: Union[str, Foo], lambda: Union[Foo, str], lambda: str | Foo, lambda: Foo | str],
ids=['Union[str, Foo]', 'Union[Foo, str]', 'str | Foo', 'Foo | str'],
)
def test_response_union_allow_str(input_union_callable: Callable[[], Any]):
def test_response_union_allow_str(set_event_loop: None, input_union_callable: Callable[[], Any]):
try:
union = input_union_callable()
except TypeError:
Expand Down Expand Up @@ -277,7 +278,7 @@ def validate_result(ctx: RunContext[None], r: Any) -> Any:
),
],
)
def test_response_multiple_return_tools(create_module: Callable[[str], Any], union_code: str):
def test_response_multiple_return_tools(set_event_loop: None, create_module: Callable[[str], Any], union_code: str):
module_code = f'''
from pydantic import BaseModel
from typing import Union
Expand Down Expand Up @@ -356,7 +357,7 @@ def validate_result(ctx: RunContext[None], r: Any) -> Any:
assert got_tool_call_name == snapshot('final_result_Bar')


def test_run_with_history_new():
def test_run_with_history_new(set_event_loop: None):
m = TestModel()

agent = Agent(m, system_prompt='Foobar')
Expand Down Expand Up @@ -440,7 +441,7 @@ async def ret_a(x: str) -> str:
)


def test_empty_tool_calls():
def test_empty_tool_calls(set_event_loop: None):
def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse:
return ModelStructuredResponse(calls=[])

Expand All @@ -450,7 +451,7 @@ def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse:
agent.run_sync('Hello')


def test_unknown_tool():
def test_unknown_tool(set_event_loop: None):
def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse:
return ModelStructuredResponse(calls=[ToolCall.from_json('foobar', '{}')])

Expand All @@ -472,7 +473,7 @@ def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse:
)


def test_unknown_tool_fix():
def test_unknown_tool_fix(set_event_loop: None):
def empty(m: list[Message], _info: AgentInfo) -> ModelAnyResponse:
if len(m) > 1:
return ModelTextResponse(content='success')
Expand All @@ -495,15 +496,15 @@ def empty(m: list[Message], _info: AgentInfo) -> ModelAnyResponse:
)


def test_model_requests_blocked(env: TestEnv):
def test_model_requests_blocked(env: TestEnv, set_event_loop: None):
env.set('GEMINI_API_KEY', 'foobar')
agent = Agent('gemini-1.5-flash', result_type=tuple[str, str], defer_model_check=True)

with pytest.raises(RuntimeError, match='Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False'):
agent.run_sync('Hello')


def test_override_model(env: TestEnv):
def test_override_model(env: TestEnv, set_event_loop: None):
env.set('GEMINI_API_KEY', 'foobar')
agent = Agent('gemini-1.5-flash', result_type=tuple[int, str], defer_model_check=True)

Expand All @@ -512,9 +513,25 @@ def test_override_model(env: TestEnv):
assert result.data == snapshot((0, 'a'))


def test_override_model_no_model():
def test_override_model_no_model(set_event_loop: None):
agent = Agent()

with pytest.raises(UserError, match=r'`model` must be set either.+Even when `override\(model=...\)` is customiz'):
with agent.override(model='test'):
agent.run_sync('Hello')


def test_run_sync_multiple(set_event_loop: None):
agent = Agent('test')

@agent.tool_plain
async def make_request() -> str:
# raised a `RuntimeError: Event loop is closed` on repeat runs when we used `asyncio.run()`
client = cached_async_http_client()
# use this as I suspect it's about the fastest globally available endpoint
response = await client.get('https://cloudflare.com/cdn-cgi/trace')
return str(response.status_code)

for _ in range(2):
result = agent.run_sync('Hello')
assert result.data == '{"make_request":"200"}'
4 changes: 2 additions & 2 deletions tests/test_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ async def example_tool(ctx: RunContext[MyDeps]) -> str:
return f'{ctx.deps}'


def test_deps_used():
def test_deps_used(set_event_loop: None):
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
assert result.data == '{"example_tool":"MyDeps(foo=1, bar=2)"}'


def test_deps_override():
def test_deps_override(set_event_loop: None):
with agent.override(deps=MyDeps(foo=3, bar=4)):
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
assert result.data == '{"example_tool":"MyDeps(foo=3, bar=4)"}'
Expand Down
1 change: 1 addition & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def test_docs_examples(
client_with_handler: ClientWithHandler,
env: TestEnv,
tmp_path: Path,
set_event_loop: None,
):
mocker.patch('pydantic_ai.agent.models.infer_model', side_effect=mock_infer_model)
mocker.patch('pydantic_ai._utils.group_by_temporal', side_effect=mock_group_by_temporal)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_logfire.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_summary() -> LogfireSummary:


@pytest.mark.skipif(not logfire_installed, reason='logfire not installed')
def test_logfire(get_logfire_summary: Callable[[], LogfireSummary]) -> None:
def test_logfire(get_logfire_summary: Callable[[], LogfireSummary], set_event_loop: None) -> None:
agent = Agent(model=TestModel())

@agent.tool_plain
Expand Down
Loading
Loading