diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index d1fc5f07d1..e843cbe06d 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -13,7 +13,7 @@ from typing_extensions import TypeGuard, TypeVar, deprecated from pydantic_graph import End, Graph, GraphRun, GraphRunContext -from pydantic_graph._utils import get_event_loop +from pydantic_graph._utils import run_until_complete from . import ( _agent_graph, @@ -567,7 +567,7 @@ def run_sync( """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - return get_event_loop().run_until_complete( + return run_until_complete( self.run( user_prompt, result_type=result_type, diff --git a/pydantic_graph/pydantic_graph/_utils.py b/pydantic_graph/pydantic_graph/_utils.py index bc49ed6fcb..dfcff0fe34 100644 --- a/pydantic_graph/pydantic_graph/_utils.py +++ b/pydantic_graph/pydantic_graph/_utils.py @@ -1,7 +1,9 @@ from __future__ import annotations as _annotations import asyncio +import sys import types +from collections.abc import Coroutine from functools import partial from typing import Any, Callable, TypeVar @@ -10,15 +12,6 @@ from typing_inspection.introspection import is_union_origin -def get_event_loop(): - try: - event_loop = asyncio.get_event_loop() - except RuntimeError: - event_loop = asyncio.new_event_loop() - asyncio.set_event_loop(event_loop) - return event_loop - - def get_union_args(tp: Any) -> tuple[Any, ...]: """Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty tuple.""" # similar to `pydantic_ai_slim/pydantic_ai/_result.py:get_union_args` @@ -100,3 +93,15 @@ async def run_in_executor(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.k return await asyncio.get_running_loop().run_in_executor(None, partial(func, *args, **kwargs)) else: return await asyncio.get_running_loop().run_in_executor(None, func, *args) # type: ignore + + +def run_until_complete(coro: Coroutine[None, None, _R]) -> _R: + if sys.version_info < (3, 11): + try: + loop = asyncio.new_event_loop() + return loop.run_until_complete(coro) + finally: + loop.close() + else: + with asyncio.runners.Runner(loop_factory=asyncio.new_event_loop) as runner: + return runner.run(coro) diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 9c026ea38e..e1238c7ea1 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -202,7 +202,7 @@ def run_sync( if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - return _utils.get_event_loop().run_until_complete( + return _utils.run_until_complete( self.run(start_node, state=state, deps=deps, persistence=persistence, infer_name=False) ) diff --git a/tests/graph/test_utils.py b/tests/graph/test_utils.py index e5987feedb..2358db7513 100644 --- a/tests/graph/test_utils.py +++ b/tests/graph/test_utils.py @@ -1,12 +1,19 @@ from threading import Thread -from pydantic_graph._utils import get_event_loop +from pydantic_graph._utils import run_until_complete -def test_get_event_loop_in_thread(): +def test_run_until_complete_in_main_thread(): + async def run(): ... + + run_until_complete(run()) + + +def test_run_until_complete_in_thread(): + async def run(): ... + def get_and_close_event_loop(): - event_loop = get_event_loop() - event_loop.close() + run_until_complete(run()) thread = Thread(target=get_and_close_event_loop) thread.start()