Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 8 additions & 60 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,18 @@ a2a = [
"starlette>=0.46.2,<1.0.0",
]

bidi = [
"aws_sdk_bedrock_runtime; python_version>='3.12'",
bidi-io = [
"prompt_toolkit>=3.0.0,<4.0.0",
"pyaudio>=0.2.13,<1.0.0",
"smithy-aws-core>=0.0.1; python_version>='3.12'",
]
bidi-gemini = ["google-genai>=1.32.0,<2.0.0"]
bidi-nova = [
"aws_sdk_bedrock_runtime; python_version>='3.12'",
"smithy-aws-core>=0.0.1; python_version>='3.12'",
]
bidi-openai = ["websockets>=15.0.0,<16.0.0"]

all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"]
bidi-all = ["strands-agents[a2a,bidi,bidi-gemini,bidi-openai,docs,otel]"]
all = ["strands-agents[a2a,anthropic,bidi-io,bidi-gemini,bidi-openai,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"]

dev = [
"commitizen>=4.4.0,<5.0.0",
Expand Down Expand Up @@ -130,7 +131,7 @@ format-fix = [
]
lint-check = [
"ruff check",
"mypy ./src"
"mypy -p src"
]
lint-fix = [
"ruff check --fix"
Expand Down Expand Up @@ -204,16 +205,10 @@ warn_no_return = true
warn_unreachable = true
follow_untyped_imports = true
ignore_missing_imports = false
exclude = ["src/strands/experimental/bidi"]

[[tool.mypy.overrides]]
module = ["strands.experimental.bidi.*"]
follow_imports = "skip"

[tool.ruff]
line-length = 120
include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests_integ/**/*.py"]
exclude = ["src/strands/experimental/bidi/**/*.py", "tests/strands/experimental/bidi/**/*.py", "tests_integ/bidi/**/*.py"]

[tool.ruff.lint]
select = [
Expand All @@ -236,16 +231,14 @@ convention = "google"
[tool.pytest.ini_options]
testpaths = ["tests"]
asyncio_default_fixture_loop_scope = "function"
addopts = "--ignore=tests/strands/experimental/bidi --ignore=tests_integ/bidi"

addopts = "--ignore=tests/strands/experimental/bidi/models/test_nova_sonic.py --ignore=tests_integ/bidi"

[tool.coverage.run]
branch = true
source = ["src"]
context = "thread"
parallel = true
concurrency = ["thread", "multiprocessing"]
omit = ["src/strands/experimental/bidi/*"]

[tool.coverage.report]
show_missing = true
Expand Down Expand Up @@ -275,48 +268,3 @@ style = [
["text", ""],
["disabled", "fg:#858585 italic"]
]

# =========================
# Bidi development configs
# =========================

[tool.hatch.envs.bidi]
dev-mode = true
features = ["dev", "bidi-all"]
installer = "uv"

[tool.hatch.envs.bidi.scripts]
prepare = [
"hatch run bidi-lint:format-fix",
"hatch run bidi-lint:quality-fix",
"hatch run bidi-lint:type-check",
"hatch run bidi-test:test-cov",
]

[tools.hatch.envs.bidi-lint]
template = "bidi"

[tool.hatch.envs.bidi-lint.scripts]
format-check = "format-fix --check"
format-fix = "ruff format {args} --target-version py312 ./src/strands/experimental/bidi/**/*.py"
quality-check = "ruff check {args} --target-version py312 ./src/strands/experimental/bidi/**/*.py"
quality-fix = "quality-check --fix"
type-check = "mypy {args} --python-version 3.12 ./src/strands/experimental/bidi/**/*.py"

[tool.hatch.envs.bidi-test]
template = "bidi"

[tool.hatch.envs.bidi-test.scripts]
test = "pytest {args} tests/strands/experimental/bidi"
test-cov = """
test \
--cov=strands.experimental.bidi \
--cov-config= \
--cov-branch \
--cov-report=term-missing \
--cov-report=xml:build/coverage/bidi-coverage.xml \
--cov-report=html:build/coverage/bidi-html
"""

[[tool.hatch.envs.bidi-test.matrix]]
python = ["3.13", "3.12"]
4 changes: 2 additions & 2 deletions src/strands/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This module implements experimental features that are subject to change in future revisions without notice.
"""

from . import steering, tools
from . import bidi, steering, tools
from .agent_config import config_to_agent

__all__ = ["config_to_agent", "tools", "steering"]
__all__ = ["bidi", "config_to_agent", "tools", "steering"]
8 changes: 0 additions & 8 deletions src/strands/experimental/bidi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
"""Bidirectional streaming package."""

import sys

if sys.version_info < (3, 12):
raise ImportError("bidi only supported for >= Python 3.12")

# Main components - Primary user interface
# Re-export standard agent events for tool handling
from ...types._events import (
Expand All @@ -19,7 +14,6 @@

# Model interface (for custom implementations)
from .models.model import BidiModel
from .models.nova_sonic import BidiNovaSonicModel

# Built-in tools
from .tools import stop_conversation
Expand Down Expand Up @@ -48,8 +42,6 @@
"BidiAgent",
# IO channels
"BidiAudioIO",
# Model providers
"BidiNovaSonicModel",
# Built-in tools
"stop_conversation",
# Input Event types
Expand Down
8 changes: 6 additions & 2 deletions src/strands/experimental/bidi/_async/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None:
funcs: Stop functions to call in sequence.

Raises:
ExceptionGroup: If any stop function raises an exception.
RuntimeError: If any stop function raises an exception.
"""
exceptions = []
for func in funcs:
Expand All @@ -26,4 +26,8 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None:
exceptions.append(exception)

if exceptions:
raise ExceptionGroup("failed stop sequence", exceptions)
exceptions.append(RuntimeError("failed stop sequence"))
for i in range(1, len(exceptions)):
exceptions[i].__cause__ = exceptions[i - 1]

raise exceptions[-1]
28 changes: 17 additions & 11 deletions src/strands/experimental/bidi/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from ...tools import ToolProvider
from .._async import stop_all
from ..models.model import BidiModel
from ..models.nova_sonic import BidiNovaSonicModel
from ..types.agent import BidiAgentInput
from ..types.events import (
BidiAudioInputEvent,
Expand Down Expand Up @@ -100,13 +99,13 @@ def __init__(
ValueError: If model configuration is invalid or state is invalid type.
TypeError: If model type is unsupported.
"""
self.model = (
BidiNovaSonicModel()
if not model
else BidiNovaSonicModel(model_id=model)
if isinstance(model, str)
else model
)
if isinstance(model, BidiModel):
self.model = model
else:
from ..models.nova_sonic import BidiNovaSonicModel

self.model = BidiNovaSonicModel(model_id=model) if isinstance(model, str) else BidiNovaSonicModel()

self.system_prompt = system_prompt
self.messages = messages or []

Expand Down Expand Up @@ -390,9 +389,16 @@ async def run_outputs(inputs_task: asyncio.Task) -> None:
for start in [*input_starts, *output_starts]:
await start(self)

async with asyncio.TaskGroup() as task_group:
inputs_task = task_group.create_task(run_inputs())
task_group.create_task(run_outputs(inputs_task))
inputs_task = asyncio.create_task(run_inputs())
outputs_task = asyncio.create_task(run_outputs(inputs_task))

try:
await asyncio.gather(inputs_task, outputs_task)
except (Exception, asyncio.CancelledError):
inputs_task.cancel()
outputs_task.cancel()
await asyncio.gather(inputs_task, outputs_task, return_exceptions=True)
raise

finally:
input_stops = [input_.stop for input_ in inputs if isinstance(input_, BidiInput)]
Expand Down
2 changes: 0 additions & 2 deletions src/strands/experimental/bidi/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""Bidirectional model interfaces and implementations."""

from .model import BidiModel, BidiModelTimeoutError
from .nova_sonic import BidiNovaSonicModel

__all__ = [
"BidiModel",
"BidiModelTimeoutError",
"BidiNovaSonicModel",
]
3 changes: 2 additions & 1 deletion src/strands/experimental/bidi/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""

import logging
from typing import Any, AsyncIterable, Protocol
from typing import Any, AsyncIterable, Protocol, runtime_checkable

from ....types._events import ToolResultEvent
from ....types.content import Messages
Expand All @@ -27,6 +27,7 @@
logger = logging.getLogger(__name__)


@runtime_checkable
class BidiModel(Protocol):
"""Protocol for bidirectional streaming models.

Expand Down
30 changes: 23 additions & 7 deletions src/strands/experimental/bidi/models/nova_sonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,41 @@
- Tool execution with content containers and identifier tracking
- 8-minute connection limits with proper cleanup sequences
- Interruption detection through stopReason events

Note, BidiNovaSonicModel is only supported for Python 3.12+
"""

import asyncio
import sys

if sys.version_info < (3, 12):
raise ImportError("BidiNovaSonicModel is only supported for Python 3.12+")

import asyncio # type: ignore[unreachable]
import base64
import json
import logging
import uuid
from typing import Any, AsyncGenerator, cast

import boto3
from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput
from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme
from aws_sdk_bedrock_runtime.models import (
from aws_sdk_bedrock_runtime.client import ( # type: ignore[import-not-found]
BedrockRuntimeClient,
InvokeModelWithBidirectionalStreamOperationInput,
)
from aws_sdk_bedrock_runtime.config import ( # type: ignore[import-not-found]
Config,
HTTPAuthSchemeResolver,
SigV4AuthScheme,
)
from aws_sdk_bedrock_runtime.models import ( # type: ignore[import-not-found]
BidirectionalInputPayloadPart,
InvokeModelWithBidirectionalStreamInputChunk,
ModelTimeoutException,
ValidationException,
)
from smithy_aws_core.identity.static import StaticCredentialsResolver
from smithy_core.aio.eventstream import DuplexEventStream
from smithy_core.shapes import ShapeID
from smithy_aws_core.identity.static import StaticCredentialsResolver # type: ignore[import-not-found]
from smithy_core.aio.eventstream import DuplexEventStream # type: ignore[import-not-found]
from smithy_core.shapes import ShapeID # type: ignore[import-not-found]

from ....types._events import ToolResultEvent, ToolUseStreamEvent
from ....types.content import Messages
Expand Down Expand Up @@ -93,6 +107,8 @@ class BidiNovaSonicModel(BidiModel):
Manages Nova Sonic's complex event sequencing, audio format conversion, and
tool execution patterns while providing the standard BidiModel interface.

Note, BidiNovaSonicModel is only supported for Python 3.12+.

Attributes:
_stream: open bedrock stream to nova sonic.
"""
Expand Down
4 changes: 2 additions & 2 deletions src/strands/tools/_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import json
import random
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, cast

from .._async import run_async
from ..tools.executors._executor import ToolExecutor
Expand Down Expand Up @@ -108,7 +108,7 @@ async def acall() -> ToolResult:

# Apply conversation management if agent supports it (traditional agents)
if hasattr(self._agent, "conversation_manager"):
self._agent.conversation_manager.apply_management(self._agent)
self._agent.conversation_manager.apply_management(cast("Agent", self._agent))

return tool_result

Expand Down
Loading
Loading