Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/strands/experimental/bidi/_async/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from typing import Awaitable, Callable

from ._task_group import _TaskGroup
from ._task_pool import _TaskPool

__all__ = ["_TaskPool"]
__all__ = ["_TaskGroup", "_TaskPool"]


async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None:
Expand All @@ -16,14 +17,14 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None:
funcs: Stop functions to call in sequence.

Raises:
ExceptionGroup: If any stop function raises an exception.
RuntimeError: If any stop function raises an exception.
"""
exceptions = []
for func in funcs:
try:
await func()
except Exception as exception:
exceptions.append(exception)
exceptions.append({"func_name": func.__name__, "exception": repr(exception)})

if exceptions:
raise ExceptionGroup("failed stop sequence", exceptions)
raise RuntimeError(f"exceptions={exceptions} | failed stop sequence")
61 changes: 61 additions & 0 deletions src/strands/experimental/bidi/_async/_task_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Manage a group of async tasks.
This is intended to mimic the behaviors of asyncio.TaskGroup released in Python 3.11.
- Docs: https://docs.python.org/3/library/asyncio-task.html#task-groups
"""

import asyncio
from typing import Any, Coroutine


class _TaskGroup:
"""Shim of asyncio.TaskGroup for use in Python 3.10.
Attributes:
_tasks: List of tasks in group.
"""

_tasks: list[asyncio.Task]

def create_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task:
"""Create an async task and add to group.
Returns:
The created task.
"""
task = asyncio.create_task(coro)
self._tasks.append(task)
return task

async def __aenter__(self) -> "_TaskGroup":
"""Setup self managed task group context."""
self._tasks = []
return self

async def __aexit__(self, *_: Any) -> None:
"""Execute tasks in group.
The following execution rules are enforced:
- The context stops executing all tasks if at least one task raises an Exception or the context is cancelled.
- The context re-raises Exceptions to the caller.
- The context re-raises CancelledErrors to the caller only if the context itself was cancelled.
"""
try:
await asyncio.gather(*self._tasks)

except (Exception, asyncio.CancelledError) as error:
for task in self._tasks:
task.cancel()

await asyncio.gather(*self._tasks, return_exceptions=True)

if not isinstance(error, asyncio.CancelledError):
raise

context_task = asyncio.current_task()
if context_task and context_task.cancelling() > 0: # context itself was cancelled
raise

finally:
self._tasks = []
4 changes: 2 additions & 2 deletions src/strands/experimental/bidi/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ....types.tools import AgentTool
from ...hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent
from ...tools import ToolProvider
from .._async import stop_all
from .._async import _TaskGroup, stop_all
from ..models.model import BidiModel
from ..models.nova_sonic import BidiNovaSonicModel
from ..types.agent import BidiAgentInput
Expand Down Expand Up @@ -390,7 +390,7 @@ async def run_outputs(inputs_task: asyncio.Task) -> None:
for start in [*input_starts, *output_starts]:
await start(self)

async with asyncio.TaskGroup() as task_group:
async with _TaskGroup() as task_group:
inputs_task = task_group.create_task(run_inputs())
task_group.create_task(run_outputs(inputs_task))

Expand Down
12 changes: 7 additions & 5 deletions tests/strands/experimental/bidi/_async/test__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,19 @@ async def test_stop_exception():
func1 = AsyncMock()
func2 = AsyncMock(side_effect=ValueError("stop 2 failed"))
func3 = AsyncMock()
func4 = AsyncMock(side_effect=ValueError("stop 4 failed"))

with pytest.raises(ExceptionGroup) as exc_info:
await stop_all(func1, func2, func3)
with pytest.raises(Exception, match=r"failed stop sequence") as exc_info:
await stop_all(func1, func2, func3, func4)

func1.assert_called_once()
func2.assert_called_once()
func3.assert_called_once()
func4.assert_called_once()

assert len(exc_info.value.exceptions) == 1
with pytest.raises(ValueError, match=r"stop 2 failed"):
raise exc_info.value.exceptions[0]
tru_message = str(exc_info.value)
assert "ValueError('stop 2 failed')" in tru_message
assert "ValueError('stop 4 failed')" in tru_message


@pytest.mark.asyncio
Expand Down
59 changes: 59 additions & 0 deletions tests/strands/experimental/bidi/_async/test_task_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import asyncio
import unittest.mock

import pytest

from strands.experimental.bidi._async._task_group import _TaskGroup


@pytest.mark.asyncio
async def test_task_group__aexit__():
coro = unittest.mock.AsyncMock()

async with _TaskGroup() as task_group:
task_group.create_task(coro())

coro.assert_called_once()


@pytest.mark.asyncio
async def test_task_group__aexit__exception():
wait_event = asyncio.Event()
async def wait():
await wait_event.wait()

async def fail():
raise ValueError("test error")

with pytest.raises(ValueError, match=r"test error"):
async with _TaskGroup() as task_group:
wait_task = task_group.create_task(wait())
fail_task = task_group.create_task(fail())

assert wait_task.cancelled()
assert not fail_task.cancelled()


@pytest.mark.asyncio
async def test_task_group__aexit__cancelled():
wait_event = asyncio.Event()
async def wait():
await wait_event.wait()

tasks = []

run_event = asyncio.Event()
async def run():
async with _TaskGroup() as task_group:
tasks.append(task_group.create_task(wait()))
run_event.set()

run_task = asyncio.create_task(run())
await run_event.wait()
run_task.cancel()

with pytest.raises(asyncio.CancelledError):
await run_task

wait_task = tasks[0]
assert wait_task.cancelled()
2 changes: 1 addition & 1 deletion tests/strands/experimental/bidi/models/test_gemini_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id):
model4 = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key})
await model4.start()
mock_live_session_cm.__aexit__.side_effect = Exception("Close failed")
with pytest.raises(ExceptionGroup):
with pytest.raises(Exception, match=r"failed stop sequence"):
await model4.stop()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ async def async_connect(*args, **kwargs):
model4 = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key})
await model4.start()
mock_ws.close.side_effect = Exception("Close failed")
with pytest.raises(ExceptionGroup):
with pytest.raises(Exception, match=r"failed stop sequence"):
await model4.stop()


Expand Down
Loading