diff --git a/src/strands/experimental/bidi/_async/__init__.py b/src/strands/experimental/bidi/_async/__init__.py index 6cee3264d..47473115c 100644 --- a/src/strands/experimental/bidi/_async/__init__.py +++ b/src/strands/experimental/bidi/_async/__init__.py @@ -2,9 +2,10 @@ from typing import Awaitable, Callable +from ._task_group import _TaskGroup from ._task_pool import _TaskPool -__all__ = ["_TaskPool"] +__all__ = ["_TaskGroup", "_TaskPool"] async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: @@ -16,14 +17,14 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: funcs: Stop functions to call in sequence. Raises: - ExceptionGroup: If any stop function raises an exception. + RuntimeError: If any stop function raises an exception. """ exceptions = [] for func in funcs: try: await func() except Exception as exception: - exceptions.append(exception) + exceptions.append({"func_name": func.__name__, "exception": repr(exception)}) if exceptions: - raise ExceptionGroup("failed stop sequence", exceptions) + raise RuntimeError(f"exceptions={exceptions} | failed stop sequence") diff --git a/src/strands/experimental/bidi/_async/_task_group.py b/src/strands/experimental/bidi/_async/_task_group.py new file mode 100644 index 000000000..26c67326d --- /dev/null +++ b/src/strands/experimental/bidi/_async/_task_group.py @@ -0,0 +1,61 @@ +"""Manage a group of async tasks. + +This is intended to mimic the behaviors of asyncio.TaskGroup released in Python 3.11. + +- Docs: https://docs.python.org/3/library/asyncio-task.html#task-groups +""" + +import asyncio +from typing import Any, Coroutine + + +class _TaskGroup: + """Shim of asyncio.TaskGroup for use in Python 3.10. + + Attributes: + _tasks: List of tasks in group. + """ + + _tasks: list[asyncio.Task] + + def create_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task: + """Create an async task and add to group. + + Returns: + The created task. + """ + task = asyncio.create_task(coro) + self._tasks.append(task) + return task + + async def __aenter__(self) -> "_TaskGroup": + """Setup self managed task group context.""" + self._tasks = [] + return self + + async def __aexit__(self, *_: Any) -> None: + """Execute tasks in group. + + The following execution rules are enforced: + - The context stops executing all tasks if at least one task raises an Exception or the context is cancelled. + - The context re-raises Exceptions to the caller. + - The context re-raises CancelledErrors to the caller only if the context itself was cancelled. + """ + try: + await asyncio.gather(*self._tasks) + + except (Exception, asyncio.CancelledError) as error: + for task in self._tasks: + task.cancel() + + await asyncio.gather(*self._tasks, return_exceptions=True) + + if not isinstance(error, asyncio.CancelledError): + raise + + context_task = asyncio.current_task() + if context_task and context_task.cancelling() > 0: # context itself was cancelled + raise + + finally: + self._tasks = [] diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 4012d5e2d..5ddb181ea 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -30,7 +30,7 @@ from ....types.tools import AgentTool from ...hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent from ...tools import ToolProvider -from .._async import stop_all +from .._async import _TaskGroup, stop_all from ..models.model import BidiModel from ..models.nova_sonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput @@ -390,7 +390,7 @@ async def run_outputs(inputs_task: asyncio.Task) -> None: for start in [*input_starts, *output_starts]: await start(self) - async with asyncio.TaskGroup() as task_group: + async with _TaskGroup() as task_group: inputs_task = task_group.create_task(run_inputs()) task_group.create_task(run_outputs(inputs_task)) diff --git a/tests/strands/experimental/bidi/_async/test__init__.py b/tests/strands/experimental/bidi/_async/test__init__.py index f8df25e14..a121ddecc 100644 --- a/tests/strands/experimental/bidi/_async/test__init__.py +++ b/tests/strands/experimental/bidi/_async/test__init__.py @@ -10,17 +10,19 @@ async def test_stop_exception(): func1 = AsyncMock() func2 = AsyncMock(side_effect=ValueError("stop 2 failed")) func3 = AsyncMock() + func4 = AsyncMock(side_effect=ValueError("stop 4 failed")) - with pytest.raises(ExceptionGroup) as exc_info: - await stop_all(func1, func2, func3) + with pytest.raises(Exception, match=r"failed stop sequence") as exc_info: + await stop_all(func1, func2, func3, func4) func1.assert_called_once() func2.assert_called_once() func3.assert_called_once() + func4.assert_called_once() - assert len(exc_info.value.exceptions) == 1 - with pytest.raises(ValueError, match=r"stop 2 failed"): - raise exc_info.value.exceptions[0] + tru_message = str(exc_info.value) + assert "ValueError('stop 2 failed')" in tru_message + assert "ValueError('stop 4 failed')" in tru_message @pytest.mark.asyncio diff --git a/tests/strands/experimental/bidi/_async/test_task_group.py b/tests/strands/experimental/bidi/_async/test_task_group.py new file mode 100644 index 000000000..23ff821f9 --- /dev/null +++ b/tests/strands/experimental/bidi/_async/test_task_group.py @@ -0,0 +1,59 @@ +import asyncio +import unittest.mock + +import pytest + +from strands.experimental.bidi._async._task_group import _TaskGroup + + +@pytest.mark.asyncio +async def test_task_group__aexit__(): + coro = unittest.mock.AsyncMock() + + async with _TaskGroup() as task_group: + task_group.create_task(coro()) + + coro.assert_called_once() + + +@pytest.mark.asyncio +async def test_task_group__aexit__exception(): + wait_event = asyncio.Event() + async def wait(): + await wait_event.wait() + + async def fail(): + raise ValueError("test error") + + with pytest.raises(ValueError, match=r"test error"): + async with _TaskGroup() as task_group: + wait_task = task_group.create_task(wait()) + fail_task = task_group.create_task(fail()) + + assert wait_task.cancelled() + assert not fail_task.cancelled() + + +@pytest.mark.asyncio +async def test_task_group__aexit__cancelled(): + wait_event = asyncio.Event() + async def wait(): + await wait_event.wait() + + tasks = [] + + run_event = asyncio.Event() + async def run(): + async with _TaskGroup() as task_group: + tasks.append(task_group.create_task(wait())) + run_event.set() + + run_task = asyncio.create_task(run()) + await run_event.wait() + run_task.cancel() + + with pytest.raises(asyncio.CancelledError): + await run_task + + wait_task = tasks[0] + assert wait_task.cancelled() diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index 79bb29d41..3a9d7e3dc 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -185,7 +185,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): model4 = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) await model4.start() mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") - with pytest.raises(ExceptionGroup): + with pytest.raises(Exception, match=r"failed stop sequence"): await model4.stop() diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py index a973e80aa..1cabbc92b 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -353,7 +353,7 @@ async def async_connect(*args, **kwargs): model4 = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) await model4.start() mock_ws.close.side_effect = Exception("Close failed") - with pytest.raises(ExceptionGroup): + with pytest.raises(Exception, match=r"failed stop sequence"): await model4.stop()