From 107f03533c2abde1586109ec4d4709a556801c21 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 25 Sep 2025 10:47:46 -0400 Subject: [PATCH 001/242] feat(bidirectional_streaming): Add experimental bidirectional streaming MVP POC implementation --- pyproject.toml | 9 +- .../bidirectional_streaming/agent/__init__.py | 2 + .../bidirectional_streaming/agent/agent.py | 167 ++++ .../event_loop/__init__.py | 2 + .../event_loop/bidirectional_event_loop.py | 539 ++++++++++++ .../models/__init__.py | 2 + .../models/bidirectional_model.py | 115 +++ .../models/novasonic.py | 777 ++++++++++++++++++ .../tests/test_bidirectional_streaming.py | 203 +++++ .../bidirectional_streaming/types/__init__.py | 3 + .../types/bidirectional_streaming.py | 167 ++++ .../bidirectional_streaming/utils/debug.py | 45 + 12 files changed, 2030 insertions(+), 1 deletion(-) create mode 100644 src/strands/experimental/bidirectional_streaming/agent/__init__.py create mode 100644 src/strands/experimental/bidirectional_streaming/agent/agent.py create mode 100644 src/strands/experimental/bidirectional_streaming/event_loop/__init__.py create mode 100644 src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py create mode 100644 src/strands/experimental/bidirectional_streaming/models/__init__.py create mode 100644 src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py create mode 100644 src/strands/experimental/bidirectional_streaming/models/novasonic.py create mode 100644 src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py create mode 100644 src/strands/experimental/bidirectional_streaming/types/__init__.py create mode 100644 src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py create mode 100644 src/strands/experimental/bidirectional_streaming/utils/debug.py diff --git a/pyproject.toml b/pyproject.toml index 3c2243299..d4f7e6eee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,13 @@ sagemaker = [ "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0", "openai>=1.68.0,<2.0.0", # SageMaker uses OpenAI-compatible interface ] +bidirectional-streaming = [ + "pyaudio>=0.2.13", + "rx>=3.2.0", + "smithy-aws-core>=0.0.1", + "pytz", + "aws_sdk_bedrock_runtime", +] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ "sphinx>=5.0.0,<6.0.0", @@ -68,7 +75,7 @@ a2a = [ "fastapi>=0.115.12,<1.0.0", "starlette>=0.46.2,<1.0.0", ] -all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] +all = ["strands-agents[a2a,anthropic,bidirectional-streaming,docs,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] dev = [ "commitizen>=4.4.0,<5.0.0", diff --git a/src/strands/experimental/bidirectional_streaming/agent/__init__.py b/src/strands/experimental/bidirectional_streaming/agent/__init__.py new file mode 100644 index 000000000..bbd2c91f3 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/agent/__init__.py @@ -0,0 +1,2 @@ +"""Bidirectional streaming agent package.""" +# Agent package \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py new file mode 100644 index 000000000..cfc005576 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -0,0 +1,167 @@ +"""Bidirectional Agent for real-time streaming conversations. + +AGENT PURPOSE: +------------- +Provides type-safe constructor and session management for real-time audio/text +interaction. Serves as the bidirectional equivalent to invoke_async() → stream_async() +but establishes sessions that continue indefinitely with concurrent task management. + +ARCHITECTURAL APPROACH: +---------------------- +While invoke_async() creates single request-response cycles that terminate after +stop_reason: "end_turn" with sequential tool processing, start_conversation() +establishes persistent sessions with concurrent processing of model events, tool +execution, and user input without session termination. + +DESIGN CHOICE: +------------- +Uses dedicated BidirectionalAgent class (Option 1 from design document) for: +- Type safety with no conditional behavior based on model type +- Separation of concerns - solely focused on bidirectional streaming +- Future proofing - allows changes without implications to existing Agent class +""" + +import asyncio +import logging +from typing import AsyncIterable, List, Optional + +from strands.tools.registry import ToolRegistry +from strands.types.content import Messages + +from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection +from ..models.bidirectional_model import BidirectionalModel +from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent +from ..utils.debug import log_event, log_flow + +logger = logging.getLogger(__name__) + + +class BidirectionalAgent: + """Agent for bidirectional streaming conversations. + + Provides type-safe constructor and session management for real-time + audio/text interaction with concurrent processing capabilities. + """ + + def __init__( + self, + model: BidirectionalModel, + tools: Optional[List] = None, + system_prompt: Optional[str] = None, + messages: Optional[Messages] = None + ): + """Initialize bidirectional agent with required model and optional configuration. + + Args: + model: BidirectionalModel instance supporting streaming sessions. + tools: Optional list of tools available to the model. + system_prompt: Optional system prompt for conversations. + messages: Optional conversation history to initialize with. + """ + self.model = model + self.system_prompt = system_prompt + self.messages = messages or [] + + # Initialize tool registry using existing Strands infrastructure + self.tool_registry = ToolRegistry() + if tools: + self.tool_registry.process_tools(tools) + self.tool_registry.initialize_tools() + + # Initialize tool executor for concurrent execution + from strands.tools.executors import ConcurrentToolExecutor + self.tool_executor = ConcurrentToolExecutor() + + # Session management + self._session = None + self._output_queue = asyncio.Queue() + + async def start_conversation(self) -> None: + """Initialize persistent bidirectional session for real-time interaction. + + Creates provider-specific session and starts concurrent background tasks + for model events, tool execution, and session lifecycle management. + + Raises: + ValueError: If conversation already active. + ConnectionError: If session creation fails. + """ + if self._session and self._session.active: + raise ValueError("Conversation already active. Call end_conversation() first.") + + log_flow("conversation_start", "initializing session") + self._session = await start_bidirectional_connection(self) + log_event("conversation_ready") + + async def send_text(self, text: str) -> None: + """Send text input during active session without interrupting model generation. + + Args: + text: Text message to send to the model. + + Raises: + ValueError: If no active session. + """ + self._validate_active_session() + log_event("text_sent", length=len(text)) + await self._session.model_session.send_text_content(text) + + async def send_audio(self, audio_input: AudioInputEvent) -> None: + """Send audio input during active session for real-time speech interaction. + + Args: + audio_input: AudioInputEvent containing audio data and configuration. + + Raises: + ValueError: If no active session. + """ + self._validate_active_session() + await self._session.model_session.send_audio_content(audio_input) + + async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: + """Receive output events from the model including audio, text. + + Provides access to model output events processed by background tasks. + Events include audio output, text responses, tool calls, and session updates. + + Yields: + BidirectionalStreamEvent: Events from the model session. + """ + while self._session and self._session.active: + try: + event = await asyncio.wait_for(self._output_queue.get(), timeout=0.1) + yield event + except asyncio.TimeoutError: + continue + + async def interrupt(self) -> None: + """Interrupt current model generation and switch to listening mode. + + Sends interruption signal to immediately stop generation and clear + pending audio output for responsive conversational experience. + + Raises: + ValueError: If no active session. + """ + self._validate_active_session() + await self._session.model_session.send_interrupt() + + async def end_conversation(self) -> None: + """End session and cleanup resources including background tasks. + + Performs graceful session termination with proper resource cleanup + including background task cancellation and connection closure. + """ + if self._session: + await stop_bidirectional_connection(self._session) + self._session = None + + def _validate_active_session(self) -> None: + """Validate that an active session exists. + + Raises: + ValueError: If no active session. + """ + if not self._session or not self._session.active: + raise ValueError("No active conversation. Call start_conversation() first.") + diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py new file mode 100644 index 000000000..24080b703 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py @@ -0,0 +1,2 @@ +"""Bidirectional streaming event loop package.""" +# Event Loop package \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py new file mode 100644 index 000000000..2164115d8 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -0,0 +1,539 @@ +"""Bidirectional session management for concurrent streaming conversations. + +SESSION PURPOSE: +--------------- +Session wrapper for bidirectional communication that manages concurrent tasks for +model events, tool execution, and audio processing while providing simple interface +for Agent interaction. + +CONCURRENT ARCHITECTURE: +----------------------- +Unlike existing event_loop_cycle() that processes events sequentially where tool +execution blocks conversation, this module coordinates concurrent tasks through +asyncio queues and background task management. +""" + +import asyncio +import json +import logging +import traceback +import uuid +from typing import Any, Dict + +from strands.tools._validator import validate_and_prepare_tools +from strands.types.content import Message +from strands.types.tools import ToolResult, ToolUse + +from ..models.bidirectional_model import BidirectionalModelSession +from ..utils.debug import log_event, log_flow + +logger = logging.getLogger(__name__) + +# Session constants +TOOL_QUEUE_TIMEOUT = 0.5 +SUPERVISION_INTERVAL = 0.1 + + +class BidirectionalConnection: + """Session wrapper for bidirectional communication. + + Manages concurrent tasks for model events, tool execution, and audio processing + while providing simple interface for Agent interaction. + """ + + def __init__(self, model_session: BidirectionalModelSession, agent): + """Initialize session with model session and agent reference. + + Args: + model_session: Provider-specific bidirectional model session. + agent: BidirectionalAgent instance for tool registry access. + """ + self.model_session = model_session + self.agent = agent + self.active = True + + # Background processing coordination + self.background_tasks = [] + self.tool_queue = asyncio.Queue() + self.audio_output_queue = asyncio.Queue() + + # Task management for cleanup + self.pending_tool_tasks: Dict[str, asyncio.Task] = {} + + # Interruption handling (model-agnostic) + self.interrupted = False + +async def start_bidirectional_connection(agent) -> BidirectionalConnection: + """Initialize bidirectional session with concurrent background tasks. + + Creates provider-specific session and starts concurrent tasks for model events, + tool execution, and session lifecycle management. + + Args: + agent: BidirectionalAgent instance. + + Returns: + BidirectionalConnection: Active session with background tasks running. + """ + log_flow("session_start", "initializing model session") + + # Create provider-specific session + model_session = await agent.model.create_bidirectional_connection( + system_prompt=agent.system_prompt, + tools=agent.tool_registry.get_all_tool_specs(), + messages=agent.messages + ) + + # Create session wrapper for background processing + session = BidirectionalConnection(model_session=model_session, agent=agent) + + # Start concurrent background processors IMMEDIATELY after session creation + # This is critical - Nova Sonic needs response processing during initialization + log_flow("background_tasks", "starting processors") + session.background_tasks = [ + asyncio.create_task(_process_model_events(session)), # Handle model responses + asyncio.create_task(_process_tool_execution(session)) # Execute tools concurrently + ] + + # Start main coordination cycle + session.main_cycle_task = asyncio.create_task( + bidirectional_event_loop_cycle(session) + ) + + # Give background tasks a moment to start + await asyncio.sleep(0.1) + log_event("session_ready", tasks=len(session.background_tasks)) + + return session + + +async def stop_bidirectional_connection(session: BidirectionalConnection) -> None: + """End session and cleanup resources including background tasks. + + Args: + session: BidirectionalConnection to cleanup. + """ + if not session.active: + return + + log_flow("session_cleanup", "starting") + session.active = False + + # Cancel pending tool tasks + for _, task in session.pending_tool_tasks.items(): + if not task.done(): + task.cancel() + + # Cancel background tasks + for task in session.background_tasks: + if not task.done(): + task.cancel() + + # Cancel main cycle task + if hasattr(session, 'main_cycle_task') and not session.main_cycle_task.done(): + session.main_cycle_task.cancel() + + # Wait for tasks to complete + all_tasks = session.background_tasks + list(session.pending_tool_tasks.values()) + if hasattr(session, 'main_cycle_task'): + all_tasks.append(session.main_cycle_task) + + if all_tasks: + await asyncio.gather(*all_tasks, return_exceptions=True) + + # Close model session + await session.model_session.close() + log_event("session_closed") + + +async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None: + """Main bidirectional event loop coordinator - runs continuously during session. + + Coordinates background tasks and manages session lifecycle. Unlike the + sequential event_loop_cycle() that processes events one by one, this coordinator + manages concurrent tasks and session state. + + Args: + session: BidirectionalConnection to coordinate. + """ + while session.active: + try: + # Check if background processors are still running + if all(task.done() for task in session.background_tasks): + log_event("session_end", reason="all_processors_completed") + session.active = False + break + + # Check for failed background tasks + for i, task in enumerate(session.background_tasks): + if task.done() and not task.cancelled(): + exception = task.exception() + if exception: + log_event("session_error", processor=i, error=str(exception)) + session.active = False + raise exception + + # Brief pause before next supervision check + await asyncio.sleep(SUPERVISION_INTERVAL) + + except asyncio.CancelledError: + break + except Exception as e: + log_event("event_loop_error", error=str(e)) + session.active = False + raise + + +async def _handle_interruption(session: BidirectionalConnection) -> None: + """Handle interruption detection with comprehensive task cancellation. + + Sets interruption flag, cancels pending tool tasks, and aggressively + clears audio output queue following Nova Sonic example patterns. + + Args: + session: BidirectionalConnection to handle interruption for. + """ + log_event("interruption_detected") + session.interrupted = True + + # 🔥 CANCEL ALL PENDING TOOL TASKS (Nova Sonic pattern) + cancelled_tools = 0 + for task_id, task in list(session.pending_tool_tasks.items()): + if not task.done(): + task.cancel() + cancelled_tools += 1 + log_event("tool_task_cancelled", task_id=task_id) + + if cancelled_tools > 0: + log_event("tool_tasks_cancelled", count=cancelled_tools) + + # 🔥 AGGRESSIVELY CLEAR AUDIO OUTPUT QUEUE (Nova Sonic pattern) + cleared_count = 0 + while True: + try: + session.audio_output_queue.get_nowait() + cleared_count += 1 + except asyncio.QueueEmpty: + break + + # Also clear the agent's audio output queue if it exists + if hasattr(session.agent, '_output_queue'): + audio_cleared = 0 + # Create a temporary list to hold non-audio events + temp_events = [] + try: + while True: + event = session.agent._output_queue.get_nowait() + if event.get("audioOutput"): + audio_cleared += 1 + else: + # Keep non-audio events + temp_events.append(event) + except asyncio.QueueEmpty: + pass + + # Put back non-audio events + for event in temp_events: + session.agent._output_queue.put_nowait(event) + + if audio_cleared > 0: + log_event("agent_audio_queue_cleared", count=audio_cleared) + + if cleared_count > 0: + log_event("session_audio_queue_cleared", count=cleared_count) + + # Brief sleep to allow audio system to settle (matches Nova Sonic timing) + await asyncio.sleep(0.05) + + # Reset interruption flag after clearing (automatic recovery) + session.interrupted = False + log_event("interruption_handled", tools_cancelled=cancelled_tools, audio_cleared=cleared_count) + + +async def _process_model_events(session: BidirectionalConnection) -> None: + """Process model events using existing Strands event types. + + This background task handles all model responses and converts + them to existing StreamEvent format for integration with Strands. + + Args: + session: BidirectionalConnection containing model session. + """ + log_flow("model_events", "processor started") + try: + async for provider_event in session.model_session.receive_events(): + if not session.active: + break + + # Convert provider events to Strands format + strands_event = _convert_to_strands_event(provider_event) + + # Handle interruption detection (multiple patterns) + if strands_event.get("interruptionDetected"): + log_event("interruption_forwarded") + await _handle_interruption(session) + # Forward interruption event to agent for application-level handling + await session.agent._output_queue.put(strands_event) + continue + + # Check for text-based interruption (Nova Sonic pattern) + if strands_event.get("textOutput"): + text_content = strands_event["textOutput"].get("content", "") + if '{ "interrupted" : true }' in text_content: + log_event("text_interruption_detected") + await _handle_interruption(session) + # Still forward the text event + await session.agent._output_queue.put(strands_event) + continue + + # Queue tool requests for concurrent execution + if strands_event.get("toolUse"): + log_event("tool_queued", name=strands_event["toolUse"].get("name")) + await session.tool_queue.put(strands_event["toolUse"]) + continue + + # Send output events to Agent for receive() method + if strands_event.get("audioOutput") or strands_event.get("textOutput"): + await session.agent._output_queue.put(strands_event) + + # Update Agent conversation history using existing patterns + if strands_event.get("messageStop"): + log_event("message_added_to_history") + session.agent.messages.append(strands_event["messageStop"]["message"]) + + except Exception as e: + log_event("model_events_error", error=str(e)) + traceback.print_exc() + finally: + log_flow("model_events", "processor stopped") + + +async def _process_tool_execution(session: BidirectionalConnection) -> None: + """Execute tools concurrently using existing Strands infrastructure with barge-in support. + + This background task manages tool execution without blocking + model event processing or user interaction. Includes proper + task cleanup and cancellation handling. + + Args: + session: BidirectionalConnection containing tool queue. + """ + log_flow("tool_execution", "processor started") + while session.active: + try: + tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) + log_event("tool_execution_started", name=tool_use.get("name"), id=tool_use.get("toolUseId")) + + if not session.active: + break + + task_id = str(uuid.uuid4()) + task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) + session.pending_tool_tasks[task_id] = task + + # 🔥 ADD CLEANUP CALLBACK (Nova Sonic pattern) + def cleanup_task(completed_task): + try: + # Remove from pending tasks + if task_id in session.pending_tool_tasks: + del session.pending_tool_tasks[task_id] + + # Log completion status + if completed_task.cancelled(): + log_event("tool_task_cleanup_cancelled", task_id=task_id) + elif completed_task.exception(): + log_event("tool_task_cleanup_error", task_id=task_id, + error=str(completed_task.exception())) + else: + log_event("tool_task_cleanup_success", task_id=task_id) + except Exception as e: + log_event("tool_task_cleanup_failed", task_id=task_id, error=str(e)) + + task.add_done_callback(cleanup_task) + + except asyncio.TimeoutError: + if not session.active: + break + # 🔥 PERIODIC CLEANUP OF COMPLETED TASKS + completed_tasks = [ + task_id for task_id, task in session.pending_tool_tasks.items() + if task.done() + ] + for task_id in completed_tasks: + if task_id in session.pending_tool_tasks: + del session.pending_tool_tasks[task_id] + + if completed_tasks: + log_event("periodic_task_cleanup", count=len(completed_tasks)) + + continue + except Exception as e: + log_event("tool_execution_error", error=str(e)) + if not session.active: + break + + log_flow("tool_execution", "processor stopped") + + +def _convert_to_strands_event(provider_event: Dict) -> Dict: + """Pass-through for events already normalized by provider sessions. + + Providers convert their raw events to standard format before reaching here. + This just validates and passes through the normalized events. + + Args: + provider_event: Already normalized event from provider session. + + Returns: + Dict: The same event, validated and passed through. + """ + # Basic validation - ensure we have a dict + if not isinstance(provider_event, dict): + return {} + + # Pass through - conversion already done by provider session + return provider_event + + +async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: Dict) -> None: + """Execute tool using existing Strands infrastructure with barge-in support. + + Model-agnostic tool execution that uses existing Strands tool system, + handles interruption during execution, and delegates result formatting + to provider-specific session. + + Args: + session: BidirectionalConnection for context. + tool_use: Tool use event to execute. + """ + tool_name = tool_use.get('name') + tool_id = tool_use.get('toolUseId') + + try: + # 🔥 CHECK FOR INTERRUPTION BEFORE STARTING (Nova Sonic pattern) + if session.interrupted or not session.active: + log_event("tool_execution_cancelled_before_start", name=tool_name, id=tool_id) + return + + # Create message structure for existing tool system + tool_message: Message = { + "role": "assistant", + "content": [{"toolUse": tool_use}] + } + + tool_uses: list[ToolUse] = [] + tool_results: list[ToolResult] = [] + invalid_tool_use_ids: list[str] = [] + + # Validate using existing Strands validation + validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids) + + # Filter valid tool uses + valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] + + if not valid_tool_uses: + log_event("tool_validation_failed", name=tool_name, id=tool_id) + return + + # Execute tools directly (simpler approach for bidirectional) + for tool_use in valid_tool_uses: + # 🔥 CHECK FOR INTERRUPTION DURING EXECUTION + if session.interrupted or not session.active: + log_event("tool_execution_cancelled_during", name=tool_name, id=tool_id) + return + + tool_func = session.agent.tool_registry.registry.get(tool_use["name"]) + + if tool_func: + try: + actual_func = _extract_callable_function(tool_func) + + # 🔥 WRAP TOOL EXECUTION IN CANCELLATION CHECK + # For async tools, we could wrap with asyncio.wait_for with cancellation + # For sync tools, we execute directly but check interruption after + result = actual_func(**tool_use.get("input", {})) + + # 🔥 CHECK FOR INTERRUPTION AFTER TOOL EXECUTION + if session.interrupted or not session.active: + log_event("tool_result_discarded_interruption", name=tool_name, id=tool_id) + return + + tool_result = _create_success_result(tool_use["toolUseId"], result) + tool_results.append(tool_result) + + except asyncio.CancelledError: + # Tool was cancelled due to interruption + log_event("tool_execution_cancelled", name=tool_name, id=tool_id) + return + except Exception as e: + # 🔥 CHECK FOR INTERRUPTION EVEN ON ERROR + if session.interrupted or not session.active: + log_event("tool_error_discarded_interruption", name=tool_name, id=tool_id) + return + + log_event("tool_execution_failed", name=tool_name, error=str(e)) + tool_result = _create_error_result(tool_use["toolUseId"], str(e)) + tool_results.append(tool_result) + else: + log_event("tool_not_found", name=tool_name) + + # 🔥 FINAL INTERRUPTION CHECK BEFORE SENDING RESULTS + if session.interrupted or not session.active: + log_event("tool_results_discarded_interruption", name=tool_name, count=len(tool_results)) + return + + # Send results through provider-specific session + for result in tool_results: + await session.model_session.send_tool_result( + tool_use.get("toolUseId"), + result + ) + + log_event("tool_execution_completed", name=tool_name, results=len(tool_results)) + + except asyncio.CancelledError: + # Task was cancelled due to interruption - this is expected behavior + log_event("tool_task_cancelled_gracefully", name=tool_name, id=tool_id) + raise # Re-raise to properly handle cancellation + except Exception as e: + log_event("tool_execution_error", name=tool_use.get('name'), error=str(e)) + + # Only send error if not interrupted + if not session.interrupted and session.active: + try: + await session.model_session.send_tool_error( + tool_use.get("toolUseId"), + str(e) + ) + except Exception as send_error: + log_event("tool_error_send_failed", error=str(send_error)) + + +def _extract_callable_function(tool_func): + """Extract the callable function from different tool object types.""" + if hasattr(tool_func, '_tool_func'): + return tool_func._tool_func + elif hasattr(tool_func, 'func'): + return tool_func.func + elif callable(tool_func): + return tool_func + else: + raise ValueError(f"Tool function not callable: {type(tool_func).__name__}") + + +def _create_success_result(tool_use_id: str, result) -> Dict[str, Any]: + """Create a successful tool result.""" + return { + "toolUseId": tool_use_id, + "status": "success", + "content": [{"text": json.dumps(result)}] + } + + +def _create_error_result(tool_use_id: str, error: str) -> Dict[str, Any]: + """Create an error tool result.""" + return { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Error: {error}"}] + } \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py new file mode 100644 index 000000000..b2b10a5f2 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -0,0 +1,2 @@ +"""Bidirectional streaming models package.""" +# Models package \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py new file mode 100644 index 000000000..32727105d --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -0,0 +1,115 @@ +"""Bidirectional model interface for real-time streaming conversations. + +INTERFACE PURPOSE: +----------------- +Declares bidirectional capabilities separate from existing Model hierarchy to maintain +clean separation of concerns. Models choose to implement this interface explicitly +for bidirectional streaming support. + +PROVIDER ABSTRACTION: +-------------------- +Abstracts incompatible initialization patterns: Nova Sonic's event-driven sequences, +Google's WebSocket setup, OpenAI's dual protocol support. Normalizes different tool +calling approaches and handles provider-specific session management with varying +time limits and connection patterns. + +SESSION-BASED APPROACH: +---------------------- +Unlike existing Model interface's stateless request-response pattern where each +stream() call processes complete messages independently, BidirectionalModel introduces +session-based approach where create_bidirectional_connection() establishes persistent +connections supporting real-time bidirectional communication during active generation. +""" + +import abc +import logging +from typing import Any, AsyncIterable, Dict, List, Optional + +from ....types.content import Messages +from ....types.tools import ToolSpec +from ..types.bidirectional_streaming import AudioInputEvent + +logger = logging.getLogger(__name__) + +class BidirectionalModelSession(abc.ABC): + """Model-specific session interface for bidirectional communication.""" + + @abc.abstractmethod + async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: + """Receive events from model in provider-agnostic format. + + Normalizes different provider event formats so the event loop + can process all providers uniformly. + """ + raise NotImplementedError + + @abc.abstractmethod + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Send audio content to model during session. + + Manages complex audio encoding and provider-specific event sequences + while presenting simple AudioInputEvent interface to Agent. + """ + raise NotImplementedError + + @abc.abstractmethod + async def send_text_content(self, text: str, **kwargs) -> None: + """Send text content processed concurrently with ongoing generation. + + Enables natural interruption and follow-up questions without session restart. + """ + raise NotImplementedError + + @abc.abstractmethod + async def send_interrupt(self) -> None: + """Send interruption signal to immediately stop generation. + + Critical for responsive conversational experiences where users + can naturally interrupt mid-response. + """ + raise NotImplementedError + + @abc.abstractmethod + async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: + """Send tool execution result to model in provider-specific format. + + Each provider handles result formatting according to their protocol: + - Nova Sonic: toolResult events with JSON content + - Google Live API: toolResponse with specific structure + - OpenAI Realtime: function call responses with call_id correlation + """ + raise NotImplementedError + + @abc.abstractmethod + async def send_tool_error(self, tool_use_id: str, error: str) -> None: + """Send tool execution error to model in provider-specific format.""" + raise NotImplementedError + + @abc.abstractmethod + async def close(self) -> None: + """Close session and cleanup resources with graceful termination.""" + raise NotImplementedError + + +class BidirectionalModel(abc.ABC): + """Interface for models that support bidirectional streaming. + + Separate from Model to maintain clean separation of concerns. + Models choose to implement this interface explicitly. + """ + + @abc.abstractmethod + async def create_bidirectional_connection( + self, + system_prompt: Optional[str] = None, + tools: Optional[List[ToolSpec]] = None, + messages: Optional[Messages] = None, + **kwargs + ) -> BidirectionalModelSession: + """Create bidirectional session with model-specific implementation. + + Abstracts complex provider-specific initialization while presenting + uniform interface to Agent. + """ + raise NotImplementedError + diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py new file mode 100644 index 000000000..ba71cd4d3 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -0,0 +1,777 @@ +"""Nova Sonic bidirectional model provider for real-time streaming conversations. + +PROVIDER PURPOSE: +---------------- +Implements BidirectionalModel and BidirectionalModelSession interfaces for Nova Sonic, +handling the complex three-tier event management and structured event cleanup sequences +required by Nova Sonic's InvokeModelWithBidirectionalStream protocol. + +NOVA SONIC SPECIFICS: +-------------------- +- Requires hierarchical event sequences: sessionStart → promptStart → content streaming +- Uses hex-encoded base64 audio format that needs conversion to raw bytes +- Implements toolUse/toolResult with content containers and identifier tracking +- Manages 8-minute session limits with proper cleanup sequences +- Handles stopReason: "INTERRUPTED" events for interruption detection + +INTEGRATION APPROACH: +-------------------- +Adapts existing Nova Sonic sample patterns to work with Strands bidirectional +infrastructure while maintaining provider-specific protocol requirements. +""" + +import asyncio +import base64 +import json +import logging +import time +import traceback +import uuid +from typing import Any, AsyncIterable, Dict, List, Optional + +from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput +from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme +from aws_sdk_bedrock_runtime.models import BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk +from smithy_aws_core.credentials_resolvers.environment import EnvironmentCredentialsResolver + +from ....types.content import Messages +from ....types.tools import ToolSpec, ToolUse +from ..types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + BidirectionalConnectionEndEvent, + BidirectionalConnectionStartEvent, + InterruptionDetectedEvent, + TextOutputEvent, +) +from ..utils.debug import log_event, log_flow, time_it_async +from .bidirectional_model import BidirectionalModel, BidirectionalModelSession + +logger = logging.getLogger(__name__) + +# Nova Sonic configuration constants +NOVA_INFERENCE_CONFIG = { + "maxTokens": 1024, + "topP": 0.9, + "temperature": 0.7 +} + +NOVA_AUDIO_INPUT_CONFIG = { + "mediaType": "audio/lpcm", + "sampleRateHertz": 16000, + "sampleSizeBits": 16, + "channelCount": 1, + "audioType": "SPEECH", + "encoding": "base64" +} + +NOVA_AUDIO_OUTPUT_CONFIG = { + "mediaType": "audio/lpcm", + "sampleRateHertz": 24000, + "sampleSizeBits": 16, + "channelCount": 1, + "voiceId": "matthew", + "encoding": "base64", + "audioType": "SPEECH" +} + +NOVA_TEXT_CONFIG = {"mediaType": "text/plain"} +NOVA_TOOL_CONFIG = {"mediaType": "application/json"} + +# Timing constants +SILENCE_THRESHOLD = 2.0 +EVENT_DELAY = 0.1 +RESPONSE_TIMEOUT = 1.0 + + +class NovaSonicSession(BidirectionalModelSession): + """Nova Sonic session handling protocol-specific details.""" + + def __init__(self, stream, config: Dict[str, Any]): + """Initialize Nova Sonic session. + + Args: + stream: Nova Sonic bidirectional stream. + config: Model configuration. + """ + self.stream = stream + self.config = config + self.prompt_name = str(uuid.uuid4()) + self._active = True + + # Nova Sonic requires unique content names + self.audio_content_name = str(uuid.uuid4()) + self.text_content_name = str(uuid.uuid4()) + + # Audio session state + self.audio_session_active = False + self.last_audio_time = None + self.silence_threshold = SILENCE_THRESHOLD + self.silence_task = None + + # Validate stream + if not stream: + logger.error("Stream is None") + raise ValueError("Stream cannot be None") + + logger.debug("Nova Sonic session initialized with prompt: %s", self.prompt_name) + + async def initialize( + self, + system_prompt: Optional[str] = None, + tools: Optional[List[ToolSpec]] = None, + messages: Optional[Messages] = None + ) -> None: + """Initialize Nova Sonic session with required protocol sequence.""" + try: + system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." + + init_events = self._build_initialization_events(system_prompt, tools or [], messages) + + log_flow("nova_init", f"sending {len(init_events)} events") + await self._send_initialization_events(init_events) + + log_event("nova_session_initialized") + self._response_task = asyncio.create_task(self._process_responses()) + + except Exception as e: + logger.error("Error during Nova Sonic initialization: %s", e) + raise + + def _build_initialization_events(self, system_prompt: str, tools: List[ToolSpec], + messages: Optional[Messages]) -> List[str]: + """Build the sequence of initialization events.""" + events = [ + self._get_session_start_event(), + self._get_prompt_start_event(tools) + ] + + events.extend(self._get_system_prompt_events(system_prompt)) + + # Message history would be processed here if needed in the future + # Currently not implemented as it's not used in the existing test cases + + return events + + async def _send_initialization_events(self, events: List[str]) -> None: + """Send initialization events with required delays.""" + for i, event in enumerate(events): + await time_it_async(f"send_init_event_{i+1}", lambda: self._send_nova_event(event)) + await asyncio.sleep(EVENT_DELAY) + + async def _process_responses(self) -> None: + """Process Nova Sonic responses continuously.""" + log_flow("nova_responses", "processor started") + + try: + while self._active: + try: + output = await asyncio.wait_for(self.stream.await_output(), timeout=RESPONSE_TIMEOUT) + result = await output[1].receive() + + if result.value and result.value.bytes_: + await self._handle_response_data(result.value.bytes_.decode('utf-8')) + + except asyncio.TimeoutError: + await asyncio.sleep(0.1) + continue + except Exception as e: + log_event("nova_response_error", error=str(e)) + await asyncio.sleep(0.1) + continue + + except Exception as e: + log_event("nova_fatal_error", error=str(e)) + finally: + log_flow("nova_responses", "processor stopped") + + async def _handle_response_data(self, response_data: str) -> None: + """Handle decoded response data from Nova Sonic.""" + try: + json_data = json.loads(response_data) + + if 'event' in json_data: + nova_event = json_data['event'] + self._log_event_type(nova_event) + + if not hasattr(self, '_event_queue'): + self._event_queue = asyncio.Queue() + + await self._event_queue.put(nova_event) + except json.JSONDecodeError as e: + log_event("nova_json_error", error=str(e)) + + def _log_event_type(self, nova_event: Dict[str, Any]) -> None: + """Log specific Nova Sonic event types for debugging.""" + if 'usageEvent' in nova_event: + log_event("nova_usage", usage=nova_event['usageEvent']) + elif 'textOutput' in nova_event: + log_event("nova_text_output") + elif 'toolUse' in nova_event: + tool_use = nova_event['toolUse'] + log_event("nova_tool_use", name=tool_use['toolName'], id=tool_use['toolUseId']) + elif 'audioOutput' in nova_event: + audio_content = nova_event['audioOutput']['content'] + audio_bytes = base64.b64decode(audio_content) + log_event("nova_audio_output", bytes=len(audio_bytes)) + + async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: + """Receive Nova Sonic events and convert to provider-agnostic format.""" + if not self.stream: + logger.error("Stream is None") + return + + log_flow("nova_events", "starting event stream") + + # Emit session start event to Strands event system + session_start: BidirectionalConnectionStartEvent = { + "sessionId": self.prompt_name, + "metadata": {"provider": "nova_sonic", "model_id": self.config.get("model_id")} + } + yield { + "BidirectionalConnectionStart": session_start + } + + # Initialize event queue if not already done + if not hasattr(self, '_event_queue'): + self._event_queue = asyncio.Queue() + + try: + while self._active: + try: + # Get events from the queue populated by _process_responses + nova_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) + + # Convert to provider-agnostic format + provider_event = self._convert_nova_event(nova_event) + if provider_event: + yield provider_event + + except asyncio.TimeoutError: + # No events in queue - continue waiting + continue + + except Exception as e: + logger.error("Error receiving Nova Sonic event: %s", e) + logger.error(traceback.format_exc()) + finally: + # Emit session end event when exiting + session_end: BidirectionalConnectionEndEvent = { + "sessionId": self.prompt_name, + "reason": "session_complete", + "metadata": {"provider": "nova_sonic"} + } + yield { + "BidirectionalConnectionEnd": session_end + } + + async def start_audio_session(self) -> None: + """Start audio input session (call once before sending audio chunks).""" + if self.audio_session_active: + return + + log_event("nova_audio_session_start") + + audio_content_start = json.dumps({ + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "type": "AUDIO", + "interactive": True, + "role": "USER", + "audioInputConfiguration": NOVA_AUDIO_INPUT_CONFIG + } + } + }) + + await self._send_nova_event(audio_content_start) + self.audio_session_active = True + + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Send audio using Nova Sonic protocol-specific format.""" + if not self._active: + return + + # Start audio session if not already active + if not self.audio_session_active: + await self.start_audio_session() + + # Update last audio time and cancel any pending silence task + self.last_audio_time = time.time() + if self.silence_task and not self.silence_task.done(): + self.silence_task.cancel() + + # Convert audio to Nova Sonic base64 format + nova_audio_data = base64.b64encode(audio_input["audioData"]).decode('utf-8') + + # Send audio input event + audio_event = json.dumps({ + "event": { + "audioInput": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "content": nova_audio_data + } + } + }) + + await self._send_nova_event(audio_event) + + # Start silence detection task + self.silence_task = asyncio.create_task(self._check_silence()) + + async def _check_silence(self): + """Check for silence and automatically end audio session.""" + try: + await asyncio.sleep(self.silence_threshold) + if self.audio_session_active and self.last_audio_time: + elapsed = time.time() - self.last_audio_time + if elapsed >= self.silence_threshold: + log_event("nova_silence_detected", elapsed=elapsed) + await self.end_audio_input() + except asyncio.CancelledError: + pass + + async def end_audio_input(self) -> None: + """End current audio input session to trigger Nova Sonic processing.""" + if not self.audio_session_active: + return + + log_event("nova_audio_session_end") + + audio_content_end = json.dumps({ + "event": { + "contentEnd": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name + } + } + }) + + await self._send_nova_event(audio_content_end) + self.audio_session_active = False + + async def send_text_content(self, text: str, **kwargs) -> None: + """Send text content using Nova Sonic format.""" + if not self._active: + return + + content_name = str(uuid.uuid4()) + events = [ + self._get_text_content_start_event(content_name), + self._get_text_input_event(content_name, text), + self._get_content_end_event(content_name) + ] + + for event in events: + await self._send_nova_event(event) + + async def send_interrupt(self) -> None: + """Send interruption signal to Nova Sonic.""" + if not self._active: + return + + # Nova Sonic handles interruption through special input events + interrupt_event = { + "event": { + "audioInput": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "stopReason": "INTERRUPTED" + } + } + } + await self._send_nova_event(interrupt_event) + + async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: + """Send tool result using Nova Sonic toolResult format.""" + if not self._active: + return + + log_event("nova_tool_result_send", id=tool_use_id) + content_name = str(uuid.uuid4()) + events = [ + self._get_tool_content_start_event(content_name, tool_use_id), + self._get_tool_result_event(content_name, result), + self._get_content_end_event(content_name) + ] + + for i, event in enumerate(events): + await time_it_async(f"send_tool_event_{i+1}", lambda: self._send_nova_event(event)) + + async def send_tool_error(self, tool_use_id: str, error: str) -> None: + """Send tool error using Nova Sonic format.""" + log_event("nova_tool_error_send", id=tool_use_id, error=error) + error_result = {"error": error} + await self.send_tool_result(tool_use_id, error_result) + + async def close(self) -> None: + """Close Nova Sonic session with proper cleanup sequence.""" + if not self._active: + return + + log_flow("nova_cleanup", "starting session close") + self._active = False + + # Cancel response processing task if running + if hasattr(self, '_response_task') and not self._response_task.done(): + self._response_task.cancel() + try: + await self._response_task + except asyncio.CancelledError: + pass + + try: + # End audio session if active + if self.audio_session_active: + await self.end_audio_input() + + # Send cleanup events + cleanup_events = [ + self._get_prompt_end_event(), + self._get_session_end_event() + ] + + for event in cleanup_events: + try: + await self._send_nova_event(event) + except Exception as e: + logger.warning("Error during Nova Sonic cleanup: %s", e) + + # Close stream + try: + await self.stream.input_stream.close() + except Exception as e: + logger.warning("Error closing Nova Sonic stream: %s", e) + + except Exception as e: + log_event("nova_cleanup_error", error=str(e)) + finally: + log_event("nova_session_closed") + + def _convert_nova_event(self, nova_event: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Convert Nova Sonic events to provider-agnostic format.""" + # Handle audio output + if "audioOutput" in nova_event: + audio_content = nova_event["audioOutput"]["content"] + audio_bytes = base64.b64decode(audio_content) + + audio_output: AudioOutputEvent = { + "audioData": audio_bytes, + "format": "pcm", + "sampleRate": 24000, + "channels": 1, + "encoding": "base64" + } + + return { + "audioOutput": audio_output + } + + # Handle text output + elif "textOutput" in nova_event: + text_content = nova_event["textOutput"]["content"] + # Use stored role from contentStart event, fallback to event role + role = getattr(self, '_current_role', nova_event["textOutput"].get("role", "assistant")) + + # Check for Nova Sonic interruption pattern (matches working sample) + if '{ "interrupted" : true }' in text_content: + log_event("nova_interruption_in_text") + interruption: InterruptionDetectedEvent = { + "reason": "user_input" + } + return { + "interruptionDetected": interruption + } + + # Show transcription for user speech - ALWAYS show these regardless of DEBUG flag + if role == "USER": + print(f"User: {text_content}") + elif role == "ASSISTANT": + print(f"Assistant: {text_content}") + + text_output: TextOutputEvent = { + "text": text_content, + "role": role.lower() + } + + return { + "textOutput": text_output + } + + # Handle tool use + elif "toolUse" in nova_event: + tool_use = nova_event["toolUse"] + + tool_use_event: ToolUse = { + "toolUseId": tool_use["toolUseId"], + "name": tool_use["toolName"], + "input": json.loads(tool_use["content"]) + } + + return { + "toolUse": tool_use_event + } + + # Handle interruption + elif nova_event.get("stopReason") == "INTERRUPTED": + log_event("nova_interruption_stop_reason") + + interruption: InterruptionDetectedEvent = { + "reason": "user_input" + } + + return { + "interruptionDetected": interruption + } + + # Handle usage events (ignore) + elif "usageEvent" in nova_event: + return None + + # Handle content start events (track role) + elif "contentStart" in nova_event: + role = nova_event["contentStart"].get("role", "unknown") + # Store role for subsequent text output events + self._current_role = role + return None + + # Handle other events + else: + return None + + # Nova Sonic event template methods + def _get_session_start_event(self) -> str: + """Generate Nova Sonic session start event.""" + return json.dumps({ + "event": { + "sessionStart": { + "inferenceConfiguration": NOVA_INFERENCE_CONFIG + } + } + }) + + def _get_prompt_start_event(self, tools: List[ToolSpec]) -> str: + """Generate Nova Sonic prompt start event with tool configuration.""" + prompt_start_event = { + "event": { + "promptStart": { + "promptName": self.prompt_name, + "textOutputConfiguration": NOVA_TEXT_CONFIG, + "audioOutputConfiguration": NOVA_AUDIO_OUTPUT_CONFIG + } + } + } + + if tools: + tool_config = self._build_tool_configuration(tools) + prompt_start_event["event"]["promptStart"]["toolUseOutputConfiguration"] = NOVA_TOOL_CONFIG + prompt_start_event["event"]["promptStart"]["toolConfiguration"] = {"tools": tool_config} + + return json.dumps(prompt_start_event) + + def _build_tool_configuration(self, tools: List[ToolSpec]) -> List[Dict]: + """Build tool configuration from tool specs.""" + tool_config = [] + for tool in tools: + input_schema = ({"json": json.dumps(tool['inputSchema']['json'])} + if 'json' in tool['inputSchema'] + else {"json": json.dumps(tool['inputSchema'])}) + + tool_config.append({ + "toolSpec": { + "name": tool["name"], + "description": tool["description"], + "inputSchema": input_schema + } + }) + return tool_config + + def _get_system_prompt_events(self, system_prompt: str) -> List[str]: + """Generate system prompt events.""" + content_name = str(uuid.uuid4()) + return [ + self._get_text_content_start_event(content_name, "SYSTEM"), + self._get_text_input_event(content_name, system_prompt), + self._get_content_end_event(content_name) + ] + + def _get_text_content_start_event(self, content_name: str, role: str = "USER") -> str: + """Generate text content start event.""" + return json.dumps({ + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": content_name, + "type": "TEXT", + "role": role, + "interactive": True, + "textInputConfiguration": NOVA_TEXT_CONFIG + } + } + }) + + def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> str: + """Generate tool content start event.""" + return json.dumps({ + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": content_name, + "interactive": False, + "type": "TOOL", + "role": "TOOL", + "toolResultInputConfiguration": { + "toolUseId": tool_use_id, + "type": "TEXT", + "textInputConfiguration": NOVA_TEXT_CONFIG + } + } + } + }) + + def _get_text_input_event(self, content_name: str, text: str) -> str: + """Generate text input event.""" + return json.dumps({ + "event": { + "textInput": { + "promptName": self.prompt_name, + "contentName": content_name, + "content": text + } + } + }) + + def _get_tool_result_event(self, content_name: str, result: Dict[str, Any]) -> str: + """Generate tool result event.""" + return json.dumps({ + "event": { + "toolResult": { + "promptName": self.prompt_name, + "contentName": content_name, + "content": json.dumps(result) + } + } + }) + + def _get_content_end_event(self, content_name: str) -> str: + """Generate content end event.""" + return json.dumps({ + "event": { + "contentEnd": { + "promptName": self.prompt_name, + "contentName": content_name + } + } + }) + + def _get_prompt_end_event(self) -> str: + """Generate prompt end event.""" + return json.dumps({ + "event": { + "promptEnd": { + "promptName": self.prompt_name + } + } + }) + + def _get_session_end_event(self) -> str: + """Generate session end event.""" + return json.dumps({ + "event": { + "sessionEnd": {} + } + }) + + async def _send_nova_event(self, event: str) -> None: + """Send event JSON string to Nova Sonic stream.""" + try: + + # Event is already a JSON string + bytes_data = event.encode('utf-8') + chunk = InvokeModelWithBidirectionalStreamInputChunk( + value=BidirectionalInputPayloadPart(bytes_=bytes_data) + ) + await self.stream.input_stream.send(chunk) + logger.debug("Successfully sent Nova Sonic event") + + except Exception as e: + logger.error("Error sending Nova Sonic event: %s", e) + logger.error("Event was: %s", event) + raise + + +class NovaSonicBidirectionalModel(BidirectionalModel): + """Nova Sonic model implementing bidirectional capabilities.""" + + def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config): + """Initialize Nova Sonic bidirectional model. + + Args: + model_id: Nova Sonic model identifier. + region: AWS region. + **config: Additional configuration. + """ + self.model_id = model_id + self.region = region + self.config = config + self._client = None + + logger.debug("Nova Sonic bidirectional model initialized: %s", model_id) + + async def create_bidirectional_connection( + self, + system_prompt: Optional[str] = None, + tools: Optional[List[ToolSpec]] = None, + messages: Optional[Messages] = None, + **kwargs + ) -> BidirectionalModelSession: + """Create Nova Sonic bidirectional session.""" + log_flow("nova_session_create", "starting") + + # Initialize client if needed + if not self._client: + await time_it_async("initialize_client", lambda: self._initialize_client()) + + # Start Nova Sonic bidirectional stream + try: + stream = await time_it_async("invoke_model_with_bidirectional_stream", + lambda: self._client.invoke_model_with_bidirectional_stream( + InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) + )) + + # Create and initialize session + session = NovaSonicSession(stream, self.config) + await time_it_async("initialize_session", + lambda: session.initialize(system_prompt, tools, messages)) + + log_event("nova_session_created") + return session + except Exception as e: + log_event("nova_session_create_error", error=str(e)) + logger.error("Failed to create Nova Sonic session: %s", e) + raise + + async def _initialize_client(self) -> None: + """Initialize Nova Sonic client.""" + try: + + config = Config( + endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", + region=self.region, + aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), + http_auth_scheme_resolver=HTTPAuthSchemeResolver(), + http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()} + ) + + self._client = BedrockRuntimeClient(config=config) + logger.debug("Nova Sonic client initialized") + + except ImportError as e: + logger.error("Nova Sonic dependencies not available: %s", e) + raise + except Exception as e: + logger.error("Error initializing Nova Sonic client: %s", e) + raise + diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py new file mode 100644 index 000000000..f35fd4462 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py @@ -0,0 +1,203 @@ +"""Simple bidirectional streaming test with enhanced interruption support.""" + +import asyncio +import time +import pyaudio + +from src.strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from src.strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel +from strands_tools import calculator + + +async def play(context): + """Play audio output with responsive interruption support.""" + audio = pyaudio.PyAudio() + speaker = audio.open( + channels=1, + format=pyaudio.paInt16, + output=True, + rate=24000, + frames_per_buffer=1024, + ) + + try: + while context["active"]: + try: + # Check for interruption first + if context.get("interrupted", False): + # Clear entire audio queue immediately + while not context["audio_out"].empty(): + try: + context["audio_out"].get_nowait() + except asyncio.QueueEmpty: + break + + context["interrupted"] = False + await asyncio.sleep(0.05) + continue + + # Get next audio data + audio_data = await asyncio.wait_for( + context["audio_out"].get(), + timeout=0.1 + ) + + if audio_data and context["active"]: + chunk_size = 1024 + for i in range(0, len(audio_data), chunk_size): + # Check for interruption before each chunk + if context.get("interrupted", False) or not context["active"]: + break + + end = min(i + chunk_size, len(audio_data)) + chunk = audio_data[i:end] + speaker.write(chunk) + await asyncio.sleep(0.001) + + except asyncio.TimeoutError: + continue # No audio available + except asyncio.QueueEmpty: + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + finally: + speaker.close() + audio.terminate() + + +async def record(context): + """Record audio input from microphone.""" + audio = pyaudio.PyAudio() + microphone = audio.open( + channels=1, + format=pyaudio.paInt16, + frames_per_buffer=1024, + input=True, + rate=16000, + ) + + try: + while context["active"]: + try: + audio_bytes = microphone.read(1024, exception_on_overflow=False) + context["audio_in"].put_nowait(audio_bytes) + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + except asyncio.CancelledError: + pass + finally: + microphone.close() + audio.terminate() + + +async def receive(agent, context): + """Receive and process events from agent.""" + try: + async for event in agent.receive(): + # Handle audio output + if "audioOutput" in event: + if not context.get("interrupted", False): + context["audio_out"].put_nowait(event["audioOutput"]["audioData"]) + + # Handle interruption events + elif "interruptionDetected" in event: + context["interrupted"] = True + elif "interrupted" in event: + context["interrupted"] = True + + # Handle text output with interruption detection + elif "textOutput" in event: + text_content = event["textOutput"].get("content", "") + role = event["textOutput"].get("role", "unknown") + + # Check for text-based interruption patterns + if '{ "interrupted" : true }' in text_content: + context["interrupted"] = True + elif "interrupted" in text_content.lower(): + context["interrupted"] = True + + # Log text output + if role.upper() == "USER": + print(f"User: {text_content}") + elif role.upper() == "ASSISTANT": + print(f"Assistant: {text_content}") + + except asyncio.CancelledError: + pass + + +async def send(agent, context): + """Send audio input to agent.""" + try: + while time.time() - context["start_time"] < context["duration"]: + try: + audio_bytes = context["audio_in"].get_nowait() + audio_event = { + "audioData": audio_bytes, + "format": "pcm", + "sampleRate": 16000 + } + await agent.send_audio(audio_event) + except asyncio.QueueEmpty: + await asyncio.sleep(0.01) # Restored to working timing + except asyncio.CancelledError: + break + + context["active"] = False + except asyncio.CancelledError: + pass + + +async def main(duration=180): + """Main function for bidirectional streaming test.""" + print("Starting bidirectional streaming test...") + print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") + + # Initialize model and agent + model = NovaSonicBidirectionalModel(region="us-east-1") + agent = BidirectionalAgent( + model=model, + tools=[calculator], + system_prompt="You are a helpful assistant." + ) + + await agent.start_conversation() + + # Create shared context for all tasks + context = { + "active": True, + "audio_in": asyncio.Queue(), + "audio_out": asyncio.Queue(), + "session": agent._session, + "duration": duration, + "start_time": time.time(), + "interrupted": False, + } + + print("Speak into microphone. Press Ctrl+C to exit.") + + try: + # Run all tasks concurrently + await asyncio.gather( + play(context), + record(context), + receive(agent, context), + send(agent, context), + return_exceptions=True + ) + except KeyboardInterrupt: + print("\nInterrupted by user") + except asyncio.CancelledError: + print("\nTest cancelled") + finally: + print("Cleaning up...") + context["active"] = False + await agent.end_conversation() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py new file mode 100644 index 000000000..f6441d2f0 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -0,0 +1,3 @@ +"""Bidirectional streaming types package.""" +# Types package + diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py new file mode 100644 index 000000000..2b1480e62 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -0,0 +1,167 @@ +"""Bidirectional streaming types for real-time audio/text conversations. + +PROBLEM ADDRESSED: +----------------- +Strands currently uses a request-response architecture without bidirectional streaming +support. Users cannot interrupt ongoing responses, provide additional context during +processing, or engage in real-time conversations. Each interaction requires a complete +request-response cycle. + +ARCHITECTURAL TRANSFORMATION: +---------------------------- +Current Limitations: Strands' unidirectional architecture follows sequential +request-response cycles that prevent real-time interaction. This represents a +pull-based architecture where the model receives the request, processes it, and +sends a response back. + +Bidirectional Solution: Uses persistent session-based connections with continuous +input and output flow. This implements a push-based architecture where the model +sends updates to the client as soon as response becomes available, without explicit +client requests. + +KEY CHARACTERISTICS: +------------------- +- Persistent Sessions: Connections remain open for extended periods (Nova Sonic: 8 minutes, + Google Live API: 15 minutes, OpenAI Realtime: 30 minutes) maintaining conversation context +- Bidirectional Communication: Users can send input while models generate responses +- Interruption Handling: Users can interrupt ongoing model responses in real-time without + terminating the session +- Tool Execution: Tools execute concurrently within the conversation flow rather than + requiring requests rebuilding + +PROVIDER NORMALIZATION: +---------------------- +Must normalize incompatible audio formats: Nova Sonic's hex-encoded base64, Google's +LINEAR16 PCM, OpenAI's Base64-encoded PCM16. Requires unified interruption event types +to handle Nova Sonic's stopReason = INTERRUPTED events, Google's VAD cancellation, and +OpenAI's conversation.item.truncate. + +This module extends existing StreamEvent types while maintaining backward compatibility +with existing Strands streaming patterns. +""" + +from typing import Any, Dict, Literal, Optional + +from strands.types.content import Role +from strands.types.streaming import StreamEvent +from typing_extensions import TypedDict + +# Audio format constants +SUPPORTED_AUDIO_FORMATS = ['pcm', 'wav', 'opus', 'mp3'] +SUPPORTED_SAMPLE_RATES = [16000, 24000, 48000] +SUPPORTED_CHANNELS = [1, 2] # 1=mono, 2=stereo +DEFAULT_SAMPLE_RATE = 16000 +DEFAULT_CHANNELS = 1 + +class AudioOutputEvent(TypedDict): + """Audio output event from the model. + + Standardizes audio output across different providers using raw bytes + instead of provider-specific encodings (base64, hex, etc.). + + Attributes: + audioData: Raw audio bytes (not base64 or hex encoded). + format: Audio format from SUPPORTED_AUDIO_FORMATS. + sampleRate: Sample rate from SUPPORTED_SAMPLE_RATES. + channels: Channel count from SUPPORTED_CHANNELS. + encoding: Original provider encoding for debugging purposes. + """ + + audioData: bytes + format: Literal['pcm', 'wav', 'opus', 'mp3'] + sampleRate: Literal[16000, 24000, 48000] + channels: Literal[1, 2] + encoding: Optional[str] + + +class AudioInputEvent(TypedDict): + """Audio input event for sending audio to the model. + + Used when sending audio data through send_audio() method. + + Attributes: + audioData: Raw audio bytes to send to model. + format: Audio format from SUPPORTED_AUDIO_FORMATS. + sampleRate: Sample rate from SUPPORTED_SAMPLE_RATES. + channels: Channel count from SUPPORTED_CHANNELS. + """ + + audioData: bytes + format: Literal['pcm', 'wav', 'opus', 'mp3'] + sampleRate: Literal[16000, 24000, 48000] + channels: Literal[1, 2] + + +class TextOutputEvent(TypedDict): + """Text output event from the model during bidirectional streaming. + + Attributes: + text: The text content from the model. + role: The role of the message sender. + """ + + text: str + role: Role + + +class InterruptionDetectedEvent(TypedDict): + """Interruption detection event. + + Signals when user interruption is detected during model generation. + + Attributes: + reason: Interruption reason from predefined set. + """ + + reason: Literal['user_input', 'vad_detected', 'manual'] + + +class BidirectionalConnectionStartEvent(TypedDict, total=False): + """Session start event for bidirectional streaming. + + Attributes: + sessionId: Unique session identifier. + metadata: Provider-specific session metadata. + """ + + sessionId: Optional[str] + metadata: Optional[Dict[str, Any]] + + +class BidirectionalConnectionEndEvent(TypedDict): + """Session end event for bidirectional streaming. + + Attributes: + reason: Reason for session end from predefined set. + sessionId: Unique session identifier. + metadata: Provider-specific session metadata. + """ + + reason: Literal['user_request', 'timeout', 'error'] + sessionId: Optional[str] + metadata: Optional[Dict[str, Any]] + + +class BidirectionalStreamEvent(StreamEvent, total=False): + """Bidirectional stream event extending existing StreamEvent. + + Inherits all existing StreamEvent fields (contentBlockDelta, toolUse, + messageStart, etc.) while adding bidirectional-specific events. + Maintains full backward compatibility with existing Strands streaming. + + Attributes: + audioOutput: Audio output from the model. + audioInput: Audio input sent to the model. + textOutput: Text output from the model. + interruptionDetected: User interruption detection. + BidirectionalConnectionStart: Session start event. + BidirectionalConnectionEnd: Session end event. + """ + + audioOutput: AudioOutputEvent + audioInput: AudioInputEvent + textOutput: TextOutputEvent + interruptionDetected: InterruptionDetectedEvent + BidirectionalConnectionStart: BidirectionalConnectionStartEvent + BidirectionalConnectionEnd: BidirectionalConnectionEndEvent + diff --git a/src/strands/experimental/bidirectional_streaming/utils/debug.py b/src/strands/experimental/bidirectional_streaming/utils/debug.py new file mode 100644 index 000000000..1e88b6ead --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/utils/debug.py @@ -0,0 +1,45 @@ +"""Debug utilities for Strands bidirectional streaming. + +Provides consistent debug logging across all bidirectional streaming components +with configurable output control matching the Nova Sonic tool use example. +""" + +import datetime +import inspect +import time + +# Debug logging system matching successful tool use example +DEBUG = False # Disable debug logging for clean output like tool use example + +def debug_print(message): + """Print debug message with timestamp and function name.""" + if DEBUG: + function_name = inspect.stack()[1].function + if function_name == 'time_it_async': + function_name = inspect.stack()[2].function + timestamp = '{:%Y-%m-%d %H:%M:%S.%f}'.format(datetime.datetime.now())[:-3] + print(f"{timestamp} {function_name} {message}") + +def log_event(event_type, **context): + """Log important events with structured context.""" + if DEBUG: + function_name = inspect.stack()[1].function + timestamp = '{:%Y-%m-%d %H:%M:%S.%f}'.format(datetime.datetime.now())[:-3] + context_str = " ".join([f"{k}={v}" for k, v in context.items()]) if context else "" + print(f"{timestamp} {function_name} EVENT: {event_type} {context_str}") + +def log_flow(step, details=""): + """Log important flow steps without excessive detail.""" + if DEBUG: + function_name = inspect.stack()[1].function + timestamp = '{:%Y-%m-%d %H:%M:%S.%f}'.format(datetime.datetime.now())[:-3] + print(f"{timestamp} {function_name} FLOW: {step} {details}") + +async def time_it_async(label, method_to_run): + """Time asynchronous method execution.""" + start_time = time.perf_counter() + result = await method_to_run() + end_time = time.perf_counter() + debug_print(f"Execution time for {label}: {end_time - start_time:.4f} seconds") + return result + From 9165a2074eaa3a35f1e7df01ddfdd04c7d6e523a Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 30 Sep 2025 10:41:16 -0400 Subject: [PATCH 002/242] Updated doc strings, updated method from send_text() and send_audio() to send(), Updated imports --- pyproject.toml | 2 +- .../bidirectional_streaming/agent/agent.py | 105 +++++++------ .../event_loop/bidirectional_event_loop.py | 62 ++++---- .../models/bidirectional_model.py | 75 +++++----- .../models/novasonic.py | 141 +++++++++--------- .../tests/test_bidirectional_streaming.py | 26 +++- .../types/bidirectional_streaming.py | 86 ++++------- 7 files changed, 234 insertions(+), 263 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d4f7e6eee..dd01ebde3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ a2a = [ "fastapi>=0.115.12,<1.0.0", "starlette>=0.46.2,<1.0.0", ] -all = ["strands-agents[a2a,anthropic,bidirectional-streaming,docs,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] +all = ["strands-agents[a2a,anthropic,docs,gemini,bidirectional-streaming,docs,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] dev = [ "commitizen>=4.4.0,<5.0.0", diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index cfc005576..023997551 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -1,30 +1,22 @@ """Bidirectional Agent for real-time streaming conversations. -AGENT PURPOSE: -------------- -Provides type-safe constructor and session management for real-time audio/text -interaction. Serves as the bidirectional equivalent to invoke_async() → stream_async() -but establishes sessions that continue indefinitely with concurrent task management. +Provides real-time audio and text interaction through persistent streaming sessions. +Unlike traditional request-response patterns, this agent maintains long-running +conversations where users can interrupt, provide additional input, and receive +continuous responses including audio output. -ARCHITECTURAL APPROACH: ----------------------- -While invoke_async() creates single request-response cycles that terminate after -stop_reason: "end_turn" with sequential tool processing, start_conversation() -establishes persistent sessions with concurrent processing of model events, tool -execution, and user input without session termination. - -DESIGN CHOICE: -------------- -Uses dedicated BidirectionalAgent class (Option 1 from design document) for: -- Type safety with no conditional behavior based on model type -- Separation of concerns - solely focused on bidirectional streaming -- Future proofing - allows changes without implications to existing Agent class +Key capabilities: +- Persistent conversation sessions with concurrent processing +- Real-time audio input/output streaming +- Mid-conversation interruption and tool execution +- Event-driven communication with model providers """ import asyncio import logging -from typing import AsyncIterable, List, Optional +from typing import AsyncIterable, List, Optional, Union +from strands.tools.executors import ConcurrentToolExecutor from strands.tools.registry import ToolRegistry from strands.types.content import Messages @@ -39,8 +31,8 @@ class BidirectionalAgent: """Agent for bidirectional streaming conversations. - Provides type-safe constructor and session management for real-time - audio/text interaction with concurrent processing capabilities. + Enables real-time audio and text interaction with AI models through persistent + sessions. Supports concurrent tool execution and interruption handling. """ def __init__( @@ -69,60 +61,63 @@ def __init__( self.tool_registry.initialize_tools() # Initialize tool executor for concurrent execution - from strands.tools.executors import ConcurrentToolExecutor self.tool_executor = ConcurrentToolExecutor() # Session management self._session = None self._output_queue = asyncio.Queue() - async def start_conversation(self) -> None: - """Initialize persistent bidirectional session for real-time interaction. + async def start(self) -> None: + """Start a persistent bidirectional conversation session. - Creates provider-specific session and starts concurrent background tasks - for model events, tool execution, and session lifecycle management. + Initializes the streaming session and starts background tasks for processing + model events, tool execution, and session management. Raises: ValueError: If conversation already active. ConnectionError: If session creation fails. """ if self._session and self._session.active: - raise ValueError("Conversation already active. Call end_conversation() first.") + raise ValueError("Conversation already active. Call end() first.") log_flow("conversation_start", "initializing session") self._session = await start_bidirectional_connection(self) log_event("conversation_ready") - async def send_text(self, text: str) -> None: - """Send text input during active session without interrupting model generation. + async def send(self, input_data: Union[str, AudioInputEvent]) -> None: + """Send input to the model (text or audio). - Args: - text: Text message to send to the model. - - Raises: - ValueError: If no active session. - """ - self._validate_active_session() - log_event("text_sent", length=len(text)) - await self._session.model_session.send_text_content(text) - - async def send_audio(self, audio_input: AudioInputEvent) -> None: - """Send audio input during active session for real-time speech interaction. + Unified method for sending both text and audio input to the model during + an active conversation session. Args: - audio_input: AudioInputEvent containing audio data and configuration. + input_data: Either a string for text input or AudioInputEvent for audio input. Raises: - ValueError: If no active session. + ValueError: If no active session or invalid input type. """ self._validate_active_session() - await self._session.model_session.send_audio_content(audio_input) + + if isinstance(input_data, str): + # Handle text input + log_event("text_sent", length=len(input_data)) + await self._session.model_session.send_text_content(input_data) + elif isinstance(input_data, dict) and "audioData" in input_data: + # Handle audio input (AudioInputEvent) + await self._session.model_session.send_audio_content(input_data) + else: + raise ValueError( + "Input must be either a string (text) or AudioInputEvent " + "(dict with audioData, format, sampleRate, channels)" + ) + + async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: - """Receive output events from the model including audio, text. + """Receive events from the model including audio, text, and tool calls. - Provides access to model output events processed by background tasks. - Events include audio output, text responses, tool calls, and session updates. + Yields model output events processed by background tasks including audio output, + text responses, tool calls, and session updates. Yields: BidirectionalStreamEvent: Events from the model session. @@ -135,10 +130,10 @@ async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: continue async def interrupt(self) -> None: - """Interrupt current model generation and switch to listening mode. + """Interrupt the current model generation and clear audio buffers. - Sends interruption signal to immediately stop generation and clear - pending audio output for responsive conversational experience. + Sends interruption signal to stop generation immediately and clears + pending audio output for responsive conversation flow. Raises: ValueError: If no active session. @@ -146,11 +141,11 @@ async def interrupt(self) -> None: self._validate_active_session() await self._session.model_session.send_interrupt() - async def end_conversation(self) -> None: - """End session and cleanup resources including background tasks. + async def end(self) -> None: + """End the conversation session and cleanup all resources. - Performs graceful session termination with proper resource cleanup - including background task cancellation and connection closure. + Terminates the streaming session, cancels background tasks, and + closes the connection to the model provider. """ if self._session: await stop_bidirectional_connection(self._session) @@ -163,5 +158,5 @@ def _validate_active_session(self) -> None: ValueError: If no active session. """ if not self._session or not self._session.active: - raise ValueError("No active conversation. Call start_conversation() first.") + raise ValueError("No active conversation. Call start() first.") diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 2164115d8..3884750d5 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -1,16 +1,14 @@ """Bidirectional session management for concurrent streaming conversations. -SESSION PURPOSE: ---------------- -Session wrapper for bidirectional communication that manages concurrent tasks for -model events, tool execution, and audio processing while providing simple interface -for Agent interaction. +Manages bidirectional communication sessions with concurrent processing of model events, +tool execution, and audio processing. Provides coordination between background tasks +while maintaining a simple interface for agent interaction. -CONCURRENT ARCHITECTURE: ------------------------ -Unlike existing event_loop_cycle() that processes events sequentially where tool -execution blocks conversation, this module coordinates concurrent tasks through -asyncio queues and background task management. +Features: +- Concurrent task management for model events and tool execution +- Interruption handling with audio buffer clearing +- Tool execution with cancellation support +- Session lifecycle management """ import asyncio @@ -35,10 +33,10 @@ class BidirectionalConnection: - """Session wrapper for bidirectional communication. + """Session wrapper for bidirectional communication with concurrent task management. - Manages concurrent tasks for model events, tool execution, and audio processing - while providing simple interface for Agent interaction. + Coordinates background tasks for model event processing, tool execution, and audio + handling while providing a simple interface for agent interactions. """ def __init__(self, model_session: BidirectionalModelSession, agent): @@ -66,8 +64,8 @@ def __init__(self, model_session: BidirectionalModelSession, agent): async def start_bidirectional_connection(agent) -> BidirectionalConnection: """Initialize bidirectional session with concurrent background tasks. - Creates provider-specific session and starts concurrent tasks for model events, - tool execution, and session lifecycle management. + Creates a model-specific session and starts background tasks for processing + model events, executing tools, and managing the session lifecycle. Args: agent: BidirectionalAgent instance. @@ -147,11 +145,10 @@ async def stop_bidirectional_connection(session: BidirectionalConnection) -> Non async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None: - """Main bidirectional event loop coordinator - runs continuously during session. + """Main event loop coordinator that runs continuously during the session. - Coordinates background tasks and manages session lifecycle. Unlike the - sequential event_loop_cycle() that processes events one by one, this coordinator - manages concurrent tasks and session state. + Monitors background tasks, manages session state, and handles session lifecycle. + Provides supervision for concurrent model event processing and tool execution. Args: session: BidirectionalConnection to coordinate. @@ -185,10 +182,10 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No async def _handle_interruption(session: BidirectionalConnection) -> None: - """Handle interruption detection with comprehensive task cancellation. + """Handle interruption detection with task cancellation and audio buffer clearing. - Sets interruption flag, cancels pending tool tasks, and aggressively - clears audio output queue following Nova Sonic example patterns. + Cancels pending tool tasks and clears audio output queues to ensure responsive + interruption handling during conversations. Args: session: BidirectionalConnection to handle interruption for. @@ -251,10 +248,10 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: async def _process_model_events(session: BidirectionalConnection) -> None: - """Process model events using existing Strands event types. + """Process model events and convert them to Strands format. - This background task handles all model responses and converts - them to existing StreamEvent format for integration with Strands. + Background task that handles all model responses, converts provider-specific + events to standardized formats, and manages interruption detection. Args: session: BidirectionalConnection containing model session. @@ -309,11 +306,11 @@ async def _process_model_events(session: BidirectionalConnection) -> None: async def _process_tool_execution(session: BidirectionalConnection) -> None: - """Execute tools concurrently using existing Strands infrastructure with barge-in support. + """Execute tools concurrently with interruption support. - This background task manages tool execution without blocking - model event processing or user interaction. Includes proper - task cleanup and cancellation handling. + Background task that manages tool execution without blocking model event + processing or user interaction. Includes proper task cleanup and cancellation + handling for interruptions. Args: session: BidirectionalConnection containing tool queue. @@ -396,11 +393,10 @@ def _convert_to_strands_event(provider_event: Dict) -> Dict: async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: Dict) -> None: - """Execute tool using existing Strands infrastructure with barge-in support. + """Execute tool using Strands infrastructure with interruption support. - Model-agnostic tool execution that uses existing Strands tool system, - handles interruption during execution, and delegates result formatting - to provider-specific session. + Executes tools using the existing Strands tool system, handles interruption + during execution, and sends results back to the model provider. Args: session: BidirectionalConnection for context. diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 32727105d..81e5cd9d6 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -1,24 +1,14 @@ """Bidirectional model interface for real-time streaming conversations. -INTERFACE PURPOSE: ------------------ -Declares bidirectional capabilities separate from existing Model hierarchy to maintain -clean separation of concerns. Models choose to implement this interface explicitly -for bidirectional streaming support. +Defines the interface for models that support bidirectional streaming capabilities. +Provides abstractions for different model providers with connection-based communication +patterns that support real-time audio and text interaction. -PROVIDER ABSTRACTION: --------------------- -Abstracts incompatible initialization patterns: Nova Sonic's event-driven sequences, -Google's WebSocket setup, OpenAI's dual protocol support. Normalizes different tool -calling approaches and handles provider-specific session management with varying -time limits and connection patterns. - -SESSION-BASED APPROACH: ----------------------- -Unlike existing Model interface's stateless request-response pattern where each -stream() call processes complete messages independently, BidirectionalModel introduces -session-based approach where create_bidirectional_connection() establishes persistent -connections supporting real-time bidirectional communication during active generation. +Features: +- connection-based persistent connections +- Real-time bidirectional communication +- Provider-agnostic event normalization +- Tool execution integration """ import abc @@ -32,51 +22,54 @@ logger = logging.getLogger(__name__) class BidirectionalModelSession(abc.ABC): - """Model-specific session interface for bidirectional communication.""" + """Abstract interface for model-specific bidirectional communication connections. + + Defines the contract for managing persistent streaming connections with individual + model providers, handling audio/text input, receiving events, and managing + tool execution results. + """ @abc.abstractmethod async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: - """Receive events from model in provider-agnostic format. + """Receive events from the model in standardized format. - Normalizes different provider event formats so the event loop - can process all providers uniformly. + Converts provider-specific events to a common format that can be + processed uniformly by the event loop. """ raise NotImplementedError @abc.abstractmethod async def send_audio_content(self, audio_input: AudioInputEvent) -> None: - """Send audio content to model during session. + """Send audio content to the model during an active connection. - Manages complex audio encoding and provider-specific event sequences - while presenting simple AudioInputEvent interface to Agent. + Handles audio encoding and provider-specific formatting while presenting + a simple AudioInputEvent interface. """ raise NotImplementedError @abc.abstractmethod async def send_text_content(self, text: str, **kwargs) -> None: - """Send text content processed concurrently with ongoing generation. + """Send text content to the model during ongoing generation. - Enables natural interruption and follow-up questions without session restart. + Allows natural interruption and follow-up questions without requiring + connection restart. """ raise NotImplementedError @abc.abstractmethod async def send_interrupt(self) -> None: - """Send interruption signal to immediately stop generation. + """Send interruption signal to stop generation immediately. - Critical for responsive conversational experiences where users - can naturally interrupt mid-response. + Enables responsive conversational experiences where users can + naturally interrupt during model responses. """ raise NotImplementedError @abc.abstractmethod async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: - """Send tool execution result to model in provider-specific format. + """Send tool execution result to the model. - Each provider handles result formatting according to their protocol: - - Nova Sonic: toolResult events with JSON content - - Google Live API: toolResponse with specific structure - - OpenAI Realtime: function call responses with call_id correlation + Formats and sends tool results according to the provider's specific protocol. """ raise NotImplementedError @@ -87,15 +80,15 @@ async def send_tool_error(self, tool_use_id: str, error: str) -> None: @abc.abstractmethod async def close(self) -> None: - """Close session and cleanup resources with graceful termination.""" + """Close the connection and cleanup resources.""" raise NotImplementedError class BidirectionalModel(abc.ABC): """Interface for models that support bidirectional streaming. - Separate from Model to maintain clean separation of concerns. - Models choose to implement this interface explicitly. + Defines the contract for creating persistent streaming connections that support + real-time audio and text communication with AI models. """ @abc.abstractmethod @@ -106,10 +99,10 @@ async def create_bidirectional_connection( messages: Optional[Messages] = None, **kwargs ) -> BidirectionalModelSession: - """Create bidirectional session with model-specific implementation. + """Create a bidirectional connection with the model. - Abstracts complex provider-specific initialization while presenting - uniform interface to Agent. + Establishes a persistent connection for real-time communication while + abstracting provider-specific initialization requirements. """ raise NotImplementedError diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index ba71cd4d3..4332181b5 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -1,23 +1,15 @@ """Nova Sonic bidirectional model provider for real-time streaming conversations. -PROVIDER PURPOSE: ----------------- -Implements BidirectionalModel and BidirectionalModelSession interfaces for Nova Sonic, -handling the complex three-tier event management and structured event cleanup sequences -required by Nova Sonic's InvokeModelWithBidirectionalStream protocol. +Implements the BidirectionalModel interface for Amazon's Nova Sonic, handling the +complex event sequencing and audio processing required by Nova Sonic's +InvokeModelWithBidirectionalStream protocol. -NOVA SONIC SPECIFICS: --------------------- -- Requires hierarchical event sequences: sessionStart → promptStart → content streaming -- Uses hex-encoded base64 audio format that needs conversion to raw bytes -- Implements toolUse/toolResult with content containers and identifier tracking -- Manages 8-minute session limits with proper cleanup sequences -- Handles stopReason: "INTERRUPTED" events for interruption detection - -INTEGRATION APPROACH: --------------------- -Adapts existing Nova Sonic sample patterns to work with Strands bidirectional -infrastructure while maintaining provider-specific protocol requirements. +Nova Sonic specifics: +- Hierarchical event sequences: connectionStart → promptStart → content streaming +- Base64-encoded audio format with hex encoding +- Tool execution with content containers and identifier tracking +- 8-minute connection limits with proper cleanup sequences +- Interruption detection through stopReason events """ import asyncio @@ -85,10 +77,15 @@ class NovaSonicSession(BidirectionalModelSession): - """Nova Sonic session handling protocol-specific details.""" + """Nova Sonic connection implementation handling the provider's specific protocol. + + Manages Nova Sonic's complex event sequencing, audio format conversion, and + tool execution patterns while providing the standard BidirectionalModelSession + interface. + """ def __init__(self, stream, config: Dict[str, Any]): - """Initialize Nova Sonic session. + """Initialize Nova Sonic connection. Args: stream: Nova Sonic bidirectional stream. @@ -103,8 +100,8 @@ def __init__(self, stream, config: Dict[str, Any]): self.audio_content_name = str(uuid.uuid4()) self.text_content_name = str(uuid.uuid4()) - # Audio session state - self.audio_session_active = False + # Audio connection state + self.audio_connection_active = False self.last_audio_time = None self.silence_threshold = SILENCE_THRESHOLD self.silence_task = None @@ -114,7 +111,7 @@ def __init__(self, stream, config: Dict[str, Any]): logger.error("Stream is None") raise ValueError("Stream cannot be None") - logger.debug("Nova Sonic session initialized with prompt: %s", self.prompt_name) + logger.debug("Nova Sonic connection initialized with prompt: %s", self.prompt_name) async def initialize( self, @@ -122,7 +119,7 @@ async def initialize( tools: Optional[List[ToolSpec]] = None, messages: Optional[Messages] = None ) -> None: - """Initialize Nova Sonic session with required protocol sequence.""" + """Initialize Nova Sonic connection with required protocol sequence.""" try: system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." @@ -131,7 +128,7 @@ async def initialize( log_flow("nova_init", f"sending {len(init_events)} events") await self._send_initialization_events(init_events) - log_event("nova_session_initialized") + log_event("nova_connection_initialized") self._response_task = asyncio.create_task(self._process_responses()) except Exception as e: @@ -142,7 +139,7 @@ def _build_initialization_events(self, system_prompt: str, tools: List[ToolSpec] messages: Optional[Messages]) -> List[str]: """Build the sequence of initialization events.""" events = [ - self._get_session_start_event(), + self._get_connection_start_event(), self._get_prompt_start_event(tools) ] @@ -223,13 +220,13 @@ async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: log_flow("nova_events", "starting event stream") - # Emit session start event to Strands event system - session_start: BidirectionalConnectionStartEvent = { - "sessionId": self.prompt_name, + # Emit connection start event to Strands event system + connection_start: BidirectionalConnectionStartEvent = { + "connectionId": self.prompt_name, "metadata": {"provider": "nova_sonic", "model_id": self.config.get("model_id")} } yield { - "BidirectionalConnectionStart": session_start + "BidirectionalConnectionStart": connection_start } # Initialize event queue if not already done @@ -255,22 +252,22 @@ async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: logger.error("Error receiving Nova Sonic event: %s", e) logger.error(traceback.format_exc()) finally: - # Emit session end event when exiting - session_end: BidirectionalConnectionEndEvent = { - "sessionId": self.prompt_name, - "reason": "session_complete", + # Emit connection end event when exiting + connection_end: BidirectionalConnectionEndEvent = { + "connectionId": self.prompt_name, + "reason": "connection_complete", "metadata": {"provider": "nova_sonic"} } yield { - "BidirectionalConnectionEnd": session_end + "BidirectionalConnectionEnd": connection_end } - async def start_audio_session(self) -> None: - """Start audio input session (call once before sending audio chunks).""" - if self.audio_session_active: + async def start_audio_connection(self) -> None: + """Start audio input connection (call once before sending audio chunks).""" + if self.audio_connection_active: return - log_event("nova_audio_session_start") + log_event("nova_audio_connection_start") audio_content_start = json.dumps({ "event": { @@ -286,16 +283,16 @@ async def start_audio_session(self) -> None: }) await self._send_nova_event(audio_content_start) - self.audio_session_active = True + self.audio_connection_active = True async def send_audio_content(self, audio_input: AudioInputEvent) -> None: """Send audio using Nova Sonic protocol-specific format.""" if not self._active: return - # Start audio session if not already active - if not self.audio_session_active: - await self.start_audio_session() + # Start audio connection if not already active + if not self.audio_connection_active: + await self.start_audio_connection() # Update last audio time and cancel any pending silence task self.last_audio_time = time.time() @@ -322,10 +319,10 @@ async def send_audio_content(self, audio_input: AudioInputEvent) -> None: self.silence_task = asyncio.create_task(self._check_silence()) async def _check_silence(self): - """Check for silence and automatically end audio session.""" + """Check for silence and automatically end audio connection.""" try: await asyncio.sleep(self.silence_threshold) - if self.audio_session_active and self.last_audio_time: + if self.audio_connection_active and self.last_audio_time: elapsed = time.time() - self.last_audio_time if elapsed >= self.silence_threshold: log_event("nova_silence_detected", elapsed=elapsed) @@ -334,11 +331,11 @@ async def _check_silence(self): pass async def end_audio_input(self) -> None: - """End current audio input session to trigger Nova Sonic processing.""" - if not self.audio_session_active: + """End current audio input connection to trigger Nova Sonic processing.""" + if not self.audio_connection_active: return - log_event("nova_audio_session_end") + log_event("nova_audio_connection_end") audio_content_end = json.dumps({ "event": { @@ -350,7 +347,7 @@ async def end_audio_input(self) -> None: }) await self._send_nova_event(audio_content_end) - self.audio_session_active = False + self.audio_connection_active = False async def send_text_content(self, text: str, **kwargs) -> None: """Send text content using Nova Sonic format.""" @@ -407,11 +404,11 @@ async def send_tool_error(self, tool_use_id: str, error: str) -> None: await self.send_tool_result(tool_use_id, error_result) async def close(self) -> None: - """Close Nova Sonic session with proper cleanup sequence.""" + """Close Nova Sonic connection with proper cleanup sequence.""" if not self._active: return - log_flow("nova_cleanup", "starting session close") + log_flow("nova_cleanup", "starting connection close") self._active = False # Cancel response processing task if running @@ -423,14 +420,14 @@ async def close(self) -> None: pass try: - # End audio session if active - if self.audio_session_active: + # End audio connection if active + if self.audio_connection_active: await self.end_audio_input() # Send cleanup events cleanup_events = [ self._get_prompt_end_event(), - self._get_session_end_event() + self._get_connection_end_event() ] for event in cleanup_events: @@ -448,7 +445,7 @@ async def close(self) -> None: except Exception as e: log_event("nova_cleanup_error", error=str(e)) finally: - log_event("nova_session_closed") + log_event("nova_connection_closed") def _convert_nova_event(self, nova_event: Dict[str, Any]) -> Optional[Dict[str, Any]]: """Convert Nova Sonic events to provider-agnostic format.""" @@ -542,8 +539,8 @@ def _convert_nova_event(self, nova_event: Dict[str, Any]) -> Optional[Dict[str, return None # Nova Sonic event template methods - def _get_session_start_event(self) -> str: - """Generate Nova Sonic session start event.""" + def _get_connection_start_event(self) -> str: + """Generate Nova Sonic connection start event.""" return json.dumps({ "event": { "sessionStart": { @@ -676,11 +673,11 @@ def _get_prompt_end_event(self) -> str: } }) - def _get_session_end_event(self) -> str: - """Generate session end event.""" + def _get_connection_end_event(self) -> str: + """Generate connection end event.""" return json.dumps({ "event": { - "sessionEnd": {} + "connectionEnd": {} } }) @@ -703,7 +700,11 @@ async def _send_nova_event(self, event: str) -> None: class NovaSonicBidirectionalModel(BidirectionalModel): - """Nova Sonic model implementing bidirectional capabilities.""" + """Nova Sonic model implementation for bidirectional streaming. + + Provides access to Amazon's Nova Sonic model through the bidirectional + streaming interface, handling AWS authentication and connection management. + """ def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config): """Initialize Nova Sonic bidirectional model. @@ -727,8 +728,8 @@ async def create_bidirectional_connection( messages: Optional[Messages] = None, **kwargs ) -> BidirectionalModelSession: - """Create Nova Sonic bidirectional session.""" - log_flow("nova_session_create", "starting") + """Create Nova Sonic bidirectional connection.""" + log_flow("nova_connection_create", "starting") # Initialize client if needed if not self._client: @@ -741,16 +742,16 @@ async def create_bidirectional_connection( InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) )) - # Create and initialize session - session = NovaSonicSession(stream, self.config) - await time_it_async("initialize_session", - lambda: session.initialize(system_prompt, tools, messages)) + # Create and initialize connection + connection = NovaSonicSession(stream, self.config) + await time_it_async("initialize_connection", + lambda: connection.initialize(system_prompt, tools, messages)) - log_event("nova_session_created") - return session + log_event("nova_connection_created") + return connection except Exception as e: - log_event("nova_session_create_error", error=str(e)) - logger.error("Failed to create Nova Sonic session: %s", e) + log_event("nova_connection_create_error", error=str(e)) + logger.error("Failed to create Nova Sonic connection: %s", e) raise async def _initialize_client(self) -> None: diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py index f35fd4462..d650aba9b 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py @@ -1,11 +1,20 @@ -"""Simple bidirectional streaming test with enhanced interruption support.""" +"""Test suite for bidirectional streaming with real-time audio interaction. + +Tests the complete bidirectional streaming system including audio input/output, +interruption handling, and concurrent tool execution using Nova Sonic. +""" import asyncio +import sys +from pathlib import Path + +# Add the src directory to Python path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) import time import pyaudio -from src.strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from src.strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel +from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel from strands_tools import calculator @@ -139,9 +148,10 @@ async def send(agent, context): audio_event = { "audioData": audio_bytes, "format": "pcm", - "sampleRate": 16000 + "sampleRate": 16000, + "channels": 1 } - await agent.send_audio(audio_event) + await agent.send(audio_event) except asyncio.QueueEmpty: await asyncio.sleep(0.01) # Restored to working timing except asyncio.CancelledError: @@ -165,14 +175,14 @@ async def main(duration=180): system_prompt="You are a helpful assistant." ) - await agent.start_conversation() + await agent.start() # Create shared context for all tasks context = { "active": True, "audio_in": asyncio.Queue(), "audio_out": asyncio.Queue(), - "session": agent._session, + "connection": agent._session, "duration": duration, "start_time": time.time(), "interrupted": False, @@ -196,7 +206,7 @@ async def main(duration=180): finally: print("Cleaning up...") context["active"] = False - await agent.end_conversation() + await agent.end() if __name__ == "__main__": diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 2b1480e62..fabe53ac9 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -1,43 +1,20 @@ """Bidirectional streaming types for real-time audio/text conversations. -PROBLEM ADDRESSED: ------------------ -Strands currently uses a request-response architecture without bidirectional streaming -support. Users cannot interrupt ongoing responses, provide additional context during -processing, or engage in real-time conversations. Each interaction requires a complete -request-response cycle. - -ARCHITECTURAL TRANSFORMATION: ----------------------------- -Current Limitations: Strands' unidirectional architecture follows sequential -request-response cycles that prevent real-time interaction. This represents a -pull-based architecture where the model receives the request, processes it, and -sends a response back. - -Bidirectional Solution: Uses persistent session-based connections with continuous -input and output flow. This implements a push-based architecture where the model -sends updates to the client as soon as response becomes available, without explicit -client requests. - -KEY CHARACTERISTICS: -------------------- -- Persistent Sessions: Connections remain open for extended periods (Nova Sonic: 8 minutes, - Google Live API: 15 minutes, OpenAI Realtime: 30 minutes) maintaining conversation context -- Bidirectional Communication: Users can send input while models generate responses -- Interruption Handling: Users can interrupt ongoing model responses in real-time without - terminating the session -- Tool Execution: Tools execute concurrently within the conversation flow rather than - requiring requests rebuilding - -PROVIDER NORMALIZATION: ----------------------- -Must normalize incompatible audio formats: Nova Sonic's hex-encoded base64, Google's -LINEAR16 PCM, OpenAI's Base64-encoded PCM16. Requires unified interruption event types -to handle Nova Sonic's stopReason = INTERRUPTED events, Google's VAD cancellation, and -OpenAI's conversation.item.truncate. - -This module extends existing StreamEvent types while maintaining backward compatibility -with existing Strands streaming patterns. +Type definitions for bidirectional streaming that extends Strands' existing streaming +capabilities with real-time audio and persistent connection support. + +Key features: +- Audio input/output events with standardized formats +- Interruption detection and handling +- connection lifecycle management +- Provider-agnostic event types +- Backwards compatibility with existing StreamEvent types + +Audio format normalization: +- Supports PCM, WAV, Opus, and MP3 formats +- Standardizes sample rates (16kHz, 24kHz, 48kHz) +- Normalizes channel configurations (mono/stereo) +- Abstracts provider-specific encodings """ from typing import Any, Dict, Literal, Optional @@ -56,8 +33,8 @@ class AudioOutputEvent(TypedDict): """Audio output event from the model. - Standardizes audio output across different providers using raw bytes - instead of provider-specific encodings (base64, hex, etc.). + Provides standardized audio output format across different providers using + raw bytes instead of provider-specific encodings. Attributes: audioData: Raw audio bytes (not base64 or hex encoded). @@ -77,7 +54,7 @@ class AudioOutputEvent(TypedDict): class AudioInputEvent(TypedDict): """Audio input event for sending audio to the model. - Used when sending audio data through send_audio() method. + Used for sending audio data through the send() method. Attributes: audioData: Raw audio bytes to send to model. @@ -117,45 +94,44 @@ class InterruptionDetectedEvent(TypedDict): class BidirectionalConnectionStartEvent(TypedDict, total=False): - """Session start event for bidirectional streaming. + """connection start event for bidirectional streaming. Attributes: - sessionId: Unique session identifier. - metadata: Provider-specific session metadata. + connectionId: Unique connection identifier. + metadata: Provider-specific connection metadata. """ - sessionId: Optional[str] + connectionId: Optional[str] metadata: Optional[Dict[str, Any]] class BidirectionalConnectionEndEvent(TypedDict): - """Session end event for bidirectional streaming. + """connection end event for bidirectional streaming. Attributes: - reason: Reason for session end from predefined set. - sessionId: Unique session identifier. - metadata: Provider-specific session metadata. + reason: Reason for connection end from predefined set. + connectionId: Unique connection identifier. + metadata: Provider-specific connection metadata. """ reason: Literal['user_request', 'timeout', 'error'] - sessionId: Optional[str] + connectionId: Optional[str] metadata: Optional[Dict[str, Any]] class BidirectionalStreamEvent(StreamEvent, total=False): """Bidirectional stream event extending existing StreamEvent. - Inherits all existing StreamEvent fields (contentBlockDelta, toolUse, - messageStart, etc.) while adding bidirectional-specific events. - Maintains full backward compatibility with existing Strands streaming. + Extends the existing StreamEvent type with bidirectional-specific events + while maintaining full backward compatibility with existing Strands streaming. Attributes: audioOutput: Audio output from the model. audioInput: Audio input sent to the model. textOutput: Text output from the model. interruptionDetected: User interruption detection. - BidirectionalConnectionStart: Session start event. - BidirectionalConnectionEnd: Session end event. + BidirectionalConnectionStart: connection start event. + BidirectionalConnectionEnd: connection end event. """ audioOutput: AudioOutputEvent From 15df9f9c06748c06376b596c7186e3712192e3cd Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 30 Sep 2025 10:45:29 -0400 Subject: [PATCH 003/242] Updated minimum python runtime dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index dd01ebde3..f45794d12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ bidirectional-streaming = [ "smithy-aws-core>=0.0.1", "pytz", "aws_sdk_bedrock_runtime", + "python>=3.12" ] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ From 3a0e7d5c360107ea4a0c890bf1c9f18ee3f1c603 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 1 Oct 2025 23:54:05 -0400 Subject: [PATCH 004/242] fix imports --- .../bidirectional_streaming/__init__.py | 5 + .../bidirectional_streaming/agent/__init__.py | 7 +- .../bidirectional_streaming/agent/agent.py | 70 ++- .../event_loop/__init__.py | 17 +- .../event_loop/bidirectional_event_loop.py | 243 ++++---- .../models/__init__.py | 8 +- .../models/bidirectional_model.py | 38 +- .../models/novasonic.py | 546 ++++++++---------- .../tests/test_bidirectional_streaming.py | 65 +-- .../bidirectional_streaming/types/__init__.py | 32 +- .../types/bidirectional_streaming.py | 53 +- .../bidirectional_streaming/utils/__init__.py | 5 + .../bidirectional_streaming/utils/debug.py | 13 +- 13 files changed, 530 insertions(+), 572 deletions(-) create mode 100644 src/strands/experimental/bidirectional_streaming/__init__.py create mode 100644 src/strands/experimental/bidirectional_streaming/utils/__init__.py diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py new file mode 100644 index 000000000..f6a3b41bf --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -0,0 +1,5 @@ +"""Bidirectional streaming package for real-time audio/text conversations.""" + +from .utils import log_event, log_flow, time_it_async + +__all__ = ["log_event", "log_flow", "time_it_async"] diff --git a/src/strands/experimental/bidirectional_streaming/agent/__init__.py b/src/strands/experimental/bidirectional_streaming/agent/__init__.py index bbd2c91f3..c490e001d 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/agent/__init__.py @@ -1,2 +1,5 @@ -"""Bidirectional streaming agent package.""" -# Agent package \ No newline at end of file +"""Bidirectional agent for real-time streaming conversations.""" + +from .agent import BidirectionalAgent + +__all__ = ["BidirectionalAgent"] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 023997551..d7a5f17a3 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -1,13 +1,13 @@ """Bidirectional Agent for real-time streaming conversations. Provides real-time audio and text interaction through persistent streaming sessions. -Unlike traditional request-response patterns, this agent maintains long-running -conversations where users can interrupt, provide additional input, and receive +Unlike traditional request-response patterns, this agent maintains long-running +conversations where users can interrupt, provide additional input, and receive continuous responses including audio output. Key capabilities: - Persistent conversation sessions with concurrent processing -- Real-time audio input/output streaming +- Real-time audio input/output streaming - Mid-conversation interruption and tool execution - Event-driven communication with model providers """ @@ -16,10 +16,9 @@ import logging from typing import AsyncIterable, List, Optional, Union -from strands.tools.executors import ConcurrentToolExecutor -from strands.tools.registry import ToolRegistry -from strands.types.content import Messages - +from ....tools.executors import ConcurrentToolExecutor +from ....tools.registry import ToolRegistry +from ....types.content import Messages from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection from ..models.bidirectional_model import BidirectionalModel from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent @@ -30,20 +29,20 @@ class BidirectionalAgent: """Agent for bidirectional streaming conversations. - + Enables real-time audio and text interaction with AI models through persistent sessions. Supports concurrent tool execution and interruption handling. """ - + def __init__( self, model: BidirectionalModel, tools: Optional[List] = None, system_prompt: Optional[str] = None, - messages: Optional[Messages] = None + messages: Optional[Messages] = None, ): """Initialize bidirectional agent with required model and optional configuration. - + Args: model: BidirectionalModel instance supporting streaming sessions. tools: Optional list of tools available to the model. @@ -53,51 +52,51 @@ def __init__( self.model = model self.system_prompt = system_prompt self.messages = messages or [] - + # Initialize tool registry using existing Strands infrastructure self.tool_registry = ToolRegistry() if tools: self.tool_registry.process_tools(tools) self.tool_registry.initialize_tools() - + # Initialize tool executor for concurrent execution self.tool_executor = ConcurrentToolExecutor() - + # Session management self._session = None self._output_queue = asyncio.Queue() - + async def start(self) -> None: """Start a persistent bidirectional conversation session. - + Initializes the streaming session and starts background tasks for processing model events, tool execution, and session management. - + Raises: ValueError: If conversation already active. ConnectionError: If session creation fails. """ if self._session and self._session.active: raise ValueError("Conversation already active. Call end() first.") - + log_flow("conversation_start", "initializing session") self._session = await start_bidirectional_connection(self) log_event("conversation_ready") - + async def send(self, input_data: Union[str, AudioInputEvent]) -> None: """Send input to the model (text or audio). - + Unified method for sending both text and audio input to the model during an active conversation session. - + Args: input_data: Either a string for text input or AudioInputEvent for audio input. - + Raises: ValueError: If no active session or invalid input type. """ self._validate_active_session() - + if isinstance(input_data, str): # Handle text input log_event("text_sent", length=len(input_data)) @@ -110,15 +109,13 @@ async def send(self, input_data: Union[str, AudioInputEvent]) -> None: "Input must be either a string (text) or AudioInputEvent " "(dict with audioData, format, sampleRate, channels)" ) - - async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: """Receive events from the model including audio, text, and tool calls. - + Yields model output events processed by background tasks including audio output, text responses, tool calls, and session updates. - + Yields: BidirectionalStreamEvent: Events from the model session. """ @@ -128,35 +125,34 @@ async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: yield event except asyncio.TimeoutError: continue - + async def interrupt(self) -> None: """Interrupt the current model generation and clear audio buffers. - - Sends interruption signal to stop generation immediately and clears + + Sends interruption signal to stop generation immediately and clears pending audio output for responsive conversation flow. - + Raises: ValueError: If no active session. """ self._validate_active_session() await self._session.model_session.send_interrupt() - + async def end(self) -> None: """End the conversation session and cleanup all resources. - - Terminates the streaming session, cancels background tasks, and + + Terminates the streaming session, cancels background tasks, and closes the connection to the model provider. """ if self._session: await stop_bidirectional_connection(self._session) self._session = None - + def _validate_active_session(self) -> None: """Validate that an active session exists. - + Raises: ValueError: If no active session. """ if not self._session or not self._session.active: raise ValueError("No active conversation. Call start() first.") - diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py index 24080b703..af8c4e1e1 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py @@ -1,2 +1,15 @@ -"""Bidirectional streaming event loop package.""" -# Event Loop package \ No newline at end of file +"""Event loop management for bidirectional streaming.""" + +from .bidirectional_event_loop import ( + BidirectionalConnection, + bidirectional_event_loop_cycle, + start_bidirectional_connection, + stop_bidirectional_connection, +) + +__all__ = [ + "BidirectionalConnection", + "start_bidirectional_connection", + "stop_bidirectional_connection", + "bidirectional_event_loop_cycle", +] diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 3884750d5..c90d118ff 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -18,10 +18,9 @@ import uuid from typing import Any, Dict -from strands.tools._validator import validate_and_prepare_tools -from strands.types.content import Message -from strands.types.tools import ToolResult, ToolUse - +from ....tools._validator import validate_and_prepare_tools +from ....types.content import Message +from ....types.tools import ToolResult, ToolUse from ..models.bidirectional_model import BidirectionalModelSession from ..utils.debug import log_event, log_flow @@ -34,14 +33,14 @@ class BidirectionalConnection: """Session wrapper for bidirectional communication with concurrent task management. - + Coordinates background tasks for model event processing, tool execution, and audio handling while providing a simple interface for agent interactions. """ - + def __init__(self, model_session: BidirectionalModelSession, agent): """Initialize session with model session and agent reference. - + Args: model_session: Provider-specific bidirectional model session. agent: BidirectionalAgent instance for tool registry access. @@ -49,96 +48,93 @@ def __init__(self, model_session: BidirectionalModelSession, agent): self.model_session = model_session self.agent = agent self.active = True - + # Background processing coordination self.background_tasks = [] self.tool_queue = asyncio.Queue() self.audio_output_queue = asyncio.Queue() - + # Task management for cleanup self.pending_tool_tasks: Dict[str, asyncio.Task] = {} - + # Interruption handling (model-agnostic) self.interrupted = False -async def start_bidirectional_connection(agent) -> BidirectionalConnection: + +async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection: """Initialize bidirectional session with concurrent background tasks. - + Creates a model-specific session and starts background tasks for processing model events, executing tools, and managing the session lifecycle. - + Args: agent: BidirectionalAgent instance. - + Returns: BidirectionalConnection: Active session with background tasks running. - """ + """ log_flow("session_start", "initializing model session") - + # Create provider-specific session model_session = await agent.model.create_bidirectional_connection( - system_prompt=agent.system_prompt, - tools=agent.tool_registry.get_all_tool_specs(), - messages=agent.messages + system_prompt=agent.system_prompt, tools=agent.tool_registry.get_all_tool_specs(), messages=agent.messages ) - + # Create session wrapper for background processing session = BidirectionalConnection(model_session=model_session, agent=agent) - + # Start concurrent background processors IMMEDIATELY after session creation # This is critical - Nova Sonic needs response processing during initialization log_flow("background_tasks", "starting processors") session.background_tasks = [ - asyncio.create_task(_process_model_events(session)), # Handle model responses - asyncio.create_task(_process_tool_execution(session)) # Execute tools concurrently + asyncio.create_task(_process_model_events(session)), # Handle model responses + asyncio.create_task(_process_tool_execution(session)), # Execute tools concurrently ] - + # Start main coordination cycle - session.main_cycle_task = asyncio.create_task( - bidirectional_event_loop_cycle(session) - ) - + session.main_cycle_task = asyncio.create_task(bidirectional_event_loop_cycle(session)) + # Give background tasks a moment to start await asyncio.sleep(0.1) log_event("session_ready", tasks=len(session.background_tasks)) - + return session async def stop_bidirectional_connection(session: BidirectionalConnection) -> None: """End session and cleanup resources including background tasks. - + Args: session: BidirectionalConnection to cleanup. """ if not session.active: return - + log_flow("session_cleanup", "starting") session.active = False - + # Cancel pending tool tasks for _, task in session.pending_tool_tasks.items(): if not task.done(): task.cancel() - + # Cancel background tasks for task in session.background_tasks: if not task.done(): task.cancel() - + # Cancel main cycle task - if hasattr(session, 'main_cycle_task') and not session.main_cycle_task.done(): + if hasattr(session, "main_cycle_task") and not session.main_cycle_task.done(): session.main_cycle_task.cancel() - + # Wait for tasks to complete all_tasks = session.background_tasks + list(session.pending_tool_tasks.values()) - if hasattr(session, 'main_cycle_task'): + if hasattr(session, "main_cycle_task"): all_tasks.append(session.main_cycle_task) - + if all_tasks: await asyncio.gather(*all_tasks, return_exceptions=True) - + # Close model session await session.model_session.close() log_event("session_closed") @@ -146,10 +142,10 @@ async def stop_bidirectional_connection(session: BidirectionalConnection) -> Non async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None: """Main event loop coordinator that runs continuously during the session. - + Monitors background tasks, manages session state, and handles session lifecycle. Provides supervision for concurrent model event processing and tool execution. - + Args: session: BidirectionalConnection to coordinate. """ @@ -160,7 +156,7 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No log_event("session_end", reason="all_processors_completed") session.active = False break - + # Check for failed background tasks for i, task in enumerate(session.background_tasks): if task.done() and not task.cancelled(): @@ -169,10 +165,10 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No log_event("session_error", processor=i, error=str(exception)) session.active = False raise exception - + # Brief pause before next supervision check await asyncio.sleep(SUPERVISION_INTERVAL) - + except asyncio.CancelledError: break except Exception as e: @@ -183,16 +179,16 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No async def _handle_interruption(session: BidirectionalConnection) -> None: """Handle interruption detection with task cancellation and audio buffer clearing. - + Cancels pending tool tasks and clears audio output queues to ensure responsive interruption handling during conversations. - + Args: session: BidirectionalConnection to handle interruption for. """ log_event("interruption_detected") session.interrupted = True - + # 🔥 CANCEL ALL PENDING TOOL TASKS (Nova Sonic pattern) cancelled_tools = 0 for task_id, task in list(session.pending_tool_tasks.items()): @@ -200,10 +196,10 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: task.cancel() cancelled_tools += 1 log_event("tool_task_cancelled", task_id=task_id) - + if cancelled_tools > 0: log_event("tool_tasks_cancelled", count=cancelled_tools) - + # 🔥 AGGRESSIVELY CLEAR AUDIO OUTPUT QUEUE (Nova Sonic pattern) cleared_count = 0 while True: @@ -212,9 +208,9 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: cleared_count += 1 except asyncio.QueueEmpty: break - + # Also clear the agent's audio output queue if it exists - if hasattr(session.agent, '_output_queue'): + if hasattr(session.agent, "_output_queue"): audio_cleared = 0 # Create a temporary list to hold non-audio events temp_events = [] @@ -228,20 +224,20 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: temp_events.append(event) except asyncio.QueueEmpty: pass - + # Put back non-audio events for event in temp_events: session.agent._output_queue.put_nowait(event) - + if audio_cleared > 0: log_event("agent_audio_queue_cleared", count=audio_cleared) - + if cleared_count > 0: log_event("session_audio_queue_cleared", count=cleared_count) - + # Brief sleep to allow audio system to settle (matches Nova Sonic timing) await asyncio.sleep(0.05) - + # Reset interruption flag after clearing (automatic recovery) session.interrupted = False log_event("interruption_handled", tools_cancelled=cancelled_tools, audio_cleared=cleared_count) @@ -249,10 +245,10 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: async def _process_model_events(session: BidirectionalConnection) -> None: """Process model events and convert them to Strands format. - + Background task that handles all model responses, converts provider-specific events to standardized formats, and manages interruption detection. - + Args: session: BidirectionalConnection containing model session. """ @@ -261,10 +257,10 @@ async def _process_model_events(session: BidirectionalConnection) -> None: async for provider_event in session.model_session.receive_events(): if not session.active: break - + # Convert provider events to Strands format strands_event = _convert_to_strands_event(provider_event) - + # Handle interruption detection (multiple patterns) if strands_event.get("interruptionDetected"): log_event("interruption_forwarded") @@ -272,7 +268,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Forward interruption event to agent for application-level handling await session.agent._output_queue.put(strands_event) continue - + # Check for text-based interruption (Nova Sonic pattern) if strands_event.get("textOutput"): text_content = strands_event["textOutput"].get("content", "") @@ -282,22 +278,22 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Still forward the text event await session.agent._output_queue.put(strands_event) continue - + # Queue tool requests for concurrent execution if strands_event.get("toolUse"): log_event("tool_queued", name=strands_event["toolUse"].get("name")) await session.tool_queue.put(strands_event["toolUse"]) continue - + # Send output events to Agent for receive() method if strands_event.get("audioOutput") or strands_event.get("textOutput"): await session.agent._output_queue.put(strands_event) - + # Update Agent conversation history using existing patterns if strands_event.get("messageStop"): log_event("message_added_to_history") session.agent.messages.append(strands_event["messageStop"]["message"]) - + except Exception as e: log_event("model_events_error", error=str(e)) traceback.print_exc() @@ -307,11 +303,11 @@ async def _process_model_events(session: BidirectionalConnection) -> None: async def _process_tool_execution(session: BidirectionalConnection) -> None: """Execute tools concurrently with interruption support. - + Background task that manages tool execution without blocking model event processing or user interaction. Includes proper task cleanup and cancellation handling for interruptions. - + Args: session: BidirectionalConnection containing tool queue. """ @@ -320,143 +316,136 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: try: tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) log_event("tool_execution_started", name=tool_use.get("name"), id=tool_use.get("toolUseId")) - + if not session.active: break - + task_id = str(uuid.uuid4()) task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) session.pending_tool_tasks[task_id] = task - + # 🔥 ADD CLEANUP CALLBACK (Nova Sonic pattern) def cleanup_task(completed_task): try: # Remove from pending tasks if task_id in session.pending_tool_tasks: del session.pending_tool_tasks[task_id] - + # Log completion status if completed_task.cancelled(): log_event("tool_task_cleanup_cancelled", task_id=task_id) elif completed_task.exception(): - log_event("tool_task_cleanup_error", task_id=task_id, - error=str(completed_task.exception())) + log_event("tool_task_cleanup_error", task_id=task_id, error=str(completed_task.exception())) else: log_event("tool_task_cleanup_success", task_id=task_id) except Exception as e: log_event("tool_task_cleanup_failed", task_id=task_id, error=str(e)) - + task.add_done_callback(cleanup_task) - + except asyncio.TimeoutError: if not session.active: break # 🔥 PERIODIC CLEANUP OF COMPLETED TASKS - completed_tasks = [ - task_id for task_id, task in session.pending_tool_tasks.items() - if task.done() - ] + completed_tasks = [task_id for task_id, task in session.pending_tool_tasks.items() if task.done()] for task_id in completed_tasks: if task_id in session.pending_tool_tasks: del session.pending_tool_tasks[task_id] - + if completed_tasks: log_event("periodic_task_cleanup", count=len(completed_tasks)) - + continue except Exception as e: log_event("tool_execution_error", error=str(e)) if not session.active: break - + log_flow("tool_execution", "processor stopped") def _convert_to_strands_event(provider_event: Dict) -> Dict: """Pass-through for events already normalized by provider sessions. - + Providers convert their raw events to standard format before reaching here. This just validates and passes through the normalized events. - + Args: provider_event: Already normalized event from provider session. - + Returns: Dict: The same event, validated and passed through. """ # Basic validation - ensure we have a dict if not isinstance(provider_event, dict): return {} - + # Pass through - conversion already done by provider session return provider_event async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: Dict) -> None: """Execute tool using Strands infrastructure with interruption support. - + Executes tools using the existing Strands tool system, handles interruption during execution, and sends results back to the model provider. - + Args: session: BidirectionalConnection for context. tool_use: Tool use event to execute. """ - tool_name = tool_use.get('name') - tool_id = tool_use.get('toolUseId') - + tool_name = tool_use.get("name") + tool_id = tool_use.get("toolUseId") + try: # 🔥 CHECK FOR INTERRUPTION BEFORE STARTING (Nova Sonic pattern) if session.interrupted or not session.active: log_event("tool_execution_cancelled_before_start", name=tool_name, id=tool_id) return - + # Create message structure for existing tool system - tool_message: Message = { - "role": "assistant", - "content": [{"toolUse": tool_use}] - } - + tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} + tool_uses: list[ToolUse] = [] tool_results: list[ToolResult] = [] invalid_tool_use_ids: list[str] = [] - + # Validate using existing Strands validation validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids) - + # Filter valid tool uses valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] - + if not valid_tool_uses: log_event("tool_validation_failed", name=tool_name, id=tool_id) return - + # Execute tools directly (simpler approach for bidirectional) for tool_use in valid_tool_uses: # 🔥 CHECK FOR INTERRUPTION DURING EXECUTION if session.interrupted or not session.active: log_event("tool_execution_cancelled_during", name=tool_name, id=tool_id) return - + tool_func = session.agent.tool_registry.registry.get(tool_use["name"]) - + if tool_func: try: actual_func = _extract_callable_function(tool_func) - + # 🔥 WRAP TOOL EXECUTION IN CANCELLATION CHECK # For async tools, we could wrap with asyncio.wait_for with cancellation # For sync tools, we execute directly but check interruption after result = actual_func(**tool_use.get("input", {})) - + # 🔥 CHECK FOR INTERRUPTION AFTER TOOL EXECUTION if session.interrupted or not session.active: log_event("tool_result_discarded_interruption", name=tool_name, id=tool_id) return - + tool_result = _create_success_result(tool_use["toolUseId"], result) tool_results.append(tool_result) - + except asyncio.CancelledError: # Tool was cancelled due to interruption log_event("tool_execution_cancelled", name=tool_name, id=tool_id) @@ -466,50 +455,44 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: if session.interrupted or not session.active: log_event("tool_error_discarded_interruption", name=tool_name, id=tool_id) return - + log_event("tool_execution_failed", name=tool_name, error=str(e)) tool_result = _create_error_result(tool_use["toolUseId"], str(e)) tool_results.append(tool_result) else: log_event("tool_not_found", name=tool_name) - + # 🔥 FINAL INTERRUPTION CHECK BEFORE SENDING RESULTS if session.interrupted or not session.active: log_event("tool_results_discarded_interruption", name=tool_name, count=len(tool_results)) return - + # Send results through provider-specific session for result in tool_results: - await session.model_session.send_tool_result( - tool_use.get("toolUseId"), - result - ) - + await session.model_session.send_tool_result(tool_use.get("toolUseId"), result) + log_event("tool_execution_completed", name=tool_name, results=len(tool_results)) - + except asyncio.CancelledError: # Task was cancelled due to interruption - this is expected behavior log_event("tool_task_cancelled_gracefully", name=tool_name, id=tool_id) raise # Re-raise to properly handle cancellation except Exception as e: - log_event("tool_execution_error", name=tool_use.get('name'), error=str(e)) - + log_event("tool_execution_error", name=tool_use.get("name"), error=str(e)) + # Only send error if not interrupted if not session.interrupted and session.active: try: - await session.model_session.send_tool_error( - tool_use.get("toolUseId"), - str(e) - ) + await session.model_session.send_tool_error(tool_use.get("toolUseId"), str(e)) except Exception as send_error: log_event("tool_error_send_failed", error=str(send_error)) def _extract_callable_function(tool_func): """Extract the callable function from different tool object types.""" - if hasattr(tool_func, '_tool_func'): + if hasattr(tool_func, "_tool_func"): return tool_func._tool_func - elif hasattr(tool_func, 'func'): + elif hasattr(tool_func, "func"): return tool_func.func elif callable(tool_func): return tool_func @@ -519,17 +502,9 @@ def _extract_callable_function(tool_func): def _create_success_result(tool_use_id: str, result) -> Dict[str, Any]: """Create a successful tool result.""" - return { - "toolUseId": tool_use_id, - "status": "success", - "content": [{"text": json.dumps(result)}] - } + return {"toolUseId": tool_use_id, "status": "success", "content": [{"text": json.dumps(result)}]} def _create_error_result(tool_use_id: str, error: str) -> Dict[str, Any]: """Create an error tool result.""" - return { - "toolUseId": tool_use_id, - "status": "error", - "content": [{"text": f"Error: {error}"}] - } \ No newline at end of file + return {"toolUseId": tool_use_id, "status": "error", "content": [{"text": f"Error: {error}"}]} diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index b2b10a5f2..6cba974e0 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -1,2 +1,6 @@ -"""Bidirectional streaming models package.""" -# Models package \ No newline at end of file +"""Bidirectional model interfaces and implementations.""" + +from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .novasonic import NovaSonicBidirectionalModel, NovaSonicSession + +__all__ = ["BidirectionalModel", "BidirectionalModelSession", "NovaSonicBidirectionalModel", "NovaSonicSession"] diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 81e5cd9d6..cc803458b 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -7,7 +7,7 @@ Features: - connection-based persistent connections - Real-time bidirectional communication -- Provider-agnostic event normalization +- Provider-agnostic event normalization - Tool execution integration """ @@ -21,63 +21,64 @@ logger = logging.getLogger(__name__) + class BidirectionalModelSession(abc.ABC): """Abstract interface for model-specific bidirectional communication connections. - + Defines the contract for managing persistent streaming connections with individual model providers, handling audio/text input, receiving events, and managing tool execution results. """ - + @abc.abstractmethod async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: """Receive events from the model in standardized format. - + Converts provider-specific events to a common format that can be processed uniformly by the event loop. """ raise NotImplementedError - + @abc.abstractmethod async def send_audio_content(self, audio_input: AudioInputEvent) -> None: """Send audio content to the model during an active connection. - + Handles audio encoding and provider-specific formatting while presenting a simple AudioInputEvent interface. """ raise NotImplementedError - + @abc.abstractmethod async def send_text_content(self, text: str, **kwargs) -> None: """Send text content to the model during ongoing generation. - + Allows natural interruption and follow-up questions without requiring connection restart. """ raise NotImplementedError - + @abc.abstractmethod async def send_interrupt(self) -> None: """Send interruption signal to stop generation immediately. - + Enables responsive conversational experiences where users can naturally interrupt during model responses. """ raise NotImplementedError - + @abc.abstractmethod async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: """Send tool execution result to the model. - + Formats and sends tool results according to the provider's specific protocol. """ raise NotImplementedError - + @abc.abstractmethod async def send_tool_error(self, tool_use_id: str, error: str) -> None: """Send tool execution error to model in provider-specific format.""" raise NotImplementedError - + @abc.abstractmethod async def close(self) -> None: """Close the connection and cleanup resources.""" @@ -86,23 +87,22 @@ async def close(self) -> None: class BidirectionalModel(abc.ABC): """Interface for models that support bidirectional streaming. - + Defines the contract for creating persistent streaming connections that support real-time audio and text communication with AI models. """ - + @abc.abstractmethod async def create_bidirectional_connection( self, system_prompt: Optional[str] = None, tools: Optional[List[ToolSpec]] = None, messages: Optional[Messages] = None, - **kwargs + **kwargs, ) -> BidirectionalModelSession: """Create a bidirectional connection with the model. - + Establishes a persistent connection for real-time communication while abstracting provider-specific initialization requirements. """ raise NotImplementedError - diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 4332181b5..0efd2413c 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -1,7 +1,7 @@ """Nova Sonic bidirectional model provider for real-time streaming conversations. Implements the BidirectionalModel interface for Amazon's Nova Sonic, handling the -complex event sequencing and audio processing required by Nova Sonic's +complex event sequencing and audio processing required by Nova Sonic's InvokeModelWithBidirectionalStream protocol. Nova Sonic specifics: @@ -42,11 +42,7 @@ logger = logging.getLogger(__name__) # Nova Sonic configuration constants -NOVA_INFERENCE_CONFIG = { - "maxTokens": 1024, - "topP": 0.9, - "temperature": 0.7 -} +NOVA_INFERENCE_CONFIG = {"maxTokens": 1024, "topP": 0.9, "temperature": 0.7} NOVA_AUDIO_INPUT_CONFIG = { "mediaType": "audio/lpcm", @@ -54,7 +50,7 @@ "sampleSizeBits": 16, "channelCount": 1, "audioType": "SPEECH", - "encoding": "base64" + "encoding": "base64", } NOVA_AUDIO_OUTPUT_CONFIG = { @@ -64,7 +60,7 @@ "channelCount": 1, "voiceId": "matthew", "encoding": "base64", - "audioType": "SPEECH" + "audioType": "SPEECH", } NOVA_TEXT_CONFIG = {"mediaType": "text/plain"} @@ -78,15 +74,15 @@ class NovaSonicSession(BidirectionalModelSession): """Nova Sonic connection implementation handling the provider's specific protocol. - + Manages Nova Sonic's complex event sequencing, audio format conversion, and tool execution patterns while providing the standard BidirectionalModelSession interface. """ - + def __init__(self, stream, config: Dict[str, Any]): """Initialize Nova Sonic connection. - + Args: stream: Nova Sonic bidirectional stream. config: Model configuration. @@ -95,80 +91,78 @@ def __init__(self, stream, config: Dict[str, Any]): self.config = config self.prompt_name = str(uuid.uuid4()) self._active = True - + # Nova Sonic requires unique content names self.audio_content_name = str(uuid.uuid4()) self.text_content_name = str(uuid.uuid4()) - + # Audio connection state self.audio_connection_active = False self.last_audio_time = None self.silence_threshold = SILENCE_THRESHOLD self.silence_task = None - + # Validate stream if not stream: logger.error("Stream is None") raise ValueError("Stream cannot be None") - + logger.debug("Nova Sonic connection initialized with prompt: %s", self.prompt_name) - + async def initialize( self, system_prompt: Optional[str] = None, tools: Optional[List[ToolSpec]] = None, - messages: Optional[Messages] = None + messages: Optional[Messages] = None, ) -> None: """Initialize Nova Sonic connection with required protocol sequence.""" try: system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." - + init_events = self._build_initialization_events(system_prompt, tools or [], messages) - + log_flow("nova_init", f"sending {len(init_events)} events") await self._send_initialization_events(init_events) - + log_event("nova_connection_initialized") self._response_task = asyncio.create_task(self._process_responses()) - + except Exception as e: logger.error("Error during Nova Sonic initialization: %s", e) raise - - def _build_initialization_events(self, system_prompt: str, tools: List[ToolSpec], - messages: Optional[Messages]) -> List[str]: + + def _build_initialization_events( + self, system_prompt: str, tools: List[ToolSpec], messages: Optional[Messages] + ) -> List[str]: """Build the sequence of initialization events.""" - events = [ - self._get_connection_start_event(), - self._get_prompt_start_event(tools) - ] - + events = [self._get_connection_start_event(), self._get_prompt_start_event(tools)] + events.extend(self._get_system_prompt_events(system_prompt)) - + # Message history would be processed here if needed in the future # Currently not implemented as it's not used in the existing test cases - + return events - + async def _send_initialization_events(self, events: List[str]) -> None: """Send initialization events with required delays.""" for i, event in enumerate(events): - await time_it_async(f"send_init_event_{i+1}", lambda: self._send_nova_event(event)) + await time_it_async(f"send_init_event_{i + 1}", lambda: self._send_nova_event(event)) await asyncio.sleep(EVENT_DELAY) - + async def _process_responses(self) -> None: """Process Nova Sonic responses continuously.""" log_flow("nova_responses", "processor started") - + try: while self._active: try: output = await asyncio.wait_for(self.stream.await_output(), timeout=RESPONSE_TIMEOUT) result = await output[1].receive() - + if result.value and result.value.bytes_: - await self._handle_response_data(result.value.bytes_.decode('utf-8')) - + await self._handle_response_data(result.value.bytes_.decode("utf-8")) + except asyncio.TimeoutError: await asyncio.sleep(0.1) continue @@ -176,39 +170,39 @@ async def _process_responses(self) -> None: log_event("nova_response_error", error=str(e)) await asyncio.sleep(0.1) continue - + except Exception as e: log_event("nova_fatal_error", error=str(e)) finally: log_flow("nova_responses", "processor stopped") - + async def _handle_response_data(self, response_data: str) -> None: """Handle decoded response data from Nova Sonic.""" try: json_data = json.loads(response_data) - - if 'event' in json_data: - nova_event = json_data['event'] + + if "event" in json_data: + nova_event = json_data["event"] self._log_event_type(nova_event) - - if not hasattr(self, '_event_queue'): + + if not hasattr(self, "_event_queue"): self._event_queue = asyncio.Queue() - + await self._event_queue.put(nova_event) except json.JSONDecodeError as e: log_event("nova_json_error", error=str(e)) - + def _log_event_type(self, nova_event: Dict[str, Any]) -> None: """Log specific Nova Sonic event types for debugging.""" - if 'usageEvent' in nova_event: - log_event("nova_usage", usage=nova_event['usageEvent']) - elif 'textOutput' in nova_event: + if "usageEvent" in nova_event: + log_event("nova_usage", usage=nova_event["usageEvent"]) + elif "textOutput" in nova_event: log_event("nova_text_output") - elif 'toolUse' in nova_event: - tool_use = nova_event['toolUse'] - log_event("nova_tool_use", name=tool_use['toolName'], id=tool_use['toolUseId']) - elif 'audioOutput' in nova_event: - audio_content = nova_event['audioOutput']['content'] + elif "toolUse" in nova_event: + tool_use = nova_event["toolUse"] + log_event("nova_tool_use", name=tool_use["toolName"], id=tool_use["toolUseId"]) + elif "audioOutput" in nova_event: + audio_content = nova_event["audioOutput"]["content"] audio_bytes = base64.b64decode(audio_content) log_event("nova_audio_output", bytes=len(audio_bytes)) @@ -217,37 +211,35 @@ async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: if not self.stream: logger.error("Stream is None") return - + log_flow("nova_events", "starting event stream") - + # Emit connection start event to Strands event system connection_start: BidirectionalConnectionStartEvent = { "connectionId": self.prompt_name, - "metadata": {"provider": "nova_sonic", "model_id": self.config.get("model_id")} - } - yield { - "BidirectionalConnectionStart": connection_start + "metadata": {"provider": "nova_sonic", "model_id": self.config.get("model_id")}, } - + yield {"BidirectionalConnectionStart": connection_start} + # Initialize event queue if not already done - if not hasattr(self, '_event_queue'): + if not hasattr(self, "_event_queue"): self._event_queue = asyncio.Queue() - + try: while self._active: try: # Get events from the queue populated by _process_responses nova_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) - + # Convert to provider-agnostic format provider_event = self._convert_nova_event(nova_event) if provider_event: yield provider_event - + except asyncio.TimeoutError: # No events in queue - continue waiting continue - + except Exception as e: logger.error("Error receiving Nova Sonic event: %s", e) logger.error(traceback.format_exc()) @@ -256,68 +248,70 @@ async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: connection_end: BidirectionalConnectionEndEvent = { "connectionId": self.prompt_name, "reason": "connection_complete", - "metadata": {"provider": "nova_sonic"} + "metadata": {"provider": "nova_sonic"}, } - yield { - "BidirectionalConnectionEnd": connection_end - } - + yield {"BidirectionalConnectionEnd": connection_end} + async def start_audio_connection(self) -> None: """Start audio input connection (call once before sending audio chunks).""" if self.audio_connection_active: return - + log_event("nova_audio_connection_start") - - audio_content_start = json.dumps({ - "event": { - "contentStart": { - "promptName": self.prompt_name, - "contentName": self.audio_content_name, - "type": "AUDIO", - "interactive": True, - "role": "USER", - "audioInputConfiguration": NOVA_AUDIO_INPUT_CONFIG + + audio_content_start = json.dumps( + { + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "type": "AUDIO", + "interactive": True, + "role": "USER", + "audioInputConfiguration": NOVA_AUDIO_INPUT_CONFIG, + } } } - }) - + ) + await self._send_nova_event(audio_content_start) self.audio_connection_active = True - + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: """Send audio using Nova Sonic protocol-specific format.""" if not self._active: return - + # Start audio connection if not already active if not self.audio_connection_active: await self.start_audio_connection() - + # Update last audio time and cancel any pending silence task self.last_audio_time = time.time() if self.silence_task and not self.silence_task.done(): self.silence_task.cancel() - + # Convert audio to Nova Sonic base64 format - nova_audio_data = base64.b64encode(audio_input["audioData"]).decode('utf-8') - + nova_audio_data = base64.b64encode(audio_input["audioData"]).decode("utf-8") + # Send audio input event - audio_event = json.dumps({ - "event": { - "audioInput": { - "promptName": self.prompt_name, - "contentName": self.audio_content_name, - "content": nova_audio_data + audio_event = json.dumps( + { + "event": { + "audioInput": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "content": nova_audio_data, + } } } - }) - + ) + await self._send_nova_event(audio_event) - + # Start silence detection task self.silence_task = asyncio.create_task(self._check_silence()) - + async def _check_silence(self): """Check for silence and automatically end audio connection.""" try: @@ -329,226 +323,195 @@ async def _check_silence(self): await self.end_audio_input() except asyncio.CancelledError: pass - + async def end_audio_input(self) -> None: """End current audio input connection to trigger Nova Sonic processing.""" if not self.audio_connection_active: return - + log_event("nova_audio_connection_end") - - audio_content_end = json.dumps({ - "event": { - "contentEnd": { - "promptName": self.prompt_name, - "contentName": self.audio_content_name - } - } - }) - + + audio_content_end = json.dumps( + {"event": {"contentEnd": {"promptName": self.prompt_name, "contentName": self.audio_content_name}}} + ) + await self._send_nova_event(audio_content_end) self.audio_connection_active = False - + async def send_text_content(self, text: str, **kwargs) -> None: """Send text content using Nova Sonic format.""" if not self._active: return - + content_name = str(uuid.uuid4()) events = [ self._get_text_content_start_event(content_name), self._get_text_input_event(content_name, text), - self._get_content_end_event(content_name) + self._get_content_end_event(content_name), ] - + for event in events: await self._send_nova_event(event) - + async def send_interrupt(self) -> None: """Send interruption signal to Nova Sonic.""" if not self._active: return - + # Nova Sonic handles interruption through special input events interrupt_event = { "event": { "audioInput": { "promptName": self.prompt_name, "contentName": self.audio_content_name, - "stopReason": "INTERRUPTED" + "stopReason": "INTERRUPTED", } } } await self._send_nova_event(interrupt_event) - + async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: """Send tool result using Nova Sonic toolResult format.""" if not self._active: return - + log_event("nova_tool_result_send", id=tool_use_id) content_name = str(uuid.uuid4()) events = [ self._get_tool_content_start_event(content_name, tool_use_id), self._get_tool_result_event(content_name, result), - self._get_content_end_event(content_name) + self._get_content_end_event(content_name), ] - + for i, event in enumerate(events): - await time_it_async(f"send_tool_event_{i+1}", lambda: self._send_nova_event(event)) - + await time_it_async(f"send_tool_event_{i + 1}", lambda: self._send_nova_event(event)) + async def send_tool_error(self, tool_use_id: str, error: str) -> None: """Send tool error using Nova Sonic format.""" log_event("nova_tool_error_send", id=tool_use_id, error=error) error_result = {"error": error} await self.send_tool_result(tool_use_id, error_result) - + async def close(self) -> None: """Close Nova Sonic connection with proper cleanup sequence.""" if not self._active: return - + log_flow("nova_cleanup", "starting connection close") self._active = False - + # Cancel response processing task if running - if hasattr(self, '_response_task') and not self._response_task.done(): + if hasattr(self, "_response_task") and not self._response_task.done(): self._response_task.cancel() try: await self._response_task except asyncio.CancelledError: pass - + try: # End audio connection if active if self.audio_connection_active: await self.end_audio_input() - + # Send cleanup events - cleanup_events = [ - self._get_prompt_end_event(), - self._get_connection_end_event() - ] - + cleanup_events = [self._get_prompt_end_event(), self._get_connection_end_event()] + for event in cleanup_events: try: await self._send_nova_event(event) except Exception as e: logger.warning("Error during Nova Sonic cleanup: %s", e) - + # Close stream try: await self.stream.input_stream.close() except Exception as e: logger.warning("Error closing Nova Sonic stream: %s", e) - + except Exception as e: log_event("nova_cleanup_error", error=str(e)) finally: log_event("nova_connection_closed") - + def _convert_nova_event(self, nova_event: Dict[str, Any]) -> Optional[Dict[str, Any]]: """Convert Nova Sonic events to provider-agnostic format.""" # Handle audio output if "audioOutput" in nova_event: audio_content = nova_event["audioOutput"]["content"] audio_bytes = base64.b64decode(audio_content) - + audio_output: AudioOutputEvent = { "audioData": audio_bytes, "format": "pcm", "sampleRate": 24000, "channels": 1, - "encoding": "base64" - } - - return { - "audioOutput": audio_output + "encoding": "base64", } - + + return {"audioOutput": audio_output} + # Handle text output elif "textOutput" in nova_event: text_content = nova_event["textOutput"]["content"] # Use stored role from contentStart event, fallback to event role - role = getattr(self, '_current_role', nova_event["textOutput"].get("role", "assistant")) - + role = getattr(self, "_current_role", nova_event["textOutput"].get("role", "assistant")) + # Check for Nova Sonic interruption pattern (matches working sample) if '{ "interrupted" : true }' in text_content: log_event("nova_interruption_in_text") - interruption: InterruptionDetectedEvent = { - "reason": "user_input" - } - return { - "interruptionDetected": interruption - } - + interruption: InterruptionDetectedEvent = {"reason": "user_input"} + return {"interruptionDetected": interruption} + # Show transcription for user speech - ALWAYS show these regardless of DEBUG flag if role == "USER": print(f"User: {text_content}") elif role == "ASSISTANT": print(f"Assistant: {text_content}") - - text_output: TextOutputEvent = { - "text": text_content, - "role": role.lower() - } - - return { - "textOutput": text_output - } - + + text_output: TextOutputEvent = {"text": text_content, "role": role.lower()} + + return {"textOutput": text_output} + # Handle tool use elif "toolUse" in nova_event: tool_use = nova_event["toolUse"] - + tool_use_event: ToolUse = { "toolUseId": tool_use["toolUseId"], "name": tool_use["toolName"], - "input": json.loads(tool_use["content"]) - } - - return { - "toolUse": tool_use_event + "input": json.loads(tool_use["content"]), } - + + return {"toolUse": tool_use_event} + # Handle interruption elif nova_event.get("stopReason") == "INTERRUPTED": log_event("nova_interruption_stop_reason") - - interruption: InterruptionDetectedEvent = { - "reason": "user_input" - } - - return { - "interruptionDetected": interruption - } - + + interruption: InterruptionDetectedEvent = {"reason": "user_input"} + + return {"interruptionDetected": interruption} + # Handle usage events (ignore) elif "usageEvent" in nova_event: return None - + # Handle content start events (track role) elif "contentStart" in nova_event: role = nova_event["contentStart"].get("role", "unknown") # Store role for subsequent text output events self._current_role = role return None - + # Handle other events else: return None - + # Nova Sonic event template methods def _get_connection_start_event(self) -> str: """Generate Nova Sonic connection start event.""" - return json.dumps({ - "event": { - "sessionStart": { - "inferenceConfiguration": NOVA_INFERENCE_CONFIG - } - } - }) - + return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": NOVA_INFERENCE_CONFIG}}}) + def _get_prompt_start_event(self, tools: List[ToolSpec]) -> str: """Generate Nova Sonic prompt start event with tool configuration.""" prompt_start_event = { @@ -556,143 +519,121 @@ def _get_prompt_start_event(self, tools: List[ToolSpec]) -> str: "promptStart": { "promptName": self.prompt_name, "textOutputConfiguration": NOVA_TEXT_CONFIG, - "audioOutputConfiguration": NOVA_AUDIO_OUTPUT_CONFIG + "audioOutputConfiguration": NOVA_AUDIO_OUTPUT_CONFIG, } } } - + if tools: tool_config = self._build_tool_configuration(tools) prompt_start_event["event"]["promptStart"]["toolUseOutputConfiguration"] = NOVA_TOOL_CONFIG prompt_start_event["event"]["promptStart"]["toolConfiguration"] = {"tools": tool_config} - + return json.dumps(prompt_start_event) - + def _build_tool_configuration(self, tools: List[ToolSpec]) -> List[Dict]: """Build tool configuration from tool specs.""" tool_config = [] for tool in tools: - input_schema = ({"json": json.dumps(tool['inputSchema']['json'])} - if 'json' in tool['inputSchema'] - else {"json": json.dumps(tool['inputSchema'])}) - - tool_config.append({ - "toolSpec": { - "name": tool["name"], - "description": tool["description"], - "inputSchema": input_schema - } - }) + input_schema = ( + {"json": json.dumps(tool["inputSchema"]["json"])} + if "json" in tool["inputSchema"] + else {"json": json.dumps(tool["inputSchema"])} + ) + + tool_config.append( + {"toolSpec": {"name": tool["name"], "description": tool["description"], "inputSchema": input_schema}} + ) return tool_config - + def _get_system_prompt_events(self, system_prompt: str) -> List[str]: """Generate system prompt events.""" content_name = str(uuid.uuid4()) return [ self._get_text_content_start_event(content_name, "SYSTEM"), self._get_text_input_event(content_name, system_prompt), - self._get_content_end_event(content_name) + self._get_content_end_event(content_name), ] - + def _get_text_content_start_event(self, content_name: str, role: str = "USER") -> str: """Generate text content start event.""" - return json.dumps({ - "event": { - "contentStart": { - "promptName": self.prompt_name, - "contentName": content_name, - "type": "TEXT", - "role": role, - "interactive": True, - "textInputConfiguration": NOVA_TEXT_CONFIG + return json.dumps( + { + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": content_name, + "type": "TEXT", + "role": role, + "interactive": True, + "textInputConfiguration": NOVA_TEXT_CONFIG, + } } } - }) - + ) + def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> str: """Generate tool content start event.""" - return json.dumps({ - "event": { - "contentStart": { - "promptName": self.prompt_name, - "contentName": content_name, - "interactive": False, - "type": "TOOL", - "role": "TOOL", - "toolResultInputConfiguration": { - "toolUseId": tool_use_id, - "type": "TEXT", - "textInputConfiguration": NOVA_TEXT_CONFIG + return json.dumps( + { + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": content_name, + "interactive": False, + "type": "TOOL", + "role": "TOOL", + "toolResultInputConfiguration": { + "toolUseId": tool_use_id, + "type": "TEXT", + "textInputConfiguration": NOVA_TEXT_CONFIG, + }, } } } - }) - + ) + def _get_text_input_event(self, content_name: str, text: str) -> str: """Generate text input event.""" - return json.dumps({ - "event": { - "textInput": { - "promptName": self.prompt_name, - "contentName": content_name, - "content": text - } - } - }) - + return json.dumps( + {"event": {"textInput": {"promptName": self.prompt_name, "contentName": content_name, "content": text}}} + ) + def _get_tool_result_event(self, content_name: str, result: Dict[str, Any]) -> str: """Generate tool result event.""" - return json.dumps({ - "event": { - "toolResult": { - "promptName": self.prompt_name, - "contentName": content_name, - "content": json.dumps(result) + return json.dumps( + { + "event": { + "toolResult": { + "promptName": self.prompt_name, + "contentName": content_name, + "content": json.dumps(result), + } } } - }) - + ) + def _get_content_end_event(self, content_name: str) -> str: """Generate content end event.""" - return json.dumps({ - "event": { - "contentEnd": { - "promptName": self.prompt_name, - "contentName": content_name - } - } - }) - + return json.dumps({"event": {"contentEnd": {"promptName": self.prompt_name, "contentName": content_name}}}) + def _get_prompt_end_event(self) -> str: """Generate prompt end event.""" - return json.dumps({ - "event": { - "promptEnd": { - "promptName": self.prompt_name - } - } - }) - + return json.dumps({"event": {"promptEnd": {"promptName": self.prompt_name}}}) + def _get_connection_end_event(self) -> str: """Generate connection end event.""" - return json.dumps({ - "event": { - "connectionEnd": {} - } - }) - + return json.dumps({"event": {"connectionEnd": {}}}) + async def _send_nova_event(self, event: str) -> None: """Send event JSON string to Nova Sonic stream.""" try: - # Event is already a JSON string - bytes_data = event.encode('utf-8') - chunk = InvokeModelWithBidirectionalStreamInputChunk( - value=BidirectionalInputPayloadPart(bytes_=bytes_data) - ) + bytes_data = event.encode("utf-8") + chunk = InvokeModelWithBidirectionalStreamInputChunk(value=BidirectionalInputPayloadPart(bytes_=bytes_data)) await self.stream.input_stream.send(chunk) logger.debug("Successfully sent Nova Sonic event") - + except Exception as e: logger.error("Error sending Nova Sonic event: %s", e) logger.error("Event was: %s", event) @@ -701,14 +642,14 @@ async def _send_nova_event(self, event: str) -> None: class NovaSonicBidirectionalModel(BidirectionalModel): """Nova Sonic model implementation for bidirectional streaming. - + Provides access to Amazon's Nova Sonic model through the bidirectional streaming interface, handling AWS authentication and connection management. """ - + def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config): """Initialize Nova Sonic bidirectional model. - + Args: model_id: Nova Sonic model identifier. region: AWS region. @@ -718,61 +659,60 @@ def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-e self.region = region self.config = config self._client = None - + logger.debug("Nova Sonic bidirectional model initialized: %s", model_id) - + async def create_bidirectional_connection( self, system_prompt: Optional[str] = None, tools: Optional[List[ToolSpec]] = None, messages: Optional[Messages] = None, - **kwargs + **kwargs, ) -> BidirectionalModelSession: """Create Nova Sonic bidirectional connection.""" log_flow("nova_connection_create", "starting") - + # Initialize client if needed if not self._client: await time_it_async("initialize_client", lambda: self._initialize_client()) - + # Start Nova Sonic bidirectional stream try: - stream = await time_it_async("invoke_model_with_bidirectional_stream", + stream = await time_it_async( + "invoke_model_with_bidirectional_stream", lambda: self._client.invoke_model_with_bidirectional_stream( InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) - )) - + ), + ) + # Create and initialize connection connection = NovaSonicSession(stream, self.config) - await time_it_async("initialize_connection", - lambda: connection.initialize(system_prompt, tools, messages)) - + await time_it_async("initialize_connection", lambda: connection.initialize(system_prompt, tools, messages)) + log_event("nova_connection_created") return connection except Exception as e: log_event("nova_connection_create_error", error=str(e)) logger.error("Failed to create Nova Sonic connection: %s", e) raise - + async def _initialize_client(self) -> None: """Initialize Nova Sonic client.""" try: - config = Config( endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", region=self.region, aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), http_auth_scheme_resolver=HTTPAuthSchemeResolver(), - http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()} + http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()}, ) - + self._client = BedrockRuntimeClient(config=config) logger.debug("Nova Sonic client initialized") - + except ImportError as e: logger.error("Nova Sonic dependencies not available: %s", e) raise except Exception as e: logger.error("Error initializing Nova Sonic client: %s", e) raise - diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py index d650aba9b..6ef96f919 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py @@ -11,12 +11,13 @@ # Add the src directory to Python path sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) import time -import pyaudio -from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel +import pyaudio from strands_tools import calculator +from ..agent.agent import BidirectionalAgent +from ..models.novasonic import NovaSonicBidirectionalModel + async def play(context): """Play audio output with responsive interruption support.""" @@ -26,7 +27,7 @@ async def play(context): format=pyaudio.paInt16, output=True, rate=24000, - frames_per_buffer=1024, + frames_per_buffer=1024, ) try: @@ -40,36 +41,33 @@ async def play(context): context["audio_out"].get_nowait() except asyncio.QueueEmpty: break - + context["interrupted"] = False - await asyncio.sleep(0.05) + await asyncio.sleep(0.05) continue - + # Get next audio data - audio_data = await asyncio.wait_for( - context["audio_out"].get(), - timeout=0.1 - ) - + audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) + if audio_data and context["active"]: - chunk_size = 1024 + chunk_size = 1024 for i in range(0, len(audio_data), chunk_size): # Check for interruption before each chunk if context.get("interrupted", False) or not context["active"]: break - + end = min(i + chunk_size, len(audio_data)) chunk = audio_data[i:end] speaker.write(chunk) await asyncio.sleep(0.001) - + except asyncio.TimeoutError: continue # No audio available except asyncio.QueueEmpty: await asyncio.sleep(0.01) except asyncio.CancelledError: break - + except asyncio.CancelledError: pass finally: @@ -111,30 +109,30 @@ async def receive(agent, context): if "audioOutput" in event: if not context.get("interrupted", False): context["audio_out"].put_nowait(event["audioOutput"]["audioData"]) - + # Handle interruption events elif "interruptionDetected" in event: context["interrupted"] = True elif "interrupted" in event: context["interrupted"] = True - + # Handle text output with interruption detection elif "textOutput" in event: text_content = event["textOutput"].get("content", "") role = event["textOutput"].get("role", "unknown") - + # Check for text-based interruption patterns if '{ "interrupted" : true }' in text_content: context["interrupted"] = True elif "interrupted" in text_content.lower(): context["interrupted"] = True - + # Log text output if role.upper() == "USER": print(f"User: {text_content}") elif role.upper() == "ASSISTANT": print(f"Assistant: {text_content}") - + except asyncio.CancelledError: pass @@ -145,18 +143,13 @@ async def send(agent, context): while time.time() - context["start_time"] < context["duration"]: try: audio_bytes = context["audio_in"].get_nowait() - audio_event = { - "audioData": audio_bytes, - "format": "pcm", - "sampleRate": 16000, - "channels": 1 - } + audio_event = {"audioData": audio_bytes, "format": "pcm", "sampleRate": 16000, "channels": 1} await agent.send(audio_event) except asyncio.QueueEmpty: await asyncio.sleep(0.01) # Restored to working timing except asyncio.CancelledError: break - + context["active"] = False except asyncio.CancelledError: pass @@ -166,14 +159,10 @@ async def main(duration=180): """Main function for bidirectional streaming test.""" print("Starting bidirectional streaming test...") print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") - + # Initialize model and agent model = NovaSonicBidirectionalModel(region="us-east-1") - agent = BidirectionalAgent( - model=model, - tools=[calculator], - system_prompt="You are a helpful assistant." - ) + agent = BidirectionalAgent(model=model, tools=[calculator], system_prompt="You are a helpful assistant.") await agent.start() @@ -189,15 +178,11 @@ async def main(duration=180): } print("Speak into microphone. Press Ctrl+C to exit.") - + try: # Run all tasks concurrently await asyncio.gather( - play(context), - record(context), - receive(agent, context), - send(agent, context), - return_exceptions=True + play(context), record(context), receive(agent, context), send(agent, context), return_exceptions=True ) except KeyboardInterrupt: print("\nInterrupted by user") diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index f6441d2f0..510285f06 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -1,3 +1,31 @@ -"""Bidirectional streaming types package.""" -# Types package +"""Type definitions for bidirectional streaming.""" +from .bidirectional_streaming import ( + DEFAULT_CHANNELS, + DEFAULT_SAMPLE_RATE, + SUPPORTED_AUDIO_FORMATS, + SUPPORTED_CHANNELS, + SUPPORTED_SAMPLE_RATES, + AudioInputEvent, + AudioOutputEvent, + BidirectionalConnectionEndEvent, + BidirectionalConnectionStartEvent, + BidirectionalStreamEvent, + InterruptionDetectedEvent, + TextOutputEvent, +) + +__all__ = [ + "AudioInputEvent", + "AudioOutputEvent", + "BidirectionalConnectionEndEvent", + "BidirectionalConnectionStartEvent", + "BidirectionalStreamEvent", + "InterruptionDetectedEvent", + "TextOutputEvent", + "SUPPORTED_AUDIO_FORMATS", + "SUPPORTED_SAMPLE_RATES", + "SUPPORTED_CHANNELS", + "DEFAULT_SAMPLE_RATE", + "DEFAULT_CHANNELS", +] diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index fabe53ac9..01d72356a 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -19,23 +19,25 @@ from typing import Any, Dict, Literal, Optional -from strands.types.content import Role -from strands.types.streaming import StreamEvent from typing_extensions import TypedDict +from ....types.content import Role +from ....types.streaming import StreamEvent + # Audio format constants -SUPPORTED_AUDIO_FORMATS = ['pcm', 'wav', 'opus', 'mp3'] +SUPPORTED_AUDIO_FORMATS = ["pcm", "wav", "opus", "mp3"] SUPPORTED_SAMPLE_RATES = [16000, 24000, 48000] SUPPORTED_CHANNELS = [1, 2] # 1=mono, 2=stereo DEFAULT_SAMPLE_RATE = 16000 DEFAULT_CHANNELS = 1 + class AudioOutputEvent(TypedDict): """Audio output event from the model. - + Provides standardized audio output format across different providers using raw bytes instead of provider-specific encodings. - + Attributes: audioData: Raw audio bytes (not base64 or hex encoded). format: Audio format from SUPPORTED_AUDIO_FORMATS. @@ -43,9 +45,9 @@ class AudioOutputEvent(TypedDict): channels: Channel count from SUPPORTED_CHANNELS. encoding: Original provider encoding for debugging purposes. """ - + audioData: bytes - format: Literal['pcm', 'wav', 'opus', 'mp3'] + format: Literal["pcm", "wav", "opus", "mp3"] sampleRate: Literal[16000, 24000, 48000] channels: Literal[1, 2] encoding: Optional[str] @@ -53,78 +55,78 @@ class AudioOutputEvent(TypedDict): class AudioInputEvent(TypedDict): """Audio input event for sending audio to the model. - + Used for sending audio data through the send() method. - + Attributes: audioData: Raw audio bytes to send to model. format: Audio format from SUPPORTED_AUDIO_FORMATS. sampleRate: Sample rate from SUPPORTED_SAMPLE_RATES. channels: Channel count from SUPPORTED_CHANNELS. """ - + audioData: bytes - format: Literal['pcm', 'wav', 'opus', 'mp3'] + format: Literal["pcm", "wav", "opus", "mp3"] sampleRate: Literal[16000, 24000, 48000] channels: Literal[1, 2] class TextOutputEvent(TypedDict): """Text output event from the model during bidirectional streaming. - + Attributes: text: The text content from the model. role: The role of the message sender. """ - + text: str role: Role class InterruptionDetectedEvent(TypedDict): """Interruption detection event. - + Signals when user interruption is detected during model generation. - + Attributes: reason: Interruption reason from predefined set. """ - - reason: Literal['user_input', 'vad_detected', 'manual'] + + reason: Literal["user_input", "vad_detected", "manual"] class BidirectionalConnectionStartEvent(TypedDict, total=False): """connection start event for bidirectional streaming. - + Attributes: connectionId: Unique connection identifier. metadata: Provider-specific connection metadata. """ - + connectionId: Optional[str] metadata: Optional[Dict[str, Any]] class BidirectionalConnectionEndEvent(TypedDict): """connection end event for bidirectional streaming. - + Attributes: reason: Reason for connection end from predefined set. connectionId: Unique connection identifier. metadata: Provider-specific connection metadata. """ - - reason: Literal['user_request', 'timeout', 'error'] + + reason: Literal["user_request", "timeout", "error"] connectionId: Optional[str] metadata: Optional[Dict[str, Any]] class BidirectionalStreamEvent(StreamEvent, total=False): """Bidirectional stream event extending existing StreamEvent. - + Extends the existing StreamEvent type with bidirectional-specific events while maintaining full backward compatibility with existing Strands streaming. - + Attributes: audioOutput: Audio output from the model. audioInput: Audio input sent to the model. @@ -133,11 +135,10 @@ class BidirectionalStreamEvent(StreamEvent, total=False): BidirectionalConnectionStart: connection start event. BidirectionalConnectionEnd: connection end event. """ - + audioOutput: AudioOutputEvent audioInput: AudioInputEvent textOutput: TextOutputEvent interruptionDetected: InterruptionDetectedEvent BidirectionalConnectionStart: BidirectionalConnectionStartEvent BidirectionalConnectionEnd: BidirectionalConnectionEndEvent - diff --git a/src/strands/experimental/bidirectional_streaming/utils/__init__.py b/src/strands/experimental/bidirectional_streaming/utils/__init__.py new file mode 100644 index 000000000..579478436 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/utils/__init__.py @@ -0,0 +1,5 @@ +"""Utility functions for bidirectional streaming.""" + +from .debug import log_event, log_flow, time_it_async + +__all__ = ["log_event", "log_flow", "time_it_async"] diff --git a/src/strands/experimental/bidirectional_streaming/utils/debug.py b/src/strands/experimental/bidirectional_streaming/utils/debug.py index 1e88b6ead..6a7fc3982 100644 --- a/src/strands/experimental/bidirectional_streaming/utils/debug.py +++ b/src/strands/experimental/bidirectional_streaming/utils/debug.py @@ -11,30 +11,34 @@ # Debug logging system matching successful tool use example DEBUG = False # Disable debug logging for clean output like tool use example + def debug_print(message): """Print debug message with timestamp and function name.""" if DEBUG: function_name = inspect.stack()[1].function - if function_name == 'time_it_async': + if function_name == "time_it_async": function_name = inspect.stack()[2].function - timestamp = '{:%Y-%m-%d %H:%M:%S.%f}'.format(datetime.datetime.now())[:-3] + timestamp = "{:%Y-%m-%d %H:%M:%S.%f}".format(datetime.datetime.now())[:-3] print(f"{timestamp} {function_name} {message}") + def log_event(event_type, **context): """Log important events with structured context.""" if DEBUG: function_name = inspect.stack()[1].function - timestamp = '{:%Y-%m-%d %H:%M:%S.%f}'.format(datetime.datetime.now())[:-3] + timestamp = "{:%Y-%m-%d %H:%M:%S.%f}".format(datetime.datetime.now())[:-3] context_str = " ".join([f"{k}={v}" for k, v in context.items()]) if context else "" print(f"{timestamp} {function_name} EVENT: {event_type} {context_str}") + def log_flow(step, details=""): """Log important flow steps without excessive detail.""" if DEBUG: function_name = inspect.stack()[1].function - timestamp = '{:%Y-%m-%d %H:%M:%S.%f}'.format(datetime.datetime.now())[:-3] + timestamp = "{:%Y-%m-%d %H:%M:%S.%f}".format(datetime.datetime.now())[:-3] print(f"{timestamp} {function_name} FLOW: {step} {details}") + async def time_it_async(label, method_to_run): """Time asynchronous method execution.""" start_time = time.perf_counter() @@ -42,4 +46,3 @@ async def time_it_async(label, method_to_run): end_time = time.perf_counter() debug_print(f"Execution time for {label}: {end_time - start_time:.4f} seconds") return result - From f7e67aec65640b9e262e88d4f82d020308143250 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 1 Oct 2025 23:59:44 -0400 Subject: [PATCH 005/242] fix linting issues --- pyproject.toml | 1 - .../event_loop/bidirectional_event_loop.py | 5 +++-- .../experimental/bidirectional_streaming/models/novasonic.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f45794d12..dd01ebde3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,6 @@ bidirectional-streaming = [ "smithy-aws-core>=0.0.1", "pytz", "aws_sdk_bedrock_runtime", - "python>=3.12" ] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index c90d118ff..4fbae3992 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -21,6 +21,7 @@ from ....tools._validator import validate_and_prepare_tools from ....types.content import Message from ....types.tools import ToolResult, ToolUse +from ..agent.agent import BidirectionalAgent from ..models.bidirectional_model import BidirectionalModelSession from ..utils.debug import log_event, log_flow @@ -61,7 +62,7 @@ def __init__(self, model_session: BidirectionalModelSession, agent): self.interrupted = False -async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection: +async def start_bidirectional_connection(agent: BidirectionalAgent) -> BidirectionalConnection: """Initialize bidirectional session with concurrent background tasks. Creates a model-specific session and starts background tasks for processing @@ -325,7 +326,7 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: session.pending_tool_tasks[task_id] = task # 🔥 ADD CLEANUP CALLBACK (Nova Sonic pattern) - def cleanup_task(completed_task): + def cleanup_task(completed_task, task_id=task_id): try: # Remove from pending tasks if task_id in session.pending_tool_tasks: diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 0efd2413c..22912354d 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -147,7 +147,7 @@ def _build_initialization_events( async def _send_initialization_events(self, events: List[str]) -> None: """Send initialization events with required delays.""" for i, event in enumerate(events): - await time_it_async(f"send_init_event_{i + 1}", lambda: self._send_nova_event(event)) + await time_it_async(f"send_init_event_{i + 1}", lambda event=event: self._send_nova_event(event)) await asyncio.sleep(EVENT_DELAY) async def _process_responses(self) -> None: @@ -384,7 +384,7 @@ async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> No ] for i, event in enumerate(events): - await time_it_async(f"send_tool_event_{i + 1}", lambda: self._send_nova_event(event)) + await time_it_async(f"send_tool_event_{i + 1}", lambda event=event: self._send_nova_event(event)) async def send_tool_error(self, tool_use_id: str, error: str) -> None: """Send tool error using Nova Sonic format.""" From c654621d9c345316c90e6895a430e2f1918a9b8c Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 00:07:34 -0400 Subject: [PATCH 006/242] Remove typing module and rely on python's built-in types --- .../bidirectional_streaming/agent/agent.py | 10 ++--- .../event_loop/bidirectional_event_loop.py | 13 +++---- .../models/bidirectional_model.py | 12 +++--- .../models/novasonic.py | 38 +++++++++---------- 4 files changed, 36 insertions(+), 37 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index d7a5f17a3..997a0d1df 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -14,7 +14,7 @@ import asyncio import logging -from typing import AsyncIterable, List, Optional, Union +from typing import AsyncIterable from ....tools.executors import ConcurrentToolExecutor from ....tools.registry import ToolRegistry @@ -37,9 +37,9 @@ class BidirectionalAgent: def __init__( self, model: BidirectionalModel, - tools: Optional[List] = None, - system_prompt: Optional[str] = None, - messages: Optional[Messages] = None, + tools: list | None = None, + system_prompt: str | None = None, + messages: Messages | None = None, ): """Initialize bidirectional agent with required model and optional configuration. @@ -83,7 +83,7 @@ async def start(self) -> None: self._session = await start_bidirectional_connection(self) log_event("conversation_ready") - async def send(self, input_data: Union[str, AudioInputEvent]) -> None: + async def send(self, input_data: str | AudioInputEvent) -> None: """Send input to the model (text or audio). Unified method for sending both text and audio input to the model during diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 4fbae3992..65ee6b905 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -16,7 +16,6 @@ import logging import traceback import uuid -from typing import Any, Dict from ....tools._validator import validate_and_prepare_tools from ....types.content import Message @@ -56,14 +55,14 @@ def __init__(self, model_session: BidirectionalModelSession, agent): self.audio_output_queue = asyncio.Queue() # Task management for cleanup - self.pending_tool_tasks: Dict[str, asyncio.Task] = {} + self.pending_tool_tasks: dict[str, asyncio.Task] = {} # Interruption handling (model-agnostic) self.interrupted = False async def start_bidirectional_connection(agent: BidirectionalAgent) -> BidirectionalConnection: - """Initialize bidirectional session with concurrent background tasks. + """Initialize bidirectional session with conycurrent background tasks. Creates a model-specific session and starts background tasks for processing model events, executing tools, and managing the session lifecycle. @@ -365,7 +364,7 @@ def cleanup_task(completed_task, task_id=task_id): log_flow("tool_execution", "processor stopped") -def _convert_to_strands_event(provider_event: Dict) -> Dict: +def _convert_to_strands_event(provider_event: dict) -> dict: """Pass-through for events already normalized by provider sessions. Providers convert their raw events to standard format before reaching here. @@ -385,7 +384,7 @@ def _convert_to_strands_event(provider_event: Dict) -> Dict: return provider_event -async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: Dict) -> None: +async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: """Execute tool using Strands infrastructure with interruption support. Executes tools using the existing Strands tool system, handles interruption @@ -501,11 +500,11 @@ def _extract_callable_function(tool_func): raise ValueError(f"Tool function not callable: {type(tool_func).__name__}") -def _create_success_result(tool_use_id: str, result) -> Dict[str, Any]: +def _create_success_result(tool_use_id: str, result) -> dict[str, any]: """Create a successful tool result.""" return {"toolUseId": tool_use_id, "status": "success", "content": [{"text": json.dumps(result)}]} -def _create_error_result(tool_use_id: str, error: str) -> Dict[str, Any]: +def _create_error_result(tool_use_id: str, error: str) -> dict[str, any]: """Create an error tool result.""" return {"toolUseId": tool_use_id, "status": "error", "content": [{"text": f"Error: {error}"}]} diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index cc803458b..1432b112a 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -13,7 +13,7 @@ import abc import logging -from typing import Any, AsyncIterable, Dict, List, Optional +from typing import AsyncIterable from ....types.content import Messages from ....types.tools import ToolSpec @@ -31,7 +31,7 @@ class BidirectionalModelSession(abc.ABC): """ @abc.abstractmethod - async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: + async def receive_events(self) -> AsyncIterable[dict[str, any]]: """Receive events from the model in standardized format. Converts provider-specific events to a common format that can be @@ -67,7 +67,7 @@ async def send_interrupt(self) -> None: raise NotImplementedError @abc.abstractmethod - async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: + async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: """Send tool execution result to the model. Formats and sends tool results according to the provider's specific protocol. @@ -95,9 +95,9 @@ class BidirectionalModel(abc.ABC): @abc.abstractmethod async def create_bidirectional_connection( self, - system_prompt: Optional[str] = None, - tools: Optional[List[ToolSpec]] = None, - messages: Optional[Messages] = None, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, **kwargs, ) -> BidirectionalModelSession: """Create a bidirectional connection with the model. diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 22912354d..969cac159 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -19,7 +19,7 @@ import time import traceback import uuid -from typing import Any, AsyncIterable, Dict, List, Optional +from typing import AsyncIterable from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme @@ -80,7 +80,7 @@ class NovaSonicSession(BidirectionalModelSession): interface. """ - def __init__(self, stream, config: Dict[str, Any]): + def __init__(self, stream, config: dict[str, any]): """Initialize Nova Sonic connection. Args: @@ -111,9 +111,9 @@ def __init__(self, stream, config: Dict[str, Any]): async def initialize( self, - system_prompt: Optional[str] = None, - tools: Optional[List[ToolSpec]] = None, - messages: Optional[Messages] = None, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, ) -> None: """Initialize Nova Sonic connection with required protocol sequence.""" try: @@ -132,8 +132,8 @@ async def initialize( raise def _build_initialization_events( - self, system_prompt: str, tools: List[ToolSpec], messages: Optional[Messages] - ) -> List[str]: + self, system_prompt: str, tools: list[ToolSpec], messages: Messages | None + ) -> list[str]: """Build the sequence of initialization events.""" events = [self._get_connection_start_event(), self._get_prompt_start_event(tools)] @@ -144,7 +144,7 @@ def _build_initialization_events( return events - async def _send_initialization_events(self, events: List[str]) -> None: + async def _send_initialization_events(self, events: list[str]) -> None: """Send initialization events with required delays.""" for i, event in enumerate(events): await time_it_async(f"send_init_event_{i + 1}", lambda event=event: self._send_nova_event(event)) @@ -192,7 +192,7 @@ async def _handle_response_data(self, response_data: str) -> None: except json.JSONDecodeError as e: log_event("nova_json_error", error=str(e)) - def _log_event_type(self, nova_event: Dict[str, Any]) -> None: + def _log_event_type(self, nova_event: dict[str, any]) -> None: """Log specific Nova Sonic event types for debugging.""" if "usageEvent" in nova_event: log_event("nova_usage", usage=nova_event["usageEvent"]) @@ -206,7 +206,7 @@ def _log_event_type(self, nova_event: Dict[str, Any]) -> None: audio_bytes = base64.b64decode(audio_content) log_event("nova_audio_output", bytes=len(audio_bytes)) - async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: + async def receive_events(self) -> AsyncIterable[dict[str, any]]: """Receive Nova Sonic events and convert to provider-agnostic format.""" if not self.stream: logger.error("Stream is None") @@ -370,7 +370,7 @@ async def send_interrupt(self) -> None: } await self._send_nova_event(interrupt_event) - async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: + async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: """Send tool result using Nova Sonic toolResult format.""" if not self._active: return @@ -433,7 +433,7 @@ async def close(self) -> None: finally: log_event("nova_connection_closed") - def _convert_nova_event(self, nova_event: Dict[str, Any]) -> Optional[Dict[str, Any]]: + def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | None: """Convert Nova Sonic events to provider-agnostic format.""" # Handle audio output if "audioOutput" in nova_event: @@ -512,7 +512,7 @@ def _get_connection_start_event(self) -> str: """Generate Nova Sonic connection start event.""" return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": NOVA_INFERENCE_CONFIG}}}) - def _get_prompt_start_event(self, tools: List[ToolSpec]) -> str: + def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: """Generate Nova Sonic prompt start event with tool configuration.""" prompt_start_event = { "event": { @@ -531,7 +531,7 @@ def _get_prompt_start_event(self, tools: List[ToolSpec]) -> str: return json.dumps(prompt_start_event) - def _build_tool_configuration(self, tools: List[ToolSpec]) -> List[Dict]: + def _build_tool_configuration(self, tools: list[ToolSpec]) -> list[dict]: """Build tool configuration from tool specs.""" tool_config = [] for tool in tools: @@ -546,7 +546,7 @@ def _build_tool_configuration(self, tools: List[ToolSpec]) -> List[Dict]: ) return tool_config - def _get_system_prompt_events(self, system_prompt: str) -> List[str]: + def _get_system_prompt_events(self, system_prompt: str) -> list[str]: """Generate system prompt events.""" content_name = str(uuid.uuid4()) return [ @@ -599,7 +599,7 @@ def _get_text_input_event(self, content_name: str, text: str) -> str: {"event": {"textInput": {"promptName": self.prompt_name, "contentName": content_name, "content": text}}} ) - def _get_tool_result_event(self, content_name: str, result: Dict[str, Any]) -> str: + def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> str: """Generate tool result event.""" return json.dumps( { @@ -664,9 +664,9 @@ def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-e async def create_bidirectional_connection( self, - system_prompt: Optional[str] = None, - tools: Optional[List[ToolSpec]] = None, - messages: Optional[Messages] = None, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, **kwargs, ) -> BidirectionalModelSession: """Create Nova Sonic bidirectional connection.""" From 1f1abacd839cd6ed26ebd9a84bfa2e8aeb50be01 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 00:12:15 -0400 Subject: [PATCH 007/242] add typing to methods --- .../event_loop/bidirectional_event_loop.py | 8 ++++---- .../bidirectional_streaming/models/novasonic.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 65ee6b905..ea00468bb 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -38,7 +38,7 @@ class BidirectionalConnection: handling while providing a simple interface for agent interactions. """ - def __init__(self, model_session: BidirectionalModelSession, agent): + def __init__(self, model_session: BidirectionalModelSession, agent: BidirectionalAgent) -> None: """Initialize session with model session and agent reference. Args: @@ -325,7 +325,7 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: session.pending_tool_tasks[task_id] = task # 🔥 ADD CLEANUP CALLBACK (Nova Sonic pattern) - def cleanup_task(completed_task, task_id=task_id): + def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: try: # Remove from pending tasks if task_id in session.pending_tool_tasks: @@ -488,7 +488,7 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: log_event("tool_error_send_failed", error=str(send_error)) -def _extract_callable_function(tool_func): +def _extract_callable_function(tool_func: any) -> any: """Extract the callable function from different tool object types.""" if hasattr(tool_func, "_tool_func"): return tool_func._tool_func @@ -500,7 +500,7 @@ def _extract_callable_function(tool_func): raise ValueError(f"Tool function not callable: {type(tool_func).__name__}") -def _create_success_result(tool_use_id: str, result) -> dict[str, any]: +def _create_success_result(tool_use_id: str, result: any) -> dict[str, any]: """Create a successful tool result.""" return {"toolUseId": tool_use_id, "status": "success", "content": [{"text": json.dumps(result)}]} diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 969cac159..89472350b 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -80,7 +80,7 @@ class NovaSonicSession(BidirectionalModelSession): interface. """ - def __init__(self, stream, config: dict[str, any]): + def __init__(self, stream: any, config: dict[str, any]) -> None: """Initialize Nova Sonic connection. Args: @@ -312,7 +312,7 @@ async def send_audio_content(self, audio_input: AudioInputEvent) -> None: # Start silence detection task self.silence_task = asyncio.create_task(self._check_silence()) - async def _check_silence(self): + async def _check_silence(self) -> None: """Check for silence and automatically end audio connection.""" try: await asyncio.sleep(self.silence_threshold) @@ -647,7 +647,7 @@ class NovaSonicBidirectionalModel(BidirectionalModel): streaming interface, handling AWS authentication and connection management. """ - def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config): + def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config: any) -> None: """Initialize Nova Sonic bidirectional model. Args: From eb543b52434dbe6af1f2f309f77446a97ed08871 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 12:00:04 -0400 Subject: [PATCH 008/242] Improve comments and remove unused method _convert_to_strands_event --- .../bidirectional_streaming/agent/agent.py | 1 - .../event_loop/bidirectional_event_loop.py | 45 +++++++------------ 2 files changed, 15 insertions(+), 31 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 997a0d1df..e27885c7e 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -98,7 +98,6 @@ async def send(self, input_data: str | AudioInputEvent) -> None: self._validate_active_session() if isinstance(input_data, str): - # Handle text input log_event("text_sent", length=len(input_data)) await self._session.model_session.send_text_content(input_data) elif isinstance(input_data, dict) and "audioData" in input_data: diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index ea00468bb..fddd1245a 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -189,7 +189,7 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: log_event("interruption_detected") session.interrupted = True - # 🔥 CANCEL ALL PENDING TOOL TASKS (Nova Sonic pattern) + # Cancel all pending tool execution tasks cancelled_tools = 0 for task_id, task in list(session.pending_tool_tasks.items()): if not task.done(): @@ -200,7 +200,7 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: if cancelled_tools > 0: log_event("tool_tasks_cancelled", count=cancelled_tools) - # 🔥 AGGRESSIVELY CLEAR AUDIO OUTPUT QUEUE (Nova Sonic pattern) + # Clear all queued audio output events cleared_count = 0 while True: try: @@ -258,8 +258,11 @@ async def _process_model_events(session: BidirectionalConnection) -> None: if not session.active: break - # Convert provider events to Strands format - strands_event = _convert_to_strands_event(provider_event) + # Basic validation - skip invalid events + if not isinstance(provider_event, dict): + continue + + strands_event = provider_event # Handle interruption detection (multiple patterns) if strands_event.get("interruptionDetected"): @@ -269,7 +272,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: await session.agent._output_queue.put(strands_event) continue - # Check for text-based interruption (Nova Sonic pattern) + # Check for text-based interruption if strands_event.get("textOutput"): text_content = strands_event["textOutput"].get("content", "") if '{ "interrupted" : true }' in text_content: @@ -324,7 +327,6 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) session.pending_tool_tasks[task_id] = task - # 🔥 ADD CLEANUP CALLBACK (Nova Sonic pattern) def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: try: # Remove from pending tasks @@ -346,7 +348,7 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: except asyncio.TimeoutError: if not session.active: break - # 🔥 PERIODIC CLEANUP OF COMPLETED TASKS + # Remove completed tasks from tracking completed_tasks = [task_id for task_id, task in session.pending_tool_tasks.items() if task.done()] for task_id in completed_tasks: if task_id in session.pending_tool_tasks: @@ -364,24 +366,7 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: log_flow("tool_execution", "processor stopped") -def _convert_to_strands_event(provider_event: dict) -> dict: - """Pass-through for events already normalized by provider sessions. - - Providers convert their raw events to standard format before reaching here. - This just validates and passes through the normalized events. - - Args: - provider_event: Already normalized event from provider session. - - Returns: - Dict: The same event, validated and passed through. - """ - # Basic validation - ensure we have a dict - if not isinstance(provider_event, dict): - return {} - # Pass through - conversion already done by provider session - return provider_event async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: @@ -398,7 +383,7 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: tool_id = tool_use.get("toolUseId") try: - # 🔥 CHECK FOR INTERRUPTION BEFORE STARTING (Nova Sonic pattern) + # Skip execution if session is interrupted or inactive if session.interrupted or not session.active: log_event("tool_execution_cancelled_before_start", name=tool_name, id=tool_id) return @@ -422,7 +407,7 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: # Execute tools directly (simpler approach for bidirectional) for tool_use in valid_tool_uses: - # 🔥 CHECK FOR INTERRUPTION DURING EXECUTION + # Return early if session was interrupted during execution if session.interrupted or not session.active: log_event("tool_execution_cancelled_during", name=tool_name, id=tool_id) return @@ -433,12 +418,12 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: try: actual_func = _extract_callable_function(tool_func) - # 🔥 WRAP TOOL EXECUTION IN CANCELLATION CHECK + # Execute tool function with provided input # For async tools, we could wrap with asyncio.wait_for with cancellation # For sync tools, we execute directly but check interruption after result = actual_func(**tool_use.get("input", {})) - # 🔥 CHECK FOR INTERRUPTION AFTER TOOL EXECUTION + # Discard result if session was interrupted during execution if session.interrupted or not session.active: log_event("tool_result_discarded_interruption", name=tool_name, id=tool_id) return @@ -451,7 +436,7 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: log_event("tool_execution_cancelled", name=tool_name, id=tool_id) return except Exception as e: - # 🔥 CHECK FOR INTERRUPTION EVEN ON ERROR + # Discard error result if session was interrupted if session.interrupted or not session.active: log_event("tool_error_discarded_interruption", name=tool_name, id=tool_id) return @@ -462,7 +447,7 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: else: log_event("tool_not_found", name=tool_name) - # 🔥 FINAL INTERRUPTION CHECK BEFORE SENDING RESULTS + # Skip sending results if session was interrupted if session.interrupted or not session.active: log_event("tool_results_discarded_interruption", name=tool_name, count=len(tool_results)) return From 5921f8bdb24740adb2b6ad2af609218674b4b4b5 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 12:23:21 -0400 Subject: [PATCH 009/242] Updated: fixed module imports baesd on the new smithy python release on 09-29, added a lock for interruption handling --- .../event_loop/bidirectional_event_loop.py | 118 ++++++++++-------- .../models/novasonic.py | 6 +- .../tests/test_bidirectional_streaming.py | 4 +- 3 files changed, 68 insertions(+), 60 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index fddd1245a..358fdcea3 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -20,7 +20,7 @@ from ....tools._validator import validate_and_prepare_tools from ....types.content import Message from ....types.tools import ToolResult, ToolUse -from ..agent.agent import BidirectionalAgent + from ..models.bidirectional_model import BidirectionalModelSession from ..utils.debug import log_event, log_flow @@ -38,7 +38,7 @@ class BidirectionalConnection: handling while providing a simple interface for agent interactions. """ - def __init__(self, model_session: BidirectionalModelSession, agent: BidirectionalAgent) -> None: + def __init__(self, model_session: BidirectionalModelSession, agent: "BidirectionalAgent") -> None: """Initialize session with model session and agent reference. Args: @@ -59,9 +59,10 @@ def __init__(self, model_session: BidirectionalModelSession, agent: Bidirectiona # Interruption handling (model-agnostic) self.interrupted = False + self.interruption_lock = asyncio.Lock() -async def start_bidirectional_connection(agent: BidirectionalAgent) -> BidirectionalConnection: +async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection: """Initialize bidirectional session with conycurrent background tasks. Creates a model-specific session and starts background tasks for processing @@ -181,66 +182,73 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: """Handle interruption detection with task cancellation and audio buffer clearing. Cancels pending tool tasks and clears audio output queues to ensure responsive - interruption handling during conversations. + interruption handling during conversations. Protected by async lock to prevent + concurrent execution and race conditions. Args: session: BidirectionalConnection to handle interruption for. """ - log_event("interruption_detected") - session.interrupted = True + async with session.interruption_lock: + # If already interrupted, skip duplicate processing + if session.interrupted: + log_event("interruption_already_in_progress") + return - # Cancel all pending tool execution tasks - cancelled_tools = 0 - for task_id, task in list(session.pending_tool_tasks.items()): - if not task.done(): - task.cancel() - cancelled_tools += 1 - log_event("tool_task_cancelled", task_id=task_id) + log_event("interruption_detected") + session.interrupted = True - if cancelled_tools > 0: - log_event("tool_tasks_cancelled", count=cancelled_tools) + # Cancel all pending tool execution tasks + cancelled_tools = 0 + for task_id, task in list(session.pending_tool_tasks.items()): + if not task.done(): + task.cancel() + cancelled_tools += 1 + log_event("tool_task_cancelled", task_id=task_id) - # Clear all queued audio output events - cleared_count = 0 - while True: - try: - session.audio_output_queue.get_nowait() - cleared_count += 1 - except asyncio.QueueEmpty: - break + if cancelled_tools > 0: + log_event("tool_tasks_cancelled", count=cancelled_tools) - # Also clear the agent's audio output queue if it exists - if hasattr(session.agent, "_output_queue"): - audio_cleared = 0 - # Create a temporary list to hold non-audio events - temp_events = [] - try: - while True: - event = session.agent._output_queue.get_nowait() - if event.get("audioOutput"): - audio_cleared += 1 - else: - # Keep non-audio events - temp_events.append(event) - except asyncio.QueueEmpty: - pass - - # Put back non-audio events - for event in temp_events: - session.agent._output_queue.put_nowait(event) - - if audio_cleared > 0: - log_event("agent_audio_queue_cleared", count=audio_cleared) - - if cleared_count > 0: - log_event("session_audio_queue_cleared", count=cleared_count) - - # Brief sleep to allow audio system to settle (matches Nova Sonic timing) - await asyncio.sleep(0.05) - - # Reset interruption flag after clearing (automatic recovery) - session.interrupted = False - log_event("interruption_handled", tools_cancelled=cancelled_tools, audio_cleared=cleared_count) + # Clear all queued audio output events + cleared_count = 0 + while True: + try: + session.audio_output_queue.get_nowait() + cleared_count += 1 + except asyncio.QueueEmpty: + break + + # Also clear the agent's audio output queue if it exists + if hasattr(session.agent, "_output_queue"): + audio_cleared = 0 + # Create a temporary list to hold non-audio events + temp_events = [] + try: + while True: + event = session.agent._output_queue.get_nowait() + if event.get("audioOutput"): + audio_cleared += 1 + else: + # Keep non-audio events + temp_events.append(event) + except asyncio.QueueEmpty: + pass + + # Put back non-audio events + for event in temp_events: + session.agent._output_queue.put_nowait(event) + + if audio_cleared > 0: + log_event("agent_audio_queue_cleared", count=audio_cleared) + + if cleared_count > 0: + log_event("session_audio_queue_cleared", count=cleared_count) + + # Brief sleep to allow audio system to settle (matches Nova Sonic timing) + await asyncio.sleep(0.05) + + # Reset interruption flag after clearing (automatic recovery) + session.interrupted = False + log_event("interruption_handled", tools_cancelled=cancelled_tools, audio_cleared=cleared_count) async def _process_model_events(session: BidirectionalConnection) -> None: diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 89472350b..e79229623 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -24,7 +24,7 @@ from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme from aws_sdk_bedrock_runtime.models import BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk -from smithy_aws_core.credentials_resolvers.environment import EnvironmentCredentialsResolver +from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver from ....types.content import Messages from ....types.tools import ToolSpec, ToolUse @@ -703,8 +703,8 @@ async def _initialize_client(self) -> None: endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", region=self.region, aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), - http_auth_scheme_resolver=HTTPAuthSchemeResolver(), - http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()}, + auth_scheme_resolver=HTTPAuthSchemeResolver(), + auth_schemes={"aws.auth#sigv4": SigV4AuthScheme(service="bedrock")}, ) self._client = BedrockRuntimeClient(config=config) diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py index 6ef96f919..b31607966 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py @@ -15,8 +15,8 @@ import pyaudio from strands_tools import calculator -from ..agent.agent import BidirectionalAgent -from ..models.novasonic import NovaSonicBidirectionalModel +from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel async def play(context): From 8cb4d98ba035d021cdff1953cf9705cca114e270 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 12:37:33 -0400 Subject: [PATCH 010/242] Removed unnecessary _output_queue check as the queue will always be initialized, and removed asyncio.sleep() as they were mainly for defensive purposes and following the pattern of nova sonic samples. --- .../event_loop/bidirectional_event_loop.py | 50 ++++++++----------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 358fdcea3..b4395f38e 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -20,7 +20,6 @@ from ....tools._validator import validate_and_prepare_tools from ....types.content import Message from ....types.tools import ToolResult, ToolUse - from ..models.bidirectional_model import BidirectionalModelSession from ..utils.debug import log_event, log_flow @@ -95,10 +94,7 @@ async def start_bidirectional_connection(agent: "BidirectionalAgent") -> Bidirec # Start main coordination cycle session.main_cycle_task = asyncio.create_task(bidirectional_event_loop_cycle(session)) - # Give background tasks a moment to start - await asyncio.sleep(0.1) log_event("session_ready", tasks=len(session.background_tasks)) - return session @@ -217,35 +213,31 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: except asyncio.QueueEmpty: break - # Also clear the agent's audio output queue if it exists - if hasattr(session.agent, "_output_queue"): - audio_cleared = 0 - # Create a temporary list to hold non-audio events - temp_events = [] - try: - while True: - event = session.agent._output_queue.get_nowait() - if event.get("audioOutput"): - audio_cleared += 1 - else: - # Keep non-audio events - temp_events.append(event) - except asyncio.QueueEmpty: - pass - - # Put back non-audio events - for event in temp_events: - session.agent._output_queue.put_nowait(event) - - if audio_cleared > 0: - log_event("agent_audio_queue_cleared", count=audio_cleared) + # Also clear the agent's audio output queue + audio_cleared = 0 + # Create a temporary list to hold non-audio events + temp_events = [] + try: + while True: + event = session.agent._output_queue.get_nowait() + if event.get("audioOutput"): + audio_cleared += 1 + else: + # Keep non-audio events + temp_events.append(event) + except asyncio.QueueEmpty: + pass + + # Put back non-audio events + for event in temp_events: + session.agent._output_queue.put_nowait(event) + + if audio_cleared > 0: + log_event("agent_audio_queue_cleared", count=audio_cleared) if cleared_count > 0: log_event("session_audio_queue_cleared", count=cleared_count) - # Brief sleep to allow audio system to settle (matches Nova Sonic timing) - await asyncio.sleep(0.05) - # Reset interruption flag after clearing (automatic recovery) session.interrupted = False log_event("interruption_handled", tools_cancelled=cancelled_tools, audio_cleared=cleared_count) From 7a6e53efdf669352bd18f19531178d46589c214d Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 13:03:42 -0400 Subject: [PATCH 011/242] Remove redundant interruption checks --- .../event_loop/bidirectional_event_loop.py | 67 +++---------------- 1 file changed, 11 insertions(+), 56 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index b4395f38e..cc4f416b7 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -264,7 +264,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: strands_event = provider_event - # Handle interruption detection (multiple patterns) + # Handle interruption detection (provider converts raw patterns to interruptionDetected) if strands_event.get("interruptionDetected"): log_event("interruption_forwarded") await _handle_interruption(session) @@ -272,16 +272,6 @@ async def _process_model_events(session: BidirectionalConnection) -> None: await session.agent._output_queue.put(strands_event) continue - # Check for text-based interruption - if strands_event.get("textOutput"): - text_content = strands_event["textOutput"].get("content", "") - if '{ "interrupted" : true }' in text_content: - log_event("text_interruption_detected") - await _handle_interruption(session) - # Still forward the text event - await session.agent._output_queue.put(strands_event) - continue - # Queue tool requests for concurrent execution if strands_event.get("toolUse"): log_event("tool_queued", name=strands_event["toolUse"].get("name")) @@ -308,8 +298,8 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: """Execute tools concurrently with interruption support. Background task that manages tool execution without blocking model event - processing or user interaction. Includes proper task cleanup and cancellation - handling for interruptions. + processing or user interaction. Uses proper asyncio cancellation for + interruption handling rather than manual state checks. Args: session: BidirectionalConnection containing tool queue. @@ -320,9 +310,6 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) log_event("tool_execution_started", name=tool_use.get("name"), id=tool_use.get("toolUseId")) - if not session.active: - break - task_id = str(uuid.uuid4()) task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) session.pending_tool_tasks[task_id] = task @@ -372,8 +359,9 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: """Execute tool using Strands infrastructure with interruption support. - Executes tools using the existing Strands tool system, handles interruption - during execution, and sends results back to the model provider. + Executes tools using the existing Strands tool system with proper asyncio + cancellation handling. Tool execution is stopped via task cancellation, + not manual state checks. Args: session: BidirectionalConnection for context. @@ -383,11 +371,6 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: tool_id = tool_use.get("toolUseId") try: - # Skip execution if session is interrupted or inactive - if session.interrupted or not session.active: - log_event("tool_execution_cancelled_before_start", name=tool_name, id=tool_id) - return - # Create message structure for existing tool system tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} @@ -407,11 +390,6 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: # Execute tools directly (simpler approach for bidirectional) for tool_use in valid_tool_uses: - # Return early if session was interrupted during execution - if session.interrupted or not session.active: - log_event("tool_execution_cancelled_during", name=tool_name, id=tool_id) - return - tool_func = session.agent.tool_registry.registry.get(tool_use["name"]) if tool_func: @@ -419,39 +397,18 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: actual_func = _extract_callable_function(tool_func) # Execute tool function with provided input - # For async tools, we could wrap with asyncio.wait_for with cancellation - # For sync tools, we execute directly but check interruption after result = actual_func(**tool_use.get("input", {})) - # Discard result if session was interrupted during execution - if session.interrupted or not session.active: - log_event("tool_result_discarded_interruption", name=tool_name, id=tool_id) - return - tool_result = _create_success_result(tool_use["toolUseId"], result) tool_results.append(tool_result) - except asyncio.CancelledError: - # Tool was cancelled due to interruption - log_event("tool_execution_cancelled", name=tool_name, id=tool_id) - return except Exception as e: - # Discard error result if session was interrupted - if session.interrupted or not session.active: - log_event("tool_error_discarded_interruption", name=tool_name, id=tool_id) - return - log_event("tool_execution_failed", name=tool_name, error=str(e)) tool_result = _create_error_result(tool_use["toolUseId"], str(e)) tool_results.append(tool_result) else: log_event("tool_not_found", name=tool_name) - # Skip sending results if session was interrupted - if session.interrupted or not session.active: - log_event("tool_results_discarded_interruption", name=tool_name, count=len(tool_results)) - return - # Send results through provider-specific session for result in tool_results: await session.model_session.send_tool_result(tool_use.get("toolUseId"), result) @@ -464,13 +421,11 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: raise # Re-raise to properly handle cancellation except Exception as e: log_event("tool_execution_error", name=tool_use.get("name"), error=str(e)) - - # Only send error if not interrupted - if not session.interrupted and session.active: - try: - await session.model_session.send_tool_error(tool_use.get("toolUseId"), str(e)) - except Exception as send_error: - log_event("tool_error_send_failed", error=str(send_error)) + + try: + await session.model_session.send_tool_error(tool_use.get("toolUseId"), str(e)) + except Exception as send_error: + log_event("tool_error_send_failed", error=str(send_error)) def _extract_callable_function(tool_func: any) -> any: From a58626107b21dad40a52bf27320f35e1af9a5df8 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 13:25:51 -0400 Subject: [PATCH 012/242] Unified tool result and tool error methods, Added implementation to add user messages to the agent messages --- .../bidirectional_streaming/agent/agent.py | 8 ++++++-- .../event_loop/bidirectional_event_loop.py | 19 ++++++++++++------- .../models/bidirectional_model.py | 6 +----- .../models/novasonic.py | 6 ------ 4 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index e27885c7e..46bc38ef2 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -87,7 +87,8 @@ async def send(self, input_data: str | AudioInputEvent) -> None: """Send input to the model (text or audio). Unified method for sending both text and audio input to the model during - an active conversation session. + an active conversation session. User input is automatically added to + conversation history for complete message tracking. Args: input_data: Either a string for text input or AudioInputEvent for audio input. @@ -98,10 +99,13 @@ async def send(self, input_data: str | AudioInputEvent) -> None: self._validate_active_session() if isinstance(input_data, str): + # Add user text message to history + self.messages.append({"role": "user", "content": input_data}) + log_event("text_sent", length=len(input_data)) await self._session.model_session.send_text_content(input_data) elif isinstance(input_data, dict) and "audioData" in input_data: - # Handle audio input (AudioInputEvent) + # Handle audio input await self._session.model_session.send_audio_content(input_data) else: raise ValueError( diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index cc4f416b7..684c0037e 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -261,7 +261,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Basic validation - skip invalid events if not isinstance(provider_event, dict): continue - + strands_event = provider_event # Handle interruption detection (provider converts raw patterns to interruptionDetected) @@ -287,6 +287,14 @@ async def _process_model_events(session: BidirectionalConnection) -> None: log_event("message_added_to_history") session.agent.messages.append(strands_event["messageStop"]["message"]) + # Handle user audio transcripts - add to message history + if strands_event.get("textOutput") and strands_event["textOutput"].get("role") == "user": + user_transcript = strands_event["textOutput"]["text"] + if user_transcript.strip(): # Only add non-empty transcripts + user_message = {"role": "user", "content": user_transcript} + session.agent.messages.append(user_message) + log_event("user_transcript_added_to_history") + except Exception as e: log_event("model_events_error", error=str(e)) traceback.print_exc() @@ -298,7 +306,7 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: """Execute tools concurrently with interruption support. Background task that manages tool execution without blocking model event - processing or user interaction. Uses proper asyncio cancellation for + processing or user interaction. Uses proper asyncio cancellation for interruption handling rather than manual state checks. Args: @@ -353,9 +361,6 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: log_flow("tool_execution", "processor stopped") - - - async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: """Execute tool using Strands infrastructure with interruption support. @@ -421,9 +426,9 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: raise # Re-raise to properly handle cancellation except Exception as e: log_event("tool_execution_error", name=tool_use.get("name"), error=str(e)) - + try: - await session.model_session.send_tool_error(tool_use.get("toolUseId"), str(e)) + await session.model_session.send_tool_result(tool_use.get("toolUseId"), {"error": str(e)}) except Exception as send_error: log_event("tool_error_send_failed", error=str(send_error)) diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 1432b112a..4cd9cc6b8 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -71,14 +71,10 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No """Send tool execution result to the model. Formats and sends tool results according to the provider's specific protocol. + Handles both successful results and error cases. """ raise NotImplementedError - @abc.abstractmethod - async def send_tool_error(self, tool_use_id: str, error: str) -> None: - """Send tool execution error to model in provider-specific format.""" - raise NotImplementedError - @abc.abstractmethod async def close(self) -> None: """Close the connection and cleanup resources.""" diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index e79229623..dfd911172 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -386,12 +386,6 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No for i, event in enumerate(events): await time_it_async(f"send_tool_event_{i + 1}", lambda event=event: self._send_nova_event(event)) - async def send_tool_error(self, tool_use_id: str, error: str) -> None: - """Send tool error using Nova Sonic format.""" - log_event("nova_tool_error_send", id=tool_use_id, error=error) - error_result = {"error": error} - await self.send_tool_result(tool_use_id, error_result) - async def close(self) -> None: """Close Nova Sonic connection with proper cleanup sequence.""" if not self._active: From 16d9b461d187b45ee6d3305268ef23293accd3b0 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 14:00:25 -0400 Subject: [PATCH 013/242] Modified logging to use python logger --- .../bidirectional_streaming/agent/agent.py | 8 +- .../event_loop/bidirectional_event_loop.py | 89 ++++++++++--------- .../models/novasonic.py | 67 +++++++------- 3 files changed, 83 insertions(+), 81 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 46bc38ef2..68d371a51 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -22,7 +22,7 @@ from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection from ..models.bidirectional_model import BidirectionalModel from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent -from ..utils.debug import log_event, log_flow + logger = logging.getLogger(__name__) @@ -79,9 +79,9 @@ async def start(self) -> None: if self._session and self._session.active: raise ValueError("Conversation already active. Call end() first.") - log_flow("conversation_start", "initializing session") + logger.debug("Conversation start - initializing session") self._session = await start_bidirectional_connection(self) - log_event("conversation_ready") + logger.debug("Conversation ready") async def send(self, input_data: str | AudioInputEvent) -> None: """Send input to the model (text or audio). @@ -102,7 +102,7 @@ async def send(self, input_data: str | AudioInputEvent) -> None: # Add user text message to history self.messages.append({"role": "user", "content": input_data}) - log_event("text_sent", length=len(input_data)) + logger.debug("Text sent: %d characters", len(input_data)) await self._session.model_session.send_text_content(input_data) elif isinstance(input_data, dict) and "audioData" in input_data: # Handle audio input diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 684c0037e..16be08aaf 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -21,7 +21,7 @@ from ....types.content import Message from ....types.tools import ToolResult, ToolUse from ..models.bidirectional_model import BidirectionalModelSession -from ..utils.debug import log_event, log_flow + logger = logging.getLogger(__name__) @@ -73,7 +73,7 @@ async def start_bidirectional_connection(agent: "BidirectionalAgent") -> Bidirec Returns: BidirectionalConnection: Active session with background tasks running. """ - log_flow("session_start", "initializing model session") + logger.debug("Starting bidirectional session - initializing model session") # Create provider-specific session model_session = await agent.model.create_bidirectional_connection( @@ -85,7 +85,7 @@ async def start_bidirectional_connection(agent: "BidirectionalAgent") -> Bidirec # Start concurrent background processors IMMEDIATELY after session creation # This is critical - Nova Sonic needs response processing during initialization - log_flow("background_tasks", "starting processors") + logger.debug("Starting background processors for concurrent processing") session.background_tasks = [ asyncio.create_task(_process_model_events(session)), # Handle model responses asyncio.create_task(_process_tool_execution(session)), # Execute tools concurrently @@ -94,7 +94,7 @@ async def start_bidirectional_connection(agent: "BidirectionalAgent") -> Bidirec # Start main coordination cycle session.main_cycle_task = asyncio.create_task(bidirectional_event_loop_cycle(session)) - log_event("session_ready", tasks=len(session.background_tasks)) + logger.debug("Session ready with %d background tasks", len(session.background_tasks)) return session @@ -107,7 +107,7 @@ async def stop_bidirectional_connection(session: BidirectionalConnection) -> Non if not session.active: return - log_flow("session_cleanup", "starting") + logger.debug("Session cleanup starting") session.active = False # Cancel pending tool tasks @@ -134,7 +134,7 @@ async def stop_bidirectional_connection(session: BidirectionalConnection) -> Non # Close model session await session.model_session.close() - log_event("session_closed") + logger.debug("Session closed") async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None: @@ -150,7 +150,7 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No try: # Check if background processors are still running if all(task.done() for task in session.background_tasks): - log_event("session_end", reason="all_processors_completed") + logger.debug("Session end - all processors completed") session.active = False break @@ -159,7 +159,7 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No if task.done() and not task.cancelled(): exception = task.exception() if exception: - log_event("session_error", processor=i, error=str(exception)) + logger.error("Session error in processor %d: %s", i, str(exception)) session.active = False raise exception @@ -169,7 +169,7 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No except asyncio.CancelledError: break except Exception as e: - log_event("event_loop_error", error=str(e)) + logger.error("Event loop error: %s", str(e)) session.active = False raise @@ -187,10 +187,10 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: async with session.interruption_lock: # If already interrupted, skip duplicate processing if session.interrupted: - log_event("interruption_already_in_progress") + logger.debug("Interruption already in progress") return - log_event("interruption_detected") + logger.debug("Interruption detected") session.interrupted = True # Cancel all pending tool execution tasks @@ -199,10 +199,10 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: if not task.done(): task.cancel() cancelled_tools += 1 - log_event("tool_task_cancelled", task_id=task_id) + logger.debug("Tool task cancelled: %s", task_id) if cancelled_tools > 0: - log_event("tool_tasks_cancelled", count=cancelled_tools) + logger.debug("Tool tasks cancelled: %d", cancelled_tools) # Clear all queued audio output events cleared_count = 0 @@ -233,14 +233,14 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: session.agent._output_queue.put_nowait(event) if audio_cleared > 0: - log_event("agent_audio_queue_cleared", count=audio_cleared) + logger.debug("Agent audio queue cleared: %d events", audio_cleared) if cleared_count > 0: - log_event("session_audio_queue_cleared", count=cleared_count) + logger.debug("Session audio queue cleared: %d events", cleared_count) # Reset interruption flag after clearing (automatic recovery) session.interrupted = False - log_event("interruption_handled", tools_cancelled=cancelled_tools, audio_cleared=cleared_count) + logger.debug("Interruption handled - tools cancelled: %d, audio cleared: %d", cancelled_tools, cleared_count) async def _process_model_events(session: BidirectionalConnection) -> None: @@ -252,7 +252,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: Args: session: BidirectionalConnection containing model session. """ - log_flow("model_events", "processor started") + logger.debug("Model events processor started") try: async for provider_event in session.model_session.receive_events(): if not session.active: @@ -261,12 +261,12 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Basic validation - skip invalid events if not isinstance(provider_event, dict): continue - + strands_event = provider_event # Handle interruption detection (provider converts raw patterns to interruptionDetected) if strands_event.get("interruptionDetected"): - log_event("interruption_forwarded") + logger.debug("Interruption forwarded") await _handle_interruption(session) # Forward interruption event to agent for application-level handling await session.agent._output_queue.put(strands_event) @@ -274,7 +274,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Queue tool requests for concurrent execution if strands_event.get("toolUse"): - log_event("tool_queued", name=strands_event["toolUse"].get("name")) + logger.debug("Tool queued: %s", strands_event["toolUse"].get("name")) await session.tool_queue.put(strands_event["toolUse"]) continue @@ -284,39 +284,39 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Update Agent conversation history using existing patterns if strands_event.get("messageStop"): - log_event("message_added_to_history") + logger.debug("Message added to history") session.agent.messages.append(strands_event["messageStop"]["message"]) - + # Handle user audio transcripts - add to message history if strands_event.get("textOutput") and strands_event["textOutput"].get("role") == "user": user_transcript = strands_event["textOutput"]["text"] if user_transcript.strip(): # Only add non-empty transcripts user_message = {"role": "user", "content": user_transcript} session.agent.messages.append(user_message) - log_event("user_transcript_added_to_history") + logger.debug("User transcript added to history") except Exception as e: - log_event("model_events_error", error=str(e)) + logger.error("Model events error: %s", str(e)) traceback.print_exc() finally: - log_flow("model_events", "processor stopped") + logger.debug("Model events processor stopped") async def _process_tool_execution(session: BidirectionalConnection) -> None: """Execute tools concurrently with interruption support. Background task that manages tool execution without blocking model event - processing or user interaction. Uses proper asyncio cancellation for + processing or user interaction. Uses proper asyncio cancellation for interruption handling rather than manual state checks. Args: session: BidirectionalConnection containing tool queue. """ - log_flow("tool_execution", "processor started") + logger.debug("Tool execution processor started") while session.active: try: tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) - log_event("tool_execution_started", name=tool_use.get("name"), id=tool_use.get("toolUseId")) + logger.debug("Tool execution started: %s (id: %s)", tool_use.get("name"), tool_use.get("toolUseId")) task_id = str(uuid.uuid4()) task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) @@ -330,13 +330,13 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: # Log completion status if completed_task.cancelled(): - log_event("tool_task_cleanup_cancelled", task_id=task_id) + logger.debug("Tool task cleanup cancelled: %s", task_id) elif completed_task.exception(): - log_event("tool_task_cleanup_error", task_id=task_id, error=str(completed_task.exception())) + logger.error("Tool task cleanup error: %s - %s", task_id, str(completed_task.exception())) else: - log_event("tool_task_cleanup_success", task_id=task_id) + logger.debug("Tool task cleanup success: %s", task_id) except Exception as e: - log_event("tool_task_cleanup_failed", task_id=task_id, error=str(e)) + logger.error("Tool task cleanup failed: %s - %s", task_id, str(e)) task.add_done_callback(cleanup_task) @@ -350,15 +350,18 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: del session.pending_tool_tasks[task_id] if completed_tasks: - log_event("periodic_task_cleanup", count=len(completed_tasks)) + logger.debug("Periodic task cleanup: %d tasks", len(completed_tasks)) continue except Exception as e: - log_event("tool_execution_error", error=str(e)) + logger.error("Tool execution error: %s", str(e)) if not session.active: break - log_flow("tool_execution", "processor stopped") + logger.debug("Tool execution processor stopped") + + + async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: @@ -390,7 +393,7 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] if not valid_tool_uses: - log_event("tool_validation_failed", name=tool_name, id=tool_id) + logger.warning("Tool validation failed: %s (id: %s)", tool_name, tool_id) return # Execute tools directly (simpler approach for bidirectional) @@ -408,29 +411,29 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: tool_results.append(tool_result) except Exception as e: - log_event("tool_execution_failed", name=tool_name, error=str(e)) + logger.error("Tool execution failed: %s - %s", tool_name, str(e)) tool_result = _create_error_result(tool_use["toolUseId"], str(e)) tool_results.append(tool_result) else: - log_event("tool_not_found", name=tool_name) + logger.warning("Tool not found: %s", tool_name) # Send results through provider-specific session for result in tool_results: await session.model_session.send_tool_result(tool_use.get("toolUseId"), result) - log_event("tool_execution_completed", name=tool_name, results=len(tool_results)) + logger.debug("Tool execution completed: %s (%d results)", tool_name, len(tool_results)) except asyncio.CancelledError: # Task was cancelled due to interruption - this is expected behavior - log_event("tool_task_cancelled_gracefully", name=tool_name, id=tool_id) + logger.debug("Tool task cancelled gracefully: %s (id: %s)", tool_name, tool_id) raise # Re-raise to properly handle cancellation except Exception as e: - log_event("tool_execution_error", name=tool_use.get("name"), error=str(e)) - + logger.error("Tool execution error: %s - %s", tool_use.get("name"), str(e)) + try: await session.model_session.send_tool_result(tool_use.get("toolUseId"), {"error": str(e)}) except Exception as send_error: - log_event("tool_error_send_failed", error=str(send_error)) + logger.error("Tool error send failed: %s", str(send_error)) def _extract_callable_function(tool_func: any) -> any: diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index dfd911172..7f7937ef1 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -36,7 +36,7 @@ InterruptionDetectedEvent, TextOutputEvent, ) -from ..utils.debug import log_event, log_flow, time_it_async + from .bidirectional_model import BidirectionalModel, BidirectionalModelSession logger = logging.getLogger(__name__) @@ -121,10 +121,10 @@ async def initialize( init_events = self._build_initialization_events(system_prompt, tools or [], messages) - log_flow("nova_init", f"sending {len(init_events)} events") + logger.debug(f"Nova Sonic initialization - sending {len(init_events)} events") await self._send_initialization_events(init_events) - log_event("nova_connection_initialized") + logger.info("Nova Sonic connection initialized successfully") self._response_task = asyncio.create_task(self._process_responses()) except Exception as e: @@ -147,12 +147,12 @@ def _build_initialization_events( async def _send_initialization_events(self, events: list[str]) -> None: """Send initialization events with required delays.""" for i, event in enumerate(events): - await time_it_async(f"send_init_event_{i + 1}", lambda event=event: self._send_nova_event(event)) + await self._send_nova_event(event) await asyncio.sleep(EVENT_DELAY) async def _process_responses(self) -> None: """Process Nova Sonic responses continuously.""" - log_flow("nova_responses", "processor started") + logger.debug("Nova Sonic response processor started") try: while self._active: @@ -167,14 +167,14 @@ async def _process_responses(self) -> None: await asyncio.sleep(0.1) continue except Exception as e: - log_event("nova_response_error", error=str(e)) + logger.warning(f"Nova Sonic response error: {e}") await asyncio.sleep(0.1) continue except Exception as e: - log_event("nova_fatal_error", error=str(e)) + logger.error(f"Nova Sonic fatal error: {e}") finally: - log_flow("nova_responses", "processor stopped") + logger.debug("Nova Sonic response processor stopped") async def _handle_response_data(self, response_data: str) -> None: """Handle decoded response data from Nova Sonic.""" @@ -190,21 +190,21 @@ async def _handle_response_data(self, response_data: str) -> None: await self._event_queue.put(nova_event) except json.JSONDecodeError as e: - log_event("nova_json_error", error=str(e)) + logger.warning(f"Nova Sonic JSON decode error: {e}") def _log_event_type(self, nova_event: dict[str, any]) -> None: """Log specific Nova Sonic event types for debugging.""" if "usageEvent" in nova_event: - log_event("nova_usage", usage=nova_event["usageEvent"]) + logger.debug("Nova usage: %s", nova_event["usageEvent"]) elif "textOutput" in nova_event: - log_event("nova_text_output") + logger.debug("Nova text output") elif "toolUse" in nova_event: tool_use = nova_event["toolUse"] - log_event("nova_tool_use", name=tool_use["toolName"], id=tool_use["toolUseId"]) + logger.debug("Nova tool use: %s (id: %s)", tool_use["toolName"], tool_use["toolUseId"]) elif "audioOutput" in nova_event: audio_content = nova_event["audioOutput"]["content"] audio_bytes = base64.b64decode(audio_content) - log_event("nova_audio_output", bytes=len(audio_bytes)) + logger.debug("Nova audio output: %d bytes", len(audio_bytes)) async def receive_events(self) -> AsyncIterable[dict[str, any]]: """Receive Nova Sonic events and convert to provider-agnostic format.""" @@ -212,7 +212,7 @@ async def receive_events(self) -> AsyncIterable[dict[str, any]]: logger.error("Stream is None") return - log_flow("nova_events", "starting event stream") + logger.debug("Nova events - starting event stream") # Emit connection start event to Strands event system connection_start: BidirectionalConnectionStartEvent = { @@ -257,7 +257,7 @@ async def start_audio_connection(self) -> None: if self.audio_connection_active: return - log_event("nova_audio_connection_start") + logger.debug("Nova audio connection start") audio_content_start = json.dumps( { @@ -319,7 +319,7 @@ async def _check_silence(self) -> None: if self.audio_connection_active and self.last_audio_time: elapsed = time.time() - self.last_audio_time if elapsed >= self.silence_threshold: - log_event("nova_silence_detected", elapsed=elapsed) + logger.debug("Nova silence detected: %.2f seconds", elapsed) await self.end_audio_input() except asyncio.CancelledError: pass @@ -329,7 +329,7 @@ async def end_audio_input(self) -> None: if not self.audio_connection_active: return - log_event("nova_audio_connection_end") + logger.debug("Nova audio connection end") audio_content_end = json.dumps( {"event": {"contentEnd": {"promptName": self.prompt_name, "contentName": self.audio_content_name}}} @@ -375,7 +375,7 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No if not self._active: return - log_event("nova_tool_result_send", id=tool_use_id) + logger.debug("Nova tool result send: %s", tool_use_id) content_name = str(uuid.uuid4()) events = [ self._get_tool_content_start_event(content_name, tool_use_id), @@ -384,14 +384,16 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No ] for i, event in enumerate(events): - await time_it_async(f"send_tool_event_{i + 1}", lambda event=event: self._send_nova_event(event)) + await self._send_nova_event(event) + + async def close(self) -> None: """Close Nova Sonic connection with proper cleanup sequence.""" if not self._active: return - log_flow("nova_cleanup", "starting connection close") + logger.debug("Nova cleanup - starting connection close") self._active = False # Cancel response processing task if running @@ -423,9 +425,9 @@ async def close(self) -> None: logger.warning("Error closing Nova Sonic stream: %s", e) except Exception as e: - log_event("nova_cleanup_error", error=str(e)) + logger.error("Nova cleanup error: %s", str(e)) finally: - log_event("nova_connection_closed") + logger.debug("Nova connection closed") def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | None: """Convert Nova Sonic events to provider-agnostic format.""" @@ -452,7 +454,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | No # Check for Nova Sonic interruption pattern (matches working sample) if '{ "interrupted" : true }' in text_content: - log_event("nova_interruption_in_text") + logger.debug("Nova interruption detected in text") interruption: InterruptionDetectedEvent = {"reason": "user_input"} return {"interruptionDetected": interruption} @@ -480,7 +482,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | No # Handle interruption elif nova_event.get("stopReason") == "INTERRUPTED": - log_event("nova_interruption_stop_reason") + logger.debug("Nova interruption stop reason") interruption: InterruptionDetectedEvent = {"reason": "user_input"} @@ -664,29 +666,26 @@ async def create_bidirectional_connection( **kwargs, ) -> BidirectionalModelSession: """Create Nova Sonic bidirectional connection.""" - log_flow("nova_connection_create", "starting") + logger.debug("Nova connection create - starting") # Initialize client if needed if not self._client: - await time_it_async("initialize_client", lambda: self._initialize_client()) + await self._initialize_client() # Start Nova Sonic bidirectional stream try: - stream = await time_it_async( - "invoke_model_with_bidirectional_stream", - lambda: self._client.invoke_model_with_bidirectional_stream( - InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) - ), + stream = await self._client.invoke_model_with_bidirectional_stream( + InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) ) # Create and initialize connection connection = NovaSonicSession(stream, self.config) - await time_it_async("initialize_connection", lambda: connection.initialize(system_prompt, tools, messages)) + await connection.initialize(system_prompt, tools, messages) - log_event("nova_connection_created") + logger.debug("Nova connection created") return connection except Exception as e: - log_event("nova_connection_create_error", error=str(e)) + logger.error("Nova connection create error: %s", str(e)) logger.error("Failed to create Nova Sonic connection: %s", e) raise From 04265baa9267865fe9686dbe89440c552e77f2da Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 14:02:24 -0400 Subject: [PATCH 014/242] Removed logging utility --- .../bidirectional_streaming/utils/__init__.py | 5 -- .../bidirectional_streaming/utils/debug.py | 48 ------------------- 2 files changed, 53 deletions(-) delete mode 100644 src/strands/experimental/bidirectional_streaming/utils/__init__.py delete mode 100644 src/strands/experimental/bidirectional_streaming/utils/debug.py diff --git a/src/strands/experimental/bidirectional_streaming/utils/__init__.py b/src/strands/experimental/bidirectional_streaming/utils/__init__.py deleted file mode 100644 index 579478436..000000000 --- a/src/strands/experimental/bidirectional_streaming/utils/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Utility functions for bidirectional streaming.""" - -from .debug import log_event, log_flow, time_it_async - -__all__ = ["log_event", "log_flow", "time_it_async"] diff --git a/src/strands/experimental/bidirectional_streaming/utils/debug.py b/src/strands/experimental/bidirectional_streaming/utils/debug.py deleted file mode 100644 index 6a7fc3982..000000000 --- a/src/strands/experimental/bidirectional_streaming/utils/debug.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Debug utilities for Strands bidirectional streaming. - -Provides consistent debug logging across all bidirectional streaming components -with configurable output control matching the Nova Sonic tool use example. -""" - -import datetime -import inspect -import time - -# Debug logging system matching successful tool use example -DEBUG = False # Disable debug logging for clean output like tool use example - - -def debug_print(message): - """Print debug message with timestamp and function name.""" - if DEBUG: - function_name = inspect.stack()[1].function - if function_name == "time_it_async": - function_name = inspect.stack()[2].function - timestamp = "{:%Y-%m-%d %H:%M:%S.%f}".format(datetime.datetime.now())[:-3] - print(f"{timestamp} {function_name} {message}") - - -def log_event(event_type, **context): - """Log important events with structured context.""" - if DEBUG: - function_name = inspect.stack()[1].function - timestamp = "{:%Y-%m-%d %H:%M:%S.%f}".format(datetime.datetime.now())[:-3] - context_str = " ".join([f"{k}={v}" for k, v in context.items()]) if context else "" - print(f"{timestamp} {function_name} EVENT: {event_type} {context_str}") - - -def log_flow(step, details=""): - """Log important flow steps without excessive detail.""" - if DEBUG: - function_name = inspect.stack()[1].function - timestamp = "{:%Y-%m-%d %H:%M:%S.%f}".format(datetime.datetime.now())[:-3] - print(f"{timestamp} {function_name} FLOW: {step} {details}") - - -async def time_it_async(label, method_to_run): - """Time asynchronous method execution.""" - start_time = time.perf_counter() - result = await method_to_run() - end_time = time.perf_counter() - debug_print(f"Execution time for {label}: {end_time - start_time:.4f} seconds") - return result From 8a7396cf0715409b7fb35deb2c51b1164541a307 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 14:36:36 -0400 Subject: [PATCH 015/242] Updated types --- .../experimental/bidirectional_streaming/__init__.py | 3 --- .../bidirectional_streaming/models/bidirectional_model.py | 6 +++--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index f6a3b41bf..52822711a 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -1,5 +1,2 @@ """Bidirectional streaming package for real-time audio/text conversations.""" -from .utils import log_event, log_flow, time_it_async - -__all__ = ["log_event", "log_flow", "time_it_async"] diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 4cd9cc6b8..d5c3c9b65 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -17,7 +17,7 @@ from ....types.content import Messages from ....types.tools import ToolSpec -from ..types.bidirectional_streaming import AudioInputEvent +from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent logger = logging.getLogger(__name__) @@ -31,7 +31,7 @@ class BidirectionalModelSession(abc.ABC): """ @abc.abstractmethod - async def receive_events(self) -> AsyncIterable[dict[str, any]]: + async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: """Receive events from the model in standardized format. Converts provider-specific events to a common format that can be @@ -71,7 +71,7 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No """Send tool execution result to the model. Formats and sends tool results according to the provider's specific protocol. - Handles both successful results and error cases. + Handles both successful results and error cases through the result dictionary. """ raise NotImplementedError From 3107e6bac979c137cf575b1fbd45b57e1e3c87fd Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 9 Oct 2025 12:53:41 -0400 Subject: [PATCH 016/242] (feat)bidirectional_streaming: add openai realtime model provider --- pyproject.toml | 12 + .../bidirectional_streaming/__init__.py | 46 +- .../models/__init__.py | 10 +- .../models/novasonic.py | 12 +- .../bidirectional_streaming/models/openai.py | 508 ++++++++++++++++++ ...al_streaming.py => test_bidi_novasonic.py} | 0 .../tests/test_bidi_openai.py | 285 ++++++++++ .../bidirectional_streaming/types/__init__.py | 4 + .../types/bidirectional_streaming.py | 50 +- 9 files changed, 916 insertions(+), 11 deletions(-) create mode 100644 src/strands/experimental/bidirectional_streaming/models/openai.py rename src/strands/experimental/bidirectional_streaming/tests/{test_bidirectional_streaming.py => test_bidi_novasonic.py} (100%) create mode 100644 src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py diff --git a/pyproject.toml b/pyproject.toml index 3b8866f4a..2900719ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,12 +53,24 @@ sagemaker = [ "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0", "openai>=1.68.0,<2.0.0", # SageMaker uses OpenAI-compatible interface ] +bidirectional-streaming-nova = [ + "pyaudio>=0.2.13", + "rx>=3.2.0", + "smithy-aws-core>=0.0.1", + "pytz", + "aws_sdk_bedrock_runtime", +] +bidirectional-streaming-openai = [ + "pyaudio>=0.2.13", + "websockets>=12.0,<14.0", +] bidirectional-streaming = [ "pyaudio>=0.2.13", "rx>=3.2.0", "smithy-aws-core>=0.0.1", "pytz", "aws_sdk_bedrock_runtime", + "websockets>=12.0,<14.0", ] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 52822711a..a6af41dff 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -1,2 +1,46 @@ -"""Bidirectional streaming package for real-time audio/text conversations.""" +""" +Bidirectional streaming package. +""" +# Main components - Primary user interface +from .agent.agent import BidirectionalAgent + +# Model providers - What users need to create models +from .models.novasonic import NovaSonicBidirectionalModel +from .models.openai import OpenAIRealtimeBidirectionalModel + +# Event types - For type hints and event handling +from .types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + TextOutputEvent, + InterruptionDetectedEvent, + BidirectionalStreamEvent, + VoiceActivityEvent, + UsageMetricsEvent, +) + +# Advanced interfaces (for custom implementations) +from .models.bidirectional_model import BidirectionalModel, BidirectionalModelSession + +__all__ = [ + # Main interface + "BidirectionalAgent", + + # Model providers + "NovaSonicBidirectionalModel", + "OpenAIRealtimeBidirectionalModel", + + # Event types + "AudioInputEvent", + "AudioOutputEvent", + "TextOutputEvent", + "InterruptionDetectedEvent", + "BidirectionalStreamEvent", + "VoiceActivityEvent", + "UsageMetricsEvent", + + # Model interface + "BidirectionalModel", + "BidirectionalModelSession", +] \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index 6cba974e0..4a11f9e4a 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -2,5 +2,13 @@ from .bidirectional_model import BidirectionalModel, BidirectionalModelSession from .novasonic import NovaSonicBidirectionalModel, NovaSonicSession +from .openai import OpenAIRealtimeBidirectionalModel, OpenAIRealtimeSession -__all__ = ["BidirectionalModel", "BidirectionalModelSession", "NovaSonicBidirectionalModel", "NovaSonicSession"] +__all__ = [ + "BidirectionalModel", + "BidirectionalModelSession", + "NovaSonicBidirectionalModel", + "NovaSonicSession", + "OpenAIRealtimeBidirectionalModel", + "OpenAIRealtimeSession" +] \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 7f7937ef1..bc00b7e91 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -35,6 +35,7 @@ BidirectionalConnectionStartEvent, InterruptionDetectedEvent, TextOutputEvent, + UsageMetricsEvent, ) from .bidirectional_model import BidirectionalModel, BidirectionalModelSession @@ -488,9 +489,16 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | No return {"interruptionDetected": interruption} - # Handle usage events (ignore) + # Handle usage events - convert to standardized format elif "usageEvent" in nova_event: - return None + usage_data = nova_event["usageEvent"] + usage_metrics: UsageMetricsEvent = { + "totalTokens": usage_data.get("totalTokens"), + "inputTokens": usage_data.get("totalInputTokens"), + "outputTokens": usage_data.get("totalOutputTokens"), + "audioTokens": usage_data.get("details", {}).get("total", {}).get("output", {}).get("speechTokens") + } + return {"usageMetrics": usage_metrics} # Handle content start events (track role) elif "contentStart" in nova_event: diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py new file mode 100644 index 000000000..0fa859db9 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -0,0 +1,508 @@ +"""OpenAI Realtime API provider for Strands bidirectional streaming. + +Provides real-time audio and text communication through OpenAI's Realtime API +with WebSocket connections, voice activity detection, and function calling. +""" + +import asyncio +import base64 +import json +import logging +import uuid +from typing import AsyncIterable + +import websockets +from websockets.exceptions import ConnectionClosed +from websockets.client import WebSocketClientProtocol + +from ....types.content import Messages +from ....types.tools import ToolSpec, ToolUse +from ..types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + BidirectionalConnectionEndEvent, + BidirectionalConnectionStartEvent, + BidirectionalStreamEvent, + InterruptionDetectedEvent, + TextOutputEvent, + VoiceActivityEvent, +) +from .bidirectional_model import BidirectionalModel, BidirectionalModelSession + +logger = logging.getLogger(__name__) + +# OpenAI Realtime API configuration +OPENAI_REALTIME_URL = "wss://api.openai.com/v1/realtime" +DEFAULT_MODEL = "gpt-realtime" + +AUDIO_FORMAT = {"type": "audio/pcm", "rate": 24000} + +DEFAULT_SESSION_CONFIG = { + "type": "realtime", + "instructions": "You are a helpful assistant. Please speak in English and keep your responses clear and concise.", + "output_modalities": ["audio"], + "audio": { + "input": { + "format": AUDIO_FORMAT, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "prefix_padding_ms": 300, + "silence_duration_ms": 500, + } + }, + "output": {"format": AUDIO_FORMAT, "voice": "alloy"}, + }, +} + + +class OpenAIRealtimeSession(BidirectionalModelSession): + """OpenAI Realtime API session for real-time audio/text streaming. + + Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, + function calling, and event conversion to Strands format. + """ + + def __init__(self, websocket: WebSocketClientProtocol, config: dict[str, any]) -> None: + """Initialize OpenAI Realtime session.""" + self.websocket = websocket + self.config = config + self.session_id = str(uuid.uuid4()) + self._active = True + + self._event_queue = asyncio.Queue() + self._response_task = None + self._function_call_buffer = {} + + logger.debug("OpenAI Realtime session initialized: %s", self.session_id) + + def _require_active(self) -> bool: + """Check if session is active.""" + return self._active + + def _create_text_event(self, text: str, role: str) -> dict[str, any]: + """Create standardized text output event.""" + text_output: TextOutputEvent = {"text": text, "role": role} + return {"textOutput": text_output} + + def _create_voice_activity_event(self, activity_type: str) -> dict[str, any]: + """Create standardized voice activity event.""" + voice_activity: VoiceActivityEvent = {"activityType": activity_type} + return {"voiceActivity": voice_activity} + + async def _create_conversation_item(self, item_data: dict) -> None: + """Create conversation item and trigger response.""" + await self._send_event({"type": "conversation.item.create", "item": item_data}) + await self._send_event({"type": "response.create"}) + + async def initialize( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + ) -> None: + """Initialize session with configuration.""" + try: + session_config = self._build_session_config(system_prompt, tools) + await self._send_event({"type": "session.update", "session": session_config}) + + if messages: + await self._add_conversation_history(messages) + + self._response_task = asyncio.create_task(self._process_responses()) + logger.info("OpenAI Realtime session initialized successfully") + + except Exception as e: + logger.error("Error during OpenAI Realtime initialization: %s", e) + raise + + def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict: + """Build session configuration for OpenAI Realtime API.""" + config = DEFAULT_SESSION_CONFIG.copy() + + if system_prompt: + config["instructions"] = system_prompt + + if tools: + config["tools"] = self._convert_tools_to_openai_format(tools) + + custom_config = self.config.get("session", {}) + supported_params = { + "type", "output_modalities", "instructions", "voice", "audio", + "tools", "tool_choice", "input_audio_format", "output_audio_format", + "input_audio_transcription", "turn_detection" + } + + for key, value in custom_config.items(): + if key in supported_params: + config[key] = value + else: + logger.warning("Ignoring unsupported session parameter: %s", key) + + return config + + def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: + """Convert Strands tool specifications to OpenAI function format.""" + openai_tools = [] + + for tool in tools: + input_schema = tool["inputSchema"] + if "json" in input_schema: + schema = json.loads(input_schema["json"]) if isinstance(input_schema["json"], str) else input_schema["json"] + else: + schema = input_schema + + openai_tool = { + "type": "function", + "function": { + "name": tool["name"], + "description": tool["description"], + "parameters": schema + } + } + openai_tools.append(openai_tool) + + return openai_tools + + async def _add_conversation_history(self, messages: Messages) -> None: + """Add conversation history to the session.""" + for message in messages: + conversation_item = { + "type": "conversation.item.create", + "item": {"type": "message", "role": message["role"], "content": []} + } + + content = message.get("content", "") + if isinstance(content, str): + conversation_item["item"]["content"].append({"type": "input_text", "text": content}) + elif isinstance(content, list): + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + conversation_item["item"]["content"].append({"type": "input_text", "text": item.get("text", "")}) + + await self._send_event(conversation_item) + + async def _process_responses(self) -> None: + """Process incoming WebSocket messages.""" + logger.debug("OpenAI Realtime response processor started") + + try: + async for message in self.websocket: + if not self._active: + break + + try: + event = json.loads(message) + await self._event_queue.put(event) + except json.JSONDecodeError as e: + logger.warning("Failed to parse OpenAI event: %s", e) + continue + + except ConnectionClosed: + logger.debug("OpenAI Realtime WebSocket connection closed") + except Exception as e: + logger.error("Error in OpenAI Realtime response processing: %s", e) + finally: + self._active = False + logger.debug("OpenAI Realtime response processor stopped") + + async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: + """Receive OpenAI events and convert to Strands format.""" + connection_start: BidirectionalConnectionStartEvent = { + "connectionId": self.session_id, + "metadata": {"provider": "openai_realtime", "model": self.config.get("model", DEFAULT_MODEL)}, + } + yield {"BidirectionalConnectionStart": connection_start} + + try: + while self._active: + try: + openai_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) + provider_event = self._convert_openai_event(openai_event) + if provider_event: + yield provider_event + except asyncio.TimeoutError: + continue + + except Exception as e: + logger.error("Error receiving OpenAI Realtime event: %s", e) + finally: + connection_end: BidirectionalConnectionEndEvent = { + "connectionId": self.session_id, + "reason": "connection_complete", + "metadata": {"provider": "openai_realtime"}, + } + yield {"BidirectionalConnectionEnd": connection_end} + + def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] | None: + """Convert OpenAI events to Strands format.""" + event_type = openai_event.get("type") + + # Audio output + if event_type == "response.output_audio.delta": + audio_data = base64.b64decode(openai_event["delta"]) + audio_output: AudioOutputEvent = { + "audioData": audio_data, + "format": "pcm", + "sampleRate": 24000, + "channels": 1, + "encoding": None, + } + return {"audioOutput": audio_output} + + # Text output using helper method + elif event_type == "response.output_text.delta": + return self._create_text_event(openai_event["delta"], "assistant") + + elif event_type == "response.output_audio_transcript.delta": + return self._create_text_event(openai_event["delta"], "assistant") + + # User transcription + elif event_type == "conversation.item.input_audio_transcription.delta": + transcript_delta = openai_event.get("delta", "") + return self._create_text_event(transcript_delta, "user") if transcript_delta.strip() else None + + elif event_type == "conversation.item.input_audio_transcription.completed": + transcript = openai_event.get("transcript", "") + return self._create_text_event(transcript, "user") if transcript.strip() else None + + elif event_type == "conversation.item.input_audio_transcription.segment": + segment_data = openai_event.get("segment", {}) + text = segment_data.get("text", "") + return self._create_text_event(text, "user") if text.strip() else None + + elif event_type == "conversation.item.input_audio_transcription.failed": + error_info = openai_event.get("error", {}) + logger.warning("OpenAI transcription failed: %s", error_info.get("message", "Unknown error")) + return None + + # Function call processing + elif event_type == "response.function_call_arguments.delta": + call_id = openai_event.get("call_id") + delta = openai_event.get("delta", "") + if call_id: + if call_id not in self._function_call_buffer: + self._function_call_buffer[call_id] = {"call_id": call_id, "name": "", "arguments": delta} + else: + self._function_call_buffer[call_id]["arguments"] += delta + return None + + elif event_type == "response.function_call_arguments.done": + call_id = openai_event.get("call_id") + if call_id and call_id in self._function_call_buffer: + function_call = self._function_call_buffer[call_id] + try: + tool_use: ToolUse = { + "toolUseId": call_id, + "name": function_call["name"], + "input": json.loads(function_call["arguments"]) if function_call["arguments"] else {}, + } + del self._function_call_buffer[call_id] + return {"toolUse": tool_use} + except (json.JSONDecodeError, KeyError) as e: + logger.warning("Error parsing function arguments for %s: %s", call_id, e) + del self._function_call_buffer[call_id] + return None + + # Voice activity detection using helper method + elif event_type == "input_audio_buffer.speech_started": + return self._create_voice_activity_event("speech_started") + elif event_type == "input_audio_buffer.speech_stopped": + return self._create_voice_activity_event("speech_stopped") + elif event_type == "input_audio_buffer.timeout_triggered": + return self._create_voice_activity_event("timeout") + + # Lifecycle events (log only) + elif event_type == "conversation.item.retrieve": + item = openai_event.get("item", {}) + logger.debug("OpenAI conversation item retrieved: %s", item.get("id")) + return None + + elif event_type == "conversation.item.added": + logger.debug("OpenAI conversation item added: %s", openai_event.get("item", {}).get("id")) + return None + + elif event_type == "conversation.item.done": + logger.debug("OpenAI conversation item done: %s", openai_event.get("item", {}).get("id")) + + item = openai_event.get("item", {}) + if item.get("type") == "message" and item.get("role") == "assistant": + content_parts = item.get("content", []) + if content_parts: + message_content = [] + for content_part in content_parts: + if content_part.get("type") == "output_text": + message_content.append({"type": "text", "text": content_part.get("text", "")}) + elif content_part.get("type") == "output_audio": + transcript = content_part.get("transcript", "") + if transcript: + message_content.append({"type": "text", "text": transcript}) + + if message_content: + message = {"role": "assistant", "content": message_content} + return {"messageStop": {"message": message}} + return None + + elif event_type in ["response.output_item.added", "response.output_item.done", + "response.content_part.added", "response.content_part.done"]: + item_data = openai_event.get("item") or openai_event.get("part") + logger.debug("OpenAI %s: %s", event_type, item_data.get("id") if item_data else "unknown") + + # Track function call names from response.output_item.added + if event_type == "response.output_item.added": + item = openai_event.get("item", {}) + if item.get("type") == "function_call": + call_id = item.get("call_id") + function_name = item.get("name") + if call_id and function_name: + if call_id not in self._function_call_buffer: + self._function_call_buffer[call_id] = {"call_id": call_id, "name": function_name, "arguments": ""} + else: + self._function_call_buffer[call_id]["name"] = function_name + return None + + elif event_type in ["input_audio_buffer.committed", "input_audio_buffer.cleared", + "session.created", "session.updated"]: + logger.debug("OpenAI %s event", event_type) + return None + + elif event_type == "error": + logger.error("OpenAI Realtime error: %s", openai_event.get("error", {})) + return None + + else: + logger.debug("Unhandled OpenAI event type: %s", event_type) + return None + + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Send audio content to OpenAI for processing.""" + if not self._require_active(): + return + + audio_base64 = base64.b64encode(audio_input["audioData"]).decode("utf-8") + await self._send_event({"type": "input_audio_buffer.append", "audio": audio_base64}) + + async def send_text_content(self, text: str, **kwargs) -> None: + """Send text content to OpenAI for processing.""" + if not self._require_active(): + return + + item_data = { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": text}] + } + await self._create_conversation_item(item_data) + + async def send_interrupt(self) -> None: + """Send interruption signal to OpenAI.""" + if not self._require_active(): + return + + await self._send_event({"type": "response.cancel"}) + + async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: + """Send tool result back to OpenAI.""" + if not self._require_active(): + return + + logger.debug("OpenAI tool result send: %s", tool_use_id) + result_text = json.dumps(result) if not isinstance(result, str) else result + + item_data = { + "type": "function_call_output", + "call_id": tool_use_id, + "output": result_text + } + await self._create_conversation_item(item_data) + + async def close(self) -> None: + """Close session and cleanup resources.""" + if not self._active: + return + + logger.debug("OpenAI Realtime cleanup - starting connection close") + self._active = False + + if self._response_task and not self._response_task.done(): + self._response_task.cancel() + try: + await self._response_task + except asyncio.CancelledError: + pass + + try: + await self.websocket.close() + except Exception as e: + logger.warning("Error closing OpenAI Realtime WebSocket: %s", e) + + logger.debug("OpenAI Realtime connection closed") + + async def _send_event(self, event: dict[str, any]) -> None: + """Send event to OpenAI via WebSocket.""" + try: + message = json.dumps(event) + await self.websocket.send(message) + logger.debug("Sent OpenAI event: %s", event.get("type")) + except Exception as e: + logger.error("Error sending OpenAI event: %s", e) + raise + + +class OpenAIRealtimeBidirectionalModel(BidirectionalModel): + """OpenAI Realtime API provider for Strands bidirectional streaming. + + Provides real-time audio/text communication through OpenAI's Realtime API + with WebSocket connections, voice activity detection, and function calling. + """ + + def __init__( + self, + model: str = DEFAULT_MODEL, + api_key: str | None = None, + **config: any + ) -> None: + """Initialize OpenAI Realtime bidirectional model.""" + self.model = model + self.api_key = api_key + self.config = config + + import os + if not self.api_key: + self.api_key = os.getenv("OPENAI_API_KEY") + if not self.api_key: + raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.") + + logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) + + async def create_bidirectional_connection( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs, + ) -> BidirectionalModelSession: + """Create bidirectional connection to OpenAI Realtime API.""" + logger.info("Creating OpenAI Realtime connection...") + + try: + url = f"{OPENAI_REALTIME_URL}?model={self.model}" + + headers = [("Authorization", f"Bearer {self.api_key}")] + if "organization" in self.config: + headers.append(("OpenAI-Organization", self.config["organization"])) + if "project" in self.config: + headers.append(("OpenAI-Project", self.config["project"])) + + websocket = await websockets.connect(url, additional_headers=headers) + logger.info("WebSocket connected successfully") + + session = OpenAIRealtimeSession(websocket, self.config) + await session.initialize(system_prompt, tools, messages) + + logger.info("OpenAI Realtime connection established") + return session + + except Exception as e: + logger.error("OpenAI connection error: %s", e) + raise \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py similarity index 100% rename from src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py rename to src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py new file mode 100644 index 000000000..098ec4a39 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python3 +"""Test OpenAI Realtime API speech-to-speech interaction.""" + +import asyncio +import os +import sys +import time +from pathlib import Path + +# Add the src directory to Python path +sys.path.insert(0, str(Path(__file__).parent / "src")) + +import pyaudio + +from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel + + +async def play(context): + """Handle audio playback with interruption support.""" + audio = pyaudio.PyAudio() + + try: + speaker = audio.open( + format=pyaudio.paInt16, + channels=1, + rate=24000, # OpenAI Realtime uses 24kHz + output=True, + frames_per_buffer=1024, + ) + + while context["active"]: + try: + # Check for interruption + if context.get("interrupted", False): + # Clear audio queue on interruption + while not context["audio_out"].empty(): + try: + context["audio_out"].get_nowait() + except asyncio.QueueEmpty: + break + + context["interrupted"] = False + await asyncio.sleep(0.05) + continue + + # Get audio data with timeout + try: + audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) + + if audio_data and context["active"]: + # Play in chunks to allow interruption + chunk_size = 1024 + for i in range(0, len(audio_data), chunk_size): + if context.get("interrupted", False) or not context["active"]: + break + + chunk = audio_data[i:i + chunk_size] + speaker.write(chunk) + await asyncio.sleep(0.001) # Brief pause for responsiveness + + except asyncio.TimeoutError: + continue + + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Audio playback error: {e}") + finally: + try: + speaker.close() + except: + pass + audio.terminate() + + +async def record(context): + """Handle microphone recording.""" + audio = pyaudio.PyAudio() + + try: + microphone = audio.open( + format=pyaudio.paInt16, + channels=1, + rate=24000, # Match OpenAI's expected input rate + input=True, + frames_per_buffer=1024, + ) + + while context["active"]: + try: + audio_bytes = microphone.read(1024, exception_on_overflow=False) + await context["audio_in"].put(audio_bytes) + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Microphone recording error: {e}") + finally: + try: + microphone.close() + except: + pass + audio.terminate() + + +async def receive(agent, context): + """Handle events from the agent.""" + try: + async for event in agent.receive(): + if not context["active"]: + break + + # Handle audio output + if "audioOutput" in event: + audio_data = event["audioOutput"]["audioData"] + + if not context.get("interrupted", False): + await context["audio_out"].put(audio_data) + + # Handle text output (transcripts) + elif "textOutput" in event: + text_output = event["textOutput"] + role = text_output.get("role", "assistant") + text = text_output.get("text", "").strip() + + if text: + if role == "user": + print(f"User: {text}") + elif role == "assistant": + print(f"Assistant: {text}") + + # Handle interruption detection + elif "interruptionDetected" in event: + context["interrupted"] = True + + # Handle connection events + elif "BidirectionalConnectionStart" in event: + pass # Silent connection start + elif "BidirectionalConnectionEnd" in event: + context["active"] = False + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Receive handler error: {e}") + finally: + pass + + +async def send(agent, context): + """Send audio from microphone to agent.""" + try: + while context["active"]: + try: + audio_bytes = await asyncio.wait_for(context["audio_in"].get(), timeout=0.1) + + # Create audio event in expected format + audio_event = { + "audioData": audio_bytes, + "format": "pcm", + "sampleRate": 24000, + "channels": 1 + } + + await agent.send(audio_event) + + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Send handler error: {e}") + finally: + pass + + +async def main(): + """Main test function for OpenAI voice chat.""" + print("Starting OpenAI Realtime API test...") + + # Check API key + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("OPENAI_API_KEY environment variable not set") + return False + + # Check audio system + try: + audio = pyaudio.PyAudio() + audio.terminate() + except Exception as e: + print(f"Audio system error: {e}") + return False + + # Create OpenAI model + model = OpenAIRealtimeBidirectionalModel( + model="gpt-4o-realtime-preview", + api_key=api_key, + session={ + "output_modalities": ["audio"], + "audio": { + "input": { + "format": {"type": "audio/pcm", "rate": 24000}, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "silence_duration_ms": 700 + } + }, + "output": { + "format": {"type": "audio/pcm", "rate": 24000}, + "voice": "alloy" + } + } + } + ) + + # Create agent + agent = BidirectionalAgent( + model=model, + system_prompt="You are a helpful voice assistant. Keep your responses brief and natural. Say hello when you first connect." + ) + + # Start the session + await agent.start() + + # Create shared context + context = { + "active": True, + "audio_in": asyncio.Queue(), + "audio_out": asyncio.Queue(), + "interrupted": False, + "start_time": time.time() + } + + print("Speak into your microphone. Press Ctrl+C to stop.") + + try: + # Run all tasks concurrently + await asyncio.gather( + play(context), + record(context), + receive(agent, context), + send(agent, context), + return_exceptions=True + ) + + except KeyboardInterrupt: + print("\nInterrupted by user") + except asyncio.CancelledError: + print("\nTest cancelled") + except Exception as e: + print(f"\nError during voice chat: {e}") + finally: + print("Cleaning up...") + context["active"] = False + + try: + await agent.end() + except Exception as e: + print(f"Cleanup error: {e}") + + return True + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nTest interrupted by user") + except Exception as e: + print(f"Test error: {e}") + import traceback + traceback.print_exc() \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index 510285f06..412061146 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -13,6 +13,8 @@ BidirectionalStreamEvent, InterruptionDetectedEvent, TextOutputEvent, + UsageMetricsEvent, + VoiceActivityEvent, ) __all__ = [ @@ -23,6 +25,8 @@ "BidirectionalStreamEvent", "InterruptionDetectedEvent", "TextOutputEvent", + "UsageMetricsEvent", + "VoiceActivityEvent", "SUPPORTED_AUDIO_FORMATS", "SUPPORTED_SAMPLE_RATES", "SUPPORTED_CHANNELS", diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 01d72356a..194698f29 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -116,11 +116,43 @@ class BidirectionalConnectionEndEvent(TypedDict): metadata: Provider-specific connection metadata. """ - reason: Literal["user_request", "timeout", "error"] + reason: Literal["user_request", "timeout", "error", "connection_complete"] connectionId: Optional[str] metadata: Optional[Dict[str, Any]] +class VoiceActivityEvent(TypedDict): + """Voice activity detection event for speech monitoring. + + Provides standardized voice activity detection events across providers + to enable speech-aware applications and better conversation flow. + + Attributes: + activityType: Type of voice activity detected. + """ + + activityType: Literal["speech_started", "speech_stopped", "timeout"] + + +class UsageMetricsEvent(TypedDict): + """Token usage and performance tracking. + + Provides standardized usage metrics across providers for cost monitoring + and performance optimization. + + Attributes: + totalTokens: Total tokens used in the interaction. + inputTokens: Tokens used for input processing. + outputTokens: Tokens used for output generation. + audioTokens: Tokens used specifically for audio processing. + """ + + totalTokens: Optional[int] + inputTokens: Optional[int] + outputTokens: Optional[int] + audioTokens: Optional[int] + + class BidirectionalStreamEvent(StreamEvent, total=False): """Bidirectional stream event extending existing StreamEvent. @@ -134,11 +166,15 @@ class BidirectionalStreamEvent(StreamEvent, total=False): interruptionDetected: User interruption detection. BidirectionalConnectionStart: connection start event. BidirectionalConnectionEnd: connection end event. + voiceActivity: Voice activity detection events. + usageMetrics: Token usage and performance metrics. """ - audioOutput: AudioOutputEvent - audioInput: AudioInputEvent - textOutput: TextOutputEvent - interruptionDetected: InterruptionDetectedEvent - BidirectionalConnectionStart: BidirectionalConnectionStartEvent - BidirectionalConnectionEnd: BidirectionalConnectionEndEvent + audioOutput: Optional[AudioOutputEvent] + audioInput: Optional[AudioInputEvent] + textOutput: Optional[TextOutputEvent] + interruptionDetected: Optional[InterruptionDetectedEvent] + BidirectionalConnectionStart: Optional[BidirectionalConnectionStartEvent] + BidirectionalConnectionEnd: Optional[BidirectionalConnectionEndEvent] + voiceActivity: Optional[VoiceActivityEvent] + usageMetrics: Optional[UsageMetricsEvent] From da8b86ca77e4f324b6cfc2a1d7ce756ec8a6d310 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Fri, 10 Oct 2025 10:29:08 -0400 Subject: [PATCH 017/242] fix function calling --- .../bidirectional_streaming/__init__.py | 15 +++++++-------- .../bidirectional_streaming/agent/agent.py | 1 - .../event_loop/bidirectional_event_loop.py | 1 - .../bidirectional_streaming/models/novasonic.py | 1 - .../bidirectional_streaming/models/openai.py | 14 ++++++-------- .../tests/test_bidi_openai.py | 2 ++ 6 files changed, 15 insertions(+), 19 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index a6af41dff..aeb335dea 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -1,10 +1,12 @@ -""" -Bidirectional streaming package. +"""Bidirectional streaming package. """ # Main components - Primary user interface from .agent.agent import BidirectionalAgent +# Advanced interfaces (for custom implementations) +from .models.bidirectional_model import BidirectionalModel, BidirectionalModelSession + # Model providers - What users need to create models from .models.novasonic import NovaSonicBidirectionalModel from .models.openai import OpenAIRealtimeBidirectionalModel @@ -13,16 +15,13 @@ from .types.bidirectional_streaming import ( AudioInputEvent, AudioOutputEvent, - TextOutputEvent, - InterruptionDetectedEvent, BidirectionalStreamEvent, - VoiceActivityEvent, + InterruptionDetectedEvent, + TextOutputEvent, UsageMetricsEvent, + VoiceActivityEvent, ) -# Advanced interfaces (for custom implementations) -from .models.bidirectional_model import BidirectionalModel, BidirectionalModelSession - __all__ = [ # Main interface "BidirectionalAgent", diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 68d371a51..0cd90063d 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -23,7 +23,6 @@ from ..models.bidirectional_model import BidirectionalModel from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent - logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 16be08aaf..340cd9267 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -22,7 +22,6 @@ from ....types.tools import ToolResult, ToolUse from ..models.bidirectional_model import BidirectionalModelSession - logger = logging.getLogger(__name__) # Session constants diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index bc00b7e91..4e4952fa9 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -37,7 +37,6 @@ TextOutputEvent, UsageMetricsEvent, ) - from .bidirectional_model import BidirectionalModel, BidirectionalModelSession logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 0fa859db9..76bf9f50d 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -12,8 +12,8 @@ from typing import AsyncIterable import websockets -from websockets.exceptions import ConnectionClosed from websockets.client import WebSocketClientProtocol +from websockets.exceptions import ConnectionClosed from ....types.content import Messages from ....types.tools import ToolSpec, ToolUse @@ -23,7 +23,6 @@ BidirectionalConnectionEndEvent, BidirectionalConnectionStartEvent, BidirectionalStreamEvent, - InterruptionDetectedEvent, TextOutputEvent, VoiceActivityEvent, ) @@ -142,7 +141,7 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] return config def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: - """Convert Strands tool specifications to OpenAI function format.""" + """Convert Strands tool specifications to OpenAI Realtime API format.""" openai_tools = [] for tool in tools: @@ -152,13 +151,12 @@ def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: else: schema = input_schema + # OpenAI Realtime API expects flat structure, not nested under "function" openai_tool = { "type": "function", - "function": { - "name": tool["name"], - "description": tool["description"], - "parameters": schema - } + "name": tool["name"], + "description": tool["description"], + "parameters": schema } openai_tools.append(openai_tool) diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py index 098ec4a39..660040f3e 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py @@ -11,6 +11,7 @@ sys.path.insert(0, str(Path(__file__).parent / "src")) import pyaudio +from strands_tools import calculator from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel @@ -229,6 +230,7 @@ async def main(): # Create agent agent = BidirectionalAgent( model=model, + tools=[calculator], system_prompt="You are a helpful voice assistant. Keep your responses brief and natural. Say hello when you first connect." ) From 9368c82c76f6ca858d355c68624e078c8b95cf4e Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 14 Oct 2025 08:30:23 -0400 Subject: [PATCH 018/242] feat(tool_executor): Plug tool executor into bidirectional streaming implementation --- .../bidirectional_streaming/__init__.py | 41 +- .../bidirectional_streaming/agent/agent.py | 297 +++++++++- .../event_loop/bidirectional_event_loop.py | 179 +++--- .../models/__init__.py | 9 +- .../models/novasonic.py | 23 +- .../bidirectional_streaming/models/openai.py | 522 ++++++++++++++++++ .../tests/test_bidi_openai.py | 317 +++++++++++ .../tests/test_bidirectional_streaming.py | 27 + 8 files changed, 1317 insertions(+), 98 deletions(-) create mode 100644 src/strands/experimental/bidirectional_streaming/models/openai.py create mode 100644 src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 52822711a..844a8a1f8 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -1,2 +1,41 @@ -"""Bidirectional streaming package for real-time audio/text conversations.""" +"""Bidirectional streaming package.""" +# Main components - Primary user interface +from .agent.agent import BidirectionalAgent + +# Advanced interfaces (for custom implementations) +from .models.bidirectional_model import BidirectionalModel, BidirectionalModelSession + +# Model providers - What users need to create models +from .models.novasonic import NovaSonicBidirectionalModel +from .models.openai import OpenAIRealtimeBidirectionalModel + +# Event types - For type hints and event handling +from .types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + BidirectionalStreamEvent, + InterruptionDetectedEvent, + TextOutputEvent, + UsageMetricsEvent, + VoiceActivityEvent, +) + +__all__ = [ + # Main interface + "BidirectionalAgent", + # Model providers + "NovaSonicBidirectionalModel", + "OpenAIRealtimeBidirectionalModel", + # Event types + "AudioInputEvent", + "AudioOutputEvent", + "TextOutputEvent", + "InterruptionDetectedEvent", + "BidirectionalStreamEvent", + "VoiceActivityEvent", + "UsageMetricsEvent", + # Model interface + "BidirectionalModel", + "BidirectionalModelSession", +] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 68d371a51..26b964c53 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -13,12 +13,22 @@ """ import asyncio +import json import logging -from typing import AsyncIterable +import random +from concurrent.futures import ThreadPoolExecutor +from typing import Any, AsyncIterable, Callable, Mapping, Optional +from .... import _identifier +from ....hooks import HookProvider, HookRegistry +from ....telemetry.metrics import EventLoopMetrics from ....tools.executors import ConcurrentToolExecutor +from ....tools.executors._executor import ToolExecutor from ....tools.registry import ToolRegistry -from ....types.content import Messages +from ....tools.watcher import ToolWatcher +from ....types.content import Message, Messages +from ....types.tools import ToolResult, ToolUse +from ....types.traces import AttributeValue from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection from ..models.bidirectional_model import BidirectionalModel from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent @@ -26,6 +36,9 @@ logger = logging.getLogger(__name__) +_DEFAULT_AGENT_NAME = "Strands Agents" +_DEFAULT_AGENT_ID = "default" + class BidirectionalAgent: """Agent for bidirectional streaming conversations. @@ -34,12 +47,125 @@ class BidirectionalAgent: sessions. Supports concurrent tool execution and interruption handling. """ + class ToolCaller: + """Call tool as a function for bidirectional agent.""" + + def __init__(self, agent: "BidirectionalAgent") -> None: + """Initialize tool caller with agent reference.""" + # WARNING: Do not add any other member variables or methods as this could result in a name conflict with + # agent tools and thus break their execution. + self._agent = agent + + def __getattr__(self, name: str) -> Callable[..., Any]: + """Call tool as a function. + + This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). + It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing'). + + Args: + name: The name of the attribute (tool) being accessed. + + Returns: + A function that when called will execute the named tool. + + Raises: + AttributeError: If no tool with the given name exists or if multiple tools match the given name. + """ + + def caller( + user_message_override: Optional[str] = None, + record_direct_tool_call: Optional[bool] = None, + **kwargs: Any, + ) -> Any: + """Call a tool directly by name. + + Args: + user_message_override: Optional custom message to record instead of default + record_direct_tool_call: Whether to record direct tool calls in message history. + For bidirectional agents, this is always True to maintain conversation history. + **kwargs: Keyword arguments to pass to the tool. + + Returns: + The result returned by the tool. + + Raises: + AttributeError: If the tool doesn't exist. + """ + normalized_name = self._find_normalized_tool_name(name) + + # Create unique tool ID and set up the tool request + tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" + tool_use: ToolUse = { + "toolUseId": tool_id, + "name": normalized_name, + "input": kwargs.copy(), + } + tool_results: list[ToolResult] = [] + invocation_state = kwargs + + async def acall() -> ToolResult: + async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): + _ = event + + return tool_results[0] + + def tcall() -> ToolResult: + return asyncio.run(acall()) + + with ThreadPoolExecutor() as executor: + future = executor.submit(tcall) + tool_result = future.result() + + # Always record direct tool calls for bidirectional agents to maintain conversation history + # Use agent's record_direct_tool_call setting if not overridden + if record_direct_tool_call is not None: + should_record_direct_tool_call = record_direct_tool_call + else: + should_record_direct_tool_call = self._agent.record_direct_tool_call + + if should_record_direct_tool_call: + # Create a record of this tool execution in the message history + self._agent._record_tool_execution(tool_use, tool_result, user_message_override) + + return tool_result + + return caller + + def _find_normalized_tool_name(self, name: str) -> str: + """Lookup the tool represented by name, replacing characters with underscores as necessary.""" + tool_registry = self._agent.tool_registry.registry + + if tool_registry.get(name, None): + return name + + # If the desired name contains underscores, it might be a placeholder for characters that can't be + # represented as python identifiers but are valid as tool names, such as dashes. In that case, find + # all tools that can be represented with the normalized name + if "_" in name: + filtered_tools = [ + tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name + ] + + # The registry itself defends against similar names, so we can just take the first match + if filtered_tools: + return filtered_tools[0] + + raise AttributeError(f"Tool '{name}' not found") + def __init__( self, model: BidirectionalModel, tools: list | None = None, system_prompt: str | None = None, messages: Messages | None = None, + record_direct_tool_call: bool = True, + load_tools_from_directory: bool = False, + agent_id: Optional[str] = None, + name: Optional[str] = None, + tool_executor: Optional[ToolExecutor] = None, + hooks: Optional[list[HookProvider]] = None, + trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + description: Optional[str] = None, ): """Initialize bidirectional agent with required model and optional configuration. @@ -48,24 +174,177 @@ def __init__( tools: Optional list of tools available to the model. system_prompt: Optional system prompt for conversations. messages: Optional conversation history to initialize with. + record_direct_tool_call: Whether to record direct tool calls in message history. + load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory. + agent_id: Optional ID for the agent, useful for session management and multi-agent scenarios. + name: Name of the Agent. + tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). + hooks: Hooks to be added to the agent hook registry. + trace_attributes: Custom trace attributes to apply to the agent's trace span. + description: Description of what the Agent does. """ self.model = model self.system_prompt = system_prompt self.messages = messages or [] - - # Initialize tool registry using existing Strands infrastructure + + # Agent identification + self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) + self.name = name or _DEFAULT_AGENT_NAME + self.description = description + + # Tool execution configuration + self.record_direct_tool_call = record_direct_tool_call + self.load_tools_from_directory = load_tools_from_directory + + # Process trace attributes to ensure they're of compatible types + self.trace_attributes: dict[str, AttributeValue] = {} + if trace_attributes: + for k, v in trace_attributes.items(): + if isinstance(v, (str, int, float, bool)) or ( + isinstance(v, list) and all(isinstance(x, (str, int, float, bool)) for x in v) + ): + self.trace_attributes[k] = v + + # Initialize tool registry self.tool_registry = ToolRegistry() - if tools: + + if tools is not None: self.tool_registry.process_tools(tools) - self.tool_registry.initialize_tools() - - # Initialize tool executor for concurrent execution - self.tool_executor = ConcurrentToolExecutor() + + self.tool_registry.initialize_tools(self.load_tools_from_directory) + + # Initialize tool watcher if directory loading is enabled + if self.load_tools_from_directory: + self.tool_watcher = ToolWatcher(tool_registry=self.tool_registry) + + # Initialize tool executor + self.tool_executor = tool_executor or ConcurrentToolExecutor() + + # Initialize hooks system + self.hooks = HookRegistry() + if hooks: + for hook in hooks: + self.hooks.add_hook(hook) + + # Initialize other components + self.event_loop_metrics = EventLoopMetrics() + self.tool_caller = BidirectionalAgent.ToolCaller(self) # Session management self._session = None self._output_queue = asyncio.Queue() + @property + def tool(self) -> ToolCaller: + """Call tool as a function. + + Returns: + Tool caller through which user can invoke tool as a function. + + Example: + ``` + agent = BidirectionalAgent(model=model, tools=[calculator]) + agent.tool.calculator(expression="2+2") + ``` + """ + return self.tool_caller + + @property + def tool_names(self) -> list[str]: + """Get a list of all registered tool names. + + Returns: + Names of all tools available to this agent. + """ + all_tools = self.tool_registry.get_all_tools_config() + return list(all_tools.keys()) + + def _record_tool_execution( + self, + tool: ToolUse, + tool_result: ToolResult, + user_message_override: Optional[str], + ) -> None: + """Record a tool execution in the message history. + + Creates a sequence of messages that represent the tool execution: + + 1. A user message describing the tool call + 2. An assistant message with the tool use + 3. A user message with the tool result + 4. An assistant message acknowledging the tool call + + Args: + tool: The tool call information. + tool_result: The result returned by the tool. + user_message_override: Optional custom message to include. + """ + # Filter tool input parameters to only include those defined in tool spec + filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"]) + + # Create user message describing the tool call + input_parameters = json.dumps(filtered_input, default=lambda o: f"<>") + + user_msg_content = [ + {"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")} + ] + + # Add override message if provided + if user_message_override: + user_msg_content.insert(0, {"text": f"{user_message_override}\n"}) + + # Create filtered tool use for message history + filtered_tool: ToolUse = { + "toolUseId": tool["toolUseId"], + "name": tool["name"], + "input": filtered_input, + } + + # Create the message sequence + user_msg: Message = { + "role": "user", + "content": user_msg_content, + } + tool_use_msg: Message = { + "role": "assistant", + "content": [{"toolUse": filtered_tool}], + } + tool_result_msg: Message = { + "role": "user", + "content": [{"toolResult": tool_result}], + } + assistant_msg: Message = { + "role": "assistant", + "content": [{"text": f"agent.tool.{tool['name']} was called."}], + } + + # Add to message history + self.messages.append(user_msg) + self.messages.append(tool_use_msg) + self.messages.append(tool_result_msg) + self.messages.append(assistant_msg) + + logger.debug("Direct tool call recorded in message history: %s", tool["name"]) + + def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]: + """Filter input parameters to only include those defined in the tool specification. + + Args: + tool_name: Name of the tool to get specification for + input_params: Original input parameters + + Returns: + Filtered parameters containing only those defined in tool spec + """ + all_tools_config = self.tool_registry.get_all_tools_config() + tool_spec = all_tools_config.get(tool_name) + + if not tool_spec or "inputSchema" not in tool_spec: + return input_params.copy() + + properties = tool_spec["inputSchema"]["json"]["properties"] + return {k: v for k, v in input_params.items() if k in properties} + async def start(self) -> None: """Start a persistent bidirectional conversation session. diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 16be08aaf..69f5d759d 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -12,12 +12,13 @@ """ import asyncio -import json import logging import traceback import uuid from ....tools._validator import validate_and_prepare_tools +from ....telemetry.metrics import Trace +from ....types._events import ToolResultEvent, ToolStreamEvent from ....types.content import Message from ....types.tools import ToolResult, ToolUse from ..models.bidirectional_model import BidirectionalModelSession @@ -59,6 +60,9 @@ def __init__(self, model_session: BidirectionalModelSession, agent: "Bidirection # Interruption handling (model-agnostic) self.interrupted = False self.interruption_lock = asyncio.Lock() + + # Tool execution tracking + self.tool_count = 0 async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection: @@ -195,11 +199,11 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: # Cancel all pending tool execution tasks cancelled_tools = 0 - for task_id, task in list(session.pending_tool_tasks.items()): + for _task_id, task in list(session.pending_tool_tasks.items()): if not task.done(): task.cancel() cancelled_tools += 1 - logger.debug("Tool task cancelled: %s", task_id) + logger.debug("Tool task cancelled: %s", _task_id) if cancelled_tools > 0: logger.debug("Tool tasks cancelled: %d", cancelled_tools) @@ -274,7 +278,8 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Queue tool requests for concurrent execution if strands_event.get("toolUse"): - logger.debug("Tool queued: %s", strands_event["toolUse"].get("name")) + tool_name = strands_event["toolUse"].get("name") + logger.debug("Tool usage detected: %s", tool_name) await session.tool_queue.put(strands_event["toolUse"]) continue @@ -316,7 +321,13 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: while session.active: try: tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) - logger.debug("Tool execution started: %s (id: %s)", tool_use.get("name"), tool_use.get("toolUseId")) + tool_name = tool_use.get("name") + tool_id = tool_use.get("toolUseId") + + session.tool_count += 1 + print(f"\nTool #{session.tool_count}: {tool_name}") + + logger.debug("Tool execution started: %s (id: %s)", tool_name, tool_id) task_id = str(uuid.uuid4()) task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) @@ -330,11 +341,11 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: # Log completion status if completed_task.cancelled(): - logger.debug("Tool task cleanup cancelled: %s", task_id) + logger.debug("Tool task cancelled: %s", task_id) elif completed_task.exception(): - logger.error("Tool task cleanup error: %s - %s", task_id, str(completed_task.exception())) + logger.error("Tool task error: %s - %s", task_id, str(completed_task.exception())) else: - logger.debug("Tool task cleanup success: %s", task_id) + logger.debug("Tool task completed: %s", task_id) except Exception as e: logger.error("Tool task cleanup failed: %s - %s", task_id, str(e)) @@ -365,94 +376,106 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: - """Execute tool using Strands infrastructure with interruption support. - - Executes tools using the existing Strands tool system with proper asyncio - cancellation handling. Tool execution is stopped via task cancellation, - not manual state checks. - + """Execute tool using the complete Strands tool execution system. + + Uses proper Strands ToolExecutor system with validation, error handling, + and event streaming. + Args: session: BidirectionalConnection for context. tool_use: Tool use event to execute. """ tool_name = tool_use.get("name") tool_id = tool_use.get("toolUseId") - + + logger.debug("Executing tool: %s (id: %s)", tool_name, tool_id) + try: - # Create message structure for existing tool system + # Create message structure for validation tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} - + + # Use Strands validation system tool_uses: list[ToolUse] = [] tool_results: list[ToolResult] = [] invalid_tool_use_ids: list[str] = [] - - # Validate using existing Strands validation + validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids) - - # Filter valid tool uses + + # Filter valid tools valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] - + if not valid_tool_uses: - logger.warning("Tool validation failed: %s (id: %s)", tool_name, tool_id) + logger.warning("No valid tools after validation: %s", tool_name) return - - # Execute tools directly (simpler approach for bidirectional) - for tool_use in valid_tool_uses: - tool_func = session.agent.tool_registry.registry.get(tool_use["name"]) - - if tool_func: - try: - actual_func = _extract_callable_function(tool_func) - - # Execute tool function with provided input - result = actual_func(**tool_use.get("input", {})) - - tool_result = _create_success_result(tool_use["toolUseId"], result) - tool_results.append(tool_result) - - except Exception as e: - logger.error("Tool execution failed: %s - %s", tool_name, str(e)) - tool_result = _create_error_result(tool_use["toolUseId"], str(e)) - tool_results.append(tool_result) - else: - logger.warning("Tool not found: %s", tool_name) - - # Send results through provider-specific session - for result in tool_results: - await session.model_session.send_tool_result(tool_use.get("toolUseId"), result) - - logger.debug("Tool execution completed: %s (%d results)", tool_name, len(tool_results)) - + + # Create invocation state for tool execution + invocation_state = { + "agent": session.agent, + "model": session.agent.model, + "messages": session.agent.messages, + "system_prompt": session.agent.system_prompt, + } + + # Create cycle trace and span + cycle_trace = Trace("Bidirectional Tool Execution") + cycle_span = None + + tool_events = session.agent.tool_executor._execute( + session.agent, + valid_tool_uses, + tool_results, + cycle_trace, + cycle_span, + invocation_state + ) + + # Process tool events and send results to provider + async for tool_event in tool_events: + if isinstance(tool_event, ToolResultEvent): + tool_result = tool_event.tool_result + tool_use_id = tool_result.get("toolUseId") + + # Send result through provider-specific session + await session.model_session.send_tool_result(tool_use_id, tool_result) + logger.debug("Tool result sent: %s", tool_use_id) + + # Handle streaming events if needed later + elif isinstance(tool_event, ToolStreamEvent): + logger.debug("Tool stream event: %s", tool_event) + pass + + # Add tool result message to conversation history + if tool_results: + from ....hooks import MessageAddedEvent + + tool_result_message: Message = { + "role": "user", + "content": [{"toolResult": result} for result in tool_results], + } + + session.agent.messages.append(tool_result_message) + session.agent.hooks.invoke_callbacks(MessageAddedEvent(agent=session.agent, message=tool_result_message)) + logger.debug("Tool result message added to history: %s", tool_name) + + logger.debug("Tool execution completed: %s", tool_name) + except asyncio.CancelledError: - # Task was cancelled due to interruption - this is expected behavior - logger.debug("Tool task cancelled gracefully: %s (id: %s)", tool_name, tool_id) - raise # Re-raise to properly handle cancellation + logger.debug("Tool execution cancelled: %s (id: %s)", tool_name, tool_id) + raise except Exception as e: - logger.error("Tool execution error: %s - %s", tool_use.get("name"), str(e)) + logger.error("Tool execution error: %s - %s", tool_name, str(e)) + # Send error result + error_result: ToolResult = { + "toolUseId": tool_id, + "status": "error", + "content": [{"text": f"Error: {str(e)}"}] + } try: - await session.model_session.send_tool_result(tool_use.get("toolUseId"), {"error": str(e)}) - except Exception as send_error: - logger.error("Tool error send failed: %s", str(send_error)) - - -def _extract_callable_function(tool_func: any) -> any: - """Extract the callable function from different tool object types.""" - if hasattr(tool_func, "_tool_func"): - return tool_func._tool_func - elif hasattr(tool_func, "func"): - return tool_func.func - elif callable(tool_func): - return tool_func - else: - raise ValueError(f"Tool function not callable: {type(tool_func).__name__}") - - -def _create_success_result(tool_use_id: str, result: any) -> dict[str, any]: - """Create a successful tool result.""" - return {"toolUseId": tool_use_id, "status": "success", "content": [{"text": json.dumps(result)}]} + await session.model_session.send_tool_result(tool_id, error_result) + logger.debug("Error result sent: %s", tool_id) + except Exception: + logger.error("Failed to send error result: %s", tool_id) + pass # Session might be closed -def _create_error_result(tool_use_id: str, error: str) -> dict[str, any]: - """Create an error tool result.""" - return {"toolUseId": tool_use_id, "status": "error", "content": [{"text": f"Error: {error}"}]} diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index 6cba974e0..882f89eef 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -3,4 +3,11 @@ from .bidirectional_model import BidirectionalModel, BidirectionalModelSession from .novasonic import NovaSonicBidirectionalModel, NovaSonicSession -__all__ = ["BidirectionalModel", "BidirectionalModelSession", "NovaSonicBidirectionalModel", "NovaSonicSession"] +__all__ = [ + "BidirectionalModel", + "BidirectionalModelSession", + "NovaSonicBidirectionalModel", + "NovaSonicSession", + "OpenAIRealtimeBidirectionalModel", + "OpenAIRealtimeSession", +] diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 7f7937ef1..a1d61e11a 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -121,7 +121,7 @@ async def initialize( init_events = self._build_initialization_events(system_prompt, tools or [], messages) - logger.debug(f"Nova Sonic initialization - sending {len(init_events)} events") + logger.debug("Nova Sonic initialization - sending %d events", len(init_events)) await self._send_initialization_events(init_events) logger.info("Nova Sonic connection initialized successfully") @@ -146,7 +146,7 @@ def _build_initialization_events( async def _send_initialization_events(self, events: list[str]) -> None: """Send initialization events with required delays.""" - for i, event in enumerate(events): + for _i, event in enumerate(events): await self._send_nova_event(event) await asyncio.sleep(EVENT_DELAY) @@ -167,12 +167,12 @@ async def _process_responses(self) -> None: await asyncio.sleep(0.1) continue except Exception as e: - logger.warning(f"Nova Sonic response error: {e}") + logger.warning("Nova Sonic response error: %s", e) await asyncio.sleep(0.1) continue except Exception as e: - logger.error(f"Nova Sonic fatal error: {e}") + logger.error("Nova Sonic fatal error: %s", e) finally: logger.debug("Nova Sonic response processor stopped") @@ -190,7 +190,7 @@ async def _handle_response_data(self, response_data: str) -> None: await self._event_queue.put(nova_event) except json.JSONDecodeError as e: - logger.warning(f"Nova Sonic JSON decode error: {e}") + logger.warning("Nova Sonic JSON decode error: %s", e) def _log_event_type(self, nova_event: dict[str, any]) -> None: """Log specific Nova Sonic event types for debugging.""" @@ -383,11 +383,9 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No self._get_content_end_event(content_name), ] - for i, event in enumerate(events): + for _i, event in enumerate(events): await self._send_nova_event(event) - - async def close(self) -> None: """Close Nova Sonic connection with proper cleanup sequence.""" if not self._active: @@ -490,7 +488,14 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | No # Handle usage events (ignore) elif "usageEvent" in nova_event: - return None + usage_data = nova_event["usageEvent"] + usage_metrics: UsageMetricsEvent = { + "totalTokens": usage_data.get("totalTokens"), + "inputTokens": usage_data.get("totalInputTokens"), + "outputTokens": usage_data.get("totalOutputTokens"), + "audioTokens": usage_data.get("details", {}).get("total", {}).get("output", {}).get("speechTokens"), + } + return {"usageMetrics": usage_metrics} # Handle content start events (track role) elif "contentStart" in nova_event: diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py new file mode 100644 index 000000000..7c79e3e6c --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -0,0 +1,522 @@ +/Users/mehtarac/Desktop/sdk-python/src/strands/experimental/bidirectional_streaming/models/openai.py + +"""OpenAI Realtime API provider for Strands bidirectional streaming. + +Provides real-time audio and text communication through OpenAI's Realtime API +with WebSocket connections, voice activity detection, and function calling. +""" + +import asyncio +import base64 +import json +import logging +import uuid +from typing import AsyncIterable + +import websockets +from websockets.client import WebSocketClientProtocol +from websockets.exceptions import ConnectionClosed + +from ....types.content import Messages +from ....types.tools import ToolSpec, ToolUse +from ..types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + BidirectionalConnectionEndEvent, + BidirectionalConnectionStartEvent, + BidirectionalStreamEvent, + TextOutputEvent, + VoiceActivityEvent, +) +from .bidirectional_model import BidirectionalModel, BidirectionalModelSession + +logger = logging.getLogger(__name__) + +# OpenAI Realtime API configuration +OPENAI_REALTIME_URL = "wss://api.openai.com/v1/realtime" +DEFAULT_MODEL = "gpt-realtime" + +AUDIO_FORMAT = {"type": "audio/pcm", "rate": 24000} + +DEFAULT_SESSION_CONFIG = { + "type": "realtime", + "instructions": "You are a helpful assistant. Please speak in English and keep your responses clear and concise.", + "output_modalities": ["audio"], + "audio": { + "input": { + "format": AUDIO_FORMAT, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "prefix_padding_ms": 300, + "silence_duration_ms": 500, + }, + }, + "output": {"format": AUDIO_FORMAT, "voice": "alloy"}, + }, +} + + +class OpenAIRealtimeSession(BidirectionalModelSession): + """OpenAI Realtime API session for real-time audio/text streaming. + + Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, + function calling, and event conversion to Strands format. + """ + + def __init__(self, websocket: WebSocketClientProtocol, config: dict[str, any]) -> None: + """Initialize OpenAI Realtime session.""" + self.websocket = websocket + self.config = config + self.session_id = str(uuid.uuid4()) + self._active = True + + self._event_queue = asyncio.Queue() + self._response_task = None + self._function_call_buffer = {} + + logger.debug("OpenAI Realtime session initialized: %s", self.session_id) + + def _require_active(self) -> bool: + """Check if session is active.""" + return self._active + + def _create_text_event(self, text: str, role: str) -> dict[str, any]: + """Create standardized text output event.""" + text_output: TextOutputEvent = {"text": text, "role": role} + return {"textOutput": text_output} + + def _create_voice_activity_event(self, activity_type: str) -> dict[str, any]: + """Create standardized voice activity event.""" + voice_activity: VoiceActivityEvent = {"activityType": activity_type} + return {"voiceActivity": voice_activity} + + async def _create_conversation_item(self, item_data: dict) -> None: + """Create conversation item and trigger response.""" + await self._send_event({"type": "conversation.item.create", "item": item_data}) + await self._send_event({"type": "response.create"}) + + async def initialize( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + ) -> None: + """Initialize session with configuration.""" + try: + session_config = self._build_session_config(system_prompt, tools) + await self._send_event({"type": "session.update", "session": session_config}) + + if messages: + await self._add_conversation_history(messages) + + self._response_task = asyncio.create_task(self._process_responses()) + logger.info("OpenAI Realtime session initialized successfully") + + except Exception as e: + logger.error("Error during OpenAI Realtime initialization: %s", e) + raise + + def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict: + """Build session configuration for OpenAI Realtime API.""" + config = DEFAULT_SESSION_CONFIG.copy() + + if system_prompt: + config["instructions"] = system_prompt + + if tools: + config["tools"] = self._convert_tools_to_openai_format(tools) + + custom_config = self.config.get("session", {}) + supported_params = { + "type", + "output_modalities", + "instructions", + "voice", + "audio", + "tools", + "tool_choice", + "input_audio_format", + "output_audio_format", + "input_audio_transcription", + "turn_detection", + } + + for key, value in custom_config.items(): + if key in supported_params: + config[key] = value + else: + logger.warning("Ignoring unsupported session parameter: %s", key) + + return config + + def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: + """Convert Strands tool specifications to OpenAI Realtime API format.""" + openai_tools = [] + + for tool in tools: + input_schema = tool["inputSchema"] + if "json" in input_schema: + schema = ( + json.loads(input_schema["json"]) if isinstance(input_schema["json"], str) else input_schema["json"] + ) + else: + schema = input_schema + + # OpenAI Realtime API expects flat structure, not nested under "function" + openai_tool = { + "type": "function", + "name": tool["name"], + "description": tool["description"], + "parameters": schema, + } + openai_tools.append(openai_tool) + + return openai_tools + + async def _add_conversation_history(self, messages: Messages) -> None: + """Add conversation history to the session.""" + for message in messages: + conversation_item = { + "type": "conversation.item.create", + "item": {"type": "message", "role": message["role"], "content": []}, + } + + content = message.get("content", "") + if isinstance(content, str): + conversation_item["item"]["content"].append({"type": "input_text", "text": content}) + elif isinstance(content, list): + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + conversation_item["item"]["content"].append( + {"type": "input_text", "text": item.get("text", "")} + ) + + await self._send_event(conversation_item) + + async def _process_responses(self) -> None: + """Process incoming WebSocket messages.""" + logger.debug("OpenAI Realtime response processor started") + + try: + async for message in self.websocket: + if not self._active: + break + + try: + event = json.loads(message) + await self._event_queue.put(event) + except json.JSONDecodeError as e: + logger.warning("Failed to parse OpenAI event: %s", e) + continue + + except ConnectionClosed: + logger.debug("OpenAI Realtime WebSocket connection closed") + except Exception as e: + logger.error("Error in OpenAI Realtime response processing: %s", e) + finally: + self._active = False + logger.debug("OpenAI Realtime response processor stopped") + + async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: + """Receive OpenAI events and convert to Strands format.""" + connection_start: BidirectionalConnectionStartEvent = { + "connectionId": self.session_id, + "metadata": {"provider": "openai_realtime", "model": self.config.get("model", DEFAULT_MODEL)}, + } + yield {"BidirectionalConnectionStart": connection_start} + + try: + while self._active: + try: + openai_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) + provider_event = self._convert_openai_event(openai_event) + if provider_event: + yield provider_event + except asyncio.TimeoutError: + continue + + except Exception as e: + logger.error("Error receiving OpenAI Realtime event: %s", e) + finally: + connection_end: BidirectionalConnectionEndEvent = { + "connectionId": self.session_id, + "reason": "connection_complete", + "metadata": {"provider": "openai_realtime"}, + } + yield {"BidirectionalConnectionEnd": connection_end} + + def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] | None: + """Convert OpenAI events to Strands format.""" + event_type = openai_event.get("type") + + # Audio output + if event_type == "response.output_audio.delta": + audio_data = base64.b64decode(openai_event["delta"]) + audio_output: AudioOutputEvent = { + "audioData": audio_data, + "format": "pcm", + "sampleRate": 24000, + "channels": 1, + "encoding": None, + } + return {"audioOutput": audio_output} + + # Text output using helper method + elif event_type == "response.output_text.delta": + return self._create_text_event(openai_event["delta"], "assistant") + + elif event_type == "response.output_audio_transcript.delta": + return self._create_text_event(openai_event["delta"], "assistant") + + # User transcription + elif event_type == "conversation.item.input_audio_transcription.delta": + transcript_delta = openai_event.get("delta", "") + return self._create_text_event(transcript_delta, "user") if transcript_delta.strip() else None + + elif event_type == "conversation.item.input_audio_transcription.completed": + transcript = openai_event.get("transcript", "") + return self._create_text_event(transcript, "user") if transcript.strip() else None + + elif event_type == "conversation.item.input_audio_transcription.segment": + segment_data = openai_event.get("segment", {}) + text = segment_data.get("text", "") + return self._create_text_event(text, "user") if text.strip() else None + + elif event_type == "conversation.item.input_audio_transcription.failed": + error_info = openai_event.get("error", {}) + logger.warning("OpenAI transcription failed: %s", error_info.get("message", "Unknown error")) + return None + + # Function call processing + elif event_type == "response.function_call_arguments.delta": + call_id = openai_event.get("call_id") + delta = openai_event.get("delta", "") + if call_id: + if call_id not in self._function_call_buffer: + self._function_call_buffer[call_id] = {"call_id": call_id, "name": "", "arguments": delta} + else: + self._function_call_buffer[call_id]["arguments"] += delta + return None + + elif event_type == "response.function_call_arguments.done": + call_id = openai_event.get("call_id") + if call_id and call_id in self._function_call_buffer: + function_call = self._function_call_buffer[call_id] + try: + tool_use: ToolUse = { + "toolUseId": call_id, + "name": function_call["name"], + "input": json.loads(function_call["arguments"]) if function_call["arguments"] else {}, + } + del self._function_call_buffer[call_id] + return {"toolUse": tool_use} + except (json.JSONDecodeError, KeyError) as e: + logger.warning("Error parsing function arguments for %s: %s", call_id, e) + del self._function_call_buffer[call_id] + return None + + # Voice activity detection using helper method + elif event_type == "input_audio_buffer.speech_started": + return self._create_voice_activity_event("speech_started") + elif event_type == "input_audio_buffer.speech_stopped": + return self._create_voice_activity_event("speech_stopped") + elif event_type == "input_audio_buffer.timeout_triggered": + return self._create_voice_activity_event("timeout") + + # Lifecycle events (log only) + elif event_type == "conversation.item.retrieve": + item = openai_event.get("item", {}) + logger.debug("OpenAI conversation item retrieved: %s", item.get("id")) + return None + + elif event_type == "conversation.item.added": + logger.debug("OpenAI conversation item added: %s", openai_event.get("item", {}).get("id")) + return None + + elif event_type == "conversation.item.done": + logger.debug("OpenAI conversation item done: %s", openai_event.get("item", {}).get("id")) + + item = openai_event.get("item", {}) + if item.get("type") == "message" and item.get("role") == "assistant": + content_parts = item.get("content", []) + if content_parts: + message_content = [] + for content_part in content_parts: + if content_part.get("type") == "output_text": + message_content.append({"type": "text", "text": content_part.get("text", "")}) + elif content_part.get("type") == "output_audio": + transcript = content_part.get("transcript", "") + if transcript: + message_content.append({"type": "text", "text": transcript}) + + if message_content: + message = {"role": "assistant", "content": message_content} + return {"messageStop": {"message": message}} + return None + + elif event_type in [ + "response.output_item.added", + "response.output_item.done", + "response.content_part.added", + "response.content_part.done", + ]: + item_data = openai_event.get("item") or openai_event.get("part") + logger.debug("OpenAI %s: %s", event_type, item_data.get("id") if item_data else "unknown") + + # Track function call names from response.output_item.added + if event_type == "response.output_item.added": + item = openai_event.get("item", {}) + if item.get("type") == "function_call": + call_id = item.get("call_id") + function_name = item.get("name") + if call_id and function_name: + if call_id not in self._function_call_buffer: + self._function_call_buffer[call_id] = { + "call_id": call_id, + "name": function_name, + "arguments": "", + } + else: + self._function_call_buffer[call_id]["name"] = function_name + return None + + elif event_type in [ + "input_audio_buffer.committed", + "input_audio_buffer.cleared", + "session.created", + "session.updated", + ]: + logger.debug("OpenAI %s event", event_type) + return None + + elif event_type == "error": + logger.error("OpenAI Realtime error: %s", openai_event.get("error", {})) + return None + + else: + logger.debug("Unhandled OpenAI event type: %s", event_type) + return None + + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Send audio content to OpenAI for processing.""" + if not self._require_active(): + return + + audio_base64 = base64.b64encode(audio_input["audioData"]).decode("utf-8") + await self._send_event({"type": "input_audio_buffer.append", "audio": audio_base64}) + + async def send_text_content(self, text: str, **kwargs) -> None: + """Send text content to OpenAI for processing.""" + if not self._require_active(): + return + + item_data = {"type": "message", "role": "user", "content": [{"type": "input_text", "text": text}]} + await self._create_conversation_item(item_data) + + async def send_interrupt(self) -> None: + """Send interruption signal to OpenAI.""" + if not self._require_active(): + return + + await self._send_event({"type": "response.cancel"}) + + async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: + """Send tool result back to OpenAI.""" + if not self._require_active(): + return + + logger.debug("OpenAI tool result send: %s", tool_use_id) + result_text = json.dumps(result) if not isinstance(result, str) else result + + item_data = {"type": "function_call_output", "call_id": tool_use_id, "output": result_text} + await self._create_conversation_item(item_data) + + async def close(self) -> None: + """Close session and cleanup resources.""" + if not self._active: + return + + logger.debug("OpenAI Realtime cleanup - starting connection close") + self._active = False + + if self._response_task and not self._response_task.done(): + self._response_task.cancel() + try: + await self._response_task + except asyncio.CancelledError: + pass + + try: + await self.websocket.close() + except Exception as e: + logger.warning("Error closing OpenAI Realtime WebSocket: %s", e) + + logger.debug("OpenAI Realtime connection closed") + + async def _send_event(self, event: dict[str, any]) -> None: + """Send event to OpenAI via WebSocket.""" + try: + message = json.dumps(event) + await self.websocket.send(message) + logger.debug("Sent OpenAI event: %s", event.get("type")) + except Exception as e: + logger.error("Error sending OpenAI event: %s", e) + raise + + +class OpenAIRealtimeBidirectionalModel(BidirectionalModel): + """OpenAI Realtime API provider for Strands bidirectional streaming. + + Provides real-time audio/text communication through OpenAI's Realtime API + with WebSocket connections, voice activity detection, and function calling. + """ + + def __init__(self, model: str = DEFAULT_MODEL, api_key: str | None = None, **config: any) -> None: + """Initialize OpenAI Realtime bidirectional model.""" + self.model = model + self.api_key = api_key + self.config = config + + import os + + if not self.api_key: + self.api_key = os.getenv("OPENAI_API_KEY") + if not self.api_key: + raise ValueError( + "OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter." + ) + + logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) + + async def create_bidirectional_connection( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs, + ) -> BidirectionalModelSession: + """Create bidirectional connection to OpenAI Realtime API.""" + logger.info("Creating OpenAI Realtime connection...") + + try: + url = f"{OPENAI_REALTIME_URL}?model={self.model}" + + headers = [("Authorization", f"Bearer {self.api_key}")] + if "organization" in self.config: + headers.append(("OpenAI-Organization", self.config["organization"])) + if "project" in self.config: + headers.append(("OpenAI-Project", self.config["project"])) + + websocket = await websockets.connect(url, additional_headers=headers) + logger.info("WebSocket connected successfully") + + session = OpenAIRealtimeSession(websocket, self.config) + await session.initialize(system_prompt, tools, messages) + + logger.info("OpenAI Realtime connection established") + return session + + except Exception as e: + logger.error("OpenAI connection error: %s", e) + raise diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py new file mode 100644 index 000000000..5ce4b8cb2 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python3 +"""Test OpenAI Realtime API speech-to-speech interaction.""" + +import asyncio +import os +import sys +import time +from pathlib import Path + +# Add the src directory to Python path +sys.path.insert(0, str(Path(__file__).parent / "src")) + +import pyaudio +from strands_tools import calculator + +from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel + + +def test_direct_tool_calling(): + """Test direct tool calling functionality.""" + print("Testing direct tool calling...") + + try: + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("OPENAI_API_KEY not set - skipping test") + return + + model = OpenAIRealtimeBidirectionalModel(model="gpt-4o-realtime-preview", api_key=api_key) + agent = BidirectionalAgent(model=model, tools=[calculator]) + + # Test calculator + result = agent.tool.calculator(expression="2 * 3") + content = result.get("content", [{}])[0].get("text", "") + print(f"Result: {content}") + print("Test completed") + + except Exception as e: + print(f"Test failed: {e}") + + +async def play(context): + """Handle audio playback with interruption support.""" + audio = pyaudio.PyAudio() + + try: + speaker = audio.open( + format=pyaudio.paInt16, + channels=1, + rate=24000, # OpenAI Realtime uses 24kHz + output=True, + frames_per_buffer=1024, + ) + + while context["active"]: + try: + # Check for interruption + if context.get("interrupted", False): + # Clear audio queue on interruption + while not context["audio_out"].empty(): + try: + context["audio_out"].get_nowait() + except asyncio.QueueEmpty: + break + + context["interrupted"] = False + await asyncio.sleep(0.05) + continue + + # Get audio data with timeout + try: + audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) + + if audio_data and context["active"]: + # Play in chunks to allow interruption + chunk_size = 1024 + for i in range(0, len(audio_data), chunk_size): + if context.get("interrupted", False) or not context["active"]: + break + + chunk = audio_data[i:i + chunk_size] + speaker.write(chunk) + await asyncio.sleep(0.001) # Brief pause for responsiveness + + except asyncio.TimeoutError: + continue + + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Audio playback error: {e}") + finally: + try: + speaker.close() + except Exception: + pass + audio.terminate() + + +async def record(context): + """Handle microphone recording.""" + audio = pyaudio.PyAudio() + + try: + microphone = audio.open( + format=pyaudio.paInt16, + channels=1, + rate=24000, # Match OpenAI's expected input rate + input=True, + frames_per_buffer=1024, + ) + + while context["active"]: + try: + audio_bytes = microphone.read(1024, exception_on_overflow=False) + await context["audio_in"].put(audio_bytes) + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Microphone recording error: {e}") + finally: + try: + microphone.close() + except Exception: + pass + audio.terminate() + + +async def receive(agent, context): + """Handle events from the agent.""" + try: + async for event in agent.receive(): + if not context["active"]: + break + + # Handle audio output + if "audioOutput" in event: + audio_data = event["audioOutput"]["audioData"] + + if not context.get("interrupted", False): + await context["audio_out"].put(audio_data) + + # Handle text output (transcripts) + elif "textOutput" in event: + text_output = event["textOutput"] + role = text_output.get("role", "assistant") + text = text_output.get("text", "").strip() + + if text: + if role == "user": + print(f"User: {text}") + elif role == "assistant": + print(f"Assistant: {text}") + + # Handle interruption detection + elif "interruptionDetected" in event: + context["interrupted"] = True + + # Handle connection events + elif "BidirectionalConnectionStart" in event: + pass # Silent connection start + elif "BidirectionalConnectionEnd" in event: + context["active"] = False + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Receive handler error: {e}") + finally: + pass + + +async def send(agent, context): + """Send audio from microphone to agent.""" + try: + while context["active"]: + try: + audio_bytes = await asyncio.wait_for(context["audio_in"].get(), timeout=0.1) + + # Create audio event in expected format + audio_event = { + "audioData": audio_bytes, + "format": "pcm", + "sampleRate": 24000, + "channels": 1 + } + + await agent.send(audio_event) + + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Send handler error: {e}") + finally: + pass + + +async def main(): + """Main test function for OpenAI voice chat.""" + print("Starting OpenAI Realtime API test...") + + # Check API key + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("OPENAI_API_KEY environment variable not set") + return False + + # Check audio system + try: + audio = pyaudio.PyAudio() + audio.terminate() + except Exception as e: + print(f"Audio system error: {e}") + return False + + # Create OpenAI model + model = OpenAIRealtimeBidirectionalModel( + model="gpt-4o-realtime-preview", + api_key=api_key, + session={ + "output_modalities": ["audio"], + "audio": { + "input": { + "format": {"type": "audio/pcm", "rate": 24000}, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "silence_duration_ms": 700 + } + }, + "output": { + "format": {"type": "audio/pcm", "rate": 24000}, + "voice": "alloy" + } + } + } + ) + + # Create agent + agent = BidirectionalAgent( + model=model, + tools=[calculator], + system_prompt=( + "You are a helpful voice assistant. Keep your responses brief and natural. " + "Say hello when you first connect." + ) + ) + + # Start the session + await agent.start() + + # Create shared context + context = { + "active": True, + "audio_in": asyncio.Queue(), + "audio_out": asyncio.Queue(), + "interrupted": False, + "start_time": time.time() + } + + print("Speak into your microphone. Press Ctrl+C to stop.") + + try: + # Run all tasks concurrently + await asyncio.gather( + play(context), + record(context), + receive(agent, context), + send(agent, context), + return_exceptions=True + ) + + except KeyboardInterrupt: + print("\nInterrupted by user") + except asyncio.CancelledError: + print("\nTest cancelled") + except Exception as e: + print(f"\nError during voice chat: {e}") + finally: + print("Cleaning up...") + context["active"] = False + + try: + await agent.end() + except Exception as e: + print(f"Cleanup error: {e}") + + return True + + +if __name__ == "__main__": + # Test direct tool calling first + print("OpenAI Realtime API Test Suite") + test_direct_tool_calling() + + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nTest interrupted by user") + except Exception as e: + print(f"Test error: {e}") + import traceback + traceback.print_exc() \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py index b31607966..8c3ae3b4c 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py @@ -10,6 +10,7 @@ # Add the src directory to Python path sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) +import os import time import pyaudio @@ -19,6 +20,29 @@ from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel +def test_direct_tools(): + """Test direct tool calling.""" + print("Testing direct tool calling...") + + # Check AWS credentials + if not all([os.getenv("AWS_ACCESS_KEY_ID"), os.getenv("AWS_SECRET_ACCESS_KEY")]): + print("AWS credentials not set - skipping test") + return + + try: + model = NovaSonicBidirectionalModel() + agent = BidirectionalAgent(model=model, tools=[calculator]) + + # Test calculator + result = agent.tool.calculator(expression="2 * 3") + content = result.get("content", [{}])[0].get("text", "") + print(f"Result: {content}") + print("Test completed") + + except Exception as e: + print(f"Test failed: {e}") + + async def play(context): """Play audio output with responsive interruption support.""" audio = pyaudio.PyAudio() @@ -195,4 +219,7 @@ async def main(duration=180): if __name__ == "__main__": + # Test direct tool calling first + test_direct_tools() + asyncio.run(main()) From ee12db36c34e786fef880d9699d6696d41ffa14c Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 14 Oct 2025 08:41:45 -0400 Subject: [PATCH 019/242] feat(tool_executor): Plug tool executor into bidirectional streaming implementation --- .../bidirectional_streaming/__init__.py | 4 - .../models/__init__.py | 2 - .../models/novasonic.py | 1 + .../bidirectional_streaming/models/openai.py | 522 ------------------ ...al_streaming.py => test_bidi_novasonic.py} | 0 .../tests/test_bidi_openai.py | 317 ----------- .../types/bidirectional_streaming.py | 35 +- 7 files changed, 29 insertions(+), 852 deletions(-) delete mode 100644 src/strands/experimental/bidirectional_streaming/models/openai.py rename src/strands/experimental/bidirectional_streaming/tests/{test_bidirectional_streaming.py => test_bidi_novasonic.py} (100%) delete mode 100644 src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 844a8a1f8..0f842ee9f 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -8,7 +8,6 @@ # Model providers - What users need to create models from .models.novasonic import NovaSonicBidirectionalModel -from .models.openai import OpenAIRealtimeBidirectionalModel # Event types - For type hints and event handling from .types.bidirectional_streaming import ( @@ -18,7 +17,6 @@ InterruptionDetectedEvent, TextOutputEvent, UsageMetricsEvent, - VoiceActivityEvent, ) __all__ = [ @@ -26,14 +24,12 @@ "BidirectionalAgent", # Model providers "NovaSonicBidirectionalModel", - "OpenAIRealtimeBidirectionalModel", # Event types "AudioInputEvent", "AudioOutputEvent", "TextOutputEvent", "InterruptionDetectedEvent", "BidirectionalStreamEvent", - "VoiceActivityEvent", "UsageMetricsEvent", # Model interface "BidirectionalModel", diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index 882f89eef..3a785e98a 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -8,6 +8,4 @@ "BidirectionalModelSession", "NovaSonicBidirectionalModel", "NovaSonicSession", - "OpenAIRealtimeBidirectionalModel", - "OpenAIRealtimeSession", ] diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index a1d61e11a..7f35a3c1c 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -35,6 +35,7 @@ BidirectionalConnectionStartEvent, InterruptionDetectedEvent, TextOutputEvent, + UsageMetricsEvent ) from .bidirectional_model import BidirectionalModel, BidirectionalModelSession diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py deleted file mode 100644 index 7c79e3e6c..000000000 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ /dev/null @@ -1,522 +0,0 @@ -/Users/mehtarac/Desktop/sdk-python/src/strands/experimental/bidirectional_streaming/models/openai.py - -"""OpenAI Realtime API provider for Strands bidirectional streaming. - -Provides real-time audio and text communication through OpenAI's Realtime API -with WebSocket connections, voice activity detection, and function calling. -""" - -import asyncio -import base64 -import json -import logging -import uuid -from typing import AsyncIterable - -import websockets -from websockets.client import WebSocketClientProtocol -from websockets.exceptions import ConnectionClosed - -from ....types.content import Messages -from ....types.tools import ToolSpec, ToolUse -from ..types.bidirectional_streaming import ( - AudioInputEvent, - AudioOutputEvent, - BidirectionalConnectionEndEvent, - BidirectionalConnectionStartEvent, - BidirectionalStreamEvent, - TextOutputEvent, - VoiceActivityEvent, -) -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession - -logger = logging.getLogger(__name__) - -# OpenAI Realtime API configuration -OPENAI_REALTIME_URL = "wss://api.openai.com/v1/realtime" -DEFAULT_MODEL = "gpt-realtime" - -AUDIO_FORMAT = {"type": "audio/pcm", "rate": 24000} - -DEFAULT_SESSION_CONFIG = { - "type": "realtime", - "instructions": "You are a helpful assistant. Please speak in English and keep your responses clear and concise.", - "output_modalities": ["audio"], - "audio": { - "input": { - "format": AUDIO_FORMAT, - "turn_detection": { - "type": "server_vad", - "threshold": 0.5, - "prefix_padding_ms": 300, - "silence_duration_ms": 500, - }, - }, - "output": {"format": AUDIO_FORMAT, "voice": "alloy"}, - }, -} - - -class OpenAIRealtimeSession(BidirectionalModelSession): - """OpenAI Realtime API session for real-time audio/text streaming. - - Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, - function calling, and event conversion to Strands format. - """ - - def __init__(self, websocket: WebSocketClientProtocol, config: dict[str, any]) -> None: - """Initialize OpenAI Realtime session.""" - self.websocket = websocket - self.config = config - self.session_id = str(uuid.uuid4()) - self._active = True - - self._event_queue = asyncio.Queue() - self._response_task = None - self._function_call_buffer = {} - - logger.debug("OpenAI Realtime session initialized: %s", self.session_id) - - def _require_active(self) -> bool: - """Check if session is active.""" - return self._active - - def _create_text_event(self, text: str, role: str) -> dict[str, any]: - """Create standardized text output event.""" - text_output: TextOutputEvent = {"text": text, "role": role} - return {"textOutput": text_output} - - def _create_voice_activity_event(self, activity_type: str) -> dict[str, any]: - """Create standardized voice activity event.""" - voice_activity: VoiceActivityEvent = {"activityType": activity_type} - return {"voiceActivity": voice_activity} - - async def _create_conversation_item(self, item_data: dict) -> None: - """Create conversation item and trigger response.""" - await self._send_event({"type": "conversation.item.create", "item": item_data}) - await self._send_event({"type": "response.create"}) - - async def initialize( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - ) -> None: - """Initialize session with configuration.""" - try: - session_config = self._build_session_config(system_prompt, tools) - await self._send_event({"type": "session.update", "session": session_config}) - - if messages: - await self._add_conversation_history(messages) - - self._response_task = asyncio.create_task(self._process_responses()) - logger.info("OpenAI Realtime session initialized successfully") - - except Exception as e: - logger.error("Error during OpenAI Realtime initialization: %s", e) - raise - - def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict: - """Build session configuration for OpenAI Realtime API.""" - config = DEFAULT_SESSION_CONFIG.copy() - - if system_prompt: - config["instructions"] = system_prompt - - if tools: - config["tools"] = self._convert_tools_to_openai_format(tools) - - custom_config = self.config.get("session", {}) - supported_params = { - "type", - "output_modalities", - "instructions", - "voice", - "audio", - "tools", - "tool_choice", - "input_audio_format", - "output_audio_format", - "input_audio_transcription", - "turn_detection", - } - - for key, value in custom_config.items(): - if key in supported_params: - config[key] = value - else: - logger.warning("Ignoring unsupported session parameter: %s", key) - - return config - - def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: - """Convert Strands tool specifications to OpenAI Realtime API format.""" - openai_tools = [] - - for tool in tools: - input_schema = tool["inputSchema"] - if "json" in input_schema: - schema = ( - json.loads(input_schema["json"]) if isinstance(input_schema["json"], str) else input_schema["json"] - ) - else: - schema = input_schema - - # OpenAI Realtime API expects flat structure, not nested under "function" - openai_tool = { - "type": "function", - "name": tool["name"], - "description": tool["description"], - "parameters": schema, - } - openai_tools.append(openai_tool) - - return openai_tools - - async def _add_conversation_history(self, messages: Messages) -> None: - """Add conversation history to the session.""" - for message in messages: - conversation_item = { - "type": "conversation.item.create", - "item": {"type": "message", "role": message["role"], "content": []}, - } - - content = message.get("content", "") - if isinstance(content, str): - conversation_item["item"]["content"].append({"type": "input_text", "text": content}) - elif isinstance(content, list): - for item in content: - if isinstance(item, dict) and item.get("type") == "text": - conversation_item["item"]["content"].append( - {"type": "input_text", "text": item.get("text", "")} - ) - - await self._send_event(conversation_item) - - async def _process_responses(self) -> None: - """Process incoming WebSocket messages.""" - logger.debug("OpenAI Realtime response processor started") - - try: - async for message in self.websocket: - if not self._active: - break - - try: - event = json.loads(message) - await self._event_queue.put(event) - except json.JSONDecodeError as e: - logger.warning("Failed to parse OpenAI event: %s", e) - continue - - except ConnectionClosed: - logger.debug("OpenAI Realtime WebSocket connection closed") - except Exception as e: - logger.error("Error in OpenAI Realtime response processing: %s", e) - finally: - self._active = False - logger.debug("OpenAI Realtime response processor stopped") - - async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: - """Receive OpenAI events and convert to Strands format.""" - connection_start: BidirectionalConnectionStartEvent = { - "connectionId": self.session_id, - "metadata": {"provider": "openai_realtime", "model": self.config.get("model", DEFAULT_MODEL)}, - } - yield {"BidirectionalConnectionStart": connection_start} - - try: - while self._active: - try: - openai_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) - provider_event = self._convert_openai_event(openai_event) - if provider_event: - yield provider_event - except asyncio.TimeoutError: - continue - - except Exception as e: - logger.error("Error receiving OpenAI Realtime event: %s", e) - finally: - connection_end: BidirectionalConnectionEndEvent = { - "connectionId": self.session_id, - "reason": "connection_complete", - "metadata": {"provider": "openai_realtime"}, - } - yield {"BidirectionalConnectionEnd": connection_end} - - def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] | None: - """Convert OpenAI events to Strands format.""" - event_type = openai_event.get("type") - - # Audio output - if event_type == "response.output_audio.delta": - audio_data = base64.b64decode(openai_event["delta"]) - audio_output: AudioOutputEvent = { - "audioData": audio_data, - "format": "pcm", - "sampleRate": 24000, - "channels": 1, - "encoding": None, - } - return {"audioOutput": audio_output} - - # Text output using helper method - elif event_type == "response.output_text.delta": - return self._create_text_event(openai_event["delta"], "assistant") - - elif event_type == "response.output_audio_transcript.delta": - return self._create_text_event(openai_event["delta"], "assistant") - - # User transcription - elif event_type == "conversation.item.input_audio_transcription.delta": - transcript_delta = openai_event.get("delta", "") - return self._create_text_event(transcript_delta, "user") if transcript_delta.strip() else None - - elif event_type == "conversation.item.input_audio_transcription.completed": - transcript = openai_event.get("transcript", "") - return self._create_text_event(transcript, "user") if transcript.strip() else None - - elif event_type == "conversation.item.input_audio_transcription.segment": - segment_data = openai_event.get("segment", {}) - text = segment_data.get("text", "") - return self._create_text_event(text, "user") if text.strip() else None - - elif event_type == "conversation.item.input_audio_transcription.failed": - error_info = openai_event.get("error", {}) - logger.warning("OpenAI transcription failed: %s", error_info.get("message", "Unknown error")) - return None - - # Function call processing - elif event_type == "response.function_call_arguments.delta": - call_id = openai_event.get("call_id") - delta = openai_event.get("delta", "") - if call_id: - if call_id not in self._function_call_buffer: - self._function_call_buffer[call_id] = {"call_id": call_id, "name": "", "arguments": delta} - else: - self._function_call_buffer[call_id]["arguments"] += delta - return None - - elif event_type == "response.function_call_arguments.done": - call_id = openai_event.get("call_id") - if call_id and call_id in self._function_call_buffer: - function_call = self._function_call_buffer[call_id] - try: - tool_use: ToolUse = { - "toolUseId": call_id, - "name": function_call["name"], - "input": json.loads(function_call["arguments"]) if function_call["arguments"] else {}, - } - del self._function_call_buffer[call_id] - return {"toolUse": tool_use} - except (json.JSONDecodeError, KeyError) as e: - logger.warning("Error parsing function arguments for %s: %s", call_id, e) - del self._function_call_buffer[call_id] - return None - - # Voice activity detection using helper method - elif event_type == "input_audio_buffer.speech_started": - return self._create_voice_activity_event("speech_started") - elif event_type == "input_audio_buffer.speech_stopped": - return self._create_voice_activity_event("speech_stopped") - elif event_type == "input_audio_buffer.timeout_triggered": - return self._create_voice_activity_event("timeout") - - # Lifecycle events (log only) - elif event_type == "conversation.item.retrieve": - item = openai_event.get("item", {}) - logger.debug("OpenAI conversation item retrieved: %s", item.get("id")) - return None - - elif event_type == "conversation.item.added": - logger.debug("OpenAI conversation item added: %s", openai_event.get("item", {}).get("id")) - return None - - elif event_type == "conversation.item.done": - logger.debug("OpenAI conversation item done: %s", openai_event.get("item", {}).get("id")) - - item = openai_event.get("item", {}) - if item.get("type") == "message" and item.get("role") == "assistant": - content_parts = item.get("content", []) - if content_parts: - message_content = [] - for content_part in content_parts: - if content_part.get("type") == "output_text": - message_content.append({"type": "text", "text": content_part.get("text", "")}) - elif content_part.get("type") == "output_audio": - transcript = content_part.get("transcript", "") - if transcript: - message_content.append({"type": "text", "text": transcript}) - - if message_content: - message = {"role": "assistant", "content": message_content} - return {"messageStop": {"message": message}} - return None - - elif event_type in [ - "response.output_item.added", - "response.output_item.done", - "response.content_part.added", - "response.content_part.done", - ]: - item_data = openai_event.get("item") or openai_event.get("part") - logger.debug("OpenAI %s: %s", event_type, item_data.get("id") if item_data else "unknown") - - # Track function call names from response.output_item.added - if event_type == "response.output_item.added": - item = openai_event.get("item", {}) - if item.get("type") == "function_call": - call_id = item.get("call_id") - function_name = item.get("name") - if call_id and function_name: - if call_id not in self._function_call_buffer: - self._function_call_buffer[call_id] = { - "call_id": call_id, - "name": function_name, - "arguments": "", - } - else: - self._function_call_buffer[call_id]["name"] = function_name - return None - - elif event_type in [ - "input_audio_buffer.committed", - "input_audio_buffer.cleared", - "session.created", - "session.updated", - ]: - logger.debug("OpenAI %s event", event_type) - return None - - elif event_type == "error": - logger.error("OpenAI Realtime error: %s", openai_event.get("error", {})) - return None - - else: - logger.debug("Unhandled OpenAI event type: %s", event_type) - return None - - async def send_audio_content(self, audio_input: AudioInputEvent) -> None: - """Send audio content to OpenAI for processing.""" - if not self._require_active(): - return - - audio_base64 = base64.b64encode(audio_input["audioData"]).decode("utf-8") - await self._send_event({"type": "input_audio_buffer.append", "audio": audio_base64}) - - async def send_text_content(self, text: str, **kwargs) -> None: - """Send text content to OpenAI for processing.""" - if not self._require_active(): - return - - item_data = {"type": "message", "role": "user", "content": [{"type": "input_text", "text": text}]} - await self._create_conversation_item(item_data) - - async def send_interrupt(self) -> None: - """Send interruption signal to OpenAI.""" - if not self._require_active(): - return - - await self._send_event({"type": "response.cancel"}) - - async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: - """Send tool result back to OpenAI.""" - if not self._require_active(): - return - - logger.debug("OpenAI tool result send: %s", tool_use_id) - result_text = json.dumps(result) if not isinstance(result, str) else result - - item_data = {"type": "function_call_output", "call_id": tool_use_id, "output": result_text} - await self._create_conversation_item(item_data) - - async def close(self) -> None: - """Close session and cleanup resources.""" - if not self._active: - return - - logger.debug("OpenAI Realtime cleanup - starting connection close") - self._active = False - - if self._response_task and not self._response_task.done(): - self._response_task.cancel() - try: - await self._response_task - except asyncio.CancelledError: - pass - - try: - await self.websocket.close() - except Exception as e: - logger.warning("Error closing OpenAI Realtime WebSocket: %s", e) - - logger.debug("OpenAI Realtime connection closed") - - async def _send_event(self, event: dict[str, any]) -> None: - """Send event to OpenAI via WebSocket.""" - try: - message = json.dumps(event) - await self.websocket.send(message) - logger.debug("Sent OpenAI event: %s", event.get("type")) - except Exception as e: - logger.error("Error sending OpenAI event: %s", e) - raise - - -class OpenAIRealtimeBidirectionalModel(BidirectionalModel): - """OpenAI Realtime API provider for Strands bidirectional streaming. - - Provides real-time audio/text communication through OpenAI's Realtime API - with WebSocket connections, voice activity detection, and function calling. - """ - - def __init__(self, model: str = DEFAULT_MODEL, api_key: str | None = None, **config: any) -> None: - """Initialize OpenAI Realtime bidirectional model.""" - self.model = model - self.api_key = api_key - self.config = config - - import os - - if not self.api_key: - self.api_key = os.getenv("OPENAI_API_KEY") - if not self.api_key: - raise ValueError( - "OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter." - ) - - logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) - - async def create_bidirectional_connection( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs, - ) -> BidirectionalModelSession: - """Create bidirectional connection to OpenAI Realtime API.""" - logger.info("Creating OpenAI Realtime connection...") - - try: - url = f"{OPENAI_REALTIME_URL}?model={self.model}" - - headers = [("Authorization", f"Bearer {self.api_key}")] - if "organization" in self.config: - headers.append(("OpenAI-Organization", self.config["organization"])) - if "project" in self.config: - headers.append(("OpenAI-Project", self.config["project"])) - - websocket = await websockets.connect(url, additional_headers=headers) - logger.info("WebSocket connected successfully") - - session = OpenAIRealtimeSession(websocket, self.config) - await session.initialize(system_prompt, tools, messages) - - logger.info("OpenAI Realtime connection established") - return session - - except Exception as e: - logger.error("OpenAI connection error: %s", e) - raise diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py similarity index 100% rename from src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py rename to src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py deleted file mode 100644 index 5ce4b8cb2..000000000 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py +++ /dev/null @@ -1,317 +0,0 @@ -#!/usr/bin/env python3 -"""Test OpenAI Realtime API speech-to-speech interaction.""" - -import asyncio -import os -import sys -import time -from pathlib import Path - -# Add the src directory to Python path -sys.path.insert(0, str(Path(__file__).parent / "src")) - -import pyaudio -from strands_tools import calculator - -from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel - - -def test_direct_tool_calling(): - """Test direct tool calling functionality.""" - print("Testing direct tool calling...") - - try: - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - print("OPENAI_API_KEY not set - skipping test") - return - - model = OpenAIRealtimeBidirectionalModel(model="gpt-4o-realtime-preview", api_key=api_key) - agent = BidirectionalAgent(model=model, tools=[calculator]) - - # Test calculator - result = agent.tool.calculator(expression="2 * 3") - content = result.get("content", [{}])[0].get("text", "") - print(f"Result: {content}") - print("Test completed") - - except Exception as e: - print(f"Test failed: {e}") - - -async def play(context): - """Handle audio playback with interruption support.""" - audio = pyaudio.PyAudio() - - try: - speaker = audio.open( - format=pyaudio.paInt16, - channels=1, - rate=24000, # OpenAI Realtime uses 24kHz - output=True, - frames_per_buffer=1024, - ) - - while context["active"]: - try: - # Check for interruption - if context.get("interrupted", False): - # Clear audio queue on interruption - while not context["audio_out"].empty(): - try: - context["audio_out"].get_nowait() - except asyncio.QueueEmpty: - break - - context["interrupted"] = False - await asyncio.sleep(0.05) - continue - - # Get audio data with timeout - try: - audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) - - if audio_data and context["active"]: - # Play in chunks to allow interruption - chunk_size = 1024 - for i in range(0, len(audio_data), chunk_size): - if context.get("interrupted", False) or not context["active"]: - break - - chunk = audio_data[i:i + chunk_size] - speaker.write(chunk) - await asyncio.sleep(0.001) # Brief pause for responsiveness - - except asyncio.TimeoutError: - continue - - except asyncio.CancelledError: - break - - except asyncio.CancelledError: - pass - except Exception as e: - print(f"Audio playback error: {e}") - finally: - try: - speaker.close() - except Exception: - pass - audio.terminate() - - -async def record(context): - """Handle microphone recording.""" - audio = pyaudio.PyAudio() - - try: - microphone = audio.open( - format=pyaudio.paInt16, - channels=1, - rate=24000, # Match OpenAI's expected input rate - input=True, - frames_per_buffer=1024, - ) - - while context["active"]: - try: - audio_bytes = microphone.read(1024, exception_on_overflow=False) - await context["audio_in"].put(audio_bytes) - await asyncio.sleep(0.01) - except asyncio.CancelledError: - break - - except asyncio.CancelledError: - pass - except Exception as e: - print(f"Microphone recording error: {e}") - finally: - try: - microphone.close() - except Exception: - pass - audio.terminate() - - -async def receive(agent, context): - """Handle events from the agent.""" - try: - async for event in agent.receive(): - if not context["active"]: - break - - # Handle audio output - if "audioOutput" in event: - audio_data = event["audioOutput"]["audioData"] - - if not context.get("interrupted", False): - await context["audio_out"].put(audio_data) - - # Handle text output (transcripts) - elif "textOutput" in event: - text_output = event["textOutput"] - role = text_output.get("role", "assistant") - text = text_output.get("text", "").strip() - - if text: - if role == "user": - print(f"User: {text}") - elif role == "assistant": - print(f"Assistant: {text}") - - # Handle interruption detection - elif "interruptionDetected" in event: - context["interrupted"] = True - - # Handle connection events - elif "BidirectionalConnectionStart" in event: - pass # Silent connection start - elif "BidirectionalConnectionEnd" in event: - context["active"] = False - break - - except asyncio.CancelledError: - pass - except Exception as e: - print(f"Receive handler error: {e}") - finally: - pass - - -async def send(agent, context): - """Send audio from microphone to agent.""" - try: - while context["active"]: - try: - audio_bytes = await asyncio.wait_for(context["audio_in"].get(), timeout=0.1) - - # Create audio event in expected format - audio_event = { - "audioData": audio_bytes, - "format": "pcm", - "sampleRate": 24000, - "channels": 1 - } - - await agent.send(audio_event) - - except asyncio.TimeoutError: - continue - except asyncio.CancelledError: - break - - except asyncio.CancelledError: - pass - except Exception as e: - print(f"Send handler error: {e}") - finally: - pass - - -async def main(): - """Main test function for OpenAI voice chat.""" - print("Starting OpenAI Realtime API test...") - - # Check API key - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - print("OPENAI_API_KEY environment variable not set") - return False - - # Check audio system - try: - audio = pyaudio.PyAudio() - audio.terminate() - except Exception as e: - print(f"Audio system error: {e}") - return False - - # Create OpenAI model - model = OpenAIRealtimeBidirectionalModel( - model="gpt-4o-realtime-preview", - api_key=api_key, - session={ - "output_modalities": ["audio"], - "audio": { - "input": { - "format": {"type": "audio/pcm", "rate": 24000}, - "turn_detection": { - "type": "server_vad", - "threshold": 0.5, - "silence_duration_ms": 700 - } - }, - "output": { - "format": {"type": "audio/pcm", "rate": 24000}, - "voice": "alloy" - } - } - } - ) - - # Create agent - agent = BidirectionalAgent( - model=model, - tools=[calculator], - system_prompt=( - "You are a helpful voice assistant. Keep your responses brief and natural. " - "Say hello when you first connect." - ) - ) - - # Start the session - await agent.start() - - # Create shared context - context = { - "active": True, - "audio_in": asyncio.Queue(), - "audio_out": asyncio.Queue(), - "interrupted": False, - "start_time": time.time() - } - - print("Speak into your microphone. Press Ctrl+C to stop.") - - try: - # Run all tasks concurrently - await asyncio.gather( - play(context), - record(context), - receive(agent, context), - send(agent, context), - return_exceptions=True - ) - - except KeyboardInterrupt: - print("\nInterrupted by user") - except asyncio.CancelledError: - print("\nTest cancelled") - except Exception as e: - print(f"\nError during voice chat: {e}") - finally: - print("Cleaning up...") - context["active"] = False - - try: - await agent.end() - except Exception as e: - print(f"Cleanup error: {e}") - - return True - - -if __name__ == "__main__": - # Test direct tool calling first - print("OpenAI Realtime API Test Suite") - test_direct_tool_calling() - - try: - asyncio.run(main()) - except KeyboardInterrupt: - print("\nTest interrupted by user") - except Exception as e: - print(f"Test error: {e}") - import traceback - traceback.print_exc() \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 01d72356a..c0f6eb209 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -116,10 +116,28 @@ class BidirectionalConnectionEndEvent(TypedDict): metadata: Provider-specific connection metadata. """ - reason: Literal["user_request", "timeout", "error"] + reason: Literal["user_request", "timeout", "error", "connection_complete"] connectionId: Optional[str] metadata: Optional[Dict[str, Any]] +class UsageMetricsEvent(TypedDict): + """Token usage and performance tracking. + + Provides standardized usage metrics across providers for cost monitoring + and performance optimization. + + Attributes: + totalTokens: Total tokens used in the interaction. + inputTokens: Tokens used for input processing. + outputTokens: Tokens used for output generation. + audioTokens: Tokens used specifically for audio processing. + """ + + totalTokens: Optional[int] + inputTokens: Optional[int] + outputTokens: Optional[int] + audioTokens: Optional[int] + class BidirectionalStreamEvent(StreamEvent, total=False): """Bidirectional stream event extending existing StreamEvent. @@ -134,11 +152,14 @@ class BidirectionalStreamEvent(StreamEvent, total=False): interruptionDetected: User interruption detection. BidirectionalConnectionStart: connection start event. BidirectionalConnectionEnd: connection end event. + usageMetrics: Token usage and performance metrics. """ - audioOutput: AudioOutputEvent - audioInput: AudioInputEvent - textOutput: TextOutputEvent - interruptionDetected: InterruptionDetectedEvent - BidirectionalConnectionStart: BidirectionalConnectionStartEvent - BidirectionalConnectionEnd: BidirectionalConnectionEndEvent + audioOutput: Optional[AudioOutputEvent] + audioInput: Optional[AudioInputEvent] + textOutput: Optional[TextOutputEvent] + interruptionDetected: Optional[InterruptionDetectedEvent] + BidirectionalConnectionStart: Optional[BidirectionalConnectionStartEvent] + BidirectionalConnectionEnd: Optional[BidirectionalConnectionEndEvent] + usageMetrics: Optional[UsageMetricsEvent] + From 4679e0c803d0e7b6fad7d32ef0866309fd8b55e4 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Mon, 20 Oct 2025 10:03:11 -0400 Subject: [PATCH 020/242] (feat)bidirectional_streaming: add openai realtime model provider #3 --- .../models/novasonic.py | 14 ++-- .../bidirectional_streaming/models/openai.py | 64 +++++++++---------- 2 files changed, 37 insertions(+), 41 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 4e4952fa9..db21fb967 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -23,7 +23,7 @@ from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme -from aws_sdk_bedrock_runtime.models import BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk +from aws_sdk_bedrock_runtime.models import BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk, InvokeModelWithBidirectionalStreamOperationOutput from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver from ....types.content import Messages @@ -80,11 +80,11 @@ class NovaSonicSession(BidirectionalModelSession): interface. """ - def __init__(self, stream: any, config: dict[str, any]) -> None: + def __init__(self, stream: InvokeModelWithBidirectionalStreamOperationOutput, config: dict[str, any]) -> None: """Initialize Nova Sonic connection. Args: - stream: Nova Sonic bidirectional stream. + stream: Nova Sonic bidirectional stream operation output from AWS SDK. config: Model configuration. """ self.stream = stream @@ -492,10 +492,10 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | No elif "usageEvent" in nova_event: usage_data = nova_event["usageEvent"] usage_metrics: UsageMetricsEvent = { - "totalTokens": usage_data.get("totalTokens"), - "inputTokens": usage_data.get("totalInputTokens"), - "outputTokens": usage_data.get("totalOutputTokens"), - "audioTokens": usage_data.get("details", {}).get("total", {}).get("output", {}).get("speechTokens") + "totalTokens": usage_data.get("totalTokens", 0), + "inputTokens": usage_data.get("totalInputTokens", 0), + "outputTokens": usage_data.get("totalOutputTokens", 0), + "audioTokens": usage_data.get("details", {}).get("total", {}).get("output", {}).get("speechTokens", 0) } return {"usageMetrics": usage_metrics} diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 76bf9f50d..7d009b1c7 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -89,10 +89,7 @@ def _create_voice_activity_event(self, activity_type: str) -> dict[str, any]: voice_activity: VoiceActivityEvent = {"activityType": activity_type} return {"voiceActivity": voice_activity} - async def _create_conversation_item(self, item_data: dict) -> None: - """Create conversation item and trigger response.""" - await self._send_event({"type": "conversation.item.create", "item": item_data}) - await self._send_event({"type": "response.create"}) + async def initialize( self, @@ -248,21 +245,16 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] } return {"audioOutput": audio_output} - # Text output using helper method - elif event_type == "response.output_text.delta": + # Assistant text output events - combine multiple similar events + elif event_type in ["response.output_text.delta", "response.output_audio_transcript.delta"]: return self._create_text_event(openai_event["delta"], "assistant") - elif event_type == "response.output_audio_transcript.delta": - return self._create_text_event(openai_event["delta"], "assistant") - - # User transcription - elif event_type == "conversation.item.input_audio_transcription.delta": - transcript_delta = openai_event.get("delta", "") - return self._create_text_event(transcript_delta, "user") if transcript_delta.strip() else None - - elif event_type == "conversation.item.input_audio_transcription.completed": - transcript = openai_event.get("transcript", "") - return self._create_text_event(transcript, "user") if transcript.strip() else None + # User transcription events - combine multiple similar events + elif event_type in ["conversation.item.input_audio_transcription.delta", + "conversation.item.input_audio_transcription.completed"]: + text_key = "delta" if "delta" in event_type else "transcript" + text = openai_event.get(text_key, "") + return self._create_text_event(text, "user") if text.strip() else None elif event_type == "conversation.item.input_audio_transcription.segment": segment_data = openai_event.get("segment", {}) @@ -302,22 +294,22 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] del self._function_call_buffer[call_id] return None - # Voice activity detection using helper method - elif event_type == "input_audio_buffer.speech_started": - return self._create_voice_activity_event("speech_started") - elif event_type == "input_audio_buffer.speech_stopped": - return self._create_voice_activity_event("speech_stopped") - elif event_type == "input_audio_buffer.timeout_triggered": - return self._create_voice_activity_event("timeout") + # Voice activity detection events - combine similar events using mapping + elif event_type in ["input_audio_buffer.speech_started", "input_audio_buffer.speech_stopped", + "input_audio_buffer.timeout_triggered"]: + # Map event types to activity types + activity_map = { + "input_audio_buffer.speech_started": "speech_started", + "input_audio_buffer.speech_stopped": "speech_stopped", + "input_audio_buffer.timeout_triggered": "timeout" + } + return self._create_voice_activity_event(activity_map[event_type]) - # Lifecycle events (log only) - elif event_type == "conversation.item.retrieve": + # Lifecycle events (log only) - combine multiple similar events + elif event_type in ["conversation.item.retrieve", "conversation.item.added"]: item = openai_event.get("item", {}) - logger.debug("OpenAI conversation item retrieved: %s", item.get("id")) - return None - - elif event_type == "conversation.item.added": - logger.debug("OpenAI conversation item added: %s", openai_event.get("item", {}).get("id")) + action = "retrieved" if "retrieve" in event_type else "added" + logger.debug("OpenAI conversation item %s: %s", action, item.get("id")) return None elif event_type == "conversation.item.done": @@ -341,6 +333,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] return {"messageStop": {"message": message}} return None + # Response output events - combine similar events elif event_type in ["response.output_item.added", "response.output_item.done", "response.content_part.added", "response.content_part.done"]: item_data = openai_event.get("item") or openai_event.get("part") @@ -359,6 +352,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] self._function_call_buffer[call_id]["name"] = function_name return None + # Session/buffer events - combine simple log-only events elif event_type in ["input_audio_buffer.committed", "input_audio_buffer.cleared", "session.created", "session.updated"]: logger.debug("OpenAI %s event", event_type) @@ -380,7 +374,7 @@ async def send_audio_content(self, audio_input: AudioInputEvent) -> None: audio_base64 = base64.b64encode(audio_input["audioData"]).decode("utf-8") await self._send_event({"type": "input_audio_buffer.append", "audio": audio_base64}) - async def send_text_content(self, text: str, **kwargs) -> None: + async def send_text_content(self, text: str) -> None: """Send text content to OpenAI for processing.""" if not self._require_active(): return @@ -390,7 +384,8 @@ async def send_text_content(self, text: str, **kwargs) -> None: "role": "user", "content": [{"type": "input_text", "text": text}] } - await self._create_conversation_item(item_data) + await self._send_event({"type": "conversation.item.create", "item": item_data}) + await self._send_event({"type": "response.create"}) async def send_interrupt(self) -> None: """Send interruption signal to OpenAI.""" @@ -412,7 +407,8 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No "call_id": tool_use_id, "output": result_text } - await self._create_conversation_item(item_data) + await self._send_event({"type": "conversation.item.create", "item": item_data}) + await self._send_event({"type": "response.create"}) async def close(self) -> None: """Close session and cleanup resources.""" From 883f6fc2790a383bdb97a29a2d203463e730e446 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Fri, 24 Oct 2025 10:55:55 -0400 Subject: [PATCH 021/242] feat: (Agent): Finalize Bidirectional Agent class --- src/strands/agent/agent.py | 113 +----------- .../bidirectional_streaming/agent/agent.py | 171 +++++------------- src/strands/tools/caller.py | 166 +++++++++++++++++ 3 files changed, 212 insertions(+), 238 deletions(-) create mode 100644 src/strands/tools/caller.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 4579ebacf..4f0ab5d6a 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -46,6 +46,7 @@ from ..session.session_manager import SessionManager from ..telemetry.metrics import EventLoopMetrics from ..telemetry.tracer import get_tracer, serialize +from ..tools.caller import ToolCaller from ..tools.executors import ConcurrentToolExecutor from ..tools.executors._executor import ToolExecutor from ..tools.registry import ToolRegistry @@ -94,116 +95,6 @@ class Agent: 6. Produces a final response """ - class ToolCaller: - """Call tool as a function.""" - - def __init__(self, agent: "Agent") -> None: - """Initialize instance. - - Args: - agent: Agent reference that will accept tool results. - """ - # WARNING: Do not add any other member variables or methods as this could result in a name conflict with - # agent tools and thus break their execution. - self._agent = agent - - def __getattr__(self, name: str) -> Callable[..., Any]: - """Call tool as a function. - - This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). - It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing'). - - Args: - name: The name of the attribute (tool) being accessed. - - Returns: - A function that when called will execute the named tool. - - Raises: - AttributeError: If no tool with the given name exists or if multiple tools match the given name. - """ - - def caller( - user_message_override: Optional[str] = None, - record_direct_tool_call: Optional[bool] = None, - **kwargs: Any, - ) -> Any: - """Call a tool directly by name. - - Args: - user_message_override: Optional custom message to record instead of default - record_direct_tool_call: Whether to record direct tool calls in message history. Overrides class - attribute if provided. - **kwargs: Keyword arguments to pass to the tool. - - Returns: - The result returned by the tool. - - Raises: - AttributeError: If the tool doesn't exist. - """ - normalized_name = self._find_normalized_tool_name(name) - - # Create unique tool ID and set up the tool request - tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" - tool_use: ToolUse = { - "toolUseId": tool_id, - "name": normalized_name, - "input": kwargs.copy(), - } - tool_results: list[ToolResult] = [] - invocation_state = kwargs - - async def acall() -> ToolResult: - async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): - _ = event - - return tool_results[0] - - def tcall() -> ToolResult: - return asyncio.run(acall()) - - with ThreadPoolExecutor() as executor: - future = executor.submit(tcall) - tool_result = future.result() - - if record_direct_tool_call is not None: - should_record_direct_tool_call = record_direct_tool_call - else: - should_record_direct_tool_call = self._agent.record_direct_tool_call - - if should_record_direct_tool_call: - # Create a record of this tool execution in the message history - self._agent._record_tool_execution(tool_use, tool_result, user_message_override) - - # Apply window management - self._agent.conversation_manager.apply_management(self._agent) - - return tool_result - - return caller - - def _find_normalized_tool_name(self, name: str) -> str: - """Lookup the tool represented by name, replacing characters with underscores as necessary.""" - tool_registry = self._agent.tool_registry.registry - - if tool_registry.get(name, None): - return name - - # If the desired name contains underscores, it might be a placeholder for characters that can't be - # represented as python identifiers but are valid as tool names, such as dashes. In that case, find - # all tools that can be represented with the normalized name - if "_" in name: - filtered_tools = [ - tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name - ] - - # The registry itself defends against similar names, so we can just take the first match - if filtered_tools: - return filtered_tools[0] - - raise AttributeError(f"Tool '{name}' not found") - def __init__( self, model: Union[Model, str, None] = None, @@ -333,7 +224,7 @@ def __init__( else: self.state = AgentState() - self.tool_caller = Agent.ToolCaller(self) + self.tool_caller = ToolCaller(self) self.hooks = HookRegistry() diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 26b964c53..b1ca5edf4 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -15,13 +15,12 @@ import asyncio import json import logging -import random -from concurrent.futures import ThreadPoolExecutor -from typing import Any, AsyncIterable, Callable, Mapping, Optional +from typing import Any, AsyncIterable, Mapping, Optional, Union, TYPE_CHECKING from .... import _identifier from ....hooks import HookProvider, HookRegistry from ....telemetry.metrics import EventLoopMetrics +from ....tools.caller import ToolCaller from ....tools.executors import ConcurrentToolExecutor from ....tools.executors._executor import ToolExecutor from ....tools.registry import ToolRegistry @@ -32,6 +31,10 @@ from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection from ..models.bidirectional_model import BidirectionalModel from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent +from ..models.novasonic import NovaSonicBidirectionalModel + +if TYPE_CHECKING: + from ..event_loop.bidirectional_event_loop import BidirectionalEventLoop logger = logging.getLogger(__name__) @@ -47,117 +50,12 @@ class BidirectionalAgent: sessions. Supports concurrent tool execution and interruption handling. """ - class ToolCaller: - """Call tool as a function for bidirectional agent.""" - - def __init__(self, agent: "BidirectionalAgent") -> None: - """Initialize tool caller with agent reference.""" - # WARNING: Do not add any other member variables or methods as this could result in a name conflict with - # agent tools and thus break their execution. - self._agent = agent - - def __getattr__(self, name: str) -> Callable[..., Any]: - """Call tool as a function. - - This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). - It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing'). - - Args: - name: The name of the attribute (tool) being accessed. - - Returns: - A function that when called will execute the named tool. - - Raises: - AttributeError: If no tool with the given name exists or if multiple tools match the given name. - """ - - def caller( - user_message_override: Optional[str] = None, - record_direct_tool_call: Optional[bool] = None, - **kwargs: Any, - ) -> Any: - """Call a tool directly by name. - - Args: - user_message_override: Optional custom message to record instead of default - record_direct_tool_call: Whether to record direct tool calls in message history. - For bidirectional agents, this is always True to maintain conversation history. - **kwargs: Keyword arguments to pass to the tool. - - Returns: - The result returned by the tool. - - Raises: - AttributeError: If the tool doesn't exist. - """ - normalized_name = self._find_normalized_tool_name(name) - - # Create unique tool ID and set up the tool request - tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" - tool_use: ToolUse = { - "toolUseId": tool_id, - "name": normalized_name, - "input": kwargs.copy(), - } - tool_results: list[ToolResult] = [] - invocation_state = kwargs - - async def acall() -> ToolResult: - async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): - _ = event - - return tool_results[0] - - def tcall() -> ToolResult: - return asyncio.run(acall()) - - with ThreadPoolExecutor() as executor: - future = executor.submit(tcall) - tool_result = future.result() - - # Always record direct tool calls for bidirectional agents to maintain conversation history - # Use agent's record_direct_tool_call setting if not overridden - if record_direct_tool_call is not None: - should_record_direct_tool_call = record_direct_tool_call - else: - should_record_direct_tool_call = self._agent.record_direct_tool_call - - if should_record_direct_tool_call: - # Create a record of this tool execution in the message history - self._agent._record_tool_execution(tool_use, tool_result, user_message_override) - - return tool_result - - return caller - - def _find_normalized_tool_name(self, name: str) -> str: - """Lookup the tool represented by name, replacing characters with underscores as necessary.""" - tool_registry = self._agent.tool_registry.registry - - if tool_registry.get(name, None): - return name - - # If the desired name contains underscores, it might be a placeholder for characters that can't be - # represented as python identifiers but are valid as tool names, such as dashes. In that case, find - # all tools that can be represented with the normalized name - if "_" in name: - filtered_tools = [ - tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name - ] - - # The registry itself defends against similar names, so we can just take the first match - if filtered_tools: - return filtered_tools[0] - - raise AttributeError(f"Tool '{name}' not found") - def __init__( self, - model: BidirectionalModel, - tools: list | None = None, - system_prompt: str | None = None, - messages: Messages | None = None, + model: Union[BidirectionalModel, str, None] = None, + tools: Optional[list[Union[str, dict[str, str], Any]]] = None, + system_prompt: Optional[str] = None, + messages: Optional[Messages] = None, record_direct_tool_call: bool = True, load_tools_from_directory: bool = False, agent_id: Optional[str] = None, @@ -166,12 +64,13 @@ def __init__( hooks: Optional[list[HookProvider]] = None, trace_attributes: Optional[Mapping[str, AttributeValue]] = None, description: Optional[str] = None, + **kwargs: Any, ): - """Initialize bidirectional agent with required model and optional configuration. + """Initialize bidirectional agent with flexible model support and extensible configuration. Args: - model: BidirectionalModel instance supporting streaming sessions. - tools: Optional list of tools available to the model. + model: BidirectionalModel instance, string model_id, or None for default detection. + tools: Optional list of tools with flexible format support. system_prompt: Optional system prompt for conversations. messages: Optional conversation history to initialize with. record_direct_tool_call: Whether to record direct tool calls in message history. @@ -182,16 +81,28 @@ def __init__( hooks: Hooks to be added to the agent hook registry. trace_attributes: Custom trace attributes to apply to the agent's trace span. description: Description of what the Agent does. + **kwargs: Additional configuration for future extensibility. + + Raises: + ValueError: If model configuration is invalid. + TypeError: If model type is unsupported. """ - self.model = model + + self.model = ( + NovaSonicBidirectionalModel() + if not model + else NovaSonicBidirectionalModel(model_id=model) + if isinstance(model, str) + else model + ) self.system_prompt = system_prompt self.messages = messages or [] - + # Agent identification self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) self.name = name or _DEFAULT_AGENT_NAME self.description = description - + # Tool execution configuration self.record_direct_tool_call = record_direct_tool_call self.load_tools_from_directory = load_tools_from_directory @@ -207,39 +118,42 @@ def __init__( # Initialize tool registry self.tool_registry = ToolRegistry() - + if tools is not None: self.tool_registry.process_tools(tools) - + self.tool_registry.initialize_tools(self.load_tools_from_directory) - + # Initialize tool watcher if directory loading is enabled if self.load_tools_from_directory: self.tool_watcher = ToolWatcher(tool_registry=self.tool_registry) # Initialize tool executor self.tool_executor = tool_executor or ConcurrentToolExecutor() - + # Initialize hooks system self.hooks = HookRegistry() if hooks: for hook in hooks: self.hooks.add_hook(hook) - + # Initialize other components self.event_loop_metrics = EventLoopMetrics() - self.tool_caller = BidirectionalAgent.ToolCaller(self) + self.tool_caller = ToolCaller(self) # Session management self._session = None self._output_queue = asyncio.Queue() + # Store extensibility kwargs for future use + self._config_kwargs = kwargs + @property def tool(self) -> ToolCaller: """Call tool as a function. Returns: - Tool caller through which user can invoke tool as a function. + ToolCaller for method-style tool execution. Example: ``` @@ -359,10 +273,11 @@ async def start(self) -> None: raise ValueError("Conversation already active. Call end() first.") logger.debug("Conversation start - initializing session") + self._session = await start_bidirectional_connection(self) logger.debug("Conversation ready") - async def send(self, input_data: str | AudioInputEvent) -> None: + async def send(self, input_data: Union[str, AudioInputEvent]) -> None: """Send input to the model (text or audio). Unified method for sending both text and audio input to the model during @@ -379,7 +294,9 @@ async def send(self, input_data: str | AudioInputEvent) -> None: if isinstance(input_data, str): # Add user text message to history - self.messages.append({"role": "user", "content": input_data}) + user_message: Message = {"role": "user", "content": [{"text": input_data}]} + + self.messages.append(user_message) logger.debug("Text sent: %d characters", len(input_data)) await self._session.model_session.send_text_content(input_data) diff --git a/src/strands/tools/caller.py b/src/strands/tools/caller.py new file mode 100644 index 000000000..06e1f23b3 --- /dev/null +++ b/src/strands/tools/caller.py @@ -0,0 +1,166 @@ +"""Shared ToolCaller base class to eliminate duplication between agent implementations. + +Provides common tool calling functionality that can be used by both traditional +Agent and BidirectionalAgent classes with agent-specific customizations. +""" + +import asyncio +import random +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Callable, Optional + +from ..tools.executors._executor import ToolExecutor +from ..types.tools import ToolResult, ToolUse + + +class ToolCaller: + """Universal tool caller that works with both traditional and bidirectional agents. + + Automatically detects agent type and applies appropriate behavior: + - Traditional agents: Uses conversation_manager.apply_management() + - Bidirectional agents: Skips conversation management (not needed for streaming) + """ + + def __init__(self, agent: Any) -> None: + """Initialize base tool caller. + + Args: + agent: Agent instance that will process tool results. + """ + # WARNING: Do not add other member variables to avoid conflicts with tool names + self._agent = agent + + def __getattr__(self, name: str) -> Callable[..., Any]: + """Enable method-style tool calling interface. + + This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). + It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing'). + + Args: + name: The name of the attribute (tool) being accessed. + + Returns: + A function that when called will execute the named tool. + + Raises: + AttributeError: If no tool with the given name exists or if multiple tools match the given name. + """ + + def caller( + user_message_override: Optional[str] = None, + record_direct_tool_call: Optional[bool] = None, + **kwargs: Any, + ) -> Any: + """Call a tool directly by name. + + Args: + user_message_override: Optional custom message to record instead of default. + record_direct_tool_call: Whether to record direct tool calls in message history. + **kwargs: Keyword arguments to pass to the tool. + + Returns: + The result returned by the tool. + + Raises: + AttributeError: If the tool doesn't exist. + """ + normalized_name = self._find_normalized_tool_name(name) + + # Create unique tool ID and set up the tool request + tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" + tool_use: ToolUse = { + "toolUseId": tool_id, + "name": normalized_name, + "input": kwargs.copy(), + } + + # Execute tool using shared execution pipeline + tool_result = self._execute_tool_sync(tool_use, kwargs) + + # Handle tool call recording with agent-specific behavior + self._handle_tool_call_recording(tool_use, tool_result, user_message_override, record_direct_tool_call) + + return tool_result + + return caller + + def _find_normalized_tool_name(self, name: str) -> str: + """Lookup the tool represented by name, replacing characters with underscores as necessary. + + Args: + name: Tool name to normalize. + + Returns: + Normalized tool name that exists in registry. + + Raises: + AttributeError: If tool not found. + """ + tool_registry = self._agent.tool_registry.registry + + if tool_registry.get(name, None): + return name + + # Handle underscore placeholder for characters that can't be python identifiers + if "_" in name: + filtered_tools = [ + tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name + ] + + # Registry defends against similar names, so take first match + if filtered_tools: + return filtered_tools[0] + + raise AttributeError(f"Tool '{name}' not found") + + def _execute_tool_sync(self, tool_use: ToolUse, invocation_state: dict[str, Any]) -> ToolResult: + """Execute tool synchronously using shared Strands pipeline. + + Args: + tool_use: Tool execution request. + invocation_state: Execution context. + + Returns: + Tool execution result. + """ + tool_results: list[ToolResult] = [] + + async def acall() -> ToolResult: + async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): + _ = event + return tool_results[0] + + def tcall() -> ToolResult: + return asyncio.run(acall()) + + with ThreadPoolExecutor() as executor: + future = executor.submit(tcall) + return future.result() + + def _handle_tool_call_recording( + self, + tool_use: ToolUse, + tool_result: ToolResult, + user_message_override: Optional[str], + record_direct_tool_call: Optional[bool], + ) -> None: + """Handle tool call recording with agent-specific behavior. + + Args: + tool_use: Tool execution information. + tool_result: Tool result. + user_message_override: Optional message override. + record_direct_tool_call: Optional recording override. + """ + # Determine if we should record the tool call + should_record = ( + record_direct_tool_call if record_direct_tool_call is not None else self._agent.record_direct_tool_call + ) + + if should_record: + # Use agent's recording method + self._agent._record_tool_execution(tool_use, tool_result, user_message_override) + + # Apply conversation management if agent supports it (traditional agents) + if hasattr(self._agent, "conversation_manager"): + self._agent.conversation_manager.apply_management(self._agent) From 23d8da85ac964a167d840be72b90bcda65f99798 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 28 Oct 2025 09:48:08 -0400 Subject: [PATCH 022/242] feat: (Agent): Finalize Bidirectional Agent class --- .../bidirectional_streaming/agent/agent.py | 38 +++++++------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index b1ca5edf4..b39ed10e5 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -8,14 +8,14 @@ Key capabilities: - Persistent conversation sessions with concurrent processing - Real-time audio input/output streaming -- Mid-conversation interruption and tool execution +- Automatic interruption detection and tool execution - Event-driven communication with model providers """ import asyncio import json import logging -from typing import Any, AsyncIterable, Mapping, Optional, Union, TYPE_CHECKING +from typing import Any, AsyncIterable, Mapping, Optional, Union from .... import _identifier from ....hooks import HookProvider, HookRegistry @@ -28,14 +28,10 @@ from ....types.content import Message, Messages from ....types.tools import ToolResult, ToolUse from ....types.traces import AttributeValue -from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection +from ..event_loop.bidirectional_event_loop import BidirectionalAgentLoop from ..models.bidirectional_model import BidirectionalModel -from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent from ..models.novasonic import NovaSonicBidirectionalModel - -if TYPE_CHECKING: - from ..event_loop.bidirectional_event_loop import BidirectionalEventLoop - +from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent logger = logging.getLogger(__name__) @@ -87,7 +83,6 @@ def __init__( ValueError: If model configuration is invalid. TypeError: If model type is unsupported. """ - self.model = ( NovaSonicBidirectionalModel() if not model @@ -142,7 +137,7 @@ def __init__( self.tool_caller = ToolCaller(self) # Session management - self._session = None + self._session: Optional["BidirectionalAgentLoop"] = None self._output_queue = asyncio.Queue() # Store extensibility kwargs for future use @@ -274,7 +269,14 @@ async def start(self) -> None: logger.debug("Conversation start - initializing session") - self._session = await start_bidirectional_connection(self) + # Create model session and event loop directly + model_session = await self.model.create_bidirectional_connection( + system_prompt=self.system_prompt, tools=self.tool_registry.get_all_tool_specs(), messages=self.messages + ) + + self._session = BidirectionalAgentLoop(model_session=model_session, agent=self) + await self._session.start() + logger.debug("Conversation ready") async def send(self, input_data: Union[str, AudioInputEvent]) -> None: @@ -325,18 +327,6 @@ async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: except asyncio.TimeoutError: continue - async def interrupt(self) -> None: - """Interrupt the current model generation and clear audio buffers. - - Sends interruption signal to stop generation immediately and clears - pending audio output for responsive conversation flow. - - Raises: - ValueError: If no active session. - """ - self._validate_active_session() - await self._session.model_session.send_interrupt() - async def end(self) -> None: """End the conversation session and cleanup all resources. @@ -344,7 +334,7 @@ async def end(self) -> None: closes the connection to the model provider. """ if self._session: - await stop_bidirectional_connection(self._session) + await self._session.stop() self._session = None def _validate_active_session(self) -> None: From 4648327c2b4dce49b3f25561012bdf76ab7cb72d Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 29 Oct 2025 15:52:11 +0100 Subject: [PATCH 023/242] feat(gemini): Add bidirectional gemini model --- .../bidirectional_streaming/agent/agent.py | 30 +- .../models/__init__.py | 11 +- .../models/bidirectional_model.py | 11 +- .../models/gemini_live.py | 499 ++++++++++++++++++ .../tests/test_gemini_live.py | 359 +++++++++++++ .../bidirectional_streaming/types/__init__.py | 4 + .../types/bidirectional_streaming.py | 38 ++ 7 files changed, 933 insertions(+), 19 deletions(-) create mode 100644 src/strands/experimental/bidirectional_streaming/models/gemini_live.py create mode 100644 src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 6f8360ade..820a6c490 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -31,7 +31,7 @@ from ....types.traces import AttributeValue from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection from ..models.bidirectional_model import BidirectionalModel -from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent +from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent logger = logging.getLogger(__name__) @@ -359,18 +359,16 @@ async def start(self) -> None: logger.debug("Conversation start - initializing session") self._session = await start_bidirectional_connection(self) - logger.debug("Conversation ready") - - async def send(self, input_data: str | AudioInputEvent) -> None: - """Send input to the model (text or audio). - - Unified method for sending both text and audio input to the model during - an active conversation session. User input is automatically added to - conversation history for complete message tracking. - + + async def send(self, input_data: str | AudioInputEvent | ImageInputEvent) -> None: + """Send input to the model (text, audio, or image). + + Unified method for sending text, audio, and image input to the model during + an active conversation session. + Args: - input_data: Either a string for text input or AudioInputEvent for audio input. - + input_data: String for text, AudioInputEvent for audio, or ImageInputEvent for images. + Raises: ValueError: If no active session or invalid input type. """ @@ -385,10 +383,14 @@ async def send(self, input_data: str | AudioInputEvent) -> None: elif isinstance(input_data, dict) and "audioData" in input_data: # Handle audio input await self._session.model_session.send_audio_content(input_data) + elif isinstance(input_data, dict) and "imageData" in input_data: + # Handle image input (ImageInputEvent) + await self._session.model_session.send_image_content(input_data) else: raise ValueError( - "Input must be either a string (text) or AudioInputEvent " - "(dict with audioData, format, sampleRate, channels)" + "Input must be either a string (text), AudioInputEvent " + "(dict with audioData, format, sampleRate, channels), or ImageInputEvent " + "(dict with imageData, mimeType, encoding)" ) async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index 67254d4fe..c5287d15d 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -1,14 +1,17 @@ """Bidirectional model interfaces and implementations.""" from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .gemini_live import GeminiLiveBidirectionalModel, GeminiLiveSession from .novasonic import NovaSonicBidirectionalModel, NovaSonicSession from .openai import OpenAIRealtimeBidirectionalModel, OpenAIRealtimeSession __all__ = [ - "BidirectionalModel", - "BidirectionalModelSession", - "NovaSonicBidirectionalModel", + "BidirectionalModel", + "BidirectionalModelSession", + "GeminiLiveBidirectionalModel", + "GeminiLiveSession", + "NovaSonicBidirectionalModel", "NovaSonicSession", "OpenAIRealtimeBidirectionalModel", - "OpenAIRealtimeSession" + "OpenAIRealtimeSession", ] diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index d5c3c9b65..42485561b 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -17,7 +17,7 @@ from ....types.content import Messages from ....types.tools import ToolSpec -from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent +from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent logger = logging.getLogger(__name__) @@ -48,6 +48,15 @@ async def send_audio_content(self, audio_input: AudioInputEvent) -> None: """ raise NotImplementedError + # TODO: remove with interface unification + async def send_image_content(self, image_input: ImageInputEvent) -> None: + """Send image content to the model during an active connection. + + Handles image encoding and provider-specific formatting while presenting + a simple ImageInputEvent interface. + """ + raise NotImplementedError + @abc.abstractmethod async def send_text_content(self, text: str, **kwargs) -> None: """Send text content to the model during ongoing generation. diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py new file mode 100644 index 000000000..64c4d7348 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -0,0 +1,499 @@ +"""Gemini Live API bidirectional model provider using official Google GenAI SDK. + +Implements the BidirectionalModel interface for Google's Gemini Live API using the +official Google GenAI SDK for simplified and robust WebSocket communication. + +Key improvements over custom WebSocket implementation: +- Uses official google-genai SDK with native Live API support +- Simplified session management with client.aio.live.connect() +- Built-in tool integration and event handling +- Automatic WebSocket connection management and error handling +- Native support for audio/text streaming and interruption +""" + +import asyncio +import base64 +import logging +import uuid +from typing import Any, AsyncIterable, Dict, List, Optional + +from google import genai +from google.genai import types as genai_types +from google.genai.types import LiveServerMessage, LiveServerContent + +from ....types.content import Messages +from ....types.tools import ToolSpec, ToolUse +from ..types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + BidirectionalConnectionEndEvent, + BidirectionalConnectionStartEvent, + ImageInputEvent, + InterruptionDetectedEvent, + TextOutputEvent, + TranscriptEvent, +) +from .bidirectional_model import BidirectionalModel, BidirectionalModelSession + +logger = logging.getLogger(__name__) + +# Audio format constants +GEMINI_INPUT_SAMPLE_RATE = 16000 +GEMINI_OUTPUT_SAMPLE_RATE = 24000 +GEMINI_CHANNELS = 1 + + +class GeminiLiveSession(BidirectionalModelSession): + """Gemini Live API session using official Google GenAI SDK. + + Provides a clean interface to Gemini Live API using the official SDK, + eliminating custom WebSocket handling and providing robust error handling. + """ + + def __init__(self, client: genai.Client, model_id: str, config: Dict[str, Any]): + """Initialize Gemini Live API session. + + Args: + client: Gemini client instance + model_id: Model identifier + config: Model configuration including live config + """ + self.client = client + self.model_id = model_id + self.config = config + self.session_id = str(uuid.uuid4()) + self._active = True + self.live_session = None + self.live_session_cm = None + + + + async def initialize( + self, + system_prompt: Optional[str] = None, + tools: Optional[List[ToolSpec]] = None, + messages: Optional[Messages] = None + ) -> None: + """Initialize Gemini Live API session by creating the connection.""" + + try: + # Build live config + live_config = self.config.get("live_config") + + if live_config is None: + raise ValueError("live_config is required but not found in session config") + + # Create the context manager + self.live_session_cm = self.client.aio.live.connect( + model=self.model_id, + config=live_config + ) + + # Enter the context manager + self.live_session = await self.live_session_cm.__aenter__() + + # Send initial message history if provided + if messages: + await self._send_message_history(messages) + + + except Exception as e: + logger.error("Error initializing Gemini Live session: %s", e) + raise + + async def _send_message_history(self, messages: Messages) -> None: + """Send conversation history to Gemini Live API. + + Sends each message as a separate turn with the correct role to maintain + proper conversation context. Follows the same pattern as the non-bidirectional + Gemini model implementation. + """ + if not messages: + return + + # Convert each message to Gemini format and send separately + for message in messages: + content_parts = [] + for content_block in message["content"]: + if "text" in content_block: + content_parts.append(genai_types.Part(text=content_block["text"])) + + if content_parts: + # Map role correctly - Gemini uses "user" and "model" roles + # "assistant" role from Messages format maps to "model" in Gemini + role = "model" if message["role"] == "assistant" else message["role"] + content = genai_types.Content(role=role, parts=content_parts) + await self.live_session.send_client_content(turns=content) + + async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: + """Receive Gemini Live API events and convert to provider-agnostic format.""" + + # Emit connection start event + connection_start: BidirectionalConnectionStartEvent = { + "connectionId": self.session_id, + "metadata": {"provider": "gemini_live", "model_id": self.config.get("model_id")} + } + yield {"BidirectionalConnectionStart": connection_start} + + try: + # Wrap in while loop to restart after turn_complete (SDK limitation workaround) + while self._active: + try: + async for message in self.live_session.receive(): + if not self._active: + break + + # Convert to provider-agnostic format + provider_event = self._convert_gemini_live_event(message) + if provider_event: + yield provider_event + + # SDK exits receive loop after turn_complete - restart automatically + if self._active: + logger.debug("Restarting receive loop after turn completion") + + except Exception as e: + logger.error("Error in receive iteration: %s", e) + # Small delay before retrying to avoid tight error loops + await asyncio.sleep(0.1) + + except Exception as e: + logger.error("Fatal error in receive loop: %s", e) + finally: + # Emit connection end event when exiting + connection_end: BidirectionalConnectionEndEvent = { + "connectionId": self.session_id, + "reason": "connection_complete", + "metadata": {"provider": "gemini_live"} + } + yield {"BidirectionalConnectionEnd": connection_end} + + def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dict[str, Any]]: + """Convert Gemini Live API events to provider-agnostic format. + + Handles different types of text output: + - inputTranscription: User's speech transcribed to text (emitted as transcript event) + - outputTranscription: Model's audio transcribed to text (emitted as transcript event) + - modelTurn text: Actual text response from the model (emitted as textOutput) + """ + try: + # Handle interruption first (from server_content) + if message.server_content and message.server_content.interrupted: + interruption: InterruptionDetectedEvent = { + "reason": "user_input" + } + return {"interruptionDetected": interruption} + + # Handle input transcription (user's speech) - emit as transcript event + if message.server_content and message.server_content.input_transcription: + input_transcript = message.server_content.input_transcription + # Check if the transcription object has text content + if hasattr(input_transcript, 'text') and input_transcript.text: + transcription_text = input_transcript.text + logger.debug(f"Input transcription detected: {transcription_text}") + transcript: TranscriptEvent = { + "text": transcription_text, + "role": "user", + "type": "input" + } + return {"transcript": transcript} + + # Handle output transcription (model's audio) - emit as transcript event + if message.server_content and message.server_content.output_transcription: + output_transcript = message.server_content.output_transcription + # Check if the transcription object has text content + if hasattr(output_transcript, 'text') and output_transcript.text: + transcription_text = output_transcript.text + logger.debug(f"Output transcription detected: {transcription_text}") + transcript: TranscriptEvent = { + "text": transcription_text, + "role": "assistant", + "type": "output" + } + return {"transcript": transcript} + + # Handle actual text output from model (not transcription) + # The SDK's message.text property accesses modelTurn.parts[].text + if message.text: + text_output: TextOutputEvent = { + "text": message.text, + "role": "assistant" + } + return {"textOutput": text_output} + + # Handle audio output using SDK's built-in data property + if message.data: + audio_output: AudioOutputEvent = { + "audioData": message.data, + "format": "pcm", + "sampleRate": GEMINI_OUTPUT_SAMPLE_RATE, + "channels": GEMINI_CHANNELS, + "encoding": "raw" + } + return {"audioOutput": audio_output} + + # Handle tool calls + if message.tool_call and message.tool_call.function_calls: + for func_call in message.tool_call.function_calls: + tool_use_event: ToolUse = { + "toolUseId": func_call.id, + "name": func_call.name, + "input": func_call.args or {} + } + return {"toolUse": tool_use_event} + + # Silently ignore setup_complete, turn_complete, generation_complete, and usage_metadata messages + return None + + except Exception as e: + logger.error("Error converting Gemini Live event: %s", e) + logger.error("Message type: %s", type(message).__name__) + logger.error("Message attributes: %s", [attr for attr in dir(message) if not attr.startswith('_')]) + return None + + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Send audio content using Gemini Live API. + + Gemini Live expects continuous audio streaming via send_realtime_input. + This automatically triggers VAD and can interrupt ongoing responses. + """ + if not self._active: + return + + try: + # Create audio blob for the SDK + audio_blob = genai_types.Blob( + data=audio_input["audioData"], + mime_type=f"audio/pcm;rate={GEMINI_INPUT_SAMPLE_RATE}" + ) + + # Send real-time audio input - this automatically handles VAD and interruption + await self.live_session.send_realtime_input(audio=audio_blob) + + except Exception as e: + logger.error("Error sending audio content: %s", e) + + async def send_image_content(self, image_input: ImageInputEvent) -> None: + """Send image content using Gemini Live API. + + Sends image frames following the same pattern as the GitHub example. + Images are sent as base64-encoded data with MIME type. + """ + if not self._active: + return + + try: + # Prepare the message based on encoding + if image_input["encoding"] == "base64": + # Data is already base64 encoded + if isinstance(image_input["imageData"], bytes): + data_str = image_input["imageData"].decode() + else: + data_str = image_input["imageData"] + else: + # Raw bytes - need to base64 encode + data_str = base64.b64encode(image_input["imageData"]).decode() + + # Create the message in the format expected by Gemini Live + msg = { + "mime_type": image_input["mimeType"], + "data": data_str + } + + # Send using the same method as the GitHub example + await self.live_session.send(input=msg) + + except Exception as e: + logger.error("Error sending image content: %s", e) + + async def send_text_content(self, text: str, **kwargs) -> None: + """Send text content using Gemini Live API.""" + if not self._active: + return + + try: + # Create content with text + content = genai_types.Content( + role="user", + parts=[genai_types.Part(text=text)] + ) + + # Send as client content + await self.live_session.send_client_content(turns=content) + + except Exception as e: + logger.error("Error sending text content: %s", e) + + async def send_interrupt(self) -> None: + """Send interruption signal to Gemini Live API. + + Gemini Live uses automatic VAD-based interruption. When new audio input + is detected, it automatically interrupts the ongoing generation. + We don't need to send explicit interrupt signals like Nova Sonic. + """ + if not self._active: + return + + try: + # Gemini Live handles interruption automatically through VAD + # When new audio input is sent via send_realtime_input, it automatically + # interrupts any ongoing generation. No explicit interrupt signal needed. + logger.debug("Interrupt requested - Gemini Live handles this automatically via VAD") + + except Exception as e: + logger.error("Error in interrupt handling: %s", e) + + async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: + """Send tool result using Gemini Live API.""" + if not self._active: + return + + try: + # Create function response + func_response = genai_types.FunctionResponse( + id=tool_use_id, + name=tool_use_id, # Gemini uses name as identifier + response=result + ) + + # Send tool response + await self.live_session.send_tool_response(function_responses=[func_response]) + except Exception as e: + logger.error("Error sending tool result: %s", e) + + async def send_tool_error(self, tool_use_id: str, error: str) -> None: + """Send tool error using Gemini Live API.""" + error_result = {"error": error} + await self.send_tool_result(tool_use_id, error_result) + + async def close(self) -> None: + """Close Gemini Live API connection.""" + if not self._active: + return + + self._active = False + + try: + # Exit the context manager properly + if self.live_session_cm: + await self.live_session_cm.__aexit__(None, None, None) + except Exception as e: + logger.error("Error closing Gemini Live session: %s", e) + raise + + +class GeminiLiveBidirectionalModel(BidirectionalModel): + """Gemini Live API model implementation using official Google GenAI SDK. + + Provides access to Google's Gemini Live API through the bidirectional + streaming interface, using the official SDK for robust and simple integration. + """ + + def __init__( + self, + model_id: str = "models/gemini-2.0-flash-live-preview-04-09", + api_key: Optional[str] = None, + **config + ): + """Initialize Gemini Live API bidirectional model. + + Args: + model_id: Gemini Live model identifier. + api_key: Google AI API key for authentication. + **config: Additional configuration. + """ + self.model_id = model_id + self.api_key = api_key + self.config = config + + # Create Gemini client with proper API version + client_kwargs = {} + if api_key: + client_kwargs["api_key"] = api_key + + # Use v1alpha for Live API as it has better model support + client_kwargs["http_options"] = {"api_version": "v1alpha"} + + self.client = genai.Client(**client_kwargs) + + async def create_bidirectional_connection( + self, + system_prompt: Optional[str] = None, + tools: Optional[List[ToolSpec]] = None, + messages: Optional[Messages] = None, + **kwargs + ) -> BidirectionalModelSession: + """Create Gemini Live API bidirectional connection using official SDK.""" + + try: + # Build configuration + live_config = self._build_live_config(system_prompt, tools, **kwargs) + + # Create session config + session_config = self._get_session_config() + session_config["live_config"] = live_config + + # Create and initialize session wrapper + session = GeminiLiveSession(self.client, self.model_id, session_config) + await session.initialize(system_prompt, tools, messages) + + return session + + except Exception as e: + logger.error("Failed to create Gemini Live connection: %s", e) + raise + + def _build_live_config( + self, + system_prompt: Optional[str] = None, + tools: Optional[List[ToolSpec]] = None, + **kwargs + ) -> Dict[str, Any]: + """Build LiveConnectConfig for the official SDK. + + Simply passes through all config parameters from params, allowing users + to configure any Gemini Live API parameter directly. + """ + # Start with user config from params + config_dict = {} + if "params" in self.config: + config_dict.update(self.config["params"]) + + # Override with any kwargs + config_dict.update(kwargs) + + # Add system instruction if provided + if system_prompt: + config_dict["system_instruction"] = system_prompt + + # Add tools if provided + if tools: + config_dict["tools"] = self._format_tools_for_live_api(tools) + + return config_dict + + def _format_tools_for_live_api(self, tool_specs: List[ToolSpec]) -> List[genai_types.Tool]: + """Format tool specs for Gemini Live API.""" + if not tool_specs: + return [] + + return [ + genai_types.Tool( + function_declarations=[ + genai_types.FunctionDeclaration( + description=tool_spec["description"], + name=tool_spec["name"], + parameters_json_schema=tool_spec["inputSchema"]["json"], + ) + for tool_spec in tool_specs + ], + ), + ] + + def _get_session_config(self) -> Dict[str, Any]: + """Get session configuration for Gemini Live API.""" + return { + "model_id": self.model_id, + "params": self.config.get("params"), + **self.config + } \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py new file mode 100644 index 000000000..4469e819a --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py @@ -0,0 +1,359 @@ +"""Test suite for Gemini Live bidirectional streaming with camera support. + +Tests the Gemini Live API with real-time audio and video interaction including: +- Audio input/output streaming +- Camera frame capture and transmission +- Interruption handling +- Concurrent tool execution +- Transcript events + +Requirements: +- pip install opencv-python pillow pyaudio google-genai +- Camera access permissions +- GOOGLE_AI_API_KEY environment variable +""" + +import asyncio +import base64 +import io +import logging +import os +import sys +from pathlib import Path + +# Add the src directory to Python path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) +import time + +try: + import cv2 + import PIL.Image + CAMERA_AVAILABLE = True +except ImportError as e: + print(f"Camera dependencies not available: {e}") + print("Install with: pip install opencv-python pillow") + CAMERA_AVAILABLE = False + +import pyaudio +from strands_tools import calculator + +from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveBidirectionalModel + +# Configure logging - debug only for Gemini Live, info for everything else +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +gemini_logger = logging.getLogger('strands.experimental.bidirectional_streaming.models.gemini_live') +gemini_logger.setLevel(logging.DEBUG) +logger = logging.getLogger(__name__) + + +async def play(context): + """Play audio output with responsive interruption support.""" + audio = pyaudio.PyAudio() + speaker = audio.open( + channels=1, + format=pyaudio.paInt16, + output=True, + rate=24000, + frames_per_buffer=1024, + ) + + try: + while context["active"]: + try: + # Check for interruption first + if context.get("interrupted", False): + # Clear entire audio queue immediately + while not context["audio_out"].empty(): + try: + context["audio_out"].get_nowait() + except asyncio.QueueEmpty: + break + + context["interrupted"] = False + await asyncio.sleep(0.05) + continue + + # Get next audio data + audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) + + if audio_data and context["active"]: + chunk_size = 1024 + for i in range(0, len(audio_data), chunk_size): + # Check for interruption before each chunk + if context.get("interrupted", False) or not context["active"]: + break + + end = min(i + chunk_size, len(audio_data)) + chunk = audio_data[i:end] + speaker.write(chunk) + await asyncio.sleep(0.001) + + except asyncio.TimeoutError: + continue # No audio available + except asyncio.QueueEmpty: + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + finally: + speaker.close() + audio.terminate() + + +async def record(context): + """Record audio input from microphone.""" + audio = pyaudio.PyAudio() + + # List all available audio devices + print("Available audio devices:") + for i in range(audio.get_device_count()): + device_info = audio.get_device_info_by_index(i) + if device_info['maxInputChannels'] > 0: # Only show input devices + print(f" Device {i}: {device_info['name']} (inputs: {device_info['maxInputChannels']})") + + # Get default input device info + default_device = audio.get_default_input_device_info() + print(f"\nUsing default input device: {default_device['name']} (Device {default_device['index']})") + + microphone = audio.open( + channels=1, + format=pyaudio.paInt16, + frames_per_buffer=1024, + input=True, + rate=16000, + ) + + try: + while context["active"]: + try: + audio_bytes = microphone.read(1024, exception_on_overflow=False) + context["audio_in"].put_nowait(audio_bytes) + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + except asyncio.CancelledError: + pass + finally: + microphone.close() + audio.terminate() + + +async def receive(agent, context): + """Receive and process events from agent.""" + try: + async for event in agent.receive(): + # Debug: Log all event types + event_types = [k for k in event.keys() if not k.startswith('_')] + if event_types: + logger.debug(f"Received event types: {event_types}") + + # Handle audio output + if "audioOutput" in event: + if not context.get("interrupted", False): + context["audio_out"].put_nowait(event["audioOutput"]["audioData"]) + + # Handle interruption events + elif "interruptionDetected" in event: + context["interrupted"] = True + elif "interrupted" in event: + context["interrupted"] = True + + # Handle text output + elif "textOutput" in event: + text_content = event["textOutput"].get("text", "") + role = event["textOutput"].get("role", "unknown") + + # Check for text-based interruption patterns + if '{ "interrupted" : true }' in text_content: + context["interrupted"] = True + elif "interrupted" in text_content.lower(): + context["interrupted"] = True + + # Log text output + if role.upper() == "USER": + print(f"User: {text_content}") + elif role.upper() == "ASSISTANT": + print(f"Assistant: {text_content}") + + # Handle transcript events (audio transcriptions) + elif "transcript" in event: + transcript_text = event["transcript"].get("text", "") + transcript_role = event["transcript"].get("role", "unknown") + transcript_type = event["transcript"].get("type", "unknown") + + # Print transcripts with special formatting to distinguish from text output + if transcript_role.upper() == "USER": + print(f"🎤 User (transcript): {transcript_text}") + elif transcript_role.upper() == "ASSISTANT": + print(f"🔊 Assistant (transcript): {transcript_text}") + + # Handle turn complete events + elif "turnComplete" in event: + logger.debug("Turn complete event received - model ready for next input") + # Reset interrupted state since the turn is complete + context["interrupted"] = False + + except asyncio.CancelledError: + pass + + +def _get_frame(cap): + """Capture and process a frame from camera.""" + if not CAMERA_AVAILABLE: + return None + + # Read the frame + ret, frame = cap.read() + # Check if the frame was read successfully + if not ret: + return None + # Convert BGR to RGB color space + # OpenCV captures in BGR but PIL expects RGB format + # This prevents the blue tint in the video feed + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + img = PIL.Image.fromarray(frame_rgb) + img.thumbnail([1024, 1024]) + + image_io = io.BytesIO() + img.save(image_io, format="jpeg") + image_io.seek(0) + + mime_type = "image/jpeg" + image_bytes = image_io.read() + return {"mime_type": mime_type, "data": base64.b64encode(image_bytes).decode()} + + +async def get_frames(context): + """Capture frames from camera and send to agent.""" + if not CAMERA_AVAILABLE: + print("Camera not available - skipping video capture") + return + + # This takes about a second, and will block the whole program + # causing the audio pipeline to overflow if you don't to_thread it. + cap = await asyncio.to_thread(cv2.VideoCapture, 0) # 0 represents the default camera + + print("Camera initialized. Starting video capture...") + + try: + while context["active"] and time.time() - context["start_time"] < context["duration"]: + frame = await asyncio.to_thread(_get_frame, cap) + if frame is None: + break + + # Send frame to agent as image input + try: + image_event = { + "imageData": frame["data"], + "mimeType": frame["mime_type"], + "encoding": "base64" + } + await context["agent"].send(image_event) + print("📸 Frame sent to model") + except Exception as e: + logger.error(f"Error sending frame: {e}") + + # Wait 1 second between frames (1 FPS) + await asyncio.sleep(1.0) + + except asyncio.CancelledError: + pass + finally: + # Release the VideoCapture object + cap.release() + + +async def send(agent, context): + """Send audio input to agent.""" + try: + while time.time() - context["start_time"] < context["duration"]: + try: + audio_bytes = context["audio_in"].get_nowait() + audio_event = {"audioData": audio_bytes, "format": "pcm", "sampleRate": 16000, "channels": 1} + await agent.send(audio_event) + except asyncio.QueueEmpty: + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + context["active"] = False + except asyncio.CancelledError: + pass + + +async def main(duration=180): + """Main function for Gemini Live bidirectional streaming test with camera support.""" + print("Starting Gemini Live bidirectional streaming test with camera...") + print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") + print("Video: Camera frames sent at 1 FPS to model") + + # Get API key from environment variable + api_key = os.getenv("GOOGLE_AI_API_KEY") + + if not api_key: + print("ERROR: GOOGLE_AI_API_KEY environment variable not set") + print("Please set it with: export GOOGLE_AI_API_KEY=your_api_key") + return + + # Initialize Gemini Live model with proper configuration + logger.info("Initializing Gemini Live model with API key") + + model = GeminiLiveBidirectionalModel( + model_id="gemini-2.5-flash-native-audio-preview-09-2025", + api_key=api_key, + params={ + "response_modalities": ["AUDIO"], + "output_audio_transcription": {}, # Enable output transcription + "input_audio_transcription": {} # Enable input transcription + } + ) + logger.info("Gemini Live model initialized successfully") + print("Using Gemini Live model") + + agent = BidirectionalAgent( + model=model, + tools=[calculator], + system_prompt="You are a helpful assistant." + ) + + await agent.start() + + # Create shared context for all tasks + context = { + "active": True, + "audio_in": asyncio.Queue(), + "audio_out": asyncio.Queue(), + "connection": agent._session, + "duration": duration, + "start_time": time.time(), + "interrupted": False, + "agent": agent, # Add agent reference for camera task + } + + print("Speak into microphone and show things to camera. Press Ctrl+C to exit.") + + try: + # Run all tasks concurrently including camera + await asyncio.gather( + play(context), + record(context), + receive(agent, context), + send(agent, context), + get_frames(context), # Add camera task + return_exceptions=True + ) + except KeyboardInterrupt: + print("\nInterrupted by user") + except asyncio.CancelledError: + print("\nTest cancelled") + finally: + print("Cleaning up...") + context["active"] = False + await agent.end() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index 412061146..d040ee436 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -11,8 +11,10 @@ BidirectionalConnectionEndEvent, BidirectionalConnectionStartEvent, BidirectionalStreamEvent, + ImageInputEvent, InterruptionDetectedEvent, TextOutputEvent, + TranscriptEvent, UsageMetricsEvent, VoiceActivityEvent, ) @@ -23,8 +25,10 @@ "BidirectionalConnectionEndEvent", "BidirectionalConnectionStartEvent", "BidirectionalStreamEvent", + "ImageInputEvent", "InterruptionDetectedEvent", "TextOutputEvent", + "TranscriptEvent", "UsageMetricsEvent", "VoiceActivityEvent", "SUPPORTED_AUDIO_FORMATS", diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 4aa720b20..4b215d74e 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -71,6 +71,23 @@ class AudioInputEvent(TypedDict): channels: Literal[1, 2] +class ImageInputEvent(TypedDict): + """Image input event for sending images/video frames to the model. + + Used for sending image data through the send() method. Supports both + raw image bytes and base64-encoded data. + + Attributes: + imageData: Image bytes (raw or base64-encoded string). + mimeType: MIME type (e.g., "image/jpeg", "image/png"). + encoding: How the imageData is encoded. + """ + + imageData: bytes | str + mimeType: str + encoding: Literal["base64", "raw"] + + class TextOutputEvent(TypedDict): """Text output event from the model during bidirectional streaming. @@ -83,6 +100,23 @@ class TextOutputEvent(TypedDict): role: Role +class TranscriptEvent(TypedDict): + """Transcript event for audio transcriptions. + + Used for both input transcriptions (user speech) and output transcriptions + (model audio). These are informational and separate from actual text responses. + + Attributes: + text: The transcribed text. + role: The role of the speaker ("user" or "assistant"). + type: Type of transcription ("input" or "output"). + """ + + text: str + role: Role + type: Literal["input", "output"] + + class InterruptionDetectedEvent(TypedDict): """Interruption detection event. @@ -180,7 +214,9 @@ class BidirectionalStreamEvent(StreamEvent, total=False): Attributes: audioOutput: Audio output from the model. audioInput: Audio input sent to the model. + imageInput: Image input sent to the model. textOutput: Text output from the model. + transcript: Audio transcription (input or output). interruptionDetected: User interruption detection. BidirectionalConnectionStart: connection start event. BidirectionalConnectionEnd: connection end event. @@ -190,7 +226,9 @@ class BidirectionalStreamEvent(StreamEvent, total=False): audioOutput: Optional[AudioOutputEvent] audioInput: Optional[AudioInputEvent] + imageInput: Optional[ImageInputEvent] textOutput: Optional[TextOutputEvent] + transcript: Optional[TranscriptEvent] interruptionDetected: Optional[InterruptionDetectedEvent] BidirectionalConnectionStart: Optional[BidirectionalConnectionStartEvent] BidirectionalConnectionEnd: Optional[BidirectionalConnectionEndEvent] From 886892ea30869bfe33d687059efbd1ac84c0dd2f Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 29 Oct 2025 15:20:22 +0100 Subject: [PATCH 024/242] test(bidirectional): Add integ test --- .gitignore | 1 + .../bidirectional_streaming/__init__.py | 1 + .../bidirectional_streaming/conftest.py | 28 ++ .../test_bidirectional_agent.py | 87 +++++ .../bidirectional_streaming/utils/__init__.py | 1 + .../utils/audio_generator.py | 154 +++++++++ .../utils/test_context.py | 304 ++++++++++++++++++ 7 files changed, 576 insertions(+) create mode 100644 tests_integ/bidirectional_streaming/__init__.py create mode 100644 tests_integ/bidirectional_streaming/conftest.py create mode 100644 tests_integ/bidirectional_streaming/test_bidirectional_agent.py create mode 100644 tests_integ/bidirectional_streaming/utils/__init__.py create mode 100644 tests_integ/bidirectional_streaming/utils/audio_generator.py create mode 100644 tests_integ/bidirectional_streaming/utils/test_context.py diff --git a/.gitignore b/.gitignore index e92a233f8..8b0fd989c 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ dist repl_state .kiro uv.lock +.audio_cache diff --git a/tests_integ/bidirectional_streaming/__init__.py b/tests_integ/bidirectional_streaming/__init__.py new file mode 100644 index 000000000..05da9afcb --- /dev/null +++ b/tests_integ/bidirectional_streaming/__init__.py @@ -0,0 +1 @@ +"""Integration tests for bidirectional streaming agents.""" diff --git a/tests_integ/bidirectional_streaming/conftest.py b/tests_integ/bidirectional_streaming/conftest.py new file mode 100644 index 000000000..52f6a2a19 --- /dev/null +++ b/tests_integ/bidirectional_streaming/conftest.py @@ -0,0 +1,28 @@ +"""Pytest fixtures for bidirectional streaming integration tests.""" + +import logging + +import pytest + +from .utils.audio_generator import AudioGenerator + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="session") +def audio_generator(): + """Provide AudioGenerator instance for tests.""" + return AudioGenerator(region="us-east-1") + + +@pytest.fixture(autouse=True) +def setup_logging(): + """Configure logging for tests.""" + logging.basicConfig( + level=logging.DEBUG, + format="%(levelname)s | %(name)s | %(message)s", + ) + # Reduce noise from some loggers + logging.getLogger("boto3").setLevel(logging.WARNING) + logging.getLogger("botocore").setLevel(logging.WARNING) + logging.getLogger("urllib3").setLevel(logging.WARNING) diff --git a/tests_integ/bidirectional_streaming/test_bidirectional_agent.py b/tests_integ/bidirectional_streaming/test_bidirectional_agent.py new file mode 100644 index 000000000..46887652b --- /dev/null +++ b/tests_integ/bidirectional_streaming/test_bidirectional_agent.py @@ -0,0 +1,87 @@ +"""Basic integration tests for Nova Sonic bidirectional streaming. + +Tests fundamental functionality including multi-turn conversations, audio I/O, +text transcription, and tool execution using the new context manager approach. +""" + +import logging + +import pytest +from strands_tools import calculator + +from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel + +from .utils.test_context import BidirectionalTestContext + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def agent_with_calculator(): + """Provide bidirectional agent with calculator tool. + + Note: Session lifecycle (start/end) is handled by BidirectionalTestContext. + """ + model = NovaSonicBidirectionalModel(region="us-east-1") + return BidirectionalAgent( + model=model, + tools=[calculator], + system_prompt="You are a helpful assistant with access to a calculator tool.", + ) + +@pytest.mark.asyncio +async def test_bidirectional_agent(agent_with_calculator, audio_generator): + """Test multi-turn conversation with follow-up questions. + + Validates: + - Session lifecycle (start/end via context manager) + - Audio input streaming + - Speech-to-text transcription + - Tool execution (calculator) + - Multi-turn conversation flow + - Text-to-speech audio output + """ + async with BidirectionalTestContext(agent_with_calculator, audio_generator) as ctx: + # Turn 1: Initial question + await ctx.say("What is five plus three?") + await ctx.wait_for_response() + + text_outputs_turn1 = ctx.get_text_outputs() + all_text_turn1 = " ".join(text_outputs_turn1).lower() + + # Validate turn 1 + assert "8" in all_text_turn1 or "eight" in all_text_turn1, ( + f"Answer '8' not found in turn 1: {text_outputs_turn1}" + ) + logger.info(f"✓ Turn 1 complete: {len(ctx.get_events())} events") + + # Turn 2: Follow-up question + await ctx.say("Now multiply that by two") + await ctx.wait_for_response() + + text_outputs_turn2 = ctx.get_text_outputs() + all_text_turn2 = " ".join(text_outputs_turn2).lower() + + # Validate turn 2 + assert "16" in all_text_turn2 or "sixteen" in all_text_turn2, ( + f"Answer '16' not found in turn 2: {text_outputs_turn2}" + ) + logger.info(f"✓ Turn 2 complete: {len(ctx.get_events())} total events") + + # Validate full conversation + assert len(text_outputs_turn2) > len(text_outputs_turn1), "No new text outputs in turn 2" + + # Validate audio outputs + audio_outputs = ctx.get_audio_outputs() + assert len(audio_outputs) > 0, "No audio output received" + total_audio_bytes = sum(len(audio) for audio in audio_outputs) + logger.info(f"✓ Audio output: {len(audio_outputs)} chunks, {total_audio_bytes} bytes") + + # Summary + logger.info("=" * 60) + logger.info("✓ Multi-turn conversation test passed") + logger.info(f" Total events: {len(ctx.get_events())}") + logger.info(f" Text outputs: {len(text_outputs_turn2)}") + logger.info(f" Audio chunks: {len(audio_outputs)}") + logger.info("=" * 60) diff --git a/tests_integ/bidirectional_streaming/utils/__init__.py b/tests_integ/bidirectional_streaming/utils/__init__.py new file mode 100644 index 000000000..fb9bdf2e9 --- /dev/null +++ b/tests_integ/bidirectional_streaming/utils/__init__.py @@ -0,0 +1 @@ +"""Utilities for bidirectional streaming integration tests.""" diff --git a/tests_integ/bidirectional_streaming/utils/audio_generator.py b/tests_integ/bidirectional_streaming/utils/audio_generator.py new file mode 100644 index 000000000..605a2aaa9 --- /dev/null +++ b/tests_integ/bidirectional_streaming/utils/audio_generator.py @@ -0,0 +1,154 @@ +"""Audio generation utilities using Amazon Polly for test audio input. + +Provides text-to-speech conversion for generating realistic audio test data +without requiring physical audio devices or pre-recorded files. +""" + +import hashlib +import logging +from pathlib import Path +from typing import Literal + +import boto3 + +logger = logging.getLogger(__name__) + +# Audio format constants matching Nova Sonic requirements +NOVA_SONIC_SAMPLE_RATE = 16000 +NOVA_SONIC_CHANNELS = 1 +NOVA_SONIC_FORMAT = "pcm" + +# Polly configuration +POLLY_VOICE_ID = "Matthew" # US English male voice +POLLY_ENGINE = "neural" # Higher quality neural engine + +# Cache directory for generated audio +CACHE_DIR = Path(__file__).parent.parent / ".audio_cache" + + +class AudioGenerator: + """Generate test audio using Amazon Polly with caching.""" + + def __init__(self, region: str = "us-east-1"): + """Initialize audio generator with Polly client. + + Args: + region: AWS region for Polly service. + """ + self.polly_client = boto3.client("polly", region_name=region) + self._ensure_cache_dir() + + def _ensure_cache_dir(self) -> None: + """Create cache directory if it doesn't exist.""" + CACHE_DIR.mkdir(parents=True, exist_ok=True) + + def _get_cache_key(self, text: str, voice_id: str) -> str: + """Generate cache key from text and voice.""" + content = f"{text}:{voice_id}".encode("utf-8") + return hashlib.md5(content).hexdigest() + + def _get_cache_path(self, cache_key: str) -> Path: + """Get cache file path for given key.""" + return CACHE_DIR / f"{cache_key}.pcm" + + async def generate_audio( + self, + text: str, + voice_id: str = POLLY_VOICE_ID, + use_cache: bool = True, + ) -> bytes: + """Generate audio from text using Polly with caching. + + Args: + text: Text to convert to speech. + voice_id: Polly voice ID to use. + use_cache: Whether to use cached audio if available. + + Returns: + Raw PCM audio bytes at 16kHz mono (Nova Sonic format). + """ + # Check cache first + if use_cache: + cache_key = self._get_cache_key(text, voice_id) + cache_path = self._get_cache_path(cache_key) + + if cache_path.exists(): + logger.debug(f"Using cached audio for: {text[:50]}...") + return cache_path.read_bytes() + + # Generate audio with Polly + logger.debug(f"Generating audio with Polly: {text[:50]}...") + + try: + response = self.polly_client.synthesize_speech( + Text=text, + OutputFormat="pcm", # Raw PCM format + VoiceId=voice_id, + Engine=POLLY_ENGINE, + SampleRate=str(NOVA_SONIC_SAMPLE_RATE), + ) + + # Read audio data + audio_data = response["AudioStream"].read() + + # Cache for future use + if use_cache: + cache_path.write_bytes(audio_data) + logger.debug(f"Cached audio: {cache_path}") + + return audio_data + + except Exception as e: + logger.error(f"Polly audio generation failed: {e}") + raise + + def create_audio_input_event( + self, + audio_data: bytes, + format: Literal["pcm", "wav", "opus", "mp3"] = NOVA_SONIC_FORMAT, + sample_rate: int = NOVA_SONIC_SAMPLE_RATE, + channels: int = NOVA_SONIC_CHANNELS, + ) -> dict: + """Create AudioInputEvent from raw audio data. + + Args: + audio_data: Raw audio bytes. + format: Audio format. + sample_rate: Sample rate in Hz. + channels: Number of audio channels. + + Returns: + AudioInputEvent dict ready for agent.send(). + """ + return { + "audioData": audio_data, + "format": format, + "sampleRate": sample_rate, + "channels": channels, + } + + def clear_cache(self) -> None: + """Clear all cached audio files.""" + if CACHE_DIR.exists(): + for cache_file in CACHE_DIR.glob("*.pcm"): + cache_file.unlink() + logger.info("Audio cache cleared") + + +# Convenience function for quick audio generation +async def generate_test_audio(text: str, use_cache: bool = True) -> dict: + """Generate test audio input event from text. + + Convenience function that creates an AudioGenerator and returns + a ready-to-use AudioInputEvent. + + Args: + text: Text to convert to speech. + use_cache: Whether to use cached audio. + + Returns: + AudioInputEvent dict ready for agent.send(). + """ + generator = AudioGenerator() + audio_data = await generator.generate_audio(text, use_cache=use_cache) + return generator.create_audio_input_event(audio_data) diff --git a/tests_integ/bidirectional_streaming/utils/test_context.py b/tests_integ/bidirectional_streaming/utils/test_context.py new file mode 100644 index 000000000..f669e12ca --- /dev/null +++ b/tests_integ/bidirectional_streaming/utils/test_context.py @@ -0,0 +1,304 @@ +"""Test context manager for bidirectional streaming tests. + +Provides a high-level interface for testing bidirectional streaming agents +with continuous background threads that mimic real-world usage patterns. +""" + +import asyncio +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent + from .audio_generator import AudioGenerator + +logger = logging.getLogger(__name__) + + +class BidirectionalTestContext: + """Manages threads and generators for bidirectional streaming tests. + + Mimics real-world usage with continuous background threads: + - Audio input thread (microphone simulation with silence padding) + - Event collection thread (captures all model outputs) + + Generators feed data into threads via queues for natural conversation flow. + + Example: + async with BidirectionalTestContext(agent, audio_generator) as ctx: + await ctx.say("What is 5 plus 3?") + await ctx.wait_for_response() + assert "8" in " ".join(ctx.get_text_outputs()) + """ + + def __init__( + self, + agent: "BidirectionalAgent", + audio_generator: "AudioGenerator | None" = None, + silence_chunk_size: int = 1024, + audio_chunk_size: int = 1024, + ): + """Initialize test context. + + Args: + agent: BidirectionalAgent instance. + audio_generator: AudioGenerator for text-to-speech. + silence_chunk_size: Size of silence chunks in bytes. + audio_chunk_size: Size of audio chunks for streaming. + """ + self.agent = agent + self.audio_generator = audio_generator + self.silence_chunk_size = silence_chunk_size + self.audio_chunk_size = audio_chunk_size + + # Queue for thread communication + self.input_queue = asyncio.Queue() # Handles both audio and text input + + # Event storage + self.events = [] # All collected events + self.last_event_time = None + + # Control flags + self.active = False + self.threads = [] + + async def __aenter__(self): + """Start context manager, agent session, and background threads.""" + # Start agent session + await self.agent.start() + logger.debug("Agent session started") + + await self.start() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Stop context manager, cleanup threads, and end agent session.""" + await self.stop() + + # End agent session + if self.agent._session and self.agent._session.active: + await self.agent.end() + logger.debug("Agent session ended") + + return False + + async def start(self): + """Start all background threads.""" + self.active = True + self.last_event_time = asyncio.get_event_loop().time() + + self.threads = [ + asyncio.create_task(self._input_thread()), + asyncio.create_task(self._event_collection_thread()), + ] + + logger.debug("Test context started with %d threads", len(self.threads)) + + async def stop(self): + """Stop all threads gracefully.""" + self.active = False + + # Cancel all threads + for task in self.threads: + if not task.done(): + task.cancel() + + # Wait for cancellation + await asyncio.gather(*self.threads, return_exceptions=True) + + logger.debug("Test context stopped") + + # === User-facing methods === + + async def say(self, text: str): + """Queue text to be converted to audio and sent to model. + + Args: + text: Text to convert to speech and send as audio. + """ + await self.input_queue.put({"type": "audio", "text": text}) + logger.debug(f"Queued speech: {text[:50]}...") + + async def send(self, data: str | dict) -> None: + """Send data directly to model (text, image, etc.). + + Args: + data: Data to send to model. Can be: + - str: Text input + - dict: Custom event (e.g., image, audio) + """ + await self.input_queue.put({"type": "direct", "data": data}) + logger.debug(f"Queued direct send: {type(data).__name__}") + + async def wait_for_response( + self, + timeout: float = 15.0, + silence_threshold: float = 2.0, + min_events: int = 1, + ): + """Wait for model to finish responding. + + Uses silence detection (no events for silence_threshold seconds) + combined with minimum event count to determine response completion. + + Args: + timeout: Maximum time to wait in seconds. + silence_threshold: Seconds of silence to consider response complete. + min_events: Minimum events before silence detection activates. + """ + start_time = asyncio.get_event_loop().time() + initial_event_count = len(self.events) + + while asyncio.get_event_loop().time() - start_time < timeout: + # Check if we have minimum events + if len(self.events) - initial_event_count >= min_events: + # Check silence + elapsed_since_event = asyncio.get_event_loop().time() - self.last_event_time + if elapsed_since_event >= silence_threshold: + logger.debug( + f"Response complete: {len(self.events) - initial_event_count} events, " + f"{elapsed_since_event:.1f}s silence" + ) + return + + await asyncio.sleep(0.1) + + logger.warning(f"Response timeout after {timeout}s") + + def get_events(self, event_type: str | None = None) -> list[dict]: + """Get collected events, optionally filtered by type. + + Args: + event_type: Optional event type to filter by (e.g., "textOutput"). + + Returns: + List of events, filtered if event_type specified. + """ + if event_type: + return [e for e in self.events if event_type in e] + return self.events.copy() + + def get_text_outputs(self) -> list[str]: + """Extract text outputs from collected events. + + Returns: + List of text content strings. + """ + texts = [] + for event in self.events: + if "textOutput" in event: + text = event["textOutput"].get("text", "") + if text: + texts.append(text) + return texts + + def get_audio_outputs(self) -> list[bytes]: + """Extract audio outputs from collected events. + + Returns: + List of audio data bytes. + """ + audio_data = [] + for event in self.events: + if "audioOutput" in event: + data = event["audioOutput"].get("audioData") + if data: + audio_data.append(data) + return audio_data + + def get_tool_uses(self) -> list[dict]: + """Extract tool use events from collected events. + + Returns: + List of tool use events. + """ + return [event["toolUse"] for event in self.events if "toolUse" in event] + + def has_interruption(self) -> bool: + """Check if any interruption was detected. + + Returns: + True if interruption detected in events. + """ + return any("interruptionDetected" in event for event in self.events) + + def clear_events(self): + """Clear collected events (useful for multi-turn tests).""" + self.events.clear() + logger.debug("Events cleared") + + # === Background threads === + + async def _input_thread(self): + """Continuously handle input to model. + + - Sends silence by default (background noise) if audio generator available + - Converts queued text to audio via Polly (for "audio" type) + - Sends text directly to model (for "text" type) + """ + try: + while self.active: + try: + # Check for queued input (non-blocking) + input_item = await asyncio.wait_for(self.input_queue.get(), timeout=0.01) + + if input_item["type"] == "audio": + # Generate and send audio + if self.audio_generator: + audio_data = await self.audio_generator.generate_audio(input_item["text"]) + + # Send audio in chunks + for i in range(0, len(audio_data), self.audio_chunk_size): + if not self.active: + break + chunk = audio_data[i : i + self.audio_chunk_size] + chunk_event = self.audio_generator.create_audio_input_event(chunk) + await self.agent.send(chunk_event) + await asyncio.sleep(0.01) + + logger.debug(f"Sent audio: {len(audio_data)} bytes") + else: + logger.warning("Audio requested but no generator available") + + elif input_item["type"] == "direct": + # Send data directly to agent + await self.agent.send(input_item["data"]) + data_repr = str(input_item["data"])[:50] if isinstance(input_item["data"], str) else type(input_item["data"]).__name__ + logger.debug(f"Sent direct: {data_repr}") + + except asyncio.TimeoutError: + # No input queued - send silence if audio generator available + if self.audio_generator: + silence = self._generate_silence_chunk() + await self.agent.send(silence) + await asyncio.sleep(0.01) + + except asyncio.CancelledError: + logger.debug("Input thread cancelled") + except Exception as e: + logger.error(f"Input thread error: {e}") + + async def _event_collection_thread(self): + """Continuously collect events from model.""" + try: + async for event in self.agent.receive(): + if not self.active: + break + + self.events.append(event) + self.last_event_time = asyncio.get_event_loop().time() + logger.debug(f"Event collected: {list(event.keys())}") + + except asyncio.CancelledError: + logger.debug("Event collection thread cancelled") + except Exception as e: + logger.error(f"Event collection thread error: {e}") + + def _generate_silence_chunk(self) -> dict: + """Generate silence chunk for background audio. + + Returns: + AudioInputEvent with silence data. + """ + silence = b"\x00" * self.silence_chunk_size + return self.audio_generator.create_audio_input_event(silence) From 1ab25c6d5203b6bb5c7a669436b3e8d26b8bfa75 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 30 Oct 2025 14:03:00 +0100 Subject: [PATCH 025/242] refactor(model): unify bidirectional model and session --- .../bidirectional_streaming/__init__.py | 14 +- .../bidirectional_streaming/agent/agent.py | 16 +- .../event_loop/bidirectional_event_loop.py | 25 +- .../models/__init__.py | 11 +- .../models/bidirectional_model.py | 125 ++-- .../models/gemini_live.py | 249 +++---- .../models/novasonic.py | 272 ++++---- .../bidirectional_streaming/models/openai.py | 246 ++++--- .../types/bidirectional_streaming.py | 14 + tests/strands/experimental/__init__.py | 1 + .../bidirectional_streaming/__init__.py | 1 + .../models/__init__.py | 1 + .../models/test_gemini_live.py | 500 ++++++++++++++ .../models/test_novasonic.py | 551 +++++++++++++++ .../models/test_openai_realtime.py | 625 ++++++++++++++++++ 15 files changed, 2176 insertions(+), 475 deletions(-) create mode 100644 tests/strands/experimental/bidirectional_streaming/__init__.py create mode 100644 tests/strands/experimental/bidirectional_streaming/models/__init__.py create mode 100644 tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py create mode 100644 tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py create mode 100644 tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 3c47dd957..e31bc670e 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -3,10 +3,11 @@ # Main components - Primary user interface from .agent.agent import BidirectionalAgent -# Advanced interfaces (for custom implementations) +# Unified model interface (for custom implementations) from .models.bidirectional_model import BidirectionalModel, BidirectionalModelSession # Model providers - What users need to create models +from .models.gemini_live import GeminiLiveBidirectionalModel from .models.novasonic import NovaSonicBidirectionalModel from .models.openai import OpenAIRealtimeBidirectionalModel @@ -15,7 +16,9 @@ AudioInputEvent, AudioOutputEvent, BidirectionalStreamEvent, + ImageInputEvent, InterruptionDetectedEvent, + TextInputEvent, TextOutputEvent, UsageMetricsEvent, VoiceActivityEvent, @@ -26,19 +29,22 @@ "BidirectionalAgent", # Model providers + "GeminiLiveBidirectionalModel", "NovaSonicBidirectionalModel", "OpenAIRealtimeBidirectionalModel", # Event types "AudioInputEvent", - "AudioOutputEvent", + "AudioOutputEvent", + "ImageInputEvent", + "TextInputEvent", "TextOutputEvent", "InterruptionDetectedEvent", "BidirectionalStreamEvent", "VoiceActivityEvent", "UsageMetricsEvent", - # Model interface + # Unified model interface "BidirectionalModel", - "BidirectionalModelSession", + "BidirectionalModelSession", # Backwards compatibility alias ] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 820a6c490..62528d472 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -379,13 +379,15 @@ async def send(self, input_data: str | AudioInputEvent | ImageInputEvent) -> Non self.messages.append({"role": "user", "content": input_data}) logger.debug("Text sent: %d characters", len(input_data)) - await self._session.model_session.send_text_content(input_data) + # Create TextInputEvent for unified send() + text_event = {"text": input_data, "role": "user"} + await self._session.model_session.send(text_event) elif isinstance(input_data, dict) and "audioData" in input_data: - # Handle audio input - await self._session.model_session.send_audio_content(input_data) + # Handle audio input - already in AudioInputEvent format + await self._session.model_session.send(input_data) elif isinstance(input_data, dict) and "imageData" in input_data: - # Handle image input (ImageInputEvent) - await self._session.model_session.send_image_content(input_data) + # Handle image input - already in ImageInputEvent format + await self._session.model_session.send(input_data) else: raise ValueError( "Input must be either a string (text), AudioInputEvent " @@ -419,7 +421,9 @@ async def interrupt(self) -> None: ValueError: If no active session. """ self._validate_active_session() - await self._session.model_session.send_interrupt() + # Interruption is now handled internally by models through audio/event processing + # No explicit interrupt method needed in unified interface + logger.debug("Interrupt requested - handled by model's audio processing") async def end(self) -> None: """End the conversation session and cleanup all resources. diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index bbf5fb425..521ebc0dd 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -21,7 +21,7 @@ from ....types._events import ToolResultEvent, ToolStreamEvent from ....types.content import Message from ....types.tools import ToolResult, ToolUse -from ..models.bidirectional_model import BidirectionalModelSession +from ..models.bidirectional_model import BidirectionalModel logger = logging.getLogger(__name__) @@ -37,11 +37,11 @@ class BidirectionalConnection: handling while providing a simple interface for agent interactions. """ - def __init__(self, model_session: BidirectionalModelSession, agent: "BidirectionalAgent") -> None: - """Initialize session with model session and agent reference. + def __init__(self, model_session: BidirectionalModel, agent: "BidirectionalAgent") -> None: + """Initialize session with model and agent reference. Args: - model_session: Provider-specific bidirectional model session. + model_session: Bidirectional model instance (unified interface). agent: BidirectionalAgent instance for tool registry access. """ self.model_session = model_session @@ -76,12 +76,15 @@ async def start_bidirectional_connection(agent: "BidirectionalAgent") -> Bidirec Returns: BidirectionalConnection: Active session with background tasks running. """ - logger.debug("Starting bidirectional session - initializing model session") + logger.debug("Starting bidirectional session - initializing model connection") - # Create provider-specific session - model_session = await agent.model.create_bidirectional_connection( + # Connect to model using unified interface + await agent.model.connect( system_prompt=agent.system_prompt, tools=agent.tool_registry.get_all_tool_specs(), messages=agent.messages ) + + # Use the model directly (unified interface - no separate session) + model_session = agent.model # Create session wrapper for background processing session = BidirectionalConnection(model_session=model_session, agent=agent) @@ -257,7 +260,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: """ logger.debug("Model events processor started") try: - async for provider_event in session.model_session.receive_events(): + async for provider_event in session.model_session.receive(): if not session.active: break @@ -434,8 +437,8 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: tool_result = tool_event.tool_result tool_use_id = tool_result.get("toolUseId") - # Send result through provider-specific session - await session.model_session.send_tool_result(tool_use_id, tool_result) + # Send result through unified send() method + await session.model_session.send(tool_result) logger.debug("Tool result sent: %s", tool_use_id) # Handle streaming events if needed later @@ -471,7 +474,7 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: "content": [{"text": f"Error: {str(e)}"}] } try: - await session.model_session.send_tool_result(tool_id, error_result) + await session.model_session.send(error_result) logger.debug("Error result sent: %s", tool_id) except Exception: logger.error("Failed to send error result: %s", tool_id) diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index c5287d15d..e2745310c 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -1,17 +1,14 @@ """Bidirectional model interfaces and implementations.""" from .bidirectional_model import BidirectionalModel, BidirectionalModelSession -from .gemini_live import GeminiLiveBidirectionalModel, GeminiLiveSession -from .novasonic import NovaSonicBidirectionalModel, NovaSonicSession -from .openai import OpenAIRealtimeBidirectionalModel, OpenAIRealtimeSession +from .gemini_live import GeminiLiveBidirectionalModel +from .novasonic import NovaSonicBidirectionalModel +from .openai import OpenAIRealtimeBidirectionalModel __all__ = [ "BidirectionalModel", - "BidirectionalModelSession", + "BidirectionalModelSession", # Backwards compatibility alias "GeminiLiveBidirectionalModel", - "GeminiLiveSession", "NovaSonicBidirectionalModel", - "NovaSonicSession", "OpenAIRealtimeBidirectionalModel", - "OpenAIRealtimeSession", ] diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 42485561b..3af05e113 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -1,11 +1,10 @@ -"""Bidirectional model interface for real-time streaming conversations. +"""Unified bidirectional streaming interface. -Defines the interface for models that support bidirectional streaming capabilities. -Provides abstractions for different model providers with connection-based communication -patterns that support real-time audio and text interaction. +Single layer combining model and session abstractions for simpler implementation. +Providers implement this directly without separate model/session classes. Features: -- connection-based persistent connections +- Unified model interface (no separate session class) - Real-time bidirectional communication - Provider-agnostic event normalization - Tool execution integration @@ -13,101 +12,85 @@ import abc import logging -from typing import AsyncIterable +from typing import AsyncIterable, Union from ....types.content import Messages -from ....types.tools import ToolSpec -from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent +from ....types.tools import ToolResult, ToolSpec +from ..types.bidirectional_streaming import ( + AudioInputEvent, + BidirectionalStreamEvent, + ImageInputEvent, + TextInputEvent, +) logger = logging.getLogger(__name__) -class BidirectionalModelSession(abc.ABC): - """Abstract interface for model-specific bidirectional communication connections. +class BidirectionalModel(abc.ABC): + """Unified interface for bidirectional streaming models. - Defines the contract for managing persistent streaming connections with individual - model providers, handling audio/text input, receiving events, and managing - tool execution results. + Combines model configuration and session communication in a single abstraction. + Providers implement this directly without separate model/session classes. """ @abc.abstractmethod - async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: - """Receive events from the model in standardized format. - - Converts provider-specific events to a common format that can be - processed uniformly by the event loop. - """ - raise NotImplementedError + async def connect( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs, + ) -> None: + """Establish bidirectional connection with the model. - @abc.abstractmethod - async def send_audio_content(self, audio_input: AudioInputEvent) -> None: - """Send audio content to the model during an active connection. + Initializes the connection state and prepares for real-time communication. + This replaces the old create_bidirectional_connection pattern. - Handles audio encoding and provider-specific formatting while presenting - a simple AudioInputEvent interface. + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Provider-specific configuration options. """ raise NotImplementedError - # TODO: remove with interface unification - async def send_image_content(self, image_input: ImageInputEvent) -> None: - """Send image content to the model during an active connection. - - Handles image encoding and provider-specific formatting while presenting - a simple ImageInputEvent interface. - """ - raise NotImplementedError - @abc.abstractmethod - async def send_text_content(self, text: str, **kwargs) -> None: - """Send text content to the model during ongoing generation. + async def close(self) -> None: + """Close connection and cleanup resources. - Allows natural interruption and follow-up questions without requiring - connection restart. + Terminates the active connection and releases any held resources. """ raise NotImplementedError @abc.abstractmethod - async def send_interrupt(self) -> None: - """Send interruption signal to stop generation immediately. - - Enables responsive conversational experiences where users can - naturally interrupt during model responses. - """ - raise NotImplementedError + async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: + """Receive events from the model in standardized format. - @abc.abstractmethod - async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: - """Send tool execution result to the model. + Yields provider-agnostic events that can be processed uniformly + by the event loop. Converts provider-specific events to common format. - Formats and sends tool results according to the provider's specific protocol. - Handles both successful results and error cases through the result dictionary. + Yields: + BidirectionalStreamEvent: Standardized event dictionaries. """ raise NotImplementedError @abc.abstractmethod - async def close(self) -> None: - """Close the connection and cleanup resources.""" - raise NotImplementedError + async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + """Send structured content to the model. + Unified method for sending all types of content. Implementations should + dispatch to appropriate internal handlers based on content type. -class BidirectionalModel(abc.ABC): - """Interface for models that support bidirectional streaming. - - Defines the contract for creating persistent streaming connections that support - real-time audio and text communication with AI models. - """ - - @abc.abstractmethod - async def create_bidirectional_connection( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs, - ) -> BidirectionalModelSession: - """Create a bidirectional connection with the model. + Args: + content: Typed event (TextInputEvent, ImageInputEvent, AudioInputEvent, or ToolResult). - Establishes a persistent connection for real-time communication while - abstracting provider-specific initialization requirements. + Example: + await model.send(TextInputEvent(text="Hello", role="user")) + await model.send(AudioInputEvent(audioData=bytes, format="pcm", ...)) + await model.send(ToolResult(toolUseId="123", status="success", ...)) """ raise NotImplementedError + + +# Backwards compatibility alias - will be removed in future version +BidirectionalModelSession = BidirectionalModel diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 64c4d7348..578de5a2b 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -1,11 +1,11 @@ """Gemini Live API bidirectional model provider using official Google GenAI SDK. -Implements the BidirectionalModel interface for Google's Gemini Live API using the +Implements the unified BidirectionalModel interface for Google's Gemini Live API using the official Google GenAI SDK for simplified and robust WebSocket communication. Key improvements over custom WebSocket implementation: - Uses official google-genai SDK with native Live API support -- Simplified session management with client.aio.live.connect() +- Unified model interface (no separate session class) - Built-in tool integration and event handling - Automatic WebSocket connection management and error handling - Native support for audio/text streaming and interruption @@ -15,14 +15,14 @@ import base64 import logging import uuid -from typing import Any, AsyncIterable, Dict, List, Optional +from typing import Any, AsyncIterable, Dict, List, Optional, Union from google import genai from google.genai import types as genai_types from google.genai.types import LiveServerMessage, LiveServerContent from ....types.content import Messages -from ....types.tools import ToolSpec, ToolUse +from ....types.tools import ToolResult, ToolSpec, ToolUse from ..types.bidirectional_streaming import ( AudioInputEvent, AudioOutputEvent, @@ -30,10 +30,11 @@ BidirectionalConnectionStartEvent, ImageInputEvent, InterruptionDetectedEvent, + TextInputEvent, TextOutputEvent, TranscriptEvent, ) -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .bidirectional_model import BidirectionalModel logger = logging.getLogger(__name__) @@ -43,45 +44,70 @@ GEMINI_CHANNELS = 1 -class GeminiLiveSession(BidirectionalModelSession): - """Gemini Live API session using official Google GenAI SDK. +class GeminiLiveBidirectionalModel(BidirectionalModel): + """Unified Gemini Live API implementation using official Google GenAI SDK. + Combines model configuration and connection state in a single class. Provides a clean interface to Gemini Live API using the official SDK, eliminating custom WebSocket handling and providing robust error handling. """ - def __init__(self, client: genai.Client, model_id: str, config: Dict[str, Any]): - """Initialize Gemini Live API session. + def __init__( + self, + model_id: str = "models/gemini-2.0-flash-live-preview-04-09", + api_key: Optional[str] = None, + **config + ): + """Initialize Gemini Live API bidirectional model. Args: - client: Gemini client instance - model_id: Model identifier - config: Model configuration including live config + model_id: Gemini Live model identifier. + api_key: Google AI API key for authentication. + **config: Additional configuration. """ - self.client = client + # Model configuration self.model_id = model_id + self.api_key = api_key self.config = config - self.session_id = str(uuid.uuid4()) - self._active = True + + # Create Gemini client with proper API version + client_kwargs = {} + if api_key: + client_kwargs["api_key"] = api_key + + # Use v1alpha for Live API as it has better model support + client_kwargs["http_options"] = {"api_version": "v1alpha"} + + self.client = genai.Client(**client_kwargs) + + # Connection state (initialized in connect()) self.live_session = None self.live_session_cm = None - - + self.session_id = None + self._active = False - async def initialize( + async def connect( self, system_prompt: Optional[str] = None, tools: Optional[List[ToolSpec]] = None, - messages: Optional[Messages] = None + messages: Optional[Messages] = None, + **kwargs ) -> None: - """Initialize Gemini Live API session by creating the connection.""" + """Establish bidirectional connection with Gemini Live API. + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Additional configuration options. + """ try: - # Build live config - live_config = self.config.get("live_config") + # Initialize connection state + self.session_id = str(uuid.uuid4()) + self._active = True - if live_config is None: - raise ValueError("live_config is required but not found in session config") + # Build live config + live_config = self._build_live_config(system_prompt, tools, **kwargs) # Create the context manager self.live_session_cm = self.client.aio.live.connect( @@ -96,9 +122,8 @@ async def initialize( if messages: await self._send_message_history(messages) - except Exception as e: - logger.error("Error initializing Gemini Live session: %s", e) + logger.error("Error connecting to Gemini Live: %s", e) raise async def _send_message_history(self, messages: Messages) -> None: @@ -125,13 +150,13 @@ async def _send_message_history(self, messages: Messages) -> None: content = genai_types.Content(role=role, parts=content_parts) await self.live_session.send_client_content(turns=content) - async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: + async def receive(self) -> AsyncIterable[Dict[str, Any]]: """Receive Gemini Live API events and convert to provider-agnostic format.""" # Emit connection start event connection_start: BidirectionalConnectionStartEvent = { "connectionId": self.session_id, - "metadata": {"provider": "gemini_live", "model_id": self.config.get("model_id")} + "metadata": {"provider": "gemini_live", "model_id": self.model_id} } yield {"BidirectionalConnectionStart": connection_start} @@ -251,15 +276,43 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic logger.error("Message attributes: %s", [attr for attr in dir(message) if not attr.startswith('_')]) return None - async def send_audio_content(self, audio_input: AudioInputEvent) -> None: - """Send audio content using Gemini Live API. + async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + """Unified send method for all content types. - Gemini Live expects continuous audio streaming via send_realtime_input. - This automatically triggers VAD and can interrupt ongoing responses. + Dispatches to appropriate internal handler based on content type. + + Args: + content: Typed event (TextInputEvent, ImageInputEvent, AudioInputEvent, or ToolResult). """ if not self._active: return + try: + if isinstance(content, dict): + # Dispatch based on content structure + if "text" in content and "role" in content: + # TextInputEvent + await self._send_text_content(content["text"]) + elif "audioData" in content: + # AudioInputEvent + await self._send_audio_content(content) + elif "imageData" in content or "image_url" in content: + # ImageInputEvent + await self._send_image_content(content) + elif "toolUseId" in content and "status" in content: + # ToolResult + await self._send_tool_result(content) + else: + logger.warning(f"Unknown content type with keys: {content.keys()}") + except Exception as e: + logger.error(f"Error sending content: {e}") + + async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Internal: Send audio content using Gemini Live API. + + Gemini Live expects continuous audio streaming via send_realtime_input. + This automatically triggers VAD and can interrupt ongoing responses. + """ try: # Create audio blob for the SDK audio_blob = genai_types.Blob( @@ -273,18 +326,15 @@ async def send_audio_content(self, audio_input: AudioInputEvent) -> None: except Exception as e: logger.error("Error sending audio content: %s", e) - async def send_image_content(self, image_input: ImageInputEvent) -> None: - """Send image content using Gemini Live API. + async def _send_image_content(self, image_input: ImageInputEvent) -> None: + """Internal: Send image content using Gemini Live API. Sends image frames following the same pattern as the GitHub example. Images are sent as base64-encoded data with MIME type. """ - if not self._active: - return - try: # Prepare the message based on encoding - if image_input["encoding"] == "base64": + if image_input.get("encoding") == "base64": # Data is already base64 encoded if isinstance(image_input["imageData"], bytes): data_str = image_input["imageData"].decode() @@ -306,11 +356,8 @@ async def send_image_content(self, image_input: ImageInputEvent) -> None: except Exception as e: logger.error("Error sending image content: %s", e) - async def send_text_content(self, text: str, **kwargs) -> None: - """Send text content using Gemini Live API.""" - if not self._active: - return - + async def _send_text_content(self, text: str) -> None: + """Internal: Send text content using Gemini Live API.""" try: # Create content with text content = genai_types.Content( @@ -324,36 +371,25 @@ async def send_text_content(self, text: str, **kwargs) -> None: except Exception as e: logger.error("Error sending text content: %s", e) - async def send_interrupt(self) -> None: - """Send interruption signal to Gemini Live API. - - Gemini Live uses automatic VAD-based interruption. When new audio input - is detected, it automatically interrupts the ongoing generation. - We don't need to send explicit interrupt signals like Nova Sonic. - """ - if not self._active: - return - + async def _send_tool_result(self, tool_result: ToolResult) -> None: + """Internal: Send tool result using Gemini Live API.""" try: - # Gemini Live handles interruption automatically through VAD - # When new audio input is sent via send_realtime_input, it automatically - # interrupts any ongoing generation. No explicit interrupt signal needed. - logger.debug("Interrupt requested - Gemini Live handles this automatically via VAD") + tool_use_id = tool_result.get("toolUseId") + + # Extract result content + result_data = {} + if "content" in tool_result: + # Extract text from content blocks + for block in tool_result["content"]: + if "text" in block: + result_data = {"result": block["text"]} + break - except Exception as e: - logger.error("Error in interrupt handling: %s", e) - - async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: - """Send tool result using Gemini Live API.""" - if not self._active: - return - - try: # Create function response func_response = genai_types.FunctionResponse( id=tool_use_id, name=tool_use_id, # Gemini uses name as identifier - response=result + response=result_data ) # Send tool response @@ -361,11 +397,6 @@ async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> No except Exception as e: logger.error("Error sending tool result: %s", e) - async def send_tool_error(self, tool_use_id: str, error: str) -> None: - """Send tool error using Gemini Live API.""" - error_result = {"error": error} - await self.send_tool_result(tool_use_id, error_result) - async def close(self) -> None: """Close Gemini Live API connection.""" if not self._active: @@ -378,69 +409,7 @@ async def close(self) -> None: if self.live_session_cm: await self.live_session_cm.__aexit__(None, None, None) except Exception as e: - logger.error("Error closing Gemini Live session: %s", e) - raise - - -class GeminiLiveBidirectionalModel(BidirectionalModel): - """Gemini Live API model implementation using official Google GenAI SDK. - - Provides access to Google's Gemini Live API through the bidirectional - streaming interface, using the official SDK for robust and simple integration. - """ - - def __init__( - self, - model_id: str = "models/gemini-2.0-flash-live-preview-04-09", - api_key: Optional[str] = None, - **config - ): - """Initialize Gemini Live API bidirectional model. - - Args: - model_id: Gemini Live model identifier. - api_key: Google AI API key for authentication. - **config: Additional configuration. - """ - self.model_id = model_id - self.api_key = api_key - self.config = config - - # Create Gemini client with proper API version - client_kwargs = {} - if api_key: - client_kwargs["api_key"] = api_key - - # Use v1alpha for Live API as it has better model support - client_kwargs["http_options"] = {"api_version": "v1alpha"} - - self.client = genai.Client(**client_kwargs) - - async def create_bidirectional_connection( - self, - system_prompt: Optional[str] = None, - tools: Optional[List[ToolSpec]] = None, - messages: Optional[Messages] = None, - **kwargs - ) -> BidirectionalModelSession: - """Create Gemini Live API bidirectional connection using official SDK.""" - - try: - # Build configuration - live_config = self._build_live_config(system_prompt, tools, **kwargs) - - # Create session config - session_config = self._get_session_config() - session_config["live_config"] = live_config - - # Create and initialize session wrapper - session = GeminiLiveSession(self.client, self.model_id, session_config) - await session.initialize(system_prompt, tools, messages) - - return session - - except Exception as e: - logger.error("Failed to create Gemini Live connection: %s", e) + logger.error("Error closing Gemini Live connection: %s", e) raise def _build_live_config( @@ -488,12 +457,4 @@ def _format_tools_for_live_api(self, tool_specs: List[ToolSpec]) -> List[genai_t for tool_spec in tool_specs ], ), - ] - - def _get_session_config(self) -> Dict[str, Any]: - """Get session configuration for Gemini Live API.""" - return { - "model_id": self.model_id, - "params": self.config.get("params"), - **self.config - } \ No newline at end of file + ] \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 134ff73fd..62b53a127 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -1,9 +1,11 @@ """Nova Sonic bidirectional model provider for real-time streaming conversations. -Implements the BidirectionalModel interface for Amazon's Nova Sonic, handling the +Implements the unified BidirectionalModel interface for Amazon's Nova Sonic, handling the complex event sequencing and audio processing required by Nova Sonic's InvokeModelWithBidirectionalStream protocol. +Unified model interface - combines configuration and connection state in single class. + Nova Sonic specifics: - Hierarchical event sequences: connectionStart → promptStart → content streaming - Base64-encoded audio format with hex encoding @@ -19,25 +21,31 @@ import time import traceback import uuid -from typing import AsyncIterable +from typing import AsyncIterable, Union from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme -from aws_sdk_bedrock_runtime.models import BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk, InvokeModelWithBidirectionalStreamOperationOutput +from aws_sdk_bedrock_runtime.models import ( + BidirectionalInputPayloadPart, + InvokeModelWithBidirectionalStreamInputChunk, + InvokeModelWithBidirectionalStreamOperationOutput, +) from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver from ....types.content import Messages -from ....types.tools import ToolSpec, ToolUse +from ....types.tools import ToolResult, ToolSpec, ToolUse from ..types.bidirectional_streaming import ( AudioInputEvent, AudioOutputEvent, BidirectionalConnectionEndEvent, BidirectionalConnectionStartEvent, + ImageInputEvent, InterruptionDetectedEvent, + TextInputEvent, TextOutputEvent, UsageMetricsEvent, ) -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .bidirectional_model import BidirectionalModel logger = logging.getLogger(__name__) @@ -72,29 +80,36 @@ RESPONSE_TIMEOUT = 1.0 -class NovaSonicSession(BidirectionalModelSession): - """Nova Sonic connection implementation handling the provider's specific protocol. +class NovaSonicBidirectionalModel(BidirectionalModel): + """Unified Nova Sonic implementation for bidirectional streaming. + Combines model configuration and connection state in a single class. Manages Nova Sonic's complex event sequencing, audio format conversion, and - tool execution patterns while providing the standard BidirectionalModelSession - interface. + tool execution patterns while providing the standard BidirectionalModel interface. """ - def __init__(self, stream: InvokeModelWithBidirectionalStreamOperationOutput, config: dict[str, any]) -> None: - """Initialize Nova Sonic connection. + def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config: any) -> None: + """Initialize Nova Sonic bidirectional model. Args: - stream: Nova Sonic bidirectional stream operation output from AWS SDK. - config: Model configuration. + model_id: Nova Sonic model identifier. + region: AWS region. + **config: Additional configuration. """ - self.stream = stream + # Model configuration + self.model_id = model_id + self.region = region self.config = config - self.prompt_name = str(uuid.uuid4()) - self._active = True + self._client = None + + # Connection state (initialized in connect()) + self.stream = None + self.prompt_name = None + self._active = False # Nova Sonic requires unique content names - self.audio_content_name = str(uuid.uuid4()) - self.text_content_name = str(uuid.uuid4()) + self.audio_content_name = None + self.text_content_name = None # Audio connection state self.audio_connection_active = False @@ -102,33 +117,67 @@ def __init__(self, stream: InvokeModelWithBidirectionalStreamOperationOutput, co self.silence_threshold = SILENCE_THRESHOLD self.silence_task = None - # Validate stream - if not stream: - logger.error("Stream is None") - raise ValueError("Stream cannot be None") + # Background task and event queue + self._response_task = None + self._event_queue = None - logger.debug("Nova Sonic connection initialized with prompt: %s", self.prompt_name) + logger.debug("Nova Sonic bidirectional model initialized: %s", model_id) - async def initialize( + async def connect( self, system_prompt: str | None = None, tools: list[ToolSpec] | None = None, messages: Messages | None = None, + **kwargs, ) -> None: - """Initialize Nova Sonic connection with required protocol sequence.""" + """Establish bidirectional connection to Nova Sonic. + + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Additional configuration options. + """ + logger.debug("Nova connection create - starting") + try: - system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." + # Initialize client if needed + if not self._client: + await self._initialize_client() + + # Initialize connection state + self.prompt_name = str(uuid.uuid4()) + self._active = True + self.audio_content_name = str(uuid.uuid4()) + self.text_content_name = str(uuid.uuid4()) + self._event_queue = asyncio.Queue() + + # Start Nova Sonic bidirectional stream + self.stream = await self._client.invoke_model_with_bidirectional_stream( + InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) + ) + + # Validate stream + if not self.stream: + logger.error("Stream is None") + raise ValueError("Stream cannot be None") + logger.debug("Nova Sonic connection initialized with prompt: %s", self.prompt_name) + + # Send initialization events + system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." init_events = self._build_initialization_events(system_prompt, tools or [], messages) logger.debug("Nova Sonic initialization - sending %d events", len(init_events)) await self._send_initialization_events(init_events) - logger.info("Nova Sonic connection initialized successfully") + # Start background response processor self._response_task = asyncio.create_task(self._process_responses()) + logger.info("Nova Sonic connection established successfully") + except Exception as e: - logger.error("Error during Nova Sonic initialization: %s", e) + logger.error("Nova connection create error: %s", str(e)) raise def _build_initialization_events( @@ -206,7 +255,7 @@ def _log_event_type(self, nova_event: dict[str, any]) -> None: audio_bytes = base64.b64decode(audio_content) logger.debug("Nova audio output: %d bytes", len(audio_bytes)) - async def receive_events(self) -> AsyncIterable[dict[str, any]]: + async def receive(self) -> AsyncIterable[dict[str, any]]: """Receive Nova Sonic events and convert to provider-agnostic format.""" if not self.stream: logger.error("Stream is None") @@ -217,14 +266,10 @@ async def receive_events(self) -> AsyncIterable[dict[str, any]]: # Emit connection start event to Strands event system connection_start: BidirectionalConnectionStartEvent = { "connectionId": self.prompt_name, - "metadata": {"provider": "nova_sonic", "model_id": self.config.get("model_id")}, + "metadata": {"provider": "nova_sonic", "model_id": self.model_id}, } yield {"BidirectionalConnectionStart": connection_start} - # Initialize event queue if not already done - if not hasattr(self, "_event_queue"): - self._event_queue = asyncio.Queue() - try: while self._active: try: @@ -252,8 +297,39 @@ async def receive_events(self) -> AsyncIterable[dict[str, any]]: } yield {"BidirectionalConnectionEnd": connection_end} - async def start_audio_connection(self) -> None: - """Start audio input connection (call once before sending audio chunks).""" + async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + """Unified send method for all content types. + + Dispatches to appropriate internal handler based on content type. + + Args: + content: Typed event (TextInputEvent, ImageInputEvent, AudioInputEvent, or ToolResult). + """ + if not self._active: + return + + try: + if isinstance(content, dict): + # Dispatch based on content structure + if "text" in content and "role" in content: + # TextInputEvent + await self._send_text_content(content["text"]) + elif "audioData" in content: + # AudioInputEvent + await self._send_audio_content(content) + elif "imageData" in content or "image_url" in content: + # ImageInputEvent - not supported by Nova Sonic + logger.warning("Image input not supported by Nova Sonic") + elif "toolUseId" in content and "status" in content: + # ToolResult + await self._send_tool_result(content) + else: + logger.warning(f"Unknown content type with keys: {content.keys()}") + except Exception as e: + logger.error(f"Error sending content: {e}") + + async def _start_audio_connection(self) -> None: + """Internal: Start audio input connection (call once before sending audio chunks).""" if self.audio_connection_active: return @@ -277,14 +353,11 @@ async def start_audio_connection(self) -> None: await self._send_nova_event(audio_content_start) self.audio_connection_active = True - async def send_audio_content(self, audio_input: AudioInputEvent) -> None: - """Send audio using Nova Sonic protocol-specific format.""" - if not self._active: - return - + async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Internal: Send audio using Nova Sonic protocol-specific format.""" # Start audio connection if not already active if not self.audio_connection_active: - await self.start_audio_connection() + await self._start_audio_connection() # Update last audio time and cancel any pending silence task self.last_audio_time = time.time() @@ -313,19 +386,19 @@ async def send_audio_content(self, audio_input: AudioInputEvent) -> None: self.silence_task = asyncio.create_task(self._check_silence()) async def _check_silence(self) -> None: - """Check for silence and automatically end audio connection.""" + """Internal: Check for silence and automatically end audio connection.""" try: await asyncio.sleep(self.silence_threshold) if self.audio_connection_active and self.last_audio_time: elapsed = time.time() - self.last_audio_time if elapsed >= self.silence_threshold: logger.debug("Nova silence detected: %.2f seconds", elapsed) - await self.end_audio_input() + await self._end_audio_input() except asyncio.CancelledError: pass - async def end_audio_input(self) -> None: - """End current audio input connection to trigger Nova Sonic processing.""" + async def _end_audio_input(self) -> None: + """Internal: End current audio input connection to trigger Nova Sonic processing.""" if not self.audio_connection_active: return @@ -338,11 +411,8 @@ async def end_audio_input(self) -> None: await self._send_nova_event(audio_content_end) self.audio_connection_active = False - async def send_text_content(self, text: str, **kwargs) -> None: - """Send text content using Nova Sonic format.""" - if not self._active: - return - + async def _send_text_content(self, text: str) -> None: + """Internal: Send text content using Nova Sonic format.""" content_name = str(uuid.uuid4()) events = [ self._get_text_content_start_event(content_name), @@ -353,37 +423,45 @@ async def send_text_content(self, text: str, **kwargs) -> None: for event in events: await self._send_nova_event(event) - async def send_interrupt(self) -> None: - """Send interruption signal to Nova Sonic.""" - if not self._active: - return - + async def _send_interrupt(self) -> None: + """Internal: Send interruption signal to Nova Sonic.""" # Nova Sonic handles interruption through special input events - interrupt_event = { - "event": { - "audioInput": { - "promptName": self.prompt_name, - "contentName": self.audio_content_name, - "stopReason": "INTERRUPTED", + interrupt_event = json.dumps( + { + "event": { + "audioInput": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "stopReason": "INTERRUPTED", + } } } - } + ) await self._send_nova_event(interrupt_event) - async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: - """Send tool result using Nova Sonic toolResult format.""" - if not self._active: - return + async def _send_tool_result(self, tool_result: ToolResult) -> None: + """Internal: Send tool result using Nova Sonic toolResult format.""" + tool_use_id = tool_result.get("toolUseId") logger.debug("Nova tool result send: %s", tool_use_id) + + # Extract result content + result_data = {} + if "content" in tool_result: + # Extract text from content blocks + for block in tool_result["content"]: + if "text" in block: + result_data = {"result": block["text"]} + break + content_name = str(uuid.uuid4()) events = [ self._get_tool_content_start_event(content_name, tool_use_id), - self._get_tool_result_event(content_name, result), + self._get_tool_result_event(content_name, result_data), self._get_content_end_event(content_name), ] - for _i, event in enumerate(events): + for event in events: await self._send_nova_event(event) async def close(self) -> None: @@ -405,7 +483,7 @@ async def close(self) -> None: try: # End audio connection if active if self.audio_connection_active: - await self.end_audio_input() + await self._end_audio_input() # Send cleanup events cleanup_events = [self._get_prompt_end_event(), self._get_connection_end_event()] @@ -640,60 +718,6 @@ async def _send_nova_event(self, event: str) -> None: logger.error("Event was: %s", event) raise - -class NovaSonicBidirectionalModel(BidirectionalModel): - """Nova Sonic model implementation for bidirectional streaming. - - Provides access to Amazon's Nova Sonic model through the bidirectional - streaming interface, handling AWS authentication and connection management. - """ - - def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config: any) -> None: - """Initialize Nova Sonic bidirectional model. - - Args: - model_id: Nova Sonic model identifier. - region: AWS region. - **config: Additional configuration. - """ - self.model_id = model_id - self.region = region - self.config = config - self._client = None - - logger.debug("Nova Sonic bidirectional model initialized: %s", model_id) - - async def create_bidirectional_connection( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs, - ) -> BidirectionalModelSession: - """Create Nova Sonic bidirectional connection.""" - logger.debug("Nova connection create - starting") - - # Initialize client if needed - if not self._client: - await self._initialize_client() - - # Start Nova Sonic bidirectional stream - try: - stream = await self._client.invoke_model_with_bidirectional_stream( - InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) - ) - - # Create and initialize connection - connection = NovaSonicSession(stream, self.config) - await connection.initialize(system_prompt, tools, messages) - - logger.debug("Nova connection created") - return connection - except Exception as e: - logger.error("Nova connection create error: %s", str(e)) - logger.error("Failed to create Nova Sonic connection: %s", e) - raise - async def _initialize_client(self) -> None: """Initialize Nova Sonic client.""" try: diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 7d009b1c7..0208ee162 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -2,6 +2,8 @@ Provides real-time audio and text communication through OpenAI's Realtime API with WebSocket connections, voice activity detection, and function calling. + +Unified model interface - combines configuration and connection state in single class. """ import asyncio @@ -9,24 +11,26 @@ import json import logging import uuid -from typing import AsyncIterable +from typing import AsyncIterable, Union import websockets from websockets.client import WebSocketClientProtocol from websockets.exceptions import ConnectionClosed from ....types.content import Messages -from ....types.tools import ToolSpec, ToolUse +from ....types.tools import ToolResult, ToolSpec, ToolUse from ..types.bidirectional_streaming import ( AudioInputEvent, AudioOutputEvent, BidirectionalConnectionEndEvent, BidirectionalConnectionStartEvent, BidirectionalStreamEvent, + ImageInputEvent, + TextInputEvent, TextOutputEvent, VoiceActivityEvent, ) -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .bidirectional_model import BidirectionalModel logger = logging.getLogger(__name__) @@ -55,63 +59,115 @@ } -class OpenAIRealtimeSession(BidirectionalModelSession): - """OpenAI Realtime API session for real-time audio/text streaming. +class OpenAIRealtimeBidirectionalModel(BidirectionalModel): + """Unified OpenAI Realtime API implementation for bidirectional streaming. + Combines model configuration and connection state in a single class. Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, function calling, and event conversion to Strands format. """ - def __init__(self, websocket: WebSocketClientProtocol, config: dict[str, any]) -> None: - """Initialize OpenAI Realtime session.""" - self.websocket = websocket + def __init__( + self, + model: str = DEFAULT_MODEL, + api_key: str | None = None, + **config: any + ) -> None: + """Initialize OpenAI Realtime bidirectional model. + + Args: + model: OpenAI model identifier (default: gpt-realtime). + api_key: OpenAI API key for authentication. + **config: Additional configuration (organization, project, session params). + """ + # Model configuration + self.model = model + self.api_key = api_key self.config = config - self.session_id = str(uuid.uuid4()) - self._active = True - self._event_queue = asyncio.Queue() + import os + if not self.api_key: + self.api_key = os.getenv("OPENAI_API_KEY") + if not self.api_key: + raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.") + + # Connection state (initialized in connect()) + self.websocket = None + self.session_id = None + self._active = False + + self._event_queue = None self._response_task = None self._function_call_buffer = {} - logger.debug("OpenAI Realtime session initialized: %s", self.session_id) - - def _require_active(self) -> bool: - """Check if session is active.""" - return self._active - - def _create_text_event(self, text: str, role: str) -> dict[str, any]: - """Create standardized text output event.""" - text_output: TextOutputEvent = {"text": text, "role": role} - return {"textOutput": text_output} - - def _create_voice_activity_event(self, activity_type: str) -> dict[str, any]: - """Create standardized voice activity event.""" - voice_activity: VoiceActivityEvent = {"activityType": activity_type} - return {"voiceActivity": voice_activity} - - + logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) - async def initialize( + async def connect( self, system_prompt: str | None = None, tools: list[ToolSpec] | None = None, messages: Messages | None = None, + **kwargs, ) -> None: - """Initialize session with configuration.""" + """Establish bidirectional connection to OpenAI Realtime API. + + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Additional configuration options. + """ + logger.info("Creating OpenAI Realtime connection...") + try: + # Initialize connection state + self.session_id = str(uuid.uuid4()) + self._active = True + self._event_queue = asyncio.Queue() + self._function_call_buffer = {} + + # Establish WebSocket connection + url = f"{OPENAI_REALTIME_URL}?model={self.model}" + + headers = [("Authorization", f"Bearer {self.api_key}")] + if "organization" in self.config: + headers.append(("OpenAI-Organization", self.config["organization"])) + if "project" in self.config: + headers.append(("OpenAI-Project", self.config["project"])) + + self.websocket = await websockets.connect(url, additional_headers=headers) + logger.info("WebSocket connected successfully") + + # Configure session session_config = self._build_session_config(system_prompt, tools) await self._send_event({"type": "session.update", "session": session_config}) + # Add conversation history if provided if messages: await self._add_conversation_history(messages) + # Start background response processor self._response_task = asyncio.create_task(self._process_responses()) - logger.info("OpenAI Realtime session initialized successfully") + logger.info("OpenAI Realtime connection established") except Exception as e: - logger.error("Error during OpenAI Realtime initialization: %s", e) + logger.error("OpenAI connection error: %s", e) raise + def _require_active(self) -> bool: + """Check if session is active.""" + return self._active + + def _create_text_event(self, text: str, role: str) -> dict[str, any]: + """Create standardized text output event.""" + text_output: TextOutputEvent = {"text": text, "role": role} + return {"textOutput": text_output} + + def _create_voice_activity_event(self, activity_type: str) -> dict[str, any]: + """Create standardized voice activity event.""" + voice_activity: VoiceActivityEvent = {"activityType": activity_type} + return {"voiceActivity": voice_activity} + def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict: """Build session configuration for OpenAI Realtime API.""" config = DEFAULT_SESSION_CONFIG.copy() @@ -201,11 +257,11 @@ async def _process_responses(self) -> None: self._active = False logger.debug("OpenAI Realtime response processor stopped") - async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: + async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: """Receive OpenAI events and convert to Strands format.""" connection_start: BidirectionalConnectionStartEvent = { "connectionId": self.session_id, - "metadata": {"provider": "openai_realtime", "model": self.config.get("model", DEFAULT_MODEL)}, + "metadata": {"provider": "openai_realtime", "model": self.model}, } yield {"BidirectionalConnectionStart": connection_start} @@ -366,19 +422,44 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] logger.debug("Unhandled OpenAI event type: %s", event_type) return None - async def send_audio_content(self, audio_input: AudioInputEvent) -> None: - """Send audio content to OpenAI for processing.""" + async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + """Unified send method for all content types. + + Dispatches to appropriate internal handler based on content type. + + Args: + content: Typed event (TextInputEvent, ImageInputEvent, AudioInputEvent, or ToolResult). + """ if not self._require_active(): return + try: + if isinstance(content, dict): + # Dispatch based on content structure + if "text" in content and "role" in content: + # TextInputEvent + await self._send_text_content(content["text"]) + elif "audioData" in content: + # AudioInputEvent + await self._send_audio_content(content) + elif "imageData" in content or "image_url" in content: + # ImageInputEvent - not supported by OpenAI Realtime yet + logger.warning("Image input not supported by OpenAI Realtime API") + elif "toolUseId" in content and "status" in content: + # ToolResult + await self._send_tool_result(content) + else: + logger.warning(f"Unknown content type with keys: {content.keys()}") + except Exception as e: + logger.error(f"Error sending content: {e}") + + async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Internal: Send audio content to OpenAI for processing.""" audio_base64 = base64.b64encode(audio_input["audioData"]).decode("utf-8") await self._send_event({"type": "input_audio_buffer.append", "audio": audio_base64}) - async def send_text_content(self, text: str) -> None: - """Send text content to OpenAI for processing.""" - if not self._require_active(): - return - + async def _send_text_content(self, text: str) -> None: + """Internal: Send text content to OpenAI for processing.""" item_data = { "type": "message", "role": "user", @@ -387,20 +468,26 @@ async def send_text_content(self, text: str) -> None: await self._send_event({"type": "conversation.item.create", "item": item_data}) await self._send_event({"type": "response.create"}) - async def send_interrupt(self) -> None: - """Send interruption signal to OpenAI.""" - if not self._require_active(): - return - + async def _send_interrupt(self) -> None: + """Internal: Send interruption signal to OpenAI.""" await self._send_event({"type": "response.cancel"}) - async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: - """Send tool result back to OpenAI.""" - if not self._require_active(): - return + async def _send_tool_result(self, tool_result: ToolResult) -> None: + """Internal: Send tool result back to OpenAI.""" + tool_use_id = tool_result.get("toolUseId") logger.debug("OpenAI tool result send: %s", tool_use_id) - result_text = json.dumps(result) if not isinstance(result, str) else result + + # Extract result content + result_data = {} + if "content" in tool_result: + # Extract text from content blocks + for block in tool_result["content"]: + if "text" in block: + result_data = block["text"] + break + + result_text = json.dumps(result_data) if not isinstance(result_data, str) else result_data item_data = { "type": "function_call_output", @@ -443,60 +530,3 @@ async def _send_event(self, event: dict[str, any]) -> None: raise -class OpenAIRealtimeBidirectionalModel(BidirectionalModel): - """OpenAI Realtime API provider for Strands bidirectional streaming. - - Provides real-time audio/text communication through OpenAI's Realtime API - with WebSocket connections, voice activity detection, and function calling. - """ - - def __init__( - self, - model: str = DEFAULT_MODEL, - api_key: str | None = None, - **config: any - ) -> None: - """Initialize OpenAI Realtime bidirectional model.""" - self.model = model - self.api_key = api_key - self.config = config - - import os - if not self.api_key: - self.api_key = os.getenv("OPENAI_API_KEY") - if not self.api_key: - raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.") - - logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) - - async def create_bidirectional_connection( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs, - ) -> BidirectionalModelSession: - """Create bidirectional connection to OpenAI Realtime API.""" - logger.info("Creating OpenAI Realtime connection...") - - try: - url = f"{OPENAI_REALTIME_URL}?model={self.model}" - - headers = [("Authorization", f"Bearer {self.api_key}")] - if "organization" in self.config: - headers.append(("OpenAI-Organization", self.config["organization"])) - if "project" in self.config: - headers.append(("OpenAI-Project", self.config["project"])) - - websocket = await websockets.connect(url, additional_headers=headers) - logger.info("WebSocket connected successfully") - - session = OpenAIRealtimeSession(websocket, self.config) - await session.initialize(system_prompt, tools, messages) - - logger.info("OpenAI Realtime connection established") - return session - - except Exception as e: - logger.error("OpenAI connection error: %s", e) - raise \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 4b215d74e..145710c3c 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -88,6 +88,20 @@ class ImageInputEvent(TypedDict): encoding: Literal["base64", "raw"] +class TextInputEvent(TypedDict): + """Text input event for sending text to the model. + + Used for sending text content through the send() method. + + Attributes: + text: The text content to send to the model. + role: The role of the message sender (typically "user"). + """ + + text: str + role: Role + + class TextOutputEvent(TypedDict): """Text output event from the model during bidirectional streaming. diff --git a/tests/strands/experimental/__init__.py b/tests/strands/experimental/__init__.py index e69de29bb..ac8db1d74 100644 --- a/tests/strands/experimental/__init__.py +++ b/tests/strands/experimental/__init__.py @@ -0,0 +1 @@ +"""Experimental features tests.""" diff --git a/tests/strands/experimental/bidirectional_streaming/__init__.py b/tests/strands/experimental/bidirectional_streaming/__init__.py new file mode 100644 index 000000000..ea37091cc --- /dev/null +++ b/tests/strands/experimental/bidirectional_streaming/__init__.py @@ -0,0 +1 @@ +"""Bidirectional streaming tests.""" diff --git a/tests/strands/experimental/bidirectional_streaming/models/__init__.py b/tests/strands/experimental/bidirectional_streaming/models/__init__.py new file mode 100644 index 000000000..ea9fbb2d0 --- /dev/null +++ b/tests/strands/experimental/bidirectional_streaming/models/__init__.py @@ -0,0 +1 @@ +"""Bidirectional streaming model tests.""" diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py new file mode 100644 index 000000000..a5baaa522 --- /dev/null +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -0,0 +1,500 @@ +"""Unit tests for Gemini Live bidirectional streaming model. + +Tests the unified GeminiLiveBidirectionalModel interface including: +- Model initialization and configuration +- Connection establishment +- Unified send() method with different content types +- Event receiving and conversion +- Connection lifecycle management +""" + +import unittest.mock +import uuid + +import pytest +from google import genai +from google.genai import types as genai_types + +from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveBidirectionalModel +from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( + AudioInputEvent, + ImageInputEvent, + TextInputEvent, +) +from strands.types.tools import ToolResult + + +@pytest.fixture +def mock_genai_client(): + """Mock the Google GenAI client.""" + with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.gemini_live.genai.Client") as mock_client_cls: + mock_client = mock_client_cls.return_value + mock_client.aio = unittest.mock.MagicMock() + + # Mock the live session + mock_live_session = unittest.mock.AsyncMock() + + # Mock the context manager + mock_live_session_cm = unittest.mock.MagicMock() + mock_live_session_cm.__aenter__ = unittest.mock.AsyncMock(return_value=mock_live_session) + mock_live_session_cm.__aexit__ = unittest.mock.AsyncMock(return_value=None) + + # Make connect return the context manager + mock_client.aio.live.connect = unittest.mock.MagicMock(return_value=mock_live_session_cm) + + yield mock_client, mock_live_session, mock_live_session_cm + + +@pytest.fixture +def model_id(): + return "models/gemini-2.0-flash-live-preview-04-09" + + +@pytest.fixture +def api_key(): + return "test-api-key" + + +@pytest.fixture +def model(mock_genai_client, model_id, api_key): + """Create a GeminiLiveBidirectionalModel instance.""" + _ = mock_genai_client + return GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + + +@pytest.fixture +def tool_spec(): + return { + "description": "Calculate mathematical expressions", + "name": "calculator", + "inputSchema": {"json": {"type": "object", "properties": {"expression": {"type": "string"}}}}, + } + + +@pytest.fixture +def system_prompt(): + return "You are a helpful assistant" + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "Hello"}]}] + + +# Initialization Tests + + +def test_init_default_config(mock_genai_client): + """Test model initialization with default configuration.""" + _ = mock_genai_client + + model = GeminiLiveBidirectionalModel() + + assert model.model_id == "models/gemini-2.0-flash-live-preview-04-09" + assert model.api_key is None + assert model._active is False + assert model.live_session is None + + +def test_init_with_api_key(mock_genai_client, model_id, api_key): + """Test model initialization with API key.""" + mock_client, _, _ = mock_genai_client + + model = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + + assert model.model_id == model_id + assert model.api_key == api_key + + # Verify client was created with correct parameters + mock_client_cls = unittest.mock.patch("strands.experimental.bidirectional_streaming.models.gemini_live.genai.Client").start() + GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + mock_client_cls.assert_called() + + +def test_init_with_custom_config(mock_genai_client, model_id): + """Test model initialization with custom configuration.""" + _ = mock_genai_client + + custom_config = {"temperature": 0.7, "top_p": 0.9} + model = GeminiLiveBidirectionalModel(model_id=model_id, **custom_config) + + assert model.config == custom_config + + +# Connection Tests + + +@pytest.mark.asyncio +async def test_connect_basic(mock_genai_client, model): + """Test basic connection establishment.""" + mock_client, mock_live_session, _ = mock_genai_client + + await model.connect() + + assert model._active is True + assert model.session_id is not None + assert model.live_session == mock_live_session + mock_client.aio.live.connect.assert_called_once() + + +@pytest.mark.asyncio +async def test_connect_with_system_prompt(mock_genai_client, model, system_prompt): + """Test connection with system prompt.""" + mock_client, _, _ = mock_genai_client + + await model.connect(system_prompt=system_prompt) + + # Verify system prompt was included in config + call_args = mock_client.aio.live.connect.call_args + config = call_args.kwargs.get("config", {}) + assert config.get("system_instruction") == system_prompt + + +@pytest.mark.asyncio +async def test_connect_with_tools(mock_genai_client, model, tool_spec): + """Test connection with tools.""" + mock_client, _, _ = mock_genai_client + + await model.connect(tools=[tool_spec]) + + # Verify tools were formatted and included + call_args = mock_client.aio.live.connect.call_args + config = call_args.kwargs.get("config", {}) + assert "tools" in config + assert len(config["tools"]) > 0 + + +@pytest.mark.asyncio +async def test_connect_with_messages(mock_genai_client, model, messages): + """Test connection with message history.""" + _, mock_live_session, _ = mock_genai_client + + await model.connect(messages=messages) + + # Verify message history was sent + mock_live_session.send_client_content.assert_called() + + +@pytest.mark.asyncio +async def test_connect_error_handling(mock_genai_client, model): + """Test connection error handling.""" + mock_client, _, _ = mock_genai_client + mock_client.aio.live.connect.side_effect = Exception("Connection failed") + + with pytest.raises(Exception, match="Connection failed"): + await model.connect() + + +# Send Method Tests + + +@pytest.mark.asyncio +async def test_send_text_input(mock_genai_client, model): + """Test sending text input through unified send() method.""" + _, mock_live_session, _ = mock_genai_client + await model.connect() + + text_input: TextInputEvent = {"text": "Hello", "role": "user"} + await model.send(text_input) + + # Verify text was sent via send_client_content + mock_live_session.send_client_content.assert_called_once() + call_args = mock_live_session.send_client_content.call_args + content = call_args.kwargs.get("turns") + assert content.role == "user" + assert content.parts[0].text == "Hello" + + +@pytest.mark.asyncio +async def test_send_audio_input(mock_genai_client, model): + """Test sending audio input through unified send() method.""" + _, mock_live_session, _ = mock_genai_client + await model.connect() + + audio_input: AudioInputEvent = { + "audioData": b"audio_bytes", + "format": "pcm", + "sampleRate": 16000, + "channels": 1, + } + await model.send(audio_input) + + # Verify audio was sent via send_realtime_input + mock_live_session.send_realtime_input.assert_called_once() + + +@pytest.mark.asyncio +async def test_send_image_input(mock_genai_client, model): + """Test sending image input through unified send() method.""" + _, mock_live_session, _ = mock_genai_client + await model.connect() + + image_input: ImageInputEvent = { + "imageData": b"image_bytes", + "mimeType": "image/jpeg", + "encoding": "raw", + } + await model.send(image_input) + + # Verify image was sent + mock_live_session.send.assert_called_once() + + +@pytest.mark.asyncio +async def test_send_tool_result(mock_genai_client, model): + """Test sending tool result through unified send() method.""" + _, mock_live_session, _ = mock_genai_client + await model.connect() + + tool_result: ToolResult = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Result: 42"}], + } + await model.send(tool_result) + + # Verify tool result was sent + mock_live_session.send_tool_response.assert_called_once() + + +@pytest.mark.asyncio +async def test_send_when_inactive(mock_genai_client, model): + """Test that send() does nothing when connection is inactive.""" + _, mock_live_session, _ = mock_genai_client + + # Don't connect, so _active is False + text_input: TextInputEvent = {"text": "Hello", "role": "user"} + await model.send(text_input) + + # Verify nothing was sent + mock_live_session.send_client_content.assert_not_called() + + +@pytest.mark.asyncio +async def test_send_unknown_content_type(mock_genai_client, model): + """Test sending unknown content type logs warning.""" + _, _, _ = mock_genai_client + await model.connect() + + unknown_content = {"unknown_field": "value"} + + # Should not raise, just log warning + await model.send(unknown_content) + + +# Receive Method Tests + + +@pytest.mark.asyncio +async def test_receive_connection_start_event(mock_genai_client, model, agenerator): + """Test that receive() emits connection start event.""" + _, mock_live_session, _ = mock_genai_client + mock_live_session.receive.return_value = agenerator([]) + + await model.connect() + + # Get first event + receive_gen = model.receive() + first_event = await anext(receive_gen) + + # First event should be connection start + assert "BidirectionalConnectionStart" in first_event + assert first_event["BidirectionalConnectionStart"]["connectionId"] == model.session_id + + # Close to stop the loop + await model.close() + + +@pytest.mark.asyncio +async def test_receive_connection_end_event(mock_genai_client, model, agenerator): + """Test that receive() emits connection end event.""" + _, mock_live_session, _ = mock_genai_client + mock_live_session.receive.return_value = agenerator([]) + + await model.connect() + + # Collect events until connection ends + events = [] + async for event in model.receive(): + events.append(event) + # Close after first event to trigger connection end + if len(events) == 1: + await model.close() + + # Last event should be connection end + assert "BidirectionalConnectionEnd" in events[-1] + + +@pytest.mark.asyncio +async def test_receive_text_output(mock_genai_client, model): + """Test receiving text output from model.""" + _, mock_live_session, _ = mock_genai_client + + mock_message = unittest.mock.Mock() + mock_message.text = "Hello from Gemini" + mock_message.data = None + mock_message.tool_call = None + mock_message.server_content = None + + await model.connect() + + # Test the conversion method directly + converted_event = model._convert_gemini_live_event(mock_message) + + assert "textOutput" in converted_event + assert converted_event["textOutput"]["text"] == "Hello from Gemini" + assert converted_event["textOutput"]["role"] == "assistant" + + +@pytest.mark.asyncio +async def test_receive_audio_output(mock_genai_client, model): + """Test receiving audio output from model.""" + _, mock_live_session, _ = mock_genai_client + + mock_message = unittest.mock.Mock() + mock_message.text = None + mock_message.data = b"audio_data" + mock_message.tool_call = None + mock_message.server_content = None + + await model.connect() + + # Test the conversion method directly + converted_event = model._convert_gemini_live_event(mock_message) + + assert "audioOutput" in converted_event + assert converted_event["audioOutput"]["audioData"] == b"audio_data" + assert converted_event["audioOutput"]["format"] == "pcm" + + +@pytest.mark.asyncio +async def test_receive_tool_call(mock_genai_client, model): + """Test receiving tool call from model.""" + _, mock_live_session, _ = mock_genai_client + + mock_func_call = unittest.mock.Mock() + mock_func_call.id = "tool-123" + mock_func_call.name = "calculator" + mock_func_call.args = {"expression": "2+2"} + + mock_tool_call = unittest.mock.Mock() + mock_tool_call.function_calls = [mock_func_call] + + mock_message = unittest.mock.Mock() + mock_message.text = None + mock_message.data = None + mock_message.tool_call = mock_tool_call + mock_message.server_content = None + + await model.connect() + + # Test the conversion method directly + converted_event = model._convert_gemini_live_event(mock_message) + + assert "toolUse" in converted_event + assert converted_event["toolUse"]["toolUseId"] == "tool-123" + assert converted_event["toolUse"]["name"] == "calculator" + + +@pytest.mark.asyncio +async def test_receive_interruption(mock_genai_client, model): + """Test receiving interruption event.""" + _, mock_live_session, _ = mock_genai_client + + mock_server_content = unittest.mock.Mock() + mock_server_content.interrupted = True + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + + mock_message = unittest.mock.Mock() + mock_message.text = None + mock_message.data = None + mock_message.tool_call = None + mock_message.server_content = mock_server_content + + await model.connect() + + # Test the conversion method directly + converted_event = model._convert_gemini_live_event(mock_message) + + assert "interruptionDetected" in converted_event + assert converted_event["interruptionDetected"]["reason"] == "user_input" + + +# Close Method Tests + + +@pytest.mark.asyncio +async def test_close_connection(mock_genai_client, model): + """Test closing connection.""" + _, _, mock_live_session_cm = mock_genai_client + + await model.connect() + await model.close() + + assert model._active is False + mock_live_session_cm.__aexit__.assert_called_once() + + +@pytest.mark.asyncio +async def test_close_when_not_connected(mock_genai_client, model): + """Test closing when not connected does nothing.""" + _, _, mock_live_session_cm = mock_genai_client + + # Don't connect + await model.close() + + # Should not raise, and __aexit__ should not be called + mock_live_session_cm.__aexit__.assert_not_called() + + +@pytest.mark.asyncio +async def test_close_error_handling(mock_genai_client, model): + """Test close error handling.""" + _, _, mock_live_session_cm = mock_genai_client + mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") + + await model.connect() + + with pytest.raises(Exception, match="Close failed"): + await model.close() + + +# Helper Method Tests + + +def test_build_live_config_basic(model): + """Test building basic live config.""" + config = model._build_live_config() + + assert isinstance(config, dict) + + +def test_build_live_config_with_system_prompt(model, system_prompt): + """Test building config with system prompt.""" + config = model._build_live_config(system_prompt=system_prompt) + + assert config["system_instruction"] == system_prompt + + +def test_build_live_config_with_tools(model, tool_spec): + """Test building config with tools.""" + config = model._build_live_config(tools=[tool_spec]) + + assert "tools" in config + assert len(config["tools"]) > 0 + + +def test_format_tools_for_live_api(model, tool_spec): + """Test tool formatting for Gemini Live API.""" + formatted_tools = model._format_tools_for_live_api([tool_spec]) + + assert len(formatted_tools) == 1 + assert isinstance(formatted_tools[0], genai_types.Tool) + + +def test_format_tools_empty_list(model): + """Test formatting empty tool list.""" + formatted_tools = model._format_tools_for_live_api([]) + + assert formatted_tools == [] diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py new file mode 100644 index 000000000..451a98aa2 --- /dev/null +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -0,0 +1,551 @@ +"""Unit tests for Nova Sonic bidirectional model implementation. + +Tests the unified BidirectionalModel interface implementation for Amazon Nova Sonic, +covering connection lifecycle, event conversion, audio streaming, and tool execution. +""" + +import asyncio +import base64 +import json +import uuid +from typing import Any, Dict +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest +import pytest_asyncio + +from strands.experimental.bidirectional_streaming.models.novasonic import ( + NovaSonicBidirectionalModel, +) +from strands.types.tools import ToolResult, ToolSpec + + +# Test fixtures +@pytest.fixture +def model_id(): + """Nova Sonic model identifier.""" + return "amazon.nova-sonic-v1:0" + + +@pytest.fixture +def region(): + """AWS region.""" + return "us-east-1" + + +@pytest.fixture +def mock_stream(): + """Mock Nova Sonic bidirectional stream.""" + stream = AsyncMock() + stream.input_stream = AsyncMock() + stream.input_stream.send = AsyncMock() + stream.input_stream.close = AsyncMock() + stream.await_output = AsyncMock() + return stream + + +@pytest.fixture +def mock_client(mock_stream): + """Mock Bedrock Runtime client.""" + client = AsyncMock() + client.invoke_model_with_bidirectional_stream = AsyncMock(return_value=mock_stream) + return client + + +@pytest_asyncio.fixture +async def nova_model(model_id, region): + """Create Nova Sonic model instance.""" + model = NovaSonicBidirectionalModel(model_id=model_id, region=region) + yield model + # Cleanup + if model._active: + await model.close() + + +# Connection lifecycle tests +@pytest.mark.asyncio +async def test_model_initialization(model_id, region): + """Test model initialization with configuration.""" + model = NovaSonicBidirectionalModel(model_id=model_id, region=region) + + assert model.model_id == model_id + assert model.region == region + assert model.stream is None + assert not model._active + assert model.prompt_name is None + + +@pytest.mark.asyncio +async def test_connect_establishes_connection(nova_model, mock_client, mock_stream): + """Test that connect() establishes bidirectional connection.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + await nova_model.connect(system_prompt="Test system prompt") + + assert nova_model._active + assert nova_model.stream == mock_stream + assert nova_model.prompt_name is not None + assert mock_client.invoke_model_with_bidirectional_stream.called + + +@pytest.mark.asyncio +async def test_connect_sends_initialization_events(nova_model, mock_client, mock_stream): + """Test that connect() sends proper initialization sequence.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + system_prompt = "You are a helpful assistant" + tools = [ + { + "name": "get_weather", + "description": "Get weather information", + "inputSchema": {"json": json.dumps({"type": "object", "properties": {}})} + } + ] + + await nova_model.connect(system_prompt=system_prompt, tools=tools) + + # Verify initialization events were sent + assert mock_stream.input_stream.send.call_count >= 3 # connectionStart, promptStart, system prompt + + +@pytest.mark.asyncio +async def test_close_cleanup(nova_model, mock_client, mock_stream): + """Test that close() properly cleans up resources.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + await nova_model.connect() + await nova_model.close() + + assert not nova_model._active + assert mock_stream.input_stream.close.called + + +# Event conversion tests +@pytest.mark.asyncio +async def test_receive_emits_connection_start(nova_model, mock_client, mock_stream): + """Test that receive() emits connection start event.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + # Setup mock to return no events and then stop + async def mock_wait_for(*args, **kwargs): + await asyncio.sleep(0.1) + nova_model._active = False + raise asyncio.TimeoutError() + + with patch("asyncio.wait_for", side_effect=mock_wait_for): + await nova_model.connect() + + events = [] + async for event in nova_model.receive(): + events.append(event) + + # Should have connection start and end + assert len(events) >= 2 + assert "BidirectionalConnectionStart" in events[0] + assert events[0]["BidirectionalConnectionStart"]["connectionId"] == nova_model.prompt_name + + +@pytest.mark.asyncio +async def test_convert_audio_output_event(nova_model): + """Test conversion of Nova Sonic audio output to standard format.""" + audio_bytes = b"test audio data" + audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") + + nova_event = { + "audioOutput": { + "content": audio_base64 + } + } + + result = nova_model._convert_nova_event(nova_event) + + assert result is not None + assert "audioOutput" in result + assert result["audioOutput"]["audioData"] == audio_bytes + assert result["audioOutput"]["format"] == "pcm" + assert result["audioOutput"]["sampleRate"] == 24000 + + +@pytest.mark.asyncio +async def test_convert_text_output_event(nova_model): + """Test conversion of Nova Sonic text output to standard format.""" + nova_event = { + "textOutput": { + "content": "Hello, world!", + "role": "ASSISTANT" + } + } + + result = nova_model._convert_nova_event(nova_event) + + assert result is not None + assert "textOutput" in result + assert result["textOutput"]["text"] == "Hello, world!" + assert result["textOutput"]["role"] == "assistant" + + +@pytest.mark.asyncio +async def test_convert_tool_use_event(nova_model): + """Test conversion of Nova Sonic tool use to standard format.""" + tool_input = {"location": "Seattle"} + nova_event = { + "toolUse": { + "toolUseId": "tool-123", + "toolName": "get_weather", + "content": json.dumps(tool_input) + } + } + + result = nova_model._convert_nova_event(nova_event) + + assert result is not None + assert "toolUse" in result + assert result["toolUse"]["toolUseId"] == "tool-123" + assert result["toolUse"]["name"] == "get_weather" + assert result["toolUse"]["input"] == tool_input + + +@pytest.mark.asyncio +async def test_convert_interruption_event(nova_model): + """Test conversion of Nova Sonic interruption to standard format.""" + nova_event = { + "stopReason": "INTERRUPTED" + } + + result = nova_model._convert_nova_event(nova_event) + + assert result is not None + assert "interruptionDetected" in result + assert result["interruptionDetected"]["reason"] == "user_input" + + +@pytest.mark.asyncio +async def test_convert_usage_metrics_event(nova_model): + """Test conversion of Nova Sonic usage event to standard format.""" + nova_event = { + "usageEvent": { + "totalTokens": 100, + "totalInputTokens": 40, + "totalOutputTokens": 60, + "details": { + "total": { + "output": { + "speechTokens": 30 + } + } + } + } + } + + result = nova_model._convert_nova_event(nova_event) + + assert result is not None + assert "usageMetrics" in result + assert result["usageMetrics"]["totalTokens"] == 100 + assert result["usageMetrics"]["inputTokens"] == 40 + assert result["usageMetrics"]["outputTokens"] == 60 + assert result["usageMetrics"]["audioTokens"] == 30 + + +@pytest.mark.asyncio +async def test_convert_content_start_tracks_role(nova_model): + """Test that contentStart events track role for subsequent text output.""" + nova_event = { + "contentStart": { + "role": "USER" + } + } + + result = nova_model._convert_nova_event(nova_event) + + # contentStart doesn't emit an event but stores role + assert result is None + assert nova_model._current_role == "USER" + + +# Send method tests +@pytest.mark.asyncio +async def test_send_text_content(nova_model, mock_client, mock_stream): + """Test sending text content through unified send() method.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + await nova_model.connect() + + text_event = { + "text": "Hello, Nova!", + "role": "user" + } + + await nova_model.send(text_event) + + # Should send contentStart, textInput, and contentEnd + assert mock_stream.input_stream.send.call_count >= 3 + + +@pytest.mark.asyncio +async def test_send_audio_content(nova_model, mock_client, mock_stream): + """Test sending audio content through unified send() method.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + await nova_model.connect() + + audio_event = { + "audioData": b"audio data", + "format": "pcm", + "sampleRate": 16000, + "channels": 1 + } + + await nova_model.send(audio_event) + + # Should start audio connection and send audio + assert nova_model.audio_connection_active + assert mock_stream.input_stream.send.called + + +@pytest.mark.asyncio +async def test_send_tool_result(nova_model, mock_client, mock_stream): + """Test sending tool result through unified send() method.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + await nova_model.connect() + + tool_result = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Weather is sunny"}] + } + + await nova_model.send(tool_result) + + # Should send contentStart, toolResult, and contentEnd + assert mock_stream.input_stream.send.call_count >= 3 + + +@pytest.mark.asyncio +async def test_send_image_content_not_supported(nova_model, mock_client, mock_stream, caplog): + """Test that image content logs warning (not supported by Nova Sonic).""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + await nova_model.connect() + + image_event = { + "imageData": b"image data", + "mimeType": "image/jpeg" + } + + await nova_model.send(image_event) + + # Should log warning about unsupported image input + assert any("not supported" in record.message.lower() for record in caplog.records) + + +# Audio streaming tests +@pytest.mark.asyncio +async def test_audio_connection_lifecycle(nova_model, mock_client, mock_stream): + """Test audio connection start and end lifecycle.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + await nova_model.connect() + + # Start audio connection + await nova_model._start_audio_connection() + assert nova_model.audio_connection_active + + # End audio connection + await nova_model._end_audio_input() + assert not nova_model.audio_connection_active + + +@pytest.mark.asyncio +async def test_silence_detection_ends_audio(nova_model, mock_client, mock_stream): + """Test that silence detection automatically ends audio input.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + nova_model.silence_threshold = 0.1 # Short threshold for testing + + await nova_model.connect() + + # Send audio to start connection + audio_event = { + "audioData": b"audio data", + "format": "pcm", + "sampleRate": 16000, + "channels": 1 + } + + await nova_model.send(audio_event) + assert nova_model.audio_connection_active + + # Wait for silence detection + await asyncio.sleep(0.2) + + # Audio connection should be ended + assert not nova_model.audio_connection_active + + +# Tool configuration tests +@pytest.mark.asyncio +async def test_build_tool_configuration(nova_model): + """Test building tool configuration from tool specs.""" + tools = [ + { + "name": "get_weather", + "description": "Get weather information", + "inputSchema": { + "json": json.dumps({ + "type": "object", + "properties": { + "location": {"type": "string"} + } + }) + } + } + ] + + tool_config = nova_model._build_tool_configuration(tools) + + assert len(tool_config) == 1 + assert tool_config[0]["toolSpec"]["name"] == "get_weather" + assert tool_config[0]["toolSpec"]["description"] == "Get weather information" + assert "inputSchema" in tool_config[0]["toolSpec"] + + +# Event template tests +@pytest.mark.asyncio +async def test_get_connection_start_event(nova_model): + """Test connection start event generation.""" + event_json = nova_model._get_connection_start_event() + event = json.loads(event_json) + + assert "event" in event + assert "sessionStart" in event["event"] + assert "inferenceConfiguration" in event["event"]["sessionStart"] + + +@pytest.mark.asyncio +async def test_get_prompt_start_event(nova_model): + """Test prompt start event generation.""" + nova_model.prompt_name = "test-prompt" + + event_json = nova_model._get_prompt_start_event([]) + event = json.loads(event_json) + + assert "event" in event + assert "promptStart" in event["event"] + assert event["event"]["promptStart"]["promptName"] == "test-prompt" + + +@pytest.mark.asyncio +async def test_get_text_input_event(nova_model): + """Test text input event generation.""" + nova_model.prompt_name = "test-prompt" + content_name = "test-content" + + event_json = nova_model._get_text_input_event(content_name, "Hello") + event = json.loads(event_json) + + assert "event" in event + assert "textInput" in event["event"] + assert event["event"]["textInput"]["content"] == "Hello" + + +@pytest.mark.asyncio +async def test_get_tool_result_event(nova_model): + """Test tool result event generation.""" + nova_model.prompt_name = "test-prompt" + content_name = "test-content" + result = {"result": "Success"} + + event_json = nova_model._get_tool_result_event(content_name, result) + event = json.loads(event_json) + + assert "event" in event + assert "toolResult" in event["event"] + assert json.loads(event["event"]["toolResult"]["content"]) == result + + +# Error handling tests +@pytest.mark.asyncio +async def test_send_when_inactive(nova_model): + """Test that send() handles inactive connection gracefully.""" + text_event = { + "text": "Hello", + "role": "user" + } + + # Should not raise error when inactive + await nova_model.send(text_event) + + +@pytest.mark.asyncio +async def test_close_when_already_closed(nova_model): + """Test that close() handles already closed connection.""" + # Should not raise error when already inactive + await nova_model.close() + await nova_model.close() # Second call should be safe + + +@pytest.mark.asyncio +async def test_response_processor_handles_errors(nova_model, mock_client, mock_stream): + """Test that response processor handles errors gracefully.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + # Setup mock to raise error + async def mock_error(*args, **kwargs): + raise Exception("Test error") + + mock_stream.await_output.side_effect = mock_error + + await nova_model.connect() + + # Wait a bit for response processor to handle error + await asyncio.sleep(0.1) + + # Should still be able to close cleanly + await nova_model.close() + + +# Integration-style tests +@pytest.mark.asyncio +async def test_full_conversation_flow(nova_model, mock_client, mock_stream): + """Test a complete conversation flow with text and audio.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + # Connect + await nova_model.connect(system_prompt="You are helpful") + + # Send text + await nova_model.send({"text": "Hello", "role": "user"}) + + # Send audio + await nova_model.send({ + "audioData": b"audio", + "format": "pcm", + "sampleRate": 16000, + "channels": 1 + }) + + # Send tool result + await nova_model.send({ + "toolUseId": "tool-1", + "status": "success", + "content": [{"text": "Result"}] + }) + + # Close + await nova_model.close() + + # Verify all operations completed + assert not nova_model._active diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py new file mode 100644 index 000000000..be69929cd --- /dev/null +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -0,0 +1,625 @@ +"""Unit tests for OpenAI Realtime bidirectional streaming model. + +Tests the unified OpenAIRealtimeBidirectionalModel interface including: +- Model initialization and configuration +- Connection establishment with WebSocket +- Unified send() method with different content types +- Event receiving and conversion +- Connection lifecycle management +- Background task management +""" + +import asyncio +import base64 +import json +import unittest.mock + +import pytest + +from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel +from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( + AudioInputEvent, + ImageInputEvent, + TextInputEvent, +) +from strands.types.tools import ToolResult + + +@pytest.fixture +def mock_websocket(): + """Mock WebSocket connection.""" + mock_ws = unittest.mock.AsyncMock() + mock_ws.send = unittest.mock.AsyncMock() + mock_ws.close = unittest.mock.AsyncMock() + return mock_ws + + +@pytest.fixture +def mock_websockets_connect(mock_websocket): + """Mock websockets.connect function.""" + async def async_connect(*args, **kwargs): + return mock_websocket + + with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.websockets.connect") as mock_connect: + mock_connect.side_effect = async_connect + yield mock_connect, mock_websocket + + +@pytest.fixture +def model_name(): + return "gpt-realtime" + + +@pytest.fixture +def api_key(): + return "test-api-key" + + +@pytest.fixture +def model(api_key, model_name): + """Create an OpenAIRealtimeBidirectionalModel instance.""" + return OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + + +@pytest.fixture +def tool_spec(): + return { + "description": "Calculate mathematical expressions", + "name": "calculator", + "inputSchema": {"json": {"type": "object", "properties": {"expression": {"type": "string"}}}}, + } + + +@pytest.fixture +def system_prompt(): + return "You are a helpful assistant" + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "Hello"}]}] + + +# Initialization Tests + + +def test_init_default_config(): + """Test model initialization with default configuration.""" + model = OpenAIRealtimeBidirectionalModel(api_key="test-key") + + assert model.model == "gpt-realtime" + assert model.api_key == "test-key" + assert model._active is False + assert model.websocket is None + + +def test_init_with_api_key(api_key, model_name): + """Test model initialization with API key.""" + model = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + + assert model.model == model_name + assert model.api_key == api_key + + +def test_init_with_custom_config(model_name, api_key): + """Test model initialization with custom configuration.""" + custom_config = {"organization": "org-123", "project": "proj-456"} + model = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key, **custom_config) + + assert model.config == custom_config + + +def test_init_without_api_key_raises(): + """Test that initialization without API key raises error.""" + with unittest.mock.patch.dict("os.environ", {}, clear=True): + with pytest.raises(ValueError, match="OpenAI API key is required"): + OpenAIRealtimeBidirectionalModel() + + +def test_init_with_env_api_key(): + """Test initialization with API key from environment.""" + with unittest.mock.patch.dict("os.environ", {"OPENAI_API_KEY": "env-key"}): + model = OpenAIRealtimeBidirectionalModel() + assert model.api_key == "env-key" + + +# Connection Tests + + +@pytest.mark.asyncio +async def test_connect_basic(mock_websockets_connect, model): + """Test basic connection establishment.""" + mock_connect, mock_ws = mock_websockets_connect + + await model.connect() + + assert model._active is True + assert model.session_id is not None + assert model.websocket == mock_ws + assert model._event_queue is not None + mock_connect.assert_called_once() + + +@pytest.mark.asyncio +async def test_connect_with_system_prompt(mock_websockets_connect, model, system_prompt): + """Test connection with system prompt.""" + _, mock_ws = mock_websockets_connect + + await model.connect(system_prompt=system_prompt) + + # Verify session.update was sent with system prompt + calls = mock_ws.send.call_args_list + session_update_call = None + for call in calls: + message = json.loads(call[0][0]) + if message.get("type") == "session.update": + session_update_call = message + break + + assert session_update_call is not None + assert session_update_call["session"]["instructions"] == system_prompt + + +@pytest.mark.asyncio +async def test_connect_with_tools(mock_websockets_connect, model, tool_spec): + """Test connection with tools.""" + _, mock_ws = mock_websockets_connect + + await model.connect(tools=[tool_spec]) + + # Verify tools were included in session config + calls = mock_ws.send.call_args_list + session_update_call = None + for call in calls: + message = json.loads(call[0][0]) + if message.get("type") == "session.update": + session_update_call = message + break + + assert session_update_call is not None + assert "tools" in session_update_call["session"] + + +@pytest.mark.asyncio +async def test_connect_with_messages(mock_websockets_connect, model, messages): + """Test connection with message history.""" + _, mock_ws = mock_websockets_connect + + await model.connect(messages=messages) + + # Verify conversation items were created + calls = mock_ws.send.call_args_list + item_create_calls = [ + json.loads(call[0][0]) for call in calls + if json.loads(call[0][0]).get("type") == "conversation.item.create" + ] + + assert len(item_create_calls) > 0 + + +@pytest.mark.asyncio +async def test_connect_error_handling(mock_websockets_connect, model): + """Test connection error handling.""" + mock_connect, _ = mock_websockets_connect + mock_connect.side_effect = Exception("Connection failed") + + with pytest.raises(Exception, match="Connection failed"): + await model.connect() + + +@pytest.mark.asyncio +async def test_connect_with_organization_header(mock_websockets_connect, api_key): + """Test connection includes organization header.""" + mock_connect, _ = mock_websockets_connect + + model = OpenAIRealtimeBidirectionalModel( + api_key=api_key, + organization="org-123" + ) + await model.connect() + + # Verify headers were passed + call_kwargs = mock_connect.call_args.kwargs + headers = call_kwargs.get("additional_headers", []) + org_header = [h for h in headers if h[0] == "OpenAI-Organization"] + assert len(org_header) == 1 + assert org_header[0][1] == "org-123" + + +# Send Method Tests + + +@pytest.mark.asyncio +async def test_send_text_input(mock_websockets_connect, model): + """Test sending text input through unified send() method.""" + _, mock_ws = mock_websockets_connect + await model.connect() + + text_input: TextInputEvent = {"text": "Hello", "role": "user"} + await model.send(text_input) + + # Verify conversation.item.create and response.create were sent + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + + item_create = [m for m in messages if m.get("type") == "conversation.item.create"] + response_create = [m for m in messages if m.get("type") == "response.create"] + + assert len(item_create) > 0 + assert len(response_create) > 0 + + +@pytest.mark.asyncio +async def test_send_audio_input(mock_websockets_connect, model): + """Test sending audio input through unified send() method.""" + _, mock_ws = mock_websockets_connect + await model.connect() + + audio_input: AudioInputEvent = { + "audioData": b"audio_bytes", + "format": "pcm", + "sampleRate": 24000, + "channels": 1, + } + await model.send(audio_input) + + # Verify input_audio_buffer.append was sent + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + + audio_append = [m for m in messages if m.get("type") == "input_audio_buffer.append"] + assert len(audio_append) > 0 + + # Verify audio was base64 encoded + assert "audio" in audio_append[0] + decoded = base64.b64decode(audio_append[0]["audio"]) + assert decoded == b"audio_bytes" + + +@pytest.mark.asyncio +async def test_send_image_input(mock_websockets_connect, model): + """Test sending image input logs warning (not supported).""" + _, mock_ws = mock_websockets_connect + await model.connect() + + image_input: ImageInputEvent = { + "imageData": b"image_bytes", + "mimeType": "image/jpeg", + "encoding": "raw", + } + + with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: + await model.send(image_input) + mock_logger.warning.assert_called_with("Image input not supported by OpenAI Realtime API") + + +@pytest.mark.asyncio +async def test_send_tool_result(mock_websockets_connect, model): + """Test sending tool result through unified send() method.""" + _, mock_ws = mock_websockets_connect + await model.connect() + + tool_result: ToolResult = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Result: 42"}], + } + await model.send(tool_result) + + # Verify function_call_output was created + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + + item_create = [m for m in messages if m.get("type") == "conversation.item.create"] + assert len(item_create) > 0 + + # Verify it's a function_call_output + item = item_create[-1].get("item", {}) + assert item.get("type") == "function_call_output" + assert item.get("call_id") == "tool-123" + + +@pytest.mark.asyncio +async def test_send_when_inactive(mock_websockets_connect, model): + """Test that send() does nothing when connection is inactive.""" + _, mock_ws = mock_websockets_connect + + # Don't connect, so _active is False + text_input: TextInputEvent = {"text": "Hello", "role": "user"} + await model.send(text_input) + + # Verify nothing was sent + mock_ws.send.assert_not_called() + + +@pytest.mark.asyncio +async def test_send_unknown_content_type(mock_websockets_connect, model): + """Test sending unknown content type logs warning.""" + _, _ = mock_websockets_connect + await model.connect() + + unknown_content = {"unknown_field": "value"} + + with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: + await model.send(unknown_content) + # Should log warning about unknown content + assert mock_logger.warning.called + + +# Receive Method Tests + + +@pytest.mark.asyncio +async def test_receive_connection_start_event(mock_websockets_connect, model): + """Test that receive() emits connection start event.""" + _, _ = mock_websockets_connect + + await model.connect() + + # Get first event + receive_gen = model.receive() + first_event = await anext(receive_gen) + + # First event should be connection start + assert "BidirectionalConnectionStart" in first_event + assert first_event["BidirectionalConnectionStart"]["connectionId"] == model.session_id + + # Close to stop the loop + await model.close() + + +@pytest.mark.asyncio +async def test_receive_connection_end_event(mock_websockets_connect, model): + """Test that receive() emits connection end event.""" + _, _ = mock_websockets_connect + + await model.connect() + + # Collect events until connection ends + events = [] + async for event in model.receive(): + events.append(event) + # Close after first event to trigger connection end + if len(events) == 1: + await model.close() + + # Last event should be connection end + assert "BidirectionalConnectionEnd" in events[-1] + + +@pytest.mark.asyncio +async def test_receive_audio_output(mock_websockets_connect, model): + """Test receiving audio output from model.""" + _, _ = mock_websockets_connect + await model.connect() + + # Create mock OpenAI event + openai_event = { + "type": "response.output_audio.delta", + "delta": base64.b64encode(b"audio_data").decode() + } + + # Test conversion directly + converted_event = model._convert_openai_event(openai_event) + + assert "audioOutput" in converted_event + assert converted_event["audioOutput"]["audioData"] == b"audio_data" + assert converted_event["audioOutput"]["format"] == "pcm" + + +@pytest.mark.asyncio +async def test_receive_text_output(mock_websockets_connect, model): + """Test receiving text output from model.""" + _, _ = mock_websockets_connect + await model.connect() + + # Create mock OpenAI event + openai_event = { + "type": "response.output_text.delta", + "delta": "Hello from OpenAI" + } + + # Test conversion directly + converted_event = model._convert_openai_event(openai_event) + + assert "textOutput" in converted_event + assert converted_event["textOutput"]["text"] == "Hello from OpenAI" + assert converted_event["textOutput"]["role"] == "assistant" + + +@pytest.mark.asyncio +async def test_receive_function_call(mock_websockets_connect, model): + """Test receiving function call from model.""" + _, _ = mock_websockets_connect + await model.connect() + + # Simulate function call sequence + # First: output_item.added with function name + item_added = { + "type": "response.output_item.added", + "item": { + "type": "function_call", + "call_id": "call-123", + "name": "calculator" + } + } + model._convert_openai_event(item_added) + + # Second: function_call_arguments.delta + args_delta = { + "type": "response.function_call_arguments.delta", + "call_id": "call-123", + "delta": '{"expression": "2+2"}' + } + model._convert_openai_event(args_delta) + + # Third: function_call_arguments.done + args_done = { + "type": "response.function_call_arguments.done", + "call_id": "call-123" + } + converted_event = model._convert_openai_event(args_done) + + assert "toolUse" in converted_event + assert converted_event["toolUse"]["toolUseId"] == "call-123" + assert converted_event["toolUse"]["name"] == "calculator" + assert converted_event["toolUse"]["input"]["expression"] == "2+2" + + +@pytest.mark.asyncio +async def test_receive_voice_activity(mock_websockets_connect, model): + """Test receiving voice activity events.""" + _, _ = mock_websockets_connect + await model.connect() + + # Test speech started + speech_started = { + "type": "input_audio_buffer.speech_started" + } + converted_event = model._convert_openai_event(speech_started) + + assert "voiceActivity" in converted_event + assert converted_event["voiceActivity"]["activityType"] == "speech_started" + + +# Close Method Tests + + +@pytest.mark.asyncio +async def test_close_connection(mock_websockets_connect, model): + """Test closing connection.""" + _, mock_ws = mock_websockets_connect + + await model.connect() + await model.close() + + assert model._active is False + mock_ws.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_close_when_not_connected(mock_websockets_connect, model): + """Test closing when not connected does nothing.""" + _, mock_ws = mock_websockets_connect + + # Don't connect + await model.close() + + # Should not raise, and close should not be called + mock_ws.close.assert_not_called() + + +@pytest.mark.asyncio +async def test_close_error_handling(mock_websockets_connect, model): + """Test close error handling.""" + _, mock_ws = mock_websockets_connect + mock_ws.close.side_effect = Exception("Close failed") + + await model.connect() + + # Should not raise, just log warning + await model.close() + assert model._active is False + + +@pytest.mark.asyncio +async def test_close_cancels_response_task(mock_websockets_connect, model): + """Test that close cancels the background response task.""" + _, _ = mock_websockets_connect + + await model.connect() + + # Verify response task is running + assert model._response_task is not None + assert not model._response_task.done() + + await model.close() + + # Task should be cancelled + assert model._response_task.cancelled() or model._response_task.done() + + +# Helper Method Tests + + +def test_build_session_config_basic(model): + """Test building basic session config.""" + config = model._build_session_config(None, None) + + assert isinstance(config, dict) + assert "instructions" in config + assert "audio" in config + + +def test_build_session_config_with_system_prompt(model, system_prompt): + """Test building config with system prompt.""" + config = model._build_session_config(system_prompt, None) + + assert config["instructions"] == system_prompt + + +def test_build_session_config_with_tools(model, tool_spec): + """Test building config with tools.""" + config = model._build_session_config(None, [tool_spec]) + + assert "tools" in config + assert len(config["tools"]) > 0 + + +def test_convert_tools_to_openai_format(model, tool_spec): + """Test tool conversion to OpenAI format.""" + openai_tools = model._convert_tools_to_openai_format([tool_spec]) + + assert len(openai_tools) == 1 + assert openai_tools[0]["type"] == "function" + assert openai_tools[0]["name"] == "calculator" + assert openai_tools[0]["description"] == "Calculate mathematical expressions" + + +def test_convert_tools_empty_list(model): + """Test converting empty tool list.""" + openai_tools = model._convert_tools_to_openai_format([]) + + assert openai_tools == [] + + +@pytest.mark.asyncio +async def test_send_event(mock_websockets_connect, model): + """Test sending event to WebSocket.""" + _, mock_ws = mock_websockets_connect + await model.connect() + + test_event = {"type": "test.event", "data": "test"} + await model._send_event(test_event) + + # Verify event was sent as JSON + calls = mock_ws.send.call_args_list + last_call = calls[-1] + sent_message = json.loads(last_call[0][0]) + + assert sent_message == test_event + + +def test_require_active(model): + """Test _require_active method.""" + assert model._require_active() is False + + model._active = True + assert model._require_active() is True + + +def test_create_text_event(model): + """Test creating text event.""" + event = model._create_text_event("Hello", "user") + + assert "textOutput" in event + assert event["textOutput"]["text"] == "Hello" + assert event["textOutput"]["role"] == "user" + + +def test_create_voice_activity_event(model): + """Test creating voice activity event.""" + event = model._create_voice_activity_event("speech_started") + + assert "voiceActivity" in event + assert event["voiceActivity"]["activityType"] == "speech_started" From 261e25fad104455447f8e4715820c16c8882456d Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 30 Oct 2025 14:05:38 +0100 Subject: [PATCH 026/242] fix: update bidirectional model docstrings --- .../models/bidirectional_model.py | 65 ++++++++++++------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 3af05e113..75a4ab5f0 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -1,13 +1,15 @@ -"""Unified bidirectional streaming interface. +"""Bidirectional streaming model interface. -Single layer combining model and session abstractions for simpler implementation. -Providers implement this directly without separate model/session classes. +Defines the abstract interface for models that support real-time bidirectional +communication with persistent connections. Unlike traditional request-response +models, bidirectional models maintain an open connection for streaming audio, +text, and tool interactions. Features: -- Unified model interface (no separate session class) -- Real-time bidirectional communication +- Persistent connection management with connect/close lifecycle +- Real-time bidirectional communication (send and receive simultaneously) - Provider-agnostic event normalization -- Tool execution integration +- Support for audio, text, image, and tool result streaming """ import abc @@ -27,10 +29,11 @@ class BidirectionalModel(abc.ABC): - """Unified interface for bidirectional streaming models. + """Abstract base class for bidirectional streaming models. - Combines model configuration and session communication in a single abstraction. - Providers implement this directly without separate model/session classes. + This interface defines the contract for models that support persistent streaming + connections with real-time audio and text communication. Implementations handle + provider-specific protocols while exposing a standardized event-based API. """ @abc.abstractmethod @@ -41,48 +44,60 @@ async def connect( messages: Messages | None = None, **kwargs, ) -> None: - """Establish bidirectional connection with the model. + """Establish a persistent streaming connection with the model. - Initializes the connection state and prepares for real-time communication. - This replaces the old create_bidirectional_connection pattern. + Opens a bidirectional connection that remains active for real-time communication. + The connection supports concurrent sending and receiving of events until explicitly + closed. Must be called before any send() or receive() operations. Args: - system_prompt: System instructions for the model. - tools: List of tools available to the model. - messages: Conversation history to initialize with. + system_prompt: System instructions to configure model behavior. + tools: Tool specifications that the model can invoke during the conversation. + messages: Initial conversation history to provide context. **kwargs: Provider-specific configuration options. """ raise NotImplementedError @abc.abstractmethod async def close(self) -> None: - """Close connection and cleanup resources. + """Close the streaming connection and release resources. - Terminates the active connection and releases any held resources. + Terminates the active bidirectional connection and cleans up any associated + resources such as network connections, buffers, or background tasks. After + calling close(), the model instance cannot be used until connect() is called again. """ raise NotImplementedError @abc.abstractmethod async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: - """Receive events from the model in standardized format. + """Receive streaming events from the model. - Yields provider-agnostic events that can be processed uniformly - by the event loop. Converts provider-specific events to common format. + Continuously yields events from the model as they arrive over the connection. + Events are normalized to a provider-agnostic format for uniform processing. + This method should be called in a loop or async task to process model responses. + + The stream continues until the connection is closed or an error occurs. Yields: - BidirectionalStreamEvent: Standardized event dictionaries. + BidirectionalStreamEvent: Standardized event dictionaries containing + audio output, text responses, tool calls, or control signals. """ raise NotImplementedError @abc.abstractmethod async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: - """Send structured content to the model. + """Send content to the model over the active connection. - Unified method for sending all types of content. Implementations should - dispatch to appropriate internal handlers based on content type. + Transmits user input or tool results to the model during an active streaming + session. Supports multiple content types including text, audio, images, and + tool execution results. Can be called multiple times during a conversation. Args: - content: Typed event (TextInputEvent, ImageInputEvent, AudioInputEvent, or ToolResult). + content: The content to send. Must be one of: + - TextInputEvent: Text message from the user + - ImageInputEvent: Image data for visual understanding + - AudioInputEvent: Audio data for speech input + - ToolResult: Result from a tool execution Example: await model.send(TextInputEvent(text="Hello", role="user")) From a9d8c88376ffe2cf8ab43217fc71a0bdafb96d80 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 30 Oct 2025 14:23:37 +0100 Subject: [PATCH 027/242] fix: remove base session references --- .../bidirectional_streaming/__init__.py | 7 ++-- .../bidirectional_streaming/agent/agent.py | 8 ++--- .../event_loop/bidirectional_event_loop.py | 35 +++++++++---------- .../models/__init__.py | 3 +- .../models/bidirectional_model.py | 4 --- 5 files changed, 24 insertions(+), 33 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index e31bc670e..d855ba038 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -3,8 +3,8 @@ # Main components - Primary user interface from .agent.agent import BidirectionalAgent -# Unified model interface (for custom implementations) -from .models.bidirectional_model import BidirectionalModel, BidirectionalModelSession +# Model interface (for custom implementations) +from .models.bidirectional_model import BidirectionalModel # Model providers - What users need to create models from .models.gemini_live import GeminiLiveBidirectionalModel @@ -44,7 +44,6 @@ "VoiceActivityEvent", "UsageMetricsEvent", - # Unified model interface + # Model interface "BidirectionalModel", - "BidirectionalModelSession", # Backwards compatibility alias ] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 62528d472..c9d7292b8 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -379,15 +379,15 @@ async def send(self, input_data: str | AudioInputEvent | ImageInputEvent) -> Non self.messages.append({"role": "user", "content": input_data}) logger.debug("Text sent: %d characters", len(input_data)) - # Create TextInputEvent for unified send() + # Create TextInputEvent for send() text_event = {"text": input_data, "role": "user"} - await self._session.model_session.send(text_event) + await self._session.model.send(text_event) elif isinstance(input_data, dict) and "audioData" in input_data: # Handle audio input - already in AudioInputEvent format - await self._session.model_session.send(input_data) + await self._session.model.send(input_data) elif isinstance(input_data, dict) and "imageData" in input_data: # Handle image input - already in ImageInputEvent format - await self._session.model_session.send(input_data) + await self._session.model.send(input_data) else: raise ValueError( "Input must be either a string (text), AudioInputEvent " diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 521ebc0dd..d1d6e90b3 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -37,14 +37,14 @@ class BidirectionalConnection: handling while providing a simple interface for agent interactions. """ - def __init__(self, model_session: BidirectionalModel, agent: "BidirectionalAgent") -> None: - """Initialize session with model and agent reference. + def __init__(self, model: BidirectionalModel, agent: "BidirectionalAgent") -> None: + """Initialize connection with model and agent reference. Args: - model_session: Bidirectional model instance (unified interface). + model: Bidirectional model instance. agent: BidirectionalAgent instance for tool registry access. """ - self.model_session = model_session + self.model = model self.agent = agent self.active = True @@ -78,16 +78,13 @@ async def start_bidirectional_connection(agent: "BidirectionalAgent") -> Bidirec """ logger.debug("Starting bidirectional session - initializing model connection") - # Connect to model using unified interface + # Connect to model await agent.model.connect( system_prompt=agent.system_prompt, tools=agent.tool_registry.get_all_tool_specs(), messages=agent.messages ) - - # Use the model directly (unified interface - no separate session) - model_session = agent.model - # Create session wrapper for background processing - session = BidirectionalConnection(model_session=model_session, agent=agent) + # Create connection wrapper for background processing + session = BidirectionalConnection(model=agent.model, agent=agent) # Start concurrent background processors IMMEDIATELY after session creation # This is critical - Nova Sonic needs response processing during initialization @@ -138,9 +135,9 @@ async def stop_bidirectional_connection(session: BidirectionalConnection) -> Non if all_tasks: await asyncio.gather(*all_tasks, return_exceptions=True) - # Close model session - await session.model_session.close() - logger.debug("Session closed") + # Close model connection + await session.model.close() + logger.debug("Connection closed") async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None: @@ -256,11 +253,11 @@ async def _process_model_events(session: BidirectionalConnection) -> None: events to standardized formats, and manages interruption detection. Args: - session: BidirectionalConnection containing model session. + session: BidirectionalConnection containing model. """ logger.debug("Model events processor started") try: - async for provider_event in session.model_session.receive(): + async for provider_event in session.model.receive(): if not session.active: break @@ -437,8 +434,8 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: tool_result = tool_event.tool_result tool_use_id = tool_result.get("toolUseId") - # Send result through unified send() method - await session.model_session.send(tool_result) + # Send result through send() method + await session.model.send(tool_result) logger.debug("Tool result sent: %s", tool_use_id) # Handle streaming events if needed later @@ -474,10 +471,10 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: "content": [{"text": f"Error: {str(e)}"}] } try: - await session.model_session.send(error_result) + await session.model.send(error_result) logger.debug("Error result sent: %s", tool_id) except Exception: logger.error("Failed to send error result: %s", tool_id) - pass # Session might be closed + pass # Connection might be closed diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index e2745310c..12fe6c271 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -1,13 +1,12 @@ """Bidirectional model interfaces and implementations.""" -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .bidirectional_model import BidirectionalModel from .gemini_live import GeminiLiveBidirectionalModel from .novasonic import NovaSonicBidirectionalModel from .openai import OpenAIRealtimeBidirectionalModel __all__ = [ "BidirectionalModel", - "BidirectionalModelSession", # Backwards compatibility alias "GeminiLiveBidirectionalModel", "NovaSonicBidirectionalModel", "OpenAIRealtimeBidirectionalModel", diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 75a4ab5f0..5b7091dcd 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -105,7 +105,3 @@ async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputE await model.send(ToolResult(toolUseId="123", status="success", ...)) """ raise NotImplementedError - - -# Backwards compatibility alias - will be removed in future version -BidirectionalModelSession = BidirectionalModel From 17686d489679f4a3a61b7819841195bba5b02e6d Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 30 Oct 2025 14:38:20 +0100 Subject: [PATCH 028/242] feat: throw exceptions when connect is called on already active connection --- .../bidirectional_streaming/models/gemini_live.py | 11 +++++++---- .../bidirectional_streaming/models/novasonic.py | 3 +++ .../bidirectional_streaming/models/openai.py | 3 +++ .../models/test_gemini_live.py | 13 +++++++++++++ .../models/test_novasonic.py | 14 ++++++++++++++ .../models/test_openai_realtime.py | 13 +++++++++++++ 6 files changed, 53 insertions(+), 4 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 578de5a2b..dabd1174b 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -1,11 +1,11 @@ """Gemini Live API bidirectional model provider using official Google GenAI SDK. -Implements the unified BidirectionalModel interface for Google's Gemini Live API using the +Implements the BidirectionalModel interface for Google's Gemini Live API using the official Google GenAI SDK for simplified and robust WebSocket communication. Key improvements over custom WebSocket implementation: - Uses official google-genai SDK with native Live API support -- Unified model interface (no separate session class) +- Simplified session management with client.aio.live.connect() - Built-in tool integration and event handling - Automatic WebSocket connection management and error handling - Native support for audio/text streaming and interruption @@ -45,7 +45,7 @@ class GeminiLiveBidirectionalModel(BidirectionalModel): - """Unified Gemini Live API implementation using official Google GenAI SDK. + """Gemini Live API implementation using official Google GenAI SDK. Combines model configuration and connection state in a single class. Provides a clean interface to Gemini Live API using the official SDK, @@ -101,6 +101,9 @@ async def connect( messages: Conversation history to initialize with. **kwargs: Additional configuration options. """ + if self._active: + raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") + try: # Initialize connection state self.session_id = str(uuid.uuid4()) @@ -277,7 +280,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic return None async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: - """Unified send method for all content types. + """Unified send method for all content types. Sends the given inputs to Google Live API Dispatches to appropriate internal handler based on content type. diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 62b53a127..ee1bcb573 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -138,6 +138,9 @@ async def connect( messages: Conversation history to initialize with. **kwargs: Additional configuration options. """ + if self._active: + raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") + logger.debug("Nova connection create - starting") try: diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 0208ee162..b62d4fa02 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -117,6 +117,9 @@ async def connect( messages: Conversation history to initialize with. **kwargs: Additional configuration options. """ + if self._active: + raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") + logger.info("Creating OpenAI Realtime connection...") try: diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index a5baaa522..de8fcfd56 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -185,6 +185,19 @@ async def test_connect_error_handling(mock_genai_client, model): await model.connect() +@pytest.mark.asyncio +async def test_connect_when_already_active(mock_genai_client, model): + """Test that connect() raises exception when already active.""" + mock_client, _, _ = mock_genai_client + + # First connection + await model.connect() + + # Second connection attempt should raise + with pytest.raises(RuntimeError, match="Connection already active"): + await model.connect() + + # Send Method Tests diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 451a98aa2..59c762b3e 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -110,6 +110,20 @@ async def test_connect_sends_initialization_events(nova_model, mock_client, mock assert mock_stream.input_stream.send.call_count >= 3 # connectionStart, promptStart, system prompt +@pytest.mark.asyncio +async def test_connect_when_already_active(nova_model, mock_client, mock_stream): + """Test that connect() raises exception when already active.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + # First connection + await nova_model.connect() + + # Second connection attempt should raise + with pytest.raises(RuntimeError, match="Connection already active"): + await nova_model.connect() + + @pytest.mark.asyncio async def test_close_cleanup(nova_model, mock_client, mock_stream): """Test that close() properly cleans up resources.""" diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index be69929cd..6183765ae 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -207,6 +207,19 @@ async def test_connect_error_handling(mock_websockets_connect, model): await model.connect() +@pytest.mark.asyncio +async def test_connect_when_already_active(mock_websockets_connect, model): + """Test that connect() raises exception when already active.""" + mock_connect, _ = mock_websockets_connect + + # First connection + await model.connect() + + # Second connection attempt should raise + with pytest.raises(RuntimeError, match="Connection already active"): + await model.connect() + + @pytest.mark.asyncio async def test_connect_with_organization_header(mock_websockets_connect, api_key): """Test connection includes organization header.""" From c2f88f75404a5807b308ff6fa31d1ed0b85c6d8f Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 30 Oct 2025 14:50:42 +0100 Subject: [PATCH 029/242] feat: Add explicit init params for gemini and openai to free kwargs --- .../models/gemini_live.py | 18 +++++++------ .../models/novasonic.py | 10 ++++--- .../bidirectional_streaming/models/openai.py | 26 ++++++++++++------- .../models/test_gemini_live.py | 6 ++--- .../models/test_openai_realtime.py | 11 +++++--- 5 files changed, 45 insertions(+), 26 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index dabd1174b..639328c64 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -56,19 +56,21 @@ def __init__( self, model_id: str = "models/gemini-2.0-flash-live-preview-04-09", api_key: Optional[str] = None, - **config + live_config: Optional[Dict[str, Any]] = None, + **kwargs ): """Initialize Gemini Live API bidirectional model. Args: model_id: Gemini Live model identifier. api_key: Google AI API key for authentication. - **config: Additional configuration. + live_config: Gemini Live API configuration parameters (e.g., response_modalities, speech_config). + **kwargs: Reserved for future parameters. """ # Model configuration self.model_id = model_id self.api_key = api_key - self.config = config + self.live_config = live_config or {} # Create Gemini client with proper API version client_kwargs = {} @@ -423,15 +425,15 @@ def _build_live_config( ) -> Dict[str, Any]: """Build LiveConnectConfig for the official SDK. - Simply passes through all config parameters from params, allowing users + Simply passes through all config parameters from live_config, allowing users to configure any Gemini Live API parameter directly. """ - # Start with user config from params + # Start with user-provided live_config config_dict = {} - if "params" in self.config: - config_dict.update(self.config["params"]) + if self.live_config: + config_dict.update(self.live_config) - # Override with any kwargs + # Override with any kwargs from connect() config_dict.update(kwargs) # Add system instruction if provided diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index ee1bcb573..5436b5ae7 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -88,18 +88,22 @@ class NovaSonicBidirectionalModel(BidirectionalModel): tool execution patterns while providing the standard BidirectionalModel interface. """ - def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config: any) -> None: + def __init__( + self, + model_id: str = "amazon.nova-sonic-v1:0", + region: str = "us-east-1", + **kwargs + ) -> None: """Initialize Nova Sonic bidirectional model. Args: model_id: Nova Sonic model identifier. region: AWS region. - **config: Additional configuration. + **kwargs: Reserved for future parameters. """ # Model configuration self.model_id = model_id self.region = region - self.config = config self._client = None # Connection state (initialized in connect()) diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index b62d4fa02..8322eef4b 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -71,19 +71,27 @@ def __init__( self, model: str = DEFAULT_MODEL, api_key: str | None = None, - **config: any + organization: str | None = None, + project: str | None = None, + session_config: dict[str, any] | None = None, + **kwargs ) -> None: """Initialize OpenAI Realtime bidirectional model. Args: model: OpenAI model identifier (default: gpt-realtime). api_key: OpenAI API key for authentication. - **config: Additional configuration (organization, project, session params). + organization: OpenAI organization ID for API requests. + project: OpenAI project ID for API requests. + session_config: Session configuration parameters (e.g., voice, turn_detection, modalities). + **kwargs: Reserved for future parameters. """ # Model configuration self.model = model self.api_key = api_key - self.config = config + self.organization = organization + self.project = project + self.session_config = session_config or {} import os if not self.api_key: @@ -133,10 +141,10 @@ async def connect( url = f"{OPENAI_REALTIME_URL}?model={self.model}" headers = [("Authorization", f"Bearer {self.api_key}")] - if "organization" in self.config: - headers.append(("OpenAI-Organization", self.config["organization"])) - if "project" in self.config: - headers.append(("OpenAI-Project", self.config["project"])) + if self.organization: + headers.append(("OpenAI-Organization", self.organization)) + if self.project: + headers.append(("OpenAI-Project", self.project)) self.websocket = await websockets.connect(url, additional_headers=headers) logger.info("WebSocket connected successfully") @@ -181,14 +189,14 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] if tools: config["tools"] = self._convert_tools_to_openai_format(tools) - custom_config = self.config.get("session", {}) + # Apply user-provided session configuration supported_params = { "type", "output_modalities", "instructions", "voice", "audio", "tools", "tool_choice", "input_audio_format", "output_audio_format", "input_audio_transcription", "turn_detection" } - for key, value in custom_config.items(): + for key, value in self.session_config.items(): if key in supported_params: config[key] = value else: diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index de8fcfd56..5dec7ca2d 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -115,10 +115,10 @@ def test_init_with_custom_config(mock_genai_client, model_id): """Test model initialization with custom configuration.""" _ = mock_genai_client - custom_config = {"temperature": 0.7, "top_p": 0.9} - model = GeminiLiveBidirectionalModel(model_id=model_id, **custom_config) + live_config = {"temperature": 0.7, "top_p": 0.9} + model = GeminiLiveBidirectionalModel(model_id=model_id, live_config=live_config) - assert model.config == custom_config + assert model.live_config == live_config # Connection Tests diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 6183765ae..ad0d3993a 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -103,10 +103,15 @@ def test_init_with_api_key(api_key, model_name): def test_init_with_custom_config(model_name, api_key): """Test model initialization with custom configuration.""" - custom_config = {"organization": "org-123", "project": "proj-456"} - model = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key, **custom_config) + model = OpenAIRealtimeBidirectionalModel( + model=model_name, + api_key=api_key, + organization="org-123", + project="proj-456" + ) - assert model.config == custom_config + assert model.organization == "org-123" + assert model.project == "proj-456" def test_init_without_api_key_raises(): From 55554aac6338cb89be41690dd01cd8ef660020a3 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 30 Oct 2025 15:16:22 +0100 Subject: [PATCH 030/242] test: Consolidate bidi model tests --- .../models/test_gemini_live.py | 424 +++++-------- .../models/test_novasonic.py | 422 +++++-------- .../models/test_openai_realtime.py | 564 ++++++------------ 3 files changed, 481 insertions(+), 929 deletions(-) diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index 5dec7ca2d..b894509c9 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -2,14 +2,12 @@ Tests the unified GeminiLiveBidirectionalModel interface including: - Model initialization and configuration -- Connection establishment +- Connection establishment and lifecycle - Unified send() method with different content types - Event receiving and conversion -- Connection lifecycle management """ import unittest.mock -import uuid import pytest from google import genai @@ -84,146 +82,121 @@ def messages(): # Initialization Tests -def test_init_default_config(mock_genai_client): - """Test model initialization with default configuration.""" +def test_model_initialization(mock_genai_client, model_id, api_key): + """Test model initialization with various configurations.""" _ = mock_genai_client - model = GeminiLiveBidirectionalModel() + # Test default config + model_default = GeminiLiveBidirectionalModel() + assert model_default.model_id == "models/gemini-2.0-flash-live-preview-04-09" + assert model_default.api_key is None + assert model_default._active is False + assert model_default.live_session is None - assert model.model_id == "models/gemini-2.0-flash-live-preview-04-09" - assert model.api_key is None - assert model._active is False - assert model.live_session is None - - -def test_init_with_api_key(mock_genai_client, model_id, api_key): - """Test model initialization with API key.""" - mock_client, _, _ = mock_genai_client - - model = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) - - assert model.model_id == model_id - assert model.api_key == api_key - - # Verify client was created with correct parameters - mock_client_cls = unittest.mock.patch("strands.experimental.bidirectional_streaming.models.gemini_live.genai.Client").start() - GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) - mock_client_cls.assert_called() - - -def test_init_with_custom_config(mock_genai_client, model_id): - """Test model initialization with custom configuration.""" - _ = mock_genai_client + # Test with API key + model_with_key = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + assert model_with_key.model_id == model_id + assert model_with_key.api_key == api_key + # Test with custom config live_config = {"temperature": 0.7, "top_p": 0.9} - model = GeminiLiveBidirectionalModel(model_id=model_id, live_config=live_config) - - assert model.live_config == live_config + model_custom = GeminiLiveBidirectionalModel(model_id=model_id, live_config=live_config) + assert model_custom.live_config == live_config # Connection Tests @pytest.mark.asyncio -async def test_connect_basic(mock_genai_client, model): - """Test basic connection establishment.""" - mock_client, mock_live_session, _ = mock_genai_client +async def test_connection_lifecycle(mock_genai_client, model, system_prompt, tool_spec, messages): + """Test complete connection lifecycle with various configurations.""" + mock_client, mock_live_session, mock_live_session_cm = mock_genai_client + # Test basic connection await model.connect() - assert model._active is True assert model.session_id is not None assert model.live_session == mock_live_session mock_client.aio.live.connect.assert_called_once() - - -@pytest.mark.asyncio -async def test_connect_with_system_prompt(mock_genai_client, model, system_prompt): - """Test connection with system prompt.""" - mock_client, _, _ = mock_genai_client - await model.connect(system_prompt=system_prompt) + # Test close + await model.close() + assert model._active is False + mock_live_session_cm.__aexit__.assert_called_once() - # Verify system prompt was included in config + # Test connection with system prompt + await model.connect(system_prompt=system_prompt) call_args = mock_client.aio.live.connect.call_args config = call_args.kwargs.get("config", {}) assert config.get("system_instruction") == system_prompt - - -@pytest.mark.asyncio -async def test_connect_with_tools(mock_genai_client, model, tool_spec): - """Test connection with tools.""" - mock_client, _, _ = mock_genai_client + await model.close() + # Test connection with tools await model.connect(tools=[tool_spec]) - - # Verify tools were formatted and included call_args = mock_client.aio.live.connect.call_args config = call_args.kwargs.get("config", {}) assert "tools" in config assert len(config["tools"]) > 0 - - -@pytest.mark.asyncio -async def test_connect_with_messages(mock_genai_client, model, messages): - """Test connection with message history.""" - _, mock_live_session, _ = mock_genai_client + await model.close() + # Test connection with messages await model.connect(messages=messages) - - # Verify message history was sent mock_live_session.send_client_content.assert_called() + await model.close() @pytest.mark.asyncio -async def test_connect_error_handling(mock_genai_client, model): - """Test connection error handling.""" - mock_client, _, _ = mock_genai_client - mock_client.aio.live.connect.side_effect = Exception("Connection failed") +async def test_connection_edge_cases(mock_genai_client, api_key, model_id): + """Test connection error handling and edge cases.""" + mock_client, _, mock_live_session_cm = mock_genai_client + # Test connection error + model1 = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + mock_client.aio.live.connect.side_effect = Exception("Connection failed") with pytest.raises(Exception, match="Connection failed"): - await model.connect() - - -@pytest.mark.asyncio -async def test_connect_when_already_active(mock_genai_client, model): - """Test that connect() raises exception when already active.""" - mock_client, _, _ = mock_genai_client + await model1.connect() - # First connection - await model.connect() + # Reset mock for next tests + mock_client.aio.live.connect.side_effect = None - # Second connection attempt should raise + # Test double connection + model2 = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + await model2.connect() with pytest.raises(RuntimeError, match="Connection already active"): - await model.connect() + await model2.connect() + await model2.close() + + # Test close when not connected + model3 = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + await model3.close() # Should not raise + + # Test close error handling + model4 = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + await model4.connect() + mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") + with pytest.raises(Exception, match="Close failed"): + await model4.close() # Send Method Tests @pytest.mark.asyncio -async def test_send_text_input(mock_genai_client, model): - """Test sending text input through unified send() method.""" +async def test_send_all_content_types(mock_genai_client, model): + """Test sending all content types through unified send() method.""" _, mock_live_session, _ = mock_genai_client await model.connect() + # Test text input text_input: TextInputEvent = {"text": "Hello", "role": "user"} await model.send(text_input) - - # Verify text was sent via send_client_content mock_live_session.send_client_content.assert_called_once() call_args = mock_live_session.send_client_content.call_args content = call_args.kwargs.get("turns") assert content.role == "user" assert content.parts[0].text == "Hello" - - -@pytest.mark.asyncio -async def test_send_audio_input(mock_genai_client, model): - """Test sending audio input through unified send() method.""" - _, mock_live_session, _ = mock_genai_client - await model.connect() + # Test audio input audio_input: AudioInputEvent = { "audioData": b"audio_bytes", "format": "pcm", @@ -231,102 +204,59 @@ async def test_send_audio_input(mock_genai_client, model): "channels": 1, } await model.send(audio_input) - - # Verify audio was sent via send_realtime_input mock_live_session.send_realtime_input.assert_called_once() - - -@pytest.mark.asyncio -async def test_send_image_input(mock_genai_client, model): - """Test sending image input through unified send() method.""" - _, mock_live_session, _ = mock_genai_client - await model.connect() + # Test image input image_input: ImageInputEvent = { "imageData": b"image_bytes", "mimeType": "image/jpeg", "encoding": "raw", } await model.send(image_input) - - # Verify image was sent mock_live_session.send.assert_called_once() - - -@pytest.mark.asyncio -async def test_send_tool_result(mock_genai_client, model): - """Test sending tool result through unified send() method.""" - _, mock_live_session, _ = mock_genai_client - await model.connect() + # Test tool result tool_result: ToolResult = { "toolUseId": "tool-123", "status": "success", "content": [{"text": "Result: 42"}], } await model.send(tool_result) - - # Verify tool result was sent mock_live_session.send_tool_response.assert_called_once() + + await model.close() @pytest.mark.asyncio -async def test_send_when_inactive(mock_genai_client, model): - """Test that send() does nothing when connection is inactive.""" +async def test_send_edge_cases(mock_genai_client, model): + """Test send() edge cases and error handling.""" _, mock_live_session, _ = mock_genai_client - # Don't connect, so _active is False + # Test send when inactive text_input: TextInputEvent = {"text": "Hello", "role": "user"} await model.send(text_input) - - # Verify nothing was sent mock_live_session.send_client_content.assert_not_called() - - -@pytest.mark.asyncio -async def test_send_unknown_content_type(mock_genai_client, model): - """Test sending unknown content type logs warning.""" - _, _, _ = mock_genai_client - await model.connect() + # Test unknown content type + await model.connect() unknown_content = {"unknown_field": "value"} + await model.send(unknown_content) # Should not raise, just log warning - # Should not raise, just log warning - await model.send(unknown_content) + await model.close() # Receive Method Tests @pytest.mark.asyncio -async def test_receive_connection_start_event(mock_genai_client, model, agenerator): - """Test that receive() emits connection start event.""" +async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): + """Test that receive() emits connection start and end events.""" _, mock_live_session, _ = mock_genai_client mock_live_session.receive.return_value = agenerator([]) await model.connect() - # Get first event - receive_gen = model.receive() - first_event = await anext(receive_gen) - - # First event should be connection start - assert "BidirectionalConnectionStart" in first_event - assert first_event["BidirectionalConnectionStart"]["connectionId"] == model.session_id - - # Close to stop the loop - await model.close() - - -@pytest.mark.asyncio -async def test_receive_connection_end_event(mock_genai_client, model, agenerator): - """Test that receive() emits connection end event.""" - _, mock_live_session, _ = mock_genai_client - mock_live_session.receive.return_value = agenerator([]) - - await model.connect() - - # Collect events until connection ends + # Collect events events = [] async for event in model.receive(): events.append(event) @@ -334,57 +264,44 @@ async def test_receive_connection_end_event(mock_genai_client, model, agenerator if len(events) == 1: await model.close() - # Last event should be connection end + # Verify connection start and end + assert len(events) >= 2 + assert "BidirectionalConnectionStart" in events[0] + assert events[0]["BidirectionalConnectionStart"]["connectionId"] == model.session_id assert "BidirectionalConnectionEnd" in events[-1] @pytest.mark.asyncio -async def test_receive_text_output(mock_genai_client, model): - """Test receiving text output from model.""" - _, mock_live_session, _ = mock_genai_client - - mock_message = unittest.mock.Mock() - mock_message.text = "Hello from Gemini" - mock_message.data = None - mock_message.tool_call = None - mock_message.server_content = None - - await model.connect() - - # Test the conversion method directly - converted_event = model._convert_gemini_live_event(mock_message) - - assert "textOutput" in converted_event - assert converted_event["textOutput"]["text"] == "Hello from Gemini" - assert converted_event["textOutput"]["role"] == "assistant" - - -@pytest.mark.asyncio -async def test_receive_audio_output(mock_genai_client, model): - """Test receiving audio output from model.""" - _, mock_live_session, _ = mock_genai_client - - mock_message = unittest.mock.Mock() - mock_message.text = None - mock_message.data = b"audio_data" - mock_message.tool_call = None - mock_message.server_content = None - +async def test_event_conversion(mock_genai_client, model): + """Test conversion of all Gemini Live event types to standard format.""" + _, _, _ = mock_genai_client await model.connect() - # Test the conversion method directly - converted_event = model._convert_gemini_live_event(mock_message) - - assert "audioOutput" in converted_event - assert converted_event["audioOutput"]["audioData"] == b"audio_data" - assert converted_event["audioOutput"]["format"] == "pcm" - - -@pytest.mark.asyncio -async def test_receive_tool_call(mock_genai_client, model): - """Test receiving tool call from model.""" - _, mock_live_session, _ = mock_genai_client - + # Test text output + mock_text = unittest.mock.Mock() + mock_text.text = "Hello from Gemini" + mock_text.data = None + mock_text.tool_call = None + mock_text.server_content = None + + text_event = model._convert_gemini_live_event(mock_text) + assert "textOutput" in text_event + assert text_event["textOutput"]["text"] == "Hello from Gemini" + assert text_event["textOutput"]["role"] == "assistant" + + # Test audio output + mock_audio = unittest.mock.Mock() + mock_audio.text = None + mock_audio.data = b"audio_data" + mock_audio.tool_call = None + mock_audio.server_content = None + + audio_event = model._convert_gemini_live_event(mock_audio) + assert "audioOutput" in audio_event + assert audio_event["audioOutput"]["audioData"] == b"audio_data" + assert audio_event["audioOutput"]["format"] == "pcm" + + # Test tool call mock_func_call = unittest.mock.Mock() mock_func_call.id = "tool-123" mock_func_call.name = "calculator" @@ -393,121 +310,62 @@ async def test_receive_tool_call(mock_genai_client, model): mock_tool_call = unittest.mock.Mock() mock_tool_call.function_calls = [mock_func_call] - mock_message = unittest.mock.Mock() - mock_message.text = None - mock_message.data = None - mock_message.tool_call = mock_tool_call - mock_message.server_content = None - - await model.connect() + mock_tool = unittest.mock.Mock() + mock_tool.text = None + mock_tool.data = None + mock_tool.tool_call = mock_tool_call + mock_tool.server_content = None - # Test the conversion method directly - converted_event = model._convert_gemini_live_event(mock_message) - - assert "toolUse" in converted_event - assert converted_event["toolUse"]["toolUseId"] == "tool-123" - assert converted_event["toolUse"]["name"] == "calculator" - - -@pytest.mark.asyncio -async def test_receive_interruption(mock_genai_client, model): - """Test receiving interruption event.""" - _, mock_live_session, _ = mock_genai_client + tool_event = model._convert_gemini_live_event(mock_tool) + assert "toolUse" in tool_event + assert tool_event["toolUse"]["toolUseId"] == "tool-123" + assert tool_event["toolUse"]["name"] == "calculator" + # Test interruption mock_server_content = unittest.mock.Mock() mock_server_content.interrupted = True mock_server_content.input_transcription = None mock_server_content.output_transcription = None - mock_message = unittest.mock.Mock() - mock_message.text = None - mock_message.data = None - mock_message.tool_call = None - mock_message.server_content = mock_server_content + mock_interrupt = unittest.mock.Mock() + mock_interrupt.text = None + mock_interrupt.data = None + mock_interrupt.tool_call = None + mock_interrupt.server_content = mock_server_content - await model.connect() + interrupt_event = model._convert_gemini_live_event(mock_interrupt) + assert "interruptionDetected" in interrupt_event + assert interrupt_event["interruptionDetected"]["reason"] == "user_input" - # Test the conversion method directly - converted_event = model._convert_gemini_live_event(mock_message) - - assert "interruptionDetected" in converted_event - assert converted_event["interruptionDetected"]["reason"] == "user_input" - - -# Close Method Tests - - -@pytest.mark.asyncio -async def test_close_connection(mock_genai_client, model): - """Test closing connection.""" - _, _, mock_live_session_cm = mock_genai_client - - await model.connect() - await model.close() - - assert model._active is False - mock_live_session_cm.__aexit__.assert_called_once() - - -@pytest.mark.asyncio -async def test_close_when_not_connected(mock_genai_client, model): - """Test closing when not connected does nothing.""" - _, _, mock_live_session_cm = mock_genai_client - - # Don't connect await model.close() - - # Should not raise, and __aexit__ should not be called - mock_live_session_cm.__aexit__.assert_not_called() - - -@pytest.mark.asyncio -async def test_close_error_handling(mock_genai_client, model): - """Test close error handling.""" - _, _, mock_live_session_cm = mock_genai_client - mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") - - await model.connect() - - with pytest.raises(Exception, match="Close failed"): - await model.close() # Helper Method Tests -def test_build_live_config_basic(model): - """Test building basic live config.""" - config = model._build_live_config() +def test_config_building(model, system_prompt, tool_spec): + """Test building live config with various options.""" + # Test basic config + config_basic = model._build_live_config() + assert isinstance(config_basic, dict) - assert isinstance(config, dict) - - -def test_build_live_config_with_system_prompt(model, system_prompt): - """Test building config with system prompt.""" - config = model._build_live_config(system_prompt=system_prompt) + # Test with system prompt + config_prompt = model._build_live_config(system_prompt=system_prompt) + assert config_prompt["system_instruction"] == system_prompt - assert config["system_instruction"] == system_prompt + # Test with tools + config_tools = model._build_live_config(tools=[tool_spec]) + assert "tools" in config_tools + assert len(config_tools["tools"]) > 0 -def test_build_live_config_with_tools(model, tool_spec): - """Test building config with tools.""" - config = model._build_live_config(tools=[tool_spec]) - - assert "tools" in config - assert len(config["tools"]) > 0 - - -def test_format_tools_for_live_api(model, tool_spec): +def test_tool_formatting(model, tool_spec): """Test tool formatting for Gemini Live API.""" + # Test with tools formatted_tools = model._format_tools_for_live_api([tool_spec]) - assert len(formatted_tools) == 1 assert isinstance(formatted_tools[0], genai_types.Tool) - - -def test_format_tools_empty_list(model): - """Test formatting empty tool list.""" - formatted_tools = model._format_tools_for_live_api([]) - assert formatted_tools == [] + # Test empty list + formatted_empty = model._format_tools_for_live_api([]) + assert formatted_empty == [] diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 59c762b3e..10066a693 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -7,9 +7,7 @@ import asyncio import base64 import json -import uuid -from typing import Any, Dict -from unittest.mock import AsyncMock, MagicMock, Mock, patch +from unittest.mock import AsyncMock, patch import pytest import pytest_asyncio @@ -17,7 +15,7 @@ from strands.experimental.bidirectional_streaming.models.novasonic import ( NovaSonicBidirectionalModel, ) -from strands.types.tools import ToolResult, ToolSpec +from strands.types.tools import ToolResult # Test fixtures @@ -62,12 +60,14 @@ async def nova_model(model_id, region): await model.close() -# Connection lifecycle tests +# Initialization and Connection Tests + + @pytest.mark.asyncio async def test_model_initialization(model_id, region): """Test model initialization with configuration.""" model = NovaSonicBidirectionalModel(model_id=model_id, region=region) - + assert model.model_id == model_id assert model.region == region assert model.stream is None @@ -76,26 +76,24 @@ async def test_model_initialization(model_id, region): @pytest.mark.asyncio -async def test_connect_establishes_connection(nova_model, mock_client, mock_stream): - """Test that connect() establishes bidirectional connection.""" +async def test_connection_lifecycle(nova_model, mock_client, mock_stream): + """Test complete connection lifecycle with various configurations.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model._client = mock_client - + + # Test basic connection await nova_model.connect(system_prompt="Test system prompt") - assert nova_model._active assert nova_model.stream == mock_stream assert nova_model.prompt_name is not None assert mock_client.invoke_model_with_bidirectional_stream.called + # Test close + await nova_model.close() + assert not nova_model._active + assert mock_stream.input_stream.close.called -@pytest.mark.asyncio -async def test_connect_sends_initialization_events(nova_model, mock_client, mock_stream): - """Test that connect() sends proper initialization sequence.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client - - system_prompt = "You are a helpful assistant" + # Test connection with tools tools = [ { "name": "get_weather", @@ -103,108 +101,147 @@ async def test_connect_sends_initialization_events(nova_model, mock_client, mock "inputSchema": {"json": json.dumps({"type": "object", "properties": {}})} } ] - - await nova_model.connect(system_prompt=system_prompt, tools=tools) - - # Verify initialization events were sent - assert mock_stream.input_stream.send.call_count >= 3 # connectionStart, promptStart, system prompt + await nova_model.connect(system_prompt="You are helpful", tools=tools) + # Verify initialization events were sent (connectionStart, promptStart, system prompt) + assert mock_stream.input_stream.send.call_count >= 3 + await nova_model.close() @pytest.mark.asyncio -async def test_connect_when_already_active(nova_model, mock_client, mock_stream): - """Test that connect() raises exception when already active.""" +async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model_id, region): + """Test connection error handling and edge cases.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model._client = mock_client - - # First connection + + # Test double connection await nova_model.connect() - - # Second connection attempt should raise with pytest.raises(RuntimeError, match="Connection already active"): await nova_model.connect() + await nova_model.close() + + # Test close when already closed + model2 = NovaSonicBidirectionalModel(model_id=model_id, region=region) + await model2.close() # Should not raise + await model2.close() # Second call should also be safe + + +# Send Method Tests @pytest.mark.asyncio -async def test_close_cleanup(nova_model, mock_client, mock_stream): - """Test that close() properly cleans up resources.""" +async def test_send_all_content_types(nova_model, mock_client, mock_stream): + """Test sending all content types through unified send() method.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model._client = mock_client - + await nova_model.connect() + + # Test text content + text_event = {"text": "Hello, Nova!", "role": "user"} + await nova_model.send(text_event) + # Should send contentStart, textInput, and contentEnd + assert mock_stream.input_stream.send.call_count >= 3 + + # Test audio content + audio_event = { + "audioData": b"audio data", + "format": "pcm", + "sampleRate": 16000, + "channels": 1 + } + await nova_model.send(audio_event) + # Should start audio connection and send audio + assert nova_model.audio_connection_active + assert mock_stream.input_stream.send.called + + # Test tool result + tool_result = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Weather is sunny"}] + } + await nova_model.send(tool_result) + # Should send contentStart, toolResult, and contentEnd + assert mock_stream.input_stream.send.called + await nova_model.close() - - assert not nova_model._active - assert mock_stream.input_stream.close.called -# Event conversion tests @pytest.mark.asyncio -async def test_receive_emits_connection_start(nova_model, mock_client, mock_stream): - """Test that receive() emits connection start event.""" +async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog): + """Test send() edge cases and error handling.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model._client = mock_client - + + # Test send when inactive + text_event = {"text": "Hello", "role": "user"} + await nova_model.send(text_event) # Should not raise + + # Test image content (not supported) + await nova_model.connect() + image_event = { + "imageData": b"image data", + "mimeType": "image/jpeg" + } + await nova_model.send(image_event) + # Should log warning about unsupported image input + assert any("not supported" in record.message.lower() for record in caplog.records) + + await nova_model.close() + + +# Receive and Event Conversion Tests + + +@pytest.mark.asyncio +async def test_receive_lifecycle_events(nova_model, mock_client, mock_stream): + """Test that receive() emits connection start and end events.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + # Setup mock to return no events and then stop async def mock_wait_for(*args, **kwargs): await asyncio.sleep(0.1) nova_model._active = False raise asyncio.TimeoutError() - + with patch("asyncio.wait_for", side_effect=mock_wait_for): await nova_model.connect() - + events = [] async for event in nova_model.receive(): events.append(event) - + # Should have connection start and end assert len(events) >= 2 assert "BidirectionalConnectionStart" in events[0] assert events[0]["BidirectionalConnectionStart"]["connectionId"] == nova_model.prompt_name + assert "BidirectionalConnectionEnd" in events[-1] @pytest.mark.asyncio -async def test_convert_audio_output_event(nova_model): - """Test conversion of Nova Sonic audio output to standard format.""" +async def test_event_conversion(nova_model): + """Test conversion of all Nova Sonic event types to standard format.""" + # Test audio output audio_bytes = b"test audio data" audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") - - nova_event = { - "audioOutput": { - "content": audio_base64 - } - } - + nova_event = {"audioOutput": {"content": audio_base64}} result = nova_model._convert_nova_event(nova_event) - assert result is not None assert "audioOutput" in result assert result["audioOutput"]["audioData"] == audio_bytes assert result["audioOutput"]["format"] == "pcm" assert result["audioOutput"]["sampleRate"] == 24000 - -@pytest.mark.asyncio -async def test_convert_text_output_event(nova_model): - """Test conversion of Nova Sonic text output to standard format.""" - nova_event = { - "textOutput": { - "content": "Hello, world!", - "role": "ASSISTANT" - } - } - + # Test text output + nova_event = {"textOutput": {"content": "Hello, world!", "role": "ASSISTANT"}} result = nova_model._convert_nova_event(nova_event) - assert result is not None assert "textOutput" in result assert result["textOutput"]["text"] == "Hello, world!" assert result["textOutput"]["role"] == "assistant" - -@pytest.mark.asyncio -async def test_convert_tool_use_event(nova_model): - """Test conversion of Nova Sonic tool use to standard format.""" + # Test tool use tool_input = {"location": "Seattle"} nova_event = { "toolUse": { @@ -213,33 +250,21 @@ async def test_convert_tool_use_event(nova_model): "content": json.dumps(tool_input) } } - result = nova_model._convert_nova_event(nova_event) - assert result is not None assert "toolUse" in result assert result["toolUse"]["toolUseId"] == "tool-123" assert result["toolUse"]["name"] == "get_weather" assert result["toolUse"]["input"] == tool_input - -@pytest.mark.asyncio -async def test_convert_interruption_event(nova_model): - """Test conversion of Nova Sonic interruption to standard format.""" - nova_event = { - "stopReason": "INTERRUPTED" - } - + # Test interruption + nova_event = {"stopReason": "INTERRUPTED"} result = nova_model._convert_nova_event(nova_event) - assert result is not None assert "interruptionDetected" in result assert result["interruptionDetected"]["reason"] == "user_input" - -@pytest.mark.asyncio -async def test_convert_usage_metrics_event(nova_model): - """Test conversion of Nova Sonic usage event to standard format.""" + # Test usage metrics nova_event = { "usageEvent": { "totalTokens": 100, @@ -254,9 +279,7 @@ async def test_convert_usage_metrics_event(nova_model): } } } - result = nova_model._convert_nova_event(nova_event) - assert result is not None assert "usageMetrics" in result assert result["usageMetrics"]["totalTokens"] == 100 @@ -264,131 +287,44 @@ async def test_convert_usage_metrics_event(nova_model): assert result["usageMetrics"]["outputTokens"] == 60 assert result["usageMetrics"]["audioTokens"] == 30 - -@pytest.mark.asyncio -async def test_convert_content_start_tracks_role(nova_model): - """Test that contentStart events track role for subsequent text output.""" - nova_event = { - "contentStart": { - "role": "USER" - } - } - + # Test content start tracks role + nova_event = {"contentStart": {"role": "USER"}} result = nova_model._convert_nova_event(nova_event) - - # contentStart doesn't emit an event but stores role - assert result is None + assert result is None # contentStart doesn't emit an event assert nova_model._current_role == "USER" -# Send method tests -@pytest.mark.asyncio -async def test_send_text_content(nova_model, mock_client, mock_stream): - """Test sending text content through unified send() method.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client - - await nova_model.connect() - - text_event = { - "text": "Hello, Nova!", - "role": "user" - } - - await nova_model.send(text_event) - - # Should send contentStart, textInput, and contentEnd - assert mock_stream.input_stream.send.call_count >= 3 - - -@pytest.mark.asyncio -async def test_send_audio_content(nova_model, mock_client, mock_stream): - """Test sending audio content through unified send() method.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client - - await nova_model.connect() - - audio_event = { - "audioData": b"audio data", - "format": "pcm", - "sampleRate": 16000, - "channels": 1 - } - - await nova_model.send(audio_event) - - # Should start audio connection and send audio - assert nova_model.audio_connection_active - assert mock_stream.input_stream.send.called - - -@pytest.mark.asyncio -async def test_send_tool_result(nova_model, mock_client, mock_stream): - """Test sending tool result through unified send() method.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client - - await nova_model.connect() - - tool_result = { - "toolUseId": "tool-123", - "status": "success", - "content": [{"text": "Weather is sunny"}] - } - - await nova_model.send(tool_result) - - # Should send contentStart, toolResult, and contentEnd - assert mock_stream.input_stream.send.call_count >= 3 - - -@pytest.mark.asyncio -async def test_send_image_content_not_supported(nova_model, mock_client, mock_stream, caplog): - """Test that image content logs warning (not supported by Nova Sonic).""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client - - await nova_model.connect() - - image_event = { - "imageData": b"image data", - "mimeType": "image/jpeg" - } - - await nova_model.send(image_event) - - # Should log warning about unsupported image input - assert any("not supported" in record.message.lower() for record in caplog.records) +# Audio Streaming Tests -# Audio streaming tests @pytest.mark.asyncio async def test_audio_connection_lifecycle(nova_model, mock_client, mock_stream): """Test audio connection start and end lifecycle.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model._client = mock_client - + await nova_model.connect() - + # Start audio connection await nova_model._start_audio_connection() assert nova_model.audio_connection_active - + # End audio connection await nova_model._end_audio_input() assert not nova_model.audio_connection_active + await nova_model.close() + @pytest.mark.asyncio -async def test_silence_detection_ends_audio(nova_model, mock_client, mock_stream): +async def test_silence_detection(nova_model, mock_client, mock_stream): """Test that silence detection automatically ends audio input.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model._client = mock_client nova_model.silence_threshold = 0.1 # Short threshold for testing - + await nova_model.connect() - + # Send audio to start connection audio_event = { "audioData": b"audio data", @@ -396,20 +332,24 @@ async def test_silence_detection_ends_audio(nova_model, mock_client, mock_stream "sampleRate": 16000, "channels": 1 } - + await nova_model.send(audio_event) assert nova_model.audio_connection_active - + # Wait for silence detection await asyncio.sleep(0.2) - + # Audio connection should be ended assert not nova_model.audio_connection_active + await nova_model.close() + + +# Helper Method Tests + -# Tool configuration tests @pytest.mark.asyncio -async def test_build_tool_configuration(nova_model): +async def test_tool_configuration(nova_model): """Test building tool configuration from tool specs.""" tools = [ { @@ -425,141 +365,69 @@ async def test_build_tool_configuration(nova_model): } } ] - + tool_config = nova_model._build_tool_configuration(tools) - + assert len(tool_config) == 1 assert tool_config[0]["toolSpec"]["name"] == "get_weather" assert tool_config[0]["toolSpec"]["description"] == "Get weather information" assert "inputSchema" in tool_config[0]["toolSpec"] -# Event template tests @pytest.mark.asyncio -async def test_get_connection_start_event(nova_model): - """Test connection start event generation.""" +async def test_event_templates(nova_model): + """Test event template generation.""" + # Test connection start event event_json = nova_model._get_connection_start_event() event = json.loads(event_json) - assert "event" in event assert "sessionStart" in event["event"] assert "inferenceConfiguration" in event["event"]["sessionStart"] - -@pytest.mark.asyncio -async def test_get_prompt_start_event(nova_model): - """Test prompt start event generation.""" + # Test prompt start event nova_model.prompt_name = "test-prompt" - event_json = nova_model._get_prompt_start_event([]) event = json.loads(event_json) - assert "event" in event assert "promptStart" in event["event"] assert event["event"]["promptStart"]["promptName"] == "test-prompt" - -@pytest.mark.asyncio -async def test_get_text_input_event(nova_model): - """Test text input event generation.""" - nova_model.prompt_name = "test-prompt" + # Test text input event content_name = "test-content" - event_json = nova_model._get_text_input_event(content_name, "Hello") event = json.loads(event_json) - assert "event" in event assert "textInput" in event["event"] assert event["event"]["textInput"]["content"] == "Hello" - -@pytest.mark.asyncio -async def test_get_tool_result_event(nova_model): - """Test tool result event generation.""" - nova_model.prompt_name = "test-prompt" - content_name = "test-content" + # Test tool result event result = {"result": "Success"} - event_json = nova_model._get_tool_result_event(content_name, result) event = json.loads(event_json) - assert "event" in event assert "toolResult" in event["event"] assert json.loads(event["event"]["toolResult"]["content"]) == result -# Error handling tests -@pytest.mark.asyncio -async def test_send_when_inactive(nova_model): - """Test that send() handles inactive connection gracefully.""" - text_event = { - "text": "Hello", - "role": "user" - } - - # Should not raise error when inactive - await nova_model.send(text_event) - - -@pytest.mark.asyncio -async def test_close_when_already_closed(nova_model): - """Test that close() handles already closed connection.""" - # Should not raise error when already inactive - await nova_model.close() - await nova_model.close() # Second call should be safe +# Error Handling Tests @pytest.mark.asyncio -async def test_response_processor_handles_errors(nova_model, mock_client, mock_stream): - """Test that response processor handles errors gracefully.""" +async def test_error_handling(nova_model, mock_client, mock_stream): + """Test error handling in various scenarios.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model._client = mock_client - - # Setup mock to raise error + + # Test response processor handles errors gracefully async def mock_error(*args, **kwargs): raise Exception("Test error") - + mock_stream.await_output.side_effect = mock_error - + await nova_model.connect() - + # Wait a bit for response processor to handle error await asyncio.sleep(0.1) - - # Should still be able to close cleanly - await nova_model.close() - -# Integration-style tests -@pytest.mark.asyncio -async def test_full_conversation_flow(nova_model, mock_client, mock_stream): - """Test a complete conversation flow with text and audio.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client - - # Connect - await nova_model.connect(system_prompt="You are helpful") - - # Send text - await nova_model.send({"text": "Hello", "role": "user"}) - - # Send audio - await nova_model.send({ - "audioData": b"audio", - "format": "pcm", - "sampleRate": 16000, - "channels": 1 - }) - - # Send tool result - await nova_model.send({ - "toolUseId": "tool-1", - "status": "success", - "content": [{"text": "Result"}] - }) - - # Close + # Should still be able to close cleanly await nova_model.close() - - # Verify all operations completed - assert not nova_model._active diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index ad0d3993a..1209150ba 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -6,7 +6,6 @@ - Unified send() method with different content types - Event receiving and conversion - Connection lifecycle management -- Background task management """ import asyncio @@ -39,7 +38,7 @@ def mock_websockets_connect(mock_websocket): """Mock websockets.connect function.""" async def async_connect(*args, **kwargs): return mock_websocket - + with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.websockets.connect") as mock_connect: mock_connect.side_effect = async_connect yield mock_connect, mock_websocket @@ -83,35 +82,34 @@ def messages(): # Initialization Tests -def test_init_default_config(): - """Test model initialization with default configuration.""" - model = OpenAIRealtimeBidirectionalModel(api_key="test-key") - - assert model.model == "gpt-realtime" - assert model.api_key == "test-key" - assert model._active is False - assert model.websocket is None - +def test_model_initialization(api_key, model_name): + """Test model initialization with various configurations.""" + # Test default config + model_default = OpenAIRealtimeBidirectionalModel(api_key="test-key") + assert model_default.model == "gpt-realtime" + assert model_default.api_key == "test-key" + assert model_default._active is False + assert model_default.websocket is None -def test_init_with_api_key(api_key, model_name): - """Test model initialization with API key.""" - model = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) - - assert model.model == model_name - assert model.api_key == api_key + # Test with custom model + model_custom = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + assert model_custom.model == model_name + assert model_custom.api_key == api_key - -def test_init_with_custom_config(model_name, api_key): - """Test model initialization with custom configuration.""" - model = OpenAIRealtimeBidirectionalModel( + # Test with organization and project + model_org = OpenAIRealtimeBidirectionalModel( model=model_name, api_key=api_key, organization="org-123", project="proj-456" ) - - assert model.organization == "org-123" - assert model.project == "proj-456" + assert model_org.organization == "org-123" + assert model_org.project == "proj-456" + + # Test with env API key + with unittest.mock.patch.dict("os.environ", {"OPENAI_API_KEY": "env-key"}): + model_env = OpenAIRealtimeBidirectionalModel() + assert model_env.api_key == "env-key" def test_init_without_api_key_raises(): @@ -121,158 +119,123 @@ def test_init_without_api_key_raises(): OpenAIRealtimeBidirectionalModel() -def test_init_with_env_api_key(): - """Test initialization with API key from environment.""" - with unittest.mock.patch.dict("os.environ", {"OPENAI_API_KEY": "env-key"}): - model = OpenAIRealtimeBidirectionalModel() - assert model.api_key == "env-key" - - # Connection Tests @pytest.mark.asyncio -async def test_connect_basic(mock_websockets_connect, model): - """Test basic connection establishment.""" +async def test_connection_lifecycle(mock_websockets_connect, model, system_prompt, tool_spec, messages): + """Test complete connection lifecycle with various configurations.""" mock_connect, mock_ws = mock_websockets_connect - + + # Test basic connection await model.connect() - assert model._active is True assert model.session_id is not None assert model.websocket == mock_ws assert model._event_queue is not None + assert model._response_task is not None mock_connect.assert_called_once() + # Test close + await model.close() + assert model._active is False + mock_ws.close.assert_called_once() -@pytest.mark.asyncio -async def test_connect_with_system_prompt(mock_websockets_connect, model, system_prompt): - """Test connection with system prompt.""" - _, mock_ws = mock_websockets_connect - + # Test connection with system prompt await model.connect(system_prompt=system_prompt) - - # Verify session.update was sent with system prompt calls = mock_ws.send.call_args_list - session_update_call = None - for call in calls: - message = json.loads(call[0][0]) - if message.get("type") == "session.update": - session_update_call = message - break - - assert session_update_call is not None - assert session_update_call["session"]["instructions"] == system_prompt - + session_update = next( + (json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "session.update"), + None + ) + assert session_update is not None + assert system_prompt in session_update["session"]["instructions"] + await model.close() -@pytest.mark.asyncio -async def test_connect_with_tools(mock_websockets_connect, model, tool_spec): - """Test connection with tools.""" - _, mock_ws = mock_websockets_connect - + # Test connection with tools await model.connect(tools=[tool_spec]) - - # Verify tools were included in session config calls = mock_ws.send.call_args_list - session_update_call = None - for call in calls: - message = json.loads(call[0][0]) - if message.get("type") == "session.update": - session_update_call = message - break - - assert session_update_call is not None - assert "tools" in session_update_call["session"] - + # Tools are sent in a separate session.update after initial connection + session_updates = [json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "session.update"] + assert len(session_updates) > 0 + # Check if any session update has tools + has_tools = any("tools" in update.get("session", {}) for update in session_updates) + assert has_tools + await model.close() -@pytest.mark.asyncio -async def test_connect_with_messages(mock_websockets_connect, model, messages): - """Test connection with message history.""" - _, mock_ws = mock_websockets_connect - + # Test connection with messages await model.connect(messages=messages) - - # Verify conversation items were created calls = mock_ws.send.call_args_list - item_create_calls = [ - json.loads(call[0][0]) for call in calls - if json.loads(call[0][0]).get("type") == "conversation.item.create" - ] - - assert len(item_create_calls) > 0 + item_creates = [json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "conversation.item.create"] + assert len(item_creates) > 0 + await model.close() + + # Test connection with organization header + model_org = OpenAIRealtimeBidirectionalModel(api_key="test-key", organization="org-123") + await model_org.connect() + call_kwargs = mock_connect.call_args.kwargs + headers = call_kwargs.get("additional_headers", []) + org_header = [h for h in headers if h[0] == "OpenAI-Organization"] + assert len(org_header) == 1 + assert org_header[0][1] == "org-123" + await model_org.close() @pytest.mark.asyncio -async def test_connect_error_handling(mock_websockets_connect, model): - """Test connection error handling.""" - mock_connect, _ = mock_websockets_connect +async def test_connection_edge_cases(mock_websockets_connect, api_key, model_name): + """Test connection error handling and edge cases.""" + mock_connect, mock_ws = mock_websockets_connect + + # Test connection error + model1 = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) mock_connect.side_effect = Exception("Connection failed") - with pytest.raises(Exception, match="Connection failed"): - await model.connect() + await model1.connect() + # Reset mock + async def async_connect(*args, **kwargs): + return mock_ws + mock_connect.side_effect = async_connect -@pytest.mark.asyncio -async def test_connect_when_already_active(mock_websockets_connect, model): - """Test that connect() raises exception when already active.""" - mock_connect, _ = mock_websockets_connect - - # First connection - await model.connect() - - # Second connection attempt should raise + # Test double connection + model2 = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + await model2.connect() with pytest.raises(RuntimeError, match="Connection already active"): - await model.connect() + await model2.connect() + await model2.close() + # Test close when not connected + model3 = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + await model3.close() # Should not raise -@pytest.mark.asyncio -async def test_connect_with_organization_header(mock_websockets_connect, api_key): - """Test connection includes organization header.""" - mock_connect, _ = mock_websockets_connect - - model = OpenAIRealtimeBidirectionalModel( - api_key=api_key, - organization="org-123" - ) - await model.connect() - - # Verify headers were passed - call_kwargs = mock_connect.call_args.kwargs - headers = call_kwargs.get("additional_headers", []) - org_header = [h for h in headers if h[0] == "OpenAI-Organization"] - assert len(org_header) == 1 - assert org_header[0][1] == "org-123" + # Test close error handling (should not raise, just log) + model4 = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + await model4.connect() + mock_ws.close.side_effect = Exception("Close failed") + await model4.close() # Should not raise + assert model4._active is False # Send Method Tests @pytest.mark.asyncio -async def test_send_text_input(mock_websockets_connect, model): - """Test sending text input through unified send() method.""" +async def test_send_all_content_types(mock_websockets_connect, model): + """Test sending all content types through unified send() method.""" _, mock_ws = mock_websockets_connect await model.connect() - + + # Test text input text_input: TextInputEvent = {"text": "Hello", "role": "user"} await model.send(text_input) - - # Verify conversation.item.create and response.create were sent calls = mock_ws.send.call_args_list messages = [json.loads(call[0][0]) for call in calls] - item_create = [m for m in messages if m.get("type") == "conversation.item.create"] response_create = [m for m in messages if m.get("type") == "response.create"] - assert len(item_create) > 0 assert len(response_create) > 0 - -@pytest.mark.asyncio -async def test_send_audio_input(mock_websockets_connect, model): - """Test sending audio input through unified send() method.""" - _, mock_ws = mock_websockets_connect - await model.connect() - + # Test audio input audio_input: AudioInputEvent = { "audioData": b"audio_bytes", "format": "pcm", @@ -280,179 +243,122 @@ async def test_send_audio_input(mock_websockets_connect, model): "channels": 1, } await model.send(audio_input) - - # Verify input_audio_buffer.append was sent calls = mock_ws.send.call_args_list messages = [json.loads(call[0][0]) for call in calls] - audio_append = [m for m in messages if m.get("type") == "input_audio_buffer.append"] assert len(audio_append) > 0 - - # Verify audio was base64 encoded assert "audio" in audio_append[0] decoded = base64.b64decode(audio_append[0]["audio"]) assert decoded == b"audio_bytes" - -@pytest.mark.asyncio -async def test_send_image_input(mock_websockets_connect, model): - """Test sending image input logs warning (not supported).""" - _, mock_ws = mock_websockets_connect - await model.connect() - - image_input: ImageInputEvent = { - "imageData": b"image_bytes", - "mimeType": "image/jpeg", - "encoding": "raw", - } - - with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: - await model.send(image_input) - mock_logger.warning.assert_called_with("Image input not supported by OpenAI Realtime API") - - -@pytest.mark.asyncio -async def test_send_tool_result(mock_websockets_connect, model): - """Test sending tool result through unified send() method.""" - _, mock_ws = mock_websockets_connect - await model.connect() - + # Test tool result tool_result: ToolResult = { "toolUseId": "tool-123", "status": "success", "content": [{"text": "Result: 42"}], } await model.send(tool_result) - - # Verify function_call_output was created calls = mock_ws.send.call_args_list messages = [json.loads(call[0][0]) for call in calls] - item_create = [m for m in messages if m.get("type") == "conversation.item.create"] assert len(item_create) > 0 - - # Verify it's a function_call_output item = item_create[-1].get("item", {}) assert item.get("type") == "function_call_output" assert item.get("call_id") == "tool-123" + await model.close() + @pytest.mark.asyncio -async def test_send_when_inactive(mock_websockets_connect, model): - """Test that send() does nothing when connection is inactive.""" +async def test_send_edge_cases(mock_websockets_connect, model): + """Test send() edge cases and error handling.""" _, mock_ws = mock_websockets_connect - - # Don't connect, so _active is False + + # Test send when inactive text_input: TextInputEvent = {"text": "Hello", "role": "user"} await model.send(text_input) - - # Verify nothing was sent mock_ws.send.assert_not_called() - -@pytest.mark.asyncio -async def test_send_unknown_content_type(mock_websockets_connect, model): - """Test sending unknown content type logs warning.""" - _, _ = mock_websockets_connect + # Test image input (not supported) await model.connect() - + image_input: ImageInputEvent = { + "imageData": b"image_bytes", + "mimeType": "image/jpeg", + "encoding": "raw", + } + with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: + await model.send(image_input) + mock_logger.warning.assert_called_with("Image input not supported by OpenAI Realtime API") + + # Test unknown content type unknown_content = {"unknown_field": "value"} - with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: await model.send(unknown_content) - # Should log warning about unknown content assert mock_logger.warning.called + await model.close() + # Receive Method Tests @pytest.mark.asyncio -async def test_receive_connection_start_event(mock_websockets_connect, model): - """Test that receive() emits connection start event.""" +async def test_receive_lifecycle_events(mock_websockets_connect, model): + """Test that receive() emits connection start and end events.""" _, _ = mock_websockets_connect - + await model.connect() - + # Get first event receive_gen = model.receive() first_event = await anext(receive_gen) - + # First event should be connection start assert "BidirectionalConnectionStart" in first_event assert first_event["BidirectionalConnectionStart"]["connectionId"] == model.session_id - - # Close to stop the loop + + # Close to trigger connection end await model.close() + # Collect remaining events + events = [first_event] + try: + async for event in receive_gen: + events.append(event) + except StopAsyncIteration: + pass -@pytest.mark.asyncio -async def test_receive_connection_end_event(mock_websockets_connect, model): - """Test that receive() emits connection end event.""" - _, _ = mock_websockets_connect - - await model.connect() - - # Collect events until connection ends - events = [] - async for event in model.receive(): - events.append(event) - # Close after first event to trigger connection end - if len(events) == 1: - await model.close() - # Last event should be connection end assert "BidirectionalConnectionEnd" in events[-1] @pytest.mark.asyncio -async def test_receive_audio_output(mock_websockets_connect, model): - """Test receiving audio output from model.""" +async def test_event_conversion(mock_websockets_connect, model): + """Test conversion of all OpenAI event types to standard format.""" _, _ = mock_websockets_connect await model.connect() - - # Create mock OpenAI event - openai_event = { + + # Test audio output + audio_event = { "type": "response.output_audio.delta", "delta": base64.b64encode(b"audio_data").decode() } - - # Test conversion directly - converted_event = model._convert_openai_event(openai_event) - - assert "audioOutput" in converted_event - assert converted_event["audioOutput"]["audioData"] == b"audio_data" - assert converted_event["audioOutput"]["format"] == "pcm" - + converted = model._convert_openai_event(audio_event) + assert "audioOutput" in converted + assert converted["audioOutput"]["audioData"] == b"audio_data" + assert converted["audioOutput"]["format"] == "pcm" -@pytest.mark.asyncio -async def test_receive_text_output(mock_websockets_connect, model): - """Test receiving text output from model.""" - _, _ = mock_websockets_connect - await model.connect() - - # Create mock OpenAI event - openai_event = { + # Test text output + text_event = { "type": "response.output_text.delta", "delta": "Hello from OpenAI" } - - # Test conversion directly - converted_event = model._convert_openai_event(openai_event) - - assert "textOutput" in converted_event - assert converted_event["textOutput"]["text"] == "Hello from OpenAI" - assert converted_event["textOutput"]["role"] == "assistant" + converted = model._convert_openai_event(text_event) + assert "textOutput" in converted + assert converted["textOutput"]["text"] == "Hello from OpenAI" + assert converted["textOutput"]["role"] == "assistant" - -@pytest.mark.asyncio -async def test_receive_function_call(mock_websockets_connect, model): - """Test receiving function call from model.""" - _, _ = mock_websockets_connect - await model.connect() - - # Simulate function call sequence - # First: output_item.added with function name + # Test function call sequence item_added = { "type": "response.output_item.added", "item": { @@ -462,182 +368,102 @@ async def test_receive_function_call(mock_websockets_connect, model): } } model._convert_openai_event(item_added) - - # Second: function_call_arguments.delta + args_delta = { "type": "response.function_call_arguments.delta", "call_id": "call-123", "delta": '{"expression": "2+2"}' } model._convert_openai_event(args_delta) - - # Third: function_call_arguments.done + args_done = { "type": "response.function_call_arguments.done", "call_id": "call-123" } - converted_event = model._convert_openai_event(args_done) - - assert "toolUse" in converted_event - assert converted_event["toolUse"]["toolUseId"] == "call-123" - assert converted_event["toolUse"]["name"] == "calculator" - assert converted_event["toolUse"]["input"]["expression"] == "2+2" - + converted = model._convert_openai_event(args_done) + assert "toolUse" in converted + assert converted["toolUse"]["toolUseId"] == "call-123" + assert converted["toolUse"]["name"] == "calculator" + assert converted["toolUse"]["input"]["expression"] == "2+2" -@pytest.mark.asyncio -async def test_receive_voice_activity(mock_websockets_connect, model): - """Test receiving voice activity events.""" - _, _ = mock_websockets_connect - await model.connect() - - # Test speech started + # Test voice activity speech_started = { "type": "input_audio_buffer.speech_started" } - converted_event = model._convert_openai_event(speech_started) - - assert "voiceActivity" in converted_event - assert converted_event["voiceActivity"]["activityType"] == "speech_started" + converted = model._convert_openai_event(speech_started) + assert "voiceActivity" in converted + assert converted["voiceActivity"]["activityType"] == "speech_started" - -# Close Method Tests - - -@pytest.mark.asyncio -async def test_close_connection(mock_websockets_connect, model): - """Test closing connection.""" - _, mock_ws = mock_websockets_connect - - await model.connect() await model.close() - - assert model._active is False - mock_ws.close.assert_called_once() - - -@pytest.mark.asyncio -async def test_close_when_not_connected(mock_websockets_connect, model): - """Test closing when not connected does nothing.""" - _, mock_ws = mock_websockets_connect - - # Don't connect - await model.close() - - # Should not raise, and close should not be called - mock_ws.close.assert_not_called() - - -@pytest.mark.asyncio -async def test_close_error_handling(mock_websockets_connect, model): - """Test close error handling.""" - _, mock_ws = mock_websockets_connect - mock_ws.close.side_effect = Exception("Close failed") - - await model.connect() - - # Should not raise, just log warning - await model.close() - assert model._active is False - - -@pytest.mark.asyncio -async def test_close_cancels_response_task(mock_websockets_connect, model): - """Test that close cancels the background response task.""" - _, _ = mock_websockets_connect - - await model.connect() - - # Verify response task is running - assert model._response_task is not None - assert not model._response_task.done() - - await model.close() - - # Task should be cancelled - assert model._response_task.cancelled() or model._response_task.done() # Helper Method Tests -def test_build_session_config_basic(model): - """Test building basic session config.""" - config = model._build_session_config(None, None) - - assert isinstance(config, dict) - assert "instructions" in config - assert "audio" in config +def test_config_building(model, system_prompt, tool_spec): + """Test building session config with various options.""" + # Test basic config + config_basic = model._build_session_config(None, None) + assert isinstance(config_basic, dict) + assert "instructions" in config_basic + assert "audio" in config_basic + # Test with system prompt + config_prompt = model._build_session_config(system_prompt, None) + assert config_prompt["instructions"] == system_prompt -def test_build_session_config_with_system_prompt(model, system_prompt): - """Test building config with system prompt.""" - config = model._build_session_config(system_prompt, None) - - assert config["instructions"] == system_prompt + # Test with tools + config_tools = model._build_session_config(None, [tool_spec]) + assert "tools" in config_tools + assert len(config_tools["tools"]) > 0 -def test_build_session_config_with_tools(model, tool_spec): - """Test building config with tools.""" - config = model._build_session_config(None, [tool_spec]) - - assert "tools" in config - assert len(config["tools"]) > 0 - - -def test_convert_tools_to_openai_format(model, tool_spec): +def test_tool_conversion(model, tool_spec): """Test tool conversion to OpenAI format.""" + # Test with tools openai_tools = model._convert_tools_to_openai_format([tool_spec]) - assert len(openai_tools) == 1 assert openai_tools[0]["type"] == "function" assert openai_tools[0]["name"] == "calculator" assert openai_tools[0]["description"] == "Calculate mathematical expressions" + # Test empty list + openai_empty = model._convert_tools_to_openai_format([]) + assert openai_empty == [] + + +def test_helper_methods(model): + """Test various helper methods.""" + # Test _require_active + assert model._require_active() is False + model._active = True + assert model._require_active() is True + model._active = False + + # Test _create_text_event + text_event = model._create_text_event("Hello", "user") + assert "textOutput" in text_event + assert text_event["textOutput"]["text"] == "Hello" + assert text_event["textOutput"]["role"] == "user" -def test_convert_tools_empty_list(model): - """Test converting empty tool list.""" - openai_tools = model._convert_tools_to_openai_format([]) - - assert openai_tools == [] + # Test _create_voice_activity_event + voice_event = model._create_voice_activity_event("speech_started") + assert "voiceActivity" in voice_event + assert voice_event["voiceActivity"]["activityType"] == "speech_started" @pytest.mark.asyncio -async def test_send_event(mock_websockets_connect, model): - """Test sending event to WebSocket.""" +async def test_send_event_helper(mock_websockets_connect, model): + """Test _send_event helper method.""" _, mock_ws = mock_websockets_connect await model.connect() - + test_event = {"type": "test.event", "data": "test"} await model._send_event(test_event) - - # Verify event was sent as JSON + calls = mock_ws.send.call_args_list last_call = calls[-1] sent_message = json.loads(last_call[0][0]) - assert sent_message == test_event - -def test_require_active(model): - """Test _require_active method.""" - assert model._require_active() is False - - model._active = True - assert model._require_active() is True - - -def test_create_text_event(model): - """Test creating text event.""" - event = model._create_text_event("Hello", "user") - - assert "textOutput" in event - assert event["textOutput"]["text"] == "Hello" - assert event["textOutput"]["role"] == "user" - - -def test_create_voice_activity_event(model): - """Test creating voice activity event.""" - event = model._create_voice_activity_event("speech_started") - - assert "voiceActivity" in event - assert event["voiceActivity"]["activityType"] == "speech_started" + await model.close() From 28fe471b099dfd149d08d4e8dd0e567e8916017b Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 30 Oct 2025 15:18:12 +0100 Subject: [PATCH 031/242] fix: update comments --- .../bidirectional_streaming/models/novasonic.py | 8 +++----- .../experimental/bidirectional_streaming/models/openai.py | 6 ++---- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 5436b5ae7..b9c5060ba 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -1,11 +1,9 @@ """Nova Sonic bidirectional model provider for real-time streaming conversations. -Implements the unified BidirectionalModel interface for Amazon's Nova Sonic, handling the +Implements the BidirectionalModel interface for Amazon's Nova Sonic, handling the complex event sequencing and audio processing required by Nova Sonic's InvokeModelWithBidirectionalStream protocol. -Unified model interface - combines configuration and connection state in single class. - Nova Sonic specifics: - Hierarchical event sequences: connectionStart → promptStart → content streaming - Base64-encoded audio format with hex encoding @@ -81,7 +79,7 @@ class NovaSonicBidirectionalModel(BidirectionalModel): - """Unified Nova Sonic implementation for bidirectional streaming. + """Nova Sonic implementation for bidirectional streaming. Combines model configuration and connection state in a single class. Manages Nova Sonic's complex event sequencing, audio format conversion, and @@ -305,7 +303,7 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: yield {"BidirectionalConnectionEnd": connection_end} async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: - """Unified send method for all content types. + """Unified send method for all content types. Sends the given content to Nova Sonic. Dispatches to appropriate internal handler based on content type. diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 8322eef4b..16f3ac4a3 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -2,8 +2,6 @@ Provides real-time audio and text communication through OpenAI's Realtime API with WebSocket connections, voice activity detection, and function calling. - -Unified model interface - combines configuration and connection state in single class. """ import asyncio @@ -60,7 +58,7 @@ class OpenAIRealtimeBidirectionalModel(BidirectionalModel): - """Unified OpenAI Realtime API implementation for bidirectional streaming. + """OpenAI Realtime API implementation for bidirectional streaming. Combines model configuration and connection state in a single class. Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, @@ -434,7 +432,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] return None async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: - """Unified send method for all content types. + """Unified send method for all content types. Sends the given content to OpenAI. Dispatches to appropriate internal handler based on content type. From 60eb493319256f80a1a50d62ba1411219c719d1d Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 30 Oct 2025 15:38:10 +0100 Subject: [PATCH 032/242] fix: move import to top --- .../experimental/bidirectional_streaming/models/openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 16f3ac4a3..a542ec894 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -8,6 +8,7 @@ import base64 import json import logging +import os import uuid from typing import AsyncIterable, Union @@ -91,7 +92,6 @@ def __init__( self.project = project self.session_config = session_config or {} - import os if not self.api_key: self.api_key = os.getenv("OPENAI_API_KEY") if not self.api_key: From 70185c5a8c1ba3e7805751c266628141150e142f Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 31 Oct 2025 11:37:06 +0100 Subject: [PATCH 033/242] refactor: Update bidirectional event types --- .../bidirectional_streaming/__init__.py | 44 +- .../bidirectional_streaming/agent/agent.py | 13 +- .../models/bidirectional_model.py | 23 +- .../models/gemini_live.py | 149 +++-- .../models/novasonic.py | 54 +- .../bidirectional_streaming/models/openai.py | 57 +- .../bidirectional_streaming/types/__init__.py | 49 +- .../types/bidirectional_streaming.py | 564 +++++++++++++----- .../models/test_gemini_live.py | 66 +- .../models/test_novasonic.py | 54 +- .../models/test_openai_realtime.py | 30 +- 11 files changed, 706 insertions(+), 397 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index d855ba038..041359314 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -14,36 +14,46 @@ # Event types - For type hints and event handling from .types.bidirectional_streaming import ( AudioInputEvent, - AudioOutputEvent, - BidirectionalStreamEvent, + AudioStreamEvent, + ErrorEvent, ImageInputEvent, - InterruptionDetectedEvent, + InputEvent, + InterruptionEvent, + ModalityUsage, + MultimodalUsage, + OutputEvent, + SessionEndEvent, + SessionStartEvent, TextInputEvent, - TextOutputEvent, - UsageMetricsEvent, - VoiceActivityEvent, + TranscriptStreamEvent, + TurnCompleteEvent, + TurnStartEvent, ) __all__ = [ # Main interface "BidirectionalAgent", - # Model providers "GeminiLiveBidirectionalModel", "NovaSonicBidirectionalModel", "OpenAIRealtimeBidirectionalModel", - - # Event types + # Input Event types + "TextInputEvent", "AudioInputEvent", - "AudioOutputEvent", "ImageInputEvent", - "TextInputEvent", - "TextOutputEvent", - "InterruptionDetectedEvent", - "BidirectionalStreamEvent", - "VoiceActivityEvent", - "UsageMetricsEvent", - + "InputEvent", + # Output Event types + "SessionStartEvent", + "TurnStartEvent", + "AudioStreamEvent", + "TranscriptStreamEvent", + "InterruptionEvent", + "TurnCompleteEvent", + "MultimodalUsage", + "ModalityUsage", + "SessionEndEvent", + "ErrorEvent", + "OutputEvent", # Model interface "BidirectionalModel", ] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index c9d7292b8..d74860222 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -31,7 +31,7 @@ from ....types.traces import AttributeValue from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection from ..models.bidirectional_model import BidirectionalModel -from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent +from ..types.bidirectional_streaming import AudioInputEvent, ImageInputEvent, OutputEvent logger = logging.getLogger(__name__) @@ -395,19 +395,24 @@ async def send(self, input_data: str | AudioInputEvent | ImageInputEvent) -> Non "(dict with imageData, mimeType, encoding)" ) - async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: + async def receive(self) -> AsyncIterable[dict[str, Any]]: """Receive events from the model including audio, text, and tool calls. Yields model output events processed by background tasks including audio output, text responses, tool calls, and session updates. Yields: - BidirectionalStreamEvent: Events from the model session. + dict: Event dictionaries from the model session. Each event is a TypedEvent + converted to a dictionary for consistency with the standard Agent API. """ while self._session and self._session.active: try: event = await asyncio.wait_for(self._output_queue.get(), timeout=0.1) - yield event + # Convert TypedEvent to dict for consistency with Agent.stream_async + if hasattr(event, 'as_dict'): + yield event.as_dict() + else: + yield event except asyncio.TimeoutError: continue diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 5b7091dcd..28a6f77ce 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -16,12 +16,13 @@ import logging from typing import AsyncIterable, Union +from ....types._events import ToolResultEvent from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec from ..types.bidirectional_streaming import ( AudioInputEvent, - BidirectionalStreamEvent, ImageInputEvent, + OutputEvent, TextInputEvent, ) @@ -69,7 +70,7 @@ async def close(self) -> None: raise NotImplementedError @abc.abstractmethod - async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: + async def receive(self) -> AsyncIterable[OutputEvent]: """Receive streaming events from the model. Continuously yields events from the model as they arrive over the connection. @@ -79,13 +80,16 @@ async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: The stream continues until the connection is closed or an error occurs. Yields: - BidirectionalStreamEvent: Standardized event dictionaries containing - audio output, text responses, tool calls, or control signals. + OutputEvent: Standardized event objects containing audio output, + transcripts, tool calls, or control signals. """ raise NotImplementedError @abc.abstractmethod - async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + async def send( + self, + content: Union[TextInputEvent, AudioInputEvent, ImageInputEvent, ToolResultEvent], + ) -> None: """Send content to the model over the active connection. Transmits user input or tool results to the model during an active streaming @@ -95,13 +99,14 @@ async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputE Args: content: The content to send. Must be one of: - TextInputEvent: Text message from the user - - ImageInputEvent: Image data for visual understanding - AudioInputEvent: Audio data for speech input - - ToolResult: Result from a tool execution + - ImageInputEvent: Image data for visual understanding + - ToolResultEvent: Result from a tool execution Example: await model.send(TextInputEvent(text="Hello", role="user")) - await model.send(AudioInputEvent(audioData=bytes, format="pcm", ...)) - await model.send(ToolResult(toolUseId="123", status="success", ...)) + await model.send(AudioInputEvent(audio=bytes, format="pcm", sample_rate=16000, channels=1)) + await model.send(ImageInputEvent(image=bytes, mime_type="image/jpeg", encoding="raw")) + await model.send(ToolResultEvent(tool_result)) """ raise NotImplementedError diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 639328c64..fe495f426 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -23,16 +23,19 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse +from ....types._events import ToolResultEvent from ..types.bidirectional_streaming import ( AudioInputEvent, - AudioOutputEvent, - BidirectionalConnectionEndEvent, - BidirectionalConnectionStartEvent, + AudioStreamEvent, + ErrorEvent, ImageInputEvent, - InterruptionDetectedEvent, + InterruptionEvent, + SessionEndEvent, + SessionStartEvent, TextInputEvent, - TextOutputEvent, - TranscriptEvent, + TranscriptStreamEvent, + TurnCompleteEvent, + TurnStartEvent, ) from .bidirectional_model import BidirectionalModel @@ -158,12 +161,12 @@ async def _send_message_history(self, messages: Messages) -> None: async def receive(self) -> AsyncIterable[Dict[str, Any]]: """Receive Gemini Live API events and convert to provider-agnostic format.""" - # Emit connection start event - connection_start: BidirectionalConnectionStartEvent = { - "connectionId": self.session_id, - "metadata": {"provider": "gemini_live", "model_id": self.model_id} - } - yield {"BidirectionalConnectionStart": connection_start} + # Emit session start event + yield SessionStartEvent( + session_id=self.session_id, + model=self.model_id, + capabilities=["audio", "tools", "images"] + ) try: # Wrap in while loop to restart after turn_complete (SDK limitation workaround) @@ -189,30 +192,23 @@ async def receive(self) -> AsyncIterable[Dict[str, Any]]: except Exception as e: logger.error("Fatal error in receive loop: %s", e) + yield ErrorEvent(error=e) finally: - # Emit connection end event when exiting - connection_end: BidirectionalConnectionEndEvent = { - "connectionId": self.session_id, - "reason": "connection_complete", - "metadata": {"provider": "gemini_live"} - } - yield {"BidirectionalConnectionEnd": connection_end} + # Emit session end event when exiting + yield SessionEndEvent(reason="complete") def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dict[str, Any]]: """Convert Gemini Live API events to provider-agnostic format. - Handles different types of text output: - - inputTranscription: User's speech transcribed to text (emitted as transcript event) - - outputTranscription: Model's audio transcribed to text (emitted as transcript event) - - modelTurn text: Actual text response from the model (emitted as textOutput) + Handles different types of content: + - inputTranscription: User's speech transcribed to text + - outputTranscription: Model's audio transcribed to text + - modelTurn text: Text response from the model """ try: # Handle interruption first (from server_content) if message.server_content and message.server_content.interrupted: - interruption: InterruptionDetectedEvent = { - "reason": "user_input" - } - return {"interruptionDetected": interruption} + return InterruptionEvent(reason="user_speech", turn_id=None) # Handle input transcription (user's speech) - emit as transcript event if message.server_content and message.server_content.input_transcription: @@ -221,12 +217,11 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic if hasattr(input_transcript, 'text') and input_transcript.text: transcription_text = input_transcript.text logger.debug(f"Input transcription detected: {transcription_text}") - transcript: TranscriptEvent = { - "text": transcription_text, - "role": "user", - "type": "input" - } - return {"transcript": transcript} + return TranscriptStreamEvent( + text=transcription_text, + source="user", + is_final=True + ) # Handle output transcription (model's audio) - emit as transcript event if message.server_content and message.server_content.output_transcription: @@ -235,32 +230,29 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic if hasattr(output_transcript, 'text') and output_transcript.text: transcription_text = output_transcript.text logger.debug(f"Output transcription detected: {transcription_text}") - transcript: TranscriptEvent = { - "text": transcription_text, - "role": "assistant", - "type": "output" - } - return {"transcript": transcript} + return TranscriptStreamEvent( + text=transcription_text, + source="assistant", + is_final=True + ) - # Handle actual text output from model (not transcription) - # The SDK's message.text property accesses modelTurn.parts[].text + # Handle text output from model if message.text: - text_output: TextOutputEvent = { - "text": message.text, - "role": "assistant" - } - return {"textOutput": text_output} + logger.debug(f"Text output as transcript: {message.text}") + return TranscriptStreamEvent( + text=message.text, + source="assistant", + is_final=True + ) # Handle audio output using SDK's built-in data property if message.data: - audio_output: AudioOutputEvent = { - "audioData": message.data, - "format": "pcm", - "sampleRate": GEMINI_OUTPUT_SAMPLE_RATE, - "channels": GEMINI_CHANNELS, - "encoding": "raw" - } - return {"audioOutput": audio_output} + return AudioStreamEvent( + audio=message.data, + format="pcm", + sample_rate=GEMINI_OUTPUT_SAMPLE_RATE, + channels=GEMINI_CHANNELS + ) # Handle tool calls if message.tool_call and message.tool_call.function_calls: @@ -281,34 +273,33 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic logger.error("Message attributes: %s", [attr for attr in dir(message) if not attr.startswith('_')]) return None - async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + async def send( + self, + content: Union[TextInputEvent, AudioInputEvent, ImageInputEvent, ToolResultEvent], + ) -> None: """Unified send method for all content types. Sends the given inputs to Google Live API Dispatches to appropriate internal handler based on content type. Args: - content: Typed event (TextInputEvent, ImageInputEvent, AudioInputEvent, or ToolResult). + content: Typed event (TextInputEvent, AudioInputEvent, ImageInputEvent, or ToolResultEvent). """ if not self._active: return try: - if isinstance(content, dict): - # Dispatch based on content structure - if "text" in content and "role" in content: - # TextInputEvent - await self._send_text_content(content["text"]) - elif "audioData" in content: - # AudioInputEvent - await self._send_audio_content(content) - elif "imageData" in content or "image_url" in content: - # ImageInputEvent - await self._send_image_content(content) - elif "toolUseId" in content and "status" in content: - # ToolResult - await self._send_tool_result(content) - else: - logger.warning(f"Unknown content type with keys: {content.keys()}") + if isinstance(content, TextInputEvent): + await self._send_text_content(content.text) + elif isinstance(content, AudioInputEvent): + await self._send_audio_content(content) + elif isinstance(content, ImageInputEvent): + await self._send_image_content(content) + elif isinstance(content, ToolResultEvent): + tool_result = content.get("tool_result") + if tool_result: + await self._send_tool_result(tool_result) + else: + logger.warning(f"Unknown content type: {type(content)}") except Exception as e: logger.error(f"Error sending content: {e}") @@ -321,7 +312,7 @@ async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: try: # Create audio blob for the SDK audio_blob = genai_types.Blob( - data=audio_input["audioData"], + data=audio_input.audio, mime_type=f"audio/pcm;rate={GEMINI_INPUT_SAMPLE_RATE}" ) @@ -339,19 +330,19 @@ async def _send_image_content(self, image_input: ImageInputEvent) -> None: """ try: # Prepare the message based on encoding - if image_input.get("encoding") == "base64": + if image_input.encoding == "base64": # Data is already base64 encoded - if isinstance(image_input["imageData"], bytes): - data_str = image_input["imageData"].decode() + if isinstance(image_input.image, bytes): + data_str = image_input.image.decode() else: - data_str = image_input["imageData"] + data_str = image_input.image else: # Raw bytes - need to base64 encode - data_str = base64.b64encode(image_input["imageData"]).decode() + data_str = base64.b64encode(image_input.image).decode() # Create the message in the format expected by Gemini Live msg = { - "mime_type": image_input["mimeType"], + "mime_type": image_input.mime_type, "data": data_str } diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index b9c5060ba..f66a71377 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -32,16 +32,20 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse +from ....types._events import ToolResultEvent from ..types.bidirectional_streaming import ( AudioInputEvent, - AudioOutputEvent, - BidirectionalConnectionEndEvent, - BidirectionalConnectionStartEvent, + AudioStreamEvent, + ErrorEvent, ImageInputEvent, - InterruptionDetectedEvent, + InterruptionEvent, + MultimodalUsage, + SessionEndEvent, + SessionStartEvent, TextInputEvent, - TextOutputEvent, - UsageMetricsEvent, + TranscriptStreamEvent, + TurnCompleteEvent, + TurnStartEvent, ) from .bidirectional_model import BidirectionalModel @@ -302,34 +306,34 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: } yield {"BidirectionalConnectionEnd": connection_end} - async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + async def send( + self, + content: Union[TextInputEvent, AudioInputEvent, ImageInputEvent, ToolResultEvent], + ) -> None: """Unified send method for all content types. Sends the given content to Nova Sonic. Dispatches to appropriate internal handler based on content type. Args: - content: Typed event (TextInputEvent, ImageInputEvent, AudioInputEvent, or ToolResult). + content: Typed event (TextInputEvent, AudioInputEvent, ImageInputEvent, or ToolResultEvent). """ if not self._active: return try: - if isinstance(content, dict): - # Dispatch based on content structure - if "text" in content and "role" in content: - # TextInputEvent - await self._send_text_content(content["text"]) - elif "audioData" in content: - # AudioInputEvent - await self._send_audio_content(content) - elif "imageData" in content or "image_url" in content: - # ImageInputEvent - not supported by Nova Sonic - logger.warning("Image input not supported by Nova Sonic") - elif "toolUseId" in content and "status" in content: - # ToolResult - await self._send_tool_result(content) - else: - logger.warning(f"Unknown content type with keys: {content.keys()}") + if isinstance(content, TextInputEvent): + await self._send_text_content(content.text) + elif isinstance(content, AudioInputEvent): + await self._send_audio_content(content) + elif isinstance(content, ImageInputEvent): + # ImageInputEvent - not supported by Nova Sonic + logger.warning("Image input not supported by Nova Sonic") + elif isinstance(content, ToolResultEvent): + tool_result = content.get("tool_result") + if tool_result: + await self._send_tool_result(tool_result) + else: + logger.warning(f"Unknown content type: {type(content)}") except Exception as e: logger.error(f"Error sending content: {e}") @@ -370,7 +374,7 @@ async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: self.silence_task.cancel() # Convert audio to Nova Sonic base64 format - nova_audio_data = base64.b64encode(audio_input["audioData"]).decode("utf-8") + nova_audio_data = base64.b64encode(audio_input.audio).decode("utf-8") # Send audio input event audio_event = json.dumps( diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index a542ec894..ae7de4c83 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -18,16 +18,21 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse +from ....types._events import ToolResultEvent from ..types.bidirectional_streaming import ( AudioInputEvent, - AudioOutputEvent, - BidirectionalConnectionEndEvent, - BidirectionalConnectionStartEvent, - BidirectionalStreamEvent, + AudioStreamEvent, + ErrorEvent, ImageInputEvent, + InterruptionEvent, + MultimodalUsage, + OutputEvent, + SessionEndEvent, + SessionStartEvent, TextInputEvent, - TextOutputEvent, - VoiceActivityEvent, + TranscriptStreamEvent, + TurnCompleteEvent, + TurnStartEvent, ) from .bidirectional_model import BidirectionalModel @@ -266,7 +271,7 @@ async def _process_responses(self) -> None: self._active = False logger.debug("OpenAI Realtime response processor stopped") - async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: + async def receive(self) -> AsyncIterable[OutputEvent]: """Receive OpenAI events and convert to Strands format.""" connection_start: BidirectionalConnectionStartEvent = { "connectionId": self.session_id, @@ -431,40 +436,40 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] logger.debug("Unhandled OpenAI event type: %s", event_type) return None - async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + async def send( + self, + content: Union[TextInputEvent, AudioInputEvent, ImageInputEvent, ToolResultEvent], + ) -> None: """Unified send method for all content types. Sends the given content to OpenAI. Dispatches to appropriate internal handler based on content type. Args: - content: Typed event (TextInputEvent, ImageInputEvent, AudioInputEvent, or ToolResult). + content: Typed event (TextInputEvent, AudioInputEvent, ImageInputEvent, or ToolResultEvent). """ if not self._require_active(): return try: - if isinstance(content, dict): - # Dispatch based on content structure - if "text" in content and "role" in content: - # TextInputEvent - await self._send_text_content(content["text"]) - elif "audioData" in content: - # AudioInputEvent - await self._send_audio_content(content) - elif "imageData" in content or "image_url" in content: - # ImageInputEvent - not supported by OpenAI Realtime yet - logger.warning("Image input not supported by OpenAI Realtime API") - elif "toolUseId" in content and "status" in content: - # ToolResult - await self._send_tool_result(content) - else: - logger.warning(f"Unknown content type with keys: {content.keys()}") + if isinstance(content, TextInputEvent): + await self._send_text_content(content.text) + elif isinstance(content, AudioInputEvent): + await self._send_audio_content(content) + elif isinstance(content, ImageInputEvent): + # ImageInputEvent - not supported by OpenAI Realtime yet + logger.warning("Image input not supported by OpenAI Realtime API") + elif isinstance(content, ToolResultEvent): + tool_result = content.get("tool_result") + if tool_result: + await self._send_tool_result(tool_result) + else: + logger.warning(f"Unknown content type: {type(content)}") except Exception as e: logger.error(f"Error sending content: {e}") async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: """Internal: Send audio content to OpenAI for processing.""" - audio_base64 = base64.b64encode(audio_input["audioData"]).decode("utf-8") + audio_base64 = base64.b64encode(audio_input.audio).decode("utf-8") await self._send_event({"type": "input_audio_buffer.append", "audio": audio_base64}) async def _send_text_content(self, text: str) -> None: diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index d040ee436..52034db1b 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -2,38 +2,51 @@ from .bidirectional_streaming import ( DEFAULT_CHANNELS, + DEFAULT_FORMAT, DEFAULT_SAMPLE_RATE, SUPPORTED_AUDIO_FORMATS, SUPPORTED_CHANNELS, SUPPORTED_SAMPLE_RATES, AudioInputEvent, - AudioOutputEvent, - BidirectionalConnectionEndEvent, - BidirectionalConnectionStartEvent, - BidirectionalStreamEvent, + AudioStreamEvent, + ErrorEvent, ImageInputEvent, - InterruptionDetectedEvent, - TextOutputEvent, - TranscriptEvent, - UsageMetricsEvent, - VoiceActivityEvent, + InputEvent, + InterruptionEvent, + ModalityUsage, + MultimodalUsage, + OutputEvent, + SessionEndEvent, + SessionStartEvent, + TextInputEvent, + TranscriptStreamEvent, + TurnCompleteEvent, + TurnStartEvent, ) __all__ = [ + # Input Events + "TextInputEvent", "AudioInputEvent", - "AudioOutputEvent", - "BidirectionalConnectionEndEvent", - "BidirectionalConnectionStartEvent", - "BidirectionalStreamEvent", "ImageInputEvent", - "InterruptionDetectedEvent", - "TextOutputEvent", - "TranscriptEvent", - "UsageMetricsEvent", - "VoiceActivityEvent", + "InputEvent", + # Output Events + "SessionStartEvent", + "TurnStartEvent", + "AudioStreamEvent", + "TranscriptStreamEvent", + "InterruptionEvent", + "TurnCompleteEvent", + "MultimodalUsage", + "ModalityUsage", + "SessionEndEvent", + "ErrorEvent", + "OutputEvent", + # Constants "SUPPORTED_AUDIO_FORMATS", "SUPPORTED_SAMPLE_RATES", "SUPPORTED_CHANNELS", "DEFAULT_SAMPLE_RATE", "DEFAULT_CHANNELS", + "DEFAULT_FORMAT", ] diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 145710c3c..e7af3ad43 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -6,9 +6,9 @@ Key features: - Audio input/output events with standardized formats - Interruption detection and handling -- connection lifecycle management +- Session lifecycle management - Provider-agnostic event types -- Backwards compatibility with existing StreamEvent types +- Type-safe discriminated unions with TypedEvent Audio format normalization: - Supports PCM, WAV, Opus, and MP3 formats @@ -17,12 +17,9 @@ - Abstracts provider-specific encodings """ -from typing import Any, Dict, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Union, cast -from typing_extensions import TypedDict - -from ....types.content import Role -from ....types.streaming import StreamEvent +from ....types._events import TypedEvent # Audio format constants SUPPORTED_AUDIO_FORMATS = ["pcm", "wav", "opus", "mp3"] @@ -30,221 +27,470 @@ SUPPORTED_CHANNELS = [1, 2] # 1=mono, 2=stereo DEFAULT_SAMPLE_RATE = 16000 DEFAULT_CHANNELS = 1 +DEFAULT_FORMAT = "pcm" -class AudioOutputEvent(TypedDict): - """Audio output event from the model. +# ============================================================================ +# Input Events (sent via session.send()) +# ============================================================================ - Provides standardized audio output format across different providers using - raw bytes instead of provider-specific encodings. - Attributes: - audioData: Raw audio bytes (not base64 or hex encoded). - format: Audio format from SUPPORTED_AUDIO_FORMATS. - sampleRate: Sample rate from SUPPORTED_SAMPLE_RATES. - channels: Channel count from SUPPORTED_CHANNELS. - encoding: Original provider encoding for debugging purposes. +class TextInputEvent(TypedEvent): + """Text input event for sending text to the model. + + Used for sending text content through the send() method. + + Parameters: + text: The text content to send to the model. + role: The role of the message sender (typically "user"). """ - audioData: bytes - format: Literal["pcm", "wav", "opus", "mp3"] - sampleRate: Literal[16000, 24000, 48000] - channels: Literal[1, 2] - encoding: Optional[str] + def __init__(self, text: str, role: str): + super().__init__( + { + "type": "bidirectional_text_input", + "text": text, + "role": role, + } + ) + + @property + def text(self) -> str: + return cast(str, self.get("text")) + @property + def role(self) -> str: + return cast(str, self.get("role")) -class AudioInputEvent(TypedDict): + +class AudioInputEvent(TypedEvent): """Audio input event for sending audio to the model. Used for sending audio data through the send() method. - Attributes: - audioData: Raw audio bytes to send to model. + Parameters: + audio: Raw audio bytes to send to model (not base64 encoded). format: Audio format from SUPPORTED_AUDIO_FORMATS. - sampleRate: Sample rate from SUPPORTED_SAMPLE_RATES. + sample_rate: Sample rate from SUPPORTED_SAMPLE_RATES. channels: Channel count from SUPPORTED_CHANNELS. """ - audioData: bytes - format: Literal["pcm", "wav", "opus", "mp3"] - sampleRate: Literal[16000, 24000, 48000] - channels: Literal[1, 2] - - -class ImageInputEvent(TypedDict): + def __init__( + self, + audio: bytes, + format: Literal["pcm", "wav", "opus", "mp3"], + sample_rate: Literal[16000, 24000, 48000], + channels: Literal[1, 2], + ): + super().__init__( + { + "type": "bidirectional_audio_input", + "audio": audio, + "format": format, + "sample_rate": sample_rate, + "channels": channels, + } + ) + + @property + def audio(self) -> bytes: + return cast(bytes, self.get("audio")) + + @property + def format(self) -> str: + return cast(str, self.get("format")) + + @property + def sample_rate(self) -> int: + return cast(int, self.get("sample_rate")) + + @property + def channels(self) -> int: + return cast(int, self.get("channels")) + + +class ImageInputEvent(TypedEvent): """Image input event for sending images/video frames to the model. - + Used for sending image data through the send() method. Supports both raw image bytes and base64-encoded data. - - Attributes: - imageData: Image bytes (raw or base64-encoded string). - mimeType: MIME type (e.g., "image/jpeg", "image/png"). - encoding: How the imageData is encoded. + + Parameters: + image: Image bytes (raw or base64-encoded string). + mime_type: MIME type (e.g., "image/jpeg", "image/png"). + encoding: How the image data is encoded. """ - - imageData: bytes | str - mimeType: str - encoding: Literal["base64", "raw"] + def __init__( + self, + image: Union[bytes, str], + mime_type: str, + encoding: Literal["base64", "raw"], + ): + super().__init__( + { + "type": "bidirectional_image_input", + "image": image, + "mime_type": mime_type, + "encoding": encoding, + } + ) + + @property + def image(self) -> Union[bytes, str]: + return cast(Union[bytes, str], self.get("image")) + + @property + def mime_type(self) -> str: + return cast(str, self.get("mime_type")) + + @property + def encoding(self) -> str: + return cast(str, self.get("encoding")) + + +# ============================================================================ +# Output Events (received via session.receive_events()) +# ============================================================================ + + +class SessionStartEvent(TypedEvent): + """Session established and ready for interaction. + + Parameters: + session_id: Unique identifier for this session. + model: Model identifier (e.g., "gpt-realtime", "gemini-2.0-flash-live"). + capabilities: List of supported features (e.g., ["audio", "tools", "images"]). + """ -class TextInputEvent(TypedDict): - """Text input event for sending text to the model. + def __init__(self, session_id: str, model: str, capabilities: List[str]): + super().__init__( + { + "type": "bidirectional_session_start", + "session_id": session_id, + "model": model, + "capabilities": capabilities, + } + ) - Used for sending text content through the send() method. + @property + def session_id(self) -> str: + return cast(str, self.get("session_id")) - Attributes: - text: The text content to send to the model. - role: The role of the message sender (typically "user"). - """ + @property + def model(self) -> str: + return cast(str, self.get("model")) - text: str - role: Role + @property + def capabilities(self) -> List[str]: + return cast(List[str], self.get("capabilities")) -class TextOutputEvent(TypedDict): - """Text output event from the model during bidirectional streaming. +class TurnStartEvent(TypedEvent): + """Model starts generating a response. - Attributes: - text: The text content from the model. - role: The role of the message sender. + Parameters: + turn_id: Unique identifier for this turn (used in turn.complete). """ - text: str - role: Role + def __init__(self, turn_id: str): + super().__init__({"type": "bidirectional_turn_start", "turn_id": turn_id}) + @property + def turn_id(self) -> str: + return cast(str, self.get("turn_id")) -class TranscriptEvent(TypedDict): - """Transcript event for audio transcriptions. - - Used for both input transcriptions (user speech) and output transcriptions - (model audio). These are informational and separate from actual text responses. - - Attributes: - text: The transcribed text. - role: The role of the speaker ("user" or "assistant"). - type: Type of transcription ("input" or "output"). - """ - - text: str - role: Role - type: Literal["input", "output"] +class AudioStreamEvent(TypedEvent): + """Streaming audio output from the model. -class InterruptionDetectedEvent(TypedDict): - """Interruption detection event. + Parameters: + audio: Raw audio data as bytes (not base64 encoded). + format: Audio encoding format. + sample_rate: Number of audio samples per second in Hz. + channels: Number of audio channels (1=mono, 2=stereo). + """ - Signals when user interruption is detected during model generation. + def __init__( + self, + audio: bytes, + format: Literal["pcm", "wav", "opus", "mp3"], + sample_rate: Literal[16000, 24000, 48000], + channels: Literal[1, 2], + ): + super().__init__( + { + "type": "bidirectional_audio_stream", + "audio": audio, + "format": format, + "sample_rate": sample_rate, + "channels": channels, + } + ) + + @property + def audio(self) -> bytes: + return cast(bytes, self.get("audio")) + + @property + def format(self) -> str: + return cast(str, self.get("format")) + + @property + def sample_rate(self) -> int: + return cast(int, self.get("sample_rate")) + + @property + def channels(self) -> int: + return cast(int, self.get("channels")) + + +class TranscriptStreamEvent(TypedEvent): + """Audio transcription of speech (user or assistant). + + Parameters: + text: Transcribed text from audio. + source: Who is speaking ("user" or "assistant"). + is_final: Whether this is the final/complete transcript. + """ - Attributes: - reason: Interruption reason from predefined set. + def __init__( + self, text: str, source: Literal["user", "assistant"], is_final: bool + ): + super().__init__( + { + "type": "bidirectional_transcript_stream", + "text": text, + "source": source, + "is_final": is_final, + } + ) + + @property + def text(self) -> str: + return cast(str, self.get("text")) + + @property + def source(self) -> str: + return cast(str, self.get("source")) + + @property + def is_final(self) -> bool: + return cast(bool, self.get("is_final")) + + +class InterruptionEvent(TypedEvent): + """Model generation was interrupted. + + Parameters: + reason: Why the interruption occurred. + turn_id: ID of the turn that was interrupted (may be None). """ - reason: Literal["user_input", "vad_detected", "manual"] + def __init__( + self, reason: Literal["user_speech", "error"], turn_id: Optional[str] = None + ): + super().__init__( + { + "type": "bidirectional_interruption", + "reason": reason, + "turn_id": turn_id, + } + ) + @property + def reason(self) -> str: + return cast(str, self.get("reason")) -class BidirectionalConnectionStartEvent(TypedDict, total=False): - """connection start event for bidirectional streaming. + @property + def turn_id(self) -> Optional[str]: + return cast(Optional[str], self.get("turn_id")) - Attributes: - connectionId: Unique connection identifier. - metadata: Provider-specific connection metadata. - """ - connectionId: Optional[str] - metadata: Optional[Dict[str, Any]] +class TurnCompleteEvent(TypedEvent): + """Model finished generating response. + Parameters: + turn_id: ID of the turn that completed (matches turn.start). + stop_reason: Why the turn ended. + """ -class BidirectionalConnectionEndEvent(TypedDict): - """connection end event for bidirectional streaming. + def __init__( + self, + turn_id: str, + stop_reason: Literal["complete", "interrupted", "tool_use", "error"], + ): + super().__init__( + { + "type": "bidirectional_turn_complete", + "turn_id": turn_id, + "stop_reason": stop_reason, + } + ) - Attributes: - reason: Reason for connection end from predefined set. - connectionId: Unique connection identifier. - metadata: Provider-specific connection metadata. - """ + @property + def turn_id(self) -> str: + return cast(str, self.get("turn_id")) - reason: Literal["user_request", "timeout", "error", "connection_complete"] - connectionId: Optional[str] - metadata: Optional[Dict[str, Any]] + @property + def stop_reason(self) -> str: + return cast(str, self.get("stop_reason")) -class UsageMetricsEvent(TypedDict): - """Token usage and performance tracking. - Provides standardized usage metrics across providers for cost monitoring - and performance optimization. +class ModalityUsage(dict): + """Token usage for a specific modality. Attributes: - totalTokens: Total tokens used in the interaction. - inputTokens: Tokens used for input processing. - outputTokens: Tokens used for output generation. - audioTokens: Tokens used specifically for audio processing. + modality: Type of content. + input_tokens: Tokens used for this modality's input. + output_tokens: Tokens used for this modality's output. """ - totalTokens: Optional[int] - inputTokens: Optional[int] - outputTokens: Optional[int] - audioTokens: Optional[int] + modality: Literal["text", "audio", "image", "cached"] + input_tokens: int + output_tokens: int -class VoiceActivityEvent(TypedDict): - """Voice activity detection event for speech monitoring. +class MultimodalUsage(TypedEvent): + """Token usage event with modality breakdown for multimodal streaming. - Provides standardized voice activity detection events across providers - to enable speech-aware applications and better conversation flow. + Combines TypedEvent behavior with Usage fields for a unified event type. - Attributes: - activityType: Type of voice activity detected. + Parameters: + input_tokens: Total tokens used for all input modalities. + output_tokens: Total tokens used for all output modalities. + total_tokens: Sum of input and output tokens. + modality_details: Optional list of token usage per modality. + cache_read_input_tokens: Optional tokens read from cache. + cache_write_input_tokens: Optional tokens written to cache. """ - activityType: Literal["speech_started", "speech_stopped", "timeout"] - - -class UsageMetricsEvent(TypedDict): - """Token usage and performance tracking. - - Provides standardized usage metrics across providers for cost monitoring - and performance optimization. - - Attributes: - totalTokens: Total tokens used in the interaction. - inputTokens: Tokens used for input processing. - outputTokens: Tokens used for output generation. - audioTokens: Tokens used specifically for audio processing. + def __init__( + self, + input_tokens: int, + output_tokens: int, + total_tokens: int, + modality_details: Optional[List[ModalityUsage]] = None, + cache_read_input_tokens: Optional[int] = None, + cache_write_input_tokens: Optional[int] = None, + ): + data: Dict[str, Any] = { + "type": "multimodal_usage", + "inputTokens": input_tokens, + "outputTokens": output_tokens, + "totalTokens": total_tokens, + } + if modality_details is not None: + data["modality_details"] = modality_details + if cache_read_input_tokens is not None: + data["cacheReadInputTokens"] = cache_read_input_tokens + if cache_write_input_tokens is not None: + data["cacheWriteInputTokens"] = cache_write_input_tokens + super().__init__(data) + + @property + def input_tokens(self) -> int: + return cast(int, self.get("inputTokens")) + + @property + def output_tokens(self) -> int: + return cast(int, self.get("outputTokens")) + + @property + def total_tokens(self) -> int: + return cast(int, self.get("totalTokens")) + + @property + def modality_details(self) -> List[ModalityUsage]: + return cast(List[ModalityUsage], self.get("modality_details", [])) + + @property + def cache_read_input_tokens(self) -> Optional[int]: + return cast(Optional[int], self.get("cacheReadInputTokens")) + + @property + def cache_write_input_tokens(self) -> Optional[int]: + return cast(Optional[int], self.get("cacheWriteInputTokens")) + + +class SessionEndEvent(TypedEvent): + """Session terminated. + + Parameters: + reason: Why the session ended. """ - totalTokens: Optional[int] - inputTokens: Optional[int] - outputTokens: Optional[int] - audioTokens: Optional[int] + def __init__( + self, reason: Literal["client_disconnect", "timeout", "error", "complete"] + ): + super().__init__({"type": "bidirectional_session_end", "reason": reason}) + @property + def reason(self) -> str: + return cast(str, self.get("reason")) -class BidirectionalStreamEvent(StreamEvent, total=False): - """Bidirectional stream event extending existing StreamEvent. - Extends the existing StreamEvent type with bidirectional-specific events - while maintaining full backward compatibility with existing Strands streaming. +class ErrorEvent(TypedEvent): + """Error occurred during the session. - Attributes: - audioOutput: Audio output from the model. - audioInput: Audio input sent to the model. - imageInput: Image input sent to the model. - textOutput: Text output from the model. - transcript: Audio transcription (input or output). - interruptionDetected: User interruption detection. - BidirectionalConnectionStart: connection start event. - BidirectionalConnectionEnd: connection end event. - voiceActivity: Voice activity detection events. - usageMetrics: Token usage and performance metrics. + Similar to strands.types._events.ForceStopEvent, this event wraps exceptions + that occur during bidirectional streaming sessions. + + Parameters: + error: The exception that occurred. + code: Optional error code for programmatic handling (defaults to exception class name). + details: Optional additional error information. """ - audioOutput: Optional[AudioOutputEvent] - audioInput: Optional[AudioInputEvent] - imageInput: Optional[ImageInputEvent] - textOutput: Optional[TextOutputEvent] - transcript: Optional[TranscriptEvent] - interruptionDetected: Optional[InterruptionDetectedEvent] - BidirectionalConnectionStart: Optional[BidirectionalConnectionStartEvent] - BidirectionalConnectionEnd: Optional[BidirectionalConnectionEndEvent] - voiceActivity: Optional[VoiceActivityEvent] - usageMetrics: Optional[UsageMetricsEvent] + def __init__( + self, + error: Exception, + code: Optional[str] = None, + details: Optional[Dict[str, Any]] = None, + ): + super().__init__( + { + "bidirectional_error": True, + "error": error, + "error_message": str(error), + "error_code": code or type(error).__name__, + "error_details": details, + } + ) + + @property + def error(self) -> Exception: + return cast(Exception, self.get("error")) + + @property + def code(self) -> str: + return cast(str, self.get("error_code")) + + @property + def message(self) -> str: + return cast(str, self.get("error_message")) + + @property + def details(self) -> Optional[Dict[str, Any]]: + return cast(Optional[Dict[str, Any]], self.get("error_details")) + + +# ============================================================================ +# Type Unions +# ============================================================================ + +# Note: ToolResultEvent and ToolUseStreamEvent are reused from strands.types._events + +InputEvent = Union[TextInputEvent, AudioInputEvent, ImageInputEvent] + +OutputEvent = Union[ + SessionStartEvent, + TurnStartEvent, + AudioStreamEvent, + TranscriptStreamEvent, + InterruptionEvent, + TurnCompleteEvent, + MultimodalUsage, + SessionEndEvent, + ErrorEvent, +] diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index b894509c9..f13e2cf04 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -19,6 +19,7 @@ ImageInputEvent, TextInputEvent, ) +from strands.types._events import ToolResultEvent from strands.types.tools import ToolResult @@ -188,7 +189,7 @@ async def test_send_all_content_types(mock_genai_client, model): await model.connect() # Test text input - text_input: TextInputEvent = {"text": "Hello", "role": "user"} + text_input = TextInputEvent(text="Hello", role="user") await model.send(text_input) mock_live_session.send_client_content.assert_called_once() call_args = mock_live_session.send_client_content.call_args @@ -197,31 +198,32 @@ async def test_send_all_content_types(mock_genai_client, model): assert content.parts[0].text == "Hello" # Test audio input - audio_input: AudioInputEvent = { - "audioData": b"audio_bytes", - "format": "pcm", - "sampleRate": 16000, - "channels": 1, - } + audio_input = AudioInputEvent( + audio=b"audio_bytes", + format="pcm", + sample_rate=16000, + channels=1, + ) await model.send(audio_input) mock_live_session.send_realtime_input.assert_called_once() # Test image input - image_input: ImageInputEvent = { - "imageData": b"image_bytes", - "mimeType": "image/jpeg", - "encoding": "raw", - } + image_input = ImageInputEvent( + image=b"image_bytes", + mime_type="image/jpeg", + encoding="raw", + ) await model.send(image_input) mock_live_session.send.assert_called_once() # Test tool result + from strands.types._events import ToolResultEvent tool_result: ToolResult = { "toolUseId": "tool-123", "status": "success", "content": [{"text": "Result: 42"}], } - await model.send(tool_result) + await model.send(ToolResultEvent(tool_result)) mock_live_session.send_tool_response.assert_called_once() await model.close() @@ -233,7 +235,7 @@ async def test_send_edge_cases(mock_genai_client, model): _, mock_live_session, _ = mock_genai_client # Test send when inactive - text_input: TextInputEvent = {"text": "Hello", "role": "user"} + text_input = TextInputEvent(text="Hello", role="user") await model.send(text_input) mock_live_session.send_client_content.assert_not_called() @@ -251,6 +253,11 @@ async def test_send_edge_cases(mock_genai_client, model): @pytest.mark.asyncio async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): """Test that receive() emits connection start and end events.""" + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( + SessionStartEvent, + SessionEndEvent, + ) + _, mock_live_session, _ = mock_genai_client mock_live_session.receive.return_value = agenerator([]) @@ -266,18 +273,24 @@ async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): # Verify connection start and end assert len(events) >= 2 - assert "BidirectionalConnectionStart" in events[0] - assert events[0]["BidirectionalConnectionStart"]["connectionId"] == model.session_id - assert "BidirectionalConnectionEnd" in events[-1] + assert isinstance(events[0], SessionStartEvent) + assert events[0].session_id == model.session_id + assert isinstance(events[-1], SessionEndEvent) @pytest.mark.asyncio async def test_event_conversion(mock_genai_client, model): """Test conversion of all Gemini Live event types to standard format.""" + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( + TranscriptStreamEvent, + AudioStreamEvent, + InterruptionEvent, + ) + _, _, _ = mock_genai_client await model.connect() - # Test text output + # Test text output (now converted to transcript) mock_text = unittest.mock.Mock() mock_text.text = "Hello from Gemini" mock_text.data = None @@ -285,9 +298,10 @@ async def test_event_conversion(mock_genai_client, model): mock_text.server_content = None text_event = model._convert_gemini_live_event(mock_text) - assert "textOutput" in text_event - assert text_event["textOutput"]["text"] == "Hello from Gemini" - assert text_event["textOutput"]["role"] == "assistant" + assert isinstance(text_event, TranscriptStreamEvent) + assert text_event.text == "Hello from Gemini" + assert text_event.source == "assistant" + assert text_event.is_final is True # Test audio output mock_audio = unittest.mock.Mock() @@ -297,9 +311,9 @@ async def test_event_conversion(mock_genai_client, model): mock_audio.server_content = None audio_event = model._convert_gemini_live_event(mock_audio) - assert "audioOutput" in audio_event - assert audio_event["audioOutput"]["audioData"] == b"audio_data" - assert audio_event["audioOutput"]["format"] == "pcm" + assert isinstance(audio_event, AudioStreamEvent) + assert audio_event.audio == b"audio_data" + assert audio_event.format == "pcm" # Test tool call mock_func_call = unittest.mock.Mock() @@ -334,8 +348,8 @@ async def test_event_conversion(mock_genai_client, model): mock_interrupt.server_content = mock_server_content interrupt_event = model._convert_gemini_live_event(mock_interrupt) - assert "interruptionDetected" in interrupt_event - assert interrupt_event["interruptionDetected"]["reason"] == "user_input" + assert isinstance(interrupt_event, InterruptionEvent) + assert interrupt_event.reason == "user_speech" await model.close() diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 10066a693..bc2b0961c 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -131,36 +131,42 @@ async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model @pytest.mark.asyncio async def test_send_all_content_types(nova_model, mock_client, mock_stream): """Test sending all content types through unified send() method.""" + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( + TextInputEvent, + AudioInputEvent, + ) + from strands.types._events import ToolResultEvent + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model._client = mock_client await nova_model.connect() # Test text content - text_event = {"text": "Hello, Nova!", "role": "user"} + text_event = TextInputEvent(text="Hello, Nova!", role="user") await nova_model.send(text_event) # Should send contentStart, textInput, and contentEnd assert mock_stream.input_stream.send.call_count >= 3 # Test audio content - audio_event = { - "audioData": b"audio data", - "format": "pcm", - "sampleRate": 16000, - "channels": 1 - } + audio_event = AudioInputEvent( + audio=b"audio data", + format="pcm", + sample_rate=16000, + channels=1 + ) await nova_model.send(audio_event) # Should start audio connection and send audio assert nova_model.audio_connection_active assert mock_stream.input_stream.send.called # Test tool result - tool_result = { + tool_result: ToolResult = { "toolUseId": "tool-123", "status": "success", "content": [{"text": "Weather is sunny"}] } - await nova_model.send(tool_result) + await nova_model.send(ToolResultEvent(tool_result)) # Should send contentStart, toolResult, and contentEnd assert mock_stream.input_stream.send.called @@ -170,19 +176,25 @@ async def test_send_all_content_types(nova_model, mock_client, mock_stream): @pytest.mark.asyncio async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog): """Test send() edge cases and error handling.""" + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( + TextInputEvent, + ImageInputEvent, + ) + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model._client = mock_client # Test send when inactive - text_event = {"text": "Hello", "role": "user"} + text_event = TextInputEvent(text="Hello", role="user") await nova_model.send(text_event) # Should not raise # Test image content (not supported) await nova_model.connect() - image_event = { - "imageData": b"image data", - "mimeType": "image/jpeg" - } + image_event = ImageInputEvent( + image=b"image data", + mime_type="image/jpeg", + encoding="raw" + ) await nova_model.send(image_event) # Should log warning about unsupported image input assert any("not supported" in record.message.lower() for record in caplog.records) @@ -319,6 +331,8 @@ async def test_audio_connection_lifecycle(nova_model, mock_client, mock_stream): @pytest.mark.asyncio async def test_silence_detection(nova_model, mock_client, mock_stream): """Test that silence detection automatically ends audio input.""" + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioInputEvent + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model._client = mock_client nova_model.silence_threshold = 0.1 # Short threshold for testing @@ -326,12 +340,12 @@ async def test_silence_detection(nova_model, mock_client, mock_stream): await nova_model.connect() # Send audio to start connection - audio_event = { - "audioData": b"audio data", - "format": "pcm", - "sampleRate": 16000, - "channels": 1 - } + audio_event = AudioInputEvent( + audio=b"audio data", + format="pcm", + sample_rate=16000, + channels=1 + ) await nova_model.send(audio_event) assert nova_model.audio_connection_active diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 1209150ba..7495a4489 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -222,11 +222,13 @@ async def async_connect(*args, **kwargs): @pytest.mark.asyncio async def test_send_all_content_types(mock_websockets_connect, model): """Test sending all content types through unified send() method.""" + from strands.types._events import ToolResultEvent + _, mock_ws = mock_websockets_connect await model.connect() # Test text input - text_input: TextInputEvent = {"text": "Hello", "role": "user"} + text_input = TextInputEvent(text="Hello", role="user") await model.send(text_input) calls = mock_ws.send.call_args_list messages = [json.loads(call[0][0]) for call in calls] @@ -236,12 +238,12 @@ async def test_send_all_content_types(mock_websockets_connect, model): assert len(response_create) > 0 # Test audio input - audio_input: AudioInputEvent = { - "audioData": b"audio_bytes", - "format": "pcm", - "sampleRate": 24000, - "channels": 1, - } + audio_input = AudioInputEvent( + audio=b"audio_bytes", + format="pcm", + sample_rate=24000, + channels=1, + ) await model.send(audio_input) calls = mock_ws.send.call_args_list messages = [json.loads(call[0][0]) for call in calls] @@ -257,7 +259,7 @@ async def test_send_all_content_types(mock_websockets_connect, model): "status": "success", "content": [{"text": "Result: 42"}], } - await model.send(tool_result) + await model.send(ToolResultEvent(tool_result)) calls = mock_ws.send.call_args_list messages = [json.loads(call[0][0]) for call in calls] item_create = [m for m in messages if m.get("type") == "conversation.item.create"] @@ -275,17 +277,17 @@ async def test_send_edge_cases(mock_websockets_connect, model): _, mock_ws = mock_websockets_connect # Test send when inactive - text_input: TextInputEvent = {"text": "Hello", "role": "user"} + text_input = TextInputEvent(text="Hello", role="user") await model.send(text_input) mock_ws.send.assert_not_called() # Test image input (not supported) await model.connect() - image_input: ImageInputEvent = { - "imageData": b"image_bytes", - "mimeType": "image/jpeg", - "encoding": "raw", - } + image_input = ImageInputEvent( + image=b"image_bytes", + mime_type="image/jpeg", + encoding="raw", + ) with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: await model.send(image_input) mock_logger.warning.assert_called_with("Image input not supported by OpenAI Realtime API") From 6529187994558056c91f41cecef053abac68c955 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 31 Oct 2025 13:51:47 +0100 Subject: [PATCH 034/242] fix: use protocol and improve _active handling --- .../models/bidirectional_model.py | 19 +++++++------------ .../models/gemini_live.py | 1 + .../models/novasonic.py | 1 + .../bidirectional_streaming/models/openai.py | 1 + 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 5b7091dcd..05fb19e0f 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -12,9 +12,8 @@ - Support for audio, text, image, and tool result streaming """ -import abc import logging -from typing import AsyncIterable, Union +from typing import AsyncIterable, Protocol, Union from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec @@ -28,15 +27,14 @@ logger = logging.getLogger(__name__) -class BidirectionalModel(abc.ABC): - """Abstract base class for bidirectional streaming models. +class BidirectionalModel(Protocol): + """Protocol for bidirectional streaming models. This interface defines the contract for models that support persistent streaming connections with real-time audio and text communication. Implementations handle provider-specific protocols while exposing a standardized event-based API. """ - @abc.abstractmethod async def connect( self, system_prompt: str | None = None, @@ -56,9 +54,8 @@ async def connect( messages: Initial conversation history to provide context. **kwargs: Provider-specific configuration options. """ - raise NotImplementedError + ... - @abc.abstractmethod async def close(self) -> None: """Close the streaming connection and release resources. @@ -66,9 +63,8 @@ async def close(self) -> None: resources such as network connections, buffers, or background tasks. After calling close(), the model instance cannot be used until connect() is called again. """ - raise NotImplementedError + ... - @abc.abstractmethod async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: """Receive streaming events from the model. @@ -82,9 +78,8 @@ async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: BidirectionalStreamEvent: Standardized event dictionaries containing audio output, text responses, tool calls, or control signals. """ - raise NotImplementedError + ... - @abc.abstractmethod async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: """Send content to the model over the active connection. @@ -104,4 +99,4 @@ async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputE await model.send(AudioInputEvent(audioData=bytes, format="pcm", ...)) await model.send(ToolResult(toolUseId="123", status="success", ...)) """ - raise NotImplementedError + ... diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 639328c64..cef8135eb 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -128,6 +128,7 @@ async def connect( await self._send_message_history(messages) except Exception as e: + self._active = False logger.error("Error connecting to Gemini Live: %s", e) raise diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index b9c5060ba..ddb0540f6 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -182,6 +182,7 @@ async def connect( logger.info("Nova Sonic connection established successfully") except Exception as e: + self._active = False logger.error("Nova connection create error: %s", str(e)) raise diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index a542ec894..e64508db7 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -160,6 +160,7 @@ async def connect( logger.info("OpenAI Realtime connection established") except Exception as e: + self._active = False logger.error("OpenAI connection error: %s", e) raise From 990d905a27d64734ed24ca36a09b78c951e0ce39 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 31 Oct 2025 14:03:17 +0100 Subject: [PATCH 035/242] refactor: simplify model names --- .../bidirectional_streaming/__init__.py | 12 ++++---- .../models/__init__.py | 12 ++++---- .../models/gemini_live.py | 2 +- .../models/novasonic.py | 4 +-- .../bidirectional_streaming/models/openai.py | 2 +- .../tests/test_bidi_novasonic.py | 6 ++-- .../tests/test_bidi_openai.py | 4 +-- .../tests/test_gemini_live.py | 4 +-- .../models/test_gemini_live.py | 22 +++++++-------- .../models/test_novasonic.py | 8 +++--- .../models/test_openai_realtime.py | 28 +++++++++---------- 11 files changed, 51 insertions(+), 53 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index d855ba038..caee4715a 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -7,9 +7,9 @@ from .models.bidirectional_model import BidirectionalModel # Model providers - What users need to create models -from .models.gemini_live import GeminiLiveBidirectionalModel -from .models.novasonic import NovaSonicBidirectionalModel -from .models.openai import OpenAIRealtimeBidirectionalModel +from .models.gemini_live import GeminiLiveModel +from .models.novasonic import NovaSonicModel +from .models.openai import OpenAIRealtimeModel # Event types - For type hints and event handling from .types.bidirectional_streaming import ( @@ -29,9 +29,9 @@ "BidirectionalAgent", # Model providers - "GeminiLiveBidirectionalModel", - "NovaSonicBidirectionalModel", - "OpenAIRealtimeBidirectionalModel", + "GeminiLiveModel", + "NovaSonicModel", + "OpenAIRealtimeModel", # Event types "AudioInputEvent", diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index 12fe6c271..5b0d50687 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -1,13 +1,13 @@ """Bidirectional model interfaces and implementations.""" from .bidirectional_model import BidirectionalModel -from .gemini_live import GeminiLiveBidirectionalModel -from .novasonic import NovaSonicBidirectionalModel -from .openai import OpenAIRealtimeBidirectionalModel +from .gemini_live import GeminiLiveModel +from .novasonic import NovaSonicModel +from .openai import OpenAIRealtimeModel __all__ = [ "BidirectionalModel", - "GeminiLiveBidirectionalModel", - "NovaSonicBidirectionalModel", - "OpenAIRealtimeBidirectionalModel", + "GeminiLiveModel", + "NovaSonicModel", + "OpenAIRealtimeModel", ] diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index cef8135eb..9f0cfe6c0 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -44,7 +44,7 @@ GEMINI_CHANNELS = 1 -class GeminiLiveBidirectionalModel(BidirectionalModel): +class GeminiLiveModel(BidirectionalModel): """Gemini Live API implementation using official Google GenAI SDK. Combines model configuration and connection state in a single class. diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index ddb0540f6..c9e5805db 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -78,7 +78,7 @@ RESPONSE_TIMEOUT = 1.0 -class NovaSonicBidirectionalModel(BidirectionalModel): +class NovaSonicModel(BidirectionalModel): """Nova Sonic implementation for bidirectional streaming. Combines model configuration and connection state in a single class. @@ -111,7 +111,6 @@ def __init__( # Nova Sonic requires unique content names self.audio_content_name = None - self.text_content_name = None # Audio connection state self.audio_connection_active = False @@ -154,7 +153,6 @@ async def connect( self.prompt_name = str(uuid.uuid4()) self._active = True self.audio_content_name = str(uuid.uuid4()) - self.text_content_name = str(uuid.uuid4()) self._event_queue = asyncio.Queue() # Start Nova Sonic bidirectional stream diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index e64508db7..0810b7b21 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -58,7 +58,7 @@ } -class OpenAIRealtimeBidirectionalModel(BidirectionalModel): +class OpenAIRealtimeModel(BidirectionalModel): """OpenAI Realtime API implementation for bidirectional streaming. Combines model configuration and connection state in a single class. diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py index 8c3ae3b4c..b0a41f20d 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py @@ -17,7 +17,7 @@ from strands_tools import calculator from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel +from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicModel def test_direct_tools(): @@ -30,7 +30,7 @@ def test_direct_tools(): return try: - model = NovaSonicBidirectionalModel() + model = NovaSonicModel() agent = BidirectionalAgent(model=model, tools=[calculator]) # Test calculator @@ -185,7 +185,7 @@ async def main(duration=180): print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") # Initialize model and agent - model = NovaSonicBidirectionalModel(region="us-east-1") + model = NovaSonicModel(region="us-east-1") agent = BidirectionalAgent(model=model, tools=[calculator], system_prompt="You are a helpful assistant.") await agent.start() diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py index 660040f3e..90e82c2bc 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py @@ -14,7 +14,7 @@ from strands_tools import calculator from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel +from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeModel async def play(context): @@ -205,7 +205,7 @@ async def main(): return False # Create OpenAI model - model = OpenAIRealtimeBidirectionalModel( + model = OpenAIRealtimeModel( model="gpt-4o-realtime-preview", api_key=api_key, session={ diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py index 4469e819a..23e97bd5d 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py @@ -38,7 +38,7 @@ from strands_tools import calculator from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveBidirectionalModel +from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveModel # Configure logging - debug only for Gemini Live, info for everything else logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') @@ -301,7 +301,7 @@ async def main(duration=180): # Initialize Gemini Live model with proper configuration logger.info("Initializing Gemini Live model with API key") - model = GeminiLiveBidirectionalModel( + model = GeminiLiveModel( model_id="gemini-2.5-flash-native-audio-preview-09-2025", api_key=api_key, params={ diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index b894509c9..8c0a61b4b 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -1,6 +1,6 @@ """Unit tests for Gemini Live bidirectional streaming model. -Tests the unified GeminiLiveBidirectionalModel interface including: +Tests the unified GeminiLiveModel interface including: - Model initialization and configuration - Connection establishment and lifecycle - Unified send() method with different content types @@ -13,7 +13,7 @@ from google import genai from google.genai import types as genai_types -from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveBidirectionalModel +from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveModel from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( AudioInputEvent, ImageInputEvent, @@ -55,9 +55,9 @@ def api_key(): @pytest.fixture def model(mock_genai_client, model_id, api_key): - """Create a GeminiLiveBidirectionalModel instance.""" + """Create a GeminiLiveModel instance.""" _ = mock_genai_client - return GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + return GeminiLiveModel(model_id=model_id, api_key=api_key) @pytest.fixture @@ -87,20 +87,20 @@ def test_model_initialization(mock_genai_client, model_id, api_key): _ = mock_genai_client # Test default config - model_default = GeminiLiveBidirectionalModel() + model_default = GeminiLiveModel() assert model_default.model_id == "models/gemini-2.0-flash-live-preview-04-09" assert model_default.api_key is None assert model_default._active is False assert model_default.live_session is None # Test with API key - model_with_key = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + model_with_key = GeminiLiveModel(model_id=model_id, api_key=api_key) assert model_with_key.model_id == model_id assert model_with_key.api_key == api_key # Test with custom config live_config = {"temperature": 0.7, "top_p": 0.9} - model_custom = GeminiLiveBidirectionalModel(model_id=model_id, live_config=live_config) + model_custom = GeminiLiveModel(model_id=model_id, live_config=live_config) assert model_custom.live_config == live_config @@ -151,7 +151,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): mock_client, _, mock_live_session_cm = mock_genai_client # Test connection error - model1 = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + model1 = GeminiLiveModel(model_id=model_id, api_key=api_key) mock_client.aio.live.connect.side_effect = Exception("Connection failed") with pytest.raises(Exception, match="Connection failed"): await model1.connect() @@ -160,18 +160,18 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): mock_client.aio.live.connect.side_effect = None # Test double connection - model2 = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + model2 = GeminiLiveModel(model_id=model_id, api_key=api_key) await model2.connect() with pytest.raises(RuntimeError, match="Connection already active"): await model2.connect() await model2.close() # Test close when not connected - model3 = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + model3 = GeminiLiveModel(model_id=model_id, api_key=api_key) await model3.close() # Should not raise # Test close error handling - model4 = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + model4 = GeminiLiveModel(model_id=model_id, api_key=api_key) await model4.connect() mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") with pytest.raises(Exception, match="Close failed"): diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 10066a693..5601e23b8 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -13,7 +13,7 @@ import pytest_asyncio from strands.experimental.bidirectional_streaming.models.novasonic import ( - NovaSonicBidirectionalModel, + NovaSonicModel, ) from strands.types.tools import ToolResult @@ -53,7 +53,7 @@ def mock_client(mock_stream): @pytest_asyncio.fixture async def nova_model(model_id, region): """Create Nova Sonic model instance.""" - model = NovaSonicBidirectionalModel(model_id=model_id, region=region) + model = NovaSonicModel(model_id=model_id, region=region) yield model # Cleanup if model._active: @@ -66,7 +66,7 @@ async def nova_model(model_id, region): @pytest.mark.asyncio async def test_model_initialization(model_id, region): """Test model initialization with configuration.""" - model = NovaSonicBidirectionalModel(model_id=model_id, region=region) + model = NovaSonicModel(model_id=model_id, region=region) assert model.model_id == model_id assert model.region == region @@ -120,7 +120,7 @@ async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model await nova_model.close() # Test close when already closed - model2 = NovaSonicBidirectionalModel(model_id=model_id, region=region) + model2 = NovaSonicModel(model_id=model_id, region=region) await model2.close() # Should not raise await model2.close() # Second call should also be safe diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 1209150ba..388fc95cc 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -1,6 +1,6 @@ """Unit tests for OpenAI Realtime bidirectional streaming model. -Tests the unified OpenAIRealtimeBidirectionalModel interface including: +Tests the unified OpenAIRealtimeModel interface including: - Model initialization and configuration - Connection establishment with WebSocket - Unified send() method with different content types @@ -15,7 +15,7 @@ import pytest -from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel +from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeModel from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( AudioInputEvent, ImageInputEvent, @@ -56,8 +56,8 @@ def api_key(): @pytest.fixture def model(api_key, model_name): - """Create an OpenAIRealtimeBidirectionalModel instance.""" - return OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + """Create an OpenAIRealtimeModel instance.""" + return OpenAIRealtimeModel(model=model_name, api_key=api_key) @pytest.fixture @@ -85,19 +85,19 @@ def messages(): def test_model_initialization(api_key, model_name): """Test model initialization with various configurations.""" # Test default config - model_default = OpenAIRealtimeBidirectionalModel(api_key="test-key") + model_default = OpenAIRealtimeModel(api_key="test-key") assert model_default.model == "gpt-realtime" assert model_default.api_key == "test-key" assert model_default._active is False assert model_default.websocket is None # Test with custom model - model_custom = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + model_custom = OpenAIRealtimeModel(model=model_name, api_key=api_key) assert model_custom.model == model_name assert model_custom.api_key == api_key # Test with organization and project - model_org = OpenAIRealtimeBidirectionalModel( + model_org = OpenAIRealtimeModel( model=model_name, api_key=api_key, organization="org-123", @@ -108,7 +108,7 @@ def test_model_initialization(api_key, model_name): # Test with env API key with unittest.mock.patch.dict("os.environ", {"OPENAI_API_KEY": "env-key"}): - model_env = OpenAIRealtimeBidirectionalModel() + model_env = OpenAIRealtimeModel() assert model_env.api_key == "env-key" @@ -116,7 +116,7 @@ def test_init_without_api_key_raises(): """Test that initialization without API key raises error.""" with unittest.mock.patch.dict("os.environ", {}, clear=True): with pytest.raises(ValueError, match="OpenAI API key is required"): - OpenAIRealtimeBidirectionalModel() + OpenAIRealtimeModel() # Connection Tests @@ -171,7 +171,7 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp await model.close() # Test connection with organization header - model_org = OpenAIRealtimeBidirectionalModel(api_key="test-key", organization="org-123") + model_org = OpenAIRealtimeModel(api_key="test-key", organization="org-123") await model_org.connect() call_kwargs = mock_connect.call_args.kwargs headers = call_kwargs.get("additional_headers", []) @@ -187,7 +187,7 @@ async def test_connection_edge_cases(mock_websockets_connect, api_key, model_nam mock_connect, mock_ws = mock_websockets_connect # Test connection error - model1 = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + model1 = OpenAIRealtimeModel(model=model_name, api_key=api_key) mock_connect.side_effect = Exception("Connection failed") with pytest.raises(Exception, match="Connection failed"): await model1.connect() @@ -198,18 +198,18 @@ async def async_connect(*args, **kwargs): mock_connect.side_effect = async_connect # Test double connection - model2 = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + model2 = OpenAIRealtimeModel(model=model_name, api_key=api_key) await model2.connect() with pytest.raises(RuntimeError, match="Connection already active"): await model2.connect() await model2.close() # Test close when not connected - model3 = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + model3 = OpenAIRealtimeModel(model=model_name, api_key=api_key) await model3.close() # Should not raise # Test close error handling (should not raise, just log) - model4 = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + model4 = OpenAIRealtimeModel(model=model_name, api_key=api_key) await model4.connect() mock_ws.close.side_effect = Exception("Close failed") await model4.close() # Should not raise From b9aa1f898a6d9d4c00b19fcaecc28d7b8f2bd131 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 31 Oct 2025 17:33:56 +0100 Subject: [PATCH 036/242] fix: get tests working --- .../bidirectional_streaming/agent/agent.py | 58 ++++++--- .../event_loop/bidirectional_event_loop.py | 50 ++++---- .../models/gemini_live.py | 24 ++-- .../models/novasonic.py | 110 ++++++++---------- .../bidirectional_streaming/models/openai.py | 95 +++++++-------- .../tests/test_bidi_novasonic.py | 57 +++++---- .../tests/test_bidi_openai.py | 68 +++++++---- .../tests/test_gemini_live.py | 110 ++++++++++-------- .../types/bidirectional_streaming.py | 48 ++++---- .../models/test_gemini_live.py | 19 +-- .../models/test_novasonic.py | 94 ++++++++------- .../models/test_openai_realtime.py | 90 ++++++++------ 12 files changed, 454 insertions(+), 369 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index d74860222..ab08978fb 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -360,39 +360,67 @@ async def start(self) -> None: logger.debug("Conversation start - initializing session") self._session = await start_bidirectional_connection(self) - async def send(self, input_data: str | AudioInputEvent | ImageInputEvent) -> None: - """Send input to the model (text, audio, or image). + async def send(self, input_data: str | AudioInputEvent | ImageInputEvent | dict) -> None: + """Send input to the model (text, audio, image, or event dict). Unified method for sending text, audio, and image input to the model during - an active conversation session. + an active conversation session. Accepts TypedEvent instances or plain dicts + (e.g., from WebSocket clients) which are automatically reconstructed. Args: - input_data: String for text, AudioInputEvent for audio, or ImageInputEvent for images. + input_data: Can be: + - str: Text message from user + - AudioInputEvent: Audio data with format/sample rate + - ImageInputEvent: Image data with MIME type + - dict: Event dictionary (will be reconstructed to TypedEvent) Raises: ValueError: If no active session or invalid input type. + + Example: + await agent.send("Hello") + await agent.send(AudioInputEvent(audio="base64...", format="pcm", ...)) + await agent.send({"type": "bidirectional_text_input", "text": "Hello", "role": "user"}) """ self._validate_active_session() + # Handle string input if isinstance(input_data, str): # Add user text message to history self.messages.append({"role": "user", "content": input_data}) - logger.debug("Text sent: %d characters", len(input_data)) - # Create TextInputEvent for send() - text_event = {"text": input_data, "role": "user"} + from ..types.bidirectional_streaming import TextInputEvent + text_event = TextInputEvent(text=input_data, role="user") await self._session.model.send(text_event) - elif isinstance(input_data, dict) and "audioData" in input_data: - # Handle audio input - already in AudioInputEvent format - await self._session.model.send(input_data) - elif isinstance(input_data, dict) and "imageData" in input_data: - # Handle image input - already in ImageInputEvent format + return + + # Handle dict - reconstruct TypedEvent for WebSocket integration + if isinstance(input_data, dict) and "type" in input_data: + from ..types.bidirectional_streaming import TextInputEvent + event_type = input_data["type"] + if event_type == "bidirectional_text_input": + input_data = TextInputEvent(text=input_data["text"], role=input_data["role"]) + elif event_type == "bidirectional_audio_input": + input_data = AudioInputEvent( + audio=input_data["audio"], + format=input_data["format"], + sample_rate=input_data["sample_rate"], + channels=input_data["channels"] + ) + elif event_type == "bidirectional_image_input": + input_data = ImageInputEvent( + image=input_data["image"], + mime_type=input_data["mime_type"] + ) + else: + raise ValueError(f"Unknown event type: {event_type}") + + # Handle TypedEvent instances + if isinstance(input_data, (AudioInputEvent, ImageInputEvent, TextInputEvent)): await self._session.model.send(input_data) else: raise ValueError( - "Input must be either a string (text), AudioInputEvent " - "(dict with audioData, format, sampleRate, channels), or ImageInputEvent " - "(dict with imageData, mimeType, encoding)" + f"Input must be a string, TypedEvent, or event dict, got: {type(input_data)}" ) async def receive(self) -> AsyncIterable[dict[str, Any]]: diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index d1d6e90b3..27732294a 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -223,7 +223,9 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: try: while True: event = session.agent._output_queue.get_nowait() - if event.get("audioOutput"): + # Check for audio events + event_type = event.get("type", "") + if event_type == "bidirectional_audio_stream": audio_cleared += 1 else: # Keep non-audio events @@ -267,8 +269,11 @@ async def _process_model_events(session: BidirectionalConnection) -> None: strands_event = provider_event - # Handle interruption detection (provider converts raw patterns to interruptionDetected) - if strands_event.get("interruptionDetected"): + # Get event type + event_type = strands_event.get("type", "") + + # Handle interruption detection + if event_type == "bidirectional_interruption": logger.debug("Interruption forwarded") await _handle_interruption(session) # Forward interruption event to agent for application-level handling @@ -276,26 +281,23 @@ async def _process_model_events(session: BidirectionalConnection) -> None: continue # Queue tool requests for concurrent execution - if strands_event.get("toolUse"): - tool_name = strands_event["toolUse"].get("name") - logger.debug("Tool usage detected: %s", tool_name) - await session.tool_queue.put(strands_event["toolUse"]) + if event_type == "tool_use": + tool_use = strands_event.get("tool_use") + if tool_use: + tool_name = tool_use.get("name") + logger.debug("Tool usage detected: %s", tool_name) + await session.tool_queue.put(tool_use) continue - # Send output events to Agent for receive() method - if strands_event.get("audioOutput") or strands_event.get("textOutput"): - await session.agent._output_queue.put(strands_event) + # Send all output events to Agent for receive() method + await session.agent._output_queue.put(strands_event) - # Update Agent conversation history using existing patterns - if strands_event.get("messageStop"): - logger.debug("Message added to history") - session.agent.messages.append(strands_event["messageStop"]["message"]) - - # Handle user audio transcripts - add to message history - if strands_event.get("textOutput") and strands_event["textOutput"].get("role") == "user": - user_transcript = strands_event["textOutput"]["text"] - if user_transcript.strip(): # Only add non-empty transcripts - user_message = {"role": "user", "content": user_transcript} + # Update Agent conversation history for user transcripts + if event_type == "bidirectional_transcript_stream": + source = strands_event.get("source") + text = strands_event.get("text", "") + if source == "user" and text.strip(): + user_message = {"role": "user", "content": text} session.agent.messages.append(user_message) logger.debug("User transcript added to history") @@ -434,8 +436,8 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: tool_result = tool_event.tool_result tool_use_id = tool_result.get("toolUseId") - # Send result through send() method - await session.model.send(tool_result) + # Send ToolResultEvent through send() method + await session.model.send(tool_event) logger.debug("Tool result sent: %s", tool_use_id) # Handle streaming events if needed later @@ -464,14 +466,14 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: except Exception as e: logger.error("Tool execution error: %s - %s", tool_name, str(e)) - # Send error result + # Send error result wrapped in ToolResultEvent error_result: ToolResult = { "toolUseId": tool_id, "status": "error", "content": [{"text": f"Error: {str(e)}"}] } try: - await session.model.send(error_result) + await session.model.send(ToolResultEvent(error_result)) logger.debug("Error result sent: %s", tool_id) except Exception: logger.error("Failed to send error result: %s", tool_id) diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index f7bfebac8..02044125f 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -248,8 +248,10 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic # Handle audio output using SDK's built-in data property if message.data: + # Convert bytes to base64 string for JSON serializability + audio_b64 = base64.b64encode(message.data).decode('utf-8') return AudioStreamEvent( - audio=message.data, + audio=audio_b64, format="pcm", sample_rate=GEMINI_OUTPUT_SAMPLE_RATE, channels=GEMINI_CHANNELS @@ -311,9 +313,12 @@ async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: This automatically triggers VAD and can interrupt ongoing responses. """ try: + # Decode base64 audio to bytes for SDK + audio_bytes = base64.b64decode(audio_input.audio) + # Create audio blob for the SDK audio_blob = genai_types.Blob( - data=audio_input.audio, + data=audio_bytes, mime_type=f"audio/pcm;rate={GEMINI_INPUT_SAMPLE_RATE}" ) @@ -330,21 +335,10 @@ async def _send_image_content(self, image_input: ImageInputEvent) -> None: Images are sent as base64-encoded data with MIME type. """ try: - # Prepare the message based on encoding - if image_input.encoding == "base64": - # Data is already base64 encoded - if isinstance(image_input.image, bytes): - data_str = image_input.image.decode() - else: - data_str = image_input.image - else: - # Raw bytes - need to base64 encode - data_str = base64.b64encode(image_input.image).decode() - - # Create the message in the format expected by Gemini Live + # Image is already base64 encoded in the event msg = { "mime_type": image_input.mime_type, - "data": data_str + "data": image_input.image } # Send using the same method as the GitHub example diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 63afc3378..fa7465454 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -40,6 +40,7 @@ ImageInputEvent, InterruptionEvent, MultimodalUsage, + OutputEvent, SessionEndEvent, SessionStartEvent, TextInputEvent, @@ -271,12 +272,12 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: logger.debug("Nova events - starting event stream") - # Emit connection start event to Strands event system - connection_start: BidirectionalConnectionStartEvent = { - "connectionId": self.prompt_name, - "metadata": {"provider": "nova_sonic", "model_id": self.model_id}, - } - yield {"BidirectionalConnectionStart": connection_start} + # Emit session start event + yield SessionStartEvent( + session_id=self.prompt_name, + model=self.model_id, + capabilities=["audio", "tools"] + ) try: while self._active: @@ -296,14 +297,10 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: except Exception as e: logger.error("Error receiving Nova Sonic event: %s", e) logger.error(traceback.format_exc()) + yield ErrorEvent(error=e) finally: - # Emit connection end event when exiting - connection_end: BidirectionalConnectionEndEvent = { - "connectionId": self.prompt_name, - "reason": "connection_complete", - "metadata": {"provider": "nova_sonic"}, - } - yield {"BidirectionalConnectionEnd": connection_end} + # Emit session end event + yield SessionEndEvent(reason="complete") async def send( self, @@ -372,9 +369,7 @@ async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: if self.silence_task and not self.silence_task.done(): self.silence_task.cancel() - # Convert audio to Nova Sonic base64 format - nova_audio_data = base64.b64encode(audio_input.audio).decode("utf-8") - + # Audio is already base64 encoded in the event # Send audio input event audio_event = json.dumps( { @@ -382,7 +377,7 @@ async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: "audioInput": { "promptName": self.prompt_name, "contentName": self.audio_content_name, - "content": nova_audio_data, + "content": audio_input.audio, } } } @@ -513,82 +508,79 @@ async def close(self) -> None: finally: logger.debug("Nova connection closed") - def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | None: - """Convert Nova Sonic events to provider-agnostic format.""" + def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: + """Convert Nova Sonic events to TypedEvent format.""" # Handle audio output if "audioOutput" in nova_event: + # Audio is already base64 string from Nova Sonic audio_content = nova_event["audioOutput"]["content"] - audio_bytes = base64.b64decode(audio_content) - - audio_output: AudioOutputEvent = { - "audioData": audio_bytes, - "format": "pcm", - "sampleRate": 24000, - "channels": 1, - "encoding": "base64", - } - - return {"audioOutput": audio_output} + return AudioStreamEvent( + audio=audio_content, + format="pcm", + sample_rate=24000, + channels=1 + ) - # Handle text output + # Handle text output (transcripts) elif "textOutput" in nova_event: text_content = nova_event["textOutput"]["content"] # Use stored role from contentStart event, fallback to event role role = getattr(self, "_current_role", nova_event["textOutput"].get("role", "assistant")) - # Check for Nova Sonic interruption pattern (matches working sample) + # Check for Nova Sonic interruption pattern if '{ "interrupted" : true }' in text_content: logger.debug("Nova interruption detected in text") - interruption: InterruptionDetectedEvent = {"reason": "user_input"} - return {"interruptionDetected": interruption} - - # Show transcription for user speech - ALWAYS show these regardless of DEBUG flag - if role == "USER": - print(f"User: {text_content}") - elif role == "ASSISTANT": - print(f"Assistant: {text_content}") + return InterruptionEvent(reason="user_speech", turn_id=None) - text_output: TextOutputEvent = {"text": text_content, "role": role.lower()} - - return {"textOutput": text_output} + return TranscriptStreamEvent( + text=text_content, + source="user" if role == "USER" else "assistant", + is_final=True + ) # Handle tool use elif "toolUse" in nova_event: tool_use = nova_event["toolUse"] - tool_use_event: ToolUse = { "toolUseId": tool_use["toolUseId"], "name": tool_use["toolName"], "input": json.loads(tool_use["content"]), } - - return {"toolUse": tool_use_event} + # Return dict with tool_use for event loop processing + return {"type": "tool_use", "tool_use": tool_use_event} # Handle interruption elif nova_event.get("stopReason") == "INTERRUPTED": logger.debug("Nova interruption stop reason") + return InterruptionEvent(reason="user_speech", turn_id=None) - interruption: InterruptionDetectedEvent = {"reason": "user_input"} - - return {"interruptionDetected": interruption} - - # Handle usage events - convert to standardized format + # Handle usage events - convert to multimodal usage format elif "usageEvent" in nova_event: usage_data = nova_event["usageEvent"] - usage_metrics: UsageMetricsEvent = { - "totalTokens": usage_data.get("totalTokens", 0), - "inputTokens": usage_data.get("totalInputTokens", 0), - "outputTokens": usage_data.get("totalOutputTokens", 0), - "audioTokens": usage_data.get("details", {}).get("total", {}).get("output", {}).get("speechTokens", 0) - } - return {"usageMetrics": usage_metrics} + total_input = usage_data.get("totalInputTokens", 0) + total_output = usage_data.get("totalOutputTokens", 0) + + return MultimodalUsage( + input_tokens=total_input, + output_tokens=total_output, + total_tokens=usage_data.get("totalTokens", total_input + total_output) + ) # Handle content start events (track role) elif "contentStart" in nova_event: role = nova_event["contentStart"].get("role", "unknown") # Store role for subsequent text output events self._current_role = role - return None + # Emit turn start event + return TurnStartEvent(turn_id=str(uuid.uuid4())) + + # Handle content stop events + elif "contentStop" in nova_event: + stop_reason = nova_event["contentStop"].get("stopReason", "complete") + return TurnCompleteEvent( + turn_id=str(uuid.uuid4()), + stop_reason="interrupted" if stop_reason == "INTERRUPTED" else "complete" + ) # Handle other events else: diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index be954dead..15d1bbf86 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -173,15 +173,21 @@ def _require_active(self) -> bool: """Check if session is active.""" return self._active - def _create_text_event(self, text: str, role: str) -> dict[str, any]: - """Create standardized text output event.""" - text_output: TextOutputEvent = {"text": text, "role": role} - return {"textOutput": text_output} + def _create_text_event(self, text: str, role: str) -> TranscriptStreamEvent: + """Create standardized transcript event.""" + return TranscriptStreamEvent( + text=text, + source="user" if role == "user" else "assistant", + is_final=True + ) - def _create_voice_activity_event(self, activity_type: str) -> dict[str, any]: - """Create standardized voice activity event.""" - voice_activity: VoiceActivityEvent = {"activityType": activity_type} - return {"voiceActivity": voice_activity} + def _create_voice_activity_event(self, activity_type: str) -> InterruptionEvent | None: + """Create standardized interruption event for voice activity.""" + # Only speech_started triggers interruption + if activity_type == "speech_started": + return InterruptionEvent(reason="user_speech", turn_id=None) + # Other voice activity events are logged but don't create events + return None def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict: """Build session configuration for OpenAI Realtime API.""" @@ -273,12 +279,13 @@ async def _process_responses(self) -> None: logger.debug("OpenAI Realtime response processor stopped") async def receive(self) -> AsyncIterable[OutputEvent]: - """Receive OpenAI events and convert to Strands format.""" - connection_start: BidirectionalConnectionStartEvent = { - "connectionId": self.session_id, - "metadata": {"provider": "openai_realtime", "model": self.model}, - } - yield {"BidirectionalConnectionStart": connection_start} + """Receive OpenAI events and convert to Strands TypedEvent format.""" + # Emit session start event + yield SessionStartEvent( + session_id=self.session_id, + model=self.model, + capabilities=["audio", "tools"] + ) try: while self._active: @@ -292,29 +299,24 @@ async def receive(self) -> AsyncIterable[OutputEvent]: except Exception as e: logger.error("Error receiving OpenAI Realtime event: %s", e) + yield ErrorEvent(error=e) finally: - connection_end: BidirectionalConnectionEndEvent = { - "connectionId": self.session_id, - "reason": "connection_complete", - "metadata": {"provider": "openai_realtime"}, - } - yield {"BidirectionalConnectionEnd": connection_end} + # Emit session end event + yield SessionEndEvent(reason="complete") - def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] | None: - """Convert OpenAI events to Strands format.""" + def _convert_openai_event(self, openai_event: dict[str, any]) -> OutputEvent | None: + """Convert OpenAI events to Strands TypedEvent format.""" event_type = openai_event.get("type") # Audio output if event_type == "response.output_audio.delta": - audio_data = base64.b64decode(openai_event["delta"]) - audio_output: AudioOutputEvent = { - "audioData": audio_data, - "format": "pcm", - "sampleRate": 24000, - "channels": 1, - "encoding": None, - } - return {"audioOutput": audio_output} + # Audio is already base64 string from OpenAI + return AudioStreamEvent( + audio=openai_event["delta"], + format="pcm", + sample_rate=24000, + channels=1 + ) # Assistant text output events - combine multiple similar events elif event_type in ["response.output_text.delta", "response.output_audio_transcript.delta"]: @@ -359,7 +361,8 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] "input": json.loads(function_call["arguments"]) if function_call["arguments"] else {}, } del self._function_call_buffer[call_id] - return {"toolUse": tool_use} + # Return dict with tool_use for event loop processing + return {"type": "tool_use", "tool_use": tool_use} except (json.JSONDecodeError, KeyError) as e: logger.warning("Error parsing function arguments for %s: %s", call_id, e) del self._function_call_buffer[call_id] @@ -385,23 +388,14 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] elif event_type == "conversation.item.done": logger.debug("OpenAI conversation item done: %s", openai_event.get("item", {}).get("id")) - + # This event signals turn completion - emit TurnCompleteEvent item = openai_event.get("item", {}) if item.get("type") == "message" and item.get("role") == "assistant": - content_parts = item.get("content", []) - if content_parts: - message_content = [] - for content_part in content_parts: - if content_part.get("type") == "output_text": - message_content.append({"type": "text", "text": content_part.get("text", "")}) - elif content_part.get("type") == "output_audio": - transcript = content_part.get("transcript", "") - if transcript: - message_content.append({"type": "text", "text": transcript}) - - if message_content: - message = {"role": "assistant", "content": message_content} - return {"messageStop": {"message": message}} + item_id = item.get("id", "unknown") + return TurnCompleteEvent( + turn_id=item_id, + stop_reason="complete" + ) return None # Response output events - combine similar events @@ -452,6 +446,7 @@ async def send( return try: + # Note: TypedEvent inherits from dict, so isinstance checks for TypedEvent must come first if isinstance(content, TextInputEvent): await self._send_text_content(content.text) elif isinstance(content, AudioInputEvent): @@ -464,14 +459,14 @@ async def send( if tool_result: await self._send_tool_result(tool_result) else: - logger.warning(f"Unknown content type: {type(content)}") + logger.warning(f"Unknown content type: {type(content).__name__}") except Exception as e: logger.error(f"Error sending content: {e}") async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: """Internal: Send audio content to OpenAI for processing.""" - audio_base64 = base64.b64encode(audio_input.audio).decode("utf-8") - await self._send_event({"type": "input_audio_buffer.append", "audio": audio_base64}) + # Audio is already base64 encoded in the event + await self._send_event({"type": "input_audio_buffer.append", "audio": audio_input.audio}) async def _send_text_content(self, text: str) -> None: """Internal: Send text content to OpenAI for processing.""" diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py index b0a41f20d..b538fc023 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py @@ -5,6 +5,7 @@ """ import asyncio +import base64 import sys from pathlib import Path @@ -129,33 +130,36 @@ async def receive(agent, context): """Receive and process events from agent.""" try: async for event in agent.receive(): - # Handle audio output - if "audioOutput" in event: + # Get event type + event_type = event.get("type", "unknown") + + # Handle audio stream events (bidirectional_audio_stream) + if event_type == "bidirectional_audio_stream": if not context.get("interrupted", False): - context["audio_out"].put_nowait(event["audioOutput"]["audioData"]) + # Decode base64 audio string to bytes for playback + audio_b64 = event["audio"] + audio_data = base64.b64decode(audio_b64) + context["audio_out"].put_nowait(audio_data) - # Handle interruption events - elif "interruptionDetected" in event: + # Handle interruption events (bidirectional_interruption) + elif event_type == "bidirectional_interruption": context["interrupted"] = True - elif "interrupted" in event: - context["interrupted"] = True - - # Handle text output with interruption detection - elif "textOutput" in event: - text_content = event["textOutput"].get("content", "") - role = event["textOutput"].get("role", "unknown") - - # Check for text-based interruption patterns - if '{ "interrupted" : true }' in text_content: - context["interrupted"] = True - elif "interrupted" in text_content.lower(): - context["interrupted"] = True - # Log text output - if role.upper() == "USER": + # Handle transcript events (bidirectional_transcript_stream) + elif event_type == "bidirectional_transcript_stream": + text_content = event.get("text", "") + source = event.get("source", "unknown") + + # Log transcript output + if source == "user": print(f"User: {text_content}") - elif role.upper() == "ASSISTANT": + elif source == "assistant": print(f"Assistant: {text_content}") + + # Handle turn complete events (bidirectional_turn_complete) + elif event_type == "bidirectional_turn_complete": + # Reset interrupted state since the turn is complete + context["interrupted"] = False except asyncio.CancelledError: pass @@ -167,7 +171,16 @@ async def send(agent, context): while time.time() - context["start_time"] < context["duration"]: try: audio_bytes = context["audio_in"].get_nowait() - audio_event = {"audioData": audio_bytes, "format": "pcm", "sampleRate": 16000, "channels": 1} + # Create audio event using TypedEvent + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioInputEvent + + audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') + audio_event = AudioInputEvent( + audio=audio_b64, + format="pcm", + sample_rate=16000, + channels=1 + ) await agent.send(audio_event) except asyncio.QueueEmpty: await asyncio.sleep(0.01) # Restored to working timing diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py index 90e82c2bc..d270637be 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py @@ -2,6 +2,7 @@ """Test OpenAI Realtime API speech-to-speech interaction.""" import asyncio +import base64 import os import sys import time @@ -118,35 +119,48 @@ async def receive(agent, context): if not context["active"]: break - # Handle audio output - if "audioOutput" in event: - audio_data = event["audioOutput"]["audioData"] + # Get event type + event_type = event.get("type", "unknown") + + # Handle audio stream events (bidirectional_audio_stream) + if event_type == "bidirectional_audio_stream": + # Decode base64 audio string to bytes for playback + audio_b64 = event["audio"] + audio_data = base64.b64decode(audio_b64) if not context.get("interrupted", False): await context["audio_out"].put(audio_data) - # Handle text output (transcripts) - elif "textOutput" in event: - text_output = event["textOutput"] - role = text_output.get("role", "assistant") - text = text_output.get("text", "").strip() + # Handle transcript events (bidirectional_transcript_stream) + elif event_type == "bidirectional_transcript_stream": + source = event.get("source", "assistant") + text = event.get("text", "").strip() if text: - if role == "user": - print(f"User: {text}") - elif role == "assistant": - print(f"Assistant: {text}") + if source == "user": + print(f"🎤 User: {text}") + elif source == "assistant": + print(f"🔊 Assistant: {text}") - # Handle interruption detection - elif "interruptionDetected" in event: + # Handle interruption events (bidirectional_interruption) + elif event_type == "bidirectional_interruption": context["interrupted"] = True + print("⚠️ Interruption detected") + + # Handle session start events (bidirectional_session_start) + elif event_type == "bidirectional_session_start": + print(f"✓ Session started: {event.get('model', 'unknown')}") - # Handle connection events - elif "BidirectionalConnectionStart" in event: - pass # Silent connection start - elif "BidirectionalConnectionEnd" in event: + # Handle session end events (bidirectional_session_end) + elif event_type == "bidirectional_session_end": + print(f"✓ Session ended: {event.get('reason', 'unknown')}") context["active"] = False break + + # Handle turn complete events (bidirectional_turn_complete) + elif event_type == "bidirectional_turn_complete": + # Reset interrupted state since the turn is complete + context["interrupted"] = False except asyncio.CancelledError: pass @@ -163,13 +177,17 @@ async def send(agent, context): try: audio_bytes = await asyncio.wait_for(context["audio_in"].get(), timeout=0.1) - # Create audio event in expected format - audio_event = { - "audioData": audio_bytes, - "format": "pcm", - "sampleRate": 24000, - "channels": 1 - } + # Create audio event using TypedEvent + # Encode audio bytes to base64 string for JSON serializability + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioInputEvent + + audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') + audio_event = AudioInputEvent( + audio=audio_b64, + format="pcm", + sample_rate=24000, + channels=1 + ) await agent.send(audio_event) diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py index 23e97bd5d..0bd283eb9 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py @@ -145,56 +145,58 @@ async def receive(agent, context): """Receive and process events from agent.""" try: async for event in agent.receive(): - # Debug: Log all event types - event_types = [k for k in event.keys() if not k.startswith('_')] - if event_types: - logger.debug(f"Received event types: {event_types}") + # Debug: Log event type and keys + event_type = event.get("type", "unknown") + event_keys = list(event.keys()) + logger.debug(f"Received event type: {event_type}, keys: {event_keys}") - # Handle audio output - if "audioOutput" in event: + # Handle audio stream events (bidirectional_audio_stream) + if event_type == "bidirectional_audio_stream": if not context.get("interrupted", False): - context["audio_out"].put_nowait(event["audioOutput"]["audioData"]) - - # Handle interruption events - elif "interruptionDetected" in event: - context["interrupted"] = True - elif "interrupted" in event: + # Decode base64 audio string to bytes for playback + audio_b64 = event["audio"] + audio_data = base64.b64decode(audio_b64) + context["audio_out"].put_nowait(audio_data) + logger.info(f"🔊 Audio queued for playback: {len(audio_data)} bytes") + + # Handle interruption events (bidirectional_interruption) + elif event_type == "bidirectional_interruption": context["interrupted"] = True + logger.info("Interruption detected") - # Handle text output - elif "textOutput" in event: - text_content = event["textOutput"].get("text", "") - role = event["textOutput"].get("role", "unknown") - - # Check for text-based interruption patterns - if '{ "interrupted" : true }' in text_content: - context["interrupted"] = True - elif "interrupted" in text_content.lower(): - context["interrupted"] = True - - # Log text output - if role.upper() == "USER": - print(f"User: {text_content}") - elif role.upper() == "ASSISTANT": - print(f"Assistant: {text_content}") - - # Handle transcript events (audio transcriptions) - elif "transcript" in event: - transcript_text = event["transcript"].get("text", "") - transcript_role = event["transcript"].get("role", "unknown") - transcript_type = event["transcript"].get("type", "unknown") + # Handle transcript events (bidirectional_transcript_stream) + elif event_type == "bidirectional_transcript_stream": + transcript_text = event.get("text", "") + transcript_source = event.get("source", "unknown") + is_final = event.get("is_final", False) - # Print transcripts with special formatting to distinguish from text output - if transcript_role.upper() == "USER": - print(f"🎤 User (transcript): {transcript_text}") - elif transcript_role.upper() == "ASSISTANT": - print(f"🔊 Assistant (transcript): {transcript_text}") + # Print transcripts with special formatting + if transcript_source == "user": + print(f"🎤 User: {transcript_text}") + elif transcript_source == "assistant": + print(f"🔊 Assistant: {transcript_text}") - # Handle turn complete events - elif "turnComplete" in event: - logger.debug("Turn complete event received - model ready for next input") + # Handle turn complete events (bidirectional_turn_complete) + elif event_type == "bidirectional_turn_complete": + logger.debug("Turn complete - model ready for next input") # Reset interrupted state since the turn is complete context["interrupted"] = False + + # Handle session start events (bidirectional_session_start) + elif event_type == "bidirectional_session_start": + logger.info(f"Session started: {event.get('model', 'unknown')}") + + # Handle session end events (bidirectional_session_end) + elif event_type == "bidirectional_session_end": + logger.info(f"Session ended: {event.get('reason', 'unknown')}") + + # Handle error events (bidirectional_error) + elif event_type == "bidirectional_error": + logger.error(f"Error: {event.get('error_message', 'unknown')}") + + # Handle turn start events (bidirectional_turn_start) + elif event_type == "bidirectional_turn_start": + logger.debug(f"Turn started: {event.get('turn_id', 'unknown')}") except asyncio.CancelledError: pass @@ -246,11 +248,12 @@ async def get_frames(context): # Send frame to agent as image input try: - image_event = { - "imageData": frame["data"], - "mimeType": frame["mime_type"], - "encoding": "base64" - } + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ImageInputEvent + + image_event = ImageInputEvent( + image=frame["data"], # Already base64 encoded + mime_type=frame["mime_type"] + ) await context["agent"].send(image_event) print("📸 Frame sent to model") except Exception as e: @@ -272,7 +275,16 @@ async def send(agent, context): while time.time() - context["start_time"] < context["duration"]: try: audio_bytes = context["audio_in"].get_nowait() - audio_event = {"audioData": audio_bytes, "format": "pcm", "sampleRate": 16000, "channels": 1} + # Create audio event using TypedEvent + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioInputEvent + + audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') + audio_event = AudioInputEvent( + audio=audio_b64, + format="pcm", + sample_rate=16000, + channels=1 + ) await agent.send(audio_event) except asyncio.QueueEmpty: await asyncio.sleep(0.01) @@ -304,7 +316,7 @@ async def main(duration=180): model = GeminiLiveModel( model_id="gemini-2.5-flash-native-audio-preview-09-2025", api_key=api_key, - params={ + live_config={ "response_modalities": ["AUDIO"], "output_audio_transcription": {}, # Enable output transcription "input_audio_transcription": {} # Enable input transcription diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index e7af3ad43..160b15a27 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -9,12 +9,14 @@ - Session lifecycle management - Provider-agnostic event types - Type-safe discriminated unions with TypedEvent +- JSON-serializable events (audio/images stored as base64 strings) Audio format normalization: - Supports PCM, WAV, Opus, and MP3 formats - Standardizes sample rates (16kHz, 24kHz, 48kHz) - Normalizes channel configurations (mono/stereo) - Abstracts provider-specific encodings +- Audio data stored as base64-encoded strings for JSON compatibility """ from typing import Any, Dict, List, Literal, Optional, Union, cast @@ -69,7 +71,7 @@ class AudioInputEvent(TypedEvent): Used for sending audio data through the send() method. Parameters: - audio: Raw audio bytes to send to model (not base64 encoded). + audio: Base64-encoded audio string to send to model. format: Audio format from SUPPORTED_AUDIO_FORMATS. sample_rate: Sample rate from SUPPORTED_SAMPLE_RATES. channels: Channel count from SUPPORTED_CHANNELS. @@ -77,7 +79,7 @@ class AudioInputEvent(TypedEvent): def __init__( self, - audio: bytes, + audio: str, format: Literal["pcm", "wav", "opus", "mp3"], sample_rate: Literal[16000, 24000, 48000], channels: Literal[1, 2], @@ -93,8 +95,8 @@ def __init__( ) @property - def audio(self) -> bytes: - return cast(bytes, self.get("audio")) + def audio(self) -> str: + return cast(str, self.get("audio")) @property def format(self) -> str: @@ -112,42 +114,34 @@ def channels(self) -> int: class ImageInputEvent(TypedEvent): """Image input event for sending images/video frames to the model. - Used for sending image data through the send() method. Supports both - raw image bytes and base64-encoded data. + Used for sending image data through the send() method. Parameters: - image: Image bytes (raw or base64-encoded string). + image: Base64-encoded image string. mime_type: MIME type (e.g., "image/jpeg", "image/png"). - encoding: How the image data is encoded. """ def __init__( self, - image: Union[bytes, str], + image: str, mime_type: str, - encoding: Literal["base64", "raw"], ): super().__init__( { "type": "bidirectional_image_input", "image": image, "mime_type": mime_type, - "encoding": encoding, } ) @property - def image(self) -> Union[bytes, str]: - return cast(Union[bytes, str], self.get("image")) + def image(self) -> str: + return cast(str, self.get("image")) @property def mime_type(self) -> str: return cast(str, self.get("mime_type")) - @property - def encoding(self) -> str: - return cast(str, self.get("encoding")) - # ============================================================================ # Output Events (received via session.receive_events()) @@ -205,7 +199,7 @@ class AudioStreamEvent(TypedEvent): """Streaming audio output from the model. Parameters: - audio: Raw audio data as bytes (not base64 encoded). + audio: Base64-encoded audio string. format: Audio encoding format. sample_rate: Number of audio samples per second in Hz. channels: Number of audio channels (1=mono, 2=stereo). @@ -213,7 +207,7 @@ class AudioStreamEvent(TypedEvent): def __init__( self, - audio: bytes, + audio: str, format: Literal["pcm", "wav", "opus", "mp3"], sample_rate: Literal[16000, 24000, 48000], channels: Literal[1, 2], @@ -229,8 +223,8 @@ def __init__( ) @property - def audio(self) -> bytes: - return cast(bytes, self.get("audio")) + def audio(self) -> str: + return cast(str, self.get("audio")) @property def format(self) -> str: @@ -436,8 +430,11 @@ class ErrorEvent(TypedEvent): Similar to strands.types._events.ForceStopEvent, this event wraps exceptions that occur during bidirectional streaming sessions. + Note: The Exception object is not stored in the event data to maintain JSON + serializability. Only the error message, code, and details are stored. + Parameters: - error: The exception that occurred. + error: The exception that occurred (used to extract message and type). code: Optional error code for programmatic handling (defaults to exception class name). details: Optional additional error information. """ @@ -450,18 +447,13 @@ def __init__( ): super().__init__( { - "bidirectional_error": True, - "error": error, + "type": "bidirectional_error", "error_message": str(error), "error_code": code or type(error).__name__, "error_details": details, } ) - @property - def error(self) -> Exception: - return cast(Exception, self.get("error")) - @property def code(self) -> str: return cast(str, self.get("error_code")) diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index 86b75fd21..f48f04910 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -197,9 +197,11 @@ async def test_send_all_content_types(mock_genai_client, model): assert content.role == "user" assert content.parts[0].text == "Hello" - # Test audio input + # Test audio input (base64 encoded) + import base64 + audio_b64 = base64.b64encode(b"audio_bytes").decode('utf-8') audio_input = AudioInputEvent( - audio=b"audio_bytes", + audio=audio_b64, format="pcm", sample_rate=16000, channels=1, @@ -207,11 +209,11 @@ async def test_send_all_content_types(mock_genai_client, model): await model.send(audio_input) mock_live_session.send_realtime_input.assert_called_once() - # Test image input + # Test image input (base64 encoded, no encoding parameter) + image_b64 = base64.b64encode(b"image_bytes").decode('utf-8') image_input = ImageInputEvent( - image=b"image_bytes", + image=image_b64, mime_type="image/jpeg", - encoding="raw", ) await model.send(image_input) mock_live_session.send.assert_called_once() @@ -303,7 +305,8 @@ async def test_event_conversion(mock_genai_client, model): assert text_event.source == "assistant" assert text_event.is_final is True - # Test audio output + # Test audio output (now returns base64 encoded string) + import base64 mock_audio = unittest.mock.Mock() mock_audio.text = None mock_audio.data = b"audio_data" @@ -312,7 +315,9 @@ async def test_event_conversion(mock_genai_client, model): audio_event = model._convert_gemini_live_event(mock_audio) assert isinstance(audio_event, AudioStreamEvent) - assert audio_event.audio == b"audio_data" + # Audio is now base64 encoded + expected_b64 = base64.b64encode(b"audio_data").decode('utf-8') + assert audio_event.audio == expected_b64 assert audio_event.format == "pcm" # Test tool call diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 17f6b8e57..83c5e2f82 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -148,9 +148,10 @@ async def test_send_all_content_types(nova_model, mock_client, mock_stream): # Should send contentStart, textInput, and contentEnd assert mock_stream.input_stream.send.call_count >= 3 - # Test audio content + # Test audio content (base64 encoded) + audio_b64 = base64.b64encode(b"audio data").decode('utf-8') audio_event = AudioInputEvent( - audio=b"audio data", + audio=audio_b64, format="pcm", sample_rate=16000, channels=1 @@ -188,12 +189,13 @@ async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog): text_event = TextInputEvent(text="Hello", role="user") await nova_model.send(text_event) # Should not raise - # Test image content (not supported) + # Test image content (not supported, base64 encoded, no encoding parameter) await nova_model.connect() + import base64 + image_b64 = base64.b64encode(b"image data").decode('utf-8') image_event = ImageInputEvent( - image=b"image data", + image=image_b64, mime_type="image/jpeg", - encoding="raw" ) await nova_model.send(image_event) # Should log warning about unsupported image input @@ -224,36 +226,41 @@ async def mock_wait_for(*args, **kwargs): async for event in nova_model.receive(): events.append(event) - # Should have connection start and end + # Should have session start and end (new TypedEvent format) assert len(events) >= 2 - assert "BidirectionalConnectionStart" in events[0] - assert events[0]["BidirectionalConnectionStart"]["connectionId"] == nova_model.prompt_name - assert "BidirectionalConnectionEnd" in events[-1] + assert events[0].get("type") == "bidirectional_session_start" + assert events[0].get("session_id") == nova_model.prompt_name + assert events[-1].get("type") == "bidirectional_session_end" @pytest.mark.asyncio async def test_event_conversion(nova_model): """Test conversion of all Nova Sonic event types to standard format.""" - # Test audio output + # Test audio output (now returns AudioStreamEvent) + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioStreamEvent audio_bytes = b"test audio data" audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") nova_event = {"audioOutput": {"content": audio_base64}} result = nova_model._convert_nova_event(nova_event) assert result is not None - assert "audioOutput" in result - assert result["audioOutput"]["audioData"] == audio_bytes - assert result["audioOutput"]["format"] == "pcm" - assert result["audioOutput"]["sampleRate"] == 24000 - - # Test text output + assert isinstance(result, AudioStreamEvent) + assert result.get("type") == "bidirectional_audio_stream" + # Audio is kept as base64 string + assert result.get("audio") == audio_base64 + assert result.get("format") == "pcm" + assert result.get("sample_rate") == 24000 + + # Test text output (now returns TranscriptStreamEvent) + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import TranscriptStreamEvent nova_event = {"textOutput": {"content": "Hello, world!", "role": "ASSISTANT"}} result = nova_model._convert_nova_event(nova_event) assert result is not None - assert "textOutput" in result - assert result["textOutput"]["text"] == "Hello, world!" - assert result["textOutput"]["role"] == "assistant" + assert isinstance(result, TranscriptStreamEvent) + assert result.get("type") == "bidirectional_transcript_stream" + assert result.get("text") == "Hello, world!" + assert result.get("source") == "assistant" - # Test tool use + # Test tool use (now returns dict with tool_use) tool_input = {"location": "Seattle"} nova_event = { "toolUse": { @@ -264,19 +271,23 @@ async def test_event_conversion(nova_model): } result = nova_model._convert_nova_event(nova_event) assert result is not None - assert "toolUse" in result - assert result["toolUse"]["toolUseId"] == "tool-123" - assert result["toolUse"]["name"] == "get_weather" - assert result["toolUse"]["input"] == tool_input - - # Test interruption + assert result.get("type") == "tool_use" + tool_use = result.get("tool_use") + assert tool_use["toolUseId"] == "tool-123" + assert tool_use["name"] == "get_weather" + assert tool_use["input"] == tool_input + + # Test interruption (now returns InterruptionEvent) + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import InterruptionEvent nova_event = {"stopReason": "INTERRUPTED"} result = nova_model._convert_nova_event(nova_event) assert result is not None - assert "interruptionDetected" in result - assert result["interruptionDetected"]["reason"] == "user_input" + assert isinstance(result, InterruptionEvent) + assert result.get("type") == "bidirectional_interruption" + assert result.get("reason") == "user_speech" - # Test usage metrics + # Test usage metrics (now returns MultimodalUsage) + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import MultimodalUsage nova_event = { "usageEvent": { "totalTokens": 100, @@ -293,16 +304,19 @@ async def test_event_conversion(nova_model): } result = nova_model._convert_nova_event(nova_event) assert result is not None - assert "usageMetrics" in result - assert result["usageMetrics"]["totalTokens"] == 100 - assert result["usageMetrics"]["inputTokens"] == 40 - assert result["usageMetrics"]["outputTokens"] == 60 - assert result["usageMetrics"]["audioTokens"] == 30 - - # Test content start tracks role + assert isinstance(result, MultimodalUsage) + assert result.get("type") == "multimodal_usage" + assert result.get("totalTokens") == 100 + assert result.get("inputTokens") == 40 + assert result.get("outputTokens") == 60 + + # Test content start tracks role and emits TurnStartEvent + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import TurnStartEvent nova_event = {"contentStart": {"role": "USER"}} result = nova_model._convert_nova_event(nova_event) - assert result is None # contentStart doesn't emit an event + assert result is not None + assert isinstance(result, TurnStartEvent) + assert result.get("type") == "bidirectional_turn_start" assert nova_model._current_role == "USER" @@ -339,9 +353,11 @@ async def test_silence_detection(nova_model, mock_client, mock_stream): await nova_model.connect() - # Send audio to start connection + # Send audio to start connection (base64 encoded) + import base64 + audio_b64 = base64.b64encode(b"audio data").decode('utf-8') audio_event = AudioInputEvent( - audio=b"audio data", + audio=audio_b64, format="pcm", sample_rate=16000, channels=1 diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 7f799816a..8640f5833 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -237,9 +237,10 @@ async def test_send_all_content_types(mock_websockets_connect, model): assert len(item_create) > 0 assert len(response_create) > 0 - # Test audio input + # Test audio input (base64 encoded) + audio_b64 = base64.b64encode(b"audio_bytes").decode('utf-8') audio_input = AudioInputEvent( - audio=b"audio_bytes", + audio=audio_b64, format="pcm", sample_rate=24000, channels=1, @@ -250,8 +251,8 @@ async def test_send_all_content_types(mock_websockets_connect, model): audio_append = [m for m in messages if m.get("type") == "input_audio_buffer.append"] assert len(audio_append) > 0 assert "audio" in audio_append[0] - decoded = base64.b64decode(audio_append[0]["audio"]) - assert decoded == b"audio_bytes" + # Audio should be passed through as base64 + assert audio_append[0]["audio"] == audio_b64 # Test tool result tool_result: ToolResult = { @@ -281,12 +282,12 @@ async def test_send_edge_cases(mock_websockets_connect, model): await model.send(text_input) mock_ws.send.assert_not_called() - # Test image input (not supported) + # Test image input (not supported, base64 encoded, no encoding parameter) await model.connect() + image_b64 = base64.b64encode(b"image_bytes").decode('utf-8') image_input = ImageInputEvent( - image=b"image_bytes", + image=image_b64, mime_type="image/jpeg", - encoding="raw", ) with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: await model.send(image_input) @@ -315,11 +316,12 @@ async def test_receive_lifecycle_events(mock_websockets_connect, model): receive_gen = model.receive() first_event = await anext(receive_gen) - # First event should be connection start - assert "BidirectionalConnectionStart" in first_event - assert first_event["BidirectionalConnectionStart"]["connectionId"] == model.session_id + # First event should be session start (new TypedEvent format) + assert first_event.get("type") == "bidirectional_session_start" + assert first_event.get("session_id") == model.session_id + assert first_event.get("model") == model.model - # Close to trigger connection end + # Close to trigger session end await model.close() # Collect remaining events @@ -330,8 +332,8 @@ async def test_receive_lifecycle_events(mock_websockets_connect, model): except StopAsyncIteration: pass - # Last event should be connection end - assert "BidirectionalConnectionEnd" in events[-1] + # Last event should be session end (new TypedEvent format) + assert events[-1].get("type") == "bidirectional_session_end" @pytest.mark.asyncio @@ -340,25 +342,29 @@ async def test_event_conversion(mock_websockets_connect, model): _, _ = mock_websockets_connect await model.connect() - # Test audio output + # Test audio output (now returns AudioStreamEvent) + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioStreamEvent audio_event = { "type": "response.output_audio.delta", "delta": base64.b64encode(b"audio_data").decode() } converted = model._convert_openai_event(audio_event) - assert "audioOutput" in converted - assert converted["audioOutput"]["audioData"] == b"audio_data" - assert converted["audioOutput"]["format"] == "pcm" + assert isinstance(converted, AudioStreamEvent) + assert converted.get("type") == "bidirectional_audio_stream" + assert converted.get("audio") == base64.b64encode(b"audio_data").decode() + assert converted.get("format") == "pcm" - # Test text output + # Test text output (now returns TranscriptStreamEvent) + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import TranscriptStreamEvent text_event = { "type": "response.output_text.delta", "delta": "Hello from OpenAI" } converted = model._convert_openai_event(text_event) - assert "textOutput" in converted - assert converted["textOutput"]["text"] == "Hello from OpenAI" - assert converted["textOutput"]["role"] == "assistant" + assert isinstance(converted, TranscriptStreamEvent) + assert converted.get("type") == "bidirectional_transcript_stream" + assert converted.get("text") == "Hello from OpenAI" + assert converted.get("source") == "assistant" # Test function call sequence item_added = { @@ -383,18 +389,23 @@ async def test_event_conversion(mock_websockets_connect, model): "call_id": "call-123" } converted = model._convert_openai_event(args_done) - assert "toolUse" in converted - assert converted["toolUse"]["toolUseId"] == "call-123" - assert converted["toolUse"]["name"] == "calculator" - assert converted["toolUse"]["input"]["expression"] == "2+2" - - # Test voice activity + # Now returns dict with tool_use + assert isinstance(converted, dict) + assert converted.get("type") == "tool_use" + tool_use = converted.get("tool_use") + assert tool_use["toolUseId"] == "call-123" + assert tool_use["name"] == "calculator" + assert tool_use["input"]["expression"] == "2+2" + + # Test voice activity (now returns InterruptionEvent for speech_started) + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import InterruptionEvent speech_started = { "type": "input_audio_buffer.speech_started" } converted = model._convert_openai_event(speech_started) - assert "voiceActivity" in converted - assert converted["voiceActivity"]["activityType"] == "speech_started" + assert isinstance(converted, InterruptionEvent) + assert converted.get("type") == "bidirectional_interruption" + assert converted.get("reason") == "user_speech" await model.close() @@ -442,16 +453,23 @@ def test_helper_methods(model): assert model._require_active() is True model._active = False - # Test _create_text_event + # Test _create_text_event (now returns TranscriptStreamEvent) + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import TranscriptStreamEvent text_event = model._create_text_event("Hello", "user") - assert "textOutput" in text_event - assert text_event["textOutput"]["text"] == "Hello" - assert text_event["textOutput"]["role"] == "user" + assert isinstance(text_event, TranscriptStreamEvent) + assert text_event.get("type") == "bidirectional_transcript_stream" + assert text_event.get("text") == "Hello" + assert text_event.get("source") == "user" - # Test _create_voice_activity_event + # Test _create_voice_activity_event (now returns InterruptionEvent for speech_started) + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import InterruptionEvent voice_event = model._create_voice_activity_event("speech_started") - assert "voiceActivity" in voice_event - assert voice_event["voiceActivity"]["activityType"] == "speech_started" + assert isinstance(voice_event, InterruptionEvent) + assert voice_event.get("type") == "bidirectional_interruption" + assert voice_event.get("reason") == "user_speech" + + # Other voice activities return None + assert model._create_voice_activity_event("speech_stopped") is None @pytest.mark.asyncio From 0b1ab4bc0c3ae790e14e91865ecb67c17d57c114 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 31 Oct 2025 17:46:17 +0100 Subject: [PATCH 037/242] feat(bidirectional): Add agent.run --- .../bidirectional_streaming/agent/agent.py | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 820a6c490..7e4f71a92 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -431,6 +431,75 @@ async def end(self) -> None: await stop_bidirectional_connection(self._session) self._session = None + async def run( + self, + send_callable: Callable[[Any], Any], + receive_callable: Callable[[], Any], + ) -> None: + """Run the agent with send/receive loop management. + + Starts the session, pipes events between the agent and transport layer, + and handles cleanup on disconnection. + + Args: + send_callable: Async callable that sends events to the client (e.g., websocket.send_json). + receive_callable: Async callable that receives events from the client (e.g., websocket.receive_json). + + Example: + ```python + # With WebSocket + agent = BidirectionalAgent(model=model, tools=[calculator]) + await agent.run(websocket.send_json, websocket.receive_json) + + # With custom transport + async def custom_send(event): + # Custom send logic + pass + + async def custom_receive(): + # Custom receive logic + return event + + await agent.run(custom_send, custom_receive) + ``` + + Raises: + Exception: Any exception from the transport layer (e.g., WebSocketDisconnect). + """ + await self.start() + + async def receive_from_agent(): + """Receive events from agent and send to client.""" + try: + async for event in self.receive(): + await send_callable(event) + except Exception as e: + logger.debug(f"Receive from agent stopped: {e}") + raise + + async def send_to_agent(): + """Receive events from client and send to agent.""" + try: + while self._session and self._session.active: + event = await receive_callable() + await self.send(event) + except Exception as e: + logger.debug(f"Send to agent stopped: {e}") + raise + + try: + # Run both loops concurrently + await asyncio.gather( + receive_from_agent(), + send_to_agent(), + return_exceptions=True + ) + finally: + try: + await self.end() + except Exception as e: + logger.debug(f"Error during cleanup: {e}") + def _validate_active_session(self) -> None: """Validate that an active session exists. From f8939f3e256db6f5abaa137ad9e41545f2fda3ff Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 31 Oct 2025 17:55:40 +0100 Subject: [PATCH 038/242] Update src/strands/experimental/bidirectional_streaming/agent/agent.py Co-authored-by: Nick Clegg --- src/strands/experimental/bidirectional_streaming/agent/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 7e4f71a92..5a43c7bcf 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -431,7 +431,7 @@ async def end(self) -> None: await stop_bidirectional_connection(self._session) self._session = None - async def run( + async def __call__( self, send_callable: Callable[[Any], Any], receive_callable: Callable[[], Any], From 0799be8f21f437ef0d3c09956eb17bf3c2f25046 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 3 Nov 2025 12:17:21 +0300 Subject: [PATCH 039/242] feat: add usage to openai --- .../bidirectional_streaming/models/openai.py | 114 +++++++++++++++--- .../models/test_openai_realtime.py | 46 ++++--- 2 files changed, 121 insertions(+), 39 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 15d1bbf86..181c01c27 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -291,9 +291,8 @@ async def receive(self) -> AsyncIterable[OutputEvent]: while self._active: try: openai_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) - provider_event = self._convert_openai_event(openai_event) - if provider_event: - yield provider_event + for event in self._convert_openai_event(openai_event) or []: + yield event except asyncio.TimeoutError: continue @@ -304,35 +303,41 @@ async def receive(self) -> AsyncIterable[OutputEvent]: # Emit session end event yield SessionEndEvent(reason="complete") - def _convert_openai_event(self, openai_event: dict[str, any]) -> OutputEvent | None: + def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEvent] | None: """Convert OpenAI events to Strands TypedEvent format.""" event_type = openai_event.get("type") + # Turn start - response begins + if event_type == "response.created": + response = openai_event.get("response", {}) + response_id = response.get("id", str(uuid.uuid4())) + return [TurnStartEvent(turn_id=response_id)] + # Audio output - if event_type == "response.output_audio.delta": + elif event_type == "response.output_audio.delta": # Audio is already base64 string from OpenAI - return AudioStreamEvent( + return [AudioStreamEvent( audio=openai_event["delta"], format="pcm", sample_rate=24000, channels=1 - ) + )] # Assistant text output events - combine multiple similar events elif event_type in ["response.output_text.delta", "response.output_audio_transcript.delta"]: - return self._create_text_event(openai_event["delta"], "assistant") + return [self._create_text_event(openai_event["delta"], "assistant")] # User transcription events - combine multiple similar events elif event_type in ["conversation.item.input_audio_transcription.delta", "conversation.item.input_audio_transcription.completed"]: text_key = "delta" if "delta" in event_type else "transcript" text = openai_event.get(text_key, "") - return self._create_text_event(text, "user") if text.strip() else None + return [self._create_text_event(text, "user")] if text.strip() else None elif event_type == "conversation.item.input_audio_transcription.segment": segment_data = openai_event.get("segment", {}) text = segment_data.get("text", "") - return self._create_text_event(text, "user") if text.strip() else None + return [self._create_text_event(text, "user")] if text.strip() else None elif event_type == "conversation.item.input_audio_transcription.failed": error_info = openai_event.get("error", {}) @@ -362,7 +367,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> OutputEvent | N } del self._function_call_buffer[call_id] # Return dict with tool_use for event loop processing - return {"type": "tool_use", "tool_use": tool_use} + return [{"type": "tool_use", "tool_use": tool_use}] except (json.JSONDecodeError, KeyError) as e: logger.warning("Error parsing function arguments for %s: %s", call_id, e) del self._function_call_buffer[call_id] @@ -377,7 +382,84 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> OutputEvent | N "input_audio_buffer.speech_stopped": "speech_stopped", "input_audio_buffer.timeout_triggered": "timeout" } - return self._create_voice_activity_event(activity_map[event_type]) + event = self._create_voice_activity_event(activity_map[event_type]) + return [event] if event else None + + # Turn complete and usage - response finished + elif event_type == "response.done": + response = openai_event.get("response", {}) + response_id = response.get("id", "unknown") + status = response.get("status", "completed") + usage = response.get("usage") + + # Map OpenAI status to our stop_reason + stop_reason_map = { + "completed": "complete", + "cancelled": "interrupted", + "failed": "error", + "incomplete": "interrupted" + } + + # Build list of events to return + events = [] + + # Always add turn complete event + events.append(TurnCompleteEvent( + turn_id=response_id, + stop_reason=stop_reason_map.get(status, "complete") + )) + + # Add usage event if available + if usage: + input_details = usage.get("input_token_details", {}) + output_details = usage.get("output_token_details", {}) + + # Build modality details + modality_details = [] + + # Text modality + text_input = input_details.get("text_tokens", 0) + text_output = output_details.get("text_tokens", 0) + if text_input > 0 or text_output > 0: + modality_details.append({ + "modality": "text", + "input_tokens": text_input, + "output_tokens": text_output + }) + + # Audio modality + audio_input = input_details.get("audio_tokens", 0) + audio_output = output_details.get("audio_tokens", 0) + if audio_input > 0 or audio_output > 0: + modality_details.append({ + "modality": "audio", + "input_tokens": audio_input, + "output_tokens": audio_output + }) + + # Image modality + image_input = input_details.get("image_tokens", 0) + if image_input > 0: + modality_details.append({ + "modality": "image", + "input_tokens": image_input, + "output_tokens": 0 + }) + + # Cached tokens + cached_tokens = input_details.get("cached_tokens", 0) + + # Add usage event + events.append(MultimodalUsage( + input_tokens=usage.get("input_tokens", 0), + output_tokens=usage.get("output_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + modality_details=modality_details if modality_details else None, + cache_read_input_tokens=cached_tokens if cached_tokens > 0 else None + )) + + # Return list of events + return events # Lifecycle events (log only) - combine multiple similar events elif event_type in ["conversation.item.retrieve", "conversation.item.added"]: @@ -388,14 +470,6 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> OutputEvent | N elif event_type == "conversation.item.done": logger.debug("OpenAI conversation item done: %s", openai_event.get("item", {}).get("id")) - # This event signals turn completion - emit TurnCompleteEvent - item = openai_event.get("item", {}) - if item.get("type") == "message" and item.get("role") == "assistant": - item_id = item.get("id", "unknown") - return TurnCompleteEvent( - turn_id=item_id, - stop_reason="complete" - ) return None # Response output events - combine similar events diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 8640f5833..60e88aa0f 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -342,29 +342,33 @@ async def test_event_conversion(mock_websockets_connect, model): _, _ = mock_websockets_connect await model.connect() - # Test audio output (now returns AudioStreamEvent) + # Test audio output (now returns list with AudioStreamEvent) from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioStreamEvent audio_event = { "type": "response.output_audio.delta", "delta": base64.b64encode(b"audio_data").decode() } converted = model._convert_openai_event(audio_event) - assert isinstance(converted, AudioStreamEvent) - assert converted.get("type") == "bidirectional_audio_stream" - assert converted.get("audio") == base64.b64encode(b"audio_data").decode() - assert converted.get("format") == "pcm" - - # Test text output (now returns TranscriptStreamEvent) + assert isinstance(converted, list) + assert len(converted) == 1 + assert isinstance(converted[0], AudioStreamEvent) + assert converted[0].get("type") == "bidirectional_audio_stream" + assert converted[0].get("audio") == base64.b64encode(b"audio_data").decode() + assert converted[0].get("format") == "pcm" + + # Test text output (now returns list with TranscriptStreamEvent) from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import TranscriptStreamEvent text_event = { "type": "response.output_text.delta", "delta": "Hello from OpenAI" } converted = model._convert_openai_event(text_event) - assert isinstance(converted, TranscriptStreamEvent) - assert converted.get("type") == "bidirectional_transcript_stream" - assert converted.get("text") == "Hello from OpenAI" - assert converted.get("source") == "assistant" + assert isinstance(converted, list) + assert len(converted) == 1 + assert isinstance(converted[0], TranscriptStreamEvent) + assert converted[0].get("type") == "bidirectional_transcript_stream" + assert converted[0].get("text") == "Hello from OpenAI" + assert converted[0].get("source") == "assistant" # Test function call sequence item_added = { @@ -389,23 +393,27 @@ async def test_event_conversion(mock_websockets_connect, model): "call_id": "call-123" } converted = model._convert_openai_event(args_done) - # Now returns dict with tool_use - assert isinstance(converted, dict) - assert converted.get("type") == "tool_use" - tool_use = converted.get("tool_use") + # Now returns list with dict containing tool_use + assert isinstance(converted, list) + assert len(converted) == 1 + assert isinstance(converted[0], dict) + assert converted[0].get("type") == "tool_use" + tool_use = converted[0].get("tool_use") assert tool_use["toolUseId"] == "call-123" assert tool_use["name"] == "calculator" assert tool_use["input"]["expression"] == "2+2" - # Test voice activity (now returns InterruptionEvent for speech_started) + # Test voice activity (now returns list with InterruptionEvent for speech_started) from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import InterruptionEvent speech_started = { "type": "input_audio_buffer.speech_started" } converted = model._convert_openai_event(speech_started) - assert isinstance(converted, InterruptionEvent) - assert converted.get("type") == "bidirectional_interruption" - assert converted.get("reason") == "user_speech" + assert isinstance(converted, list) + assert len(converted) == 1 + assert isinstance(converted[0], InterruptionEvent) + assert converted[0].get("type") == "bidirectional_interruption" + assert converted[0].get("reason") == "user_speech" await model.close() From a2f29b37fde4d566ec588bd615d8e200c96df54e Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 3 Nov 2025 12:17:58 +0300 Subject: [PATCH 040/242] feat: add usage to gemini --- .../models/gemini_live.py | 45 ++++++++++++++++++- .../models/test_gemini_live.py | 4 +- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 02044125f..39e6deed7 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -30,6 +30,7 @@ ErrorEvent, ImageInputEvent, InterruptionEvent, + MultimodalUsage, SessionEndEvent, SessionStartEvent, TextInputEvent, @@ -205,6 +206,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic - inputTranscription: User's speech transcribed to text - outputTranscription: Model's audio transcribed to text - modelTurn text: Text response from the model + - usageMetadata: Token usage information """ try: # Handle interruption first (from server_content) @@ -267,7 +269,48 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic } return {"toolUse": tool_use_event} - # Silently ignore setup_complete, turn_complete, generation_complete, and usage_metadata messages + # Handle usage metadata + if hasattr(message, 'usage_metadata') and message.usage_metadata: + usage = message.usage_metadata + + # Build modality details from token details + modality_details = [] + + # Process prompt tokens details + if usage.prompt_tokens_details: + for detail in usage.prompt_tokens_details: + if detail.modality and detail.token_count: + modality_details.append({ + "modality": str(detail.modality).lower(), + "input_tokens": detail.token_count, + "output_tokens": 0 + }) + + # Process response tokens details + if usage.response_tokens_details: + for detail in usage.response_tokens_details: + if detail.modality and detail.token_count: + # Find or create modality entry + modality_str = str(detail.modality).lower() + existing = next((m for m in modality_details if m["modality"] == modality_str), None) + if existing: + existing["output_tokens"] = detail.token_count + else: + modality_details.append({ + "modality": modality_str, + "input_tokens": 0, + "output_tokens": detail.token_count + }) + + return MultimodalUsage( + input_tokens=usage.prompt_token_count or 0, + output_tokens=usage.response_token_count or 0, + total_tokens=usage.total_token_count or 0, + modality_details=modality_details if modality_details else None, + cache_read_input_tokens=usage.cached_content_token_count if usage.cached_content_token_count else None + ) + + # Silently ignore setup_complete and generation_complete messages return None except Exception as e: diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index f48f04910..d3bf965f4 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -292,7 +292,7 @@ async def test_event_conversion(mock_genai_client, model): _, _, _ = mock_genai_client await model.connect() - # Test text output (now converted to transcript) + # Test text output (converted to transcript) mock_text = unittest.mock.Mock() mock_text.text = "Hello from Gemini" mock_text.data = None @@ -305,7 +305,7 @@ async def test_event_conversion(mock_genai_client, model): assert text_event.source == "assistant" assert text_event.is_final is True - # Test audio output (now returns base64 encoded string) + # Test audio output (base64 encoded) import base64 mock_audio = unittest.mock.Mock() mock_audio.text = None From 73e85dc68ef02fa22da4293e00a45860529f36aa Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 3 Nov 2025 12:23:15 +0300 Subject: [PATCH 041/242] fix: change call to run --- .../bidirectional_streaming/agent/agent.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 5a43c7bcf..4e8adfe7c 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -431,10 +431,11 @@ async def end(self) -> None: await stop_bidirectional_connection(self._session) self._session = None - async def __call__( + async def run( self, - send_callable: Callable[[Any], Any], - receive_callable: Callable[[], Any], + *, + sender: Callable[[Any], Any], + receiver: Callable[[], Any], ) -> None: """Run the agent with send/receive loop management. @@ -442,14 +443,14 @@ async def __call__( and handles cleanup on disconnection. Args: - send_callable: Async callable that sends events to the client (e.g., websocket.send_json). - receive_callable: Async callable that receives events from the client (e.g., websocket.receive_json). + sender: Async callable that sends events to the client (e.g., websocket.send_json). + receiver: Async callable that receives events from the client (e.g., websocket.receive_json). Example: ```python # With WebSocket agent = BidirectionalAgent(model=model, tools=[calculator]) - await agent.run(websocket.send_json, websocket.receive_json) + await agent.run(sender=websocket.send_json, receiver=websocket.receive_json) # With custom transport async def custom_send(event): @@ -460,7 +461,7 @@ async def custom_receive(): # Custom receive logic return event - await agent.run(custom_send, custom_receive) + await agent.run(sender=custom_send, receiver=custom_receive) ``` Raises: @@ -472,7 +473,7 @@ async def receive_from_agent(): """Receive events from agent and send to client.""" try: async for event in self.receive(): - await send_callable(event) + await sender(event) except Exception as e: logger.debug(f"Receive from agent stopped: {e}") raise @@ -481,7 +482,7 @@ async def send_to_agent(): """Receive events from client and send to agent.""" try: while self._session and self._session.active: - event = await receive_callable() + event = await receiver() await self.send(event) except Exception as e: logger.debug(f"Send to agent stopped: {e}") From e53615881a241fd6816335e4e82336601be2e785 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 3 Nov 2025 17:07:35 +0300 Subject: [PATCH 042/242] feat: refactor test context and add openai bidi agent test --- .../test_bidirectional_agent.py | 173 +++++++++++++++--- .../utils/test_context.py | 138 +++++++++----- 2 files changed, 237 insertions(+), 74 deletions(-) diff --git a/tests_integ/bidirectional_streaming/test_bidirectional_agent.py b/tests_integ/bidirectional_streaming/test_bidirectional_agent.py index 46887652b..9ae64514d 100644 --- a/tests_integ/bidirectional_streaming/test_bidirectional_agent.py +++ b/tests_integ/bidirectional_streaming/test_bidirectional_agent.py @@ -1,38 +1,143 @@ -"""Basic integration tests for Nova Sonic bidirectional streaming. +"""Parameterized integration tests for bidirectional streaming. -Tests fundamental functionality including multi-turn conversations, audio I/O, -text transcription, and tool execution using the new context manager approach. +Tests fundamental functionality across multiple model providers (Nova Sonic, OpenAI, etc.) +including multi-turn conversations, audio I/O, text transcription, and tool execution. + +This demonstrates the provider-agnostic design of the bidirectional streaming system. """ +import asyncio import logging +import os import pytest from strands_tools import calculator from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel +from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel +from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveBidirectionalModel from .utils.test_context import BidirectionalTestContext logger = logging.getLogger(__name__) +# Provider configurations +PROVIDER_CONFIGS = { + "nova_sonic": { + "model_class": NovaSonicBidirectionalModel, + "model_kwargs": {"region": "us-east-1"}, + "silence_duration": 2.5, # Nova Sonic needs 2+ seconds of silence + "env_vars": ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], + "skip_reason": "AWS credentials not available", + }, + "openai": { + "model_class": OpenAIRealtimeBidirectionalModel, + "model_kwargs": { + "model": "gpt-4o-realtime-preview-2024-12-17", + "session": { + "output_modalities": ["audio"], # OpenAI only supports audio OR text, not both + "audio": { + "input": { + "format": {"type": "audio/pcm", "rate": 24000}, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "silence_duration_ms": 700, + }, + }, + "output": {"format": {"type": "audio/pcm", "rate": 24000}, "voice": "alloy"}, + }, + }, + }, + "silence_duration": 1.0, # OpenAI has faster VAD + "env_vars": ["OPENAI_API_KEY"], + "skip_reason": "OPENAI_API_KEY not available", + }, + # NOTE: Gemini Live is temporarily disabled in parameterized tests + # Issue: Transcript events are not being properly emitted alongside audio events + # The model responds with audio but the test infrastructure expects text/transcripts + # TODO: Fix Gemini Live event emission to yield both transcript and audio events + # "gemini_live": { + # "model_class": GeminiLiveBidirectionalModel, + # "model_kwargs": { + # "model_id": "gemini-2.5-flash-native-audio-preview-09-2025", + # "params": { + # "response_modalities": ["AUDIO"], + # "output_audio_transcription": {}, + # "input_audio_transcription": {}, + # }, + # }, + # "silence_duration": 3.0, + # "env_vars": ["GOOGLE_AI_API_KEY"], + # "skip_reason": "GOOGLE_AI_API_KEY not available", + # }, +} + + +def check_provider_available(provider_name: str) -> tuple[bool, str]: + """Check if a provider's credentials are available. + + Args: + provider_name: Name of the provider to check. + + Returns: + Tuple of (is_available, skip_reason). + """ + config = PROVIDER_CONFIGS[provider_name] + env_vars = config["env_vars"] + + missing_vars = [var for var in env_vars if not os.getenv(var)] + + if missing_vars: + return False, f"{config['skip_reason']}: {', '.join(missing_vars)}" + + return True, "" + + +@pytest.fixture(params=list(PROVIDER_CONFIGS.keys())) +def provider_config(request): + """Provide configuration for each model provider. + + This fixture is parameterized to run tests against all available providers. + """ + provider_name = request.param + config = PROVIDER_CONFIGS[provider_name] + + # Check if provider is available + is_available, skip_reason = check_provider_available(provider_name) + if not is_available: + pytest.skip(skip_reason) + + return { + "name": provider_name, + **config, + } + + @pytest.fixture -def agent_with_calculator(): - """Provide bidirectional agent with calculator tool. +def agent_with_calculator(provider_config): + """Provide bidirectional agent with calculator tool for the given provider. Note: Session lifecycle (start/end) is handled by BidirectionalTestContext. """ - model = NovaSonicBidirectionalModel(region="us-east-1") + model_class = provider_config["model_class"] + model_kwargs = provider_config["model_kwargs"] + + model = model_class(**model_kwargs) return BidirectionalAgent( model=model, tools=[calculator], - system_prompt="You are a helpful assistant with access to a calculator tool.", + system_prompt="You are a helpful assistant with access to a calculator tool. Keep responses brief.", ) @pytest.mark.asyncio -async def test_bidirectional_agent(agent_with_calculator, audio_generator): - """Test multi-turn conversation with follow-up questions. +async def test_bidirectional_agent(agent_with_calculator, audio_generator, provider_config): + """Test multi-turn conversation with follow-up questions across providers. + + This test runs against all configured providers (Nova Sonic, OpenAI, etc.) + to validate provider-agnostic functionality. Validates: - Session lifecycle (start/end via context manager) @@ -42,46 +147,56 @@ async def test_bidirectional_agent(agent_with_calculator, audio_generator): - Multi-turn conversation flow - Text-to-speech audio output """ + provider_name = provider_config["name"] + silence_duration = provider_config["silence_duration"] + + logger.info(f"Testing provider: {provider_name}") + async with BidirectionalTestContext(agent_with_calculator, audio_generator) as ctx: - # Turn 1: Initial question - await ctx.say("What is five plus three?") + # Turn 1: Simple greeting to test basic audio I/O + await ctx.say("Hello, can you hear me?") + # Wait for silence to trigger provider's VAD/silence detection + await asyncio.sleep(silence_duration) await ctx.wait_for_response() text_outputs_turn1 = ctx.get_text_outputs() all_text_turn1 = " ".join(text_outputs_turn1).lower() - # Validate turn 1 - assert "8" in all_text_turn1 or "eight" in all_text_turn1, ( - f"Answer '8' not found in turn 1: {text_outputs_turn1}" + # Validate turn 1 - just check we got a response + assert len(text_outputs_turn1) > 0, ( + f"[{provider_name}] No text output received in turn 1" ) - logger.info(f"✓ Turn 1 complete: {len(ctx.get_events())} events") + + logger.info(f"[{provider_name}] ✓ Turn 1 complete: received response") + logger.info(f"[{provider_name}] Response: {text_outputs_turn1[0][:100]}...") - # Turn 2: Follow-up question - await ctx.say("Now multiply that by two") + # Turn 2: Follow-up to test multi-turn conversation + await ctx.say("What's your name?") + # Wait for silence to trigger provider's VAD/silence detection + await asyncio.sleep(silence_duration) await ctx.wait_for_response() text_outputs_turn2 = ctx.get_text_outputs() - all_text_turn2 = " ".join(text_outputs_turn2).lower() - # Validate turn 2 - assert "16" in all_text_turn2 or "sixteen" in all_text_turn2, ( - f"Answer '16' not found in turn 2: {text_outputs_turn2}" + # Validate turn 2 - check we got more responses + assert len(text_outputs_turn2) > len(text_outputs_turn1), ( + f"[{provider_name}] No new text output in turn 2" ) - logger.info(f"✓ Turn 2 complete: {len(ctx.get_events())} total events") + + logger.info(f"[{provider_name}] ✓ Turn 2 complete: multi-turn conversation works") + logger.info(f"[{provider_name}] Total responses: {len(text_outputs_turn2)}") # Validate full conversation - assert len(text_outputs_turn2) > len(text_outputs_turn1), "No new text outputs in turn 2" - # Validate audio outputs audio_outputs = ctx.get_audio_outputs() - assert len(audio_outputs) > 0, "No audio output received" + assert len(audio_outputs) > 0, f"[{provider_name}] No audio output received" total_audio_bytes = sum(len(audio) for audio in audio_outputs) - logger.info(f"✓ Audio output: {len(audio_outputs)} chunks, {total_audio_bytes} bytes") # Summary logger.info("=" * 60) - logger.info("✓ Multi-turn conversation test passed") + logger.info(f"[{provider_name}] ✓ Multi-turn conversation test PASSED") + logger.info(f" Provider: {provider_name}") logger.info(f" Total events: {len(ctx.get_events())}") - logger.info(f" Text outputs: {len(text_outputs_turn2)}") - logger.info(f" Audio chunks: {len(audio_outputs)}") + logger.info(f" Text responses: {len(text_outputs_turn2)}") + logger.info(f" Audio chunks: {len(audio_outputs)} ({total_audio_bytes:,} bytes)") logger.info("=" * 60) diff --git a/tests_integ/bidirectional_streaming/utils/test_context.py b/tests_integ/bidirectional_streaming/utils/test_context.py index f669e12ca..4c91e2fc1 100644 --- a/tests_integ/bidirectional_streaming/utils/test_context.py +++ b/tests_integ/bidirectional_streaming/utils/test_context.py @@ -6,6 +6,7 @@ import asyncio import logging +import time from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -14,6 +15,12 @@ logger = logging.getLogger(__name__) +# Constants for timing and buffering +QUEUE_POLL_TIMEOUT = 0.05 # 50ms - balance between responsiveness and CPU usage +SILENCE_INTERVAL = 0.05 # 50ms - send silence every 50ms when queue empty +AUDIO_CHUNK_DELAY = 0.01 # 10ms - small delay between audio chunks +WAIT_POLL_INTERVAL = 0.1 # 100ms - how often to check for response completion + class BidirectionalTestContext: """Manages threads and generators for bidirectional streaming tests. @@ -54,8 +61,9 @@ def __init__( # Queue for thread communication self.input_queue = asyncio.Queue() # Handles both audio and text input - # Event storage - self.events = [] # All collected events + # Event storage (thread-safe) + self._event_queue = asyncio.Queue() # Events from collection thread + self.events = [] # Cached events for test access self.last_event_time = None # Control flags @@ -84,8 +92,9 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): async def start(self): """Start all background threads.""" + import time self.active = True - self.last_event_time = asyncio.get_event_loop().time() + self.last_event_time = time.monotonic() self.threads = [ asyncio.create_task(self._input_thread()), @@ -96,6 +105,11 @@ async def start(self): async def stop(self): """Stop all threads gracefully.""" + if not self.active: + logger.debug("stop() called but already stopped") + return + + logger.debug("stop() called - stopping threads") self.active = False # Cancel all threads @@ -111,13 +125,29 @@ async def stop(self): # === User-facing methods === async def say(self, text: str): - """Queue text to be converted to audio and sent to model. + """Convert text to audio and queue audio chunks to be sent to model. Args: text: Text to convert to speech and send as audio. + + Raises: + ValueError: If audio generator is not available. """ - await self.input_queue.put({"type": "audio", "text": text}) - logger.debug(f"Queued speech: {text[:50]}...") + if not self.audio_generator: + raise ValueError( + "Audio generator not available. Pass audio_generator to BidirectionalTestContext." + ) + + # Generate audio via Polly + audio_data = await self.audio_generator.generate_audio(text) + + # Split into chunks and queue each chunk + for i in range(0, len(audio_data), self.audio_chunk_size): + chunk = audio_data[i : i + self.audio_chunk_size] + chunk_event = self.audio_generator.create_audio_input_event(chunk) + await self.input_queue.put({"type": "audio_chunk", "data": chunk_event}) + + logger.debug(f"Queued {len(audio_data)} bytes of audio for: {text[:50]}...") async def send(self, data: str | dict) -> None: """Send data directly to model (text, image, etc.). @@ -146,27 +176,33 @@ async def wait_for_response( silence_threshold: Seconds of silence to consider response complete. min_events: Minimum events before silence detection activates. """ - start_time = asyncio.get_event_loop().time() - initial_event_count = len(self.events) - - while asyncio.get_event_loop().time() - start_time < timeout: + import time + start_time = time.monotonic() + initial_event_count = len(self.get_events()) # Drain queue + + while time.monotonic() - start_time < timeout: + # Drain queue to get latest events + current_events = self.get_events() + # Check if we have minimum events - if len(self.events) - initial_event_count >= min_events: + if len(current_events) - initial_event_count >= min_events: # Check silence - elapsed_since_event = asyncio.get_event_loop().time() - self.last_event_time + elapsed_since_event = time.monotonic() - self.last_event_time if elapsed_since_event >= silence_threshold: logger.debug( - f"Response complete: {len(self.events) - initial_event_count} events, " + f"Response complete: {len(current_events) - initial_event_count} events, " f"{elapsed_since_event:.1f}s silence" ) return - await asyncio.sleep(0.1) + await asyncio.sleep(WAIT_POLL_INTERVAL) logger.warning(f"Response timeout after {timeout}s") def get_events(self, event_type: str | None = None) -> list[dict]: """Get collected events, optionally filtered by type. + + Drains the event queue and caches events for subsequent calls. Args: event_type: Optional event type to filter by (e.g., "textOutput"). @@ -174,22 +210,40 @@ def get_events(self, event_type: str | None = None) -> list[dict]: Returns: List of events, filtered if event_type specified. """ + # Drain queue into cache (non-blocking) + while not self._event_queue.empty(): + try: + event = self._event_queue.get_nowait() + self.events.append(event) + import time + self.last_event_time = time.monotonic() + except asyncio.QueueEmpty: + break + if event_type: return [e for e in self.events if event_type in e] return self.events.copy() def get_text_outputs(self) -> list[str]: """Extract text outputs from collected events. + + Handles both textOutput events (Nova Sonic, OpenAI) and transcript events (Gemini Live). Returns: List of text content strings. """ texts = [] - for event in self.events: + for event in self.get_events(): # Drain queue first + # Handle textOutput events (Nova Sonic, OpenAI) if "textOutput" in event: text = event["textOutput"].get("text", "") if text: texts.append(text) + # Handle transcript events (Gemini Live) + elif "transcript" in event: + text = event["transcript"].get("text", "") + if text: + texts.append(text) return texts def get_audio_outputs(self) -> list[bytes]: @@ -198,8 +252,10 @@ def get_audio_outputs(self) -> list[bytes]: Returns: List of audio data bytes. """ + # Drain queue first to get latest events + events = self.get_events() audio_data = [] - for event in self.events: + for event in events: if "audioOutput" in event: data = event["audioOutput"].get("audioData") if data: @@ -212,7 +268,9 @@ def get_tool_uses(self) -> list[dict]: Returns: List of tool use events. """ - return [event["toolUse"] for event in self.events if "toolUse" in event] + # Drain queue first to get latest events + events = self.get_events() + return [event["toolUse"] for event in events if "toolUse" in event] def has_interruption(self) -> bool: """Check if any interruption was detected. @@ -232,33 +290,21 @@ def clear_events(self): async def _input_thread(self): """Continuously handle input to model. - - Sends silence by default (background noise) if audio generator available - - Converts queued text to audio via Polly (for "audio" type) - - Sends text directly to model (for "text" type) + - Sends queued audio chunks immediately + - Sends silence chunks periodically when queue is empty (simulates microphone) + - Sends direct data to model """ try: + logger.debug(f"Input thread starting, active={self.active}") while self.active: try: - # Check for queued input (non-blocking) - input_item = await asyncio.wait_for(self.input_queue.get(), timeout=0.01) - - if input_item["type"] == "audio": - # Generate and send audio - if self.audio_generator: - audio_data = await self.audio_generator.generate_audio(input_item["text"]) - - # Send audio in chunks - for i in range(0, len(audio_data), self.audio_chunk_size): - if not self.active: - break - chunk = audio_data[i : i + self.audio_chunk_size] - chunk_event = self.audio_generator.create_audio_input_event(chunk) - await self.agent.send(chunk_event) - await asyncio.sleep(0.01) - - logger.debug(f"Sent audio: {len(audio_data)} bytes") - else: - logger.warning("Audio requested but no generator available") + # Check for queued input (non-blocking with short timeout) + input_item = await asyncio.wait_for(self.input_queue.get(), timeout=QUEUE_POLL_TIMEOUT) + + if input_item["type"] == "audio_chunk": + # Send pre-generated audio chunk + await self.agent.send(input_item["data"]) + await asyncio.sleep(AUDIO_CHUNK_DELAY) elif input_item["type"] == "direct": # Send data directly to agent @@ -267,16 +313,18 @@ async def _input_thread(self): logger.debug(f"Sent direct: {data_repr}") except asyncio.TimeoutError: - # No input queued - send silence if audio generator available + # No input queued - send silence chunk to simulate continuous microphone input if self.audio_generator: silence = self._generate_silence_chunk() await self.agent.send(silence) - await asyncio.sleep(0.01) + await asyncio.sleep(SILENCE_INTERVAL) except asyncio.CancelledError: logger.debug("Input thread cancelled") except Exception as e: - logger.error(f"Input thread error: {e}") + logger.error(f"Input thread error: {e}", exc_info=True) + finally: + logger.debug(f"Input thread stopped, active={self.active}") async def _event_collection_thread(self): """Continuously collect events from model.""" @@ -285,8 +333,8 @@ async def _event_collection_thread(self): if not self.active: break - self.events.append(event) - self.last_event_time = asyncio.get_event_loop().time() + # Thread-safe: put in queue instead of direct append + await self._event_queue.put(event) logger.debug(f"Event collected: {list(event.keys())}") except asyncio.CancelledError: From 231f1718d8533fe78d4f98bca11a7cff73235084 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 4 Nov 2025 15:27:07 +0300 Subject: [PATCH 043/242] fix: address comments in pr --- .../event_loop/bidirectional_event_loop.py | 6 +-- .../models/gemini_live.py | 11 ++--- .../models/novasonic.py | 41 ++++++++++--------- .../bidirectional_streaming/models/openai.py | 3 +- .../models/test_novasonic.py | 26 ++++++------ .../models/test_openai_realtime.py | 2 +- 6 files changed, 46 insertions(+), 43 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index d1d6e90b3..38d92aea8 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -473,8 +473,8 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: try: await session.model.send(error_result) logger.debug("Error result sent: %s", tool_id) - except Exception: - logger.error("Failed to send error result: %s", tool_id) - pass # Connection might be closed + except Exception as send_error: + logger.error("Failed to send error result: %s - %s", tool_id, str(send_error)) + raise # Propagate exception since this is experimental code diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 9f0cfe6c0..ffff98cf1 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -84,7 +84,7 @@ def __init__( # Connection state (initialized in connect()) self.live_session = None - self.live_session_cm = None + self.live_session_context_manager = None self.session_id = None self._active = False @@ -115,13 +115,13 @@ async def connect( live_config = self._build_live_config(system_prompt, tools, **kwargs) # Create the context manager - self.live_session_cm = self.client.aio.live.connect( + self.live_session_context_manager = self.client.aio.live.connect( model=self.model_id, config=live_config ) # Enter the context manager - self.live_session = await self.live_session_cm.__aenter__() + self.live_session = await self.live_session_context_manager.__aenter__() # Send initial message history if provided if messages: @@ -312,6 +312,7 @@ async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputE logger.warning(f"Unknown content type with keys: {content.keys()}") except Exception as e: logger.error(f"Error sending content: {e}") + raise # Propagate exception for debugging in experimental code async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: """Internal: Send audio content using Gemini Live API. @@ -412,8 +413,8 @@ async def close(self) -> None: try: # Exit the context manager properly - if self.live_session_cm: - await self.live_session_cm.__aexit__(None, None, None) + if self.live_session_context_manager: + await self.live_session_context_manager.__aexit__(None, None, None) except Exception as e: logger.error("Error closing Gemini Live connection: %s", e) raise diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index c9e5805db..e4c0d1565 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -102,11 +102,11 @@ def __init__( # Model configuration self.model_id = model_id self.region = region - self._client = None + self.client = None # Connection state (initialized in connect()) self.stream = None - self.prompt_name = None + self.session_id = None self._active = False # Nova Sonic requires unique content names @@ -146,17 +146,17 @@ async def connect( try: # Initialize client if needed - if not self._client: + if not self.client: await self._initialize_client() # Initialize connection state - self.prompt_name = str(uuid.uuid4()) + self.session_id = str(uuid.uuid4()) self._active = True self.audio_content_name = str(uuid.uuid4()) self._event_queue = asyncio.Queue() # Start Nova Sonic bidirectional stream - self.stream = await self._client.invoke_model_with_bidirectional_stream( + self.stream = await self.client.invoke_model_with_bidirectional_stream( InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) ) @@ -165,7 +165,7 @@ async def connect( logger.error("Stream is None") raise ValueError("Stream cannot be None") - logger.debug("Nova Sonic connection initialized with prompt: %s", self.prompt_name) + logger.debug("Nova Sonic connection initialized with session: %s", self.session_id) # Send initialization events system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." @@ -269,7 +269,7 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: # Emit connection start event to Strands event system connection_start: BidirectionalConnectionStartEvent = { - "connectionId": self.prompt_name, + "connectionId": self.session_id, "metadata": {"provider": "nova_sonic", "model_id": self.model_id}, } yield {"BidirectionalConnectionStart": connection_start} @@ -295,7 +295,7 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: finally: # Emit connection end event when exiting connection_end: BidirectionalConnectionEndEvent = { - "connectionId": self.prompt_name, + "connectionId": self.session_id, "reason": "connection_complete", "metadata": {"provider": "nova_sonic"}, } @@ -331,6 +331,7 @@ async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputE logger.warning(f"Unknown content type with keys: {content.keys()}") except Exception as e: logger.error(f"Error sending content: {e}") + raise # Propagate exception for debugging in experimental code async def _start_audio_connection(self) -> None: """Internal: Start audio input connection (call once before sending audio chunks).""" @@ -343,7 +344,7 @@ async def _start_audio_connection(self) -> None: { "event": { "contentStart": { - "promptName": self.prompt_name, + "promptName": self.session_id, "contentName": self.audio_content_name, "type": "AUDIO", "interactive": True, @@ -376,7 +377,7 @@ async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: { "event": { "audioInput": { - "promptName": self.prompt_name, + "promptName": self.session_id, "contentName": self.audio_content_name, "content": nova_audio_data, } @@ -409,7 +410,7 @@ async def _end_audio_input(self) -> None: logger.debug("Nova audio connection end") audio_content_end = json.dumps( - {"event": {"contentEnd": {"promptName": self.prompt_name, "contentName": self.audio_content_name}}} + {"event": {"contentEnd": {"promptName": self.session_id, "contentName": self.audio_content_name}}} ) await self._send_nova_event(audio_content_end) @@ -434,7 +435,7 @@ async def _send_interrupt(self) -> None: { "event": { "audioInput": { - "promptName": self.prompt_name, + "promptName": self.session_id, "contentName": self.audio_content_name, "stopReason": "INTERRUPTED", } @@ -600,7 +601,7 @@ def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: prompt_start_event = { "event": { "promptStart": { - "promptName": self.prompt_name, + "promptName": self.session_id, "textOutputConfiguration": NOVA_TEXT_CONFIG, "audioOutputConfiguration": NOVA_AUDIO_OUTPUT_CONFIG, } @@ -644,7 +645,7 @@ def _get_text_content_start_event(self, content_name: str, role: str = "USER") - { "event": { "contentStart": { - "promptName": self.prompt_name, + "promptName": self.session_id, "contentName": content_name, "type": "TEXT", "role": role, @@ -661,7 +662,7 @@ def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> { "event": { "contentStart": { - "promptName": self.prompt_name, + "promptName": self.session_id, "contentName": content_name, "interactive": False, "type": "TOOL", @@ -679,7 +680,7 @@ def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> def _get_text_input_event(self, content_name: str, text: str) -> str: """Generate text input event.""" return json.dumps( - {"event": {"textInput": {"promptName": self.prompt_name, "contentName": content_name, "content": text}}} + {"event": {"textInput": {"promptName": self.session_id, "contentName": content_name, "content": text}}} ) def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> str: @@ -688,7 +689,7 @@ def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> s { "event": { "toolResult": { - "promptName": self.prompt_name, + "promptName": self.session_id, "contentName": content_name, "content": json.dumps(result), } @@ -698,11 +699,11 @@ def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> s def _get_content_end_event(self, content_name: str) -> str: """Generate content end event.""" - return json.dumps({"event": {"contentEnd": {"promptName": self.prompt_name, "contentName": content_name}}}) + return json.dumps({"event": {"contentEnd": {"promptName": self.session_id, "contentName": content_name}}}) def _get_prompt_end_event(self) -> str: """Generate prompt end event.""" - return json.dumps({"event": {"promptEnd": {"promptName": self.prompt_name}}}) + return json.dumps({"event": {"promptEnd": {"promptName": self.session_id}}}) def _get_connection_end_event(self) -> str: """Generate connection end event.""" @@ -733,7 +734,7 @@ async def _initialize_client(self) -> None: auth_schemes={"aws.auth#sigv4": SigV4AuthScheme(service="bedrock")}, ) - self._client = BedrockRuntimeClient(config=config) + self.client = BedrockRuntimeClient(config=config) logger.debug("Nova Sonic client initialized") except ImportError as e: diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 0810b7b21..4bf43b563 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -144,7 +144,7 @@ async def connect( if self.project: headers.append(("OpenAI-Project", self.project)) - self.websocket = await websockets.connect(url, additional_headers=headers) + self.websocket = await websockets.connect(url, extra_headers=headers) logger.info("WebSocket connected successfully") # Configure session @@ -462,6 +462,7 @@ async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputE logger.warning(f"Unknown content type with keys: {content.keys()}") except Exception as e: logger.error(f"Error sending content: {e}") + raise # Propagate exception for debugging in experimental code async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: """Internal: Send audio content to OpenAI for processing.""" diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 5601e23b8..7265bfacd 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -72,20 +72,20 @@ async def test_model_initialization(model_id, region): assert model.region == region assert model.stream is None assert not model._active - assert model.prompt_name is None + assert model.session_id is None @pytest.mark.asyncio async def test_connection_lifecycle(nova_model, mock_client, mock_stream): """Test complete connection lifecycle with various configurations.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client + nova_model.client = mock_client # Test basic connection await nova_model.connect(system_prompt="Test system prompt") assert nova_model._active assert nova_model.stream == mock_stream - assert nova_model.prompt_name is not None + assert nova_model.session_id is not None assert mock_client.invoke_model_with_bidirectional_stream.called # Test close @@ -111,7 +111,7 @@ async def test_connection_lifecycle(nova_model, mock_client, mock_stream): async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model_id, region): """Test connection error handling and edge cases.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client + nova_model.client = mock_client # Test double connection await nova_model.connect() @@ -132,7 +132,7 @@ async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model async def test_send_all_content_types(nova_model, mock_client, mock_stream): """Test sending all content types through unified send() method.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client + nova_model.client = mock_client await nova_model.connect() @@ -171,7 +171,7 @@ async def test_send_all_content_types(nova_model, mock_client, mock_stream): async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog): """Test send() edge cases and error handling.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client + nova_model.client = mock_client # Test send when inactive text_event = {"text": "Hello", "role": "user"} @@ -197,7 +197,7 @@ async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog): async def test_receive_lifecycle_events(nova_model, mock_client, mock_stream): """Test that receive() emits connection start and end events.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client + nova_model.client = mock_client # Setup mock to return no events and then stop async def mock_wait_for(*args, **kwargs): @@ -215,7 +215,7 @@ async def mock_wait_for(*args, **kwargs): # Should have connection start and end assert len(events) >= 2 assert "BidirectionalConnectionStart" in events[0] - assert events[0]["BidirectionalConnectionStart"]["connectionId"] == nova_model.prompt_name + assert events[0]["BidirectionalConnectionStart"]["connectionId"] == nova_model.session_id assert "BidirectionalConnectionEnd" in events[-1] @@ -301,7 +301,7 @@ async def test_event_conversion(nova_model): async def test_audio_connection_lifecycle(nova_model, mock_client, mock_stream): """Test audio connection start and end lifecycle.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client + nova_model.client = mock_client await nova_model.connect() @@ -320,7 +320,7 @@ async def test_audio_connection_lifecycle(nova_model, mock_client, mock_stream): async def test_silence_detection(nova_model, mock_client, mock_stream): """Test that silence detection automatically ends audio input.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client + nova_model.client = mock_client nova_model.silence_threshold = 0.1 # Short threshold for testing await nova_model.connect() @@ -385,12 +385,12 @@ async def test_event_templates(nova_model): assert "inferenceConfiguration" in event["event"]["sessionStart"] # Test prompt start event - nova_model.prompt_name = "test-prompt" + nova_model.session_id = "test-session" event_json = nova_model._get_prompt_start_event([]) event = json.loads(event_json) assert "event" in event assert "promptStart" in event["event"] - assert event["event"]["promptStart"]["promptName"] == "test-prompt" + assert event["event"]["promptStart"]["promptName"] == "test-session" # Test text input event content_name = "test-content" @@ -416,7 +416,7 @@ async def test_event_templates(nova_model): async def test_error_handling(nova_model, mock_client, mock_stream): """Test error handling in various scenarios.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client + nova_model.client = mock_client # Test response processor handles errors gracefully async def mock_error(*args, **kwargs): diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 388fc95cc..1c0b949b0 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -174,7 +174,7 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp model_org = OpenAIRealtimeModel(api_key="test-key", organization="org-123") await model_org.connect() call_kwargs = mock_connect.call_args.kwargs - headers = call_kwargs.get("additional_headers", []) + headers = call_kwargs.get("extra_headers", []) org_header = [h for h in headers if h[0] == "OpenAI-Organization"] assert len(org_header) == 1 assert org_header[0][1] == "org-123" From 44e5a631a015dcccbc60e38953acc121045352f8 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 4 Nov 2025 17:17:02 +0300 Subject: [PATCH 044/242] fix: update websockets in openai --- pyproject.toml | 4 +-- .../bidirectional_streaming/agent/agent.py | 33 ++++++++++++------- .../bidirectional_streaming/models/openai.py | 3 +- .../models/test_openai_realtime.py | 2 +- 4 files changed, 26 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e079ec263..7810d09c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,7 @@ bidirectional-streaming-nova = [ ] bidirectional-streaming-openai = [ "pyaudio>=0.2.13", - "websockets>=12.0,<14.0", + "websockets>=14.0,<16.0", ] bidirectional-streaming = [ "pyaudio>=0.2.13", @@ -71,7 +71,7 @@ bidirectional-streaming = [ "smithy-aws-core>=0.0.1", "pytz", "aws_sdk_bedrock_runtime", - "websockets>=12.0,<14.0", + "websockets>=14.0,<16.0", ] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 0b62d87fc..bbe3f3da2 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -31,7 +31,13 @@ from ....types.traces import AttributeValue from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection from ..models.bidirectional_model import BidirectionalModel -from ..types.bidirectional_streaming import AudioInputEvent, ImageInputEvent, OutputEvent +from ..types.bidirectional_streaming import ( + AudioInputEvent, + ImageInputEvent, + InputEvent, + OutputEvent, + TextInputEvent, +) logger = logging.getLogger(__name__) @@ -389,14 +395,18 @@ async def send(self, input_data: str | AudioInputEvent | ImageInputEvent | dict) # Add user text message to history self.messages.append({"role": "user", "content": input_data}) logger.debug("Text sent: %d characters", len(input_data)) - from ..types.bidirectional_streaming import TextInputEvent text_event = TextInputEvent(text=input_data, role="user") await self._session.model.send(text_event) return - # Handle dict - reconstruct TypedEvent for WebSocket integration + # Handle InputEvent instances (TextInputEvent, AudioInputEvent, ImageInputEvent) + # Check this before dict since TypedEvent inherits from dict + if isinstance(input_data, (TextInputEvent, AudioInputEvent, ImageInputEvent)): + await self._session.model.send(input_data) + return + + # Handle plain dict - reconstruct TypedEvent for WebSocket integration if isinstance(input_data, dict) and "type" in input_data: - from ..types.bidirectional_streaming import TextInputEvent event_type = input_data["type"] if event_type == "bidirectional_text_input": input_data = TextInputEvent(text=input_data["text"], role=input_data["role"]) @@ -414,14 +424,15 @@ async def send(self, input_data: str | AudioInputEvent | ImageInputEvent | dict) ) else: raise ValueError(f"Unknown event type: {event_type}") - - # Handle TypedEvent instances - if isinstance(input_data, (AudioInputEvent, ImageInputEvent, TextInputEvent)): + + # Send the reconstructed TypedEvent await self._session.model.send(input_data) - else: - raise ValueError( - f"Input must be a string, TypedEvent, or event dict, got: {type(input_data)}" - ) + return + + # If we get here, input type is invalid + raise ValueError( + f"Input must be a string, InputEvent (TextInputEvent/AudioInputEvent/ImageInputEvent), or event dict with 'type' field, got: {type(input_data)}" + ) async def receive(self) -> AsyncIterable[dict[str, Any]]: """Receive events from the model including audio, text, and tool calls. diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 1c4318def..92da27729 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -13,7 +13,6 @@ from typing import AsyncIterable, Union import websockets -from websockets.client import WebSocketClientProtocol from websockets.exceptions import ConnectionClosed from ....types.content import Messages @@ -149,7 +148,7 @@ async def connect( if self.project: headers.append(("OpenAI-Project", self.project)) - self.websocket = await websockets.connect(url, extra_headers=headers) + self.websocket = await websockets.connect(url, additional_headers=headers) logger.info("WebSocket connected successfully") # Configure session diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 6909a1f5c..60e88aa0f 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -174,7 +174,7 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp model_org = OpenAIRealtimeModel(api_key="test-key", organization="org-123") await model_org.connect() call_kwargs = mock_connect.call_args.kwargs - headers = call_kwargs.get("extra_headers", []) + headers = call_kwargs.get("additional_headers", []) org_header = [h for h in headers if h[0] == "OpenAI-Organization"] assert len(org_header) == 1 assert org_header[0][1] == "org-123" From 631b620ea85c0dcfcecd1f40b54cf9f50883a099 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 4 Nov 2025 17:17:12 +0300 Subject: [PATCH 045/242] fix: fix integ test --- .../test_bidirectional_agent.py | 41 +++++++++++++++---- .../utils/audio_generator.py | 10 ++++- .../utils/test_context.py | 24 ++++++++--- 3 files changed, 61 insertions(+), 14 deletions(-) diff --git a/tests_integ/bidirectional_streaming/test_bidirectional_agent.py b/tests_integ/bidirectional_streaming/test_bidirectional_agent.py index 9ae64514d..80b32b178 100644 --- a/tests_integ/bidirectional_streaming/test_bidirectional_agent.py +++ b/tests_integ/bidirectional_streaming/test_bidirectional_agent.py @@ -11,29 +11,56 @@ import os import pytest -from strands_tools import calculator +from strands import tool from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel -from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel -from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveBidirectionalModel +from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicModel +from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeModel +from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveModel from .utils.test_context import BidirectionalTestContext logger = logging.getLogger(__name__) +# Simple calculator tool for testing +@tool +def calculator(operation: str, x: float, y: float) -> float: + """Perform basic arithmetic operations. + + Args: + operation: The operation to perform (add, subtract, multiply, divide) + x: First number + y: Second number + + Returns: + Result of the operation + """ + if operation == "add": + return x + y + elif operation == "subtract": + return x - y + elif operation == "multiply": + return x * y + elif operation == "divide": + if y == 0: + raise ValueError("Cannot divide by zero") + return x / y + else: + raise ValueError(f"Unknown operation: {operation}") + + # Provider configurations PROVIDER_CONFIGS = { "nova_sonic": { - "model_class": NovaSonicBidirectionalModel, + "model_class": NovaSonicModel, "model_kwargs": {"region": "us-east-1"}, "silence_duration": 2.5, # Nova Sonic needs 2+ seconds of silence "env_vars": ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], "skip_reason": "AWS credentials not available", }, "openai": { - "model_class": OpenAIRealtimeBidirectionalModel, + "model_class": OpenAIRealtimeModel, "model_kwargs": { "model": "gpt-4o-realtime-preview-2024-12-17", "session": { @@ -60,7 +87,7 @@ # The model responds with audio but the test infrastructure expects text/transcripts # TODO: Fix Gemini Live event emission to yield both transcript and audio events # "gemini_live": { - # "model_class": GeminiLiveBidirectionalModel, + # "model_class": GeminiLiveModel, # "model_kwargs": { # "model_id": "gemini-2.5-flash-native-audio-preview-09-2025", # "params": { diff --git a/tests_integ/bidirectional_streaming/utils/audio_generator.py b/tests_integ/bidirectional_streaming/utils/audio_generator.py index 605a2aaa9..c3ad3f965 100644 --- a/tests_integ/bidirectional_streaming/utils/audio_generator.py +++ b/tests_integ/bidirectional_streaming/utils/audio_generator.py @@ -120,10 +120,16 @@ def create_audio_input_event( Returns: AudioInputEvent dict ready for agent.send(). """ + import base64 + + # Convert bytes to base64 string for JSON compatibility + audio_b64 = base64.b64encode(audio_data).decode('utf-8') + return { - "audioData": audio_data, + "type": "bidirectional_audio_input", + "audio": audio_b64, "format": format, - "sampleRate": sample_rate, + "sample_rate": sample_rate, "channels": channels, } diff --git a/tests_integ/bidirectional_streaming/utils/test_context.py b/tests_integ/bidirectional_streaming/utils/test_context.py index 4c91e2fc1..687aef1b5 100644 --- a/tests_integ/bidirectional_streaming/utils/test_context.py +++ b/tests_integ/bidirectional_streaming/utils/test_context.py @@ -227,19 +227,24 @@ def get_events(self, event_type: str | None = None) -> list[dict]: def get_text_outputs(self) -> list[str]: """Extract text outputs from collected events. - Handles both textOutput events (Nova Sonic, OpenAI) and transcript events (Gemini Live). + Handles both new TypedEvent format and legacy event formats. Returns: List of text content strings. """ texts = [] for event in self.get_events(): # Drain queue first - # Handle textOutput events (Nova Sonic, OpenAI) - if "textOutput" in event: + # Handle new TypedEvent format (bidirectional_transcript_stream) + if event.get("type") == "bidirectional_transcript_stream": + text = event.get("text", "") + if text: + texts.append(text) + # Handle legacy textOutput events (Nova Sonic, OpenAI) + elif "textOutput" in event: text = event["textOutput"].get("text", "") if text: texts.append(text) - # Handle transcript events (Gemini Live) + # Handle legacy transcript events (Gemini Live) elif "transcript" in event: text = event["transcript"].get("text", "") if text: @@ -252,11 +257,20 @@ def get_audio_outputs(self) -> list[bytes]: Returns: List of audio data bytes. """ + import base64 + # Drain queue first to get latest events events = self.get_events() audio_data = [] for event in events: - if "audioOutput" in event: + # Handle new TypedEvent format (bidirectional_audio_stream) + if event.get("type") == "bidirectional_audio_stream": + audio_b64 = event.get("audio") + if audio_b64: + # Decode base64 to bytes + audio_data.append(base64.b64decode(audio_b64)) + # Handle legacy audioOutput events + elif "audioOutput" in event: data = event["audioOutput"].get("audioData") if data: audio_data.append(data) From 6e33d9a4bb3c28834816e6e34f23489410becd43 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 4 Nov 2025 17:55:28 +0300 Subject: [PATCH 046/242] feat(types): add tool use types --- .../bidirectional_streaming/__init__.py | 12 ++++++++++++ .../event_loop/bidirectional_event_loop.py | 18 +++++++++++++----- .../models/gemini_live.py | 8 ++++++-- .../models/novasonic.py | 9 ++++++--- .../bidirectional_streaming/models/openai.py | 9 ++++++--- 5 files changed, 43 insertions(+), 13 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index f75834a76..678dfc0d4 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -30,6 +30,13 @@ TurnStartEvent, ) +# Re-export standard agent events for tool handling +from ...types._events import ( + ToolResultEvent, + ToolStreamEvent, + ToolUseStreamEvent, +) + __all__ = [ # Main interface "BidirectionalAgent", @@ -58,6 +65,11 @@ "ErrorEvent", "OutputEvent", + # Tool Event types (reused from standard agent) + "ToolUseStreamEvent", + "ToolResultEvent", + "ToolStreamEvent", + # Model interface "BidirectionalModel", ] diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index b3b1ee8ca..e618245e1 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -281,12 +281,15 @@ async def _process_model_events(session: BidirectionalConnection) -> None: continue # Queue tool requests for concurrent execution - if event_type == "tool_use": - tool_use = strands_event.get("tool_use") + # Check for ToolUseStreamEvent (standard agent event) + if "current_tool_use" in strands_event: + tool_use = strands_event.get("current_tool_use") if tool_use: tool_name = tool_use.get("name") logger.debug("Tool usage detected: %s", tool_name) await session.tool_queue.put(tool_use) + # Forward ToolUseStreamEvent to output queue for client visibility + await session.agent._output_queue.put(strands_event) continue # Send all output events to Agent for receive() method @@ -436,14 +439,19 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: tool_result = tool_event.tool_result tool_use_id = tool_result.get("toolUseId") - # Send ToolResultEvent through send() method + # Send ToolResultEvent through send() method to model await session.model.send(tool_event) - logger.debug("Tool result sent: %s", tool_use_id) + logger.debug("Tool result sent to model: %s", tool_use_id) + + # Also forward ToolResultEvent to output queue for client visibility + await session.agent._output_queue.put(tool_event.as_dict()) + logger.debug("Tool result sent to client: %s", tool_use_id) # Handle streaming events if needed later elif isinstance(tool_event, ToolStreamEvent): logger.debug("Tool stream event: %s", tool_event) - pass + # Forward tool stream events to output queue + await session.agent._output_queue.put(tool_event.as_dict()) # Add tool result message to conversation history if tool_results: diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index ac546e010..1475edaac 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -23,7 +23,7 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse -from ....types._events import ToolResultEvent +from ....types._events import ToolResultEvent, ToolUseStreamEvent from ..types.bidirectional_streaming import ( AudioInputEvent, AudioStreamEvent, @@ -267,7 +267,11 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic "name": func_call.name, "input": func_call.args or {} } - return {"toolUse": tool_use_event} + # Return ToolUseStreamEvent for consistency with standard agent + return ToolUseStreamEvent( + delta={"toolUse": tool_use_event}, + current_tool_use=tool_use_event + ) # Handle usage metadata if hasattr(message, 'usage_metadata') and message.usage_metadata: diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index d054142fb..033eff4e9 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -32,7 +32,7 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse -from ....types._events import ToolResultEvent +from ....types._events import ToolResultEvent, ToolUseStreamEvent from ..types.bidirectional_streaming import ( AudioInputEvent, AudioStreamEvent, @@ -547,8 +547,11 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: "name": tool_use["toolName"], "input": json.loads(tool_use["content"]), } - # Return dict with tool_use for event loop processing - return {"type": "tool_use", "tool_use": tool_use_event} + # Return ToolUseStreamEvent for consistency with standard agent + return ToolUseStreamEvent( + delta={"toolUse": tool_use_event}, + current_tool_use=tool_use_event + ) # Handle interruption elif nova_event.get("stopReason") == "INTERRUPTED": diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 92da27729..393deb0bd 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -17,7 +17,7 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse -from ....types._events import ToolResultEvent +from ....types._events import ToolResultEvent, ToolUseStreamEvent from ..types.bidirectional_streaming import ( AudioInputEvent, AudioStreamEvent, @@ -365,8 +365,11 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEven "input": json.loads(function_call["arguments"]) if function_call["arguments"] else {}, } del self._function_call_buffer[call_id] - # Return dict with tool_use for event loop processing - return [{"type": "tool_use", "tool_use": tool_use}] + # Return ToolUseStreamEvent for consistency with standard agent + return [ToolUseStreamEvent( + delta={"toolUse": tool_use}, + current_tool_use=tool_use + )] except (json.JSONDecodeError, KeyError) as e: logger.warning("Error parsing function arguments for %s: %s", call_id, e) del self._function_call_buffer[call_id] From 863f04fc966e099e1950bfa0eb362cfdf8443f6c Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 4 Nov 2025 10:01:08 -0500 Subject: [PATCH 047/242] feat: (Agent): Finalize Bidirectional Agent class --- .../adapters/audio_adapter.py | 286 ++++++++++++++++++ .../bidirectional_streaming/agent/agent.py | 98 ++++-- .../tests/optimized_example.py | 34 +++ 3 files changed, 400 insertions(+), 18 deletions(-) create mode 100644 src/strands/experimental/bidirectional_streaming/adapters/audio_adapter.py create mode 100644 src/strands/experimental/bidirectional_streaming/tests/optimized_example.py diff --git a/src/strands/experimental/bidirectional_streaming/adapters/audio_adapter.py b/src/strands/experimental/bidirectional_streaming/adapters/audio_adapter.py new file mode 100644 index 000000000..ceac7a64b --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/adapters/audio_adapter.py @@ -0,0 +1,286 @@ +"""AudioAdapter - Clean separation of audio functionality from core BidirectionalAgent. + +Provides audio input/output capabilities for BidirectionalAgent through the adapter pattern. +Handles all PyAudio setup, streaming, and cleanup while keeping the core agent data-agnostic. +""" + +import asyncio +import base64 +import logging +from typing import Any, Callable, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from ..agent import BidirectionalAgent + +try: + import pyaudio +except ImportError: + pyaudio = None + +logger = logging.getLogger(__name__) + + +class AudioAdapter: + """Audio adapter for BidirectionalAgent with queue-based processing.""" + + def __init__( + self, + agent: "BidirectionalAgent", + audio_config: Optional[dict] = None, + ): + """Initialize AudioAdapter with clean audio configuration. + + Args: + agent: The BidirectionalAgent instance to wrap + audio_config: Dictionary containing audio configuration: + - input_sample_rate (int): Microphone sample rate (default: 24000) + - output_sample_rate (int): Speaker sample rate (default: 24000) + - chunk_size (int): Audio chunk size in bytes (default: 1024) + - input_device_index (int): Specific input device (optional) + - output_device_index (int): Specific output device (optional) + - input_channels (int): Input channels (default: 1) + - output_channels (int): Output channels (default: 1) + """ + if pyaudio is None: + raise ImportError("PyAudio is required for AudioAdapter. Install with: pip install pyaudio") + + self.agent = agent + + # Default audio configuration + default_config = { + "input_sample_rate": 24000, + "output_sample_rate": 24000, + "chunk_size": 1024, + "input_device_index": None, + "output_device_index": None, + "input_channels": 1, + "output_channels": 1 + } + + # Merge user config with defaults + if audio_config: + default_config.update(audio_config) + + # Set audio configuration attributes + self.input_sample_rate = default_config["input_sample_rate"] + self.output_sample_rate = default_config["output_sample_rate"] + self.chunk_size = default_config["chunk_size"] + self.input_device_index = default_config["input_device_index"] + self.output_device_index = default_config["output_device_index"] + self.input_channels = default_config["input_channels"] + self.output_channels = default_config["output_channels"] + + # Audio infrastructure (lazy initialization) + self.audio = None + self.input_stream = None + self.output_stream = None + self.interrupted = False + + # Audio output queue for background processing + self.audio_output_queue = asyncio.Queue() + + def _setup_audio(self) -> None: + """Setup PyAudio streams for input and output.""" + if self.audio: + return # Already setup + + self.audio = pyaudio.PyAudio() + + try: + # Input stream (microphone) + self.input_stream = self.audio.open( + format=pyaudio.paInt16, + channels=self.input_channels, + rate=self.input_sample_rate, + input=True, + frames_per_buffer=self.chunk_size, + input_device_index=self.input_device_index + ) + + # Output stream (speakers) + self.output_stream = self.audio.open( + format=pyaudio.paInt16, + channels=self.output_channels, + rate=self.output_sample_rate, + output=True, + frames_per_buffer=self.chunk_size, + output_device_index=self.output_device_index + ) + + # Start streams - required for audio to flow + self.input_stream.start_stream() + self.output_stream.start_stream() + + except Exception as e: + logger.error(f"AudioAdapter: Audio setup failed: {e}") + self._cleanup_audio() + raise + + def _cleanup_audio(self) -> None: + """Clean up PyAudio resources.""" + try: + if self.input_stream: + if self.input_stream.is_active(): + self.input_stream.stop_stream() + self.input_stream.close() + + if self.output_stream: + if self.output_stream.is_active(): + self.output_stream.stop_stream() + self.output_stream.close() + + if self.audio: + self.audio.terminate() + + self.input_stream = None + self.output_stream = None + self.audio = None + + except Exception as e: + logger.warning(f"Audio cleanup error: {e}") + + def create_input(self) -> Callable[[], dict]: + """Create audio input function for agent.run().""" + async def audio_receiver() -> dict: + """Read audio from microphone.""" + if not self.input_stream: + self._setup_audio() + + try: + audio_bytes = self.input_stream.read(self.chunk_size, exception_on_overflow=False) + return { + "audioData": audio_bytes, + "format": "pcm", + "sampleRate": self.input_sample_rate, + "channels": self.input_channels # Use configured channels + } + except Exception as e: + logger.warning(f"Audio input error: {e}") + return {"audioData": b"", "format": "pcm", "sampleRate": self.input_sample_rate, "channels": self.input_channels} + + return audio_receiver + + def create_output(self) -> Callable[[dict], None]: + """Create output function that queues audio for background processing.""" + + # Start background audio processor once + if not hasattr(self, '_audio_task') or self._audio_task.done(): + self._audio_task = asyncio.create_task(self._process_audio_queue()) + + events_queued = 0 + + async def audio_sender(event: dict) -> None: + """Queue audio events with minimal debug.""" + nonlocal events_queued + + if "audioOutput" in event: + if not self.interrupted: + audio_data = event["audioOutput"]["audioData"] + self.audio_output_queue.put_nowait(audio_data) + events_queued += 1 + + elif "interruptionDetected" in event or "interrupted" in event: + self.interrupted = True + cleared = 0 + while not self.audio_output_queue.empty(): + try: + self.audio_output_queue.get_nowait() + cleared += 1 + except asyncio.QueueEmpty: + break + logger.debug(f"Cleared {cleared} audio chunks on interruption") + self.interrupted = False + + elif "textOutput" in event: + text = event["textOutput"].get("text", "") + role = event["textOutput"].get("role", "") + if role.upper() == "ASSISTANT": + logger.info(f"Assistant: {text}") + elif role.upper() == "USER": + logger.info(f"User: {text}") + + return audio_sender + + async def _process_audio_queue(self): + """Audio processor without performance-killing delays.""" + logger.debug("Audio processor started - optimized") + + # Separate PyAudio instance for background processing + audio = pyaudio.PyAudio() + speaker = audio.open( + channels=self.output_channels, + format=pyaudio.paInt16, + output=True, + rate=self.output_sample_rate, + frames_per_buffer=self.chunk_size, + output_device_index=self.output_device_index + ) + + try: + chunks = 0 + while True: + try: + # Get audio from queue + audio_data = await asyncio.wait_for(self.audio_output_queue.get(), timeout=0.1) + + if audio_data and not self.interrupted: + chunks += 1 + + # Use chunked playback like working test_bidi_openai.py + chunk_size = 1024 + for i in range(0, len(audio_data), chunk_size): + if self.interrupted: + break + + chunk = audio_data[i:i + chunk_size] + speaker.write(chunk) + await asyncio.sleep(0.001) # Same as working implementation + + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + + finally: + logger.debug(f"AudioAdapter finished processing {chunks} chunks") + speaker.close() + audio.terminate() + + async def chat(self, duration: Optional[float] = None) -> None: + """Start voice conversation using agent.run() pattern.""" + try: + self._setup_audio() + + if duration: + await asyncio.wait_for( + self.agent.run( + sender=self.create_output(), + receiver=self.create_input() + ), + timeout=duration + ) + else: + await self.agent.run( + sender=self.create_output(), + receiver=self.create_input() + ) + + except KeyboardInterrupt: + logger.info("Conversation ended by user") + except asyncio.TimeoutError: + logger.info(f"Conversation ended after {duration}s timeout") + finally: + if hasattr(self, '_audio_task'): + self._audio_task.cancel() + self._cleanup_audio() + + # Context manager support + async def __aenter__(self) -> "AudioAdapter": + """Async context manager entry.""" + self._setup_audio() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Async context manager exit with cleanup.""" + self._cleanup_audio() + diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 5a2c10c48..937cb34fd 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -15,7 +15,7 @@ import asyncio import json import logging -from typing import Any, AsyncIterable, Mapping, Optional, Union +from typing import Any, AsyncIterable, Mapping, Optional, Union, Callable from .... import _identifier from ....hooks import HookProvider, HookRegistry @@ -324,7 +324,7 @@ async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: Yields: BidirectionalStreamEvent: Events from the model session. """ - while self._session and self._session.active: + while self.active: try: event = await asyncio.wait_for(self._output_queue.get(), timeout=0.1) yield event @@ -341,6 +341,54 @@ async def end(self) -> None: await self._session.stop() self._session = None + async def __aenter__(self) -> "BidirectionalAgent": + """Async context manager entry point. + + Automatically starts the bidirectional session when entering the context. + + Returns: + Self for use in the context. + + Raises: + ValueError: If session is already active. + ConnectionError: If session creation fails. + """ + logger.debug("Entering async context manager - starting session") + await self.start() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Async context manager exit point. + + Automatically ends the session and cleans up resources when exiting + the context, regardless of whether an exception occurred. + + Args: + exc_type: Exception type if an exception occurred, None otherwise. + exc_val: Exception value if an exception occurred, None otherwise. + exc_tb: Exception traceback if an exception occurred, None otherwise. + """ + try: + logger.debug("Exiting async context manager - ending session") + await self.end() + except Exception as cleanup_error: + if exc_type is None: + # No original exception, re-raise cleanup error + logger.error("Error during context manager cleanup: %s", cleanup_error) + raise + else: + # Original exception exists, log cleanup error but don't suppress original + logger.error("Error during context manager cleanup (suppressed due to original exception): %s", cleanup_error) + + @property + def active(self) -> bool: + """Check if the agent session is currently active. + + Returns: + True if session is active and ready for communication, False otherwise. + """ + return self._session is not None and self._session.active + async def run( self, *, @@ -377,8 +425,28 @@ async def custom_receive(): Raises: Exception: Any exception from the transport layer (e.g., WebSocketDisconnect). """ - await self.start() + # Check if session is already active + session_was_active = self.active + + if session_was_active: + # Use existing session + await self._run_with_session(sender, receiver) + else: + # Use async context manager for automatic lifecycle management + async with self: + await self._run_with_session(sender, receiver) + async def _run_with_session( + self, + sender: Callable[[Any], Any], + receiver: Callable[[], Any], + ) -> None: + """Internal method to run send/receive loops with an active session. + + Args: + sender: Async callable that sends events to the client. + receiver: Async callable that receives events from the client. + """ async def receive_from_agent(): """Receive events from agent and send to client.""" try: @@ -391,25 +459,19 @@ async def receive_from_agent(): async def send_to_agent(): """Receive events from client and send to agent.""" try: - while self._session and self._session.active: + while self.active: event = await receiver() await self.send(event) except Exception as e: logger.debug(f"Send to agent stopped: {e}") raise - try: - # Run both loops concurrently - await asyncio.gather( - receive_from_agent(), - send_to_agent(), - return_exceptions=True - ) - finally: - try: - await self.end() - except Exception as e: - logger.debug(f"Error during cleanup: {e}") + # Run both loops concurrently + await asyncio.gather( + receive_from_agent(), + send_to_agent(), + return_exceptions=True + ) def _validate_active_session(self) -> None: """Validate that an active session exists. @@ -417,5 +479,5 @@ def _validate_active_session(self) -> None: Raises: ValueError: If no active session. """ - if not self._session or not self._session.active: - raise ValueError("No active conversation. Call start() first.") + if not self.active: + raise ValueError("No active conversation. Call start() first or use async context manager.") diff --git a/src/strands/experimental/bidirectional_streaming/tests/optimized_example.py b/src/strands/experimental/bidirectional_streaming/tests/optimized_example.py new file mode 100644 index 000000000..1270f3e4c --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/tests/optimized_example.py @@ -0,0 +1,34 @@ +"""Example using the OptimizedAudioAdapter - clean and simple.""" + +import asyncio +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) + +from strands.experimental.bidirectional_streaming.agent.clean_agent import CleanBidirectionalAgent +from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent + +from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel +from strands.experimental.bidirectional_streaming.adapters.optimized_audio_adapter import OptimizedAudioAdapter +from strands_tools import calculator + + +async def main(): + """Test the optimized audio adapter.""" + # Nova Sonic model + model = NovaSonicBidirectionalModel() + + # Clean agent with tools + agent = BidirectionalAgent(model=model, tools=[calculator]) + + # Optimized audio adapter + adapter = OptimizedAudioAdapter(agent) + + # Simple chat using context manager for automatic cleanup + await agent.run(sender=adapter.create_output(), receiver=adapter.create_input()) + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file From a91e41b6739d4de85f261210c3828bdca75e035b Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 5 Nov 2025 12:23:48 -0500 Subject: [PATCH 048/242] update dev experience and change verbage from session to connection --- .../adapters/__init__.py | 10 + .../adapters/audio_adapter.py | 173 +++++------------- .../bidirectional_streaming/agent/agent.py | 157 ++++++++-------- 3 files changed, 141 insertions(+), 199 deletions(-) create mode 100644 src/strands/experimental/bidirectional_streaming/adapters/__init__.py diff --git a/src/strands/experimental/bidirectional_streaming/adapters/__init__.py b/src/strands/experimental/bidirectional_streaming/adapters/__init__.py new file mode 100644 index 000000000..07d258a3e --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/adapters/__init__.py @@ -0,0 +1,10 @@ +"""Adapters for BidirectionalAgent. + +Provides clean separation of concerns by moving hardware-specific functionality +(audio, video, sensors, etc.) into separate adapter classes that work with +the core BidirectionalAgent through the run() pattern. +""" + +from .audio_adapter import AudioAdapter + +__all__ = ["AudioAdapter"] \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/adapters/audio_adapter.py b/src/strands/experimental/bidirectional_streaming/adapters/audio_adapter.py index ceac7a64b..1126b976b 100644 --- a/src/strands/experimental/bidirectional_streaming/adapters/audio_adapter.py +++ b/src/strands/experimental/bidirectional_streaming/adapters/audio_adapter.py @@ -7,10 +7,7 @@ import asyncio import base64 import logging -from typing import Any, Callable, Optional, TYPE_CHECKING - -if TYPE_CHECKING: - from ..agent import BidirectionalAgent +from typing import Any, Callable, Optional try: import pyaudio @@ -21,17 +18,15 @@ class AudioAdapter: - """Audio adapter for BidirectionalAgent with queue-based processing.""" + """Audio adapter for BidirectionalAgent with direct stream processing.""" def __init__( self, - agent: "BidirectionalAgent", audio_config: Optional[dict] = None, ): """Initialize AudioAdapter with clean audio configuration. Args: - agent: The BidirectionalAgent instance to wrap audio_config: Dictionary containing audio configuration: - input_sample_rate (int): Microphone sample rate (default: 24000) - output_sample_rate (int): Speaker sample rate (default: 24000) @@ -44,8 +39,6 @@ def __init__( if pyaudio is None: raise ImportError("PyAudio is required for AudioAdapter. Install with: pip install pyaudio") - self.agent = agent - # Default audio configuration default_config = { "input_sample_rate": 24000, @@ -70,24 +63,21 @@ def __init__( self.input_channels = default_config["input_channels"] self.output_channels = default_config["output_channels"] - # Audio infrastructure (lazy initialization) + # Audio infrastructure self.audio = None self.input_stream = None self.output_stream = None self.interrupted = False - - # Audio output queue for background processing - self.audio_output_queue = asyncio.Queue() def _setup_audio(self) -> None: """Setup PyAudio streams for input and output.""" if self.audio: - return # Already setup + return self.audio = pyaudio.PyAudio() try: - # Input stream (microphone) + # Input stream self.input_stream = self.audio.open( format=pyaudio.paInt16, channels=self.input_channels, @@ -97,7 +87,7 @@ def _setup_audio(self) -> None: input_device_index=self.input_device_index ) - # Output stream (speakers) + # Output stream self.output_stream = self.audio.open( format=pyaudio.paInt16, channels=self.output_channels, @@ -107,7 +97,7 @@ def _setup_audio(self) -> None: output_device_index=self.output_device_index ) - # Start streams - required for audio to flow + # Start streams self.input_stream.start_stream() self.output_stream.start_stream() @@ -152,7 +142,7 @@ async def audio_receiver() -> dict: "audioData": audio_bytes, "format": "pcm", "sampleRate": self.input_sample_rate, - "channels": self.input_channels # Use configured channels + "channels": self.input_channels } except Exception as e: logger.warning(f"Audio input error: {e}") @@ -161,126 +151,61 @@ async def audio_receiver() -> dict: return audio_receiver def create_output(self) -> Callable[[dict], None]: - """Create output function that queues audio for background processing.""" - - # Start background audio processor once - if not hasattr(self, '_audio_task') or self._audio_task.done(): - self._audio_task = asyncio.create_task(self._process_audio_queue()) - - events_queued = 0 + """Create audio output function with direct stream writing.""" async def audio_sender(event: dict) -> None: - """Queue audio events with minimal debug.""" - nonlocal events_queued + """Handle audio events with direct stream writing.""" + if not self.output_stream: + self._setup_audio() - if "audioOutput" in event: - if not self.interrupted: - audio_data = event["audioOutput"]["audioData"] - self.audio_output_queue.put_nowait(audio_data) - events_queued += 1 + # Handle audio output + if "audioOutput" in event and not self.interrupted: + audio_data = event["audioOutput"]["audioData"] + + # Handle both base64 and raw bytes + if isinstance(audio_data, str): + audio_data = base64.b64decode(audio_data) + + if audio_data: + chunk_size = 2048 + for i in range(0, len(audio_data), chunk_size): + # Check for interruption before each chunk + if self.interrupted: + break + + chunk = audio_data[i:i + chunk_size] + try: + self.output_stream.write(chunk, exception_on_underflow=False) + await asyncio.sleep(0) + except Exception as e: + logger.warning(f"Audio playback error: {e}") + break elif "interruptionDetected" in event or "interrupted" in event: self.interrupted = True - cleared = 0 - while not self.audio_output_queue.empty(): + logger.debug("Interruption detected") + + # Stop and restart stream for immediate interruption + if self.output_stream: try: - self.audio_output_queue.get_nowait() - cleared += 1 - except asyncio.QueueEmpty: - break - logger.debug(f"Cleared {cleared} audio chunks on interruption") + self.output_stream.stop_stream() + self.output_stream.start_stream() + except Exception as e: + logger.debug(f"Error clearing audio buffer: {e}") + self.interrupted = False elif "textOutput" in event: - text = event["textOutput"].get("text", "") + text = event["textOutput"].get("text", "").strip() role = event["textOutput"].get("role", "") - if role.upper() == "ASSISTANT": - logger.info(f"Assistant: {text}") - elif role.upper() == "USER": - logger.info(f"User: {text}") + if text: + if role.upper() == "ASSISTANT": + print(f"🤖 {text}") + elif role.upper() == "USER": + print(f"User: {text}") return audio_sender - async def _process_audio_queue(self): - """Audio processor without performance-killing delays.""" - logger.debug("Audio processor started - optimized") - - # Separate PyAudio instance for background processing - audio = pyaudio.PyAudio() - speaker = audio.open( - channels=self.output_channels, - format=pyaudio.paInt16, - output=True, - rate=self.output_sample_rate, - frames_per_buffer=self.chunk_size, - output_device_index=self.output_device_index - ) - try: - chunks = 0 - while True: - try: - # Get audio from queue - audio_data = await asyncio.wait_for(self.audio_output_queue.get(), timeout=0.1) - - if audio_data and not self.interrupted: - chunks += 1 - - # Use chunked playback like working test_bidi_openai.py - chunk_size = 1024 - for i in range(0, len(audio_data), chunk_size): - if self.interrupted: - break - - chunk = audio_data[i:i + chunk_size] - speaker.write(chunk) - await asyncio.sleep(0.001) # Same as working implementation - except asyncio.TimeoutError: - continue - except asyncio.CancelledError: - break - - finally: - logger.debug(f"AudioAdapter finished processing {chunks} chunks") - speaker.close() - audio.terminate() - - async def chat(self, duration: Optional[float] = None) -> None: - """Start voice conversation using agent.run() pattern.""" - try: - self._setup_audio() - - if duration: - await asyncio.wait_for( - self.agent.run( - sender=self.create_output(), - receiver=self.create_input() - ), - timeout=duration - ) - else: - await self.agent.run( - sender=self.create_output(), - receiver=self.create_input() - ) - - except KeyboardInterrupt: - logger.info("Conversation ended by user") - except asyncio.TimeoutError: - logger.info(f"Conversation ended after {duration}s timeout") - finally: - if hasattr(self, '_audio_task'): - self._audio_task.cancel() - self._cleanup_audio() - - # Context manager support - async def __aenter__(self) -> "AudioAdapter": - """Async context manager entry.""" - self._setup_audio() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: - """Async context manager exit with cleanup.""" - self._cleanup_audio() diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 937cb34fd..9608b8080 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -1,12 +1,12 @@ """Bidirectional Agent for real-time streaming conversations. -Provides real-time audio and text interaction through persistent streaming sessions. +Provides real-time audio and text interaction through persistent streaming connections. Unlike traditional request-response patterns, this agent maintains long-running conversations where users can interrupt, provide additional input, and receive continuous responses including audio output. Key capabilities: -- Persistent conversation sessions with concurrent processing +- Persistent conversation connections with concurrent processing - Real-time audio input/output streaming - Automatic interruption detection and tool execution - Event-driven communication with model providers @@ -43,7 +43,7 @@ class BidirectionalAgent: """Agent for bidirectional streaming conversations. Enables real-time audio and text interaction with AI models through persistent - sessions. Supports concurrent tool execution and interruption handling. + connections. Supports concurrent tool execution and interruption handling. """ def __init__( @@ -60,6 +60,7 @@ def __init__( hooks: Optional[list[HookProvider]] = None, trace_attributes: Optional[Mapping[str, AttributeValue]] = None, description: Optional[str] = None, + adapters: Optional[list[Any]] = None, **kwargs: Any, ): """Initialize bidirectional agent with flexible model support and extensible configuration. @@ -71,12 +72,13 @@ def __init__( messages: Optional conversation history to initialize with. record_direct_tool_call: Whether to record direct tool calls in message history. load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory. - agent_id: Optional ID for the agent, useful for session management and multi-agent scenarios. + agent_id: Optional ID for the agent, useful for connection management and multi-agent scenarios. name: Name of the Agent. tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). hooks: Hooks to be added to the agent hook registry. trace_attributes: Custom trace attributes to apply to the agent's trace span. description: Description of what the Agent does. + adapters: Optional list of adapter instances (e.g., AudioAdapter) for hardware abstraction. **kwargs: Additional configuration for future extensibility. Raises: @@ -136,13 +138,16 @@ def __init__( self.event_loop_metrics = EventLoopMetrics() self.tool_caller = ToolCaller(self) - # Session management - self._session: Optional["BidirectionalAgentLoop"] = None + # connection management + self._agentloop: Optional["BidirectionalAgentLoop"] = None self._output_queue = asyncio.Queue() # Store extensibility kwargs for future use self._config_kwargs = kwargs + # Initialize adapters + self.adapters = adapters or [] + @property def tool(self) -> ToolCaller: """Call tool as a function. @@ -255,27 +260,27 @@ def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: di return {k: v for k, v in input_params.items() if k in properties} async def start(self) -> None: - """Start a persistent bidirectional conversation session. + """Start a persistent bidirectional conversation connection. - Initializes the streaming session and starts background tasks for processing - model events, tool execution, and session management. + Initializes the streaming connection and starts background tasks for processing + model events, tool execution, and connection management. Raises: ValueError: If conversation already active. - ConnectionError: If session creation fails. + ConnectionError: If connection creation fails. """ - if self._session and self._session.active: + if self._agentloop and self._agentloop.active: raise ValueError("Conversation already active. Call end() first.") - logger.debug("Conversation start - initializing session") + logger.debug("Conversation start - initializing connection") # Create model session and event loop directly model_session = await self.model.create_bidirectional_connection( system_prompt=self.system_prompt, tools=self.tool_registry.get_all_tool_specs(), messages=self.messages ) - self._session = BidirectionalAgentLoop(model_session=model_session, agent=self) - await self._session.start() + self._agentloop = BidirectionalAgentLoop(model_session=model_session, agent=self) + await self._agentloop.start() logger.debug("Conversation ready") @@ -283,16 +288,16 @@ async def send(self, input_data: Union[str, AudioInputEvent]) -> None: """Send input to the model (text or audio). Unified method for sending both text and audio input to the model during - an active conversation session. User input is automatically added to + an active conversation connection. User input is automatically added to conversation history for complete message tracking. Args: input_data: String for text, AudioInputEvent for audio, or ImageInputEvent for images. Raises: - ValueError: If no active session or invalid input type. + ValueError: If no active connection or invalid input type. """ - self._validate_active_session() + self._validate_active_agentloop() if isinstance(input_data, str): # Add user text message to history @@ -301,13 +306,13 @@ async def send(self, input_data: Union[str, AudioInputEvent]) -> None: self.messages.append(user_message) logger.debug("Text sent: %d characters", len(input_data)) - await self._session.model_session.send_text_content(input_data) + await self._agentloop.model_session.send_text_content(input_data) elif isinstance(input_data, dict) and "audioData" in input_data: # Handle audio input - await self._session.model_session.send_audio_content(input_data) + await self._agentloop.model_session.send_audio_content(input_data) elif isinstance(input_data, dict) and "imageData" in input_data: # Handle image input (ImageInputEvent) - await self._session.model_session.send_image_content(input_data) + await self._agentloop.model_session.send_image_content(input_data) else: raise ValueError( "Input must be either a string (text), AudioInputEvent " @@ -319,7 +324,7 @@ async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: """Receive events from the model including audio, text, and tool calls. Yields model output events processed by background tasks including audio output, - text responses, tool calls, and session updates. + text responses, tool calls, and connection updates. Yields: BidirectionalStreamEvent: Events from the model session. @@ -332,36 +337,36 @@ async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: continue async def end(self) -> None: - """End the conversation session and cleanup all resources. + """End the conversation connection and cleanup all resources. - Terminates the streaming session, cancels background tasks, and + Terminates the streaming connection, cancels background tasks, and closes the connection to the model provider. """ - if self._session: - await self._session.stop() - self._session = None + if self._agentloop: + await self._agentloop.stop() + self._agentloop = None async def __aenter__(self) -> "BidirectionalAgent": """Async context manager entry point. - Automatically starts the bidirectional session when entering the context. + Automatically starts the bidirectional connection when entering the context. Returns: Self for use in the context. Raises: - ValueError: If session is already active. - ConnectionError: If session creation fails. + ValueError: If connection is already active. + ConnectionError: If connection creation fails. """ - logger.debug("Entering async context manager - starting session") + logger.debug("Entering async context manager - starting connection") await self.start() return self async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: """Async context manager exit point. - Automatically ends the session and cleans up resources when exiting - the context, regardless of whether an exception occurred. + Automatically ends the connection and cleans up resources including adapters + when exiting the context, regardless of whether an exception occurred. Args: exc_type: Exception type if an exception occurred, None otherwise. @@ -369,8 +374,20 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: exc_tb: Exception traceback if an exception occurred, None otherwise. """ try: - logger.debug("Exiting async context manager - ending session") + logger.debug("Exiting async context manager - cleaning up adapters and connection") + + # Cleanup adapters first + for adapter in self.adapters: + if hasattr(adapter, '_cleanup_audio'): + try: + adapter._cleanup_audio() + logger.debug(f"Cleaned up adapter: {type(adapter).__name__}") + except Exception as adapter_error: + logger.warning(f"Error cleaning up adapter: {adapter_error}") + + # Then cleanup agent connection await self.end() + except Exception as cleanup_error: if exc_type is None: # No original exception, re-raise cleanup error @@ -382,66 +399,56 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: @property def active(self) -> bool: - """Check if the agent session is currently active. + """Check if the agent connection is currently active. Returns: - True if session is active and ready for communication, False otherwise. + True if connection is active and ready for communication, False otherwise. """ - return self._session is not None and self._session.active - - async def run( - self, - *, - sender: Callable[[Any], Any], - receiver: Callable[[], Any], - ) -> None: - """Run the agent with send/receive loop management. + return self._agentloop is not None and self._agentloop.active - Starts the session, pipes events between the agent and transport layer, - and handles cleanup on disconnection. + async def connect(self) -> None: + """Connect the agent using configured adapters for bidirectional communication. - Args: - sender: Async callable that sends events to the client (e.g., websocket.send_json). - receiver: Async callable that receives events from the client (e.g., websocket.receive_json). + Automatically uses configured adapters to establish bidirectional communication + with the model. Handles connection lifecycle and transport coordination. Example: ```python - # With WebSocket - agent = BidirectionalAgent(model=model, tools=[calculator]) - await agent.run(sender=websocket.send_json, receiver=websocket.receive_json) - - # With custom transport - async def custom_send(event): - # Custom send logic - pass - - async def custom_receive(): - # Custom receive logic - return event - - await agent.run(sender=custom_send, receiver=custom_receive) + # With AudioAdapter + adapter = AudioAdapter(audio_config={"input_sample_rate": 16000}) + agent = BidirectionalAgent(model=model, tools=[calculator], adapters=[adapter]) + await agent.connect() ``` Raises: - Exception: Any exception from the transport layer (e.g., WebSocketDisconnect). + ValueError: If no adapters are configured. + Exception: Any exception from the transport layer. """ - # Check if session is already active - session_was_active = self.active + if not self.adapters: + raise ValueError("No adapters configured. Add adapters to the agent constructor.") + + # Use first adapter + adapter = self.adapters[0] + sender = adapter.create_output() + receiver = adapter.create_input() + + # Check if connection is already active + connection_was_active = self.active - if session_was_active: - # Use existing session - await self._run_with_session(sender, receiver) + if connection_was_active: + # Use existing connection + await self._run_with_agentloop(sender, receiver) else: # Use async context manager for automatic lifecycle management async with self: - await self._run_with_session(sender, receiver) + await self._run_with_agentloop(sender, receiver) - async def _run_with_session( + async def _run_with_agentloop( self, sender: Callable[[Any], Any], receiver: Callable[[], Any], ) -> None: - """Internal method to run send/receive loops with an active session. + """Internal method to run send/receive loops with an active connection. Args: sender: Async callable that sends events to the client. @@ -473,11 +480,11 @@ async def send_to_agent(): return_exceptions=True ) - def _validate_active_session(self) -> None: - """Validate that an active session exists. + def _validate_active_agentloop(self) -> None: + """Validate that an active connection exists. Raises: - ValueError: If no active session. + ValueError: If no active connection. """ if not self.active: raise ValueError("No active conversation. Call start() first or use async context manager.") From 1e9d185d06849de607678f942d80b9ce840e2b6c Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 5 Nov 2025 12:46:31 -0500 Subject: [PATCH 049/242] rename validate_connection --- .../bidirectional_streaming/agent/agent.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 9608b8080..9e6bc7349 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -297,7 +297,7 @@ async def send(self, input_data: Union[str, AudioInputEvent]) -> None: Raises: ValueError: If no active connection or invalid input type. """ - self._validate_active_agentloop() + self._validate_active_connection() if isinstance(input_data, str): # Add user text message to history @@ -431,19 +431,16 @@ async def connect(self) -> None: adapter = self.adapters[0] sender = adapter.create_output() receiver = adapter.create_input() - - # Check if connection is already active - connection_was_active = self.active - - if connection_was_active: + + if self.active: # Use existing connection - await self._run_with_agentloop(sender, receiver) + await self._run(sender, receiver) else: - # Use async context manager for automatic lifecycle management + # Use async context manager for automatic lifecycle management async with self: - await self._run_with_agentloop(sender, receiver) + await self._run(sender, receiver) - async def _run_with_agentloop( + async def _run( self, sender: Callable[[Any], Any], receiver: Callable[[], Any], @@ -480,7 +477,7 @@ async def send_to_agent(): return_exceptions=True ) - def _validate_active_agentloop(self) -> None: + def _validate_active_connection(self) -> None: """Validate that an active connection exists. Raises: From 6d8c3557a11af2045e1cc9bf7330a9dfdf61f894 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 5 Nov 2025 13:33:19 -0500 Subject: [PATCH 050/242] remove hooks and otel parameters from constructor for focused implementation. Will be added later when implementation is added. --- .../bidirectional_streaming/agent/agent.py | 31 +++---------------- src/strands/tools/caller.py | 8 ++--- 2 files changed, 8 insertions(+), 31 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 9e6bc7349..161420454 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -18,7 +18,6 @@ from typing import Any, AsyncIterable, Mapping, Optional, Union, Callable from .... import _identifier -from ....hooks import HookProvider, HookRegistry from ....telemetry.metrics import EventLoopMetrics from ....tools.caller import ToolCaller from ....tools.executors import ConcurrentToolExecutor @@ -37,6 +36,8 @@ _DEFAULT_AGENT_NAME = "Strands Agents" _DEFAULT_AGENT_ID = "default" +# Type alias for cleaner send() method signature +BidirectionalInput = str | AudioInputEvent | ImageInputEvent class BidirectionalAgent: @@ -57,13 +58,11 @@ def __init__( agent_id: Optional[str] = None, name: Optional[str] = None, tool_executor: Optional[ToolExecutor] = None, - hooks: Optional[list[HookProvider]] = None, - trace_attributes: Optional[Mapping[str, AttributeValue]] = None, description: Optional[str] = None, adapters: Optional[list[Any]] = None, **kwargs: Any, ): - """Initialize bidirectional agent with flexible model support and extensible configuration. + """Initialize bidirectional agent. Args: model: BidirectionalModel instance, string model_id, or None for default detection. @@ -75,8 +74,6 @@ def __init__( agent_id: Optional ID for the agent, useful for connection management and multi-agent scenarios. name: Name of the Agent. tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). - hooks: Hooks to be added to the agent hook registry. - trace_attributes: Custom trace attributes to apply to the agent's trace span. description: Description of what the Agent does. adapters: Optional list of adapter instances (e.g., AudioAdapter) for hardware abstraction. **kwargs: Additional configuration for future extensibility. @@ -104,15 +101,6 @@ def __init__( self.record_direct_tool_call = record_direct_tool_call self.load_tools_from_directory = load_tools_from_directory - # Process trace attributes to ensure they're of compatible types - self.trace_attributes: dict[str, AttributeValue] = {} - if trace_attributes: - for k, v in trace_attributes.items(): - if isinstance(v, (str, int, float, bool)) or ( - isinstance(v, list) and all(isinstance(x, (str, int, float, bool)) for x in v) - ): - self.trace_attributes[k] = v - # Initialize tool registry self.tool_registry = ToolRegistry() @@ -128,12 +116,6 @@ def __init__( # Initialize tool executor self.tool_executor = tool_executor or ConcurrentToolExecutor() - # Initialize hooks system - self.hooks = HookRegistry() - if hooks: - for hook in hooks: - self.hooks.add_hook(hook) - # Initialize other components self.event_loop_metrics = EventLoopMetrics() self.tool_caller = ToolCaller(self) @@ -142,9 +124,6 @@ def __init__( self._agentloop: Optional["BidirectionalAgentLoop"] = None self._output_queue = asyncio.Queue() - # Store extensibility kwargs for future use - self._config_kwargs = kwargs - # Initialize adapters self.adapters = adapters or [] @@ -284,7 +263,7 @@ async def start(self) -> None: logger.debug("Conversation ready") - async def send(self, input_data: Union[str, AudioInputEvent]) -> None: + async def send(self, input_data: BidirectionalInput) -> None: """Send input to the model (text or audio). Unified method for sending both text and audio input to the model during @@ -331,7 +310,7 @@ async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: """ while self.active: try: - event = await asyncio.wait_for(self._output_queue.get(), timeout=0.1) + event = await self._output_queue.get() yield event except asyncio.TimeoutError: continue diff --git a/src/strands/tools/caller.py b/src/strands/tools/caller.py index 06e1f23b3..cc9dc5743 100644 --- a/src/strands/tools/caller.py +++ b/src/strands/tools/caller.py @@ -1,7 +1,5 @@ -"""Shared ToolCaller base class to eliminate duplication between agent implementations. +"""ToolCaller base class. -Provides common tool calling functionality that can be used by both traditional -Agent and BidirectionalAgent classes with agent-specific customizations. """ import asyncio @@ -14,11 +12,11 @@ class ToolCaller: - """Universal tool caller that works with both traditional and bidirectional agents. + """Provides common tool calling functionality that can be used by both traditional +Agent and BidirectionalAgent classes with agent-specific customizations. Automatically detects agent type and applies appropriate behavior: - Traditional agents: Uses conversation_manager.apply_management() - - Bidirectional agents: Skips conversation management (not needed for streaming) """ def __init__(self, agent: Any) -> None: From c5328e06cb1b689412b8bb1cca591d76ad783682 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 5 Nov 2025 14:53:34 -0500 Subject: [PATCH 051/242] hatch fmt --formatter --- .../adapters/audio_adapter.py | 90 ++++++++++--------- .../bidirectional_streaming/agent/agent.py | 41 +++++---- src/strands/tools/caller.py | 10 +-- 3 files changed, 70 insertions(+), 71 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/adapters/audio_adapter.py b/src/strands/experimental/bidirectional_streaming/adapters/audio_adapter.py index 1126b976b..b093ae0dd 100644 --- a/src/strands/experimental/bidirectional_streaming/adapters/audio_adapter.py +++ b/src/strands/experimental/bidirectional_streaming/adapters/audio_adapter.py @@ -19,13 +19,13 @@ class AudioAdapter: """Audio adapter for BidirectionalAgent with direct stream processing.""" - + def __init__( self, audio_config: Optional[dict] = None, ): """Initialize AudioAdapter with clean audio configuration. - + Args: audio_config: Dictionary containing audio configuration: - input_sample_rate (int): Microphone sample rate (default: 24000) @@ -38,7 +38,7 @@ def __init__( """ if pyaudio is None: raise ImportError("PyAudio is required for AudioAdapter. Install with: pip install pyaudio") - + # Default audio configuration default_config = { "input_sample_rate": 24000, @@ -47,13 +47,13 @@ def __init__( "input_device_index": None, "output_device_index": None, "input_channels": 1, - "output_channels": 1 + "output_channels": 1, } - + # Merge user config with defaults if audio_config: default_config.update(audio_config) - + # Set audio configuration attributes self.input_sample_rate = default_config["input_sample_rate"] self.output_sample_rate = default_config["output_sample_rate"] @@ -62,7 +62,7 @@ def __init__( self.output_device_index = default_config["output_device_index"] self.input_channels = default_config["input_channels"] self.output_channels = default_config["output_channels"] - + # Audio infrastructure self.audio = None self.input_stream = None @@ -73,39 +73,39 @@ def _setup_audio(self) -> None: """Setup PyAudio streams for input and output.""" if self.audio: return - + self.audio = pyaudio.PyAudio() - + try: # Input stream self.input_stream = self.audio.open( - format=pyaudio.paInt16, - channels=self.input_channels, + format=pyaudio.paInt16, + channels=self.input_channels, rate=self.input_sample_rate, - input=True, + input=True, frames_per_buffer=self.chunk_size, - input_device_index=self.input_device_index + input_device_index=self.input_device_index, ) - + # Output stream self.output_stream = self.audio.open( - format=pyaudio.paInt16, - channels=self.output_channels, + format=pyaudio.paInt16, + channels=self.output_channels, rate=self.output_sample_rate, - output=True, + output=True, frames_per_buffer=self.chunk_size, - output_device_index=self.output_device_index + output_device_index=self.output_device_index, ) - + # Start streams self.input_stream.start_stream() self.output_stream.start_stream() - + except Exception as e: logger.error(f"AudioAdapter: Audio setup failed: {e}") self._cleanup_audio() raise - + def _cleanup_audio(self) -> None: """Clean up PyAudio resources.""" try: @@ -113,67 +113,73 @@ def _cleanup_audio(self) -> None: if self.input_stream.is_active(): self.input_stream.stop_stream() self.input_stream.close() - + if self.output_stream: if self.output_stream.is_active(): self.output_stream.stop_stream() self.output_stream.close() - + if self.audio: self.audio.terminate() - + self.input_stream = None self.output_stream = None self.audio = None - + except Exception as e: logger.warning(f"Audio cleanup error: {e}") def create_input(self) -> Callable[[], dict]: """Create audio input function for agent.run().""" + async def audio_receiver() -> dict: """Read audio from microphone.""" if not self.input_stream: self._setup_audio() - + try: audio_bytes = self.input_stream.read(self.chunk_size, exception_on_overflow=False) return { "audioData": audio_bytes, - "format": "pcm", + "format": "pcm", "sampleRate": self.input_sample_rate, - "channels": self.input_channels + "channels": self.input_channels, } except Exception as e: logger.warning(f"Audio input error: {e}") - return {"audioData": b"", "format": "pcm", "sampleRate": self.input_sample_rate, "channels": self.input_channels} - + return { + "audioData": b"", + "format": "pcm", + "sampleRate": self.input_sample_rate, + "channels": self.input_channels, + } + return audio_receiver - + def create_output(self) -> Callable[[dict], None]: """Create audio output function with direct stream writing.""" - + async def audio_sender(event: dict) -> None: """Handle audio events with direct stream writing.""" if not self.output_stream: self._setup_audio() - + # Handle audio output if "audioOutput" in event and not self.interrupted: audio_data = event["audioOutput"]["audioData"] - + # Handle both base64 and raw bytes if isinstance(audio_data, str): audio_data = base64.b64decode(audio_data) - + if audio_data: chunk_size = 2048 for i in range(0, len(audio_data), chunk_size): # Check for interruption before each chunk if self.interrupted: break - - chunk = audio_data[i:i + chunk_size] + + chunk = audio_data[i : i + chunk_size] try: self.output_stream.write(chunk, exception_on_underflow=False) await asyncio.sleep(0) @@ -184,7 +190,7 @@ async def audio_sender(event: dict) -> None: elif "interruptionDetected" in event or "interrupted" in event: self.interrupted = True logger.debug("Interruption detected") - + # Stop and restart stream for immediate interruption if self.output_stream: try: @@ -192,7 +198,7 @@ async def audio_sender(event: dict) -> None: self.output_stream.start_stream() except Exception as e: logger.debug(f"Error clearing audio buffer: {e}") - + self.interrupted = False elif "textOutput" in event: @@ -203,9 +209,5 @@ async def audio_sender(event: dict) -> None: print(f"🤖 {text}") elif role.upper() == "USER": print(f"User: {text}") - - return audio_sender - - - + return audio_sender diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 161420454..2378923fc 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -272,7 +272,7 @@ async def send(self, input_data: BidirectionalInput) -> None: Args: input_data: String for text, AudioInputEvent for audio, or ImageInputEvent for images. - + Raises: ValueError: If no active connection or invalid input type. """ @@ -327,12 +327,12 @@ async def end(self) -> None: async def __aenter__(self) -> "BidirectionalAgent": """Async context manager entry point. - + Automatically starts the bidirectional connection when entering the context. - + Returns: Self for use in the context. - + Raises: ValueError: If connection is already active. ConnectionError: If connection creation fails. @@ -343,10 +343,10 @@ async def __aenter__(self) -> "BidirectionalAgent": async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: """Async context manager exit point. - - Automatically ends the connection and cleans up resources including adapters + + Automatically ends the connection and cleans up resources including adapters when exiting the context, regardless of whether an exception occurred. - + Args: exc_type: Exception type if an exception occurred, None otherwise. exc_val: Exception value if an exception occurred, None otherwise. @@ -354,19 +354,19 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: """ try: logger.debug("Exiting async context manager - cleaning up adapters and connection") - + # Cleanup adapters first for adapter in self.adapters: - if hasattr(adapter, '_cleanup_audio'): + if hasattr(adapter, "_cleanup_audio"): try: adapter._cleanup_audio() logger.debug(f"Cleaned up adapter: {type(adapter).__name__}") except Exception as adapter_error: logger.warning(f"Error cleaning up adapter: {adapter_error}") - + # Then cleanup agent connection await self.end() - + except Exception as cleanup_error: if exc_type is None: # No original exception, re-raise cleanup error @@ -374,12 +374,14 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: raise else: # Original exception exists, log cleanup error but don't suppress original - logger.error("Error during context manager cleanup (suppressed due to original exception): %s", cleanup_error) + logger.error( + "Error during context manager cleanup (suppressed due to original exception): %s", cleanup_error + ) @property def active(self) -> bool: """Check if the agent connection is currently active. - + Returns: True if connection is active and ready for communication, False otherwise. """ @@ -405,7 +407,7 @@ async def connect(self) -> None: """ if not self.adapters: raise ValueError("No adapters configured. Add adapters to the agent constructor.") - + # Use first adapter adapter = self.adapters[0] sender = adapter.create_output() @@ -415,7 +417,7 @@ async def connect(self) -> None: # Use existing connection await self._run(sender, receiver) else: - # Use async context manager for automatic lifecycle management + # Use async context manager for automatic lifecycle management async with self: await self._run(sender, receiver) @@ -425,11 +427,12 @@ async def _run( receiver: Callable[[], Any], ) -> None: """Internal method to run send/receive loops with an active connection. - + Args: sender: Async callable that sends events to the client. receiver: Async callable that receives events from the client. """ + async def receive_from_agent(): """Receive events from agent and send to client.""" try: @@ -450,11 +453,7 @@ async def send_to_agent(): raise # Run both loops concurrently - await asyncio.gather( - receive_from_agent(), - send_to_agent(), - return_exceptions=True - ) + await asyncio.gather(receive_from_agent(), send_to_agent(), return_exceptions=True) def _validate_active_connection(self) -> None: """Validate that an active connection exists. diff --git a/src/strands/tools/caller.py b/src/strands/tools/caller.py index cc9dc5743..167789801 100644 --- a/src/strands/tools/caller.py +++ b/src/strands/tools/caller.py @@ -1,6 +1,4 @@ -"""ToolCaller base class. - -""" +"""ToolCaller base class.""" import asyncio import random @@ -13,10 +11,10 @@ class ToolCaller: """Provides common tool calling functionality that can be used by both traditional -Agent and BidirectionalAgent classes with agent-specific customizations. + Agent and BidirectionalAgent classes with agent-specific customizations. - Automatically detects agent type and applies appropriate behavior: - - Traditional agents: Uses conversation_manager.apply_management() + Automatically detects agent type and applies appropriate behavior: + - Traditional agents: Uses conversation_manager.apply_management() """ def __init__(self, agent: Any) -> None: From bd5401f081443dcb433f6be6bf1c85a69e9391a3 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 6 Nov 2025 13:44:50 +0300 Subject: [PATCH 052/242] Return typed dicts on agent and refactor error event --- .../bidirectional_streaming/agent/agent.py | 13 +++---- .../event_loop/bidirectional_event_loop.py | 4 +- .../types/bidirectional_streaming.py | 38 ++++++++++++------- 3 files changed, 31 insertions(+), 24 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index bbe3f3da2..f0205f8a8 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -434,24 +434,21 @@ async def send(self, input_data: str | AudioInputEvent | ImageInputEvent | dict) f"Input must be a string, InputEvent (TextInputEvent/AudioInputEvent/ImageInputEvent), or event dict with 'type' field, got: {type(input_data)}" ) - async def receive(self) -> AsyncIterable[dict[str, Any]]: + async def receive(self) -> AsyncIterable["OutputEvent"]: """Receive events from the model including audio, text, and tool calls. Yields model output events processed by background tasks including audio output, text responses, tool calls, and session updates. Yields: - dict: Event dictionaries from the model session. Each event is a TypedEvent - converted to a dictionary for consistency with the standard Agent API. + OutputEvent: TypedEvent objects from the model session. Events are + JSON-serializable by default (use json.dumps(event) for transport). """ while self._session and self._session.active: try: event = await asyncio.wait_for(self._output_queue.get(), timeout=0.1) - # Convert TypedEvent to dict for consistency with Agent.stream_async - if hasattr(event, 'as_dict'): - yield event.as_dict() - else: - yield event + # Return TypedEvent objects directly (JSON-serializable by default) + yield event except asyncio.TimeoutError: continue diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index e618245e1..8af2515ef 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -444,14 +444,14 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: logger.debug("Tool result sent to model: %s", tool_use_id) # Also forward ToolResultEvent to output queue for client visibility - await session.agent._output_queue.put(tool_event.as_dict()) + await session.agent._output_queue.put(tool_event) logger.debug("Tool result sent to client: %s", tool_use_id) # Handle streaming events if needed later elif isinstance(tool_event, ToolStreamEvent): logger.debug("Tool stream event: %s", tool_event) # Forward tool stream events to output queue - await session.agent._output_queue.put(tool_event.as_dict()) + await session.agent._output_queue.put(tool_event) # Add tool result message to conversation history if tool_results: diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 160b15a27..c7ca1515f 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -427,44 +427,54 @@ def reason(self) -> str: class ErrorEvent(TypedEvent): """Error occurred during the session. - Similar to strands.types._events.ForceStopEvent, this event wraps exceptions - that occur during bidirectional streaming sessions. - - Note: The Exception object is not stored in the event data to maintain JSON - serializability. Only the error message, code, and details are stored. + Stores the full Exception object as an instance attribute for debugging while + keeping the event dict JSON-serializable. The exception can be accessed via + the `error` property for re-raising or type-based error handling. Parameters: - error: The exception that occurred (used to extract message and type). - code: Optional error code for programmatic handling (defaults to exception class name). + error: The exception that occurred. details: Optional additional error information. """ def __init__( self, error: Exception, - code: Optional[str] = None, details: Optional[Dict[str, Any]] = None, ): + # Store serializable data in dict (for JSON serialization) super().__init__( { "type": "bidirectional_error", - "error_message": str(error), - "error_code": code or type(error).__name__, - "error_details": details, + "message": str(error), + "code": type(error).__name__, + "details": details, } ) + # Store exception as instance attribute (not serialized) + self._error = error + + @property + def error(self) -> Exception: + """The original exception that occurred. + + Can be used for re-raising or type-based error handling. + """ + return self._error @property def code(self) -> str: - return cast(str, self.get("error_code")) + """Error code derived from exception class name.""" + return cast(str, self.get("code")) @property def message(self) -> str: - return cast(str, self.get("error_message")) + """Human-readable error message from the exception.""" + return cast(str, self.get("message")) @property def details(self) -> Optional[Dict[str, Any]]: - return cast(Optional[Dict[str, Any]], self.get("error_details")) + """Additional error context beyond the exception itself.""" + return cast(Optional[Dict[str, Any]], self.get("details")) # ============================================================================ From 8a74a93c963f13419ac5a453551b36f72840c6b5 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 6 Nov 2025 14:09:56 +0300 Subject: [PATCH 053/242] refactor: change multimodal usage event to usage event --- .../experimental/bidirectional_streaming/__init__.py | 4 ++-- .../bidirectional_streaming/models/gemini_live.py | 4 ++-- .../bidirectional_streaming/models/novasonic.py | 4 ++-- .../bidirectional_streaming/models/openai.py | 4 ++-- .../bidirectional_streaming/types/__init__.py | 4 ++-- .../types/bidirectional_streaming.py | 11 ++++++----- .../bidirectional_streaming/models/test_novasonic.py | 8 ++++---- 7 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 678dfc0d4..86a1139d0 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -20,7 +20,7 @@ InputEvent, InterruptionEvent, ModalityUsage, - MultimodalUsage, + UsageEvent, OutputEvent, SessionEndEvent, SessionStartEvent, @@ -59,7 +59,7 @@ "TranscriptStreamEvent", "InterruptionEvent", "TurnCompleteEvent", - "MultimodalUsage", + "UsageEvent", "ModalityUsage", "SessionEndEvent", "ErrorEvent", diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 1475edaac..da3387e1d 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -30,7 +30,7 @@ ErrorEvent, ImageInputEvent, InterruptionEvent, - MultimodalUsage, + UsageEvent, SessionEndEvent, SessionStartEvent, TextInputEvent, @@ -306,7 +306,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic "output_tokens": detail.token_count }) - return MultimodalUsage( + return UsageEvent( input_tokens=usage.prompt_token_count or 0, output_tokens=usage.response_token_count or 0, total_tokens=usage.total_token_count or 0, diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 033eff4e9..c6790e506 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -39,7 +39,7 @@ ErrorEvent, ImageInputEvent, InterruptionEvent, - MultimodalUsage, + UsageEvent, OutputEvent, SessionEndEvent, SessionStartEvent, @@ -564,7 +564,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: total_input = usage_data.get("totalInputTokens", 0) total_output = usage_data.get("totalOutputTokens", 0) - return MultimodalUsage( + return UsageEvent( input_tokens=total_input, output_tokens=total_output, total_tokens=usage_data.get("totalTokens", total_input + total_output) diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 393deb0bd..016ce32b1 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -24,7 +24,7 @@ ErrorEvent, ImageInputEvent, InterruptionEvent, - MultimodalUsage, + UsageEvent, OutputEvent, SessionEndEvent, SessionStartEvent, @@ -452,7 +452,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEven cached_tokens = input_details.get("cached_tokens", 0) # Add usage event - events.append(MultimodalUsage( + events.append(UsageEvent( input_tokens=usage.get("input_tokens", 0), output_tokens=usage.get("output_tokens", 0), total_tokens=usage.get("total_tokens", 0), diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index 52034db1b..2721ceab9 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -14,7 +14,7 @@ InputEvent, InterruptionEvent, ModalityUsage, - MultimodalUsage, + UsageEvent, OutputEvent, SessionEndEvent, SessionStartEvent, @@ -37,7 +37,7 @@ "TranscriptStreamEvent", "InterruptionEvent", "TurnCompleteEvent", - "MultimodalUsage", + "UsageEvent", "ModalityUsage", "SessionEndEvent", "ErrorEvent", diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index c7ca1515f..fbec640b8 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -345,10 +345,11 @@ class ModalityUsage(dict): output_tokens: int -class MultimodalUsage(TypedEvent): - """Token usage event with modality breakdown for multimodal streaming. +class UsageEvent(TypedEvent): + """Token usage event with modality breakdown for bidirectional streaming. - Combines TypedEvent behavior with Usage fields for a unified event type. + Tracks token consumption across different modalities (audio, text, images) + during bidirectional streaming sessions. Parameters: input_tokens: Total tokens used for all input modalities. @@ -369,7 +370,7 @@ def __init__( cache_write_input_tokens: Optional[int] = None, ): data: Dict[str, Any] = { - "type": "multimodal_usage", + "type": "bidirectional_usage", "inputTokens": input_tokens, "outputTokens": output_tokens, "totalTokens": total_tokens, @@ -492,7 +493,7 @@ def details(self) -> Optional[Dict[str, Any]]: TranscriptStreamEvent, InterruptionEvent, TurnCompleteEvent, - MultimodalUsage, + UsageEvent, SessionEndEvent, ErrorEvent, ] diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 6c77457c2..1e07eb449 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -286,8 +286,8 @@ async def test_event_conversion(nova_model): assert result.get("type") == "bidirectional_interruption" assert result.get("reason") == "user_speech" - # Test usage metrics (now returns MultimodalUsage) - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import MultimodalUsage + # Test usage metrics (now returns UsageEvent) + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import UsageEvent nova_event = { "usageEvent": { "totalTokens": 100, @@ -304,8 +304,8 @@ async def test_event_conversion(nova_model): } result = nova_model._convert_nova_event(nova_event) assert result is not None - assert isinstance(result, MultimodalUsage) - assert result.get("type") == "multimodal_usage" + assert isinstance(result, UsageEvent) + assert result.get("type") == "bidirectional_usage" assert result.get("totalTokens") == 100 assert result.get("inputTokens") == 40 assert result.get("outputTokens") == 60 From 433c610e188f725df95efe1c559026afd09dc655 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 6 Nov 2025 14:39:04 +0300 Subject: [PATCH 054/242] refactor(bidi): change session terminology to connection --- .../bidirectional_streaming/__init__.py | 8 ++-- .../models/gemini_live.py | 18 ++++---- .../models/novasonic.py | 42 +++++++++--------- .../bidirectional_streaming/models/openai.py | 18 ++++---- .../bidirectional_streaming/types/__init__.py | 8 ++-- .../types/bidirectional_streaming.py | 43 ++++++++++++------- .../models/test_novasonic.py | 4 +- 7 files changed, 77 insertions(+), 64 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 86a1139d0..1b901b0a2 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -15,6 +15,8 @@ from .types.bidirectional_streaming import ( AudioInputEvent, AudioStreamEvent, + ConnectionCloseEvent, + ConnectionStartEvent, ErrorEvent, ImageInputEvent, InputEvent, @@ -22,8 +24,6 @@ ModalityUsage, UsageEvent, OutputEvent, - SessionEndEvent, - SessionStartEvent, TextInputEvent, TranscriptStreamEvent, TurnCompleteEvent, @@ -53,7 +53,8 @@ "InputEvent", # Output Event types - "SessionStartEvent", + "ConnectionStartEvent", + "ConnectionCloseEvent", "TurnStartEvent", "AudioStreamEvent", "TranscriptStreamEvent", @@ -61,7 +62,6 @@ "TurnCompleteEvent", "UsageEvent", "ModalityUsage", - "SessionEndEvent", "ErrorEvent", "OutputEvent", diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index da3387e1d..5819b84eb 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -27,12 +27,12 @@ from ..types.bidirectional_streaming import ( AudioInputEvent, AudioStreamEvent, + ConnectionCloseEvent, + ConnectionStartEvent, ErrorEvent, ImageInputEvent, InterruptionEvent, UsageEvent, - SessionEndEvent, - SessionStartEvent, TextInputEvent, TranscriptStreamEvent, TurnCompleteEvent, @@ -89,7 +89,7 @@ def __init__( # Connection state (initialized in connect()) self.live_session = None self.live_session_context_manager = None - self.session_id = None + self.connection_id = None self._active = False async def connect( @@ -112,7 +112,7 @@ async def connect( try: # Initialize connection state - self.session_id = str(uuid.uuid4()) + self.connection_id = str(uuid.uuid4()) self._active = True # Build live config @@ -163,9 +163,9 @@ async def _send_message_history(self, messages: Messages) -> None: async def receive(self) -> AsyncIterable[Dict[str, Any]]: """Receive Gemini Live API events and convert to provider-agnostic format.""" - # Emit session start event - yield SessionStartEvent( - session_id=self.session_id, + # Emit connection start event + yield ConnectionStartEvent( + connection_id=self.connection_id, model=self.model_id, capabilities=["audio", "tools", "images"] ) @@ -196,8 +196,8 @@ async def receive(self) -> AsyncIterable[Dict[str, Any]]: logger.error("Fatal error in receive loop: %s", e) yield ErrorEvent(error=e) finally: - # Emit session end event when exiting - yield SessionEndEvent(reason="complete") + # Emit connection close event when exiting + yield ConnectionCloseEvent(connection_id=self.connection_id, reason="complete") def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dict[str, Any]]: """Convert Gemini Live API events to provider-agnostic format. diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index c6790e506..da57e0c57 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -36,13 +36,13 @@ from ..types.bidirectional_streaming import ( AudioInputEvent, AudioStreamEvent, + ConnectionCloseEvent, + ConnectionStartEvent, ErrorEvent, ImageInputEvent, InterruptionEvent, UsageEvent, OutputEvent, - SessionEndEvent, - SessionStartEvent, TextInputEvent, TranscriptStreamEvent, TurnCompleteEvent, @@ -111,7 +111,7 @@ def __init__( # Connection state (initialized in connect()) self.stream = None - self.session_id = None + self.connection_id = None self._active = False # Nova Sonic requires unique content names @@ -155,7 +155,7 @@ async def connect( await self._initialize_client() # Initialize connection state - self.session_id = str(uuid.uuid4()) + self.connection_id = str(uuid.uuid4()) self._active = True self.audio_content_name = str(uuid.uuid4()) self._event_queue = asyncio.Queue() @@ -170,7 +170,7 @@ async def connect( logger.error("Stream is None") raise ValueError("Stream cannot be None") - logger.debug("Nova Sonic connection initialized with session: %s", self.session_id) + logger.debug("Nova Sonic connection initialized with connection_id: %s", self.connection_id) # Send initialization events system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." @@ -272,9 +272,9 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: logger.debug("Nova events - starting event stream") - # Emit session start event - yield SessionStartEvent( - session_id=self.session_id, + # Emit connection start event + yield ConnectionStartEvent( + connection_id=self.connection_id, model=self.model_id, capabilities=["audio", "tools"] ) @@ -299,8 +299,8 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: logger.error(traceback.format_exc()) yield ErrorEvent(error=e) finally: - # Emit session end event - yield SessionEndEvent(reason="complete") + # Emit connection close event + yield ConnectionCloseEvent(connection_id=self.connection_id, reason="complete") async def send( self, @@ -345,7 +345,7 @@ async def _start_audio_connection(self) -> None: { "event": { "contentStart": { - "promptName": self.session_id, + "promptName": self.connection_id, "contentName": self.audio_content_name, "type": "AUDIO", "interactive": True, @@ -376,7 +376,7 @@ async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: { "event": { "audioInput": { - "promptName": self.session_id, + "promptName": self.connection_id, "contentName": self.audio_content_name, "content": audio_input.audio, } @@ -409,7 +409,7 @@ async def _end_audio_input(self) -> None: logger.debug("Nova audio connection end") audio_content_end = json.dumps( - {"event": {"contentEnd": {"promptName": self.session_id, "contentName": self.audio_content_name}}} + {"event": {"contentEnd": {"promptName": self.connection_id, "contentName": self.audio_content_name}}} ) await self._send_nova_event(audio_content_end) @@ -434,7 +434,7 @@ async def _send_interrupt(self) -> None: { "event": { "audioInput": { - "promptName": self.session_id, + "promptName": self.connection_id, "contentName": self.audio_content_name, "stopReason": "INTERRUPTED", } @@ -600,7 +600,7 @@ def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: prompt_start_event = { "event": { "promptStart": { - "promptName": self.session_id, + "promptName": self.connection_id, "textOutputConfiguration": NOVA_TEXT_CONFIG, "audioOutputConfiguration": NOVA_AUDIO_OUTPUT_CONFIG, } @@ -644,7 +644,7 @@ def _get_text_content_start_event(self, content_name: str, role: str = "USER") - { "event": { "contentStart": { - "promptName": self.session_id, + "promptName": self.connection_id, "contentName": content_name, "type": "TEXT", "role": role, @@ -661,7 +661,7 @@ def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> { "event": { "contentStart": { - "promptName": self.session_id, + "promptName": self.connection_id, "contentName": content_name, "interactive": False, "type": "TOOL", @@ -679,7 +679,7 @@ def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> def _get_text_input_event(self, content_name: str, text: str) -> str: """Generate text input event.""" return json.dumps( - {"event": {"textInput": {"promptName": self.session_id, "contentName": content_name, "content": text}}} + {"event": {"textInput": {"promptName": self.connection_id, "contentName": content_name, "content": text}}} ) def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> str: @@ -688,7 +688,7 @@ def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> s { "event": { "toolResult": { - "promptName": self.session_id, + "promptName": self.connection_id, "contentName": content_name, "content": json.dumps(result), } @@ -698,11 +698,11 @@ def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> s def _get_content_end_event(self, content_name: str) -> str: """Generate content end event.""" - return json.dumps({"event": {"contentEnd": {"promptName": self.session_id, "contentName": content_name}}}) + return json.dumps({"event": {"contentEnd": {"promptName": self.connection_id, "contentName": content_name}}}) def _get_prompt_end_event(self) -> str: """Generate prompt end event.""" - return json.dumps({"event": {"promptEnd": {"promptName": self.session_id}}}) + return json.dumps({"event": {"promptEnd": {"promptName": self.connection_id}}}) def _get_connection_end_event(self) -> str: """Generate connection end event.""" diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 016ce32b1..52a3cdf79 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -21,13 +21,13 @@ from ..types.bidirectional_streaming import ( AudioInputEvent, AudioStreamEvent, + ConnectionCloseEvent, + ConnectionStartEvent, ErrorEvent, ImageInputEvent, InterruptionEvent, UsageEvent, OutputEvent, - SessionEndEvent, - SessionStartEvent, TextInputEvent, TranscriptStreamEvent, TurnCompleteEvent, @@ -103,7 +103,7 @@ def __init__( # Connection state (initialized in connect()) self.websocket = None - self.session_id = None + self.connection_id = None self._active = False self._event_queue = None @@ -134,7 +134,7 @@ async def connect( try: # Initialize connection state - self.session_id = str(uuid.uuid4()) + self.connection_id = str(uuid.uuid4()) self._active = True self._event_queue = asyncio.Queue() self._function_call_buffer = {} @@ -279,9 +279,9 @@ async def _process_responses(self) -> None: async def receive(self) -> AsyncIterable[OutputEvent]: """Receive OpenAI events and convert to Strands TypedEvent format.""" - # Emit session start event - yield SessionStartEvent( - session_id=self.session_id, + # Emit connection start event + yield ConnectionStartEvent( + connection_id=self.connection_id, model=self.model, capabilities=["audio", "tools"] ) @@ -299,8 +299,8 @@ async def receive(self) -> AsyncIterable[OutputEvent]: logger.error("Error receiving OpenAI Realtime event: %s", e) yield ErrorEvent(error=e) finally: - # Emit session end event - yield SessionEndEvent(reason="complete") + # Emit connection close event + yield ConnectionCloseEvent(connection_id=self.connection_id, reason="complete") def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEvent] | None: """Convert OpenAI events to Strands TypedEvent format.""" diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index 2721ceab9..9ab16dd38 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -9,6 +9,8 @@ SUPPORTED_SAMPLE_RATES, AudioInputEvent, AudioStreamEvent, + ConnectionCloseEvent, + ConnectionStartEvent, ErrorEvent, ImageInputEvent, InputEvent, @@ -16,8 +18,6 @@ ModalityUsage, UsageEvent, OutputEvent, - SessionEndEvent, - SessionStartEvent, TextInputEvent, TranscriptStreamEvent, TurnCompleteEvent, @@ -31,7 +31,8 @@ "ImageInputEvent", "InputEvent", # Output Events - "SessionStartEvent", + "ConnectionStartEvent", + "ConnectionCloseEvent", "TurnStartEvent", "AudioStreamEvent", "TranscriptStreamEvent", @@ -39,7 +40,6 @@ "TurnCompleteEvent", "UsageEvent", "ModalityUsage", - "SessionEndEvent", "ErrorEvent", "OutputEvent", # Constants diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index fbec640b8..5dea738d9 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -148,28 +148,28 @@ def mime_type(self) -> str: # ============================================================================ -class SessionStartEvent(TypedEvent): - """Session established and ready for interaction. +class ConnectionStartEvent(TypedEvent): + """Streaming connection established and ready for interaction. Parameters: - session_id: Unique identifier for this session. + connection_id: Unique identifier for this streaming connection. model: Model identifier (e.g., "gpt-realtime", "gemini-2.0-flash-live"). capabilities: List of supported features (e.g., ["audio", "tools", "images"]). """ - def __init__(self, session_id: str, model: str, capabilities: List[str]): + def __init__(self, connection_id: str, model: str, capabilities: List[str]): super().__init__( { - "type": "bidirectional_session_start", - "session_id": session_id, + "type": "bidirectional_connection_start", + "connection_id": connection_id, "model": model, "capabilities": capabilities, } ) @property - def session_id(self) -> str: - return cast(str, self.get("session_id")) + def connection_id(self) -> str: + return cast(str, self.get("connection_id")) @property def model(self) -> str: @@ -408,17 +408,30 @@ def cache_write_input_tokens(self) -> Optional[int]: return cast(Optional[int], self.get("cacheWriteInputTokens")) -class SessionEndEvent(TypedEvent): - """Session terminated. +class ConnectionCloseEvent(TypedEvent): + """Streaming connection closed. Parameters: - reason: Why the session ended. + connection_id: Unique identifier for this streaming connection (matches ConnectionStartEvent). + reason: Why the connection was closed. """ def __init__( - self, reason: Literal["client_disconnect", "timeout", "error", "complete"] + self, + connection_id: str, + reason: Literal["client_disconnect", "timeout", "error", "complete"], ): - super().__init__({"type": "bidirectional_session_end", "reason": reason}) + super().__init__( + { + "type": "bidirectional_connection_close", + "connection_id": connection_id, + "reason": reason, + } + ) + + @property + def connection_id(self) -> str: + return cast(str, self.get("connection_id")) @property def reason(self) -> str: @@ -487,13 +500,13 @@ def details(self) -> Optional[Dict[str, Any]]: InputEvent = Union[TextInputEvent, AudioInputEvent, ImageInputEvent] OutputEvent = Union[ - SessionStartEvent, + ConnectionStartEvent, TurnStartEvent, AudioStreamEvent, TranscriptStreamEvent, InterruptionEvent, TurnCompleteEvent, UsageEvent, - SessionEndEvent, + ConnectionCloseEvent, ErrorEvent, ] diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 1e07eb449..851afd92a 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -415,12 +415,12 @@ async def test_event_templates(nova_model): assert "inferenceConfiguration" in event["event"]["sessionStart"] # Test prompt start event - nova_model.session_id = "test-session" + nova_model.connection_id = "test-connection" event_json = nova_model._get_prompt_start_event([]) event = json.loads(event_json) assert "event" in event assert "promptStart" in event["event"] - assert event["event"]["promptStart"]["promptName"] == "test-session" + assert event["event"]["promptStart"]["promptName"] == "test-connection" # Test text input event content_name = "test-content" From e3389fb4d2ff243833c25c83f4fab7383bee8f7c Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 6 Nov 2025 15:01:38 +0300 Subject: [PATCH 055/242] refactor: rename turn events to response events --- .../bidirectional_streaming/__init__.py | 8 ++-- .../models/gemini_live.py | 10 ++--- .../models/novasonic.py | 14 +++---- .../bidirectional_streaming/models/openai.py | 14 +++---- .../bidirectional_streaming/types/__init__.py | 8 ++-- .../types/bidirectional_streaming.py | 42 +++++++++---------- .../models/test_gemini_live.py | 2 +- .../models/test_novasonic.py | 2 +- .../models/test_openai_realtime.py | 4 +- 9 files changed, 52 insertions(+), 52 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 1b901b0a2..31b9ead32 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -24,10 +24,10 @@ ModalityUsage, UsageEvent, OutputEvent, + ResponseCompleteEvent, + ResponseStartEvent, TextInputEvent, TranscriptStreamEvent, - TurnCompleteEvent, - TurnStartEvent, ) # Re-export standard agent events for tool handling @@ -55,11 +55,11 @@ # Output Event types "ConnectionStartEvent", "ConnectionCloseEvent", - "TurnStartEvent", + "ResponseStartEvent", + "ResponseCompleteEvent", "AudioStreamEvent", "TranscriptStreamEvent", "InterruptionEvent", - "TurnCompleteEvent", "UsageEvent", "ModalityUsage", "ErrorEvent", diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 5819b84eb..ad2ca678d 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -35,8 +35,8 @@ UsageEvent, TextInputEvent, TranscriptStreamEvent, - TurnCompleteEvent, - TurnStartEvent, + ResponseCompleteEvent, + ResponseStartEvent, ) from .bidirectional_model import BidirectionalModel @@ -222,7 +222,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic logger.debug(f"Input transcription detected: {transcription_text}") return TranscriptStreamEvent( text=transcription_text, - source="user", + role="user", is_final=True ) @@ -235,7 +235,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic logger.debug(f"Output transcription detected: {transcription_text}") return TranscriptStreamEvent( text=transcription_text, - source="assistant", + role="assistant", is_final=True ) @@ -244,7 +244,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic logger.debug(f"Text output as transcript: {message.text}") return TranscriptStreamEvent( text=message.text, - source="assistant", + role="assistant", is_final=True ) diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index da57e0c57..f18085020 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -45,8 +45,8 @@ OutputEvent, TextInputEvent, TranscriptStreamEvent, - TurnCompleteEvent, - TurnStartEvent, + ResponseCompleteEvent, + ResponseStartEvent, ) from .bidirectional_model import BidirectionalModel @@ -535,7 +535,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: return TranscriptStreamEvent( text=text_content, - source="user" if role == "USER" else "assistant", + role="user" if role == "USER" else "assistant", is_final=True ) @@ -575,14 +575,14 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: role = nova_event["contentStart"].get("role", "unknown") # Store role for subsequent text output events self._current_role = role - # Emit turn start event - return TurnStartEvent(turn_id=str(uuid.uuid4())) + # Emit response start event + return ResponseStartEvent(response_id=str(uuid.uuid4())) # Handle content stop events elif "contentStop" in nova_event: stop_reason = nova_event["contentStop"].get("stopReason", "complete") - return TurnCompleteEvent( - turn_id=str(uuid.uuid4()), + return ResponseCompleteEvent( + response_id=str(uuid.uuid4()), stop_reason="interrupted" if stop_reason == "INTERRUPTED" else "complete" ) diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 52a3cdf79..33c89ba6c 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -30,8 +30,8 @@ OutputEvent, TextInputEvent, TranscriptStreamEvent, - TurnCompleteEvent, - TurnStartEvent, + ResponseCompleteEvent, + ResponseStartEvent, ) from .bidirectional_model import BidirectionalModel @@ -176,7 +176,7 @@ def _create_text_event(self, text: str, role: str) -> TranscriptStreamEvent: """Create standardized transcript event.""" return TranscriptStreamEvent( text=text, - source="user" if role == "user" else "assistant", + role="user" if role == "user" else "assistant", is_final=True ) @@ -310,7 +310,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEven if event_type == "response.created": response = openai_event.get("response", {}) response_id = response.get("id", str(uuid.uuid4())) - return [TurnStartEvent(turn_id=response_id)] + return [ResponseStartEvent(response_id=response_id)] # Audio output elif event_type == "response.output_audio.delta": @@ -405,9 +405,9 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEven # Build list of events to return events = [] - # Always add turn complete event - events.append(TurnCompleteEvent( - turn_id=response_id, + # Always add response complete event + events.append(ResponseCompleteEvent( + response_id=response_id, stop_reason=stop_reason_map.get(status, "complete") )) diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index 9ab16dd38..0a2abb68f 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -18,10 +18,10 @@ ModalityUsage, UsageEvent, OutputEvent, + ResponseCompleteEvent, + ResponseStartEvent, TextInputEvent, TranscriptStreamEvent, - TurnCompleteEvent, - TurnStartEvent, ) __all__ = [ @@ -33,11 +33,11 @@ # Output Events "ConnectionStartEvent", "ConnectionCloseEvent", - "TurnStartEvent", + "ResponseStartEvent", + "ResponseCompleteEvent", "AudioStreamEvent", "TranscriptStreamEvent", "InterruptionEvent", - "TurnCompleteEvent", "UsageEvent", "ModalityUsage", "ErrorEvent", diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 5dea738d9..5641200e7 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -180,19 +180,19 @@ def capabilities(self) -> List[str]: return cast(List[str], self.get("capabilities")) -class TurnStartEvent(TypedEvent): +class ResponseStartEvent(TypedEvent): """Model starts generating a response. Parameters: - turn_id: Unique identifier for this turn (used in turn.complete). + response_id: Unique identifier for this response (used in response.complete). """ - def __init__(self, turn_id: str): - super().__init__({"type": "bidirectional_turn_start", "turn_id": turn_id}) + def __init__(self, response_id: str): + super().__init__({"type": "bidirectional_response_start", "response_id": response_id}) @property - def turn_id(self) -> str: - return cast(str, self.get("turn_id")) + def response_id(self) -> str: + return cast(str, self.get("response_id")) class AudioStreamEvent(TypedEvent): @@ -244,18 +244,18 @@ class TranscriptStreamEvent(TypedEvent): Parameters: text: Transcribed text from audio. - source: Who is speaking ("user" or "assistant"). + role: Who is speaking ("user" or "assistant"). Aligns with Message.role convention. is_final: Whether this is the final/complete transcript. """ def __init__( - self, text: str, source: Literal["user", "assistant"], is_final: bool + self, text: str, role: Literal["user", "assistant"], is_final: bool ): super().__init__( { "type": "bidirectional_transcript_stream", "text": text, - "source": source, + "role": role, "is_final": is_final, } ) @@ -265,8 +265,8 @@ def text(self) -> str: return cast(str, self.get("text")) @property - def source(self) -> str: - return cast(str, self.get("source")) + def role(self) -> str: + return cast(str, self.get("role")) @property def is_final(self) -> bool: @@ -301,30 +301,30 @@ def turn_id(self) -> Optional[str]: return cast(Optional[str], self.get("turn_id")) -class TurnCompleteEvent(TypedEvent): +class ResponseCompleteEvent(TypedEvent): """Model finished generating response. Parameters: - turn_id: ID of the turn that completed (matches turn.start). - stop_reason: Why the turn ended. + response_id: ID of the response that completed (matches response.start). + stop_reason: Why the response ended. """ def __init__( self, - turn_id: str, + response_id: str, stop_reason: Literal["complete", "interrupted", "tool_use", "error"], ): super().__init__( { - "type": "bidirectional_turn_complete", - "turn_id": turn_id, + "type": "bidirectional_response_complete", + "response_id": response_id, "stop_reason": stop_reason, } ) @property - def turn_id(self) -> str: - return cast(str, self.get("turn_id")) + def response_id(self) -> str: + return cast(str, self.get("response_id")) @property def stop_reason(self) -> str: @@ -501,11 +501,11 @@ def details(self) -> Optional[Dict[str, Any]]: OutputEvent = Union[ ConnectionStartEvent, - TurnStartEvent, + ResponseStartEvent, AudioStreamEvent, TranscriptStreamEvent, InterruptionEvent, - TurnCompleteEvent, + ResponseCompleteEvent, UsageEvent, ConnectionCloseEvent, ErrorEvent, diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index d3bf965f4..5f6319318 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -302,7 +302,7 @@ async def test_event_conversion(mock_genai_client, model): text_event = model._convert_gemini_live_event(mock_text) assert isinstance(text_event, TranscriptStreamEvent) assert text_event.text == "Hello from Gemini" - assert text_event.source == "assistant" + assert text_event.role == "assistant" assert text_event.is_final is True # Test audio output (base64 encoded) diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 851afd92a..feb320d91 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -258,7 +258,7 @@ async def test_event_conversion(nova_model): assert isinstance(result, TranscriptStreamEvent) assert result.get("type") == "bidirectional_transcript_stream" assert result.get("text") == "Hello, world!" - assert result.get("source") == "assistant" + assert result.get("role") == "assistant" # Test tool use (now returns dict with tool_use) tool_input = {"location": "Seattle"} diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 60e88aa0f..98c520fdb 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -368,7 +368,7 @@ async def test_event_conversion(mock_websockets_connect, model): assert isinstance(converted[0], TranscriptStreamEvent) assert converted[0].get("type") == "bidirectional_transcript_stream" assert converted[0].get("text") == "Hello from OpenAI" - assert converted[0].get("source") == "assistant" + assert converted[0].get("role") == "assistant" # Test function call sequence item_added = { @@ -467,7 +467,7 @@ def test_helper_methods(model): assert isinstance(text_event, TranscriptStreamEvent) assert text_event.get("type") == "bidirectional_transcript_stream" assert text_event.get("text") == "Hello" - assert text_event.get("source") == "user" + assert text_event.get("role") == "user" # Test _create_voice_activity_event (now returns InterruptionEvent for speech_started) from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import InterruptionEvent From 5877c5fce07435d35ed24b9c42cbc0c10681ac66 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 6 Nov 2025 15:22:10 +0300 Subject: [PATCH 056/242] fix: fix bidi tests --- .../models/test_gemini_live.py | 20 +++++++------- .../models/test_novasonic.py | 26 ++++++++++--------- .../models/test_openai_realtime.py | 21 ++++++++------- 3 files changed, 36 insertions(+), 31 deletions(-) diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index 5f6319318..25f11c23c 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -116,7 +116,7 @@ async def test_connection_lifecycle(mock_genai_client, model, system_prompt, too # Test basic connection await model.connect() assert model._active is True - assert model.session_id is not None + assert model.connection_id is not None assert model.live_session == mock_live_session mock_client.aio.live.connect.assert_called_once() @@ -256,8 +256,8 @@ async def test_send_edge_cases(mock_genai_client, model): async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): """Test that receive() emits connection start and end events.""" from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( - SessionStartEvent, - SessionEndEvent, + ConnectionStartEvent, + ConnectionCloseEvent, ) _, mock_live_session, _ = mock_genai_client @@ -275,9 +275,9 @@ async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): # Verify connection start and end assert len(events) >= 2 - assert isinstance(events[0], SessionStartEvent) - assert events[0].session_id == model.session_id - assert isinstance(events[-1], SessionEndEvent) + assert isinstance(events[0], ConnectionStartEvent) + assert events[0].connection_id == model.connection_id + assert isinstance(events[-1], ConnectionCloseEvent) @pytest.mark.asyncio @@ -336,9 +336,11 @@ async def test_event_conversion(mock_genai_client, model): mock_tool.server_content = None tool_event = model._convert_gemini_live_event(mock_tool) - assert "toolUse" in tool_event - assert tool_event["toolUse"]["toolUseId"] == "tool-123" - assert tool_event["toolUse"]["name"] == "calculator" + # ToolUseStreamEvent has delta and current_tool_use, not a "type" field + assert "delta" in tool_event + assert "toolUse" in tool_event["delta"] + assert tool_event["delta"]["toolUse"]["toolUseId"] == "tool-123" + assert tool_event["delta"]["toolUse"]["name"] == "calculator" # Test interruption mock_server_content = unittest.mock.Mock() diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index feb320d91..3865eb353 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -72,7 +72,7 @@ async def test_model_initialization(model_id, region): assert model.region == region assert model.stream is None assert not model._active - assert model.session_id is None + assert model.connection_id is None @pytest.mark.asyncio @@ -85,7 +85,7 @@ async def test_connection_lifecycle(nova_model, mock_client, mock_stream): await nova_model.connect(system_prompt="Test system prompt") assert nova_model._active assert nova_model.stream == mock_stream - assert nova_model.session_id is not None + assert nova_model.connection_id is not None assert mock_client.invoke_model_with_bidirectional_stream.called # Test close @@ -228,9 +228,9 @@ async def mock_wait_for(*args, **kwargs): # Should have session start and end (new TypedEvent format) assert len(events) >= 2 - assert events[0].get("type") == "bidirectional_session_start" - assert events[0].get("session_id") == nova_model.session_id - assert events[-1].get("type") == "bidirectional_session_end" + assert events[0].get("type") == "bidirectional_connection_start" + assert events[0].get("connection_id") == nova_model.connection_id + assert events[-1].get("type") == "bidirectional_connection_close" @pytest.mark.asyncio @@ -260,7 +260,7 @@ async def test_event_conversion(nova_model): assert result.get("text") == "Hello, world!" assert result.get("role") == "assistant" - # Test tool use (now returns dict with tool_use) + # Test tool use (now returns ToolUseStreamEvent from core strands) tool_input = {"location": "Seattle"} nova_event = { "toolUse": { @@ -271,8 +271,10 @@ async def test_event_conversion(nova_model): } result = nova_model._convert_nova_event(nova_event) assert result is not None - assert result.get("type") == "tool_use" - tool_use = result.get("tool_use") + # ToolUseStreamEvent has delta and current_tool_use, not a "type" field + assert "delta" in result + assert "toolUse" in result["delta"] + tool_use = result["delta"]["toolUse"] assert tool_use["toolUseId"] == "tool-123" assert tool_use["name"] == "get_weather" assert tool_use["input"] == tool_input @@ -310,13 +312,13 @@ async def test_event_conversion(nova_model): assert result.get("inputTokens") == 40 assert result.get("outputTokens") == 60 - # Test content start tracks role and emits TurnStartEvent - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import TurnStartEvent + # Test content start tracks role and emits ResponseStartEvent + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ResponseStartEvent nova_event = {"contentStart": {"role": "USER"}} result = nova_model._convert_nova_event(nova_event) assert result is not None - assert isinstance(result, TurnStartEvent) - assert result.get("type") == "bidirectional_turn_start" + assert isinstance(result, ResponseStartEvent) + assert result.get("type") == "bidirectional_response_start" assert nova_model._current_role == "USER" diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 98c520fdb..a1c7e65cb 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -130,7 +130,7 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp # Test basic connection await model.connect() assert model._active is True - assert model.session_id is not None + assert model.connection_id is not None assert model.websocket == mock_ws assert model._event_queue is not None assert model._response_task is not None @@ -316,9 +316,9 @@ async def test_receive_lifecycle_events(mock_websockets_connect, model): receive_gen = model.receive() first_event = await anext(receive_gen) - # First event should be session start (new TypedEvent format) - assert first_event.get("type") == "bidirectional_session_start" - assert first_event.get("session_id") == model.session_id + # First event should be connection start (new TypedEvent format) + assert first_event.get("type") == "bidirectional_connection_start" + assert first_event.get("connection_id") == model.connection_id assert first_event.get("model") == model.model # Close to trigger session end @@ -332,8 +332,8 @@ async def test_receive_lifecycle_events(mock_websockets_connect, model): except StopAsyncIteration: pass - # Last event should be session end (new TypedEvent format) - assert events[-1].get("type") == "bidirectional_session_end" + # Last event should be connection close (new TypedEvent format) + assert events[-1].get("type") == "bidirectional_connection_close" @pytest.mark.asyncio @@ -393,12 +393,13 @@ async def test_event_conversion(mock_websockets_connect, model): "call_id": "call-123" } converted = model._convert_openai_event(args_done) - # Now returns list with dict containing tool_use + # Now returns list with ToolUseStreamEvent assert isinstance(converted, list) assert len(converted) == 1 - assert isinstance(converted[0], dict) - assert converted[0].get("type") == "tool_use" - tool_use = converted[0].get("tool_use") + # ToolUseStreamEvent has delta and current_tool_use, not a "type" field + assert "delta" in converted[0] + assert "toolUse" in converted[0]["delta"] + tool_use = converted[0]["delta"]["toolUse"] assert tool_use["toolUseId"] == "call-123" assert tool_use["name"] == "calculator" assert tool_use["input"]["expression"] == "2+2" From 1e0e65ae8124b5d2985f72b3bcefb355d63593f7 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 6 Nov 2025 15:38:34 +0300 Subject: [PATCH 057/242] feat: add json serialization tests for events --- .../bidirectional_streaming/types/__init__.py | 1 + .../types/test_bidirectional_streaming.py | 108 ++++++++++++++++++ 2 files changed, 109 insertions(+) create mode 100644 tests/strands/experimental/bidirectional_streaming/types/__init__.py create mode 100644 tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py diff --git a/tests/strands/experimental/bidirectional_streaming/types/__init__.py b/tests/strands/experimental/bidirectional_streaming/types/__init__.py new file mode 100644 index 000000000..a1330e552 --- /dev/null +++ b/tests/strands/experimental/bidirectional_streaming/types/__init__.py @@ -0,0 +1 @@ +"""Tests for bidirectional streaming types.""" diff --git a/tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py b/tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py new file mode 100644 index 000000000..0efde8823 --- /dev/null +++ b/tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py @@ -0,0 +1,108 @@ +"""Tests for bidirectional streaming event types. + +This module tests JSON serialization for all bidirectional streaming event types. +""" + +import base64 +import json + +import pytest + +from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( + AudioInputEvent, + AudioStreamEvent, + ConnectionCloseEvent, + ConnectionStartEvent, + ErrorEvent, + ImageInputEvent, + InterruptionEvent, + ResponseCompleteEvent, + ResponseStartEvent, + TextInputEvent, + TranscriptStreamEvent, + UsageEvent, +) + + +@pytest.mark.parametrize( + "event_class,kwargs,expected_type", + [ + # Input events + (TextInputEvent, {"text": "Hello", "role": "user"}, "bidirectional_text_input"), + ( + AudioInputEvent, + { + "audio": base64.b64encode(b"audio").decode("utf-8"), + "format": "pcm", + "sample_rate": 16000, + "channels": 1, + }, + "bidirectional_audio_input", + ), + ( + ImageInputEvent, + {"image": base64.b64encode(b"image").decode("utf-8"), "mime_type": "image/jpeg"}, + "bidirectional_image_input", + ), + # Output events + ( + ConnectionStartEvent, + {"connection_id": "c1", "model": "m1", "capabilities": ["audio"]}, + "bidirectional_connection_start", + ), + (ResponseStartEvent, {"response_id": "r1"}, "bidirectional_response_start"), + ( + AudioStreamEvent, + { + "audio": base64.b64encode(b"audio").decode("utf-8"), + "format": "pcm", + "sample_rate": 24000, + "channels": 1, + }, + "bidirectional_audio_stream", + ), + ( + TranscriptStreamEvent, + {"text": "Hello", "role": "assistant", "is_final": True}, + "bidirectional_transcript_stream", + ), + (InterruptionEvent, {"reason": "user_speech", "turn_id": None}, "bidirectional_interruption"), + ( + ResponseCompleteEvent, + {"response_id": "r1", "stop_reason": "complete"}, + "bidirectional_response_complete", + ), + ( + UsageEvent, + {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, + "bidirectional_usage", + ), + ( + ConnectionCloseEvent, + {"connection_id": "c1", "reason": "complete"}, + "bidirectional_connection_close", + ), + (ErrorEvent, {"error": ValueError("test"), "details": None}, "bidirectional_error"), + ], +) +def test_event_json_serialization(event_class, kwargs, expected_type): + """Test that all event types are JSON serializable and deserializable.""" + # Create event + event = event_class(**kwargs) + + # Verify type field + assert event["type"] == expected_type + + # Serialize to JSON + json_str = json.dumps(event) + + # Deserialize back + data = json.loads(json_str) + + # Verify type preserved + assert data["type"] == expected_type + + # Verify all non-private keys preserved + for key in event.keys(): + if not key.startswith("_"): + assert key in data From 4d091a4659bca8658a7ae6a1416f8a7abdd2f6c3 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 6 Nov 2025 16:04:22 +0300 Subject: [PATCH 058/242] refactor: transcript events to extend model stream event --- .../models/gemini_live.py | 15 +++-- .../models/novasonic.py | 7 ++- .../bidirectional_streaming/models/openai.py | 12 ++-- .../types/bidirectional_streaming.py | 40 +++++++++---- .../models/test_gemini_live.py | 2 + .../models/test_novasonic.py | 2 + .../models/test_openai_realtime.py | 5 ++ .../types/test_bidirectional_streaming.py | 59 ++++++++++++++++++- 8 files changed, 115 insertions(+), 27 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index ad2ca678d..29b18da9e 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -166,8 +166,7 @@ async def receive(self) -> AsyncIterable[Dict[str, Any]]: # Emit connection start event yield ConnectionStartEvent( connection_id=self.connection_id, - model=self.model_id, - capabilities=["audio", "tools", "images"] + model=self.model_id ) try: @@ -221,9 +220,11 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic transcription_text = input_transcript.text logger.debug(f"Input transcription detected: {transcription_text}") return TranscriptStreamEvent( + delta={"text": transcription_text}, text=transcription_text, role="user", - is_final=True + is_final=True, + current_transcript=transcription_text ) # Handle output transcription (model's audio) - emit as transcript event @@ -234,18 +235,22 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic transcription_text = output_transcript.text logger.debug(f"Output transcription detected: {transcription_text}") return TranscriptStreamEvent( + delta={"text": transcription_text}, text=transcription_text, role="assistant", - is_final=True + is_final=True, + current_transcript=transcription_text ) # Handle text output from model if message.text: logger.debug(f"Text output as transcript: {message.text}") return TranscriptStreamEvent( + delta={"text": message.text}, text=message.text, role="assistant", - is_final=True + is_final=True, + current_transcript=message.text ) # Handle audio output using SDK's built-in data property diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index f18085020..3b4419586 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -275,8 +275,7 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: # Emit connection start event yield ConnectionStartEvent( connection_id=self.connection_id, - model=self.model_id, - capabilities=["audio", "tools"] + model=self.model_id ) try: @@ -534,9 +533,11 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: return InterruptionEvent(reason="user_speech", turn_id=None) return TranscriptStreamEvent( + delta={"text": text_content}, text=text_content, role="user" if role == "USER" else "assistant", - is_final=True + is_final=True, + current_transcript=text_content ) # Handle tool use diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 33c89ba6c..1f072ac87 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -172,12 +172,14 @@ def _require_active(self) -> bool: """Check if session is active.""" return self._active - def _create_text_event(self, text: str, role: str) -> TranscriptStreamEvent: + def _create_text_event(self, text: str, role: str, is_final: bool = True) -> TranscriptStreamEvent: """Create standardized transcript event.""" return TranscriptStreamEvent( + delta={"text": text}, text=text, role="user" if role == "user" else "assistant", - is_final=True + is_final=is_final, + current_transcript=text if is_final else None ) def _create_voice_activity_event(self, activity_type: str) -> InterruptionEvent | None: @@ -282,8 +284,7 @@ async def receive(self) -> AsyncIterable[OutputEvent]: # Emit connection start event yield ConnectionStartEvent( connection_id=self.connection_id, - model=self.model, - capabilities=["audio", "tools"] + model=self.model ) try: @@ -331,7 +332,8 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEven "conversation.item.input_audio_transcription.completed"]: text_key = "delta" if "delta" in event_type else "transcript" text = openai_event.get(text_key, "") - return [self._create_text_event(text, "user")] if text.strip() else None + is_final = "completed" in event_type + return [self._create_text_event(text, "user", is_final=is_final)] if text.strip() else None elif event_type == "conversation.item.input_audio_transcription.segment": segment_data = openai_event.get("segment", {}) diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 5641200e7..355e78c2f 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -21,7 +21,8 @@ from typing import Any, Dict, List, Literal, Optional, Union, cast -from ....types._events import TypedEvent +from ....types._events import ModelStreamEvent, TypedEvent +from ....types.streaming import ContentBlockDelta # Audio format constants SUPPORTED_AUDIO_FORMATS = ["pcm", "wav", "opus", "mp3"] @@ -154,16 +155,14 @@ class ConnectionStartEvent(TypedEvent): Parameters: connection_id: Unique identifier for this streaming connection. model: Model identifier (e.g., "gpt-realtime", "gemini-2.0-flash-live"). - capabilities: List of supported features (e.g., ["audio", "tools", "images"]). """ - def __init__(self, connection_id: str, model: str, capabilities: List[str]): + def __init__(self, connection_id: str, model: str): super().__init__( { "type": "bidirectional_connection_start", "connection_id": connection_id, "model": model, - "capabilities": capabilities, } ) @@ -175,10 +174,6 @@ def connection_id(self) -> str: def model(self) -> str: return cast(str, self.get("model")) - @property - def capabilities(self) -> List[str]: - return cast(List[str], self.get("capabilities")) - class ResponseStartEvent(TypedEvent): """Model starts generating a response. @@ -239,27 +234,44 @@ def channels(self) -> int: return cast(int, self.get("channels")) -class TranscriptStreamEvent(TypedEvent): - """Audio transcription of speech (user or assistant). +class TranscriptStreamEvent(ModelStreamEvent): + """Audio transcription streaming (user or assistant speech). + + Follows the same delta + current state pattern as TextStreamEvent and ToolUseStreamEvent + from core Strands. Supports incremental transcript updates for providers like OpenAI + that send partial transcripts before the final version. Parameters: - text: Transcribed text from audio. + delta: The incremental transcript change (ContentBlockDelta). + text: The delta text (same as delta content for convenience). role: Who is speaking ("user" or "assistant"). Aligns with Message.role convention. is_final: Whether this is the final/complete transcript. + current_transcript: The accumulated transcript text so far (None for first delta). """ def __init__( - self, text: str, role: Literal["user", "assistant"], is_final: bool + self, + delta: ContentBlockDelta, + text: str, + role: Literal["user", "assistant"], + is_final: bool, + current_transcript: Optional[str] = None, ): super().__init__( { "type": "bidirectional_transcript_stream", + "delta": delta, "text": text, "role": role, "is_final": is_final, + "current_transcript": current_transcript, } ) + @property + def delta(self) -> ContentBlockDelta: + return cast(ContentBlockDelta, self.get("delta")) + @property def text(self) -> str: return cast(str, self.get("text")) @@ -272,6 +284,10 @@ def role(self) -> str: def is_final(self) -> bool: return cast(bool, self.get("is_final")) + @property + def current_transcript(self) -> Optional[str]: + return cast(Optional[str], self.get("current_transcript")) + class InterruptionEvent(TypedEvent): """Model generation was interrupted. diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index 25f11c23c..107a8a84a 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -304,6 +304,8 @@ async def test_event_conversion(mock_genai_client, model): assert text_event.text == "Hello from Gemini" assert text_event.role == "assistant" assert text_event.is_final is True + assert text_event.delta == {"text": "Hello from Gemini"} + assert text_event.current_transcript == "Hello from Gemini" # Test audio output (base64 encoded) import base64 diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 3865eb353..1a2fef426 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -259,6 +259,8 @@ async def test_event_conversion(nova_model): assert result.get("type") == "bidirectional_transcript_stream" assert result.get("text") == "Hello, world!" assert result.get("role") == "assistant" + assert result.delta == {"text": "Hello, world!"} + assert result.current_transcript == "Hello, world!" # Test tool use (now returns ToolUseStreamEvent from core strands) tool_input = {"location": "Seattle"} diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index a1c7e65cb..2045424e1 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -369,6 +369,8 @@ async def test_event_conversion(mock_websockets_connect, model): assert converted[0].get("type") == "bidirectional_transcript_stream" assert converted[0].get("text") == "Hello from OpenAI" assert converted[0].get("role") == "assistant" + assert converted[0].delta == {"text": "Hello from OpenAI"} + assert converted[0].is_final is True # Test function call sequence item_added = { @@ -469,6 +471,9 @@ def test_helper_methods(model): assert text_event.get("type") == "bidirectional_transcript_stream" assert text_event.get("text") == "Hello" assert text_event.get("role") == "user" + assert text_event.delta == {"text": "Hello"} + assert text_event.is_final is True + assert text_event.current_transcript == "Hello" # Test _create_voice_activity_event (now returns InterruptionEvent for speech_started) from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import InterruptionEvent diff --git a/tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py b/tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py index 0efde8823..b6290cfcf 100644 --- a/tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py +++ b/tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py @@ -47,7 +47,7 @@ # Output events ( ConnectionStartEvent, - {"connection_id": "c1", "model": "m1", "capabilities": ["audio"]}, + {"connection_id": "c1", "model": "m1"}, "bidirectional_connection_start", ), (ResponseStartEvent, {"response_id": "r1"}, "bidirectional_response_start"), @@ -63,7 +63,13 @@ ), ( TranscriptStreamEvent, - {"text": "Hello", "role": "assistant", "is_final": True}, + { + "delta": {"text": "Hello"}, + "text": "Hello", + "role": "assistant", + "is_final": True, + "current_transcript": "Hello", + }, "bidirectional_transcript_stream", ), (InterruptionEvent, {"reason": "user_speech", "turn_id": None}, "bidirectional_interruption"), @@ -106,3 +112,52 @@ def test_event_json_serialization(event_class, kwargs, expected_type): for key in event.keys(): if not key.startswith("_"): assert key in data + + + +def test_transcript_stream_event_delta_pattern(): + """Test that TranscriptStreamEvent follows ModelStreamEvent delta pattern.""" + # Test partial transcript (delta) + partial_event = TranscriptStreamEvent( + delta={"text": "Hello"}, + text="Hello", + role="user", + is_final=False, + current_transcript=None, + ) + + assert partial_event.text == "Hello" + assert partial_event.role == "user" + assert partial_event.is_final is False + assert partial_event.current_transcript is None + assert partial_event.delta == {"text": "Hello"} + + # Test final transcript with accumulated text + final_event = TranscriptStreamEvent( + delta={"text": " world"}, + text=" world", + role="user", + is_final=True, + current_transcript="Hello world", + ) + + assert final_event.text == " world" + assert final_event.role == "user" + assert final_event.is_final is True + assert final_event.current_transcript == "Hello world" + assert final_event.delta == {"text": " world"} + + +def test_transcript_stream_event_extends_model_stream_event(): + """Test that TranscriptStreamEvent is a ModelStreamEvent.""" + from strands.types._events import ModelStreamEvent + + event = TranscriptStreamEvent( + delta={"text": "test"}, + text="test", + role="assistant", + is_final=True, + current_transcript="test", + ) + + assert isinstance(event, ModelStreamEvent) From 5e3b10ab7ad2637f3bb4e0d6b334b5e67c511bff Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 6 Nov 2025 09:16:02 -0500 Subject: [PATCH 059/242] Add default AudioAdapter when no adapter passed in --- .../adapters/__init__.py | 2 +- .../bidirectional_streaming/agent/agent.py | 27 ++++++++------ .../tests/optimized_example.py | 34 ------------------ .../tests/test_bidi.py | 35 +++++++++++++++++++ 4 files changed, 53 insertions(+), 45 deletions(-) delete mode 100644 src/strands/experimental/bidirectional_streaming/tests/optimized_example.py create mode 100644 src/strands/experimental/bidirectional_streaming/tests/test_bidi.py diff --git a/src/strands/experimental/bidirectional_streaming/adapters/__init__.py b/src/strands/experimental/bidirectional_streaming/adapters/__init__.py index 07d258a3e..6b192ef16 100644 --- a/src/strands/experimental/bidirectional_streaming/adapters/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/adapters/__init__.py @@ -7,4 +7,4 @@ from .audio_adapter import AudioAdapter -__all__ = ["AudioAdapter"] \ No newline at end of file +__all__ = ["AudioAdapter"] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 2378923fc..49b09679f 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -27,6 +27,7 @@ from ....types.content import Message, Messages from ....types.tools import ToolResult, ToolUse from ....types.traces import AttributeValue +from ..adapters.audio_adapter import AudioAdapter from ..event_loop.bidirectional_event_loop import BidirectionalAgentLoop from ..models.bidirectional_model import BidirectionalModel from ..models.novasonic import NovaSonicBidirectionalModel @@ -76,6 +77,7 @@ def __init__( tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). description: Description of what the Agent does. adapters: Optional list of adapter instances (e.g., AudioAdapter) for hardware abstraction. + If None, automatically creates default AudioAdapter for basic audio functionality. **kwargs: Additional configuration for future extensibility. Raises: @@ -124,8 +126,13 @@ def __init__( self._agentloop: Optional["BidirectionalAgentLoop"] = None self._output_queue = asyncio.Queue() - # Initialize adapters - self.adapters = adapters or [] + # Initialize adapters - auto-create AudioAdapter as default + if adapters is None: + # Create default AudioAdapter for basic audio functionality + default_audio_adapter = AudioAdapter(audio_config={"input_sample_rate": 16000}) + self.adapters = [default_audio_adapter] + else: + self.adapters = adapters @property def tool(self) -> ToolCaller: @@ -391,24 +398,24 @@ async def connect(self) -> None: """Connect the agent using configured adapters for bidirectional communication. Automatically uses configured adapters to establish bidirectional communication - with the model. Handles connection lifecycle and transport coordination. + with the model. If no adapters are provided in constructor, uses default AudioAdapter. Example: ```python - # With AudioAdapter - adapter = AudioAdapter(audio_config={"input_sample_rate": 16000}) + # Simple - uses default AudioAdapter + agent = BidirectionalAgent(model=model, tools=[calculator]) + await agent.connect() + + # Custom adapter + adapter = AudioAdapter(audio_config={"input_sample_rate": 24000}) agent = BidirectionalAgent(model=model, tools=[calculator], adapters=[adapter]) await agent.connect() ``` Raises: - ValueError: If no adapters are configured. Exception: Any exception from the transport layer. """ - if not self.adapters: - raise ValueError("No adapters configured. Add adapters to the agent constructor.") - - # Use first adapter + # Use first adapter (always available due to default initialization) adapter = self.adapters[0] sender = adapter.create_output() receiver = adapter.create_input() diff --git a/src/strands/experimental/bidirectional_streaming/tests/optimized_example.py b/src/strands/experimental/bidirectional_streaming/tests/optimized_example.py deleted file mode 100644 index 1270f3e4c..000000000 --- a/src/strands/experimental/bidirectional_streaming/tests/optimized_example.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Example using the OptimizedAudioAdapter - clean and simple.""" - -import asyncio -import os -import sys -from pathlib import Path - -sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) - -from strands.experimental.bidirectional_streaming.agent.clean_agent import CleanBidirectionalAgent -from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent - -from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel -from strands.experimental.bidirectional_streaming.adapters.optimized_audio_adapter import OptimizedAudioAdapter -from strands_tools import calculator - - -async def main(): - """Test the optimized audio adapter.""" - # Nova Sonic model - model = NovaSonicBidirectionalModel() - - # Clean agent with tools - agent = BidirectionalAgent(model=model, tools=[calculator]) - - # Optimized audio adapter - adapter = OptimizedAudioAdapter(agent) - - # Simple chat using context manager for automatic cleanup - await agent.run(sender=adapter.create_output(), receiver=adapter.create_input()) - - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi.py new file mode 100644 index 000000000..b273ab7a0 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi.py @@ -0,0 +1,35 @@ +"""Test BidirectionalAgent with new ultra-simple developer experience.""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) + +from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel +from strands_tools import calculator + + +async def main(): + """Test the BidirectionalAgent API.""" + + + # Nova Sonic model + model = NovaSonicBidirectionalModel() + + async with BidirectionalAgent(model=model, tools=[calculator]) as agent: + print("New BidirectionalAgent Experience") + print("Try asking: 'What is 25 times 8?' or 'Calculate the square root of 144'") + await agent.connect() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\n⏹️ Conversation ended by user") + except Exception as e: + print(f"❌ Error: {e}") + import traceback + traceback.print_exc() From ce56f42ff08852dc42447d2b2245e636148c0e84 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 6 Nov 2025 09:19:55 -0500 Subject: [PATCH 060/242] Update test comment --- .../experimental/bidirectional_streaming/tests/test_bidi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi.py index b273ab7a0..a53cbeba6 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi.py @@ -1,4 +1,4 @@ -"""Test BidirectionalAgent with new ultra-simple developer experience.""" +"""Test BidirectionalAgent with simple developer experience.""" import asyncio import sys From e13d51f2a3b140d26c8b9e1b1f514964e46c13f1 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 6 Nov 2025 17:20:20 +0300 Subject: [PATCH 061/242] fix novasonic example script --- .../bidirectional_streaming/tests/test_bidi_novasonic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py index b538fc023..e5a2e7c46 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py @@ -148,12 +148,12 @@ async def receive(agent, context): # Handle transcript events (bidirectional_transcript_stream) elif event_type == "bidirectional_transcript_stream": text_content = event.get("text", "") - source = event.get("source", "unknown") + role = event.get("role", "unknown") # Log transcript output - if source == "user": + if role == "user": print(f"User: {text_content}") - elif source == "assistant": + elif role == "assistant": print(f"Assistant: {text_content}") # Handle turn complete events (bidirectional_turn_complete) From 9240bad28d893aeaf930992ead60679f06e5a383 Mon Sep 17 00:00:00 2001 From: mehtarac Date: Thu, 6 Nov 2025 09:49:49 -0500 Subject: [PATCH 062/242] Update src/strands/experimental/bidirectional_streaming/agent/agent.py Co-authored-by: Nick Clegg --- src/strands/experimental/bidirectional_streaming/agent/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index efe890c2b..d5ff27ce2 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -51,7 +51,7 @@ class BidirectionalAgent: def __init__( self, model: Union[BidirectionalModel, str, None] = None, - tools: Optional[list[Union[str, dict[str, str], Any]]] = None, + tools: list[string, AgentTool, ToolProvider] = None, system_prompt: Optional[str] = None, messages: Optional[Messages] = None, record_direct_tool_call: bool = True, From 2a2861b751c8adad2665a5b71c962db0f07cfe5b Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 6 Nov 2025 09:59:57 -0500 Subject: [PATCH 063/242] Update imports --- .../bidirectional_streaming/agent/agent.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index d5ff27ce2..6e586b73e 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -25,13 +25,13 @@ from ....tools.registry import ToolRegistry from ....tools.watcher import ToolWatcher from ....types.content import Message, Messages -from ....types.tools import ToolResult, ToolUse -from ....types.traces import AttributeValue +from ....types.tools import ToolResult, ToolUse, AgentTool from ..adapters.audio_adapter import AudioAdapter from ..event_loop.bidirectional_event_loop import BidirectionalAgentLoop from ..models.bidirectional_model import BidirectionalModel -from ..models.novasonic import NovaSonicBidirectionalModel +from ..models.novasonic import NovaSonicModel from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent +from ....experimental.tools import ToolProvider logger = logging.getLogger(__name__) @@ -51,7 +51,7 @@ class BidirectionalAgent: def __init__( self, model: Union[BidirectionalModel, str, None] = None, - tools: list[string, AgentTool, ToolProvider] = None, + tools: list[str, AgentTool, ToolProvider] = None, system_prompt: Optional[str] = None, messages: Optional[Messages] = None, record_direct_tool_call: bool = True, @@ -85,9 +85,9 @@ def __init__( TypeError: If model type is unsupported. """ self.model = ( - NovaSonicBidirectionalModel() + NovaSonicModel() if not model - else NovaSonicBidirectionalModel(model_id=model) + else NovaSonicModel(model_id=model) if isinstance(model, str) else model ) From 30c0a5d2bdcd63acfb33b29b0904566aa14c987f Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 6 Nov 2025 22:44:25 +0300 Subject: [PATCH 064/242] fix: remove turn id --- .../bidirectional_streaming/models/gemini_live.py | 2 +- .../bidirectional_streaming/models/novasonic.py | 4 ++-- .../bidirectional_streaming/models/openai.py | 2 +- .../bidirectional_streaming/tests/test_gemini_live.py | 2 +- .../types/bidirectional_streaming.py | 11 ++--------- .../types/test_bidirectional_streaming.py | 5 +++-- 6 files changed, 10 insertions(+), 16 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 29b18da9e..4337a6cfa 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -210,7 +210,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic try: # Handle interruption first (from server_content) if message.server_content and message.server_content.interrupted: - return InterruptionEvent(reason="user_speech", turn_id=None) + return InterruptionEvent(reason="user_speech") # Handle input transcription (user's speech) - emit as transcript event if message.server_content and message.server_content.input_transcription: diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 3b4419586..c5aa277ed 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -530,7 +530,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: # Check for Nova Sonic interruption pattern if '{ "interrupted" : true }' in text_content: logger.debug("Nova interruption detected in text") - return InterruptionEvent(reason="user_speech", turn_id=None) + return InterruptionEvent(reason="user_speech") return TranscriptStreamEvent( delta={"text": text_content}, @@ -557,7 +557,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: # Handle interruption elif nova_event.get("stopReason") == "INTERRUPTED": logger.debug("Nova interruption stop reason") - return InterruptionEvent(reason="user_speech", turn_id=None) + return InterruptionEvent(reason="user_speech", response_id=None) # Handle usage events - convert to multimodal usage format elif "usageEvent" in nova_event: diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 1f072ac87..d923605e5 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -186,7 +186,7 @@ def _create_voice_activity_event(self, activity_type: str) -> InterruptionEvent """Create standardized interruption event for voice activity.""" # Only speech_started triggers interruption if activity_type == "speech_started": - return InterruptionEvent(reason="user_speech", turn_id=None) + return InterruptionEvent(reason="user_speech") # Other voice activity events are logged but don't create events return None diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py index 0bd283eb9..38791d9ed 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py @@ -196,7 +196,7 @@ async def receive(agent, context): # Handle turn start events (bidirectional_turn_start) elif event_type == "bidirectional_turn_start": - logger.debug(f"Turn started: {event.get('turn_id', 'unknown')}") + logger.debug(f"Turn started: {event.get('response_id', 'unknown')}") except asyncio.CancelledError: pass diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 355e78c2f..5069ccd5e 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -294,17 +294,14 @@ class InterruptionEvent(TypedEvent): Parameters: reason: Why the interruption occurred. - turn_id: ID of the turn that was interrupted (may be None). + response_id: ID of the response that was interrupted (may be None). """ - def __init__( - self, reason: Literal["user_speech", "error"], turn_id: Optional[str] = None - ): + def __init__(self, reason: Literal["user_speech", "error"]): super().__init__( { "type": "bidirectional_interruption", "reason": reason, - "turn_id": turn_id, } ) @@ -312,10 +309,6 @@ def __init__( def reason(self) -> str: return cast(str, self.get("reason")) - @property - def turn_id(self) -> Optional[str]: - return cast(Optional[str], self.get("turn_id")) - class ResponseCompleteEvent(TypedEvent): """Model finished generating response. diff --git a/tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py b/tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py index b6290cfcf..45bcd2de4 100644 --- a/tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py +++ b/tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py @@ -72,7 +72,7 @@ }, "bidirectional_transcript_stream", ), - (InterruptionEvent, {"reason": "user_speech", "turn_id": None}, "bidirectional_interruption"), + (InterruptionEvent, {"reason": "user_speech"}, "bidirectional_interruption"), ( ResponseCompleteEvent, {"response_id": "r1", "stop_reason": "complete"}, @@ -101,7 +101,8 @@ def test_event_json_serialization(event_class, kwargs, expected_type): # Serialize to JSON json_str = json.dumps(event) - + print("event_class:", event_class) + print(json_str) # Deserialize back data = json.loads(json_str) From 6ad01209c13cbddf38ec427ebdd57cf90385e1b3 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 6 Nov 2025 22:47:17 +0300 Subject: [PATCH 065/242] fix: fix nova completion id tracking --- .../models/novasonic.py | 54 ++++++++++++++----- 1 file changed, 42 insertions(+), 12 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index c5aa277ed..91d231a7a 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -126,6 +126,10 @@ def __init__( # Background task and event queue self._response_task = None self._event_queue = None + + # Track API-provided identifiers + self._current_completion_id = None + self._current_role = None logger.debug("Nova Sonic bidirectional model initialized: %s", model_id) @@ -510,6 +514,28 @@ async def close(self) -> None: def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: """Convert Nova Sonic events to TypedEvent format.""" + # Handle completion start - track completionId + if "completionStart" in nova_event: + completion_data = nova_event["completionStart"] + self._current_completion_id = completion_data.get("completionId") + logger.debug("Nova completion started: %s", self._current_completion_id) + return None + + # Handle completion end + if "completionEnd" in nova_event: + completion_data = nova_event["completionEnd"] + completion_id = completion_data.get("completionId", self._current_completion_id) + stop_reason = completion_data.get("stopReason", "END_TURN") + + event = ResponseCompleteEvent( + response_id=completion_id or str(uuid.uuid4()), # Fallback to UUID if missing + stop_reason="interrupted" if stop_reason == "INTERRUPTED" else "complete" + ) + + # Clear completion tracking + self._current_completion_id = None + return event + # Handle audio output if "audioOutput" in nova_event: # Audio is already base64 string from Nova Sonic @@ -557,7 +583,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: # Handle interruption elif nova_event.get("stopReason") == "INTERRUPTED": logger.debug("Nova interruption stop reason") - return InterruptionEvent(reason="user_speech", response_id=None) + return InterruptionEvent(reason="user_speech") # Handle usage events - convert to multimodal usage format elif "usageEvent" in nova_event: @@ -571,22 +597,26 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: total_tokens=usage_data.get("totalTokens", total_input + total_output) ) - # Handle content start events (track role) + # Handle content start events (track role and emit response start) elif "contentStart" in nova_event: - role = nova_event["contentStart"].get("role", "unknown") + content_data = nova_event["contentStart"] + role = content_data.get("role", "unknown") # Store role for subsequent text output events self._current_role = role - # Emit response start event - return ResponseStartEvent(response_id=str(uuid.uuid4())) - - # Handle content stop events - elif "contentStop" in nova_event: - stop_reason = nova_event["contentStop"].get("stopReason", "complete") - return ResponseCompleteEvent( - response_id=str(uuid.uuid4()), - stop_reason="interrupted" if stop_reason == "INTERRUPTED" else "complete" + + # Emit response start event using API-provided completionId + # completionId should already be tracked from completionStart event + return ResponseStartEvent( + response_id=self._current_completion_id or str(uuid.uuid4()) # Fallback to UUID if missing ) + # Handle content end events + elif "contentEnd" in nova_event: + # contentEnd doesn't signal response completion in Nova Sonic + # Multiple content blocks can exist in a single response + # Only completionEnd signals the actual response completion + return None + # Handle other events else: return None From 774ab8605e9136350db8945cd54e006ee1c8b6d7 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 6 Nov 2025 22:52:14 +0300 Subject: [PATCH 066/242] fix: remove unnecessary if condition --- .../bidirectional_streaming/models/novasonic.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 91d231a7a..944e45d4b 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -610,14 +610,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: response_id=self._current_completion_id or str(uuid.uuid4()) # Fallback to UUID if missing ) - # Handle content end events - elif "contentEnd" in nova_event: - # contentEnd doesn't signal response completion in Nova Sonic - # Multiple content blocks can exist in a single response - # Only completionEnd signals the actual response completion - return None - - # Handle other events + # Handle other events (contentEnd, etc.) else: return None From 0a63829d3b6e32e42251cb75e350137b659f24ca Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Fri, 7 Nov 2025 11:11:09 -0500 Subject: [PATCH 067/242] Update implementation based on bar-raising - Remove adapter from constructor - Implement BidirectionlIO interface - Add adapter the run() method --- .../adapters/__init__.py | 10 -- .../bidirectional_streaming/agent/agent.py | 130 +++++++------- .../tests/test_bidi.py | 8 +- .../bidirectional_streaming/types/__init__.py | 4 + .../audio_adapter.py => types/audio_io.py} | 166 +++++++++--------- .../types/bidirectional_io.py | 41 +++++ 6 files changed, 193 insertions(+), 166 deletions(-) delete mode 100644 src/strands/experimental/bidirectional_streaming/adapters/__init__.py rename src/strands/experimental/bidirectional_streaming/{adapters/audio_adapter.py => types/audio_io.py} (53%) create mode 100644 src/strands/experimental/bidirectional_streaming/types/bidirectional_io.py diff --git a/src/strands/experimental/bidirectional_streaming/adapters/__init__.py b/src/strands/experimental/bidirectional_streaming/adapters/__init__.py deleted file mode 100644 index 6b192ef16..000000000 --- a/src/strands/experimental/bidirectional_streaming/adapters/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Adapters for BidirectionalAgent. - -Provides clean separation of concerns by moving hardware-specific functionality -(audio, video, sensors, etc.) into separate adapter classes that work with -the core BidirectionalAgent through the run() pattern. -""" - -from .audio_adapter import AudioAdapter - -__all__ = ["AudioAdapter"] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 6e586b73e..ca772bf0c 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -26,11 +26,12 @@ from ....tools.watcher import ToolWatcher from ....types.content import Message, Messages from ....types.tools import ToolResult, ToolUse, AgentTool -from ..adapters.audio_adapter import AudioAdapter + from ..event_loop.bidirectional_event_loop import BidirectionalAgentLoop from ..models.bidirectional_model import BidirectionalModel from ..models.novasonic import NovaSonicModel from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent +from ..types import BidirectionalIO from ....experimental.tools import ToolProvider logger = logging.getLogger(__name__) @@ -60,7 +61,6 @@ def __init__( name: Optional[str] = None, tool_executor: Optional[ToolExecutor] = None, description: Optional[str] = None, - adapters: Optional[list[Any]] = None, **kwargs: Any, ): """Initialize bidirectional agent. @@ -76,8 +76,6 @@ def __init__( name: Name of the Agent. tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). description: Description of what the Agent does. - adapters: Optional list of adapter instances (e.g., AudioAdapter) for hardware abstraction. - If None, automatically creates default AudioAdapter for basic audio functionality. **kwargs: Additional configuration for future extensibility. Raises: @@ -125,14 +123,7 @@ def __init__( # connection management self._agentloop: Optional["BidirectionalAgentLoop"] = None self._output_queue = asyncio.Queue() - - # Initialize adapters - auto-create AudioAdapter as default - if adapters is None: - # Create default AudioAdapter for basic audio functionality - default_audio_adapter = AudioAdapter(audio_config={"input_sample_rate": 16000}) - self.adapters = [default_audio_adapter] - else: - self.adapters = adapters + self._current_adapters = [] # Track adapters for cleanup @property def tool(self) -> ToolCaller: @@ -261,11 +252,11 @@ async def start(self) -> None: logger.debug("Conversation start - initializing connection") # Create model session and event loop directly - model_session = await self.model.create_bidirectional_connection( + model_session = await self.model.connect( system_prompt=self.system_prompt, tools=self.tool_registry.get_all_tool_specs(), messages=self.messages ) - self._agentloop = BidirectionalAgentLoop(model_session=model_session, agent=self) + self._agentloop = BidirectionalAgentLoop(model=self.model, agent=self) await self._agentloop.start() logger.debug("Conversation ready") @@ -294,13 +285,13 @@ async def send(self, input_data: BidirectionalInput) -> None: logger.debug("Text sent: %d characters", len(input_data)) # Create TextInputEvent for send() text_event = {"text": input_data, "role": "user"} - await self._agentloop.model_session.send(text_event) + await self._agentloop.model.send(text_event) elif isinstance(input_data, dict) and "audioData" in input_data: # Handle audio input - await self._agentloop.model_session.send(input_data) + await self._agentloop.model.send(input_data) elif isinstance(input_data, dict) and "imageData" in input_data: # Handle image input (ImageInputEvent) - await self._agentloop.model_session.send(input_data) + await self._agentloop.model.send(input_data) else: raise ValueError( "Input must be either a string (text), AudioInputEvent " @@ -363,17 +354,20 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: """ try: logger.debug("Exiting async context manager - cleaning up adapters and connection") - - # Cleanup adapters first - for adapter in self.adapters: - if hasattr(adapter, "_cleanup_audio"): + + # Cleanup adapters if any are currently active + for adapter in self._current_adapters: + if hasattr(adapter, "cleanup"): try: - adapter._cleanup_audio() + adapter.cleanup() logger.debug(f"Cleaned up adapter: {type(adapter).__name__}") except Exception as adapter_error: logger.warning(f"Error cleaning up adapter: {adapter_error}") - - # Then cleanup agent connection + + # Clear current adapters + self._current_adapters = [] + + # Cleanup agent connection await self.end() except Exception as cleanup_error: @@ -396,72 +390,72 @@ def active(self) -> bool: """ return self._agentloop is not None and self._agentloop.active - async def connect(self) -> None: - """Connect the agent using configured adapters for bidirectional communication. - - Automatically uses configured adapters to establish bidirectional communication - with the model. If no adapters are provided in constructor, uses default AudioAdapter. + async def run(self, io_channels: list[BidirectionalIO | tuple[Callable, Callable]]) -> None: + """Run the agent using provided IO channels or transport tuples for bidirectional communication. + Args: + io_channels: List containing either BidirectionalIO instances or (sender, receiver) tuples. + - BidirectionalIO: IO channel instance with input_channel(), output_channel(), and cleanup() methods + - tuple: (sender_callable, receiver_callable) for custom transport + Example: ```python - # Simple - uses default AudioAdapter + # With IO channel + audio_io = AudioIO(audio_config={"input_sample_rate": 16000}) agent = BidirectionalAgent(model=model, tools=[calculator]) - await agent.connect() + await agent.run(io_channels=[audio_io]) - # Custom adapter - adapter = AudioAdapter(audio_config={"input_sample_rate": 24000}) - agent = BidirectionalAgent(model=model, tools=[calculator], adapters=[adapter]) - await agent.connect() + # With tuple (backward compatibility) + await agent.run(io_channels=[(sender_function, receiver_function)]) ``` Raises: + ValueError: If io_channels list is empty or contains invalid items. Exception: Any exception from the transport layer. """ - # Use first adapter (always available due to default initialization) - adapter = self.adapters[0] - sender = adapter.create_output() - receiver = adapter.create_input() + if not io_channels: + raise ValueError("io_channels parameter cannot be empty. Provide either an IO channel or (sender, receiver) tuple.") + + transport = io_channels[0] + + # Set IO channel tracking for cleanup + if hasattr(transport, 'input_channel') and hasattr(transport, 'output_channel'): + self._current_adapters = [transport] # IO channel needs cleanup + elif isinstance(transport, tuple) and len(transport) == 2: + self._current_adapters = [] # Tuple needs no cleanup + else: + raise ValueError("io_channels list must contain either BidirectionalIO instances or (sender, receiver) tuples.") + # Auto-manage session lifecycle if self.active: - # Use existing connection - await self._run(sender, receiver) + await self._run_with_transport(transport) else: - # Use async context manager for automatic lifecycle management async with self: - await self._run(sender, receiver) + await self._run_with_transport(transport) - async def _run( + async def _run_with_transport( self, - sender: Callable[[Any], Any], - receiver: Callable[[], Any], + transport: BidirectionalIO | tuple[Callable, Callable], ) -> None: - """Internal method to run send/receive loops with an active connection. - - Args: - sender: Async callable that sends events to the client. - receiver: Async callable that receives events from the client. - """ + """Internal method to run send/receive loops with an active connection.""" async def receive_from_agent(): - """Receive events from agent and send to client.""" - try: - async for event in self.receive(): - await sender(event) - except Exception as e: - logger.debug(f"Receive from agent stopped: {e}") - raise + """Receive events from agent and send to transport.""" + async for event in self.receive(): + if hasattr(transport, 'output_channel'): + await transport.output_channel(event) + else: + await transport[0](event) async def send_to_agent(): - """Receive events from client and send to agent.""" - try: - while self.active: - event = await receiver() - await self.send(event) - except Exception as e: - logger.debug(f"Send to agent stopped: {e}") - raise + """Receive events from transport and send to agent.""" + while self.active: + if hasattr(transport, 'input_channel'): + event = await transport.input_channel() + else: + event = await transport[1]() + await self.send(event) - # Run both loops concurrently await asyncio.gather(receive_from_agent(), send_to_agent(), return_exceptions=True) def _validate_active_connection(self) -> None: diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi.py index a53cbeba6..d3917771d 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi.py @@ -7,7 +7,8 @@ sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel +from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicModel +from strands.experimental.bidirectional_streaming.types.audio_io import AudioIO from strands_tools import calculator @@ -16,12 +17,13 @@ async def main(): # Nova Sonic model - model = NovaSonicBidirectionalModel() + adapter = AudioIO() + model = NovaSonicModel(region="us-east-1") async with BidirectionalAgent(model=model, tools=[calculator]) as agent: print("New BidirectionalAgent Experience") print("Try asking: 'What is 25 times 8?' or 'Calculate the square root of 144'") - await agent.connect() + await agent.run(io_channels=[adapter]) if __name__ == "__main__": diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index d040ee436..62e293ef1 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -1,5 +1,7 @@ """Type definitions for bidirectional streaming.""" +from .audio_io import AudioIO +from .bidirectional_io import BidirectionalIO from .bidirectional_streaming import ( DEFAULT_CHANNELS, DEFAULT_SAMPLE_RATE, @@ -20,6 +22,8 @@ ) __all__ = [ + "AudioIO", + "BidirectionalIO", "AudioInputEvent", "AudioOutputEvent", "BidirectionalConnectionEndEvent", diff --git a/src/strands/experimental/bidirectional_streaming/adapters/audio_adapter.py b/src/strands/experimental/bidirectional_streaming/types/audio_io.py similarity index 53% rename from src/strands/experimental/bidirectional_streaming/adapters/audio_adapter.py rename to src/strands/experimental/bidirectional_streaming/types/audio_io.py index b093ae0dd..58cd1b3ab 100644 --- a/src/strands/experimental/bidirectional_streaming/adapters/audio_adapter.py +++ b/src/strands/experimental/bidirectional_streaming/types/audio_io.py @@ -1,6 +1,6 @@ -"""AudioAdapter - Clean separation of audio functionality from core BidirectionalAgent. +"""AudioIO - Clean separation of audio functionality from core BidirectionalAgent. -Provides audio input/output capabilities for BidirectionalAgent through the adapter pattern. +Provides audio input/output capabilities for BidirectionalAgent through the BidirectionalIO protocol. Handles all PyAudio setup, streaming, and cleanup while keeping the core agent data-agnostic. """ @@ -9,6 +9,8 @@ import logging from typing import Any, Callable, Optional +from .bidirectional_io import BidirectionalIO + try: import pyaudio except ImportError: @@ -17,14 +19,14 @@ logger = logging.getLogger(__name__) -class AudioAdapter: - """Audio adapter for BidirectionalAgent with direct stream processing.""" +class AudioIO(BidirectionalIO): + """Audio IO channel for BidirectionalAgent with direct stream processing.""" def __init__( self, audio_config: Optional[dict] = None, ): - """Initialize AudioAdapter with clean audio configuration. + """Initialize AudioIO with clean audio configuration. Args: audio_config: Dictionary containing audio configuration: @@ -37,7 +39,7 @@ def __init__( - output_channels (int): Output channels (default: 1) """ if pyaudio is None: - raise ImportError("PyAudio is required for AudioAdapter. Install with: pip install pyaudio") + raise ImportError("PyAudio is required for AudioIO. Install with: pip install pyaudio") # Default audio configuration default_config = { @@ -102,7 +104,7 @@ def _setup_audio(self) -> None: self.output_stream.start_stream() except Exception as e: - logger.error(f"AudioAdapter: Audio setup failed: {e}") + logger.error(f"AudioIO: Audio setup failed: {e}") self._cleanup_audio() raise @@ -129,85 +131,79 @@ def _cleanup_audio(self) -> None: except Exception as e: logger.warning(f"Audio cleanup error: {e}") - def create_input(self) -> Callable[[], dict]: - """Create audio input function for agent.run().""" - - async def audio_receiver() -> dict: - """Read audio from microphone.""" - if not self.input_stream: - self._setup_audio() - - try: - audio_bytes = self.input_stream.read(self.chunk_size, exception_on_overflow=False) - return { - "audioData": audio_bytes, - "format": "pcm", - "sampleRate": self.input_sample_rate, - "channels": self.input_channels, - } - except Exception as e: - logger.warning(f"Audio input error: {e}") - return { - "audioData": b"", - "format": "pcm", - "sampleRate": self.input_sample_rate, - "channels": self.input_channels, - } - - return audio_receiver - - def create_output(self) -> Callable[[dict], None]: - """Create audio output function with direct stream writing.""" - - async def audio_sender(event: dict) -> None: - """Handle audio events with direct stream writing.""" - if not self.output_stream: - self._setup_audio() - - # Handle audio output - if "audioOutput" in event and not self.interrupted: - audio_data = event["audioOutput"]["audioData"] - - # Handle both base64 and raw bytes - if isinstance(audio_data, str): - audio_data = base64.b64decode(audio_data) - - if audio_data: - chunk_size = 2048 - for i in range(0, len(audio_data), chunk_size): - # Check for interruption before each chunk - if self.interrupted: - break - - chunk = audio_data[i : i + chunk_size] - try: - self.output_stream.write(chunk, exception_on_underflow=False) - await asyncio.sleep(0) - except Exception as e: - logger.warning(f"Audio playback error: {e}") - break - - elif "interruptionDetected" in event or "interrupted" in event: - self.interrupted = True - logger.debug("Interruption detected") - - # Stop and restart stream for immediate interruption - if self.output_stream: + async def input_channel(self) -> dict: + """Read audio from microphone.""" + if not self.input_stream: + self._setup_audio() + + try: + audio_bytes = self.input_stream.read(self.chunk_size, exception_on_overflow=False) + return { + "audioData": audio_bytes, + "format": "pcm", + "sampleRate": self.input_sample_rate, + "channels": self.input_channels, + } + except Exception as e: + logger.warning(f"Audio input error: {e}") + return { + "audioData": b"", + "format": "pcm", + "sampleRate": self.input_sample_rate, + "channels": self.input_channels, + } + + async def output_channel(self, event: dict) -> None: + """Handle audio events with direct stream writing.""" + if not self.output_stream: + self._setup_audio() + + # Handle audio output + if "audioOutput" in event and not self.interrupted: + audio_data = event["audioOutput"]["audioData"] + + # Handle both base64 and raw bytes + if isinstance(audio_data, str): + audio_data = base64.b64decode(audio_data) + + if audio_data: + chunk_size = 2048 + for i in range(0, len(audio_data), chunk_size): + # Check for interruption before each chunk + if self.interrupted: + break + + chunk = audio_data[i : i + chunk_size] try: - self.output_stream.stop_stream() - self.output_stream.start_stream() + self.output_stream.write(chunk, exception_on_underflow=False) + await asyncio.sleep(0) except Exception as e: - logger.debug(f"Error clearing audio buffer: {e}") + logger.warning(f"Audio playback error: {e}") + break - self.interrupted = False + elif "interruptionDetected" in event or "interrupted" in event: + self.interrupted = True + logger.debug("Interruption detected") - elif "textOutput" in event: - text = event["textOutput"].get("text", "").strip() - role = event["textOutput"].get("role", "") - if text: - if role.upper() == "ASSISTANT": - print(f"🤖 {text}") - elif role.upper() == "USER": - print(f"User: {text}") - - return audio_sender + # Stop and restart stream for immediate interruption + if self.output_stream: + try: + self.output_stream.stop_stream() + self.output_stream.start_stream() + except Exception as e: + logger.debug(f"Error clearing audio buffer: {e}") + + self.interrupted = False + + elif "textOutput" in event: + text = event["textOutput"].get("text", "").strip() + role = event["textOutput"].get("role", "") + if text: + if role.upper() == "ASSISTANT": + print(f"🤖 {text}") + elif role.upper() == "USER": + print(f"User: {text}") + + def cleanup(self) -> None: + """Clean up IO channel resources.""" + self._cleanup_audio() diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_io.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_io.py new file mode 100644 index 000000000..d786e6d45 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_io.py @@ -0,0 +1,41 @@ +"""BidirectionalIO protocol for bidirectional streaming IO channels. + +Defines the standard interface that all bidirectional IO channels must implement +for integration with BidirectionalAgent. This protocol enables clean +separation between the agent's core logic and hardware-specific implementations. +""" + +from typing import Protocol + + +class BidirectionalIO(Protocol): + """Base protocol for bidirectional IO channels. + + Defines the interface that IO channels must implement to work with + BidirectionalAgent. IO channels handle hardware abstraction (audio, video, + WebSocket, etc.) while the agent handles model communication and logic. + """ + + async def input_channel(self) -> dict: + """Read input data from the IO channel source. + + Returns: + dict: Input event data to send to the model. + """ + ... + + async def output_channel(self, event: dict) -> None: + """Process output event from the model through the IO channel. + + Args: + event: Output event from the model to handle. + """ + ... + + def cleanup(self) -> None: + """Clean up IO channel resources. + + Called by the agent during shutdown to ensure proper + resource cleanup (streams, connections, etc.). + """ + ... \ No newline at end of file From a9784f043c4dae5916c26eaa89bdc5e803cdfafe Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Sun, 9 Nov 2025 19:02:33 +0300 Subject: [PATCH 068/242] temp commit message, review the changes --- .../models/gemini_live.py | 57 ++++++++++++++----- .../tests/test_gemini_live.py | 19 ++----- .../models/test_gemini_live.py | 56 ++++++++++++++++-- .../test_bidirectional_agent.py | 27 +++------ 4 files changed, 108 insertions(+), 51 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 4337a6cfa..465e8c9eb 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -58,7 +58,7 @@ class GeminiLiveModel(BidirectionalModel): def __init__( self, - model_id: str = "models/gemini-2.0-flash-live-preview-04-09", + model_id: str = "gemini-2.5-flash-native-audio-preview-09-2025", api_key: Optional[str] = None, live_config: Optional[Dict[str, Any]] = None, **kwargs @@ -74,7 +74,19 @@ def __init__( # Model configuration self.model_id = model_id self.api_key = api_key - self.live_config = live_config or {} + + # Set default live_config with transcription enabled + default_config = { + "response_modalities": ["AUDIO"], + "outputAudioTranscription": {}, # Enable output transcription by default + "inputAudioTranscription": {} # Enable input transcription by default + } + + # Merge user config with defaults (user config takes precedence) + if live_config: + default_config.update(live_config) + + self.live_config = default_config # Create Gemini client with proper API version client_kwargs = {} @@ -242,18 +254,8 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic current_transcript=transcription_text ) - # Handle text output from model - if message.text: - logger.debug(f"Text output as transcript: {message.text}") - return TranscriptStreamEvent( - delta={"text": message.text}, - text=message.text, - role="assistant", - is_final=True, - current_transcript=message.text - ) - # Handle audio output using SDK's built-in data property + # Check this BEFORE text to avoid triggering warning on mixed content if message.data: # Convert bytes to base64 string for JSON serializability audio_b64 = base64.b64encode(message.data).decode('utf-8') @@ -264,6 +266,32 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic channels=GEMINI_CHANNELS ) + # Handle text output from model_turn (avoids warning by checking parts directly) + if message.server_content and message.server_content.model_turn: + model_turn = message.server_content.model_turn + if model_turn.parts: + # Concatenate all text parts (Gemini may send multiple parts) + text_parts = [] + for part in model_turn.parts: + # Log all part types for debugging + part_attrs = {attr: getattr(part, attr, None) for attr in dir(part) if not attr.startswith('_')} + logger.debug(f"Model turn part attributes: {part_attrs}") + + # Check if part has text attribute and it's not empty + if hasattr(part, 'text') and part.text: + text_parts.append(part.text) + + if text_parts: + full_text = " ".join(text_parts) + logger.debug(f"Text output as transcript ({len(text_parts)} parts): {full_text}") + return TranscriptStreamEvent( + delta={"text": full_text}, + text=full_text, + role="assistant", + is_final=True, + current_transcript=full_text + ) + # Handle tool calls if message.tool_call and message.tool_call.function_calls: for func_call in message.tool_call.function_calls: @@ -326,7 +354,8 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic logger.error("Error converting Gemini Live event: %s", e) logger.error("Message type: %s", type(message).__name__) logger.error("Message attributes: %s", [attr for attr in dir(message) if not attr.startswith('_')]) - return None + # Return ErrorEvent instead of None so caller can handle it + return ErrorEvent(error=e) async def send( self, diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py index 38791d9ed..e9d715b49 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py @@ -167,13 +167,13 @@ async def receive(agent, context): # Handle transcript events (bidirectional_transcript_stream) elif event_type == "bidirectional_transcript_stream": transcript_text = event.get("text", "") - transcript_source = event.get("source", "unknown") + transcript_role = event.get("role", "unknown") is_final = event.get("is_final", False) # Print transcripts with special formatting - if transcript_source == "user": + if transcript_role == "user": print(f"🎤 User: {transcript_text}") - elif transcript_source == "assistant": + elif transcript_role == "assistant": print(f"🔊 Assistant: {transcript_text}") # Handle turn complete events (bidirectional_turn_complete) @@ -313,17 +313,10 @@ async def main(duration=180): # Initialize Gemini Live model with proper configuration logger.info("Initializing Gemini Live model with API key") - model = GeminiLiveModel( - model_id="gemini-2.5-flash-native-audio-preview-09-2025", - api_key=api_key, - live_config={ - "response_modalities": ["AUDIO"], - "output_audio_transcription": {}, # Enable output transcription - "input_audio_transcription": {} # Enable input transcription - } - ) + # Use default model and config (includes transcription enabled by default) + model = GeminiLiveModel(api_key=api_key) logger.info("Gemini Live model initialized successfully") - print("Using Gemini Live model") + print("Using Gemini Live model with default config (audio output + transcription enabled)") agent = BidirectionalAgent( model=model, diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index 107a8a84a..e890ddcbb 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -89,20 +89,28 @@ def test_model_initialization(mock_genai_client, model_id, api_key): # Test default config model_default = GeminiLiveModel() - assert model_default.model_id == "models/gemini-2.0-flash-live-preview-04-09" + assert model_default.model_id == "gemini-2.5-flash-native-audio-preview-09-2025" assert model_default.api_key is None assert model_default._active is False assert model_default.live_session is None + # Check default config includes transcription + assert model_default.live_config["response_modalities"] == ["AUDIO"] + assert "outputAudioTranscription" in model_default.live_config + assert "inputAudioTranscription" in model_default.live_config # Test with API key model_with_key = GeminiLiveModel(model_id=model_id, api_key=api_key) assert model_with_key.model_id == model_id assert model_with_key.api_key == api_key - # Test with custom config + # Test with custom config (merges with defaults) live_config = {"temperature": 0.7, "top_p": 0.9} model_custom = GeminiLiveModel(model_id=model_id, live_config=live_config) - assert model_custom.live_config == live_config + # Custom config should be merged with defaults + assert model_custom.live_config["temperature"] == 0.7 + assert model_custom.live_config["top_p"] == 0.9 + # Defaults should still be present + assert "response_modalities" in model_custom.live_config # Connection Tests @@ -292,12 +300,24 @@ async def test_event_conversion(mock_genai_client, model): _, _, _ = mock_genai_client await model.connect() - # Test text output (converted to transcript) + # Test text output (converted to transcript via model_turn.parts) mock_text = unittest.mock.Mock() - mock_text.text = "Hello from Gemini" mock_text.data = None mock_text.tool_call = None - mock_text.server_content = None + + # Create proper server_content structure with model_turn + mock_server_content = unittest.mock.Mock() + mock_server_content.interrupted = False + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + + mock_model_turn = unittest.mock.Mock() + mock_part = unittest.mock.Mock() + mock_part.text = "Hello from Gemini" + mock_model_turn.parts = [mock_part] + mock_server_content.model_turn = mock_model_turn + + mock_text.server_content = mock_server_content text_event = model._convert_gemini_live_event(mock_text) assert isinstance(text_event, TranscriptStreamEvent) @@ -307,6 +327,30 @@ async def test_event_conversion(mock_genai_client, model): assert text_event.delta == {"text": "Hello from Gemini"} assert text_event.current_transcript == "Hello from Gemini" + # Test multiple text parts (should concatenate) + mock_multi_text = unittest.mock.Mock() + mock_multi_text.data = None + mock_multi_text.tool_call = None + + mock_server_content_multi = unittest.mock.Mock() + mock_server_content_multi.interrupted = False + mock_server_content_multi.input_transcription = None + mock_server_content_multi.output_transcription = None + + mock_model_turn_multi = unittest.mock.Mock() + mock_part1 = unittest.mock.Mock() + mock_part1.text = "Hello" + mock_part2 = unittest.mock.Mock() + mock_part2.text = "from Gemini" + mock_model_turn_multi.parts = [mock_part1, mock_part2] + mock_server_content_multi.model_turn = mock_model_turn_multi + + mock_multi_text.server_content = mock_server_content_multi + + multi_text_event = model._convert_gemini_live_event(mock_multi_text) + assert isinstance(multi_text_event, TranscriptStreamEvent) + assert multi_text_event.text == "Hello from Gemini" # Concatenated with space + # Test audio output (base64 encoded) import base64 mock_audio = unittest.mock.Mock() diff --git a/tests_integ/bidirectional_streaming/test_bidirectional_agent.py b/tests_integ/bidirectional_streaming/test_bidirectional_agent.py index 80b32b178..f23e6b84f 100644 --- a/tests_integ/bidirectional_streaming/test_bidirectional_agent.py +++ b/tests_integ/bidirectional_streaming/test_bidirectional_agent.py @@ -82,24 +82,15 @@ def calculator(operation: str, x: float, y: float) -> float: "env_vars": ["OPENAI_API_KEY"], "skip_reason": "OPENAI_API_KEY not available", }, - # NOTE: Gemini Live is temporarily disabled in parameterized tests - # Issue: Transcript events are not being properly emitted alongside audio events - # The model responds with audio but the test infrastructure expects text/transcripts - # TODO: Fix Gemini Live event emission to yield both transcript and audio events - # "gemini_live": { - # "model_class": GeminiLiveModel, - # "model_kwargs": { - # "model_id": "gemini-2.5-flash-native-audio-preview-09-2025", - # "params": { - # "response_modalities": ["AUDIO"], - # "output_audio_transcription": {}, - # "input_audio_transcription": {}, - # }, - # }, - # "silence_duration": 3.0, - # "env_vars": ["GOOGLE_AI_API_KEY"], - # "skip_reason": "GOOGLE_AI_API_KEY not available", - # }, + "gemini_live": { + "model_class": GeminiLiveModel, + "model_kwargs": { + # Uses default model and config (audio output + transcription enabled) + }, + "silence_duration": 1.5, # Gemini has good VAD, similar to OpenAI + "env_vars": ["GOOGLE_AI_API_KEY"], + "skip_reason": "GOOGLE_AI_API_KEY not available", + }, } From 8d9a29869d238628ba056542b918b834fa472d62 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sun, 9 Nov 2025 16:55:18 -0500 Subject: [PATCH 069/242] Updates: make ToolCaller private, minor updates based on PR comments --- src/strands/agent/agent.py | 6 +- .../bidirectional_streaming/agent/agent.py | 61 ++++++++----------- 2 files changed, 29 insertions(+), 38 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 42cfb61e3..4273d737b 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -51,7 +51,7 @@ from ..session.session_manager import SessionManager from ..telemetry.metrics import EventLoopMetrics from ..telemetry.tracer import get_tracer, serialize -from ..tools.caller import ToolCaller +from ..tools.caller import _ToolCaller from ..tools.executors import ConcurrentToolExecutor from ..tools.executors._executor import ToolExecutor from ..tools.registry import ToolRegistry @@ -240,7 +240,7 @@ def __init__( else: self.state = AgentState() - self.tool_caller = ToolCaller(self) + self.tool_caller = _ToolCaller(self) self.hooks = HookRegistry() @@ -259,7 +259,7 @@ def __init__( self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self)) @property - def tool(self) -> ToolCaller: + def tool(self) -> _ToolCaller: """Call tool as a function. Returns: diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index ca772bf0c..0d8925dbc 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -15,11 +15,11 @@ import asyncio import json import logging -from typing import Any, AsyncIterable, Mapping, Optional, Union, Callable +from typing import Any, AsyncIterable, Callable from .... import _identifier from ....telemetry.metrics import EventLoopMetrics -from ....tools.caller import ToolCaller +from ....tools.caller import _ToolCaller from ....tools.executors import ConcurrentToolExecutor from ....tools.executors._executor import ToolExecutor from ....tools.registry import ToolRegistry @@ -51,16 +51,16 @@ class BidirectionalAgent: def __init__( self, - model: Union[BidirectionalModel, str, None] = None, - tools: list[str, AgentTool, ToolProvider] = None, - system_prompt: Optional[str] = None, - messages: Optional[Messages] = None, + model: BidirectionalModel| str | None = None, + tools: list[str| AgentTool| ToolProvider]| None = None, + system_prompt: str | None = None, + messages: Messages | None = None, record_direct_tool_call: bool = True, load_tools_from_directory: bool = False, - agent_id: Optional[str] = None, - name: Optional[str] = None, - tool_executor: Optional[ToolExecutor] = None, - description: Optional[str] = None, + agent_id: str | None = None, + name: str | None = None, + tool_executor: ToolExecutor | None = None, + description: str | None = None, **kwargs: Any, ): """Initialize bidirectional agent. @@ -118,15 +118,15 @@ def __init__( # Initialize other components self.event_loop_metrics = EventLoopMetrics() - self.tool_caller = ToolCaller(self) + self._tool_caller = _ToolCaller(self) # connection management - self._agentloop: Optional["BidirectionalAgentLoop"] = None + self._agent_loop: "BidirectionalAgentLoop" | None = None self._output_queue = asyncio.Queue() self._current_adapters = [] # Track adapters for cleanup @property - def tool(self) -> ToolCaller: + def tool(self) -> _ToolCaller: """Call tool as a function. Returns: @@ -138,7 +138,7 @@ def tool(self) -> ToolCaller: agent.tool.calculator(expression="2+2") ``` """ - return self.tool_caller + return self._tool_caller @property def tool_names(self) -> list[str]: @@ -154,7 +154,7 @@ def _record_tool_execution( self, tool: ToolUse, tool_result: ToolResult, - user_message_override: Optional[str], + user_message_override: str | None, ) -> None: """Record a tool execution in the message history. @@ -246,18 +246,18 @@ async def start(self) -> None: ValueError: If conversation already active. ConnectionError: If connection creation fails. """ - if self._agentloop and self._agentloop.active: + if self._agent_loop and self._agent_loop.active: raise ValueError("Conversation already active. Call end() first.") logger.debug("Conversation start - initializing connection") # Create model session and event loop directly - model_session = await self.model.connect( + await self.model.connect( system_prompt=self.system_prompt, tools=self.tool_registry.get_all_tool_specs(), messages=self.messages ) - self._agentloop = BidirectionalAgentLoop(model=self.model, agent=self) - await self._agentloop.start() + self._agent_loop = BidirectionalAgentLoop(model=self.model, agent=self) + await self._agent_loop.start() logger.debug("Conversation ready") @@ -285,19 +285,10 @@ async def send(self, input_data: BidirectionalInput) -> None: logger.debug("Text sent: %d characters", len(input_data)) # Create TextInputEvent for send() text_event = {"text": input_data, "role": "user"} - await self._agentloop.model.send(text_event) - elif isinstance(input_data, dict) and "audioData" in input_data: - # Handle audio input - await self._agentloop.model.send(input_data) - elif isinstance(input_data, dict) and "imageData" in input_data: - # Handle image input (ImageInputEvent) - await self._agentloop.model.send(input_data) + await self._agent_loop.model.send(text_event) else: - raise ValueError( - "Input must be either a string (text), AudioInputEvent " - "(dict with audioData, format, sampleRate, channels), or ImageInputEvent " - "(dict with imageData, mimeType, encoding)" - ) + # For audio, image, or any other input - let model handle it + await self._agent_loop.model.send(input_data) async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: """Receive events from the model including audio, text, and tool calls. @@ -321,9 +312,9 @@ async def end(self) -> None: Terminates the streaming connection, cancels background tasks, and closes the connection to the model provider. """ - if self._agentloop: - await self._agentloop.stop() - self._agentloop = None + if self._agent_loop: + await self._agent_loop.stop() + self._agent_loop = None async def __aenter__(self) -> "BidirectionalAgent": """Async context manager entry point. @@ -388,7 +379,7 @@ def active(self) -> bool: Returns: True if connection is active and ready for communication, False otherwise. """ - return self._agentloop is not None and self._agentloop.active + return self._agent_loop is not None and self._agent_loop.active async def run(self, io_channels: list[BidirectionalIO | tuple[Callable, Callable]]) -> None: """Run the agent using provided IO channels or transport tuples for bidirectional communication. From 73416d7e71fef30646b2f636ad189d0c2d531a36 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sun, 9 Nov 2025 17:07:15 -0500 Subject: [PATCH 070/242] Update: file names, locations, and ToolCaller class name --- .../bidirectional_streaming/agent/agent.py | 12 ++++++------ .../bidirectional_streaming/io/__init__.py | 5 +++++ .../{types/audio_io.py => io/audio.py} | 15 +++++---------- .../bidirectional_streaming/tests/test_bidi.py | 2 +- .../types/{bidirectional_io.py => io.py} | 4 ++-- src/strands/tools/caller.py | 2 +- 6 files changed, 20 insertions(+), 20 deletions(-) create mode 100644 src/strands/experimental/bidirectional_streaming/io/__init__.py rename src/strands/experimental/bidirectional_streaming/{types/audio_io.py => io/audio.py} (96%) rename src/strands/experimental/bidirectional_streaming/types/{bidirectional_io.py => io.py} (92%) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 0d8925dbc..8bc372a2a 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -31,7 +31,7 @@ from ..models.bidirectional_model import BidirectionalModel from ..models.novasonic import NovaSonicModel from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent -from ..types import BidirectionalIO +from ..types import BidiIO from ....experimental.tools import ToolProvider logger = logging.getLogger(__name__) @@ -381,12 +381,12 @@ def active(self) -> bool: """ return self._agent_loop is not None and self._agent_loop.active - async def run(self, io_channels: list[BidirectionalIO | tuple[Callable, Callable]]) -> None: + async def run(self, io_channels: list[BidiIO | tuple[Callable, Callable]]) -> None: """Run the agent using provided IO channels or transport tuples for bidirectional communication. Args: - io_channels: List containing either BidirectionalIO instances or (sender, receiver) tuples. - - BidirectionalIO: IO channel instance with input_channel(), output_channel(), and cleanup() methods + io_channels: List containing either BidiIO instances or (sender, receiver) tuples. + - BidiIO: IO channel instance with input_channel(), output_channel(), and cleanup() methods - tuple: (sender_callable, receiver_callable) for custom transport Example: @@ -415,7 +415,7 @@ async def run(self, io_channels: list[BidirectionalIO | tuple[Callable, Callable elif isinstance(transport, tuple) and len(transport) == 2: self._current_adapters = [] # Tuple needs no cleanup else: - raise ValueError("io_channels list must contain either BidirectionalIO instances or (sender, receiver) tuples.") + raise ValueError("io_channels list must contain either BidiIO instances or (sender, receiver) tuples.") # Auto-manage session lifecycle if self.active: @@ -426,7 +426,7 @@ async def run(self, io_channels: list[BidirectionalIO | tuple[Callable, Callable async def _run_with_transport( self, - transport: BidirectionalIO | tuple[Callable, Callable], + transport: BidiIO | tuple[Callable, Callable], ) -> None: """Internal method to run send/receive loops with an active connection.""" diff --git a/src/strands/experimental/bidirectional_streaming/io/__init__.py b/src/strands/experimental/bidirectional_streaming/io/__init__.py new file mode 100644 index 000000000..0bf186777 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/io/__init__.py @@ -0,0 +1,5 @@ +"""IO channel implementations for bidirectional streaming.""" + +from .audio import AudioIO + +__all__ = ["AudioIO"] \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/types/audio_io.py b/src/strands/experimental/bidirectional_streaming/io/audio.py similarity index 96% rename from src/strands/experimental/bidirectional_streaming/types/audio_io.py rename to src/strands/experimental/bidirectional_streaming/io/audio.py index 58cd1b3ab..4194b293c 100644 --- a/src/strands/experimental/bidirectional_streaming/types/audio_io.py +++ b/src/strands/experimental/bidirectional_streaming/io/audio.py @@ -1,30 +1,25 @@ """AudioIO - Clean separation of audio functionality from core BidirectionalAgent. -Provides audio input/output capabilities for BidirectionalAgent through the BidirectionalIO protocol. +Provides audio input/output capabilities for BidirectionalAgent through the BidiIO protocol. Handles all PyAudio setup, streaming, and cleanup while keeping the core agent data-agnostic. """ import asyncio import base64 import logging -from typing import Any, Callable, Optional +import pyaudio -from .bidirectional_io import BidirectionalIO - -try: - import pyaudio -except ImportError: - pyaudio = None +from ..types.io import BidiIO logger = logging.getLogger(__name__) -class AudioIO(BidirectionalIO): +class AudioIO(BidiIO): """Audio IO channel for BidirectionalAgent with direct stream processing.""" def __init__( self, - audio_config: Optional[dict] = None, + audio_config: dict | None = None, ): """Initialize AudioIO with clean audio configuration. diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi.py index d3917771d..57ce8b986 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi.py @@ -8,7 +8,7 @@ from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicModel -from strands.experimental.bidirectional_streaming.types.audio_io import AudioIO +from strands.experimental.bidirectional_streaming.io.audio import AudioIO from strands_tools import calculator diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_io.py b/src/strands/experimental/bidirectional_streaming/types/io.py similarity index 92% rename from src/strands/experimental/bidirectional_streaming/types/bidirectional_io.py rename to src/strands/experimental/bidirectional_streaming/types/io.py index d786e6d45..99639619f 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_io.py +++ b/src/strands/experimental/bidirectional_streaming/types/io.py @@ -1,4 +1,4 @@ -"""BidirectionalIO protocol for bidirectional streaming IO channels. +"""BidiIO protocol for bidirectional streaming IO channels. Defines the standard interface that all bidirectional IO channels must implement for integration with BidirectionalAgent. This protocol enables clean @@ -8,7 +8,7 @@ from typing import Protocol -class BidirectionalIO(Protocol): +class BidiIO(Protocol): """Base protocol for bidirectional IO channels. Defines the interface that IO channels must implement to work with diff --git a/src/strands/tools/caller.py b/src/strands/tools/caller.py index 167789801..9fe213fec 100644 --- a/src/strands/tools/caller.py +++ b/src/strands/tools/caller.py @@ -9,7 +9,7 @@ from ..types.tools import ToolResult, ToolUse -class ToolCaller: +class _ToolCaller: """Provides common tool calling functionality that can be used by both traditional Agent and BidirectionalAgent classes with agent-specific customizations. From a49273b2fb3914166cd5269925f86e826b649710 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sun, 9 Nov 2025 17:23:18 -0500 Subject: [PATCH 071/242] Update method names imports for io.py and audio.py and their dependencies --- .../bidirectional_streaming/__init__.py | 7 ++- .../bidirectional_streaming/agent/agent.py | 14 ++--- .../bidirectional_streaming/io/audio.py | 56 +++++++++---------- .../bidirectional_streaming/types/__init__.py | 6 +- .../bidirectional_streaming/types/io.py | 11 +++- 5 files changed, 48 insertions(+), 46 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index caee4715a..0955a8939 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -3,6 +3,9 @@ # Main components - Primary user interface from .agent.agent import BidirectionalAgent +# IO channels - Hardware abstraction +from .io.audio import AudioIO + # Model interface (for custom implementations) from .models.bidirectional_model import BidirectionalModel @@ -27,7 +30,8 @@ __all__ = [ # Main interface "BidirectionalAgent", - + # IO channels + "AudioIO", # Model providers "GeminiLiveModel", "NovaSonicModel", @@ -43,7 +47,6 @@ "BidirectionalStreamEvent", "VoiceActivityEvent", "UsageMetricsEvent", - # Model interface "BidirectionalModel", ] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 8bc372a2a..aae029ab1 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -350,7 +350,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: for adapter in self._current_adapters: if hasattr(adapter, "cleanup"): try: - adapter.cleanup() + adapter.end() logger.debug(f"Cleaned up adapter: {type(adapter).__name__}") except Exception as adapter_error: logger.warning(f"Error cleaning up adapter: {adapter_error}") @@ -386,7 +386,7 @@ async def run(self, io_channels: list[BidiIO | tuple[Callable, Callable]]) -> No Args: io_channels: List containing either BidiIO instances or (sender, receiver) tuples. - - BidiIO: IO channel instance with input_channel(), output_channel(), and cleanup() methods + - BidiIO: IO channel instance with send(), receive(), and end() methods - tuple: (sender_callable, receiver_callable) for custom transport Example: @@ -410,7 +410,7 @@ async def run(self, io_channels: list[BidiIO | tuple[Callable, Callable]]) -> No transport = io_channels[0] # Set IO channel tracking for cleanup - if hasattr(transport, 'input_channel') and hasattr(transport, 'output_channel'): + if hasattr(transport, 'send') and hasattr(transport, 'receive'): self._current_adapters = [transport] # IO channel needs cleanup elif isinstance(transport, tuple) and len(transport) == 2: self._current_adapters = [] # Tuple needs no cleanup @@ -433,16 +433,16 @@ async def _run_with_transport( async def receive_from_agent(): """Receive events from agent and send to transport.""" async for event in self.receive(): - if hasattr(transport, 'output_channel'): - await transport.output_channel(event) + if hasattr(transport, 'receive'): + await transport.receive(event) else: await transport[0](event) async def send_to_agent(): """Receive events from transport and send to agent.""" while self.active: - if hasattr(transport, 'input_channel'): - event = await transport.input_channel() + if hasattr(transport, 'send'): + event = await transport.send() else: event = await transport[1]() await self.send(event) diff --git a/src/strands/experimental/bidirectional_streaming/io/audio.py b/src/strands/experimental/bidirectional_streaming/io/audio.py index 4194b293c..4fb60a2b5 100644 --- a/src/strands/experimental/bidirectional_streaming/io/audio.py +++ b/src/strands/experimental/bidirectional_streaming/io/audio.py @@ -66,7 +66,7 @@ def __init__( self.output_stream = None self.interrupted = False - def _setup_audio(self) -> None: + def start(self) -> None: """Setup PyAudio streams for input and output.""" if self.audio: return @@ -103,33 +103,10 @@ def _setup_audio(self) -> None: self._cleanup_audio() raise - def _cleanup_audio(self) -> None: - """Clean up PyAudio resources.""" - try: - if self.input_stream: - if self.input_stream.is_active(): - self.input_stream.stop_stream() - self.input_stream.close() - - if self.output_stream: - if self.output_stream.is_active(): - self.output_stream.stop_stream() - self.output_stream.close() - - if self.audio: - self.audio.terminate() - - self.input_stream = None - self.output_stream = None - self.audio = None - - except Exception as e: - logger.warning(f"Audio cleanup error: {e}") - - async def input_channel(self) -> dict: + async def send(self) -> dict: """Read audio from microphone.""" if not self.input_stream: - self._setup_audio() + self.start() try: audio_bytes = self.input_stream.read(self.chunk_size, exception_on_overflow=False) @@ -148,10 +125,10 @@ async def input_channel(self) -> dict: "channels": self.input_channels, } - async def output_channel(self, event: dict) -> None: + async def receive(self, event: dict) -> None: """Handle audio events with direct stream writing.""" if not self.output_stream: - self._setup_audio() + self.start() # Handle audio output if "audioOutput" in event and not self.interrupted: @@ -199,6 +176,25 @@ async def output_channel(self, event: dict) -> None: elif role.upper() == "USER": print(f"User: {text}") - def cleanup(self) -> None: + def end(self) -> None: """Clean up IO channel resources.""" - self._cleanup_audio() + try: + if self.input_stream: + if self.input_stream.is_active(): + self.input_stream.stop_stream() + self.input_stream.close() + + if self.output_stream: + if self.output_stream.is_active(): + self.output_stream.stop_stream() + self.output_stream.close() + + if self.audio: + self.audio.terminate() + + self.input_stream = None + self.output_stream = None + self.audio = None + + except Exception as e: + logger.warning(f"Audio cleanup error: {e}") \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index 62e293ef1..5879c0505 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -1,7 +1,6 @@ """Type definitions for bidirectional streaming.""" -from .audio_io import AudioIO -from .bidirectional_io import BidirectionalIO +from .io import BidiIO from .bidirectional_streaming import ( DEFAULT_CHANNELS, DEFAULT_SAMPLE_RATE, @@ -22,8 +21,7 @@ ) __all__ = [ - "AudioIO", - "BidirectionalIO", + "BidiIO", "AudioInputEvent", "AudioOutputEvent", "BidirectionalConnectionEndEvent", diff --git a/src/strands/experimental/bidirectional_streaming/types/io.py b/src/strands/experimental/bidirectional_streaming/types/io.py index 99639619f..98b9b28bd 100644 --- a/src/strands/experimental/bidirectional_streaming/types/io.py +++ b/src/strands/experimental/bidirectional_streaming/types/io.py @@ -16,7 +16,12 @@ class BidiIO(Protocol): WebSocket, etc.) while the agent handles model communication and logic. """ - async def input_channel(self) -> dict: + async def start(self) -> dict: + + """Setup IO channels for input and output.""" + ... + + async def send(self) -> dict: """Read input data from the IO channel source. Returns: @@ -24,7 +29,7 @@ async def input_channel(self) -> dict: """ ... - async def output_channel(self, event: dict) -> None: + async def receive(self, event: dict) -> None: """Process output event from the model through the IO channel. Args: @@ -32,7 +37,7 @@ async def output_channel(self, event: dict) -> None: """ ... - def cleanup(self) -> None: + def end(self) -> None: """Clean up IO channel resources. Called by the agent during shutdown to ensure proper From 986fc45052f29edf1847187cb313d83a9fa4aa67 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 10 Nov 2025 13:50:54 +0300 Subject: [PATCH 072/242] use input event in method signatures and update outdated comments --- .../bidirectional_streaming/agent/agent.py | 4 +- .../models/bidirectional_model.py | 3 +- .../models/gemini_live.py | 3 +- .../models/novasonic.py | 3 +- .../bidirectional_streaming/models/openai.py | 3 +- .../types/bidirectional_streaming.py | 44 +++++++++---------- 6 files changed, 32 insertions(+), 28 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index f0205f8a8..6b21ef86a 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -366,7 +366,7 @@ async def start(self) -> None: logger.debug("Conversation start - initializing session") self._session = await start_bidirectional_connection(self) - async def send(self, input_data: str | AudioInputEvent | ImageInputEvent | dict) -> None: + async def send(self, input_data: str | InputEvent | dict) -> None: """Send input to the model (text, audio, image, or event dict). Unified method for sending text, audio, and image input to the model during @@ -434,7 +434,7 @@ async def send(self, input_data: str | AudioInputEvent | ImageInputEvent | dict) f"Input must be a string, InputEvent (TextInputEvent/AudioInputEvent/ImageInputEvent), or event dict with 'type' field, got: {type(input_data)}" ) - async def receive(self) -> AsyncIterable["OutputEvent"]: + async def receive(self) -> AsyncIterable[OutputEvent]: """Receive events from the model including audio, text, and tool calls. Yields model output events processed by background tasks including audio output, diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 5956494b0..0d2f6cb94 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -21,6 +21,7 @@ from ..types.bidirectional_streaming import ( AudioInputEvent, ImageInputEvent, + InputEvent, OutputEvent, TextInputEvent, ) @@ -83,7 +84,7 @@ async def receive(self) -> AsyncIterable[OutputEvent]: async def send( self, - content: Union[TextInputEvent, AudioInputEvent, ImageInputEvent, ToolResultEvent], + content: InputEvent | ToolResultEvent, ) -> None: """Send content to the model over the active connection. diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 4337a6cfa..8ee2e224f 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -31,6 +31,7 @@ ConnectionStartEvent, ErrorEvent, ImageInputEvent, + InputEvent, InterruptionEvent, UsageEvent, TextInputEvent, @@ -330,7 +331,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic async def send( self, - content: Union[TextInputEvent, AudioInputEvent, ImageInputEvent, ToolResultEvent], + content: InputEvent | ToolResultEvent, ) -> None: """Unified send method for all content types. Sends the given inputs to Google Live API diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 944e45d4b..7ce2ce695 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -40,6 +40,7 @@ ConnectionStartEvent, ErrorEvent, ImageInputEvent, + InputEvent, InterruptionEvent, UsageEvent, OutputEvent, @@ -307,7 +308,7 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: async def send( self, - content: Union[TextInputEvent, AudioInputEvent, ImageInputEvent, ToolResultEvent], + content: InputEvent | ToolResultEvent, ) -> None: """Unified send method for all content types. Sends the given content to Nova Sonic. diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index d923605e5..0839d543c 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -25,6 +25,7 @@ ConnectionStartEvent, ErrorEvent, ImageInputEvent, + InputEvent, InterruptionEvent, UsageEvent, OutputEvent, @@ -511,7 +512,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEven async def send( self, - content: Union[TextInputEvent, AudioInputEvent, ImageInputEvent, ToolResultEvent], + content: InputEvent | ToolResultEvent, ) -> None: """Unified send method for all content types. Sends the given content to OpenAI. diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 5069ccd5e..410a39032 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -6,7 +6,7 @@ Key features: - Audio input/output events with standardized formats - Interruption detection and handling -- Session lifecycle management +- Connection lifecycle management - Provider-agnostic event types - Type-safe discriminated unions with TypedEvent - JSON-serializable events (audio/images stored as base64 strings) @@ -34,7 +34,7 @@ # ============================================================================ -# Input Events (sent via session.send()) +# Input Events (sent via agent.send()) # ============================================================================ @@ -145,7 +145,7 @@ def mime_type(self) -> str: # ============================================================================ -# Output Events (received via session.receive_events()) +# Output Events (received via agent.receive()) # ============================================================================ @@ -237,14 +237,13 @@ def channels(self) -> int: class TranscriptStreamEvent(ModelStreamEvent): """Audio transcription streaming (user or assistant speech). - Follows the same delta + current state pattern as TextStreamEvent and ToolUseStreamEvent - from core Strands. Supports incremental transcript updates for providers like OpenAI - that send partial transcripts before the final version. + Supports incremental transcript updates for providers that send partial + transcripts before the final version. Parameters: delta: The incremental transcript change (ContentBlockDelta). text: The delta text (same as delta content for convenience). - role: Who is speaking ("user" or "assistant"). Aligns with Message.role convention. + role: Who is speaking ("user" or "assistant"). is_final: Whether this is the final/complete transcript. current_transcript: The accumulated transcript text so far (None for first delta). """ @@ -504,18 +503,19 @@ def details(self) -> Optional[Dict[str, Any]]: # Type Unions # ============================================================================ -# Note: ToolResultEvent and ToolUseStreamEvent are reused from strands.types._events - -InputEvent = Union[TextInputEvent, AudioInputEvent, ImageInputEvent] - -OutputEvent = Union[ - ConnectionStartEvent, - ResponseStartEvent, - AudioStreamEvent, - TranscriptStreamEvent, - InterruptionEvent, - ResponseCompleteEvent, - UsageEvent, - ConnectionCloseEvent, - ErrorEvent, -] +# Note: ToolResultEvent is imported from strands.types._events and used alongside +# InputEvent in send() methods for sending tool results back to the model. + +InputEvent = TextInputEvent | AudioInputEvent | ImageInputEvent + +OutputEvent = ( + ConnectionStartEvent + | ResponseStartEvent + | AudioStreamEvent + | TranscriptStreamEvent + | InterruptionEvent + | ResponseCompleteEvent + | UsageEvent + | ConnectionCloseEvent + | ErrorEvent +) From 69965d23bee1b9c7fd7dea0994a6477426e73537 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 10 Nov 2025 16:27:57 +0300 Subject: [PATCH 073/242] fix(openai): Improve interruption handling --- .../bidirectional_streaming/models/openai.py | 38 +++++++++----- .../models/test_openai_realtime.py | 49 ++++++++++++++++--- 2 files changed, 68 insertions(+), 19 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 0839d543c..8ebe294a0 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -378,17 +378,20 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEven del self._function_call_buffer[call_id] return None - # Voice activity detection events - combine similar events using mapping - elif event_type in ["input_audio_buffer.speech_started", "input_audio_buffer.speech_stopped", - "input_audio_buffer.timeout_triggered"]: - # Map event types to activity types - activity_map = { - "input_audio_buffer.speech_started": "speech_started", - "input_audio_buffer.speech_stopped": "speech_stopped", - "input_audio_buffer.timeout_triggered": "timeout" - } - event = self._create_voice_activity_event(activity_map[event_type]) - return [event] if event else None + # Voice activity detection - speech_started triggers interruption + elif event_type == "input_audio_buffer.speech_started": + # This is the primary interruption signal - handle it first + return [InterruptionEvent(reason="user_speech")] + + # Response cancelled - handle interruption + elif event_type == "response.cancelled": + response = openai_event.get("response", {}) + response_id = response.get("id", "unknown") + logger.debug("OpenAI response cancelled: %s", response_id) + return [ResponseCompleteEvent( + response_id=response_id, + stop_reason="interrupted" + )] # Turn complete and usage - response finished elif event_type == "response.done": @@ -503,7 +506,18 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEven return None elif event_type == "error": - logger.error("OpenAI Realtime error: %s", openai_event.get("error", {})) + error_data = openai_event.get("error", {}) + error_code = error_data.get("code", "") + + # Suppress expected errors that don't affect session state + if error_code == "response_cancel_not_active": + # This happens when trying to cancel a response that's not active + # It's safe to ignore as the session remains functional + logger.debug("OpenAI response cancel attempted when no response active (safe to ignore)") + return None + + # Log other errors + logger.error("OpenAI Realtime error: %s", error_data) return None else: diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 2045424e1..dc165f536 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -18,9 +18,14 @@ from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeModel from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( AudioInputEvent, + AudioStreamEvent, ImageInputEvent, + InterruptionEvent, + ResponseCompleteEvent, TextInputEvent, + TranscriptStreamEvent, ) +from strands.types._events import ToolResultEvent from strands.types.tools import ToolResult @@ -222,8 +227,6 @@ async def async_connect(*args, **kwargs): @pytest.mark.asyncio async def test_send_all_content_types(mock_websockets_connect, model): """Test sending all content types through unified send() method.""" - from strands.types._events import ToolResultEvent - _, mock_ws = mock_websockets_connect await model.connect() @@ -343,7 +346,6 @@ async def test_event_conversion(mock_websockets_connect, model): await model.connect() # Test audio output (now returns list with AudioStreamEvent) - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioStreamEvent audio_event = { "type": "response.output_audio.delta", "delta": base64.b64encode(b"audio_data").decode() @@ -357,7 +359,6 @@ async def test_event_conversion(mock_websockets_connect, model): assert converted[0].get("format") == "pcm" # Test text output (now returns list with TranscriptStreamEvent) - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import TranscriptStreamEvent text_event = { "type": "response.output_text.delta", "delta": "Hello from OpenAI" @@ -407,7 +408,6 @@ async def test_event_conversion(mock_websockets_connect, model): assert tool_use["input"]["expression"] == "2+2" # Test voice activity (now returns list with InterruptionEvent for speech_started) - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import InterruptionEvent speech_started = { "type": "input_audio_buffer.speech_started" } @@ -418,6 +418,43 @@ async def test_event_conversion(mock_websockets_connect, model): assert converted[0].get("type") == "bidirectional_interruption" assert converted[0].get("reason") == "user_speech" + # Test response.cancelled event (should return ResponseCompleteEvent with interrupted reason) + response_cancelled = { + "type": "response.cancelled", + "response": { + "id": "resp_123" + } + } + converted = model._convert_openai_event(response_cancelled) + assert isinstance(converted, list) + assert len(converted) == 1 + assert isinstance(converted[0], ResponseCompleteEvent) + assert converted[0].get("type") == "bidirectional_response_complete" + assert converted[0].get("response_id") == "resp_123" + assert converted[0].get("stop_reason") == "interrupted" + + # Test error handling - response_cancel_not_active should be suppressed + error_cancel_not_active = { + "type": "error", + "error": { + "code": "response_cancel_not_active", + "message": "No active response to cancel" + } + } + converted = model._convert_openai_event(error_cancel_not_active) + assert converted is None # Should be suppressed + + # Test error handling - other errors should be logged but return None + error_other = { + "type": "error", + "error": { + "code": "some_other_error", + "message": "Something went wrong" + } + } + converted = model._convert_openai_event(error_other) + assert converted is None + await model.close() @@ -465,7 +502,6 @@ def test_helper_methods(model): model._active = False # Test _create_text_event (now returns TranscriptStreamEvent) - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import TranscriptStreamEvent text_event = model._create_text_event("Hello", "user") assert isinstance(text_event, TranscriptStreamEvent) assert text_event.get("type") == "bidirectional_transcript_stream" @@ -476,7 +512,6 @@ def test_helper_methods(model): assert text_event.current_transcript == "Hello" # Test _create_voice_activity_event (now returns InterruptionEvent for speech_started) - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import InterruptionEvent voice_event = model._create_voice_activity_event("speech_started") assert isinstance(voice_event, InterruptionEvent) assert voice_event.get("type") == "bidirectional_interruption" From f7c18d46ca03d77bb048d7d45ab48e7ad46a1a01 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 10 Nov 2025 16:57:54 +0300 Subject: [PATCH 074/242] fix: improve gemini test script to display interrupts and remove excessive logging --- .../bidirectional_streaming/tests/test_gemini_live.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py index e9d715b49..bf427812c 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py @@ -41,9 +41,9 @@ from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveModel # Configure logging - debug only for Gemini Live, info for everything else -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logging.basicConfig(level=logging.WARN, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') gemini_logger = logging.getLogger('strands.experimental.bidirectional_streaming.models.gemini_live') -gemini_logger.setLevel(logging.DEBUG) +gemini_logger.setLevel(logging.WARN) logger = logging.getLogger(__name__) @@ -162,7 +162,7 @@ async def receive(agent, context): # Handle interruption events (bidirectional_interruption) elif event_type == "bidirectional_interruption": context["interrupted"] = True - logger.info("Interruption detected") + print("⚠️ Interruption detected") # Handle transcript events (bidirectional_transcript_stream) elif event_type == "bidirectional_transcript_stream": From 843133b384a15ab2b3c65d17b8ebafd13b2df596 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Mon, 10 Nov 2025 11:56:34 -0500 Subject: [PATCH 075/242] Move test scripts into dedicated directory so tests directory only has unit tests and integ tests --- .../bidirectional_streaming/{tests => scripts}/test_bidi.py | 0 .../{tests => scripts}/test_bidi_novasonic.py | 0 .../{tests => scripts}/test_bidi_openai.py | 0 .../{tests => scripts}/test_gemini_live.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename src/strands/experimental/bidirectional_streaming/{tests => scripts}/test_bidi.py (100%) rename src/strands/experimental/bidirectional_streaming/{tests => scripts}/test_bidi_novasonic.py (100%) rename src/strands/experimental/bidirectional_streaming/{tests => scripts}/test_bidi_openai.py (100%) rename src/strands/experimental/bidirectional_streaming/{tests => scripts}/test_gemini_live.py (100%) diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi.py b/src/strands/experimental/bidirectional_streaming/scripts/test_bidi.py similarity index 100% rename from src/strands/experimental/bidirectional_streaming/tests/test_bidi.py rename to src/strands/experimental/bidirectional_streaming/scripts/test_bidi.py diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py b/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_novasonic.py similarity index 100% rename from src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py rename to src/strands/experimental/bidirectional_streaming/scripts/test_bidi_novasonic.py diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_openai.py similarity index 100% rename from src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py rename to src/strands/experimental/bidirectional_streaming/scripts/test_bidi_openai.py diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py b/src/strands/experimental/bidirectional_streaming/scripts/test_gemini_live.py similarity index 100% rename from src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py rename to src/strands/experimental/bidirectional_streaming/scripts/test_gemini_live.py From f8ab2a0d685345e5ae0089c211b19aedb51abf04 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 10 Nov 2025 20:26:03 +0300 Subject: [PATCH 076/242] refactor: rename events and files --- .../bidirectional_streaming/__init__.py | 50 ++++++++-------- .../bidirectional_streaming/agent/agent.py | 28 ++++----- .../models/bidirectional_model.py | 20 +++---- .../models/gemini_live.py | 56 ++++++++--------- .../models/novasonic.py | 58 +++++++++--------- .../bidirectional_streaming/models/openai.py | 60 +++++++++---------- .../tests/test_bidi_novasonic.py | 4 +- .../tests/test_bidi_openai.py | 4 +- .../tests/test_gemini_live.py | 8 +-- .../bidirectional_streaming/types/__init__.py | 50 ++++++++-------- .../{bidirectional_streaming.py => events.py} | 46 +++++++------- .../models/test_gemini_live.py | 40 ++++++------- .../models/test_novasonic.py | 54 ++++++++--------- .../models/test_openai_realtime.py | 46 +++++++------- ...irectional_streaming.py => test_events.py} | 60 +++++++++---------- .../bidirectional_streaming/conftest.py | 2 +- .../{utils/test_context.py => context.py} | 7 +-- .../generators/__init__.py | 1 + .../audio.py} | 8 +-- .../test_bidirectional_agent.py | 2 +- .../bidirectional_streaming/utils/__init__.py | 1 - .../wrappers/__init__.py | 4 ++ 22 files changed, 305 insertions(+), 304 deletions(-) rename src/strands/experimental/bidirectional_streaming/types/{bidirectional_streaming.py => events.py} (94%) rename tests/strands/experimental/bidirectional_streaming/types/{test_bidirectional_streaming.py => test_events.py} (73%) rename tests_integ/bidirectional_streaming/{utils/test_context.py => context.py} (98%) create mode 100644 tests_integ/bidirectional_streaming/generators/__init__.py rename tests_integ/bidirectional_streaming/{utils/audio_generator.py => generators/audio.py} (95%) delete mode 100644 tests_integ/bidirectional_streaming/utils/__init__.py create mode 100644 tests_integ/bidirectional_streaming/wrappers/__init__.py diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 31b9ead32..7030a3864 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -12,22 +12,22 @@ from .models.openai import OpenAIRealtimeModel # Event types - For type hints and event handling -from .types.bidirectional_streaming import ( - AudioInputEvent, - AudioStreamEvent, - ConnectionCloseEvent, - ConnectionStartEvent, - ErrorEvent, - ImageInputEvent, +from .types.events import ( + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionCloseEvent, + BidiConnectionStartEvent, + BidiErrorEvent, + BidiImageInputEvent, InputEvent, - InterruptionEvent, + BidiInterruptionEvent, ModalityUsage, - UsageEvent, + BidiUsageEvent, OutputEvent, - ResponseCompleteEvent, - ResponseStartEvent, - TextInputEvent, - TranscriptStreamEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, ) # Re-export standard agent events for tool handling @@ -47,22 +47,22 @@ "OpenAIRealtimeModel", # Input Event types - "TextInputEvent", - "AudioInputEvent", - "ImageInputEvent", + "BidiTextInputEvent", + "BidiAudioInputEvent", + "BidiImageInputEvent", "InputEvent", # Output Event types - "ConnectionStartEvent", - "ConnectionCloseEvent", - "ResponseStartEvent", - "ResponseCompleteEvent", - "AudioStreamEvent", - "TranscriptStreamEvent", - "InterruptionEvent", - "UsageEvent", + "BidiConnectionStartEvent", + "BidiConnectionCloseEvent", + "BidiResponseStartEvent", + "BidiResponseCompleteEvent", + "BidiAudioStreamEvent", + "BidiTranscriptStreamEvent", + "BidiInterruptionEvent", + "BidiUsageEvent", "ModalityUsage", - "ErrorEvent", + "BidiErrorEvent", "OutputEvent", # Tool Event types (reused from standard agent) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 6b21ef86a..58ba5a070 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -31,12 +31,12 @@ from ....types.traces import AttributeValue from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection from ..models.bidirectional_model import BidirectionalModel -from ..types.bidirectional_streaming import ( - AudioInputEvent, - ImageInputEvent, +from ..types.events import ( + BidiAudioInputEvent, + BidiImageInputEvent, InputEvent, OutputEvent, - TextInputEvent, + BidiTextInputEvent, ) logger = logging.getLogger(__name__) @@ -376,8 +376,8 @@ async def send(self, input_data: str | InputEvent | dict) -> None: Args: input_data: Can be: - str: Text message from user - - AudioInputEvent: Audio data with format/sample rate - - ImageInputEvent: Image data with MIME type + - BidiAudioInputEvent: Audio data with format/sample rate + - BidiImageInputEvent: Image data with MIME type - dict: Event dictionary (will be reconstructed to TypedEvent) Raises: @@ -385,7 +385,7 @@ async def send(self, input_data: str | InputEvent | dict) -> None: Example: await agent.send("Hello") - await agent.send(AudioInputEvent(audio="base64...", format="pcm", ...)) + await agent.send(BidiAudioInputEvent(audio="base64...", format="pcm", ...)) await agent.send({"type": "bidirectional_text_input", "text": "Hello", "role": "user"}) """ self._validate_active_session() @@ -395,13 +395,13 @@ async def send(self, input_data: str | InputEvent | dict) -> None: # Add user text message to history self.messages.append({"role": "user", "content": input_data}) logger.debug("Text sent: %d characters", len(input_data)) - text_event = TextInputEvent(text=input_data, role="user") + text_event = BidiTextInputEvent(text=input_data, role="user") await self._session.model.send(text_event) return - # Handle InputEvent instances (TextInputEvent, AudioInputEvent, ImageInputEvent) + # Handle InputEvent instances (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent) # Check this before dict since TypedEvent inherits from dict - if isinstance(input_data, (TextInputEvent, AudioInputEvent, ImageInputEvent)): + if isinstance(input_data, (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent)): await self._session.model.send(input_data) return @@ -409,16 +409,16 @@ async def send(self, input_data: str | InputEvent | dict) -> None: if isinstance(input_data, dict) and "type" in input_data: event_type = input_data["type"] if event_type == "bidirectional_text_input": - input_data = TextInputEvent(text=input_data["text"], role=input_data["role"]) + input_data = BidiTextInputEvent(text=input_data["text"], role=input_data["role"]) elif event_type == "bidirectional_audio_input": - input_data = AudioInputEvent( + input_data = BidiAudioInputEvent( audio=input_data["audio"], format=input_data["format"], sample_rate=input_data["sample_rate"], channels=input_data["channels"] ) elif event_type == "bidirectional_image_input": - input_data = ImageInputEvent( + input_data = BidiImageInputEvent( image=input_data["image"], mime_type=input_data["mime_type"] ) @@ -431,7 +431,7 @@ async def send(self, input_data: str | InputEvent | dict) -> None: # If we get here, input type is invalid raise ValueError( - f"Input must be a string, InputEvent (TextInputEvent/AudioInputEvent/ImageInputEvent), or event dict with 'type' field, got: {type(input_data)}" + f"Input must be a string, InputEvent (BidiTextInputEvent/BidiAudioInputEvent/BidiImageInputEvent), or event dict with 'type' field, got: {type(input_data)}" ) async def receive(self) -> AsyncIterable[OutputEvent]: diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 0d2f6cb94..ad385019e 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -18,12 +18,12 @@ from ....types._events import ToolResultEvent from ....types.content import Messages from ....types.tools import ToolSpec -from ..types.bidirectional_streaming import ( - AudioInputEvent, - ImageInputEvent, +from ..types.events import ( + BidiAudioInputEvent, + BidiImageInputEvent, InputEvent, OutputEvent, - TextInputEvent, + BidiTextInputEvent, ) logger = logging.getLogger(__name__) @@ -94,15 +94,15 @@ async def send( Args: content: The content to send. Must be one of: - - TextInputEvent: Text message from the user - - AudioInputEvent: Audio data for speech input - - ImageInputEvent: Image data for visual understanding + - BidiTextInputEvent: Text message from the user + - BidiAudioInputEvent: Audio data for speech input + - BidiImageInputEvent: Image data for visual understanding - ToolResultEvent: Result from a tool execution Example: - await model.send(TextInputEvent(text="Hello", role="user")) - await model.send(AudioInputEvent(audio=bytes, format="pcm", sample_rate=16000, channels=1)) - await model.send(ImageInputEvent(image=bytes, mime_type="image/jpeg", encoding="raw")) + await model.send(BidiTextInputEvent(text="Hello", role="user")) + await model.send(BidiAudioInputEvent(audio=bytes, format="pcm", sample_rate=16000, channels=1)) + await model.send(BidiImageInputEvent(image=bytes, mime_type="image/jpeg", encoding="raw")) await model.send(ToolResultEvent(tool_result)) """ ... diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 8ee2e224f..e298bf74d 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -24,20 +24,20 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from ....types._events import ToolResultEvent, ToolUseStreamEvent -from ..types.bidirectional_streaming import ( - AudioInputEvent, - AudioStreamEvent, - ConnectionCloseEvent, - ConnectionStartEvent, - ErrorEvent, - ImageInputEvent, +from ..types.events import ( + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionCloseEvent, + BidiConnectionStartEvent, + BidiErrorEvent, + BidiImageInputEvent, InputEvent, - InterruptionEvent, - UsageEvent, - TextInputEvent, - TranscriptStreamEvent, - ResponseCompleteEvent, - ResponseStartEvent, + BidiInterruptionEvent, + BidiUsageEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, ) from .bidirectional_model import BidirectionalModel @@ -165,7 +165,7 @@ async def receive(self) -> AsyncIterable[Dict[str, Any]]: """Receive Gemini Live API events and convert to provider-agnostic format.""" # Emit connection start event - yield ConnectionStartEvent( + yield BidiConnectionStartEvent( connection_id=self.connection_id, model=self.model_id ) @@ -194,10 +194,10 @@ async def receive(self) -> AsyncIterable[Dict[str, Any]]: except Exception as e: logger.error("Fatal error in receive loop: %s", e) - yield ErrorEvent(error=e) + yield BidiErrorEvent(error=e) finally: # Emit connection close event when exiting - yield ConnectionCloseEvent(connection_id=self.connection_id, reason="complete") + yield BidiConnectionCloseEvent(connection_id=self.connection_id, reason="complete") def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dict[str, Any]]: """Convert Gemini Live API events to provider-agnostic format. @@ -211,7 +211,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic try: # Handle interruption first (from server_content) if message.server_content and message.server_content.interrupted: - return InterruptionEvent(reason="user_speech") + return BidiInterruptionEvent(reason="user_speech") # Handle input transcription (user's speech) - emit as transcript event if message.server_content and message.server_content.input_transcription: @@ -220,7 +220,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic if hasattr(input_transcript, 'text') and input_transcript.text: transcription_text = input_transcript.text logger.debug(f"Input transcription detected: {transcription_text}") - return TranscriptStreamEvent( + return BidiTranscriptStreamEvent( delta={"text": transcription_text}, text=transcription_text, role="user", @@ -235,7 +235,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic if hasattr(output_transcript, 'text') and output_transcript.text: transcription_text = output_transcript.text logger.debug(f"Output transcription detected: {transcription_text}") - return TranscriptStreamEvent( + return BidiTranscriptStreamEvent( delta={"text": transcription_text}, text=transcription_text, role="assistant", @@ -246,7 +246,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic # Handle text output from model if message.text: logger.debug(f"Text output as transcript: {message.text}") - return TranscriptStreamEvent( + return BidiTranscriptStreamEvent( delta={"text": message.text}, text=message.text, role="assistant", @@ -258,7 +258,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic if message.data: # Convert bytes to base64 string for JSON serializability audio_b64 = base64.b64encode(message.data).decode('utf-8') - return AudioStreamEvent( + return BidiAudioStreamEvent( audio=audio_b64, format="pcm", sample_rate=GEMINI_OUTPUT_SAMPLE_RATE, @@ -312,7 +312,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic "output_tokens": detail.token_count }) - return UsageEvent( + return BidiUsageEvent( input_tokens=usage.prompt_token_count or 0, output_tokens=usage.response_token_count or 0, total_tokens=usage.total_token_count or 0, @@ -338,17 +338,17 @@ async def send( Dispatches to appropriate internal handler based on content type. Args: - content: Typed event (TextInputEvent, AudioInputEvent, ImageInputEvent, or ToolResultEvent). + content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent). """ if not self._active: return try: - if isinstance(content, TextInputEvent): + if isinstance(content, BidiTextInputEvent): await self._send_text_content(content.text) - elif isinstance(content, AudioInputEvent): + elif isinstance(content, BidiAudioInputEvent): await self._send_audio_content(content) - elif isinstance(content, ImageInputEvent): + elif isinstance(content, BidiImageInputEvent): await self._send_image_content(content) elif isinstance(content, ToolResultEvent): tool_result = content.get("tool_result") @@ -360,7 +360,7 @@ async def send( logger.error(f"Error sending content: {e}") raise # Propagate exception for debugging in experimental code - async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: + async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: """Internal: Send audio content using Gemini Live API. Gemini Live expects continuous audio streaming via send_realtime_input. @@ -382,7 +382,7 @@ async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: except Exception as e: logger.error("Error sending audio content: %s", e) - async def _send_image_content(self, image_input: ImageInputEvent) -> None: + async def _send_image_content(self, image_input: BidiImageInputEvent) -> None: """Internal: Send image content using Gemini Live API. Sends image frames following the same pattern as the GitHub example. diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 7ce2ce695..d5054721f 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -33,21 +33,21 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from ....types._events import ToolResultEvent, ToolUseStreamEvent -from ..types.bidirectional_streaming import ( - AudioInputEvent, - AudioStreamEvent, - ConnectionCloseEvent, - ConnectionStartEvent, - ErrorEvent, - ImageInputEvent, +from ..types.events import ( + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionCloseEvent, + BidiConnectionStartEvent, + BidiErrorEvent, + BidiImageInputEvent, InputEvent, - InterruptionEvent, - UsageEvent, + BidiInterruptionEvent, + BidiUsageEvent, OutputEvent, - TextInputEvent, - TranscriptStreamEvent, - ResponseCompleteEvent, - ResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, ) from .bidirectional_model import BidirectionalModel @@ -278,7 +278,7 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: logger.debug("Nova events - starting event stream") # Emit connection start event - yield ConnectionStartEvent( + yield BidiConnectionStartEvent( connection_id=self.connection_id, model=self.model_id ) @@ -301,10 +301,10 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: except Exception as e: logger.error("Error receiving Nova Sonic event: %s", e) logger.error(traceback.format_exc()) - yield ErrorEvent(error=e) + yield BidiErrorEvent(error=e) finally: # Emit connection close event - yield ConnectionCloseEvent(connection_id=self.connection_id, reason="complete") + yield BidiConnectionCloseEvent(connection_id=self.connection_id, reason="complete") async def send( self, @@ -315,18 +315,18 @@ async def send( Dispatches to appropriate internal handler based on content type. Args: - content: Typed event (TextInputEvent, AudioInputEvent, ImageInputEvent, or ToolResultEvent). + content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent). """ if not self._active: return try: - if isinstance(content, TextInputEvent): + if isinstance(content, BidiTextInputEvent): await self._send_text_content(content.text) - elif isinstance(content, AudioInputEvent): + elif isinstance(content, BidiAudioInputEvent): await self._send_audio_content(content) - elif isinstance(content, ImageInputEvent): - # ImageInputEvent - not supported by Nova Sonic + elif isinstance(content, BidiImageInputEvent): + # BidiImageInputEvent - not supported by Nova Sonic logger.warning("Image input not supported by Nova Sonic") elif isinstance(content, ToolResultEvent): tool_result = content.get("tool_result") @@ -363,7 +363,7 @@ async def _start_audio_connection(self) -> None: await self._send_nova_event(audio_content_start) self.audio_connection_active = True - async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: + async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: """Internal: Send audio using Nova Sonic protocol-specific format.""" # Start audio connection if not already active if not self.audio_connection_active: @@ -528,7 +528,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: completion_id = completion_data.get("completionId", self._current_completion_id) stop_reason = completion_data.get("stopReason", "END_TURN") - event = ResponseCompleteEvent( + event = BidiResponseCompleteEvent( response_id=completion_id or str(uuid.uuid4()), # Fallback to UUID if missing stop_reason="interrupted" if stop_reason == "INTERRUPTED" else "complete" ) @@ -541,7 +541,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: if "audioOutput" in nova_event: # Audio is already base64 string from Nova Sonic audio_content = nova_event["audioOutput"]["content"] - return AudioStreamEvent( + return BidiAudioStreamEvent( audio=audio_content, format="pcm", sample_rate=24000, @@ -557,9 +557,9 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: # Check for Nova Sonic interruption pattern if '{ "interrupted" : true }' in text_content: logger.debug("Nova interruption detected in text") - return InterruptionEvent(reason="user_speech") + return BidiInterruptionEvent(reason="user_speech") - return TranscriptStreamEvent( + return BidiTranscriptStreamEvent( delta={"text": text_content}, text=text_content, role="user" if role == "USER" else "assistant", @@ -584,7 +584,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: # Handle interruption elif nova_event.get("stopReason") == "INTERRUPTED": logger.debug("Nova interruption stop reason") - return InterruptionEvent(reason="user_speech") + return BidiInterruptionEvent(reason="user_speech") # Handle usage events - convert to multimodal usage format elif "usageEvent" in nova_event: @@ -592,7 +592,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: total_input = usage_data.get("totalInputTokens", 0) total_output = usage_data.get("totalOutputTokens", 0) - return UsageEvent( + return BidiUsageEvent( input_tokens=total_input, output_tokens=total_output, total_tokens=usage_data.get("totalTokens", total_input + total_output) @@ -607,7 +607,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: # Emit response start event using API-provided completionId # completionId should already be tracked from completionStart event - return ResponseStartEvent( + return BidiResponseStartEvent( response_id=self._current_completion_id or str(uuid.uuid4()) # Fallback to UUID if missing ) diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 0839d543c..3fe6f5f78 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -18,21 +18,21 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from ....types._events import ToolResultEvent, ToolUseStreamEvent -from ..types.bidirectional_streaming import ( - AudioInputEvent, - AudioStreamEvent, - ConnectionCloseEvent, - ConnectionStartEvent, - ErrorEvent, - ImageInputEvent, +from ..types.events import ( + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionCloseEvent, + BidiConnectionStartEvent, + BidiErrorEvent, + BidiImageInputEvent, InputEvent, - InterruptionEvent, - UsageEvent, + BidiInterruptionEvent, + BidiUsageEvent, OutputEvent, - TextInputEvent, - TranscriptStreamEvent, - ResponseCompleteEvent, - ResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, ) from .bidirectional_model import BidirectionalModel @@ -173,9 +173,9 @@ def _require_active(self) -> bool: """Check if session is active.""" return self._active - def _create_text_event(self, text: str, role: str, is_final: bool = True) -> TranscriptStreamEvent: + def _create_text_event(self, text: str, role: str, is_final: bool = True) -> BidiTranscriptStreamEvent: """Create standardized transcript event.""" - return TranscriptStreamEvent( + return BidiTranscriptStreamEvent( delta={"text": text}, text=text, role="user" if role == "user" else "assistant", @@ -183,11 +183,11 @@ def _create_text_event(self, text: str, role: str, is_final: bool = True) -> Tra current_transcript=text if is_final else None ) - def _create_voice_activity_event(self, activity_type: str) -> InterruptionEvent | None: + def _create_voice_activity_event(self, activity_type: str) -> BidiInterruptionEvent | None: """Create standardized interruption event for voice activity.""" # Only speech_started triggers interruption if activity_type == "speech_started": - return InterruptionEvent(reason="user_speech") + return BidiInterruptionEvent(reason="user_speech") # Other voice activity events are logged but don't create events return None @@ -283,7 +283,7 @@ async def _process_responses(self) -> None: async def receive(self) -> AsyncIterable[OutputEvent]: """Receive OpenAI events and convert to Strands TypedEvent format.""" # Emit connection start event - yield ConnectionStartEvent( + yield BidiConnectionStartEvent( connection_id=self.connection_id, model=self.model ) @@ -299,10 +299,10 @@ async def receive(self) -> AsyncIterable[OutputEvent]: except Exception as e: logger.error("Error receiving OpenAI Realtime event: %s", e) - yield ErrorEvent(error=e) + yield BidiErrorEvent(error=e) finally: # Emit connection close event - yield ConnectionCloseEvent(connection_id=self.connection_id, reason="complete") + yield BidiConnectionCloseEvent(connection_id=self.connection_id, reason="complete") def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEvent] | None: """Convert OpenAI events to Strands TypedEvent format.""" @@ -312,12 +312,12 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEven if event_type == "response.created": response = openai_event.get("response", {}) response_id = response.get("id", str(uuid.uuid4())) - return [ResponseStartEvent(response_id=response_id)] + return [BidiResponseStartEvent(response_id=response_id)] # Audio output elif event_type == "response.output_audio.delta": # Audio is already base64 string from OpenAI - return [AudioStreamEvent( + return [BidiAudioStreamEvent( audio=openai_event["delta"], format="pcm", sample_rate=24000, @@ -409,7 +409,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEven events = [] # Always add response complete event - events.append(ResponseCompleteEvent( + events.append(BidiResponseCompleteEvent( response_id=response_id, stop_reason=stop_reason_map.get(status, "complete") )) @@ -455,7 +455,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEven cached_tokens = input_details.get("cached_tokens", 0) # Add usage event - events.append(UsageEvent( + events.append(BidiUsageEvent( input_tokens=usage.get("input_tokens", 0), output_tokens=usage.get("output_tokens", 0), total_tokens=usage.get("total_tokens", 0), @@ -519,19 +519,19 @@ async def send( Dispatches to appropriate internal handler based on content type. Args: - content: Typed event (TextInputEvent, AudioInputEvent, ImageInputEvent, or ToolResultEvent). + content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent). """ if not self._require_active(): return try: # Note: TypedEvent inherits from dict, so isinstance checks for TypedEvent must come first - if isinstance(content, TextInputEvent): + if isinstance(content, BidiTextInputEvent): await self._send_text_content(content.text) - elif isinstance(content, AudioInputEvent): + elif isinstance(content, BidiAudioInputEvent): await self._send_audio_content(content) - elif isinstance(content, ImageInputEvent): - # ImageInputEvent - not supported by OpenAI Realtime yet + elif isinstance(content, BidiImageInputEvent): + # BidiImageInputEvent - not supported by OpenAI Realtime yet logger.warning("Image input not supported by OpenAI Realtime API") elif isinstance(content, ToolResultEvent): tool_result = content.get("tool_result") @@ -543,7 +543,7 @@ async def send( logger.error(f"Error sending content: {e}") raise # Propagate exception for debugging in experimental code - async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: + async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: """Internal: Send audio content to OpenAI for processing.""" # Audio is already base64 encoded in the event await self._send_event({"type": "input_audio_buffer.append", "audio": audio_input.audio}) diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py index e5a2e7c46..d210796f3 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py @@ -172,10 +172,10 @@ async def send(agent, context): try: audio_bytes = context["audio_in"].get_nowait() # Create audio event using TypedEvent - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioInputEvent + from strands.experimental.bidirectional_streaming.types.events import BidiAudioInputEvent audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') - audio_event = AudioInputEvent( + audio_event = BidiAudioInputEvent( audio=audio_b64, format="pcm", sample_rate=16000, diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py index d270637be..c0d38131a 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py @@ -179,10 +179,10 @@ async def send(agent, context): # Create audio event using TypedEvent # Encode audio bytes to base64 string for JSON serializability - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioInputEvent + from strands.experimental.bidirectional_streaming.types.events import BidiAudioInputEvent audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') - audio_event = AudioInputEvent( + audio_event = BidiAudioInputEvent( audio=audio_b64, format="pcm", sample_rate=24000, diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py index 38791d9ed..bec9dd98f 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py @@ -248,9 +248,9 @@ async def get_frames(context): # Send frame to agent as image input try: - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ImageInputEvent + from strands.experimental.bidirectional_streaming.types.events import BidiImageInputEvent - image_event = ImageInputEvent( + image_event = BidiImageInputEvent( image=frame["data"], # Already base64 encoded mime_type=frame["mime_type"] ) @@ -276,10 +276,10 @@ async def send(agent, context): try: audio_bytes = context["audio_in"].get_nowait() # Create audio event using TypedEvent - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioInputEvent + from strands.experimental.bidirectional_streaming.types.events import BidiAudioInputEvent audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') - audio_event = AudioInputEvent( + audio_event = BidiAudioInputEvent( audio=audio_b64, format="pcm", sample_rate=16000, diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index 0a2abb68f..01e16d224 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -1,46 +1,46 @@ """Type definitions for bidirectional streaming.""" -from .bidirectional_streaming import ( +from .events import ( DEFAULT_CHANNELS, DEFAULT_FORMAT, DEFAULT_SAMPLE_RATE, SUPPORTED_AUDIO_FORMATS, SUPPORTED_CHANNELS, SUPPORTED_SAMPLE_RATES, - AudioInputEvent, - AudioStreamEvent, - ConnectionCloseEvent, - ConnectionStartEvent, - ErrorEvent, - ImageInputEvent, + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionCloseEvent, + BidiConnectionStartEvent, + BidiErrorEvent, + BidiImageInputEvent, InputEvent, - InterruptionEvent, + BidiInterruptionEvent, ModalityUsage, - UsageEvent, + BidiUsageEvent, OutputEvent, - ResponseCompleteEvent, - ResponseStartEvent, - TextInputEvent, - TranscriptStreamEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, ) __all__ = [ # Input Events - "TextInputEvent", - "AudioInputEvent", - "ImageInputEvent", + "BidiTextInputEvent", + "BidiAudioInputEvent", + "BidiImageInputEvent", "InputEvent", # Output Events - "ConnectionStartEvent", - "ConnectionCloseEvent", - "ResponseStartEvent", - "ResponseCompleteEvent", - "AudioStreamEvent", - "TranscriptStreamEvent", - "InterruptionEvent", - "UsageEvent", + "BidiConnectionStartEvent", + "BidiConnectionCloseEvent", + "BidiResponseStartEvent", + "BidiResponseCompleteEvent", + "BidiAudioStreamEvent", + "BidiTranscriptStreamEvent", + "BidiInterruptionEvent", + "BidiUsageEvent", "ModalityUsage", - "ErrorEvent", + "BidiErrorEvent", "OutputEvent", # Constants "SUPPORTED_AUDIO_FORMATS", diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/events.py similarity index 94% rename from src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py rename to src/strands/experimental/bidirectional_streaming/types/events.py index 410a39032..5275d8e58 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/events.py @@ -38,7 +38,7 @@ # ============================================================================ -class TextInputEvent(TypedEvent): +class BidiTextInputEvent(TypedEvent): """Text input event for sending text to the model. Used for sending text content through the send() method. @@ -66,7 +66,7 @@ def role(self) -> str: return cast(str, self.get("role")) -class AudioInputEvent(TypedEvent): +class BidiAudioInputEvent(TypedEvent): """Audio input event for sending audio to the model. Used for sending audio data through the send() method. @@ -112,7 +112,7 @@ def channels(self) -> int: return cast(int, self.get("channels")) -class ImageInputEvent(TypedEvent): +class BidiImageInputEvent(TypedEvent): """Image input event for sending images/video frames to the model. Used for sending image data through the send() method. @@ -149,7 +149,7 @@ def mime_type(self) -> str: # ============================================================================ -class ConnectionStartEvent(TypedEvent): +class BidiConnectionStartEvent(TypedEvent): """Streaming connection established and ready for interaction. Parameters: @@ -175,7 +175,7 @@ def model(self) -> str: return cast(str, self.get("model")) -class ResponseStartEvent(TypedEvent): +class BidiResponseStartEvent(TypedEvent): """Model starts generating a response. Parameters: @@ -190,7 +190,7 @@ def response_id(self) -> str: return cast(str, self.get("response_id")) -class AudioStreamEvent(TypedEvent): +class BidiAudioStreamEvent(TypedEvent): """Streaming audio output from the model. Parameters: @@ -234,7 +234,7 @@ def channels(self) -> int: return cast(int, self.get("channels")) -class TranscriptStreamEvent(ModelStreamEvent): +class BidiTranscriptStreamEvent(ModelStreamEvent): """Audio transcription streaming (user or assistant speech). Supports incremental transcript updates for providers that send partial @@ -288,7 +288,7 @@ def current_transcript(self) -> Optional[str]: return cast(Optional[str], self.get("current_transcript")) -class InterruptionEvent(TypedEvent): +class BidiInterruptionEvent(TypedEvent): """Model generation was interrupted. Parameters: @@ -309,7 +309,7 @@ def reason(self) -> str: return cast(str, self.get("reason")) -class ResponseCompleteEvent(TypedEvent): +class BidiResponseCompleteEvent(TypedEvent): """Model finished generating response. Parameters: @@ -353,7 +353,7 @@ class ModalityUsage(dict): output_tokens: int -class UsageEvent(TypedEvent): +class BidiUsageEvent(TypedEvent): """Token usage event with modality breakdown for bidirectional streaming. Tracks token consumption across different modalities (audio, text, images) @@ -416,11 +416,11 @@ def cache_write_input_tokens(self) -> Optional[int]: return cast(Optional[int], self.get("cacheWriteInputTokens")) -class ConnectionCloseEvent(TypedEvent): +class BidiConnectionCloseEvent(TypedEvent): """Streaming connection closed. Parameters: - connection_id: Unique identifier for this streaming connection (matches ConnectionStartEvent). + connection_id: Unique identifier for this streaming connection (matches BidiConnectionStartEvent). reason: Why the connection was closed. """ @@ -446,7 +446,7 @@ def reason(self) -> str: return cast(str, self.get("reason")) -class ErrorEvent(TypedEvent): +class BidiErrorEvent(TypedEvent): """Error occurred during the session. Stores the full Exception object as an instance attribute for debugging while @@ -506,16 +506,16 @@ def details(self) -> Optional[Dict[str, Any]]: # Note: ToolResultEvent is imported from strands.types._events and used alongside # InputEvent in send() methods for sending tool results back to the model. -InputEvent = TextInputEvent | AudioInputEvent | ImageInputEvent +InputEvent = BidiTextInputEvent | BidiAudioInputEvent | BidiImageInputEvent OutputEvent = ( - ConnectionStartEvent - | ResponseStartEvent - | AudioStreamEvent - | TranscriptStreamEvent - | InterruptionEvent - | ResponseCompleteEvent - | UsageEvent - | ConnectionCloseEvent - | ErrorEvent + BidiConnectionStartEvent + | BidiResponseStartEvent + | BidiAudioStreamEvent + | BidiTranscriptStreamEvent + | BidiInterruptionEvent + | BidiResponseCompleteEvent + | BidiUsageEvent + | BidiConnectionCloseEvent + | BidiErrorEvent ) diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index 107a8a84a..b6280b2ee 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -14,10 +14,10 @@ from google.genai import types as genai_types from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveModel -from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( - AudioInputEvent, - ImageInputEvent, - TextInputEvent, +from strands.experimental.bidirectional_streaming.types.events import ( + BidiAudioInputEvent, + BidiImageInputEvent, + BidiTextInputEvent, ) from strands.types._events import ToolResultEvent from strands.types.tools import ToolResult @@ -189,7 +189,7 @@ async def test_send_all_content_types(mock_genai_client, model): await model.connect() # Test text input - text_input = TextInputEvent(text="Hello", role="user") + text_input = BidiTextInputEvent(text="Hello", role="user") await model.send(text_input) mock_live_session.send_client_content.assert_called_once() call_args = mock_live_session.send_client_content.call_args @@ -200,7 +200,7 @@ async def test_send_all_content_types(mock_genai_client, model): # Test audio input (base64 encoded) import base64 audio_b64 = base64.b64encode(b"audio_bytes").decode('utf-8') - audio_input = AudioInputEvent( + audio_input = BidiAudioInputEvent( audio=audio_b64, format="pcm", sample_rate=16000, @@ -211,7 +211,7 @@ async def test_send_all_content_types(mock_genai_client, model): # Test image input (base64 encoded, no encoding parameter) image_b64 = base64.b64encode(b"image_bytes").decode('utf-8') - image_input = ImageInputEvent( + image_input = BidiImageInputEvent( image=image_b64, mime_type="image/jpeg", ) @@ -237,7 +237,7 @@ async def test_send_edge_cases(mock_genai_client, model): _, mock_live_session, _ = mock_genai_client # Test send when inactive - text_input = TextInputEvent(text="Hello", role="user") + text_input = BidiTextInputEvent(text="Hello", role="user") await model.send(text_input) mock_live_session.send_client_content.assert_not_called() @@ -255,9 +255,9 @@ async def test_send_edge_cases(mock_genai_client, model): @pytest.mark.asyncio async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): """Test that receive() emits connection start and end events.""" - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( - ConnectionStartEvent, - ConnectionCloseEvent, + from strands.experimental.bidirectional_streaming.types.events import ( + BidiConnectionStartEvent, + BidiConnectionCloseEvent, ) _, mock_live_session, _ = mock_genai_client @@ -275,18 +275,18 @@ async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): # Verify connection start and end assert len(events) >= 2 - assert isinstance(events[0], ConnectionStartEvent) + assert isinstance(events[0], BidiConnectionStartEvent) assert events[0].connection_id == model.connection_id - assert isinstance(events[-1], ConnectionCloseEvent) + assert isinstance(events[-1], BidiConnectionCloseEvent) @pytest.mark.asyncio async def test_event_conversion(mock_genai_client, model): """Test conversion of all Gemini Live event types to standard format.""" - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( - TranscriptStreamEvent, - AudioStreamEvent, - InterruptionEvent, + from strands.experimental.bidirectional_streaming.types.events import ( + BidiTranscriptStreamEvent, + BidiAudioStreamEvent, + BidiInterruptionEvent, ) _, _, _ = mock_genai_client @@ -300,7 +300,7 @@ async def test_event_conversion(mock_genai_client, model): mock_text.server_content = None text_event = model._convert_gemini_live_event(mock_text) - assert isinstance(text_event, TranscriptStreamEvent) + assert isinstance(text_event, BidiTranscriptStreamEvent) assert text_event.text == "Hello from Gemini" assert text_event.role == "assistant" assert text_event.is_final is True @@ -316,7 +316,7 @@ async def test_event_conversion(mock_genai_client, model): mock_audio.server_content = None audio_event = model._convert_gemini_live_event(mock_audio) - assert isinstance(audio_event, AudioStreamEvent) + assert isinstance(audio_event, BidiAudioStreamEvent) # Audio is now base64 encoded expected_b64 = base64.b64encode(b"audio_data").decode('utf-8') assert audio_event.audio == expected_b64 @@ -357,7 +357,7 @@ async def test_event_conversion(mock_genai_client, model): mock_interrupt.server_content = mock_server_content interrupt_event = model._convert_gemini_live_event(mock_interrupt) - assert isinstance(interrupt_event, InterruptionEvent) + assert isinstance(interrupt_event, BidiInterruptionEvent) assert interrupt_event.reason == "user_speech" await model.close() diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 1a2fef426..d410d32f1 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -131,9 +131,9 @@ async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model @pytest.mark.asyncio async def test_send_all_content_types(nova_model, mock_client, mock_stream): """Test sending all content types through unified send() method.""" - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( - TextInputEvent, - AudioInputEvent, + from strands.experimental.bidirectional_streaming.types.events import ( + BidiTextInputEvent, + BidiAudioInputEvent, ) from strands.types._events import ToolResultEvent @@ -143,14 +143,14 @@ async def test_send_all_content_types(nova_model, mock_client, mock_stream): await nova_model.connect() # Test text content - text_event = TextInputEvent(text="Hello, Nova!", role="user") + text_event = BidiTextInputEvent(text="Hello, Nova!", role="user") await nova_model.send(text_event) # Should send contentStart, textInput, and contentEnd assert mock_stream.input_stream.send.call_count >= 3 # Test audio content (base64 encoded) audio_b64 = base64.b64encode(b"audio data").decode('utf-8') - audio_event = AudioInputEvent( + audio_event = BidiAudioInputEvent( audio=audio_b64, format="pcm", sample_rate=16000, @@ -177,23 +177,23 @@ async def test_send_all_content_types(nova_model, mock_client, mock_stream): @pytest.mark.asyncio async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog): """Test send() edge cases and error handling.""" - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( - TextInputEvent, - ImageInputEvent, + from strands.experimental.bidirectional_streaming.types.events import ( + BidiTextInputEvent, + BidiImageInputEvent, ) with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model.client = mock_client # Test send when inactive - text_event = TextInputEvent(text="Hello", role="user") + text_event = BidiTextInputEvent(text="Hello", role="user") await nova_model.send(text_event) # Should not raise # Test image content (not supported, base64 encoded, no encoding parameter) await nova_model.connect() import base64 image_b64 = base64.b64encode(b"image data").decode('utf-8') - image_event = ImageInputEvent( + image_event = BidiImageInputEvent( image=image_b64, mime_type="image/jpeg", ) @@ -236,26 +236,26 @@ async def mock_wait_for(*args, **kwargs): @pytest.mark.asyncio async def test_event_conversion(nova_model): """Test conversion of all Nova Sonic event types to standard format.""" - # Test audio output (now returns AudioStreamEvent) - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioStreamEvent + # Test audio output (now returns BidiAudioStreamEvent) + from strands.experimental.bidirectional_streaming.types.events import BidiAudioStreamEvent audio_bytes = b"test audio data" audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") nova_event = {"audioOutput": {"content": audio_base64}} result = nova_model._convert_nova_event(nova_event) assert result is not None - assert isinstance(result, AudioStreamEvent) + assert isinstance(result, BidiAudioStreamEvent) assert result.get("type") == "bidirectional_audio_stream" # Audio is kept as base64 string assert result.get("audio") == audio_base64 assert result.get("format") == "pcm" assert result.get("sample_rate") == 24000 - # Test text output (now returns TranscriptStreamEvent) - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import TranscriptStreamEvent + # Test text output (now returns BidiTranscriptStreamEvent) + from strands.experimental.bidirectional_streaming.types.events import BidiTranscriptStreamEvent nova_event = {"textOutput": {"content": "Hello, world!", "role": "ASSISTANT"}} result = nova_model._convert_nova_event(nova_event) assert result is not None - assert isinstance(result, TranscriptStreamEvent) + assert isinstance(result, BidiTranscriptStreamEvent) assert result.get("type") == "bidirectional_transcript_stream" assert result.get("text") == "Hello, world!" assert result.get("role") == "assistant" @@ -281,17 +281,17 @@ async def test_event_conversion(nova_model): assert tool_use["name"] == "get_weather" assert tool_use["input"] == tool_input - # Test interruption (now returns InterruptionEvent) - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import InterruptionEvent + # Test interruption (now returns BidiInterruptionEvent) + from strands.experimental.bidirectional_streaming.types.events import BidiInterruptionEvent nova_event = {"stopReason": "INTERRUPTED"} result = nova_model._convert_nova_event(nova_event) assert result is not None - assert isinstance(result, InterruptionEvent) + assert isinstance(result, BidiInterruptionEvent) assert result.get("type") == "bidirectional_interruption" assert result.get("reason") == "user_speech" - # Test usage metrics (now returns UsageEvent) - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import UsageEvent + # Test usage metrics (now returns BidiUsageEvent) + from strands.experimental.bidirectional_streaming.types.events import BidiUsageEvent nova_event = { "usageEvent": { "totalTokens": 100, @@ -308,18 +308,18 @@ async def test_event_conversion(nova_model): } result = nova_model._convert_nova_event(nova_event) assert result is not None - assert isinstance(result, UsageEvent) + assert isinstance(result, BidiUsageEvent) assert result.get("type") == "bidirectional_usage" assert result.get("totalTokens") == 100 assert result.get("inputTokens") == 40 assert result.get("outputTokens") == 60 - # Test content start tracks role and emits ResponseStartEvent - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ResponseStartEvent + # Test content start tracks role and emits BidiResponseStartEvent + from strands.experimental.bidirectional_streaming.types.events import BidiResponseStartEvent nova_event = {"contentStart": {"role": "USER"}} result = nova_model._convert_nova_event(nova_event) assert result is not None - assert isinstance(result, ResponseStartEvent) + assert isinstance(result, BidiResponseStartEvent) assert result.get("type") == "bidirectional_response_start" assert nova_model._current_role == "USER" @@ -349,7 +349,7 @@ async def test_audio_connection_lifecycle(nova_model, mock_client, mock_stream): @pytest.mark.asyncio async def test_silence_detection(nova_model, mock_client, mock_stream): """Test that silence detection automatically ends audio input.""" - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioInputEvent + from strands.experimental.bidirectional_streaming.types.events import BidiAudioInputEvent with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model.client = mock_client @@ -360,7 +360,7 @@ async def test_silence_detection(nova_model, mock_client, mock_stream): # Send audio to start connection (base64 encoded) import base64 audio_b64 = base64.b64encode(b"audio data").decode('utf-8') - audio_event = AudioInputEvent( + audio_event = BidiAudioInputEvent( audio=audio_b64, format="pcm", sample_rate=16000, diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 2045424e1..0e8349091 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -16,10 +16,10 @@ import pytest from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeModel -from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( - AudioInputEvent, - ImageInputEvent, - TextInputEvent, +from strands.experimental.bidirectional_streaming.types.events import ( + BidiAudioInputEvent, + BidiImageInputEvent, + BidiTextInputEvent, ) from strands.types.tools import ToolResult @@ -228,7 +228,7 @@ async def test_send_all_content_types(mock_websockets_connect, model): await model.connect() # Test text input - text_input = TextInputEvent(text="Hello", role="user") + text_input = BidiTextInputEvent(text="Hello", role="user") await model.send(text_input) calls = mock_ws.send.call_args_list messages = [json.loads(call[0][0]) for call in calls] @@ -239,7 +239,7 @@ async def test_send_all_content_types(mock_websockets_connect, model): # Test audio input (base64 encoded) audio_b64 = base64.b64encode(b"audio_bytes").decode('utf-8') - audio_input = AudioInputEvent( + audio_input = BidiAudioInputEvent( audio=audio_b64, format="pcm", sample_rate=24000, @@ -278,14 +278,14 @@ async def test_send_edge_cases(mock_websockets_connect, model): _, mock_ws = mock_websockets_connect # Test send when inactive - text_input = TextInputEvent(text="Hello", role="user") + text_input = BidiTextInputEvent(text="Hello", role="user") await model.send(text_input) mock_ws.send.assert_not_called() # Test image input (not supported, base64 encoded, no encoding parameter) await model.connect() image_b64 = base64.b64encode(b"image_bytes").decode('utf-8') - image_input = ImageInputEvent( + image_input = BidiImageInputEvent( image=image_b64, mime_type="image/jpeg", ) @@ -342,8 +342,8 @@ async def test_event_conversion(mock_websockets_connect, model): _, _ = mock_websockets_connect await model.connect() - # Test audio output (now returns list with AudioStreamEvent) - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioStreamEvent + # Test audio output (now returns list with BidiAudioStreamEvent) + from strands.experimental.bidirectional_streaming.types.events import BidiAudioStreamEvent audio_event = { "type": "response.output_audio.delta", "delta": base64.b64encode(b"audio_data").decode() @@ -351,13 +351,13 @@ async def test_event_conversion(mock_websockets_connect, model): converted = model._convert_openai_event(audio_event) assert isinstance(converted, list) assert len(converted) == 1 - assert isinstance(converted[0], AudioStreamEvent) + assert isinstance(converted[0], BidiAudioStreamEvent) assert converted[0].get("type") == "bidirectional_audio_stream" assert converted[0].get("audio") == base64.b64encode(b"audio_data").decode() assert converted[0].get("format") == "pcm" - # Test text output (now returns list with TranscriptStreamEvent) - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import TranscriptStreamEvent + # Test text output (now returns list with BidiTranscriptStreamEvent) + from strands.experimental.bidirectional_streaming.types.events import BidiTranscriptStreamEvent text_event = { "type": "response.output_text.delta", "delta": "Hello from OpenAI" @@ -365,7 +365,7 @@ async def test_event_conversion(mock_websockets_connect, model): converted = model._convert_openai_event(text_event) assert isinstance(converted, list) assert len(converted) == 1 - assert isinstance(converted[0], TranscriptStreamEvent) + assert isinstance(converted[0], BidiTranscriptStreamEvent) assert converted[0].get("type") == "bidirectional_transcript_stream" assert converted[0].get("text") == "Hello from OpenAI" assert converted[0].get("role") == "assistant" @@ -406,15 +406,15 @@ async def test_event_conversion(mock_websockets_connect, model): assert tool_use["name"] == "calculator" assert tool_use["input"]["expression"] == "2+2" - # Test voice activity (now returns list with InterruptionEvent for speech_started) - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import InterruptionEvent + # Test voice activity (now returns list with BidiInterruptionEvent for speech_started) + from strands.experimental.bidirectional_streaming.types.events import BidiInterruptionEvent speech_started = { "type": "input_audio_buffer.speech_started" } converted = model._convert_openai_event(speech_started) assert isinstance(converted, list) assert len(converted) == 1 - assert isinstance(converted[0], InterruptionEvent) + assert isinstance(converted[0], BidiInterruptionEvent) assert converted[0].get("type") == "bidirectional_interruption" assert converted[0].get("reason") == "user_speech" @@ -464,10 +464,10 @@ def test_helper_methods(model): assert model._require_active() is True model._active = False - # Test _create_text_event (now returns TranscriptStreamEvent) - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import TranscriptStreamEvent + # Test _create_text_event (now returns BidiTranscriptStreamEvent) + from strands.experimental.bidirectional_streaming.types.events import BidiTranscriptStreamEvent text_event = model._create_text_event("Hello", "user") - assert isinstance(text_event, TranscriptStreamEvent) + assert isinstance(text_event, BidiTranscriptStreamEvent) assert text_event.get("type") == "bidirectional_transcript_stream" assert text_event.get("text") == "Hello" assert text_event.get("role") == "user" @@ -475,10 +475,10 @@ def test_helper_methods(model): assert text_event.is_final is True assert text_event.current_transcript == "Hello" - # Test _create_voice_activity_event (now returns InterruptionEvent for speech_started) - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import InterruptionEvent + # Test _create_voice_activity_event (now returns BidiInterruptionEvent for speech_started) + from strands.experimental.bidirectional_streaming.types.events import BidiInterruptionEvent voice_event = model._create_voice_activity_event("speech_started") - assert isinstance(voice_event, InterruptionEvent) + assert isinstance(voice_event, BidiInterruptionEvent) assert voice_event.get("type") == "bidirectional_interruption" assert voice_event.get("reason") == "user_speech" diff --git a/tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py b/tests/strands/experimental/bidirectional_streaming/types/test_events.py similarity index 73% rename from tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py rename to tests/strands/experimental/bidirectional_streaming/types/test_events.py index 45bcd2de4..fd7639cc4 100644 --- a/tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py +++ b/tests/strands/experimental/bidirectional_streaming/types/test_events.py @@ -8,19 +8,19 @@ import pytest -from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( - AudioInputEvent, - AudioStreamEvent, - ConnectionCloseEvent, - ConnectionStartEvent, - ErrorEvent, - ImageInputEvent, - InterruptionEvent, - ResponseCompleteEvent, - ResponseStartEvent, - TextInputEvent, - TranscriptStreamEvent, - UsageEvent, +from strands.experimental.bidirectional_streaming.types.events import ( + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionCloseEvent, + BidiConnectionStartEvent, + BidiErrorEvent, + BidiImageInputEvent, + BidiInterruptionEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiUsageEvent, ) @@ -28,9 +28,9 @@ "event_class,kwargs,expected_type", [ # Input events - (TextInputEvent, {"text": "Hello", "role": "user"}, "bidirectional_text_input"), + (BidiTextInputEvent, {"text": "Hello", "role": "user"}, "bidirectional_text_input"), ( - AudioInputEvent, + BidiAudioInputEvent, { "audio": base64.b64encode(b"audio").decode("utf-8"), "format": "pcm", @@ -40,19 +40,19 @@ "bidirectional_audio_input", ), ( - ImageInputEvent, + BidiImageInputEvent, {"image": base64.b64encode(b"image").decode("utf-8"), "mime_type": "image/jpeg"}, "bidirectional_image_input", ), # Output events ( - ConnectionStartEvent, + BidiConnectionStartEvent, {"connection_id": "c1", "model": "m1"}, "bidirectional_connection_start", ), - (ResponseStartEvent, {"response_id": "r1"}, "bidirectional_response_start"), + (BidiResponseStartEvent, {"response_id": "r1"}, "bidirectional_response_start"), ( - AudioStreamEvent, + BidiAudioStreamEvent, { "audio": base64.b64encode(b"audio").decode("utf-8"), "format": "pcm", @@ -62,7 +62,7 @@ "bidirectional_audio_stream", ), ( - TranscriptStreamEvent, + BidiTranscriptStreamEvent, { "delta": {"text": "Hello"}, "text": "Hello", @@ -72,23 +72,23 @@ }, "bidirectional_transcript_stream", ), - (InterruptionEvent, {"reason": "user_speech"}, "bidirectional_interruption"), + (BidiInterruptionEvent, {"reason": "user_speech"}, "bidirectional_interruption"), ( - ResponseCompleteEvent, + BidiResponseCompleteEvent, {"response_id": "r1", "stop_reason": "complete"}, "bidirectional_response_complete", ), ( - UsageEvent, + BidiUsageEvent, {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, "bidirectional_usage", ), ( - ConnectionCloseEvent, + BidiConnectionCloseEvent, {"connection_id": "c1", "reason": "complete"}, "bidirectional_connection_close", ), - (ErrorEvent, {"error": ValueError("test"), "details": None}, "bidirectional_error"), + (BidiErrorEvent, {"error": ValueError("test"), "details": None}, "bidirectional_error"), ], ) def test_event_json_serialization(event_class, kwargs, expected_type): @@ -117,9 +117,9 @@ def test_event_json_serialization(event_class, kwargs, expected_type): def test_transcript_stream_event_delta_pattern(): - """Test that TranscriptStreamEvent follows ModelStreamEvent delta pattern.""" + """Test that BidiTranscriptStreamEvent follows ModelStreamEvent delta pattern.""" # Test partial transcript (delta) - partial_event = TranscriptStreamEvent( + partial_event = BidiTranscriptStreamEvent( delta={"text": "Hello"}, text="Hello", role="user", @@ -134,7 +134,7 @@ def test_transcript_stream_event_delta_pattern(): assert partial_event.delta == {"text": "Hello"} # Test final transcript with accumulated text - final_event = TranscriptStreamEvent( + final_event = BidiTranscriptStreamEvent( delta={"text": " world"}, text=" world", role="user", @@ -150,10 +150,10 @@ def test_transcript_stream_event_delta_pattern(): def test_transcript_stream_event_extends_model_stream_event(): - """Test that TranscriptStreamEvent is a ModelStreamEvent.""" + """Test that BidiTranscriptStreamEvent is a ModelStreamEvent.""" from strands.types._events import ModelStreamEvent - event = TranscriptStreamEvent( + event = BidiTranscriptStreamEvent( delta={"text": "test"}, text="test", role="assistant", diff --git a/tests_integ/bidirectional_streaming/conftest.py b/tests_integ/bidirectional_streaming/conftest.py index 52f6a2a19..0d453818a 100644 --- a/tests_integ/bidirectional_streaming/conftest.py +++ b/tests_integ/bidirectional_streaming/conftest.py @@ -4,7 +4,7 @@ import pytest -from .utils.audio_generator import AudioGenerator +from .generators.audio import AudioGenerator logger = logging.getLogger(__name__) diff --git a/tests_integ/bidirectional_streaming/utils/test_context.py b/tests_integ/bidirectional_streaming/context.py similarity index 98% rename from tests_integ/bidirectional_streaming/utils/test_context.py rename to tests_integ/bidirectional_streaming/context.py index 687aef1b5..069c9c653 100644 --- a/tests_integ/bidirectional_streaming/utils/test_context.py +++ b/tests_integ/bidirectional_streaming/context.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent - from .audio_generator import AudioGenerator + from .generators.audio import AudioGenerator logger = logging.getLogger(__name__) @@ -92,7 +92,6 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): async def start(self): """Start all background threads.""" - import time self.active = True self.last_event_time = time.monotonic() @@ -176,7 +175,6 @@ async def wait_for_response( silence_threshold: Seconds of silence to consider response complete. min_events: Minimum events before silence detection activates. """ - import time start_time = time.monotonic() initial_event_count = len(self.get_events()) # Drain queue @@ -215,7 +213,6 @@ def get_events(self, event_type: str | None = None) -> list[dict]: try: event = self._event_queue.get_nowait() self.events.append(event) - import time self.last_event_time = time.monotonic() except asyncio.QueueEmpty: break @@ -360,7 +357,7 @@ def _generate_silence_chunk(self) -> dict: """Generate silence chunk for background audio. Returns: - AudioInputEvent with silence data. + BidiAudioInputEvent with silence data. """ silence = b"\x00" * self.silence_chunk_size return self.audio_generator.create_audio_input_event(silence) diff --git a/tests_integ/bidirectional_streaming/generators/__init__.py b/tests_integ/bidirectional_streaming/generators/__init__.py new file mode 100644 index 000000000..1f13f0564 --- /dev/null +++ b/tests_integ/bidirectional_streaming/generators/__init__.py @@ -0,0 +1 @@ +"""Test data generators for bidirectional streaming integration tests.""" diff --git a/tests_integ/bidirectional_streaming/utils/audio_generator.py b/tests_integ/bidirectional_streaming/generators/audio.py similarity index 95% rename from tests_integ/bidirectional_streaming/utils/audio_generator.py rename to tests_integ/bidirectional_streaming/generators/audio.py index c3ad3f965..0af0f9949 100644 --- a/tests_integ/bidirectional_streaming/utils/audio_generator.py +++ b/tests_integ/bidirectional_streaming/generators/audio.py @@ -109,7 +109,7 @@ def create_audio_input_event( sample_rate: int = NOVA_SONIC_SAMPLE_RATE, channels: int = NOVA_SONIC_CHANNELS, ) -> dict: - """Create AudioInputEvent from raw audio data. + """Create BidiAudioInputEvent from raw audio data. Args: audio_data: Raw audio bytes. @@ -118,7 +118,7 @@ def create_audio_input_event( channels: Number of audio channels. Returns: - AudioInputEvent dict ready for agent.send(). + BidiAudioInputEvent dict ready for agent.send(). """ import base64 @@ -146,14 +146,14 @@ async def generate_test_audio(text: str, use_cache: bool = True) -> dict: """Generate test audio input event from text. Convenience function that creates an AudioGenerator and returns - a ready-to-use AudioInputEvent. + a ready-to-use BidiAudioInputEvent. Args: text: Text to convert to speech. use_cache: Whether to use cached audio. Returns: - AudioInputEvent dict ready for agent.send(). + BidiAudioInputEvent dict ready for agent.send(). """ generator = AudioGenerator() audio_data = await generator.generate_audio(text, use_cache=use_cache) diff --git a/tests_integ/bidirectional_streaming/test_bidirectional_agent.py b/tests_integ/bidirectional_streaming/test_bidirectional_agent.py index 80b32b178..ea87f9d84 100644 --- a/tests_integ/bidirectional_streaming/test_bidirectional_agent.py +++ b/tests_integ/bidirectional_streaming/test_bidirectional_agent.py @@ -18,7 +18,7 @@ from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeModel from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveModel -from .utils.test_context import BidirectionalTestContext +from .context import BidirectionalTestContext logger = logging.getLogger(__name__) diff --git a/tests_integ/bidirectional_streaming/utils/__init__.py b/tests_integ/bidirectional_streaming/utils/__init__.py deleted file mode 100644 index fb9bdf2e9..000000000 --- a/tests_integ/bidirectional_streaming/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Utilities for bidirectional streaming integration tests.""" diff --git a/tests_integ/bidirectional_streaming/wrappers/__init__.py b/tests_integ/bidirectional_streaming/wrappers/__init__.py new file mode 100644 index 000000000..6b8a64984 --- /dev/null +++ b/tests_integ/bidirectional_streaming/wrappers/__init__.py @@ -0,0 +1,4 @@ +"""Wrappers for bidirectional streaming integration tests. + +Includes fault injection and other transparent wrappers around real implementations. +""" From b815706480d8f44d18e12aef1fac90f4e2c7acbc Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Mon, 10 Nov 2025 12:35:08 -0500 Subject: [PATCH 077/242] Rename bidirectional components --- .../bidirectional_streaming/__init__.py | 20 +++++------ .../bidirectional_streaming/agent/__init__.py | 4 +-- .../bidirectional_streaming/agent/agent.py | 34 +++++++++---------- .../event_loop/bidirectional_event_loop.py | 14 ++++---- .../bidirectional_streaming/io/audio.py | 2 +- .../models/__init__.py | 16 ++++----- .../models/bidirectional_model.py | 8 ++--- .../models/gemini_live.py | 14 ++++---- .../models/novasonic.py | 14 ++++---- .../bidirectional_streaming/models/openai.py | 10 +++--- .../scripts/test_bidi.py | 4 +-- .../scripts/test_bidi_novasonic.py | 8 ++--- .../scripts/test_bidi_openai.py | 6 ++-- .../scripts/test_gemini_live.py | 6 ++-- .../bidirectional_streaming/types/io.py | 2 +- 15 files changed, 81 insertions(+), 81 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 0955a8939..645869c55 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -1,18 +1,18 @@ """Bidirectional streaming package.""" # Main components - Primary user interface -from .agent.agent import BidirectionalAgent +from .agent.agent import BidiAgent # IO channels - Hardware abstraction from .io.audio import AudioIO # Model interface (for custom implementations) -from .models.bidirectional_model import BidirectionalModel +from .models.bidirectional_model import BidiModel # Model providers - What users need to create models -from .models.gemini_live import GeminiLiveModel -from .models.novasonic import NovaSonicModel -from .models.openai import OpenAIRealtimeModel +from .models.gemini_live import BidiGeminiLiveModel +from .models.novasonic import BidiNovaSonicModel +from .models.openai import BidiOpenAIRealtimeModel # Event types - For type hints and event handling from .types.bidirectional_streaming import ( @@ -29,13 +29,13 @@ __all__ = [ # Main interface - "BidirectionalAgent", + "BidiAgent", # IO channels "AudioIO", # Model providers - "GeminiLiveModel", - "NovaSonicModel", - "OpenAIRealtimeModel", + "BidiGeminiLiveModel", + "BidiNovaSonicModel", + "BidiOpenAIRealtimeModel", # Event types "AudioInputEvent", @@ -48,5 +48,5 @@ "VoiceActivityEvent", "UsageMetricsEvent", # Model interface - "BidirectionalModel", + "BidiModel", ] diff --git a/src/strands/experimental/bidirectional_streaming/agent/__init__.py b/src/strands/experimental/bidirectional_streaming/agent/__init__.py index c490e001d..564973099 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/agent/__init__.py @@ -1,5 +1,5 @@ """Bidirectional agent for real-time streaming conversations.""" -from .agent import BidirectionalAgent +from .agent import BidiAgent -__all__ = ["BidirectionalAgent"] +__all__ = ["BidiAgent"] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index aae029ab1..b5dcd428b 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -27,9 +27,9 @@ from ....types.content import Message, Messages from ....types.tools import ToolResult, ToolUse, AgentTool -from ..event_loop.bidirectional_event_loop import BidirectionalAgentLoop -from ..models.bidirectional_model import BidirectionalModel -from ..models.novasonic import NovaSonicModel +from ..event_loop.bidirectional_event_loop import BidiAgentLoop +from ..models.bidirectional_model import BidiModel +from ..models.novasonic import BidiNovaSonicModel from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent from ..types import BidiIO from ....experimental.tools import ToolProvider @@ -42,7 +42,7 @@ BidirectionalInput = str | AudioInputEvent | ImageInputEvent -class BidirectionalAgent: +class BidiAgent: """Agent for bidirectional streaming conversations. Enables real-time audio and text interaction with AI models through persistent @@ -51,7 +51,7 @@ class BidirectionalAgent: def __init__( self, - model: BidirectionalModel| str | None = None, + model: BidiModel| str | None = None, tools: list[str| AgentTool| ToolProvider]| None = None, system_prompt: str | None = None, messages: Messages | None = None, @@ -66,7 +66,7 @@ def __init__( """Initialize bidirectional agent. Args: - model: BidirectionalModel instance, string model_id, or None for default detection. + model: BidiModel instance, string model_id, or None for default detection. tools: Optional list of tools with flexible format support. system_prompt: Optional system prompt for conversations. messages: Optional conversation history to initialize with. @@ -83,9 +83,9 @@ def __init__( TypeError: If model type is unsupported. """ self.model = ( - NovaSonicModel() + BidiNovaSonicModel() if not model - else NovaSonicModel(model_id=model) + else BidiNovaSonicModel(model_id=model) if isinstance(model, str) else model ) @@ -121,7 +121,7 @@ def __init__( self._tool_caller = _ToolCaller(self) # connection management - self._agent_loop: "BidirectionalAgentLoop" | None = None + self._agent_loop: "BidiAgentLoop" | None = None self._output_queue = asyncio.Queue() self._current_adapters = [] # Track adapters for cleanup @@ -134,7 +134,7 @@ def tool(self) -> _ToolCaller: Example: ``` - agent = BidirectionalAgent(model=model, tools=[calculator]) + agent = BidiAgent(model=model, tools=[calculator]) agent.tool.calculator(expression="2+2") ``` """ @@ -252,11 +252,11 @@ async def start(self) -> None: logger.debug("Conversation start - initializing connection") # Create model session and event loop directly - await self.model.connect( + await self.model.start( system_prompt=self.system_prompt, tools=self.tool_registry.get_all_tool_specs(), messages=self.messages ) - self._agent_loop = BidirectionalAgentLoop(model=self.model, agent=self) + self._agent_loop = BidiAgentLoop(model=self.model, agent=self) await self._agent_loop.start() logger.debug("Conversation ready") @@ -306,7 +306,7 @@ async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: except asyncio.TimeoutError: continue - async def end(self) -> None: + async def stop(self) -> None: """End the conversation connection and cleanup all resources. Terminates the streaming connection, cancels background tasks, and @@ -316,7 +316,7 @@ async def end(self) -> None: await self._agent_loop.stop() self._agent_loop = None - async def __aenter__(self) -> "BidirectionalAgent": + async def __aenter__(self) -> "BidiAgent": """Async context manager entry point. Automatically starts the bidirectional connection when entering the context. @@ -350,7 +350,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: for adapter in self._current_adapters: if hasattr(adapter, "cleanup"): try: - adapter.end() + adapter.stop() logger.debug(f"Cleaned up adapter: {type(adapter).__name__}") except Exception as adapter_error: logger.warning(f"Error cleaning up adapter: {adapter_error}") @@ -359,7 +359,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: self._current_adapters = [] # Cleanup agent connection - await self.end() + await self.stop() except Exception as cleanup_error: if exc_type is None: @@ -393,7 +393,7 @@ async def run(self, io_channels: list[BidiIO | tuple[Callable, Callable]]) -> No ```python # With IO channel audio_io = AudioIO(audio_config={"input_sample_rate": 16000}) - agent = BidirectionalAgent(model=model, tools=[calculator]) + agent = BidiAgent(model=model, tools=[calculator]) await agent.run(io_channels=[audio_io]) # With tuple (backward compatibility) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 38d92aea8..d7f87f69e 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -21,7 +21,7 @@ from ....types._events import ToolResultEvent, ToolStreamEvent from ....types.content import Message from ....types.tools import ToolResult, ToolUse -from ..models.bidirectional_model import BidirectionalModel +from ..models.bidirectional_model import BidiModel logger = logging.getLogger(__name__) @@ -37,12 +37,12 @@ class BidirectionalConnection: handling while providing a simple interface for agent interactions. """ - def __init__(self, model: BidirectionalModel, agent: "BidirectionalAgent") -> None: + def __init__(self, model: BidiModel, agent: "BidiAgent") -> None: """Initialize connection with model and agent reference. Args: model: Bidirectional model instance. - agent: BidirectionalAgent instance for tool registry access. + agent: BidiAgent instance for tool registry access. """ self.model = model self.agent = agent @@ -64,14 +64,14 @@ def __init__(self, model: BidirectionalModel, agent: "BidirectionalAgent") -> No self.tool_count = 0 -async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection: +async def start_bidirectional_connection(agent: "BidiAgent") -> BidirectionalConnection: """Initialize bidirectional session with conycurrent background tasks. Creates a model-specific session and starts background tasks for processing model events, executing tools, and managing the session lifecycle. Args: - agent: BidirectionalAgent instance. + agent: BidiAgent instance. Returns: BidirectionalConnection: Active session with background tasks running. @@ -79,7 +79,7 @@ async def start_bidirectional_connection(agent: "BidirectionalAgent") -> Bidirec logger.debug("Starting bidirectional session - initializing model connection") # Connect to model - await agent.model.connect( + await agent.model.start( system_prompt=agent.system_prompt, tools=agent.tool_registry.get_all_tool_specs(), messages=agent.messages ) @@ -136,7 +136,7 @@ async def stop_bidirectional_connection(session: BidirectionalConnection) -> Non await asyncio.gather(*all_tasks, return_exceptions=True) # Close model connection - await session.model.close() + await session.model.stop() logger.debug("Connection closed") diff --git a/src/strands/experimental/bidirectional_streaming/io/audio.py b/src/strands/experimental/bidirectional_streaming/io/audio.py index 4fb60a2b5..a16dce884 100644 --- a/src/strands/experimental/bidirectional_streaming/io/audio.py +++ b/src/strands/experimental/bidirectional_streaming/io/audio.py @@ -176,7 +176,7 @@ async def receive(self, event: dict) -> None: elif role.upper() == "USER": print(f"User: {text}") - def end(self) -> None: + def stop(self) -> None: """Clean up IO channel resources.""" try: if self.input_stream: diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index 5b0d50687..6d6d6590b 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -1,13 +1,13 @@ """Bidirectional model interfaces and implementations.""" -from .bidirectional_model import BidirectionalModel -from .gemini_live import GeminiLiveModel -from .novasonic import NovaSonicModel -from .openai import OpenAIRealtimeModel +from .bidirectional_model import BidiModel +from .gemini_live import BidiGeminiLiveModel +from .novasonic import BidiNovaSonicModel +from .openai import BidiOpenAIRealtimeModel __all__ = [ - "BidirectionalModel", - "GeminiLiveModel", - "NovaSonicModel", - "OpenAIRealtimeModel", + "BidiModel", + "BidiGeminiLiveModel", + "BidiNovaSonicModel", + "BidiOpenAIRealtimeModel", ] diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 05fb19e0f..8a4235d15 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) -class BidirectionalModel(Protocol): +class BidiModel(Protocol): """Protocol for bidirectional streaming models. This interface defines the contract for models that support persistent streaming @@ -35,7 +35,7 @@ class BidirectionalModel(Protocol): provider-specific protocols while exposing a standardized event-based API. """ - async def connect( + async def start( self, system_prompt: str | None = None, tools: list[ToolSpec] | None = None, @@ -56,12 +56,12 @@ async def connect( """ ... - async def close(self) -> None: + async def stop(self) -> None: """Close the streaming connection and release resources. Terminates the active bidirectional connection and cleans up any associated resources such as network connections, buffers, or background tasks. After - calling close(), the model instance cannot be used until connect() is called again. + calling close(), the model instance cannot be used until start() is called again. """ ... diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index ffff98cf1..f719fdac6 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -1,6 +1,6 @@ """Gemini Live API bidirectional model provider using official Google GenAI SDK. -Implements the BidirectionalModel interface for Google's Gemini Live API using the +Implements the BidiModel interface for Google's Gemini Live API using the official Google GenAI SDK for simplified and robust WebSocket communication. Key improvements over custom WebSocket implementation: @@ -34,7 +34,7 @@ TextOutputEvent, TranscriptEvent, ) -from .bidirectional_model import BidirectionalModel +from .bidirectional_model import BidiModel logger = logging.getLogger(__name__) @@ -44,7 +44,7 @@ GEMINI_CHANNELS = 1 -class GeminiLiveModel(BidirectionalModel): +class BidiGeminiLiveModel(BidiModel): """Gemini Live API implementation using official Google GenAI SDK. Combines model configuration and connection state in a single class. @@ -82,13 +82,13 @@ def __init__( self.client = genai.Client(**client_kwargs) - # Connection state (initialized in connect()) + # Connection state (initialized in start()) self.live_session = None self.live_session_context_manager = None self.session_id = None self._active = False - async def connect( + async def start( self, system_prompt: Optional[str] = None, tools: Optional[List[ToolSpec]] = None, @@ -404,7 +404,7 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: except Exception as e: logger.error("Error sending tool result: %s", e) - async def close(self) -> None: + async def stop(self) -> None: """Close Gemini Live API connection.""" if not self._active: return @@ -435,7 +435,7 @@ def _build_live_config( if self.live_config: config_dict.update(self.live_config) - # Override with any kwargs from connect() + # Override with any kwargs from start() config_dict.update(kwargs) # Add system instruction if provided diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index e4c0d1565..e8e048064 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -1,6 +1,6 @@ """Nova Sonic bidirectional model provider for real-time streaming conversations. -Implements the BidirectionalModel interface for Amazon's Nova Sonic, handling the +Implements the BidiModel interface for Amazon's Nova Sonic, handling the complex event sequencing and audio processing required by Nova Sonic's InvokeModelWithBidirectionalStream protocol. @@ -43,7 +43,7 @@ TextOutputEvent, UsageMetricsEvent, ) -from .bidirectional_model import BidirectionalModel +from .bidirectional_model import BidiModel logger = logging.getLogger(__name__) @@ -78,12 +78,12 @@ RESPONSE_TIMEOUT = 1.0 -class NovaSonicModel(BidirectionalModel): +class BidiNovaSonicModel(BidiModel): """Nova Sonic implementation for bidirectional streaming. Combines model configuration and connection state in a single class. Manages Nova Sonic's complex event sequencing, audio format conversion, and - tool execution patterns while providing the standard BidirectionalModel interface. + tool execution patterns while providing the standard BidiModel interface. """ def __init__( @@ -104,7 +104,7 @@ def __init__( self.region = region self.client = None - # Connection state (initialized in connect()) + # Connection state (initialized in start()) self.stream = None self.session_id = None self._active = False @@ -124,7 +124,7 @@ def __init__( logger.debug("Nova Sonic bidirectional model initialized: %s", model_id) - async def connect( + async def start( self, system_prompt: str | None = None, tools: list[ToolSpec] | None = None, @@ -469,7 +469,7 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: for event in events: await self._send_nova_event(event) - async def close(self) -> None: + async def stop(self) -> None: """Close Nova Sonic connection with proper cleanup sequence.""" if not self._active: return diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 4bf43b563..312077621 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -29,7 +29,7 @@ TextOutputEvent, VoiceActivityEvent, ) -from .bidirectional_model import BidirectionalModel +from .bidirectional_model import BidiModel logger = logging.getLogger(__name__) @@ -58,7 +58,7 @@ } -class OpenAIRealtimeModel(BidirectionalModel): +class BidiOpenAIRealtimeModel(BidiModel): """OpenAI Realtime API implementation for bidirectional streaming. Combines model configuration and connection state in a single class. @@ -97,7 +97,7 @@ def __init__( if not self.api_key: raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.") - # Connection state (initialized in connect()) + # Connection state (initialized in start()) self.websocket = None self.session_id = None self._active = False @@ -108,7 +108,7 @@ def __init__( logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) - async def connect( + async def start( self, system_prompt: str | None = None, tools: list[ToolSpec] | None = None, @@ -508,7 +508,7 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: await self._send_event({"type": "conversation.item.create", "item": item_data}) await self._send_event({"type": "response.create"}) - async def close(self) -> None: + async def stop(self) -> None: """Close session and cleanup resources.""" if not self._active: return diff --git a/src/strands/experimental/bidirectional_streaming/scripts/test_bidi.py b/src/strands/experimental/bidirectional_streaming/scripts/test_bidi.py index 57ce8b986..6df0063be 100644 --- a/src/strands/experimental/bidirectional_streaming/scripts/test_bidi.py +++ b/src/strands/experimental/bidirectional_streaming/scripts/test_bidi.py @@ -7,7 +7,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicModel +from strands.experimental.bidirectional_streaming.models.novasonic import BidiNovaSonicModel from strands.experimental.bidirectional_streaming.io.audio import AudioIO from strands_tools import calculator @@ -18,7 +18,7 @@ async def main(): # Nova Sonic model adapter = AudioIO() - model = NovaSonicModel(region="us-east-1") + model = BidiNovaSonicModel(region="us-east-1") async with BidirectionalAgent(model=model, tools=[calculator]) as agent: print("New BidirectionalAgent Experience") diff --git a/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_novasonic.py b/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_novasonic.py index b0a41f20d..173c091ac 100644 --- a/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_novasonic.py @@ -17,7 +17,7 @@ from strands_tools import calculator from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicModel +from strands.experimental.bidirectional_streaming.models.novasonic import BidiNovaSonicModel def test_direct_tools(): @@ -30,7 +30,7 @@ def test_direct_tools(): return try: - model = NovaSonicModel() + model = BidiNovaSonicModel() agent = BidirectionalAgent(model=model, tools=[calculator]) # Test calculator @@ -185,7 +185,7 @@ async def main(duration=180): print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") # Initialize model and agent - model = NovaSonicModel(region="us-east-1") + model = BidiNovaSonicModel(region="us-east-1") agent = BidirectionalAgent(model=model, tools=[calculator], system_prompt="You are a helpful assistant.") await agent.start() @@ -215,7 +215,7 @@ async def main(duration=180): finally: print("Cleaning up...") context["active"] = False - await agent.end() + await agent.stop() if __name__ == "__main__": diff --git a/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_openai.py index 90e82c2bc..a2a8efc90 100644 --- a/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_openai.py +++ b/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_openai.py @@ -14,7 +14,7 @@ from strands_tools import calculator from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeModel +from strands.experimental.bidirectional_streaming.models.openai import BidiOpenAIRealtimeModel async def play(context): @@ -205,7 +205,7 @@ async def main(): return False # Create OpenAI model - model = OpenAIRealtimeModel( + model = BidiOpenAIRealtimeModel( model="gpt-4o-realtime-preview", api_key=api_key, session={ @@ -269,7 +269,7 @@ async def main(): context["active"] = False try: - await agent.end() + await agent.stop() except Exception as e: print(f"Cleanup error: {e}") diff --git a/src/strands/experimental/bidirectional_streaming/scripts/test_gemini_live.py b/src/strands/experimental/bidirectional_streaming/scripts/test_gemini_live.py index 23e97bd5d..ca0e5f8ef 100644 --- a/src/strands/experimental/bidirectional_streaming/scripts/test_gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/scripts/test_gemini_live.py @@ -38,7 +38,7 @@ from strands_tools import calculator from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveModel +from strands.experimental.bidirectional_streaming.models.gemini_live import BidiGeminiLiveModel # Configure logging - debug only for Gemini Live, info for everything else logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') @@ -301,7 +301,7 @@ async def main(duration=180): # Initialize Gemini Live model with proper configuration logger.info("Initializing Gemini Live model with API key") - model = GeminiLiveModel( + model = BidiGeminiLiveModel( model_id="gemini-2.5-flash-native-audio-preview-09-2025", api_key=api_key, params={ @@ -352,7 +352,7 @@ async def main(duration=180): finally: print("Cleaning up...") context["active"] = False - await agent.end() + await agent.stop() if __name__ == "__main__": diff --git a/src/strands/experimental/bidirectional_streaming/types/io.py b/src/strands/experimental/bidirectional_streaming/types/io.py index 98b9b28bd..2e113c74b 100644 --- a/src/strands/experimental/bidirectional_streaming/types/io.py +++ b/src/strands/experimental/bidirectional_streaming/types/io.py @@ -37,7 +37,7 @@ async def receive(self, event: dict) -> None: """ ... - def end(self) -> None: + def stop(self) -> None: """Clean up IO channel resources. Called by the agent during shutdown to ensure proper From 805aa3a55ef5d8ae8a32d098c7a2ba9add98ef36 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Mon, 10 Nov 2025 14:13:31 -0500 Subject: [PATCH 078/242] Fix main branch. Temporarily rename loop to original name --- .../experimental/bidirectional_streaming/agent/agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index b5dcd428b..bc34d7180 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -27,7 +27,7 @@ from ....types.content import Message, Messages from ....types.tools import ToolResult, ToolUse, AgentTool -from ..event_loop.bidirectional_event_loop import BidiAgentLoop +from ..event_loop.bidirectional_event_loop import BidirectionalConnection from ..models.bidirectional_model import BidiModel from ..models.novasonic import BidiNovaSonicModel from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent @@ -121,7 +121,7 @@ def __init__( self._tool_caller = _ToolCaller(self) # connection management - self._agent_loop: "BidiAgentLoop" | None = None + self._agent_loop: "BidirectionalConnection" | None = None self._output_queue = asyncio.Queue() self._current_adapters = [] # Track adapters for cleanup @@ -256,7 +256,7 @@ async def start(self) -> None: system_prompt=self.system_prompt, tools=self.tool_registry.get_all_tool_specs(), messages=self.messages ) - self._agent_loop = BidiAgentLoop(model=self.model, agent=self) + self._agent_loop = BidirectionalConnection(model=self.model, agent=self) await self._agent_loop.start() logger.debug("Conversation ready") From 873441b63e9a7c3b64867c5feed180390f0aedfd Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 10 Nov 2025 22:54:56 +0300 Subject: [PATCH 079/242] Fix agent send() to convert dicts to TypedEvent instances MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The agent's send() method was passing plain dicts directly to models, but models expect TypedEvent instances for isinstance() checks to work. Added dict-to-TypedEvent conversion logic that was lost in merge: - Checks event 'type' field in dict - Reconstructs appropriate TypedEvent (BidiTextInputEvent, BidiAudioInputEvent, etc.) - Maintains backward compatibility with WebSocket/dict-based clients Tests: - ✅ 14/14 type tests passing - ✅ 2/2 integration tests passing (nova_sonic, openai) --- .../bidirectional_streaming/agent/agent.py | 65 +++++++++++++++---- 1 file changed, 54 insertions(+), 11 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 22dca4db6..8a15ee20e 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -257,20 +257,30 @@ async def start(self) -> None: self._agent_loop = await start_bidirectional_connection(self) async def send(self, input_data: BidirectionalInput) -> None: - """Send input to the model (text or audio). - - Unified method for sending both text and audio input to the model during - an active conversation connection. User input is automatically added to - conversation history for complete message tracking. - + """Send input to the model (text, audio, image, or event dict). + + Unified method for sending text, audio, and image input to the model during + an active conversation session. Accepts TypedEvent instances or plain dicts + (e.g., from WebSocket clients) which are automatically reconstructed. + Args: - input_data: String for text, BidiAudioInputEvent for audio, or BidiImageInputEvent for images. - + input_data: Can be: + - str: Text message from user + - BidiAudioInputEvent: Audio data with format/sample rate + - BidiImageInputEvent: Image data with MIME type + - dict: Event dictionary (will be reconstructed to TypedEvent) + Raises: - ValueError: If no active connection or invalid input type. + ValueError: If no active session or invalid input type. + + Example: + await agent.send("Hello") + await agent.send(BidiAudioInputEvent(audio="base64...", format="pcm", ...)) + await agent.send({"type": "bidirectional_text_input", "text": "Hello", "role": "user"}) """ self._validate_active_connection() + # Handle string input if isinstance(input_data, str): # Add user text message to history user_message: Message = {"role": "user", "content": [{"text": input_data}]} @@ -281,9 +291,42 @@ async def send(self, input_data: BidirectionalInput) -> None: # Create BidiTextInputEvent for send() text_event = BidiTextInputEvent(text=input_data, role="user") await self._agent_loop.model.send(text_event) - else: - # For audio, image, or any other input - let model handle it + return + + # Handle InputEvent instances (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent) + # Check this before dict since TypedEvent inherits from dict + if isinstance(input_data, (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent)): await self._agent_loop.model.send(input_data) + return + + # Handle plain dict - reconstruct TypedEvent for WebSocket integration + if isinstance(input_data, dict) and "type" in input_data: + event_type = input_data["type"] + if event_type == "bidirectional_text_input": + input_data = BidiTextInputEvent(text=input_data["text"], role=input_data["role"]) + elif event_type == "bidirectional_audio_input": + input_data = BidiAudioInputEvent( + audio=input_data["audio"], + format=input_data["format"], + sample_rate=input_data["sample_rate"], + channels=input_data["channels"] + ) + elif event_type == "bidirectional_image_input": + input_data = BidiImageInputEvent( + image=input_data["image"], + mime_type=input_data["mime_type"] + ) + else: + raise ValueError(f"Unknown event type: {event_type}") + + # Send the reconstructed TypedEvent + await self._agent_loop.model.send(input_data) + return + + # If we get here, input type is invalid + raise ValueError( + f"Input must be a string, InputEvent (BidiTextInputEvent/BidiAudioInputEvent/BidiImageInputEvent), or event dict with 'type' field, got: {type(input_data)}" + ) async def receive(self) -> AsyncIterable[OutputEvent]: """Receive events from the model including audio, text, and tool calls. From 272a1fc6cd0728dba414887e1681440e6d38f051 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 10 Nov 2025 23:29:30 +0300 Subject: [PATCH 080/242] Update model unit tests to use new class names MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Updated test imports and usages: - GeminiLiveModel → BidiGeminiLiveModel - NovaSonicModel → BidiNovaSonicModel - OpenAIRealtimeModel → BidiOpenAIRealtimeModel Note: 21 model tests still failing because they call .connect() but models now use .start(). This is a pre-existing issue that needs separate fix - tests need API update. --- .../models/test_gemini_live.py | 22 +++++++-------- .../models/test_novasonic.py | 8 +++--- .../models/test_openai_realtime.py | 28 +++++++++---------- 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index b6280b2ee..208fb3276 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -1,6 +1,6 @@ """Unit tests for Gemini Live bidirectional streaming model. -Tests the unified GeminiLiveModel interface including: +Tests the unified BidiGeminiLiveModel interface including: - Model initialization and configuration - Connection establishment and lifecycle - Unified send() method with different content types @@ -13,7 +13,7 @@ from google import genai from google.genai import types as genai_types -from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveModel +from strands.experimental.bidirectional_streaming.models.gemini_live import BidiGeminiLiveModel from strands.experimental.bidirectional_streaming.types.events import ( BidiAudioInputEvent, BidiImageInputEvent, @@ -56,9 +56,9 @@ def api_key(): @pytest.fixture def model(mock_genai_client, model_id, api_key): - """Create a GeminiLiveModel instance.""" + """Create a BidiGeminiLiveModel instance.""" _ = mock_genai_client - return GeminiLiveModel(model_id=model_id, api_key=api_key) + return BidiGeminiLiveModel(model_id=model_id, api_key=api_key) @pytest.fixture @@ -88,20 +88,20 @@ def test_model_initialization(mock_genai_client, model_id, api_key): _ = mock_genai_client # Test default config - model_default = GeminiLiveModel() + model_default = BidiGeminiLiveModel() assert model_default.model_id == "models/gemini-2.0-flash-live-preview-04-09" assert model_default.api_key is None assert model_default._active is False assert model_default.live_session is None # Test with API key - model_with_key = GeminiLiveModel(model_id=model_id, api_key=api_key) + model_with_key = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) assert model_with_key.model_id == model_id assert model_with_key.api_key == api_key # Test with custom config live_config = {"temperature": 0.7, "top_p": 0.9} - model_custom = GeminiLiveModel(model_id=model_id, live_config=live_config) + model_custom = BidiGeminiLiveModel(model_id=model_id, live_config=live_config) assert model_custom.live_config == live_config @@ -152,7 +152,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): mock_client, _, mock_live_session_cm = mock_genai_client # Test connection error - model1 = GeminiLiveModel(model_id=model_id, api_key=api_key) + model1 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) mock_client.aio.live.connect.side_effect = Exception("Connection failed") with pytest.raises(Exception, match="Connection failed"): await model1.connect() @@ -161,18 +161,18 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): mock_client.aio.live.connect.side_effect = None # Test double connection - model2 = GeminiLiveModel(model_id=model_id, api_key=api_key) + model2 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) await model2.connect() with pytest.raises(RuntimeError, match="Connection already active"): await model2.connect() await model2.close() # Test close when not connected - model3 = GeminiLiveModel(model_id=model_id, api_key=api_key) + model3 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) await model3.close() # Should not raise # Test close error handling - model4 = GeminiLiveModel(model_id=model_id, api_key=api_key) + model4 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) await model4.connect() mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") with pytest.raises(Exception, match="Close failed"): diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index d410d32f1..5e3a72ee6 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -13,7 +13,7 @@ import pytest_asyncio from strands.experimental.bidirectional_streaming.models.novasonic import ( - NovaSonicModel, + BidiNovaSonicModel, ) from strands.types.tools import ToolResult @@ -53,7 +53,7 @@ def mock_client(mock_stream): @pytest_asyncio.fixture async def nova_model(model_id, region): """Create Nova Sonic model instance.""" - model = NovaSonicModel(model_id=model_id, region=region) + model = BidiNovaSonicModel(model_id=model_id, region=region) yield model # Cleanup if model._active: @@ -66,7 +66,7 @@ async def nova_model(model_id, region): @pytest.mark.asyncio async def test_model_initialization(model_id, region): """Test model initialization with configuration.""" - model = NovaSonicModel(model_id=model_id, region=region) + model = BidiNovaSonicModel(model_id=model_id, region=region) assert model.model_id == model_id assert model.region == region @@ -120,7 +120,7 @@ async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model await nova_model.close() # Test close when already closed - model2 = NovaSonicModel(model_id=model_id, region=region) + model2 = BidiNovaSonicModel(model_id=model_id, region=region) await model2.close() # Should not raise await model2.close() # Second call should also be safe diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 0e8349091..99721e061 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -1,6 +1,6 @@ """Unit tests for OpenAI Realtime bidirectional streaming model. -Tests the unified OpenAIRealtimeModel interface including: +Tests the unified BidiOpenAIRealtimeModel interface including: - Model initialization and configuration - Connection establishment with WebSocket - Unified send() method with different content types @@ -15,7 +15,7 @@ import pytest -from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeModel +from strands.experimental.bidirectional_streaming.models.openai import BidiOpenAIRealtimeModel from strands.experimental.bidirectional_streaming.types.events import ( BidiAudioInputEvent, BidiImageInputEvent, @@ -56,8 +56,8 @@ def api_key(): @pytest.fixture def model(api_key, model_name): - """Create an OpenAIRealtimeModel instance.""" - return OpenAIRealtimeModel(model=model_name, api_key=api_key) + """Create an BidiOpenAIRealtimeModel instance.""" + return BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) @pytest.fixture @@ -85,19 +85,19 @@ def messages(): def test_model_initialization(api_key, model_name): """Test model initialization with various configurations.""" # Test default config - model_default = OpenAIRealtimeModel(api_key="test-key") + model_default = BidiOpenAIRealtimeModel(api_key="test-key") assert model_default.model == "gpt-realtime" assert model_default.api_key == "test-key" assert model_default._active is False assert model_default.websocket is None # Test with custom model - model_custom = OpenAIRealtimeModel(model=model_name, api_key=api_key) + model_custom = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) assert model_custom.model == model_name assert model_custom.api_key == api_key # Test with organization and project - model_org = OpenAIRealtimeModel( + model_org = BidiOpenAIRealtimeModel( model=model_name, api_key=api_key, organization="org-123", @@ -108,7 +108,7 @@ def test_model_initialization(api_key, model_name): # Test with env API key with unittest.mock.patch.dict("os.environ", {"OPENAI_API_KEY": "env-key"}): - model_env = OpenAIRealtimeModel() + model_env = BidiOpenAIRealtimeModel() assert model_env.api_key == "env-key" @@ -116,7 +116,7 @@ def test_init_without_api_key_raises(): """Test that initialization without API key raises error.""" with unittest.mock.patch.dict("os.environ", {}, clear=True): with pytest.raises(ValueError, match="OpenAI API key is required"): - OpenAIRealtimeModel() + BidiOpenAIRealtimeModel() # Connection Tests @@ -171,7 +171,7 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp await model.close() # Test connection with organization header - model_org = OpenAIRealtimeModel(api_key="test-key", organization="org-123") + model_org = BidiOpenAIRealtimeModel(api_key="test-key", organization="org-123") await model_org.connect() call_kwargs = mock_connect.call_args.kwargs headers = call_kwargs.get("additional_headers", []) @@ -187,7 +187,7 @@ async def test_connection_edge_cases(mock_websockets_connect, api_key, model_nam mock_connect, mock_ws = mock_websockets_connect # Test connection error - model1 = OpenAIRealtimeModel(model=model_name, api_key=api_key) + model1 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) mock_connect.side_effect = Exception("Connection failed") with pytest.raises(Exception, match="Connection failed"): await model1.connect() @@ -198,18 +198,18 @@ async def async_connect(*args, **kwargs): mock_connect.side_effect = async_connect # Test double connection - model2 = OpenAIRealtimeModel(model=model_name, api_key=api_key) + model2 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) await model2.connect() with pytest.raises(RuntimeError, match="Connection already active"): await model2.connect() await model2.close() # Test close when not connected - model3 = OpenAIRealtimeModel(model=model_name, api_key=api_key) + model3 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) await model3.close() # Should not raise # Test close error handling (should not raise, just log) - model4 = OpenAIRealtimeModel(model=model_name, api_key=api_key) + model4 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) await model4.connect() mock_ws.close.side_effect = Exception("Close failed") await model4.close() # Should not raise From 6c2cbf5ad615b07e77d0a739fa2f598e954868f0 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 10 Nov 2025 23:33:54 +0300 Subject: [PATCH 081/242] Fix all model unit tests to use new API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Updated all test calls from old API to new API: - .connect() → .start() - .close() → .stop() - Updated error message expectations to match actual errors All tests now passing: - ✅ 47/47 bidirectional streaming tests passing - ✅ 14/14 type tests - ✅ 33/33 model tests - ✅ 2/2 integration tests --- .../models/test_gemini_live.py | 46 ++++++++-------- .../models/test_novasonic.py | 42 +++++++-------- .../models/test_openai_realtime.py | 54 +++++++++---------- 3 files changed, 71 insertions(+), 71 deletions(-) diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index 208fb3276..f2821a7a0 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -114,36 +114,36 @@ async def test_connection_lifecycle(mock_genai_client, model, system_prompt, too mock_client, mock_live_session, mock_live_session_cm = mock_genai_client # Test basic connection - await model.connect() + await model.start() assert model._active is True assert model.connection_id is not None assert model.live_session == mock_live_session mock_client.aio.live.connect.assert_called_once() # Test close - await model.close() + await model.stop() assert model._active is False mock_live_session_cm.__aexit__.assert_called_once() # Test connection with system prompt - await model.connect(system_prompt=system_prompt) + await model.start(system_prompt=system_prompt) call_args = mock_client.aio.live.connect.call_args config = call_args.kwargs.get("config", {}) assert config.get("system_instruction") == system_prompt - await model.close() + await model.stop() # Test connection with tools - await model.connect(tools=[tool_spec]) + await model.start(tools=[tool_spec]) call_args = mock_client.aio.live.connect.call_args config = call_args.kwargs.get("config", {}) assert "tools" in config assert len(config["tools"]) > 0 - await model.close() + await model.stop() # Test connection with messages - await model.connect(messages=messages) + await model.start(messages=messages) mock_live_session.send_client_content.assert_called() - await model.close() + await model.stop() @pytest.mark.asyncio @@ -155,28 +155,28 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): model1 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) mock_client.aio.live.connect.side_effect = Exception("Connection failed") with pytest.raises(Exception, match="Connection failed"): - await model1.connect() + await model1.start() # Reset mock for next tests mock_client.aio.live.connect.side_effect = None # Test double connection model2 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) - await model2.connect() + await model2.start() with pytest.raises(RuntimeError, match="Connection already active"): - await model2.connect() - await model2.close() + await model2.start() + await model2.stop() # Test close when not connected model3 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) - await model3.close() # Should not raise + await model3.stop() # Should not raise # Test close error handling model4 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) - await model4.connect() + await model4.start() mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") with pytest.raises(Exception, match="Close failed"): - await model4.close() + await model4.stop() # Send Method Tests @@ -186,7 +186,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): async def test_send_all_content_types(mock_genai_client, model): """Test sending all content types through unified send() method.""" _, mock_live_session, _ = mock_genai_client - await model.connect() + await model.start() # Test text input text_input = BidiTextInputEvent(text="Hello", role="user") @@ -228,7 +228,7 @@ async def test_send_all_content_types(mock_genai_client, model): await model.send(ToolResultEvent(tool_result)) mock_live_session.send_tool_response.assert_called_once() - await model.close() + await model.stop() @pytest.mark.asyncio @@ -242,11 +242,11 @@ async def test_send_edge_cases(mock_genai_client, model): mock_live_session.send_client_content.assert_not_called() # Test unknown content type - await model.connect() + await model.start() unknown_content = {"unknown_field": "value"} await model.send(unknown_content) # Should not raise, just log warning - await model.close() + await model.stop() # Receive Method Tests @@ -263,7 +263,7 @@ async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): _, mock_live_session, _ = mock_genai_client mock_live_session.receive.return_value = agenerator([]) - await model.connect() + await model.start() # Collect events events = [] @@ -271,7 +271,7 @@ async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): events.append(event) # Close after first event to trigger connection end if len(events) == 1: - await model.close() + await model.stop() # Verify connection start and end assert len(events) >= 2 @@ -290,7 +290,7 @@ async def test_event_conversion(mock_genai_client, model): ) _, _, _ = mock_genai_client - await model.connect() + await model.start() # Test text output (converted to transcript) mock_text = unittest.mock.Mock() @@ -360,7 +360,7 @@ async def test_event_conversion(mock_genai_client, model): assert isinstance(interrupt_event, BidiInterruptionEvent) assert interrupt_event.reason == "user_speech" - await model.close() + await model.stop() # Helper Method Tests diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 5e3a72ee6..9f0d35f88 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -57,7 +57,7 @@ async def nova_model(model_id, region): yield model # Cleanup if model._active: - await model.close() + await model.stop() # Initialization and Connection Tests @@ -82,14 +82,14 @@ async def test_connection_lifecycle(nova_model, mock_client, mock_stream): nova_model.client = mock_client # Test basic connection - await nova_model.connect(system_prompt="Test system prompt") + await nova_model.start(system_prompt="Test system prompt") assert nova_model._active assert nova_model.stream == mock_stream assert nova_model.connection_id is not None assert mock_client.invoke_model_with_bidirectional_stream.called # Test close - await nova_model.close() + await nova_model.stop() assert not nova_model._active assert mock_stream.input_stream.close.called @@ -101,10 +101,10 @@ async def test_connection_lifecycle(nova_model, mock_client, mock_stream): "inputSchema": {"json": json.dumps({"type": "object", "properties": {}})} } ] - await nova_model.connect(system_prompt="You are helpful", tools=tools) + await nova_model.start(system_prompt="You are helpful", tools=tools) # Verify initialization events were sent (connectionStart, promptStart, system prompt) assert mock_stream.input_stream.send.call_count >= 3 - await nova_model.close() + await nova_model.stop() @pytest.mark.asyncio @@ -114,15 +114,15 @@ async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model nova_model.client = mock_client # Test double connection - await nova_model.connect() + await nova_model.start() with pytest.raises(RuntimeError, match="Connection already active"): - await nova_model.connect() - await nova_model.close() + await nova_model.start() + await nova_model.stop() # Test close when already closed model2 = BidiNovaSonicModel(model_id=model_id, region=region) - await model2.close() # Should not raise - await model2.close() # Second call should also be safe + await model2.stop() # Should not raise + await model2.stop() # Second call should also be safe # Send Method Tests @@ -140,7 +140,7 @@ async def test_send_all_content_types(nova_model, mock_client, mock_stream): with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model.client = mock_client - await nova_model.connect() + await nova_model.start() # Test text content text_event = BidiTextInputEvent(text="Hello, Nova!", role="user") @@ -171,7 +171,7 @@ async def test_send_all_content_types(nova_model, mock_client, mock_stream): # Should send contentStart, toolResult, and contentEnd assert mock_stream.input_stream.send.called - await nova_model.close() + await nova_model.stop() @pytest.mark.asyncio @@ -190,7 +190,7 @@ async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog): await nova_model.send(text_event) # Should not raise # Test image content (not supported, base64 encoded, no encoding parameter) - await nova_model.connect() + await nova_model.start() import base64 image_b64 = base64.b64encode(b"image data").decode('utf-8') image_event = BidiImageInputEvent( @@ -201,7 +201,7 @@ async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog): # Should log warning about unsupported image input assert any("not supported" in record.message.lower() for record in caplog.records) - await nova_model.close() + await nova_model.stop() # Receive and Event Conversion Tests @@ -220,7 +220,7 @@ async def mock_wait_for(*args, **kwargs): raise asyncio.TimeoutError() with patch("asyncio.wait_for", side_effect=mock_wait_for): - await nova_model.connect() + await nova_model.start() events = [] async for event in nova_model.receive(): @@ -333,7 +333,7 @@ async def test_audio_connection_lifecycle(nova_model, mock_client, mock_stream): with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model.client = mock_client - await nova_model.connect() + await nova_model.start() # Start audio connection await nova_model._start_audio_connection() @@ -343,7 +343,7 @@ async def test_audio_connection_lifecycle(nova_model, mock_client, mock_stream): await nova_model._end_audio_input() assert not nova_model.audio_connection_active - await nova_model.close() + await nova_model.stop() @pytest.mark.asyncio @@ -355,7 +355,7 @@ async def test_silence_detection(nova_model, mock_client, mock_stream): nova_model.client = mock_client nova_model.silence_threshold = 0.1 # Short threshold for testing - await nova_model.connect() + await nova_model.start() # Send audio to start connection (base64 encoded) import base64 @@ -376,7 +376,7 @@ async def test_silence_detection(nova_model, mock_client, mock_stream): # Audio connection should be ended assert not nova_model.audio_connection_active - await nova_model.close() + await nova_model.stop() # Helper Method Tests @@ -458,10 +458,10 @@ async def mock_error(*args, **kwargs): mock_stream.await_output.side_effect = mock_error - await nova_model.connect() + await nova_model.start() # Wait a bit for response processor to handle error await asyncio.sleep(0.1) # Should still be able to close cleanly - await nova_model.close() + await nova_model.stop() diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 99721e061..36aaf9242 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -128,7 +128,7 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp mock_connect, mock_ws = mock_websockets_connect # Test basic connection - await model.connect() + await model.start() assert model._active is True assert model.connection_id is not None assert model.websocket == mock_ws @@ -137,12 +137,12 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp mock_connect.assert_called_once() # Test close - await model.close() + await model.stop() assert model._active is False mock_ws.close.assert_called_once() # Test connection with system prompt - await model.connect(system_prompt=system_prompt) + await model.start(system_prompt=system_prompt) calls = mock_ws.send.call_args_list session_update = next( (json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "session.update"), @@ -150,10 +150,10 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp ) assert session_update is not None assert system_prompt in session_update["session"]["instructions"] - await model.close() + await model.stop() # Test connection with tools - await model.connect(tools=[tool_spec]) + await model.start(tools=[tool_spec]) calls = mock_ws.send.call_args_list # Tools are sent in a separate session.update after initial connection session_updates = [json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "session.update"] @@ -161,24 +161,24 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp # Check if any session update has tools has_tools = any("tools" in update.get("session", {}) for update in session_updates) assert has_tools - await model.close() + await model.stop() # Test connection with messages - await model.connect(messages=messages) + await model.start(messages=messages) calls = mock_ws.send.call_args_list item_creates = [json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "conversation.item.create"] assert len(item_creates) > 0 - await model.close() + await model.stop() # Test connection with organization header model_org = BidiOpenAIRealtimeModel(api_key="test-key", organization="org-123") - await model_org.connect() + await model_org.start() call_kwargs = mock_connect.call_args.kwargs headers = call_kwargs.get("additional_headers", []) org_header = [h for h in headers if h[0] == "OpenAI-Organization"] assert len(org_header) == 1 assert org_header[0][1] == "org-123" - await model_org.close() + await model_org.stop() @pytest.mark.asyncio @@ -190,7 +190,7 @@ async def test_connection_edge_cases(mock_websockets_connect, api_key, model_nam model1 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) mock_connect.side_effect = Exception("Connection failed") with pytest.raises(Exception, match="Connection failed"): - await model1.connect() + await model1.start() # Reset mock async def async_connect(*args, **kwargs): @@ -199,20 +199,20 @@ async def async_connect(*args, **kwargs): # Test double connection model2 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) - await model2.connect() + await model2.start() with pytest.raises(RuntimeError, match="Connection already active"): - await model2.connect() - await model2.close() + await model2.start() + await model2.stop() # Test close when not connected model3 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) - await model3.close() # Should not raise + await model3.stop() # Should not raise # Test close error handling (should not raise, just log) model4 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) - await model4.connect() + await model4.start() mock_ws.close.side_effect = Exception("Close failed") - await model4.close() # Should not raise + await model4.stop() # Should not raise assert model4._active is False @@ -225,7 +225,7 @@ async def test_send_all_content_types(mock_websockets_connect, model): from strands.types._events import ToolResultEvent _, mock_ws = mock_websockets_connect - await model.connect() + await model.start() # Test text input text_input = BidiTextInputEvent(text="Hello", role="user") @@ -269,7 +269,7 @@ async def test_send_all_content_types(mock_websockets_connect, model): assert item.get("type") == "function_call_output" assert item.get("call_id") == "tool-123" - await model.close() + await model.stop() @pytest.mark.asyncio @@ -283,7 +283,7 @@ async def test_send_edge_cases(mock_websockets_connect, model): mock_ws.send.assert_not_called() # Test image input (not supported, base64 encoded, no encoding parameter) - await model.connect() + await model.start() image_b64 = base64.b64encode(b"image_bytes").decode('utf-8') image_input = BidiImageInputEvent( image=image_b64, @@ -299,7 +299,7 @@ async def test_send_edge_cases(mock_websockets_connect, model): await model.send(unknown_content) assert mock_logger.warning.called - await model.close() + await model.stop() # Receive Method Tests @@ -310,7 +310,7 @@ async def test_receive_lifecycle_events(mock_websockets_connect, model): """Test that receive() emits connection start and end events.""" _, _ = mock_websockets_connect - await model.connect() + await model.start() # Get first event receive_gen = model.receive() @@ -322,7 +322,7 @@ async def test_receive_lifecycle_events(mock_websockets_connect, model): assert first_event.get("model") == model.model # Close to trigger session end - await model.close() + await model.stop() # Collect remaining events events = [first_event] @@ -340,7 +340,7 @@ async def test_receive_lifecycle_events(mock_websockets_connect, model): async def test_event_conversion(mock_websockets_connect, model): """Test conversion of all OpenAI event types to standard format.""" _, _ = mock_websockets_connect - await model.connect() + await model.start() # Test audio output (now returns list with BidiAudioStreamEvent) from strands.experimental.bidirectional_streaming.types.events import BidiAudioStreamEvent @@ -418,7 +418,7 @@ async def test_event_conversion(mock_websockets_connect, model): assert converted[0].get("type") == "bidirectional_interruption" assert converted[0].get("reason") == "user_speech" - await model.close() + await model.stop() # Helper Method Tests @@ -490,7 +490,7 @@ def test_helper_methods(model): async def test_send_event_helper(mock_websockets_connect, model): """Test _send_event helper method.""" _, mock_ws = mock_websockets_connect - await model.connect() + await model.start() test_event = {"type": "test.event", "data": "test"} await model._send_event(test_event) @@ -500,4 +500,4 @@ async def test_send_event_helper(mock_websockets_connect, model): sent_message = json.loads(last_call[0][0]) assert sent_message == test_event - await model.close() + await model.stop() From 8918757301c06b2360c6e449f366ee0f16f1b48c Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 11 Nov 2025 00:01:49 +0300 Subject: [PATCH 082/242] fix: fix bidi tests --- .../event_loop/bidirectional_event_loop.py | 6 ++--- .../models/gemini_live.py | 9 ++++--- .../models/novasonic.py | 4 ++-- .../bidirectional_streaming/models/openai.py | 24 +++++++++++++++---- 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index b131c72a3..4ca763fc5 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -282,7 +282,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Queue tool requests for concurrent execution # Check for ToolUseStreamEvent (standard agent event) - if "current_tool_use" in strands_event: + if event_type == "tool_use_stream": tool_use = strands_event.get("current_tool_use") if tool_use: tool_name = tool_use.get("name") @@ -297,9 +297,9 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Update Agent conversation history for user transcripts if event_type == "bidirectional_transcript_stream": - source = strands_event.get("source") + role = strands_event.get("role") text = strands_event.get("text", "") - if source == "user" and text.strip(): + if role == "user" and text.strip(): user_message = {"role": "user", "content": text} session.agent.messages.append(user_message) logger.debug("User transcript added to history") diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 5161694ef..17c49f7de 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -219,11 +219,12 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic # Check if the transcription object has text content if hasattr(input_transcript, 'text') and input_transcript.text: transcription_text = input_transcript.text + role = getattr(input_transcript, 'role', 'user') logger.debug(f"Input transcription detected: {transcription_text}") return BidiTranscriptStreamEvent( delta={"text": transcription_text}, text=transcription_text, - role="user", + role=role.lower() if isinstance(role, str) else "user", is_final=True, current_transcript=transcription_text ) @@ -234,22 +235,24 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic # Check if the transcription object has text content if hasattr(output_transcript, 'text') and output_transcript.text: transcription_text = output_transcript.text + role = getattr(output_transcript, 'role', 'assistant') logger.debug(f"Output transcription detected: {transcription_text}") return BidiTranscriptStreamEvent( delta={"text": transcription_text}, text=transcription_text, - role="assistant", + role=role.lower() if isinstance(role, str) else "assistant", is_final=True, current_transcript=transcription_text ) # Handle text output from model if message.text: + role = getattr(message, 'role', 'assistant') logger.debug(f"Text output as transcript: {message.text}") return BidiTranscriptStreamEvent( delta={"text": message.text}, text=message.text, - role="assistant", + role=role.lower() if isinstance(role, str) else "assistant", is_final=True, current_transcript=message.text ) diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index ccd12f24e..b7e8d05b8 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -552,7 +552,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: elif "textOutput" in nova_event: text_content = nova_event["textOutput"]["content"] # Use stored role from contentStart event, fallback to event role - role = getattr(self, "_current_role", nova_event["textOutput"].get("role", "assistant")) + role = getattr(self, "_current_role", None) or nova_event["textOutput"].get("role", "assistant") # Check for Nova Sonic interruption pattern if '{ "interrupted" : true }' in text_content: @@ -562,7 +562,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: return BidiTranscriptStreamEvent( delta={"text": text_content}, text=text_content, - role="user" if role == "USER" else "assistant", + role=role.lower() if isinstance(role, str) else "assistant", is_final=True, current_transcript=text_content ) diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 6616a6b02..e21866ccc 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -174,11 +174,22 @@ def _require_active(self) -> bool: return self._active def _create_text_event(self, text: str, role: str, is_final: bool = True) -> BidiTranscriptStreamEvent: - """Create standardized transcript event.""" + """Create standardized transcript event. + + Args: + text: The transcript text + role: The role (will be normalized to lowercase) + is_final: Whether this is the final transcript + """ + # Normalize role to lowercase and ensure it's either "user" or "assistant" + normalized_role = role.lower() if isinstance(role, str) else "assistant" + if normalized_role not in ["user", "assistant"]: + normalized_role = "assistant" + return BidiTranscriptStreamEvent( delta={"text": text}, text=text, - role="user" if role == "user" else "assistant", + role=normalized_role, is_final=is_final, current_transcript=text if is_final else None ) @@ -326,20 +337,23 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEven # Assistant text output events - combine multiple similar events elif event_type in ["response.output_text.delta", "response.output_audio_transcript.delta"]: - return [self._create_text_event(openai_event["delta"], "assistant")] + role = openai_event.get("role", "assistant") + return [self._create_text_event(openai_event["delta"], role.lower() if isinstance(role, str) else "assistant")] # User transcription events - combine multiple similar events elif event_type in ["conversation.item.input_audio_transcription.delta", "conversation.item.input_audio_transcription.completed"]: text_key = "delta" if "delta" in event_type else "transcript" text = openai_event.get(text_key, "") + role = openai_event.get("role", "user") is_final = "completed" in event_type - return [self._create_text_event(text, "user", is_final=is_final)] if text.strip() else None + return [self._create_text_event(text, role.lower() if isinstance(role, str) else "user", is_final=is_final)] if text.strip() else None elif event_type == "conversation.item.input_audio_transcription.segment": segment_data = openai_event.get("segment", {}) text = segment_data.get("text", "") - return [self._create_text_event(text, "user")] if text.strip() else None + role = segment_data.get("role", "user") + return [self._create_text_event(text, role.lower() if isinstance(role, str) else "user")] if text.strip() else None elif event_type == "conversation.item.input_audio_transcription.failed": error_info = openai_event.get("error", {}) From 3a9f944fa73ff97589343591ef2d1b34b0cbd141 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 11 Nov 2025 00:09:05 +0300 Subject: [PATCH 083/242] refactor: rename to bidi input and output events --- .../experimental/bidirectional_streaming/__init__.py | 8 ++++---- .../bidirectional_streaming/agent/agent.py | 6 +++--- .../models/bidirectional_model.py | 10 +++++----- .../bidirectional_streaming/models/gemini_live.py | 4 ++-- .../bidirectional_streaming/models/novasonic.py | 8 ++++---- .../bidirectional_streaming/models/openai.py | 10 +++++----- .../bidirectional_streaming/types/__init__.py | 8 ++++---- .../bidirectional_streaming/types/events.py | 6 +++--- 8 files changed, 30 insertions(+), 30 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 476b8e397..2dd38e172 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -22,11 +22,11 @@ BidiConnectionStartEvent, BidiErrorEvent, BidiImageInputEvent, - InputEvent, + BidiInputEvent, BidiInterruptionEvent, ModalityUsage, BidiUsageEvent, - OutputEvent, + BidiOutputEvent, BidiResponseCompleteEvent, BidiResponseStartEvent, BidiTextInputEvent, @@ -54,7 +54,7 @@ "BidiTextInputEvent", "BidiAudioInputEvent", "BidiImageInputEvent", - "InputEvent", + "BidiInputEvent", # Output Event types "BidiConnectionStartEvent", @@ -67,7 +67,7 @@ "BidiUsageEvent", "ModalityUsage", "BidiErrorEvent", - "OutputEvent", + "BidiOutputEvent", # Tool Event types (reused from standard agent) "ToolUseStreamEvent", diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 8a15ee20e..7fb2e21c1 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -34,7 +34,7 @@ ) from ..models.bidirectional_model import BidiModel from ..models.novasonic import BidiNovaSonicModel -from ..types.events import BidiAudioInputEvent, BidiImageInputEvent, BidiTextInputEvent, InputEvent, OutputEvent +from ..types.events import BidiAudioInputEvent, BidiImageInputEvent, BidiTextInputEvent, BidiInputEvent, BidiOutputEvent from ..types import BidiIO from ....experimental.tools import ToolProvider @@ -325,10 +325,10 @@ async def send(self, input_data: BidirectionalInput) -> None: # If we get here, input type is invalid raise ValueError( - f"Input must be a string, InputEvent (BidiTextInputEvent/BidiAudioInputEvent/BidiImageInputEvent), or event dict with 'type' field, got: {type(input_data)}" + f"Input must be a string, BidiInputEvent (BidiTextInputEvent/BidiAudioInputEvent/BidiImageInputEvent), or event dict with 'type' field, got: {type(input_data)}" ) - async def receive(self) -> AsyncIterable[OutputEvent]: + async def receive(self) -> AsyncIterable[BidiOutputEvent]: """Receive events from the model including audio, text, and tool calls. Yields model output events processed by background tasks including audio output, diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index c08891646..d3c3aa7ec 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -21,8 +21,8 @@ from ..types.events import ( BidiAudioInputEvent, BidiImageInputEvent, - InputEvent, - OutputEvent, + BidiInputEvent, + BidiOutputEvent, BidiTextInputEvent, ) @@ -67,7 +67,7 @@ async def stop(self) -> None: """ ... - async def receive(self) -> AsyncIterable[OutputEvent]: + async def receive(self) -> AsyncIterable[BidiOutputEvent]: """Receive streaming events from the model. Continuously yields events from the model as they arrive over the connection. @@ -77,14 +77,14 @@ async def receive(self) -> AsyncIterable[OutputEvent]: The stream continues until the connection is closed or an error occurs. Yields: - OutputEvent: Standardized event objects containing audio output, + BidiOutputEvent: Standardized event objects containing audio output, transcripts, tool calls, or control signals. """ ... async def send( self, - content: InputEvent | ToolResultEvent, + content: BidiInputEvent | ToolResultEvent, ) -> None: """Send content to the model over the active connection. diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 17c49f7de..5d1763bcd 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -31,7 +31,7 @@ BidiConnectionStartEvent, BidiErrorEvent, BidiImageInputEvent, - InputEvent, + BidiInputEvent, BidiInterruptionEvent, BidiUsageEvent, BidiTextInputEvent, @@ -334,7 +334,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic async def send( self, - content: InputEvent | ToolResultEvent, + content: BidiInputEvent | ToolResultEvent, ) -> None: """Unified send method for all content types. Sends the given inputs to Google Live API diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index b7e8d05b8..06b267270 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -40,10 +40,10 @@ BidiConnectionStartEvent, BidiErrorEvent, BidiImageInputEvent, - InputEvent, + BidiInputEvent, BidiInterruptionEvent, BidiUsageEvent, - OutputEvent, + BidiOutputEvent, BidiTextInputEvent, BidiTranscriptStreamEvent, BidiResponseCompleteEvent, @@ -308,7 +308,7 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: async def send( self, - content: InputEvent | ToolResultEvent, + content: BidiInputEvent | ToolResultEvent, ) -> None: """Unified send method for all content types. Sends the given content to Nova Sonic. @@ -513,7 +513,7 @@ async def stop(self) -> None: finally: logger.debug("Nova connection closed") - def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: + def _convert_nova_event(self, nova_event: dict[str, any]) -> BidiOutputEvent | None: """Convert Nova Sonic events to TypedEvent format.""" # Handle completion start - track completionId if "completionStart" in nova_event: diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index e21866ccc..d3d36bf6d 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -25,10 +25,10 @@ BidiConnectionStartEvent, BidiErrorEvent, BidiImageInputEvent, - InputEvent, + BidiInputEvent, BidiInterruptionEvent, BidiUsageEvent, - OutputEvent, + BidiOutputEvent, BidiTextInputEvent, BidiTranscriptStreamEvent, BidiResponseCompleteEvent, @@ -291,7 +291,7 @@ async def _process_responses(self) -> None: self._active = False logger.debug("OpenAI Realtime response processor stopped") - async def receive(self) -> AsyncIterable[OutputEvent]: + async def receive(self) -> AsyncIterable[BidiOutputEvent]: """Receive OpenAI events and convert to Strands TypedEvent format.""" # Emit connection start event yield BidiConnectionStartEvent( @@ -315,7 +315,7 @@ async def receive(self) -> AsyncIterable[OutputEvent]: # Emit connection close event yield BidiConnectionCloseEvent(connection_id=self.connection_id, reason="complete") - def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEvent] | None: + def _convert_openai_event(self, openai_event: dict[str, any]) -> list[BidiOutputEvent] | None: """Convert OpenAI events to Strands TypedEvent format.""" event_type = openai_event.get("type") @@ -526,7 +526,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEven async def send( self, - content: InputEvent | ToolResultEvent, + content: BidiInputEvent | ToolResultEvent, ) -> None: """Unified send method for all content types. Sends the given content to OpenAI. diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index bcc447b9b..704104c3c 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -14,11 +14,11 @@ BidiConnectionStartEvent, BidiErrorEvent, BidiImageInputEvent, - InputEvent, + BidiInputEvent, BidiInterruptionEvent, ModalityUsage, BidiUsageEvent, - OutputEvent, + BidiOutputEvent, BidiResponseCompleteEvent, BidiResponseStartEvent, BidiTextInputEvent, @@ -31,7 +31,7 @@ "BidiTextInputEvent", "BidiAudioInputEvent", "BidiImageInputEvent", - "InputEvent", + "BidiInputEvent", # Output Events "BidiConnectionStartEvent", "BidiConnectionCloseEvent", @@ -43,7 +43,7 @@ "BidiUsageEvent", "ModalityUsage", "BidiErrorEvent", - "OutputEvent", + "BidiOutputEvent", # Constants "SUPPORTED_AUDIO_FORMATS", "SUPPORTED_SAMPLE_RATES", diff --git a/src/strands/experimental/bidirectional_streaming/types/events.py b/src/strands/experimental/bidirectional_streaming/types/events.py index 5275d8e58..595ed9f38 100644 --- a/src/strands/experimental/bidirectional_streaming/types/events.py +++ b/src/strands/experimental/bidirectional_streaming/types/events.py @@ -504,11 +504,11 @@ def details(self) -> Optional[Dict[str, Any]]: # ============================================================================ # Note: ToolResultEvent is imported from strands.types._events and used alongside -# InputEvent in send() methods for sending tool results back to the model. +# BidiInputEvent in send() methods for sending tool results back to the model. -InputEvent = BidiTextInputEvent | BidiAudioInputEvent | BidiImageInputEvent +BidiInputEvent = BidiTextInputEvent | BidiAudioInputEvent | BidiImageInputEvent -OutputEvent = ( +BidiOutputEvent = ( BidiConnectionStartEvent | BidiResponseStartEvent | BidiAudioStreamEvent From 5eca8f98eeea68f0a6b4e0e38a09fe5204b34d82 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 11 Nov 2025 00:57:23 +0300 Subject: [PATCH 084/242] refactor: change event type prefix to bidi --- .../bidirectional_streaming/agent/agent.py | 6 ++--- .../event_loop/bidirectional_event_loop.py | 6 ++--- .../bidirectional_streaming/types/events.py | 24 +++++++++---------- .../models/test_gemini_live.py | 5 ++++ .../models/test_novasonic.py | 14 +++++------ .../models/test_openai_realtime.py | 14 +++++------ .../types/test_events.py | 24 +++++++++---------- .../bidirectional_streaming/context.py | 8 +++---- .../generators/audio.py | 2 +- 9 files changed, 54 insertions(+), 49 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 7fb2e21c1..0730a044f 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -302,16 +302,16 @@ async def send(self, input_data: BidirectionalInput) -> None: # Handle plain dict - reconstruct TypedEvent for WebSocket integration if isinstance(input_data, dict) and "type" in input_data: event_type = input_data["type"] - if event_type == "bidirectional_text_input": + if event_type == "bidi_text_input": input_data = BidiTextInputEvent(text=input_data["text"], role=input_data["role"]) - elif event_type == "bidirectional_audio_input": + elif event_type == "bidi_audio_input": input_data = BidiAudioInputEvent( audio=input_data["audio"], format=input_data["format"], sample_rate=input_data["sample_rate"], channels=input_data["channels"] ) - elif event_type == "bidirectional_image_input": + elif event_type == "bidi_image_input": input_data = BidiImageInputEvent( image=input_data["image"], mime_type=input_data["mime_type"] diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 4ca763fc5..8799f14ec 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -225,7 +225,7 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: event = session.agent._output_queue.get_nowait() # Check for audio events event_type = event.get("type", "") - if event_type == "bidirectional_audio_stream": + if event_type == "bidi_audio_stream": audio_cleared += 1 else: # Keep non-audio events @@ -273,7 +273,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: event_type = strands_event.get("type", "") # Handle interruption detection - if event_type == "bidirectional_interruption": + if event_type == "bidi_interruption": logger.debug("Interruption forwarded") await _handle_interruption(session) # Forward interruption event to agent for application-level handling @@ -296,7 +296,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: await session.agent._output_queue.put(strands_event) # Update Agent conversation history for user transcripts - if event_type == "bidirectional_transcript_stream": + if event_type == "bidi_transcript_stream": role = strands_event.get("role") text = strands_event.get("text", "") if role == "user" and text.strip(): diff --git a/src/strands/experimental/bidirectional_streaming/types/events.py b/src/strands/experimental/bidirectional_streaming/types/events.py index 595ed9f38..852950f5a 100644 --- a/src/strands/experimental/bidirectional_streaming/types/events.py +++ b/src/strands/experimental/bidirectional_streaming/types/events.py @@ -51,7 +51,7 @@ class BidiTextInputEvent(TypedEvent): def __init__(self, text: str, role: str): super().__init__( { - "type": "bidirectional_text_input", + "type": "bidi_text_input", "text": text, "role": role, } @@ -87,7 +87,7 @@ def __init__( ): super().__init__( { - "type": "bidirectional_audio_input", + "type": "bidi_audio_input", "audio": audio, "format": format, "sample_rate": sample_rate, @@ -129,7 +129,7 @@ def __init__( ): super().__init__( { - "type": "bidirectional_image_input", + "type": "bidi_image_input", "image": image, "mime_type": mime_type, } @@ -160,7 +160,7 @@ class BidiConnectionStartEvent(TypedEvent): def __init__(self, connection_id: str, model: str): super().__init__( { - "type": "bidirectional_connection_start", + "type": "bidi_connection_start", "connection_id": connection_id, "model": model, } @@ -183,7 +183,7 @@ class BidiResponseStartEvent(TypedEvent): """ def __init__(self, response_id: str): - super().__init__({"type": "bidirectional_response_start", "response_id": response_id}) + super().__init__({"type": "bidi_response_start", "response_id": response_id}) @property def response_id(self) -> str: @@ -209,7 +209,7 @@ def __init__( ): super().__init__( { - "type": "bidirectional_audio_stream", + "type": "bidi_audio_stream", "audio": audio, "format": format, "sample_rate": sample_rate, @@ -258,7 +258,7 @@ def __init__( ): super().__init__( { - "type": "bidirectional_transcript_stream", + "type": "bidi_transcript_stream", "delta": delta, "text": text, "role": role, @@ -299,7 +299,7 @@ class BidiInterruptionEvent(TypedEvent): def __init__(self, reason: Literal["user_speech", "error"]): super().__init__( { - "type": "bidirectional_interruption", + "type": "bidi_interruption", "reason": reason, } ) @@ -324,7 +324,7 @@ def __init__( ): super().__init__( { - "type": "bidirectional_response_complete", + "type": "bidi_response_complete", "response_id": response_id, "stop_reason": stop_reason, } @@ -378,7 +378,7 @@ def __init__( cache_write_input_tokens: Optional[int] = None, ): data: Dict[str, Any] = { - "type": "bidirectional_usage", + "type": "bidi_usage", "inputTokens": input_tokens, "outputTokens": output_tokens, "totalTokens": total_tokens, @@ -431,7 +431,7 @@ def __init__( ): super().__init__( { - "type": "bidirectional_connection_close", + "type": "bidi_connection_close", "connection_id": connection_id, "reason": reason, } @@ -466,7 +466,7 @@ def __init__( # Store serializable data in dict (for JSON serialization) super().__init__( { - "type": "bidirectional_error", + "type": "bidi_error", "message": str(error), "code": type(error).__name__, "details": details, diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index f2821a7a0..4e868ddae 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -276,8 +276,10 @@ async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): # Verify connection start and end assert len(events) >= 2 assert isinstance(events[0], BidiConnectionStartEvent) + assert events[0].get("type") == "bidi_connection_start" assert events[0].connection_id == model.connection_id assert isinstance(events[-1], BidiConnectionCloseEvent) + assert events[-1].get("type") == "bidi_connection_close" @pytest.mark.asyncio @@ -301,6 +303,7 @@ async def test_event_conversion(mock_genai_client, model): text_event = model._convert_gemini_live_event(mock_text) assert isinstance(text_event, BidiTranscriptStreamEvent) + assert text_event.get("type") == "bidi_transcript_stream" assert text_event.text == "Hello from Gemini" assert text_event.role == "assistant" assert text_event.is_final is True @@ -317,6 +320,7 @@ async def test_event_conversion(mock_genai_client, model): audio_event = model._convert_gemini_live_event(mock_audio) assert isinstance(audio_event, BidiAudioStreamEvent) + assert audio_event.get("type") == "bidi_audio_stream" # Audio is now base64 encoded expected_b64 = base64.b64encode(b"audio_data").decode('utf-8') assert audio_event.audio == expected_b64 @@ -358,6 +362,7 @@ async def test_event_conversion(mock_genai_client, model): interrupt_event = model._convert_gemini_live_event(mock_interrupt) assert isinstance(interrupt_event, BidiInterruptionEvent) + assert interrupt_event.get("type") == "bidi_interruption" assert interrupt_event.reason == "user_speech" await model.stop() diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 9f0d35f88..2b2b54080 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -228,9 +228,9 @@ async def mock_wait_for(*args, **kwargs): # Should have session start and end (new TypedEvent format) assert len(events) >= 2 - assert events[0].get("type") == "bidirectional_connection_start" + assert events[0].get("type") == "bidi_connection_start" assert events[0].get("connection_id") == nova_model.connection_id - assert events[-1].get("type") == "bidirectional_connection_close" + assert events[-1].get("type") == "bidi_connection_close" @pytest.mark.asyncio @@ -244,7 +244,7 @@ async def test_event_conversion(nova_model): result = nova_model._convert_nova_event(nova_event) assert result is not None assert isinstance(result, BidiAudioStreamEvent) - assert result.get("type") == "bidirectional_audio_stream" + assert result.get("type") == "bidi_audio_stream" # Audio is kept as base64 string assert result.get("audio") == audio_base64 assert result.get("format") == "pcm" @@ -256,7 +256,7 @@ async def test_event_conversion(nova_model): result = nova_model._convert_nova_event(nova_event) assert result is not None assert isinstance(result, BidiTranscriptStreamEvent) - assert result.get("type") == "bidirectional_transcript_stream" + assert result.get("type") == "bidi_transcript_stream" assert result.get("text") == "Hello, world!" assert result.get("role") == "assistant" assert result.delta == {"text": "Hello, world!"} @@ -287,7 +287,7 @@ async def test_event_conversion(nova_model): result = nova_model._convert_nova_event(nova_event) assert result is not None assert isinstance(result, BidiInterruptionEvent) - assert result.get("type") == "bidirectional_interruption" + assert result.get("type") == "bidi_interruption" assert result.get("reason") == "user_speech" # Test usage metrics (now returns BidiUsageEvent) @@ -309,7 +309,7 @@ async def test_event_conversion(nova_model): result = nova_model._convert_nova_event(nova_event) assert result is not None assert isinstance(result, BidiUsageEvent) - assert result.get("type") == "bidirectional_usage" + assert result.get("type") == "bidi_usage" assert result.get("totalTokens") == 100 assert result.get("inputTokens") == 40 assert result.get("outputTokens") == 60 @@ -320,7 +320,7 @@ async def test_event_conversion(nova_model): result = nova_model._convert_nova_event(nova_event) assert result is not None assert isinstance(result, BidiResponseStartEvent) - assert result.get("type") == "bidirectional_response_start" + assert result.get("type") == "bidi_response_start" assert nova_model._current_role == "USER" diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 36aaf9242..48ccf336f 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -317,7 +317,7 @@ async def test_receive_lifecycle_events(mock_websockets_connect, model): first_event = await anext(receive_gen) # First event should be connection start (new TypedEvent format) - assert first_event.get("type") == "bidirectional_connection_start" + assert first_event.get("type") == "bidi_connection_start" assert first_event.get("connection_id") == model.connection_id assert first_event.get("model") == model.model @@ -333,7 +333,7 @@ async def test_receive_lifecycle_events(mock_websockets_connect, model): pass # Last event should be connection close (new TypedEvent format) - assert events[-1].get("type") == "bidirectional_connection_close" + assert events[-1].get("type") == "bidi_connection_close" @pytest.mark.asyncio @@ -352,7 +352,7 @@ async def test_event_conversion(mock_websockets_connect, model): assert isinstance(converted, list) assert len(converted) == 1 assert isinstance(converted[0], BidiAudioStreamEvent) - assert converted[0].get("type") == "bidirectional_audio_stream" + assert converted[0].get("type") == "bidi_audio_stream" assert converted[0].get("audio") == base64.b64encode(b"audio_data").decode() assert converted[0].get("format") == "pcm" @@ -366,7 +366,7 @@ async def test_event_conversion(mock_websockets_connect, model): assert isinstance(converted, list) assert len(converted) == 1 assert isinstance(converted[0], BidiTranscriptStreamEvent) - assert converted[0].get("type") == "bidirectional_transcript_stream" + assert converted[0].get("type") == "bidi_transcript_stream" assert converted[0].get("text") == "Hello from OpenAI" assert converted[0].get("role") == "assistant" assert converted[0].delta == {"text": "Hello from OpenAI"} @@ -415,7 +415,7 @@ async def test_event_conversion(mock_websockets_connect, model): assert isinstance(converted, list) assert len(converted) == 1 assert isinstance(converted[0], BidiInterruptionEvent) - assert converted[0].get("type") == "bidirectional_interruption" + assert converted[0].get("type") == "bidi_interruption" assert converted[0].get("reason") == "user_speech" await model.stop() @@ -468,7 +468,7 @@ def test_helper_methods(model): from strands.experimental.bidirectional_streaming.types.events import BidiTranscriptStreamEvent text_event = model._create_text_event("Hello", "user") assert isinstance(text_event, BidiTranscriptStreamEvent) - assert text_event.get("type") == "bidirectional_transcript_stream" + assert text_event.get("type") == "bidi_transcript_stream" assert text_event.get("text") == "Hello" assert text_event.get("role") == "user" assert text_event.delta == {"text": "Hello"} @@ -479,7 +479,7 @@ def test_helper_methods(model): from strands.experimental.bidirectional_streaming.types.events import BidiInterruptionEvent voice_event = model._create_voice_activity_event("speech_started") assert isinstance(voice_event, BidiInterruptionEvent) - assert voice_event.get("type") == "bidirectional_interruption" + assert voice_event.get("type") == "bidi_interruption" assert voice_event.get("reason") == "user_speech" # Other voice activities return None diff --git a/tests/strands/experimental/bidirectional_streaming/types/test_events.py b/tests/strands/experimental/bidirectional_streaming/types/test_events.py index fd7639cc4..bc7ec4844 100644 --- a/tests/strands/experimental/bidirectional_streaming/types/test_events.py +++ b/tests/strands/experimental/bidirectional_streaming/types/test_events.py @@ -28,7 +28,7 @@ "event_class,kwargs,expected_type", [ # Input events - (BidiTextInputEvent, {"text": "Hello", "role": "user"}, "bidirectional_text_input"), + (BidiTextInputEvent, {"text": "Hello", "role": "user"}, "bidi_text_input"), ( BidiAudioInputEvent, { @@ -37,20 +37,20 @@ "sample_rate": 16000, "channels": 1, }, - "bidirectional_audio_input", + "bidi_audio_input", ), ( BidiImageInputEvent, {"image": base64.b64encode(b"image").decode("utf-8"), "mime_type": "image/jpeg"}, - "bidirectional_image_input", + "bidi_image_input", ), # Output events ( BidiConnectionStartEvent, {"connection_id": "c1", "model": "m1"}, - "bidirectional_connection_start", + "bidi_connection_start", ), - (BidiResponseStartEvent, {"response_id": "r1"}, "bidirectional_response_start"), + (BidiResponseStartEvent, {"response_id": "r1"}, "bidi_response_start"), ( BidiAudioStreamEvent, { @@ -59,7 +59,7 @@ "sample_rate": 24000, "channels": 1, }, - "bidirectional_audio_stream", + "bidi_audio_stream", ), ( BidiTranscriptStreamEvent, @@ -70,25 +70,25 @@ "is_final": True, "current_transcript": "Hello", }, - "bidirectional_transcript_stream", + "bidi_transcript_stream", ), - (BidiInterruptionEvent, {"reason": "user_speech"}, "bidirectional_interruption"), + (BidiInterruptionEvent, {"reason": "user_speech"}, "bidi_interruption"), ( BidiResponseCompleteEvent, {"response_id": "r1", "stop_reason": "complete"}, - "bidirectional_response_complete", + "bidi_response_complete", ), ( BidiUsageEvent, {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - "bidirectional_usage", + "bidi_usage", ), ( BidiConnectionCloseEvent, {"connection_id": "c1", "reason": "complete"}, - "bidirectional_connection_close", + "bidi_connection_close", ), - (BidiErrorEvent, {"error": ValueError("test"), "details": None}, "bidirectional_error"), + (BidiErrorEvent, {"error": ValueError("test"), "details": None}, "bidi_error"), ], ) def test_event_json_serialization(event_class, kwargs, expected_type): diff --git a/tests_integ/bidirectional_streaming/context.py b/tests_integ/bidirectional_streaming/context.py index a8c784cc2..9553da699 100644 --- a/tests_integ/bidirectional_streaming/context.py +++ b/tests_integ/bidirectional_streaming/context.py @@ -231,8 +231,8 @@ def get_text_outputs(self) -> list[str]: """ texts = [] for event in self.get_events(): # Drain queue first - # Handle new TypedEvent format (bidirectional_transcript_stream) - if event.get("type") == "bidirectional_transcript_stream": + # Handle new TypedEvent format (bidi_transcript_stream) + if event.get("type") == "bidi_transcript_stream": text = event.get("text", "") if text: texts.append(text) @@ -260,8 +260,8 @@ def get_audio_outputs(self) -> list[bytes]: events = self.get_events() audio_data = [] for event in events: - # Handle new TypedEvent format (bidirectional_audio_stream) - if event.get("type") == "bidirectional_audio_stream": + # Handle new TypedEvent format (bidi_audio_stream) + if event.get("type") == "bidi_audio_stream": audio_b64 = event.get("audio") if audio_b64: # Decode base64 to bytes diff --git a/tests_integ/bidirectional_streaming/generators/audio.py b/tests_integ/bidirectional_streaming/generators/audio.py index 0af0f9949..ab90e304a 100644 --- a/tests_integ/bidirectional_streaming/generators/audio.py +++ b/tests_integ/bidirectional_streaming/generators/audio.py @@ -126,7 +126,7 @@ def create_audio_input_event( audio_b64 = base64.b64encode(audio_data).decode('utf-8') return { - "type": "bidirectional_audio_input", + "type": "bidi_audio_input", "audio": audio_b64, "format": format, "sample_rate": sample_rate, From 328b5dd219c508deae66ce2c4617c2e49476efcb Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 11 Nov 2025 01:39:11 +0300 Subject: [PATCH 085/242] add timeout to agent.receive and fix integ tests --- .../bidirectional_streaming/agent/agent.py | 4 ++- .../models/test_gemini_live.py | 21 +++++-------- .../models/test_novasonic.py | 31 +++++++------------ .../models/test_openai_realtime.py | 11 +++---- .../bidirectional_streaming/context.py | 14 +++++---- .../generators/audio.py | 3 +- 6 files changed, 34 insertions(+), 50 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 0730a044f..a6551fb03 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -339,9 +339,11 @@ async def receive(self) -> AsyncIterable[BidiOutputEvent]: """ while self.active: try: - event = await self._output_queue.get() + # Use a timeout to periodically check if we should stop + event = await asyncio.wait_for(self._output_queue.get(), timeout=0.5) yield event except asyncio.TimeoutError: + # Timeout allows us to check self.active periodically continue async def stop(self) -> None: diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index 4e868ddae..5e8e7a80d 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -7,6 +7,8 @@ - Event receiving and conversion """ +import base64 +import json import unittest.mock import pytest @@ -16,8 +18,13 @@ from strands.experimental.bidirectional_streaming.models.gemini_live import BidiGeminiLiveModel from strands.experimental.bidirectional_streaming.types.events import ( BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionCloseEvent, + BidiConnectionStartEvent, BidiImageInputEvent, + BidiInterruptionEvent, BidiTextInputEvent, + BidiTranscriptStreamEvent, ) from strands.types._events import ToolResultEvent from strands.types.tools import ToolResult @@ -198,7 +205,6 @@ async def test_send_all_content_types(mock_genai_client, model): assert content.parts[0].text == "Hello" # Test audio input (base64 encoded) - import base64 audio_b64 = base64.b64encode(b"audio_bytes").decode('utf-8') audio_input = BidiAudioInputEvent( audio=audio_b64, @@ -219,7 +225,6 @@ async def test_send_all_content_types(mock_genai_client, model): mock_live_session.send.assert_called_once() # Test tool result - from strands.types._events import ToolResultEvent tool_result: ToolResult = { "toolUseId": "tool-123", "status": "success", @@ -255,11 +260,6 @@ async def test_send_edge_cases(mock_genai_client, model): @pytest.mark.asyncio async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): """Test that receive() emits connection start and end events.""" - from strands.experimental.bidirectional_streaming.types.events import ( - BidiConnectionStartEvent, - BidiConnectionCloseEvent, - ) - _, mock_live_session, _ = mock_genai_client mock_live_session.receive.return_value = agenerator([]) @@ -285,12 +285,6 @@ async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): @pytest.mark.asyncio async def test_event_conversion(mock_genai_client, model): """Test conversion of all Gemini Live event types to standard format.""" - from strands.experimental.bidirectional_streaming.types.events import ( - BidiTranscriptStreamEvent, - BidiAudioStreamEvent, - BidiInterruptionEvent, - ) - _, _, _ = mock_genai_client await model.start() @@ -311,7 +305,6 @@ async def test_event_conversion(mock_genai_client, model): assert text_event.current_transcript == "Hello from Gemini" # Test audio output (base64 encoded) - import base64 mock_audio = unittest.mock.Mock() mock_audio.text = None mock_audio.data = b"audio_data" diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 2b2b54080..c79e1d673 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -15,6 +15,17 @@ from strands.experimental.bidirectional_streaming.models.novasonic import ( BidiNovaSonicModel, ) +from strands.experimental.bidirectional_streaming.types.events import ( + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiImageInputEvent, + BidiInterruptionEvent, + BidiResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiUsageEvent, +) +from strands.types._events import ToolResultEvent from strands.types.tools import ToolResult @@ -131,12 +142,6 @@ async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model @pytest.mark.asyncio async def test_send_all_content_types(nova_model, mock_client, mock_stream): """Test sending all content types through unified send() method.""" - from strands.experimental.bidirectional_streaming.types.events import ( - BidiTextInputEvent, - BidiAudioInputEvent, - ) - from strands.types._events import ToolResultEvent - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model.client = mock_client @@ -177,11 +182,6 @@ async def test_send_all_content_types(nova_model, mock_client, mock_stream): @pytest.mark.asyncio async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog): """Test send() edge cases and error handling.""" - from strands.experimental.bidirectional_streaming.types.events import ( - BidiTextInputEvent, - BidiImageInputEvent, - ) - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model.client = mock_client @@ -191,7 +191,6 @@ async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog): # Test image content (not supported, base64 encoded, no encoding parameter) await nova_model.start() - import base64 image_b64 = base64.b64encode(b"image data").decode('utf-8') image_event = BidiImageInputEvent( image=image_b64, @@ -237,7 +236,6 @@ async def mock_wait_for(*args, **kwargs): async def test_event_conversion(nova_model): """Test conversion of all Nova Sonic event types to standard format.""" # Test audio output (now returns BidiAudioStreamEvent) - from strands.experimental.bidirectional_streaming.types.events import BidiAudioStreamEvent audio_bytes = b"test audio data" audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") nova_event = {"audioOutput": {"content": audio_base64}} @@ -251,7 +249,6 @@ async def test_event_conversion(nova_model): assert result.get("sample_rate") == 24000 # Test text output (now returns BidiTranscriptStreamEvent) - from strands.experimental.bidirectional_streaming.types.events import BidiTranscriptStreamEvent nova_event = {"textOutput": {"content": "Hello, world!", "role": "ASSISTANT"}} result = nova_model._convert_nova_event(nova_event) assert result is not None @@ -282,7 +279,6 @@ async def test_event_conversion(nova_model): assert tool_use["input"] == tool_input # Test interruption (now returns BidiInterruptionEvent) - from strands.experimental.bidirectional_streaming.types.events import BidiInterruptionEvent nova_event = {"stopReason": "INTERRUPTED"} result = nova_model._convert_nova_event(nova_event) assert result is not None @@ -291,7 +287,6 @@ async def test_event_conversion(nova_model): assert result.get("reason") == "user_speech" # Test usage metrics (now returns BidiUsageEvent) - from strands.experimental.bidirectional_streaming.types.events import BidiUsageEvent nova_event = { "usageEvent": { "totalTokens": 100, @@ -315,7 +310,6 @@ async def test_event_conversion(nova_model): assert result.get("outputTokens") == 60 # Test content start tracks role and emits BidiResponseStartEvent - from strands.experimental.bidirectional_streaming.types.events import BidiResponseStartEvent nova_event = {"contentStart": {"role": "USER"}} result = nova_model._convert_nova_event(nova_event) assert result is not None @@ -349,8 +343,6 @@ async def test_audio_connection_lifecycle(nova_model, mock_client, mock_stream): @pytest.mark.asyncio async def test_silence_detection(nova_model, mock_client, mock_stream): """Test that silence detection automatically ends audio input.""" - from strands.experimental.bidirectional_streaming.types.events import BidiAudioInputEvent - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model.client = mock_client nova_model.silence_threshold = 0.1 # Short threshold for testing @@ -358,7 +350,6 @@ async def test_silence_detection(nova_model, mock_client, mock_stream): await nova_model.start() # Send audio to start connection (base64 encoded) - import base64 audio_b64 = base64.b64encode(b"audio data").decode('utf-8') audio_event = BidiAudioInputEvent( audio=audio_b64, diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 48ccf336f..4873262ff 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -18,9 +18,13 @@ from strands.experimental.bidirectional_streaming.models.openai import BidiOpenAIRealtimeModel from strands.experimental.bidirectional_streaming.types.events import ( BidiAudioInputEvent, + BidiAudioStreamEvent, BidiImageInputEvent, + BidiInterruptionEvent, BidiTextInputEvent, + BidiTranscriptStreamEvent, ) +from strands.types._events import ToolResultEvent from strands.types.tools import ToolResult @@ -222,8 +226,6 @@ async def async_connect(*args, **kwargs): @pytest.mark.asyncio async def test_send_all_content_types(mock_websockets_connect, model): """Test sending all content types through unified send() method.""" - from strands.types._events import ToolResultEvent - _, mock_ws = mock_websockets_connect await model.start() @@ -343,7 +345,6 @@ async def test_event_conversion(mock_websockets_connect, model): await model.start() # Test audio output (now returns list with BidiAudioStreamEvent) - from strands.experimental.bidirectional_streaming.types.events import BidiAudioStreamEvent audio_event = { "type": "response.output_audio.delta", "delta": base64.b64encode(b"audio_data").decode() @@ -357,7 +358,6 @@ async def test_event_conversion(mock_websockets_connect, model): assert converted[0].get("format") == "pcm" # Test text output (now returns list with BidiTranscriptStreamEvent) - from strands.experimental.bidirectional_streaming.types.events import BidiTranscriptStreamEvent text_event = { "type": "response.output_text.delta", "delta": "Hello from OpenAI" @@ -407,7 +407,6 @@ async def test_event_conversion(mock_websockets_connect, model): assert tool_use["input"]["expression"] == "2+2" # Test voice activity (now returns list with BidiInterruptionEvent for speech_started) - from strands.experimental.bidirectional_streaming.types.events import BidiInterruptionEvent speech_started = { "type": "input_audio_buffer.speech_started" } @@ -465,7 +464,6 @@ def test_helper_methods(model): model._active = False # Test _create_text_event (now returns BidiTranscriptStreamEvent) - from strands.experimental.bidirectional_streaming.types.events import BidiTranscriptStreamEvent text_event = model._create_text_event("Hello", "user") assert isinstance(text_event, BidiTranscriptStreamEvent) assert text_event.get("type") == "bidi_transcript_stream" @@ -476,7 +474,6 @@ def test_helper_methods(model): assert text_event.current_transcript == "Hello" # Test _create_voice_activity_event (now returns BidiInterruptionEvent for speech_started) - from strands.experimental.bidirectional_streaming.types.events import BidiInterruptionEvent voice_event = model._create_voice_activity_event("speech_started") assert isinstance(voice_event, BidiInterruptionEvent) assert voice_event.get("type") == "bidi_interruption" diff --git a/tests_integ/bidirectional_streaming/context.py b/tests_integ/bidirectional_streaming/context.py index 9553da699..349ad0cb9 100644 --- a/tests_integ/bidirectional_streaming/context.py +++ b/tests_integ/bidirectional_streaming/context.py @@ -5,6 +5,7 @@ """ import asyncio +import base64 import logging import time from typing import TYPE_CHECKING @@ -81,12 +82,13 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): """Stop context manager, cleanup threads, and end agent session.""" - await self.stop() - - # End agent session + # End agent session FIRST - this will cause receive() to exit cleanly if self.agent._agent_loop and self.agent._agent_loop.active: await self.agent.stop() - logger.debug("Agent session ended") + logger.debug("Agent session stopped") + + # Then stop the context threads + await self.stop() return False @@ -254,8 +256,6 @@ def get_audio_outputs(self) -> list[bytes]: Returns: List of audio data bytes. """ - import base64 - # Drain queue first to get latest events events = self.get_events() audio_data = [] @@ -332,6 +332,7 @@ async def _input_thread(self): except asyncio.CancelledError: logger.debug("Input thread cancelled") + raise # Re-raise to properly propagate cancellation except Exception as e: logger.error(f"Input thread error: {e}", exc_info=True) finally: @@ -350,6 +351,7 @@ async def _event_collection_thread(self): except asyncio.CancelledError: logger.debug("Event collection thread cancelled") + raise # Re-raise to properly propagate cancellation except Exception as e: logger.error(f"Event collection thread error: {e}") diff --git a/tests_integ/bidirectional_streaming/generators/audio.py b/tests_integ/bidirectional_streaming/generators/audio.py index ab90e304a..75c17a1e3 100644 --- a/tests_integ/bidirectional_streaming/generators/audio.py +++ b/tests_integ/bidirectional_streaming/generators/audio.py @@ -4,6 +4,7 @@ without requiring physical audio devices or pre-recorded files. """ +import base64 import hashlib import logging from pathlib import Path @@ -120,8 +121,6 @@ def create_audio_input_event( Returns: BidiAudioInputEvent dict ready for agent.send(). """ - import base64 - # Convert bytes to base64 string for JSON compatibility audio_b64 = base64.b64encode(audio_data).decode('utf-8') From f7874fe922d773ec17db6f815e22eb179d21b6b9 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 11 Nov 2025 01:45:57 +0300 Subject: [PATCH 086/242] add bidi agent input type alias --- .../bidirectional_streaming/agent/agent.py | 13 ++++++------- .../bidirectional_streaming/types/__init__.py | 2 ++ .../bidirectional_streaming/types/agent.py | 10 ++++++++++ 3 files changed, 18 insertions(+), 7 deletions(-) create mode 100644 src/strands/experimental/bidirectional_streaming/types/agent.py diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index a6551fb03..2c3b902a8 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -34,6 +34,7 @@ ) from ..models.bidirectional_model import BidiModel from ..models.novasonic import BidiNovaSonicModel +from ..types.agent import BidiAgentInput from ..types.events import BidiAudioInputEvent, BidiImageInputEvent, BidiTextInputEvent, BidiInputEvent, BidiOutputEvent from ..types import BidiIO from ....experimental.tools import ToolProvider @@ -42,8 +43,6 @@ _DEFAULT_AGENT_NAME = "Strands Agents" _DEFAULT_AGENT_ID = "default" -# Type alias for cleaner send() method signature -BidirectionalInput = str | BidiAudioInputEvent | BidiImageInputEvent class BidiAgent: @@ -256,7 +255,7 @@ async def start(self) -> None: logger.debug("Conversation start - initializing connection") self._agent_loop = await start_bidirectional_connection(self) - async def send(self, input_data: BidirectionalInput) -> None: + async def send(self, input_data: BidiAgentInput) -> None: """Send input to the model (text, audio, image, or event dict). Unified method for sending text, audio, and image input to the model during @@ -303,16 +302,16 @@ async def send(self, input_data: BidirectionalInput) -> None: if isinstance(input_data, dict) and "type" in input_data: event_type = input_data["type"] if event_type == "bidi_text_input": - input_data = BidiTextInputEvent(text=input_data["text"], role=input_data["role"]) + input_event = BidiTextInputEvent(text=input_data["text"], role=input_data["role"]) elif event_type == "bidi_audio_input": - input_data = BidiAudioInputEvent( + input_event = BidiAudioInputEvent( audio=input_data["audio"], format=input_data["format"], sample_rate=input_data["sample_rate"], channels=input_data["channels"] ) elif event_type == "bidi_image_input": - input_data = BidiImageInputEvent( + input_event = BidiImageInputEvent( image=input_data["image"], mime_type=input_data["mime_type"] ) @@ -320,7 +319,7 @@ async def send(self, input_data: BidirectionalInput) -> None: raise ValueError(f"Unknown event type: {event_type}") # Send the reconstructed TypedEvent - await self._agent_loop.model.send(input_data) + await self._agent_loop.model.send(input_event) return # If we get here, input type is invalid diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index 704104c3c..34dd2685b 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -1,5 +1,6 @@ """Type definitions for bidirectional streaming.""" +from .agent import BidiAgentInput from .io import BidiIO from .events import ( DEFAULT_CHANNELS, @@ -27,6 +28,7 @@ __all__ = [ "BidiIO", + "BidiAgentInput", # Input Events "BidiTextInputEvent", "BidiAudioInputEvent", diff --git a/src/strands/experimental/bidirectional_streaming/types/agent.py b/src/strands/experimental/bidirectional_streaming/types/agent.py new file mode 100644 index 000000000..8d1e9aab7 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/types/agent.py @@ -0,0 +1,10 @@ +"""Agent-related type definitions for bidirectional streaming. + +This module defines the types used for BidiAgent. +""" + +from typing import TypeAlias + +from .events import BidiAudioInputEvent, BidiImageInputEvent, BidiTextInputEvent + +BidiAgentInput: TypeAlias = str | BidiTextInputEvent | BidiAudioInputEvent | BidiImageInputEvent From 330831de62c3c84d1ed225d447287e2a8174f1e1 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 11 Nov 2025 01:51:49 +0300 Subject: [PATCH 087/242] use bidi input event for type check --- .../experimental/bidirectional_streaming/agent/agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 2c3b902a8..a5675e857 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -292,9 +292,9 @@ async def send(self, input_data: BidiAgentInput) -> None: await self._agent_loop.model.send(text_event) return - # Handle InputEvent instances (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent) + # Handle BidiInputEvent instances # Check this before dict since TypedEvent inherits from dict - if isinstance(input_data, (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent)): + if isinstance(input_data, BidiInputEvent): await self._agent_loop.model.send(input_data) return From c06ec6f930a812fd879fa28f45db8a757ad8c1c5 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 11 Nov 2025 13:34:48 +0300 Subject: [PATCH 088/242] remove text logging --- .../experimental/bidirectional_streaming/models/gemini_live.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 6be0275f8..47b38e0eb 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -278,7 +278,6 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic for part in model_turn.parts: # Log all part types for debugging part_attrs = {attr: getattr(part, attr, None) for attr in dir(part) if not attr.startswith('_')} - logger.debug(f"Model turn part attributes: {part_attrs}") # Check if part has text attribute and it's not empty if hasattr(part, 'text') and part.text: @@ -286,7 +285,6 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic if text_parts: full_text = " ".join(text_parts) - logger.debug(f"Text output as transcript ({len(text_parts)} parts): {full_text}") return BidiTranscriptStreamEvent( delta={"text": full_text}, text=full_text, From ef292d3f8400c542fa7f4dc7b6adaf01775c97e3 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 11 Nov 2025 14:36:04 +0300 Subject: [PATCH 089/242] add hooks to agent for tool execution --- src/strands/experimental/bidirectional_streaming/agent/agent.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index a5675e857..05037cadd 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -18,6 +18,7 @@ from typing import Any, AsyncIterable, Callable from .... import _identifier +from ....hooks.registry import HookRegistry from ....telemetry.metrics import EventLoopMetrics from ....tools.caller import _ToolCaller from ....tools.executors import ConcurrentToolExecutor @@ -122,6 +123,7 @@ def __init__( # Initialize other components self.event_loop_metrics = EventLoopMetrics() self._tool_caller = _ToolCaller(self) + self.hooks = HookRegistry() # connection management self._agent_loop: "BidirectionalConnection" | None = None From 43fed3c51a7400918565b89b7277bf0b8c46c326 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 11 Nov 2025 14:36:22 +0300 Subject: [PATCH 090/242] add event types for tool related events --- src/strands/types/_events.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index afce36f2b..e3bbe2316 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -145,7 +145,7 @@ class ToolUseStreamEvent(ModelStreamEvent): def __init__(self, delta: ContentBlockDelta, current_tool_use: dict[str, Any]) -> None: """Initialize with delta and current tool use state.""" - super().__init__({"delta": delta, "current_tool_use": current_tool_use}) + super().__init__({"type": "tool_use_stream", "delta": delta, "current_tool_use": current_tool_use}) class TextStreamEvent(ModelStreamEvent): @@ -281,7 +281,7 @@ def __init__(self, tool_result: ToolResult) -> None: Args: tool_result: Final result from the tool execution """ - super().__init__({"tool_result": tool_result}) + super().__init__({"type": "tool_result", "tool_result": tool_result}) @property def tool_use_id(self) -> str: @@ -309,7 +309,7 @@ def __init__(self, tool_use: ToolUse, tool_stream_data: Any) -> None: tool_use: The tool invocation producing the stream tool_stream_data: The yielded event from the tool execution """ - super().__init__({"tool_stream_event": {"tool_use": tool_use, "data": tool_stream_data}}) + super().__init__({"type": "tool_stream", "tool_stream_event": {"tool_use": tool_use, "data": tool_stream_data}}) @property def tool_use_id(self) -> str: From f4f7e4da0a56d307ed14637dfb53d7e587c1d957 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 11 Nov 2025 14:39:45 +0300 Subject: [PATCH 091/242] fix test scripts --- .../scripts/test_bidi.py | 6 +- .../scripts/test_bidi_novasonic.py | 42 ++++++++--- .../scripts/test_bidi_openai.py | 49 +++++++++---- .../scripts/test_gemini_live.py | 73 +++++++++---------- 4 files changed, 102 insertions(+), 68 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/scripts/test_bidi.py b/src/strands/experimental/bidirectional_streaming/scripts/test_bidi.py index 6df0063be..359f04dbf 100644 --- a/src/strands/experimental/bidirectional_streaming/scripts/test_bidi.py +++ b/src/strands/experimental/bidirectional_streaming/scripts/test_bidi.py @@ -6,7 +6,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) -from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.agent.agent import BidiAgent from strands.experimental.bidirectional_streaming.models.novasonic import BidiNovaSonicModel from strands.experimental.bidirectional_streaming.io.audio import AudioIO from strands_tools import calculator @@ -20,8 +20,8 @@ async def main(): adapter = AudioIO() model = BidiNovaSonicModel(region="us-east-1") - async with BidirectionalAgent(model=model, tools=[calculator]) as agent: - print("New BidirectionalAgent Experience") + async with BidiAgent(model=model, tools=[calculator]) as agent: + print("New BidiAgent Experience") print("Try asking: 'What is 25 times 8?' or 'Calculate the square root of 144'") await agent.run(io_channels=[adapter]) diff --git a/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_novasonic.py b/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_novasonic.py index 9e698b35e..42d8d436e 100644 --- a/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_novasonic.py @@ -17,7 +17,7 @@ import pyaudio from strands_tools import calculator -from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.agent.agent import BidiAgent from strands.experimental.bidirectional_streaming.models.novasonic import BidiNovaSonicModel @@ -130,23 +130,22 @@ async def receive(agent, context): """Receive and process events from agent.""" try: async for event in agent.receive(): - # Get event type event_type = event.get("type", "unknown") - # Handle audio stream events (bidirectional_audio_stream) - if event_type == "bidirectional_audio_stream": + # Handle audio stream events (bidi_audio_stream) + if event_type == "bidi_audio_stream": if not context.get("interrupted", False): # Decode base64 audio string to bytes for playback audio_b64 = event["audio"] audio_data = base64.b64decode(audio_b64) context["audio_out"].put_nowait(audio_data) - # Handle interruption events (bidirectional_interruption) - elif event_type == "bidirectional_interruption": + # Handle interruption events (bidi_interruption) + elif event_type == "bidi_interruption": context["interrupted"] = True - # Handle transcript events (bidirectional_transcript_stream) - elif event_type == "bidirectional_transcript_stream": + # Handle transcript events (bidi_transcript_stream) + elif event_type == "bidi_transcript_stream": text_content = event.get("text", "") role = event.get("role", "unknown") @@ -156,10 +155,29 @@ async def receive(agent, context): elif role == "assistant": print(f"Assistant: {text_content}") - # Handle turn complete events (bidirectional_turn_complete) - elif event_type == "bidirectional_turn_complete": + # Handle response complete events (bidi_response_complete) + elif event_type == "bidi_response_complete": # Reset interrupted state since the turn is complete context["interrupted"] = False + + # Handle tool use events (tool_use_stream) + elif event_type == "tool_use_stream": + tool_use = event.get("current_tool_use", {}) + tool_name = tool_use.get("name", "unknown") + tool_input = tool_use.get("input", {}) + print(f"🔧 Tool called: {tool_name} with input: {tool_input}") + + # Handle tool result events (tool_result) + elif event_type == "tool_result": + tool_result = event.get("tool_result", {}) + tool_name = tool_result.get("name", "unknown") + result_content = tool_result.get("content", []) + result_text = "" + for block in result_content: + if isinstance(block, dict) and block.get("type") == "text": + result_text = block.get("text", "") + break + print(f"✅ Tool result from {tool_name}: {result_text}") except asyncio.CancelledError: pass @@ -199,7 +217,7 @@ async def main(duration=180): # Initialize model and agent model = BidiNovaSonicModel(region="us-east-1") - agent = BidirectionalAgent(model=model, tools=[calculator], system_prompt="You are a helpful assistant.") + agent = BidiAgent(model=model, tools=[calculator], system_prompt="You are a helpful assistant.") await agent.start() @@ -208,7 +226,7 @@ async def main(duration=180): "active": True, "audio_in": asyncio.Queue(), "audio_out": asyncio.Queue(), - "connection": agent._session, + "connection": agent._agent_loop, "duration": duration, "start_time": time.time(), "interrupted": False, diff --git a/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_openai.py index c2d07f170..dd19e958d 100644 --- a/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_openai.py +++ b/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_openai.py @@ -14,7 +14,7 @@ import pyaudio from strands_tools import calculator -from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.agent.agent import BidiAgent from strands.experimental.bidirectional_streaming.models.openai import BidiOpenAIRealtimeModel @@ -122,8 +122,8 @@ async def receive(agent, context): # Get event type event_type = event.get("type", "unknown") - # Handle audio stream events (bidirectional_audio_stream) - if event_type == "bidirectional_audio_stream": + # Handle audio stream events (bidi_audio_stream) + if event_type == "bidi_audio_stream": # Decode base64 audio string to bytes for playback audio_b64 = event["audio"] audio_data = base64.b64decode(audio_b64) @@ -131,9 +131,9 @@ async def receive(agent, context): if not context.get("interrupted", False): await context["audio_out"].put(audio_data) - # Handle transcript events (bidirectional_transcript_stream) - elif event_type == "bidirectional_transcript_stream": - source = event.get("source", "assistant") + # Handle transcript events (bidi_transcript_stream) + elif event_type == "bidi_transcript_stream": + source = event.get("role", "assistant") text = event.get("text", "").strip() if text: @@ -142,25 +142,44 @@ async def receive(agent, context): elif source == "assistant": print(f"🔊 Assistant: {text}") - # Handle interruption events (bidirectional_interruption) - elif event_type == "bidirectional_interruption": + # Handle interruption events (bidi_interruption) + elif event_type == "bidi_interruption": context["interrupted"] = True print("⚠️ Interruption detected") - # Handle session start events (bidirectional_session_start) - elif event_type == "bidirectional_session_start": + # Handle connection start events (bidi_connection_start) + elif event_type == "bidi_connection_start": print(f"✓ Session started: {event.get('model', 'unknown')}") - # Handle session end events (bidirectional_session_end) - elif event_type == "bidirectional_session_end": + # Handle connection close events (bidi_connection_close) + elif event_type == "bidi_connection_close": print(f"✓ Session ended: {event.get('reason', 'unknown')}") context["active"] = False break - # Handle turn complete events (bidirectional_turn_complete) - elif event_type == "bidirectional_turn_complete": + # Handle response complete events (bidi_response_complete) + elif event_type == "bidi_response_complete": # Reset interrupted state since the turn is complete context["interrupted"] = False + + # Handle tool use events (tool_use_stream) + elif event_type == "tool_use_stream": + tool_use = event.get("current_tool_use", {}) + tool_name = tool_use.get("name", "unknown") + tool_input = tool_use.get("input", {}) + print(f"🔧 Tool called: {tool_name} with input: {tool_input}") + + # Handle tool result events (tool_result) + elif event_type == "tool_result": + tool_result = event.get("tool_result", {}) + tool_name = tool_result.get("name", "unknown") + result_content = tool_result.get("content", []) + result_text = "" + for block in result_content: + if isinstance(block, dict) and block.get("type") == "text": + result_text = block.get("text", "") + break + print(f"✅ Tool result from {tool_name}: {result_text}") except asyncio.CancelledError: pass @@ -246,7 +265,7 @@ async def main(): ) # Create agent - agent = BidirectionalAgent( + agent = BidiAgent( model=model, tools=[calculator], system_prompt="You are a helpful voice assistant. Keep your responses brief and natural. Say hello when you first connect." diff --git a/src/strands/experimental/bidirectional_streaming/scripts/test_gemini_live.py b/src/strands/experimental/bidirectional_streaming/scripts/test_gemini_live.py index c50ab27b1..ba0d9edf7 100644 --- a/src/strands/experimental/bidirectional_streaming/scripts/test_gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/scripts/test_gemini_live.py @@ -37,13 +37,11 @@ import pyaudio from strands_tools import calculator -from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.agent.agent import BidiAgent from strands.experimental.bidirectional_streaming.models.gemini_live import BidiGeminiLiveModel -# Configure logging - debug only for Gemini Live, info for everything else +# Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -gemini_logger = logging.getLogger('strands.experimental.bidirectional_streaming.models.gemini_live') -gemini_logger.setLevel(logging.DEBUG) logger = logging.getLogger(__name__) @@ -145,58 +143,57 @@ async def receive(agent, context): """Receive and process events from agent.""" try: async for event in agent.receive(): - # Debug: Log event type and keys event_type = event.get("type", "unknown") - event_keys = list(event.keys()) - logger.debug(f"Received event type: {event_type}, keys: {event_keys}") - # Handle audio stream events (bidirectional_audio_stream) - if event_type == "bidirectional_audio_stream": + # Handle audio stream events (bidi_audio_stream) + if event_type == "bidi_audio_stream": if not context.get("interrupted", False): # Decode base64 audio string to bytes for playback audio_b64 = event["audio"] audio_data = base64.b64decode(audio_b64) context["audio_out"].put_nowait(audio_data) - logger.info(f"🔊 Audio queued for playback: {len(audio_data)} bytes") - # Handle interruption events (bidirectional_interruption) - elif event_type == "bidirectional_interruption": + # Handle interruption events (bidi_interruption) + elif event_type == "bidi_interruption": context["interrupted"] = True - logger.info("Interruption detected") + print("⚠️ Interruption detected") - # Handle transcript events (bidirectional_transcript_stream) - elif event_type == "bidirectional_transcript_stream": + # Handle transcript events (bidi_transcript_stream) + elif event_type == "bidi_transcript_stream": transcript_text = event.get("text", "") - transcript_source = event.get("source", "unknown") + transcript_role = event.get("role", "unknown") is_final = event.get("is_final", False) # Print transcripts with special formatting - if transcript_source == "user": + if transcript_role == "user": print(f"🎤 User: {transcript_text}") - elif transcript_source == "assistant": + elif transcript_role == "assistant": print(f"🔊 Assistant: {transcript_text}") - # Handle turn complete events (bidirectional_turn_complete) - elif event_type == "bidirectional_turn_complete": - logger.debug("Turn complete - model ready for next input") - # Reset interrupted state since the turn is complete + # Handle response complete events (bidi_response_complete) + elif event_type == "bidi_response_complete": + # Reset interrupted state since the response is complete context["interrupted"] = False - # Handle session start events (bidirectional_session_start) - elif event_type == "bidirectional_session_start": - logger.info(f"Session started: {event.get('model', 'unknown')}") + # Handle tool use events (tool_use_stream) + elif event_type == "tool_use_stream": + tool_use = event.get("current_tool_use", {}) + tool_name = tool_use.get("name", "unknown") + tool_input = tool_use.get("input", {}) + print(f"🔧 Tool called: {tool_name} with input: {tool_input}") - # Handle session end events (bidirectional_session_end) - elif event_type == "bidirectional_session_end": - logger.info(f"Session ended: {event.get('reason', 'unknown')}") - - # Handle error events (bidirectional_error) - elif event_type == "bidirectional_error": - logger.error(f"Error: {event.get('error_message', 'unknown')}") - - # Handle turn start events (bidirectional_turn_start) - elif event_type == "bidirectional_turn_start": - logger.debug(f"Turn started: {event.get('response_id', 'unknown')}") + # Handle tool result events (tool_result) + elif event_type == "tool_result": + tool_result = event.get("tool_result", {}) + tool_name = tool_result.get("name", "unknown") + result_content = tool_result.get("content", []) + # Extract text from content blocks + result_text = "" + for block in result_content: + if isinstance(block, dict) and block.get("type") == "text": + result_text = block.get("text", "") + break + print(f"✅ Tool result from {tool_name}: {result_text}") except asyncio.CancelledError: pass @@ -325,7 +322,7 @@ async def main(duration=180): logger.info("Gemini Live model initialized successfully") print("Using Gemini Live model") - agent = BidirectionalAgent( + agent = BidiAgent( model=model, tools=[calculator], system_prompt="You are a helpful assistant." @@ -338,7 +335,7 @@ async def main(duration=180): "active": True, "audio_in": asyncio.Queue(), "audio_out": asyncio.Queue(), - "connection": agent._session, + "connection": agent._agent_loop, "duration": duration, "start_time": time.time(), "interrupted": False, From 2d01d2d7220130376ccbdca2915fd65c363763fb Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 11 Nov 2025 14:59:27 +0300 Subject: [PATCH 092/242] add user transcription to openai --- .../experimental/bidirectional_streaming/models/openai.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index f7c4bb5ac..74f1942ff 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -51,6 +51,9 @@ "audio": { "input": { "format": AUDIO_FORMAT, + "transcription": { + "model": "gpt-4o-transcribe" + }, "turn_detection": { "type": "server_vad", "threshold": 0.5, From 490ce1e29a4bec20495bb339840e6b65c17bcad5 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 11 Nov 2025 18:17:40 +0300 Subject: [PATCH 093/242] fix(gemini): return multiple tool use events --- .../models/gemini_live.py | 53 +++++++++------- .../models/test_gemini_live.py | 63 +++++++++++++++++-- 2 files changed, 86 insertions(+), 30 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 47b38e0eb..9bb5bba77 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -33,6 +33,7 @@ BidiImageInputEvent, BidiInputEvent, BidiInterruptionEvent, + BidiOutputEvent, BidiUsageEvent, BidiTextInputEvent, BidiTranscriptStreamEvent, @@ -173,7 +174,7 @@ async def _send_message_history(self, messages: Messages) -> None: content = genai_types.Content(role=role, parts=content_parts) await self.live_session.send_client_content(turns=content) - async def receive(self) -> AsyncIterable[Dict[str, Any]]: + async def receive(self) -> AsyncIterable[BidiOutputEvent]: """Receive Gemini Live API events and convert to provider-agnostic format.""" # Emit connection start event @@ -190,10 +191,9 @@ async def receive(self) -> AsyncIterable[Dict[str, Any]]: if not self._active: break - # Convert to provider-agnostic format - provider_event = self._convert_gemini_live_event(message) - if provider_event: - yield provider_event + # Convert to provider-agnostic format (always returns list) + for event in self._convert_gemini_live_event(message): + yield event # SDK exits receive loop after turn_complete - restart automatically if self._active: @@ -211,7 +211,7 @@ async def receive(self) -> AsyncIterable[Dict[str, Any]]: # Emit connection close event when exiting yield BidiConnectionCloseEvent(connection_id=self.connection_id, reason="complete") - def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dict[str, Any]]: + def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOutputEvent]: """Convert Gemini Live API events to provider-agnostic format. Handles different types of content: @@ -219,11 +219,14 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic - outputTranscription: Model's audio transcribed to text - modelTurn text: Text response from the model - usageMetadata: Token usage information + + Returns: + List of event dicts (empty list if no events to emit). """ try: # Handle interruption first (from server_content) if message.server_content and message.server_content.interrupted: - return BidiInterruptionEvent(reason="user_speech") + return [BidiInterruptionEvent(reason="user_speech")] # Handle input transcription (user's speech) - emit as transcript event if message.server_content and message.server_content.input_transcription: @@ -233,13 +236,13 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic transcription_text = input_transcript.text role = getattr(input_transcript, 'role', 'user') logger.debug(f"Input transcription detected: {transcription_text}") - return BidiTranscriptStreamEvent( + return [BidiTranscriptStreamEvent( delta={"text": transcription_text}, text=transcription_text, role=role.lower() if isinstance(role, str) else "user", is_final=True, current_transcript=transcription_text - ) + )] # Handle output transcription (model's audio) - emit as transcript event if message.server_content and message.server_content.output_transcription: @@ -249,25 +252,25 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic transcription_text = output_transcript.text role = getattr(output_transcript, 'role', 'assistant') logger.debug(f"Output transcription detected: {transcription_text}") - return BidiTranscriptStreamEvent( + return [BidiTranscriptStreamEvent( delta={"text": transcription_text}, text=transcription_text, role=role.lower() if isinstance(role, str) else "assistant", is_final=True, current_transcript=transcription_text - ) + )] # Handle audio output using SDK's built-in data property # Check this BEFORE text to avoid triggering warning on mixed content if message.data: # Convert bytes to base64 string for JSON serializability audio_b64 = base64.b64encode(message.data).decode('utf-8') - return BidiAudioStreamEvent( + return [BidiAudioStreamEvent( audio=audio_b64, format="pcm", sample_rate=GEMINI_OUTPUT_SAMPLE_RATE, channels=GEMINI_CHANNELS - ) + )] # Handle text output from model_turn (avoids warning by checking parts directly) if message.server_content and message.server_content.model_turn: @@ -285,27 +288,29 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic if text_parts: full_text = " ".join(text_parts) - return BidiTranscriptStreamEvent( + return [BidiTranscriptStreamEvent( delta={"text": full_text}, text=full_text, role="assistant", is_final=True, current_transcript=full_text - ) + )] - # Handle tool calls + # Handle tool calls - return list to support multiple tool calls if message.tool_call and message.tool_call.function_calls: + tool_events = [] for func_call in message.tool_call.function_calls: tool_use_event: ToolUse = { "toolUseId": func_call.id, "name": func_call.name, "input": func_call.args or {} } - # Return ToolUseStreamEvent for consistency with standard agent - return ToolUseStreamEvent( + # Create ToolUseStreamEvent for consistency with standard agent + tool_events.append(ToolUseStreamEvent( delta={"toolUse": tool_use_event}, current_tool_use=tool_use_event - ) + )) + return tool_events # Handle usage metadata if hasattr(message, 'usage_metadata') and message.usage_metadata: @@ -340,23 +345,23 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic "output_tokens": detail.token_count }) - return BidiUsageEvent( + return [BidiUsageEvent( input_tokens=usage.prompt_token_count or 0, output_tokens=usage.response_token_count or 0, total_tokens=usage.total_token_count or 0, modality_details=modality_details if modality_details else None, cache_read_input_tokens=usage.cached_content_token_count if usage.cached_content_token_count else None - ) + )] # Silently ignore setup_complete and generation_complete messages - return None + return [] except Exception as e: logger.error("Error converting Gemini Live event: %s", e) logger.error("Message type: %s", type(message).__name__) logger.error("Message attributes: %s", [attr for attr in dir(message) if not attr.startswith('_')]) - # Return ErrorEvent instead of None so caller can handle it - return BidiErrorEvent(error=e) + # Return ErrorEvent in list so caller can handle it + return [BidiErrorEvent(error=e)] async def send( self, diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index 85e416164..272314272 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -315,7 +315,10 @@ async def test_event_conversion(mock_genai_client, model): mock_text.server_content = mock_server_content - text_event = model._convert_gemini_live_event(mock_text) + text_events = model._convert_gemini_live_event(mock_text) + assert isinstance(text_events, list) + assert len(text_events) == 1 + text_event = text_events[0] assert isinstance(text_event, BidiTranscriptStreamEvent) assert text_event.get("type") == "bidi_transcript_stream" assert text_event.text == "Hello from Gemini" @@ -344,7 +347,10 @@ async def test_event_conversion(mock_genai_client, model): mock_multi_text.server_content = mock_server_content_multi - multi_text_event = model._convert_gemini_live_event(mock_multi_text) + multi_text_events = model._convert_gemini_live_event(mock_multi_text) + assert isinstance(multi_text_events, list) + assert len(multi_text_events) == 1 + multi_text_event = multi_text_events[0] assert isinstance(multi_text_event, BidiTranscriptStreamEvent) assert multi_text_event.text == "Hello from Gemini" # Concatenated with space @@ -355,7 +361,10 @@ async def test_event_conversion(mock_genai_client, model): mock_audio.tool_call = None mock_audio.server_content = None - audio_event = model._convert_gemini_live_event(mock_audio) + audio_events = model._convert_gemini_live_event(mock_audio) + assert isinstance(audio_events, list) + assert len(audio_events) == 1 + audio_event = audio_events[0] assert isinstance(audio_event, BidiAudioStreamEvent) assert audio_event.get("type") == "bidi_audio_stream" # Audio is now base64 encoded @@ -363,7 +372,7 @@ async def test_event_conversion(mock_genai_client, model): assert audio_event.audio == expected_b64 assert audio_event.format == "pcm" - # Test tool call + # Test single tool call (returns list with one event) mock_func_call = unittest.mock.Mock() mock_func_call.id = "tool-123" mock_func_call.name = "calculator" @@ -378,13 +387,52 @@ async def test_event_conversion(mock_genai_client, model): mock_tool.tool_call = mock_tool_call mock_tool.server_content = None - tool_event = model._convert_gemini_live_event(mock_tool) + tool_events = model._convert_gemini_live_event(mock_tool) + # Should return a list of ToolUseStreamEvent + assert isinstance(tool_events, list) + assert len(tool_events) == 1 + tool_event = tool_events[0] # ToolUseStreamEvent has delta and current_tool_use, not a "type" field assert "delta" in tool_event assert "toolUse" in tool_event["delta"] assert tool_event["delta"]["toolUse"]["toolUseId"] == "tool-123" assert tool_event["delta"]["toolUse"]["name"] == "calculator" + # Test multiple tool calls (returns list with multiple events) + mock_func_call_1 = unittest.mock.Mock() + mock_func_call_1.id = "tool-123" + mock_func_call_1.name = "calculator" + mock_func_call_1.args = {"expression": "2+2"} + + mock_func_call_2 = unittest.mock.Mock() + mock_func_call_2.id = "tool-456" + mock_func_call_2.name = "weather" + mock_func_call_2.args = {"location": "Seattle"} + + mock_tool_call_multi = unittest.mock.Mock() + mock_tool_call_multi.function_calls = [mock_func_call_1, mock_func_call_2] + + mock_tool_multi = unittest.mock.Mock() + mock_tool_multi.text = None + mock_tool_multi.data = None + mock_tool_multi.tool_call = mock_tool_call_multi + mock_tool_multi.server_content = None + + tool_events_multi = model._convert_gemini_live_event(mock_tool_multi) + # Should return a list with two ToolUseStreamEvent + assert isinstance(tool_events_multi, list) + assert len(tool_events_multi) == 2 + + # Verify first tool call + assert tool_events_multi[0]["delta"]["toolUse"]["toolUseId"] == "tool-123" + assert tool_events_multi[0]["delta"]["toolUse"]["name"] == "calculator" + assert tool_events_multi[0]["delta"]["toolUse"]["input"] == {"expression": "2+2"} + + # Verify second tool call + assert tool_events_multi[1]["delta"]["toolUse"]["toolUseId"] == "tool-456" + assert tool_events_multi[1]["delta"]["toolUse"]["name"] == "weather" + assert tool_events_multi[1]["delta"]["toolUse"]["input"] == {"location": "Seattle"} + # Test interruption mock_server_content = unittest.mock.Mock() mock_server_content.interrupted = True @@ -397,7 +445,10 @@ async def test_event_conversion(mock_genai_client, model): mock_interrupt.tool_call = None mock_interrupt.server_content = mock_server_content - interrupt_event = model._convert_gemini_live_event(mock_interrupt) + interrupt_events = model._convert_gemini_live_event(mock_interrupt) + assert isinstance(interrupt_events, list) + assert len(interrupt_events) == 1 + interrupt_event = interrupt_events[0] assert isinstance(interrupt_event, BidiInterruptionEvent) assert interrupt_event.get("type") == "bidi_interruption" assert interrupt_event.reason == "user_speech" From 4e29b5abdd0104e8971cdb7bfc417b647560a20a Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 10 Nov 2025 18:25:48 -0500 Subject: [PATCH 094/242] agent - loop - refine --- .../bidirectional_streaming/agent/agent.py | 61 +-- .../bidirectional_streaming/agent/loop.py | 155 ++++++ .../event_loop/__init__.py | 15 - .../event_loop/bidirectional_event_loop.py | 490 ------------------ .../models/novasonic.py | 50 +- 5 files changed, 183 insertions(+), 588 deletions(-) create mode 100644 src/strands/experimental/bidirectional_streaming/agent/loop.py delete mode 100644 src/strands/experimental/bidirectional_streaming/event_loop/__init__.py delete mode 100644 src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 05037cadd..03c18d8d0 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -18,8 +18,6 @@ from typing import Any, AsyncIterable, Callable from .... import _identifier -from ....hooks.registry import HookRegistry -from ....telemetry.metrics import EventLoopMetrics from ....tools.caller import _ToolCaller from ....tools.executors import ConcurrentToolExecutor from ....tools.executors._executor import ToolExecutor @@ -28,17 +26,13 @@ from ....types.content import Message, Messages from ....types.tools import ToolResult, ToolUse, AgentTool -from ..event_loop.bidirectional_event_loop import ( - BidirectionalConnection, - start_bidirectional_connection, - stop_bidirectional_connection, -) +from .loop import BidiAgentLoop from ..models.bidirectional_model import BidiModel from ..models.novasonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput from ..types.events import BidiAudioInputEvent, BidiImageInputEvent, BidiTextInputEvent, BidiInputEvent, BidiOutputEvent from ..types import BidiIO -from ....experimental.tools import ToolProvider +from ...tools import ToolProvider logger = logging.getLogger(__name__) @@ -121,15 +115,12 @@ def __init__( self.tool_executor = tool_executor or ConcurrentToolExecutor() # Initialize other components - self.event_loop_metrics = EventLoopMetrics() self._tool_caller = _ToolCaller(self) - self.hooks = HookRegistry() - # connection management - self._agent_loop: "BidirectionalConnection" | None = None - self._output_queue = asyncio.Queue() self._current_adapters = [] # Track adapters for cleanup + self._loop = BidiAgentLoop(self) + @property def tool(self) -> _ToolCaller: """Call tool as a function. @@ -246,16 +237,10 @@ async def start(self) -> None: Initializes the streaming connection and starts background tasks for processing model events, tool execution, and connection management. - - Raises: - ValueError: If conversation already active. - ConnectionError: If connection creation fails. """ - if self._agent_loop and self._agent_loop.active: - raise ValueError("Conversation already active. Call end() first.") + logger.debug("starting agent") - logger.debug("Conversation start - initializing connection") - self._agent_loop = await start_bidirectional_connection(self) + await self._loop.start() async def send(self, input_data: BidiAgentInput) -> None: """Send input to the model (text, audio, image, or event dict). @@ -291,13 +276,13 @@ async def send(self, input_data: BidiAgentInput) -> None: logger.debug("Text sent: %d characters", len(input_data)) # Create BidiTextInputEvent for send() text_event = BidiTextInputEvent(text=input_data, role="user") - await self._agent_loop.model.send(text_event) + await self.model.send(text_event) return # Handle BidiInputEvent instances # Check this before dict since TypedEvent inherits from dict if isinstance(input_data, BidiInputEvent): - await self._agent_loop.model.send(input_data) + await self.model.send(input_data) return # Handle plain dict - reconstruct TypedEvent for WebSocket integration @@ -321,7 +306,7 @@ async def send(self, input_data: BidiAgentInput) -> None: raise ValueError(f"Unknown event type: {event_type}") # Send the reconstructed TypedEvent - await self._agent_loop.model.send(input_event) + await self.model.send(input_event) return # If we get here, input type is invalid @@ -336,16 +321,10 @@ async def receive(self) -> AsyncIterable[BidiOutputEvent]: text responses, tool calls, and connection updates. Yields: - BidirectionalStreamEvent: Events from the model session. + Model and tool call events. """ - while self.active: - try: - # Use a timeout to periodically check if we should stop - event = await asyncio.wait_for(self._output_queue.get(), timeout=0.5) - yield event - except asyncio.TimeoutError: - # Timeout allows us to check self.active periodically - continue + async for event in self._loop.receive(): + yield event async def stop(self) -> None: """End the conversation connection and cleanup all resources. @@ -353,9 +332,7 @@ async def stop(self) -> None: Terminates the streaming connection, cancels background tasks, and closes the connection to the model provider. """ - if self._agent_loop: - await stop_bidirectional_connection(self._agent_loop) - self._agent_loop = None + await self._loop.stop() async def __aenter__(self) -> "BidiAgent": """Async context manager entry point. @@ -364,10 +341,6 @@ async def __aenter__(self) -> "BidiAgent": Returns: Self for use in the context. - - Raises: - ValueError: If connection is already active. - ConnectionError: If connection creation fails. """ logger.debug("Entering async context manager - starting connection") await self.start() @@ -415,12 +388,8 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: @property def active(self) -> bool: - """Check if the agent connection is currently active. - - Returns: - True if connection is active and ready for communication, False otherwise. - """ - return self._agent_loop is not None and self._agent_loop.active + """True if agent loop started, False otherwise.""" + return self._loop.active async def run(self, io_channels: list[BidiIO | tuple[Callable, Callable]]) -> None: """Run the agent using provided IO channels or transport tuples for bidirectional communication. diff --git a/src/strands/experimental/bidirectional_streaming/agent/loop.py b/src/strands/experimental/bidirectional_streaming/agent/loop.py new file mode 100644 index 000000000..5e9d4fc78 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/agent/loop.py @@ -0,0 +1,155 @@ +"""Agent loop. + +The agent loop handles the events received from the model and executes tools when given a tool use request. +""" + +import asyncio +import logging +from typing import AsyncIterable, Awaitable, TYPE_CHECKING + +from ..types.events import BidiOutputEvent, BidiTranscriptStreamEvent +from ....types._events import ToolResultEvent, ToolResultMessageEvent, ToolStreamEvent, ToolUseStreamEvent +from ....types.content import Message +from ....types.tools import ToolResult, ToolUse + +if TYPE_CHECKING: + from .agent import BidiAgent + +logger = logging.getLogger(__name__) + + +class BidiAgentLoop: + """Agent loop.""" + + def __init__(self, agent: "BidiAgent") -> None: + """Initialize members of the agent loop. + + Note, before receiving events from the loop, the user must call `start`. + + Args: + agent: Bidirectional agent to loop over. + """ + self._agent = agent + self._event_queue = asyncio.Queue() # queue model and tool call events + self._tasks = set() # track active async tasks created in loop + self._active = False # flag if agent loop is started + + async def start(self) -> None: + """Start the agent loop. + + The agent model is started as part of this call. + """ + if self.active: + return + + logger.debug("starting agent loop") + + await self._agent.model.start( + system_prompt=self._agent.system_prompt, + tools=self._agent.tool_registry.get_all_tool_specs(), + messages=self._agent.messages, + ) + + self._create_task(self._run_model()) + + self._active = True + + async def stop(self) -> None: + """Stop the agent loop.""" + if not self.active: + return + + logger.debug("stopping agent loop") + + for task in self._tasks: + task.cancel() + + await asyncio.gather(*self._tasks, return_exceptions=True) + + await self._agent.model.stop() + + self._active = False + + async def receive(self) -> AsyncIterable[BidiOutputEvent]: + """Receive model and tool call events.""" + while self.active: + try: + yield self._event_queue.get_nowait() + except asyncio.TimeoutError: + pass + + # unblock the event loop + await asyncio.sleep(0) + + @property + def active(self) -> bool: + """True if agent loop started, False otherwise.""" + return self._active + + def _create_task(self, coro: Awaitable[None]) -> None: + """Utilitly to create async task. + + Adds a clean up callback to run after task completes. + """ + task = asyncio.create_task(coro) + task.add_done_callback(lambda task: self._tasks.remove(task)) + + self._tasks.add(task) + + async def _run_model(self) -> None: + """Task for running the model. + + Events are streamed through the event queue. + """ + logger.debug("running model") + + async for event in self._agent.model.receive(): + if not self.active: + break + + self._event_queue.put_nowait(event) + + if isinstance(event, BidiTranscriptStreamEvent): + if event["is_final"]: + message: Message = {"role": event["role"], "content": [{"text": event["text"]}]} + self._agent.messages.append(message) + + if isinstance(event, ToolUseStreamEvent): + self._create_task(self._run_tool(event["current_tool_use"])) + + async def _run_tool(self, tool_use: ToolUse) -> None: + """Task for running tool requested by the model.""" + logger.debug("running tool") + + result: ToolResult = None + + try: + tool = self._agent.tool_registry.registry[tool_use["name"]] + invocation_state = {} + + async for event in tool.stream(tool_use, invocation_state): + if isinstance(event, ToolResultEvent): + self._event_queue.put_nowait(event) + result = event.tool_result + break + + if isinstance(event, ToolStreamEvent): + self._event_queue.put_nowait(event) + else: + self._event_queue.put_nowait(ToolStreamEvent(tool_use, event)) + + except Exception as e: + result = { + "toolUseId": tool_use["toolUseId"], + "status": "error", + "content": [{"text": f"Error: {str(e)}"}] + } + + await self._agent.model.send(ToolResultEvent(result)) + + message: Message = { + "role": "user", + "content": [{"toolResult": result}], + } + self._agent.messages.append(message) + self._event_queue.put_nowait(ToolResultMessageEvent(message)) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py deleted file mode 100644 index af8c4e1e1..000000000 --- a/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Event loop management for bidirectional streaming.""" - -from .bidirectional_event_loop import ( - BidirectionalConnection, - bidirectional_event_loop_cycle, - start_bidirectional_connection, - stop_bidirectional_connection, -) - -__all__ = [ - "BidirectionalConnection", - "start_bidirectional_connection", - "stop_bidirectional_connection", - "bidirectional_event_loop_cycle", -] diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py deleted file mode 100644 index 8799f14ec..000000000 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ /dev/null @@ -1,490 +0,0 @@ -"""Bidirectional session management for concurrent streaming conversations. - -Manages bidirectional communication sessions with concurrent processing of model events, -tool execution, and audio processing. Provides coordination between background tasks -while maintaining a simple interface for agent interaction. - -Features: -- Concurrent task management for model events and tool execution -- Interruption handling with audio buffer clearing -- Tool execution with cancellation support -- Session lifecycle management -""" - -import asyncio -import logging -import traceback -import uuid - -from ....tools._validator import validate_and_prepare_tools -from ....telemetry.metrics import Trace -from ....types._events import ToolResultEvent, ToolStreamEvent -from ....types.content import Message -from ....types.tools import ToolResult, ToolUse -from ..models.bidirectional_model import BidiModel - -logger = logging.getLogger(__name__) - -# Session constants -TOOL_QUEUE_TIMEOUT = 0.5 -SUPERVISION_INTERVAL = 0.1 - - -class BidirectionalConnection: - """Session wrapper for bidirectional communication with concurrent task management. - - Coordinates background tasks for model event processing, tool execution, and audio - handling while providing a simple interface for agent interactions. - """ - - def __init__(self, model: BidiModel, agent: "BidiAgent") -> None: - """Initialize connection with model and agent reference. - - Args: - model: Bidirectional model instance. - agent: BidiAgent instance for tool registry access. - """ - self.model = model - self.agent = agent - self.active = True - - # Background processing coordination - self.background_tasks = [] - self.tool_queue = asyncio.Queue() - self.audio_output_queue = asyncio.Queue() - - # Task management for cleanup - self.pending_tool_tasks: dict[str, asyncio.Task] = {} - - # Interruption handling (model-agnostic) - self.interrupted = False - self.interruption_lock = asyncio.Lock() - - # Tool execution tracking - self.tool_count = 0 - - -async def start_bidirectional_connection(agent: "BidiAgent") -> BidirectionalConnection: - """Initialize bidirectional session with conycurrent background tasks. - - Creates a model-specific session and starts background tasks for processing - model events, executing tools, and managing the session lifecycle. - - Args: - agent: BidiAgent instance. - - Returns: - BidirectionalConnection: Active session with background tasks running. - """ - logger.debug("Starting bidirectional session - initializing model connection") - - # Connect to model - await agent.model.start( - system_prompt=agent.system_prompt, tools=agent.tool_registry.get_all_tool_specs(), messages=agent.messages - ) - - # Create connection wrapper for background processing - session = BidirectionalConnection(model=agent.model, agent=agent) - - # Start concurrent background processors IMMEDIATELY after session creation - # This is critical - Nova Sonic needs response processing during initialization - logger.debug("Starting background processors for concurrent processing") - session.background_tasks = [ - asyncio.create_task(_process_model_events(session)), # Handle model responses - asyncio.create_task(_process_tool_execution(session)), # Execute tools concurrently - ] - - # Start main coordination cycle - session.main_cycle_task = asyncio.create_task(bidirectional_event_loop_cycle(session)) - - logger.debug("Session ready with %d background tasks", len(session.background_tasks)) - return session - - -async def stop_bidirectional_connection(session: BidirectionalConnection) -> None: - """End session and cleanup resources including background tasks. - - Args: - session: BidirectionalConnection to cleanup. - """ - if not session.active: - return - - logger.debug("Session cleanup starting") - session.active = False - - # Cancel pending tool tasks - for _, task in session.pending_tool_tasks.items(): - if not task.done(): - task.cancel() - - # Cancel background tasks - for task in session.background_tasks: - if not task.done(): - task.cancel() - - # Cancel main cycle task - if hasattr(session, "main_cycle_task") and not session.main_cycle_task.done(): - session.main_cycle_task.cancel() - - # Wait for tasks to complete - all_tasks = session.background_tasks + list(session.pending_tool_tasks.values()) - if hasattr(session, "main_cycle_task"): - all_tasks.append(session.main_cycle_task) - - if all_tasks: - await asyncio.gather(*all_tasks, return_exceptions=True) - - # Close model connection - await session.model.stop() - logger.debug("Connection closed") - - -async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None: - """Main event loop coordinator that runs continuously during the session. - - Monitors background tasks, manages session state, and handles session lifecycle. - Provides supervision for concurrent model event processing and tool execution. - - Args: - session: BidirectionalConnection to coordinate. - """ - while session.active: - try: - # Check if background processors are still running - if all(task.done() for task in session.background_tasks): - logger.debug("Session end - all processors completed") - session.active = False - break - - # Check for failed background tasks - for i, task in enumerate(session.background_tasks): - if task.done() and not task.cancelled(): - exception = task.exception() - if exception: - logger.error("Session error in processor %d: %s", i, str(exception)) - session.active = False - raise exception - - # Brief pause before next supervision check - await asyncio.sleep(SUPERVISION_INTERVAL) - - except asyncio.CancelledError: - break - except Exception as e: - logger.error("Event loop error: %s", str(e)) - session.active = False - raise - - -async def _handle_interruption(session: BidirectionalConnection) -> None: - """Handle interruption detection with task cancellation and audio buffer clearing. - - Cancels pending tool tasks and clears audio output queues to ensure responsive - interruption handling during conversations. Protected by async lock to prevent - concurrent execution and race conditions. - - Args: - session: BidirectionalConnection to handle interruption for. - """ - async with session.interruption_lock: - # If already interrupted, skip duplicate processing - if session.interrupted: - logger.debug("Interruption already in progress") - return - - logger.debug("Interruption detected") - session.interrupted = True - - # Cancel all pending tool execution tasks - cancelled_tools = 0 - for _task_id, task in list(session.pending_tool_tasks.items()): - if not task.done(): - task.cancel() - cancelled_tools += 1 - logger.debug("Tool task cancelled: %s", _task_id) - - if cancelled_tools > 0: - logger.debug("Tool tasks cancelled: %d", cancelled_tools) - - # Clear all queued audio output events - cleared_count = 0 - while True: - try: - session.audio_output_queue.get_nowait() - cleared_count += 1 - except asyncio.QueueEmpty: - break - - # Also clear the agent's audio output queue - audio_cleared = 0 - # Create a temporary list to hold non-audio events - temp_events = [] - try: - while True: - event = session.agent._output_queue.get_nowait() - # Check for audio events - event_type = event.get("type", "") - if event_type == "bidi_audio_stream": - audio_cleared += 1 - else: - # Keep non-audio events - temp_events.append(event) - except asyncio.QueueEmpty: - pass - - # Put back non-audio events - for event in temp_events: - session.agent._output_queue.put_nowait(event) - - if audio_cleared > 0: - logger.debug("Agent audio queue cleared: %d events", audio_cleared) - - if cleared_count > 0: - logger.debug("Session audio queue cleared: %d events", cleared_count) - - # Reset interruption flag after clearing (automatic recovery) - session.interrupted = False - logger.debug("Interruption handled - tools cancelled: %d, audio cleared: %d", cancelled_tools, cleared_count) - - -async def _process_model_events(session: BidirectionalConnection) -> None: - """Process model events and convert them to Strands format. - - Background task that handles all model responses, converts provider-specific - events to standardized formats, and manages interruption detection. - - Args: - session: BidirectionalConnection containing model. - """ - logger.debug("Model events processor started") - try: - async for provider_event in session.model.receive(): - if not session.active: - break - - # Basic validation - skip invalid events - if not isinstance(provider_event, dict): - continue - - strands_event = provider_event - - # Get event type - event_type = strands_event.get("type", "") - - # Handle interruption detection - if event_type == "bidi_interruption": - logger.debug("Interruption forwarded") - await _handle_interruption(session) - # Forward interruption event to agent for application-level handling - await session.agent._output_queue.put(strands_event) - continue - - # Queue tool requests for concurrent execution - # Check for ToolUseStreamEvent (standard agent event) - if event_type == "tool_use_stream": - tool_use = strands_event.get("current_tool_use") - if tool_use: - tool_name = tool_use.get("name") - logger.debug("Tool usage detected: %s", tool_name) - await session.tool_queue.put(tool_use) - # Forward ToolUseStreamEvent to output queue for client visibility - await session.agent._output_queue.put(strands_event) - continue - - # Send all output events to Agent for receive() method - await session.agent._output_queue.put(strands_event) - - # Update Agent conversation history for user transcripts - if event_type == "bidi_transcript_stream": - role = strands_event.get("role") - text = strands_event.get("text", "") - if role == "user" and text.strip(): - user_message = {"role": "user", "content": text} - session.agent.messages.append(user_message) - logger.debug("User transcript added to history") - - except Exception as e: - logger.error("Model events error: %s", str(e)) - traceback.print_exc() - finally: - logger.debug("Model events processor stopped") - - -async def _process_tool_execution(session: BidirectionalConnection) -> None: - """Execute tools concurrently with interruption support. - - Background task that manages tool execution without blocking model event - processing or user interaction. Uses proper asyncio cancellation for - interruption handling rather than manual state checks. - - Args: - session: BidirectionalConnection containing tool queue. - """ - logger.debug("Tool execution processor started") - while session.active: - try: - tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) - tool_name = tool_use.get("name") - tool_id = tool_use.get("toolUseId") - - session.tool_count += 1 - print(f"\nTool #{session.tool_count}: {tool_name}") - - logger.debug("Tool execution started: %s (id: %s)", tool_name, tool_id) - - task_id = str(uuid.uuid4()) - task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) - session.pending_tool_tasks[task_id] = task - - def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: - try: - # Remove from pending tasks - if task_id in session.pending_tool_tasks: - del session.pending_tool_tasks[task_id] - - # Log completion status - if completed_task.cancelled(): - logger.debug("Tool task cancelled: %s", task_id) - elif completed_task.exception(): - logger.error("Tool task error: %s - %s", task_id, str(completed_task.exception())) - else: - logger.debug("Tool task completed: %s", task_id) - except Exception as e: - logger.error("Tool task cleanup failed: %s - %s", task_id, str(e)) - - task.add_done_callback(cleanup_task) - - except asyncio.TimeoutError: - if not session.active: - break - # Remove completed tasks from tracking - completed_tasks = [task_id for task_id, task in session.pending_tool_tasks.items() if task.done()] - for task_id in completed_tasks: - if task_id in session.pending_tool_tasks: - del session.pending_tool_tasks[task_id] - - if completed_tasks: - logger.debug("Periodic task cleanup: %d tasks", len(completed_tasks)) - - continue - except Exception as e: - logger.error("Tool execution error: %s", str(e)) - if not session.active: - break - - logger.debug("Tool execution processor stopped") - - - - - -async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: - """Execute tool using the complete Strands tool execution system. - - Uses proper Strands ToolExecutor system with validation, error handling, - and event streaming. - - Args: - session: BidirectionalConnection for context. - tool_use: Tool use event to execute. - """ - tool_name = tool_use.get("name") - tool_id = tool_use.get("toolUseId") - - logger.debug("Executing tool: %s (id: %s)", tool_name, tool_id) - - try: - # Create message structure for validation - tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} - - # Use Strands validation system - tool_uses: list[ToolUse] = [] - tool_results: list[ToolResult] = [] - invalid_tool_use_ids: list[str] = [] - - validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids) - - # Filter valid tools - valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] - - if not valid_tool_uses: - logger.warning("No valid tools after validation: %s", tool_name) - return - - # Create invocation state for tool execution - invocation_state = { - "agent": session.agent, - "model": session.agent.model, - "messages": session.agent.messages, - "system_prompt": session.agent.system_prompt, - } - - # Create cycle trace and span - cycle_trace = Trace("Bidirectional Tool Execution") - cycle_span = None - - tool_events = session.agent.tool_executor._execute( - session.agent, - valid_tool_uses, - tool_results, - cycle_trace, - cycle_span, - invocation_state - ) - - # Process tool events and send results to provider - async for tool_event in tool_events: - if isinstance(tool_event, ToolResultEvent): - tool_result = tool_event.tool_result - tool_use_id = tool_result.get("toolUseId") - - # Send ToolResultEvent through send() method to model - await session.model.send(tool_event) - logger.debug("Tool result sent to model: %s", tool_use_id) - - # Also forward ToolResultEvent to output queue for client visibility - await session.agent._output_queue.put(tool_event) - logger.debug("Tool result sent to client: %s", tool_use_id) - - # Handle streaming events if needed later - elif isinstance(tool_event, ToolStreamEvent): - logger.debug("Tool stream event: %s", tool_event) - # Forward tool stream events to output queue - await session.agent._output_queue.put(tool_event) - - # Add tool result message to conversation history - if tool_results: - from ....hooks import MessageAddedEvent - - tool_result_message: Message = { - "role": "user", - "content": [{"toolResult": result} for result in tool_results], - } - - session.agent.messages.append(tool_result_message) - session.agent.hooks.invoke_callbacks(MessageAddedEvent(agent=session.agent, message=tool_result_message)) - logger.debug("Tool result message added to history: %s", tool_name) - - logger.debug("Tool execution completed: %s", tool_name) - - except asyncio.CancelledError: - logger.debug("Tool execution cancelled: %s (id: %s)", tool_name, tool_id) - raise - except Exception as e: - logger.error("Tool execution error: %s - %s", tool_name, str(e)) - - # Send error result wrapped in ToolResultEvent - error_result: ToolResult = { - "toolUseId": tool_id, - "status": "error", - "content": [{"text": f"Error: {str(e)}"}] - } - try: - await session.model.send(ToolResultEvent(error_result)) - logger.debug("Error result sent: %s", tool_id) - except Exception as send_error: - logger.error("Failed to send error result: %s - %s", tool_id, str(send_error)) - raise # Propagate exception since this is experimental code - - diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 06b267270..16c9b7970 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -79,7 +79,6 @@ NOVA_TOOL_CONFIG = {"mediaType": "application/json"} # Timing constants -SILENCE_THRESHOLD = 2.0 EVENT_DELAY = 0.1 RESPONSE_TIMEOUT = 1.0 @@ -120,9 +119,6 @@ def __init__( # Audio connection state self.audio_connection_active = False - self.last_audio_time = None - self.silence_threshold = SILENCE_THRESHOLD - self.silence_task = None # Background task and event queue self._response_task = None @@ -131,6 +127,7 @@ def __init__( # Track API-provided identifiers self._current_completion_id = None self._current_role = None + self._generation_stage = None logger.debug("Nova Sonic bidirectional model initialized: %s", model_id) @@ -151,7 +148,7 @@ async def start( """ if self._active: raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") - + logger.debug("Nova connection create - starting") try: @@ -369,11 +366,6 @@ async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: if not self.audio_connection_active: await self._start_audio_connection() - # Update last audio time and cancel any pending silence task - self.last_audio_time = time.time() - if self.silence_task and not self.silence_task.done(): - self.silence_task.cancel() - # Audio is already base64 encoded in the event # Send audio input event audio_event = json.dumps( @@ -390,21 +382,6 @@ async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: await self._send_nova_event(audio_event) - # Start silence detection task - self.silence_task = asyncio.create_task(self._check_silence()) - - async def _check_silence(self) -> None: - """Internal: Check for silence and automatically end audio connection.""" - try: - await asyncio.sleep(self.silence_threshold) - if self.audio_connection_active and self.last_audio_time: - elapsed = time.time() - self.last_audio_time - if elapsed >= self.silence_threshold: - logger.debug("Nova silence detected: %.2f seconds", elapsed) - await self._end_audio_input() - except asyncio.CancelledError: - pass - async def _end_audio_input(self) -> None: """Internal: End current audio input connection to trigger Nova Sonic processing.""" if not self.audio_connection_active: @@ -551,9 +528,6 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> BidiOutputEvent | N # Handle text output (transcripts) elif "textOutput" in nova_event: text_content = nova_event["textOutput"]["content"] - # Use stored role from contentStart event, fallback to event role - role = getattr(self, "_current_role", None) or nova_event["textOutput"].get("role", "assistant") - # Check for Nova Sonic interruption pattern if '{ "interrupted" : true }' in text_content: logger.debug("Nova interruption detected in text") @@ -562,13 +536,13 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> BidiOutputEvent | N return BidiTranscriptStreamEvent( delta={"text": text_content}, text=text_content, - role=role.lower() if isinstance(role, str) else "assistant", - is_final=True, + role=self._current_role.lower() if self._current_role else "assistant", + is_final=self._generation_stage == "FINAL", current_transcript=text_content ) # Handle tool use - elif "toolUse" in nova_event: + if "toolUse" in nova_event: tool_use = nova_event["toolUse"] tool_use_event: ToolUse = { "toolUseId": tool_use["toolUseId"], @@ -582,12 +556,12 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> BidiOutputEvent | N ) # Handle interruption - elif nova_event.get("stopReason") == "INTERRUPTED": + if nova_event.get("stopReason") == "INTERRUPTED": logger.debug("Nova interruption stop reason") return BidiInterruptionEvent(reason="user_speech") # Handle usage events - convert to multimodal usage format - elif "usageEvent" in nova_event: + if "usageEvent" in nova_event: usage_data = nova_event["usageEvent"] total_input = usage_data.get("totalInputTokens", 0) total_output = usage_data.get("totalOutputTokens", 0) @@ -599,21 +573,23 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> BidiOutputEvent | N ) # Handle content start events (track role and emit response start) - elif "contentStart" in nova_event: + if "contentStart" in nova_event: content_data = nova_event["contentStart"] role = content_data.get("role", "unknown") # Store role for subsequent text output events self._current_role = role + if content_data["type"] == "TEXT": + self._generation_stage = json.loads(content_data["additionalModelFields"])["generationStage"] + # Emit response start event using API-provided completionId # completionId should already be tracked from completionStart event return BidiResponseStartEvent( response_id=self._current_completion_id or str(uuid.uuid4()) # Fallback to UUID if missing ) - # Handle other events (contentEnd, etc.) - else: - return None + # Ignore other events (contentEnd, etc.) + return # Nova Sonic event template methods def _get_connection_start_event(self) -> str: From 2ec5e57d84442e649ea02fca7513819819836cd1 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 11 Nov 2025 18:49:23 -0500 Subject: [PATCH 095/242] queue empty vs timeout --- src/strands/experimental/bidirectional_streaming/agent/loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/loop.py b/src/strands/experimental/bidirectional_streaming/agent/loop.py index 5e9d4fc78..d4f1f892b 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/loop.py +++ b/src/strands/experimental/bidirectional_streaming/agent/loop.py @@ -75,7 +75,7 @@ async def receive(self) -> AsyncIterable[BidiOutputEvent]: while self.active: try: yield self._event_queue.get_nowait() - except asyncio.TimeoutError: + except asyncio.QueueEmpty: pass # unblock the event loop From 8bbfd3822914b697539f73c657eea702731bdafe Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 11 Nov 2025 21:22:56 -0500 Subject: [PATCH 096/242] bidi audio io - speech --- .../bidirectional_streaming/__init__.py | 4 +- .../bidirectional_streaming/agent/agent.py | 67 ++----- .../bidirectional_streaming/agent/loop.py | 2 +- .../bidirectional_streaming/io/__init__.py | 4 +- .../bidirectional_streaming/io/audio.py | 184 +++++------------- .../bidirectional_streaming/types/io.py | 2 +- 6 files changed, 74 insertions(+), 189 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 2dd38e172..033a4bb78 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -4,7 +4,7 @@ from .agent.agent import BidiAgent # IO channels - Hardware abstraction -from .io.audio import AudioIO +from .io.audio import BidiAudioIO # Model interface (for custom implementations) from .models.bidirectional_model import BidiModel @@ -44,7 +44,7 @@ # Main interface "BidiAgent", # IO channels - "AudioIO", + "BidiAudioIO", # Model providers "BidiGeminiLiveModel", "BidiNovaSonicModel", diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 03c18d8d0..ec2606d19 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -391,73 +391,36 @@ def active(self) -> bool: """True if agent loop started, False otherwise.""" return self._loop.active - async def run(self, io_channels: list[BidiIO | tuple[Callable, Callable]]) -> None: - """Run the agent using provided IO channels or transport tuples for bidirectional communication. + async def run(self, io_channel: BidiIO) -> None: + """Run the agent using provided IO channels for bidirectional communication. Args: - io_channels: List containing either BidiIO instances or (sender, receiver) tuples. + io_channels: List containing either BidiIO instances. - BidiIO: IO channel instance with send(), receive(), and end() methods - - tuple: (sender_callable, receiver_callable) for custom transport Example: ```python - # With IO channel audio_io = AudioIO(audio_config={"input_sample_rate": 16000}) agent = BidiAgent(model=model, tools=[calculator]) await agent.run(io_channels=[audio_io]) - - # With tuple (backward compatibility) - await agent.run(io_channels=[(sender_function, receiver_function)]) ``` - - Raises: - ValueError: If io_channels list is empty or contains invalid items. - Exception: Any exception from the transport layer. """ - if not io_channels: - raise ValueError("io_channels parameter cannot be empty. Provide either an IO channel or (sender, receiver) tuple.") - - transport = io_channels[0] - - # Set IO channel tracking for cleanup - if hasattr(transport, 'send') and hasattr(transport, 'receive'): - self._current_adapters = [transport] # IO channel needs cleanup - elif isinstance(transport, tuple) and len(transport) == 2: - self._current_adapters = [] # Tuple needs no cleanup - else: - raise ValueError("io_channels list must contain either BidiIO instances or (sender, receiver) tuples.") - - # Auto-manage session lifecycle - if self.active: - await self._run_with_transport(transport) - else: - async with self: - await self._run_with_transport(transport) - - async def _run_with_transport( - self, - transport: BidiIO | tuple[Callable, Callable], - ) -> None: - """Internal method to run send/receive loops with an active connection.""" + async def send(): + while self.active: + event = await io_channel.receive() + await self.send(event) - async def receive_from_agent(): - """Receive events from agent and send to transport.""" + async def receive(): async for event in self.receive(): - if hasattr(transport, 'receive'): - await transport.receive(event) - else: - await transport[0](event) + await io_channel.send(event) - async def send_to_agent(): - """Receive events from transport and send to agent.""" - while self.active: - if hasattr(transport, 'send'): - event = await transport.send() - else: - event = await transport[1]() - await self.send(event) + await io_channel.start() + + try: + await asyncio.gather(send(), receive(), return_exceptions=True) - await asyncio.gather(receive_from_agent(), send_to_agent(), return_exceptions=True) + finally: + io_channel.stop() def _validate_active_connection(self) -> None: """Validate that an active connection exists. diff --git a/src/strands/experimental/bidirectional_streaming/agent/loop.py b/src/strands/experimental/bidirectional_streaming/agent/loop.py index d4f1f892b..483ba5392 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/loop.py +++ b/src/strands/experimental/bidirectional_streaming/agent/loop.py @@ -104,7 +104,7 @@ async def _run_model(self) -> None: logger.debug("running model") async for event in self._agent.model.receive(): - if not self.active: + if not self.active: # TODO: maybe remove break self._event_queue.put_nowait(event) diff --git a/src/strands/experimental/bidirectional_streaming/io/__init__.py b/src/strands/experimental/bidirectional_streaming/io/__init__.py index 0bf186777..faa969168 100644 --- a/src/strands/experimental/bidirectional_streaming/io/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/io/__init__.py @@ -1,5 +1,5 @@ """IO channel implementations for bidirectional streaming.""" -from .audio import AudioIO +from .audio import BidiAudioIO -__all__ = ["AudioIO"] \ No newline at end of file +__all__ = ["BidiAudioIO"] \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/io/audio.py b/src/strands/experimental/bidirectional_streaming/io/audio.py index a16dce884..16b32a38e 100644 --- a/src/strands/experimental/bidirectional_streaming/io/audio.py +++ b/src/strands/experimental/bidirectional_streaming/io/audio.py @@ -4,17 +4,18 @@ Handles all PyAudio setup, streaming, and cleanup while keeping the core agent data-agnostic. """ -import asyncio import base64 import logging + import pyaudio from ..types.io import BidiIO +from ..types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiOutputEvent, BidiTranscriptStreamEvent logger = logging.getLogger(__name__) -class AudioIO(BidiIO): +class BidiAudioIO(BidiIO): """Audio IO channel for BidirectionalAgent with direct stream processing.""" def __init__( @@ -33,10 +34,6 @@ def __init__( - input_channels (int): Input channels (default: 1) - output_channels (int): Output channels (default: 1) """ - if pyaudio is None: - raise ImportError("PyAudio is required for AudioIO. Install with: pip install pyaudio") - - # Default audio configuration default_config = { "input_sample_rate": 24000, "output_sample_rate": 24000, @@ -66,135 +63,60 @@ def __init__( self.output_stream = None self.interrupted = False - def start(self) -> None: + async def start(self) -> None: """Setup PyAudio streams for input and output.""" if self.audio: return self.audio = pyaudio.PyAudio() - try: - # Input stream - self.input_stream = self.audio.open( - format=pyaudio.paInt16, - channels=self.input_channels, - rate=self.input_sample_rate, - input=True, - frames_per_buffer=self.chunk_size, - input_device_index=self.input_device_index, - ) - - # Output stream - self.output_stream = self.audio.open( - format=pyaudio.paInt16, - channels=self.output_channels, - rate=self.output_sample_rate, - output=True, - frames_per_buffer=self.chunk_size, - output_device_index=self.output_device_index, - ) - - # Start streams - self.input_stream.start_stream() - self.output_stream.start_stream() - - except Exception as e: - logger.error(f"AudioIO: Audio setup failed: {e}") - self._cleanup_audio() - raise - - async def send(self) -> dict: - """Read audio from microphone.""" - if not self.input_stream: - self.start() - - try: - audio_bytes = self.input_stream.read(self.chunk_size, exception_on_overflow=False) - return { - "audioData": audio_bytes, - "format": "pcm", - "sampleRate": self.input_sample_rate, - "channels": self.input_channels, - } - except Exception as e: - logger.warning(f"Audio input error: {e}") - return { - "audioData": b"", - "format": "pcm", - "sampleRate": self.input_sample_rate, - "channels": self.input_channels, - } - - async def receive(self, event: dict) -> None: - """Handle audio events with direct stream writing.""" - if not self.output_stream: - self.start() - - # Handle audio output - if "audioOutput" in event and not self.interrupted: - audio_data = event["audioOutput"]["audioData"] - - # Handle both base64 and raw bytes - if isinstance(audio_data, str): - audio_data = base64.b64decode(audio_data) - - if audio_data: - chunk_size = 2048 - for i in range(0, len(audio_data), chunk_size): - # Check for interruption before each chunk - if self.interrupted: - break - - chunk = audio_data[i : i + chunk_size] - try: - self.output_stream.write(chunk, exception_on_underflow=False) - await asyncio.sleep(0) - except Exception as e: - logger.warning(f"Audio playback error: {e}") - break - - elif "interruptionDetected" in event or "interrupted" in event: - self.interrupted = True - logger.debug("Interruption detected") - - # Stop and restart stream for immediate interruption - if self.output_stream: - try: - self.output_stream.stop_stream() - self.output_stream.start_stream() - except Exception as e: - logger.debug(f"Error clearing audio buffer: {e}") - - self.interrupted = False - - elif "textOutput" in event: - text = event["textOutput"].get("text", "").strip() - role = event["textOutput"].get("role", "") - if text: - if role.upper() == "ASSISTANT": - print(f"🤖 {text}") - elif role.upper() == "USER": - print(f"User: {text}") - - def stop(self) -> None: + self.input_stream = self.audio.open( + format=pyaudio.paInt16, + channels=self.input_channels, + rate=self.input_sample_rate, + input=True, + frames_per_buffer=self.chunk_size, + input_device_index=self.input_device_index, + ) + + self.output_stream = self.audio.open( + format=pyaudio.paInt16, + channels=self.output_channels, + rate=self.output_sample_rate, + output=True, + frames_per_buffer=self.chunk_size, + output_device_index=self.output_device_index, + ) + + async def stop(self) -> None: """Clean up IO channel resources.""" - try: - if self.input_stream: - if self.input_stream.is_active(): - self.input_stream.stop_stream() - self.input_stream.close() - - if self.output_stream: - if self.output_stream.is_active(): - self.output_stream.stop_stream() - self.output_stream.close() - - if self.audio: - self.audio.terminate() - - self.input_stream = None - self.output_stream = None - self.audio = None - - except Exception as e: - logger.warning(f"Audio cleanup error: {e}") \ No newline at end of file + if not self.audio: + return + + self.input_stream.close() + self.output_stream.close() + self.audio.terminate() + + self.input_stream = None + self.output_stream = None + self.audio = None + + async def send(self, event: BidiOutputEvent) -> None: + """Handle audio events with direct stream writing.""" + + if isinstance(event, BidiAudioStreamEvent): + self.output_stream.write(base64.b64decode(event["audio"])) + + elif isinstance(event, BidiTranscriptStreamEvent): + print(event["current_transcript"]) + + async def receive(self) -> BidiAudioInputEvent: + """Read audio from microphone.""" + audio_bytes = self.input_stream.read(self.chunk_size, exception_on_overflow=False) + + return BidiAudioInputEvent( + audio=base64.b64encode(audio_bytes).decode("utf-8"), + format="pcm", + sample_rate=self.input_sample_rate, + channels=self.input_channels, + ) diff --git a/src/strands/experimental/bidirectional_streaming/types/io.py b/src/strands/experimental/bidirectional_streaming/types/io.py index 2e113c74b..25121e951 100644 --- a/src/strands/experimental/bidirectional_streaming/types/io.py +++ b/src/strands/experimental/bidirectional_streaming/types/io.py @@ -37,7 +37,7 @@ async def receive(self, event: dict) -> None: """ ... - def stop(self) -> None: + async def stop(self) -> None: """Clean up IO channel resources. Called by the agent during shutdown to ensure proper From cf6e41cd1bd62788c211324a502d793b12fc821d Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 11 Nov 2025 22:26:33 -0500 Subject: [PATCH 097/242] fix tool call run --- .../bidirectional_streaming/agent/agent.py | 3 ++- .../bidirectional_streaming/types/io.py | 20 ++++++++++--------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index ec2606d19..79bef5126 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -409,6 +409,7 @@ async def send(): while self.active: event = await io_channel.receive() await self.send(event) + await asyncio.sleep(0) async def receive(): async for event in self.receive(): @@ -420,7 +421,7 @@ async def receive(): await asyncio.gather(send(), receive(), return_exceptions=True) finally: - io_channel.stop() + await io_channel.stop() def _validate_active_connection(self) -> None: """Validate that an active connection exists. diff --git a/src/strands/experimental/bidirectional_streaming/types/io.py b/src/strands/experimental/bidirectional_streaming/types/io.py index 25121e951..c41e175b1 100644 --- a/src/strands/experimental/bidirectional_streaming/types/io.py +++ b/src/strands/experimental/bidirectional_streaming/types/io.py @@ -7,6 +7,8 @@ from typing import Protocol +from ..types.events import BidiInputEvent, BidiOutputEvent + class BidiIO(Protocol): """Base protocol for bidirectional IO channels. @@ -21,19 +23,19 @@ async def start(self) -> dict: """Setup IO channels for input and output.""" ... - async def send(self) -> dict: - """Read input data from the IO channel source. + async def send(self, event: BidiOutputEvent) -> None: + """Process output event from the model through the IO channel. - Returns: - dict: Input event data to send to the model. + Args: + event: Output event from the model to handle. """ ... - async def receive(self, event: dict) -> None: - """Process output event from the model through the IO channel. + async def receive(self) -> BidiInputEvent: + """Read input data from the IO channel source. - Args: - event: Output event from the model to handle. + Returns: + dict: Input event data to send to the model. """ ... @@ -43,4 +45,4 @@ async def stop(self) -> None: Called by the agent during shutdown to ensure proper resource cleanup (streams, connections, etc.). """ - ... \ No newline at end of file + ... From 7101f4a6e7571cff87dab02fc08fe82fef48ea39 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 11 Nov 2025 22:27:52 -0500 Subject: [PATCH 098/242] asyncio.sleep(0.01) --- src/strands/experimental/bidirectional_streaming/agent/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 79bef5126..53080ce66 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -409,7 +409,7 @@ async def send(): while self.active: event = await io_channel.receive() await self.send(event) - await asyncio.sleep(0) + await asyncio.sleep(0.01) async def receive(): async for event in self.receive(): From 76f9983cbb69d32838101d6121200ce9be2b340f Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 11 Nov 2025 23:44:03 -0500 Subject: [PATCH 099/242] tidy --- .../bidirectional_streaming/agent/agent.py | 6 ++++-- .../bidirectional_streaming/agent/loop.py | 12 +++++++++--- .../experimental/bidirectional_streaming/io/audio.py | 11 +++++++---- .../bidirectional_streaming/models/novasonic.py | 6 ++---- 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 53080ce66..cc1468741 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -15,7 +15,7 @@ import asyncio import json import logging -from typing import Any, AsyncIterable, Callable +from typing import Any, AsyncIterable from .... import _identifier from ....tools.caller import _ToolCaller @@ -409,11 +409,13 @@ async def send(): while self.active: event = await io_channel.receive() await self.send(event) - await asyncio.sleep(0.01) + + await asyncio.sleep(0.001) async def receive(): async for event in self.receive(): await io_channel.send(event) + await asyncio.sleep(0.01) await io_channel.start() diff --git a/src/strands/experimental/bidirectional_streaming/agent/loop.py b/src/strands/experimental/bidirectional_streaming/agent/loop.py index 483ba5392..86a65387b 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/loop.py +++ b/src/strands/experimental/bidirectional_streaming/agent/loop.py @@ -7,7 +7,7 @@ import logging from typing import AsyncIterable, Awaitable, TYPE_CHECKING -from ..types.events import BidiOutputEvent, BidiTranscriptStreamEvent +from ..types.events import BidiAudioStreamEvent, BidiInterruptionEvent, BidiOutputEvent, BidiTranscriptStreamEvent from ....types._events import ToolResultEvent, ToolResultMessageEvent, ToolStreamEvent, ToolUseStreamEvent from ....types.content import Message from ....types.tools import ToolResult, ToolUse @@ -104,7 +104,7 @@ async def _run_model(self) -> None: logger.debug("running model") async for event in self._agent.model.receive(): - if not self.active: # TODO: maybe remove + if not self.active: break self._event_queue.put_nowait(event) @@ -114,9 +114,15 @@ async def _run_model(self) -> None: message: Message = {"role": event["role"], "content": [{"text": event["text"]}]} self._agent.messages.append(message) - if isinstance(event, ToolUseStreamEvent): + elif isinstance(event, ToolUseStreamEvent): self._create_task(self._run_tool(event["current_tool_use"])) + elif isinstance(event, BidiInterruptionEvent): + for _ in range(self._event_queue.qsize()): + event = self._event_queue.get_nowait() + if not isinstance(event, BidiAudioStreamEvent): + self._event_queue.put_nowait(event) + async def _run_tool(self, tool_use: ToolUse) -> None: """Task for running tool requested by the model.""" logger.debug("running tool") diff --git a/src/strands/experimental/bidirectional_streaming/io/audio.py b/src/strands/experimental/bidirectional_streaming/io/audio.py index 16b32a38e..ffa1ea4bd 100644 --- a/src/strands/experimental/bidirectional_streaming/io/audio.py +++ b/src/strands/experimental/bidirectional_streaming/io/audio.py @@ -10,7 +10,7 @@ import pyaudio from ..types.io import BidiIO -from ..types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiOutputEvent, BidiTranscriptStreamEvent +from ..types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiInterruptionEvent, BidiOutputEvent, BidiTranscriptStreamEvent logger = logging.getLogger(__name__) @@ -35,9 +35,9 @@ def __init__( - output_channels (int): Output channels (default: 1) """ default_config = { - "input_sample_rate": 24000, - "output_sample_rate": 24000, - "chunk_size": 1024, + "input_sample_rate": 16000, + "output_sample_rate": 16000, + "chunk_size": 512, "input_device_index": None, "output_device_index": None, "input_channels": 1, @@ -107,6 +107,9 @@ async def send(self, event: BidiOutputEvent) -> None: if isinstance(event, BidiAudioStreamEvent): self.output_stream.write(base64.b64decode(event["audio"])) + if isinstance(event, BidiInterruptionEvent): + print("interrupted") + elif isinstance(event, BidiTranscriptStreamEvent): print(event["current_transcript"]) diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 16c9b7970..8c23aa0da 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -16,17 +16,15 @@ import base64 import json import logging -import time import traceback import uuid -from typing import AsyncIterable, Union +from typing import AsyncIterable from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme from aws_sdk_bedrock_runtime.models import ( BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk, - InvokeModelWithBidirectionalStreamOperationOutput, ) from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver @@ -67,7 +65,7 @@ NOVA_AUDIO_OUTPUT_CONFIG = { "mediaType": "audio/lpcm", - "sampleRateHertz": 24000, + "sampleRateHertz": 16000, "sampleSizeBits": 16, "channelCount": 1, "voiceId": "matthew", From 566f952f452dd0c245690b74d1cb240dbd2bb1f3 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 12 Nov 2025 01:36:45 -0500 Subject: [PATCH 100/242] preview transcript --- .../experimental/bidirectional_streaming/io/audio.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/strands/experimental/bidirectional_streaming/io/audio.py b/src/strands/experimental/bidirectional_streaming/io/audio.py index ffa1ea4bd..05276552d 100644 --- a/src/strands/experimental/bidirectional_streaming/io/audio.py +++ b/src/strands/experimental/bidirectional_streaming/io/audio.py @@ -111,7 +111,11 @@ async def send(self, event: BidiOutputEvent) -> None: print("interrupted") elif isinstance(event, BidiTranscriptStreamEvent): - print(event["current_transcript"]) + text = event["text"] + if not event["is_final"]: + text = f"Preview: {text}" + + print(text) async def receive(self) -> BidiAudioInputEvent: """Read audio from microphone.""" From e67e90b23e1da3fbc920f149dddc2cd8a9346d7a Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 12 Nov 2025 01:39:27 -0500 Subject: [PATCH 101/242] comment --- src/strands/experimental/bidirectional_streaming/agent/loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/strands/experimental/bidirectional_streaming/agent/loop.py b/src/strands/experimental/bidirectional_streaming/agent/loop.py index 86a65387b..21094982c 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/loop.py +++ b/src/strands/experimental/bidirectional_streaming/agent/loop.py @@ -118,6 +118,7 @@ async def _run_model(self) -> None: self._create_task(self._run_tool(event["current_tool_use"])) elif isinstance(event, BidiInterruptionEvent): + # clear the audio for _ in range(self._event_queue.qsize()): event = self._event_queue.get_nowait() if not isinstance(event, BidiAudioStreamEvent): From d76fefcf6efed71c091b33070e1c677e24199e6d Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 12 Nov 2025 01:46:52 -0500 Subject: [PATCH 102/242] sleep comment --- .../experimental/bidirectional_streaming/agent/agent.py | 3 ++- src/strands/experimental/bidirectional_streaming/io/audio.py | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index cc1468741..26a6313bb 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -410,12 +410,13 @@ async def send(): event = await io_channel.receive() await self.send(event) + # TODO: Need to make tool result send in Nova provider atomic. Audio input events end up interleaving + # and leading to failures. Adding a sleep here as a temporary solution. await asyncio.sleep(0.001) async def receive(): async for event in self.receive(): await io_channel.send(event) - await asyncio.sleep(0.01) await io_channel.start() diff --git a/src/strands/experimental/bidirectional_streaming/io/audio.py b/src/strands/experimental/bidirectional_streaming/io/audio.py index 05276552d..e4480f753 100644 --- a/src/strands/experimental/bidirectional_streaming/io/audio.py +++ b/src/strands/experimental/bidirectional_streaming/io/audio.py @@ -4,6 +4,7 @@ Handles all PyAudio setup, streaming, and cleanup while keeping the core agent data-agnostic. """ +import asyncio import base64 import logging @@ -107,6 +108,10 @@ async def send(self, event: BidiOutputEvent) -> None: if isinstance(event, BidiAudioStreamEvent): self.output_stream.write(base64.b64decode(event["audio"])) + # TODO: Outputing audio to speakers is a sync operation. Adding sleep to prevent event loop hogging. Will + # follow up on identifying a cleaner approach. + await asyncio.sleep(0.01) + if isinstance(event, BidiInterruptionEvent): print("interrupted") From d225cd2efc4cc0fb3a80896fd7635e59497c30c4 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 12 Nov 2025 10:02:24 -0500 Subject: [PATCH 103/242] split BidiIO into BidiInput and BidiOutput --- .../bidirectional_streaming/agent/agent.py | 49 ++++---- .../bidirectional_streaming/agent/loop.py | 2 +- .../bidirectional_streaming/io/__init__.py | 3 +- .../bidirectional_streaming/io/audio.py | 105 +++++++++++------- .../bidirectional_streaming/io/text.py | 31 ++++++ .../bidirectional_streaming/types/__init__.py | 5 +- .../bidirectional_streaming/types/io.py | 61 +++++----- 7 files changed, 168 insertions(+), 88 deletions(-) create mode 100644 src/strands/experimental/bidirectional_streaming/io/text.py diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 26a6313bb..eab909449 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -26,12 +26,12 @@ from ....types.content import Message, Messages from ....types.tools import ToolResult, ToolUse, AgentTool -from .loop import BidiAgentLoop +from .loop import _BidiAgentLoop from ..models.bidirectional_model import BidiModel from ..models.novasonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput from ..types.events import BidiAudioInputEvent, BidiImageInputEvent, BidiTextInputEvent, BidiInputEvent, BidiOutputEvent -from ..types import BidiIO +from ..types.io import BidiInput, BidiOutput from ...tools import ToolProvider logger = logging.getLogger(__name__) @@ -119,7 +119,7 @@ def __init__( self._current_adapters = [] # Track adapters for cleanup - self._loop = BidiAgentLoop(self) + self._loop = _BidiAgentLoop(self) @property def tool(self) -> _ToolCaller: @@ -391,40 +391,51 @@ def active(self) -> bool: """True if agent loop started, False otherwise.""" return self._loop.active - async def run(self, io_channel: BidiIO) -> None: + async def run(self, inputs: list[BidiInput], outputs: list[BidiOutput]) -> None: """Run the agent using provided IO channels for bidirectional communication. Args: - io_channels: List containing either BidiIO instances. - - BidiIO: IO channel instance with send(), receive(), and end() methods + inputs: Input callables to read data from a source + outputs: Output callables to receive events from the agent Example: ```python - audio_io = AudioIO(audio_config={"input_sample_rate": 16000}) + audio_io = BidiAudioIO(audio_config={"input_sample_rate": 16000}) + text_io = BidiTextIO() agent = BidiAgent(model=model, tools=[calculator]) - await agent.run(io_channels=[audio_io]) + await agent.run(inputs=[audio_io.input()], outputs=[audio_io.output(), text_io.output()]) ``` """ - async def send(): + async def run_inputs(): while self.active: - event = await io_channel.receive() - await self.send(event) + for input_ in inputs: + event = await input_() + await self.send(event) - # TODO: Need to make tool result send in Nova provider atomic. Audio input events end up interleaving - # and leading to failures. Adding a sleep here as a temporary solution. - await asyncio.sleep(0.001) + # TODO: Need to make tool result send in Nova provider atomic. Audio input events end up interleaving + # and leading to failures. Adding a sleep here as a temporary solution. + await asyncio.sleep(0.001) - async def receive(): + async def run_outputs(): async for event in self.receive(): - await io_channel.send(event) + for output in outputs: + await output(event) - await io_channel.start() + for input_ in inputs: + await input_.start() + + for output in outputs: + await output.start() try: - await asyncio.gather(send(), receive(), return_exceptions=True) + await asyncio.gather(run_inputs(), run_outputs(), return_exceptions=True) finally: - await io_channel.stop() + for input_ in inputs: + await input_.stop() + + for output in outputs: + await output.stop() def _validate_active_connection(self) -> None: """Validate that an active connection exists. diff --git a/src/strands/experimental/bidirectional_streaming/agent/loop.py b/src/strands/experimental/bidirectional_streaming/agent/loop.py index 21094982c..e0bc02ef2 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/loop.py +++ b/src/strands/experimental/bidirectional_streaming/agent/loop.py @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) -class BidiAgentLoop: +class _BidiAgentLoop: """Agent loop.""" def __init__(self, agent: "BidiAgent") -> None: diff --git a/src/strands/experimental/bidirectional_streaming/io/__init__.py b/src/strands/experimental/bidirectional_streaming/io/__init__.py index faa969168..d099cba2f 100644 --- a/src/strands/experimental/bidirectional_streaming/io/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/io/__init__.py @@ -1,5 +1,6 @@ """IO channel implementations for bidirectional streaming.""" from .audio import BidiAudioIO +from .text import BidiTextIO -__all__ = ["BidiAudioIO"] \ No newline at end of file +__all__ = ["BidiAudioIO", "BidiTextIO"] diff --git a/src/strands/experimental/bidirectional_streaming/io/audio.py b/src/strands/experimental/bidirectional_streaming/io/audio.py index e4480f753..2ec167480 100644 --- a/src/strands/experimental/bidirectional_streaming/io/audio.py +++ b/src/strands/experimental/bidirectional_streaming/io/audio.py @@ -1,6 +1,6 @@ -"""AudioIO - Clean separation of audio functionality from core BidirectionalAgent. +"""AudioIO - Clean separation of audio functionality from core BidiAgent. -Provides audio input/output capabilities for BidirectionalAgent through the BidiIO protocol. +Provides audio input/output capabilities for BidiAgent through the BidiIO protocol. Handles all PyAudio setup, streaming, and cleanup while keeping the core agent data-agnostic. """ @@ -10,14 +10,64 @@ import pyaudio -from ..types.io import BidiIO -from ..types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiInterruptionEvent, BidiOutputEvent, BidiTranscriptStreamEvent +from ..types.io import BidiInput, BidiOutput +from ..types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiOutputEvent logger = logging.getLogger(__name__) -class BidiAudioIO(BidiIO): - """Audio IO channel for BidirectionalAgent with direct stream processing.""" +class _BidiAudioInput(BidiInput): + "Handle audio input from bidi agent." + def __init__(self, audio: "BidiAudioIO") -> None: + """Store reference to pyaudio instance.""" + self.audio = audio + + async def start(self) -> None: + """Start audio input.""" + self.audio._start() + + async def stop(self) -> None: + """Stop audio input.""" + self.audio._stop() + + async def __call__(self) -> BidiAudioInputEvent: + """Read audio from microphone.""" + audio_bytes = self.audio.input_stream.read(self.audio.chunk_size, exception_on_overflow=False) + + return BidiAudioInputEvent( + audio=base64.b64encode(audio_bytes).decode("utf-8"), + format="pcm", + sample_rate=self.audio.input_sample_rate, + channels=self.audio.input_channels, + ) + + +class _BidiAudioOutput(BidiOutput): + "Handle audio output from bidi agent." + def __init__(self, audio: "BidiAudioIO") -> None: + """Store reference to pyaudio instance.""" + self.audio = audio + + async def start(self) -> None: + """Start audio output.""" + self.audio._start() + + async def stop(self) -> None: + """Stop audio output.""" + self.audio._stop() + + async def __call__(self, event: BidiOutputEvent) -> None: + """Handle audio events with direct stream writing.""" + if isinstance(event, BidiAudioStreamEvent): + self.audio.output_stream.write(base64.b64decode(event["audio"])) + + # TODO: Outputing audio to speakers is a sync operation. Adding sleep to prevent event loop hogging. Will + # follow up on identifying a cleaner approach. + await asyncio.sleep(0.01) + + +class BidiAudioIO: + """Audio IO channel for BidiAgent with direct stream processing.""" def __init__( self, @@ -64,7 +114,15 @@ def __init__( self.output_stream = None self.interrupted = False - async def start(self) -> None: + def input(self) -> _BidiAudioInput: + "Return audio processing BidiInput" + return _BidiAudioInput(self) + + def output(self) -> _BidiAudioOutput: + "Return audio processing BidiOutput" + return _BidiAudioOutput(self) + + def _start(self) -> None: """Setup PyAudio streams for input and output.""" if self.audio: return @@ -89,7 +147,7 @@ async def start(self) -> None: output_device_index=self.output_device_index, ) - async def stop(self) -> None: + def _stop(self) -> None: """Clean up IO channel resources.""" if not self.audio: return @@ -101,34 +159,3 @@ async def stop(self) -> None: self.input_stream = None self.output_stream = None self.audio = None - - async def send(self, event: BidiOutputEvent) -> None: - """Handle audio events with direct stream writing.""" - - if isinstance(event, BidiAudioStreamEvent): - self.output_stream.write(base64.b64decode(event["audio"])) - - # TODO: Outputing audio to speakers is a sync operation. Adding sleep to prevent event loop hogging. Will - # follow up on identifying a cleaner approach. - await asyncio.sleep(0.01) - - if isinstance(event, BidiInterruptionEvent): - print("interrupted") - - elif isinstance(event, BidiTranscriptStreamEvent): - text = event["text"] - if not event["is_final"]: - text = f"Preview: {text}" - - print(text) - - async def receive(self) -> BidiAudioInputEvent: - """Read audio from microphone.""" - audio_bytes = self.input_stream.read(self.chunk_size, exception_on_overflow=False) - - return BidiAudioInputEvent( - audio=base64.b64encode(audio_bytes).decode("utf-8"), - format="pcm", - sample_rate=self.input_sample_rate, - channels=self.input_channels, - ) diff --git a/src/strands/experimental/bidirectional_streaming/io/text.py b/src/strands/experimental/bidirectional_streaming/io/text.py new file mode 100644 index 000000000..ba503f4e4 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/io/text.py @@ -0,0 +1,31 @@ +"""Handle text input and output from bidi agent.""" + +import logging + +from ..types.io import BidiOutput +from ..types.events import BidiOutputEvent, BidiInterruptionEvent, BidiTranscriptStreamEvent + +logger = logging.getLogger(__name__) + + +class _BidiTextOutput(BidiOutput): + "Handle text output from bidi agent." + async def __call__(self, event: BidiOutputEvent) -> None: + """Print text events to stdout.""" + + if isinstance(event, BidiInterruptionEvent): + print("interrupted") + + elif isinstance(event, BidiTranscriptStreamEvent): + text = event["text"] + if not event["is_final"]: + text = f"Preview: {text}" + + print(text) + + +class BidiTextIO: + "Handle text input and output from bidi agent." + def output(self) -> _BidiTextOutput: + "Return text processing BidiOutput" + return _BidiTextOutput() diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index 34dd2685b..d5263bb28 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -1,7 +1,7 @@ """Type definitions for bidirectional streaming.""" from .agent import BidiAgentInput -from .io import BidiIO +from .io import BidiInput, BidiOutput from .events import ( DEFAULT_CHANNELS, DEFAULT_FORMAT, @@ -27,7 +27,8 @@ ) __all__ = [ - "BidiIO", + "BidiInput", + "BidiOutput", "BidiAgentInput", # Input Events "BidiTextInputEvent", diff --git a/src/strands/experimental/bidirectional_streaming/types/io.py b/src/strands/experimental/bidirectional_streaming/types/io.py index c41e175b1..8b79455ec 100644 --- a/src/strands/experimental/bidirectional_streaming/types/io.py +++ b/src/strands/experimental/bidirectional_streaming/types/io.py @@ -1,48 +1,57 @@ -"""BidiIO protocol for bidirectional streaming IO channels. +"""Protocol for bidirectional streaming IO channels. -Defines the standard interface that all bidirectional IO channels must implement -for integration with BidirectionalAgent. This protocol enables clean -separation between the agent's core logic and hardware-specific implementations. +Defines callable protocols for input and output channels that can be used +with BidiAgent. This approach provides better typing and flexibility +by separating input and output concerns into independent callables. """ -from typing import Protocol +from typing import Awaitable, Protocol from ..types.events import BidiInputEvent, BidiOutputEvent -class BidiIO(Protocol): - """Base protocol for bidirectional IO channels. - - Defines the interface that IO channels must implement to work with - BidirectionalAgent. IO channels handle hardware abstraction (audio, video, - WebSocket, etc.) while the agent handles model communication and logic. - """ +class BidiInput(Protocol): + """Protocol for bidirectional input callables. - async def start(self) -> dict: + Input callables read data from a source (microphone, camera, websocket, etc.) + and return events to be sent to the agent. + """ - """Setup IO channels for input and output.""" + async def start(self) -> None: + """Start input.""" ... - async def send(self, event: BidiOutputEvent) -> None: - """Process output event from the model through the IO channel. - - Args: - event: Output event from the model to handle. - """ + async def stop(self) -> None: + """Stop input.""" ... - async def receive(self) -> BidiInputEvent: - """Read input data from the IO channel source. + def __call__(self) -> Awaitable[BidiInputEvent]: + """Read input data from the source. Returns: - dict: Input event data to send to the model. + Awaitable that resolves to an input event (audio, text, image, etc.) """ ... +class BidiOutput(Protocol): + """Protocol for bidirectional output callables. + + Output callables receive events from the agent and handle them appropriately + (play audio, display text, send over websocket, etc.). + """ + + async def start(self) -> None: + """Start output.""" + ... + async def stop(self) -> None: - """Clean up IO channel resources. + """Stop output.""" + ... + + def __call__(self, event: BidiOutputEvent) -> Awaitable[None]: + """Process output events from the agent. - Called by the agent during shutdown to ensure proper - resource cleanup (streams, connections, etc.). + Args: + event: Output event from the agent (audio, text, tool calls, etc.) """ ... From aa289ba36524ad2f37a4b924deb18b178725d76f Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 12 Nov 2025 12:38:44 -0500 Subject: [PATCH 104/242] Update test script --- .../bidirectional_streaming/scripts/test_bidi.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/scripts/test_bidi.py b/src/strands/experimental/bidirectional_streaming/scripts/test_bidi.py index 359f04dbf..f04677635 100644 --- a/src/strands/experimental/bidirectional_streaming/scripts/test_bidi.py +++ b/src/strands/experimental/bidirectional_streaming/scripts/test_bidi.py @@ -8,7 +8,7 @@ from strands.experimental.bidirectional_streaming.agent.agent import BidiAgent from strands.experimental.bidirectional_streaming.models.novasonic import BidiNovaSonicModel -from strands.experimental.bidirectional_streaming.io.audio import AudioIO +from strands.experimental.bidirectional_streaming.io import BidiAudioIO, BidiTextIO from strands_tools import calculator @@ -17,13 +17,14 @@ async def main(): # Nova Sonic model - adapter = AudioIO() + audio_io = BidiAudioIO(audio_config={}) + text_io = BidiTextIO() model = BidiNovaSonicModel(region="us-east-1") async with BidiAgent(model=model, tools=[calculator]) as agent: print("New BidiAgent Experience") print("Try asking: 'What is 25 times 8?' or 'Calculate the square root of 144'") - await agent.run(io_channels=[adapter]) + await agent.run(inputs=[audio_io.input()], outputs=[audio_io.output(), text_io.output()]) if __name__ == "__main__": From 5148c350c93e35233953f819ef8573386378e3f8 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 12 Nov 2025 12:43:29 -0500 Subject: [PATCH 105/242] Update dependencies --- pyproject.toml | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7810d09c7..cc335b80b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,21 +57,30 @@ sagemaker = [ bidirectional-streaming-nova = [ "pyaudio>=0.2.13", "rx>=3.2.0", - "smithy-aws-core>=0.0.1", + "smithy-aws-core>=0.0.1; python_version>='3.12'", "pytz", - "aws_sdk_bedrock_runtime", + "aws_sdk_bedrock_runtime; python_version>='3.12'", ] bidirectional-streaming-openai = [ "pyaudio>=0.2.13", "websockets>=14.0,<16.0", ] +bidirectional-streaming-gemini = [ + "pyaudio>=0.2.13", + "google-genai>=1.32.0,<2.0.0", + "opencv-python>=4.8.0", + "pillow>=10.0.0", +] bidirectional-streaming = [ "pyaudio>=0.2.13", "rx>=3.2.0", - "smithy-aws-core>=0.0.1", + "smithy-aws-core>=0.0.1; python_version>='3.12'", "pytz", - "aws_sdk_bedrock_runtime", + "aws_sdk_bedrock_runtime; python_version>='3.12'", "websockets>=14.0,<16.0", + "google-genai>=1.32.0,<2.0.0", + "opencv-python>=4.8.0", + "pillow>=10.0.0", ] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ @@ -88,7 +97,8 @@ a2a = [ "fastapi>=0.115.12,<1.0.0", "starlette>=0.46.2,<1.0.0", ] -all = ["strands-agents[a2a,anthropic,docs,gemini,bidirectional-streaming,docs,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] +bidirectional = ["bidirectional-streaming-openai", "bidirectional-streaming-gemini", "bidirectional-streaming-nova", "bidirectional-streaming"] +all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] dev = [ "commitizen>=4.4.0,<5.0.0", @@ -119,12 +129,13 @@ source = "vcs" # Use git tags for versioning [tool.hatch.envs.hatch-static-analysis] installer = "uv" -features = ["all"] +features = ["all", "bidirectional-streaming"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.13.0,<0.14.0", - # Include required pacakge dependencies for mypy + # Include required package dependencies for mypy "strands-agents @ {root:uri}", + "strands-agents-tools", ] # Define static-analysis scripts so we can include mypy as part of the linting check @@ -146,7 +157,7 @@ lint-fix = [ [tool.hatch.envs.hatch-test] installer = "uv" -features = ["all"] +features = ["all", "bidirectional-streaming"] extra-args = ["-n", "auto", "-vv"] dependencies = [ "pytest>=8.0.0,<9.0.0", @@ -169,7 +180,7 @@ cov-report = [] [tool.hatch.envs.default] installer = "uv" dev-mode = true -features = ["all"] +features = ["all", "bidirectional-streaming"] dependencies = [ "commitizen>=4.4.0,<5.0.0", "hatch>=1.0.0,<2.0.0", From 00b379d500a93a50f23daeeb44e589b363a8adaf Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 12 Nov 2025 12:46:09 -0500 Subject: [PATCH 106/242] Update dependencies --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index cc335b80b..fe4f7edb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -135,7 +135,6 @@ dependencies = [ "ruff>=0.13.0,<0.14.0", # Include required package dependencies for mypy "strands-agents @ {root:uri}", - "strands-agents-tools", ] # Define static-analysis scripts so we can include mypy as part of the linting check From cfe7e260003c99ec4f199359f2ac6540e3989e6c Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 12 Nov 2025 12:50:06 -0500 Subject: [PATCH 107/242] Update dependencies --- pyproject.toml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fe4f7edb0..efdf1a5d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,24 +54,24 @@ sagemaker = [ "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0", "openai>=1.68.0,<2.0.0", # SageMaker uses OpenAI-compatible interface ] -bidirectional-streaming-nova = [ +bidi-novasonic = [ "pyaudio>=0.2.13", "rx>=3.2.0", "smithy-aws-core>=0.0.1; python_version>='3.12'", "pytz", "aws_sdk_bedrock_runtime; python_version>='3.12'", ] -bidirectional-streaming-openai = [ +bidi-openai = [ "pyaudio>=0.2.13", "websockets>=14.0,<16.0", ] -bidirectional-streaming-gemini = [ +bidi-gemini = [ "pyaudio>=0.2.13", "google-genai>=1.32.0,<2.0.0", "opencv-python>=4.8.0", "pillow>=10.0.0", ] -bidirectional-streaming = [ +bidi = [ "pyaudio>=0.2.13", "rx>=3.2.0", "smithy-aws-core>=0.0.1; python_version>='3.12'", @@ -97,7 +97,7 @@ a2a = [ "fastapi>=0.115.12,<1.0.0", "starlette>=0.46.2,<1.0.0", ] -bidirectional = ["bidirectional-streaming-openai", "bidirectional-streaming-gemini", "bidirectional-streaming-nova", "bidirectional-streaming"] +bidirectional = ["bidi-openai", "bidi-gemini", "bidi-novasonic", "bidi"] all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] dev = [ From ba32dda0d97df239fc5193228afe6c8056c21df4 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 12 Nov 2025 12:51:25 -0500 Subject: [PATCH 108/242] Update dependencies --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index efdf1a5d5..df84586f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -129,7 +129,7 @@ source = "vcs" # Use git tags for versioning [tool.hatch.envs.hatch-static-analysis] installer = "uv" -features = ["all", "bidirectional-streaming"] +features = ["all", "bidi"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.13.0,<0.14.0", @@ -156,7 +156,7 @@ lint-fix = [ [tool.hatch.envs.hatch-test] installer = "uv" -features = ["all", "bidirectional-streaming"] +features = ["all", "bidi"] extra-args = ["-n", "auto", "-vv"] dependencies = [ "pytest>=8.0.0,<9.0.0", @@ -179,7 +179,7 @@ cov-report = [] [tool.hatch.envs.default] installer = "uv" dev-mode = true -features = ["all", "bidirectional-streaming"] +features = ["all", "bidi"] dependencies = [ "commitizen>=4.4.0,<5.0.0", "hatch>=1.0.0,<2.0.0", From 2fd1b4470482b2a9d934626c0c9c4c633a289a28 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 12 Nov 2025 12:53:32 -0500 Subject: [PATCH 109/242] Update dependencies --- pyproject.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index df84586f7..998059f17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,7 +97,7 @@ a2a = [ "fastapi>=0.115.12,<1.0.0", "starlette>=0.46.2,<1.0.0", ] -bidirectional = ["bidi-openai", "bidi-gemini", "bidi-novasonic", "bidi"] +bidi-all = ["bidi-openai", "bidi-gemini", "bidi-novasonic", "bidi"] all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] dev = [ @@ -129,7 +129,7 @@ source = "vcs" # Use git tags for versioning [tool.hatch.envs.hatch-static-analysis] installer = "uv" -features = ["all", "bidi"] +features = ["all", "bidi-all"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.13.0,<0.14.0", @@ -156,7 +156,7 @@ lint-fix = [ [tool.hatch.envs.hatch-test] installer = "uv" -features = ["all", "bidi"] +features = ["all", "bidi-all"] extra-args = ["-n", "auto", "-vv"] dependencies = [ "pytest>=8.0.0,<9.0.0", @@ -179,7 +179,7 @@ cov-report = [] [tool.hatch.envs.default] installer = "uv" dev-mode = true -features = ["all", "bidi"] +features = ["all", "bidi-all"] dependencies = [ "commitizen>=4.4.0,<5.0.0", "hatch>=1.0.0,<2.0.0", From 9756570c9ccf03604012bde420a86344ea25f68f Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 12 Nov 2025 13:17:29 -0500 Subject: [PATCH 110/242] Rename directories and files to use bidi instead of bidirectional --- src/strands/experimental/bidi/__init__.py | 79 ++ .../experimental/bidi/agent/__init__.py | 5 + src/strands/experimental/bidi/agent/agent.py | 447 +++++++++++ src/strands/experimental/bidi/agent/loop.py | 162 ++++ src/strands/experimental/bidi/io/__init__.py | 6 + src/strands/experimental/bidi/io/audio.py | 161 ++++ src/strands/experimental/bidi/io/text.py | 31 + .../experimental/bidi/models/__init__.py | 13 + .../experimental/bidi/models/bidi_model.py | 108 +++ .../bidi/models/bidirectional_model.py | 108 +++ .../experimental/bidi/models/gemini_live.py | 539 +++++++++++++ .../experimental/bidi/models/novasonic.py | 743 ++++++++++++++++++ .../experimental/bidi/models/openai.py | 653 +++++++++++++++ .../experimental/bidi/scripts/test_bidi.py | 38 + .../bidi/scripts/test_bidi_novasonic.py | 256 ++++++ .../bidi/scripts/test_bidi_openai.py | 324 ++++++++ .../bidi/scripts/test_gemini_live.py | 363 +++++++++ .../experimental/bidi/types/__init__.py | 57 ++ src/strands/experimental/bidi/types/agent.py | 10 + src/strands/experimental/bidi/types/events.py | 521 ++++++++++++ src/strands/experimental/bidi/types/io.py | 57 ++ tests/strands/experimental/bidi/__init__.py | 1 + .../experimental/bidi/models/__init__.py | 1 + .../bidi/models/test_gemini_live.py | 487 ++++++++++++ .../bidi/models/test_novasonic.py | 458 +++++++++++ .../bidi/models/test_openai_realtime.py | 538 +++++++++++++ .../experimental/bidi/types/__init__.py | 1 + .../experimental/bidi/types/test_events.py | 164 ++++ tests_integ/bidi/__init__.py | 1 + tests_integ/bidi/conftest.py | 28 + tests_integ/bidi/context.py | 365 +++++++++ tests_integ/bidi/generators/__init__.py | 1 + tests_integ/bidi/generators/audio.py | 159 ++++ tests_integ/bidi/test_bidirectional_agent.py | 220 ++++++ tests_integ/bidi/wrappers/__init__.py | 4 + 35 files changed, 7109 insertions(+) create mode 100644 src/strands/experimental/bidi/__init__.py create mode 100644 src/strands/experimental/bidi/agent/__init__.py create mode 100644 src/strands/experimental/bidi/agent/agent.py create mode 100644 src/strands/experimental/bidi/agent/loop.py create mode 100644 src/strands/experimental/bidi/io/__init__.py create mode 100644 src/strands/experimental/bidi/io/audio.py create mode 100644 src/strands/experimental/bidi/io/text.py create mode 100644 src/strands/experimental/bidi/models/__init__.py create mode 100644 src/strands/experimental/bidi/models/bidi_model.py create mode 100644 src/strands/experimental/bidi/models/bidirectional_model.py create mode 100644 src/strands/experimental/bidi/models/gemini_live.py create mode 100644 src/strands/experimental/bidi/models/novasonic.py create mode 100644 src/strands/experimental/bidi/models/openai.py create mode 100644 src/strands/experimental/bidi/scripts/test_bidi.py create mode 100644 src/strands/experimental/bidi/scripts/test_bidi_novasonic.py create mode 100644 src/strands/experimental/bidi/scripts/test_bidi_openai.py create mode 100644 src/strands/experimental/bidi/scripts/test_gemini_live.py create mode 100644 src/strands/experimental/bidi/types/__init__.py create mode 100644 src/strands/experimental/bidi/types/agent.py create mode 100644 src/strands/experimental/bidi/types/events.py create mode 100644 src/strands/experimental/bidi/types/io.py create mode 100644 tests/strands/experimental/bidi/__init__.py create mode 100644 tests/strands/experimental/bidi/models/__init__.py create mode 100644 tests/strands/experimental/bidi/models/test_gemini_live.py create mode 100644 tests/strands/experimental/bidi/models/test_novasonic.py create mode 100644 tests/strands/experimental/bidi/models/test_openai_realtime.py create mode 100644 tests/strands/experimental/bidi/types/__init__.py create mode 100644 tests/strands/experimental/bidi/types/test_events.py create mode 100644 tests_integ/bidi/__init__.py create mode 100644 tests_integ/bidi/conftest.py create mode 100644 tests_integ/bidi/context.py create mode 100644 tests_integ/bidi/generators/__init__.py create mode 100644 tests_integ/bidi/generators/audio.py create mode 100644 tests_integ/bidi/test_bidirectional_agent.py create mode 100644 tests_integ/bidi/wrappers/__init__.py diff --git a/src/strands/experimental/bidi/__init__.py b/src/strands/experimental/bidi/__init__.py new file mode 100644 index 000000000..033a4bb78 --- /dev/null +++ b/src/strands/experimental/bidi/__init__.py @@ -0,0 +1,79 @@ +"""Bidirectional streaming package.""" + +# Main components - Primary user interface +from .agent.agent import BidiAgent + +# IO channels - Hardware abstraction +from .io.audio import BidiAudioIO + +# Model interface (for custom implementations) +from .models.bidirectional_model import BidiModel + +# Model providers - What users need to create models +from .models.gemini_live import BidiGeminiLiveModel +from .models.novasonic import BidiNovaSonicModel +from .models.openai import BidiOpenAIRealtimeModel + +# Event types - For type hints and event handling +from .types.events import ( + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionCloseEvent, + BidiConnectionStartEvent, + BidiErrorEvent, + BidiImageInputEvent, + BidiInputEvent, + BidiInterruptionEvent, + ModalityUsage, + BidiUsageEvent, + BidiOutputEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, +) + +# Re-export standard agent events for tool handling +from ...types._events import ( + ToolResultEvent, + ToolStreamEvent, + ToolUseStreamEvent, +) + +__all__ = [ + # Main interface + "BidiAgent", + # IO channels + "BidiAudioIO", + # Model providers + "BidiGeminiLiveModel", + "BidiNovaSonicModel", + "BidiOpenAIRealtimeModel", + + # Input Event types + "BidiTextInputEvent", + "BidiAudioInputEvent", + "BidiImageInputEvent", + "BidiInputEvent", + + # Output Event types + "BidiConnectionStartEvent", + "BidiConnectionCloseEvent", + "BidiResponseStartEvent", + "BidiResponseCompleteEvent", + "BidiAudioStreamEvent", + "BidiTranscriptStreamEvent", + "BidiInterruptionEvent", + "BidiUsageEvent", + "ModalityUsage", + "BidiErrorEvent", + "BidiOutputEvent", + + # Tool Event types (reused from standard agent) + "ToolUseStreamEvent", + "ToolResultEvent", + "ToolStreamEvent", + + # Model interface + "BidiModel", +] diff --git a/src/strands/experimental/bidi/agent/__init__.py b/src/strands/experimental/bidi/agent/__init__.py new file mode 100644 index 000000000..564973099 --- /dev/null +++ b/src/strands/experimental/bidi/agent/__init__.py @@ -0,0 +1,5 @@ +"""Bidirectional agent for real-time streaming conversations.""" + +from .agent import BidiAgent + +__all__ = ["BidiAgent"] diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py new file mode 100644 index 000000000..eab909449 --- /dev/null +++ b/src/strands/experimental/bidi/agent/agent.py @@ -0,0 +1,447 @@ +"""Bidirectional Agent for real-time streaming conversations. + +Provides real-time audio and text interaction through persistent streaming connections. +Unlike traditional request-response patterns, this agent maintains long-running +conversations where users can interrupt, provide additional input, and receive +continuous responses including audio output. + +Key capabilities: +- Persistent conversation connections with concurrent processing +- Real-time audio input/output streaming +- Automatic interruption detection and tool execution +- Event-driven communication with model providers +""" + +import asyncio +import json +import logging +from typing import Any, AsyncIterable + +from .... import _identifier +from ....tools.caller import _ToolCaller +from ....tools.executors import ConcurrentToolExecutor +from ....tools.executors._executor import ToolExecutor +from ....tools.registry import ToolRegistry +from ....tools.watcher import ToolWatcher +from ....types.content import Message, Messages +from ....types.tools import ToolResult, ToolUse, AgentTool + +from .loop import _BidiAgentLoop +from ..models.bidirectional_model import BidiModel +from ..models.novasonic import BidiNovaSonicModel +from ..types.agent import BidiAgentInput +from ..types.events import BidiAudioInputEvent, BidiImageInputEvent, BidiTextInputEvent, BidiInputEvent, BidiOutputEvent +from ..types.io import BidiInput, BidiOutput +from ...tools import ToolProvider + +logger = logging.getLogger(__name__) + +_DEFAULT_AGENT_NAME = "Strands Agents" +_DEFAULT_AGENT_ID = "default" + + +class BidiAgent: + """Agent for bidirectional streaming conversations. + + Enables real-time audio and text interaction with AI models through persistent + connections. Supports concurrent tool execution and interruption handling. + """ + + def __init__( + self, + model: BidiModel| str | None = None, + tools: list[str| AgentTool| ToolProvider]| None = None, + system_prompt: str | None = None, + messages: Messages | None = None, + record_direct_tool_call: bool = True, + load_tools_from_directory: bool = False, + agent_id: str | None = None, + name: str | None = None, + tool_executor: ToolExecutor | None = None, + description: str | None = None, + **kwargs: Any, + ): + """Initialize bidirectional agent. + + Args: + model: BidiModel instance, string model_id, or None for default detection. + tools: Optional list of tools with flexible format support. + system_prompt: Optional system prompt for conversations. + messages: Optional conversation history to initialize with. + record_direct_tool_call: Whether to record direct tool calls in message history. + load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory. + agent_id: Optional ID for the agent, useful for connection management and multi-agent scenarios. + name: Name of the Agent. + tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). + description: Description of what the Agent does. + **kwargs: Additional configuration for future extensibility. + + Raises: + ValueError: If model configuration is invalid. + TypeError: If model type is unsupported. + """ + self.model = ( + BidiNovaSonicModel() + if not model + else BidiNovaSonicModel(model_id=model) + if isinstance(model, str) + else model + ) + self.system_prompt = system_prompt + self.messages = messages or [] + + # Agent identification + self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) + self.name = name or _DEFAULT_AGENT_NAME + self.description = description + + # Tool execution configuration + self.record_direct_tool_call = record_direct_tool_call + self.load_tools_from_directory = load_tools_from_directory + + # Initialize tool registry + self.tool_registry = ToolRegistry() + + if tools is not None: + self.tool_registry.process_tools(tools) + + self.tool_registry.initialize_tools(self.load_tools_from_directory) + + # Initialize tool watcher if directory loading is enabled + if self.load_tools_from_directory: + self.tool_watcher = ToolWatcher(tool_registry=self.tool_registry) + + # Initialize tool executor + self.tool_executor = tool_executor or ConcurrentToolExecutor() + + # Initialize other components + self._tool_caller = _ToolCaller(self) + + self._current_adapters = [] # Track adapters for cleanup + + self._loop = _BidiAgentLoop(self) + + @property + def tool(self) -> _ToolCaller: + """Call tool as a function. + + Returns: + ToolCaller for method-style tool execution. + + Example: + ``` + agent = BidiAgent(model=model, tools=[calculator]) + agent.tool.calculator(expression="2+2") + ``` + """ + return self._tool_caller + + @property + def tool_names(self) -> list[str]: + """Get a list of all registered tool names. + + Returns: + Names of all tools available to this agent. + """ + all_tools = self.tool_registry.get_all_tools_config() + return list(all_tools.keys()) + + def _record_tool_execution( + self, + tool: ToolUse, + tool_result: ToolResult, + user_message_override: str | None, + ) -> None: + """Record a tool execution in the message history. + + Creates a sequence of messages that represent the tool execution: + + 1. A user message describing the tool call + 2. An assistant message with the tool use + 3. A user message with the tool result + 4. An assistant message acknowledging the tool call + + Args: + tool: The tool call information. + tool_result: The result returned by the tool. + user_message_override: Optional custom message to include. + """ + # Filter tool input parameters to only include those defined in tool spec + filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"]) + + # Create user message describing the tool call + input_parameters = json.dumps(filtered_input, default=lambda o: f"<>") + + user_msg_content = [ + {"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")} + ] + + # Add override message if provided + if user_message_override: + user_msg_content.insert(0, {"text": f"{user_message_override}\n"}) + + # Create filtered tool use for message history + filtered_tool: ToolUse = { + "toolUseId": tool["toolUseId"], + "name": tool["name"], + "input": filtered_input, + } + + # Create the message sequence + user_msg: Message = { + "role": "user", + "content": user_msg_content, + } + tool_use_msg: Message = { + "role": "assistant", + "content": [{"toolUse": filtered_tool}], + } + tool_result_msg: Message = { + "role": "user", + "content": [{"toolResult": tool_result}], + } + assistant_msg: Message = { + "role": "assistant", + "content": [{"text": f"agent.tool.{tool['name']} was called."}], + } + + # Add to message history + self.messages.append(user_msg) + self.messages.append(tool_use_msg) + self.messages.append(tool_result_msg) + self.messages.append(assistant_msg) + + logger.debug("Direct tool call recorded in message history: %s", tool["name"]) + + def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]: + """Filter input parameters to only include those defined in the tool specification. + + Args: + tool_name: Name of the tool to get specification for + input_params: Original input parameters + + Returns: + Filtered parameters containing only those defined in tool spec + """ + all_tools_config = self.tool_registry.get_all_tools_config() + tool_spec = all_tools_config.get(tool_name) + + if not tool_spec or "inputSchema" not in tool_spec: + return input_params.copy() + + properties = tool_spec["inputSchema"]["json"]["properties"] + return {k: v for k, v in input_params.items() if k in properties} + + async def start(self) -> None: + """Start a persistent bidirectional conversation connection. + + Initializes the streaming connection and starts background tasks for processing + model events, tool execution, and connection management. + """ + logger.debug("starting agent") + + await self._loop.start() + + async def send(self, input_data: BidiAgentInput) -> None: + """Send input to the model (text, audio, image, or event dict). + + Unified method for sending text, audio, and image input to the model during + an active conversation session. Accepts TypedEvent instances or plain dicts + (e.g., from WebSocket clients) which are automatically reconstructed. + + Args: + input_data: Can be: + - str: Text message from user + - BidiAudioInputEvent: Audio data with format/sample rate + - BidiImageInputEvent: Image data with MIME type + - dict: Event dictionary (will be reconstructed to TypedEvent) + + Raises: + ValueError: If no active session or invalid input type. + + Example: + await agent.send("Hello") + await agent.send(BidiAudioInputEvent(audio="base64...", format="pcm", ...)) + await agent.send({"type": "bidirectional_text_input", "text": "Hello", "role": "user"}) + """ + self._validate_active_connection() + + # Handle string input + if isinstance(input_data, str): + # Add user text message to history + user_message: Message = {"role": "user", "content": [{"text": input_data}]} + + self.messages.append(user_message) + + logger.debug("Text sent: %d characters", len(input_data)) + # Create BidiTextInputEvent for send() + text_event = BidiTextInputEvent(text=input_data, role="user") + await self.model.send(text_event) + return + + # Handle BidiInputEvent instances + # Check this before dict since TypedEvent inherits from dict + if isinstance(input_data, BidiInputEvent): + await self.model.send(input_data) + return + + # Handle plain dict - reconstruct TypedEvent for WebSocket integration + if isinstance(input_data, dict) and "type" in input_data: + event_type = input_data["type"] + if event_type == "bidi_text_input": + input_event = BidiTextInputEvent(text=input_data["text"], role=input_data["role"]) + elif event_type == "bidi_audio_input": + input_event = BidiAudioInputEvent( + audio=input_data["audio"], + format=input_data["format"], + sample_rate=input_data["sample_rate"], + channels=input_data["channels"] + ) + elif event_type == "bidi_image_input": + input_event = BidiImageInputEvent( + image=input_data["image"], + mime_type=input_data["mime_type"] + ) + else: + raise ValueError(f"Unknown event type: {event_type}") + + # Send the reconstructed TypedEvent + await self.model.send(input_event) + return + + # If we get here, input type is invalid + raise ValueError( + f"Input must be a string, BidiInputEvent (BidiTextInputEvent/BidiAudioInputEvent/BidiImageInputEvent), or event dict with 'type' field, got: {type(input_data)}" + ) + + async def receive(self) -> AsyncIterable[BidiOutputEvent]: + """Receive events from the model including audio, text, and tool calls. + + Yields model output events processed by background tasks including audio output, + text responses, tool calls, and connection updates. + + Yields: + Model and tool call events. + """ + async for event in self._loop.receive(): + yield event + + async def stop(self) -> None: + """End the conversation connection and cleanup all resources. + + Terminates the streaming connection, cancels background tasks, and + closes the connection to the model provider. + """ + await self._loop.stop() + + async def __aenter__(self) -> "BidiAgent": + """Async context manager entry point. + + Automatically starts the bidirectional connection when entering the context. + + Returns: + Self for use in the context. + """ + logger.debug("Entering async context manager - starting connection") + await self.start() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Async context manager exit point. + + Automatically ends the connection and cleans up resources including adapters + when exiting the context, regardless of whether an exception occurred. + + Args: + exc_type: Exception type if an exception occurred, None otherwise. + exc_val: Exception value if an exception occurred, None otherwise. + exc_tb: Exception traceback if an exception occurred, None otherwise. + """ + try: + logger.debug("Exiting async context manager - cleaning up adapters and connection") + + # Cleanup adapters if any are currently active + for adapter in self._current_adapters: + if hasattr(adapter, "cleanup"): + try: + adapter.stop() + logger.debug(f"Cleaned up adapter: {type(adapter).__name__}") + except Exception as adapter_error: + logger.warning(f"Error cleaning up adapter: {adapter_error}") + + # Clear current adapters + self._current_adapters = [] + + # Cleanup agent connection + await self.stop() + + except Exception as cleanup_error: + if exc_type is None: + # No original exception, re-raise cleanup error + logger.error("Error during context manager cleanup: %s", cleanup_error) + raise + else: + # Original exception exists, log cleanup error but don't suppress original + logger.error( + "Error during context manager cleanup (suppressed due to original exception): %s", cleanup_error + ) + + @property + def active(self) -> bool: + """True if agent loop started, False otherwise.""" + return self._loop.active + + async def run(self, inputs: list[BidiInput], outputs: list[BidiOutput]) -> None: + """Run the agent using provided IO channels for bidirectional communication. + + Args: + inputs: Input callables to read data from a source + outputs: Output callables to receive events from the agent + + Example: + ```python + audio_io = BidiAudioIO(audio_config={"input_sample_rate": 16000}) + text_io = BidiTextIO() + agent = BidiAgent(model=model, tools=[calculator]) + await agent.run(inputs=[audio_io.input()], outputs=[audio_io.output(), text_io.output()]) + ``` + """ + async def run_inputs(): + while self.active: + for input_ in inputs: + event = await input_() + await self.send(event) + + # TODO: Need to make tool result send in Nova provider atomic. Audio input events end up interleaving + # and leading to failures. Adding a sleep here as a temporary solution. + await asyncio.sleep(0.001) + + async def run_outputs(): + async for event in self.receive(): + for output in outputs: + await output(event) + + for input_ in inputs: + await input_.start() + + for output in outputs: + await output.start() + + try: + await asyncio.gather(run_inputs(), run_outputs(), return_exceptions=True) + + finally: + for input_ in inputs: + await input_.stop() + + for output in outputs: + await output.stop() + + def _validate_active_connection(self) -> None: + """Validate that an active connection exists. + + Raises: + ValueError: If no active connection. + """ + if not self.active: + raise ValueError("No active conversation. Call start() first or use async context manager.") diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py new file mode 100644 index 000000000..e0bc02ef2 --- /dev/null +++ b/src/strands/experimental/bidi/agent/loop.py @@ -0,0 +1,162 @@ +"""Agent loop. + +The agent loop handles the events received from the model and executes tools when given a tool use request. +""" + +import asyncio +import logging +from typing import AsyncIterable, Awaitable, TYPE_CHECKING + +from ..types.events import BidiAudioStreamEvent, BidiInterruptionEvent, BidiOutputEvent, BidiTranscriptStreamEvent +from ....types._events import ToolResultEvent, ToolResultMessageEvent, ToolStreamEvent, ToolUseStreamEvent +from ....types.content import Message +from ....types.tools import ToolResult, ToolUse + +if TYPE_CHECKING: + from .agent import BidiAgent + +logger = logging.getLogger(__name__) + + +class _BidiAgentLoop: + """Agent loop.""" + + def __init__(self, agent: "BidiAgent") -> None: + """Initialize members of the agent loop. + + Note, before receiving events from the loop, the user must call `start`. + + Args: + agent: Bidirectional agent to loop over. + """ + self._agent = agent + self._event_queue = asyncio.Queue() # queue model and tool call events + self._tasks = set() # track active async tasks created in loop + self._active = False # flag if agent loop is started + + async def start(self) -> None: + """Start the agent loop. + + The agent model is started as part of this call. + """ + if self.active: + return + + logger.debug("starting agent loop") + + await self._agent.model.start( + system_prompt=self._agent.system_prompt, + tools=self._agent.tool_registry.get_all_tool_specs(), + messages=self._agent.messages, + ) + + self._create_task(self._run_model()) + + self._active = True + + async def stop(self) -> None: + """Stop the agent loop.""" + if not self.active: + return + + logger.debug("stopping agent loop") + + for task in self._tasks: + task.cancel() + + await asyncio.gather(*self._tasks, return_exceptions=True) + + await self._agent.model.stop() + + self._active = False + + async def receive(self) -> AsyncIterable[BidiOutputEvent]: + """Receive model and tool call events.""" + while self.active: + try: + yield self._event_queue.get_nowait() + except asyncio.QueueEmpty: + pass + + # unblock the event loop + await asyncio.sleep(0) + + @property + def active(self) -> bool: + """True if agent loop started, False otherwise.""" + return self._active + + def _create_task(self, coro: Awaitable[None]) -> None: + """Utilitly to create async task. + + Adds a clean up callback to run after task completes. + """ + task = asyncio.create_task(coro) + task.add_done_callback(lambda task: self._tasks.remove(task)) + + self._tasks.add(task) + + async def _run_model(self) -> None: + """Task for running the model. + + Events are streamed through the event queue. + """ + logger.debug("running model") + + async for event in self._agent.model.receive(): + if not self.active: + break + + self._event_queue.put_nowait(event) + + if isinstance(event, BidiTranscriptStreamEvent): + if event["is_final"]: + message: Message = {"role": event["role"], "content": [{"text": event["text"]}]} + self._agent.messages.append(message) + + elif isinstance(event, ToolUseStreamEvent): + self._create_task(self._run_tool(event["current_tool_use"])) + + elif isinstance(event, BidiInterruptionEvent): + # clear the audio + for _ in range(self._event_queue.qsize()): + event = self._event_queue.get_nowait() + if not isinstance(event, BidiAudioStreamEvent): + self._event_queue.put_nowait(event) + + async def _run_tool(self, tool_use: ToolUse) -> None: + """Task for running tool requested by the model.""" + logger.debug("running tool") + + result: ToolResult = None + + try: + tool = self._agent.tool_registry.registry[tool_use["name"]] + invocation_state = {} + + async for event in tool.stream(tool_use, invocation_state): + if isinstance(event, ToolResultEvent): + self._event_queue.put_nowait(event) + result = event.tool_result + break + + if isinstance(event, ToolStreamEvent): + self._event_queue.put_nowait(event) + else: + self._event_queue.put_nowait(ToolStreamEvent(tool_use, event)) + + except Exception as e: + result = { + "toolUseId": tool_use["toolUseId"], + "status": "error", + "content": [{"text": f"Error: {str(e)}"}] + } + + await self._agent.model.send(ToolResultEvent(result)) + + message: Message = { + "role": "user", + "content": [{"toolResult": result}], + } + self._agent.messages.append(message) + self._event_queue.put_nowait(ToolResultMessageEvent(message)) diff --git a/src/strands/experimental/bidi/io/__init__.py b/src/strands/experimental/bidi/io/__init__.py new file mode 100644 index 000000000..d099cba2f --- /dev/null +++ b/src/strands/experimental/bidi/io/__init__.py @@ -0,0 +1,6 @@ +"""IO channel implementations for bidirectional streaming.""" + +from .audio import BidiAudioIO +from .text import BidiTextIO + +__all__ = ["BidiAudioIO", "BidiTextIO"] diff --git a/src/strands/experimental/bidi/io/audio.py b/src/strands/experimental/bidi/io/audio.py new file mode 100644 index 000000000..2ec167480 --- /dev/null +++ b/src/strands/experimental/bidi/io/audio.py @@ -0,0 +1,161 @@ +"""AudioIO - Clean separation of audio functionality from core BidiAgent. + +Provides audio input/output capabilities for BidiAgent through the BidiIO protocol. +Handles all PyAudio setup, streaming, and cleanup while keeping the core agent data-agnostic. +""" + +import asyncio +import base64 +import logging + +import pyaudio + +from ..types.io import BidiInput, BidiOutput +from ..types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiOutputEvent + +logger = logging.getLogger(__name__) + + +class _BidiAudioInput(BidiInput): + "Handle audio input from bidi agent." + def __init__(self, audio: "BidiAudioIO") -> None: + """Store reference to pyaudio instance.""" + self.audio = audio + + async def start(self) -> None: + """Start audio input.""" + self.audio._start() + + async def stop(self) -> None: + """Stop audio input.""" + self.audio._stop() + + async def __call__(self) -> BidiAudioInputEvent: + """Read audio from microphone.""" + audio_bytes = self.audio.input_stream.read(self.audio.chunk_size, exception_on_overflow=False) + + return BidiAudioInputEvent( + audio=base64.b64encode(audio_bytes).decode("utf-8"), + format="pcm", + sample_rate=self.audio.input_sample_rate, + channels=self.audio.input_channels, + ) + + +class _BidiAudioOutput(BidiOutput): + "Handle audio output from bidi agent." + def __init__(self, audio: "BidiAudioIO") -> None: + """Store reference to pyaudio instance.""" + self.audio = audio + + async def start(self) -> None: + """Start audio output.""" + self.audio._start() + + async def stop(self) -> None: + """Stop audio output.""" + self.audio._stop() + + async def __call__(self, event: BidiOutputEvent) -> None: + """Handle audio events with direct stream writing.""" + if isinstance(event, BidiAudioStreamEvent): + self.audio.output_stream.write(base64.b64decode(event["audio"])) + + # TODO: Outputing audio to speakers is a sync operation. Adding sleep to prevent event loop hogging. Will + # follow up on identifying a cleaner approach. + await asyncio.sleep(0.01) + + +class BidiAudioIO: + """Audio IO channel for BidiAgent with direct stream processing.""" + + def __init__( + self, + audio_config: dict | None = None, + ): + """Initialize AudioIO with clean audio configuration. + + Args: + audio_config: Dictionary containing audio configuration: + - input_sample_rate (int): Microphone sample rate (default: 24000) + - output_sample_rate (int): Speaker sample rate (default: 24000) + - chunk_size (int): Audio chunk size in bytes (default: 1024) + - input_device_index (int): Specific input device (optional) + - output_device_index (int): Specific output device (optional) + - input_channels (int): Input channels (default: 1) + - output_channels (int): Output channels (default: 1) + """ + default_config = { + "input_sample_rate": 16000, + "output_sample_rate": 16000, + "chunk_size": 512, + "input_device_index": None, + "output_device_index": None, + "input_channels": 1, + "output_channels": 1, + } + + # Merge user config with defaults + if audio_config: + default_config.update(audio_config) + + # Set audio configuration attributes + self.input_sample_rate = default_config["input_sample_rate"] + self.output_sample_rate = default_config["output_sample_rate"] + self.chunk_size = default_config["chunk_size"] + self.input_device_index = default_config["input_device_index"] + self.output_device_index = default_config["output_device_index"] + self.input_channels = default_config["input_channels"] + self.output_channels = default_config["output_channels"] + + # Audio infrastructure + self.audio = None + self.input_stream = None + self.output_stream = None + self.interrupted = False + + def input(self) -> _BidiAudioInput: + "Return audio processing BidiInput" + return _BidiAudioInput(self) + + def output(self) -> _BidiAudioOutput: + "Return audio processing BidiOutput" + return _BidiAudioOutput(self) + + def _start(self) -> None: + """Setup PyAudio streams for input and output.""" + if self.audio: + return + + self.audio = pyaudio.PyAudio() + + self.input_stream = self.audio.open( + format=pyaudio.paInt16, + channels=self.input_channels, + rate=self.input_sample_rate, + input=True, + frames_per_buffer=self.chunk_size, + input_device_index=self.input_device_index, + ) + + self.output_stream = self.audio.open( + format=pyaudio.paInt16, + channels=self.output_channels, + rate=self.output_sample_rate, + output=True, + frames_per_buffer=self.chunk_size, + output_device_index=self.output_device_index, + ) + + def _stop(self) -> None: + """Clean up IO channel resources.""" + if not self.audio: + return + + self.input_stream.close() + self.output_stream.close() + self.audio.terminate() + + self.input_stream = None + self.output_stream = None + self.audio = None diff --git a/src/strands/experimental/bidi/io/text.py b/src/strands/experimental/bidi/io/text.py new file mode 100644 index 000000000..ba503f4e4 --- /dev/null +++ b/src/strands/experimental/bidi/io/text.py @@ -0,0 +1,31 @@ +"""Handle text input and output from bidi agent.""" + +import logging + +from ..types.io import BidiOutput +from ..types.events import BidiOutputEvent, BidiInterruptionEvent, BidiTranscriptStreamEvent + +logger = logging.getLogger(__name__) + + +class _BidiTextOutput(BidiOutput): + "Handle text output from bidi agent." + async def __call__(self, event: BidiOutputEvent) -> None: + """Print text events to stdout.""" + + if isinstance(event, BidiInterruptionEvent): + print("interrupted") + + elif isinstance(event, BidiTranscriptStreamEvent): + text = event["text"] + if not event["is_final"]: + text = f"Preview: {text}" + + print(text) + + +class BidiTextIO: + "Handle text input and output from bidi agent." + def output(self) -> _BidiTextOutput: + "Return text processing BidiOutput" + return _BidiTextOutput() diff --git a/src/strands/experimental/bidi/models/__init__.py b/src/strands/experimental/bidi/models/__init__.py new file mode 100644 index 000000000..6d6d6590b --- /dev/null +++ b/src/strands/experimental/bidi/models/__init__.py @@ -0,0 +1,13 @@ +"""Bidirectional model interfaces and implementations.""" + +from .bidirectional_model import BidiModel +from .gemini_live import BidiGeminiLiveModel +from .novasonic import BidiNovaSonicModel +from .openai import BidiOpenAIRealtimeModel + +__all__ = [ + "BidiModel", + "BidiGeminiLiveModel", + "BidiNovaSonicModel", + "BidiOpenAIRealtimeModel", +] diff --git a/src/strands/experimental/bidi/models/bidi_model.py b/src/strands/experimental/bidi/models/bidi_model.py new file mode 100644 index 000000000..d3c3aa7ec --- /dev/null +++ b/src/strands/experimental/bidi/models/bidi_model.py @@ -0,0 +1,108 @@ +"""Bidirectional streaming model interface. + +Defines the abstract interface for models that support real-time bidirectional +communication with persistent connections. Unlike traditional request-response +models, bidirectional models maintain an open connection for streaming audio, +text, and tool interactions. + +Features: +- Persistent connection management with connect/close lifecycle +- Real-time bidirectional communication (send and receive simultaneously) +- Provider-agnostic event normalization +- Support for audio, text, image, and tool result streaming +""" + +import logging +from typing import AsyncIterable, Protocol, Union + +from ....types._events import ToolResultEvent +from ....types.content import Messages +from ....types.tools import ToolSpec +from ..types.events import ( + BidiAudioInputEvent, + BidiImageInputEvent, + BidiInputEvent, + BidiOutputEvent, + BidiTextInputEvent, +) + +logger = logging.getLogger(__name__) + + +class BidiModel(Protocol): + """Protocol for bidirectional streaming models. + + This interface defines the contract for models that support persistent streaming + connections with real-time audio and text communication. Implementations handle + provider-specific protocols while exposing a standardized event-based API. + """ + + async def start( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs, + ) -> None: + """Establish a persistent streaming connection with the model. + + Opens a bidirectional connection that remains active for real-time communication. + The connection supports concurrent sending and receiving of events until explicitly + closed. Must be called before any send() or receive() operations. + + Args: + system_prompt: System instructions to configure model behavior. + tools: Tool specifications that the model can invoke during the conversation. + messages: Initial conversation history to provide context. + **kwargs: Provider-specific configuration options. + """ + ... + + async def stop(self) -> None: + """Close the streaming connection and release resources. + + Terminates the active bidirectional connection and cleans up any associated + resources such as network connections, buffers, or background tasks. After + calling close(), the model instance cannot be used until start() is called again. + """ + ... + + async def receive(self) -> AsyncIterable[BidiOutputEvent]: + """Receive streaming events from the model. + + Continuously yields events from the model as they arrive over the connection. + Events are normalized to a provider-agnostic format for uniform processing. + This method should be called in a loop or async task to process model responses. + + The stream continues until the connection is closed or an error occurs. + + Yields: + BidiOutputEvent: Standardized event objects containing audio output, + transcripts, tool calls, or control signals. + """ + ... + + async def send( + self, + content: BidiInputEvent | ToolResultEvent, + ) -> None: + """Send content to the model over the active connection. + + Transmits user input or tool results to the model during an active streaming + session. Supports multiple content types including text, audio, images, and + tool execution results. Can be called multiple times during a conversation. + + Args: + content: The content to send. Must be one of: + - BidiTextInputEvent: Text message from the user + - BidiAudioInputEvent: Audio data for speech input + - BidiImageInputEvent: Image data for visual understanding + - ToolResultEvent: Result from a tool execution + + Example: + await model.send(BidiTextInputEvent(text="Hello", role="user")) + await model.send(BidiAudioInputEvent(audio=bytes, format="pcm", sample_rate=16000, channels=1)) + await model.send(BidiImageInputEvent(image=bytes, mime_type="image/jpeg", encoding="raw")) + await model.send(ToolResultEvent(tool_result)) + """ + ... diff --git a/src/strands/experimental/bidi/models/bidirectional_model.py b/src/strands/experimental/bidi/models/bidirectional_model.py new file mode 100644 index 000000000..d3c3aa7ec --- /dev/null +++ b/src/strands/experimental/bidi/models/bidirectional_model.py @@ -0,0 +1,108 @@ +"""Bidirectional streaming model interface. + +Defines the abstract interface for models that support real-time bidirectional +communication with persistent connections. Unlike traditional request-response +models, bidirectional models maintain an open connection for streaming audio, +text, and tool interactions. + +Features: +- Persistent connection management with connect/close lifecycle +- Real-time bidirectional communication (send and receive simultaneously) +- Provider-agnostic event normalization +- Support for audio, text, image, and tool result streaming +""" + +import logging +from typing import AsyncIterable, Protocol, Union + +from ....types._events import ToolResultEvent +from ....types.content import Messages +from ....types.tools import ToolSpec +from ..types.events import ( + BidiAudioInputEvent, + BidiImageInputEvent, + BidiInputEvent, + BidiOutputEvent, + BidiTextInputEvent, +) + +logger = logging.getLogger(__name__) + + +class BidiModel(Protocol): + """Protocol for bidirectional streaming models. + + This interface defines the contract for models that support persistent streaming + connections with real-time audio and text communication. Implementations handle + provider-specific protocols while exposing a standardized event-based API. + """ + + async def start( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs, + ) -> None: + """Establish a persistent streaming connection with the model. + + Opens a bidirectional connection that remains active for real-time communication. + The connection supports concurrent sending and receiving of events until explicitly + closed. Must be called before any send() or receive() operations. + + Args: + system_prompt: System instructions to configure model behavior. + tools: Tool specifications that the model can invoke during the conversation. + messages: Initial conversation history to provide context. + **kwargs: Provider-specific configuration options. + """ + ... + + async def stop(self) -> None: + """Close the streaming connection and release resources. + + Terminates the active bidirectional connection and cleans up any associated + resources such as network connections, buffers, or background tasks. After + calling close(), the model instance cannot be used until start() is called again. + """ + ... + + async def receive(self) -> AsyncIterable[BidiOutputEvent]: + """Receive streaming events from the model. + + Continuously yields events from the model as they arrive over the connection. + Events are normalized to a provider-agnostic format for uniform processing. + This method should be called in a loop or async task to process model responses. + + The stream continues until the connection is closed or an error occurs. + + Yields: + BidiOutputEvent: Standardized event objects containing audio output, + transcripts, tool calls, or control signals. + """ + ... + + async def send( + self, + content: BidiInputEvent | ToolResultEvent, + ) -> None: + """Send content to the model over the active connection. + + Transmits user input or tool results to the model during an active streaming + session. Supports multiple content types including text, audio, images, and + tool execution results. Can be called multiple times during a conversation. + + Args: + content: The content to send. Must be one of: + - BidiTextInputEvent: Text message from the user + - BidiAudioInputEvent: Audio data for speech input + - BidiImageInputEvent: Image data for visual understanding + - ToolResultEvent: Result from a tool execution + + Example: + await model.send(BidiTextInputEvent(text="Hello", role="user")) + await model.send(BidiAudioInputEvent(audio=bytes, format="pcm", sample_rate=16000, channels=1)) + await model.send(BidiImageInputEvent(image=bytes, mime_type="image/jpeg", encoding="raw")) + await model.send(ToolResultEvent(tool_result)) + """ + ... diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py new file mode 100644 index 000000000..9bb5bba77 --- /dev/null +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -0,0 +1,539 @@ +"""Gemini Live API bidirectional model provider using official Google GenAI SDK. + +Implements the BidiModel interface for Google's Gemini Live API using the +official Google GenAI SDK for simplified and robust WebSocket communication. + +Key improvements over custom WebSocket implementation: +- Uses official google-genai SDK with native Live API support +- Simplified session management with client.aio.live.connect() +- Built-in tool integration and event handling +- Automatic WebSocket connection management and error handling +- Native support for audio/text streaming and interruption +""" + +import asyncio +import base64 +import logging +import uuid +from typing import Any, AsyncIterable, Dict, List, Optional, Union + +from google import genai +from google.genai import types as genai_types +from google.genai.types import LiveServerMessage, LiveServerContent + +from ....types.content import Messages +from ....types.tools import ToolResult, ToolSpec, ToolUse +from ....types._events import ToolResultEvent, ToolUseStreamEvent +from ..types.events import ( + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionCloseEvent, + BidiConnectionStartEvent, + BidiErrorEvent, + BidiImageInputEvent, + BidiInputEvent, + BidiInterruptionEvent, + BidiOutputEvent, + BidiUsageEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, +) +from .bidirectional_model import BidiModel + +logger = logging.getLogger(__name__) + +# Audio format constants +GEMINI_INPUT_SAMPLE_RATE = 16000 +GEMINI_OUTPUT_SAMPLE_RATE = 24000 +GEMINI_CHANNELS = 1 + + +class BidiGeminiLiveModel(BidiModel): + """Gemini Live API implementation using official Google GenAI SDK. + + Combines model configuration and connection state in a single class. + Provides a clean interface to Gemini Live API using the official SDK, + eliminating custom WebSocket handling and providing robust error handling. + """ + + def __init__( + self, + model_id: str = "gemini-2.5-flash-native-audio-preview-09-2025", + api_key: Optional[str] = None, + live_config: Optional[Dict[str, Any]] = None, + **kwargs + ): + """Initialize Gemini Live API bidirectional model. + + Args: + model_id: Gemini Live model identifier. + api_key: Google AI API key for authentication. + live_config: Gemini Live API configuration parameters (e.g., response_modalities, speech_config). + **kwargs: Reserved for future parameters. + """ + # Model configuration + self.model_id = model_id + self.api_key = api_key + + # Set default live_config with transcription enabled + default_config = { + "response_modalities": ["AUDIO"], + "outputAudioTranscription": {}, # Enable output transcription by default + "inputAudioTranscription": {} # Enable input transcription by default + } + + # Merge user config with defaults (user config takes precedence) + if live_config: + default_config.update(live_config) + + self.live_config = default_config + + # Create Gemini client with proper API version + client_kwargs = {} + if api_key: + client_kwargs["api_key"] = api_key + + # Use v1alpha for Live API as it has better model support + client_kwargs["http_options"] = {"api_version": "v1alpha"} + + self.client = genai.Client(**client_kwargs) + + # Connection state (initialized in start()) + self.live_session = None + self.live_session_context_manager = None + self.connection_id = None + self._active = False + + async def start( + self, + system_prompt: Optional[str] = None, + tools: Optional[List[ToolSpec]] = None, + messages: Optional[Messages] = None, + **kwargs + ) -> None: + """Establish bidirectional connection with Gemini Live API. + + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Additional configuration options. + """ + if self._active: + raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") + + try: + # Initialize connection state + self.connection_id = str(uuid.uuid4()) + self._active = True + + # Build live config + live_config = self._build_live_config(system_prompt, tools, **kwargs) + + # Create the context manager + self.live_session_context_manager = self.client.aio.live.connect( + model=self.model_id, + config=live_config + ) + + # Enter the context manager + self.live_session = await self.live_session_context_manager.__aenter__() + + # Send initial message history if provided + if messages: + await self._send_message_history(messages) + + except Exception as e: + self._active = False + logger.error("Error connecting to Gemini Live: %s", e) + raise + + async def _send_message_history(self, messages: Messages) -> None: + """Send conversation history to Gemini Live API. + + Sends each message as a separate turn with the correct role to maintain + proper conversation context. Follows the same pattern as the non-bidirectional + Gemini model implementation. + """ + if not messages: + return + + # Convert each message to Gemini format and send separately + for message in messages: + content_parts = [] + for content_block in message["content"]: + if "text" in content_block: + content_parts.append(genai_types.Part(text=content_block["text"])) + + if content_parts: + # Map role correctly - Gemini uses "user" and "model" roles + # "assistant" role from Messages format maps to "model" in Gemini + role = "model" if message["role"] == "assistant" else message["role"] + content = genai_types.Content(role=role, parts=content_parts) + await self.live_session.send_client_content(turns=content) + + async def receive(self) -> AsyncIterable[BidiOutputEvent]: + """Receive Gemini Live API events and convert to provider-agnostic format.""" + + # Emit connection start event + yield BidiConnectionStartEvent( + connection_id=self.connection_id, + model=self.model_id + ) + + try: + # Wrap in while loop to restart after turn_complete (SDK limitation workaround) + while self._active: + try: + async for message in self.live_session.receive(): + if not self._active: + break + + # Convert to provider-agnostic format (always returns list) + for event in self._convert_gemini_live_event(message): + yield event + + # SDK exits receive loop after turn_complete - restart automatically + if self._active: + logger.debug("Restarting receive loop after turn completion") + + except Exception as e: + logger.error("Error in receive iteration: %s", e) + # Small delay before retrying to avoid tight error loops + await asyncio.sleep(0.1) + + except Exception as e: + logger.error("Fatal error in receive loop: %s", e) + yield BidiErrorEvent(error=e) + finally: + # Emit connection close event when exiting + yield BidiConnectionCloseEvent(connection_id=self.connection_id, reason="complete") + + def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOutputEvent]: + """Convert Gemini Live API events to provider-agnostic format. + + Handles different types of content: + - inputTranscription: User's speech transcribed to text + - outputTranscription: Model's audio transcribed to text + - modelTurn text: Text response from the model + - usageMetadata: Token usage information + + Returns: + List of event dicts (empty list if no events to emit). + """ + try: + # Handle interruption first (from server_content) + if message.server_content and message.server_content.interrupted: + return [BidiInterruptionEvent(reason="user_speech")] + + # Handle input transcription (user's speech) - emit as transcript event + if message.server_content and message.server_content.input_transcription: + input_transcript = message.server_content.input_transcription + # Check if the transcription object has text content + if hasattr(input_transcript, 'text') and input_transcript.text: + transcription_text = input_transcript.text + role = getattr(input_transcript, 'role', 'user') + logger.debug(f"Input transcription detected: {transcription_text}") + return [BidiTranscriptStreamEvent( + delta={"text": transcription_text}, + text=transcription_text, + role=role.lower() if isinstance(role, str) else "user", + is_final=True, + current_transcript=transcription_text + )] + + # Handle output transcription (model's audio) - emit as transcript event + if message.server_content and message.server_content.output_transcription: + output_transcript = message.server_content.output_transcription + # Check if the transcription object has text content + if hasattr(output_transcript, 'text') and output_transcript.text: + transcription_text = output_transcript.text + role = getattr(output_transcript, 'role', 'assistant') + logger.debug(f"Output transcription detected: {transcription_text}") + return [BidiTranscriptStreamEvent( + delta={"text": transcription_text}, + text=transcription_text, + role=role.lower() if isinstance(role, str) else "assistant", + is_final=True, + current_transcript=transcription_text + )] + + # Handle audio output using SDK's built-in data property + # Check this BEFORE text to avoid triggering warning on mixed content + if message.data: + # Convert bytes to base64 string for JSON serializability + audio_b64 = base64.b64encode(message.data).decode('utf-8') + return [BidiAudioStreamEvent( + audio=audio_b64, + format="pcm", + sample_rate=GEMINI_OUTPUT_SAMPLE_RATE, + channels=GEMINI_CHANNELS + )] + + # Handle text output from model_turn (avoids warning by checking parts directly) + if message.server_content and message.server_content.model_turn: + model_turn = message.server_content.model_turn + if model_turn.parts: + # Concatenate all text parts (Gemini may send multiple parts) + text_parts = [] + for part in model_turn.parts: + # Log all part types for debugging + part_attrs = {attr: getattr(part, attr, None) for attr in dir(part) if not attr.startswith('_')} + + # Check if part has text attribute and it's not empty + if hasattr(part, 'text') and part.text: + text_parts.append(part.text) + + if text_parts: + full_text = " ".join(text_parts) + return [BidiTranscriptStreamEvent( + delta={"text": full_text}, + text=full_text, + role="assistant", + is_final=True, + current_transcript=full_text + )] + + # Handle tool calls - return list to support multiple tool calls + if message.tool_call and message.tool_call.function_calls: + tool_events = [] + for func_call in message.tool_call.function_calls: + tool_use_event: ToolUse = { + "toolUseId": func_call.id, + "name": func_call.name, + "input": func_call.args or {} + } + # Create ToolUseStreamEvent for consistency with standard agent + tool_events.append(ToolUseStreamEvent( + delta={"toolUse": tool_use_event}, + current_tool_use=tool_use_event + )) + return tool_events + + # Handle usage metadata + if hasattr(message, 'usage_metadata') and message.usage_metadata: + usage = message.usage_metadata + + # Build modality details from token details + modality_details = [] + + # Process prompt tokens details + if usage.prompt_tokens_details: + for detail in usage.prompt_tokens_details: + if detail.modality and detail.token_count: + modality_details.append({ + "modality": str(detail.modality).lower(), + "input_tokens": detail.token_count, + "output_tokens": 0 + }) + + # Process response tokens details + if usage.response_tokens_details: + for detail in usage.response_tokens_details: + if detail.modality and detail.token_count: + # Find or create modality entry + modality_str = str(detail.modality).lower() + existing = next((m for m in modality_details if m["modality"] == modality_str), None) + if existing: + existing["output_tokens"] = detail.token_count + else: + modality_details.append({ + "modality": modality_str, + "input_tokens": 0, + "output_tokens": detail.token_count + }) + + return [BidiUsageEvent( + input_tokens=usage.prompt_token_count or 0, + output_tokens=usage.response_token_count or 0, + total_tokens=usage.total_token_count or 0, + modality_details=modality_details if modality_details else None, + cache_read_input_tokens=usage.cached_content_token_count if usage.cached_content_token_count else None + )] + + # Silently ignore setup_complete and generation_complete messages + return [] + + except Exception as e: + logger.error("Error converting Gemini Live event: %s", e) + logger.error("Message type: %s", type(message).__name__) + logger.error("Message attributes: %s", [attr for attr in dir(message) if not attr.startswith('_')]) + # Return ErrorEvent in list so caller can handle it + return [BidiErrorEvent(error=e)] + + async def send( + self, + content: BidiInputEvent | ToolResultEvent, + ) -> None: + """Unified send method for all content types. Sends the given inputs to Google Live API + + Dispatches to appropriate internal handler based on content type. + + Args: + content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent). + """ + if not self._active: + return + + try: + if isinstance(content, BidiTextInputEvent): + await self._send_text_content(content.text) + elif isinstance(content, BidiAudioInputEvent): + await self._send_audio_content(content) + elif isinstance(content, BidiImageInputEvent): + await self._send_image_content(content) + elif isinstance(content, ToolResultEvent): + tool_result = content.get("tool_result") + if tool_result: + await self._send_tool_result(tool_result) + else: + logger.warning(f"Unknown content type: {type(content)}") + except Exception as e: + logger.error(f"Error sending content: {e}") + raise # Propagate exception for debugging in experimental code + + async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: + """Internal: Send audio content using Gemini Live API. + + Gemini Live expects continuous audio streaming via send_realtime_input. + This automatically triggers VAD and can interrupt ongoing responses. + """ + try: + # Decode base64 audio to bytes for SDK + audio_bytes = base64.b64decode(audio_input.audio) + + # Create audio blob for the SDK + audio_blob = genai_types.Blob( + data=audio_bytes, + mime_type=f"audio/pcm;rate={GEMINI_INPUT_SAMPLE_RATE}" + ) + + # Send real-time audio input - this automatically handles VAD and interruption + await self.live_session.send_realtime_input(audio=audio_blob) + + except Exception as e: + logger.error("Error sending audio content: %s", e) + + async def _send_image_content(self, image_input: BidiImageInputEvent) -> None: + """Internal: Send image content using Gemini Live API. + + Sends image frames following the same pattern as the GitHub example. + Images are sent as base64-encoded data with MIME type. + """ + try: + # Image is already base64 encoded in the event + msg = { + "mime_type": image_input.mime_type, + "data": image_input.image + } + + # Send using the same method as the GitHub example + await self.live_session.send(input=msg) + + except Exception as e: + logger.error("Error sending image content: %s", e) + + async def _send_text_content(self, text: str) -> None: + """Internal: Send text content using Gemini Live API.""" + try: + # Create content with text + content = genai_types.Content( + role="user", + parts=[genai_types.Part(text=text)] + ) + + # Send as client content + await self.live_session.send_client_content(turns=content) + + except Exception as e: + logger.error("Error sending text content: %s", e) + + async def _send_tool_result(self, tool_result: ToolResult) -> None: + """Internal: Send tool result using Gemini Live API.""" + try: + tool_use_id = tool_result.get("toolUseId") + + # Extract result content + result_data = {} + if "content" in tool_result: + # Extract text from content blocks + for block in tool_result["content"]: + if "text" in block: + result_data = {"result": block["text"]} + break + + # Create function response + func_response = genai_types.FunctionResponse( + id=tool_use_id, + name=tool_use_id, # Gemini uses name as identifier + response=result_data + ) + + # Send tool response + await self.live_session.send_tool_response(function_responses=[func_response]) + except Exception as e: + logger.error("Error sending tool result: %s", e) + + async def stop(self) -> None: + """Close Gemini Live API connection.""" + if not self._active: + return + + self._active = False + + try: + # Exit the context manager properly + if self.live_session_context_manager: + await self.live_session_context_manager.__aexit__(None, None, None) + except Exception as e: + logger.error("Error closing Gemini Live connection: %s", e) + raise + + def _build_live_config( + self, + system_prompt: Optional[str] = None, + tools: Optional[List[ToolSpec]] = None, + **kwargs + ) -> Dict[str, Any]: + """Build LiveConnectConfig for the official SDK. + + Simply passes through all config parameters from live_config, allowing users + to configure any Gemini Live API parameter directly. + """ + # Start with user-provided live_config + config_dict = {} + if self.live_config: + config_dict.update(self.live_config) + + # Override with any kwargs from start() + config_dict.update(kwargs) + + # Add system instruction if provided + if system_prompt: + config_dict["system_instruction"] = system_prompt + + # Add tools if provided + if tools: + config_dict["tools"] = self._format_tools_for_live_api(tools) + + return config_dict + + def _format_tools_for_live_api(self, tool_specs: List[ToolSpec]) -> List[genai_types.Tool]: + """Format tool specs for Gemini Live API.""" + if not tool_specs: + return [] + + return [ + genai_types.Tool( + function_declarations=[ + genai_types.FunctionDeclaration( + description=tool_spec["description"], + name=tool_spec["name"], + parameters_json_schema=tool_spec["inputSchema"]["json"], + ) + for tool_spec in tool_specs + ], + ), + ] \ No newline at end of file diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py new file mode 100644 index 000000000..8c23aa0da --- /dev/null +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -0,0 +1,743 @@ +"""Nova Sonic bidirectional model provider for real-time streaming conversations. + +Implements the BidiModel interface for Amazon's Nova Sonic, handling the +complex event sequencing and audio processing required by Nova Sonic's +InvokeModelWithBidirectionalStream protocol. + +Nova Sonic specifics: +- Hierarchical event sequences: connectionStart → promptStart → content streaming +- Base64-encoded audio format with hex encoding +- Tool execution with content containers and identifier tracking +- 8-minute connection limits with proper cleanup sequences +- Interruption detection through stopReason events +""" + +import asyncio +import base64 +import json +import logging +import traceback +import uuid +from typing import AsyncIterable + +from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput +from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme +from aws_sdk_bedrock_runtime.models import ( + BidirectionalInputPayloadPart, + InvokeModelWithBidirectionalStreamInputChunk, +) +from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver + +from ....types.content import Messages +from ....types.tools import ToolResult, ToolSpec, ToolUse +from ....types._events import ToolResultEvent, ToolUseStreamEvent +from ..types.events import ( + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionCloseEvent, + BidiConnectionStartEvent, + BidiErrorEvent, + BidiImageInputEvent, + BidiInputEvent, + BidiInterruptionEvent, + BidiUsageEvent, + BidiOutputEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, +) +from .bidirectional_model import BidiModel + +logger = logging.getLogger(__name__) + +# Nova Sonic configuration constants +NOVA_INFERENCE_CONFIG = {"maxTokens": 1024, "topP": 0.9, "temperature": 0.7} + +NOVA_AUDIO_INPUT_CONFIG = { + "mediaType": "audio/lpcm", + "sampleRateHertz": 16000, + "sampleSizeBits": 16, + "channelCount": 1, + "audioType": "SPEECH", + "encoding": "base64", +} + +NOVA_AUDIO_OUTPUT_CONFIG = { + "mediaType": "audio/lpcm", + "sampleRateHertz": 16000, + "sampleSizeBits": 16, + "channelCount": 1, + "voiceId": "matthew", + "encoding": "base64", + "audioType": "SPEECH", +} + +NOVA_TEXT_CONFIG = {"mediaType": "text/plain"} +NOVA_TOOL_CONFIG = {"mediaType": "application/json"} + +# Timing constants +EVENT_DELAY = 0.1 +RESPONSE_TIMEOUT = 1.0 + + +class BidiNovaSonicModel(BidiModel): + """Nova Sonic implementation for bidirectional streaming. + + Combines model configuration and connection state in a single class. + Manages Nova Sonic's complex event sequencing, audio format conversion, and + tool execution patterns while providing the standard BidiModel interface. + """ + + def __init__( + self, + model_id: str = "amazon.nova-sonic-v1:0", + region: str = "us-east-1", + **kwargs + ) -> None: + """Initialize Nova Sonic bidirectional model. + + Args: + model_id: Nova Sonic model identifier. + region: AWS region. + **kwargs: Reserved for future parameters. + """ + # Model configuration + self.model_id = model_id + self.region = region + self.client = None + + # Connection state (initialized in start()) + self.stream = None + self.connection_id = None + self._active = False + + # Nova Sonic requires unique content names + self.audio_content_name = None + + # Audio connection state + self.audio_connection_active = False + + # Background task and event queue + self._response_task = None + self._event_queue = None + + # Track API-provided identifiers + self._current_completion_id = None + self._current_role = None + self._generation_stage = None + + logger.debug("Nova Sonic bidirectional model initialized: %s", model_id) + + async def start( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs, + ) -> None: + """Establish bidirectional connection to Nova Sonic. + + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Additional configuration options. + """ + if self._active: + raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") + + logger.debug("Nova connection create - starting") + + try: + # Initialize client if needed + if not self.client: + await self._initialize_client() + + # Initialize connection state + self.connection_id = str(uuid.uuid4()) + self._active = True + self.audio_content_name = str(uuid.uuid4()) + self._event_queue = asyncio.Queue() + + # Start Nova Sonic bidirectional stream + self.stream = await self.client.invoke_model_with_bidirectional_stream( + InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) + ) + + # Validate stream + if not self.stream: + logger.error("Stream is None") + raise ValueError("Stream cannot be None") + + logger.debug("Nova Sonic connection initialized with connection_id: %s", self.connection_id) + + # Send initialization events + system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." + init_events = self._build_initialization_events(system_prompt, tools or [], messages) + + logger.debug("Nova Sonic initialization - sending %d events", len(init_events)) + await self._send_initialization_events(init_events) + + # Start background response processor + self._response_task = asyncio.create_task(self._process_responses()) + + logger.info("Nova Sonic connection established successfully") + + except Exception as e: + self._active = False + logger.error("Nova connection create error: %s", str(e)) + raise + + def _build_initialization_events( + self, system_prompt: str, tools: list[ToolSpec], messages: Messages | None + ) -> list[str]: + """Build the sequence of initialization events.""" + events = [self._get_connection_start_event(), self._get_prompt_start_event(tools)] + + events.extend(self._get_system_prompt_events(system_prompt)) + + # Message history would be processed here if needed in the future + # Currently not implemented as it's not used in the existing test cases + + return events + + async def _send_initialization_events(self, events: list[str]) -> None: + """Send initialization events with required delays.""" + for _i, event in enumerate(events): + await self._send_nova_event(event) + await asyncio.sleep(EVENT_DELAY) + + async def _process_responses(self) -> None: + """Process Nova Sonic responses continuously.""" + logger.debug("Nova Sonic response processor started") + + try: + while self._active: + try: + output = await asyncio.wait_for(self.stream.await_output(), timeout=RESPONSE_TIMEOUT) + result = await output[1].receive() + + if result.value and result.value.bytes_: + await self._handle_response_data(result.value.bytes_.decode("utf-8")) + + except asyncio.TimeoutError: + await asyncio.sleep(0.1) + continue + except Exception as e: + logger.warning("Nova Sonic response error: %s", e) + await asyncio.sleep(0.1) + continue + + except Exception as e: + logger.error("Nova Sonic fatal error: %s", e) + finally: + logger.debug("Nova Sonic response processor stopped") + + async def _handle_response_data(self, response_data: str) -> None: + """Handle decoded response data from Nova Sonic.""" + try: + json_data = json.loads(response_data) + + if "event" in json_data: + nova_event = json_data["event"] + self._log_event_type(nova_event) + + if not hasattr(self, "_event_queue"): + self._event_queue = asyncio.Queue() + + await self._event_queue.put(nova_event) + except json.JSONDecodeError as e: + logger.warning("Nova Sonic JSON decode error: %s", e) + + def _log_event_type(self, nova_event: dict[str, any]) -> None: + """Log specific Nova Sonic event types for debugging.""" + if "usageEvent" in nova_event: + logger.debug("Nova usage: %s", nova_event["usageEvent"]) + elif "textOutput" in nova_event: + logger.debug("Nova text output") + elif "toolUse" in nova_event: + tool_use = nova_event["toolUse"] + logger.debug("Nova tool use: %s (id: %s)", tool_use["toolName"], tool_use["toolUseId"]) + elif "audioOutput" in nova_event: + audio_content = nova_event["audioOutput"]["content"] + audio_bytes = base64.b64decode(audio_content) + logger.debug("Nova audio output: %d bytes", len(audio_bytes)) + + async def receive(self) -> AsyncIterable[dict[str, any]]: + """Receive Nova Sonic events and convert to provider-agnostic format.""" + if not self.stream: + logger.error("Stream is None") + return + + logger.debug("Nova events - starting event stream") + + # Emit connection start event + yield BidiConnectionStartEvent( + connection_id=self.connection_id, + model=self.model_id + ) + + try: + while self._active: + try: + # Get events from the queue populated by _process_responses + nova_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) + + # Convert to provider-agnostic format + provider_event = self._convert_nova_event(nova_event) + if provider_event: + yield provider_event + + except asyncio.TimeoutError: + # No events in queue - continue waiting + continue + + except Exception as e: + logger.error("Error receiving Nova Sonic event: %s", e) + logger.error(traceback.format_exc()) + yield BidiErrorEvent(error=e) + finally: + # Emit connection close event + yield BidiConnectionCloseEvent(connection_id=self.connection_id, reason="complete") + + async def send( + self, + content: BidiInputEvent | ToolResultEvent, + ) -> None: + """Unified send method for all content types. Sends the given content to Nova Sonic. + + Dispatches to appropriate internal handler based on content type. + + Args: + content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent). + """ + if not self._active: + return + + try: + if isinstance(content, BidiTextInputEvent): + await self._send_text_content(content.text) + elif isinstance(content, BidiAudioInputEvent): + await self._send_audio_content(content) + elif isinstance(content, BidiImageInputEvent): + # BidiImageInputEvent - not supported by Nova Sonic + logger.warning("Image input not supported by Nova Sonic") + elif isinstance(content, ToolResultEvent): + tool_result = content.get("tool_result") + if tool_result: + await self._send_tool_result(tool_result) + else: + logger.warning(f"Unknown content type: {type(content)}") + except Exception as e: + logger.error(f"Error sending content: {e}") + raise # Propagate exception for debugging in experimental code + + async def _start_audio_connection(self) -> None: + """Internal: Start audio input connection (call once before sending audio chunks).""" + if self.audio_connection_active: + return + + logger.debug("Nova audio connection start") + + audio_content_start = json.dumps( + { + "event": { + "contentStart": { + "promptName": self.connection_id, + "contentName": self.audio_content_name, + "type": "AUDIO", + "interactive": True, + "role": "USER", + "audioInputConfiguration": NOVA_AUDIO_INPUT_CONFIG, + } + } + } + ) + + await self._send_nova_event(audio_content_start) + self.audio_connection_active = True + + async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: + """Internal: Send audio using Nova Sonic protocol-specific format.""" + # Start audio connection if not already active + if not self.audio_connection_active: + await self._start_audio_connection() + + # Audio is already base64 encoded in the event + # Send audio input event + audio_event = json.dumps( + { + "event": { + "audioInput": { + "promptName": self.connection_id, + "contentName": self.audio_content_name, + "content": audio_input.audio, + } + } + } + ) + + await self._send_nova_event(audio_event) + + async def _end_audio_input(self) -> None: + """Internal: End current audio input connection to trigger Nova Sonic processing.""" + if not self.audio_connection_active: + return + + logger.debug("Nova audio connection end") + + audio_content_end = json.dumps( + {"event": {"contentEnd": {"promptName": self.connection_id, "contentName": self.audio_content_name}}} + ) + + await self._send_nova_event(audio_content_end) + self.audio_connection_active = False + + async def _send_text_content(self, text: str) -> None: + """Internal: Send text content using Nova Sonic format.""" + content_name = str(uuid.uuid4()) + events = [ + self._get_text_content_start_event(content_name), + self._get_text_input_event(content_name, text), + self._get_content_end_event(content_name), + ] + + for event in events: + await self._send_nova_event(event) + + async def _send_interrupt(self) -> None: + """Internal: Send interruption signal to Nova Sonic.""" + # Nova Sonic handles interruption through special input events + interrupt_event = json.dumps( + { + "event": { + "audioInput": { + "promptName": self.connection_id, + "contentName": self.audio_content_name, + "stopReason": "INTERRUPTED", + } + } + } + ) + await self._send_nova_event(interrupt_event) + + async def _send_tool_result(self, tool_result: ToolResult) -> None: + """Internal: Send tool result using Nova Sonic toolResult format.""" + tool_use_id = tool_result.get("toolUseId") + + logger.debug("Nova tool result send: %s", tool_use_id) + + # Extract result content + result_data = {} + if "content" in tool_result: + # Extract text from content blocks + for block in tool_result["content"]: + if "text" in block: + result_data = {"result": block["text"]} + break + + content_name = str(uuid.uuid4()) + events = [ + self._get_tool_content_start_event(content_name, tool_use_id), + self._get_tool_result_event(content_name, result_data), + self._get_content_end_event(content_name), + ] + + for event in events: + await self._send_nova_event(event) + + async def stop(self) -> None: + """Close Nova Sonic connection with proper cleanup sequence.""" + if not self._active: + return + + logger.debug("Nova cleanup - starting connection close") + self._active = False + + # Cancel response processing task if running + if hasattr(self, "_response_task") and not self._response_task.done(): + self._response_task.cancel() + try: + await self._response_task + except asyncio.CancelledError: + pass + + try: + # End audio connection if active + if self.audio_connection_active: + await self._end_audio_input() + + # Send cleanup events + cleanup_events = [self._get_prompt_end_event(), self._get_connection_end_event()] + + for event in cleanup_events: + try: + await self._send_nova_event(event) + except Exception as e: + logger.warning("Error during Nova Sonic cleanup: %s", e) + + # Close stream + try: + await self.stream.input_stream.close() + except Exception as e: + logger.warning("Error closing Nova Sonic stream: %s", e) + + except Exception as e: + logger.error("Nova cleanup error: %s", str(e)) + finally: + logger.debug("Nova connection closed") + + def _convert_nova_event(self, nova_event: dict[str, any]) -> BidiOutputEvent | None: + """Convert Nova Sonic events to TypedEvent format.""" + # Handle completion start - track completionId + if "completionStart" in nova_event: + completion_data = nova_event["completionStart"] + self._current_completion_id = completion_data.get("completionId") + logger.debug("Nova completion started: %s", self._current_completion_id) + return None + + # Handle completion end + if "completionEnd" in nova_event: + completion_data = nova_event["completionEnd"] + completion_id = completion_data.get("completionId", self._current_completion_id) + stop_reason = completion_data.get("stopReason", "END_TURN") + + event = BidiResponseCompleteEvent( + response_id=completion_id or str(uuid.uuid4()), # Fallback to UUID if missing + stop_reason="interrupted" if stop_reason == "INTERRUPTED" else "complete" + ) + + # Clear completion tracking + self._current_completion_id = None + return event + + # Handle audio output + if "audioOutput" in nova_event: + # Audio is already base64 string from Nova Sonic + audio_content = nova_event["audioOutput"]["content"] + return BidiAudioStreamEvent( + audio=audio_content, + format="pcm", + sample_rate=24000, + channels=1 + ) + + # Handle text output (transcripts) + elif "textOutput" in nova_event: + text_content = nova_event["textOutput"]["content"] + # Check for Nova Sonic interruption pattern + if '{ "interrupted" : true }' in text_content: + logger.debug("Nova interruption detected in text") + return BidiInterruptionEvent(reason="user_speech") + + return BidiTranscriptStreamEvent( + delta={"text": text_content}, + text=text_content, + role=self._current_role.lower() if self._current_role else "assistant", + is_final=self._generation_stage == "FINAL", + current_transcript=text_content + ) + + # Handle tool use + if "toolUse" in nova_event: + tool_use = nova_event["toolUse"] + tool_use_event: ToolUse = { + "toolUseId": tool_use["toolUseId"], + "name": tool_use["toolName"], + "input": json.loads(tool_use["content"]), + } + # Return ToolUseStreamEvent for consistency with standard agent + return ToolUseStreamEvent( + delta={"toolUse": tool_use_event}, + current_tool_use=tool_use_event + ) + + # Handle interruption + if nova_event.get("stopReason") == "INTERRUPTED": + logger.debug("Nova interruption stop reason") + return BidiInterruptionEvent(reason="user_speech") + + # Handle usage events - convert to multimodal usage format + if "usageEvent" in nova_event: + usage_data = nova_event["usageEvent"] + total_input = usage_data.get("totalInputTokens", 0) + total_output = usage_data.get("totalOutputTokens", 0) + + return BidiUsageEvent( + input_tokens=total_input, + output_tokens=total_output, + total_tokens=usage_data.get("totalTokens", total_input + total_output) + ) + + # Handle content start events (track role and emit response start) + if "contentStart" in nova_event: + content_data = nova_event["contentStart"] + role = content_data.get("role", "unknown") + # Store role for subsequent text output events + self._current_role = role + + if content_data["type"] == "TEXT": + self._generation_stage = json.loads(content_data["additionalModelFields"])["generationStage"] + + # Emit response start event using API-provided completionId + # completionId should already be tracked from completionStart event + return BidiResponseStartEvent( + response_id=self._current_completion_id or str(uuid.uuid4()) # Fallback to UUID if missing + ) + + # Ignore other events (contentEnd, etc.) + return + + # Nova Sonic event template methods + def _get_connection_start_event(self) -> str: + """Generate Nova Sonic connection start event.""" + return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": NOVA_INFERENCE_CONFIG}}}) + + def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: + """Generate Nova Sonic prompt start event with tool configuration.""" + prompt_start_event = { + "event": { + "promptStart": { + "promptName": self.connection_id, + "textOutputConfiguration": NOVA_TEXT_CONFIG, + "audioOutputConfiguration": NOVA_AUDIO_OUTPUT_CONFIG, + } + } + } + + if tools: + tool_config = self._build_tool_configuration(tools) + prompt_start_event["event"]["promptStart"]["toolUseOutputConfiguration"] = NOVA_TOOL_CONFIG + prompt_start_event["event"]["promptStart"]["toolConfiguration"] = {"tools": tool_config} + + return json.dumps(prompt_start_event) + + def _build_tool_configuration(self, tools: list[ToolSpec]) -> list[dict]: + """Build tool configuration from tool specs.""" + tool_config = [] + for tool in tools: + input_schema = ( + {"json": json.dumps(tool["inputSchema"]["json"])} + if "json" in tool["inputSchema"] + else {"json": json.dumps(tool["inputSchema"])} + ) + + tool_config.append( + {"toolSpec": {"name": tool["name"], "description": tool["description"], "inputSchema": input_schema}} + ) + return tool_config + + def _get_system_prompt_events(self, system_prompt: str) -> list[str]: + """Generate system prompt events.""" + content_name = str(uuid.uuid4()) + return [ + self._get_text_content_start_event(content_name, "SYSTEM"), + self._get_text_input_event(content_name, system_prompt), + self._get_content_end_event(content_name), + ] + + def _get_text_content_start_event(self, content_name: str, role: str = "USER") -> str: + """Generate text content start event.""" + return json.dumps( + { + "event": { + "contentStart": { + "promptName": self.connection_id, + "contentName": content_name, + "type": "TEXT", + "role": role, + "interactive": True, + "textInputConfiguration": NOVA_TEXT_CONFIG, + } + } + } + ) + + def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> str: + """Generate tool content start event.""" + return json.dumps( + { + "event": { + "contentStart": { + "promptName": self.connection_id, + "contentName": content_name, + "interactive": False, + "type": "TOOL", + "role": "TOOL", + "toolResultInputConfiguration": { + "toolUseId": tool_use_id, + "type": "TEXT", + "textInputConfiguration": NOVA_TEXT_CONFIG, + }, + } + } + } + ) + + def _get_text_input_event(self, content_name: str, text: str) -> str: + """Generate text input event.""" + return json.dumps( + {"event": {"textInput": {"promptName": self.connection_id, "contentName": content_name, "content": text}}} + ) + + def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> str: + """Generate tool result event.""" + return json.dumps( + { + "event": { + "toolResult": { + "promptName": self.connection_id, + "contentName": content_name, + "content": json.dumps(result), + } + } + } + ) + + def _get_content_end_event(self, content_name: str) -> str: + """Generate content end event.""" + return json.dumps({"event": {"contentEnd": {"promptName": self.connection_id, "contentName": content_name}}}) + + def _get_prompt_end_event(self) -> str: + """Generate prompt end event.""" + return json.dumps({"event": {"promptEnd": {"promptName": self.connection_id}}}) + + def _get_connection_end_event(self) -> str: + """Generate connection end event.""" + return json.dumps({"event": {"connectionEnd": {}}}) + + async def _send_nova_event(self, event: str) -> None: + """Send event JSON string to Nova Sonic stream.""" + try: + # Event is already a JSON string + bytes_data = event.encode("utf-8") + chunk = InvokeModelWithBidirectionalStreamInputChunk(value=BidirectionalInputPayloadPart(bytes_=bytes_data)) + await self.stream.input_stream.send(chunk) + logger.debug("Successfully sent Nova Sonic event") + + except Exception as e: + logger.error("Error sending Nova Sonic event: %s", e) + logger.error("Event was: %s", event) + raise + + async def _initialize_client(self) -> None: + """Initialize Nova Sonic client.""" + try: + config = Config( + endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", + region=self.region, + aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), + auth_scheme_resolver=HTTPAuthSchemeResolver(), + auth_schemes={"aws.auth#sigv4": SigV4AuthScheme(service="bedrock")}, + ) + + self.client = BedrockRuntimeClient(config=config) + logger.debug("Nova Sonic client initialized") + + except ImportError as e: + logger.error("Nova Sonic dependencies not available: %s", e) + raise + except Exception as e: + logger.error("Error initializing Nova Sonic client: %s", e) + raise diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py new file mode 100644 index 000000000..74f1942ff --- /dev/null +++ b/src/strands/experimental/bidi/models/openai.py @@ -0,0 +1,653 @@ +"""OpenAI Realtime API provider for Strands bidirectional streaming. + +Provides real-time audio and text communication through OpenAI's Realtime API +with WebSocket connections, voice activity detection, and function calling. +""" + +import asyncio +import base64 +import json +import logging +import os +import uuid +from typing import AsyncIterable, Union + +import websockets +from websockets.exceptions import ConnectionClosed + +from ....types.content import Messages +from ....types.tools import ToolResult, ToolSpec, ToolUse +from ....types._events import ToolResultEvent, ToolUseStreamEvent +from ..types.events import ( + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionCloseEvent, + BidiConnectionStartEvent, + BidiErrorEvent, + BidiImageInputEvent, + BidiInputEvent, + BidiInterruptionEvent, + BidiUsageEvent, + BidiOutputEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, +) +from .bidirectional_model import BidiModel + +logger = logging.getLogger(__name__) + +# OpenAI Realtime API configuration +OPENAI_REALTIME_URL = "wss://api.openai.com/v1/realtime" +DEFAULT_MODEL = "gpt-realtime" + +AUDIO_FORMAT = {"type": "audio/pcm", "rate": 24000} + +DEFAULT_SESSION_CONFIG = { + "type": "realtime", + "instructions": "You are a helpful assistant. Please speak in English and keep your responses clear and concise.", + "output_modalities": ["audio"], + "audio": { + "input": { + "format": AUDIO_FORMAT, + "transcription": { + "model": "gpt-4o-transcribe" + }, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "prefix_padding_ms": 300, + "silence_duration_ms": 500, + } + }, + "output": {"format": AUDIO_FORMAT, "voice": "alloy"}, + }, +} + + +class BidiOpenAIRealtimeModel(BidiModel): + """OpenAI Realtime API implementation for bidirectional streaming. + + Combines model configuration and connection state in a single class. + Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, + function calling, and event conversion to Strands format. + """ + + def __init__( + self, + model: str = DEFAULT_MODEL, + api_key: str | None = None, + organization: str | None = None, + project: str | None = None, + session_config: dict[str, any] | None = None, + **kwargs + ) -> None: + """Initialize OpenAI Realtime bidirectional model. + + Args: + model: OpenAI model identifier (default: gpt-realtime). + api_key: OpenAI API key for authentication. + organization: OpenAI organization ID for API requests. + project: OpenAI project ID for API requests. + session_config: Session configuration parameters (e.g., voice, turn_detection, modalities). + **kwargs: Reserved for future parameters. + """ + # Model configuration + self.model = model + self.api_key = api_key + self.organization = organization + self.project = project + self.session_config = session_config or {} + + if not self.api_key: + self.api_key = os.getenv("OPENAI_API_KEY") + if not self.api_key: + raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.") + + # Connection state (initialized in start()) + self.websocket = None + self.connection_id = None + self._active = False + + self._event_queue = None + self._response_task = None + self._function_call_buffer = {} + + logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) + + async def start( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs, + ) -> None: + """Establish bidirectional connection to OpenAI Realtime API. + + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Additional configuration options. + """ + if self._active: + raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") + + logger.info("Creating OpenAI Realtime connection...") + + try: + # Initialize connection state + self.connection_id = str(uuid.uuid4()) + self._active = True + self._event_queue = asyncio.Queue() + self._function_call_buffer = {} + + # Establish WebSocket connection + url = f"{OPENAI_REALTIME_URL}?model={self.model}" + + headers = [("Authorization", f"Bearer {self.api_key}")] + if self.organization: + headers.append(("OpenAI-Organization", self.organization)) + if self.project: + headers.append(("OpenAI-Project", self.project)) + + self.websocket = await websockets.connect(url, additional_headers=headers) + logger.info("WebSocket connected successfully") + + # Configure session + session_config = self._build_session_config(system_prompt, tools) + await self._send_event({"type": "session.update", "session": session_config}) + + # Add conversation history if provided + if messages: + await self._add_conversation_history(messages) + + # Start background response processor + self._response_task = asyncio.create_task(self._process_responses()) + logger.info("OpenAI Realtime connection established") + + except Exception as e: + self._active = False + logger.error("OpenAI connection error: %s", e) + raise + + def _require_active(self) -> bool: + """Check if session is active.""" + return self._active + + def _create_text_event(self, text: str, role: str, is_final: bool = True) -> BidiTranscriptStreamEvent: + """Create standardized transcript event. + + Args: + text: The transcript text + role: The role (will be normalized to lowercase) + is_final: Whether this is the final transcript + """ + # Normalize role to lowercase and ensure it's either "user" or "assistant" + normalized_role = role.lower() if isinstance(role, str) else "assistant" + if normalized_role not in ["user", "assistant"]: + normalized_role = "assistant" + + return BidiTranscriptStreamEvent( + delta={"text": text}, + text=text, + role=normalized_role, + is_final=is_final, + current_transcript=text if is_final else None + ) + + def _create_voice_activity_event(self, activity_type: str) -> BidiInterruptionEvent | None: + """Create standardized interruption event for voice activity.""" + # Only speech_started triggers interruption + if activity_type == "speech_started": + return BidiInterruptionEvent(reason="user_speech") + # Other voice activity events are logged but don't create events + return None + + def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict: + """Build session configuration for OpenAI Realtime API.""" + config = DEFAULT_SESSION_CONFIG.copy() + + if system_prompt: + config["instructions"] = system_prompt + + if tools: + config["tools"] = self._convert_tools_to_openai_format(tools) + + # Apply user-provided session configuration + supported_params = { + "type", "output_modalities", "instructions", "voice", "audio", + "tools", "tool_choice", "input_audio_format", "output_audio_format", + "input_audio_transcription", "turn_detection" + } + + for key, value in self.session_config.items(): + if key in supported_params: + config[key] = value + else: + logger.warning("Ignoring unsupported session parameter: %s", key) + + return config + + def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: + """Convert Strands tool specifications to OpenAI Realtime API format.""" + openai_tools = [] + + for tool in tools: + input_schema = tool["inputSchema"] + if "json" in input_schema: + schema = json.loads(input_schema["json"]) if isinstance(input_schema["json"], str) else input_schema["json"] + else: + schema = input_schema + + # OpenAI Realtime API expects flat structure, not nested under "function" + openai_tool = { + "type": "function", + "name": tool["name"], + "description": tool["description"], + "parameters": schema + } + openai_tools.append(openai_tool) + + return openai_tools + + async def _add_conversation_history(self, messages: Messages) -> None: + """Add conversation history to the session.""" + for message in messages: + conversation_item = { + "type": "conversation.item.create", + "item": {"type": "message", "role": message["role"], "content": []} + } + + content = message.get("content", "") + if isinstance(content, str): + conversation_item["item"]["content"].append({"type": "input_text", "text": content}) + elif isinstance(content, list): + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + conversation_item["item"]["content"].append({"type": "input_text", "text": item.get("text", "")}) + + await self._send_event(conversation_item) + + async def _process_responses(self) -> None: + """Process incoming WebSocket messages.""" + logger.debug("OpenAI Realtime response processor started") + + try: + async for message in self.websocket: + if not self._active: + break + + try: + event = json.loads(message) + await self._event_queue.put(event) + except json.JSONDecodeError as e: + logger.warning("Failed to parse OpenAI event: %s", e) + continue + + except ConnectionClosed: + logger.debug("OpenAI Realtime WebSocket connection closed") + except Exception as e: + logger.error("Error in OpenAI Realtime response processing: %s", e) + finally: + self._active = False + logger.debug("OpenAI Realtime response processor stopped") + + async def receive(self) -> AsyncIterable[BidiOutputEvent]: + """Receive OpenAI events and convert to Strands TypedEvent format.""" + # Emit connection start event + yield BidiConnectionStartEvent( + connection_id=self.connection_id, + model=self.model + ) + + try: + while self._active: + try: + openai_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) + for event in self._convert_openai_event(openai_event) or []: + yield event + except asyncio.TimeoutError: + continue + + except Exception as e: + logger.error("Error receiving OpenAI Realtime event: %s", e) + yield BidiErrorEvent(error=e) + finally: + # Emit connection close event + yield BidiConnectionCloseEvent(connection_id=self.connection_id, reason="complete") + + def _convert_openai_event(self, openai_event: dict[str, any]) -> list[BidiOutputEvent] | None: + """Convert OpenAI events to Strands TypedEvent format.""" + event_type = openai_event.get("type") + + # Turn start - response begins + if event_type == "response.created": + response = openai_event.get("response", {}) + response_id = response.get("id", str(uuid.uuid4())) + return [BidiResponseStartEvent(response_id=response_id)] + + # Audio output + elif event_type == "response.output_audio.delta": + # Audio is already base64 string from OpenAI + return [BidiAudioStreamEvent( + audio=openai_event["delta"], + format="pcm", + sample_rate=24000, + channels=1 + )] + + # Assistant text output events - combine multiple similar events + elif event_type in ["response.output_text.delta", "response.output_audio_transcript.delta"]: + role = openai_event.get("role", "assistant") + return [self._create_text_event(openai_event["delta"], role.lower() if isinstance(role, str) else "assistant")] + + # User transcription events - combine multiple similar events + elif event_type in ["conversation.item.input_audio_transcription.delta", + "conversation.item.input_audio_transcription.completed"]: + text_key = "delta" if "delta" in event_type else "transcript" + text = openai_event.get(text_key, "") + role = openai_event.get("role", "user") + is_final = "completed" in event_type + return [self._create_text_event(text, role.lower() if isinstance(role, str) else "user", is_final=is_final)] if text.strip() else None + + elif event_type == "conversation.item.input_audio_transcription.segment": + segment_data = openai_event.get("segment", {}) + text = segment_data.get("text", "") + role = segment_data.get("role", "user") + return [self._create_text_event(text, role.lower() if isinstance(role, str) else "user")] if text.strip() else None + + elif event_type == "conversation.item.input_audio_transcription.failed": + error_info = openai_event.get("error", {}) + logger.warning("OpenAI transcription failed: %s", error_info.get("message", "Unknown error")) + return None + + # Function call processing + elif event_type == "response.function_call_arguments.delta": + call_id = openai_event.get("call_id") + delta = openai_event.get("delta", "") + if call_id: + if call_id not in self._function_call_buffer: + self._function_call_buffer[call_id] = {"call_id": call_id, "name": "", "arguments": delta} + else: + self._function_call_buffer[call_id]["arguments"] += delta + return None + + elif event_type == "response.function_call_arguments.done": + call_id = openai_event.get("call_id") + if call_id and call_id in self._function_call_buffer: + function_call = self._function_call_buffer[call_id] + try: + tool_use: ToolUse = { + "toolUseId": call_id, + "name": function_call["name"], + "input": json.loads(function_call["arguments"]) if function_call["arguments"] else {}, + } + del self._function_call_buffer[call_id] + # Return ToolUseStreamEvent for consistency with standard agent + return [ToolUseStreamEvent( + delta={"toolUse": tool_use}, + current_tool_use=tool_use + )] + except (json.JSONDecodeError, KeyError) as e: + logger.warning("Error parsing function arguments for %s: %s", call_id, e) + del self._function_call_buffer[call_id] + return None + + # Voice activity detection - speech_started triggers interruption + elif event_type == "input_audio_buffer.speech_started": + # This is the primary interruption signal - handle it first + return [BidiInterruptionEvent(reason="user_speech")] + + # Response cancelled - handle interruption + elif event_type == "response.cancelled": + response = openai_event.get("response", {}) + response_id = response.get("id", "unknown") + logger.debug("OpenAI response cancelled: %s", response_id) + return [BidiResponseCompleteEvent( + response_id=response_id, + stop_reason="interrupted" + )] + + # Turn complete and usage - response finished + elif event_type == "response.done": + response = openai_event.get("response", {}) + response_id = response.get("id", "unknown") + status = response.get("status", "completed") + usage = response.get("usage") + + # Map OpenAI status to our stop_reason + stop_reason_map = { + "completed": "complete", + "cancelled": "interrupted", + "failed": "error", + "incomplete": "interrupted" + } + + # Build list of events to return + events = [] + + # Always add response complete event + events.append(BidiResponseCompleteEvent( + response_id=response_id, + stop_reason=stop_reason_map.get(status, "complete") + )) + + # Add usage event if available + if usage: + input_details = usage.get("input_token_details", {}) + output_details = usage.get("output_token_details", {}) + + # Build modality details + modality_details = [] + + # Text modality + text_input = input_details.get("text_tokens", 0) + text_output = output_details.get("text_tokens", 0) + if text_input > 0 or text_output > 0: + modality_details.append({ + "modality": "text", + "input_tokens": text_input, + "output_tokens": text_output + }) + + # Audio modality + audio_input = input_details.get("audio_tokens", 0) + audio_output = output_details.get("audio_tokens", 0) + if audio_input > 0 or audio_output > 0: + modality_details.append({ + "modality": "audio", + "input_tokens": audio_input, + "output_tokens": audio_output + }) + + # Image modality + image_input = input_details.get("image_tokens", 0) + if image_input > 0: + modality_details.append({ + "modality": "image", + "input_tokens": image_input, + "output_tokens": 0 + }) + + # Cached tokens + cached_tokens = input_details.get("cached_tokens", 0) + + # Add usage event + events.append(BidiUsageEvent( + input_tokens=usage.get("input_tokens", 0), + output_tokens=usage.get("output_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + modality_details=modality_details if modality_details else None, + cache_read_input_tokens=cached_tokens if cached_tokens > 0 else None + )) + + # Return list of events + return events + + # Lifecycle events (log only) - combine multiple similar events + elif event_type in ["conversation.item.retrieve", "conversation.item.added"]: + item = openai_event.get("item", {}) + action = "retrieved" if "retrieve" in event_type else "added" + logger.debug("OpenAI conversation item %s: %s", action, item.get("id")) + return None + + elif event_type == "conversation.item.done": + logger.debug("OpenAI conversation item done: %s", openai_event.get("item", {}).get("id")) + return None + + # Response output events - combine similar events + elif event_type in ["response.output_item.added", "response.output_item.done", + "response.content_part.added", "response.content_part.done"]: + item_data = openai_event.get("item") or openai_event.get("part") + logger.debug("OpenAI %s: %s", event_type, item_data.get("id") if item_data else "unknown") + + # Track function call names from response.output_item.added + if event_type == "response.output_item.added": + item = openai_event.get("item", {}) + if item.get("type") == "function_call": + call_id = item.get("call_id") + function_name = item.get("name") + if call_id and function_name: + if call_id not in self._function_call_buffer: + self._function_call_buffer[call_id] = {"call_id": call_id, "name": function_name, "arguments": ""} + else: + self._function_call_buffer[call_id]["name"] = function_name + return None + + # Session/buffer events - combine simple log-only events + elif event_type in ["input_audio_buffer.committed", "input_audio_buffer.cleared", + "session.created", "session.updated"]: + logger.debug("OpenAI %s event", event_type) + return None + + elif event_type == "error": + error_data = openai_event.get("error", {}) + error_code = error_data.get("code", "") + + # Suppress expected errors that don't affect session state + if error_code == "response_cancel_not_active": + # This happens when trying to cancel a response that's not active + # It's safe to ignore as the session remains functional + logger.debug("OpenAI response cancel attempted when no response active (safe to ignore)") + return None + + # Log other errors + logger.error("OpenAI Realtime error: %s", error_data) + return None + + else: + logger.debug("Unhandled OpenAI event type: %s", event_type) + return None + + async def send( + self, + content: BidiInputEvent | ToolResultEvent, + ) -> None: + """Unified send method for all content types. Sends the given content to OpenAI. + + Dispatches to appropriate internal handler based on content type. + + Args: + content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent). + """ + if not self._require_active(): + return + + try: + # Note: TypedEvent inherits from dict, so isinstance checks for TypedEvent must come first + if isinstance(content, BidiTextInputEvent): + await self._send_text_content(content.text) + elif isinstance(content, BidiAudioInputEvent): + await self._send_audio_content(content) + elif isinstance(content, BidiImageInputEvent): + # BidiImageInputEvent - not supported by OpenAI Realtime yet + logger.warning("Image input not supported by OpenAI Realtime API") + elif isinstance(content, ToolResultEvent): + tool_result = content.get("tool_result") + if tool_result: + await self._send_tool_result(tool_result) + else: + logger.warning(f"Unknown content type: {type(content).__name__}") + except Exception as e: + logger.error(f"Error sending content: {e}") + raise # Propagate exception for debugging in experimental code + + async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: + """Internal: Send audio content to OpenAI for processing.""" + # Audio is already base64 encoded in the event + await self._send_event({"type": "input_audio_buffer.append", "audio": audio_input.audio}) + + async def _send_text_content(self, text: str) -> None: + """Internal: Send text content to OpenAI for processing.""" + item_data = { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": text}] + } + await self._send_event({"type": "conversation.item.create", "item": item_data}) + await self._send_event({"type": "response.create"}) + + async def _send_interrupt(self) -> None: + """Internal: Send interruption signal to OpenAI.""" + await self._send_event({"type": "response.cancel"}) + + async def _send_tool_result(self, tool_result: ToolResult) -> None: + """Internal: Send tool result back to OpenAI.""" + tool_use_id = tool_result.get("toolUseId") + + logger.debug("OpenAI tool result send: %s", tool_use_id) + + # Extract result content + result_data = {} + if "content" in tool_result: + # Extract text from content blocks + for block in tool_result["content"]: + if "text" in block: + result_data = block["text"] + break + + result_text = json.dumps(result_data) if not isinstance(result_data, str) else result_data + + item_data = { + "type": "function_call_output", + "call_id": tool_use_id, + "output": result_text + } + await self._send_event({"type": "conversation.item.create", "item": item_data}) + await self._send_event({"type": "response.create"}) + + async def stop(self) -> None: + """Close session and cleanup resources.""" + if not self._active: + return + + logger.debug("OpenAI Realtime cleanup - starting connection close") + self._active = False + + if self._response_task and not self._response_task.done(): + self._response_task.cancel() + try: + await self._response_task + except asyncio.CancelledError: + pass + + try: + await self.websocket.close() + except Exception as e: + logger.warning("Error closing OpenAI Realtime WebSocket: %s", e) + + logger.debug("OpenAI Realtime connection closed") + + async def _send_event(self, event: dict[str, any]) -> None: + """Send event to OpenAI via WebSocket.""" + try: + message = json.dumps(event) + await self.websocket.send(message) + logger.debug("Sent OpenAI event: %s", event.get("type")) + except Exception as e: + logger.error("Error sending OpenAI event: %s", e) + raise + + diff --git a/src/strands/experimental/bidi/scripts/test_bidi.py b/src/strands/experimental/bidi/scripts/test_bidi.py new file mode 100644 index 000000000..abeb9fcf7 --- /dev/null +++ b/src/strands/experimental/bidi/scripts/test_bidi.py @@ -0,0 +1,38 @@ +"""Test BidirectionalAgent with simple developer experience.""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) + +from strands.experimental.bidi.agent.agent import BidiAgent +from strands.experimental.bidi.models.openai import BidiOpenAIRealtimeModel +from strands.experimental.bidi.io import BidiAudioIO, BidiTextIO +from strands_tools import calculator + + +async def main(): + """Test the BidirectionalAgent API.""" + + + # Nova Sonic model + audio_io = BidiAudioIO(audio_config={}) + text_io = BidiTextIO() + model = BidiOpenAIRealtimeModel(region="us-east-1") + + async with BidiAgent(model=model, tools=[calculator]) as agent: + print("New BidiAgent Experience") + print("Try asking: 'What is 25 times 8?' or 'Calculate the square root of 144'") + await agent.run(inputs=[audio_io.input()], outputs=[audio_io.output(), text_io.output()]) + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\n⏹️ Conversation ended by user") + except Exception as e: + print(f"❌ Error: {e}") + import traceback + traceback.print_exc() diff --git a/src/strands/experimental/bidi/scripts/test_bidi_novasonic.py b/src/strands/experimental/bidi/scripts/test_bidi_novasonic.py new file mode 100644 index 000000000..38654f7fd --- /dev/null +++ b/src/strands/experimental/bidi/scripts/test_bidi_novasonic.py @@ -0,0 +1,256 @@ +"""Test suite for bidirectional streaming with real-time audio interaction. + +Tests the complete bidirectional streaming system including audio input/output, +interruption handling, and concurrent tool execution using Nova Sonic. +""" + +import asyncio +import base64 +import sys +from pathlib import Path + +# Add the src directory to Python path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) +import os +import time + +import pyaudio +from strands_tools import calculator + +from strands.experimental.bidi.agent.agent import BidiAgent +from strands.experimental.bidi.models.novasonic import BidiNovaSonicModel + + +def test_direct_tools(): + """Test direct tool calling.""" + print("Testing direct tool calling...") + + # Check AWS credentials + if not all([os.getenv("AWS_ACCESS_KEY_ID"), os.getenv("AWS_SECRET_ACCESS_KEY")]): + print("AWS credentials not set - skipping test") + return + + try: + model = BidiNovaSonicModel() + agent = BidirectionalAgent(model=model, tools=[calculator]) + + # Test calculator + result = agent.tool.calculator(expression="2 * 3") + content = result.get("content", [{}])[0].get("text", "") + print(f"Result: {content}") + print("Test completed") + + except Exception as e: + print(f"Test failed: {e}") + + +async def play(context): + """Play audio output with responsive interruption support.""" + audio = pyaudio.PyAudio() + speaker = audio.open( + channels=1, + format=pyaudio.paInt16, + output=True, + rate=24000, + frames_per_buffer=1024, + ) + + try: + while context["active"]: + try: + # Check for interruption first + if context.get("interrupted", False): + # Clear entire audio queue immediately + while not context["audio_out"].empty(): + try: + context["audio_out"].get_nowait() + except asyncio.QueueEmpty: + break + + context["interrupted"] = False + await asyncio.sleep(0.05) + continue + + # Get next audio data + audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) + + if audio_data and context["active"]: + chunk_size = 1024 + for i in range(0, len(audio_data), chunk_size): + # Check for interruption before each chunk + if context.get("interrupted", False) or not context["active"]: + break + + end = min(i + chunk_size, len(audio_data)) + chunk = audio_data[i:end] + speaker.write(chunk) + await asyncio.sleep(0.001) + + except asyncio.TimeoutError: + continue # No audio available + except asyncio.QueueEmpty: + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + finally: + speaker.close() + audio.terminate() + + +async def record(context): + """Record audio input from microphone.""" + audio = pyaudio.PyAudio() + microphone = audio.open( + channels=1, + format=pyaudio.paInt16, + frames_per_buffer=1024, + input=True, + rate=16000, + ) + + try: + while context["active"]: + try: + audio_bytes = microphone.read(1024, exception_on_overflow=False) + context["audio_in"].put_nowait(audio_bytes) + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + except asyncio.CancelledError: + pass + finally: + microphone.close() + audio.terminate() + + +async def receive(agent, context): + """Receive and process events from agent.""" + try: + async for event in agent.receive(): + event_type = event.get("type", "unknown") + + # Handle audio stream events (bidi_audio_stream) + if event_type == "bidi_audio_stream": + if not context.get("interrupted", False): + # Decode base64 audio string to bytes for playback + audio_b64 = event["audio"] + audio_data = base64.b64decode(audio_b64) + context["audio_out"].put_nowait(audio_data) + + # Handle interruption events (bidi_interruption) + elif event_type == "bidi_interruption": + context["interrupted"] = True + + # Handle transcript events (bidi_transcript_stream) + elif event_type == "bidi_transcript_stream": + text_content = event.get("text", "") + role = event.get("role", "unknown") + + # Log transcript output + if role == "user": + print(f"User: {text_content}") + elif role == "assistant": + print(f"Assistant: {text_content}") + + # Handle response complete events (bidi_response_complete) + elif event_type == "bidi_response_complete": + # Reset interrupted state since the turn is complete + context["interrupted"] = False + + # Handle tool use events (tool_use_stream) + elif event_type == "tool_use_stream": + tool_use = event.get("current_tool_use", {}) + tool_name = tool_use.get("name", "unknown") + tool_input = tool_use.get("input", {}) + print(f"🔧 Tool called: {tool_name} with input: {tool_input}") + + # Handle tool result events (tool_result) + elif event_type == "tool_result": + tool_result = event.get("tool_result", {}) + tool_name = tool_result.get("name", "unknown") + result_content = tool_result.get("content", []) + result_text = "" + for block in result_content: + if isinstance(block, dict) and block.get("type") == "text": + result_text = block.get("text", "") + break + print(f"✅ Tool result from {tool_name}: {result_text}") + + except asyncio.CancelledError: + pass + + +async def send(agent, context): + """Send audio input to agent.""" + try: + while time.time() - context["start_time"] < context["duration"]: + try: + audio_bytes = context["audio_in"].get_nowait() + # Create audio event using TypedEvent + from strands.experimental.bidi.types.events import BidiAudioInputEvent + + audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') + audio_event = BidiAudioInputEvent( + audio=audio_b64, + format="pcm", + sample_rate=16000, + channels=1 + ) + await agent.send(audio_event) + except asyncio.QueueEmpty: + await asyncio.sleep(0.01) # Restored to working timing + except asyncio.CancelledError: + break + + context["active"] = False + except asyncio.CancelledError: + pass + + +async def main(duration=180): + """Main function for bidirectional streaming test.""" + print("Starting bidirectional streaming test...") + print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") + + # Initialize model and agent + model = BidiNovaSonicModel(region="us-east-1") + agent = BidiAgent(model=model, tools=[calculator], system_prompt="You are a helpful assistant.") + + await agent.start() + + # Create shared context for all tasks + context = { + "active": True, + "audio_in": asyncio.Queue(), + "audio_out": asyncio.Queue(), + "connection": agent._agent_loop, + "duration": duration, + "start_time": time.time(), + "interrupted": False, + } + + print("Speak into microphone. Press Ctrl+C to exit.") + + try: + # Run all tasks concurrently + await asyncio.gather( + play(context), record(context), receive(agent, context), send(agent, context), return_exceptions=True + ) + except KeyboardInterrupt: + print("\nInterrupted by user") + except asyncio.CancelledError: + print("\nTest cancelled") + finally: + print("Cleaning up...") + context["active"] = False + await agent.stop() + + +if __name__ == "__main__": + # Test direct tool calling first + test_direct_tools() + + asyncio.run(main()) diff --git a/src/strands/experimental/bidi/scripts/test_bidi_openai.py b/src/strands/experimental/bidi/scripts/test_bidi_openai.py new file mode 100644 index 000000000..71e934fb7 --- /dev/null +++ b/src/strands/experimental/bidi/scripts/test_bidi_openai.py @@ -0,0 +1,324 @@ +#!/usr/bin/env python3 +"""Test OpenAI Realtime API speech-to-speech interaction.""" + +import asyncio +import base64 +import os +import sys +import time +from pathlib import Path + +# Add the src directory to Python path +sys.path.insert(0, str(Path(__file__).parent / "src")) + +import pyaudio +from strands_tools import calculator + +from strands.experimental.bidi.agent.agent import BidiAgent +from strands.experimental.bidi.models.openai import BidiOpenAIRealtimeModel + + +async def play(context): + """Handle audio playback with interruption support.""" + audio = pyaudio.PyAudio() + + try: + speaker = audio.open( + format=pyaudio.paInt16, + channels=1, + rate=24000, # OpenAI Realtime uses 24kHz + output=True, + frames_per_buffer=1024, + ) + + while context["active"]: + try: + # Check for interruption + if context.get("interrupted", False): + # Clear audio queue on interruption + while not context["audio_out"].empty(): + try: + context["audio_out"].get_nowait() + except asyncio.QueueEmpty: + break + + context["interrupted"] = False + await asyncio.sleep(0.05) + continue + + # Get audio data with timeout + try: + audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) + + if audio_data and context["active"]: + # Play in chunks to allow interruption + chunk_size = 1024 + for i in range(0, len(audio_data), chunk_size): + if context.get("interrupted", False) or not context["active"]: + break + + chunk = audio_data[i:i + chunk_size] + speaker.write(chunk) + await asyncio.sleep(0.001) # Brief pause for responsiveness + + except asyncio.TimeoutError: + continue + + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Audio playback error: {e}") + finally: + try: + speaker.close() + except: + pass + audio.terminate() + + +async def record(context): + """Handle microphone recording.""" + audio = pyaudio.PyAudio() + + try: + microphone = audio.open( + format=pyaudio.paInt16, + channels=1, + rate=24000, # Match OpenAI's expected input rate + input=True, + frames_per_buffer=1024, + ) + + while context["active"]: + try: + audio_bytes = microphone.read(1024, exception_on_overflow=False) + await context["audio_in"].put(audio_bytes) + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Microphone recording error: {e}") + finally: + try: + microphone.close() + except: + pass + audio.terminate() + + +async def receive(agent, context): + """Handle events from the agent.""" + try: + async for event in agent.receive(): + if not context["active"]: + break + + # Get event type + event_type = event.get("type", "unknown") + + # Handle audio stream events (bidi_audio_stream) + if event_type == "bidi_audio_stream": + # Decode base64 audio string to bytes for playback + audio_b64 = event["audio"] + audio_data = base64.b64decode(audio_b64) + + if not context.get("interrupted", False): + await context["audio_out"].put(audio_data) + + # Handle transcript events (bidi_transcript_stream) + elif event_type == "bidi_transcript_stream": + source = event.get("role", "assistant") + text = event.get("text", "").strip() + + if text: + if source == "user": + print(f"🎤 User: {text}") + elif source == "assistant": + print(f"🔊 Assistant: {text}") + + # Handle interruption events (bidi_interruption) + elif event_type == "bidi_interruption": + context["interrupted"] = True + print("⚠️ Interruption detected") + + # Handle connection start events (bidi_connection_start) + elif event_type == "bidi_connection_start": + print(f"✓ Session started: {event.get('model', 'unknown')}") + + # Handle connection close events (bidi_connection_close) + elif event_type == "bidi_connection_close": + print(f"✓ Session ended: {event.get('reason', 'unknown')}") + context["active"] = False + break + + # Handle response complete events (bidi_response_complete) + elif event_type == "bidi_response_complete": + # Reset interrupted state since the turn is complete + context["interrupted"] = False + + # Handle tool use events (tool_use_stream) + elif event_type == "tool_use_stream": + tool_use = event.get("current_tool_use", {}) + tool_name = tool_use.get("name", "unknown") + tool_input = tool_use.get("input", {}) + print(f"🔧 Tool called: {tool_name} with input: {tool_input}") + + # Handle tool result events (tool_result) + elif event_type == "tool_result": + tool_result = event.get("tool_result", {}) + tool_name = tool_result.get("name", "unknown") + result_content = tool_result.get("content", []) + result_text = "" + for block in result_content: + if isinstance(block, dict) and block.get("type") == "text": + result_text = block.get("text", "") + break + print(f"✅ Tool result from {tool_name}: {result_text}") + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Receive handler error: {e}") + finally: + pass + + +async def send(agent, context): + """Send audio from microphone to agent.""" + try: + while context["active"]: + try: + audio_bytes = await asyncio.wait_for(context["audio_in"].get(), timeout=0.1) + + # Create audio event using TypedEvent + # Encode audio bytes to base64 string for JSON serializability + from strands.experimental.bidi.types.events import BidiAudioInputEvent + + audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') + audio_event = BidiAudioInputEvent( + audio=audio_b64, + format="pcm", + sample_rate=24000, + channels=1 + ) + + await agent.send(audio_event) + + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Send handler error: {e}") + finally: + pass + + +async def main(): + """Main test function for OpenAI voice chat.""" + print("Starting OpenAI Realtime API test...") + + # Check API key + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("OPENAI_API_KEY environment variable not set") + return False + + # Check audio system + try: + audio = pyaudio.PyAudio() + audio.terminate() + except Exception as e: + print(f"Audio system error: {e}") + return False + + # Create OpenAI model + model = BidiOpenAIRealtimeModel( + model="gpt-4o-realtime-preview", + api_key=api_key, + session={ + "output_modalities": ["audio"], + "audio": { + "input": { + "format": {"type": "audio/pcm", "rate": 24000}, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "silence_duration_ms": 700 + } + }, + "output": { + "format": {"type": "audio/pcm", "rate": 24000}, + "voice": "alloy" + } + } + } + ) + + # Create agent + agent = BidiAgent( + model=model, + tools=[calculator], + system_prompt="You are a helpful voice assistant. Keep your responses brief and natural. Say hello when you first connect." + ) + + # Start the session + await agent.start() + + # Create shared context + context = { + "active": True, + "audio_in": asyncio.Queue(), + "audio_out": asyncio.Queue(), + "interrupted": False, + "start_time": time.time() + } + + print("Speak into your microphone. Press Ctrl+C to stop.") + + try: + # Run all tasks concurrently + await asyncio.gather( + play(context), + record(context), + receive(agent, context), + send(agent, context), + return_exceptions=True + ) + + except KeyboardInterrupt: + print("\nInterrupted by user") + except asyncio.CancelledError: + print("\nTest cancelled") + except Exception as e: + print(f"\nError during voice chat: {e}") + finally: + print("Cleaning up...") + context["active"] = False + + try: + await agent.stop() + except Exception as e: + print(f"Cleanup error: {e}") + + return True + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nTest interrupted by user") + except Exception as e: + print(f"Test error: {e}") + import traceback + traceback.print_exc() \ No newline at end of file diff --git a/src/strands/experimental/bidi/scripts/test_gemini_live.py b/src/strands/experimental/bidi/scripts/test_gemini_live.py new file mode 100644 index 000000000..807a8da2b --- /dev/null +++ b/src/strands/experimental/bidi/scripts/test_gemini_live.py @@ -0,0 +1,363 @@ +"""Test suite for Gemini Live bidirectional streaming with camera support. + +Tests the Gemini Live API with real-time audio and video interaction including: +- Audio input/output streaming +- Camera frame capture and transmission +- Interruption handling +- Concurrent tool execution +- Transcript events + +Requirements: +- pip install opencv-python pillow pyaudio google-genai +- Camera access permissions +- GOOGLE_AI_API_KEY environment variable +""" + +import asyncio +import base64 +import io +import logging +import os +import sys +from pathlib import Path + +# Add the src directory to Python path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) +import time + +try: + import cv2 + import PIL.Image + CAMERA_AVAILABLE = True +except ImportError as e: + print(f"Camera dependencies not available: {e}") + print("Install with: pip install opencv-python pillow") + CAMERA_AVAILABLE = False + +import pyaudio +from strands_tools import calculator + +from strands.experimental.bidi.agent.agent import BidiAgent +from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel + +# Configure logging - debug only for Gemini Live, info for everything else +logging.basicConfig(level=logging.WARN, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +gemini_logger = logging.getLogger('strands.experimental.bidirectional_streaming.models.gemini_live') +gemini_logger.setLevel(logging.WARN) +logger = logging.getLogger(__name__) + + +async def play(context): + """Play audio output with responsive interruption support.""" + audio = pyaudio.PyAudio() + speaker = audio.open( + channels=1, + format=pyaudio.paInt16, + output=True, + rate=24000, + frames_per_buffer=1024, + ) + + try: + while context["active"]: + try: + # Check for interruption first + if context.get("interrupted", False): + # Clear entire audio queue immediately + while not context["audio_out"].empty(): + try: + context["audio_out"].get_nowait() + except asyncio.QueueEmpty: + break + + context["interrupted"] = False + await asyncio.sleep(0.05) + continue + + # Get next audio data + audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) + + if audio_data and context["active"]: + chunk_size = 1024 + for i in range(0, len(audio_data), chunk_size): + # Check for interruption before each chunk + if context.get("interrupted", False) or not context["active"]: + break + + end = min(i + chunk_size, len(audio_data)) + chunk = audio_data[i:end] + speaker.write(chunk) + await asyncio.sleep(0.001) + + except asyncio.TimeoutError: + continue # No audio available + except asyncio.QueueEmpty: + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + finally: + speaker.close() + audio.terminate() + + +async def record(context): + """Record audio input from microphone.""" + audio = pyaudio.PyAudio() + + # List all available audio devices + print("Available audio devices:") + for i in range(audio.get_device_count()): + device_info = audio.get_device_info_by_index(i) + if device_info['maxInputChannels'] > 0: # Only show input devices + print(f" Device {i}: {device_info['name']} (inputs: {device_info['maxInputChannels']})") + + # Get default input device info + default_device = audio.get_default_input_device_info() + print(f"\nUsing default input device: {default_device['name']} (Device {default_device['index']})") + + microphone = audio.open( + channels=1, + format=pyaudio.paInt16, + frames_per_buffer=1024, + input=True, + rate=16000, + ) + + try: + while context["active"]: + try: + audio_bytes = microphone.read(1024, exception_on_overflow=False) + context["audio_in"].put_nowait(audio_bytes) + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + except asyncio.CancelledError: + pass + finally: + microphone.close() + audio.terminate() + + +async def receive(agent, context): + """Receive and process events from agent.""" + try: + async for event in agent.receive(): + event_type = event.get("type", "unknown") + + # Handle audio stream events (bidi_audio_stream) + if event_type == "bidi_audio_stream": + if not context.get("interrupted", False): + # Decode base64 audio string to bytes for playback + audio_b64 = event["audio"] + audio_data = base64.b64decode(audio_b64) + context["audio_out"].put_nowait(audio_data) + + # Handle interruption events (bidi_interruption) + elif event_type == "bidi_interruption": + context["interrupted"] = True + print("⚠️ Interruption detected") + + # Handle transcript events (bidi_transcript_stream) + elif event_type == "bidi_transcript_stream": + transcript_text = event.get("text", "") + transcript_role = event.get("role", "unknown") + is_final = event.get("is_final", False) + + # Print transcripts with special formatting + if transcript_role == "user": + print(f"🎤 User: {transcript_text}") + elif transcript_role == "assistant": + print(f"🔊 Assistant: {transcript_text}") + + # Handle response complete events (bidi_response_complete) + elif event_type == "bidi_response_complete": + # Reset interrupted state since the response is complete + context["interrupted"] = False + + # Handle tool use events (tool_use_stream) + elif event_type == "tool_use_stream": + tool_use = event.get("current_tool_use", {}) + tool_name = tool_use.get("name", "unknown") + tool_input = tool_use.get("input", {}) + print(f"🔧 Tool called: {tool_name} with input: {tool_input}") + + # Handle tool result events (tool_result) + elif event_type == "tool_result": + tool_result = event.get("tool_result", {}) + tool_name = tool_result.get("name", "unknown") + result_content = tool_result.get("content", []) + # Extract text from content blocks + result_text = "" + for block in result_content: + if isinstance(block, dict) and block.get("type") == "text": + result_text = block.get("text", "") + break + print(f"✅ Tool result from {tool_name}: {result_text}") + + except asyncio.CancelledError: + pass + + +def _get_frame(cap): + """Capture and process a frame from camera.""" + if not CAMERA_AVAILABLE: + return None + + # Read the frame + ret, frame = cap.read() + # Check if the frame was read successfully + if not ret: + return None + # Convert BGR to RGB color space + # OpenCV captures in BGR but PIL expects RGB format + # This prevents the blue tint in the video feed + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + img = PIL.Image.fromarray(frame_rgb) + img.thumbnail([1024, 1024]) + + image_io = io.BytesIO() + img.save(image_io, format="jpeg") + image_io.seek(0) + + mime_type = "image/jpeg" + image_bytes = image_io.read() + return {"mime_type": mime_type, "data": base64.b64encode(image_bytes).decode()} + + +async def get_frames(context): + """Capture frames from camera and send to agent.""" + if not CAMERA_AVAILABLE: + print("Camera not available - skipping video capture") + return + + # This takes about a second, and will block the whole program + # causing the audio pipeline to overflow if you don't to_thread it. + cap = await asyncio.to_thread(cv2.VideoCapture, 0) # 0 represents the default camera + + print("Camera initialized. Starting video capture...") + + try: + while context["active"] and time.time() - context["start_time"] < context["duration"]: + frame = await asyncio.to_thread(_get_frame, cap) + if frame is None: + break + + # Send frame to agent as image input + try: + from strands.experimental.bidi.types.events import BidiImageInputEvent + + image_event = BidiImageInputEvent( + image=frame["data"], # Already base64 encoded + mime_type=frame["mime_type"] + ) + await context["agent"].send(image_event) + print("📸 Frame sent to model") + except Exception as e: + logger.error(f"Error sending frame: {e}") + + # Wait 1 second between frames (1 FPS) + await asyncio.sleep(1.0) + + except asyncio.CancelledError: + pass + finally: + # Release the VideoCapture object + cap.release() + + +async def send(agent, context): + """Send audio input to agent.""" + try: + while time.time() - context["start_time"] < context["duration"]: + try: + audio_bytes = context["audio_in"].get_nowait() + # Create audio event using TypedEvent + from strands.experimental.bidi.types.events import BidiAudioInputEvent + + audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') + audio_event = BidiAudioInputEvent( + audio=audio_b64, + format="pcm", + sample_rate=16000, + channels=1 + ) + await agent.send(audio_event) + except asyncio.QueueEmpty: + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + context["active"] = False + except asyncio.CancelledError: + pass + + +async def main(duration=180): + """Main function for Gemini Live bidirectional streaming test with camera support.""" + print("Starting Gemini Live bidirectional streaming test with camera...") + print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") + print("Video: Camera frames sent at 1 FPS to model") + + # Get API key from environment variable + api_key = os.getenv("GOOGLE_AI_API_KEY") + + if not api_key: + print("ERROR: GOOGLE_AI_API_KEY environment variable not set") + print("Please set it with: export GOOGLE_AI_API_KEY=your_api_key") + return + + # Initialize Gemini Live model with proper configuration + logger.info("Initializing Gemini Live model with API key") + + # Use default model and config (includes transcription enabled by default) + model = BidiGeminiLiveModel(api_key=api_key) + logger.info("Gemini Live model initialized successfully") + print("Using Gemini Live model with default config (audio output + transcription enabled)") + + agent = BidiAgent( + model=model, + tools=[calculator], + system_prompt="You are a helpful assistant." + ) + + await agent.start() + + # Create shared context for all tasks + context = { + "active": True, + "audio_in": asyncio.Queue(), + "audio_out": asyncio.Queue(), + "connection": agent._agent_loop, + "duration": duration, + "start_time": time.time(), + "interrupted": False, + "agent": agent, # Add agent reference for camera task + } + + print("Speak into microphone and show things to camera. Press Ctrl+C to exit.") + + try: + # Run all tasks concurrently including camera + await asyncio.gather( + play(context), + record(context), + receive(agent, context), + send(agent, context), + get_frames(context), # Add camera task + return_exceptions=True + ) + except KeyboardInterrupt: + print("\nInterrupted by user") + except asyncio.CancelledError: + print("\nTest cancelled") + finally: + print("Cleaning up...") + context["active"] = False + await agent.stop() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/src/strands/experimental/bidi/types/__init__.py b/src/strands/experimental/bidi/types/__init__.py new file mode 100644 index 000000000..d5263bb28 --- /dev/null +++ b/src/strands/experimental/bidi/types/__init__.py @@ -0,0 +1,57 @@ +"""Type definitions for bidirectional streaming.""" + +from .agent import BidiAgentInput +from .io import BidiInput, BidiOutput +from .events import ( + DEFAULT_CHANNELS, + DEFAULT_FORMAT, + DEFAULT_SAMPLE_RATE, + SUPPORTED_AUDIO_FORMATS, + SUPPORTED_CHANNELS, + SUPPORTED_SAMPLE_RATES, + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionCloseEvent, + BidiConnectionStartEvent, + BidiErrorEvent, + BidiImageInputEvent, + BidiInputEvent, + BidiInterruptionEvent, + ModalityUsage, + BidiUsageEvent, + BidiOutputEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, +) + +__all__ = [ + "BidiInput", + "BidiOutput", + "BidiAgentInput", + # Input Events + "BidiTextInputEvent", + "BidiAudioInputEvent", + "BidiImageInputEvent", + "BidiInputEvent", + # Output Events + "BidiConnectionStartEvent", + "BidiConnectionCloseEvent", + "BidiResponseStartEvent", + "BidiResponseCompleteEvent", + "BidiAudioStreamEvent", + "BidiTranscriptStreamEvent", + "BidiInterruptionEvent", + "BidiUsageEvent", + "ModalityUsage", + "BidiErrorEvent", + "BidiOutputEvent", + # Constants + "SUPPORTED_AUDIO_FORMATS", + "SUPPORTED_SAMPLE_RATES", + "SUPPORTED_CHANNELS", + "DEFAULT_SAMPLE_RATE", + "DEFAULT_CHANNELS", + "DEFAULT_FORMAT", +] diff --git a/src/strands/experimental/bidi/types/agent.py b/src/strands/experimental/bidi/types/agent.py new file mode 100644 index 000000000..8d1e9aab7 --- /dev/null +++ b/src/strands/experimental/bidi/types/agent.py @@ -0,0 +1,10 @@ +"""Agent-related type definitions for bidirectional streaming. + +This module defines the types used for BidiAgent. +""" + +from typing import TypeAlias + +from .events import BidiAudioInputEvent, BidiImageInputEvent, BidiTextInputEvent + +BidiAgentInput: TypeAlias = str | BidiTextInputEvent | BidiAudioInputEvent | BidiImageInputEvent diff --git a/src/strands/experimental/bidi/types/events.py b/src/strands/experimental/bidi/types/events.py new file mode 100644 index 000000000..852950f5a --- /dev/null +++ b/src/strands/experimental/bidi/types/events.py @@ -0,0 +1,521 @@ +"""Bidirectional streaming types for real-time audio/text conversations. + +Type definitions for bidirectional streaming that extends Strands' existing streaming +capabilities with real-time audio and persistent connection support. + +Key features: +- Audio input/output events with standardized formats +- Interruption detection and handling +- Connection lifecycle management +- Provider-agnostic event types +- Type-safe discriminated unions with TypedEvent +- JSON-serializable events (audio/images stored as base64 strings) + +Audio format normalization: +- Supports PCM, WAV, Opus, and MP3 formats +- Standardizes sample rates (16kHz, 24kHz, 48kHz) +- Normalizes channel configurations (mono/stereo) +- Abstracts provider-specific encodings +- Audio data stored as base64-encoded strings for JSON compatibility +""" + +from typing import Any, Dict, List, Literal, Optional, Union, cast + +from ....types._events import ModelStreamEvent, TypedEvent +from ....types.streaming import ContentBlockDelta + +# Audio format constants +SUPPORTED_AUDIO_FORMATS = ["pcm", "wav", "opus", "mp3"] +SUPPORTED_SAMPLE_RATES = [16000, 24000, 48000] +SUPPORTED_CHANNELS = [1, 2] # 1=mono, 2=stereo +DEFAULT_SAMPLE_RATE = 16000 +DEFAULT_CHANNELS = 1 +DEFAULT_FORMAT = "pcm" + + +# ============================================================================ +# Input Events (sent via agent.send()) +# ============================================================================ + + +class BidiTextInputEvent(TypedEvent): + """Text input event for sending text to the model. + + Used for sending text content through the send() method. + + Parameters: + text: The text content to send to the model. + role: The role of the message sender (typically "user"). + """ + + def __init__(self, text: str, role: str): + super().__init__( + { + "type": "bidi_text_input", + "text": text, + "role": role, + } + ) + + @property + def text(self) -> str: + return cast(str, self.get("text")) + + @property + def role(self) -> str: + return cast(str, self.get("role")) + + +class BidiAudioInputEvent(TypedEvent): + """Audio input event for sending audio to the model. + + Used for sending audio data through the send() method. + + Parameters: + audio: Base64-encoded audio string to send to model. + format: Audio format from SUPPORTED_AUDIO_FORMATS. + sample_rate: Sample rate from SUPPORTED_SAMPLE_RATES. + channels: Channel count from SUPPORTED_CHANNELS. + """ + + def __init__( + self, + audio: str, + format: Literal["pcm", "wav", "opus", "mp3"], + sample_rate: Literal[16000, 24000, 48000], + channels: Literal[1, 2], + ): + super().__init__( + { + "type": "bidi_audio_input", + "audio": audio, + "format": format, + "sample_rate": sample_rate, + "channels": channels, + } + ) + + @property + def audio(self) -> str: + return cast(str, self.get("audio")) + + @property + def format(self) -> str: + return cast(str, self.get("format")) + + @property + def sample_rate(self) -> int: + return cast(int, self.get("sample_rate")) + + @property + def channels(self) -> int: + return cast(int, self.get("channels")) + + +class BidiImageInputEvent(TypedEvent): + """Image input event for sending images/video frames to the model. + + Used for sending image data through the send() method. + + Parameters: + image: Base64-encoded image string. + mime_type: MIME type (e.g., "image/jpeg", "image/png"). + """ + + def __init__( + self, + image: str, + mime_type: str, + ): + super().__init__( + { + "type": "bidi_image_input", + "image": image, + "mime_type": mime_type, + } + ) + + @property + def image(self) -> str: + return cast(str, self.get("image")) + + @property + def mime_type(self) -> str: + return cast(str, self.get("mime_type")) + + +# ============================================================================ +# Output Events (received via agent.receive()) +# ============================================================================ + + +class BidiConnectionStartEvent(TypedEvent): + """Streaming connection established and ready for interaction. + + Parameters: + connection_id: Unique identifier for this streaming connection. + model: Model identifier (e.g., "gpt-realtime", "gemini-2.0-flash-live"). + """ + + def __init__(self, connection_id: str, model: str): + super().__init__( + { + "type": "bidi_connection_start", + "connection_id": connection_id, + "model": model, + } + ) + + @property + def connection_id(self) -> str: + return cast(str, self.get("connection_id")) + + @property + def model(self) -> str: + return cast(str, self.get("model")) + + +class BidiResponseStartEvent(TypedEvent): + """Model starts generating a response. + + Parameters: + response_id: Unique identifier for this response (used in response.complete). + """ + + def __init__(self, response_id: str): + super().__init__({"type": "bidi_response_start", "response_id": response_id}) + + @property + def response_id(self) -> str: + return cast(str, self.get("response_id")) + + +class BidiAudioStreamEvent(TypedEvent): + """Streaming audio output from the model. + + Parameters: + audio: Base64-encoded audio string. + format: Audio encoding format. + sample_rate: Number of audio samples per second in Hz. + channels: Number of audio channels (1=mono, 2=stereo). + """ + + def __init__( + self, + audio: str, + format: Literal["pcm", "wav", "opus", "mp3"], + sample_rate: Literal[16000, 24000, 48000], + channels: Literal[1, 2], + ): + super().__init__( + { + "type": "bidi_audio_stream", + "audio": audio, + "format": format, + "sample_rate": sample_rate, + "channels": channels, + } + ) + + @property + def audio(self) -> str: + return cast(str, self.get("audio")) + + @property + def format(self) -> str: + return cast(str, self.get("format")) + + @property + def sample_rate(self) -> int: + return cast(int, self.get("sample_rate")) + + @property + def channels(self) -> int: + return cast(int, self.get("channels")) + + +class BidiTranscriptStreamEvent(ModelStreamEvent): + """Audio transcription streaming (user or assistant speech). + + Supports incremental transcript updates for providers that send partial + transcripts before the final version. + + Parameters: + delta: The incremental transcript change (ContentBlockDelta). + text: The delta text (same as delta content for convenience). + role: Who is speaking ("user" or "assistant"). + is_final: Whether this is the final/complete transcript. + current_transcript: The accumulated transcript text so far (None for first delta). + """ + + def __init__( + self, + delta: ContentBlockDelta, + text: str, + role: Literal["user", "assistant"], + is_final: bool, + current_transcript: Optional[str] = None, + ): + super().__init__( + { + "type": "bidi_transcript_stream", + "delta": delta, + "text": text, + "role": role, + "is_final": is_final, + "current_transcript": current_transcript, + } + ) + + @property + def delta(self) -> ContentBlockDelta: + return cast(ContentBlockDelta, self.get("delta")) + + @property + def text(self) -> str: + return cast(str, self.get("text")) + + @property + def role(self) -> str: + return cast(str, self.get("role")) + + @property + def is_final(self) -> bool: + return cast(bool, self.get("is_final")) + + @property + def current_transcript(self) -> Optional[str]: + return cast(Optional[str], self.get("current_transcript")) + + +class BidiInterruptionEvent(TypedEvent): + """Model generation was interrupted. + + Parameters: + reason: Why the interruption occurred. + response_id: ID of the response that was interrupted (may be None). + """ + + def __init__(self, reason: Literal["user_speech", "error"]): + super().__init__( + { + "type": "bidi_interruption", + "reason": reason, + } + ) + + @property + def reason(self) -> str: + return cast(str, self.get("reason")) + + +class BidiResponseCompleteEvent(TypedEvent): + """Model finished generating response. + + Parameters: + response_id: ID of the response that completed (matches response.start). + stop_reason: Why the response ended. + """ + + def __init__( + self, + response_id: str, + stop_reason: Literal["complete", "interrupted", "tool_use", "error"], + ): + super().__init__( + { + "type": "bidi_response_complete", + "response_id": response_id, + "stop_reason": stop_reason, + } + ) + + @property + def response_id(self) -> str: + return cast(str, self.get("response_id")) + + @property + def stop_reason(self) -> str: + return cast(str, self.get("stop_reason")) + + +class ModalityUsage(dict): + """Token usage for a specific modality. + + Attributes: + modality: Type of content. + input_tokens: Tokens used for this modality's input. + output_tokens: Tokens used for this modality's output. + """ + + modality: Literal["text", "audio", "image", "cached"] + input_tokens: int + output_tokens: int + + +class BidiUsageEvent(TypedEvent): + """Token usage event with modality breakdown for bidirectional streaming. + + Tracks token consumption across different modalities (audio, text, images) + during bidirectional streaming sessions. + + Parameters: + input_tokens: Total tokens used for all input modalities. + output_tokens: Total tokens used for all output modalities. + total_tokens: Sum of input and output tokens. + modality_details: Optional list of token usage per modality. + cache_read_input_tokens: Optional tokens read from cache. + cache_write_input_tokens: Optional tokens written to cache. + """ + + def __init__( + self, + input_tokens: int, + output_tokens: int, + total_tokens: int, + modality_details: Optional[List[ModalityUsage]] = None, + cache_read_input_tokens: Optional[int] = None, + cache_write_input_tokens: Optional[int] = None, + ): + data: Dict[str, Any] = { + "type": "bidi_usage", + "inputTokens": input_tokens, + "outputTokens": output_tokens, + "totalTokens": total_tokens, + } + if modality_details is not None: + data["modality_details"] = modality_details + if cache_read_input_tokens is not None: + data["cacheReadInputTokens"] = cache_read_input_tokens + if cache_write_input_tokens is not None: + data["cacheWriteInputTokens"] = cache_write_input_tokens + super().__init__(data) + + @property + def input_tokens(self) -> int: + return cast(int, self.get("inputTokens")) + + @property + def output_tokens(self) -> int: + return cast(int, self.get("outputTokens")) + + @property + def total_tokens(self) -> int: + return cast(int, self.get("totalTokens")) + + @property + def modality_details(self) -> List[ModalityUsage]: + return cast(List[ModalityUsage], self.get("modality_details", [])) + + @property + def cache_read_input_tokens(self) -> Optional[int]: + return cast(Optional[int], self.get("cacheReadInputTokens")) + + @property + def cache_write_input_tokens(self) -> Optional[int]: + return cast(Optional[int], self.get("cacheWriteInputTokens")) + + +class BidiConnectionCloseEvent(TypedEvent): + """Streaming connection closed. + + Parameters: + connection_id: Unique identifier for this streaming connection (matches BidiConnectionStartEvent). + reason: Why the connection was closed. + """ + + def __init__( + self, + connection_id: str, + reason: Literal["client_disconnect", "timeout", "error", "complete"], + ): + super().__init__( + { + "type": "bidi_connection_close", + "connection_id": connection_id, + "reason": reason, + } + ) + + @property + def connection_id(self) -> str: + return cast(str, self.get("connection_id")) + + @property + def reason(self) -> str: + return cast(str, self.get("reason")) + + +class BidiErrorEvent(TypedEvent): + """Error occurred during the session. + + Stores the full Exception object as an instance attribute for debugging while + keeping the event dict JSON-serializable. The exception can be accessed via + the `error` property for re-raising or type-based error handling. + + Parameters: + error: The exception that occurred. + details: Optional additional error information. + """ + + def __init__( + self, + error: Exception, + details: Optional[Dict[str, Any]] = None, + ): + # Store serializable data in dict (for JSON serialization) + super().__init__( + { + "type": "bidi_error", + "message": str(error), + "code": type(error).__name__, + "details": details, + } + ) + # Store exception as instance attribute (not serialized) + self._error = error + + @property + def error(self) -> Exception: + """The original exception that occurred. + + Can be used for re-raising or type-based error handling. + """ + return self._error + + @property + def code(self) -> str: + """Error code derived from exception class name.""" + return cast(str, self.get("code")) + + @property + def message(self) -> str: + """Human-readable error message from the exception.""" + return cast(str, self.get("message")) + + @property + def details(self) -> Optional[Dict[str, Any]]: + """Additional error context beyond the exception itself.""" + return cast(Optional[Dict[str, Any]], self.get("details")) + + +# ============================================================================ +# Type Unions +# ============================================================================ + +# Note: ToolResultEvent is imported from strands.types._events and used alongside +# BidiInputEvent in send() methods for sending tool results back to the model. + +BidiInputEvent = BidiTextInputEvent | BidiAudioInputEvent | BidiImageInputEvent + +BidiOutputEvent = ( + BidiConnectionStartEvent + | BidiResponseStartEvent + | BidiAudioStreamEvent + | BidiTranscriptStreamEvent + | BidiInterruptionEvent + | BidiResponseCompleteEvent + | BidiUsageEvent + | BidiConnectionCloseEvent + | BidiErrorEvent +) diff --git a/src/strands/experimental/bidi/types/io.py b/src/strands/experimental/bidi/types/io.py new file mode 100644 index 000000000..8b79455ec --- /dev/null +++ b/src/strands/experimental/bidi/types/io.py @@ -0,0 +1,57 @@ +"""Protocol for bidirectional streaming IO channels. + +Defines callable protocols for input and output channels that can be used +with BidiAgent. This approach provides better typing and flexibility +by separating input and output concerns into independent callables. +""" + +from typing import Awaitable, Protocol + +from ..types.events import BidiInputEvent, BidiOutputEvent + + +class BidiInput(Protocol): + """Protocol for bidirectional input callables. + + Input callables read data from a source (microphone, camera, websocket, etc.) + and return events to be sent to the agent. + """ + + async def start(self) -> None: + """Start input.""" + ... + + async def stop(self) -> None: + """Stop input.""" + ... + + def __call__(self) -> Awaitable[BidiInputEvent]: + """Read input data from the source. + + Returns: + Awaitable that resolves to an input event (audio, text, image, etc.) + """ + ... + +class BidiOutput(Protocol): + """Protocol for bidirectional output callables. + + Output callables receive events from the agent and handle them appropriately + (play audio, display text, send over websocket, etc.). + """ + + async def start(self) -> None: + """Start output.""" + ... + + async def stop(self) -> None: + """Stop output.""" + ... + + def __call__(self, event: BidiOutputEvent) -> Awaitable[None]: + """Process output events from the agent. + + Args: + event: Output event from the agent (audio, text, tool calls, etc.) + """ + ... diff --git a/tests/strands/experimental/bidi/__init__.py b/tests/strands/experimental/bidi/__init__.py new file mode 100644 index 000000000..ea37091cc --- /dev/null +++ b/tests/strands/experimental/bidi/__init__.py @@ -0,0 +1 @@ +"""Bidirectional streaming tests.""" diff --git a/tests/strands/experimental/bidi/models/__init__.py b/tests/strands/experimental/bidi/models/__init__.py new file mode 100644 index 000000000..ea9fbb2d0 --- /dev/null +++ b/tests/strands/experimental/bidi/models/__init__.py @@ -0,0 +1 @@ +"""Bidirectional streaming model tests.""" diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py new file mode 100644 index 000000000..c575f1788 --- /dev/null +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -0,0 +1,487 @@ +"""Unit tests for Gemini Live bidirectional streaming model. + +Tests the unified BidiGeminiLiveModel interface including: +- Model initialization and configuration +- Connection establishment and lifecycle +- Unified send() method with different content types +- Event receiving and conversion +""" + +import base64 +import json +import unittest.mock + +import pytest +from google import genai +from google.genai import types as genai_types + +from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel +from strands.experimental.bidi.types.events import ( + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionCloseEvent, + BidiConnectionStartEvent, + BidiImageInputEvent, + BidiInterruptionEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, +) +from strands.types._events import ToolResultEvent +from strands.types.tools import ToolResult + + +@pytest.fixture +def mock_genai_client(): + """Mock the Google GenAI client.""" + with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.gemini_live.genai.Client") as mock_client_cls: + mock_client = mock_client_cls.return_value + mock_client.aio = unittest.mock.MagicMock() + + # Mock the live session + mock_live_session = unittest.mock.AsyncMock() + + # Mock the context manager + mock_live_session_cm = unittest.mock.MagicMock() + mock_live_session_cm.__aenter__ = unittest.mock.AsyncMock(return_value=mock_live_session) + mock_live_session_cm.__aexit__ = unittest.mock.AsyncMock(return_value=None) + + # Make connect return the context manager + mock_client.aio.live.connect = unittest.mock.MagicMock(return_value=mock_live_session_cm) + + yield mock_client, mock_live_session, mock_live_session_cm + + +@pytest.fixture +def model_id(): + return "models/gemini-2.0-flash-live-preview-04-09" + + +@pytest.fixture +def api_key(): + return "test-api-key" + + +@pytest.fixture +def model(mock_genai_client, model_id, api_key): + """Create a BidiGeminiLiveModel instance.""" + _ = mock_genai_client + return BidiGeminiLiveModel(model_id=model_id, api_key=api_key) + + +@pytest.fixture +def tool_spec(): + return { + "description": "Calculate mathematical expressions", + "name": "calculator", + "inputSchema": {"json": {"type": "object", "properties": {"expression": {"type": "string"}}}}, + } + + +@pytest.fixture +def system_prompt(): + return "You are a helpful assistant" + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "Hello"}]}] + + +# Initialization Tests + + +def test_model_initialization(mock_genai_client, model_id, api_key): + """Test model initialization with various configurations.""" + _ = mock_genai_client + + # Test default config + model_default = BidiGeminiLiveModel() + assert model_default.model_id == "gemini-2.5-flash-native-audio-preview-09-2025" + assert model_default.api_key is None + assert model_default._active is False + assert model_default.live_session is None + # Check default config includes transcription + assert model_default.live_config["response_modalities"] == ["AUDIO"] + assert "outputAudioTranscription" in model_default.live_config + assert "inputAudioTranscription" in model_default.live_config + + # Test with API key + model_with_key = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) + assert model_with_key.model_id == model_id + assert model_with_key.api_key == api_key + + # Test with custom config (merges with defaults) + live_config = {"temperature": 0.7, "top_p": 0.9} + model_custom = BidiGeminiLiveModel(model_id=model_id, live_config=live_config) + # Custom config should be merged with defaults + assert model_custom.live_config["temperature"] == 0.7 + assert model_custom.live_config["top_p"] == 0.9 + # Defaults should still be present + assert "response_modalities" in model_custom.live_config + + +# Connection Tests + + +@pytest.mark.asyncio +async def test_connection_lifecycle(mock_genai_client, model, system_prompt, tool_spec, messages): + """Test complete connection lifecycle with various configurations.""" + mock_client, mock_live_session, mock_live_session_cm = mock_genai_client + + # Test basic connection + await model.start() + assert model._active is True + assert model.connection_id is not None + assert model.live_session == mock_live_session + mock_client.aio.live.connect.assert_called_once() + + # Test close + await model.stop() + assert model._active is False + mock_live_session_cm.__aexit__.assert_called_once() + + # Test connection with system prompt + await model.start(system_prompt=system_prompt) + call_args = mock_client.aio.live.connect.call_args + config = call_args.kwargs.get("config", {}) + assert config.get("system_instruction") == system_prompt + await model.stop() + + # Test connection with tools + await model.start(tools=[tool_spec]) + call_args = mock_client.aio.live.connect.call_args + config = call_args.kwargs.get("config", {}) + assert "tools" in config + assert len(config["tools"]) > 0 + await model.stop() + + # Test connection with messages + await model.start(messages=messages) + mock_live_session.send_client_content.assert_called() + await model.stop() + + +@pytest.mark.asyncio +async def test_connection_edge_cases(mock_genai_client, api_key, model_id): + """Test connection error handling and edge cases.""" + mock_client, _, mock_live_session_cm = mock_genai_client + + # Test connection error + model1 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) + mock_client.aio.live.connect.side_effect = Exception("Connection failed") + with pytest.raises(Exception, match="Connection failed"): + await model1.start() + + # Reset mock for next tests + mock_client.aio.live.connect.side_effect = None + + # Test double connection + model2 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) + await model2.start() + with pytest.raises(RuntimeError, match="Connection already active"): + await model2.start() + await model2.stop() + + # Test close when not connected + model3 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) + await model3.stop() # Should not raise + + # Test close error handling + model4 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) + await model4.start() + mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") + with pytest.raises(Exception, match="Close failed"): + await model4.stop() + + +# Send Method Tests + + +@pytest.mark.asyncio +async def test_send_all_content_types(mock_genai_client, model): + """Test sending all content types through unified send() method.""" + _, mock_live_session, _ = mock_genai_client + await model.start() + + # Test text input + text_input = BidiTextInputEvent(text="Hello", role="user") + await model.send(text_input) + mock_live_session.send_client_content.assert_called_once() + call_args = mock_live_session.send_client_content.call_args + content = call_args.kwargs.get("turns") + assert content.role == "user" + assert content.parts[0].text == "Hello" + + # Test audio input (base64 encoded) + audio_b64 = base64.b64encode(b"audio_bytes").decode('utf-8') + audio_input = BidiAudioInputEvent( + audio=audio_b64, + format="pcm", + sample_rate=16000, + channels=1, + ) + await model.send(audio_input) + mock_live_session.send_realtime_input.assert_called_once() + + # Test image input (base64 encoded, no encoding parameter) + image_b64 = base64.b64encode(b"image_bytes").decode('utf-8') + image_input = BidiImageInputEvent( + image=image_b64, + mime_type="image/jpeg", + ) + await model.send(image_input) + mock_live_session.send.assert_called_once() + + # Test tool result + tool_result: ToolResult = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Result: 42"}], + } + await model.send(ToolResultEvent(tool_result)) + mock_live_session.send_tool_response.assert_called_once() + + await model.stop() + + +@pytest.mark.asyncio +async def test_send_edge_cases(mock_genai_client, model): + """Test send() edge cases and error handling.""" + _, mock_live_session, _ = mock_genai_client + + # Test send when inactive + text_input = BidiTextInputEvent(text="Hello", role="user") + await model.send(text_input) + mock_live_session.send_client_content.assert_not_called() + + # Test unknown content type + await model.start() + unknown_content = {"unknown_field": "value"} + await model.send(unknown_content) # Should not raise, just log warning + + await model.stop() + + +# Receive Method Tests + + +@pytest.mark.asyncio +async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): + """Test that receive() emits connection start and end events.""" + _, mock_live_session, _ = mock_genai_client + mock_live_session.receive.return_value = agenerator([]) + + await model.start() + + # Collect events + events = [] + async for event in model.receive(): + events.append(event) + # Close after first event to trigger connection end + if len(events) == 1: + await model.stop() + + # Verify connection start and end + assert len(events) >= 2 + assert isinstance(events[0], BidiConnectionStartEvent) + assert events[0].get("type") == "bidi_connection_start" + assert events[0].connection_id == model.connection_id + assert isinstance(events[-1], BidiConnectionCloseEvent) + assert events[-1].get("type") == "bidi_connection_close" + + +@pytest.mark.asyncio +async def test_event_conversion(mock_genai_client, model): + """Test conversion of all Gemini Live event types to standard format.""" + _, _, _ = mock_genai_client + await model.start() + + # Test text output (converted to transcript via model_turn.parts) + mock_text = unittest.mock.Mock() + mock_text.data = None + mock_text.tool_call = None + + # Create proper server_content structure with model_turn + mock_server_content = unittest.mock.Mock() + mock_server_content.interrupted = False + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + + mock_model_turn = unittest.mock.Mock() + mock_part = unittest.mock.Mock() + mock_part.text = "Hello from Gemini" + mock_model_turn.parts = [mock_part] + mock_server_content.model_turn = mock_model_turn + + mock_text.server_content = mock_server_content + + text_events = model._convert_gemini_live_event(mock_text) + assert isinstance(text_events, list) + assert len(text_events) == 1 + text_event = text_events[0] + assert isinstance(text_event, BidiTranscriptStreamEvent) + assert text_event.get("type") == "bidi_transcript_stream" + assert text_event.text == "Hello from Gemini" + assert text_event.role == "assistant" + assert text_event.is_final is True + assert text_event.delta == {"text": "Hello from Gemini"} + assert text_event.current_transcript == "Hello from Gemini" + + # Test multiple text parts (should concatenate) + mock_multi_text = unittest.mock.Mock() + mock_multi_text.data = None + mock_multi_text.tool_call = None + + mock_server_content_multi = unittest.mock.Mock() + mock_server_content_multi.interrupted = False + mock_server_content_multi.input_transcription = None + mock_server_content_multi.output_transcription = None + + mock_model_turn_multi = unittest.mock.Mock() + mock_part1 = unittest.mock.Mock() + mock_part1.text = "Hello" + mock_part2 = unittest.mock.Mock() + mock_part2.text = "from Gemini" + mock_model_turn_multi.parts = [mock_part1, mock_part2] + mock_server_content_multi.model_turn = mock_model_turn_multi + + mock_multi_text.server_content = mock_server_content_multi + + multi_text_events = model._convert_gemini_live_event(mock_multi_text) + assert isinstance(multi_text_events, list) + assert len(multi_text_events) == 1 + multi_text_event = multi_text_events[0] + assert isinstance(multi_text_event, BidiTranscriptStreamEvent) + assert multi_text_event.text == "Hello from Gemini" # Concatenated with space + + # Test audio output (base64 encoded) + mock_audio = unittest.mock.Mock() + mock_audio.text = None + mock_audio.data = b"audio_data" + mock_audio.tool_call = None + mock_audio.server_content = None + + audio_events = model._convert_gemini_live_event(mock_audio) + assert isinstance(audio_events, list) + assert len(audio_events) == 1 + audio_event = audio_events[0] + assert isinstance(audio_event, BidiAudioStreamEvent) + assert audio_event.get("type") == "bidi_audio_stream" + # Audio is now base64 encoded + expected_b64 = base64.b64encode(b"audio_data").decode('utf-8') + assert audio_event.audio == expected_b64 + assert audio_event.format == "pcm" + + # Test single tool call (returns list with one event) + mock_func_call = unittest.mock.Mock() + mock_func_call.id = "tool-123" + mock_func_call.name = "calculator" + mock_func_call.args = {"expression": "2+2"} + + mock_tool_call = unittest.mock.Mock() + mock_tool_call.function_calls = [mock_func_call] + + mock_tool = unittest.mock.Mock() + mock_tool.text = None + mock_tool.data = None + mock_tool.tool_call = mock_tool_call + mock_tool.server_content = None + + tool_events = model._convert_gemini_live_event(mock_tool) + # Should return a list of ToolUseStreamEvent + assert isinstance(tool_events, list) + assert len(tool_events) == 1 + tool_event = tool_events[0] + # ToolUseStreamEvent has delta and current_tool_use, not a "type" field + assert "delta" in tool_event + assert "toolUse" in tool_event["delta"] + assert tool_event["delta"]["toolUse"]["toolUseId"] == "tool-123" + assert tool_event["delta"]["toolUse"]["name"] == "calculator" + + # Test multiple tool calls (returns list with multiple events) + mock_func_call_1 = unittest.mock.Mock() + mock_func_call_1.id = "tool-123" + mock_func_call_1.name = "calculator" + mock_func_call_1.args = {"expression": "2+2"} + + mock_func_call_2 = unittest.mock.Mock() + mock_func_call_2.id = "tool-456" + mock_func_call_2.name = "weather" + mock_func_call_2.args = {"location": "Seattle"} + + mock_tool_call_multi = unittest.mock.Mock() + mock_tool_call_multi.function_calls = [mock_func_call_1, mock_func_call_2] + + mock_tool_multi = unittest.mock.Mock() + mock_tool_multi.text = None + mock_tool_multi.data = None + mock_tool_multi.tool_call = mock_tool_call_multi + mock_tool_multi.server_content = None + + tool_events_multi = model._convert_gemini_live_event(mock_tool_multi) + # Should return a list with two ToolUseStreamEvent + assert isinstance(tool_events_multi, list) + assert len(tool_events_multi) == 2 + + # Verify first tool call + assert tool_events_multi[0]["delta"]["toolUse"]["toolUseId"] == "tool-123" + assert tool_events_multi[0]["delta"]["toolUse"]["name"] == "calculator" + assert tool_events_multi[0]["delta"]["toolUse"]["input"] == {"expression": "2+2"} + + # Verify second tool call + assert tool_events_multi[1]["delta"]["toolUse"]["toolUseId"] == "tool-456" + assert tool_events_multi[1]["delta"]["toolUse"]["name"] == "weather" + assert tool_events_multi[1]["delta"]["toolUse"]["input"] == {"location": "Seattle"} + + # Test interruption + mock_server_content = unittest.mock.Mock() + mock_server_content.interrupted = True + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + + mock_interrupt = unittest.mock.Mock() + mock_interrupt.text = None + mock_interrupt.data = None + mock_interrupt.tool_call = None + mock_interrupt.server_content = mock_server_content + + interrupt_events = model._convert_gemini_live_event(mock_interrupt) + assert isinstance(interrupt_events, list) + assert len(interrupt_events) == 1 + interrupt_event = interrupt_events[0] + assert isinstance(interrupt_event, BidiInterruptionEvent) + assert interrupt_event.get("type") == "bidi_interruption" + assert interrupt_event.reason == "user_speech" + + await model.stop() + + +# Helper Method Tests + + +def test_config_building(model, system_prompt, tool_spec): + """Test building live config with various options.""" + # Test basic config + config_basic = model._build_live_config() + assert isinstance(config_basic, dict) + + # Test with system prompt + config_prompt = model._build_live_config(system_prompt=system_prompt) + assert config_prompt["system_instruction"] == system_prompt + + # Test with tools + config_tools = model._build_live_config(tools=[tool_spec]) + assert "tools" in config_tools + assert len(config_tools["tools"]) > 0 + + +def test_tool_formatting(model, tool_spec): + """Test tool formatting for Gemini Live API.""" + # Test with tools + formatted_tools = model._format_tools_for_live_api([tool_spec]) + assert len(formatted_tools) == 1 + assert isinstance(formatted_tools[0], genai_types.Tool) + + # Test empty list + formatted_empty = model._format_tools_for_live_api([]) + assert formatted_empty == [] diff --git a/tests/strands/experimental/bidi/models/test_novasonic.py b/tests/strands/experimental/bidi/models/test_novasonic.py new file mode 100644 index 000000000..db61ed43e --- /dev/null +++ b/tests/strands/experimental/bidi/models/test_novasonic.py @@ -0,0 +1,458 @@ +"""Unit tests for Nova Sonic bidirectional model implementation. + +Tests the unified BidirectionalModel interface implementation for Amazon Nova Sonic, +covering connection lifecycle, event conversion, audio streaming, and tool execution. +""" + +import asyncio +import base64 +import json +from unittest.mock import AsyncMock, patch + +import pytest +import pytest_asyncio + +from strands.experimental.bidi.models.novasonic import ( + BidiNovaSonicModel, +) +from strands.experimental.bidi.types.events import ( + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiImageInputEvent, + BidiInterruptionEvent, + BidiResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiUsageEvent, +) +from strands.types._events import ToolResultEvent +from strands.types.tools import ToolResult + + +# Test fixtures +@pytest.fixture +def model_id(): + """Nova Sonic model identifier.""" + return "amazon.nova-sonic-v1:0" + + +@pytest.fixture +def region(): + """AWS region.""" + return "us-east-1" + + +@pytest.fixture +def mock_stream(): + """Mock Nova Sonic bidirectional stream.""" + stream = AsyncMock() + stream.input_stream = AsyncMock() + stream.input_stream.send = AsyncMock() + stream.input_stream.close = AsyncMock() + stream.await_output = AsyncMock() + return stream + + +@pytest.fixture +def mock_client(mock_stream): + """Mock Bedrock Runtime client.""" + client = AsyncMock() + client.invoke_model_with_bidirectional_stream = AsyncMock(return_value=mock_stream) + return client + + +@pytest_asyncio.fixture +async def nova_model(model_id, region): + """Create Nova Sonic model instance.""" + model = BidiNovaSonicModel(model_id=model_id, region=region) + yield model + # Cleanup + if model._active: + await model.stop() + + +# Initialization and Connection Tests + + +@pytest.mark.asyncio +async def test_model_initialization(model_id, region): + """Test model initialization with configuration.""" + model = BidiNovaSonicModel(model_id=model_id, region=region) + + assert model.model_id == model_id + assert model.region == region + assert model.stream is None + assert not model._active + assert model.connection_id is None + + +@pytest.mark.asyncio +async def test_connection_lifecycle(nova_model, mock_client, mock_stream): + """Test complete connection lifecycle with various configurations.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model.client = mock_client + + # Test basic connection + await nova_model.start(system_prompt="Test system prompt") + assert nova_model._active + assert nova_model.stream == mock_stream + assert nova_model.connection_id is not None + assert mock_client.invoke_model_with_bidirectional_stream.called + + # Test close + await nova_model.stop() + assert not nova_model._active + assert mock_stream.input_stream.close.called + + # Test connection with tools + tools = [ + { + "name": "get_weather", + "description": "Get weather information", + "inputSchema": {"json": json.dumps({"type": "object", "properties": {}})} + } + ] + await nova_model.start(system_prompt="You are helpful", tools=tools) + # Verify initialization events were sent (connectionStart, promptStart, system prompt) + assert mock_stream.input_stream.send.call_count >= 3 + await nova_model.stop() + + +@pytest.mark.asyncio +async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model_id, region): + """Test connection error handling and edge cases.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model.client = mock_client + + # Test double connection + await nova_model.start() + with pytest.raises(RuntimeError, match="Connection already active"): + await nova_model.start() + await nova_model.stop() + + # Test close when already closed + model2 = BidiNovaSonicModel(model_id=model_id, region=region) + await model2.stop() # Should not raise + await model2.stop() # Second call should also be safe + + +# Send Method Tests + + +@pytest.mark.asyncio +async def test_send_all_content_types(nova_model, mock_client, mock_stream): + """Test sending all content types through unified send() method.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model.client = mock_client + + await nova_model.start() + + # Test text content + text_event = BidiTextInputEvent(text="Hello, Nova!", role="user") + await nova_model.send(text_event) + # Should send contentStart, textInput, and contentEnd + assert mock_stream.input_stream.send.call_count >= 3 + + # Test audio content (base64 encoded) + audio_b64 = base64.b64encode(b"audio data").decode('utf-8') + audio_event = BidiAudioInputEvent( + audio=audio_b64, + format="pcm", + sample_rate=16000, + channels=1 + ) + await nova_model.send(audio_event) + # Should start audio connection and send audio + assert nova_model.audio_connection_active + assert mock_stream.input_stream.send.called + + # Test tool result + tool_result: ToolResult = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Weather is sunny"}] + } + await nova_model.send(ToolResultEvent(tool_result)) + # Should send contentStart, toolResult, and contentEnd + assert mock_stream.input_stream.send.called + + await nova_model.stop() + + +@pytest.mark.asyncio +async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog): + """Test send() edge cases and error handling.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model.client = mock_client + + # Test send when inactive + text_event = BidiTextInputEvent(text="Hello", role="user") + await nova_model.send(text_event) # Should not raise + + # Test image content (not supported, base64 encoded, no encoding parameter) + await nova_model.start() + image_b64 = base64.b64encode(b"image data").decode('utf-8') + image_event = BidiImageInputEvent( + image=image_b64, + mime_type="image/jpeg", + ) + await nova_model.send(image_event) + # Should log warning about unsupported image input + assert any("not supported" in record.message.lower() for record in caplog.records) + + await nova_model.stop() + + +# Receive and Event Conversion Tests + + +@pytest.mark.asyncio +async def test_receive_lifecycle_events(nova_model, mock_client, mock_stream): + """Test that receive() emits connection start and end events.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model.client = mock_client + + # Setup mock to return no events and then stop + async def mock_wait_for(*args, **kwargs): + await asyncio.sleep(0.1) + nova_model._active = False + raise asyncio.TimeoutError() + + with patch("asyncio.wait_for", side_effect=mock_wait_for): + await nova_model.start() + + events = [] + async for event in nova_model.receive(): + events.append(event) + + # Should have session start and end (new TypedEvent format) + assert len(events) >= 2 + assert events[0].get("type") == "bidi_connection_start" + assert events[0].get("connection_id") == nova_model.connection_id + assert events[-1].get("type") == "bidi_connection_close" + + +@pytest.mark.asyncio +async def test_event_conversion(nova_model): + """Test conversion of all Nova Sonic event types to standard format.""" + # Test audio output (now returns BidiAudioStreamEvent) + audio_bytes = b"test audio data" + audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") + nova_event = {"audioOutput": {"content": audio_base64}} + result = nova_model._convert_nova_event(nova_event) + assert result is not None + assert isinstance(result, BidiAudioStreamEvent) + assert result.get("type") == "bidi_audio_stream" + # Audio is kept as base64 string + assert result.get("audio") == audio_base64 + assert result.get("format") == "pcm" + assert result.get("sample_rate") == 24000 + + # Test text output (now returns BidiTranscriptStreamEvent) + nova_event = {"textOutput": {"content": "Hello, world!", "role": "ASSISTANT"}} + result = nova_model._convert_nova_event(nova_event) + assert result is not None + assert isinstance(result, BidiTranscriptStreamEvent) + assert result.get("type") == "bidi_transcript_stream" + assert result.get("text") == "Hello, world!" + assert result.get("role") == "assistant" + assert result.delta == {"text": "Hello, world!"} + assert result.current_transcript == "Hello, world!" + + # Test tool use (now returns ToolUseStreamEvent from core strands) + tool_input = {"location": "Seattle"} + nova_event = { + "toolUse": { + "toolUseId": "tool-123", + "toolName": "get_weather", + "content": json.dumps(tool_input) + } + } + result = nova_model._convert_nova_event(nova_event) + assert result is not None + # ToolUseStreamEvent has delta and current_tool_use, not a "type" field + assert "delta" in result + assert "toolUse" in result["delta"] + tool_use = result["delta"]["toolUse"] + assert tool_use["toolUseId"] == "tool-123" + assert tool_use["name"] == "get_weather" + assert tool_use["input"] == tool_input + + # Test interruption (now returns BidiInterruptionEvent) + nova_event = {"stopReason": "INTERRUPTED"} + result = nova_model._convert_nova_event(nova_event) + assert result is not None + assert isinstance(result, BidiInterruptionEvent) + assert result.get("type") == "bidi_interruption" + assert result.get("reason") == "user_speech" + + # Test usage metrics (now returns BidiUsageEvent) + nova_event = { + "usageEvent": { + "totalTokens": 100, + "totalInputTokens": 40, + "totalOutputTokens": 60, + "details": { + "total": { + "output": { + "speechTokens": 30 + } + } + } + } + } + result = nova_model._convert_nova_event(nova_event) + assert result is not None + assert isinstance(result, BidiUsageEvent) + assert result.get("type") == "bidi_usage" + assert result.get("totalTokens") == 100 + assert result.get("inputTokens") == 40 + assert result.get("outputTokens") == 60 + + # Test content start tracks role and emits BidiResponseStartEvent + nova_event = {"contentStart": {"role": "USER"}} + result = nova_model._convert_nova_event(nova_event) + assert result is not None + assert isinstance(result, BidiResponseStartEvent) + assert result.get("type") == "bidi_response_start" + assert nova_model._current_role == "USER" + + +# Audio Streaming Tests + + +@pytest.mark.asyncio +async def test_audio_connection_lifecycle(nova_model, mock_client, mock_stream): + """Test audio connection start and end lifecycle.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model.client = mock_client + + await nova_model.start() + + # Start audio connection + await nova_model._start_audio_connection() + assert nova_model.audio_connection_active + + # End audio connection + await nova_model._end_audio_input() + assert not nova_model.audio_connection_active + + await nova_model.stop() + + +@pytest.mark.asyncio +async def test_silence_detection(nova_model, mock_client, mock_stream): + """Test that silence detection automatically ends audio input.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model.client = mock_client + nova_model.silence_threshold = 0.1 # Short threshold for testing + + await nova_model.start() + + # Send audio to start connection (base64 encoded) + audio_b64 = base64.b64encode(b"audio data").decode('utf-8') + audio_event = BidiAudioInputEvent( + audio=audio_b64, + format="pcm", + sample_rate=16000, + channels=1 + ) + + await nova_model.send(audio_event) + assert nova_model.audio_connection_active + + # Wait for silence detection + await asyncio.sleep(0.2) + + # Audio connection should be ended + assert not nova_model.audio_connection_active + + await nova_model.stop() + + +# Helper Method Tests + + +@pytest.mark.asyncio +async def test_tool_configuration(nova_model): + """Test building tool configuration from tool specs.""" + tools = [ + { + "name": "get_weather", + "description": "Get weather information", + "inputSchema": { + "json": json.dumps({ + "type": "object", + "properties": { + "location": {"type": "string"} + } + }) + } + } + ] + + tool_config = nova_model._build_tool_configuration(tools) + + assert len(tool_config) == 1 + assert tool_config[0]["toolSpec"]["name"] == "get_weather" + assert tool_config[0]["toolSpec"]["description"] == "Get weather information" + assert "inputSchema" in tool_config[0]["toolSpec"] + + +@pytest.mark.asyncio +async def test_event_templates(nova_model): + """Test event template generation.""" + # Test connection start event + event_json = nova_model._get_connection_start_event() + event = json.loads(event_json) + assert "event" in event + assert "sessionStart" in event["event"] + assert "inferenceConfiguration" in event["event"]["sessionStart"] + + # Test prompt start event + nova_model.connection_id = "test-connection" + event_json = nova_model._get_prompt_start_event([]) + event = json.loads(event_json) + assert "event" in event + assert "promptStart" in event["event"] + assert event["event"]["promptStart"]["promptName"] == "test-connection" + + # Test text input event + content_name = "test-content" + event_json = nova_model._get_text_input_event(content_name, "Hello") + event = json.loads(event_json) + assert "event" in event + assert "textInput" in event["event"] + assert event["event"]["textInput"]["content"] == "Hello" + + # Test tool result event + result = {"result": "Success"} + event_json = nova_model._get_tool_result_event(content_name, result) + event = json.loads(event_json) + assert "event" in event + assert "toolResult" in event["event"] + assert json.loads(event["event"]["toolResult"]["content"]) == result + + +# Error Handling Tests + + +@pytest.mark.asyncio +async def test_error_handling(nova_model, mock_client, mock_stream): + """Test error handling in various scenarios.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model.client = mock_client + + # Test response processor handles errors gracefully + async def mock_error(*args, **kwargs): + raise Exception("Test error") + + mock_stream.await_output.side_effect = mock_error + + await nova_model.start() + + # Wait a bit for response processor to handle error + await asyncio.sleep(0.1) + + # Should still be able to close cleanly + await nova_model.stop() diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py new file mode 100644 index 000000000..b9e844250 --- /dev/null +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -0,0 +1,538 @@ +"""Unit tests for OpenAI Realtime bidirectional streaming model. + +Tests the unified BidiOpenAIRealtimeModel interface including: +- Model initialization and configuration +- Connection establishment with WebSocket +- Unified send() method with different content types +- Event receiving and conversion +- Connection lifecycle management +""" + +import asyncio +import base64 +import json +import unittest.mock + +import pytest + +from strands.experimental.bidi.models.openai import BidiOpenAIRealtimeModel +from strands.experimental.bidi.types.events import ( + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiImageInputEvent, + BidiInterruptionEvent, + BidiResponseCompleteEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, +) +from strands.types._events import ToolResultEvent +from strands.types.tools import ToolResult + + +@pytest.fixture +def mock_websocket(): + """Mock WebSocket connection.""" + mock_ws = unittest.mock.AsyncMock() + mock_ws.send = unittest.mock.AsyncMock() + mock_ws.close = unittest.mock.AsyncMock() + return mock_ws + + +@pytest.fixture +def mock_websockets_connect(mock_websocket): + """Mock websockets.connect function.""" + async def async_connect(*args, **kwargs): + return mock_websocket + + with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.websockets.connect") as mock_connect: + mock_connect.side_effect = async_connect + yield mock_connect, mock_websocket + + +@pytest.fixture +def model_name(): + return "gpt-realtime" + + +@pytest.fixture +def api_key(): + return "test-api-key" + + +@pytest.fixture +def model(api_key, model_name): + """Create an BidiOpenAIRealtimeModel instance.""" + return BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) + + +@pytest.fixture +def tool_spec(): + return { + "description": "Calculate mathematical expressions", + "name": "calculator", + "inputSchema": {"json": {"type": "object", "properties": {"expression": {"type": "string"}}}}, + } + + +@pytest.fixture +def system_prompt(): + return "You are a helpful assistant" + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "Hello"}]}] + + +# Initialization Tests + + +def test_model_initialization(api_key, model_name): + """Test model initialization with various configurations.""" + # Test default config + model_default = BidiOpenAIRealtimeModel(api_key="test-key") + assert model_default.model == "gpt-realtime" + assert model_default.api_key == "test-key" + assert model_default._active is False + assert model_default.websocket is None + + # Test with custom model + model_custom = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) + assert model_custom.model == model_name + assert model_custom.api_key == api_key + + # Test with organization and project + model_org = BidiOpenAIRealtimeModel( + model=model_name, + api_key=api_key, + organization="org-123", + project="proj-456" + ) + assert model_org.organization == "org-123" + assert model_org.project == "proj-456" + + # Test with env API key + with unittest.mock.patch.dict("os.environ", {"OPENAI_API_KEY": "env-key"}): + model_env = BidiOpenAIRealtimeModel() + assert model_env.api_key == "env-key" + + +def test_init_without_api_key_raises(): + """Test that initialization without API key raises error.""" + with unittest.mock.patch.dict("os.environ", {}, clear=True): + with pytest.raises(ValueError, match="OpenAI API key is required"): + BidiOpenAIRealtimeModel() + + +# Connection Tests + + +@pytest.mark.asyncio +async def test_connection_lifecycle(mock_websockets_connect, model, system_prompt, tool_spec, messages): + """Test complete connection lifecycle with various configurations.""" + mock_connect, mock_ws = mock_websockets_connect + + # Test basic connection + await model.start() + assert model._active is True + assert model.connection_id is not None + assert model.websocket == mock_ws + assert model._event_queue is not None + assert model._response_task is not None + mock_connect.assert_called_once() + + # Test close + await model.stop() + assert model._active is False + mock_ws.close.assert_called_once() + + # Test connection with system prompt + await model.start(system_prompt=system_prompt) + calls = mock_ws.send.call_args_list + session_update = next( + (json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "session.update"), + None + ) + assert session_update is not None + assert system_prompt in session_update["session"]["instructions"] + await model.stop() + + # Test connection with tools + await model.start(tools=[tool_spec]) + calls = mock_ws.send.call_args_list + # Tools are sent in a separate session.update after initial connection + session_updates = [json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "session.update"] + assert len(session_updates) > 0 + # Check if any session update has tools + has_tools = any("tools" in update.get("session", {}) for update in session_updates) + assert has_tools + await model.stop() + + # Test connection with messages + await model.start(messages=messages) + calls = mock_ws.send.call_args_list + item_creates = [json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "conversation.item.create"] + assert len(item_creates) > 0 + await model.stop() + + # Test connection with organization header + model_org = BidiOpenAIRealtimeModel(api_key="test-key", organization="org-123") + await model_org.start() + call_kwargs = mock_connect.call_args.kwargs + headers = call_kwargs.get("additional_headers", []) + org_header = [h for h in headers if h[0] == "OpenAI-Organization"] + assert len(org_header) == 1 + assert org_header[0][1] == "org-123" + await model_org.stop() + + +@pytest.mark.asyncio +async def test_connection_edge_cases(mock_websockets_connect, api_key, model_name): + """Test connection error handling and edge cases.""" + mock_connect, mock_ws = mock_websockets_connect + + # Test connection error + model1 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) + mock_connect.side_effect = Exception("Connection failed") + with pytest.raises(Exception, match="Connection failed"): + await model1.start() + + # Reset mock + async def async_connect(*args, **kwargs): + return mock_ws + mock_connect.side_effect = async_connect + + # Test double connection + model2 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) + await model2.start() + with pytest.raises(RuntimeError, match="Connection already active"): + await model2.start() + await model2.stop() + + # Test close when not connected + model3 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) + await model3.stop() # Should not raise + + # Test close error handling (should not raise, just log) + model4 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) + await model4.start() + mock_ws.close.side_effect = Exception("Close failed") + await model4.stop() # Should not raise + assert model4._active is False + + +# Send Method Tests + + +@pytest.mark.asyncio +async def test_send_all_content_types(mock_websockets_connect, model): + """Test sending all content types through unified send() method.""" + _, mock_ws = mock_websockets_connect + await model.start() + + # Test text input + text_input = BidiTextInputEvent(text="Hello", role="user") + await model.send(text_input) + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + item_create = [m for m in messages if m.get("type") == "conversation.item.create"] + response_create = [m for m in messages if m.get("type") == "response.create"] + assert len(item_create) > 0 + assert len(response_create) > 0 + + # Test audio input (base64 encoded) + audio_b64 = base64.b64encode(b"audio_bytes").decode('utf-8') + audio_input = BidiAudioInputEvent( + audio=audio_b64, + format="pcm", + sample_rate=24000, + channels=1, + ) + await model.send(audio_input) + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + audio_append = [m for m in messages if m.get("type") == "input_audio_buffer.append"] + assert len(audio_append) > 0 + assert "audio" in audio_append[0] + # Audio should be passed through as base64 + assert audio_append[0]["audio"] == audio_b64 + + # Test tool result + tool_result: ToolResult = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Result: 42"}], + } + await model.send(ToolResultEvent(tool_result)) + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + item_create = [m for m in messages if m.get("type") == "conversation.item.create"] + assert len(item_create) > 0 + item = item_create[-1].get("item", {}) + assert item.get("type") == "function_call_output" + assert item.get("call_id") == "tool-123" + + await model.stop() + + +@pytest.mark.asyncio +async def test_send_edge_cases(mock_websockets_connect, model): + """Test send() edge cases and error handling.""" + _, mock_ws = mock_websockets_connect + + # Test send when inactive + text_input = BidiTextInputEvent(text="Hello", role="user") + await model.send(text_input) + mock_ws.send.assert_not_called() + + # Test image input (not supported, base64 encoded, no encoding parameter) + await model.start() + image_b64 = base64.b64encode(b"image_bytes").decode('utf-8') + image_input = BidiImageInputEvent( + image=image_b64, + mime_type="image/jpeg", + ) + with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: + await model.send(image_input) + mock_logger.warning.assert_called_with("Image input not supported by OpenAI Realtime API") + + # Test unknown content type + unknown_content = {"unknown_field": "value"} + with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: + await model.send(unknown_content) + assert mock_logger.warning.called + + await model.stop() + + +# Receive Method Tests + + +@pytest.mark.asyncio +async def test_receive_lifecycle_events(mock_websockets_connect, model): + """Test that receive() emits connection start and end events.""" + _, _ = mock_websockets_connect + + await model.start() + + # Get first event + receive_gen = model.receive() + first_event = await anext(receive_gen) + + # First event should be connection start (new TypedEvent format) + assert first_event.get("type") == "bidi_connection_start" + assert first_event.get("connection_id") == model.connection_id + assert first_event.get("model") == model.model + + # Close to trigger session end + await model.stop() + + # Collect remaining events + events = [first_event] + try: + async for event in receive_gen: + events.append(event) + except StopAsyncIteration: + pass + + # Last event should be connection close (new TypedEvent format) + assert events[-1].get("type") == "bidi_connection_close" + + +@pytest.mark.asyncio +async def test_event_conversion(mock_websockets_connect, model): + """Test conversion of all OpenAI event types to standard format.""" + _, _ = mock_websockets_connect + await model.start() + + # Test audio output (now returns list with BidiAudioStreamEvent) + audio_event = { + "type": "response.output_audio.delta", + "delta": base64.b64encode(b"audio_data").decode() + } + converted = model._convert_openai_event(audio_event) + assert isinstance(converted, list) + assert len(converted) == 1 + assert isinstance(converted[0], BidiAudioStreamEvent) + assert converted[0].get("type") == "bidi_audio_stream" + assert converted[0].get("audio") == base64.b64encode(b"audio_data").decode() + assert converted[0].get("format") == "pcm" + + # Test text output (now returns list with BidiTranscriptStreamEvent) + text_event = { + "type": "response.output_text.delta", + "delta": "Hello from OpenAI" + } + converted = model._convert_openai_event(text_event) + assert isinstance(converted, list) + assert len(converted) == 1 + assert isinstance(converted[0], BidiTranscriptStreamEvent) + assert converted[0].get("type") == "bidi_transcript_stream" + assert converted[0].get("text") == "Hello from OpenAI" + assert converted[0].get("role") == "assistant" + assert converted[0].delta == {"text": "Hello from OpenAI"} + assert converted[0].is_final is True + + # Test function call sequence + item_added = { + "type": "response.output_item.added", + "item": { + "type": "function_call", + "call_id": "call-123", + "name": "calculator" + } + } + model._convert_openai_event(item_added) + + args_delta = { + "type": "response.function_call_arguments.delta", + "call_id": "call-123", + "delta": '{"expression": "2+2"}' + } + model._convert_openai_event(args_delta) + + args_done = { + "type": "response.function_call_arguments.done", + "call_id": "call-123" + } + converted = model._convert_openai_event(args_done) + # Now returns list with ToolUseStreamEvent + assert isinstance(converted, list) + assert len(converted) == 1 + # ToolUseStreamEvent has delta and current_tool_use, not a "type" field + assert "delta" in converted[0] + assert "toolUse" in converted[0]["delta"] + tool_use = converted[0]["delta"]["toolUse"] + assert tool_use["toolUseId"] == "call-123" + assert tool_use["name"] == "calculator" + assert tool_use["input"]["expression"] == "2+2" + + # Test voice activity (now returns list with BidiInterruptionEvent for speech_started) + speech_started = { + "type": "input_audio_buffer.speech_started" + } + converted = model._convert_openai_event(speech_started) + assert isinstance(converted, list) + assert len(converted) == 1 + assert isinstance(converted[0], BidiInterruptionEvent) + assert converted[0].get("type") == "bidi_interruption" + assert converted[0].get("reason") == "user_speech" + + # Test response.cancelled event (should return ResponseCompleteEvent with interrupted reason) + response_cancelled = { + "type": "response.cancelled", + "response": { + "id": "resp_123" + } + } + converted = model._convert_openai_event(response_cancelled) + assert isinstance(converted, list) + assert len(converted) == 1 + assert isinstance(converted[0], BidiResponseCompleteEvent) + assert converted[0].get("type") == "bidi_response_complete" + assert converted[0].get("response_id") == "resp_123" + assert converted[0].get("stop_reason") == "interrupted" + + # Test error handling - response_cancel_not_active should be suppressed + error_cancel_not_active = { + "type": "error", + "error": { + "code": "response_cancel_not_active", + "message": "No active response to cancel" + } + } + converted = model._convert_openai_event(error_cancel_not_active) + assert converted is None # Should be suppressed + + # Test error handling - other errors should be logged but return None + error_other = { + "type": "error", + "error": { + "code": "some_other_error", + "message": "Something went wrong" + } + } + converted = model._convert_openai_event(error_other) + assert converted is None + + await model.stop() + + +# Helper Method Tests + + +def test_config_building(model, system_prompt, tool_spec): + """Test building session config with various options.""" + # Test basic config + config_basic = model._build_session_config(None, None) + assert isinstance(config_basic, dict) + assert "instructions" in config_basic + assert "audio" in config_basic + + # Test with system prompt + config_prompt = model._build_session_config(system_prompt, None) + assert config_prompt["instructions"] == system_prompt + + # Test with tools + config_tools = model._build_session_config(None, [tool_spec]) + assert "tools" in config_tools + assert len(config_tools["tools"]) > 0 + + +def test_tool_conversion(model, tool_spec): + """Test tool conversion to OpenAI format.""" + # Test with tools + openai_tools = model._convert_tools_to_openai_format([tool_spec]) + assert len(openai_tools) == 1 + assert openai_tools[0]["type"] == "function" + assert openai_tools[0]["name"] == "calculator" + assert openai_tools[0]["description"] == "Calculate mathematical expressions" + + # Test empty list + openai_empty = model._convert_tools_to_openai_format([]) + assert openai_empty == [] + + +def test_helper_methods(model): + """Test various helper methods.""" + # Test _require_active + assert model._require_active() is False + model._active = True + assert model._require_active() is True + model._active = False + + # Test _create_text_event (now returns BidiTranscriptStreamEvent) + text_event = model._create_text_event("Hello", "user") + assert isinstance(text_event, BidiTranscriptStreamEvent) + assert text_event.get("type") == "bidi_transcript_stream" + assert text_event.get("text") == "Hello" + assert text_event.get("role") == "user" + assert text_event.delta == {"text": "Hello"} + assert text_event.is_final is True + assert text_event.current_transcript == "Hello" + + # Test _create_voice_activity_event (now returns BidiInterruptionEvent for speech_started) + voice_event = model._create_voice_activity_event("speech_started") + assert isinstance(voice_event, BidiInterruptionEvent) + assert voice_event.get("type") == "bidi_interruption" + assert voice_event.get("reason") == "user_speech" + + # Other voice activities return None + assert model._create_voice_activity_event("speech_stopped") is None + + +@pytest.mark.asyncio +async def test_send_event_helper(mock_websockets_connect, model): + """Test _send_event helper method.""" + _, mock_ws = mock_websockets_connect + await model.start() + + test_event = {"type": "test.event", "data": "test"} + await model._send_event(test_event) + + calls = mock_ws.send.call_args_list + last_call = calls[-1] + sent_message = json.loads(last_call[0][0]) + assert sent_message == test_event + + await model.stop() diff --git a/tests/strands/experimental/bidi/types/__init__.py b/tests/strands/experimental/bidi/types/__init__.py new file mode 100644 index 000000000..a1330e552 --- /dev/null +++ b/tests/strands/experimental/bidi/types/__init__.py @@ -0,0 +1 @@ +"""Tests for bidirectional streaming types.""" diff --git a/tests/strands/experimental/bidi/types/test_events.py b/tests/strands/experimental/bidi/types/test_events.py new file mode 100644 index 000000000..0b6419719 --- /dev/null +++ b/tests/strands/experimental/bidi/types/test_events.py @@ -0,0 +1,164 @@ +"""Tests for bidirectional streaming event types. + +This module tests JSON serialization for all bidirectional streaming event types. +""" + +import base64 +import json + +import pytest + +from strands.experimental.bidi.types.events import ( + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionCloseEvent, + BidiConnectionStartEvent, + BidiErrorEvent, + BidiImageInputEvent, + BidiInterruptionEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiUsageEvent, +) + + +@pytest.mark.parametrize( + "event_class,kwargs,expected_type", + [ + # Input events + (BidiTextInputEvent, {"text": "Hello", "role": "user"}, "bidi_text_input"), + ( + BidiAudioInputEvent, + { + "audio": base64.b64encode(b"audio").decode("utf-8"), + "format": "pcm", + "sample_rate": 16000, + "channels": 1, + }, + "bidi_audio_input", + ), + ( + BidiImageInputEvent, + {"image": base64.b64encode(b"image").decode("utf-8"), "mime_type": "image/jpeg"}, + "bidi_image_input", + ), + # Output events + ( + BidiConnectionStartEvent, + {"connection_id": "c1", "model": "m1"}, + "bidi_connection_start", + ), + (BidiResponseStartEvent, {"response_id": "r1"}, "bidi_response_start"), + ( + BidiAudioStreamEvent, + { + "audio": base64.b64encode(b"audio").decode("utf-8"), + "format": "pcm", + "sample_rate": 24000, + "channels": 1, + }, + "bidi_audio_stream", + ), + ( + BidiTranscriptStreamEvent, + { + "delta": {"text": "Hello"}, + "text": "Hello", + "role": "assistant", + "is_final": True, + "current_transcript": "Hello", + }, + "bidi_transcript_stream", + ), + (BidiInterruptionEvent, {"reason": "user_speech"}, "bidi_interruption"), + ( + BidiResponseCompleteEvent, + {"response_id": "r1", "stop_reason": "complete"}, + "bidi_response_complete", + ), + ( + BidiUsageEvent, + {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, + "bidi_usage", + ), + ( + BidiConnectionCloseEvent, + {"connection_id": "c1", "reason": "complete"}, + "bidi_connection_close", + ), + (BidiErrorEvent, {"error": ValueError("test"), "details": None}, "bidi_error"), + ], +) +def test_event_json_serialization(event_class, kwargs, expected_type): + """Test that all event types are JSON serializable and deserializable.""" + # Create event + event = event_class(**kwargs) + + # Verify type field + assert event["type"] == expected_type + + # Serialize to JSON + json_str = json.dumps(event) + print("event_class:", event_class) + print(json_str) + # Deserialize back + data = json.loads(json_str) + + # Verify type preserved + assert data["type"] == expected_type + + # Verify all non-private keys preserved + for key in event.keys(): + if not key.startswith("_"): + assert key in data + + + +def test_transcript_stream_event_delta_pattern(): + """Test that BidiTranscriptStreamEvent follows ModelStreamEvent delta pattern.""" + # Test partial transcript (delta) + partial_event = BidiTranscriptStreamEvent( + delta={"text": "Hello"}, + text="Hello", + role="user", + is_final=False, + current_transcript=None, + ) + + assert partial_event.text == "Hello" + assert partial_event.role == "user" + assert partial_event.is_final is False + assert partial_event.current_transcript is None + assert partial_event.delta == {"text": "Hello"} + + # Test final transcript with accumulated text + final_event = BidiTranscriptStreamEvent( + delta={"text": " world"}, + text=" world", + role="user", + is_final=True, + current_transcript="Hello world", + ) + + assert final_event.text == " world" + assert final_event.role == "user" + assert final_event.is_final is True + assert final_event.current_transcript == "Hello world" + assert final_event.delta == {"text": " world"} + + +def test_transcript_stream_event_extends_model_stream_event(): + """Test that BidiTranscriptStreamEvent is a ModelStreamEvent.""" + from strands.types._events import ModelStreamEvent + + event = BidiTranscriptStreamEvent( + delta={"text": "test"}, + text="test", + role="assistant", + is_final=True, + current_transcript="test", + ) + + assert isinstance(event, ModelStreamEvent) diff --git a/tests_integ/bidi/__init__.py b/tests_integ/bidi/__init__.py new file mode 100644 index 000000000..05da9afcb --- /dev/null +++ b/tests_integ/bidi/__init__.py @@ -0,0 +1 @@ +"""Integration tests for bidirectional streaming agents.""" diff --git a/tests_integ/bidi/conftest.py b/tests_integ/bidi/conftest.py new file mode 100644 index 000000000..0d453818a --- /dev/null +++ b/tests_integ/bidi/conftest.py @@ -0,0 +1,28 @@ +"""Pytest fixtures for bidirectional streaming integration tests.""" + +import logging + +import pytest + +from .generators.audio import AudioGenerator + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="session") +def audio_generator(): + """Provide AudioGenerator instance for tests.""" + return AudioGenerator(region="us-east-1") + + +@pytest.fixture(autouse=True) +def setup_logging(): + """Configure logging for tests.""" + logging.basicConfig( + level=logging.DEBUG, + format="%(levelname)s | %(name)s | %(message)s", + ) + # Reduce noise from some loggers + logging.getLogger("boto3").setLevel(logging.WARNING) + logging.getLogger("botocore").setLevel(logging.WARNING) + logging.getLogger("urllib3").setLevel(logging.WARNING) diff --git a/tests_integ/bidi/context.py b/tests_integ/bidi/context.py new file mode 100644 index 000000000..4a5278a62 --- /dev/null +++ b/tests_integ/bidi/context.py @@ -0,0 +1,365 @@ +"""Test context manager for bidirectional streaming tests. + +Provides a high-level interface for testing bidirectional streaming agents +with continuous background threads that mimic real-world usage patterns. +""" + +import asyncio +import base64 +import logging +import time +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from strands.experimental.bidi.agent.agent import BidiAgent + from .generators.audio import AudioGenerator + +logger = logging.getLogger(__name__) + +# Constants for timing and buffering +QUEUE_POLL_TIMEOUT = 0.05 # 50ms - balance between responsiveness and CPU usage +SILENCE_INTERVAL = 0.05 # 50ms - send silence every 50ms when queue empty +AUDIO_CHUNK_DELAY = 0.01 # 10ms - small delay between audio chunks +WAIT_POLL_INTERVAL = 0.1 # 100ms - how often to check for response completion + + +class BidirectionalTestContext: + """Manages threads and generators for bidirectional streaming tests. + + Mimics real-world usage with continuous background threads: + - Audio input thread (microphone simulation with silence padding) + - Event collection thread (captures all model outputs) + + Generators feed data into threads via queues for natural conversation flow. + + Example: + async with BidirectionalTestContext(agent, audio_generator) as ctx: + await ctx.say("What is 5 plus 3?") + await ctx.wait_for_response() + assert "8" in " ".join(ctx.get_text_outputs()) + """ + + def __init__( + self, + agent: "BidiAgent", + audio_generator: "AudioGenerator | None" = None, + silence_chunk_size: int = 1024, + audio_chunk_size: int = 1024, + ): + """Initialize test context. + + Args: + agent: BidiAgent instance. + audio_generator: AudioGenerator for text-to-speech. + silence_chunk_size: Size of silence chunks in bytes. + audio_chunk_size: Size of audio chunks for streaming. + """ + self.agent = agent + self.audio_generator = audio_generator + self.silence_chunk_size = silence_chunk_size + self.audio_chunk_size = audio_chunk_size + + # Queue for thread communication + self.input_queue = asyncio.Queue() # Handles both audio and text input + + # Event storage (thread-safe) + self._event_queue = asyncio.Queue() # Events from collection thread + self.events = [] # Cached events for test access + self.last_event_time = None + + # Control flags + self.active = False + self.threads = [] + + async def __aenter__(self): + """Start context manager, agent session, and background threads.""" + # Start agent session + await self.agent.start() + logger.debug("Agent session started") + + await self.start() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Stop context manager, cleanup threads, and end agent session.""" + # End agent session FIRST - this will cause receive() to exit cleanly + if self.agent._agent_loop and self.agent._agent_loop.active: + await self.agent.stop() + logger.debug("Agent session stopped") + + # Then stop the context threads + await self.stop() + + return False + + async def start(self): + """Start all background threads.""" + self.active = True + self.last_event_time = time.monotonic() + + self.threads = [ + asyncio.create_task(self._input_thread()), + asyncio.create_task(self._event_collection_thread()), + ] + + logger.debug("Test context started with %d threads", len(self.threads)) + + async def stop(self): + """Stop all threads gracefully.""" + if not self.active: + logger.debug("stop() called but already stopped") + return + + logger.debug("stop() called - stopping threads") + self.active = False + + # Cancel all threads + for task in self.threads: + if not task.done(): + task.cancel() + + # Wait for cancellation + await asyncio.gather(*self.threads, return_exceptions=True) + + logger.debug("Test context stopped") + + # === User-facing methods === + + async def say(self, text: str): + """Convert text to audio and queue audio chunks to be sent to model. + + Args: + text: Text to convert to speech and send as audio. + + Raises: + ValueError: If audio generator is not available. + """ + if not self.audio_generator: + raise ValueError( + "Audio generator not available. Pass audio_generator to BidirectionalTestContext." + ) + + # Generate audio via Polly + audio_data = await self.audio_generator.generate_audio(text) + + # Split into chunks and queue each chunk + for i in range(0, len(audio_data), self.audio_chunk_size): + chunk = audio_data[i : i + self.audio_chunk_size] + chunk_event = self.audio_generator.create_audio_input_event(chunk) + await self.input_queue.put({"type": "audio_chunk", "data": chunk_event}) + + logger.debug(f"Queued {len(audio_data)} bytes of audio for: {text[:50]}...") + + async def send(self, data: str | dict) -> None: + """Send data directly to model (text, image, etc.). + + Args: + data: Data to send to model. Can be: + - str: Text input + - dict: Custom event (e.g., image, audio) + """ + await self.input_queue.put({"type": "direct", "data": data}) + logger.debug(f"Queued direct send: {type(data).__name__}") + + async def wait_for_response( + self, + timeout: float = 15.0, + silence_threshold: float = 2.0, + min_events: int = 1, + ): + """Wait for model to finish responding. + + Uses silence detection (no events for silence_threshold seconds) + combined with minimum event count to determine response completion. + + Args: + timeout: Maximum time to wait in seconds. + silence_threshold: Seconds of silence to consider response complete. + min_events: Minimum events before silence detection activates. + """ + start_time = time.monotonic() + initial_event_count = len(self.get_events()) # Drain queue + + while time.monotonic() - start_time < timeout: + # Drain queue to get latest events + current_events = self.get_events() + + # Check if we have minimum events + if len(current_events) - initial_event_count >= min_events: + # Check silence + elapsed_since_event = time.monotonic() - self.last_event_time + if elapsed_since_event >= silence_threshold: + logger.debug( + f"Response complete: {len(current_events) - initial_event_count} events, " + f"{elapsed_since_event:.1f}s silence" + ) + return + + await asyncio.sleep(WAIT_POLL_INTERVAL) + + logger.warning(f"Response timeout after {timeout}s") + + def get_events(self, event_type: str | None = None) -> list[dict]: + """Get collected events, optionally filtered by type. + + Drains the event queue and caches events for subsequent calls. + + Args: + event_type: Optional event type to filter by (e.g., "textOutput"). + + Returns: + List of events, filtered if event_type specified. + """ + # Drain queue into cache (non-blocking) + while not self._event_queue.empty(): + try: + event = self._event_queue.get_nowait() + self.events.append(event) + self.last_event_time = time.monotonic() + except asyncio.QueueEmpty: + break + + if event_type: + return [e for e in self.events if event_type in e] + return self.events.copy() + + def get_text_outputs(self) -> list[str]: + """Extract text outputs from collected events. + + Handles both new TypedEvent format and legacy event formats. + + Returns: + List of text content strings. + """ + texts = [] + for event in self.get_events(): # Drain queue first + # Handle new TypedEvent format (bidi_transcript_stream) + if event.get("type") == "bidi_transcript_stream": + text = event.get("text", "") + if text: + texts.append(text) + # Handle legacy textOutput events (Nova Sonic, OpenAI) + elif "textOutput" in event: + text = event["textOutput"].get("text", "") + if text: + texts.append(text) + # Handle legacy transcript events (Gemini Live) + elif "transcript" in event: + text = event["transcript"].get("text", "") + if text: + texts.append(text) + return texts + + def get_audio_outputs(self) -> list[bytes]: + """Extract audio outputs from collected events. + + Returns: + List of audio data bytes. + """ + # Drain queue first to get latest events + events = self.get_events() + audio_data = [] + for event in events: + # Handle new TypedEvent format (bidi_audio_stream) + if event.get("type") == "bidi_audio_stream": + audio_b64 = event.get("audio") + if audio_b64: + # Decode base64 to bytes + audio_data.append(base64.b64decode(audio_b64)) + # Handle legacy audioOutput events + elif "audioOutput" in event: + data = event["audioOutput"].get("audioData") + if data: + audio_data.append(data) + return audio_data + + def get_tool_uses(self) -> list[dict]: + """Extract tool use events from collected events. + + Returns: + List of tool use events. + """ + # Drain queue first to get latest events + events = self.get_events() + return [event["toolUse"] for event in events if "toolUse" in event] + + def has_interruption(self) -> bool: + """Check if any interruption was detected. + + Returns: + True if interruption detected in events. + """ + return any("interruptionDetected" in event for event in self.events) + + def clear_events(self): + """Clear collected events (useful for multi-turn tests).""" + self.events.clear() + logger.debug("Events cleared") + + # === Background threads === + + async def _input_thread(self): + """Continuously handle input to model. + + - Sends queued audio chunks immediately + - Sends silence chunks periodically when queue is empty (simulates microphone) + - Sends direct data to model + """ + try: + logger.debug(f"Input thread starting, active={self.active}") + while self.active: + try: + # Check for queued input (non-blocking with short timeout) + input_item = await asyncio.wait_for(self.input_queue.get(), timeout=QUEUE_POLL_TIMEOUT) + + if input_item["type"] == "audio_chunk": + # Send pre-generated audio chunk + await self.agent.send(input_item["data"]) + await asyncio.sleep(AUDIO_CHUNK_DELAY) + + elif input_item["type"] == "direct": + # Send data directly to agent + await self.agent.send(input_item["data"]) + data_repr = str(input_item["data"])[:50] if isinstance(input_item["data"], str) else type(input_item["data"]).__name__ + logger.debug(f"Sent direct: {data_repr}") + + except asyncio.TimeoutError: + # No input queued - send silence chunk to simulate continuous microphone input + if self.audio_generator: + silence = self._generate_silence_chunk() + await self.agent.send(silence) + await asyncio.sleep(SILENCE_INTERVAL) + + except asyncio.CancelledError: + logger.debug("Input thread cancelled") + raise # Re-raise to properly propagate cancellation + except Exception as e: + logger.error(f"Input thread error: {e}", exc_info=True) + finally: + logger.debug(f"Input thread stopped, active={self.active}") + + async def _event_collection_thread(self): + """Continuously collect events from model.""" + try: + async for event in self.agent.receive(): + if not self.active: + break + + # Thread-safe: put in queue instead of direct append + await self._event_queue.put(event) + logger.debug(f"Event collected: {list(event.keys())}") + + except asyncio.CancelledError: + logger.debug("Event collection thread cancelled") + raise # Re-raise to properly propagate cancellation + except Exception as e: + logger.error(f"Event collection thread error: {e}") + + def _generate_silence_chunk(self) -> dict: + """Generate silence chunk for background audio. + + Returns: + BidiAudioInputEvent with silence data. + """ + silence = b"\x00" * self.silence_chunk_size + return self.audio_generator.create_audio_input_event(silence) diff --git a/tests_integ/bidi/generators/__init__.py b/tests_integ/bidi/generators/__init__.py new file mode 100644 index 000000000..1f13f0564 --- /dev/null +++ b/tests_integ/bidi/generators/__init__.py @@ -0,0 +1 @@ +"""Test data generators for bidirectional streaming integration tests.""" diff --git a/tests_integ/bidi/generators/audio.py b/tests_integ/bidi/generators/audio.py new file mode 100644 index 000000000..75c17a1e3 --- /dev/null +++ b/tests_integ/bidi/generators/audio.py @@ -0,0 +1,159 @@ +"""Audio generation utilities using Amazon Polly for test audio input. + +Provides text-to-speech conversion for generating realistic audio test data +without requiring physical audio devices or pre-recorded files. +""" + +import base64 +import hashlib +import logging +from pathlib import Path +from typing import Literal + +import boto3 + +logger = logging.getLogger(__name__) + +# Audio format constants matching Nova Sonic requirements +NOVA_SONIC_SAMPLE_RATE = 16000 +NOVA_SONIC_CHANNELS = 1 +NOVA_SONIC_FORMAT = "pcm" + +# Polly configuration +POLLY_VOICE_ID = "Matthew" # US English male voice +POLLY_ENGINE = "neural" # Higher quality neural engine + +# Cache directory for generated audio +CACHE_DIR = Path(__file__).parent.parent / ".audio_cache" + + +class AudioGenerator: + """Generate test audio using Amazon Polly with caching.""" + + def __init__(self, region: str = "us-east-1"): + """Initialize audio generator with Polly client. + + Args: + region: AWS region for Polly service. + """ + self.polly_client = boto3.client("polly", region_name=region) + self._ensure_cache_dir() + + def _ensure_cache_dir(self) -> None: + """Create cache directory if it doesn't exist.""" + CACHE_DIR.mkdir(parents=True, exist_ok=True) + + def _get_cache_key(self, text: str, voice_id: str) -> str: + """Generate cache key from text and voice.""" + content = f"{text}:{voice_id}".encode("utf-8") + return hashlib.md5(content).hexdigest() + + def _get_cache_path(self, cache_key: str) -> Path: + """Get cache file path for given key.""" + return CACHE_DIR / f"{cache_key}.pcm" + + async def generate_audio( + self, + text: str, + voice_id: str = POLLY_VOICE_ID, + use_cache: bool = True, + ) -> bytes: + """Generate audio from text using Polly with caching. + + Args: + text: Text to convert to speech. + voice_id: Polly voice ID to use. + use_cache: Whether to use cached audio if available. + + Returns: + Raw PCM audio bytes at 16kHz mono (Nova Sonic format). + """ + # Check cache first + if use_cache: + cache_key = self._get_cache_key(text, voice_id) + cache_path = self._get_cache_path(cache_key) + + if cache_path.exists(): + logger.debug(f"Using cached audio for: {text[:50]}...") + return cache_path.read_bytes() + + # Generate audio with Polly + logger.debug(f"Generating audio with Polly: {text[:50]}...") + + try: + response = self.polly_client.synthesize_speech( + Text=text, + OutputFormat="pcm", # Raw PCM format + VoiceId=voice_id, + Engine=POLLY_ENGINE, + SampleRate=str(NOVA_SONIC_SAMPLE_RATE), + ) + + # Read audio data + audio_data = response["AudioStream"].read() + + # Cache for future use + if use_cache: + cache_path.write_bytes(audio_data) + logger.debug(f"Cached audio: {cache_path}") + + return audio_data + + except Exception as e: + logger.error(f"Polly audio generation failed: {e}") + raise + + def create_audio_input_event( + self, + audio_data: bytes, + format: Literal["pcm", "wav", "opus", "mp3"] = NOVA_SONIC_FORMAT, + sample_rate: int = NOVA_SONIC_SAMPLE_RATE, + channels: int = NOVA_SONIC_CHANNELS, + ) -> dict: + """Create BidiAudioInputEvent from raw audio data. + + Args: + audio_data: Raw audio bytes. + format: Audio format. + sample_rate: Sample rate in Hz. + channels: Number of audio channels. + + Returns: + BidiAudioInputEvent dict ready for agent.send(). + """ + # Convert bytes to base64 string for JSON compatibility + audio_b64 = base64.b64encode(audio_data).decode('utf-8') + + return { + "type": "bidi_audio_input", + "audio": audio_b64, + "format": format, + "sample_rate": sample_rate, + "channels": channels, + } + + def clear_cache(self) -> None: + """Clear all cached audio files.""" + if CACHE_DIR.exists(): + for cache_file in CACHE_DIR.glob("*.pcm"): + cache_file.unlink() + logger.info("Audio cache cleared") + + +# Convenience function for quick audio generation +async def generate_test_audio(text: str, use_cache: bool = True) -> dict: + """Generate test audio input event from text. + + Convenience function that creates an AudioGenerator and returns + a ready-to-use BidiAudioInputEvent. + + Args: + text: Text to convert to speech. + use_cache: Whether to use cached audio. + + Returns: + BidiAudioInputEvent dict ready for agent.send(). + """ + generator = AudioGenerator() + audio_data = await generator.generate_audio(text, use_cache=use_cache) + return generator.create_audio_input_event(audio_data) diff --git a/tests_integ/bidi/test_bidirectional_agent.py b/tests_integ/bidi/test_bidirectional_agent.py new file mode 100644 index 000000000..594379b64 --- /dev/null +++ b/tests_integ/bidi/test_bidirectional_agent.py @@ -0,0 +1,220 @@ +"""Parameterized integration tests for bidirectional streaming. + +Tests fundamental functionality across multiple model providers (Nova Sonic, OpenAI, etc.) +including multi-turn conversations, audio I/O, text transcription, and tool execution. + +This demonstrates the provider-agnostic design of the bidirectional streaming system. +""" + +import asyncio +import logging +import os + +import pytest + +from strands import tool +from strands.experimental.bidi.agent.agent import BidiAgent +from strands.experimental.bidi.models.novasonic import BidiNovaSonicModel +from strands.experimental.bidi.models.openai import BidiOpenAIRealtimeModel +from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel + +from .context import BidirectionalTestContext + +logger = logging.getLogger(__name__) + + +# Simple calculator tool for testing +@tool +def calculator(operation: str, x: float, y: float) -> float: + """Perform basic arithmetic operations. + + Args: + operation: The operation to perform (add, subtract, multiply, divide) + x: First number + y: Second number + + Returns: + Result of the operation + """ + if operation == "add": + return x + y + elif operation == "subtract": + return x - y + elif operation == "multiply": + return x * y + elif operation == "divide": + if y == 0: + raise ValueError("Cannot divide by zero") + return x / y + else: + raise ValueError(f"Unknown operation: {operation}") + + +# Provider configurations +PROVIDER_CONFIGS = { + "nova_sonic": { + "model_class": BidiNovaSonicModel, + "model_kwargs": {"region": "us-east-1"}, + "silence_duration": 2.5, # Nova Sonic needs 2+ seconds of silence + "env_vars": ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], + "skip_reason": "AWS credentials not available", + }, + "openai": { + "model_class": BidiOpenAIRealtimeModel, + "model_kwargs": { + "model": "gpt-4o-realtime-preview-2024-12-17", + "session": { + "output_modalities": ["audio"], # OpenAI only supports audio OR text, not both + "audio": { + "input": { + "format": {"type": "audio/pcm", "rate": 24000}, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "silence_duration_ms": 700, + }, + }, + "output": {"format": {"type": "audio/pcm", "rate": 24000}, "voice": "alloy"}, + }, + }, + }, + "silence_duration": 1.0, # OpenAI has faster VAD + "env_vars": ["OPENAI_API_KEY"], + "skip_reason": "OPENAI_API_KEY not available", + }, + "gemini_live": { + "model_class": BidiGeminiLiveModel, + "model_kwargs": { + # Uses default model and config (audio output + transcription enabled) + }, + "silence_duration": 1.5, # Gemini has good VAD, similar to OpenAI + "env_vars": ["GOOGLE_AI_API_KEY"], + "skip_reason": "GOOGLE_AI_API_KEY not available", + }, +} + + +def check_provider_available(provider_name: str) -> tuple[bool, str]: + """Check if a provider's credentials are available. + + Args: + provider_name: Name of the provider to check. + + Returns: + Tuple of (is_available, skip_reason). + """ + config = PROVIDER_CONFIGS[provider_name] + env_vars = config["env_vars"] + + missing_vars = [var for var in env_vars if not os.getenv(var)] + + if missing_vars: + return False, f"{config['skip_reason']}: {', '.join(missing_vars)}" + + return True, "" + + +@pytest.fixture(params=list(PROVIDER_CONFIGS.keys())) +def provider_config(request): + """Provide configuration for each model provider. + + This fixture is parameterized to run tests against all available providers. + """ + provider_name = request.param + config = PROVIDER_CONFIGS[provider_name] + + # Check if provider is available + is_available, skip_reason = check_provider_available(provider_name) + if not is_available: + pytest.skip(skip_reason) + + return { + "name": provider_name, + **config, + } + + +@pytest.fixture +def agent_with_calculator(provider_config): + """Provide bidirectional agent with calculator tool for the given provider. + + Note: Session lifecycle (start/end) is handled by BidirectionalTestContext. + """ + model_class = provider_config["model_class"] + model_kwargs = provider_config["model_kwargs"] + + model = model_class(**model_kwargs) + return BidiAgent( + model=model, + tools=[calculator], + system_prompt="You are a helpful assistant with access to a calculator tool. Keep responses brief.", + ) + +@pytest.mark.asyncio +async def test_bidirectional_agent(agent_with_calculator, audio_generator, provider_config): + """Test multi-turn conversation with follow-up questions across providers. + + This test runs against all configured providers (Nova Sonic, OpenAI, etc.) + to validate provider-agnostic functionality. + + Validates: + - Session lifecycle (start/end via context manager) + - Audio input streaming + - Speech-to-text transcription + - Tool execution (calculator) + - Multi-turn conversation flow + - Text-to-speech audio output + """ + provider_name = provider_config["name"] + silence_duration = provider_config["silence_duration"] + + logger.info(f"Testing provider: {provider_name}") + + async with BidirectionalTestContext(agent_with_calculator, audio_generator) as ctx: + # Turn 1: Simple greeting to test basic audio I/O + await ctx.say("Hello, can you hear me?") + # Wait for silence to trigger provider's VAD/silence detection + await asyncio.sleep(silence_duration) + await ctx.wait_for_response() + + text_outputs_turn1 = ctx.get_text_outputs() + all_text_turn1 = " ".join(text_outputs_turn1).lower() + + # Validate turn 1 - just check we got a response + assert len(text_outputs_turn1) > 0, ( + f"[{provider_name}] No text output received in turn 1" + ) + + logger.info(f"[{provider_name}] ✓ Turn 1 complete: received response") + logger.info(f"[{provider_name}] Response: {text_outputs_turn1[0][:100]}...") + + # Turn 2: Follow-up to test multi-turn conversation + await ctx.say("What's your name?") + # Wait for silence to trigger provider's VAD/silence detection + await asyncio.sleep(silence_duration) + await ctx.wait_for_response() + + text_outputs_turn2 = ctx.get_text_outputs() + + # Validate turn 2 - check we got more responses + assert len(text_outputs_turn2) > len(text_outputs_turn1), ( + f"[{provider_name}] No new text output in turn 2" + ) + + logger.info(f"[{provider_name}] ✓ Turn 2 complete: multi-turn conversation works") + logger.info(f"[{provider_name}] Total responses: {len(text_outputs_turn2)}") + + # Validate full conversation + # Validate audio outputs + audio_outputs = ctx.get_audio_outputs() + assert len(audio_outputs) > 0, f"[{provider_name}] No audio output received" + total_audio_bytes = sum(len(audio) for audio in audio_outputs) + + # Summary + logger.info("=" * 60) + logger.info(f"[{provider_name}] ✓ Multi-turn conversation test PASSED") + logger.info(f" Provider: {provider_name}") + logger.info(f" Total events: {len(ctx.get_events())}") + logger.info(f" Text responses: {len(text_outputs_turn2)}") + logger.info(f" Audio chunks: {len(audio_outputs)} ({total_audio_bytes:,} bytes)") + logger.info("=" * 60) diff --git a/tests_integ/bidi/wrappers/__init__.py b/tests_integ/bidi/wrappers/__init__.py new file mode 100644 index 000000000..6b8a64984 --- /dev/null +++ b/tests_integ/bidi/wrappers/__init__.py @@ -0,0 +1,4 @@ +"""Wrappers for bidirectional streaming integration tests. + +Includes fault injection and other transparent wrappers around real implementations. +""" From fca4a528a7a332a42a75854e515dc8be46cc79fc Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 12 Nov 2025 13:19:05 -0500 Subject: [PATCH 111/242] Update naming of directory and file --- .../bidi/models/bidirectional_model.py | 108 --- .../bidirectional_streaming/__init__.py | 79 -- .../bidirectional_streaming/agent/__init__.py | 5 - .../bidirectional_streaming/agent/agent.py | 447 ----------- .../bidirectional_streaming/agent/loop.py | 162 ---- .../bidirectional_streaming/io/__init__.py | 6 - .../bidirectional_streaming/io/audio.py | 161 ---- .../bidirectional_streaming/io/text.py | 31 - .../models/__init__.py | 13 - .../models/bidirectional_model.py | 108 --- .../models/gemini_live.py | 539 ------------- .../models/novasonic.py | 743 ------------------ .../bidirectional_streaming/models/openai.py | 653 --------------- .../scripts/test_bidi.py | 38 - .../scripts/test_bidi_novasonic.py | 256 ------ .../scripts/test_bidi_openai.py | 324 -------- .../scripts/test_gemini_live.py | 363 --------- .../bidirectional_streaming/types/__init__.py | 57 -- .../bidirectional_streaming/types/agent.py | 10 - .../bidirectional_streaming/types/events.py | 521 ------------ .../bidirectional_streaming/types/io.py | 57 -- .../bidirectional_streaming/__init__.py | 1 - .../models/__init__.py | 1 - .../models/test_gemini_live.py | 487 ------------ .../models/test_novasonic.py | 458 ----------- .../models/test_openai_realtime.py | 538 ------------- .../bidirectional_streaming/types/__init__.py | 1 - .../types/test_events.py | 164 ---- .../bidirectional_streaming/__init__.py | 1 - .../bidirectional_streaming/conftest.py | 28 - .../bidirectional_streaming/context.py | 365 --------- .../generators/__init__.py | 1 - .../generators/audio.py | 159 ---- .../test_bidirectional_agent.py | 220 ------ .../wrappers/__init__.py | 4 - 35 files changed, 7109 deletions(-) delete mode 100644 src/strands/experimental/bidi/models/bidirectional_model.py delete mode 100644 src/strands/experimental/bidirectional_streaming/__init__.py delete mode 100644 src/strands/experimental/bidirectional_streaming/agent/__init__.py delete mode 100644 src/strands/experimental/bidirectional_streaming/agent/agent.py delete mode 100644 src/strands/experimental/bidirectional_streaming/agent/loop.py delete mode 100644 src/strands/experimental/bidirectional_streaming/io/__init__.py delete mode 100644 src/strands/experimental/bidirectional_streaming/io/audio.py delete mode 100644 src/strands/experimental/bidirectional_streaming/io/text.py delete mode 100644 src/strands/experimental/bidirectional_streaming/models/__init__.py delete mode 100644 src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py delete mode 100644 src/strands/experimental/bidirectional_streaming/models/gemini_live.py delete mode 100644 src/strands/experimental/bidirectional_streaming/models/novasonic.py delete mode 100644 src/strands/experimental/bidirectional_streaming/models/openai.py delete mode 100644 src/strands/experimental/bidirectional_streaming/scripts/test_bidi.py delete mode 100644 src/strands/experimental/bidirectional_streaming/scripts/test_bidi_novasonic.py delete mode 100644 src/strands/experimental/bidirectional_streaming/scripts/test_bidi_openai.py delete mode 100644 src/strands/experimental/bidirectional_streaming/scripts/test_gemini_live.py delete mode 100644 src/strands/experimental/bidirectional_streaming/types/__init__.py delete mode 100644 src/strands/experimental/bidirectional_streaming/types/agent.py delete mode 100644 src/strands/experimental/bidirectional_streaming/types/events.py delete mode 100644 src/strands/experimental/bidirectional_streaming/types/io.py delete mode 100644 tests/strands/experimental/bidirectional_streaming/__init__.py delete mode 100644 tests/strands/experimental/bidirectional_streaming/models/__init__.py delete mode 100644 tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py delete mode 100644 tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py delete mode 100644 tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py delete mode 100644 tests/strands/experimental/bidirectional_streaming/types/__init__.py delete mode 100644 tests/strands/experimental/bidirectional_streaming/types/test_events.py delete mode 100644 tests_integ/bidirectional_streaming/__init__.py delete mode 100644 tests_integ/bidirectional_streaming/conftest.py delete mode 100644 tests_integ/bidirectional_streaming/context.py delete mode 100644 tests_integ/bidirectional_streaming/generators/__init__.py delete mode 100644 tests_integ/bidirectional_streaming/generators/audio.py delete mode 100644 tests_integ/bidirectional_streaming/test_bidirectional_agent.py delete mode 100644 tests_integ/bidirectional_streaming/wrappers/__init__.py diff --git a/src/strands/experimental/bidi/models/bidirectional_model.py b/src/strands/experimental/bidi/models/bidirectional_model.py deleted file mode 100644 index d3c3aa7ec..000000000 --- a/src/strands/experimental/bidi/models/bidirectional_model.py +++ /dev/null @@ -1,108 +0,0 @@ -"""Bidirectional streaming model interface. - -Defines the abstract interface for models that support real-time bidirectional -communication with persistent connections. Unlike traditional request-response -models, bidirectional models maintain an open connection for streaming audio, -text, and tool interactions. - -Features: -- Persistent connection management with connect/close lifecycle -- Real-time bidirectional communication (send and receive simultaneously) -- Provider-agnostic event normalization -- Support for audio, text, image, and tool result streaming -""" - -import logging -from typing import AsyncIterable, Protocol, Union - -from ....types._events import ToolResultEvent -from ....types.content import Messages -from ....types.tools import ToolSpec -from ..types.events import ( - BidiAudioInputEvent, - BidiImageInputEvent, - BidiInputEvent, - BidiOutputEvent, - BidiTextInputEvent, -) - -logger = logging.getLogger(__name__) - - -class BidiModel(Protocol): - """Protocol for bidirectional streaming models. - - This interface defines the contract for models that support persistent streaming - connections with real-time audio and text communication. Implementations handle - provider-specific protocols while exposing a standardized event-based API. - """ - - async def start( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs, - ) -> None: - """Establish a persistent streaming connection with the model. - - Opens a bidirectional connection that remains active for real-time communication. - The connection supports concurrent sending and receiving of events until explicitly - closed. Must be called before any send() or receive() operations. - - Args: - system_prompt: System instructions to configure model behavior. - tools: Tool specifications that the model can invoke during the conversation. - messages: Initial conversation history to provide context. - **kwargs: Provider-specific configuration options. - """ - ... - - async def stop(self) -> None: - """Close the streaming connection and release resources. - - Terminates the active bidirectional connection and cleans up any associated - resources such as network connections, buffers, or background tasks. After - calling close(), the model instance cannot be used until start() is called again. - """ - ... - - async def receive(self) -> AsyncIterable[BidiOutputEvent]: - """Receive streaming events from the model. - - Continuously yields events from the model as they arrive over the connection. - Events are normalized to a provider-agnostic format for uniform processing. - This method should be called in a loop or async task to process model responses. - - The stream continues until the connection is closed or an error occurs. - - Yields: - BidiOutputEvent: Standardized event objects containing audio output, - transcripts, tool calls, or control signals. - """ - ... - - async def send( - self, - content: BidiInputEvent | ToolResultEvent, - ) -> None: - """Send content to the model over the active connection. - - Transmits user input or tool results to the model during an active streaming - session. Supports multiple content types including text, audio, images, and - tool execution results. Can be called multiple times during a conversation. - - Args: - content: The content to send. Must be one of: - - BidiTextInputEvent: Text message from the user - - BidiAudioInputEvent: Audio data for speech input - - BidiImageInputEvent: Image data for visual understanding - - ToolResultEvent: Result from a tool execution - - Example: - await model.send(BidiTextInputEvent(text="Hello", role="user")) - await model.send(BidiAudioInputEvent(audio=bytes, format="pcm", sample_rate=16000, channels=1)) - await model.send(BidiImageInputEvent(image=bytes, mime_type="image/jpeg", encoding="raw")) - await model.send(ToolResultEvent(tool_result)) - """ - ... diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py deleted file mode 100644 index 033a4bb78..000000000 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Bidirectional streaming package.""" - -# Main components - Primary user interface -from .agent.agent import BidiAgent - -# IO channels - Hardware abstraction -from .io.audio import BidiAudioIO - -# Model interface (for custom implementations) -from .models.bidirectional_model import BidiModel - -# Model providers - What users need to create models -from .models.gemini_live import BidiGeminiLiveModel -from .models.novasonic import BidiNovaSonicModel -from .models.openai import BidiOpenAIRealtimeModel - -# Event types - For type hints and event handling -from .types.events import ( - BidiAudioInputEvent, - BidiAudioStreamEvent, - BidiConnectionCloseEvent, - BidiConnectionStartEvent, - BidiErrorEvent, - BidiImageInputEvent, - BidiInputEvent, - BidiInterruptionEvent, - ModalityUsage, - BidiUsageEvent, - BidiOutputEvent, - BidiResponseCompleteEvent, - BidiResponseStartEvent, - BidiTextInputEvent, - BidiTranscriptStreamEvent, -) - -# Re-export standard agent events for tool handling -from ...types._events import ( - ToolResultEvent, - ToolStreamEvent, - ToolUseStreamEvent, -) - -__all__ = [ - # Main interface - "BidiAgent", - # IO channels - "BidiAudioIO", - # Model providers - "BidiGeminiLiveModel", - "BidiNovaSonicModel", - "BidiOpenAIRealtimeModel", - - # Input Event types - "BidiTextInputEvent", - "BidiAudioInputEvent", - "BidiImageInputEvent", - "BidiInputEvent", - - # Output Event types - "BidiConnectionStartEvent", - "BidiConnectionCloseEvent", - "BidiResponseStartEvent", - "BidiResponseCompleteEvent", - "BidiAudioStreamEvent", - "BidiTranscriptStreamEvent", - "BidiInterruptionEvent", - "BidiUsageEvent", - "ModalityUsage", - "BidiErrorEvent", - "BidiOutputEvent", - - # Tool Event types (reused from standard agent) - "ToolUseStreamEvent", - "ToolResultEvent", - "ToolStreamEvent", - - # Model interface - "BidiModel", -] diff --git a/src/strands/experimental/bidirectional_streaming/agent/__init__.py b/src/strands/experimental/bidirectional_streaming/agent/__init__.py deleted file mode 100644 index 564973099..000000000 --- a/src/strands/experimental/bidirectional_streaming/agent/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Bidirectional agent for real-time streaming conversations.""" - -from .agent import BidiAgent - -__all__ = ["BidiAgent"] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py deleted file mode 100644 index eab909449..000000000 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ /dev/null @@ -1,447 +0,0 @@ -"""Bidirectional Agent for real-time streaming conversations. - -Provides real-time audio and text interaction through persistent streaming connections. -Unlike traditional request-response patterns, this agent maintains long-running -conversations where users can interrupt, provide additional input, and receive -continuous responses including audio output. - -Key capabilities: -- Persistent conversation connections with concurrent processing -- Real-time audio input/output streaming -- Automatic interruption detection and tool execution -- Event-driven communication with model providers -""" - -import asyncio -import json -import logging -from typing import Any, AsyncIterable - -from .... import _identifier -from ....tools.caller import _ToolCaller -from ....tools.executors import ConcurrentToolExecutor -from ....tools.executors._executor import ToolExecutor -from ....tools.registry import ToolRegistry -from ....tools.watcher import ToolWatcher -from ....types.content import Message, Messages -from ....types.tools import ToolResult, ToolUse, AgentTool - -from .loop import _BidiAgentLoop -from ..models.bidirectional_model import BidiModel -from ..models.novasonic import BidiNovaSonicModel -from ..types.agent import BidiAgentInput -from ..types.events import BidiAudioInputEvent, BidiImageInputEvent, BidiTextInputEvent, BidiInputEvent, BidiOutputEvent -from ..types.io import BidiInput, BidiOutput -from ...tools import ToolProvider - -logger = logging.getLogger(__name__) - -_DEFAULT_AGENT_NAME = "Strands Agents" -_DEFAULT_AGENT_ID = "default" - - -class BidiAgent: - """Agent for bidirectional streaming conversations. - - Enables real-time audio and text interaction with AI models through persistent - connections. Supports concurrent tool execution and interruption handling. - """ - - def __init__( - self, - model: BidiModel| str | None = None, - tools: list[str| AgentTool| ToolProvider]| None = None, - system_prompt: str | None = None, - messages: Messages | None = None, - record_direct_tool_call: bool = True, - load_tools_from_directory: bool = False, - agent_id: str | None = None, - name: str | None = None, - tool_executor: ToolExecutor | None = None, - description: str | None = None, - **kwargs: Any, - ): - """Initialize bidirectional agent. - - Args: - model: BidiModel instance, string model_id, or None for default detection. - tools: Optional list of tools with flexible format support. - system_prompt: Optional system prompt for conversations. - messages: Optional conversation history to initialize with. - record_direct_tool_call: Whether to record direct tool calls in message history. - load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory. - agent_id: Optional ID for the agent, useful for connection management and multi-agent scenarios. - name: Name of the Agent. - tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). - description: Description of what the Agent does. - **kwargs: Additional configuration for future extensibility. - - Raises: - ValueError: If model configuration is invalid. - TypeError: If model type is unsupported. - """ - self.model = ( - BidiNovaSonicModel() - if not model - else BidiNovaSonicModel(model_id=model) - if isinstance(model, str) - else model - ) - self.system_prompt = system_prompt - self.messages = messages or [] - - # Agent identification - self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) - self.name = name or _DEFAULT_AGENT_NAME - self.description = description - - # Tool execution configuration - self.record_direct_tool_call = record_direct_tool_call - self.load_tools_from_directory = load_tools_from_directory - - # Initialize tool registry - self.tool_registry = ToolRegistry() - - if tools is not None: - self.tool_registry.process_tools(tools) - - self.tool_registry.initialize_tools(self.load_tools_from_directory) - - # Initialize tool watcher if directory loading is enabled - if self.load_tools_from_directory: - self.tool_watcher = ToolWatcher(tool_registry=self.tool_registry) - - # Initialize tool executor - self.tool_executor = tool_executor or ConcurrentToolExecutor() - - # Initialize other components - self._tool_caller = _ToolCaller(self) - - self._current_adapters = [] # Track adapters for cleanup - - self._loop = _BidiAgentLoop(self) - - @property - def tool(self) -> _ToolCaller: - """Call tool as a function. - - Returns: - ToolCaller for method-style tool execution. - - Example: - ``` - agent = BidiAgent(model=model, tools=[calculator]) - agent.tool.calculator(expression="2+2") - ``` - """ - return self._tool_caller - - @property - def tool_names(self) -> list[str]: - """Get a list of all registered tool names. - - Returns: - Names of all tools available to this agent. - """ - all_tools = self.tool_registry.get_all_tools_config() - return list(all_tools.keys()) - - def _record_tool_execution( - self, - tool: ToolUse, - tool_result: ToolResult, - user_message_override: str | None, - ) -> None: - """Record a tool execution in the message history. - - Creates a sequence of messages that represent the tool execution: - - 1. A user message describing the tool call - 2. An assistant message with the tool use - 3. A user message with the tool result - 4. An assistant message acknowledging the tool call - - Args: - tool: The tool call information. - tool_result: The result returned by the tool. - user_message_override: Optional custom message to include. - """ - # Filter tool input parameters to only include those defined in tool spec - filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"]) - - # Create user message describing the tool call - input_parameters = json.dumps(filtered_input, default=lambda o: f"<>") - - user_msg_content = [ - {"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")} - ] - - # Add override message if provided - if user_message_override: - user_msg_content.insert(0, {"text": f"{user_message_override}\n"}) - - # Create filtered tool use for message history - filtered_tool: ToolUse = { - "toolUseId": tool["toolUseId"], - "name": tool["name"], - "input": filtered_input, - } - - # Create the message sequence - user_msg: Message = { - "role": "user", - "content": user_msg_content, - } - tool_use_msg: Message = { - "role": "assistant", - "content": [{"toolUse": filtered_tool}], - } - tool_result_msg: Message = { - "role": "user", - "content": [{"toolResult": tool_result}], - } - assistant_msg: Message = { - "role": "assistant", - "content": [{"text": f"agent.tool.{tool['name']} was called."}], - } - - # Add to message history - self.messages.append(user_msg) - self.messages.append(tool_use_msg) - self.messages.append(tool_result_msg) - self.messages.append(assistant_msg) - - logger.debug("Direct tool call recorded in message history: %s", tool["name"]) - - def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]: - """Filter input parameters to only include those defined in the tool specification. - - Args: - tool_name: Name of the tool to get specification for - input_params: Original input parameters - - Returns: - Filtered parameters containing only those defined in tool spec - """ - all_tools_config = self.tool_registry.get_all_tools_config() - tool_spec = all_tools_config.get(tool_name) - - if not tool_spec or "inputSchema" not in tool_spec: - return input_params.copy() - - properties = tool_spec["inputSchema"]["json"]["properties"] - return {k: v for k, v in input_params.items() if k in properties} - - async def start(self) -> None: - """Start a persistent bidirectional conversation connection. - - Initializes the streaming connection and starts background tasks for processing - model events, tool execution, and connection management. - """ - logger.debug("starting agent") - - await self._loop.start() - - async def send(self, input_data: BidiAgentInput) -> None: - """Send input to the model (text, audio, image, or event dict). - - Unified method for sending text, audio, and image input to the model during - an active conversation session. Accepts TypedEvent instances or plain dicts - (e.g., from WebSocket clients) which are automatically reconstructed. - - Args: - input_data: Can be: - - str: Text message from user - - BidiAudioInputEvent: Audio data with format/sample rate - - BidiImageInputEvent: Image data with MIME type - - dict: Event dictionary (will be reconstructed to TypedEvent) - - Raises: - ValueError: If no active session or invalid input type. - - Example: - await agent.send("Hello") - await agent.send(BidiAudioInputEvent(audio="base64...", format="pcm", ...)) - await agent.send({"type": "bidirectional_text_input", "text": "Hello", "role": "user"}) - """ - self._validate_active_connection() - - # Handle string input - if isinstance(input_data, str): - # Add user text message to history - user_message: Message = {"role": "user", "content": [{"text": input_data}]} - - self.messages.append(user_message) - - logger.debug("Text sent: %d characters", len(input_data)) - # Create BidiTextInputEvent for send() - text_event = BidiTextInputEvent(text=input_data, role="user") - await self.model.send(text_event) - return - - # Handle BidiInputEvent instances - # Check this before dict since TypedEvent inherits from dict - if isinstance(input_data, BidiInputEvent): - await self.model.send(input_data) - return - - # Handle plain dict - reconstruct TypedEvent for WebSocket integration - if isinstance(input_data, dict) and "type" in input_data: - event_type = input_data["type"] - if event_type == "bidi_text_input": - input_event = BidiTextInputEvent(text=input_data["text"], role=input_data["role"]) - elif event_type == "bidi_audio_input": - input_event = BidiAudioInputEvent( - audio=input_data["audio"], - format=input_data["format"], - sample_rate=input_data["sample_rate"], - channels=input_data["channels"] - ) - elif event_type == "bidi_image_input": - input_event = BidiImageInputEvent( - image=input_data["image"], - mime_type=input_data["mime_type"] - ) - else: - raise ValueError(f"Unknown event type: {event_type}") - - # Send the reconstructed TypedEvent - await self.model.send(input_event) - return - - # If we get here, input type is invalid - raise ValueError( - f"Input must be a string, BidiInputEvent (BidiTextInputEvent/BidiAudioInputEvent/BidiImageInputEvent), or event dict with 'type' field, got: {type(input_data)}" - ) - - async def receive(self) -> AsyncIterable[BidiOutputEvent]: - """Receive events from the model including audio, text, and tool calls. - - Yields model output events processed by background tasks including audio output, - text responses, tool calls, and connection updates. - - Yields: - Model and tool call events. - """ - async for event in self._loop.receive(): - yield event - - async def stop(self) -> None: - """End the conversation connection and cleanup all resources. - - Terminates the streaming connection, cancels background tasks, and - closes the connection to the model provider. - """ - await self._loop.stop() - - async def __aenter__(self) -> "BidiAgent": - """Async context manager entry point. - - Automatically starts the bidirectional connection when entering the context. - - Returns: - Self for use in the context. - """ - logger.debug("Entering async context manager - starting connection") - await self.start() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: - """Async context manager exit point. - - Automatically ends the connection and cleans up resources including adapters - when exiting the context, regardless of whether an exception occurred. - - Args: - exc_type: Exception type if an exception occurred, None otherwise. - exc_val: Exception value if an exception occurred, None otherwise. - exc_tb: Exception traceback if an exception occurred, None otherwise. - """ - try: - logger.debug("Exiting async context manager - cleaning up adapters and connection") - - # Cleanup adapters if any are currently active - for adapter in self._current_adapters: - if hasattr(adapter, "cleanup"): - try: - adapter.stop() - logger.debug(f"Cleaned up adapter: {type(adapter).__name__}") - except Exception as adapter_error: - logger.warning(f"Error cleaning up adapter: {adapter_error}") - - # Clear current adapters - self._current_adapters = [] - - # Cleanup agent connection - await self.stop() - - except Exception as cleanup_error: - if exc_type is None: - # No original exception, re-raise cleanup error - logger.error("Error during context manager cleanup: %s", cleanup_error) - raise - else: - # Original exception exists, log cleanup error but don't suppress original - logger.error( - "Error during context manager cleanup (suppressed due to original exception): %s", cleanup_error - ) - - @property - def active(self) -> bool: - """True if agent loop started, False otherwise.""" - return self._loop.active - - async def run(self, inputs: list[BidiInput], outputs: list[BidiOutput]) -> None: - """Run the agent using provided IO channels for bidirectional communication. - - Args: - inputs: Input callables to read data from a source - outputs: Output callables to receive events from the agent - - Example: - ```python - audio_io = BidiAudioIO(audio_config={"input_sample_rate": 16000}) - text_io = BidiTextIO() - agent = BidiAgent(model=model, tools=[calculator]) - await agent.run(inputs=[audio_io.input()], outputs=[audio_io.output(), text_io.output()]) - ``` - """ - async def run_inputs(): - while self.active: - for input_ in inputs: - event = await input_() - await self.send(event) - - # TODO: Need to make tool result send in Nova provider atomic. Audio input events end up interleaving - # and leading to failures. Adding a sleep here as a temporary solution. - await asyncio.sleep(0.001) - - async def run_outputs(): - async for event in self.receive(): - for output in outputs: - await output(event) - - for input_ in inputs: - await input_.start() - - for output in outputs: - await output.start() - - try: - await asyncio.gather(run_inputs(), run_outputs(), return_exceptions=True) - - finally: - for input_ in inputs: - await input_.stop() - - for output in outputs: - await output.stop() - - def _validate_active_connection(self) -> None: - """Validate that an active connection exists. - - Raises: - ValueError: If no active connection. - """ - if not self.active: - raise ValueError("No active conversation. Call start() first or use async context manager.") diff --git a/src/strands/experimental/bidirectional_streaming/agent/loop.py b/src/strands/experimental/bidirectional_streaming/agent/loop.py deleted file mode 100644 index e0bc02ef2..000000000 --- a/src/strands/experimental/bidirectional_streaming/agent/loop.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Agent loop. - -The agent loop handles the events received from the model and executes tools when given a tool use request. -""" - -import asyncio -import logging -from typing import AsyncIterable, Awaitable, TYPE_CHECKING - -from ..types.events import BidiAudioStreamEvent, BidiInterruptionEvent, BidiOutputEvent, BidiTranscriptStreamEvent -from ....types._events import ToolResultEvent, ToolResultMessageEvent, ToolStreamEvent, ToolUseStreamEvent -from ....types.content import Message -from ....types.tools import ToolResult, ToolUse - -if TYPE_CHECKING: - from .agent import BidiAgent - -logger = logging.getLogger(__name__) - - -class _BidiAgentLoop: - """Agent loop.""" - - def __init__(self, agent: "BidiAgent") -> None: - """Initialize members of the agent loop. - - Note, before receiving events from the loop, the user must call `start`. - - Args: - agent: Bidirectional agent to loop over. - """ - self._agent = agent - self._event_queue = asyncio.Queue() # queue model and tool call events - self._tasks = set() # track active async tasks created in loop - self._active = False # flag if agent loop is started - - async def start(self) -> None: - """Start the agent loop. - - The agent model is started as part of this call. - """ - if self.active: - return - - logger.debug("starting agent loop") - - await self._agent.model.start( - system_prompt=self._agent.system_prompt, - tools=self._agent.tool_registry.get_all_tool_specs(), - messages=self._agent.messages, - ) - - self._create_task(self._run_model()) - - self._active = True - - async def stop(self) -> None: - """Stop the agent loop.""" - if not self.active: - return - - logger.debug("stopping agent loop") - - for task in self._tasks: - task.cancel() - - await asyncio.gather(*self._tasks, return_exceptions=True) - - await self._agent.model.stop() - - self._active = False - - async def receive(self) -> AsyncIterable[BidiOutputEvent]: - """Receive model and tool call events.""" - while self.active: - try: - yield self._event_queue.get_nowait() - except asyncio.QueueEmpty: - pass - - # unblock the event loop - await asyncio.sleep(0) - - @property - def active(self) -> bool: - """True if agent loop started, False otherwise.""" - return self._active - - def _create_task(self, coro: Awaitable[None]) -> None: - """Utilitly to create async task. - - Adds a clean up callback to run after task completes. - """ - task = asyncio.create_task(coro) - task.add_done_callback(lambda task: self._tasks.remove(task)) - - self._tasks.add(task) - - async def _run_model(self) -> None: - """Task for running the model. - - Events are streamed through the event queue. - """ - logger.debug("running model") - - async for event in self._agent.model.receive(): - if not self.active: - break - - self._event_queue.put_nowait(event) - - if isinstance(event, BidiTranscriptStreamEvent): - if event["is_final"]: - message: Message = {"role": event["role"], "content": [{"text": event["text"]}]} - self._agent.messages.append(message) - - elif isinstance(event, ToolUseStreamEvent): - self._create_task(self._run_tool(event["current_tool_use"])) - - elif isinstance(event, BidiInterruptionEvent): - # clear the audio - for _ in range(self._event_queue.qsize()): - event = self._event_queue.get_nowait() - if not isinstance(event, BidiAudioStreamEvent): - self._event_queue.put_nowait(event) - - async def _run_tool(self, tool_use: ToolUse) -> None: - """Task for running tool requested by the model.""" - logger.debug("running tool") - - result: ToolResult = None - - try: - tool = self._agent.tool_registry.registry[tool_use["name"]] - invocation_state = {} - - async for event in tool.stream(tool_use, invocation_state): - if isinstance(event, ToolResultEvent): - self._event_queue.put_nowait(event) - result = event.tool_result - break - - if isinstance(event, ToolStreamEvent): - self._event_queue.put_nowait(event) - else: - self._event_queue.put_nowait(ToolStreamEvent(tool_use, event)) - - except Exception as e: - result = { - "toolUseId": tool_use["toolUseId"], - "status": "error", - "content": [{"text": f"Error: {str(e)}"}] - } - - await self._agent.model.send(ToolResultEvent(result)) - - message: Message = { - "role": "user", - "content": [{"toolResult": result}], - } - self._agent.messages.append(message) - self._event_queue.put_nowait(ToolResultMessageEvent(message)) diff --git a/src/strands/experimental/bidirectional_streaming/io/__init__.py b/src/strands/experimental/bidirectional_streaming/io/__init__.py deleted file mode 100644 index d099cba2f..000000000 --- a/src/strands/experimental/bidirectional_streaming/io/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""IO channel implementations for bidirectional streaming.""" - -from .audio import BidiAudioIO -from .text import BidiTextIO - -__all__ = ["BidiAudioIO", "BidiTextIO"] diff --git a/src/strands/experimental/bidirectional_streaming/io/audio.py b/src/strands/experimental/bidirectional_streaming/io/audio.py deleted file mode 100644 index 2ec167480..000000000 --- a/src/strands/experimental/bidirectional_streaming/io/audio.py +++ /dev/null @@ -1,161 +0,0 @@ -"""AudioIO - Clean separation of audio functionality from core BidiAgent. - -Provides audio input/output capabilities for BidiAgent through the BidiIO protocol. -Handles all PyAudio setup, streaming, and cleanup while keeping the core agent data-agnostic. -""" - -import asyncio -import base64 -import logging - -import pyaudio - -from ..types.io import BidiInput, BidiOutput -from ..types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiOutputEvent - -logger = logging.getLogger(__name__) - - -class _BidiAudioInput(BidiInput): - "Handle audio input from bidi agent." - def __init__(self, audio: "BidiAudioIO") -> None: - """Store reference to pyaudio instance.""" - self.audio = audio - - async def start(self) -> None: - """Start audio input.""" - self.audio._start() - - async def stop(self) -> None: - """Stop audio input.""" - self.audio._stop() - - async def __call__(self) -> BidiAudioInputEvent: - """Read audio from microphone.""" - audio_bytes = self.audio.input_stream.read(self.audio.chunk_size, exception_on_overflow=False) - - return BidiAudioInputEvent( - audio=base64.b64encode(audio_bytes).decode("utf-8"), - format="pcm", - sample_rate=self.audio.input_sample_rate, - channels=self.audio.input_channels, - ) - - -class _BidiAudioOutput(BidiOutput): - "Handle audio output from bidi agent." - def __init__(self, audio: "BidiAudioIO") -> None: - """Store reference to pyaudio instance.""" - self.audio = audio - - async def start(self) -> None: - """Start audio output.""" - self.audio._start() - - async def stop(self) -> None: - """Stop audio output.""" - self.audio._stop() - - async def __call__(self, event: BidiOutputEvent) -> None: - """Handle audio events with direct stream writing.""" - if isinstance(event, BidiAudioStreamEvent): - self.audio.output_stream.write(base64.b64decode(event["audio"])) - - # TODO: Outputing audio to speakers is a sync operation. Adding sleep to prevent event loop hogging. Will - # follow up on identifying a cleaner approach. - await asyncio.sleep(0.01) - - -class BidiAudioIO: - """Audio IO channel for BidiAgent with direct stream processing.""" - - def __init__( - self, - audio_config: dict | None = None, - ): - """Initialize AudioIO with clean audio configuration. - - Args: - audio_config: Dictionary containing audio configuration: - - input_sample_rate (int): Microphone sample rate (default: 24000) - - output_sample_rate (int): Speaker sample rate (default: 24000) - - chunk_size (int): Audio chunk size in bytes (default: 1024) - - input_device_index (int): Specific input device (optional) - - output_device_index (int): Specific output device (optional) - - input_channels (int): Input channels (default: 1) - - output_channels (int): Output channels (default: 1) - """ - default_config = { - "input_sample_rate": 16000, - "output_sample_rate": 16000, - "chunk_size": 512, - "input_device_index": None, - "output_device_index": None, - "input_channels": 1, - "output_channels": 1, - } - - # Merge user config with defaults - if audio_config: - default_config.update(audio_config) - - # Set audio configuration attributes - self.input_sample_rate = default_config["input_sample_rate"] - self.output_sample_rate = default_config["output_sample_rate"] - self.chunk_size = default_config["chunk_size"] - self.input_device_index = default_config["input_device_index"] - self.output_device_index = default_config["output_device_index"] - self.input_channels = default_config["input_channels"] - self.output_channels = default_config["output_channels"] - - # Audio infrastructure - self.audio = None - self.input_stream = None - self.output_stream = None - self.interrupted = False - - def input(self) -> _BidiAudioInput: - "Return audio processing BidiInput" - return _BidiAudioInput(self) - - def output(self) -> _BidiAudioOutput: - "Return audio processing BidiOutput" - return _BidiAudioOutput(self) - - def _start(self) -> None: - """Setup PyAudio streams for input and output.""" - if self.audio: - return - - self.audio = pyaudio.PyAudio() - - self.input_stream = self.audio.open( - format=pyaudio.paInt16, - channels=self.input_channels, - rate=self.input_sample_rate, - input=True, - frames_per_buffer=self.chunk_size, - input_device_index=self.input_device_index, - ) - - self.output_stream = self.audio.open( - format=pyaudio.paInt16, - channels=self.output_channels, - rate=self.output_sample_rate, - output=True, - frames_per_buffer=self.chunk_size, - output_device_index=self.output_device_index, - ) - - def _stop(self) -> None: - """Clean up IO channel resources.""" - if not self.audio: - return - - self.input_stream.close() - self.output_stream.close() - self.audio.terminate() - - self.input_stream = None - self.output_stream = None - self.audio = None diff --git a/src/strands/experimental/bidirectional_streaming/io/text.py b/src/strands/experimental/bidirectional_streaming/io/text.py deleted file mode 100644 index ba503f4e4..000000000 --- a/src/strands/experimental/bidirectional_streaming/io/text.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Handle text input and output from bidi agent.""" - -import logging - -from ..types.io import BidiOutput -from ..types.events import BidiOutputEvent, BidiInterruptionEvent, BidiTranscriptStreamEvent - -logger = logging.getLogger(__name__) - - -class _BidiTextOutput(BidiOutput): - "Handle text output from bidi agent." - async def __call__(self, event: BidiOutputEvent) -> None: - """Print text events to stdout.""" - - if isinstance(event, BidiInterruptionEvent): - print("interrupted") - - elif isinstance(event, BidiTranscriptStreamEvent): - text = event["text"] - if not event["is_final"]: - text = f"Preview: {text}" - - print(text) - - -class BidiTextIO: - "Handle text input and output from bidi agent." - def output(self) -> _BidiTextOutput: - "Return text processing BidiOutput" - return _BidiTextOutput() diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py deleted file mode 100644 index 6d6d6590b..000000000 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Bidirectional model interfaces and implementations.""" - -from .bidirectional_model import BidiModel -from .gemini_live import BidiGeminiLiveModel -from .novasonic import BidiNovaSonicModel -from .openai import BidiOpenAIRealtimeModel - -__all__ = [ - "BidiModel", - "BidiGeminiLiveModel", - "BidiNovaSonicModel", - "BidiOpenAIRealtimeModel", -] diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py deleted file mode 100644 index d3c3aa7ec..000000000 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ /dev/null @@ -1,108 +0,0 @@ -"""Bidirectional streaming model interface. - -Defines the abstract interface for models that support real-time bidirectional -communication with persistent connections. Unlike traditional request-response -models, bidirectional models maintain an open connection for streaming audio, -text, and tool interactions. - -Features: -- Persistent connection management with connect/close lifecycle -- Real-time bidirectional communication (send and receive simultaneously) -- Provider-agnostic event normalization -- Support for audio, text, image, and tool result streaming -""" - -import logging -from typing import AsyncIterable, Protocol, Union - -from ....types._events import ToolResultEvent -from ....types.content import Messages -from ....types.tools import ToolSpec -from ..types.events import ( - BidiAudioInputEvent, - BidiImageInputEvent, - BidiInputEvent, - BidiOutputEvent, - BidiTextInputEvent, -) - -logger = logging.getLogger(__name__) - - -class BidiModel(Protocol): - """Protocol for bidirectional streaming models. - - This interface defines the contract for models that support persistent streaming - connections with real-time audio and text communication. Implementations handle - provider-specific protocols while exposing a standardized event-based API. - """ - - async def start( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs, - ) -> None: - """Establish a persistent streaming connection with the model. - - Opens a bidirectional connection that remains active for real-time communication. - The connection supports concurrent sending and receiving of events until explicitly - closed. Must be called before any send() or receive() operations. - - Args: - system_prompt: System instructions to configure model behavior. - tools: Tool specifications that the model can invoke during the conversation. - messages: Initial conversation history to provide context. - **kwargs: Provider-specific configuration options. - """ - ... - - async def stop(self) -> None: - """Close the streaming connection and release resources. - - Terminates the active bidirectional connection and cleans up any associated - resources such as network connections, buffers, or background tasks. After - calling close(), the model instance cannot be used until start() is called again. - """ - ... - - async def receive(self) -> AsyncIterable[BidiOutputEvent]: - """Receive streaming events from the model. - - Continuously yields events from the model as they arrive over the connection. - Events are normalized to a provider-agnostic format for uniform processing. - This method should be called in a loop or async task to process model responses. - - The stream continues until the connection is closed or an error occurs. - - Yields: - BidiOutputEvent: Standardized event objects containing audio output, - transcripts, tool calls, or control signals. - """ - ... - - async def send( - self, - content: BidiInputEvent | ToolResultEvent, - ) -> None: - """Send content to the model over the active connection. - - Transmits user input or tool results to the model during an active streaming - session. Supports multiple content types including text, audio, images, and - tool execution results. Can be called multiple times during a conversation. - - Args: - content: The content to send. Must be one of: - - BidiTextInputEvent: Text message from the user - - BidiAudioInputEvent: Audio data for speech input - - BidiImageInputEvent: Image data for visual understanding - - ToolResultEvent: Result from a tool execution - - Example: - await model.send(BidiTextInputEvent(text="Hello", role="user")) - await model.send(BidiAudioInputEvent(audio=bytes, format="pcm", sample_rate=16000, channels=1)) - await model.send(BidiImageInputEvent(image=bytes, mime_type="image/jpeg", encoding="raw")) - await model.send(ToolResultEvent(tool_result)) - """ - ... diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py deleted file mode 100644 index 9bb5bba77..000000000 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ /dev/null @@ -1,539 +0,0 @@ -"""Gemini Live API bidirectional model provider using official Google GenAI SDK. - -Implements the BidiModel interface for Google's Gemini Live API using the -official Google GenAI SDK for simplified and robust WebSocket communication. - -Key improvements over custom WebSocket implementation: -- Uses official google-genai SDK with native Live API support -- Simplified session management with client.aio.live.connect() -- Built-in tool integration and event handling -- Automatic WebSocket connection management and error handling -- Native support for audio/text streaming and interruption -""" - -import asyncio -import base64 -import logging -import uuid -from typing import Any, AsyncIterable, Dict, List, Optional, Union - -from google import genai -from google.genai import types as genai_types -from google.genai.types import LiveServerMessage, LiveServerContent - -from ....types.content import Messages -from ....types.tools import ToolResult, ToolSpec, ToolUse -from ....types._events import ToolResultEvent, ToolUseStreamEvent -from ..types.events import ( - BidiAudioInputEvent, - BidiAudioStreamEvent, - BidiConnectionCloseEvent, - BidiConnectionStartEvent, - BidiErrorEvent, - BidiImageInputEvent, - BidiInputEvent, - BidiInterruptionEvent, - BidiOutputEvent, - BidiUsageEvent, - BidiTextInputEvent, - BidiTranscriptStreamEvent, - BidiResponseCompleteEvent, - BidiResponseStartEvent, -) -from .bidirectional_model import BidiModel - -logger = logging.getLogger(__name__) - -# Audio format constants -GEMINI_INPUT_SAMPLE_RATE = 16000 -GEMINI_OUTPUT_SAMPLE_RATE = 24000 -GEMINI_CHANNELS = 1 - - -class BidiGeminiLiveModel(BidiModel): - """Gemini Live API implementation using official Google GenAI SDK. - - Combines model configuration and connection state in a single class. - Provides a clean interface to Gemini Live API using the official SDK, - eliminating custom WebSocket handling and providing robust error handling. - """ - - def __init__( - self, - model_id: str = "gemini-2.5-flash-native-audio-preview-09-2025", - api_key: Optional[str] = None, - live_config: Optional[Dict[str, Any]] = None, - **kwargs - ): - """Initialize Gemini Live API bidirectional model. - - Args: - model_id: Gemini Live model identifier. - api_key: Google AI API key for authentication. - live_config: Gemini Live API configuration parameters (e.g., response_modalities, speech_config). - **kwargs: Reserved for future parameters. - """ - # Model configuration - self.model_id = model_id - self.api_key = api_key - - # Set default live_config with transcription enabled - default_config = { - "response_modalities": ["AUDIO"], - "outputAudioTranscription": {}, # Enable output transcription by default - "inputAudioTranscription": {} # Enable input transcription by default - } - - # Merge user config with defaults (user config takes precedence) - if live_config: - default_config.update(live_config) - - self.live_config = default_config - - # Create Gemini client with proper API version - client_kwargs = {} - if api_key: - client_kwargs["api_key"] = api_key - - # Use v1alpha for Live API as it has better model support - client_kwargs["http_options"] = {"api_version": "v1alpha"} - - self.client = genai.Client(**client_kwargs) - - # Connection state (initialized in start()) - self.live_session = None - self.live_session_context_manager = None - self.connection_id = None - self._active = False - - async def start( - self, - system_prompt: Optional[str] = None, - tools: Optional[List[ToolSpec]] = None, - messages: Optional[Messages] = None, - **kwargs - ) -> None: - """Establish bidirectional connection with Gemini Live API. - - Args: - system_prompt: System instructions for the model. - tools: List of tools available to the model. - messages: Conversation history to initialize with. - **kwargs: Additional configuration options. - """ - if self._active: - raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") - - try: - # Initialize connection state - self.connection_id = str(uuid.uuid4()) - self._active = True - - # Build live config - live_config = self._build_live_config(system_prompt, tools, **kwargs) - - # Create the context manager - self.live_session_context_manager = self.client.aio.live.connect( - model=self.model_id, - config=live_config - ) - - # Enter the context manager - self.live_session = await self.live_session_context_manager.__aenter__() - - # Send initial message history if provided - if messages: - await self._send_message_history(messages) - - except Exception as e: - self._active = False - logger.error("Error connecting to Gemini Live: %s", e) - raise - - async def _send_message_history(self, messages: Messages) -> None: - """Send conversation history to Gemini Live API. - - Sends each message as a separate turn with the correct role to maintain - proper conversation context. Follows the same pattern as the non-bidirectional - Gemini model implementation. - """ - if not messages: - return - - # Convert each message to Gemini format and send separately - for message in messages: - content_parts = [] - for content_block in message["content"]: - if "text" in content_block: - content_parts.append(genai_types.Part(text=content_block["text"])) - - if content_parts: - # Map role correctly - Gemini uses "user" and "model" roles - # "assistant" role from Messages format maps to "model" in Gemini - role = "model" if message["role"] == "assistant" else message["role"] - content = genai_types.Content(role=role, parts=content_parts) - await self.live_session.send_client_content(turns=content) - - async def receive(self) -> AsyncIterable[BidiOutputEvent]: - """Receive Gemini Live API events and convert to provider-agnostic format.""" - - # Emit connection start event - yield BidiConnectionStartEvent( - connection_id=self.connection_id, - model=self.model_id - ) - - try: - # Wrap in while loop to restart after turn_complete (SDK limitation workaround) - while self._active: - try: - async for message in self.live_session.receive(): - if not self._active: - break - - # Convert to provider-agnostic format (always returns list) - for event in self._convert_gemini_live_event(message): - yield event - - # SDK exits receive loop after turn_complete - restart automatically - if self._active: - logger.debug("Restarting receive loop after turn completion") - - except Exception as e: - logger.error("Error in receive iteration: %s", e) - # Small delay before retrying to avoid tight error loops - await asyncio.sleep(0.1) - - except Exception as e: - logger.error("Fatal error in receive loop: %s", e) - yield BidiErrorEvent(error=e) - finally: - # Emit connection close event when exiting - yield BidiConnectionCloseEvent(connection_id=self.connection_id, reason="complete") - - def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOutputEvent]: - """Convert Gemini Live API events to provider-agnostic format. - - Handles different types of content: - - inputTranscription: User's speech transcribed to text - - outputTranscription: Model's audio transcribed to text - - modelTurn text: Text response from the model - - usageMetadata: Token usage information - - Returns: - List of event dicts (empty list if no events to emit). - """ - try: - # Handle interruption first (from server_content) - if message.server_content and message.server_content.interrupted: - return [BidiInterruptionEvent(reason="user_speech")] - - # Handle input transcription (user's speech) - emit as transcript event - if message.server_content and message.server_content.input_transcription: - input_transcript = message.server_content.input_transcription - # Check if the transcription object has text content - if hasattr(input_transcript, 'text') and input_transcript.text: - transcription_text = input_transcript.text - role = getattr(input_transcript, 'role', 'user') - logger.debug(f"Input transcription detected: {transcription_text}") - return [BidiTranscriptStreamEvent( - delta={"text": transcription_text}, - text=transcription_text, - role=role.lower() if isinstance(role, str) else "user", - is_final=True, - current_transcript=transcription_text - )] - - # Handle output transcription (model's audio) - emit as transcript event - if message.server_content and message.server_content.output_transcription: - output_transcript = message.server_content.output_transcription - # Check if the transcription object has text content - if hasattr(output_transcript, 'text') and output_transcript.text: - transcription_text = output_transcript.text - role = getattr(output_transcript, 'role', 'assistant') - logger.debug(f"Output transcription detected: {transcription_text}") - return [BidiTranscriptStreamEvent( - delta={"text": transcription_text}, - text=transcription_text, - role=role.lower() if isinstance(role, str) else "assistant", - is_final=True, - current_transcript=transcription_text - )] - - # Handle audio output using SDK's built-in data property - # Check this BEFORE text to avoid triggering warning on mixed content - if message.data: - # Convert bytes to base64 string for JSON serializability - audio_b64 = base64.b64encode(message.data).decode('utf-8') - return [BidiAudioStreamEvent( - audio=audio_b64, - format="pcm", - sample_rate=GEMINI_OUTPUT_SAMPLE_RATE, - channels=GEMINI_CHANNELS - )] - - # Handle text output from model_turn (avoids warning by checking parts directly) - if message.server_content and message.server_content.model_turn: - model_turn = message.server_content.model_turn - if model_turn.parts: - # Concatenate all text parts (Gemini may send multiple parts) - text_parts = [] - for part in model_turn.parts: - # Log all part types for debugging - part_attrs = {attr: getattr(part, attr, None) for attr in dir(part) if not attr.startswith('_')} - - # Check if part has text attribute and it's not empty - if hasattr(part, 'text') and part.text: - text_parts.append(part.text) - - if text_parts: - full_text = " ".join(text_parts) - return [BidiTranscriptStreamEvent( - delta={"text": full_text}, - text=full_text, - role="assistant", - is_final=True, - current_transcript=full_text - )] - - # Handle tool calls - return list to support multiple tool calls - if message.tool_call and message.tool_call.function_calls: - tool_events = [] - for func_call in message.tool_call.function_calls: - tool_use_event: ToolUse = { - "toolUseId": func_call.id, - "name": func_call.name, - "input": func_call.args or {} - } - # Create ToolUseStreamEvent for consistency with standard agent - tool_events.append(ToolUseStreamEvent( - delta={"toolUse": tool_use_event}, - current_tool_use=tool_use_event - )) - return tool_events - - # Handle usage metadata - if hasattr(message, 'usage_metadata') and message.usage_metadata: - usage = message.usage_metadata - - # Build modality details from token details - modality_details = [] - - # Process prompt tokens details - if usage.prompt_tokens_details: - for detail in usage.prompt_tokens_details: - if detail.modality and detail.token_count: - modality_details.append({ - "modality": str(detail.modality).lower(), - "input_tokens": detail.token_count, - "output_tokens": 0 - }) - - # Process response tokens details - if usage.response_tokens_details: - for detail in usage.response_tokens_details: - if detail.modality and detail.token_count: - # Find or create modality entry - modality_str = str(detail.modality).lower() - existing = next((m for m in modality_details if m["modality"] == modality_str), None) - if existing: - existing["output_tokens"] = detail.token_count - else: - modality_details.append({ - "modality": modality_str, - "input_tokens": 0, - "output_tokens": detail.token_count - }) - - return [BidiUsageEvent( - input_tokens=usage.prompt_token_count or 0, - output_tokens=usage.response_token_count or 0, - total_tokens=usage.total_token_count or 0, - modality_details=modality_details if modality_details else None, - cache_read_input_tokens=usage.cached_content_token_count if usage.cached_content_token_count else None - )] - - # Silently ignore setup_complete and generation_complete messages - return [] - - except Exception as e: - logger.error("Error converting Gemini Live event: %s", e) - logger.error("Message type: %s", type(message).__name__) - logger.error("Message attributes: %s", [attr for attr in dir(message) if not attr.startswith('_')]) - # Return ErrorEvent in list so caller can handle it - return [BidiErrorEvent(error=e)] - - async def send( - self, - content: BidiInputEvent | ToolResultEvent, - ) -> None: - """Unified send method for all content types. Sends the given inputs to Google Live API - - Dispatches to appropriate internal handler based on content type. - - Args: - content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent). - """ - if not self._active: - return - - try: - if isinstance(content, BidiTextInputEvent): - await self._send_text_content(content.text) - elif isinstance(content, BidiAudioInputEvent): - await self._send_audio_content(content) - elif isinstance(content, BidiImageInputEvent): - await self._send_image_content(content) - elif isinstance(content, ToolResultEvent): - tool_result = content.get("tool_result") - if tool_result: - await self._send_tool_result(tool_result) - else: - logger.warning(f"Unknown content type: {type(content)}") - except Exception as e: - logger.error(f"Error sending content: {e}") - raise # Propagate exception for debugging in experimental code - - async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: - """Internal: Send audio content using Gemini Live API. - - Gemini Live expects continuous audio streaming via send_realtime_input. - This automatically triggers VAD and can interrupt ongoing responses. - """ - try: - # Decode base64 audio to bytes for SDK - audio_bytes = base64.b64decode(audio_input.audio) - - # Create audio blob for the SDK - audio_blob = genai_types.Blob( - data=audio_bytes, - mime_type=f"audio/pcm;rate={GEMINI_INPUT_SAMPLE_RATE}" - ) - - # Send real-time audio input - this automatically handles VAD and interruption - await self.live_session.send_realtime_input(audio=audio_blob) - - except Exception as e: - logger.error("Error sending audio content: %s", e) - - async def _send_image_content(self, image_input: BidiImageInputEvent) -> None: - """Internal: Send image content using Gemini Live API. - - Sends image frames following the same pattern as the GitHub example. - Images are sent as base64-encoded data with MIME type. - """ - try: - # Image is already base64 encoded in the event - msg = { - "mime_type": image_input.mime_type, - "data": image_input.image - } - - # Send using the same method as the GitHub example - await self.live_session.send(input=msg) - - except Exception as e: - logger.error("Error sending image content: %s", e) - - async def _send_text_content(self, text: str) -> None: - """Internal: Send text content using Gemini Live API.""" - try: - # Create content with text - content = genai_types.Content( - role="user", - parts=[genai_types.Part(text=text)] - ) - - # Send as client content - await self.live_session.send_client_content(turns=content) - - except Exception as e: - logger.error("Error sending text content: %s", e) - - async def _send_tool_result(self, tool_result: ToolResult) -> None: - """Internal: Send tool result using Gemini Live API.""" - try: - tool_use_id = tool_result.get("toolUseId") - - # Extract result content - result_data = {} - if "content" in tool_result: - # Extract text from content blocks - for block in tool_result["content"]: - if "text" in block: - result_data = {"result": block["text"]} - break - - # Create function response - func_response = genai_types.FunctionResponse( - id=tool_use_id, - name=tool_use_id, # Gemini uses name as identifier - response=result_data - ) - - # Send tool response - await self.live_session.send_tool_response(function_responses=[func_response]) - except Exception as e: - logger.error("Error sending tool result: %s", e) - - async def stop(self) -> None: - """Close Gemini Live API connection.""" - if not self._active: - return - - self._active = False - - try: - # Exit the context manager properly - if self.live_session_context_manager: - await self.live_session_context_manager.__aexit__(None, None, None) - except Exception as e: - logger.error("Error closing Gemini Live connection: %s", e) - raise - - def _build_live_config( - self, - system_prompt: Optional[str] = None, - tools: Optional[List[ToolSpec]] = None, - **kwargs - ) -> Dict[str, Any]: - """Build LiveConnectConfig for the official SDK. - - Simply passes through all config parameters from live_config, allowing users - to configure any Gemini Live API parameter directly. - """ - # Start with user-provided live_config - config_dict = {} - if self.live_config: - config_dict.update(self.live_config) - - # Override with any kwargs from start() - config_dict.update(kwargs) - - # Add system instruction if provided - if system_prompt: - config_dict["system_instruction"] = system_prompt - - # Add tools if provided - if tools: - config_dict["tools"] = self._format_tools_for_live_api(tools) - - return config_dict - - def _format_tools_for_live_api(self, tool_specs: List[ToolSpec]) -> List[genai_types.Tool]: - """Format tool specs for Gemini Live API.""" - if not tool_specs: - return [] - - return [ - genai_types.Tool( - function_declarations=[ - genai_types.FunctionDeclaration( - description=tool_spec["description"], - name=tool_spec["name"], - parameters_json_schema=tool_spec["inputSchema"]["json"], - ) - for tool_spec in tool_specs - ], - ), - ] \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py deleted file mode 100644 index 8c23aa0da..000000000 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ /dev/null @@ -1,743 +0,0 @@ -"""Nova Sonic bidirectional model provider for real-time streaming conversations. - -Implements the BidiModel interface for Amazon's Nova Sonic, handling the -complex event sequencing and audio processing required by Nova Sonic's -InvokeModelWithBidirectionalStream protocol. - -Nova Sonic specifics: -- Hierarchical event sequences: connectionStart → promptStart → content streaming -- Base64-encoded audio format with hex encoding -- Tool execution with content containers and identifier tracking -- 8-minute connection limits with proper cleanup sequences -- Interruption detection through stopReason events -""" - -import asyncio -import base64 -import json -import logging -import traceback -import uuid -from typing import AsyncIterable - -from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput -from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme -from aws_sdk_bedrock_runtime.models import ( - BidirectionalInputPayloadPart, - InvokeModelWithBidirectionalStreamInputChunk, -) -from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver - -from ....types.content import Messages -from ....types.tools import ToolResult, ToolSpec, ToolUse -from ....types._events import ToolResultEvent, ToolUseStreamEvent -from ..types.events import ( - BidiAudioInputEvent, - BidiAudioStreamEvent, - BidiConnectionCloseEvent, - BidiConnectionStartEvent, - BidiErrorEvent, - BidiImageInputEvent, - BidiInputEvent, - BidiInterruptionEvent, - BidiUsageEvent, - BidiOutputEvent, - BidiTextInputEvent, - BidiTranscriptStreamEvent, - BidiResponseCompleteEvent, - BidiResponseStartEvent, -) -from .bidirectional_model import BidiModel - -logger = logging.getLogger(__name__) - -# Nova Sonic configuration constants -NOVA_INFERENCE_CONFIG = {"maxTokens": 1024, "topP": 0.9, "temperature": 0.7} - -NOVA_AUDIO_INPUT_CONFIG = { - "mediaType": "audio/lpcm", - "sampleRateHertz": 16000, - "sampleSizeBits": 16, - "channelCount": 1, - "audioType": "SPEECH", - "encoding": "base64", -} - -NOVA_AUDIO_OUTPUT_CONFIG = { - "mediaType": "audio/lpcm", - "sampleRateHertz": 16000, - "sampleSizeBits": 16, - "channelCount": 1, - "voiceId": "matthew", - "encoding": "base64", - "audioType": "SPEECH", -} - -NOVA_TEXT_CONFIG = {"mediaType": "text/plain"} -NOVA_TOOL_CONFIG = {"mediaType": "application/json"} - -# Timing constants -EVENT_DELAY = 0.1 -RESPONSE_TIMEOUT = 1.0 - - -class BidiNovaSonicModel(BidiModel): - """Nova Sonic implementation for bidirectional streaming. - - Combines model configuration and connection state in a single class. - Manages Nova Sonic's complex event sequencing, audio format conversion, and - tool execution patterns while providing the standard BidiModel interface. - """ - - def __init__( - self, - model_id: str = "amazon.nova-sonic-v1:0", - region: str = "us-east-1", - **kwargs - ) -> None: - """Initialize Nova Sonic bidirectional model. - - Args: - model_id: Nova Sonic model identifier. - region: AWS region. - **kwargs: Reserved for future parameters. - """ - # Model configuration - self.model_id = model_id - self.region = region - self.client = None - - # Connection state (initialized in start()) - self.stream = None - self.connection_id = None - self._active = False - - # Nova Sonic requires unique content names - self.audio_content_name = None - - # Audio connection state - self.audio_connection_active = False - - # Background task and event queue - self._response_task = None - self._event_queue = None - - # Track API-provided identifiers - self._current_completion_id = None - self._current_role = None - self._generation_stage = None - - logger.debug("Nova Sonic bidirectional model initialized: %s", model_id) - - async def start( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs, - ) -> None: - """Establish bidirectional connection to Nova Sonic. - - Args: - system_prompt: System instructions for the model. - tools: List of tools available to the model. - messages: Conversation history to initialize with. - **kwargs: Additional configuration options. - """ - if self._active: - raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") - - logger.debug("Nova connection create - starting") - - try: - # Initialize client if needed - if not self.client: - await self._initialize_client() - - # Initialize connection state - self.connection_id = str(uuid.uuid4()) - self._active = True - self.audio_content_name = str(uuid.uuid4()) - self._event_queue = asyncio.Queue() - - # Start Nova Sonic bidirectional stream - self.stream = await self.client.invoke_model_with_bidirectional_stream( - InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) - ) - - # Validate stream - if not self.stream: - logger.error("Stream is None") - raise ValueError("Stream cannot be None") - - logger.debug("Nova Sonic connection initialized with connection_id: %s", self.connection_id) - - # Send initialization events - system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." - init_events = self._build_initialization_events(system_prompt, tools or [], messages) - - logger.debug("Nova Sonic initialization - sending %d events", len(init_events)) - await self._send_initialization_events(init_events) - - # Start background response processor - self._response_task = asyncio.create_task(self._process_responses()) - - logger.info("Nova Sonic connection established successfully") - - except Exception as e: - self._active = False - logger.error("Nova connection create error: %s", str(e)) - raise - - def _build_initialization_events( - self, system_prompt: str, tools: list[ToolSpec], messages: Messages | None - ) -> list[str]: - """Build the sequence of initialization events.""" - events = [self._get_connection_start_event(), self._get_prompt_start_event(tools)] - - events.extend(self._get_system_prompt_events(system_prompt)) - - # Message history would be processed here if needed in the future - # Currently not implemented as it's not used in the existing test cases - - return events - - async def _send_initialization_events(self, events: list[str]) -> None: - """Send initialization events with required delays.""" - for _i, event in enumerate(events): - await self._send_nova_event(event) - await asyncio.sleep(EVENT_DELAY) - - async def _process_responses(self) -> None: - """Process Nova Sonic responses continuously.""" - logger.debug("Nova Sonic response processor started") - - try: - while self._active: - try: - output = await asyncio.wait_for(self.stream.await_output(), timeout=RESPONSE_TIMEOUT) - result = await output[1].receive() - - if result.value and result.value.bytes_: - await self._handle_response_data(result.value.bytes_.decode("utf-8")) - - except asyncio.TimeoutError: - await asyncio.sleep(0.1) - continue - except Exception as e: - logger.warning("Nova Sonic response error: %s", e) - await asyncio.sleep(0.1) - continue - - except Exception as e: - logger.error("Nova Sonic fatal error: %s", e) - finally: - logger.debug("Nova Sonic response processor stopped") - - async def _handle_response_data(self, response_data: str) -> None: - """Handle decoded response data from Nova Sonic.""" - try: - json_data = json.loads(response_data) - - if "event" in json_data: - nova_event = json_data["event"] - self._log_event_type(nova_event) - - if not hasattr(self, "_event_queue"): - self._event_queue = asyncio.Queue() - - await self._event_queue.put(nova_event) - except json.JSONDecodeError as e: - logger.warning("Nova Sonic JSON decode error: %s", e) - - def _log_event_type(self, nova_event: dict[str, any]) -> None: - """Log specific Nova Sonic event types for debugging.""" - if "usageEvent" in nova_event: - logger.debug("Nova usage: %s", nova_event["usageEvent"]) - elif "textOutput" in nova_event: - logger.debug("Nova text output") - elif "toolUse" in nova_event: - tool_use = nova_event["toolUse"] - logger.debug("Nova tool use: %s (id: %s)", tool_use["toolName"], tool_use["toolUseId"]) - elif "audioOutput" in nova_event: - audio_content = nova_event["audioOutput"]["content"] - audio_bytes = base64.b64decode(audio_content) - logger.debug("Nova audio output: %d bytes", len(audio_bytes)) - - async def receive(self) -> AsyncIterable[dict[str, any]]: - """Receive Nova Sonic events and convert to provider-agnostic format.""" - if not self.stream: - logger.error("Stream is None") - return - - logger.debug("Nova events - starting event stream") - - # Emit connection start event - yield BidiConnectionStartEvent( - connection_id=self.connection_id, - model=self.model_id - ) - - try: - while self._active: - try: - # Get events from the queue populated by _process_responses - nova_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) - - # Convert to provider-agnostic format - provider_event = self._convert_nova_event(nova_event) - if provider_event: - yield provider_event - - except asyncio.TimeoutError: - # No events in queue - continue waiting - continue - - except Exception as e: - logger.error("Error receiving Nova Sonic event: %s", e) - logger.error(traceback.format_exc()) - yield BidiErrorEvent(error=e) - finally: - # Emit connection close event - yield BidiConnectionCloseEvent(connection_id=self.connection_id, reason="complete") - - async def send( - self, - content: BidiInputEvent | ToolResultEvent, - ) -> None: - """Unified send method for all content types. Sends the given content to Nova Sonic. - - Dispatches to appropriate internal handler based on content type. - - Args: - content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent). - """ - if not self._active: - return - - try: - if isinstance(content, BidiTextInputEvent): - await self._send_text_content(content.text) - elif isinstance(content, BidiAudioInputEvent): - await self._send_audio_content(content) - elif isinstance(content, BidiImageInputEvent): - # BidiImageInputEvent - not supported by Nova Sonic - logger.warning("Image input not supported by Nova Sonic") - elif isinstance(content, ToolResultEvent): - tool_result = content.get("tool_result") - if tool_result: - await self._send_tool_result(tool_result) - else: - logger.warning(f"Unknown content type: {type(content)}") - except Exception as e: - logger.error(f"Error sending content: {e}") - raise # Propagate exception for debugging in experimental code - - async def _start_audio_connection(self) -> None: - """Internal: Start audio input connection (call once before sending audio chunks).""" - if self.audio_connection_active: - return - - logger.debug("Nova audio connection start") - - audio_content_start = json.dumps( - { - "event": { - "contentStart": { - "promptName": self.connection_id, - "contentName": self.audio_content_name, - "type": "AUDIO", - "interactive": True, - "role": "USER", - "audioInputConfiguration": NOVA_AUDIO_INPUT_CONFIG, - } - } - } - ) - - await self._send_nova_event(audio_content_start) - self.audio_connection_active = True - - async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: - """Internal: Send audio using Nova Sonic protocol-specific format.""" - # Start audio connection if not already active - if not self.audio_connection_active: - await self._start_audio_connection() - - # Audio is already base64 encoded in the event - # Send audio input event - audio_event = json.dumps( - { - "event": { - "audioInput": { - "promptName": self.connection_id, - "contentName": self.audio_content_name, - "content": audio_input.audio, - } - } - } - ) - - await self._send_nova_event(audio_event) - - async def _end_audio_input(self) -> None: - """Internal: End current audio input connection to trigger Nova Sonic processing.""" - if not self.audio_connection_active: - return - - logger.debug("Nova audio connection end") - - audio_content_end = json.dumps( - {"event": {"contentEnd": {"promptName": self.connection_id, "contentName": self.audio_content_name}}} - ) - - await self._send_nova_event(audio_content_end) - self.audio_connection_active = False - - async def _send_text_content(self, text: str) -> None: - """Internal: Send text content using Nova Sonic format.""" - content_name = str(uuid.uuid4()) - events = [ - self._get_text_content_start_event(content_name), - self._get_text_input_event(content_name, text), - self._get_content_end_event(content_name), - ] - - for event in events: - await self._send_nova_event(event) - - async def _send_interrupt(self) -> None: - """Internal: Send interruption signal to Nova Sonic.""" - # Nova Sonic handles interruption through special input events - interrupt_event = json.dumps( - { - "event": { - "audioInput": { - "promptName": self.connection_id, - "contentName": self.audio_content_name, - "stopReason": "INTERRUPTED", - } - } - } - ) - await self._send_nova_event(interrupt_event) - - async def _send_tool_result(self, tool_result: ToolResult) -> None: - """Internal: Send tool result using Nova Sonic toolResult format.""" - tool_use_id = tool_result.get("toolUseId") - - logger.debug("Nova tool result send: %s", tool_use_id) - - # Extract result content - result_data = {} - if "content" in tool_result: - # Extract text from content blocks - for block in tool_result["content"]: - if "text" in block: - result_data = {"result": block["text"]} - break - - content_name = str(uuid.uuid4()) - events = [ - self._get_tool_content_start_event(content_name, tool_use_id), - self._get_tool_result_event(content_name, result_data), - self._get_content_end_event(content_name), - ] - - for event in events: - await self._send_nova_event(event) - - async def stop(self) -> None: - """Close Nova Sonic connection with proper cleanup sequence.""" - if not self._active: - return - - logger.debug("Nova cleanup - starting connection close") - self._active = False - - # Cancel response processing task if running - if hasattr(self, "_response_task") and not self._response_task.done(): - self._response_task.cancel() - try: - await self._response_task - except asyncio.CancelledError: - pass - - try: - # End audio connection if active - if self.audio_connection_active: - await self._end_audio_input() - - # Send cleanup events - cleanup_events = [self._get_prompt_end_event(), self._get_connection_end_event()] - - for event in cleanup_events: - try: - await self._send_nova_event(event) - except Exception as e: - logger.warning("Error during Nova Sonic cleanup: %s", e) - - # Close stream - try: - await self.stream.input_stream.close() - except Exception as e: - logger.warning("Error closing Nova Sonic stream: %s", e) - - except Exception as e: - logger.error("Nova cleanup error: %s", str(e)) - finally: - logger.debug("Nova connection closed") - - def _convert_nova_event(self, nova_event: dict[str, any]) -> BidiOutputEvent | None: - """Convert Nova Sonic events to TypedEvent format.""" - # Handle completion start - track completionId - if "completionStart" in nova_event: - completion_data = nova_event["completionStart"] - self._current_completion_id = completion_data.get("completionId") - logger.debug("Nova completion started: %s", self._current_completion_id) - return None - - # Handle completion end - if "completionEnd" in nova_event: - completion_data = nova_event["completionEnd"] - completion_id = completion_data.get("completionId", self._current_completion_id) - stop_reason = completion_data.get("stopReason", "END_TURN") - - event = BidiResponseCompleteEvent( - response_id=completion_id or str(uuid.uuid4()), # Fallback to UUID if missing - stop_reason="interrupted" if stop_reason == "INTERRUPTED" else "complete" - ) - - # Clear completion tracking - self._current_completion_id = None - return event - - # Handle audio output - if "audioOutput" in nova_event: - # Audio is already base64 string from Nova Sonic - audio_content = nova_event["audioOutput"]["content"] - return BidiAudioStreamEvent( - audio=audio_content, - format="pcm", - sample_rate=24000, - channels=1 - ) - - # Handle text output (transcripts) - elif "textOutput" in nova_event: - text_content = nova_event["textOutput"]["content"] - # Check for Nova Sonic interruption pattern - if '{ "interrupted" : true }' in text_content: - logger.debug("Nova interruption detected in text") - return BidiInterruptionEvent(reason="user_speech") - - return BidiTranscriptStreamEvent( - delta={"text": text_content}, - text=text_content, - role=self._current_role.lower() if self._current_role else "assistant", - is_final=self._generation_stage == "FINAL", - current_transcript=text_content - ) - - # Handle tool use - if "toolUse" in nova_event: - tool_use = nova_event["toolUse"] - tool_use_event: ToolUse = { - "toolUseId": tool_use["toolUseId"], - "name": tool_use["toolName"], - "input": json.loads(tool_use["content"]), - } - # Return ToolUseStreamEvent for consistency with standard agent - return ToolUseStreamEvent( - delta={"toolUse": tool_use_event}, - current_tool_use=tool_use_event - ) - - # Handle interruption - if nova_event.get("stopReason") == "INTERRUPTED": - logger.debug("Nova interruption stop reason") - return BidiInterruptionEvent(reason="user_speech") - - # Handle usage events - convert to multimodal usage format - if "usageEvent" in nova_event: - usage_data = nova_event["usageEvent"] - total_input = usage_data.get("totalInputTokens", 0) - total_output = usage_data.get("totalOutputTokens", 0) - - return BidiUsageEvent( - input_tokens=total_input, - output_tokens=total_output, - total_tokens=usage_data.get("totalTokens", total_input + total_output) - ) - - # Handle content start events (track role and emit response start) - if "contentStart" in nova_event: - content_data = nova_event["contentStart"] - role = content_data.get("role", "unknown") - # Store role for subsequent text output events - self._current_role = role - - if content_data["type"] == "TEXT": - self._generation_stage = json.loads(content_data["additionalModelFields"])["generationStage"] - - # Emit response start event using API-provided completionId - # completionId should already be tracked from completionStart event - return BidiResponseStartEvent( - response_id=self._current_completion_id or str(uuid.uuid4()) # Fallback to UUID if missing - ) - - # Ignore other events (contentEnd, etc.) - return - - # Nova Sonic event template methods - def _get_connection_start_event(self) -> str: - """Generate Nova Sonic connection start event.""" - return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": NOVA_INFERENCE_CONFIG}}}) - - def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: - """Generate Nova Sonic prompt start event with tool configuration.""" - prompt_start_event = { - "event": { - "promptStart": { - "promptName": self.connection_id, - "textOutputConfiguration": NOVA_TEXT_CONFIG, - "audioOutputConfiguration": NOVA_AUDIO_OUTPUT_CONFIG, - } - } - } - - if tools: - tool_config = self._build_tool_configuration(tools) - prompt_start_event["event"]["promptStart"]["toolUseOutputConfiguration"] = NOVA_TOOL_CONFIG - prompt_start_event["event"]["promptStart"]["toolConfiguration"] = {"tools": tool_config} - - return json.dumps(prompt_start_event) - - def _build_tool_configuration(self, tools: list[ToolSpec]) -> list[dict]: - """Build tool configuration from tool specs.""" - tool_config = [] - for tool in tools: - input_schema = ( - {"json": json.dumps(tool["inputSchema"]["json"])} - if "json" in tool["inputSchema"] - else {"json": json.dumps(tool["inputSchema"])} - ) - - tool_config.append( - {"toolSpec": {"name": tool["name"], "description": tool["description"], "inputSchema": input_schema}} - ) - return tool_config - - def _get_system_prompt_events(self, system_prompt: str) -> list[str]: - """Generate system prompt events.""" - content_name = str(uuid.uuid4()) - return [ - self._get_text_content_start_event(content_name, "SYSTEM"), - self._get_text_input_event(content_name, system_prompt), - self._get_content_end_event(content_name), - ] - - def _get_text_content_start_event(self, content_name: str, role: str = "USER") -> str: - """Generate text content start event.""" - return json.dumps( - { - "event": { - "contentStart": { - "promptName": self.connection_id, - "contentName": content_name, - "type": "TEXT", - "role": role, - "interactive": True, - "textInputConfiguration": NOVA_TEXT_CONFIG, - } - } - } - ) - - def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> str: - """Generate tool content start event.""" - return json.dumps( - { - "event": { - "contentStart": { - "promptName": self.connection_id, - "contentName": content_name, - "interactive": False, - "type": "TOOL", - "role": "TOOL", - "toolResultInputConfiguration": { - "toolUseId": tool_use_id, - "type": "TEXT", - "textInputConfiguration": NOVA_TEXT_CONFIG, - }, - } - } - } - ) - - def _get_text_input_event(self, content_name: str, text: str) -> str: - """Generate text input event.""" - return json.dumps( - {"event": {"textInput": {"promptName": self.connection_id, "contentName": content_name, "content": text}}} - ) - - def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> str: - """Generate tool result event.""" - return json.dumps( - { - "event": { - "toolResult": { - "promptName": self.connection_id, - "contentName": content_name, - "content": json.dumps(result), - } - } - } - ) - - def _get_content_end_event(self, content_name: str) -> str: - """Generate content end event.""" - return json.dumps({"event": {"contentEnd": {"promptName": self.connection_id, "contentName": content_name}}}) - - def _get_prompt_end_event(self) -> str: - """Generate prompt end event.""" - return json.dumps({"event": {"promptEnd": {"promptName": self.connection_id}}}) - - def _get_connection_end_event(self) -> str: - """Generate connection end event.""" - return json.dumps({"event": {"connectionEnd": {}}}) - - async def _send_nova_event(self, event: str) -> None: - """Send event JSON string to Nova Sonic stream.""" - try: - # Event is already a JSON string - bytes_data = event.encode("utf-8") - chunk = InvokeModelWithBidirectionalStreamInputChunk(value=BidirectionalInputPayloadPart(bytes_=bytes_data)) - await self.stream.input_stream.send(chunk) - logger.debug("Successfully sent Nova Sonic event") - - except Exception as e: - logger.error("Error sending Nova Sonic event: %s", e) - logger.error("Event was: %s", event) - raise - - async def _initialize_client(self) -> None: - """Initialize Nova Sonic client.""" - try: - config = Config( - endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", - region=self.region, - aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), - auth_scheme_resolver=HTTPAuthSchemeResolver(), - auth_schemes={"aws.auth#sigv4": SigV4AuthScheme(service="bedrock")}, - ) - - self.client = BedrockRuntimeClient(config=config) - logger.debug("Nova Sonic client initialized") - - except ImportError as e: - logger.error("Nova Sonic dependencies not available: %s", e) - raise - except Exception as e: - logger.error("Error initializing Nova Sonic client: %s", e) - raise diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py deleted file mode 100644 index 74f1942ff..000000000 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ /dev/null @@ -1,653 +0,0 @@ -"""OpenAI Realtime API provider for Strands bidirectional streaming. - -Provides real-time audio and text communication through OpenAI's Realtime API -with WebSocket connections, voice activity detection, and function calling. -""" - -import asyncio -import base64 -import json -import logging -import os -import uuid -from typing import AsyncIterable, Union - -import websockets -from websockets.exceptions import ConnectionClosed - -from ....types.content import Messages -from ....types.tools import ToolResult, ToolSpec, ToolUse -from ....types._events import ToolResultEvent, ToolUseStreamEvent -from ..types.events import ( - BidiAudioInputEvent, - BidiAudioStreamEvent, - BidiConnectionCloseEvent, - BidiConnectionStartEvent, - BidiErrorEvent, - BidiImageInputEvent, - BidiInputEvent, - BidiInterruptionEvent, - BidiUsageEvent, - BidiOutputEvent, - BidiTextInputEvent, - BidiTranscriptStreamEvent, - BidiResponseCompleteEvent, - BidiResponseStartEvent, -) -from .bidirectional_model import BidiModel - -logger = logging.getLogger(__name__) - -# OpenAI Realtime API configuration -OPENAI_REALTIME_URL = "wss://api.openai.com/v1/realtime" -DEFAULT_MODEL = "gpt-realtime" - -AUDIO_FORMAT = {"type": "audio/pcm", "rate": 24000} - -DEFAULT_SESSION_CONFIG = { - "type": "realtime", - "instructions": "You are a helpful assistant. Please speak in English and keep your responses clear and concise.", - "output_modalities": ["audio"], - "audio": { - "input": { - "format": AUDIO_FORMAT, - "transcription": { - "model": "gpt-4o-transcribe" - }, - "turn_detection": { - "type": "server_vad", - "threshold": 0.5, - "prefix_padding_ms": 300, - "silence_duration_ms": 500, - } - }, - "output": {"format": AUDIO_FORMAT, "voice": "alloy"}, - }, -} - - -class BidiOpenAIRealtimeModel(BidiModel): - """OpenAI Realtime API implementation for bidirectional streaming. - - Combines model configuration and connection state in a single class. - Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, - function calling, and event conversion to Strands format. - """ - - def __init__( - self, - model: str = DEFAULT_MODEL, - api_key: str | None = None, - organization: str | None = None, - project: str | None = None, - session_config: dict[str, any] | None = None, - **kwargs - ) -> None: - """Initialize OpenAI Realtime bidirectional model. - - Args: - model: OpenAI model identifier (default: gpt-realtime). - api_key: OpenAI API key for authentication. - organization: OpenAI organization ID for API requests. - project: OpenAI project ID for API requests. - session_config: Session configuration parameters (e.g., voice, turn_detection, modalities). - **kwargs: Reserved for future parameters. - """ - # Model configuration - self.model = model - self.api_key = api_key - self.organization = organization - self.project = project - self.session_config = session_config or {} - - if not self.api_key: - self.api_key = os.getenv("OPENAI_API_KEY") - if not self.api_key: - raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.") - - # Connection state (initialized in start()) - self.websocket = None - self.connection_id = None - self._active = False - - self._event_queue = None - self._response_task = None - self._function_call_buffer = {} - - logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) - - async def start( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs, - ) -> None: - """Establish bidirectional connection to OpenAI Realtime API. - - Args: - system_prompt: System instructions for the model. - tools: List of tools available to the model. - messages: Conversation history to initialize with. - **kwargs: Additional configuration options. - """ - if self._active: - raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") - - logger.info("Creating OpenAI Realtime connection...") - - try: - # Initialize connection state - self.connection_id = str(uuid.uuid4()) - self._active = True - self._event_queue = asyncio.Queue() - self._function_call_buffer = {} - - # Establish WebSocket connection - url = f"{OPENAI_REALTIME_URL}?model={self.model}" - - headers = [("Authorization", f"Bearer {self.api_key}")] - if self.organization: - headers.append(("OpenAI-Organization", self.organization)) - if self.project: - headers.append(("OpenAI-Project", self.project)) - - self.websocket = await websockets.connect(url, additional_headers=headers) - logger.info("WebSocket connected successfully") - - # Configure session - session_config = self._build_session_config(system_prompt, tools) - await self._send_event({"type": "session.update", "session": session_config}) - - # Add conversation history if provided - if messages: - await self._add_conversation_history(messages) - - # Start background response processor - self._response_task = asyncio.create_task(self._process_responses()) - logger.info("OpenAI Realtime connection established") - - except Exception as e: - self._active = False - logger.error("OpenAI connection error: %s", e) - raise - - def _require_active(self) -> bool: - """Check if session is active.""" - return self._active - - def _create_text_event(self, text: str, role: str, is_final: bool = True) -> BidiTranscriptStreamEvent: - """Create standardized transcript event. - - Args: - text: The transcript text - role: The role (will be normalized to lowercase) - is_final: Whether this is the final transcript - """ - # Normalize role to lowercase and ensure it's either "user" or "assistant" - normalized_role = role.lower() if isinstance(role, str) else "assistant" - if normalized_role not in ["user", "assistant"]: - normalized_role = "assistant" - - return BidiTranscriptStreamEvent( - delta={"text": text}, - text=text, - role=normalized_role, - is_final=is_final, - current_transcript=text if is_final else None - ) - - def _create_voice_activity_event(self, activity_type: str) -> BidiInterruptionEvent | None: - """Create standardized interruption event for voice activity.""" - # Only speech_started triggers interruption - if activity_type == "speech_started": - return BidiInterruptionEvent(reason="user_speech") - # Other voice activity events are logged but don't create events - return None - - def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict: - """Build session configuration for OpenAI Realtime API.""" - config = DEFAULT_SESSION_CONFIG.copy() - - if system_prompt: - config["instructions"] = system_prompt - - if tools: - config["tools"] = self._convert_tools_to_openai_format(tools) - - # Apply user-provided session configuration - supported_params = { - "type", "output_modalities", "instructions", "voice", "audio", - "tools", "tool_choice", "input_audio_format", "output_audio_format", - "input_audio_transcription", "turn_detection" - } - - for key, value in self.session_config.items(): - if key in supported_params: - config[key] = value - else: - logger.warning("Ignoring unsupported session parameter: %s", key) - - return config - - def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: - """Convert Strands tool specifications to OpenAI Realtime API format.""" - openai_tools = [] - - for tool in tools: - input_schema = tool["inputSchema"] - if "json" in input_schema: - schema = json.loads(input_schema["json"]) if isinstance(input_schema["json"], str) else input_schema["json"] - else: - schema = input_schema - - # OpenAI Realtime API expects flat structure, not nested under "function" - openai_tool = { - "type": "function", - "name": tool["name"], - "description": tool["description"], - "parameters": schema - } - openai_tools.append(openai_tool) - - return openai_tools - - async def _add_conversation_history(self, messages: Messages) -> None: - """Add conversation history to the session.""" - for message in messages: - conversation_item = { - "type": "conversation.item.create", - "item": {"type": "message", "role": message["role"], "content": []} - } - - content = message.get("content", "") - if isinstance(content, str): - conversation_item["item"]["content"].append({"type": "input_text", "text": content}) - elif isinstance(content, list): - for item in content: - if isinstance(item, dict) and item.get("type") == "text": - conversation_item["item"]["content"].append({"type": "input_text", "text": item.get("text", "")}) - - await self._send_event(conversation_item) - - async def _process_responses(self) -> None: - """Process incoming WebSocket messages.""" - logger.debug("OpenAI Realtime response processor started") - - try: - async for message in self.websocket: - if not self._active: - break - - try: - event = json.loads(message) - await self._event_queue.put(event) - except json.JSONDecodeError as e: - logger.warning("Failed to parse OpenAI event: %s", e) - continue - - except ConnectionClosed: - logger.debug("OpenAI Realtime WebSocket connection closed") - except Exception as e: - logger.error("Error in OpenAI Realtime response processing: %s", e) - finally: - self._active = False - logger.debug("OpenAI Realtime response processor stopped") - - async def receive(self) -> AsyncIterable[BidiOutputEvent]: - """Receive OpenAI events and convert to Strands TypedEvent format.""" - # Emit connection start event - yield BidiConnectionStartEvent( - connection_id=self.connection_id, - model=self.model - ) - - try: - while self._active: - try: - openai_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) - for event in self._convert_openai_event(openai_event) or []: - yield event - except asyncio.TimeoutError: - continue - - except Exception as e: - logger.error("Error receiving OpenAI Realtime event: %s", e) - yield BidiErrorEvent(error=e) - finally: - # Emit connection close event - yield BidiConnectionCloseEvent(connection_id=self.connection_id, reason="complete") - - def _convert_openai_event(self, openai_event: dict[str, any]) -> list[BidiOutputEvent] | None: - """Convert OpenAI events to Strands TypedEvent format.""" - event_type = openai_event.get("type") - - # Turn start - response begins - if event_type == "response.created": - response = openai_event.get("response", {}) - response_id = response.get("id", str(uuid.uuid4())) - return [BidiResponseStartEvent(response_id=response_id)] - - # Audio output - elif event_type == "response.output_audio.delta": - # Audio is already base64 string from OpenAI - return [BidiAudioStreamEvent( - audio=openai_event["delta"], - format="pcm", - sample_rate=24000, - channels=1 - )] - - # Assistant text output events - combine multiple similar events - elif event_type in ["response.output_text.delta", "response.output_audio_transcript.delta"]: - role = openai_event.get("role", "assistant") - return [self._create_text_event(openai_event["delta"], role.lower() if isinstance(role, str) else "assistant")] - - # User transcription events - combine multiple similar events - elif event_type in ["conversation.item.input_audio_transcription.delta", - "conversation.item.input_audio_transcription.completed"]: - text_key = "delta" if "delta" in event_type else "transcript" - text = openai_event.get(text_key, "") - role = openai_event.get("role", "user") - is_final = "completed" in event_type - return [self._create_text_event(text, role.lower() if isinstance(role, str) else "user", is_final=is_final)] if text.strip() else None - - elif event_type == "conversation.item.input_audio_transcription.segment": - segment_data = openai_event.get("segment", {}) - text = segment_data.get("text", "") - role = segment_data.get("role", "user") - return [self._create_text_event(text, role.lower() if isinstance(role, str) else "user")] if text.strip() else None - - elif event_type == "conversation.item.input_audio_transcription.failed": - error_info = openai_event.get("error", {}) - logger.warning("OpenAI transcription failed: %s", error_info.get("message", "Unknown error")) - return None - - # Function call processing - elif event_type == "response.function_call_arguments.delta": - call_id = openai_event.get("call_id") - delta = openai_event.get("delta", "") - if call_id: - if call_id not in self._function_call_buffer: - self._function_call_buffer[call_id] = {"call_id": call_id, "name": "", "arguments": delta} - else: - self._function_call_buffer[call_id]["arguments"] += delta - return None - - elif event_type == "response.function_call_arguments.done": - call_id = openai_event.get("call_id") - if call_id and call_id in self._function_call_buffer: - function_call = self._function_call_buffer[call_id] - try: - tool_use: ToolUse = { - "toolUseId": call_id, - "name": function_call["name"], - "input": json.loads(function_call["arguments"]) if function_call["arguments"] else {}, - } - del self._function_call_buffer[call_id] - # Return ToolUseStreamEvent for consistency with standard agent - return [ToolUseStreamEvent( - delta={"toolUse": tool_use}, - current_tool_use=tool_use - )] - except (json.JSONDecodeError, KeyError) as e: - logger.warning("Error parsing function arguments for %s: %s", call_id, e) - del self._function_call_buffer[call_id] - return None - - # Voice activity detection - speech_started triggers interruption - elif event_type == "input_audio_buffer.speech_started": - # This is the primary interruption signal - handle it first - return [BidiInterruptionEvent(reason="user_speech")] - - # Response cancelled - handle interruption - elif event_type == "response.cancelled": - response = openai_event.get("response", {}) - response_id = response.get("id", "unknown") - logger.debug("OpenAI response cancelled: %s", response_id) - return [BidiResponseCompleteEvent( - response_id=response_id, - stop_reason="interrupted" - )] - - # Turn complete and usage - response finished - elif event_type == "response.done": - response = openai_event.get("response", {}) - response_id = response.get("id", "unknown") - status = response.get("status", "completed") - usage = response.get("usage") - - # Map OpenAI status to our stop_reason - stop_reason_map = { - "completed": "complete", - "cancelled": "interrupted", - "failed": "error", - "incomplete": "interrupted" - } - - # Build list of events to return - events = [] - - # Always add response complete event - events.append(BidiResponseCompleteEvent( - response_id=response_id, - stop_reason=stop_reason_map.get(status, "complete") - )) - - # Add usage event if available - if usage: - input_details = usage.get("input_token_details", {}) - output_details = usage.get("output_token_details", {}) - - # Build modality details - modality_details = [] - - # Text modality - text_input = input_details.get("text_tokens", 0) - text_output = output_details.get("text_tokens", 0) - if text_input > 0 or text_output > 0: - modality_details.append({ - "modality": "text", - "input_tokens": text_input, - "output_tokens": text_output - }) - - # Audio modality - audio_input = input_details.get("audio_tokens", 0) - audio_output = output_details.get("audio_tokens", 0) - if audio_input > 0 or audio_output > 0: - modality_details.append({ - "modality": "audio", - "input_tokens": audio_input, - "output_tokens": audio_output - }) - - # Image modality - image_input = input_details.get("image_tokens", 0) - if image_input > 0: - modality_details.append({ - "modality": "image", - "input_tokens": image_input, - "output_tokens": 0 - }) - - # Cached tokens - cached_tokens = input_details.get("cached_tokens", 0) - - # Add usage event - events.append(BidiUsageEvent( - input_tokens=usage.get("input_tokens", 0), - output_tokens=usage.get("output_tokens", 0), - total_tokens=usage.get("total_tokens", 0), - modality_details=modality_details if modality_details else None, - cache_read_input_tokens=cached_tokens if cached_tokens > 0 else None - )) - - # Return list of events - return events - - # Lifecycle events (log only) - combine multiple similar events - elif event_type in ["conversation.item.retrieve", "conversation.item.added"]: - item = openai_event.get("item", {}) - action = "retrieved" if "retrieve" in event_type else "added" - logger.debug("OpenAI conversation item %s: %s", action, item.get("id")) - return None - - elif event_type == "conversation.item.done": - logger.debug("OpenAI conversation item done: %s", openai_event.get("item", {}).get("id")) - return None - - # Response output events - combine similar events - elif event_type in ["response.output_item.added", "response.output_item.done", - "response.content_part.added", "response.content_part.done"]: - item_data = openai_event.get("item") or openai_event.get("part") - logger.debug("OpenAI %s: %s", event_type, item_data.get("id") if item_data else "unknown") - - # Track function call names from response.output_item.added - if event_type == "response.output_item.added": - item = openai_event.get("item", {}) - if item.get("type") == "function_call": - call_id = item.get("call_id") - function_name = item.get("name") - if call_id and function_name: - if call_id not in self._function_call_buffer: - self._function_call_buffer[call_id] = {"call_id": call_id, "name": function_name, "arguments": ""} - else: - self._function_call_buffer[call_id]["name"] = function_name - return None - - # Session/buffer events - combine simple log-only events - elif event_type in ["input_audio_buffer.committed", "input_audio_buffer.cleared", - "session.created", "session.updated"]: - logger.debug("OpenAI %s event", event_type) - return None - - elif event_type == "error": - error_data = openai_event.get("error", {}) - error_code = error_data.get("code", "") - - # Suppress expected errors that don't affect session state - if error_code == "response_cancel_not_active": - # This happens when trying to cancel a response that's not active - # It's safe to ignore as the session remains functional - logger.debug("OpenAI response cancel attempted when no response active (safe to ignore)") - return None - - # Log other errors - logger.error("OpenAI Realtime error: %s", error_data) - return None - - else: - logger.debug("Unhandled OpenAI event type: %s", event_type) - return None - - async def send( - self, - content: BidiInputEvent | ToolResultEvent, - ) -> None: - """Unified send method for all content types. Sends the given content to OpenAI. - - Dispatches to appropriate internal handler based on content type. - - Args: - content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent). - """ - if not self._require_active(): - return - - try: - # Note: TypedEvent inherits from dict, so isinstance checks for TypedEvent must come first - if isinstance(content, BidiTextInputEvent): - await self._send_text_content(content.text) - elif isinstance(content, BidiAudioInputEvent): - await self._send_audio_content(content) - elif isinstance(content, BidiImageInputEvent): - # BidiImageInputEvent - not supported by OpenAI Realtime yet - logger.warning("Image input not supported by OpenAI Realtime API") - elif isinstance(content, ToolResultEvent): - tool_result = content.get("tool_result") - if tool_result: - await self._send_tool_result(tool_result) - else: - logger.warning(f"Unknown content type: {type(content).__name__}") - except Exception as e: - logger.error(f"Error sending content: {e}") - raise # Propagate exception for debugging in experimental code - - async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: - """Internal: Send audio content to OpenAI for processing.""" - # Audio is already base64 encoded in the event - await self._send_event({"type": "input_audio_buffer.append", "audio": audio_input.audio}) - - async def _send_text_content(self, text: str) -> None: - """Internal: Send text content to OpenAI for processing.""" - item_data = { - "type": "message", - "role": "user", - "content": [{"type": "input_text", "text": text}] - } - await self._send_event({"type": "conversation.item.create", "item": item_data}) - await self._send_event({"type": "response.create"}) - - async def _send_interrupt(self) -> None: - """Internal: Send interruption signal to OpenAI.""" - await self._send_event({"type": "response.cancel"}) - - async def _send_tool_result(self, tool_result: ToolResult) -> None: - """Internal: Send tool result back to OpenAI.""" - tool_use_id = tool_result.get("toolUseId") - - logger.debug("OpenAI tool result send: %s", tool_use_id) - - # Extract result content - result_data = {} - if "content" in tool_result: - # Extract text from content blocks - for block in tool_result["content"]: - if "text" in block: - result_data = block["text"] - break - - result_text = json.dumps(result_data) if not isinstance(result_data, str) else result_data - - item_data = { - "type": "function_call_output", - "call_id": tool_use_id, - "output": result_text - } - await self._send_event({"type": "conversation.item.create", "item": item_data}) - await self._send_event({"type": "response.create"}) - - async def stop(self) -> None: - """Close session and cleanup resources.""" - if not self._active: - return - - logger.debug("OpenAI Realtime cleanup - starting connection close") - self._active = False - - if self._response_task and not self._response_task.done(): - self._response_task.cancel() - try: - await self._response_task - except asyncio.CancelledError: - pass - - try: - await self.websocket.close() - except Exception as e: - logger.warning("Error closing OpenAI Realtime WebSocket: %s", e) - - logger.debug("OpenAI Realtime connection closed") - - async def _send_event(self, event: dict[str, any]) -> None: - """Send event to OpenAI via WebSocket.""" - try: - message = json.dumps(event) - await self.websocket.send(message) - logger.debug("Sent OpenAI event: %s", event.get("type")) - except Exception as e: - logger.error("Error sending OpenAI event: %s", e) - raise - - diff --git a/src/strands/experimental/bidirectional_streaming/scripts/test_bidi.py b/src/strands/experimental/bidirectional_streaming/scripts/test_bidi.py deleted file mode 100644 index f04677635..000000000 --- a/src/strands/experimental/bidirectional_streaming/scripts/test_bidi.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Test BidirectionalAgent with simple developer experience.""" - -import asyncio -import sys -from pathlib import Path - -sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) - -from strands.experimental.bidirectional_streaming.agent.agent import BidiAgent -from strands.experimental.bidirectional_streaming.models.novasonic import BidiNovaSonicModel -from strands.experimental.bidirectional_streaming.io import BidiAudioIO, BidiTextIO -from strands_tools import calculator - - -async def main(): - """Test the BidirectionalAgent API.""" - - - # Nova Sonic model - audio_io = BidiAudioIO(audio_config={}) - text_io = BidiTextIO() - model = BidiNovaSonicModel(region="us-east-1") - - async with BidiAgent(model=model, tools=[calculator]) as agent: - print("New BidiAgent Experience") - print("Try asking: 'What is 25 times 8?' or 'Calculate the square root of 144'") - await agent.run(inputs=[audio_io.input()], outputs=[audio_io.output(), text_io.output()]) - - -if __name__ == "__main__": - try: - asyncio.run(main()) - except KeyboardInterrupt: - print("\n⏹️ Conversation ended by user") - except Exception as e: - print(f"❌ Error: {e}") - import traceback - traceback.print_exc() diff --git a/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_novasonic.py b/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_novasonic.py deleted file mode 100644 index 42d8d436e..000000000 --- a/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_novasonic.py +++ /dev/null @@ -1,256 +0,0 @@ -"""Test suite for bidirectional streaming with real-time audio interaction. - -Tests the complete bidirectional streaming system including audio input/output, -interruption handling, and concurrent tool execution using Nova Sonic. -""" - -import asyncio -import base64 -import sys -from pathlib import Path - -# Add the src directory to Python path -sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) -import os -import time - -import pyaudio -from strands_tools import calculator - -from strands.experimental.bidirectional_streaming.agent.agent import BidiAgent -from strands.experimental.bidirectional_streaming.models.novasonic import BidiNovaSonicModel - - -def test_direct_tools(): - """Test direct tool calling.""" - print("Testing direct tool calling...") - - # Check AWS credentials - if not all([os.getenv("AWS_ACCESS_KEY_ID"), os.getenv("AWS_SECRET_ACCESS_KEY")]): - print("AWS credentials not set - skipping test") - return - - try: - model = BidiNovaSonicModel() - agent = BidirectionalAgent(model=model, tools=[calculator]) - - # Test calculator - result = agent.tool.calculator(expression="2 * 3") - content = result.get("content", [{}])[0].get("text", "") - print(f"Result: {content}") - print("Test completed") - - except Exception as e: - print(f"Test failed: {e}") - - -async def play(context): - """Play audio output with responsive interruption support.""" - audio = pyaudio.PyAudio() - speaker = audio.open( - channels=1, - format=pyaudio.paInt16, - output=True, - rate=24000, - frames_per_buffer=1024, - ) - - try: - while context["active"]: - try: - # Check for interruption first - if context.get("interrupted", False): - # Clear entire audio queue immediately - while not context["audio_out"].empty(): - try: - context["audio_out"].get_nowait() - except asyncio.QueueEmpty: - break - - context["interrupted"] = False - await asyncio.sleep(0.05) - continue - - # Get next audio data - audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) - - if audio_data and context["active"]: - chunk_size = 1024 - for i in range(0, len(audio_data), chunk_size): - # Check for interruption before each chunk - if context.get("interrupted", False) or not context["active"]: - break - - end = min(i + chunk_size, len(audio_data)) - chunk = audio_data[i:end] - speaker.write(chunk) - await asyncio.sleep(0.001) - - except asyncio.TimeoutError: - continue # No audio available - except asyncio.QueueEmpty: - await asyncio.sleep(0.01) - except asyncio.CancelledError: - break - - except asyncio.CancelledError: - pass - finally: - speaker.close() - audio.terminate() - - -async def record(context): - """Record audio input from microphone.""" - audio = pyaudio.PyAudio() - microphone = audio.open( - channels=1, - format=pyaudio.paInt16, - frames_per_buffer=1024, - input=True, - rate=16000, - ) - - try: - while context["active"]: - try: - audio_bytes = microphone.read(1024, exception_on_overflow=False) - context["audio_in"].put_nowait(audio_bytes) - await asyncio.sleep(0.01) - except asyncio.CancelledError: - break - except asyncio.CancelledError: - pass - finally: - microphone.close() - audio.terminate() - - -async def receive(agent, context): - """Receive and process events from agent.""" - try: - async for event in agent.receive(): - event_type = event.get("type", "unknown") - - # Handle audio stream events (bidi_audio_stream) - if event_type == "bidi_audio_stream": - if not context.get("interrupted", False): - # Decode base64 audio string to bytes for playback - audio_b64 = event["audio"] - audio_data = base64.b64decode(audio_b64) - context["audio_out"].put_nowait(audio_data) - - # Handle interruption events (bidi_interruption) - elif event_type == "bidi_interruption": - context["interrupted"] = True - - # Handle transcript events (bidi_transcript_stream) - elif event_type == "bidi_transcript_stream": - text_content = event.get("text", "") - role = event.get("role", "unknown") - - # Log transcript output - if role == "user": - print(f"User: {text_content}") - elif role == "assistant": - print(f"Assistant: {text_content}") - - # Handle response complete events (bidi_response_complete) - elif event_type == "bidi_response_complete": - # Reset interrupted state since the turn is complete - context["interrupted"] = False - - # Handle tool use events (tool_use_stream) - elif event_type == "tool_use_stream": - tool_use = event.get("current_tool_use", {}) - tool_name = tool_use.get("name", "unknown") - tool_input = tool_use.get("input", {}) - print(f"🔧 Tool called: {tool_name} with input: {tool_input}") - - # Handle tool result events (tool_result) - elif event_type == "tool_result": - tool_result = event.get("tool_result", {}) - tool_name = tool_result.get("name", "unknown") - result_content = tool_result.get("content", []) - result_text = "" - for block in result_content: - if isinstance(block, dict) and block.get("type") == "text": - result_text = block.get("text", "") - break - print(f"✅ Tool result from {tool_name}: {result_text}") - - except asyncio.CancelledError: - pass - - -async def send(agent, context): - """Send audio input to agent.""" - try: - while time.time() - context["start_time"] < context["duration"]: - try: - audio_bytes = context["audio_in"].get_nowait() - # Create audio event using TypedEvent - from strands.experimental.bidirectional_streaming.types.events import BidiAudioInputEvent - - audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') - audio_event = BidiAudioInputEvent( - audio=audio_b64, - format="pcm", - sample_rate=16000, - channels=1 - ) - await agent.send(audio_event) - except asyncio.QueueEmpty: - await asyncio.sleep(0.01) # Restored to working timing - except asyncio.CancelledError: - break - - context["active"] = False - except asyncio.CancelledError: - pass - - -async def main(duration=180): - """Main function for bidirectional streaming test.""" - print("Starting bidirectional streaming test...") - print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") - - # Initialize model and agent - model = BidiNovaSonicModel(region="us-east-1") - agent = BidiAgent(model=model, tools=[calculator], system_prompt="You are a helpful assistant.") - - await agent.start() - - # Create shared context for all tasks - context = { - "active": True, - "audio_in": asyncio.Queue(), - "audio_out": asyncio.Queue(), - "connection": agent._agent_loop, - "duration": duration, - "start_time": time.time(), - "interrupted": False, - } - - print("Speak into microphone. Press Ctrl+C to exit.") - - try: - # Run all tasks concurrently - await asyncio.gather( - play(context), record(context), receive(agent, context), send(agent, context), return_exceptions=True - ) - except KeyboardInterrupt: - print("\nInterrupted by user") - except asyncio.CancelledError: - print("\nTest cancelled") - finally: - print("Cleaning up...") - context["active"] = False - await agent.stop() - - -if __name__ == "__main__": - # Test direct tool calling first - test_direct_tools() - - asyncio.run(main()) diff --git a/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_openai.py deleted file mode 100644 index dd19e958d..000000000 --- a/src/strands/experimental/bidirectional_streaming/scripts/test_bidi_openai.py +++ /dev/null @@ -1,324 +0,0 @@ -#!/usr/bin/env python3 -"""Test OpenAI Realtime API speech-to-speech interaction.""" - -import asyncio -import base64 -import os -import sys -import time -from pathlib import Path - -# Add the src directory to Python path -sys.path.insert(0, str(Path(__file__).parent / "src")) - -import pyaudio -from strands_tools import calculator - -from strands.experimental.bidirectional_streaming.agent.agent import BidiAgent -from strands.experimental.bidirectional_streaming.models.openai import BidiOpenAIRealtimeModel - - -async def play(context): - """Handle audio playback with interruption support.""" - audio = pyaudio.PyAudio() - - try: - speaker = audio.open( - format=pyaudio.paInt16, - channels=1, - rate=24000, # OpenAI Realtime uses 24kHz - output=True, - frames_per_buffer=1024, - ) - - while context["active"]: - try: - # Check for interruption - if context.get("interrupted", False): - # Clear audio queue on interruption - while not context["audio_out"].empty(): - try: - context["audio_out"].get_nowait() - except asyncio.QueueEmpty: - break - - context["interrupted"] = False - await asyncio.sleep(0.05) - continue - - # Get audio data with timeout - try: - audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) - - if audio_data and context["active"]: - # Play in chunks to allow interruption - chunk_size = 1024 - for i in range(0, len(audio_data), chunk_size): - if context.get("interrupted", False) or not context["active"]: - break - - chunk = audio_data[i:i + chunk_size] - speaker.write(chunk) - await asyncio.sleep(0.001) # Brief pause for responsiveness - - except asyncio.TimeoutError: - continue - - except asyncio.CancelledError: - break - - except asyncio.CancelledError: - pass - except Exception as e: - print(f"Audio playback error: {e}") - finally: - try: - speaker.close() - except: - pass - audio.terminate() - - -async def record(context): - """Handle microphone recording.""" - audio = pyaudio.PyAudio() - - try: - microphone = audio.open( - format=pyaudio.paInt16, - channels=1, - rate=24000, # Match OpenAI's expected input rate - input=True, - frames_per_buffer=1024, - ) - - while context["active"]: - try: - audio_bytes = microphone.read(1024, exception_on_overflow=False) - await context["audio_in"].put(audio_bytes) - await asyncio.sleep(0.01) - except asyncio.CancelledError: - break - - except asyncio.CancelledError: - pass - except Exception as e: - print(f"Microphone recording error: {e}") - finally: - try: - microphone.close() - except: - pass - audio.terminate() - - -async def receive(agent, context): - """Handle events from the agent.""" - try: - async for event in agent.receive(): - if not context["active"]: - break - - # Get event type - event_type = event.get("type", "unknown") - - # Handle audio stream events (bidi_audio_stream) - if event_type == "bidi_audio_stream": - # Decode base64 audio string to bytes for playback - audio_b64 = event["audio"] - audio_data = base64.b64decode(audio_b64) - - if not context.get("interrupted", False): - await context["audio_out"].put(audio_data) - - # Handle transcript events (bidi_transcript_stream) - elif event_type == "bidi_transcript_stream": - source = event.get("role", "assistant") - text = event.get("text", "").strip() - - if text: - if source == "user": - print(f"🎤 User: {text}") - elif source == "assistant": - print(f"🔊 Assistant: {text}") - - # Handle interruption events (bidi_interruption) - elif event_type == "bidi_interruption": - context["interrupted"] = True - print("⚠️ Interruption detected") - - # Handle connection start events (bidi_connection_start) - elif event_type == "bidi_connection_start": - print(f"✓ Session started: {event.get('model', 'unknown')}") - - # Handle connection close events (bidi_connection_close) - elif event_type == "bidi_connection_close": - print(f"✓ Session ended: {event.get('reason', 'unknown')}") - context["active"] = False - break - - # Handle response complete events (bidi_response_complete) - elif event_type == "bidi_response_complete": - # Reset interrupted state since the turn is complete - context["interrupted"] = False - - # Handle tool use events (tool_use_stream) - elif event_type == "tool_use_stream": - tool_use = event.get("current_tool_use", {}) - tool_name = tool_use.get("name", "unknown") - tool_input = tool_use.get("input", {}) - print(f"🔧 Tool called: {tool_name} with input: {tool_input}") - - # Handle tool result events (tool_result) - elif event_type == "tool_result": - tool_result = event.get("tool_result", {}) - tool_name = tool_result.get("name", "unknown") - result_content = tool_result.get("content", []) - result_text = "" - for block in result_content: - if isinstance(block, dict) and block.get("type") == "text": - result_text = block.get("text", "") - break - print(f"✅ Tool result from {tool_name}: {result_text}") - - except asyncio.CancelledError: - pass - except Exception as e: - print(f"Receive handler error: {e}") - finally: - pass - - -async def send(agent, context): - """Send audio from microphone to agent.""" - try: - while context["active"]: - try: - audio_bytes = await asyncio.wait_for(context["audio_in"].get(), timeout=0.1) - - # Create audio event using TypedEvent - # Encode audio bytes to base64 string for JSON serializability - from strands.experimental.bidirectional_streaming.types.events import BidiAudioInputEvent - - audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') - audio_event = BidiAudioInputEvent( - audio=audio_b64, - format="pcm", - sample_rate=24000, - channels=1 - ) - - await agent.send(audio_event) - - except asyncio.TimeoutError: - continue - except asyncio.CancelledError: - break - - except asyncio.CancelledError: - pass - except Exception as e: - print(f"Send handler error: {e}") - finally: - pass - - -async def main(): - """Main test function for OpenAI voice chat.""" - print("Starting OpenAI Realtime API test...") - - # Check API key - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - print("OPENAI_API_KEY environment variable not set") - return False - - # Check audio system - try: - audio = pyaudio.PyAudio() - audio.terminate() - except Exception as e: - print(f"Audio system error: {e}") - return False - - # Create OpenAI model - model = BidiOpenAIRealtimeModel( - model="gpt-4o-realtime-preview", - api_key=api_key, - session={ - "output_modalities": ["audio"], - "audio": { - "input": { - "format": {"type": "audio/pcm", "rate": 24000}, - "turn_detection": { - "type": "server_vad", - "threshold": 0.5, - "silence_duration_ms": 700 - } - }, - "output": { - "format": {"type": "audio/pcm", "rate": 24000}, - "voice": "alloy" - } - } - } - ) - - # Create agent - agent = BidiAgent( - model=model, - tools=[calculator], - system_prompt="You are a helpful voice assistant. Keep your responses brief and natural. Say hello when you first connect." - ) - - # Start the session - await agent.start() - - # Create shared context - context = { - "active": True, - "audio_in": asyncio.Queue(), - "audio_out": asyncio.Queue(), - "interrupted": False, - "start_time": time.time() - } - - print("Speak into your microphone. Press Ctrl+C to stop.") - - try: - # Run all tasks concurrently - await asyncio.gather( - play(context), - record(context), - receive(agent, context), - send(agent, context), - return_exceptions=True - ) - - except KeyboardInterrupt: - print("\nInterrupted by user") - except asyncio.CancelledError: - print("\nTest cancelled") - except Exception as e: - print(f"\nError during voice chat: {e}") - finally: - print("Cleaning up...") - context["active"] = False - - try: - await agent.stop() - except Exception as e: - print(f"Cleanup error: {e}") - - return True - - -if __name__ == "__main__": - try: - asyncio.run(main()) - except KeyboardInterrupt: - print("\nTest interrupted by user") - except Exception as e: - print(f"Test error: {e}") - import traceback - traceback.print_exc() \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/scripts/test_gemini_live.py b/src/strands/experimental/bidirectional_streaming/scripts/test_gemini_live.py deleted file mode 100644 index 814586de1..000000000 --- a/src/strands/experimental/bidirectional_streaming/scripts/test_gemini_live.py +++ /dev/null @@ -1,363 +0,0 @@ -"""Test suite for Gemini Live bidirectional streaming with camera support. - -Tests the Gemini Live API with real-time audio and video interaction including: -- Audio input/output streaming -- Camera frame capture and transmission -- Interruption handling -- Concurrent tool execution -- Transcript events - -Requirements: -- pip install opencv-python pillow pyaudio google-genai -- Camera access permissions -- GOOGLE_AI_API_KEY environment variable -""" - -import asyncio -import base64 -import io -import logging -import os -import sys -from pathlib import Path - -# Add the src directory to Python path -sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) -import time - -try: - import cv2 - import PIL.Image - CAMERA_AVAILABLE = True -except ImportError as e: - print(f"Camera dependencies not available: {e}") - print("Install with: pip install opencv-python pillow") - CAMERA_AVAILABLE = False - -import pyaudio -from strands_tools import calculator - -from strands.experimental.bidirectional_streaming.agent.agent import BidiAgent -from strands.experimental.bidirectional_streaming.models.gemini_live import BidiGeminiLiveModel - -# Configure logging - debug only for Gemini Live, info for everything else -logging.basicConfig(level=logging.WARN, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -gemini_logger = logging.getLogger('strands.experimental.bidirectional_streaming.models.gemini_live') -gemini_logger.setLevel(logging.WARN) -logger = logging.getLogger(__name__) - - -async def play(context): - """Play audio output with responsive interruption support.""" - audio = pyaudio.PyAudio() - speaker = audio.open( - channels=1, - format=pyaudio.paInt16, - output=True, - rate=24000, - frames_per_buffer=1024, - ) - - try: - while context["active"]: - try: - # Check for interruption first - if context.get("interrupted", False): - # Clear entire audio queue immediately - while not context["audio_out"].empty(): - try: - context["audio_out"].get_nowait() - except asyncio.QueueEmpty: - break - - context["interrupted"] = False - await asyncio.sleep(0.05) - continue - - # Get next audio data - audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) - - if audio_data and context["active"]: - chunk_size = 1024 - for i in range(0, len(audio_data), chunk_size): - # Check for interruption before each chunk - if context.get("interrupted", False) or not context["active"]: - break - - end = min(i + chunk_size, len(audio_data)) - chunk = audio_data[i:end] - speaker.write(chunk) - await asyncio.sleep(0.001) - - except asyncio.TimeoutError: - continue # No audio available - except asyncio.QueueEmpty: - await asyncio.sleep(0.01) - except asyncio.CancelledError: - break - - except asyncio.CancelledError: - pass - finally: - speaker.close() - audio.terminate() - - -async def record(context): - """Record audio input from microphone.""" - audio = pyaudio.PyAudio() - - # List all available audio devices - print("Available audio devices:") - for i in range(audio.get_device_count()): - device_info = audio.get_device_info_by_index(i) - if device_info['maxInputChannels'] > 0: # Only show input devices - print(f" Device {i}: {device_info['name']} (inputs: {device_info['maxInputChannels']})") - - # Get default input device info - default_device = audio.get_default_input_device_info() - print(f"\nUsing default input device: {default_device['name']} (Device {default_device['index']})") - - microphone = audio.open( - channels=1, - format=pyaudio.paInt16, - frames_per_buffer=1024, - input=True, - rate=16000, - ) - - try: - while context["active"]: - try: - audio_bytes = microphone.read(1024, exception_on_overflow=False) - context["audio_in"].put_nowait(audio_bytes) - await asyncio.sleep(0.01) - except asyncio.CancelledError: - break - except asyncio.CancelledError: - pass - finally: - microphone.close() - audio.terminate() - - -async def receive(agent, context): - """Receive and process events from agent.""" - try: - async for event in agent.receive(): - event_type = event.get("type", "unknown") - - # Handle audio stream events (bidi_audio_stream) - if event_type == "bidi_audio_stream": - if not context.get("interrupted", False): - # Decode base64 audio string to bytes for playback - audio_b64 = event["audio"] - audio_data = base64.b64decode(audio_b64) - context["audio_out"].put_nowait(audio_data) - - # Handle interruption events (bidi_interruption) - elif event_type == "bidi_interruption": - context["interrupted"] = True - print("⚠️ Interruption detected") - - # Handle transcript events (bidi_transcript_stream) - elif event_type == "bidi_transcript_stream": - transcript_text = event.get("text", "") - transcript_role = event.get("role", "unknown") - is_final = event.get("is_final", False) - - # Print transcripts with special formatting - if transcript_role == "user": - print(f"🎤 User: {transcript_text}") - elif transcript_role == "assistant": - print(f"🔊 Assistant: {transcript_text}") - - # Handle response complete events (bidi_response_complete) - elif event_type == "bidi_response_complete": - # Reset interrupted state since the response is complete - context["interrupted"] = False - - # Handle tool use events (tool_use_stream) - elif event_type == "tool_use_stream": - tool_use = event.get("current_tool_use", {}) - tool_name = tool_use.get("name", "unknown") - tool_input = tool_use.get("input", {}) - print(f"🔧 Tool called: {tool_name} with input: {tool_input}") - - # Handle tool result events (tool_result) - elif event_type == "tool_result": - tool_result = event.get("tool_result", {}) - tool_name = tool_result.get("name", "unknown") - result_content = tool_result.get("content", []) - # Extract text from content blocks - result_text = "" - for block in result_content: - if isinstance(block, dict) and block.get("type") == "text": - result_text = block.get("text", "") - break - print(f"✅ Tool result from {tool_name}: {result_text}") - - except asyncio.CancelledError: - pass - - -def _get_frame(cap): - """Capture and process a frame from camera.""" - if not CAMERA_AVAILABLE: - return None - - # Read the frame - ret, frame = cap.read() - # Check if the frame was read successfully - if not ret: - return None - # Convert BGR to RGB color space - # OpenCV captures in BGR but PIL expects RGB format - # This prevents the blue tint in the video feed - frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - img = PIL.Image.fromarray(frame_rgb) - img.thumbnail([1024, 1024]) - - image_io = io.BytesIO() - img.save(image_io, format="jpeg") - image_io.seek(0) - - mime_type = "image/jpeg" - image_bytes = image_io.read() - return {"mime_type": mime_type, "data": base64.b64encode(image_bytes).decode()} - - -async def get_frames(context): - """Capture frames from camera and send to agent.""" - if not CAMERA_AVAILABLE: - print("Camera not available - skipping video capture") - return - - # This takes about a second, and will block the whole program - # causing the audio pipeline to overflow if you don't to_thread it. - cap = await asyncio.to_thread(cv2.VideoCapture, 0) # 0 represents the default camera - - print("Camera initialized. Starting video capture...") - - try: - while context["active"] and time.time() - context["start_time"] < context["duration"]: - frame = await asyncio.to_thread(_get_frame, cap) - if frame is None: - break - - # Send frame to agent as image input - try: - from strands.experimental.bidirectional_streaming.types.events import BidiImageInputEvent - - image_event = BidiImageInputEvent( - image=frame["data"], # Already base64 encoded - mime_type=frame["mime_type"] - ) - await context["agent"].send(image_event) - print("📸 Frame sent to model") - except Exception as e: - logger.error(f"Error sending frame: {e}") - - # Wait 1 second between frames (1 FPS) - await asyncio.sleep(1.0) - - except asyncio.CancelledError: - pass - finally: - # Release the VideoCapture object - cap.release() - - -async def send(agent, context): - """Send audio input to agent.""" - try: - while time.time() - context["start_time"] < context["duration"]: - try: - audio_bytes = context["audio_in"].get_nowait() - # Create audio event using TypedEvent - from strands.experimental.bidirectional_streaming.types.events import BidiAudioInputEvent - - audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') - audio_event = BidiAudioInputEvent( - audio=audio_b64, - format="pcm", - sample_rate=16000, - channels=1 - ) - await agent.send(audio_event) - except asyncio.QueueEmpty: - await asyncio.sleep(0.01) - except asyncio.CancelledError: - break - - context["active"] = False - except asyncio.CancelledError: - pass - - -async def main(duration=180): - """Main function for Gemini Live bidirectional streaming test with camera support.""" - print("Starting Gemini Live bidirectional streaming test with camera...") - print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") - print("Video: Camera frames sent at 1 FPS to model") - - # Get API key from environment variable - api_key = os.getenv("GOOGLE_AI_API_KEY") - - if not api_key: - print("ERROR: GOOGLE_AI_API_KEY environment variable not set") - print("Please set it with: export GOOGLE_AI_API_KEY=your_api_key") - return - - # Initialize Gemini Live model with proper configuration - logger.info("Initializing Gemini Live model with API key") - - # Use default model and config (includes transcription enabled by default) - model = BidiGeminiLiveModel(api_key=api_key) - logger.info("Gemini Live model initialized successfully") - print("Using Gemini Live model with default config (audio output + transcription enabled)") - - agent = BidiAgent( - model=model, - tools=[calculator], - system_prompt="You are a helpful assistant." - ) - - await agent.start() - - # Create shared context for all tasks - context = { - "active": True, - "audio_in": asyncio.Queue(), - "audio_out": asyncio.Queue(), - "connection": agent._agent_loop, - "duration": duration, - "start_time": time.time(), - "interrupted": False, - "agent": agent, # Add agent reference for camera task - } - - print("Speak into microphone and show things to camera. Press Ctrl+C to exit.") - - try: - # Run all tasks concurrently including camera - await asyncio.gather( - play(context), - record(context), - receive(agent, context), - send(agent, context), - get_frames(context), # Add camera task - return_exceptions=True - ) - except KeyboardInterrupt: - print("\nInterrupted by user") - except asyncio.CancelledError: - print("\nTest cancelled") - finally: - print("Cleaning up...") - context["active"] = False - await agent.stop() - - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py deleted file mode 100644 index d5263bb28..000000000 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Type definitions for bidirectional streaming.""" - -from .agent import BidiAgentInput -from .io import BidiInput, BidiOutput -from .events import ( - DEFAULT_CHANNELS, - DEFAULT_FORMAT, - DEFAULT_SAMPLE_RATE, - SUPPORTED_AUDIO_FORMATS, - SUPPORTED_CHANNELS, - SUPPORTED_SAMPLE_RATES, - BidiAudioInputEvent, - BidiAudioStreamEvent, - BidiConnectionCloseEvent, - BidiConnectionStartEvent, - BidiErrorEvent, - BidiImageInputEvent, - BidiInputEvent, - BidiInterruptionEvent, - ModalityUsage, - BidiUsageEvent, - BidiOutputEvent, - BidiResponseCompleteEvent, - BidiResponseStartEvent, - BidiTextInputEvent, - BidiTranscriptStreamEvent, -) - -__all__ = [ - "BidiInput", - "BidiOutput", - "BidiAgentInput", - # Input Events - "BidiTextInputEvent", - "BidiAudioInputEvent", - "BidiImageInputEvent", - "BidiInputEvent", - # Output Events - "BidiConnectionStartEvent", - "BidiConnectionCloseEvent", - "BidiResponseStartEvent", - "BidiResponseCompleteEvent", - "BidiAudioStreamEvent", - "BidiTranscriptStreamEvent", - "BidiInterruptionEvent", - "BidiUsageEvent", - "ModalityUsage", - "BidiErrorEvent", - "BidiOutputEvent", - # Constants - "SUPPORTED_AUDIO_FORMATS", - "SUPPORTED_SAMPLE_RATES", - "SUPPORTED_CHANNELS", - "DEFAULT_SAMPLE_RATE", - "DEFAULT_CHANNELS", - "DEFAULT_FORMAT", -] diff --git a/src/strands/experimental/bidirectional_streaming/types/agent.py b/src/strands/experimental/bidirectional_streaming/types/agent.py deleted file mode 100644 index 8d1e9aab7..000000000 --- a/src/strands/experimental/bidirectional_streaming/types/agent.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Agent-related type definitions for bidirectional streaming. - -This module defines the types used for BidiAgent. -""" - -from typing import TypeAlias - -from .events import BidiAudioInputEvent, BidiImageInputEvent, BidiTextInputEvent - -BidiAgentInput: TypeAlias = str | BidiTextInputEvent | BidiAudioInputEvent | BidiImageInputEvent diff --git a/src/strands/experimental/bidirectional_streaming/types/events.py b/src/strands/experimental/bidirectional_streaming/types/events.py deleted file mode 100644 index 852950f5a..000000000 --- a/src/strands/experimental/bidirectional_streaming/types/events.py +++ /dev/null @@ -1,521 +0,0 @@ -"""Bidirectional streaming types for real-time audio/text conversations. - -Type definitions for bidirectional streaming that extends Strands' existing streaming -capabilities with real-time audio and persistent connection support. - -Key features: -- Audio input/output events with standardized formats -- Interruption detection and handling -- Connection lifecycle management -- Provider-agnostic event types -- Type-safe discriminated unions with TypedEvent -- JSON-serializable events (audio/images stored as base64 strings) - -Audio format normalization: -- Supports PCM, WAV, Opus, and MP3 formats -- Standardizes sample rates (16kHz, 24kHz, 48kHz) -- Normalizes channel configurations (mono/stereo) -- Abstracts provider-specific encodings -- Audio data stored as base64-encoded strings for JSON compatibility -""" - -from typing import Any, Dict, List, Literal, Optional, Union, cast - -from ....types._events import ModelStreamEvent, TypedEvent -from ....types.streaming import ContentBlockDelta - -# Audio format constants -SUPPORTED_AUDIO_FORMATS = ["pcm", "wav", "opus", "mp3"] -SUPPORTED_SAMPLE_RATES = [16000, 24000, 48000] -SUPPORTED_CHANNELS = [1, 2] # 1=mono, 2=stereo -DEFAULT_SAMPLE_RATE = 16000 -DEFAULT_CHANNELS = 1 -DEFAULT_FORMAT = "pcm" - - -# ============================================================================ -# Input Events (sent via agent.send()) -# ============================================================================ - - -class BidiTextInputEvent(TypedEvent): - """Text input event for sending text to the model. - - Used for sending text content through the send() method. - - Parameters: - text: The text content to send to the model. - role: The role of the message sender (typically "user"). - """ - - def __init__(self, text: str, role: str): - super().__init__( - { - "type": "bidi_text_input", - "text": text, - "role": role, - } - ) - - @property - def text(self) -> str: - return cast(str, self.get("text")) - - @property - def role(self) -> str: - return cast(str, self.get("role")) - - -class BidiAudioInputEvent(TypedEvent): - """Audio input event for sending audio to the model. - - Used for sending audio data through the send() method. - - Parameters: - audio: Base64-encoded audio string to send to model. - format: Audio format from SUPPORTED_AUDIO_FORMATS. - sample_rate: Sample rate from SUPPORTED_SAMPLE_RATES. - channels: Channel count from SUPPORTED_CHANNELS. - """ - - def __init__( - self, - audio: str, - format: Literal["pcm", "wav", "opus", "mp3"], - sample_rate: Literal[16000, 24000, 48000], - channels: Literal[1, 2], - ): - super().__init__( - { - "type": "bidi_audio_input", - "audio": audio, - "format": format, - "sample_rate": sample_rate, - "channels": channels, - } - ) - - @property - def audio(self) -> str: - return cast(str, self.get("audio")) - - @property - def format(self) -> str: - return cast(str, self.get("format")) - - @property - def sample_rate(self) -> int: - return cast(int, self.get("sample_rate")) - - @property - def channels(self) -> int: - return cast(int, self.get("channels")) - - -class BidiImageInputEvent(TypedEvent): - """Image input event for sending images/video frames to the model. - - Used for sending image data through the send() method. - - Parameters: - image: Base64-encoded image string. - mime_type: MIME type (e.g., "image/jpeg", "image/png"). - """ - - def __init__( - self, - image: str, - mime_type: str, - ): - super().__init__( - { - "type": "bidi_image_input", - "image": image, - "mime_type": mime_type, - } - ) - - @property - def image(self) -> str: - return cast(str, self.get("image")) - - @property - def mime_type(self) -> str: - return cast(str, self.get("mime_type")) - - -# ============================================================================ -# Output Events (received via agent.receive()) -# ============================================================================ - - -class BidiConnectionStartEvent(TypedEvent): - """Streaming connection established and ready for interaction. - - Parameters: - connection_id: Unique identifier for this streaming connection. - model: Model identifier (e.g., "gpt-realtime", "gemini-2.0-flash-live"). - """ - - def __init__(self, connection_id: str, model: str): - super().__init__( - { - "type": "bidi_connection_start", - "connection_id": connection_id, - "model": model, - } - ) - - @property - def connection_id(self) -> str: - return cast(str, self.get("connection_id")) - - @property - def model(self) -> str: - return cast(str, self.get("model")) - - -class BidiResponseStartEvent(TypedEvent): - """Model starts generating a response. - - Parameters: - response_id: Unique identifier for this response (used in response.complete). - """ - - def __init__(self, response_id: str): - super().__init__({"type": "bidi_response_start", "response_id": response_id}) - - @property - def response_id(self) -> str: - return cast(str, self.get("response_id")) - - -class BidiAudioStreamEvent(TypedEvent): - """Streaming audio output from the model. - - Parameters: - audio: Base64-encoded audio string. - format: Audio encoding format. - sample_rate: Number of audio samples per second in Hz. - channels: Number of audio channels (1=mono, 2=stereo). - """ - - def __init__( - self, - audio: str, - format: Literal["pcm", "wav", "opus", "mp3"], - sample_rate: Literal[16000, 24000, 48000], - channels: Literal[1, 2], - ): - super().__init__( - { - "type": "bidi_audio_stream", - "audio": audio, - "format": format, - "sample_rate": sample_rate, - "channels": channels, - } - ) - - @property - def audio(self) -> str: - return cast(str, self.get("audio")) - - @property - def format(self) -> str: - return cast(str, self.get("format")) - - @property - def sample_rate(self) -> int: - return cast(int, self.get("sample_rate")) - - @property - def channels(self) -> int: - return cast(int, self.get("channels")) - - -class BidiTranscriptStreamEvent(ModelStreamEvent): - """Audio transcription streaming (user or assistant speech). - - Supports incremental transcript updates for providers that send partial - transcripts before the final version. - - Parameters: - delta: The incremental transcript change (ContentBlockDelta). - text: The delta text (same as delta content for convenience). - role: Who is speaking ("user" or "assistant"). - is_final: Whether this is the final/complete transcript. - current_transcript: The accumulated transcript text so far (None for first delta). - """ - - def __init__( - self, - delta: ContentBlockDelta, - text: str, - role: Literal["user", "assistant"], - is_final: bool, - current_transcript: Optional[str] = None, - ): - super().__init__( - { - "type": "bidi_transcript_stream", - "delta": delta, - "text": text, - "role": role, - "is_final": is_final, - "current_transcript": current_transcript, - } - ) - - @property - def delta(self) -> ContentBlockDelta: - return cast(ContentBlockDelta, self.get("delta")) - - @property - def text(self) -> str: - return cast(str, self.get("text")) - - @property - def role(self) -> str: - return cast(str, self.get("role")) - - @property - def is_final(self) -> bool: - return cast(bool, self.get("is_final")) - - @property - def current_transcript(self) -> Optional[str]: - return cast(Optional[str], self.get("current_transcript")) - - -class BidiInterruptionEvent(TypedEvent): - """Model generation was interrupted. - - Parameters: - reason: Why the interruption occurred. - response_id: ID of the response that was interrupted (may be None). - """ - - def __init__(self, reason: Literal["user_speech", "error"]): - super().__init__( - { - "type": "bidi_interruption", - "reason": reason, - } - ) - - @property - def reason(self) -> str: - return cast(str, self.get("reason")) - - -class BidiResponseCompleteEvent(TypedEvent): - """Model finished generating response. - - Parameters: - response_id: ID of the response that completed (matches response.start). - stop_reason: Why the response ended. - """ - - def __init__( - self, - response_id: str, - stop_reason: Literal["complete", "interrupted", "tool_use", "error"], - ): - super().__init__( - { - "type": "bidi_response_complete", - "response_id": response_id, - "stop_reason": stop_reason, - } - ) - - @property - def response_id(self) -> str: - return cast(str, self.get("response_id")) - - @property - def stop_reason(self) -> str: - return cast(str, self.get("stop_reason")) - - -class ModalityUsage(dict): - """Token usage for a specific modality. - - Attributes: - modality: Type of content. - input_tokens: Tokens used for this modality's input. - output_tokens: Tokens used for this modality's output. - """ - - modality: Literal["text", "audio", "image", "cached"] - input_tokens: int - output_tokens: int - - -class BidiUsageEvent(TypedEvent): - """Token usage event with modality breakdown for bidirectional streaming. - - Tracks token consumption across different modalities (audio, text, images) - during bidirectional streaming sessions. - - Parameters: - input_tokens: Total tokens used for all input modalities. - output_tokens: Total tokens used for all output modalities. - total_tokens: Sum of input and output tokens. - modality_details: Optional list of token usage per modality. - cache_read_input_tokens: Optional tokens read from cache. - cache_write_input_tokens: Optional tokens written to cache. - """ - - def __init__( - self, - input_tokens: int, - output_tokens: int, - total_tokens: int, - modality_details: Optional[List[ModalityUsage]] = None, - cache_read_input_tokens: Optional[int] = None, - cache_write_input_tokens: Optional[int] = None, - ): - data: Dict[str, Any] = { - "type": "bidi_usage", - "inputTokens": input_tokens, - "outputTokens": output_tokens, - "totalTokens": total_tokens, - } - if modality_details is not None: - data["modality_details"] = modality_details - if cache_read_input_tokens is not None: - data["cacheReadInputTokens"] = cache_read_input_tokens - if cache_write_input_tokens is not None: - data["cacheWriteInputTokens"] = cache_write_input_tokens - super().__init__(data) - - @property - def input_tokens(self) -> int: - return cast(int, self.get("inputTokens")) - - @property - def output_tokens(self) -> int: - return cast(int, self.get("outputTokens")) - - @property - def total_tokens(self) -> int: - return cast(int, self.get("totalTokens")) - - @property - def modality_details(self) -> List[ModalityUsage]: - return cast(List[ModalityUsage], self.get("modality_details", [])) - - @property - def cache_read_input_tokens(self) -> Optional[int]: - return cast(Optional[int], self.get("cacheReadInputTokens")) - - @property - def cache_write_input_tokens(self) -> Optional[int]: - return cast(Optional[int], self.get("cacheWriteInputTokens")) - - -class BidiConnectionCloseEvent(TypedEvent): - """Streaming connection closed. - - Parameters: - connection_id: Unique identifier for this streaming connection (matches BidiConnectionStartEvent). - reason: Why the connection was closed. - """ - - def __init__( - self, - connection_id: str, - reason: Literal["client_disconnect", "timeout", "error", "complete"], - ): - super().__init__( - { - "type": "bidi_connection_close", - "connection_id": connection_id, - "reason": reason, - } - ) - - @property - def connection_id(self) -> str: - return cast(str, self.get("connection_id")) - - @property - def reason(self) -> str: - return cast(str, self.get("reason")) - - -class BidiErrorEvent(TypedEvent): - """Error occurred during the session. - - Stores the full Exception object as an instance attribute for debugging while - keeping the event dict JSON-serializable. The exception can be accessed via - the `error` property for re-raising or type-based error handling. - - Parameters: - error: The exception that occurred. - details: Optional additional error information. - """ - - def __init__( - self, - error: Exception, - details: Optional[Dict[str, Any]] = None, - ): - # Store serializable data in dict (for JSON serialization) - super().__init__( - { - "type": "bidi_error", - "message": str(error), - "code": type(error).__name__, - "details": details, - } - ) - # Store exception as instance attribute (not serialized) - self._error = error - - @property - def error(self) -> Exception: - """The original exception that occurred. - - Can be used for re-raising or type-based error handling. - """ - return self._error - - @property - def code(self) -> str: - """Error code derived from exception class name.""" - return cast(str, self.get("code")) - - @property - def message(self) -> str: - """Human-readable error message from the exception.""" - return cast(str, self.get("message")) - - @property - def details(self) -> Optional[Dict[str, Any]]: - """Additional error context beyond the exception itself.""" - return cast(Optional[Dict[str, Any]], self.get("details")) - - -# ============================================================================ -# Type Unions -# ============================================================================ - -# Note: ToolResultEvent is imported from strands.types._events and used alongside -# BidiInputEvent in send() methods for sending tool results back to the model. - -BidiInputEvent = BidiTextInputEvent | BidiAudioInputEvent | BidiImageInputEvent - -BidiOutputEvent = ( - BidiConnectionStartEvent - | BidiResponseStartEvent - | BidiAudioStreamEvent - | BidiTranscriptStreamEvent - | BidiInterruptionEvent - | BidiResponseCompleteEvent - | BidiUsageEvent - | BidiConnectionCloseEvent - | BidiErrorEvent -) diff --git a/src/strands/experimental/bidirectional_streaming/types/io.py b/src/strands/experimental/bidirectional_streaming/types/io.py deleted file mode 100644 index 8b79455ec..000000000 --- a/src/strands/experimental/bidirectional_streaming/types/io.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Protocol for bidirectional streaming IO channels. - -Defines callable protocols for input and output channels that can be used -with BidiAgent. This approach provides better typing and flexibility -by separating input and output concerns into independent callables. -""" - -from typing import Awaitable, Protocol - -from ..types.events import BidiInputEvent, BidiOutputEvent - - -class BidiInput(Protocol): - """Protocol for bidirectional input callables. - - Input callables read data from a source (microphone, camera, websocket, etc.) - and return events to be sent to the agent. - """ - - async def start(self) -> None: - """Start input.""" - ... - - async def stop(self) -> None: - """Stop input.""" - ... - - def __call__(self) -> Awaitable[BidiInputEvent]: - """Read input data from the source. - - Returns: - Awaitable that resolves to an input event (audio, text, image, etc.) - """ - ... - -class BidiOutput(Protocol): - """Protocol for bidirectional output callables. - - Output callables receive events from the agent and handle them appropriately - (play audio, display text, send over websocket, etc.). - """ - - async def start(self) -> None: - """Start output.""" - ... - - async def stop(self) -> None: - """Stop output.""" - ... - - def __call__(self, event: BidiOutputEvent) -> Awaitable[None]: - """Process output events from the agent. - - Args: - event: Output event from the agent (audio, text, tool calls, etc.) - """ - ... diff --git a/tests/strands/experimental/bidirectional_streaming/__init__.py b/tests/strands/experimental/bidirectional_streaming/__init__.py deleted file mode 100644 index ea37091cc..000000000 --- a/tests/strands/experimental/bidirectional_streaming/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Bidirectional streaming tests.""" diff --git a/tests/strands/experimental/bidirectional_streaming/models/__init__.py b/tests/strands/experimental/bidirectional_streaming/models/__init__.py deleted file mode 100644 index ea9fbb2d0..000000000 --- a/tests/strands/experimental/bidirectional_streaming/models/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Bidirectional streaming model tests.""" diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py deleted file mode 100644 index 272314272..000000000 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ /dev/null @@ -1,487 +0,0 @@ -"""Unit tests for Gemini Live bidirectional streaming model. - -Tests the unified BidiGeminiLiveModel interface including: -- Model initialization and configuration -- Connection establishment and lifecycle -- Unified send() method with different content types -- Event receiving and conversion -""" - -import base64 -import json -import unittest.mock - -import pytest -from google import genai -from google.genai import types as genai_types - -from strands.experimental.bidirectional_streaming.models.gemini_live import BidiGeminiLiveModel -from strands.experimental.bidirectional_streaming.types.events import ( - BidiAudioInputEvent, - BidiAudioStreamEvent, - BidiConnectionCloseEvent, - BidiConnectionStartEvent, - BidiImageInputEvent, - BidiInterruptionEvent, - BidiTextInputEvent, - BidiTranscriptStreamEvent, -) -from strands.types._events import ToolResultEvent -from strands.types.tools import ToolResult - - -@pytest.fixture -def mock_genai_client(): - """Mock the Google GenAI client.""" - with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.gemini_live.genai.Client") as mock_client_cls: - mock_client = mock_client_cls.return_value - mock_client.aio = unittest.mock.MagicMock() - - # Mock the live session - mock_live_session = unittest.mock.AsyncMock() - - # Mock the context manager - mock_live_session_cm = unittest.mock.MagicMock() - mock_live_session_cm.__aenter__ = unittest.mock.AsyncMock(return_value=mock_live_session) - mock_live_session_cm.__aexit__ = unittest.mock.AsyncMock(return_value=None) - - # Make connect return the context manager - mock_client.aio.live.connect = unittest.mock.MagicMock(return_value=mock_live_session_cm) - - yield mock_client, mock_live_session, mock_live_session_cm - - -@pytest.fixture -def model_id(): - return "models/gemini-2.0-flash-live-preview-04-09" - - -@pytest.fixture -def api_key(): - return "test-api-key" - - -@pytest.fixture -def model(mock_genai_client, model_id, api_key): - """Create a BidiGeminiLiveModel instance.""" - _ = mock_genai_client - return BidiGeminiLiveModel(model_id=model_id, api_key=api_key) - - -@pytest.fixture -def tool_spec(): - return { - "description": "Calculate mathematical expressions", - "name": "calculator", - "inputSchema": {"json": {"type": "object", "properties": {"expression": {"type": "string"}}}}, - } - - -@pytest.fixture -def system_prompt(): - return "You are a helpful assistant" - - -@pytest.fixture -def messages(): - return [{"role": "user", "content": [{"text": "Hello"}]}] - - -# Initialization Tests - - -def test_model_initialization(mock_genai_client, model_id, api_key): - """Test model initialization with various configurations.""" - _ = mock_genai_client - - # Test default config - model_default = BidiGeminiLiveModel() - assert model_default.model_id == "gemini-2.5-flash-native-audio-preview-09-2025" - assert model_default.api_key is None - assert model_default._active is False - assert model_default.live_session is None - # Check default config includes transcription - assert model_default.live_config["response_modalities"] == ["AUDIO"] - assert "outputAudioTranscription" in model_default.live_config - assert "inputAudioTranscription" in model_default.live_config - - # Test with API key - model_with_key = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) - assert model_with_key.model_id == model_id - assert model_with_key.api_key == api_key - - # Test with custom config (merges with defaults) - live_config = {"temperature": 0.7, "top_p": 0.9} - model_custom = BidiGeminiLiveModel(model_id=model_id, live_config=live_config) - # Custom config should be merged with defaults - assert model_custom.live_config["temperature"] == 0.7 - assert model_custom.live_config["top_p"] == 0.9 - # Defaults should still be present - assert "response_modalities" in model_custom.live_config - - -# Connection Tests - - -@pytest.mark.asyncio -async def test_connection_lifecycle(mock_genai_client, model, system_prompt, tool_spec, messages): - """Test complete connection lifecycle with various configurations.""" - mock_client, mock_live_session, mock_live_session_cm = mock_genai_client - - # Test basic connection - await model.start() - assert model._active is True - assert model.connection_id is not None - assert model.live_session == mock_live_session - mock_client.aio.live.connect.assert_called_once() - - # Test close - await model.stop() - assert model._active is False - mock_live_session_cm.__aexit__.assert_called_once() - - # Test connection with system prompt - await model.start(system_prompt=system_prompt) - call_args = mock_client.aio.live.connect.call_args - config = call_args.kwargs.get("config", {}) - assert config.get("system_instruction") == system_prompt - await model.stop() - - # Test connection with tools - await model.start(tools=[tool_spec]) - call_args = mock_client.aio.live.connect.call_args - config = call_args.kwargs.get("config", {}) - assert "tools" in config - assert len(config["tools"]) > 0 - await model.stop() - - # Test connection with messages - await model.start(messages=messages) - mock_live_session.send_client_content.assert_called() - await model.stop() - - -@pytest.mark.asyncio -async def test_connection_edge_cases(mock_genai_client, api_key, model_id): - """Test connection error handling and edge cases.""" - mock_client, _, mock_live_session_cm = mock_genai_client - - # Test connection error - model1 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) - mock_client.aio.live.connect.side_effect = Exception("Connection failed") - with pytest.raises(Exception, match="Connection failed"): - await model1.start() - - # Reset mock for next tests - mock_client.aio.live.connect.side_effect = None - - # Test double connection - model2 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) - await model2.start() - with pytest.raises(RuntimeError, match="Connection already active"): - await model2.start() - await model2.stop() - - # Test close when not connected - model3 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) - await model3.stop() # Should not raise - - # Test close error handling - model4 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) - await model4.start() - mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") - with pytest.raises(Exception, match="Close failed"): - await model4.stop() - - -# Send Method Tests - - -@pytest.mark.asyncio -async def test_send_all_content_types(mock_genai_client, model): - """Test sending all content types through unified send() method.""" - _, mock_live_session, _ = mock_genai_client - await model.start() - - # Test text input - text_input = BidiTextInputEvent(text="Hello", role="user") - await model.send(text_input) - mock_live_session.send_client_content.assert_called_once() - call_args = mock_live_session.send_client_content.call_args - content = call_args.kwargs.get("turns") - assert content.role == "user" - assert content.parts[0].text == "Hello" - - # Test audio input (base64 encoded) - audio_b64 = base64.b64encode(b"audio_bytes").decode('utf-8') - audio_input = BidiAudioInputEvent( - audio=audio_b64, - format="pcm", - sample_rate=16000, - channels=1, - ) - await model.send(audio_input) - mock_live_session.send_realtime_input.assert_called_once() - - # Test image input (base64 encoded, no encoding parameter) - image_b64 = base64.b64encode(b"image_bytes").decode('utf-8') - image_input = BidiImageInputEvent( - image=image_b64, - mime_type="image/jpeg", - ) - await model.send(image_input) - mock_live_session.send.assert_called_once() - - # Test tool result - tool_result: ToolResult = { - "toolUseId": "tool-123", - "status": "success", - "content": [{"text": "Result: 42"}], - } - await model.send(ToolResultEvent(tool_result)) - mock_live_session.send_tool_response.assert_called_once() - - await model.stop() - - -@pytest.mark.asyncio -async def test_send_edge_cases(mock_genai_client, model): - """Test send() edge cases and error handling.""" - _, mock_live_session, _ = mock_genai_client - - # Test send when inactive - text_input = BidiTextInputEvent(text="Hello", role="user") - await model.send(text_input) - mock_live_session.send_client_content.assert_not_called() - - # Test unknown content type - await model.start() - unknown_content = {"unknown_field": "value"} - await model.send(unknown_content) # Should not raise, just log warning - - await model.stop() - - -# Receive Method Tests - - -@pytest.mark.asyncio -async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): - """Test that receive() emits connection start and end events.""" - _, mock_live_session, _ = mock_genai_client - mock_live_session.receive.return_value = agenerator([]) - - await model.start() - - # Collect events - events = [] - async for event in model.receive(): - events.append(event) - # Close after first event to trigger connection end - if len(events) == 1: - await model.stop() - - # Verify connection start and end - assert len(events) >= 2 - assert isinstance(events[0], BidiConnectionStartEvent) - assert events[0].get("type") == "bidi_connection_start" - assert events[0].connection_id == model.connection_id - assert isinstance(events[-1], BidiConnectionCloseEvent) - assert events[-1].get("type") == "bidi_connection_close" - - -@pytest.mark.asyncio -async def test_event_conversion(mock_genai_client, model): - """Test conversion of all Gemini Live event types to standard format.""" - _, _, _ = mock_genai_client - await model.start() - - # Test text output (converted to transcript via model_turn.parts) - mock_text = unittest.mock.Mock() - mock_text.data = None - mock_text.tool_call = None - - # Create proper server_content structure with model_turn - mock_server_content = unittest.mock.Mock() - mock_server_content.interrupted = False - mock_server_content.input_transcription = None - mock_server_content.output_transcription = None - - mock_model_turn = unittest.mock.Mock() - mock_part = unittest.mock.Mock() - mock_part.text = "Hello from Gemini" - mock_model_turn.parts = [mock_part] - mock_server_content.model_turn = mock_model_turn - - mock_text.server_content = mock_server_content - - text_events = model._convert_gemini_live_event(mock_text) - assert isinstance(text_events, list) - assert len(text_events) == 1 - text_event = text_events[0] - assert isinstance(text_event, BidiTranscriptStreamEvent) - assert text_event.get("type") == "bidi_transcript_stream" - assert text_event.text == "Hello from Gemini" - assert text_event.role == "assistant" - assert text_event.is_final is True - assert text_event.delta == {"text": "Hello from Gemini"} - assert text_event.current_transcript == "Hello from Gemini" - - # Test multiple text parts (should concatenate) - mock_multi_text = unittest.mock.Mock() - mock_multi_text.data = None - mock_multi_text.tool_call = None - - mock_server_content_multi = unittest.mock.Mock() - mock_server_content_multi.interrupted = False - mock_server_content_multi.input_transcription = None - mock_server_content_multi.output_transcription = None - - mock_model_turn_multi = unittest.mock.Mock() - mock_part1 = unittest.mock.Mock() - mock_part1.text = "Hello" - mock_part2 = unittest.mock.Mock() - mock_part2.text = "from Gemini" - mock_model_turn_multi.parts = [mock_part1, mock_part2] - mock_server_content_multi.model_turn = mock_model_turn_multi - - mock_multi_text.server_content = mock_server_content_multi - - multi_text_events = model._convert_gemini_live_event(mock_multi_text) - assert isinstance(multi_text_events, list) - assert len(multi_text_events) == 1 - multi_text_event = multi_text_events[0] - assert isinstance(multi_text_event, BidiTranscriptStreamEvent) - assert multi_text_event.text == "Hello from Gemini" # Concatenated with space - - # Test audio output (base64 encoded) - mock_audio = unittest.mock.Mock() - mock_audio.text = None - mock_audio.data = b"audio_data" - mock_audio.tool_call = None - mock_audio.server_content = None - - audio_events = model._convert_gemini_live_event(mock_audio) - assert isinstance(audio_events, list) - assert len(audio_events) == 1 - audio_event = audio_events[0] - assert isinstance(audio_event, BidiAudioStreamEvent) - assert audio_event.get("type") == "bidi_audio_stream" - # Audio is now base64 encoded - expected_b64 = base64.b64encode(b"audio_data").decode('utf-8') - assert audio_event.audio == expected_b64 - assert audio_event.format == "pcm" - - # Test single tool call (returns list with one event) - mock_func_call = unittest.mock.Mock() - mock_func_call.id = "tool-123" - mock_func_call.name = "calculator" - mock_func_call.args = {"expression": "2+2"} - - mock_tool_call = unittest.mock.Mock() - mock_tool_call.function_calls = [mock_func_call] - - mock_tool = unittest.mock.Mock() - mock_tool.text = None - mock_tool.data = None - mock_tool.tool_call = mock_tool_call - mock_tool.server_content = None - - tool_events = model._convert_gemini_live_event(mock_tool) - # Should return a list of ToolUseStreamEvent - assert isinstance(tool_events, list) - assert len(tool_events) == 1 - tool_event = tool_events[0] - # ToolUseStreamEvent has delta and current_tool_use, not a "type" field - assert "delta" in tool_event - assert "toolUse" in tool_event["delta"] - assert tool_event["delta"]["toolUse"]["toolUseId"] == "tool-123" - assert tool_event["delta"]["toolUse"]["name"] == "calculator" - - # Test multiple tool calls (returns list with multiple events) - mock_func_call_1 = unittest.mock.Mock() - mock_func_call_1.id = "tool-123" - mock_func_call_1.name = "calculator" - mock_func_call_1.args = {"expression": "2+2"} - - mock_func_call_2 = unittest.mock.Mock() - mock_func_call_2.id = "tool-456" - mock_func_call_2.name = "weather" - mock_func_call_2.args = {"location": "Seattle"} - - mock_tool_call_multi = unittest.mock.Mock() - mock_tool_call_multi.function_calls = [mock_func_call_1, mock_func_call_2] - - mock_tool_multi = unittest.mock.Mock() - mock_tool_multi.text = None - mock_tool_multi.data = None - mock_tool_multi.tool_call = mock_tool_call_multi - mock_tool_multi.server_content = None - - tool_events_multi = model._convert_gemini_live_event(mock_tool_multi) - # Should return a list with two ToolUseStreamEvent - assert isinstance(tool_events_multi, list) - assert len(tool_events_multi) == 2 - - # Verify first tool call - assert tool_events_multi[0]["delta"]["toolUse"]["toolUseId"] == "tool-123" - assert tool_events_multi[0]["delta"]["toolUse"]["name"] == "calculator" - assert tool_events_multi[0]["delta"]["toolUse"]["input"] == {"expression": "2+2"} - - # Verify second tool call - assert tool_events_multi[1]["delta"]["toolUse"]["toolUseId"] == "tool-456" - assert tool_events_multi[1]["delta"]["toolUse"]["name"] == "weather" - assert tool_events_multi[1]["delta"]["toolUse"]["input"] == {"location": "Seattle"} - - # Test interruption - mock_server_content = unittest.mock.Mock() - mock_server_content.interrupted = True - mock_server_content.input_transcription = None - mock_server_content.output_transcription = None - - mock_interrupt = unittest.mock.Mock() - mock_interrupt.text = None - mock_interrupt.data = None - mock_interrupt.tool_call = None - mock_interrupt.server_content = mock_server_content - - interrupt_events = model._convert_gemini_live_event(mock_interrupt) - assert isinstance(interrupt_events, list) - assert len(interrupt_events) == 1 - interrupt_event = interrupt_events[0] - assert isinstance(interrupt_event, BidiInterruptionEvent) - assert interrupt_event.get("type") == "bidi_interruption" - assert interrupt_event.reason == "user_speech" - - await model.stop() - - -# Helper Method Tests - - -def test_config_building(model, system_prompt, tool_spec): - """Test building live config with various options.""" - # Test basic config - config_basic = model._build_live_config() - assert isinstance(config_basic, dict) - - # Test with system prompt - config_prompt = model._build_live_config(system_prompt=system_prompt) - assert config_prompt["system_instruction"] == system_prompt - - # Test with tools - config_tools = model._build_live_config(tools=[tool_spec]) - assert "tools" in config_tools - assert len(config_tools["tools"]) > 0 - - -def test_tool_formatting(model, tool_spec): - """Test tool formatting for Gemini Live API.""" - # Test with tools - formatted_tools = model._format_tools_for_live_api([tool_spec]) - assert len(formatted_tools) == 1 - assert isinstance(formatted_tools[0], genai_types.Tool) - - # Test empty list - formatted_empty = model._format_tools_for_live_api([]) - assert formatted_empty == [] diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py deleted file mode 100644 index c79e1d673..000000000 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ /dev/null @@ -1,458 +0,0 @@ -"""Unit tests for Nova Sonic bidirectional model implementation. - -Tests the unified BidirectionalModel interface implementation for Amazon Nova Sonic, -covering connection lifecycle, event conversion, audio streaming, and tool execution. -""" - -import asyncio -import base64 -import json -from unittest.mock import AsyncMock, patch - -import pytest -import pytest_asyncio - -from strands.experimental.bidirectional_streaming.models.novasonic import ( - BidiNovaSonicModel, -) -from strands.experimental.bidirectional_streaming.types.events import ( - BidiAudioInputEvent, - BidiAudioStreamEvent, - BidiImageInputEvent, - BidiInterruptionEvent, - BidiResponseStartEvent, - BidiTextInputEvent, - BidiTranscriptStreamEvent, - BidiUsageEvent, -) -from strands.types._events import ToolResultEvent -from strands.types.tools import ToolResult - - -# Test fixtures -@pytest.fixture -def model_id(): - """Nova Sonic model identifier.""" - return "amazon.nova-sonic-v1:0" - - -@pytest.fixture -def region(): - """AWS region.""" - return "us-east-1" - - -@pytest.fixture -def mock_stream(): - """Mock Nova Sonic bidirectional stream.""" - stream = AsyncMock() - stream.input_stream = AsyncMock() - stream.input_stream.send = AsyncMock() - stream.input_stream.close = AsyncMock() - stream.await_output = AsyncMock() - return stream - - -@pytest.fixture -def mock_client(mock_stream): - """Mock Bedrock Runtime client.""" - client = AsyncMock() - client.invoke_model_with_bidirectional_stream = AsyncMock(return_value=mock_stream) - return client - - -@pytest_asyncio.fixture -async def nova_model(model_id, region): - """Create Nova Sonic model instance.""" - model = BidiNovaSonicModel(model_id=model_id, region=region) - yield model - # Cleanup - if model._active: - await model.stop() - - -# Initialization and Connection Tests - - -@pytest.mark.asyncio -async def test_model_initialization(model_id, region): - """Test model initialization with configuration.""" - model = BidiNovaSonicModel(model_id=model_id, region=region) - - assert model.model_id == model_id - assert model.region == region - assert model.stream is None - assert not model._active - assert model.connection_id is None - - -@pytest.mark.asyncio -async def test_connection_lifecycle(nova_model, mock_client, mock_stream): - """Test complete connection lifecycle with various configurations.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model.client = mock_client - - # Test basic connection - await nova_model.start(system_prompt="Test system prompt") - assert nova_model._active - assert nova_model.stream == mock_stream - assert nova_model.connection_id is not None - assert mock_client.invoke_model_with_bidirectional_stream.called - - # Test close - await nova_model.stop() - assert not nova_model._active - assert mock_stream.input_stream.close.called - - # Test connection with tools - tools = [ - { - "name": "get_weather", - "description": "Get weather information", - "inputSchema": {"json": json.dumps({"type": "object", "properties": {}})} - } - ] - await nova_model.start(system_prompt="You are helpful", tools=tools) - # Verify initialization events were sent (connectionStart, promptStart, system prompt) - assert mock_stream.input_stream.send.call_count >= 3 - await nova_model.stop() - - -@pytest.mark.asyncio -async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model_id, region): - """Test connection error handling and edge cases.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model.client = mock_client - - # Test double connection - await nova_model.start() - with pytest.raises(RuntimeError, match="Connection already active"): - await nova_model.start() - await nova_model.stop() - - # Test close when already closed - model2 = BidiNovaSonicModel(model_id=model_id, region=region) - await model2.stop() # Should not raise - await model2.stop() # Second call should also be safe - - -# Send Method Tests - - -@pytest.mark.asyncio -async def test_send_all_content_types(nova_model, mock_client, mock_stream): - """Test sending all content types through unified send() method.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model.client = mock_client - - await nova_model.start() - - # Test text content - text_event = BidiTextInputEvent(text="Hello, Nova!", role="user") - await nova_model.send(text_event) - # Should send contentStart, textInput, and contentEnd - assert mock_stream.input_stream.send.call_count >= 3 - - # Test audio content (base64 encoded) - audio_b64 = base64.b64encode(b"audio data").decode('utf-8') - audio_event = BidiAudioInputEvent( - audio=audio_b64, - format="pcm", - sample_rate=16000, - channels=1 - ) - await nova_model.send(audio_event) - # Should start audio connection and send audio - assert nova_model.audio_connection_active - assert mock_stream.input_stream.send.called - - # Test tool result - tool_result: ToolResult = { - "toolUseId": "tool-123", - "status": "success", - "content": [{"text": "Weather is sunny"}] - } - await nova_model.send(ToolResultEvent(tool_result)) - # Should send contentStart, toolResult, and contentEnd - assert mock_stream.input_stream.send.called - - await nova_model.stop() - - -@pytest.mark.asyncio -async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog): - """Test send() edge cases and error handling.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model.client = mock_client - - # Test send when inactive - text_event = BidiTextInputEvent(text="Hello", role="user") - await nova_model.send(text_event) # Should not raise - - # Test image content (not supported, base64 encoded, no encoding parameter) - await nova_model.start() - image_b64 = base64.b64encode(b"image data").decode('utf-8') - image_event = BidiImageInputEvent( - image=image_b64, - mime_type="image/jpeg", - ) - await nova_model.send(image_event) - # Should log warning about unsupported image input - assert any("not supported" in record.message.lower() for record in caplog.records) - - await nova_model.stop() - - -# Receive and Event Conversion Tests - - -@pytest.mark.asyncio -async def test_receive_lifecycle_events(nova_model, mock_client, mock_stream): - """Test that receive() emits connection start and end events.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model.client = mock_client - - # Setup mock to return no events and then stop - async def mock_wait_for(*args, **kwargs): - await asyncio.sleep(0.1) - nova_model._active = False - raise asyncio.TimeoutError() - - with patch("asyncio.wait_for", side_effect=mock_wait_for): - await nova_model.start() - - events = [] - async for event in nova_model.receive(): - events.append(event) - - # Should have session start and end (new TypedEvent format) - assert len(events) >= 2 - assert events[0].get("type") == "bidi_connection_start" - assert events[0].get("connection_id") == nova_model.connection_id - assert events[-1].get("type") == "bidi_connection_close" - - -@pytest.mark.asyncio -async def test_event_conversion(nova_model): - """Test conversion of all Nova Sonic event types to standard format.""" - # Test audio output (now returns BidiAudioStreamEvent) - audio_bytes = b"test audio data" - audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") - nova_event = {"audioOutput": {"content": audio_base64}} - result = nova_model._convert_nova_event(nova_event) - assert result is not None - assert isinstance(result, BidiAudioStreamEvent) - assert result.get("type") == "bidi_audio_stream" - # Audio is kept as base64 string - assert result.get("audio") == audio_base64 - assert result.get("format") == "pcm" - assert result.get("sample_rate") == 24000 - - # Test text output (now returns BidiTranscriptStreamEvent) - nova_event = {"textOutput": {"content": "Hello, world!", "role": "ASSISTANT"}} - result = nova_model._convert_nova_event(nova_event) - assert result is not None - assert isinstance(result, BidiTranscriptStreamEvent) - assert result.get("type") == "bidi_transcript_stream" - assert result.get("text") == "Hello, world!" - assert result.get("role") == "assistant" - assert result.delta == {"text": "Hello, world!"} - assert result.current_transcript == "Hello, world!" - - # Test tool use (now returns ToolUseStreamEvent from core strands) - tool_input = {"location": "Seattle"} - nova_event = { - "toolUse": { - "toolUseId": "tool-123", - "toolName": "get_weather", - "content": json.dumps(tool_input) - } - } - result = nova_model._convert_nova_event(nova_event) - assert result is not None - # ToolUseStreamEvent has delta and current_tool_use, not a "type" field - assert "delta" in result - assert "toolUse" in result["delta"] - tool_use = result["delta"]["toolUse"] - assert tool_use["toolUseId"] == "tool-123" - assert tool_use["name"] == "get_weather" - assert tool_use["input"] == tool_input - - # Test interruption (now returns BidiInterruptionEvent) - nova_event = {"stopReason": "INTERRUPTED"} - result = nova_model._convert_nova_event(nova_event) - assert result is not None - assert isinstance(result, BidiInterruptionEvent) - assert result.get("type") == "bidi_interruption" - assert result.get("reason") == "user_speech" - - # Test usage metrics (now returns BidiUsageEvent) - nova_event = { - "usageEvent": { - "totalTokens": 100, - "totalInputTokens": 40, - "totalOutputTokens": 60, - "details": { - "total": { - "output": { - "speechTokens": 30 - } - } - } - } - } - result = nova_model._convert_nova_event(nova_event) - assert result is not None - assert isinstance(result, BidiUsageEvent) - assert result.get("type") == "bidi_usage" - assert result.get("totalTokens") == 100 - assert result.get("inputTokens") == 40 - assert result.get("outputTokens") == 60 - - # Test content start tracks role and emits BidiResponseStartEvent - nova_event = {"contentStart": {"role": "USER"}} - result = nova_model._convert_nova_event(nova_event) - assert result is not None - assert isinstance(result, BidiResponseStartEvent) - assert result.get("type") == "bidi_response_start" - assert nova_model._current_role == "USER" - - -# Audio Streaming Tests - - -@pytest.mark.asyncio -async def test_audio_connection_lifecycle(nova_model, mock_client, mock_stream): - """Test audio connection start and end lifecycle.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model.client = mock_client - - await nova_model.start() - - # Start audio connection - await nova_model._start_audio_connection() - assert nova_model.audio_connection_active - - # End audio connection - await nova_model._end_audio_input() - assert not nova_model.audio_connection_active - - await nova_model.stop() - - -@pytest.mark.asyncio -async def test_silence_detection(nova_model, mock_client, mock_stream): - """Test that silence detection automatically ends audio input.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model.client = mock_client - nova_model.silence_threshold = 0.1 # Short threshold for testing - - await nova_model.start() - - # Send audio to start connection (base64 encoded) - audio_b64 = base64.b64encode(b"audio data").decode('utf-8') - audio_event = BidiAudioInputEvent( - audio=audio_b64, - format="pcm", - sample_rate=16000, - channels=1 - ) - - await nova_model.send(audio_event) - assert nova_model.audio_connection_active - - # Wait for silence detection - await asyncio.sleep(0.2) - - # Audio connection should be ended - assert not nova_model.audio_connection_active - - await nova_model.stop() - - -# Helper Method Tests - - -@pytest.mark.asyncio -async def test_tool_configuration(nova_model): - """Test building tool configuration from tool specs.""" - tools = [ - { - "name": "get_weather", - "description": "Get weather information", - "inputSchema": { - "json": json.dumps({ - "type": "object", - "properties": { - "location": {"type": "string"} - } - }) - } - } - ] - - tool_config = nova_model._build_tool_configuration(tools) - - assert len(tool_config) == 1 - assert tool_config[0]["toolSpec"]["name"] == "get_weather" - assert tool_config[0]["toolSpec"]["description"] == "Get weather information" - assert "inputSchema" in tool_config[0]["toolSpec"] - - -@pytest.mark.asyncio -async def test_event_templates(nova_model): - """Test event template generation.""" - # Test connection start event - event_json = nova_model._get_connection_start_event() - event = json.loads(event_json) - assert "event" in event - assert "sessionStart" in event["event"] - assert "inferenceConfiguration" in event["event"]["sessionStart"] - - # Test prompt start event - nova_model.connection_id = "test-connection" - event_json = nova_model._get_prompt_start_event([]) - event = json.loads(event_json) - assert "event" in event - assert "promptStart" in event["event"] - assert event["event"]["promptStart"]["promptName"] == "test-connection" - - # Test text input event - content_name = "test-content" - event_json = nova_model._get_text_input_event(content_name, "Hello") - event = json.loads(event_json) - assert "event" in event - assert "textInput" in event["event"] - assert event["event"]["textInput"]["content"] == "Hello" - - # Test tool result event - result = {"result": "Success"} - event_json = nova_model._get_tool_result_event(content_name, result) - event = json.loads(event_json) - assert "event" in event - assert "toolResult" in event["event"] - assert json.loads(event["event"]["toolResult"]["content"]) == result - - -# Error Handling Tests - - -@pytest.mark.asyncio -async def test_error_handling(nova_model, mock_client, mock_stream): - """Test error handling in various scenarios.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model.client = mock_client - - # Test response processor handles errors gracefully - async def mock_error(*args, **kwargs): - raise Exception("Test error") - - mock_stream.await_output.side_effect = mock_error - - await nova_model.start() - - # Wait a bit for response processor to handle error - await asyncio.sleep(0.1) - - # Should still be able to close cleanly - await nova_model.stop() diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py deleted file mode 100644 index 44fe20204..000000000 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ /dev/null @@ -1,538 +0,0 @@ -"""Unit tests for OpenAI Realtime bidirectional streaming model. - -Tests the unified BidiOpenAIRealtimeModel interface including: -- Model initialization and configuration -- Connection establishment with WebSocket -- Unified send() method with different content types -- Event receiving and conversion -- Connection lifecycle management -""" - -import asyncio -import base64 -import json -import unittest.mock - -import pytest - -from strands.experimental.bidirectional_streaming.models.openai import BidiOpenAIRealtimeModel -from strands.experimental.bidirectional_streaming.types.events import ( - BidiAudioInputEvent, - BidiAudioStreamEvent, - BidiImageInputEvent, - BidiInterruptionEvent, - BidiResponseCompleteEvent, - BidiTextInputEvent, - BidiTranscriptStreamEvent, -) -from strands.types._events import ToolResultEvent -from strands.types.tools import ToolResult - - -@pytest.fixture -def mock_websocket(): - """Mock WebSocket connection.""" - mock_ws = unittest.mock.AsyncMock() - mock_ws.send = unittest.mock.AsyncMock() - mock_ws.close = unittest.mock.AsyncMock() - return mock_ws - - -@pytest.fixture -def mock_websockets_connect(mock_websocket): - """Mock websockets.connect function.""" - async def async_connect(*args, **kwargs): - return mock_websocket - - with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.websockets.connect") as mock_connect: - mock_connect.side_effect = async_connect - yield mock_connect, mock_websocket - - -@pytest.fixture -def model_name(): - return "gpt-realtime" - - -@pytest.fixture -def api_key(): - return "test-api-key" - - -@pytest.fixture -def model(api_key, model_name): - """Create an BidiOpenAIRealtimeModel instance.""" - return BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) - - -@pytest.fixture -def tool_spec(): - return { - "description": "Calculate mathematical expressions", - "name": "calculator", - "inputSchema": {"json": {"type": "object", "properties": {"expression": {"type": "string"}}}}, - } - - -@pytest.fixture -def system_prompt(): - return "You are a helpful assistant" - - -@pytest.fixture -def messages(): - return [{"role": "user", "content": [{"text": "Hello"}]}] - - -# Initialization Tests - - -def test_model_initialization(api_key, model_name): - """Test model initialization with various configurations.""" - # Test default config - model_default = BidiOpenAIRealtimeModel(api_key="test-key") - assert model_default.model == "gpt-realtime" - assert model_default.api_key == "test-key" - assert model_default._active is False - assert model_default.websocket is None - - # Test with custom model - model_custom = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) - assert model_custom.model == model_name - assert model_custom.api_key == api_key - - # Test with organization and project - model_org = BidiOpenAIRealtimeModel( - model=model_name, - api_key=api_key, - organization="org-123", - project="proj-456" - ) - assert model_org.organization == "org-123" - assert model_org.project == "proj-456" - - # Test with env API key - with unittest.mock.patch.dict("os.environ", {"OPENAI_API_KEY": "env-key"}): - model_env = BidiOpenAIRealtimeModel() - assert model_env.api_key == "env-key" - - -def test_init_without_api_key_raises(): - """Test that initialization without API key raises error.""" - with unittest.mock.patch.dict("os.environ", {}, clear=True): - with pytest.raises(ValueError, match="OpenAI API key is required"): - BidiOpenAIRealtimeModel() - - -# Connection Tests - - -@pytest.mark.asyncio -async def test_connection_lifecycle(mock_websockets_connect, model, system_prompt, tool_spec, messages): - """Test complete connection lifecycle with various configurations.""" - mock_connect, mock_ws = mock_websockets_connect - - # Test basic connection - await model.start() - assert model._active is True - assert model.connection_id is not None - assert model.websocket == mock_ws - assert model._event_queue is not None - assert model._response_task is not None - mock_connect.assert_called_once() - - # Test close - await model.stop() - assert model._active is False - mock_ws.close.assert_called_once() - - # Test connection with system prompt - await model.start(system_prompt=system_prompt) - calls = mock_ws.send.call_args_list - session_update = next( - (json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "session.update"), - None - ) - assert session_update is not None - assert system_prompt in session_update["session"]["instructions"] - await model.stop() - - # Test connection with tools - await model.start(tools=[tool_spec]) - calls = mock_ws.send.call_args_list - # Tools are sent in a separate session.update after initial connection - session_updates = [json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "session.update"] - assert len(session_updates) > 0 - # Check if any session update has tools - has_tools = any("tools" in update.get("session", {}) for update in session_updates) - assert has_tools - await model.stop() - - # Test connection with messages - await model.start(messages=messages) - calls = mock_ws.send.call_args_list - item_creates = [json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "conversation.item.create"] - assert len(item_creates) > 0 - await model.stop() - - # Test connection with organization header - model_org = BidiOpenAIRealtimeModel(api_key="test-key", organization="org-123") - await model_org.start() - call_kwargs = mock_connect.call_args.kwargs - headers = call_kwargs.get("additional_headers", []) - org_header = [h for h in headers if h[0] == "OpenAI-Organization"] - assert len(org_header) == 1 - assert org_header[0][1] == "org-123" - await model_org.stop() - - -@pytest.mark.asyncio -async def test_connection_edge_cases(mock_websockets_connect, api_key, model_name): - """Test connection error handling and edge cases.""" - mock_connect, mock_ws = mock_websockets_connect - - # Test connection error - model1 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) - mock_connect.side_effect = Exception("Connection failed") - with pytest.raises(Exception, match="Connection failed"): - await model1.start() - - # Reset mock - async def async_connect(*args, **kwargs): - return mock_ws - mock_connect.side_effect = async_connect - - # Test double connection - model2 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) - await model2.start() - with pytest.raises(RuntimeError, match="Connection already active"): - await model2.start() - await model2.stop() - - # Test close when not connected - model3 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) - await model3.stop() # Should not raise - - # Test close error handling (should not raise, just log) - model4 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) - await model4.start() - mock_ws.close.side_effect = Exception("Close failed") - await model4.stop() # Should not raise - assert model4._active is False - - -# Send Method Tests - - -@pytest.mark.asyncio -async def test_send_all_content_types(mock_websockets_connect, model): - """Test sending all content types through unified send() method.""" - _, mock_ws = mock_websockets_connect - await model.start() - - # Test text input - text_input = BidiTextInputEvent(text="Hello", role="user") - await model.send(text_input) - calls = mock_ws.send.call_args_list - messages = [json.loads(call[0][0]) for call in calls] - item_create = [m for m in messages if m.get("type") == "conversation.item.create"] - response_create = [m for m in messages if m.get("type") == "response.create"] - assert len(item_create) > 0 - assert len(response_create) > 0 - - # Test audio input (base64 encoded) - audio_b64 = base64.b64encode(b"audio_bytes").decode('utf-8') - audio_input = BidiAudioInputEvent( - audio=audio_b64, - format="pcm", - sample_rate=24000, - channels=1, - ) - await model.send(audio_input) - calls = mock_ws.send.call_args_list - messages = [json.loads(call[0][0]) for call in calls] - audio_append = [m for m in messages if m.get("type") == "input_audio_buffer.append"] - assert len(audio_append) > 0 - assert "audio" in audio_append[0] - # Audio should be passed through as base64 - assert audio_append[0]["audio"] == audio_b64 - - # Test tool result - tool_result: ToolResult = { - "toolUseId": "tool-123", - "status": "success", - "content": [{"text": "Result: 42"}], - } - await model.send(ToolResultEvent(tool_result)) - calls = mock_ws.send.call_args_list - messages = [json.loads(call[0][0]) for call in calls] - item_create = [m for m in messages if m.get("type") == "conversation.item.create"] - assert len(item_create) > 0 - item = item_create[-1].get("item", {}) - assert item.get("type") == "function_call_output" - assert item.get("call_id") == "tool-123" - - await model.stop() - - -@pytest.mark.asyncio -async def test_send_edge_cases(mock_websockets_connect, model): - """Test send() edge cases and error handling.""" - _, mock_ws = mock_websockets_connect - - # Test send when inactive - text_input = BidiTextInputEvent(text="Hello", role="user") - await model.send(text_input) - mock_ws.send.assert_not_called() - - # Test image input (not supported, base64 encoded, no encoding parameter) - await model.start() - image_b64 = base64.b64encode(b"image_bytes").decode('utf-8') - image_input = BidiImageInputEvent( - image=image_b64, - mime_type="image/jpeg", - ) - with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: - await model.send(image_input) - mock_logger.warning.assert_called_with("Image input not supported by OpenAI Realtime API") - - # Test unknown content type - unknown_content = {"unknown_field": "value"} - with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: - await model.send(unknown_content) - assert mock_logger.warning.called - - await model.stop() - - -# Receive Method Tests - - -@pytest.mark.asyncio -async def test_receive_lifecycle_events(mock_websockets_connect, model): - """Test that receive() emits connection start and end events.""" - _, _ = mock_websockets_connect - - await model.start() - - # Get first event - receive_gen = model.receive() - first_event = await anext(receive_gen) - - # First event should be connection start (new TypedEvent format) - assert first_event.get("type") == "bidi_connection_start" - assert first_event.get("connection_id") == model.connection_id - assert first_event.get("model") == model.model - - # Close to trigger session end - await model.stop() - - # Collect remaining events - events = [first_event] - try: - async for event in receive_gen: - events.append(event) - except StopAsyncIteration: - pass - - # Last event should be connection close (new TypedEvent format) - assert events[-1].get("type") == "bidi_connection_close" - - -@pytest.mark.asyncio -async def test_event_conversion(mock_websockets_connect, model): - """Test conversion of all OpenAI event types to standard format.""" - _, _ = mock_websockets_connect - await model.start() - - # Test audio output (now returns list with BidiAudioStreamEvent) - audio_event = { - "type": "response.output_audio.delta", - "delta": base64.b64encode(b"audio_data").decode() - } - converted = model._convert_openai_event(audio_event) - assert isinstance(converted, list) - assert len(converted) == 1 - assert isinstance(converted[0], BidiAudioStreamEvent) - assert converted[0].get("type") == "bidi_audio_stream" - assert converted[0].get("audio") == base64.b64encode(b"audio_data").decode() - assert converted[0].get("format") == "pcm" - - # Test text output (now returns list with BidiTranscriptStreamEvent) - text_event = { - "type": "response.output_text.delta", - "delta": "Hello from OpenAI" - } - converted = model._convert_openai_event(text_event) - assert isinstance(converted, list) - assert len(converted) == 1 - assert isinstance(converted[0], BidiTranscriptStreamEvent) - assert converted[0].get("type") == "bidi_transcript_stream" - assert converted[0].get("text") == "Hello from OpenAI" - assert converted[0].get("role") == "assistant" - assert converted[0].delta == {"text": "Hello from OpenAI"} - assert converted[0].is_final is True - - # Test function call sequence - item_added = { - "type": "response.output_item.added", - "item": { - "type": "function_call", - "call_id": "call-123", - "name": "calculator" - } - } - model._convert_openai_event(item_added) - - args_delta = { - "type": "response.function_call_arguments.delta", - "call_id": "call-123", - "delta": '{"expression": "2+2"}' - } - model._convert_openai_event(args_delta) - - args_done = { - "type": "response.function_call_arguments.done", - "call_id": "call-123" - } - converted = model._convert_openai_event(args_done) - # Now returns list with ToolUseStreamEvent - assert isinstance(converted, list) - assert len(converted) == 1 - # ToolUseStreamEvent has delta and current_tool_use, not a "type" field - assert "delta" in converted[0] - assert "toolUse" in converted[0]["delta"] - tool_use = converted[0]["delta"]["toolUse"] - assert tool_use["toolUseId"] == "call-123" - assert tool_use["name"] == "calculator" - assert tool_use["input"]["expression"] == "2+2" - - # Test voice activity (now returns list with BidiInterruptionEvent for speech_started) - speech_started = { - "type": "input_audio_buffer.speech_started" - } - converted = model._convert_openai_event(speech_started) - assert isinstance(converted, list) - assert len(converted) == 1 - assert isinstance(converted[0], BidiInterruptionEvent) - assert converted[0].get("type") == "bidi_interruption" - assert converted[0].get("reason") == "user_speech" - - # Test response.cancelled event (should return ResponseCompleteEvent with interrupted reason) - response_cancelled = { - "type": "response.cancelled", - "response": { - "id": "resp_123" - } - } - converted = model._convert_openai_event(response_cancelled) - assert isinstance(converted, list) - assert len(converted) == 1 - assert isinstance(converted[0], BidiResponseCompleteEvent) - assert converted[0].get("type") == "bidi_response_complete" - assert converted[0].get("response_id") == "resp_123" - assert converted[0].get("stop_reason") == "interrupted" - - # Test error handling - response_cancel_not_active should be suppressed - error_cancel_not_active = { - "type": "error", - "error": { - "code": "response_cancel_not_active", - "message": "No active response to cancel" - } - } - converted = model._convert_openai_event(error_cancel_not_active) - assert converted is None # Should be suppressed - - # Test error handling - other errors should be logged but return None - error_other = { - "type": "error", - "error": { - "code": "some_other_error", - "message": "Something went wrong" - } - } - converted = model._convert_openai_event(error_other) - assert converted is None - - await model.stop() - - -# Helper Method Tests - - -def test_config_building(model, system_prompt, tool_spec): - """Test building session config with various options.""" - # Test basic config - config_basic = model._build_session_config(None, None) - assert isinstance(config_basic, dict) - assert "instructions" in config_basic - assert "audio" in config_basic - - # Test with system prompt - config_prompt = model._build_session_config(system_prompt, None) - assert config_prompt["instructions"] == system_prompt - - # Test with tools - config_tools = model._build_session_config(None, [tool_spec]) - assert "tools" in config_tools - assert len(config_tools["tools"]) > 0 - - -def test_tool_conversion(model, tool_spec): - """Test tool conversion to OpenAI format.""" - # Test with tools - openai_tools = model._convert_tools_to_openai_format([tool_spec]) - assert len(openai_tools) == 1 - assert openai_tools[0]["type"] == "function" - assert openai_tools[0]["name"] == "calculator" - assert openai_tools[0]["description"] == "Calculate mathematical expressions" - - # Test empty list - openai_empty = model._convert_tools_to_openai_format([]) - assert openai_empty == [] - - -def test_helper_methods(model): - """Test various helper methods.""" - # Test _require_active - assert model._require_active() is False - model._active = True - assert model._require_active() is True - model._active = False - - # Test _create_text_event (now returns BidiTranscriptStreamEvent) - text_event = model._create_text_event("Hello", "user") - assert isinstance(text_event, BidiTranscriptStreamEvent) - assert text_event.get("type") == "bidi_transcript_stream" - assert text_event.get("text") == "Hello" - assert text_event.get("role") == "user" - assert text_event.delta == {"text": "Hello"} - assert text_event.is_final is True - assert text_event.current_transcript == "Hello" - - # Test _create_voice_activity_event (now returns BidiInterruptionEvent for speech_started) - voice_event = model._create_voice_activity_event("speech_started") - assert isinstance(voice_event, BidiInterruptionEvent) - assert voice_event.get("type") == "bidi_interruption" - assert voice_event.get("reason") == "user_speech" - - # Other voice activities return None - assert model._create_voice_activity_event("speech_stopped") is None - - -@pytest.mark.asyncio -async def test_send_event_helper(mock_websockets_connect, model): - """Test _send_event helper method.""" - _, mock_ws = mock_websockets_connect - await model.start() - - test_event = {"type": "test.event", "data": "test"} - await model._send_event(test_event) - - calls = mock_ws.send.call_args_list - last_call = calls[-1] - sent_message = json.loads(last_call[0][0]) - assert sent_message == test_event - - await model.stop() diff --git a/tests/strands/experimental/bidirectional_streaming/types/__init__.py b/tests/strands/experimental/bidirectional_streaming/types/__init__.py deleted file mode 100644 index a1330e552..000000000 --- a/tests/strands/experimental/bidirectional_streaming/types/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for bidirectional streaming types.""" diff --git a/tests/strands/experimental/bidirectional_streaming/types/test_events.py b/tests/strands/experimental/bidirectional_streaming/types/test_events.py deleted file mode 100644 index bc7ec4844..000000000 --- a/tests/strands/experimental/bidirectional_streaming/types/test_events.py +++ /dev/null @@ -1,164 +0,0 @@ -"""Tests for bidirectional streaming event types. - -This module tests JSON serialization for all bidirectional streaming event types. -""" - -import base64 -import json - -import pytest - -from strands.experimental.bidirectional_streaming.types.events import ( - BidiAudioInputEvent, - BidiAudioStreamEvent, - BidiConnectionCloseEvent, - BidiConnectionStartEvent, - BidiErrorEvent, - BidiImageInputEvent, - BidiInterruptionEvent, - BidiResponseCompleteEvent, - BidiResponseStartEvent, - BidiTextInputEvent, - BidiTranscriptStreamEvent, - BidiUsageEvent, -) - - -@pytest.mark.parametrize( - "event_class,kwargs,expected_type", - [ - # Input events - (BidiTextInputEvent, {"text": "Hello", "role": "user"}, "bidi_text_input"), - ( - BidiAudioInputEvent, - { - "audio": base64.b64encode(b"audio").decode("utf-8"), - "format": "pcm", - "sample_rate": 16000, - "channels": 1, - }, - "bidi_audio_input", - ), - ( - BidiImageInputEvent, - {"image": base64.b64encode(b"image").decode("utf-8"), "mime_type": "image/jpeg"}, - "bidi_image_input", - ), - # Output events - ( - BidiConnectionStartEvent, - {"connection_id": "c1", "model": "m1"}, - "bidi_connection_start", - ), - (BidiResponseStartEvent, {"response_id": "r1"}, "bidi_response_start"), - ( - BidiAudioStreamEvent, - { - "audio": base64.b64encode(b"audio").decode("utf-8"), - "format": "pcm", - "sample_rate": 24000, - "channels": 1, - }, - "bidi_audio_stream", - ), - ( - BidiTranscriptStreamEvent, - { - "delta": {"text": "Hello"}, - "text": "Hello", - "role": "assistant", - "is_final": True, - "current_transcript": "Hello", - }, - "bidi_transcript_stream", - ), - (BidiInterruptionEvent, {"reason": "user_speech"}, "bidi_interruption"), - ( - BidiResponseCompleteEvent, - {"response_id": "r1", "stop_reason": "complete"}, - "bidi_response_complete", - ), - ( - BidiUsageEvent, - {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - "bidi_usage", - ), - ( - BidiConnectionCloseEvent, - {"connection_id": "c1", "reason": "complete"}, - "bidi_connection_close", - ), - (BidiErrorEvent, {"error": ValueError("test"), "details": None}, "bidi_error"), - ], -) -def test_event_json_serialization(event_class, kwargs, expected_type): - """Test that all event types are JSON serializable and deserializable.""" - # Create event - event = event_class(**kwargs) - - # Verify type field - assert event["type"] == expected_type - - # Serialize to JSON - json_str = json.dumps(event) - print("event_class:", event_class) - print(json_str) - # Deserialize back - data = json.loads(json_str) - - # Verify type preserved - assert data["type"] == expected_type - - # Verify all non-private keys preserved - for key in event.keys(): - if not key.startswith("_"): - assert key in data - - - -def test_transcript_stream_event_delta_pattern(): - """Test that BidiTranscriptStreamEvent follows ModelStreamEvent delta pattern.""" - # Test partial transcript (delta) - partial_event = BidiTranscriptStreamEvent( - delta={"text": "Hello"}, - text="Hello", - role="user", - is_final=False, - current_transcript=None, - ) - - assert partial_event.text == "Hello" - assert partial_event.role == "user" - assert partial_event.is_final is False - assert partial_event.current_transcript is None - assert partial_event.delta == {"text": "Hello"} - - # Test final transcript with accumulated text - final_event = BidiTranscriptStreamEvent( - delta={"text": " world"}, - text=" world", - role="user", - is_final=True, - current_transcript="Hello world", - ) - - assert final_event.text == " world" - assert final_event.role == "user" - assert final_event.is_final is True - assert final_event.current_transcript == "Hello world" - assert final_event.delta == {"text": " world"} - - -def test_transcript_stream_event_extends_model_stream_event(): - """Test that BidiTranscriptStreamEvent is a ModelStreamEvent.""" - from strands.types._events import ModelStreamEvent - - event = BidiTranscriptStreamEvent( - delta={"text": "test"}, - text="test", - role="assistant", - is_final=True, - current_transcript="test", - ) - - assert isinstance(event, ModelStreamEvent) diff --git a/tests_integ/bidirectional_streaming/__init__.py b/tests_integ/bidirectional_streaming/__init__.py deleted file mode 100644 index 05da9afcb..000000000 --- a/tests_integ/bidirectional_streaming/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Integration tests for bidirectional streaming agents.""" diff --git a/tests_integ/bidirectional_streaming/conftest.py b/tests_integ/bidirectional_streaming/conftest.py deleted file mode 100644 index 0d453818a..000000000 --- a/tests_integ/bidirectional_streaming/conftest.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Pytest fixtures for bidirectional streaming integration tests.""" - -import logging - -import pytest - -from .generators.audio import AudioGenerator - -logger = logging.getLogger(__name__) - - -@pytest.fixture(scope="session") -def audio_generator(): - """Provide AudioGenerator instance for tests.""" - return AudioGenerator(region="us-east-1") - - -@pytest.fixture(autouse=True) -def setup_logging(): - """Configure logging for tests.""" - logging.basicConfig( - level=logging.DEBUG, - format="%(levelname)s | %(name)s | %(message)s", - ) - # Reduce noise from some loggers - logging.getLogger("boto3").setLevel(logging.WARNING) - logging.getLogger("botocore").setLevel(logging.WARNING) - logging.getLogger("urllib3").setLevel(logging.WARNING) diff --git a/tests_integ/bidirectional_streaming/context.py b/tests_integ/bidirectional_streaming/context.py deleted file mode 100644 index 349ad0cb9..000000000 --- a/tests_integ/bidirectional_streaming/context.py +++ /dev/null @@ -1,365 +0,0 @@ -"""Test context manager for bidirectional streaming tests. - -Provides a high-level interface for testing bidirectional streaming agents -with continuous background threads that mimic real-world usage patterns. -""" - -import asyncio -import base64 -import logging -import time -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from strands.experimental.bidirectional_streaming.agent.agent import BidiAgent - from .generators.audio import AudioGenerator - -logger = logging.getLogger(__name__) - -# Constants for timing and buffering -QUEUE_POLL_TIMEOUT = 0.05 # 50ms - balance between responsiveness and CPU usage -SILENCE_INTERVAL = 0.05 # 50ms - send silence every 50ms when queue empty -AUDIO_CHUNK_DELAY = 0.01 # 10ms - small delay between audio chunks -WAIT_POLL_INTERVAL = 0.1 # 100ms - how often to check for response completion - - -class BidirectionalTestContext: - """Manages threads and generators for bidirectional streaming tests. - - Mimics real-world usage with continuous background threads: - - Audio input thread (microphone simulation with silence padding) - - Event collection thread (captures all model outputs) - - Generators feed data into threads via queues for natural conversation flow. - - Example: - async with BidirectionalTestContext(agent, audio_generator) as ctx: - await ctx.say("What is 5 plus 3?") - await ctx.wait_for_response() - assert "8" in " ".join(ctx.get_text_outputs()) - """ - - def __init__( - self, - agent: "BidiAgent", - audio_generator: "AudioGenerator | None" = None, - silence_chunk_size: int = 1024, - audio_chunk_size: int = 1024, - ): - """Initialize test context. - - Args: - agent: BidiAgent instance. - audio_generator: AudioGenerator for text-to-speech. - silence_chunk_size: Size of silence chunks in bytes. - audio_chunk_size: Size of audio chunks for streaming. - """ - self.agent = agent - self.audio_generator = audio_generator - self.silence_chunk_size = silence_chunk_size - self.audio_chunk_size = audio_chunk_size - - # Queue for thread communication - self.input_queue = asyncio.Queue() # Handles both audio and text input - - # Event storage (thread-safe) - self._event_queue = asyncio.Queue() # Events from collection thread - self.events = [] # Cached events for test access - self.last_event_time = None - - # Control flags - self.active = False - self.threads = [] - - async def __aenter__(self): - """Start context manager, agent session, and background threads.""" - # Start agent session - await self.agent.start() - logger.debug("Agent session started") - - await self.start() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """Stop context manager, cleanup threads, and end agent session.""" - # End agent session FIRST - this will cause receive() to exit cleanly - if self.agent._agent_loop and self.agent._agent_loop.active: - await self.agent.stop() - logger.debug("Agent session stopped") - - # Then stop the context threads - await self.stop() - - return False - - async def start(self): - """Start all background threads.""" - self.active = True - self.last_event_time = time.monotonic() - - self.threads = [ - asyncio.create_task(self._input_thread()), - asyncio.create_task(self._event_collection_thread()), - ] - - logger.debug("Test context started with %d threads", len(self.threads)) - - async def stop(self): - """Stop all threads gracefully.""" - if not self.active: - logger.debug("stop() called but already stopped") - return - - logger.debug("stop() called - stopping threads") - self.active = False - - # Cancel all threads - for task in self.threads: - if not task.done(): - task.cancel() - - # Wait for cancellation - await asyncio.gather(*self.threads, return_exceptions=True) - - logger.debug("Test context stopped") - - # === User-facing methods === - - async def say(self, text: str): - """Convert text to audio and queue audio chunks to be sent to model. - - Args: - text: Text to convert to speech and send as audio. - - Raises: - ValueError: If audio generator is not available. - """ - if not self.audio_generator: - raise ValueError( - "Audio generator not available. Pass audio_generator to BidirectionalTestContext." - ) - - # Generate audio via Polly - audio_data = await self.audio_generator.generate_audio(text) - - # Split into chunks and queue each chunk - for i in range(0, len(audio_data), self.audio_chunk_size): - chunk = audio_data[i : i + self.audio_chunk_size] - chunk_event = self.audio_generator.create_audio_input_event(chunk) - await self.input_queue.put({"type": "audio_chunk", "data": chunk_event}) - - logger.debug(f"Queued {len(audio_data)} bytes of audio for: {text[:50]}...") - - async def send(self, data: str | dict) -> None: - """Send data directly to model (text, image, etc.). - - Args: - data: Data to send to model. Can be: - - str: Text input - - dict: Custom event (e.g., image, audio) - """ - await self.input_queue.put({"type": "direct", "data": data}) - logger.debug(f"Queued direct send: {type(data).__name__}") - - async def wait_for_response( - self, - timeout: float = 15.0, - silence_threshold: float = 2.0, - min_events: int = 1, - ): - """Wait for model to finish responding. - - Uses silence detection (no events for silence_threshold seconds) - combined with minimum event count to determine response completion. - - Args: - timeout: Maximum time to wait in seconds. - silence_threshold: Seconds of silence to consider response complete. - min_events: Minimum events before silence detection activates. - """ - start_time = time.monotonic() - initial_event_count = len(self.get_events()) # Drain queue - - while time.monotonic() - start_time < timeout: - # Drain queue to get latest events - current_events = self.get_events() - - # Check if we have minimum events - if len(current_events) - initial_event_count >= min_events: - # Check silence - elapsed_since_event = time.monotonic() - self.last_event_time - if elapsed_since_event >= silence_threshold: - logger.debug( - f"Response complete: {len(current_events) - initial_event_count} events, " - f"{elapsed_since_event:.1f}s silence" - ) - return - - await asyncio.sleep(WAIT_POLL_INTERVAL) - - logger.warning(f"Response timeout after {timeout}s") - - def get_events(self, event_type: str | None = None) -> list[dict]: - """Get collected events, optionally filtered by type. - - Drains the event queue and caches events for subsequent calls. - - Args: - event_type: Optional event type to filter by (e.g., "textOutput"). - - Returns: - List of events, filtered if event_type specified. - """ - # Drain queue into cache (non-blocking) - while not self._event_queue.empty(): - try: - event = self._event_queue.get_nowait() - self.events.append(event) - self.last_event_time = time.monotonic() - except asyncio.QueueEmpty: - break - - if event_type: - return [e for e in self.events if event_type in e] - return self.events.copy() - - def get_text_outputs(self) -> list[str]: - """Extract text outputs from collected events. - - Handles both new TypedEvent format and legacy event formats. - - Returns: - List of text content strings. - """ - texts = [] - for event in self.get_events(): # Drain queue first - # Handle new TypedEvent format (bidi_transcript_stream) - if event.get("type") == "bidi_transcript_stream": - text = event.get("text", "") - if text: - texts.append(text) - # Handle legacy textOutput events (Nova Sonic, OpenAI) - elif "textOutput" in event: - text = event["textOutput"].get("text", "") - if text: - texts.append(text) - # Handle legacy transcript events (Gemini Live) - elif "transcript" in event: - text = event["transcript"].get("text", "") - if text: - texts.append(text) - return texts - - def get_audio_outputs(self) -> list[bytes]: - """Extract audio outputs from collected events. - - Returns: - List of audio data bytes. - """ - # Drain queue first to get latest events - events = self.get_events() - audio_data = [] - for event in events: - # Handle new TypedEvent format (bidi_audio_stream) - if event.get("type") == "bidi_audio_stream": - audio_b64 = event.get("audio") - if audio_b64: - # Decode base64 to bytes - audio_data.append(base64.b64decode(audio_b64)) - # Handle legacy audioOutput events - elif "audioOutput" in event: - data = event["audioOutput"].get("audioData") - if data: - audio_data.append(data) - return audio_data - - def get_tool_uses(self) -> list[dict]: - """Extract tool use events from collected events. - - Returns: - List of tool use events. - """ - # Drain queue first to get latest events - events = self.get_events() - return [event["toolUse"] for event in events if "toolUse" in event] - - def has_interruption(self) -> bool: - """Check if any interruption was detected. - - Returns: - True if interruption detected in events. - """ - return any("interruptionDetected" in event for event in self.events) - - def clear_events(self): - """Clear collected events (useful for multi-turn tests).""" - self.events.clear() - logger.debug("Events cleared") - - # === Background threads === - - async def _input_thread(self): - """Continuously handle input to model. - - - Sends queued audio chunks immediately - - Sends silence chunks periodically when queue is empty (simulates microphone) - - Sends direct data to model - """ - try: - logger.debug(f"Input thread starting, active={self.active}") - while self.active: - try: - # Check for queued input (non-blocking with short timeout) - input_item = await asyncio.wait_for(self.input_queue.get(), timeout=QUEUE_POLL_TIMEOUT) - - if input_item["type"] == "audio_chunk": - # Send pre-generated audio chunk - await self.agent.send(input_item["data"]) - await asyncio.sleep(AUDIO_CHUNK_DELAY) - - elif input_item["type"] == "direct": - # Send data directly to agent - await self.agent.send(input_item["data"]) - data_repr = str(input_item["data"])[:50] if isinstance(input_item["data"], str) else type(input_item["data"]).__name__ - logger.debug(f"Sent direct: {data_repr}") - - except asyncio.TimeoutError: - # No input queued - send silence chunk to simulate continuous microphone input - if self.audio_generator: - silence = self._generate_silence_chunk() - await self.agent.send(silence) - await asyncio.sleep(SILENCE_INTERVAL) - - except asyncio.CancelledError: - logger.debug("Input thread cancelled") - raise # Re-raise to properly propagate cancellation - except Exception as e: - logger.error(f"Input thread error: {e}", exc_info=True) - finally: - logger.debug(f"Input thread stopped, active={self.active}") - - async def _event_collection_thread(self): - """Continuously collect events from model.""" - try: - async for event in self.agent.receive(): - if not self.active: - break - - # Thread-safe: put in queue instead of direct append - await self._event_queue.put(event) - logger.debug(f"Event collected: {list(event.keys())}") - - except asyncio.CancelledError: - logger.debug("Event collection thread cancelled") - raise # Re-raise to properly propagate cancellation - except Exception as e: - logger.error(f"Event collection thread error: {e}") - - def _generate_silence_chunk(self) -> dict: - """Generate silence chunk for background audio. - - Returns: - BidiAudioInputEvent with silence data. - """ - silence = b"\x00" * self.silence_chunk_size - return self.audio_generator.create_audio_input_event(silence) diff --git a/tests_integ/bidirectional_streaming/generators/__init__.py b/tests_integ/bidirectional_streaming/generators/__init__.py deleted file mode 100644 index 1f13f0564..000000000 --- a/tests_integ/bidirectional_streaming/generators/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Test data generators for bidirectional streaming integration tests.""" diff --git a/tests_integ/bidirectional_streaming/generators/audio.py b/tests_integ/bidirectional_streaming/generators/audio.py deleted file mode 100644 index 75c17a1e3..000000000 --- a/tests_integ/bidirectional_streaming/generators/audio.py +++ /dev/null @@ -1,159 +0,0 @@ -"""Audio generation utilities using Amazon Polly for test audio input. - -Provides text-to-speech conversion for generating realistic audio test data -without requiring physical audio devices or pre-recorded files. -""" - -import base64 -import hashlib -import logging -from pathlib import Path -from typing import Literal - -import boto3 - -logger = logging.getLogger(__name__) - -# Audio format constants matching Nova Sonic requirements -NOVA_SONIC_SAMPLE_RATE = 16000 -NOVA_SONIC_CHANNELS = 1 -NOVA_SONIC_FORMAT = "pcm" - -# Polly configuration -POLLY_VOICE_ID = "Matthew" # US English male voice -POLLY_ENGINE = "neural" # Higher quality neural engine - -# Cache directory for generated audio -CACHE_DIR = Path(__file__).parent.parent / ".audio_cache" - - -class AudioGenerator: - """Generate test audio using Amazon Polly with caching.""" - - def __init__(self, region: str = "us-east-1"): - """Initialize audio generator with Polly client. - - Args: - region: AWS region for Polly service. - """ - self.polly_client = boto3.client("polly", region_name=region) - self._ensure_cache_dir() - - def _ensure_cache_dir(self) -> None: - """Create cache directory if it doesn't exist.""" - CACHE_DIR.mkdir(parents=True, exist_ok=True) - - def _get_cache_key(self, text: str, voice_id: str) -> str: - """Generate cache key from text and voice.""" - content = f"{text}:{voice_id}".encode("utf-8") - return hashlib.md5(content).hexdigest() - - def _get_cache_path(self, cache_key: str) -> Path: - """Get cache file path for given key.""" - return CACHE_DIR / f"{cache_key}.pcm" - - async def generate_audio( - self, - text: str, - voice_id: str = POLLY_VOICE_ID, - use_cache: bool = True, - ) -> bytes: - """Generate audio from text using Polly with caching. - - Args: - text: Text to convert to speech. - voice_id: Polly voice ID to use. - use_cache: Whether to use cached audio if available. - - Returns: - Raw PCM audio bytes at 16kHz mono (Nova Sonic format). - """ - # Check cache first - if use_cache: - cache_key = self._get_cache_key(text, voice_id) - cache_path = self._get_cache_path(cache_key) - - if cache_path.exists(): - logger.debug(f"Using cached audio for: {text[:50]}...") - return cache_path.read_bytes() - - # Generate audio with Polly - logger.debug(f"Generating audio with Polly: {text[:50]}...") - - try: - response = self.polly_client.synthesize_speech( - Text=text, - OutputFormat="pcm", # Raw PCM format - VoiceId=voice_id, - Engine=POLLY_ENGINE, - SampleRate=str(NOVA_SONIC_SAMPLE_RATE), - ) - - # Read audio data - audio_data = response["AudioStream"].read() - - # Cache for future use - if use_cache: - cache_path.write_bytes(audio_data) - logger.debug(f"Cached audio: {cache_path}") - - return audio_data - - except Exception as e: - logger.error(f"Polly audio generation failed: {e}") - raise - - def create_audio_input_event( - self, - audio_data: bytes, - format: Literal["pcm", "wav", "opus", "mp3"] = NOVA_SONIC_FORMAT, - sample_rate: int = NOVA_SONIC_SAMPLE_RATE, - channels: int = NOVA_SONIC_CHANNELS, - ) -> dict: - """Create BidiAudioInputEvent from raw audio data. - - Args: - audio_data: Raw audio bytes. - format: Audio format. - sample_rate: Sample rate in Hz. - channels: Number of audio channels. - - Returns: - BidiAudioInputEvent dict ready for agent.send(). - """ - # Convert bytes to base64 string for JSON compatibility - audio_b64 = base64.b64encode(audio_data).decode('utf-8') - - return { - "type": "bidi_audio_input", - "audio": audio_b64, - "format": format, - "sample_rate": sample_rate, - "channels": channels, - } - - def clear_cache(self) -> None: - """Clear all cached audio files.""" - if CACHE_DIR.exists(): - for cache_file in CACHE_DIR.glob("*.pcm"): - cache_file.unlink() - logger.info("Audio cache cleared") - - -# Convenience function for quick audio generation -async def generate_test_audio(text: str, use_cache: bool = True) -> dict: - """Generate test audio input event from text. - - Convenience function that creates an AudioGenerator and returns - a ready-to-use BidiAudioInputEvent. - - Args: - text: Text to convert to speech. - use_cache: Whether to use cached audio. - - Returns: - BidiAudioInputEvent dict ready for agent.send(). - """ - generator = AudioGenerator() - audio_data = await generator.generate_audio(text, use_cache=use_cache) - return generator.create_audio_input_event(audio_data) diff --git a/tests_integ/bidirectional_streaming/test_bidirectional_agent.py b/tests_integ/bidirectional_streaming/test_bidirectional_agent.py deleted file mode 100644 index e93a267a0..000000000 --- a/tests_integ/bidirectional_streaming/test_bidirectional_agent.py +++ /dev/null @@ -1,220 +0,0 @@ -"""Parameterized integration tests for bidirectional streaming. - -Tests fundamental functionality across multiple model providers (Nova Sonic, OpenAI, etc.) -including multi-turn conversations, audio I/O, text transcription, and tool execution. - -This demonstrates the provider-agnostic design of the bidirectional streaming system. -""" - -import asyncio -import logging -import os - -import pytest - -from strands import tool -from strands.experimental.bidirectional_streaming.agent.agent import BidiAgent -from strands.experimental.bidirectional_streaming.models.novasonic import BidiNovaSonicModel -from strands.experimental.bidirectional_streaming.models.openai import BidiOpenAIRealtimeModel -from strands.experimental.bidirectional_streaming.models.gemini_live import BidiGeminiLiveModel - -from .context import BidirectionalTestContext - -logger = logging.getLogger(__name__) - - -# Simple calculator tool for testing -@tool -def calculator(operation: str, x: float, y: float) -> float: - """Perform basic arithmetic operations. - - Args: - operation: The operation to perform (add, subtract, multiply, divide) - x: First number - y: Second number - - Returns: - Result of the operation - """ - if operation == "add": - return x + y - elif operation == "subtract": - return x - y - elif operation == "multiply": - return x * y - elif operation == "divide": - if y == 0: - raise ValueError("Cannot divide by zero") - return x / y - else: - raise ValueError(f"Unknown operation: {operation}") - - -# Provider configurations -PROVIDER_CONFIGS = { - "nova_sonic": { - "model_class": BidiNovaSonicModel, - "model_kwargs": {"region": "us-east-1"}, - "silence_duration": 2.5, # Nova Sonic needs 2+ seconds of silence - "env_vars": ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], - "skip_reason": "AWS credentials not available", - }, - "openai": { - "model_class": BidiOpenAIRealtimeModel, - "model_kwargs": { - "model": "gpt-4o-realtime-preview-2024-12-17", - "session": { - "output_modalities": ["audio"], # OpenAI only supports audio OR text, not both - "audio": { - "input": { - "format": {"type": "audio/pcm", "rate": 24000}, - "turn_detection": { - "type": "server_vad", - "threshold": 0.5, - "silence_duration_ms": 700, - }, - }, - "output": {"format": {"type": "audio/pcm", "rate": 24000}, "voice": "alloy"}, - }, - }, - }, - "silence_duration": 1.0, # OpenAI has faster VAD - "env_vars": ["OPENAI_API_KEY"], - "skip_reason": "OPENAI_API_KEY not available", - }, - "gemini_live": { - "model_class": BidiGeminiLiveModel, - "model_kwargs": { - # Uses default model and config (audio output + transcription enabled) - }, - "silence_duration": 1.5, # Gemini has good VAD, similar to OpenAI - "env_vars": ["GOOGLE_AI_API_KEY"], - "skip_reason": "GOOGLE_AI_API_KEY not available", - }, -} - - -def check_provider_available(provider_name: str) -> tuple[bool, str]: - """Check if a provider's credentials are available. - - Args: - provider_name: Name of the provider to check. - - Returns: - Tuple of (is_available, skip_reason). - """ - config = PROVIDER_CONFIGS[provider_name] - env_vars = config["env_vars"] - - missing_vars = [var for var in env_vars if not os.getenv(var)] - - if missing_vars: - return False, f"{config['skip_reason']}: {', '.join(missing_vars)}" - - return True, "" - - -@pytest.fixture(params=list(PROVIDER_CONFIGS.keys())) -def provider_config(request): - """Provide configuration for each model provider. - - This fixture is parameterized to run tests against all available providers. - """ - provider_name = request.param - config = PROVIDER_CONFIGS[provider_name] - - # Check if provider is available - is_available, skip_reason = check_provider_available(provider_name) - if not is_available: - pytest.skip(skip_reason) - - return { - "name": provider_name, - **config, - } - - -@pytest.fixture -def agent_with_calculator(provider_config): - """Provide bidirectional agent with calculator tool for the given provider. - - Note: Session lifecycle (start/end) is handled by BidirectionalTestContext. - """ - model_class = provider_config["model_class"] - model_kwargs = provider_config["model_kwargs"] - - model = model_class(**model_kwargs) - return BidiAgent( - model=model, - tools=[calculator], - system_prompt="You are a helpful assistant with access to a calculator tool. Keep responses brief.", - ) - -@pytest.mark.asyncio -async def test_bidirectional_agent(agent_with_calculator, audio_generator, provider_config): - """Test multi-turn conversation with follow-up questions across providers. - - This test runs against all configured providers (Nova Sonic, OpenAI, etc.) - to validate provider-agnostic functionality. - - Validates: - - Session lifecycle (start/end via context manager) - - Audio input streaming - - Speech-to-text transcription - - Tool execution (calculator) - - Multi-turn conversation flow - - Text-to-speech audio output - """ - provider_name = provider_config["name"] - silence_duration = provider_config["silence_duration"] - - logger.info(f"Testing provider: {provider_name}") - - async with BidirectionalTestContext(agent_with_calculator, audio_generator) as ctx: - # Turn 1: Simple greeting to test basic audio I/O - await ctx.say("Hello, can you hear me?") - # Wait for silence to trigger provider's VAD/silence detection - await asyncio.sleep(silence_duration) - await ctx.wait_for_response() - - text_outputs_turn1 = ctx.get_text_outputs() - all_text_turn1 = " ".join(text_outputs_turn1).lower() - - # Validate turn 1 - just check we got a response - assert len(text_outputs_turn1) > 0, ( - f"[{provider_name}] No text output received in turn 1" - ) - - logger.info(f"[{provider_name}] ✓ Turn 1 complete: received response") - logger.info(f"[{provider_name}] Response: {text_outputs_turn1[0][:100]}...") - - # Turn 2: Follow-up to test multi-turn conversation - await ctx.say("What's your name?") - # Wait for silence to trigger provider's VAD/silence detection - await asyncio.sleep(silence_duration) - await ctx.wait_for_response() - - text_outputs_turn2 = ctx.get_text_outputs() - - # Validate turn 2 - check we got more responses - assert len(text_outputs_turn2) > len(text_outputs_turn1), ( - f"[{provider_name}] No new text output in turn 2" - ) - - logger.info(f"[{provider_name}] ✓ Turn 2 complete: multi-turn conversation works") - logger.info(f"[{provider_name}] Total responses: {len(text_outputs_turn2)}") - - # Validate full conversation - # Validate audio outputs - audio_outputs = ctx.get_audio_outputs() - assert len(audio_outputs) > 0, f"[{provider_name}] No audio output received" - total_audio_bytes = sum(len(audio) for audio in audio_outputs) - - # Summary - logger.info("=" * 60) - logger.info(f"[{provider_name}] ✓ Multi-turn conversation test PASSED") - logger.info(f" Provider: {provider_name}") - logger.info(f" Total events: {len(ctx.get_events())}") - logger.info(f" Text responses: {len(text_outputs_turn2)}") - logger.info(f" Audio chunks: {len(audio_outputs)} ({total_audio_bytes:,} bytes)") - logger.info("=" * 60) diff --git a/tests_integ/bidirectional_streaming/wrappers/__init__.py b/tests_integ/bidirectional_streaming/wrappers/__init__.py deleted file mode 100644 index 6b8a64984..000000000 --- a/tests_integ/bidirectional_streaming/wrappers/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Wrappers for bidirectional streaming integration tests. - -Includes fault injection and other transparent wrappers around real implementations. -""" From 24e189200f7ca3cd2ce39066cbb62a347db58314 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 12 Nov 2025 13:22:58 -0500 Subject: [PATCH 112/242] Update import path --- src/strands/experimental/bidi/agent/agent.py | 2 +- src/strands/experimental/bidi/models/__init__.py | 2 +- src/strands/experimental/bidi/models/gemini_live.py | 2 +- src/strands/experimental/bidi/models/novasonic.py | 2 +- src/strands/experimental/bidi/models/openai.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index eab909449..aa3930f0c 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -27,7 +27,7 @@ from ....types.tools import ToolResult, ToolUse, AgentTool from .loop import _BidiAgentLoop -from ..models.bidirectional_model import BidiModel +from ..models.bidi_model import BidiModel from ..models.novasonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput from ..types.events import BidiAudioInputEvent, BidiImageInputEvent, BidiTextInputEvent, BidiInputEvent, BidiOutputEvent diff --git a/src/strands/experimental/bidi/models/__init__.py b/src/strands/experimental/bidi/models/__init__.py index 6d6d6590b..13aaa9697 100644 --- a/src/strands/experimental/bidi/models/__init__.py +++ b/src/strands/experimental/bidi/models/__init__.py @@ -1,6 +1,6 @@ """Bidirectional model interfaces and implementations.""" -from .bidirectional_model import BidiModel +from .bidi_model import BidiModel from .gemini_live import BidiGeminiLiveModel from .novasonic import BidiNovaSonicModel from .openai import BidiOpenAIRealtimeModel diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 9bb5bba77..ef08233c3 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -40,7 +40,7 @@ BidiResponseCompleteEvent, BidiResponseStartEvent, ) -from .bidirectional_model import BidiModel +from .bidi_model import BidiModel logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 8c23aa0da..13373c69c 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -47,7 +47,7 @@ BidiResponseCompleteEvent, BidiResponseStartEvent, ) -from .bidirectional_model import BidiModel +from .bidi_model import BidiModel logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 74f1942ff..13a09f8d4 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -34,7 +34,7 @@ BidiResponseCompleteEvent, BidiResponseStartEvent, ) -from .bidirectional_model import BidiModel +from .bidi_model import BidiModel logger = logging.getLogger(__name__) From fbdcf11eff3ac3395fbf0c6aca87d5accbb4aa99 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 12 Nov 2025 13:25:19 -0500 Subject: [PATCH 113/242] Update model in test script --- src/strands/experimental/bidi/scripts/test_bidi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/strands/experimental/bidi/scripts/test_bidi.py b/src/strands/experimental/bidi/scripts/test_bidi.py index abeb9fcf7..8a59824d9 100644 --- a/src/strands/experimental/bidi/scripts/test_bidi.py +++ b/src/strands/experimental/bidi/scripts/test_bidi.py @@ -7,7 +7,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) from strands.experimental.bidi.agent.agent import BidiAgent -from strands.experimental.bidi.models.openai import BidiOpenAIRealtimeModel +from strands.experimental.bidi.models.novasonic import BidiNovaSonicModel from strands.experimental.bidi.io import BidiAudioIO, BidiTextIO from strands_tools import calculator @@ -19,7 +19,7 @@ async def main(): # Nova Sonic model audio_io = BidiAudioIO(audio_config={}) text_io = BidiTextIO() - model = BidiOpenAIRealtimeModel(region="us-east-1") + model = BidiNovaSonicModel(region="us-east-1") async with BidiAgent(model=model, tools=[calculator]) as agent: print("New BidiAgent Experience") From c1b3bf5c30b1209086de1c900fda5d352d9e8072 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 12 Nov 2025 13:34:19 -0500 Subject: [PATCH 114/242] Update import --- src/strands/experimental/bidi/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/experimental/bidi/__init__.py b/src/strands/experimental/bidi/__init__.py index 033a4bb78..a25655789 100644 --- a/src/strands/experimental/bidi/__init__.py +++ b/src/strands/experimental/bidi/__init__.py @@ -7,7 +7,7 @@ from .io.audio import BidiAudioIO # Model interface (for custom implementations) -from .models.bidirectional_model import BidiModel +from .models.bidi_model import BidiModel # Model providers - What users need to create models from .models.gemini_live import BidiGeminiLiveModel From 92e6822f14c2762d9f018bce9a09a032964571ab Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 13 Nov 2025 15:44:14 +0300 Subject: [PATCH 115/242] fix(pyproject): Fix bidi-all dependency group --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 998059f17..f5311c299 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,7 +97,7 @@ a2a = [ "fastapi>=0.115.12,<1.0.0", "starlette>=0.46.2,<1.0.0", ] -bidi-all = ["bidi-openai", "bidi-gemini", "bidi-novasonic", "bidi"] +bidi-all = ["strands-agents[bidi,bidi-openai,bidi-gemini,bidi-novasonic]"] all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] dev = [ From ef3d6595eb7f4400273759f907e6b9512b4a5270 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 13 Nov 2025 10:20:25 -0500 Subject: [PATCH 116/242] Merge pull request #43 from mehtarac/bidio_interface --- src/strands/experimental/bidi/agent/agent.py | 12 ++++++++---- src/strands/experimental/bidi/io/audio.py | 8 ++++---- src/strands/experimental/bidi/io/text.py | 4 ++-- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index aa3930f0c..a92adc315 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -422,20 +422,24 @@ async def run_outputs(): await output(event) for input_ in inputs: - await input_.start() + if hasattr(input_, "start"): + await input_.start() for output in outputs: - await output.start() + if hasattr(output, "start"): + await output.start() try: await asyncio.gather(run_inputs(), run_outputs(), return_exceptions=True) finally: for input_ in inputs: - await input_.stop() + if hasattr(input_, "stop"): + await input_.stop() for output in outputs: - await output.stop() + if hasattr(output, "stop"): + await output.stop() def _validate_active_connection(self) -> None: """Validate that an active connection exists. diff --git a/src/strands/experimental/bidi/io/audio.py b/src/strands/experimental/bidi/io/audio.py index 2ec167480..c99176de8 100644 --- a/src/strands/experimental/bidi/io/audio.py +++ b/src/strands/experimental/bidi/io/audio.py @@ -17,7 +17,7 @@ class _BidiAudioInput(BidiInput): - "Handle audio input from bidi agent." + """Handle audio input from bidi agent.""" def __init__(self, audio: "BidiAudioIO") -> None: """Store reference to pyaudio instance.""" self.audio = audio @@ -43,7 +43,7 @@ async def __call__(self) -> BidiAudioInputEvent: class _BidiAudioOutput(BidiOutput): - "Handle audio output from bidi agent." + """Handle audio output from bidi agent.""" def __init__(self, audio: "BidiAudioIO") -> None: """Store reference to pyaudio instance.""" self.audio = audio @@ -115,11 +115,11 @@ def __init__( self.interrupted = False def input(self) -> _BidiAudioInput: - "Return audio processing BidiInput" + """Return audio processing BidiInput""" return _BidiAudioInput(self) def output(self) -> _BidiAudioOutput: - "Return audio processing BidiOutput" + """Return audio processing BidiOutput""" return _BidiAudioOutput(self) def _start(self) -> None: diff --git a/src/strands/experimental/bidi/io/text.py b/src/strands/experimental/bidi/io/text.py index ba503f4e4..cd76de2c9 100644 --- a/src/strands/experimental/bidi/io/text.py +++ b/src/strands/experimental/bidi/io/text.py @@ -9,7 +9,7 @@ class _BidiTextOutput(BidiOutput): - "Handle text output from bidi agent." + """Handle text output from bidi agent.""" async def __call__(self, event: BidiOutputEvent) -> None: """Print text events to stdout.""" @@ -25,7 +25,7 @@ async def __call__(self, event: BidiOutputEvent) -> None: class BidiTextIO: - "Handle text input and output from bidi agent." + """Handle text input and output from bidi agent.""" def output(self) -> _BidiTextOutput: "Return text processing BidiOutput" return _BidiTextOutput() From aae5cf0f2a7df19473360c9fdb487f4465e23d9d Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 13 Nov 2025 18:52:11 +0300 Subject: [PATCH 117/242] Add start/stop to bidi_agent.run --- src/strands/experimental/bidi/agent/agent.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index a92adc315..fa3b8fde5 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -428,11 +428,15 @@ async def run_outputs(): for output in outputs: if hasattr(output, "start"): await output.start() - + + # Start agent after all IO is ready + await self.start() try: await asyncio.gather(run_inputs(), run_outputs(), return_exceptions=True) finally: + await self.stop() + for input_ in inputs: if hasattr(input_, "stop"): await input_.stop() From 504509e6cd88d122c42091740ac13ac1936706f8 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 17 Nov 2025 11:19:52 +0100 Subject: [PATCH 118/242] feat: Add hook support to bidi agents --- src/strands/experimental/bidi/agent/agent.py | 14 + src/strands/experimental/bidi/agent/loop.py | 66 ++++- .../experimental/bidi/hooks/__init__.py | 55 ++++ src/strands/experimental/bidi/hooks/events.py | 193 +++++++++++++ .../experimental/bidi/hooks/__init__.py | 1 + .../bidi/hooks/test_bidi_hook_events.py | 170 ++++++++++++ tests_integ/bidi/test_bidi_hooks.py | 260 ++++++++++++++++++ 7 files changed, 753 insertions(+), 6 deletions(-) create mode 100644 src/strands/experimental/bidi/hooks/__init__.py create mode 100644 src/strands/experimental/bidi/hooks/events.py create mode 100644 tests/strands/experimental/bidi/hooks/__init__.py create mode 100644 tests/strands/experimental/bidi/hooks/test_bidi_hook_events.py create mode 100644 tests_integ/bidi/test_bidi_hooks.py diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index fa3b8fde5..9f4bb0941 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -18,6 +18,7 @@ from typing import Any, AsyncIterable from .... import _identifier +from ....hooks import HookProvider, HookRegistry from ....tools.caller import _ToolCaller from ....tools.executors import ConcurrentToolExecutor from ....tools.executors._executor import ToolExecutor @@ -27,6 +28,7 @@ from ....types.tools import ToolResult, ToolUse, AgentTool from .loop import _BidiAgentLoop +from ..hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent from ..models.bidi_model import BidiModel from ..models.novasonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput @@ -59,6 +61,7 @@ def __init__( name: str | None = None, tool_executor: ToolExecutor | None = None, description: str | None = None, + hooks: list[HookProvider] | None = None, **kwargs: Any, ): """Initialize bidirectional agent. @@ -74,6 +77,7 @@ def __init__( name: Name of the Agent. tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). description: Description of what the Agent does. + hooks: Optional list of hook providers to register for lifecycle events. **kwargs: Additional configuration for future extensibility. Raises: @@ -119,8 +123,17 @@ def __init__( self._current_adapters = [] # Track adapters for cleanup + # Initialize hooks registry + self.hooks = HookRegistry() + if hooks: + for hook in hooks: + self.hooks.add_hook(hook) + self._loop = _BidiAgentLoop(self) + # Emit initialization event + self.hooks.invoke_callbacks(BidiAgentInitializedEvent(agent=self)) + @property def tool(self) -> _ToolCaller: """Call tool as a function. @@ -272,6 +285,7 @@ async def send(self, input_data: BidiAgentInput) -> None: user_message: Message = {"role": "user", "content": [{"text": input_data}]} self.messages.append(user_message) + self.hooks.invoke_callbacks(BidiMessageAddedEvent(agent=self, message=user_message)) logger.debug("Text sent: %d characters", len(input_data)) # Create BidiTextInputEvent for send() diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index e0bc02ef2..d7c1d61d2 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -7,6 +7,14 @@ import logging from typing import AsyncIterable, Awaitable, TYPE_CHECKING +from ..hooks.events import ( + BidiAfterInvocationEvent, + BidiAfterToolCallEvent, + BidiBeforeInvocationEvent, + BidiBeforeToolCallEvent, + BidiInterruptionEvent as BidiInterruptionHookEvent, + BidiMessageAddedEvent, +) from ..types.events import BidiAudioStreamEvent, BidiInterruptionEvent, BidiOutputEvent, BidiTranscriptStreamEvent from ....types._events import ToolResultEvent, ToolResultMessageEvent, ToolStreamEvent, ToolUseStreamEvent from ....types.content import Message @@ -44,6 +52,9 @@ async def start(self) -> None: logger.debug("starting agent loop") + # Emit before invocation event + self._agent.hooks.invoke_callbacks(BidiBeforeInvocationEvent(agent=self._agent)) + await self._agent.model.start( system_prompt=self._agent.system_prompt, tools=self._agent.tool_registry.get_all_tool_specs(), @@ -61,14 +72,19 @@ async def stop(self) -> None: logger.debug("stopping agent loop") - for task in self._tasks: - task.cancel() + try: + for task in self._tasks: + task.cancel() + + await asyncio.gather(*self._tasks, return_exceptions=True) - await asyncio.gather(*self._tasks, return_exceptions=True) + await self._agent.model.stop() - await self._agent.model.stop() + self._active = False - self._active = False + finally: + # Emit after invocation event (reverse order for cleanup) + self._agent.hooks.invoke_callbacks(BidiAfterInvocationEvent(agent=self._agent)) async def receive(self) -> AsyncIterable[BidiOutputEvent]: """Receive model and tool call events.""" @@ -113,11 +129,21 @@ async def _run_model(self) -> None: if event["is_final"]: message: Message = {"role": event["role"], "content": [{"text": event["text"]}]} self._agent.messages.append(message) + self._agent.hooks.invoke_callbacks(BidiMessageAddedEvent(agent=self._agent, message=message)) elif isinstance(event, ToolUseStreamEvent): self._create_task(self._run_tool(event["current_tool_use"])) elif isinstance(event, BidiInterruptionEvent): + # Emit interruption hook event + self._agent.hooks.invoke_callbacks( + BidiInterruptionHookEvent( + agent=self._agent, + reason=event["reason"], + interrupted_response_id=event.get("interrupted_response_id"), + ) + ) + # clear the audio for _ in range(self._event_queue.qsize()): event = self._event_queue.get_nowait() @@ -129,10 +155,22 @@ async def _run_tool(self, tool_use: ToolUse) -> None: logger.debug("running tool") result: ToolResult = None + exception: Exception | None = None + tool = None + invocation_state = {} try: tool = self._agent.tool_registry.registry[tool_use["name"]] - invocation_state = {} + + # Emit before tool call event + self._agent.hooks.invoke_callbacks( + BidiBeforeToolCallEvent( + agent=self._agent, + selected_tool=tool, + tool_use=tool_use, + invocation_state=invocation_state, + ) + ) async for event in tool.stream(tool_use, invocation_state): if isinstance(event, ToolResultEvent): @@ -146,12 +184,27 @@ async def _run_tool(self, tool_use: ToolUse) -> None: self._event_queue.put_nowait(ToolStreamEvent(tool_use, event)) except Exception as e: + exception = e result = { "toolUseId": tool_use["toolUseId"], "status": "error", "content": [{"text": f"Error: {str(e)}"}] } + finally: + # Emit after tool call event (reverse order for cleanup) + if result: + self._agent.hooks.invoke_callbacks( + BidiAfterToolCallEvent( + agent=self._agent, + selected_tool=tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=result, + exception=exception, + ) + ) + await self._agent.model.send(ToolResultEvent(result)) message: Message = { @@ -159,4 +212,5 @@ async def _run_tool(self, tool_use: ToolUse) -> None: "content": [{"toolResult": result}], } self._agent.messages.append(message) + self._agent.hooks.invoke_callbacks(BidiMessageAddedEvent(agent=self._agent, message=message)) self._event_queue.put_nowait(ToolResultMessageEvent(message)) diff --git a/src/strands/experimental/bidi/hooks/__init__.py b/src/strands/experimental/bidi/hooks/__init__.py new file mode 100644 index 000000000..6ed0e52cf --- /dev/null +++ b/src/strands/experimental/bidi/hooks/__init__.py @@ -0,0 +1,55 @@ +"""Typed hook system for BidiAgent. + +This module provides hook events specifically for BidiAgent, enabling +composable extension of bidirectional streaming agent functionality. + +Example Usage: + ```python + from strands.experimental.bidi.hooks import ( + BidiBeforeInvocationEvent, + BidiInterruptionEvent, + HookProvider, + HookRegistry + ) + + class BidiLoggingHooks(HookProvider): + def register_hooks(self, registry: HookRegistry) -> None: + registry.add_callback(BidiBeforeInvocationEvent, self.log_session_start) + registry.add_callback(BidiInterruptionEvent, self.log_interruption) + + def log_session_start(self, event: BidiBeforeInvocationEvent) -> None: + print(f"BidiAgent {event.agent.name} starting session") + + def log_interruption(self, event: BidiInterruptionEvent) -> None: + print(f"Interrupted: {event.reason}") + + # Use with BidiAgent + agent = BidiAgent(hooks=[BidiLoggingHooks()]) + ``` +""" + +from ....hooks import HookCallback, HookProvider, HookRegistry +from .events import ( + BidiAfterInvocationEvent, + BidiAfterToolCallEvent, + BidiAgentInitializedEvent, + BidiBeforeInvocationEvent, + BidiBeforeToolCallEvent, + BidiHookEvent, + BidiInterruptionEvent, + BidiMessageAddedEvent, +) + +__all__ = [ + "BidiAgentInitializedEvent", + "BidiBeforeInvocationEvent", + "BidiAfterInvocationEvent", + "BidiBeforeToolCallEvent", + "BidiAfterToolCallEvent", + "BidiMessageAddedEvent", + "BidiInterruptionEvent", + "BidiHookEvent", + "HookProvider", + "HookCallback", + "HookRegistry", +] diff --git a/src/strands/experimental/bidi/hooks/events.py b/src/strands/experimental/bidi/hooks/events.py new file mode 100644 index 000000000..e5aa5b2bd --- /dev/null +++ b/src/strands/experimental/bidi/hooks/events.py @@ -0,0 +1,193 @@ +"""Hook events for BidiAgent. + +This module defines the events that are emitted as BidiAgent runs through +the lifecycle of a streaming session. +""" + +import uuid +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal, Optional + +from typing_extensions import override + +from ....hooks.registry import BaseHookEvent +from ....types.content import Message +from ....types.interrupt import _Interruptible +from ....types.tools import AgentTool, ToolResult, ToolUse + +if TYPE_CHECKING: + from ..agent.agent import BidiAgent + + +@dataclass +class BidiHookEvent(BaseHookEvent): + """Base class for BidiAgent hook events. + + Attributes: + agent: The BidiAgent instance that triggered this event. + """ + + agent: "BidiAgent" + + +@dataclass +class BidiAgentInitializedEvent(BidiHookEvent): + """Event triggered when a BidiAgent has finished initialization. + + This event is fired after the BidiAgent has been fully constructed and all + built-in components have been initialized. Hook providers can use this + event to perform setup tasks that require a fully initialized agent. + """ + + pass + + +@dataclass +class BidiBeforeInvocationEvent(BidiHookEvent): + """Event triggered when BidiAgent starts a streaming session. + + This event is fired before the BidiAgent begins a streaming session, + before any model connection or audio processing occurs. Hook providers can + use this event to perform session-level setup, logging, or validation. + + This event is triggered at the beginning of agent.start(). + """ + + pass + + +@dataclass +class BidiAfterInvocationEvent(BidiHookEvent): + """Event triggered when BidiAgent ends a streaming session. + + This event is fired after the BidiAgent has completed a streaming session, + regardless of whether it completed successfully or encountered an error. + Hook providers can use this event for cleanup, logging, or state persistence. + + Note: This event uses reverse callback ordering, meaning callbacks registered + later will be invoked first during cleanup. + + This event is triggered at the end of agent.stop(). + """ + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True + + +@dataclass +class BidiMessageAddedEvent(BidiHookEvent): + """Event triggered when BidiAgent adds a message to the conversation. + + This event is fired whenever the BidiAgent adds a new message to its internal + message history, including user messages (from transcripts), assistant responses, + and tool results. Hook providers can use this event for logging, monitoring, or + implementing custom message processing logic. + + Note: This event is only triggered for messages added by the framework + itself, not for messages manually added by tools or external code. + + Attributes: + message: The message that was added to the conversation history. + """ + + message: Message + + +@dataclass +class BidiBeforeToolCallEvent(BidiHookEvent, _Interruptible): + """Event triggered before BidiAgent executes a tool. + + This event is fired just before the BidiAgent executes a tool during a streaming + session, allowing hook providers to inspect, modify, or replace the tool that + will be executed. The selected_tool can be modified by hook callbacks to change + which tool gets executed. + + Attributes: + selected_tool: The tool that will be invoked. Can be modified by hooks + to change which tool gets executed. This may be None if tool lookup failed. + tool_use: The tool parameters that will be passed to selected_tool. + invocation_state: Keyword arguments that will be passed to the tool. + cancel_tool: A user defined message that when set, will cancel the tool call. + The message will be placed into a tool result with an error status. If set to `True`, + Strands will cancel the tool call and use a default cancel message. + """ + + selected_tool: Optional[AgentTool] + tool_use: ToolUse + invocation_state: dict[str, Any] + cancel_tool: bool | str = False + + def _can_write(self, name: str) -> bool: + return name in ["cancel_tool", "selected_tool", "tool_use"] + + @override + def _interrupt_id(self, name: str) -> str: + """Unique id for the interrupt. + + Args: + name: User defined name for the interrupt. + + Returns: + Interrupt id. + """ + return f"v1:bidi_before_tool_call:{self.tool_use['toolUseId']}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}" + + +@dataclass +class BidiAfterToolCallEvent(BidiHookEvent): + """Event triggered after BidiAgent executes a tool. + + This event is fired after the BidiAgent has finished executing a tool during + a streaming session, regardless of whether the execution was successful or + resulted in an error. Hook providers can use this event for cleanup, logging, + or post-processing. + + Note: This event uses reverse callback ordering, meaning callbacks registered + later will be invoked first during cleanup. + + Attributes: + selected_tool: The tool that was invoked. It may be None if tool lookup failed. + tool_use: The tool parameters that were passed to the tool invoked. + invocation_state: Keyword arguments that were passed to the tool. + result: The result of the tool invocation. Either a ToolResult on success + or an Exception if the tool execution failed. + exception: Exception if the tool execution failed, None if successful. + cancel_message: The cancellation message if the user cancelled the tool call. + """ + + selected_tool: Optional[AgentTool] + tool_use: ToolUse + invocation_state: dict[str, Any] + result: ToolResult + exception: Optional[Exception] = None + cancel_message: str | None = None + + def _can_write(self, name: str) -> bool: + return name == "result" + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True + + +@dataclass +class BidiInterruptionEvent(BidiHookEvent): + """Event triggered when model generation is interrupted. + + This event is fired when the user interrupts the assistant (e.g., by speaking + during the assistant's response) or when an error causes interruption. This is + specific to bidirectional streaming and doesn't exist in standard agents. + + Hook providers can use this event to log interruptions, implement custom + interruption handling, or trigger cleanup logic. + + Attributes: + reason: The reason for the interruption ("user_speech" or "error"). + interrupted_response_id: Optional ID of the response that was interrupted. + """ + + reason: Literal["user_speech", "error"] + interrupted_response_id: Optional[str] = None diff --git a/tests/strands/experimental/bidi/hooks/__init__.py b/tests/strands/experimental/bidi/hooks/__init__.py new file mode 100644 index 000000000..20a078833 --- /dev/null +++ b/tests/strands/experimental/bidi/hooks/__init__.py @@ -0,0 +1 @@ +"""Tests for BidiAgent hooks.""" diff --git a/tests/strands/experimental/bidi/hooks/test_bidi_hook_events.py b/tests/strands/experimental/bidi/hooks/test_bidi_hook_events.py new file mode 100644 index 000000000..70550ee56 --- /dev/null +++ b/tests/strands/experimental/bidi/hooks/test_bidi_hook_events.py @@ -0,0 +1,170 @@ +"""Unit tests for BidiAgent hook events.""" + +from unittest.mock import Mock + +import pytest + +from strands.experimental.bidi.hooks import ( + BidiAfterInvocationEvent, + BidiAfterToolCallEvent, + BidiAgentInitializedEvent, + BidiBeforeInvocationEvent, + BidiBeforeToolCallEvent, + BidiInterruptionEvent, + BidiMessageAddedEvent, +) +from strands.types.tools import ToolResult, ToolUse + + +@pytest.fixture +def agent(): + return Mock() + + +@pytest.fixture +def tool(): + tool = Mock() + tool.tool_name = "test_tool" + return tool + + +@pytest.fixture +def tool_use(): + return ToolUse(name="test_tool", toolUseId="123", input={"param": "value"}) + + +@pytest.fixture +def tool_invocation_state(): + return {"param": "value"} + + +@pytest.fixture +def tool_result(): + return ToolResult(content=[{"text": "result"}], status="success", toolUseId="123") + + +@pytest.fixture +def message(): + return {"role": "user", "content": [{"text": "Hello"}]} + + +@pytest.fixture +def initialized_event(agent): + return BidiAgentInitializedEvent(agent=agent) + + +@pytest.fixture +def before_invocation_event(agent): + return BidiBeforeInvocationEvent(agent=agent) + + +@pytest.fixture +def after_invocation_event(agent): + return BidiAfterInvocationEvent(agent=agent) + + +@pytest.fixture +def message_added_event(agent, message): + return BidiMessageAddedEvent(agent=agent, message=message) + + +@pytest.fixture +def before_tool_event(agent, tool, tool_use, tool_invocation_state): + return BidiBeforeToolCallEvent( + agent=agent, + selected_tool=tool, + tool_use=tool_use, + invocation_state=tool_invocation_state, + ) + + +@pytest.fixture +def after_tool_event(agent, tool, tool_use, tool_invocation_state, tool_result): + return BidiAfterToolCallEvent( + agent=agent, + selected_tool=tool, + tool_use=tool_use, + invocation_state=tool_invocation_state, + result=tool_result, + ) + + +@pytest.fixture +def interruption_event(agent): + return BidiInterruptionEvent(agent=agent, reason="user_speech") + + +def test_event_should_reverse_callbacks( + initialized_event, + before_invocation_event, + after_invocation_event, + message_added_event, + before_tool_event, + after_tool_event, + interruption_event, +): + """Verify which events use reverse callback ordering.""" + # note that we ignore E712 (explicit booleans) for consistency/readability purposes + + assert initialized_event.should_reverse_callbacks == False # noqa: E712 + assert message_added_event.should_reverse_callbacks == False # noqa: E712 + assert interruption_event.should_reverse_callbacks == False # noqa: E712 + + assert before_invocation_event.should_reverse_callbacks == False # noqa: E712 + assert after_invocation_event.should_reverse_callbacks == True # noqa: E712 + + assert before_tool_event.should_reverse_callbacks == False # noqa: E712 + assert after_tool_event.should_reverse_callbacks == True # noqa: E712 + + +def test_interruption_event_with_response_id(agent): + """Verify BidiInterruptionEvent can include response ID.""" + event = BidiInterruptionEvent(agent=agent, reason="error", interrupted_response_id="resp_123") + + assert event.reason == "error" + assert event.interrupted_response_id == "resp_123" + + +def test_message_added_event_cannot_write_properties(message_added_event): + """Verify BidiMessageAddedEvent properties are read-only.""" + with pytest.raises(AttributeError, match="Property agent is not writable"): + message_added_event.agent = Mock() + with pytest.raises(AttributeError, match="Property message is not writable"): + message_added_event.message = {} + + +def test_before_tool_call_event_can_write_properties(before_tool_event): + """Verify BidiBeforeToolCallEvent allows writing specific properties.""" + new_tool_use = ToolUse(name="new_tool", toolUseId="456", input={}) + before_tool_event.selected_tool = None # Should not raise + before_tool_event.tool_use = new_tool_use # Should not raise + before_tool_event.cancel_tool = True # Should not raise + + +def test_before_tool_call_event_cannot_write_properties(before_tool_event): + """Verify BidiBeforeToolCallEvent protects certain properties.""" + with pytest.raises(AttributeError, match="Property agent is not writable"): + before_tool_event.agent = Mock() + with pytest.raises(AttributeError, match="Property invocation_state is not writable"): + before_tool_event.invocation_state = {} + + +def test_after_tool_call_event_can_write_properties(after_tool_event): + """Verify BidiAfterToolCallEvent allows writing result property.""" + new_result = ToolResult(content=[{"text": "new result"}], status="success", toolUseId="456") + after_tool_event.result = new_result # Should not raise + + +def test_after_tool_call_event_cannot_write_properties(after_tool_event): + """Verify BidiAfterToolCallEvent protects certain properties.""" + with pytest.raises(AttributeError, match="Property agent is not writable"): + after_tool_event.agent = Mock() + with pytest.raises(AttributeError, match="Property selected_tool is not writable"): + after_tool_event.selected_tool = None + with pytest.raises(AttributeError, match="Property tool_use is not writable"): + after_tool_event.tool_use = ToolUse(name="new", toolUseId="456", input={}) + with pytest.raises(AttributeError, match="Property invocation_state is not writable"): + after_tool_event.invocation_state = {} + with pytest.raises(AttributeError, match="Property exception is not writable"): + after_tool_event.exception = Exception("test") + diff --git a/tests_integ/bidi/test_bidi_hooks.py b/tests_integ/bidi/test_bidi_hooks.py new file mode 100644 index 000000000..badfea384 --- /dev/null +++ b/tests_integ/bidi/test_bidi_hooks.py @@ -0,0 +1,260 @@ +"""Integration tests for BidiAgent hooks with real model providers.""" + +import asyncio + +import pytest + +from src.strands import tool +from src.strands.experimental.bidi.agent.agent import BidiAgent +from src.strands.experimental.bidi.hooks import ( + BidiAfterInvocationEvent, + BidiAfterToolCallEvent, + BidiAgentInitializedEvent, + BidiBeforeInvocationEvent, + BidiBeforeToolCallEvent, + BidiInterruptionEvent, + BidiMessageAddedEvent, + HookProvider, +) + + +class HookEventCollector(HookProvider): + """Hook provider that collects all emitted events for testing.""" + + def __init__(self): + self.events = [] + + def register_hooks(self, registry): + registry.add_callback(BidiAgentInitializedEvent, self.on_initialized) + registry.add_callback(BidiBeforeInvocationEvent, self.on_before_invocation) + registry.add_callback(BidiAfterInvocationEvent, self.on_after_invocation) + registry.add_callback(BidiBeforeToolCallEvent, self.on_before_tool_call) + registry.add_callback(BidiAfterToolCallEvent, self.on_after_tool_call) + registry.add_callback(BidiMessageAddedEvent, self.on_message_added) + registry.add_callback(BidiInterruptionEvent, self.on_interruption) + + def on_initialized(self, event: BidiAgentInitializedEvent): + self.events.append(("initialized", event)) + + def on_before_invocation(self, event: BidiBeforeInvocationEvent): + self.events.append(("before_invocation", event)) + + def on_after_invocation(self, event: BidiAfterInvocationEvent): + self.events.append(("after_invocation", event)) + + def on_before_tool_call(self, event: BidiBeforeToolCallEvent): + self.events.append(("before_tool_call", event)) + + def on_after_tool_call(self, event: BidiAfterToolCallEvent): + self.events.append(("after_tool_call", event)) + + def on_message_added(self, event: BidiMessageAddedEvent): + self.events.append(("message_added", event)) + + def on_interruption(self, event: BidiInterruptionEvent): + self.events.append(("interruption", event)) + + def get_event_types(self): + """Get list of event type names in order.""" + return [event_type for event_type, _ in self.events] + + def get_events_by_type(self, event_type): + """Get all events of a specific type.""" + return [event for et, event in self.events if et == event_type] + + +@pytest.mark.asyncio +class TestBidiAgentHooksLifecycle: + """Test BidiAgent hook lifecycle events.""" + + async def test_agent_initialization_emits_hook(self): + """Verify agent initialization emits BidiAgentInitializedEvent.""" + collector = HookEventCollector() + agent = BidiAgent(hooks=[collector]) + + # Should have emitted initialized event + assert "initialized" in collector.get_event_types() + init_events = collector.get_events_by_type("initialized") + assert len(init_events) == 1 + assert init_events[0].agent == agent + + async def test_session_lifecycle_emits_hooks(self): + """Verify session start/stop emits before/after invocation events.""" + collector = HookEventCollector() + agent = BidiAgent(hooks=[collector]) + + # Start session + await agent.start() + + # Should have emitted before_invocation + assert "before_invocation" in collector.get_event_types() + + # Stop session + await agent.stop() + + # Should have emitted after_invocation + assert "after_invocation" in collector.get_event_types() + + # Verify order: initialized -> before_invocation -> after_invocation + event_types = collector.get_event_types() + assert event_types.index("initialized") < event_types.index("before_invocation") + assert event_types.index("before_invocation") < event_types.index("after_invocation") + + async def test_message_added_hook_on_text_input(self): + """Verify sending text emits BidiMessageAddedEvent.""" + collector = HookEventCollector() + agent = BidiAgent(hooks=[collector]) + + await agent.start() + + # Send text message + await agent.send("Hello, agent!") + + await agent.stop() + + # Should have emitted message_added event + message_events = collector.get_events_by_type("message_added") + assert len(message_events) >= 1 + + # Find the user message event + user_messages = [e for e in message_events if e.message["role"] == "user"] + assert len(user_messages) >= 1 + assert user_messages[0].message["content"][0]["text"] == "Hello, agent!" + + +@pytest.mark.asyncio +class TestBidiAgentHooksWithTools: + """Test BidiAgent hook events with tool execution.""" + + async def test_tool_call_hooks_emitted(self): + """Verify tool execution emits before/after tool call events.""" + + @tool + def test_calculator(expression: str) -> str: + """Calculate a math expression.""" + return f"Result: {eval(expression)}" + + collector = HookEventCollector() + agent = BidiAgent(tools=[test_calculator], hooks=[collector]) + + # Note: This test verifies hook infrastructure is in place + # Actual tool execution would require model interaction + # which is tested in full integration tests + + # Verify hooks are registered + assert agent.hooks.has_callbacks() + + # Verify tool is registered + assert "test_calculator" in agent.tool_names + + +@pytest.mark.asyncio +class TestBidiAgentHooksEventData: + """Test BidiAgent hook event data integrity.""" + + async def test_hook_events_contain_agent_reference(self): + """Verify all hook events contain correct agent reference.""" + collector = HookEventCollector() + agent = BidiAgent(hooks=[collector]) + + await agent.start() + await agent.send("Test message") + await agent.stop() + + # All events should reference the same agent + for event_type, event in collector.events: + assert hasattr(event, "agent") + assert event.agent == agent + + async def test_message_added_event_contains_message(self): + """Verify BidiMessageAddedEvent contains the actual message.""" + collector = HookEventCollector() + agent = BidiAgent(hooks=[collector]) + + await agent.start() + test_text = "Test message content" + await agent.send(test_text) + await agent.stop() + + # Find message_added events + message_events = collector.get_events_by_type("message_added") + assert len(message_events) >= 1 + + # Verify message content + user_messages = [e for e in message_events if e.message["role"] == "user"] + assert len(user_messages) >= 1 + assert user_messages[0].message["content"][0]["text"] == test_text + + +@pytest.mark.asyncio +class TestBidiAgentHooksOrdering: + """Test BidiAgent hook callback ordering.""" + + async def test_multiple_hooks_fire_in_order(self): + """Verify multiple hook providers fire in registration order.""" + call_order = [] + + class FirstHook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BidiBeforeInvocationEvent, lambda e: call_order.append("first")) + + class SecondHook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BidiBeforeInvocationEvent, lambda e: call_order.append("second")) + + class ThirdHook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BidiBeforeInvocationEvent, lambda e: call_order.append("third")) + + agent = BidiAgent(hooks=[FirstHook(), SecondHook(), ThirdHook()]) + + await agent.start() + await agent.stop() + + # Verify order + assert call_order == ["first", "second", "third"] + + async def test_after_invocation_fires_in_reverse_order(self): + """Verify after invocation hooks fire in reverse order (cleanup).""" + call_order = [] + + class FirstHook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BidiAfterInvocationEvent, lambda e: call_order.append("first")) + + class SecondHook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BidiAfterInvocationEvent, lambda e: call_order.append("second")) + + class ThirdHook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BidiAfterInvocationEvent, lambda e: call_order.append("third")) + + agent = BidiAgent(hooks=[FirstHook(), SecondHook(), ThirdHook()]) + + await agent.start() + await agent.stop() + + # Verify reverse order for cleanup + assert call_order == ["third", "second", "first"] + + +@pytest.mark.asyncio +class TestBidiAgentHooksContextManager: + """Test BidiAgent hooks with async context manager.""" + + async def test_hooks_fire_with_context_manager(self): + """Verify hooks fire correctly when using async context manager.""" + collector = HookEventCollector() + + async with BidiAgent(hooks=[collector]) as agent: + await agent.send("Test message") + + # Verify lifecycle events + event_types = collector.get_event_types() + assert "initialized" in event_types + assert "before_invocation" in event_types + assert "after_invocation" in event_types + + # Verify order + assert event_types.index("before_invocation") < event_types.index("after_invocation") From 918b1af62821aebb116e5fc6552a9158f7fbc5ee Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 17 Nov 2025 09:59:06 -0500 Subject: [PATCH 119/242] bidi audio io - handle interruption (#45) --- src/strands/experimental/bidi/agent/agent.py | 4 +- src/strands/experimental/bidi/agent/loop.py | 9 +- src/strands/experimental/bidi/io/audio.py | 260 ++++++++++-------- .../experimental/bidi/models/novasonic.py | 2 +- .../experimental/bidi/scripts/test_bidi.py | 2 +- .../strands/experimental/bidi/io/__init__.py | 0 .../experimental/bidi/io/test_audio.py | 80 ++++++ 7 files changed, 235 insertions(+), 122 deletions(-) create mode 100644 tests/strands/experimental/bidi/io/__init__.py create mode 100644 tests/strands/experimental/bidi/io/test_audio.py diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index fa3b8fde5..b8503ceca 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -428,7 +428,7 @@ async def run_outputs(): for output in outputs: if hasattr(output, "start"): await output.start() - + # Start agent after all IO is ready await self.start() try: @@ -436,7 +436,7 @@ async def run_outputs(): finally: await self.stop() - + for input_ in inputs: if hasattr(input_, "stop"): await input_.stop() diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index e0bc02ef2..9ec978a45 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -7,7 +7,7 @@ import logging from typing import AsyncIterable, Awaitable, TYPE_CHECKING -from ..types.events import BidiAudioStreamEvent, BidiInterruptionEvent, BidiOutputEvent, BidiTranscriptStreamEvent +from ..types.events import BidiOutputEvent, BidiTranscriptStreamEvent from ....types._events import ToolResultEvent, ToolResultMessageEvent, ToolStreamEvent, ToolUseStreamEvent from ....types.content import Message from ....types.tools import ToolResult, ToolUse @@ -117,13 +117,6 @@ async def _run_model(self) -> None: elif isinstance(event, ToolUseStreamEvent): self._create_task(self._run_tool(event["current_tool_use"])) - elif isinstance(event, BidiInterruptionEvent): - # clear the audio - for _ in range(self._event_queue.qsize()): - event = self._event_queue.get_nowait() - if not isinstance(event, BidiAudioStreamEvent): - self._event_queue.put_nowait(event) - async def _run_tool(self, tool_use: ToolUse) -> None: """Task for running tool requested by the model.""" logger.debug("running tool") diff --git a/src/strands/experimental/bidi/io/audio.py b/src/strands/experimental/bidi/io/audio.py index c99176de8..34e8fd0e9 100644 --- a/src/strands/experimental/bidi/io/audio.py +++ b/src/strands/experimental/bidi/io/audio.py @@ -1,161 +1,201 @@ -"""AudioIO - Clean separation of audio functionality from core BidiAgent. +"""Send and receive audio data from devices. -Provides audio input/output capabilities for BidiAgent through the BidiIO protocol. -Handles all PyAudio setup, streaming, and cleanup while keeping the core agent data-agnostic. +Reads user audio from input device and sends agent audio to output device using PyAudio. If a user interrupts the agent, +the output buffer is cleared to stop playback. """ import asyncio import base64 import logging +from collections import deque +from typing import Any import pyaudio from ..types.io import BidiInput, BidiOutput -from ..types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiOutputEvent +from ..types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiInterruptionEvent, BidiOutputEvent logger = logging.getLogger(__name__) class _BidiAudioInput(BidiInput): - """Handle audio input from bidi agent.""" - def __init__(self, audio: "BidiAudioIO") -> None: - """Store reference to pyaudio instance.""" - self.audio = audio - + """Handle audio input from user. + + Attributes: + _audio: PyAudio instance for audio system access. + _stream: Audio input stream. + """ + + _audio: pyaudio.PyAudio + _stream: pyaudio.Stream + + _CHANNELS: int = 1 + _DEVICE_INDEX: int | None = None + _ENCODING: str = "pcm" + _FORMAT: int = pyaudio.paInt16 + _FRAMES_PER_BUFFER: int = 512 + _RATE: int = 16000 + + def __init__(self, config: dict[str, Any]) -> None: + """Extract configs.""" + self._channels = config.get("input_channels", _BidiAudioInput._CHANNELS) + self._device_index = config.get("input_device_index", _BidiAudioInput._DEVICE_INDEX) + self._format = config.get("input_format", _BidiAudioInput._FORMAT) + self._frames_per_buffer = config.get("input_frames_per_buffer", _BidiAudioInput._FRAMES_PER_BUFFER) + self._rate = config.get("input_rate", _BidiAudioInput._RATE) + async def start(self) -> None: - """Start audio input.""" - self.audio._start() + """Start input stream.""" + self._audio = pyaudio.PyAudio() + self._stream = self._audio.open( + channels=self._channels, + format=self._format, + frames_per_buffer=self._frames_per_buffer, + input=True, + input_device_index=self._device_index, + rate=self._rate, + ) async def stop(self) -> None: - """Stop audio input.""" - self.audio._stop() + """Stop input stream.""" + # TODO: Provide time for streaming thread to exit cleanly to prevent conflicts with the Nova threads. + # See if we can remove after properly handling cancellation for agent. + await asyncio.sleep(0.1) + + self._stream.close() + self._audio.terminate() + + self._stream = None + self._audio = None async def __call__(self) -> BidiAudioInputEvent: - """Read audio from microphone.""" - audio_bytes = self.audio.input_stream.read(self.audio.chunk_size, exception_on_overflow=False) + """Read audio from input stream.""" + audio_bytes = await asyncio.to_thread( + self._stream.read, self._frames_per_buffer, exception_on_overflow=False + ) return BidiAudioInputEvent( audio=base64.b64encode(audio_bytes).decode("utf-8"), - format="pcm", - sample_rate=self.audio.input_sample_rate, - channels=self.audio.input_channels, + channels=self._channels, + format=_BidiAudioInput._ENCODING, + sample_rate=self._rate, ) class _BidiAudioOutput(BidiOutput): - """Handle audio output from bidi agent.""" - def __init__(self, audio: "BidiAudioIO") -> None: - """Store reference to pyaudio instance.""" - self.audio = audio + """Handle audio output from bidi agent. + + Attributes: + _audio: PyAudio instance for audio system access. + _stream: Audio output stream. + _buffer: Deque buffer for queuing audio data. + _buffer_event: Event to signal when buffer has data. + _output_task: Background task for processing audio output. + """ + + _audio: pyaudio.PyAudio + _stream: pyaudio.Stream + _buffer: deque + _buffer_event: asyncio.Event + _output_task: asyncio.Task + + _BUFFER_SIZE: int | None = None + _CHANNELS: int = 1 + _DEVICE_INDEX: int | None = None + _FORMAT: int = pyaudio.paInt16 + _FRAMES_PER_BUFFER: int = 512 + _RATE: int = 16000 + + def __init__(self, config: dict[str, Any]) -> None: + """Extract configs.""" + self._buffer_size = config.get("output_buffer_size", _BidiAudioOutput._BUFFER_SIZE) + self._channels = config.get("output_channels", _BidiAudioOutput._CHANNELS) + self._device_index = config.get("output_device_index", _BidiAudioOutput._DEVICE_INDEX) + self._format = config.get("output_format", _BidiAudioOutput._FORMAT) + self._frames_per_buffer = config.get("output_frames_per_buffer", _BidiAudioOutput._FRAMES_PER_BUFFER) + self._rate = config.get("output_rate", _BidiAudioOutput._RATE) async def start(self) -> None: - """Start audio output.""" - self.audio._start() + """Start output stream.""" + self._audio = pyaudio.PyAudio() + self._stream = self._audio.open( + channels=self._channels, + format=self._format, + frames_per_buffer=self._frames_per_buffer, + output=True, + output_device_index=self._device_index, + rate=self._rate, + ) + self._buffer = deque(maxlen=self._buffer_size) + self._buffer_event = asyncio.Event() + self._output_task = asyncio.create_task(self._output()) async def stop(self) -> None: - """Stop audio output.""" - self.audio._stop() + """Stop output stream.""" + self._buffer.clear() + self._buffer.append(None) + self._buffer_event.set() + await self._output_task + + self._stream.close() + self._audio.terminate() + + self._output_task = None + self._buffer = None + self._buffer_event = None + self._stream = None + self._audio = None async def __call__(self, event: BidiOutputEvent) -> None: """Handle audio events with direct stream writing.""" if isinstance(event, BidiAudioStreamEvent): - self.audio.output_stream.write(base64.b64decode(event["audio"])) + audio_bytes = base64.b64decode(event["audio"]) + self._buffer.append(audio_bytes) + self._buffer_event.set() + + elif isinstance(event, BidiInterruptionEvent): + self._buffer.clear() + self._buffer_event.clear() - # TODO: Outputing audio to speakers is a sync operation. Adding sleep to prevent event loop hogging. Will - # follow up on identifying a cleaner approach. - await asyncio.sleep(0.01) + async def _output(self) -> None: + while True: + await self._buffer_event.wait() + self._buffer_event.clear() + + while self._buffer: + audio_bytes = self._buffer.popleft() + if not audio_bytes: + return + + await asyncio.to_thread(self._stream.write, audio_bytes) class BidiAudioIO: - """Audio IO channel for BidiAgent with direct stream processing.""" + """Send and receive audio data from devices.""" - def __init__( - self, - audio_config: dict | None = None, - ): - """Initialize AudioIO with clean audio configuration. + def __init__(self, **config: Any) -> None: + """Initialize audio devices. Args: - audio_config: Dictionary containing audio configuration: - - input_sample_rate (int): Microphone sample rate (default: 24000) - - output_sample_rate (int): Speaker sample rate (default: 24000) - - chunk_size (int): Audio chunk size in bytes (default: 1024) - - input_device_index (int): Specific input device (optional) - - output_device_index (int): Specific output device (optional) + **config: Dictionary containing audio configuration: - input_channels (int): Input channels (default: 1) + - input_device_index (int): Specific input device (optional) + - input_format (int): Audio format (default: paInt16) + - input_frames_per_buffer (int): Frames per buffer (default: 512) + - input_rate (int): Input sample rate (default: 16000) + - output_buffer_size (int): Maximum output buffer size (default: None) - output_channels (int): Output channels (default: 1) + - output_device_index (int): Specific output device (optional) + - output_format (int): Audio format (default: paInt16) + - output_frames_per_buffer (int): Frames per buffer (default: 512) + - output_rate (int): Output sample rate (default: 16000) """ - default_config = { - "input_sample_rate": 16000, - "output_sample_rate": 16000, - "chunk_size": 512, - "input_device_index": None, - "output_device_index": None, - "input_channels": 1, - "output_channels": 1, - } - - # Merge user config with defaults - if audio_config: - default_config.update(audio_config) - - # Set audio configuration attributes - self.input_sample_rate = default_config["input_sample_rate"] - self.output_sample_rate = default_config["output_sample_rate"] - self.chunk_size = default_config["chunk_size"] - self.input_device_index = default_config["input_device_index"] - self.output_device_index = default_config["output_device_index"] - self.input_channels = default_config["input_channels"] - self.output_channels = default_config["output_channels"] - - # Audio infrastructure - self.audio = None - self.input_stream = None - self.output_stream = None - self.interrupted = False + self._config = config def input(self) -> _BidiAudioInput: """Return audio processing BidiInput""" - return _BidiAudioInput(self) + return _BidiAudioInput(self._config) def output(self) -> _BidiAudioOutput: """Return audio processing BidiOutput""" - return _BidiAudioOutput(self) - - def _start(self) -> None: - """Setup PyAudio streams for input and output.""" - if self.audio: - return - - self.audio = pyaudio.PyAudio() - - self.input_stream = self.audio.open( - format=pyaudio.paInt16, - channels=self.input_channels, - rate=self.input_sample_rate, - input=True, - frames_per_buffer=self.chunk_size, - input_device_index=self.input_device_index, - ) - - self.output_stream = self.audio.open( - format=pyaudio.paInt16, - channels=self.output_channels, - rate=self.output_sample_rate, - output=True, - frames_per_buffer=self.chunk_size, - output_device_index=self.output_device_index, - ) - - def _stop(self) -> None: - """Clean up IO channel resources.""" - if not self.audio: - return - - self.input_stream.close() - self.output_stream.close() - self.audio.terminate() - - self.input_stream = None - self.output_stream = None - self.audio = None + return _BidiAudioOutput(self._config) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 13373c69c..0840245d7 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -519,7 +519,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> BidiOutputEvent | N return BidiAudioStreamEvent( audio=audio_content, format="pcm", - sample_rate=24000, + sample_rate=NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"], channels=1 ) diff --git a/src/strands/experimental/bidi/scripts/test_bidi.py b/src/strands/experimental/bidi/scripts/test_bidi.py index 8a59824d9..f07ef1fc4 100644 --- a/src/strands/experimental/bidi/scripts/test_bidi.py +++ b/src/strands/experimental/bidi/scripts/test_bidi.py @@ -17,7 +17,7 @@ async def main(): # Nova Sonic model - audio_io = BidiAudioIO(audio_config={}) + audio_io = BidiAudioIO() text_io = BidiTextIO() model = BidiNovaSonicModel(region="us-east-1") diff --git a/tests/strands/experimental/bidi/io/__init__.py b/tests/strands/experimental/bidi/io/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/bidi/io/test_audio.py b/tests/strands/experimental/bidi/io/test_audio.py new file mode 100644 index 000000000..9a3c2979c --- /dev/null +++ b/tests/strands/experimental/bidi/io/test_audio.py @@ -0,0 +1,80 @@ +import asyncio +import base64 +import unittest.mock + +import pytest + +from strands.experimental.bidi.io import BidiAudioIO +from strands.experimental.bidi.types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiInterruptionEvent + + +@pytest.fixture +def py_audio(): + with unittest.mock.patch("strands.experimental.bidi.io.audio.pyaudio") as mock: + yield mock.PyAudio() + + +@pytest.fixture +def audio_io(): + return BidiAudioIO() + + +@pytest.fixture +def audio_input(audio_io): + return audio_io.input() + + +@pytest.fixture +def audio_output(audio_io): + return audio_io.output() + + +@pytest.mark.asyncio +async def test_bidi_audio_io_input(py_audio, audio_input): + microphone = unittest.mock.Mock() + microphone.read.return_value = b"test-audio" + + py_audio.open.return_value = microphone + + await audio_input.start() + tru_event = await audio_input() + await audio_input.stop() + + exp_event = BidiAudioInputEvent( + audio=base64.b64encode(b"test-audio").decode("utf-8"), + channels=1, + format="pcm", + sample_rate=16000, + ) + assert tru_event == exp_event + + microphone.read.assert_called_once_with(512, exception_on_overflow=False) + + +@pytest.mark.asyncio +async def test_bidi_audio_io_output(py_audio, audio_output): + write_future = asyncio.Future() + write_event = asyncio.Event() + def write(data): + write_future.set_result(data) + write_event.set() + + speaker = unittest.mock.Mock() + speaker.write.side_effect = write + + py_audio.open.return_value = speaker + + await audio_output.start() + + audio_event = BidiAudioStreamEvent( + audio=base64.b64encode(b"test-audio").decode("utf-8"), + channels=1, + format="pcm", + sample_rate=1600, + ) + await audio_output(audio_event) + await write_event.wait() + + await audio_output.stop() + + speaker.write.assert_called_once_with(write_future.result()) From 5073fe4e45afc1f81559729c9d181338390d0501 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 17 Nov 2025 09:59:28 -0500 Subject: [PATCH 120/242] nova sonic - remove model task and events queue (#46) --- .../experimental/bidi/models/novasonic.py | 68 ++----------------- 1 file changed, 7 insertions(+), 61 deletions(-) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 0840245d7..74558f4f7 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -118,10 +118,6 @@ def __init__( # Audio connection state self.audio_connection_active = False - # Background task and event queue - self._response_task = None - self._event_queue = None - # Track API-provided identifiers self._current_completion_id = None self._current_role = None @@ -158,7 +154,6 @@ async def start( self.connection_id = str(uuid.uuid4()) self._active = True self.audio_content_name = str(uuid.uuid4()) - self._event_queue = asyncio.Queue() # Start Nova Sonic bidirectional stream self.stream = await self.client.invoke_model_with_bidirectional_stream( @@ -179,9 +174,6 @@ async def start( logger.debug("Nova Sonic initialization - sending %d events", len(init_events)) await self._send_initialization_events(init_events) - # Start background response processor - self._response_task = asyncio.create_task(self._process_responses()) - logger.info("Nova Sonic connection established successfully") except Exception as e: @@ -208,48 +200,6 @@ async def _send_initialization_events(self, events: list[str]) -> None: await self._send_nova_event(event) await asyncio.sleep(EVENT_DELAY) - async def _process_responses(self) -> None: - """Process Nova Sonic responses continuously.""" - logger.debug("Nova Sonic response processor started") - - try: - while self._active: - try: - output = await asyncio.wait_for(self.stream.await_output(), timeout=RESPONSE_TIMEOUT) - result = await output[1].receive() - - if result.value and result.value.bytes_: - await self._handle_response_data(result.value.bytes_.decode("utf-8")) - - except asyncio.TimeoutError: - await asyncio.sleep(0.1) - continue - except Exception as e: - logger.warning("Nova Sonic response error: %s", e) - await asyncio.sleep(0.1) - continue - - except Exception as e: - logger.error("Nova Sonic fatal error: %s", e) - finally: - logger.debug("Nova Sonic response processor stopped") - - async def _handle_response_data(self, response_data: str) -> None: - """Handle decoded response data from Nova Sonic.""" - try: - json_data = json.loads(response_data) - - if "event" in json_data: - nova_event = json_data["event"] - self._log_event_type(nova_event) - - if not hasattr(self, "_event_queue"): - self._event_queue = asyncio.Queue() - - await self._event_queue.put(nova_event) - except json.JSONDecodeError as e: - logger.warning("Nova Sonic JSON decode error: %s", e) - def _log_event_type(self, nova_event: dict[str, any]) -> None: """Log specific Nova Sonic event types for debugging.""" if "usageEvent" in nova_event: @@ -281,8 +231,13 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: try: while self._active: try: - # Get events from the queue populated by _process_responses - nova_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) + output = await asyncio.wait_for(self.stream.await_output(), timeout=RESPONSE_TIMEOUT) + result = await output[1].receive() + + response_data = result.value.bytes_.decode("utf-8") + json_data = json.loads(response_data) + nova_event = json_data["event"] + self._log_event_type(nova_event) # Convert to provider-agnostic format provider_event = self._convert_nova_event(nova_event) @@ -290,7 +245,6 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: yield provider_event except asyncio.TimeoutError: - # No events in queue - continue waiting continue except Exception as e: @@ -455,14 +409,6 @@ async def stop(self) -> None: logger.debug("Nova cleanup - starting connection close") self._active = False - # Cancel response processing task if running - if hasattr(self, "_response_task") and not self._response_task.done(): - self._response_task.cancel() - try: - await self._response_task - except asyncio.CancelledError: - pass - try: # End audio connection if active if self.audio_connection_active: From b186775b6484e7d309e7a497fd9c7dc1b258d927 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 17 Nov 2025 09:59:42 -0500 Subject: [PATCH 121/242] openai - remove model task and event queue (#47) --- .../experimental/bidi/models/openai.py | 64 ++++++------------- 1 file changed, 18 insertions(+), 46 deletions(-) diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 13a09f8d4..d955155a4 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -5,12 +5,11 @@ """ import asyncio -import base64 import json import logging import os import uuid -from typing import AsyncIterable, Union +from typing import AsyncIterable import websockets from websockets.exceptions import ConnectionClosed @@ -110,8 +109,6 @@ def __init__( self.connection_id = None self._active = False - self._event_queue = None - self._response_task = None self._function_call_buffer = {} logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) @@ -140,7 +137,6 @@ async def start( # Initialize connection state self.connection_id = str(uuid.uuid4()) self._active = True - self._event_queue = asyncio.Queue() self._function_call_buffer = {} # Establish WebSocket connection @@ -163,10 +159,6 @@ async def start( if messages: await self._add_conversation_history(messages) - # Start background response processor - self._response_task = asyncio.create_task(self._process_responses()) - logger.info("OpenAI Realtime connection established") - except Exception as e: self._active = False logger.error("OpenAI connection error: %s", e) @@ -270,30 +262,6 @@ async def _add_conversation_history(self, messages: Messages) -> None: await self._send_event(conversation_item) - async def _process_responses(self) -> None: - """Process incoming WebSocket messages.""" - logger.debug("OpenAI Realtime response processor started") - - try: - async for message in self.websocket: - if not self._active: - break - - try: - event = json.loads(message) - await self._event_queue.put(event) - except json.JSONDecodeError as e: - logger.warning("Failed to parse OpenAI event: %s", e) - continue - - except ConnectionClosed: - logger.debug("OpenAI Realtime WebSocket connection closed") - except Exception as e: - logger.error("Error in OpenAI Realtime response processing: %s", e) - finally: - self._active = False - logger.debug("OpenAI Realtime response processor stopped") - async def receive(self) -> AsyncIterable[BidiOutputEvent]: """Receive OpenAI events and convert to Strands TypedEvent format.""" # Emit connection start event @@ -304,12 +272,14 @@ async def receive(self) -> AsyncIterable[BidiOutputEvent]: try: while self._active: - try: - openai_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) + async for message in self.websocket: + if not self._active: + break + + openai_event = json.loads(message) + for event in self._convert_openai_event(openai_event) or []: yield event - except asyncio.TimeoutError: - continue except Exception as e: logger.error("Error receiving OpenAI Realtime event: %s", e) @@ -317,6 +287,7 @@ async def receive(self) -> AsyncIterable[BidiOutputEvent]: finally: # Emit connection close event yield BidiConnectionCloseEvent(connection_id=self.connection_id, reason="complete") + self._active = False def _convert_openai_event(self, openai_event: dict[str, any]) -> list[BidiOutputEvent] | None: """Convert OpenAI events to Strands TypedEvent format.""" @@ -334,14 +305,22 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[BidiOutput return [BidiAudioStreamEvent( audio=openai_event["delta"], format="pcm", - sample_rate=24000, + sample_rate=AUDIO_FORMAT["rate"], channels=1 )] # Assistant text output events - combine multiple similar events elif event_type in ["response.output_text.delta", "response.output_audio_transcript.delta"]: role = openai_event.get("role", "assistant") - return [self._create_text_event(openai_event["delta"], role.lower() if isinstance(role, str) else "assistant")] + return [self._create_text_event(openai_event["delta"], role.lower() if isinstance(role, str) else "assistant", is_final=False)] + + elif event_type in ["response.output_audio_transcript.done"]: + role = openai_event.get("role", "assistant").lower() + return [self._create_text_event(openai_event["transcript"], role)] + + elif event_type in ["response.output_text.done"]: + role = openai_event.get("role", "assistant").lower() + return [self._create_text_event(openai_event["text"], role)] # User transcription events - combine multiple similar events elif event_type in ["conversation.item.input_audio_transcription.delta", @@ -626,13 +605,6 @@ async def stop(self) -> None: logger.debug("OpenAI Realtime cleanup - starting connection close") self._active = False - if self._response_task and not self._response_task.done(): - self._response_task.cancel() - try: - await self._response_task - except asyncio.CancelledError: - pass - try: await self.websocket.close() except Exception as e: From 496c0255893c0192c1839eb1d6d40d51110dca59 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 17 Nov 2025 10:15:24 -0500 Subject: [PATCH 122/242] agent loop - bound event queue (#48) --- src/strands/experimental/bidi/agent/loop.py | 60 ++++++++++++++------- 1 file changed, 40 insertions(+), 20 deletions(-) diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 9ec978a45..61790cecf 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -19,7 +19,19 @@ class _BidiAgentLoop: - """Agent loop.""" + """Agent loop. + + Attributes: + _agent: BidiAgent instance to loop. + _event_queue: Queue model and tool call events for receiver. + _stop_event: Sentinel to mark end of loop. + _tasks: Track active async tasks created in loop. + _active: Flag if agent loop is started. + """ + + _event_queue: asyncio.Queue + _stop_event: object + _tasks: set def __init__(self, agent: "BidiAgent") -> None: """Initialize members of the agent loop. @@ -30,9 +42,7 @@ def __init__(self, agent: "BidiAgent") -> None: agent: Bidirectional agent to loop over. """ self._agent = agent - self._event_queue = asyncio.Queue() # queue model and tool call events - self._tasks = set() # track active async tasks created in loop - self._active = False # flag if agent loop is started + self._active = False async def start(self) -> None: """Start the agent loop. @@ -44,6 +54,10 @@ async def start(self) -> None: logger.debug("starting agent loop") + self._event_queue = asyncio.Queue(maxsize=1) + self._stop_event = object() + self._tasks = set() + await self._agent.model.start( system_prompt=self._agent.system_prompt, tools=self._agent.tool_registry.get_all_tool_specs(), @@ -68,18 +82,23 @@ async def stop(self) -> None: await self._agent.model.stop() + if not self._event_queue.empty(): + self._event_queue.get_nowait() + self._event_queue.put_nowait(self._stop_event) + self._active = False + self._tasks = None + self._stop_event = None + self._event_queue = None async def receive(self) -> AsyncIterable[BidiOutputEvent]: """Receive model and tool call events.""" - while self.active: - try: - yield self._event_queue.get_nowait() - except asyncio.QueueEmpty: - pass + while True: + event = await self._event_queue.get() + if event is self._stop_event: + break - # unblock the event loop - await asyncio.sleep(0) + yield event @property def active(self) -> bool: @@ -104,10 +123,7 @@ async def _run_model(self) -> None: logger.debug("running model") async for event in self._agent.model.receive(): - if not self.active: - break - - self._event_queue.put_nowait(event) + await self._event_queue.put(event) if isinstance(event, BidiTranscriptStreamEvent): if event["is_final"]: @@ -115,7 +131,11 @@ async def _run_model(self) -> None: self._agent.messages.append(message) elif isinstance(event, ToolUseStreamEvent): - self._create_task(self._run_tool(event["current_tool_use"])) + tool_use = event["current_tool_use"] + self._create_task(self._run_tool(tool_use)) + + message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} + self._agent.messages.append(message) async def _run_tool(self, tool_use: ToolUse) -> None: """Task for running tool requested by the model.""" @@ -129,14 +149,14 @@ async def _run_tool(self, tool_use: ToolUse) -> None: async for event in tool.stream(tool_use, invocation_state): if isinstance(event, ToolResultEvent): - self._event_queue.put_nowait(event) + await self._event_queue.put(event) result = event.tool_result break if isinstance(event, ToolStreamEvent): - self._event_queue.put_nowait(event) + await self._event_queue.put(event) else: - self._event_queue.put_nowait(ToolStreamEvent(tool_use, event)) + await self._event_queue.put(ToolStreamEvent(tool_use, event)) except Exception as e: result = { @@ -152,4 +172,4 @@ async def _run_tool(self, tool_use: ToolUse) -> None: "content": [{"toolResult": result}], } self._agent.messages.append(message) - self._event_queue.put_nowait(ToolResultMessageEvent(message)) + await self._event_queue.put(ToolResultMessageEvent(message)) From e7636293305c252f727361bab1447a1b0c06e51b Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Mon, 17 Nov 2025 07:22:13 -0800 Subject: [PATCH 123/242] Format code --- src/strands/experimental/bidi/__init__.py | 21 +- src/strands/experimental/bidi/agent/agent.py | 45 ++- src/strands/experimental/bidi/agent/loop.py | 12 +- src/strands/experimental/bidi/io/audio.py | 14 +- src/strands/experimental/bidi/io/text.py | 7 +- .../experimental/bidi/models/bidi_model.py | 5 +- .../experimental/bidi/models/gemini_live.py | 327 +++++++++-------- .../experimental/bidi/models/novasonic.py | 50 +-- .../experimental/bidi/models/openai.py | 329 +++++++++--------- .../experimental/bidi/scripts/test_bidi.py | 8 +- .../bidi/scripts/test_bidi_novasonic.py | 21 +- .../bidi/scripts/test_bidi_openai.py | 119 +++---- .../bidi/scripts/test_gemini_live.py | 70 ++-- .../experimental/bidi/types/__init__.py | 6 +- src/strands/experimental/bidi/types/events.py | 6 +- src/strands/experimental/bidi/types/io.py | 5 +- 16 files changed, 492 insertions(+), 553 deletions(-) diff --git a/src/strands/experimental/bidi/__init__.py b/src/strands/experimental/bidi/__init__.py index a25655789..97de04684 100644 --- a/src/strands/experimental/bidi/__init__.py +++ b/src/strands/experimental/bidi/__init__.py @@ -1,6 +1,12 @@ """Bidirectional streaming package.""" # Main components - Primary user interface +# Re-export standard agent events for tool handling +from ...types._events import ( + ToolResultEvent, + ToolStreamEvent, + ToolUseStreamEvent, +) from .agent.agent import BidiAgent # IO channels - Hardware abstraction @@ -24,20 +30,13 @@ BidiImageInputEvent, BidiInputEvent, BidiInterruptionEvent, - ModalityUsage, - BidiUsageEvent, BidiOutputEvent, BidiResponseCompleteEvent, BidiResponseStartEvent, BidiTextInputEvent, BidiTranscriptStreamEvent, -) - -# Re-export standard agent events for tool handling -from ...types._events import ( - ToolResultEvent, - ToolStreamEvent, - ToolUseStreamEvent, + BidiUsageEvent, + ModalityUsage, ) __all__ = [ @@ -49,13 +48,11 @@ "BidiGeminiLiveModel", "BidiNovaSonicModel", "BidiOpenAIRealtimeModel", - # Input Event types "BidiTextInputEvent", "BidiAudioInputEvent", "BidiImageInputEvent", "BidiInputEvent", - # Output Event types "BidiConnectionStartEvent", "BidiConnectionCloseEvent", @@ -68,12 +65,10 @@ "ModalityUsage", "BidiErrorEvent", "BidiOutputEvent", - # Tool Event types (reused from standard agent) "ToolUseStreamEvent", "ToolResultEvent", "ToolStreamEvent", - # Model interface "BidiModel", ] diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index b8503ceca..846473e86 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -24,15 +24,14 @@ from ....tools.registry import ToolRegistry from ....tools.watcher import ToolWatcher from ....types.content import Message, Messages -from ....types.tools import ToolResult, ToolUse, AgentTool - -from .loop import _BidiAgentLoop +from ....types.tools import AgentTool, ToolResult, ToolUse +from ...tools import ToolProvider from ..models.bidi_model import BidiModel from ..models.novasonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput -from ..types.events import BidiAudioInputEvent, BidiImageInputEvent, BidiTextInputEvent, BidiInputEvent, BidiOutputEvent +from ..types.events import BidiAudioInputEvent, BidiImageInputEvent, BidiInputEvent, BidiOutputEvent, BidiTextInputEvent from ..types.io import BidiInput, BidiOutput -from ...tools import ToolProvider +from .loop import _BidiAgentLoop logger = logging.getLogger(__name__) @@ -49,8 +48,8 @@ class BidiAgent: def __init__( self, - model: BidiModel| str | None = None, - tools: list[str| AgentTool| ToolProvider]| None = None, + model: BidiModel | str | None = None, + tools: list[str | AgentTool | ToolProvider] | None = None, system_prompt: str | None = None, messages: Messages | None = None, record_direct_tool_call: bool = True, @@ -244,21 +243,21 @@ async def start(self) -> None: async def send(self, input_data: BidiAgentInput) -> None: """Send input to the model (text, audio, image, or event dict). - + Unified method for sending text, audio, and image input to the model during an active conversation session. Accepts TypedEvent instances or plain dicts (e.g., from WebSocket clients) which are automatically reconstructed. - + Args: input_data: Can be: - str: Text message from user - BidiAudioInputEvent: Audio data with format/sample rate - BidiImageInputEvent: Image data with MIME type - dict: Event dictionary (will be reconstructed to TypedEvent) - + Raises: ValueError: If no active session or invalid input type. - + Example: await agent.send("Hello") await agent.send(BidiAudioInputEvent(audio="base64...", format="pcm", ...)) @@ -278,13 +277,13 @@ async def send(self, input_data: BidiAgentInput) -> None: text_event = BidiTextInputEvent(text=input_data, role="user") await self.model.send(text_event) return - + # Handle BidiInputEvent instances # Check this before dict since TypedEvent inherits from dict if isinstance(input_data, BidiInputEvent): await self.model.send(input_data) return - + # Handle plain dict - reconstruct TypedEvent for WebSocket integration if isinstance(input_data, dict) and "type" in input_data: event_type = input_data["type"] @@ -295,20 +294,17 @@ async def send(self, input_data: BidiAgentInput) -> None: audio=input_data["audio"], format=input_data["format"], sample_rate=input_data["sample_rate"], - channels=input_data["channels"] + channels=input_data["channels"], ) elif event_type == "bidi_image_input": - input_event = BidiImageInputEvent( - image=input_data["image"], - mime_type=input_data["mime_type"] - ) + input_event = BidiImageInputEvent(image=input_data["image"], mime_type=input_data["mime_type"]) else: raise ValueError(f"Unknown event type: {event_type}") - + # Send the reconstructed TypedEvent await self.model.send(input_event) return - + # If we get here, input type is invalid raise ValueError( f"Input must be a string, BidiInputEvent (BidiTextInputEvent/BidiAudioInputEvent/BidiImageInputEvent), or event dict with 'type' field, got: {type(input_data)}" @@ -359,7 +355,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: """ try: logger.debug("Exiting async context manager - cleaning up adapters and connection") - + # Cleanup adapters if any are currently active for adapter in self._current_adapters: if hasattr(adapter, "cleanup"): @@ -368,10 +364,10 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: logger.debug(f"Cleaned up adapter: {type(adapter).__name__}") except Exception as adapter_error: logger.warning(f"Error cleaning up adapter: {adapter_error}") - + # Clear current adapters self._current_adapters = [] - + # Cleanup agent connection await self.stop() @@ -397,7 +393,7 @@ async def run(self, inputs: list[BidiInput], outputs: list[BidiOutput]) -> None: Args: inputs: Input callables to read data from a source outputs: Output callables to receive events from the agent - + Example: ```python audio_io = BidiAudioIO(audio_config={"input_sample_rate": 16000}) @@ -406,6 +402,7 @@ async def run(self, inputs: list[BidiInput], outputs: list[BidiOutput]) -> None: await agent.run(inputs=[audio_io.input()], outputs=[audio_io.output(), text_io.output()]) ``` """ + async def run_inputs(): while self.active: for input_ in inputs: diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 61790cecf..bb4c3a55e 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -5,12 +5,12 @@ import asyncio import logging -from typing import AsyncIterable, Awaitable, TYPE_CHECKING +from typing import TYPE_CHECKING, AsyncIterable, Awaitable -from ..types.events import BidiOutputEvent, BidiTranscriptStreamEvent from ....types._events import ToolResultEvent, ToolResultMessageEvent, ToolStreamEvent, ToolUseStreamEvent from ....types.content import Message from ....types.tools import ToolResult, ToolUse +from ..types.events import BidiOutputEvent, BidiTranscriptStreamEvent if TYPE_CHECKING: from .agent import BidiAgent @@ -46,7 +46,7 @@ def __init__(self, agent: "BidiAgent") -> None: async def start(self) -> None: """Start the agent loop. - + The agent model is started as part of this call. """ if self.active: @@ -159,11 +159,7 @@ async def _run_tool(self, tool_use: ToolUse) -> None: await self._event_queue.put(ToolStreamEvent(tool_use, event)) except Exception as e: - result = { - "toolUseId": tool_use["toolUseId"], - "status": "error", - "content": [{"text": f"Error: {str(e)}"}] - } + result = {"toolUseId": tool_use["toolUseId"], "status": "error", "content": [{"text": f"Error: {str(e)}"}]} await self._agent.model.send(ToolResultEvent(result)) diff --git a/src/strands/experimental/bidi/io/audio.py b/src/strands/experimental/bidi/io/audio.py index 34e8fd0e9..e0ebef070 100644 --- a/src/strands/experimental/bidi/io/audio.py +++ b/src/strands/experimental/bidi/io/audio.py @@ -12,15 +12,15 @@ import pyaudio -from ..types.io import BidiInput, BidiOutput from ..types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiInterruptionEvent, BidiOutputEvent +from ..types.io import BidiInput, BidiOutput logger = logging.getLogger(__name__) class _BidiAudioInput(BidiInput): """Handle audio input from user. - + Attributes: _audio: PyAudio instance for audio system access. _stream: Audio input stream. @@ -43,7 +43,7 @@ def __init__(self, config: dict[str, Any]) -> None: self._format = config.get("input_format", _BidiAudioInput._FORMAT) self._frames_per_buffer = config.get("input_frames_per_buffer", _BidiAudioInput._FRAMES_PER_BUFFER) self._rate = config.get("input_rate", _BidiAudioInput._RATE) - + async def start(self) -> None: """Start input stream.""" self._audio = pyaudio.PyAudio() @@ -70,9 +70,7 @@ async def stop(self) -> None: async def __call__(self) -> BidiAudioInputEvent: """Read audio from input stream.""" - audio_bytes = await asyncio.to_thread( - self._stream.read, self._frames_per_buffer, exception_on_overflow=False - ) + audio_bytes = await asyncio.to_thread(self._stream.read, self._frames_per_buffer, exception_on_overflow=False) return BidiAudioInputEvent( audio=base64.b64encode(audio_bytes).decode("utf-8"), @@ -84,7 +82,7 @@ async def __call__(self) -> BidiAudioInputEvent: class _BidiAudioOutput(BidiOutput): """Handle audio output from bidi agent. - + Attributes: _audio: PyAudio instance for audio system access. _stream: Audio output stream. @@ -161,7 +159,7 @@ async def _output(self) -> None: while True: await self._buffer_event.wait() self._buffer_event.clear() - + while self._buffer: audio_bytes = self._buffer.popleft() if not audio_bytes: diff --git a/src/strands/experimental/bidi/io/text.py b/src/strands/experimental/bidi/io/text.py index cd76de2c9..289003e02 100644 --- a/src/strands/experimental/bidi/io/text.py +++ b/src/strands/experimental/bidi/io/text.py @@ -2,17 +2,17 @@ import logging +from ..types.events import BidiInterruptionEvent, BidiOutputEvent, BidiTranscriptStreamEvent from ..types.io import BidiOutput -from ..types.events import BidiOutputEvent, BidiInterruptionEvent, BidiTranscriptStreamEvent logger = logging.getLogger(__name__) class _BidiTextOutput(BidiOutput): """Handle text output from bidi agent.""" + async def __call__(self, event: BidiOutputEvent) -> None: """Print text events to stdout.""" - if isinstance(event, BidiInterruptionEvent): print("interrupted") @@ -26,6 +26,7 @@ async def __call__(self, event: BidiOutputEvent) -> None: class BidiTextIO: """Handle text input and output from bidi agent.""" + def output(self) -> _BidiTextOutput: - "Return text processing BidiOutput" + """Return text processing BidiOutput""" return _BidiTextOutput() diff --git a/src/strands/experimental/bidi/models/bidi_model.py b/src/strands/experimental/bidi/models/bidi_model.py index d3c3aa7ec..e598498e1 100644 --- a/src/strands/experimental/bidi/models/bidi_model.py +++ b/src/strands/experimental/bidi/models/bidi_model.py @@ -13,17 +13,14 @@ """ import logging -from typing import AsyncIterable, Protocol, Union +from typing import AsyncIterable, Protocol from ....types._events import ToolResultEvent from ....types.content import Messages from ....types.tools import ToolSpec from ..types.events import ( - BidiAudioInputEvent, - BidiImageInputEvent, BidiInputEvent, BidiOutputEvent, - BidiTextInputEvent, ) logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index ef08233c3..2f7c523ec 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -15,15 +15,15 @@ import base64 import logging import uuid -from typing import Any, AsyncIterable, Dict, List, Optional, Union +from typing import Any, AsyncIterable, Dict, List, Optional from google import genai from google.genai import types as genai_types -from google.genai.types import LiveServerMessage, LiveServerContent +from google.genai.types import LiveServerMessage +from ....types._events import ToolResultEvent, ToolUseStreamEvent from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse -from ....types._events import ToolResultEvent, ToolUseStreamEvent from ..types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, @@ -34,11 +34,9 @@ BidiInputEvent, BidiInterruptionEvent, BidiOutputEvent, - BidiUsageEvent, BidiTextInputEvent, BidiTranscriptStreamEvent, - BidiResponseCompleteEvent, - BidiResponseStartEvent, + BidiUsageEvent, ) from .bidi_model import BidiModel @@ -52,21 +50,21 @@ class BidiGeminiLiveModel(BidiModel): """Gemini Live API implementation using official Google GenAI SDK. - + Combines model configuration and connection state in a single class. Provides a clean interface to Gemini Live API using the official SDK, eliminating custom WebSocket handling and providing robust error handling. """ - + def __init__( self, model_id: str = "gemini-2.5-flash-native-audio-preview-09-2025", api_key: Optional[str] = None, live_config: Optional[Dict[str, Any]] = None, - **kwargs + **kwargs, ): """Initialize Gemini Live API bidirectional model. - + Args: model_id: Gemini Live model identifier. api_key: Google AI API key for authentication. @@ -76,45 +74,45 @@ def __init__( # Model configuration self.model_id = model_id self.api_key = api_key - + # Set default live_config with transcription enabled default_config = { "response_modalities": ["AUDIO"], "outputAudioTranscription": {}, # Enable output transcription by default - "inputAudioTranscription": {} # Enable input transcription by default + "inputAudioTranscription": {}, # Enable input transcription by default } - + # Merge user config with defaults (user config takes precedence) if live_config: default_config.update(live_config) - + self.live_config = default_config - + # Create Gemini client with proper API version client_kwargs = {} if api_key: client_kwargs["api_key"] = api_key - + # Use v1alpha for Live API as it has better model support client_kwargs["http_options"] = {"api_version": "v1alpha"} - + self.client = genai.Client(**client_kwargs) - + # Connection state (initialized in start()) self.live_session = None self.live_session_context_manager = None self.connection_id = None self._active = False - + async def start( self, system_prompt: Optional[str] = None, tools: Optional[List[ToolSpec]] = None, messages: Optional[Messages] = None, - **kwargs + **kwargs, ) -> None: """Establish bidirectional connection with Gemini Live API. - + Args: system_prompt: System instructions for the model. tools: List of tools available to the model. @@ -123,66 +121,59 @@ async def start( """ if self._active: raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") - + try: # Initialize connection state self.connection_id = str(uuid.uuid4()) self._active = True - + # Build live config live_config = self._build_live_config(system_prompt, tools, **kwargs) - + # Create the context manager - self.live_session_context_manager = self.client.aio.live.connect( - model=self.model_id, - config=live_config - ) - + self.live_session_context_manager = self.client.aio.live.connect(model=self.model_id, config=live_config) + # Enter the context manager self.live_session = await self.live_session_context_manager.__aenter__() - + # Send initial message history if provided if messages: await self._send_message_history(messages) - + except Exception as e: self._active = False logger.error("Error connecting to Gemini Live: %s", e) raise - + async def _send_message_history(self, messages: Messages) -> None: """Send conversation history to Gemini Live API. - + Sends each message as a separate turn with the correct role to maintain proper conversation context. Follows the same pattern as the non-bidirectional Gemini model implementation. """ if not messages: return - + # Convert each message to Gemini format and send separately for message in messages: content_parts = [] for content_block in message["content"]: if "text" in content_block: content_parts.append(genai_types.Part(text=content_block["text"])) - + if content_parts: # Map role correctly - Gemini uses "user" and "model" roles # "assistant" role from Messages format maps to "model" in Gemini role = "model" if message["role"] == "assistant" else message["role"] content = genai_types.Content(role=role, parts=content_parts) await self.live_session.send_client_content(turns=content) - + async def receive(self) -> AsyncIterable[BidiOutputEvent]: """Receive Gemini Live API events and convert to provider-agnostic format.""" - # Emit connection start event - yield BidiConnectionStartEvent( - connection_id=self.connection_id, - model=self.model_id - ) - + yield BidiConnectionStartEvent(connection_id=self.connection_id, model=self.model_id) + try: # Wrap in while loop to restart after turn_complete (SDK limitation workaround) while self._active: @@ -190,36 +181,36 @@ async def receive(self) -> AsyncIterable[BidiOutputEvent]: async for message in self.live_session.receive(): if not self._active: break - + # Convert to provider-agnostic format (always returns list) for event in self._convert_gemini_live_event(message): yield event - + # SDK exits receive loop after turn_complete - restart automatically if self._active: logger.debug("Restarting receive loop after turn completion") - + except Exception as e: logger.error("Error in receive iteration: %s", e) # Small delay before retrying to avoid tight error loops await asyncio.sleep(0.1) - + except Exception as e: logger.error("Fatal error in receive loop: %s", e) yield BidiErrorEvent(error=e) finally: # Emit connection close event when exiting yield BidiConnectionCloseEvent(connection_id=self.connection_id, reason="complete") - + def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOutputEvent]: """Convert Gemini Live API events to provider-agnostic format. - + Handles different types of content: - inputTranscription: User's speech transcribed to text - outputTranscription: Model's audio transcribed to text - modelTurn text: Text response from the model - usageMetadata: Token usage information - + Returns: List of event dicts (empty list if no events to emit). """ @@ -227,51 +218,54 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOut # Handle interruption first (from server_content) if message.server_content and message.server_content.interrupted: return [BidiInterruptionEvent(reason="user_speech")] - + # Handle input transcription (user's speech) - emit as transcript event if message.server_content and message.server_content.input_transcription: input_transcript = message.server_content.input_transcription # Check if the transcription object has text content - if hasattr(input_transcript, 'text') and input_transcript.text: + if hasattr(input_transcript, "text") and input_transcript.text: transcription_text = input_transcript.text - role = getattr(input_transcript, 'role', 'user') + role = getattr(input_transcript, "role", "user") logger.debug(f"Input transcription detected: {transcription_text}") - return [BidiTranscriptStreamEvent( - delta={"text": transcription_text}, - text=transcription_text, - role=role.lower() if isinstance(role, str) else "user", - is_final=True, - current_transcript=transcription_text - )] - + return [ + BidiTranscriptStreamEvent( + delta={"text": transcription_text}, + text=transcription_text, + role=role.lower() if isinstance(role, str) else "user", + is_final=True, + current_transcript=transcription_text, + ) + ] + # Handle output transcription (model's audio) - emit as transcript event if message.server_content and message.server_content.output_transcription: output_transcript = message.server_content.output_transcription # Check if the transcription object has text content - if hasattr(output_transcript, 'text') and output_transcript.text: + if hasattr(output_transcript, "text") and output_transcript.text: transcription_text = output_transcript.text - role = getattr(output_transcript, 'role', 'assistant') + role = getattr(output_transcript, "role", "assistant") logger.debug(f"Output transcription detected: {transcription_text}") - return [BidiTranscriptStreamEvent( - delta={"text": transcription_text}, - text=transcription_text, - role=role.lower() if isinstance(role, str) else "assistant", - is_final=True, - current_transcript=transcription_text - )] - + return [ + BidiTranscriptStreamEvent( + delta={"text": transcription_text}, + text=transcription_text, + role=role.lower() if isinstance(role, str) else "assistant", + is_final=True, + current_transcript=transcription_text, + ) + ] + # Handle audio output using SDK's built-in data property # Check this BEFORE text to avoid triggering warning on mixed content if message.data: # Convert bytes to base64 string for JSON serializability - audio_b64 = base64.b64encode(message.data).decode('utf-8') - return [BidiAudioStreamEvent( - audio=audio_b64, - format="pcm", - sample_rate=GEMINI_OUTPUT_SAMPLE_RATE, - channels=GEMINI_CHANNELS - )] - + audio_b64 = base64.b64encode(message.data).decode("utf-8") + return [ + BidiAudioStreamEvent( + audio=audio_b64, format="pcm", sample_rate=GEMINI_OUTPUT_SAMPLE_RATE, channels=GEMINI_CHANNELS + ) + ] + # Handle text output from model_turn (avoids warning by checking parts directly) if message.server_content and message.server_content.model_turn: model_turn = message.server_content.model_turn @@ -280,22 +274,24 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOut text_parts = [] for part in model_turn.parts: # Log all part types for debugging - part_attrs = {attr: getattr(part, attr, None) for attr in dir(part) if not attr.startswith('_')} - + part_attrs = {attr: getattr(part, attr, None) for attr in dir(part) if not attr.startswith("_")} + # Check if part has text attribute and it's not empty - if hasattr(part, 'text') and part.text: + if hasattr(part, "text") and part.text: text_parts.append(part.text) - + if text_parts: full_text = " ".join(text_parts) - return [BidiTranscriptStreamEvent( - delta={"text": full_text}, - text=full_text, - role="assistant", - is_final=True, - current_transcript=full_text - )] - + return [ + BidiTranscriptStreamEvent( + delta={"text": full_text}, + text=full_text, + role="assistant", + is_final=True, + current_transcript=full_text, + ) + ] + # Handle tool calls - return list to support multiple tool calls if message.tool_call and message.tool_call.function_calls: tool_events = [] @@ -303,32 +299,33 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOut tool_use_event: ToolUse = { "toolUseId": func_call.id, "name": func_call.name, - "input": func_call.args or {} + "input": func_call.args or {}, } # Create ToolUseStreamEvent for consistency with standard agent - tool_events.append(ToolUseStreamEvent( - delta={"toolUse": tool_use_event}, - current_tool_use=tool_use_event - )) + tool_events.append( + ToolUseStreamEvent(delta={"toolUse": tool_use_event}, current_tool_use=tool_use_event) + ) return tool_events - + # Handle usage metadata - if hasattr(message, 'usage_metadata') and message.usage_metadata: + if hasattr(message, "usage_metadata") and message.usage_metadata: usage = message.usage_metadata - + # Build modality details from token details modality_details = [] - + # Process prompt tokens details if usage.prompt_tokens_details: for detail in usage.prompt_tokens_details: if detail.modality and detail.token_count: - modality_details.append({ - "modality": str(detail.modality).lower(), - "input_tokens": detail.token_count, - "output_tokens": 0 - }) - + modality_details.append( + { + "modality": str(detail.modality).lower(), + "input_tokens": detail.token_count, + "output_tokens": 0, + } + ) + # Process response tokens details if usage.response_tokens_details: for detail in usage.response_tokens_details: @@ -339,44 +336,46 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOut if existing: existing["output_tokens"] = detail.token_count else: - modality_details.append({ - "modality": modality_str, - "input_tokens": 0, - "output_tokens": detail.token_count - }) - - return [BidiUsageEvent( - input_tokens=usage.prompt_token_count or 0, - output_tokens=usage.response_token_count or 0, - total_tokens=usage.total_token_count or 0, - modality_details=modality_details if modality_details else None, - cache_read_input_tokens=usage.cached_content_token_count if usage.cached_content_token_count else None - )] - + modality_details.append( + {"modality": modality_str, "input_tokens": 0, "output_tokens": detail.token_count} + ) + + return [ + BidiUsageEvent( + input_tokens=usage.prompt_token_count or 0, + output_tokens=usage.response_token_count or 0, + total_tokens=usage.total_token_count or 0, + modality_details=modality_details if modality_details else None, + cache_read_input_tokens=usage.cached_content_token_count + if usage.cached_content_token_count + else None, + ) + ] + # Silently ignore setup_complete and generation_complete messages return [] - + except Exception as e: logger.error("Error converting Gemini Live event: %s", e) logger.error("Message type: %s", type(message).__name__) - logger.error("Message attributes: %s", [attr for attr in dir(message) if not attr.startswith('_')]) + logger.error("Message attributes: %s", [attr for attr in dir(message) if not attr.startswith("_")]) # Return ErrorEvent in list so caller can handle it return [BidiErrorEvent(error=e)] - + async def send( self, content: BidiInputEvent | ToolResultEvent, ) -> None: """Unified send method for all content types. Sends the given inputs to Google Live API - + Dispatches to appropriate internal handler based on content type. - + Args: content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent). """ if not self._active: return - + try: if isinstance(content, BidiTextInputEvent): await self._send_text_content(content.text) @@ -393,68 +392,59 @@ async def send( except Exception as e: logger.error(f"Error sending content: {e}") raise # Propagate exception for debugging in experimental code - + async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: """Internal: Send audio content using Gemini Live API. - + Gemini Live expects continuous audio streaming via send_realtime_input. This automatically triggers VAD and can interrupt ongoing responses. """ try: # Decode base64 audio to bytes for SDK audio_bytes = base64.b64decode(audio_input.audio) - + # Create audio blob for the SDK - audio_blob = genai_types.Blob( - data=audio_bytes, - mime_type=f"audio/pcm;rate={GEMINI_INPUT_SAMPLE_RATE}" - ) - + audio_blob = genai_types.Blob(data=audio_bytes, mime_type=f"audio/pcm;rate={GEMINI_INPUT_SAMPLE_RATE}") + # Send real-time audio input - this automatically handles VAD and interruption await self.live_session.send_realtime_input(audio=audio_blob) - + except Exception as e: logger.error("Error sending audio content: %s", e) - + async def _send_image_content(self, image_input: BidiImageInputEvent) -> None: """Internal: Send image content using Gemini Live API. - + Sends image frames following the same pattern as the GitHub example. Images are sent as base64-encoded data with MIME type. """ try: # Image is already base64 encoded in the event - msg = { - "mime_type": image_input.mime_type, - "data": image_input.image - } - + msg = {"mime_type": image_input.mime_type, "data": image_input.image} + # Send using the same method as the GitHub example await self.live_session.send(input=msg) - + except Exception as e: logger.error("Error sending image content: %s", e) - + async def _send_text_content(self, text: str) -> None: """Internal: Send text content using Gemini Live API.""" try: # Create content with text - content = genai_types.Content( - role="user", - parts=[genai_types.Part(text=text)] - ) - + content = genai_types.Content(role="user", parts=[genai_types.Part(text=text)]) + # Send as client content await self.live_session.send_client_content(turns=content) - + except Exception as e: logger.error("Error sending text content: %s", e) - + async def _send_tool_result(self, tool_result: ToolResult) -> None: """Internal: Send tool result using Gemini Live API.""" try: tool_use_id = tool_result.get("toolUseId") - + # Extract result content result_data = {} if "content" in tool_result: @@ -463,26 +453,26 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: if "text" in block: result_data = {"result": block["text"]} break - + # Create function response func_response = genai_types.FunctionResponse( id=tool_use_id, name=tool_use_id, # Gemini uses name as identifier - response=result_data + response=result_data, ) - + # Send tool response await self.live_session.send_tool_response(function_responses=[func_response]) except Exception as e: logger.error("Error sending tool result: %s", e) - + async def stop(self) -> None: """Close Gemini Live API connection.""" if not self._active: return - + self._active = False - + try: # Exit the context manager properly if self.live_session_context_manager: @@ -490,15 +480,12 @@ async def stop(self) -> None: except Exception as e: logger.error("Error closing Gemini Live connection: %s", e) raise - + def _build_live_config( - self, - system_prompt: Optional[str] = None, - tools: Optional[List[ToolSpec]] = None, - **kwargs + self, system_prompt: Optional[str] = None, tools: Optional[List[ToolSpec]] = None, **kwargs ) -> Dict[str, Any]: """Build LiveConnectConfig for the official SDK. - + Simply passes through all config parameters from live_config, allowing users to configure any Gemini Live API parameter directly. """ @@ -506,25 +493,25 @@ def _build_live_config( config_dict = {} if self.live_config: config_dict.update(self.live_config) - + # Override with any kwargs from start() config_dict.update(kwargs) - + # Add system instruction if provided if system_prompt: config_dict["system_instruction"] = system_prompt - + # Add tools if provided if tools: config_dict["tools"] = self._format_tools_for_live_api(tools) - + return config_dict - + def _format_tools_for_live_api(self, tool_specs: List[ToolSpec]) -> List[genai_types.Tool]: """Format tool specs for Gemini Live API.""" if not tool_specs: return [] - + return [ genai_types.Tool( function_declarations=[ @@ -536,4 +523,4 @@ def _format_tools_for_live_api(self, tool_specs: List[ToolSpec]) -> List[genai_t for tool_spec in tool_specs ], ), - ] \ No newline at end of file + ] diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 74558f4f7..22be7edb7 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -28,9 +28,9 @@ ) from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver +from ....types._events import ToolResultEvent, ToolUseStreamEvent from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse -from ....types._events import ToolResultEvent, ToolUseStreamEvent from ..types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, @@ -40,12 +40,12 @@ BidiImageInputEvent, BidiInputEvent, BidiInterruptionEvent, - BidiUsageEvent, BidiOutputEvent, - BidiTextInputEvent, - BidiTranscriptStreamEvent, BidiResponseCompleteEvent, BidiResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiUsageEvent, ) from .bidi_model import BidiModel @@ -89,12 +89,7 @@ class BidiNovaSonicModel(BidiModel): tool execution patterns while providing the standard BidiModel interface. """ - def __init__( - self, - model_id: str = "amazon.nova-sonic-v1:0", - region: str = "us-east-1", - **kwargs - ) -> None: + def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **kwargs) -> None: """Initialize Nova Sonic bidirectional model. Args: @@ -223,10 +218,7 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: logger.debug("Nova events - starting event stream") # Emit connection start event - yield BidiConnectionStartEvent( - connection_id=self.connection_id, - model=self.model_id - ) + yield BidiConnectionStartEvent(connection_id=self.connection_id, model=self.model_id) try: while self._active: @@ -442,31 +434,28 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> BidiOutputEvent | N self._current_completion_id = completion_data.get("completionId") logger.debug("Nova completion started: %s", self._current_completion_id) return None - + # Handle completion end if "completionEnd" in nova_event: completion_data = nova_event["completionEnd"] completion_id = completion_data.get("completionId", self._current_completion_id) stop_reason = completion_data.get("stopReason", "END_TURN") - + event = BidiResponseCompleteEvent( response_id=completion_id or str(uuid.uuid4()), # Fallback to UUID if missing - stop_reason="interrupted" if stop_reason == "INTERRUPTED" else "complete" + stop_reason="interrupted" if stop_reason == "INTERRUPTED" else "complete", ) - + # Clear completion tracking self._current_completion_id = None return event - + # Handle audio output if "audioOutput" in nova_event: # Audio is already base64 string from Nova Sonic audio_content = nova_event["audioOutput"]["content"] return BidiAudioStreamEvent( - audio=audio_content, - format="pcm", - sample_rate=NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"], - channels=1 + audio=audio_content, format="pcm", sample_rate=NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"], channels=1 ) # Handle text output (transcripts) @@ -482,7 +471,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> BidiOutputEvent | N text=text_content, role=self._current_role.lower() if self._current_role else "assistant", is_final=self._generation_stage == "FINAL", - current_transcript=text_content + current_transcript=text_content, ) # Handle tool use @@ -494,10 +483,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> BidiOutputEvent | N "input": json.loads(tool_use["content"]), } # Return ToolUseStreamEvent for consistency with standard agent - return ToolUseStreamEvent( - delta={"toolUse": tool_use_event}, - current_tool_use=tool_use_event - ) + return ToolUseStreamEvent(delta={"toolUse": tool_use_event}, current_tool_use=tool_use_event) # Handle interruption if nova_event.get("stopReason") == "INTERRUPTED": @@ -509,11 +495,11 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> BidiOutputEvent | N usage_data = nova_event["usageEvent"] total_input = usage_data.get("totalInputTokens", 0) total_output = usage_data.get("totalOutputTokens", 0) - + return BidiUsageEvent( input_tokens=total_input, output_tokens=total_output, - total_tokens=usage_data.get("totalTokens", total_input + total_output) + total_tokens=usage_data.get("totalTokens", total_input + total_output), ) # Handle content start events (track role and emit response start) @@ -522,10 +508,10 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> BidiOutputEvent | N role = content_data.get("role", "unknown") # Store role for subsequent text output events self._current_role = role - + if content_data["type"] == "TEXT": self._generation_stage = json.loads(content_data["additionalModelFields"])["generationStage"] - + # Emit response start event using API-provided completionId # completionId should already be tracked from completionStart event return BidiResponseStartEvent( diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index d955155a4..7f7ce2eb6 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -4,7 +4,6 @@ with WebSocket connections, voice activity detection, and function calling. """ -import asyncio import json import logging import os @@ -12,11 +11,10 @@ from typing import AsyncIterable import websockets -from websockets.exceptions import ConnectionClosed +from ....types._events import ToolResultEvent, ToolUseStreamEvent from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse -from ....types._events import ToolResultEvent, ToolUseStreamEvent from ..types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, @@ -26,12 +24,12 @@ BidiImageInputEvent, BidiInputEvent, BidiInterruptionEvent, - BidiUsageEvent, BidiOutputEvent, - BidiTextInputEvent, - BidiTranscriptStreamEvent, BidiResponseCompleteEvent, BidiResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiUsageEvent, ) from .bidi_model import BidiModel @@ -50,15 +48,13 @@ "audio": { "input": { "format": AUDIO_FORMAT, - "transcription": { - "model": "gpt-4o-transcribe" - }, + "transcription": {"model": "gpt-4o-transcribe"}, "turn_detection": { "type": "server_vad", "threshold": 0.5, "prefix_padding_ms": 300, "silence_duration_ms": 500, - } + }, }, "output": {"format": AUDIO_FORMAT, "voice": "alloy"}, }, @@ -67,23 +63,23 @@ class BidiOpenAIRealtimeModel(BidiModel): """OpenAI Realtime API implementation for bidirectional streaming. - + Combines model configuration and connection state in a single class. Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, function calling, and event conversion to Strands format. """ def __init__( - self, + self, model: str = DEFAULT_MODEL, api_key: str | None = None, organization: str | None = None, project: str | None = None, session_config: dict[str, any] | None = None, - **kwargs + **kwargs, ) -> None: """Initialize OpenAI Realtime bidirectional model. - + Args: model: OpenAI model identifier (default: gpt-realtime). api_key: OpenAI API key for authentication. @@ -98,19 +94,21 @@ def __init__( self.organization = organization self.project = project self.session_config = session_config or {} - + if not self.api_key: self.api_key = os.getenv("OPENAI_API_KEY") if not self.api_key: - raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.") - + raise ValueError( + "OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter." + ) + # Connection state (initialized in start()) self.websocket = None self.connection_id = None self._active = False - + self._function_call_buffer = {} - + logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) async def start( @@ -121,7 +119,7 @@ async def start( **kwargs, ) -> None: """Establish bidirectional connection to OpenAI Realtime API. - + Args: system_prompt: System instructions for the model. tools: List of tools available to the model. @@ -130,35 +128,35 @@ async def start( """ if self._active: raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") - + logger.info("Creating OpenAI Realtime connection...") - + try: # Initialize connection state self.connection_id = str(uuid.uuid4()) self._active = True self._function_call_buffer = {} - + # Establish WebSocket connection url = f"{OPENAI_REALTIME_URL}?model={self.model}" - + headers = [("Authorization", f"Bearer {self.api_key}")] if self.organization: headers.append(("OpenAI-Organization", self.organization)) if self.project: headers.append(("OpenAI-Project", self.project)) - + self.websocket = await websockets.connect(url, additional_headers=headers) logger.info("WebSocket connected successfully") - + # Configure session session_config = self._build_session_config(system_prompt, tools) await self._send_event({"type": "session.update", "session": session_config}) - + # Add conversation history if provided if messages: await self._add_conversation_history(messages) - + except Exception as e: self._active = False logger.error("OpenAI connection error: %s", e) @@ -170,7 +168,7 @@ def _require_active(self) -> bool: def _create_text_event(self, text: str, role: str, is_final: bool = True) -> BidiTranscriptStreamEvent: """Create standardized transcript event. - + Args: text: The transcript text role: The role (will be normalized to lowercase) @@ -180,13 +178,13 @@ def _create_text_event(self, text: str, role: str, is_final: bool = True) -> Bid normalized_role = role.lower() if isinstance(role, str) else "assistant" if normalized_role not in ["user", "assistant"]: normalized_role = "assistant" - + return BidiTranscriptStreamEvent( delta={"text": text}, text=text, role=normalized_role, is_final=is_final, - current_transcript=text if is_final else None + current_transcript=text if is_final else None, ) def _create_voice_activity_event(self, activity_type: str) -> BidiInterruptionEvent | None: @@ -200,48 +198,58 @@ def _create_voice_activity_event(self, activity_type: str) -> BidiInterruptionEv def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict: """Build session configuration for OpenAI Realtime API.""" config = DEFAULT_SESSION_CONFIG.copy() - + if system_prompt: config["instructions"] = system_prompt - + if tools: config["tools"] = self._convert_tools_to_openai_format(tools) - + # Apply user-provided session configuration supported_params = { - "type", "output_modalities", "instructions", "voice", "audio", - "tools", "tool_choice", "input_audio_format", "output_audio_format", - "input_audio_transcription", "turn_detection" + "type", + "output_modalities", + "instructions", + "voice", + "audio", + "tools", + "tool_choice", + "input_audio_format", + "output_audio_format", + "input_audio_transcription", + "turn_detection", } - + for key, value in self.session_config.items(): if key in supported_params: config[key] = value else: logger.warning("Ignoring unsupported session parameter: %s", key) - + return config def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: """Convert Strands tool specifications to OpenAI Realtime API format.""" openai_tools = [] - + for tool in tools: input_schema = tool["inputSchema"] if "json" in input_schema: - schema = json.loads(input_schema["json"]) if isinstance(input_schema["json"], str) else input_schema["json"] + schema = ( + json.loads(input_schema["json"]) if isinstance(input_schema["json"], str) else input_schema["json"] + ) else: schema = input_schema - + # OpenAI Realtime API expects flat structure, not nested under "function" openai_tool = { "type": "function", "name": tool["name"], "description": tool["description"], - "parameters": schema + "parameters": schema, } openai_tools.append(openai_tool) - + return openai_tools async def _add_conversation_history(self, messages: Messages) -> None: @@ -249,38 +257,37 @@ async def _add_conversation_history(self, messages: Messages) -> None: for message in messages: conversation_item = { "type": "conversation.item.create", - "item": {"type": "message", "role": message["role"], "content": []} + "item": {"type": "message", "role": message["role"], "content": []}, } - + content = message.get("content", "") if isinstance(content, str): conversation_item["item"]["content"].append({"type": "input_text", "text": content}) elif isinstance(content, list): for item in content: if isinstance(item, dict) and item.get("type") == "text": - conversation_item["item"]["content"].append({"type": "input_text", "text": item.get("text", "")}) - + conversation_item["item"]["content"].append( + {"type": "input_text", "text": item.get("text", "")} + ) + await self._send_event(conversation_item) async def receive(self) -> AsyncIterable[BidiOutputEvent]: """Receive OpenAI events and convert to Strands TypedEvent format.""" # Emit connection start event - yield BidiConnectionStartEvent( - connection_id=self.connection_id, - model=self.model - ) - + yield BidiConnectionStartEvent(connection_id=self.connection_id, model=self.model) + try: while self._active: async for message in self.websocket: if not self._active: break - + openai_event = json.loads(message) - for event in self._convert_openai_event(openai_event) or []: + for event in self._convert_openai_event(openai_event) or []: yield event - + except Exception as e: logger.error("Error receiving OpenAI Realtime event: %s", e) yield BidiErrorEvent(error=e) @@ -292,27 +299,30 @@ async def receive(self) -> AsyncIterable[BidiOutputEvent]: def _convert_openai_event(self, openai_event: dict[str, any]) -> list[BidiOutputEvent] | None: """Convert OpenAI events to Strands TypedEvent format.""" event_type = openai_event.get("type") - + # Turn start - response begins if event_type == "response.created": response = openai_event.get("response", {}) response_id = response.get("id", str(uuid.uuid4())) return [BidiResponseStartEvent(response_id=response_id)] - + # Audio output elif event_type == "response.output_audio.delta": # Audio is already base64 string from OpenAI - return [BidiAudioStreamEvent( - audio=openai_event["delta"], - format="pcm", - sample_rate=AUDIO_FORMAT["rate"], - channels=1 - )] - + return [ + BidiAudioStreamEvent( + audio=openai_event["delta"], format="pcm", sample_rate=AUDIO_FORMAT["rate"], channels=1 + ) + ] + # Assistant text output events - combine multiple similar events elif event_type in ["response.output_text.delta", "response.output_audio_transcript.delta"]: role = openai_event.get("role", "assistant") - return [self._create_text_event(openai_event["delta"], role.lower() if isinstance(role, str) else "assistant", is_final=False)] + return [ + self._create_text_event( + openai_event["delta"], role.lower() if isinstance(role, str) else "assistant", is_final=False + ) + ] elif event_type in ["response.output_audio_transcript.done"]: role = openai_event.get("role", "assistant").lower() @@ -321,27 +331,37 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[BidiOutput elif event_type in ["response.output_text.done"]: role = openai_event.get("role", "assistant").lower() return [self._create_text_event(openai_event["text"], role)] - + # User transcription events - combine multiple similar events - elif event_type in ["conversation.item.input_audio_transcription.delta", - "conversation.item.input_audio_transcription.completed"]: + elif event_type in [ + "conversation.item.input_audio_transcription.delta", + "conversation.item.input_audio_transcription.completed", + ]: text_key = "delta" if "delta" in event_type else "transcript" text = openai_event.get(text_key, "") role = openai_event.get("role", "user") is_final = "completed" in event_type - return [self._create_text_event(text, role.lower() if isinstance(role, str) else "user", is_final=is_final)] if text.strip() else None - + return ( + [self._create_text_event(text, role.lower() if isinstance(role, str) else "user", is_final=is_final)] + if text.strip() + else None + ) + elif event_type == "conversation.item.input_audio_transcription.segment": segment_data = openai_event.get("segment", {}) text = segment_data.get("text", "") role = segment_data.get("role", "user") - return [self._create_text_event(text, role.lower() if isinstance(role, str) else "user")] if text.strip() else None - + return ( + [self._create_text_event(text, role.lower() if isinstance(role, str) else "user")] + if text.strip() + else None + ) + elif event_type == "conversation.item.input_audio_transcription.failed": error_info = openai_event.get("error", {}) logger.warning("OpenAI transcription failed: %s", error_info.get("message", "Unknown error")) return None - + # Function call processing elif event_type == "response.function_call_arguments.delta": call_id = openai_event.get("call_id") @@ -352,7 +372,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[BidiOutput else: self._function_call_buffer[call_id]["arguments"] += delta return None - + elif event_type == "response.function_call_arguments.done": call_id = openai_event.get("call_id") if call_id and call_id in self._function_call_buffer: @@ -365,123 +385,114 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[BidiOutput } del self._function_call_buffer[call_id] # Return ToolUseStreamEvent for consistency with standard agent - return [ToolUseStreamEvent( - delta={"toolUse": tool_use}, - current_tool_use=tool_use - )] + return [ToolUseStreamEvent(delta={"toolUse": tool_use}, current_tool_use=tool_use)] except (json.JSONDecodeError, KeyError) as e: logger.warning("Error parsing function arguments for %s: %s", call_id, e) del self._function_call_buffer[call_id] return None - + # Voice activity detection - speech_started triggers interruption elif event_type == "input_audio_buffer.speech_started": # This is the primary interruption signal - handle it first return [BidiInterruptionEvent(reason="user_speech")] - + # Response cancelled - handle interruption elif event_type == "response.cancelled": response = openai_event.get("response", {}) response_id = response.get("id", "unknown") logger.debug("OpenAI response cancelled: %s", response_id) - return [BidiResponseCompleteEvent( - response_id=response_id, - stop_reason="interrupted" - )] - + return [BidiResponseCompleteEvent(response_id=response_id, stop_reason="interrupted")] + # Turn complete and usage - response finished elif event_type == "response.done": response = openai_event.get("response", {}) response_id = response.get("id", "unknown") status = response.get("status", "completed") usage = response.get("usage") - + # Map OpenAI status to our stop_reason stop_reason_map = { "completed": "complete", "cancelled": "interrupted", "failed": "error", - "incomplete": "interrupted" + "incomplete": "interrupted", } - + # Build list of events to return events = [] - + # Always add response complete event - events.append(BidiResponseCompleteEvent( - response_id=response_id, - stop_reason=stop_reason_map.get(status, "complete") - )) - + events.append( + BidiResponseCompleteEvent(response_id=response_id, stop_reason=stop_reason_map.get(status, "complete")) + ) + # Add usage event if available if usage: input_details = usage.get("input_token_details", {}) output_details = usage.get("output_token_details", {}) - + # Build modality details modality_details = [] - + # Text modality text_input = input_details.get("text_tokens", 0) text_output = output_details.get("text_tokens", 0) if text_input > 0 or text_output > 0: - modality_details.append({ - "modality": "text", - "input_tokens": text_input, - "output_tokens": text_output - }) - + modality_details.append( + {"modality": "text", "input_tokens": text_input, "output_tokens": text_output} + ) + # Audio modality audio_input = input_details.get("audio_tokens", 0) audio_output = output_details.get("audio_tokens", 0) if audio_input > 0 or audio_output > 0: - modality_details.append({ - "modality": "audio", - "input_tokens": audio_input, - "output_tokens": audio_output - }) - + modality_details.append( + {"modality": "audio", "input_tokens": audio_input, "output_tokens": audio_output} + ) + # Image modality image_input = input_details.get("image_tokens", 0) if image_input > 0: - modality_details.append({ - "modality": "image", - "input_tokens": image_input, - "output_tokens": 0 - }) - + modality_details.append({"modality": "image", "input_tokens": image_input, "output_tokens": 0}) + # Cached tokens cached_tokens = input_details.get("cached_tokens", 0) - + # Add usage event - events.append(BidiUsageEvent( - input_tokens=usage.get("input_tokens", 0), - output_tokens=usage.get("output_tokens", 0), - total_tokens=usage.get("total_tokens", 0), - modality_details=modality_details if modality_details else None, - cache_read_input_tokens=cached_tokens if cached_tokens > 0 else None - )) - + events.append( + BidiUsageEvent( + input_tokens=usage.get("input_tokens", 0), + output_tokens=usage.get("output_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + modality_details=modality_details if modality_details else None, + cache_read_input_tokens=cached_tokens if cached_tokens > 0 else None, + ) + ) + # Return list of events return events - + # Lifecycle events (log only) - combine multiple similar events elif event_type in ["conversation.item.retrieve", "conversation.item.added"]: item = openai_event.get("item", {}) action = "retrieved" if "retrieve" in event_type else "added" logger.debug("OpenAI conversation item %s: %s", action, item.get("id")) return None - + elif event_type == "conversation.item.done": logger.debug("OpenAI conversation item done: %s", openai_event.get("item", {}).get("id")) return None - + # Response output events - combine similar events - elif event_type in ["response.output_item.added", "response.output_item.done", - "response.content_part.added", "response.content_part.done"]: + elif event_type in [ + "response.output_item.added", + "response.output_item.done", + "response.content_part.added", + "response.content_part.done", + ]: item_data = openai_event.get("item") or openai_event.get("part") logger.debug("OpenAI %s: %s", event_type, item_data.get("id") if item_data else "unknown") - + # Track function call names from response.output_item.added if event_type == "response.output_item.added": item = openai_event.get("item", {}) @@ -490,32 +501,40 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[BidiOutput function_name = item.get("name") if call_id and function_name: if call_id not in self._function_call_buffer: - self._function_call_buffer[call_id] = {"call_id": call_id, "name": function_name, "arguments": ""} + self._function_call_buffer[call_id] = { + "call_id": call_id, + "name": function_name, + "arguments": "", + } else: self._function_call_buffer[call_id]["name"] = function_name return None - + # Session/buffer events - combine simple log-only events - elif event_type in ["input_audio_buffer.committed", "input_audio_buffer.cleared", - "session.created", "session.updated"]: + elif event_type in [ + "input_audio_buffer.committed", + "input_audio_buffer.cleared", + "session.created", + "session.updated", + ]: logger.debug("OpenAI %s event", event_type) return None - + elif event_type == "error": error_data = openai_event.get("error", {}) error_code = error_data.get("code", "") - + # Suppress expected errors that don't affect session state if error_code == "response_cancel_not_active": # This happens when trying to cancel a response that's not active # It's safe to ignore as the session remains functional logger.debug("OpenAI response cancel attempted when no response active (safe to ignore)") return None - + # Log other errors logger.error("OpenAI Realtime error: %s", error_data) return None - + else: logger.debug("Unhandled OpenAI event type: %s", event_type) return None @@ -525,15 +544,15 @@ async def send( content: BidiInputEvent | ToolResultEvent, ) -> None: """Unified send method for all content types. Sends the given content to OpenAI. - + Dispatches to appropriate internal handler based on content type. - + Args: content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent). """ if not self._require_active(): return - + try: # Note: TypedEvent inherits from dict, so isinstance checks for TypedEvent must come first if isinstance(content, BidiTextInputEvent): @@ -560,11 +579,7 @@ async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: async def _send_text_content(self, text: str) -> None: """Internal: Send text content to OpenAI for processing.""" - item_data = { - "type": "message", - "role": "user", - "content": [{"type": "input_text", "text": text}] - } + item_data = {"type": "message", "role": "user", "content": [{"type": "input_text", "text": text}]} await self._send_event({"type": "conversation.item.create", "item": item_data}) await self._send_event({"type": "response.create"}) @@ -575,9 +590,9 @@ async def _send_interrupt(self) -> None: async def _send_tool_result(self, tool_result: ToolResult) -> None: """Internal: Send tool result back to OpenAI.""" tool_use_id = tool_result.get("toolUseId") - + logger.debug("OpenAI tool result send: %s", tool_use_id) - + # Extract result content result_data = {} if "content" in tool_result: @@ -586,14 +601,10 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: if "text" in block: result_data = block["text"] break - + result_text = json.dumps(result_data) if not isinstance(result_data, str) else result_data - - item_data = { - "type": "function_call_output", - "call_id": tool_use_id, - "output": result_text - } + + item_data = {"type": "function_call_output", "call_id": tool_use_id, "output": result_text} await self._send_event({"type": "conversation.item.create", "item": item_data}) await self._send_event({"type": "response.create"}) @@ -601,15 +612,15 @@ async def stop(self) -> None: """Close session and cleanup resources.""" if not self._active: return - + logger.debug("OpenAI Realtime cleanup - starting connection close") self._active = False - + try: await self.websocket.close() except Exception as e: logger.warning("Error closing OpenAI Realtime WebSocket: %s", e) - + logger.debug("OpenAI Realtime connection closed") async def _send_event(self, event: dict[str, any]) -> None: @@ -621,5 +632,3 @@ async def _send_event(self, event: dict[str, any]) -> None: except Exception as e: logger.error("Error sending OpenAI event: %s", e) raise - - diff --git a/src/strands/experimental/bidi/scripts/test_bidi.py b/src/strands/experimental/bidi/scripts/test_bidi.py index f07ef1fc4..85480bfaa 100644 --- a/src/strands/experimental/bidi/scripts/test_bidi.py +++ b/src/strands/experimental/bidi/scripts/test_bidi.py @@ -6,16 +6,15 @@ sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) +from strands_tools import calculator + from strands.experimental.bidi.agent.agent import BidiAgent -from strands.experimental.bidi.models.novasonic import BidiNovaSonicModel from strands.experimental.bidi.io import BidiAudioIO, BidiTextIO -from strands_tools import calculator +from strands.experimental.bidi.models.novasonic import BidiNovaSonicModel async def main(): """Test the BidirectionalAgent API.""" - - # Nova Sonic model audio_io = BidiAudioIO() text_io = BidiTextIO() @@ -35,4 +34,5 @@ async def main(): except Exception as e: print(f"❌ Error: {e}") import traceback + traceback.print_exc() diff --git a/src/strands/experimental/bidi/scripts/test_bidi_novasonic.py b/src/strands/experimental/bidi/scripts/test_bidi_novasonic.py index 38654f7fd..d6ce5f0c7 100644 --- a/src/strands/experimental/bidi/scripts/test_bidi_novasonic.py +++ b/src/strands/experimental/bidi/scripts/test_bidi_novasonic.py @@ -131,7 +131,7 @@ async def receive(agent, context): try: async for event in agent.receive(): event_type = event.get("type", "unknown") - + # Handle audio stream events (bidi_audio_stream) if event_type == "bidi_audio_stream": if not context.get("interrupted", False): @@ -148,25 +148,25 @@ async def receive(agent, context): elif event_type == "bidi_transcript_stream": text_content = event.get("text", "") role = event.get("role", "unknown") - + # Log transcript output if role == "user": print(f"User: {text_content}") elif role == "assistant": print(f"Assistant: {text_content}") - + # Handle response complete events (bidi_response_complete) elif event_type == "bidi_response_complete": # Reset interrupted state since the turn is complete context["interrupted"] = False - + # Handle tool use events (tool_use_stream) elif event_type == "tool_use_stream": tool_use = event.get("current_tool_use", {}) tool_name = tool_use.get("name", "unknown") tool_input = tool_use.get("input", {}) print(f"🔧 Tool called: {tool_name} with input: {tool_input}") - + # Handle tool result events (tool_result) elif event_type == "tool_result": tool_result = event.get("tool_result", {}) @@ -191,14 +191,9 @@ async def send(agent, context): audio_bytes = context["audio_in"].get_nowait() # Create audio event using TypedEvent from strands.experimental.bidi.types.events import BidiAudioInputEvent - - audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') - audio_event = BidiAudioInputEvent( - audio=audio_b64, - format="pcm", - sample_rate=16000, - channels=1 - ) + + audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") + audio_event = BidiAudioInputEvent(audio=audio_b64, format="pcm", sample_rate=16000, channels=1) await agent.send(audio_event) except asyncio.QueueEmpty: await asyncio.sleep(0.01) # Restored to working timing diff --git a/src/strands/experimental/bidi/scripts/test_bidi_openai.py b/src/strands/experimental/bidi/scripts/test_bidi_openai.py index 71e934fb7..2243ac84d 100644 --- a/src/strands/experimental/bidi/scripts/test_bidi_openai.py +++ b/src/strands/experimental/bidi/scripts/test_bidi_openai.py @@ -21,7 +21,7 @@ async def play(context): """Handle audio playback with interruption support.""" audio = pyaudio.PyAudio() - + try: speaker = audio.open( format=pyaudio.paInt16, @@ -30,7 +30,7 @@ async def play(context): output=True, frames_per_buffer=1024, ) - + while context["active"]: try: # Check for interruption @@ -41,32 +41,32 @@ async def play(context): context["audio_out"].get_nowait() except asyncio.QueueEmpty: break - + context["interrupted"] = False await asyncio.sleep(0.05) continue - + # Get audio data with timeout try: audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) - + if audio_data and context["active"]: # Play in chunks to allow interruption chunk_size = 1024 for i in range(0, len(audio_data), chunk_size): if context.get("interrupted", False) or not context["active"]: break - - chunk = audio_data[i:i + chunk_size] + + chunk = audio_data[i : i + chunk_size] speaker.write(chunk) await asyncio.sleep(0.001) # Brief pause for responsiveness - + except asyncio.TimeoutError: continue - + except asyncio.CancelledError: break - + except asyncio.CancelledError: pass except Exception as e: @@ -82,7 +82,7 @@ async def play(context): async def record(context): """Handle microphone recording.""" audio = pyaudio.PyAudio() - + try: microphone = audio.open( format=pyaudio.paInt16, @@ -91,7 +91,7 @@ async def record(context): input=True, frames_per_buffer=1024, ) - + while context["active"]: try: audio_bytes = microphone.read(1024, exception_on_overflow=False) @@ -99,7 +99,7 @@ async def record(context): await asyncio.sleep(0.01) except asyncio.CancelledError: break - + except asyncio.CancelledError: pass except Exception as e: @@ -118,57 +118,57 @@ async def receive(agent, context): async for event in agent.receive(): if not context["active"]: break - + # Get event type event_type = event.get("type", "unknown") - + # Handle audio stream events (bidi_audio_stream) if event_type == "bidi_audio_stream": # Decode base64 audio string to bytes for playback audio_b64 = event["audio"] audio_data = base64.b64decode(audio_b64) - + if not context.get("interrupted", False): await context["audio_out"].put(audio_data) - + # Handle transcript events (bidi_transcript_stream) elif event_type == "bidi_transcript_stream": source = event.get("role", "assistant") text = event.get("text", "").strip() - + if text: if source == "user": print(f"🎤 User: {text}") elif source == "assistant": print(f"🔊 Assistant: {text}") - + # Handle interruption events (bidi_interruption) elif event_type == "bidi_interruption": context["interrupted"] = True print("⚠️ Interruption detected") - + # Handle connection start events (bidi_connection_start) elif event_type == "bidi_connection_start": print(f"✓ Session started: {event.get('model', 'unknown')}") - + # Handle connection close events (bidi_connection_close) elif event_type == "bidi_connection_close": print(f"✓ Session ended: {event.get('reason', 'unknown')}") context["active"] = False break - + # Handle response complete events (bidi_response_complete) elif event_type == "bidi_response_complete": # Reset interrupted state since the turn is complete context["interrupted"] = False - + # Handle tool use events (tool_use_stream) elif event_type == "tool_use_stream": tool_use = event.get("current_tool_use", {}) tool_name = tool_use.get("name", "unknown") tool_input = tool_use.get("input", {}) print(f"🔧 Tool called: {tool_name} with input: {tool_input}") - + # Handle tool result events (tool_result) elif event_type == "tool_result": tool_result = event.get("tool_result", {}) @@ -180,7 +180,7 @@ async def receive(agent, context): result_text = block.get("text", "") break print(f"✅ Tool result from {tool_name}: {result_text}") - + except asyncio.CancelledError: pass except Exception as e: @@ -195,26 +195,21 @@ async def send(agent, context): while context["active"]: try: audio_bytes = await asyncio.wait_for(context["audio_in"].get(), timeout=0.1) - + # Create audio event using TypedEvent # Encode audio bytes to base64 string for JSON serializability from strands.experimental.bidi.types.events import BidiAudioInputEvent - - audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') - audio_event = BidiAudioInputEvent( - audio=audio_b64, - format="pcm", - sample_rate=24000, - channels=1 - ) - + + audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") + audio_event = BidiAudioInputEvent(audio=audio_b64, format="pcm", sample_rate=24000, channels=1) + await agent.send(audio_event) - + except asyncio.TimeoutError: continue except asyncio.CancelledError: break - + except asyncio.CancelledError: pass except Exception as e: @@ -226,13 +221,13 @@ async def send(agent, context): async def main(): """Main test function for OpenAI voice chat.""" print("Starting OpenAI Realtime API test...") - + # Check API key api_key = os.getenv("OPENAI_API_KEY") if not api_key: print("OPENAI_API_KEY environment variable not set") return False - + # Check audio system try: audio = pyaudio.PyAudio() @@ -240,7 +235,7 @@ async def main(): except Exception as e: print(f"Audio system error: {e}") return False - + # Create OpenAI model model = BidiOpenAIRealtimeModel( model="gpt-4o-realtime-preview", @@ -250,51 +245,40 @@ async def main(): "audio": { "input": { "format": {"type": "audio/pcm", "rate": 24000}, - "turn_detection": { - "type": "server_vad", - "threshold": 0.5, - "silence_duration_ms": 700 - } + "turn_detection": {"type": "server_vad", "threshold": 0.5, "silence_duration_ms": 700}, }, - "output": { - "format": {"type": "audio/pcm", "rate": 24000}, - "voice": "alloy" - } - } - } + "output": {"format": {"type": "audio/pcm", "rate": 24000}, "voice": "alloy"}, + }, + }, ) - + # Create agent agent = BidiAgent( model=model, tools=[calculator], - system_prompt="You are a helpful voice assistant. Keep your responses brief and natural. Say hello when you first connect." + system_prompt="You are a helpful voice assistant. Keep your responses brief and natural. Say hello when you first connect.", ) - + # Start the session await agent.start() - + # Create shared context context = { "active": True, "audio_in": asyncio.Queue(), "audio_out": asyncio.Queue(), "interrupted": False, - "start_time": time.time() + "start_time": time.time(), } - + print("Speak into your microphone. Press Ctrl+C to stop.") - + try: # Run all tasks concurrently await asyncio.gather( - play(context), - record(context), - receive(agent, context), - send(agent, context), - return_exceptions=True + play(context), record(context), receive(agent, context), send(agent, context), return_exceptions=True ) - + except KeyboardInterrupt: print("\nInterrupted by user") except asyncio.CancelledError: @@ -304,12 +288,12 @@ async def main(): finally: print("Cleaning up...") context["active"] = False - + try: await agent.stop() except Exception as e: print(f"Cleanup error: {e}") - + return True @@ -321,4 +305,5 @@ async def main(): except Exception as e: print(f"Test error: {e}") import traceback - traceback.print_exc() \ No newline at end of file + + traceback.print_exc() diff --git a/src/strands/experimental/bidi/scripts/test_gemini_live.py b/src/strands/experimental/bidi/scripts/test_gemini_live.py index 807a8da2b..ee2010c6f 100644 --- a/src/strands/experimental/bidi/scripts/test_gemini_live.py +++ b/src/strands/experimental/bidi/scripts/test_gemini_live.py @@ -28,6 +28,7 @@ try: import cv2 import PIL.Image + CAMERA_AVAILABLE = True except ImportError as e: print(f"Camera dependencies not available: {e}") @@ -41,9 +42,9 @@ from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel # Configure logging - debug only for Gemini Live, info for everything else -logging.basicConfig(level=logging.WARN, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -gemini_logger = logging.getLogger('strands.experimental.bidirectional_streaming.models.gemini_live') -gemini_logger.setLevel(logging.WARN) +logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +gemini_logger = logging.getLogger("strands.experimental.bidirectional_streaming.models.gemini_live") +gemini_logger.setLevel(logging.WARNING) logger = logging.getLogger(__name__) @@ -106,18 +107,18 @@ async def play(context): async def record(context): """Record audio input from microphone.""" audio = pyaudio.PyAudio() - + # List all available audio devices print("Available audio devices:") for i in range(audio.get_device_count()): device_info = audio.get_device_info_by_index(i) - if device_info['maxInputChannels'] > 0: # Only show input devices + if device_info["maxInputChannels"] > 0: # Only show input devices print(f" Device {i}: {device_info['name']} (inputs: {device_info['maxInputChannels']})") - + # Get default input device info default_device = audio.get_default_input_device_info() print(f"\nUsing default input device: {default_device['name']} (Device {default_device['index']})") - + microphone = audio.open( channels=1, format=pyaudio.paInt16, @@ -146,7 +147,7 @@ async def receive(agent, context): try: async for event in agent.receive(): event_type = event.get("type", "unknown") - + # Handle audio stream events (bidi_audio_stream) if event_type == "bidi_audio_stream": if not context.get("interrupted", False): @@ -165,25 +166,25 @@ async def receive(agent, context): transcript_text = event.get("text", "") transcript_role = event.get("role", "unknown") is_final = event.get("is_final", False) - + # Print transcripts with special formatting if transcript_role == "user": print(f"🎤 User: {transcript_text}") elif transcript_role == "assistant": print(f"🔊 Assistant: {transcript_text}") - + # Handle response complete events (bidi_response_complete) elif event_type == "bidi_response_complete": # Reset interrupted state since the response is complete context["interrupted"] = False - + # Handle tool use events (tool_use_stream) elif event_type == "tool_use_stream": tool_use = event.get("current_tool_use", {}) tool_name = tool_use.get("name", "unknown") tool_input = tool_use.get("input", {}) print(f"🔧 Tool called: {tool_name} with input: {tool_input}") - + # Handle tool result events (tool_result) elif event_type == "tool_result": tool_result = event.get("tool_result", {}) @@ -205,7 +206,7 @@ def _get_frame(cap): """Capture and process a frame from camera.""" if not CAMERA_AVAILABLE: return None - + # Read the frame ret, frame = cap.read() # Check if the frame was read successfully @@ -232,11 +233,11 @@ async def get_frames(context): if not CAMERA_AVAILABLE: print("Camera not available - skipping video capture") return - + # This takes about a second, and will block the whole program # causing the audio pipeline to overflow if you don't to_thread it. cap = await asyncio.to_thread(cv2.VideoCapture, 0) # 0 represents the default camera - + print("Camera initialized. Starting video capture...") try: @@ -248,10 +249,10 @@ async def get_frames(context): # Send frame to agent as image input try: from strands.experimental.bidi.types.events import BidiImageInputEvent - + image_event = BidiImageInputEvent( image=frame["data"], # Already base64 encoded - mime_type=frame["mime_type"] + mime_type=frame["mime_type"], ) await context["agent"].send(image_event) print("📸 Frame sent to model") @@ -276,14 +277,9 @@ async def send(agent, context): audio_bytes = context["audio_in"].get_nowait() # Create audio event using TypedEvent from strands.experimental.bidi.types.events import BidiAudioInputEvent - - audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') - audio_event = BidiAudioInputEvent( - audio=audio_b64, - format="pcm", - sample_rate=16000, - channels=1 - ) + + audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") + audio_event = BidiAudioInputEvent(audio=audio_b64, format="pcm", sample_rate=16000, channels=1) await agent.send(audio_event) except asyncio.QueueEmpty: await asyncio.sleep(0.01) @@ -303,25 +299,21 @@ async def main(duration=180): # Get API key from environment variable api_key = os.getenv("GOOGLE_AI_API_KEY") - + if not api_key: print("ERROR: GOOGLE_AI_API_KEY environment variable not set") print("Please set it with: export GOOGLE_AI_API_KEY=your_api_key") return - + # Initialize Gemini Live model with proper configuration logger.info("Initializing Gemini Live model with API key") - + # Use default model and config (includes transcription enabled by default) model = BidiGeminiLiveModel(api_key=api_key) logger.info("Gemini Live model initialized successfully") print("Using Gemini Live model with default config (audio output + transcription enabled)") - - agent = BidiAgent( - model=model, - tools=[calculator], - system_prompt="You are a helpful assistant." - ) + + agent = BidiAgent(model=model, tools=[calculator], system_prompt="You are a helpful assistant.") await agent.start() @@ -342,12 +334,12 @@ async def main(duration=180): try: # Run all tasks concurrently including camera await asyncio.gather( - play(context), - record(context), - receive(agent, context), + play(context), + record(context), + receive(agent, context), send(agent, context), get_frames(context), # Add camera task - return_exceptions=True + return_exceptions=True, ) except KeyboardInterrupt: print("\nInterrupted by user") @@ -360,4 +352,4 @@ async def main(duration=180): if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/src/strands/experimental/bidi/types/__init__.py b/src/strands/experimental/bidi/types/__init__.py index d5263bb28..d8525e23a 100644 --- a/src/strands/experimental/bidi/types/__init__.py +++ b/src/strands/experimental/bidi/types/__init__.py @@ -1,7 +1,6 @@ """Type definitions for bidirectional streaming.""" from .agent import BidiAgentInput -from .io import BidiInput, BidiOutput from .events import ( DEFAULT_CHANNELS, DEFAULT_FORMAT, @@ -17,14 +16,15 @@ BidiImageInputEvent, BidiInputEvent, BidiInterruptionEvent, - ModalityUsage, - BidiUsageEvent, BidiOutputEvent, BidiResponseCompleteEvent, BidiResponseStartEvent, BidiTextInputEvent, BidiTranscriptStreamEvent, + BidiUsageEvent, + ModalityUsage, ) +from .io import BidiInput, BidiOutput __all__ = [ "BidiInput", diff --git a/src/strands/experimental/bidi/types/events.py b/src/strands/experimental/bidi/types/events.py index 852950f5a..ee2114bc2 100644 --- a/src/strands/experimental/bidi/types/events.py +++ b/src/strands/experimental/bidi/types/events.py @@ -19,7 +19,7 @@ - Audio data stored as base64-encoded strings for JSON compatibility """ -from typing import Any, Dict, List, Literal, Optional, Union, cast +from typing import Any, Dict, List, Literal, Optional, cast from ....types._events import ModelStreamEvent, TypedEvent from ....types.streaming import ContentBlockDelta @@ -236,7 +236,7 @@ def channels(self) -> int: class BidiTranscriptStreamEvent(ModelStreamEvent): """Audio transcription streaming (user or assistant speech). - + Supports incremental transcript updates for providers that send partial transcripts before the final version. @@ -478,7 +478,7 @@ def __init__( @property def error(self) -> Exception: """The original exception that occurred. - + Can be used for re-raising or type-based error handling. """ return self._error diff --git a/src/strands/experimental/bidi/types/io.py b/src/strands/experimental/bidi/types/io.py index 8b79455ec..10ae5db77 100644 --- a/src/strands/experimental/bidi/types/io.py +++ b/src/strands/experimental/bidi/types/io.py @@ -27,12 +27,13 @@ async def stop(self) -> None: def __call__(self) -> Awaitable[BidiInputEvent]: """Read input data from the source. - + Returns: Awaitable that resolves to an input event (audio, text, image, etc.) """ ... + class BidiOutput(Protocol): """Protocol for bidirectional output callables. @@ -50,7 +51,7 @@ async def stop(self) -> None: def __call__(self, event: BidiOutputEvent) -> Awaitable[None]: """Process output events from the agent. - + Args: event: Output event from the agent (audio, text, tool calls, etc.) """ From 2f956bfe4eca736f51eb52a2238ff1dd3fb7841d Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Mon, 17 Nov 2025 07:50:12 -0800 Subject: [PATCH 124/242] Format tests and fix linting errors - D102, D107 --- src/strands/experimental/bidi/types/events.py | 31 +++++ .../experimental/bidi/io/test_audio.py | 13 +- .../bidi/models/test_gemini_live.py | 128 +++++++++--------- .../bidi/models/test_novasonic.py | 49 ++----- .../bidi/models/test_openai_realtime.py | 78 ++++------- .../experimental/bidi/types/test_events.py | 11 +- tests_integ/bidi/context.py | 35 ++--- tests_integ/bidi/generators/audio.py | 4 +- tests_integ/bidi/test_bidirectional_agent.py | 53 ++++---- 9 files changed, 188 insertions(+), 214 deletions(-) diff --git a/src/strands/experimental/bidi/types/events.py b/src/strands/experimental/bidi/types/events.py index ee2114bc2..39df53a0a 100644 --- a/src/strands/experimental/bidi/types/events.py +++ b/src/strands/experimental/bidi/types/events.py @@ -59,10 +59,12 @@ def __init__(self, text: str, role: str): @property def text(self) -> str: + """The text content to send to the model.""" return cast(str, self.get("text")) @property def role(self) -> str: + """The role of the message sender.""" return cast(str, self.get("role")) @@ -97,18 +99,22 @@ def __init__( @property def audio(self) -> str: + """Base64-encoded audio string.""" return cast(str, self.get("audio")) @property def format(self) -> str: + """Audio encoding format.""" return cast(str, self.get("format")) @property def sample_rate(self) -> int: + """Number of audio samples per second in Hz.""" return cast(int, self.get("sample_rate")) @property def channels(self) -> int: + """Number of audio channels (1=mono, 2=stereo).""" return cast(int, self.get("channels")) @@ -137,10 +143,12 @@ def __init__( @property def image(self) -> str: + """Base64-encoded image string.""" return cast(str, self.get("image")) @property def mime_type(self) -> str: + """MIME type of the image (e.g., "image/jpeg", "image/png").""" return cast(str, self.get("mime_type")) @@ -168,10 +176,12 @@ def __init__(self, connection_id: str, model: str): @property def connection_id(self) -> str: + """Unique identifier for this streaming connection.""" return cast(str, self.get("connection_id")) @property def model(self) -> str: + """Model identifier (e.g., 'gpt-realtime', 'gemini-2.0-flash-live').""" return cast(str, self.get("model")) @@ -187,6 +197,7 @@ def __init__(self, response_id: str): @property def response_id(self) -> str: + """Unique identifier for this response.""" return cast(str, self.get("response_id")) @@ -219,18 +230,22 @@ def __init__( @property def audio(self) -> str: + """Base64-encoded audio string.""" return cast(str, self.get("audio")) @property def format(self) -> str: + """Audio encoding format.""" return cast(str, self.get("format")) @property def sample_rate(self) -> int: + """Number of audio samples per second in Hz.""" return cast(int, self.get("sample_rate")) @property def channels(self) -> int: + """Number of audio channels (1=mono, 2=stereo).""" return cast(int, self.get("channels")) @@ -269,22 +284,27 @@ def __init__( @property def delta(self) -> ContentBlockDelta: + """The incremental transcript change.""" return cast(ContentBlockDelta, self.get("delta")) @property def text(self) -> str: + """The text content to send to the model.""" return cast(str, self.get("text")) @property def role(self) -> str: + """The role of the message sender.""" return cast(str, self.get("role")) @property def is_final(self) -> bool: + """Whether this is the final/complete transcript.""" return cast(bool, self.get("is_final")) @property def current_transcript(self) -> Optional[str]: + """The accumulated transcript text so far.""" return cast(Optional[str], self.get("current_transcript")) @@ -306,6 +326,7 @@ def __init__(self, reason: Literal["user_speech", "error"]): @property def reason(self) -> str: + """Why the interruption occurred.""" return cast(str, self.get("reason")) @@ -332,10 +353,12 @@ def __init__( @property def response_id(self) -> str: + """Unique identifier for this response.""" return cast(str, self.get("response_id")) @property def stop_reason(self) -> str: + """Why the response ended.""" return cast(str, self.get("stop_reason")) @@ -393,26 +416,32 @@ def __init__( @property def input_tokens(self) -> int: + """Total tokens used for all input modalities.""" return cast(int, self.get("inputTokens")) @property def output_tokens(self) -> int: + """Total tokens used for all output modalities.""" return cast(int, self.get("outputTokens")) @property def total_tokens(self) -> int: + """Sum of input and output tokens.""" return cast(int, self.get("totalTokens")) @property def modality_details(self) -> List[ModalityUsage]: + """Optional list of token usage per modality.""" return cast(List[ModalityUsage], self.get("modality_details", [])) @property def cache_read_input_tokens(self) -> Optional[int]: + """Optional tokens read from cache.""" return cast(Optional[int], self.get("cacheReadInputTokens")) @property def cache_write_input_tokens(self) -> Optional[int]: + """Optional tokens written to cache.""" return cast(Optional[int], self.get("cacheWriteInputTokens")) @@ -439,10 +468,12 @@ def __init__( @property def connection_id(self) -> str: + """Unique identifier for this streaming connection.""" return cast(str, self.get("connection_id")) @property def reason(self) -> str: + """Why the interruption occurred.""" return cast(str, self.get("reason")) diff --git a/tests/strands/experimental/bidi/io/test_audio.py b/tests/strands/experimental/bidi/io/test_audio.py index 9a3c2979c..e5e710b98 100644 --- a/tests/strands/experimental/bidi/io/test_audio.py +++ b/tests/strands/experimental/bidi/io/test_audio.py @@ -5,7 +5,7 @@ import pytest from strands.experimental.bidi.io import BidiAudioIO -from strands.experimental.bidi.types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiInterruptionEvent +from strands.experimental.bidi.types.events import BidiAudioInputEvent, BidiAudioStreamEvent @pytest.fixture @@ -35,7 +35,7 @@ async def test_bidi_audio_io_input(py_audio, audio_input): microphone.read.return_value = b"test-audio" py_audio.open.return_value = microphone - + await audio_input.start() tru_event = await audio_input() await audio_input.stop() @@ -55,17 +55,18 @@ async def test_bidi_audio_io_input(py_audio, audio_input): async def test_bidi_audio_io_output(py_audio, audio_output): write_future = asyncio.Future() write_event = asyncio.Event() + def write(data): write_future.set_result(data) write_event.set() - + speaker = unittest.mock.Mock() speaker.write.side_effect = write py_audio.open.return_value = speaker - + await audio_output.start() - + audio_event = BidiAudioStreamEvent( audio=base64.b64encode(b"test-audio").decode("utf-8"), channels=1, @@ -76,5 +77,5 @@ def write(data): await write_event.wait() await audio_output.stop() - + speaker.write.assert_called_once_with(write_future.result()) diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index c575f1788..8cf875598 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -8,11 +8,9 @@ """ import base64 -import json import unittest.mock import pytest -from google import genai from google.genai import types as genai_types from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel @@ -33,21 +31,23 @@ @pytest.fixture def mock_genai_client(): """Mock the Google GenAI client.""" - with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.gemini_live.genai.Client") as mock_client_cls: + with unittest.mock.patch( + "strands.experimental.bidirectional_streaming.models.gemini_live.genai.Client" + ) as mock_client_cls: mock_client = mock_client_cls.return_value mock_client.aio = unittest.mock.MagicMock() - + # Mock the live session mock_live_session = unittest.mock.AsyncMock() - + # Mock the context manager mock_live_session_cm = unittest.mock.MagicMock() mock_live_session_cm.__aenter__ = unittest.mock.AsyncMock(return_value=mock_live_session) mock_live_session_cm.__aexit__ = unittest.mock.AsyncMock(return_value=None) - + # Make connect return the context manager mock_client.aio.live.connect = unittest.mock.MagicMock(return_value=mock_live_session_cm) - + yield mock_client, mock_live_session, mock_live_session_cm @@ -93,7 +93,7 @@ def messages(): def test_model_initialization(mock_genai_client, model_id, api_key): """Test model initialization with various configurations.""" _ = mock_genai_client - + # Test default config model_default = BidiGeminiLiveModel() assert model_default.model_id == "gemini-2.5-flash-native-audio-preview-09-2025" @@ -104,12 +104,12 @@ def test_model_initialization(mock_genai_client, model_id, api_key): assert model_default.live_config["response_modalities"] == ["AUDIO"] assert "outputAudioTranscription" in model_default.live_config assert "inputAudioTranscription" in model_default.live_config - + # Test with API key model_with_key = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) assert model_with_key.model_id == model_id assert model_with_key.api_key == api_key - + # Test with custom config (merges with defaults) live_config = {"temperature": 0.7, "top_p": 0.9} model_custom = BidiGeminiLiveModel(model_id=model_id, live_config=live_config) @@ -127,26 +127,26 @@ def test_model_initialization(mock_genai_client, model_id, api_key): async def test_connection_lifecycle(mock_genai_client, model, system_prompt, tool_spec, messages): """Test complete connection lifecycle with various configurations.""" mock_client, mock_live_session, mock_live_session_cm = mock_genai_client - + # Test basic connection await model.start() assert model._active is True assert model.connection_id is not None assert model.live_session == mock_live_session mock_client.aio.live.connect.assert_called_once() - + # Test close await model.stop() assert model._active is False mock_live_session_cm.__aexit__.assert_called_once() - + # Test connection with system prompt await model.start(system_prompt=system_prompt) call_args = mock_client.aio.live.connect.call_args config = call_args.kwargs.get("config", {}) assert config.get("system_instruction") == system_prompt await model.stop() - + # Test connection with tools await model.start(tools=[tool_spec]) call_args = mock_client.aio.live.connect.call_args @@ -154,7 +154,7 @@ async def test_connection_lifecycle(mock_genai_client, model, system_prompt, too assert "tools" in config assert len(config["tools"]) > 0 await model.stop() - + # Test connection with messages await model.start(messages=messages) mock_live_session.send_client_content.assert_called() @@ -165,27 +165,27 @@ async def test_connection_lifecycle(mock_genai_client, model, system_prompt, too async def test_connection_edge_cases(mock_genai_client, api_key, model_id): """Test connection error handling and edge cases.""" mock_client, _, mock_live_session_cm = mock_genai_client - + # Test connection error model1 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) mock_client.aio.live.connect.side_effect = Exception("Connection failed") with pytest.raises(Exception, match="Connection failed"): await model1.start() - + # Reset mock for next tests mock_client.aio.live.connect.side_effect = None - + # Test double connection model2 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) await model2.start() with pytest.raises(RuntimeError, match="Connection already active"): await model2.start() await model2.stop() - + # Test close when not connected model3 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) await model3.stop() # Should not raise - + # Test close error handling model4 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) await model4.start() @@ -202,7 +202,7 @@ async def test_send_all_content_types(mock_genai_client, model): """Test sending all content types through unified send() method.""" _, mock_live_session, _ = mock_genai_client await model.start() - + # Test text input text_input = BidiTextInputEvent(text="Hello", role="user") await model.send(text_input) @@ -211,9 +211,9 @@ async def test_send_all_content_types(mock_genai_client, model): content = call_args.kwargs.get("turns") assert content.role == "user" assert content.parts[0].text == "Hello" - + # Test audio input (base64 encoded) - audio_b64 = base64.b64encode(b"audio_bytes").decode('utf-8') + audio_b64 = base64.b64encode(b"audio_bytes").decode("utf-8") audio_input = BidiAudioInputEvent( audio=audio_b64, format="pcm", @@ -222,16 +222,16 @@ async def test_send_all_content_types(mock_genai_client, model): ) await model.send(audio_input) mock_live_session.send_realtime_input.assert_called_once() - + # Test image input (base64 encoded, no encoding parameter) - image_b64 = base64.b64encode(b"image_bytes").decode('utf-8') + image_b64 = base64.b64encode(b"image_bytes").decode("utf-8") image_input = BidiImageInputEvent( image=image_b64, mime_type="image/jpeg", ) await model.send(image_input) mock_live_session.send.assert_called_once() - + # Test tool result tool_result: ToolResult = { "toolUseId": "tool-123", @@ -240,7 +240,7 @@ async def test_send_all_content_types(mock_genai_client, model): } await model.send(ToolResultEvent(tool_result)) mock_live_session.send_tool_response.assert_called_once() - + await model.stop() @@ -248,17 +248,17 @@ async def test_send_all_content_types(mock_genai_client, model): async def test_send_edge_cases(mock_genai_client, model): """Test send() edge cases and error handling.""" _, mock_live_session, _ = mock_genai_client - + # Test send when inactive text_input = BidiTextInputEvent(text="Hello", role="user") await model.send(text_input) mock_live_session.send_client_content.assert_not_called() - + # Test unknown content type await model.start() unknown_content = {"unknown_field": "value"} await model.send(unknown_content) # Should not raise, just log warning - + await model.stop() @@ -270,9 +270,9 @@ async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): """Test that receive() emits connection start and end events.""" _, mock_live_session, _ = mock_genai_client mock_live_session.receive.return_value = agenerator([]) - + await model.start() - + # Collect events events = [] async for event in model.receive(): @@ -280,7 +280,7 @@ async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): # Close after first event to trigger connection end if len(events) == 1: await model.stop() - + # Verify connection start and end assert len(events) >= 2 assert isinstance(events[0], BidiConnectionStartEvent) @@ -295,26 +295,26 @@ async def test_event_conversion(mock_genai_client, model): """Test conversion of all Gemini Live event types to standard format.""" _, _, _ = mock_genai_client await model.start() - + # Test text output (converted to transcript via model_turn.parts) mock_text = unittest.mock.Mock() mock_text.data = None mock_text.tool_call = None - + # Create proper server_content structure with model_turn mock_server_content = unittest.mock.Mock() mock_server_content.interrupted = False mock_server_content.input_transcription = None mock_server_content.output_transcription = None - + mock_model_turn = unittest.mock.Mock() mock_part = unittest.mock.Mock() mock_part.text = "Hello from Gemini" mock_model_turn.parts = [mock_part] mock_server_content.model_turn = mock_model_turn - + mock_text.server_content = mock_server_content - + text_events = model._convert_gemini_live_event(mock_text) assert isinstance(text_events, list) assert len(text_events) == 1 @@ -326,17 +326,17 @@ async def test_event_conversion(mock_genai_client, model): assert text_event.is_final is True assert text_event.delta == {"text": "Hello from Gemini"} assert text_event.current_transcript == "Hello from Gemini" - + # Test multiple text parts (should concatenate) mock_multi_text = unittest.mock.Mock() mock_multi_text.data = None mock_multi_text.tool_call = None - + mock_server_content_multi = unittest.mock.Mock() mock_server_content_multi.interrupted = False mock_server_content_multi.input_transcription = None mock_server_content_multi.output_transcription = None - + mock_model_turn_multi = unittest.mock.Mock() mock_part1 = unittest.mock.Mock() mock_part1.text = "Hello" @@ -344,23 +344,23 @@ async def test_event_conversion(mock_genai_client, model): mock_part2.text = "from Gemini" mock_model_turn_multi.parts = [mock_part1, mock_part2] mock_server_content_multi.model_turn = mock_model_turn_multi - + mock_multi_text.server_content = mock_server_content_multi - + multi_text_events = model._convert_gemini_live_event(mock_multi_text) assert isinstance(multi_text_events, list) assert len(multi_text_events) == 1 multi_text_event = multi_text_events[0] assert isinstance(multi_text_event, BidiTranscriptStreamEvent) assert multi_text_event.text == "Hello from Gemini" # Concatenated with space - + # Test audio output (base64 encoded) mock_audio = unittest.mock.Mock() mock_audio.text = None mock_audio.data = b"audio_data" mock_audio.tool_call = None mock_audio.server_content = None - + audio_events = model._convert_gemini_live_event(mock_audio) assert isinstance(audio_events, list) assert len(audio_events) == 1 @@ -368,25 +368,25 @@ async def test_event_conversion(mock_genai_client, model): assert isinstance(audio_event, BidiAudioStreamEvent) assert audio_event.get("type") == "bidi_audio_stream" # Audio is now base64 encoded - expected_b64 = base64.b64encode(b"audio_data").decode('utf-8') + expected_b64 = base64.b64encode(b"audio_data").decode("utf-8") assert audio_event.audio == expected_b64 assert audio_event.format == "pcm" - + # Test single tool call (returns list with one event) mock_func_call = unittest.mock.Mock() mock_func_call.id = "tool-123" mock_func_call.name = "calculator" mock_func_call.args = {"expression": "2+2"} - + mock_tool_call = unittest.mock.Mock() mock_tool_call.function_calls = [mock_func_call] - + mock_tool = unittest.mock.Mock() mock_tool.text = None mock_tool.data = None mock_tool.tool_call = mock_tool_call mock_tool.server_content = None - + tool_events = model._convert_gemini_live_event(mock_tool) # Should return a list of ToolUseStreamEvent assert isinstance(tool_events, list) @@ -397,54 +397,54 @@ async def test_event_conversion(mock_genai_client, model): assert "toolUse" in tool_event["delta"] assert tool_event["delta"]["toolUse"]["toolUseId"] == "tool-123" assert tool_event["delta"]["toolUse"]["name"] == "calculator" - + # Test multiple tool calls (returns list with multiple events) mock_func_call_1 = unittest.mock.Mock() mock_func_call_1.id = "tool-123" mock_func_call_1.name = "calculator" mock_func_call_1.args = {"expression": "2+2"} - + mock_func_call_2 = unittest.mock.Mock() mock_func_call_2.id = "tool-456" mock_func_call_2.name = "weather" mock_func_call_2.args = {"location": "Seattle"} - + mock_tool_call_multi = unittest.mock.Mock() mock_tool_call_multi.function_calls = [mock_func_call_1, mock_func_call_2] - + mock_tool_multi = unittest.mock.Mock() mock_tool_multi.text = None mock_tool_multi.data = None mock_tool_multi.tool_call = mock_tool_call_multi mock_tool_multi.server_content = None - + tool_events_multi = model._convert_gemini_live_event(mock_tool_multi) # Should return a list with two ToolUseStreamEvent assert isinstance(tool_events_multi, list) assert len(tool_events_multi) == 2 - + # Verify first tool call assert tool_events_multi[0]["delta"]["toolUse"]["toolUseId"] == "tool-123" assert tool_events_multi[0]["delta"]["toolUse"]["name"] == "calculator" assert tool_events_multi[0]["delta"]["toolUse"]["input"] == {"expression": "2+2"} - + # Verify second tool call assert tool_events_multi[1]["delta"]["toolUse"]["toolUseId"] == "tool-456" assert tool_events_multi[1]["delta"]["toolUse"]["name"] == "weather" assert tool_events_multi[1]["delta"]["toolUse"]["input"] == {"location": "Seattle"} - + # Test interruption mock_server_content = unittest.mock.Mock() mock_server_content.interrupted = True mock_server_content.input_transcription = None mock_server_content.output_transcription = None - + mock_interrupt = unittest.mock.Mock() mock_interrupt.text = None mock_interrupt.data = None mock_interrupt.tool_call = None mock_interrupt.server_content = mock_server_content - + interrupt_events = model._convert_gemini_live_event(mock_interrupt) assert isinstance(interrupt_events, list) assert len(interrupt_events) == 1 @@ -452,7 +452,7 @@ async def test_event_conversion(mock_genai_client, model): assert isinstance(interrupt_event, BidiInterruptionEvent) assert interrupt_event.get("type") == "bidi_interruption" assert interrupt_event.reason == "user_speech" - + await model.stop() @@ -464,11 +464,11 @@ def test_config_building(model, system_prompt, tool_spec): # Test basic config config_basic = model._build_live_config() assert isinstance(config_basic, dict) - + # Test with system prompt config_prompt = model._build_live_config(system_prompt=system_prompt) assert config_prompt["system_instruction"] == system_prompt - + # Test with tools config_tools = model._build_live_config(tools=[tool_spec]) assert "tools" in config_tools @@ -481,7 +481,7 @@ def test_tool_formatting(model, tool_spec): formatted_tools = model._format_tools_for_live_api([tool_spec]) assert len(formatted_tools) == 1 assert isinstance(formatted_tools[0], genai_types.Tool) - + # Test empty list formatted_empty = model._format_tools_for_live_api([]) assert formatted_empty == [] diff --git a/tests/strands/experimental/bidi/models/test_novasonic.py b/tests/strands/experimental/bidi/models/test_novasonic.py index db61ed43e..8e60b8fb5 100644 --- a/tests/strands/experimental/bidi/models/test_novasonic.py +++ b/tests/strands/experimental/bidi/models/test_novasonic.py @@ -109,7 +109,7 @@ async def test_connection_lifecycle(nova_model, mock_client, mock_stream): { "name": "get_weather", "description": "Get weather information", - "inputSchema": {"json": json.dumps({"type": "object", "properties": {}})} + "inputSchema": {"json": json.dumps({"type": "object", "properties": {}})}, } ] await nova_model.start(system_prompt="You are helpful", tools=tools) @@ -154,13 +154,8 @@ async def test_send_all_content_types(nova_model, mock_client, mock_stream): assert mock_stream.input_stream.send.call_count >= 3 # Test audio content (base64 encoded) - audio_b64 = base64.b64encode(b"audio data").decode('utf-8') - audio_event = BidiAudioInputEvent( - audio=audio_b64, - format="pcm", - sample_rate=16000, - channels=1 - ) + audio_b64 = base64.b64encode(b"audio data").decode("utf-8") + audio_event = BidiAudioInputEvent(audio=audio_b64, format="pcm", sample_rate=16000, channels=1) await nova_model.send(audio_event) # Should start audio connection and send audio assert nova_model.audio_connection_active @@ -170,7 +165,7 @@ async def test_send_all_content_types(nova_model, mock_client, mock_stream): tool_result: ToolResult = { "toolUseId": "tool-123", "status": "success", - "content": [{"text": "Weather is sunny"}] + "content": [{"text": "Weather is sunny"}], } await nova_model.send(ToolResultEvent(tool_result)) # Should send contentStart, toolResult, and contentEnd @@ -191,7 +186,7 @@ async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog): # Test image content (not supported, base64 encoded, no encoding parameter) await nova_model.start() - image_b64 = base64.b64encode(b"image data").decode('utf-8') + image_b64 = base64.b64encode(b"image data").decode("utf-8") image_event = BidiImageInputEvent( image=image_b64, mime_type="image/jpeg", @@ -261,13 +256,7 @@ async def test_event_conversion(nova_model): # Test tool use (now returns ToolUseStreamEvent from core strands) tool_input = {"location": "Seattle"} - nova_event = { - "toolUse": { - "toolUseId": "tool-123", - "toolName": "get_weather", - "content": json.dumps(tool_input) - } - } + nova_event = {"toolUse": {"toolUseId": "tool-123", "toolName": "get_weather", "content": json.dumps(tool_input)}} result = nova_model._convert_nova_event(nova_event) assert result is not None # ToolUseStreamEvent has delta and current_tool_use, not a "type" field @@ -292,13 +281,7 @@ async def test_event_conversion(nova_model): "totalTokens": 100, "totalInputTokens": 40, "totalOutputTokens": 60, - "details": { - "total": { - "output": { - "speechTokens": 30 - } - } - } + "details": {"total": {"output": {"speechTokens": 30}}}, } } result = nova_model._convert_nova_event(nova_event) @@ -350,13 +333,8 @@ async def test_silence_detection(nova_model, mock_client, mock_stream): await nova_model.start() # Send audio to start connection (base64 encoded) - audio_b64 = base64.b64encode(b"audio data").decode('utf-8') - audio_event = BidiAudioInputEvent( - audio=audio_b64, - format="pcm", - sample_rate=16000, - channels=1 - ) + audio_b64 = base64.b64encode(b"audio data").decode("utf-8") + audio_event = BidiAudioInputEvent(audio=audio_b64, format="pcm", sample_rate=16000, channels=1) await nova_model.send(audio_event) assert nova_model.audio_connection_active @@ -380,14 +358,7 @@ async def test_tool_configuration(nova_model): { "name": "get_weather", "description": "Get weather information", - "inputSchema": { - "json": json.dumps({ - "type": "object", - "properties": { - "location": {"type": "string"} - } - }) - } + "inputSchema": {"json": json.dumps({"type": "object", "properties": {"location": {"type": "string"}}})}, } ] diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py index b9e844250..badc52031 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -8,7 +8,6 @@ - Connection lifecycle management """ -import asyncio import base64 import json import unittest.mock @@ -41,10 +40,13 @@ def mock_websocket(): @pytest.fixture def mock_websockets_connect(mock_websocket): """Mock websockets.connect function.""" + async def async_connect(*args, **kwargs): return mock_websocket - with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.websockets.connect") as mock_connect: + with unittest.mock.patch( + "strands.experimental.bidirectional_streaming.models.openai.websockets.connect" + ) as mock_connect: mock_connect.side_effect = async_connect yield mock_connect, mock_websocket @@ -102,12 +104,7 @@ def test_model_initialization(api_key, model_name): assert model_custom.api_key == api_key # Test with organization and project - model_org = BidiOpenAIRealtimeModel( - model=model_name, - api_key=api_key, - organization="org-123", - project="proj-456" - ) + model_org = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key, organization="org-123", project="proj-456") assert model_org.organization == "org-123" assert model_org.project == "proj-456" @@ -150,8 +147,7 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp await model.start(system_prompt=system_prompt) calls = mock_ws.send.call_args_list session_update = next( - (json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "session.update"), - None + (json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "session.update"), None ) assert session_update is not None assert system_prompt in session_update["session"]["instructions"] @@ -161,7 +157,9 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp await model.start(tools=[tool_spec]) calls = mock_ws.send.call_args_list # Tools are sent in a separate session.update after initial connection - session_updates = [json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "session.update"] + session_updates = [ + json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "session.update" + ] assert len(session_updates) > 0 # Check if any session update has tools has_tools = any("tools" in update.get("session", {}) for update in session_updates) @@ -171,7 +169,9 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp # Test connection with messages await model.start(messages=messages) calls = mock_ws.send.call_args_list - item_creates = [json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "conversation.item.create"] + item_creates = [ + json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "conversation.item.create" + ] assert len(item_creates) > 0 await model.stop() @@ -200,6 +200,7 @@ async def test_connection_edge_cases(mock_websockets_connect, api_key, model_nam # Reset mock async def async_connect(*args, **kwargs): return mock_ws + mock_connect.side_effect = async_connect # Test double connection @@ -241,7 +242,7 @@ async def test_send_all_content_types(mock_websockets_connect, model): assert len(response_create) > 0 # Test audio input (base64 encoded) - audio_b64 = base64.b64encode(b"audio_bytes").decode('utf-8') + audio_b64 = base64.b64encode(b"audio_bytes").decode("utf-8") audio_input = BidiAudioInputEvent( audio=audio_b64, format="pcm", @@ -287,7 +288,7 @@ async def test_send_edge_cases(mock_websockets_connect, model): # Test image input (not supported, base64 encoded, no encoding parameter) await model.start() - image_b64 = base64.b64encode(b"image_bytes").decode('utf-8') + image_b64 = base64.b64encode(b"image_bytes").decode("utf-8") image_input = BidiImageInputEvent( image=image_b64, mime_type="image/jpeg", @@ -346,10 +347,7 @@ async def test_event_conversion(mock_websockets_connect, model): await model.start() # Test audio output (now returns list with BidiAudioStreamEvent) - audio_event = { - "type": "response.output_audio.delta", - "delta": base64.b64encode(b"audio_data").decode() - } + audio_event = {"type": "response.output_audio.delta", "delta": base64.b64encode(b"audio_data").decode()} converted = model._convert_openai_event(audio_event) assert isinstance(converted, list) assert len(converted) == 1 @@ -359,10 +357,7 @@ async def test_event_conversion(mock_websockets_connect, model): assert converted[0].get("format") == "pcm" # Test text output (now returns list with BidiTranscriptStreamEvent) - text_event = { - "type": "response.output_text.delta", - "delta": "Hello from OpenAI" - } + text_event = {"type": "response.output_text.delta", "delta": "Hello from OpenAI"} converted = model._convert_openai_event(text_event) assert isinstance(converted, list) assert len(converted) == 1 @@ -376,25 +371,18 @@ async def test_event_conversion(mock_websockets_connect, model): # Test function call sequence item_added = { "type": "response.output_item.added", - "item": { - "type": "function_call", - "call_id": "call-123", - "name": "calculator" - } + "item": {"type": "function_call", "call_id": "call-123", "name": "calculator"}, } model._convert_openai_event(item_added) args_delta = { "type": "response.function_call_arguments.delta", "call_id": "call-123", - "delta": '{"expression": "2+2"}' + "delta": '{"expression": "2+2"}', } model._convert_openai_event(args_delta) - args_done = { - "type": "response.function_call_arguments.done", - "call_id": "call-123" - } + args_done = {"type": "response.function_call_arguments.done", "call_id": "call-123"} converted = model._convert_openai_event(args_done) # Now returns list with ToolUseStreamEvent assert isinstance(converted, list) @@ -408,9 +396,7 @@ async def test_event_conversion(mock_websockets_connect, model): assert tool_use["input"]["expression"] == "2+2" # Test voice activity (now returns list with BidiInterruptionEvent for speech_started) - speech_started = { - "type": "input_audio_buffer.speech_started" - } + speech_started = {"type": "input_audio_buffer.speech_started"} converted = model._convert_openai_event(speech_started) assert isinstance(converted, list) assert len(converted) == 1 @@ -419,12 +405,7 @@ async def test_event_conversion(mock_websockets_connect, model): assert converted[0].get("reason") == "user_speech" # Test response.cancelled event (should return ResponseCompleteEvent with interrupted reason) - response_cancelled = { - "type": "response.cancelled", - "response": { - "id": "resp_123" - } - } + response_cancelled = {"type": "response.cancelled", "response": {"id": "resp_123"}} converted = model._convert_openai_event(response_cancelled) assert isinstance(converted, list) assert len(converted) == 1 @@ -436,22 +417,13 @@ async def test_event_conversion(mock_websockets_connect, model): # Test error handling - response_cancel_not_active should be suppressed error_cancel_not_active = { "type": "error", - "error": { - "code": "response_cancel_not_active", - "message": "No active response to cancel" - } + "error": {"code": "response_cancel_not_active", "message": "No active response to cancel"}, } converted = model._convert_openai_event(error_cancel_not_active) assert converted is None # Should be suppressed # Test error handling - other errors should be logged but return None - error_other = { - "type": "error", - "error": { - "code": "some_other_error", - "message": "Something went wrong" - } - } + error_other = {"type": "error", "error": {"code": "some_other_error", "message": "Something went wrong"}} converted = model._convert_openai_event(error_other) assert converted is None @@ -516,7 +488,7 @@ def test_helper_methods(model): assert isinstance(voice_event, BidiInterruptionEvent) assert voice_event.get("type") == "bidi_interruption" assert voice_event.get("reason") == "user_speech" - + # Other voice activities return None assert model._create_voice_activity_event("speech_stopped") is None diff --git a/tests/strands/experimental/bidi/types/test_events.py b/tests/strands/experimental/bidi/types/test_events.py index 0b6419719..1e609bd36 100644 --- a/tests/strands/experimental/bidi/types/test_events.py +++ b/tests/strands/experimental/bidi/types/test_events.py @@ -115,7 +115,6 @@ def test_event_json_serialization(event_class, kwargs, expected_type): assert key in data - def test_transcript_stream_event_delta_pattern(): """Test that BidiTranscriptStreamEvent follows ModelStreamEvent delta pattern.""" # Test partial transcript (delta) @@ -126,13 +125,13 @@ def test_transcript_stream_event_delta_pattern(): is_final=False, current_transcript=None, ) - + assert partial_event.text == "Hello" assert partial_event.role == "user" assert partial_event.is_final is False assert partial_event.current_transcript is None assert partial_event.delta == {"text": "Hello"} - + # Test final transcript with accumulated text final_event = BidiTranscriptStreamEvent( delta={"text": " world"}, @@ -141,7 +140,7 @@ def test_transcript_stream_event_delta_pattern(): is_final=True, current_transcript="Hello world", ) - + assert final_event.text == " world" assert final_event.role == "user" assert final_event.is_final is True @@ -152,7 +151,7 @@ def test_transcript_stream_event_delta_pattern(): def test_transcript_stream_event_extends_model_stream_event(): """Test that BidiTranscriptStreamEvent is a ModelStreamEvent.""" from strands.types._events import ModelStreamEvent - + event = BidiTranscriptStreamEvent( delta={"text": "test"}, text="test", @@ -160,5 +159,5 @@ def test_transcript_stream_event_extends_model_stream_event(): is_final=True, current_transcript="test", ) - + assert isinstance(event, ModelStreamEvent) diff --git a/tests_integ/bidi/context.py b/tests_integ/bidi/context.py index 4a5278a62..978878bea 100644 --- a/tests_integ/bidi/context.py +++ b/tests_integ/bidi/context.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from strands.experimental.bidi.agent.agent import BidiAgent + from .generators.audio import AudioGenerator logger = logging.getLogger(__name__) @@ -76,7 +77,7 @@ async def __aenter__(self): # Start agent session await self.agent.start() logger.debug("Agent session started") - + await self.start() return self @@ -86,10 +87,10 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): if self.agent._agent_loop and self.agent._agent_loop.active: await self.agent.stop() logger.debug("Agent session stopped") - + # Then stop the context threads await self.stop() - + return False async def start(self): @@ -109,7 +110,7 @@ async def stop(self): if not self.active: logger.debug("stop() called but already stopped") return - + logger.debug("stop() called - stopping threads") self.active = False @@ -130,24 +131,22 @@ async def say(self, text: str): Args: text: Text to convert to speech and send as audio. - + Raises: ValueError: If audio generator is not available. """ if not self.audio_generator: - raise ValueError( - "Audio generator not available. Pass audio_generator to BidirectionalTestContext." - ) - + raise ValueError("Audio generator not available. Pass audio_generator to BidirectionalTestContext.") + # Generate audio via Polly audio_data = await self.audio_generator.generate_audio(text) - + # Split into chunks and queue each chunk for i in range(0, len(audio_data), self.audio_chunk_size): chunk = audio_data[i : i + self.audio_chunk_size] chunk_event = self.audio_generator.create_audio_input_event(chunk) await self.input_queue.put({"type": "audio_chunk", "data": chunk_event}) - + logger.debug(f"Queued {len(audio_data)} bytes of audio for: {text[:50]}...") async def send(self, data: str | dict) -> None: @@ -183,7 +182,7 @@ async def wait_for_response( while time.monotonic() - start_time < timeout: # Drain queue to get latest events current_events = self.get_events() - + # Check if we have minimum events if len(current_events) - initial_event_count >= min_events: # Check silence @@ -201,7 +200,7 @@ async def wait_for_response( def get_events(self, event_type: str | None = None) -> list[dict]: """Get collected events, optionally filtered by type. - + Drains the event queue and caches events for subsequent calls. Args: @@ -218,14 +217,14 @@ def get_events(self, event_type: str | None = None) -> list[dict]: self.last_event_time = time.monotonic() except asyncio.QueueEmpty: break - + if event_type: return [e for e in self.events if event_type in e] return self.events.copy() def get_text_outputs(self) -> list[str]: """Extract text outputs from collected events. - + Handles both new TypedEvent format and legacy event formats. Returns: @@ -320,7 +319,11 @@ async def _input_thread(self): elif input_item["type"] == "direct": # Send data directly to agent await self.agent.send(input_item["data"]) - data_repr = str(input_item["data"])[:50] if isinstance(input_item["data"], str) else type(input_item["data"]).__name__ + data_repr = ( + str(input_item["data"])[:50] + if isinstance(input_item["data"], str) + else type(input_item["data"]).__name__ + ) logger.debug(f"Sent direct: {data_repr}") except asyncio.TimeoutError: diff --git a/tests_integ/bidi/generators/audio.py b/tests_integ/bidi/generators/audio.py index 75c17a1e3..8f2a9929f 100644 --- a/tests_integ/bidi/generators/audio.py +++ b/tests_integ/bidi/generators/audio.py @@ -122,8 +122,8 @@ def create_audio_input_event( BidiAudioInputEvent dict ready for agent.send(). """ # Convert bytes to base64 string for JSON compatibility - audio_b64 = base64.b64encode(audio_data).decode('utf-8') - + audio_b64 = base64.b64encode(audio_data).decode("utf-8") + return { "type": "bidi_audio_input", "audio": audio_b64, diff --git a/tests_integ/bidi/test_bidirectional_agent.py b/tests_integ/bidi/test_bidirectional_agent.py index 594379b64..ed2f8d5f1 100644 --- a/tests_integ/bidi/test_bidirectional_agent.py +++ b/tests_integ/bidi/test_bidirectional_agent.py @@ -14,9 +14,9 @@ from strands import tool from strands.experimental.bidi.agent.agent import BidiAgent +from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel from strands.experimental.bidi.models.novasonic import BidiNovaSonicModel from strands.experimental.bidi.models.openai import BidiOpenAIRealtimeModel -from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel from .context import BidirectionalTestContext @@ -27,12 +27,12 @@ @tool def calculator(operation: str, x: float, y: float) -> float: """Perform basic arithmetic operations. - + Args: operation: The operation to perform (add, subtract, multiply, divide) x: First number y: Second number - + Returns: Result of the operation """ @@ -96,38 +96,38 @@ def calculator(operation: str, x: float, y: float) -> float: def check_provider_available(provider_name: str) -> tuple[bool, str]: """Check if a provider's credentials are available. - + Args: provider_name: Name of the provider to check. - + Returns: Tuple of (is_available, skip_reason). """ config = PROVIDER_CONFIGS[provider_name] env_vars = config["env_vars"] - + missing_vars = [var for var in env_vars if not os.getenv(var)] - + if missing_vars: return False, f"{config['skip_reason']}: {', '.join(missing_vars)}" - + return True, "" @pytest.fixture(params=list(PROVIDER_CONFIGS.keys())) def provider_config(request): """Provide configuration for each model provider. - + This fixture is parameterized to run tests against all available providers. """ provider_name = request.param config = PROVIDER_CONFIGS[provider_name] - + # Check if provider is available is_available, skip_reason = check_provider_available(provider_name) if not is_available: pytest.skip(skip_reason) - + return { "name": provider_name, **config, @@ -137,12 +137,12 @@ def provider_config(request): @pytest.fixture def agent_with_calculator(provider_config): """Provide bidirectional agent with calculator tool for the given provider. - + Note: Session lifecycle (start/end) is handled by BidirectionalTestContext. """ model_class = provider_config["model_class"] model_kwargs = provider_config["model_kwargs"] - + model = model_class(**model_kwargs) return BidiAgent( model=model, @@ -150,13 +150,14 @@ def agent_with_calculator(provider_config): system_prompt="You are a helpful assistant with access to a calculator tool. Keep responses brief.", ) + @pytest.mark.asyncio async def test_bidirectional_agent(agent_with_calculator, audio_generator, provider_config): """Test multi-turn conversation with follow-up questions across providers. - + This test runs against all configured providers (Nova Sonic, OpenAI, etc.) to validate provider-agnostic functionality. - + Validates: - Session lifecycle (start/end via context manager) - Audio input streaming @@ -167,9 +168,9 @@ async def test_bidirectional_agent(agent_with_calculator, audio_generator, provi """ provider_name = provider_config["name"] silence_duration = provider_config["silence_duration"] - + logger.info(f"Testing provider: {provider_name}") - + async with BidirectionalTestContext(agent_with_calculator, audio_generator) as ctx: # Turn 1: Simple greeting to test basic audio I/O await ctx.say("Hello, can you hear me?") @@ -179,12 +180,10 @@ async def test_bidirectional_agent(agent_with_calculator, audio_generator, provi text_outputs_turn1 = ctx.get_text_outputs() all_text_turn1 = " ".join(text_outputs_turn1).lower() - + # Validate turn 1 - just check we got a response - assert len(text_outputs_turn1) > 0, ( - f"[{provider_name}] No text output received in turn 1" - ) - + assert len(text_outputs_turn1) > 0, f"[{provider_name}] No text output received in turn 1" + logger.info(f"[{provider_name}] ✓ Turn 1 complete: received response") logger.info(f"[{provider_name}] Response: {text_outputs_turn1[0][:100]}...") @@ -195,12 +194,10 @@ async def test_bidirectional_agent(agent_with_calculator, audio_generator, provi await ctx.wait_for_response() text_outputs_turn2 = ctx.get_text_outputs() - + # Validate turn 2 - check we got more responses - assert len(text_outputs_turn2) > len(text_outputs_turn1), ( - f"[{provider_name}] No new text output in turn 2" - ) - + assert len(text_outputs_turn2) > len(text_outputs_turn1), f"[{provider_name}] No new text output in turn 2" + logger.info(f"[{provider_name}] ✓ Turn 2 complete: multi-turn conversation works") logger.info(f"[{provider_name}] Total responses: {len(text_outputs_turn2)}") @@ -209,7 +206,7 @@ async def test_bidirectional_agent(agent_with_calculator, audio_generator, provi audio_outputs = ctx.get_audio_outputs() assert len(audio_outputs) > 0, f"[{provider_name}] No audio output received" total_audio_bytes = sum(len(audio) for audio in audio_outputs) - + # Summary logger.info("=" * 60) logger.info(f"[{provider_name}] ✓ Multi-turn conversation test PASSED") From e597ca267f55f7633637a3055189434f7f83691a Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Mon, 17 Nov 2025 08:13:46 -0800 Subject: [PATCH 125/242] fix E501 errors --- src/strands/experimental/bidi/agent/agent.py | 9 ++++++--- .../experimental/bidi/scripts/test_bidi_openai.py | 6 +++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 846473e86..aa7a2105d 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -307,7 +307,9 @@ async def send(self, input_data: BidiAgentInput) -> None: # If we get here, input type is invalid raise ValueError( - f"Input must be a string, BidiInputEvent (BidiTextInputEvent/BidiAudioInputEvent/BidiImageInputEvent), or event dict with 'type' field, got: {type(input_data)}" + f"Input must be a string, BidiInputEvent " + f"(BidiTextInputEvent/BidiAudioInputEvent/BidiImageInputEvent), " + f"or event dict with 'type' field, got: {type(input_data)}" ) async def receive(self) -> AsyncIterable[BidiOutputEvent]: @@ -409,8 +411,9 @@ async def run_inputs(): event = await input_() await self.send(event) - # TODO: Need to make tool result send in Nova provider atomic. Audio input events end up interleaving - # and leading to failures. Adding a sleep here as a temporary solution. + # TODO: Need to make tool result send in Nova provider atomic. + # Audio input events end up interleaving and leading to failures. + # Adding a sleep here as a temporary solution. await asyncio.sleep(0.001) async def run_outputs(): diff --git a/src/strands/experimental/bidi/scripts/test_bidi_openai.py b/src/strands/experimental/bidi/scripts/test_bidi_openai.py index 2243ac84d..6e90aee32 100644 --- a/src/strands/experimental/bidi/scripts/test_bidi_openai.py +++ b/src/strands/experimental/bidi/scripts/test_bidi_openai.py @@ -256,7 +256,11 @@ async def main(): agent = BidiAgent( model=model, tools=[calculator], - system_prompt="You are a helpful voice assistant. Keep your responses brief and natural. Say hello when you first connect.", + system_prompt=( + "You are a helpful voice assistant. " + "Keep your responses brief and natural. " + "Say hello when you first connect." + ), ) # Start the session From 818072f7d5605a7ed7a3a195987c7a8cbd6a9515 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Mon, 17 Nov 2025 08:48:00 -0800 Subject: [PATCH 126/242] fix: D107 errors --- src/strands/experimental/bidi/types/events.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/strands/experimental/bidi/types/events.py b/src/strands/experimental/bidi/types/events.py index 39df53a0a..547919494 100644 --- a/src/strands/experimental/bidi/types/events.py +++ b/src/strands/experimental/bidi/types/events.py @@ -49,6 +49,7 @@ class BidiTextInputEvent(TypedEvent): """ def __init__(self, text: str, role: str): + """Initialize text input event.""" super().__init__( { "type": "bidi_text_input", @@ -87,6 +88,7 @@ def __init__( sample_rate: Literal[16000, 24000, 48000], channels: Literal[1, 2], ): + """Initialize audio input event.""" super().__init__( { "type": "bidi_audio_input", @@ -133,6 +135,7 @@ def __init__( image: str, mime_type: str, ): + """Initialize image input event.""" super().__init__( { "type": "bidi_image_input", @@ -166,6 +169,7 @@ class BidiConnectionStartEvent(TypedEvent): """ def __init__(self, connection_id: str, model: str): + """Initialize connection start event.""" super().__init__( { "type": "bidi_connection_start", @@ -193,6 +197,7 @@ class BidiResponseStartEvent(TypedEvent): """ def __init__(self, response_id: str): + """Initialize response start event.""" super().__init__({"type": "bidi_response_start", "response_id": response_id}) @property @@ -218,6 +223,7 @@ def __init__( sample_rate: Literal[16000, 24000, 48000], channels: Literal[1, 2], ): + """Initialize audio stream event.""" super().__init__( { "type": "bidi_audio_stream", @@ -271,6 +277,7 @@ def __init__( is_final: bool, current_transcript: Optional[str] = None, ): + """Initialize transcript stream event.""" super().__init__( { "type": "bidi_transcript_stream", @@ -317,6 +324,7 @@ class BidiInterruptionEvent(TypedEvent): """ def __init__(self, reason: Literal["user_speech", "error"]): + """Initialize interruption event.""" super().__init__( { "type": "bidi_interruption", @@ -343,6 +351,7 @@ def __init__( response_id: str, stop_reason: Literal["complete", "interrupted", "tool_use", "error"], ): + """Initialize response complete event.""" super().__init__( { "type": "bidi_response_complete", @@ -400,6 +409,7 @@ def __init__( cache_read_input_tokens: Optional[int] = None, cache_write_input_tokens: Optional[int] = None, ): + """Initialize usage event.""" data: Dict[str, Any] = { "type": "bidi_usage", "inputTokens": input_tokens, @@ -458,6 +468,7 @@ def __init__( connection_id: str, reason: Literal["client_disconnect", "timeout", "error", "complete"], ): + """Initialize connection close event.""" super().__init__( { "type": "bidi_connection_close", @@ -494,6 +505,7 @@ def __init__( error: Exception, details: Optional[Dict[str, Any]] = None, ): + """Initialize error event.""" # Store serializable data in dict (for JSON serialization) super().__init__( { From f343ea6a73087fc60ecb9f82acdc6ec21b6cb65f Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Mon, 17 Nov 2025 08:56:59 -0800 Subject: [PATCH 127/242] fix linting error codes - D205, G201, F841, D415 --- src/strands/experimental/bidi/io/audio.py | 4 ++-- src/strands/experimental/bidi/io/text.py | 2 +- src/strands/experimental/bidi/models/gemini_live.py | 5 +---- src/strands/experimental/bidi/scripts/test_gemini_live.py | 1 - src/strands/tools/caller.py | 8 +++++--- tests_integ/bidi/context.py | 2 +- tests_integ/bidi/test_bidirectional_agent.py | 1 - 7 files changed, 10 insertions(+), 13 deletions(-) diff --git a/src/strands/experimental/bidi/io/audio.py b/src/strands/experimental/bidi/io/audio.py index e0ebef070..68871f558 100644 --- a/src/strands/experimental/bidi/io/audio.py +++ b/src/strands/experimental/bidi/io/audio.py @@ -191,9 +191,9 @@ def __init__(self, **config: Any) -> None: self._config = config def input(self) -> _BidiAudioInput: - """Return audio processing BidiInput""" + """Return audio processing BidiInput.""" return _BidiAudioInput(self._config) def output(self) -> _BidiAudioOutput: - """Return audio processing BidiOutput""" + """Return audio processing BidiOutput.""" return _BidiAudioOutput(self._config) diff --git a/src/strands/experimental/bidi/io/text.py b/src/strands/experimental/bidi/io/text.py index 289003e02..21214e4a5 100644 --- a/src/strands/experimental/bidi/io/text.py +++ b/src/strands/experimental/bidi/io/text.py @@ -28,5 +28,5 @@ class BidiTextIO: """Handle text input and output from bidi agent.""" def output(self) -> _BidiTextOutput: - """Return text processing BidiOutput""" + """Return text processing BidiOutput.""" return _BidiTextOutput() diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 2f7c523ec..bff958f26 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -273,9 +273,6 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOut # Concatenate all text parts (Gemini may send multiple parts) text_parts = [] for part in model_turn.parts: - # Log all part types for debugging - part_attrs = {attr: getattr(part, attr, None) for attr in dir(part) if not attr.startswith("_")} - # Check if part has text attribute and it's not empty if hasattr(part, "text") and part.text: text_parts.append(part.text) @@ -366,7 +363,7 @@ async def send( self, content: BidiInputEvent | ToolResultEvent, ) -> None: - """Unified send method for all content types. Sends the given inputs to Google Live API + """Unified send method for all content types. Sends the given inputs to Google Live API. Dispatches to appropriate internal handler based on content type. diff --git a/src/strands/experimental/bidi/scripts/test_gemini_live.py b/src/strands/experimental/bidi/scripts/test_gemini_live.py index ee2010c6f..09f00abbf 100644 --- a/src/strands/experimental/bidi/scripts/test_gemini_live.py +++ b/src/strands/experimental/bidi/scripts/test_gemini_live.py @@ -165,7 +165,6 @@ async def receive(agent, context): elif event_type == "bidi_transcript_stream": transcript_text = event.get("text", "") transcript_role = event.get("role", "unknown") - is_final = event.get("is_final", False) # Print transcripts with special formatting if transcript_role == "user": diff --git a/src/strands/tools/caller.py b/src/strands/tools/caller.py index 9fe213fec..9c1c116b5 100644 --- a/src/strands/tools/caller.py +++ b/src/strands/tools/caller.py @@ -10,10 +10,12 @@ class _ToolCaller: - """Provides common tool calling functionality that can be used by both traditional - Agent and BidirectionalAgent classes with agent-specific customizations. + """Provides common tool calling functionality for Agent classes. - Automatically detects agent type and applies appropriate behavior: + Can be used by both traditional Agent and BidirectionalAgent classes with + agent-specific customizations. + + Automatically detects agent type and applies appropriate behavior: - Traditional agents: Uses conversation_manager.apply_management() """ diff --git a/tests_integ/bidi/context.py b/tests_integ/bidi/context.py index 978878bea..99760704b 100644 --- a/tests_integ/bidi/context.py +++ b/tests_integ/bidi/context.py @@ -337,7 +337,7 @@ async def _input_thread(self): logger.debug("Input thread cancelled") raise # Re-raise to properly propagate cancellation except Exception as e: - logger.error(f"Input thread error: {e}", exc_info=True) + logger.exception(f"Input thread error: {e}") finally: logger.debug(f"Input thread stopped, active={self.active}") diff --git a/tests_integ/bidi/test_bidirectional_agent.py b/tests_integ/bidi/test_bidirectional_agent.py index ed2f8d5f1..d948d615a 100644 --- a/tests_integ/bidi/test_bidirectional_agent.py +++ b/tests_integ/bidi/test_bidirectional_agent.py @@ -179,7 +179,6 @@ async def test_bidirectional_agent(agent_with_calculator, audio_generator, provi await ctx.wait_for_response() text_outputs_turn1 = ctx.get_text_outputs() - all_text_turn1 = " ".join(text_outputs_turn1).lower() # Validate turn 1 - just check we got a response assert len(text_outputs_turn1) > 0, f"[{provider_name}] No text output received in turn 1" From 7020875adcf123d874c03d134101699b3e8872d6 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Mon, 17 Nov 2025 09:04:30 -0800 Subject: [PATCH 128/242] fix formatting errors - B012, F821, E722 --- .../experimental/bidi/scripts/test_bidi_novasonic.py | 2 +- src/strands/experimental/bidi/scripts/test_bidi_openai.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/strands/experimental/bidi/scripts/test_bidi_novasonic.py b/src/strands/experimental/bidi/scripts/test_bidi_novasonic.py index d6ce5f0c7..1a95303ec 100644 --- a/src/strands/experimental/bidi/scripts/test_bidi_novasonic.py +++ b/src/strands/experimental/bidi/scripts/test_bidi_novasonic.py @@ -32,7 +32,7 @@ def test_direct_tools(): try: model = BidiNovaSonicModel() - agent = BidirectionalAgent(model=model, tools=[calculator]) + agent = BidiAgent(model=model, tools=[calculator]) # Test calculator result = agent.tool.calculator(expression="2 * 3") diff --git a/src/strands/experimental/bidi/scripts/test_bidi_openai.py b/src/strands/experimental/bidi/scripts/test_bidi_openai.py index 6e90aee32..6c29d1dfc 100644 --- a/src/strands/experimental/bidi/scripts/test_bidi_openai.py +++ b/src/strands/experimental/bidi/scripts/test_bidi_openai.py @@ -74,7 +74,7 @@ async def play(context): finally: try: speaker.close() - except: + except Exception: pass audio.terminate() @@ -107,7 +107,7 @@ async def record(context): finally: try: microphone.close() - except: + except Exception: pass audio.terminate() @@ -298,7 +298,7 @@ async def main(): except Exception as e: print(f"Cleanup error: {e}") - return True + return True if __name__ == "__main__": From 397408e14341c1796c32b731989f410cc2462486 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Mon, 17 Nov 2025 09:36:14 -0800 Subject: [PATCH 129/242] fix logging and error code G004 --- src/strands/experimental/bidi/agent/agent.py | 19 +++--- src/strands/experimental/bidi/agent/loop.py | 8 +-- src/strands/experimental/bidi/io/audio.py | 21 ++++++ src/strands/experimental/bidi/io/text.py | 13 +++- .../experimental/bidi/models/gemini_live.py | 34 +++++----- .../experimental/bidi/models/novasonic.py | 65 ++++++++++--------- .../experimental/bidi/models/openai.py | 52 ++++++++------- .../bidi/scripts/test_gemini_live.py | 2 +- tests_integ/bidi/context.py | 23 +++---- tests_integ/bidi/generators/audio.py | 8 +-- tests_integ/bidi/test_bidirectional_agent.py | 24 ++++--- 11 files changed, 158 insertions(+), 111 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index aa7a2105d..10de35f74 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -210,7 +210,7 @@ def _record_tool_execution( self.messages.append(tool_result_msg) self.messages.append(assistant_msg) - logger.debug("Direct tool call recorded in message history: %s", tool["name"]) + logger.debug("tool_name=<%s> | direct tool call recorded in message history", tool["name"]) def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]: """Filter input parameters to only include those defined in the tool specification. @@ -237,7 +237,7 @@ async def start(self) -> None: Initializes the streaming connection and starts background tasks for processing model events, tool execution, and connection management. """ - logger.debug("starting agent") + logger.debug("agent starting") await self._loop.start() @@ -272,7 +272,7 @@ async def send(self, input_data: BidiAgentInput) -> None: self.messages.append(user_message) - logger.debug("Text sent: %d characters", len(input_data)) + logger.debug("text_length=<%d> | text sent to model", len(input_data)) # Create BidiTextInputEvent for send() text_event = BidiTextInputEvent(text=input_data, role="user") await self.model.send(text_event) @@ -340,7 +340,7 @@ async def __aenter__(self) -> "BidiAgent": Returns: Self for use in the context. """ - logger.debug("Entering async context manager - starting connection") + logger.debug("context_manager= | starting connection") await self.start() return self @@ -356,16 +356,16 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: exc_tb: Exception traceback if an exception occurred, None otherwise. """ try: - logger.debug("Exiting async context manager - cleaning up adapters and connection") + logger.debug("context_manager= | cleaning up adapters and connection") # Cleanup adapters if any are currently active for adapter in self._current_adapters: if hasattr(adapter, "cleanup"): try: adapter.stop() - logger.debug(f"Cleaned up adapter: {type(adapter).__name__}") + logger.debug("adapter_type=<%s> | adapter cleaned up", type(adapter).__name__) except Exception as adapter_error: - logger.warning(f"Error cleaning up adapter: {adapter_error}") + logger.warning("adapter_error=<%s> | error cleaning up adapter", adapter_error) # Clear current adapters self._current_adapters = [] @@ -376,12 +376,13 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: except Exception as cleanup_error: if exc_type is None: # No original exception, re-raise cleanup error - logger.error("Error during context manager cleanup: %s", cleanup_error) + logger.error("cleanup_error=<%s> | error during context manager cleanup", cleanup_error) raise else: # Original exception exists, log cleanup error but don't suppress original logger.error( - "Error during context manager cleanup (suppressed due to original exception): %s", cleanup_error + "cleanup_error=<%s> | error during context manager cleanup suppressed due to original exception", + cleanup_error, ) @property diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index bb4c3a55e..5fac824c2 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -52,7 +52,7 @@ async def start(self) -> None: if self.active: return - logger.debug("starting agent loop") + logger.debug("agent loop starting") self._event_queue = asyncio.Queue(maxsize=1) self._stop_event = object() @@ -73,7 +73,7 @@ async def stop(self) -> None: if not self.active: return - logger.debug("stopping agent loop") + logger.debug("agent loop stopping") for task in self._tasks: task.cancel() @@ -120,7 +120,7 @@ async def _run_model(self) -> None: Events are streamed through the event queue. """ - logger.debug("running model") + logger.debug("model task starting") async for event in self._agent.model.receive(): await self._event_queue.put(event) @@ -139,7 +139,7 @@ async def _run_model(self) -> None: async def _run_tool(self, tool_use: ToolUse) -> None: """Task for running tool requested by the model.""" - logger.debug("running tool") + logger.debug("tool_name=<%s> | tool execution starting", tool_use["name"]) result: ToolResult = None diff --git a/src/strands/experimental/bidi/io/audio.py b/src/strands/experimental/bidi/io/audio.py index 68871f558..9a798a537 100644 --- a/src/strands/experimental/bidi/io/audio.py +++ b/src/strands/experimental/bidi/io/audio.py @@ -46,6 +46,12 @@ def __init__(self, config: dict[str, Any]) -> None: async def start(self) -> None: """Start input stream.""" + logger.debug( + "rate=<%d>, channels=<%d>, device_index=<%s> | starting audio input stream", + self._rate, + self._channels, + self._device_index, + ) self._audio = pyaudio.PyAudio() self._stream = self._audio.open( channels=self._channels, @@ -55,9 +61,11 @@ async def start(self) -> None: input_device_index=self._device_index, rate=self._rate, ) + logger.info("rate=<%d>, channels=<%d> | audio input stream started", self._rate, self._channels) async def stop(self) -> None: """Stop input stream.""" + logger.debug("stopping audio input stream") # TODO: Provide time for streaming thread to exit cleanly to prevent conflicts with the Nova threads. # See if we can remove after properly handling cancellation for agent. await asyncio.sleep(0.1) @@ -67,6 +75,7 @@ async def stop(self) -> None: self._stream = None self._audio = None + logger.debug("audio input stream stopped") async def __call__(self) -> BidiAudioInputEvent: """Read audio from input stream.""" @@ -115,6 +124,13 @@ def __init__(self, config: dict[str, Any]) -> None: async def start(self) -> None: """Start output stream.""" + logger.debug( + "rate=<%d>, channels=<%d>, device_index=<%s>, buffer_size=<%s> | starting audio output stream", + self._rate, + self._channels, + self._device_index, + self._buffer_size, + ) self._audio = pyaudio.PyAudio() self._stream = self._audio.open( channels=self._channels, @@ -127,9 +143,11 @@ async def start(self) -> None: self._buffer = deque(maxlen=self._buffer_size) self._buffer_event = asyncio.Event() self._output_task = asyncio.create_task(self._output()) + logger.info("rate=<%d>, channels=<%d> | audio output stream started", self._rate, self._channels) async def stop(self) -> None: """Stop output stream.""" + logger.debug("stopping audio output stream") self._buffer.clear() self._buffer.append(None) self._buffer_event.set() @@ -143,6 +161,7 @@ async def stop(self) -> None: self._buffer_event = None self._stream = None self._audio = None + logger.debug("audio output stream stopped") async def __call__(self, event: BidiOutputEvent) -> None: """Handle audio events with direct stream writing.""" @@ -150,8 +169,10 @@ async def __call__(self, event: BidiOutputEvent) -> None: audio_bytes = base64.b64decode(event["audio"]) self._buffer.append(audio_bytes) self._buffer_event.set() + logger.debug("audio_bytes=<%d> | audio chunk buffered for playback", len(audio_bytes)) elif isinstance(event, BidiInterruptionEvent): + logger.debug("reason=<%s> | clearing audio buffer due to interruption", event["reason"]) self._buffer.clear() self._buffer_event.clear() diff --git a/src/strands/experimental/bidi/io/text.py b/src/strands/experimental/bidi/io/text.py index 21214e4a5..18b39819d 100644 --- a/src/strands/experimental/bidi/io/text.py +++ b/src/strands/experimental/bidi/io/text.py @@ -14,11 +14,22 @@ class _BidiTextOutput(BidiOutput): async def __call__(self, event: BidiOutputEvent) -> None: """Print text events to stdout.""" if isinstance(event, BidiInterruptionEvent): + logger.debug("reason=<%s> | text output interrupted", event["reason"]) print("interrupted") elif isinstance(event, BidiTranscriptStreamEvent): text = event["text"] - if not event["is_final"]: + is_final = event["is_final"] + role = event["role"] + + logger.debug( + "role=<%s>, is_final=<%s>, text_length=<%d> | text transcript received", + role, + is_final, + len(text), + ) + + if not is_final: text = f"Preview: {text}" print(text) diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index bff958f26..e4d83ab2e 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -142,7 +142,7 @@ async def start( except Exception as e: self._active = False - logger.error("Error connecting to Gemini Live: %s", e) + logger.error("error=<%s> | error connecting to gemini live", e) raise async def _send_message_history(self, messages: Messages) -> None: @@ -188,15 +188,15 @@ async def receive(self) -> AsyncIterable[BidiOutputEvent]: # SDK exits receive loop after turn_complete - restart automatically if self._active: - logger.debug("Restarting receive loop after turn completion") + logger.debug("gemini receive loop restarting after turn completion") except Exception as e: - logger.error("Error in receive iteration: %s", e) + logger.error("error=<%s> | error in gemini receive iteration", e) # Small delay before retrying to avoid tight error loops await asyncio.sleep(0.1) except Exception as e: - logger.error("Fatal error in receive loop: %s", e) + logger.error("error=<%s> | fatal error in gemini receive loop", e) yield BidiErrorEvent(error=e) finally: # Emit connection close event when exiting @@ -226,7 +226,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOut if hasattr(input_transcript, "text") and input_transcript.text: transcription_text = input_transcript.text role = getattr(input_transcript, "role", "user") - logger.debug(f"Input transcription detected: {transcription_text}") + logger.debug("text_length=<%d> | gemini input transcription detected", len(transcription_text)) return [ BidiTranscriptStreamEvent( delta={"text": transcription_text}, @@ -244,7 +244,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOut if hasattr(output_transcript, "text") and output_transcript.text: transcription_text = output_transcript.text role = getattr(output_transcript, "role", "assistant") - logger.debug(f"Output transcription detected: {transcription_text}") + logger.debug("text_length=<%d> | gemini output transcription detected", len(transcription_text)) return [ BidiTranscriptStreamEvent( delta={"text": transcription_text}, @@ -353,9 +353,11 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOut return [] except Exception as e: - logger.error("Error converting Gemini Live event: %s", e) - logger.error("Message type: %s", type(message).__name__) - logger.error("Message attributes: %s", [attr for attr in dir(message) if not attr.startswith("_")]) + logger.error( + "error=<%s>, message_type=<%s> | error converting gemini live event", + e, + type(message).__name__, + ) # Return ErrorEvent in list so caller can handle it return [BidiErrorEvent(error=e)] @@ -385,9 +387,9 @@ async def send( if tool_result: await self._send_tool_result(tool_result) else: - logger.warning(f"Unknown content type: {type(content)}") + logger.warning("content_type=<%s> | unknown content type", type(content).__name__) except Exception as e: - logger.error(f"Error sending content: {e}") + logger.error("error=<%s> | error sending content to gemini live", e) raise # Propagate exception for debugging in experimental code async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: @@ -407,7 +409,7 @@ async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: await self.live_session.send_realtime_input(audio=audio_blob) except Exception as e: - logger.error("Error sending audio content: %s", e) + logger.error("error=<%s> | error sending audio content to gemini live", e) async def _send_image_content(self, image_input: BidiImageInputEvent) -> None: """Internal: Send image content using Gemini Live API. @@ -423,7 +425,7 @@ async def _send_image_content(self, image_input: BidiImageInputEvent) -> None: await self.live_session.send(input=msg) except Exception as e: - logger.error("Error sending image content: %s", e) + logger.error("error=<%s> | error sending image content to gemini live", e) async def _send_text_content(self, text: str) -> None: """Internal: Send text content using Gemini Live API.""" @@ -435,7 +437,7 @@ async def _send_text_content(self, text: str) -> None: await self.live_session.send_client_content(turns=content) except Exception as e: - logger.error("Error sending text content: %s", e) + logger.error("error=<%s> | error sending text content to gemini live", e) async def _send_tool_result(self, tool_result: ToolResult) -> None: """Internal: Send tool result using Gemini Live API.""" @@ -461,7 +463,7 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: # Send tool response await self.live_session.send_tool_response(function_responses=[func_response]) except Exception as e: - logger.error("Error sending tool result: %s", e) + logger.error("error=<%s> | error sending tool result to gemini live", e) async def stop(self) -> None: """Close Gemini Live API connection.""" @@ -475,7 +477,7 @@ async def stop(self) -> None: if self.live_session_context_manager: await self.live_session_context_manager.__aexit__(None, None, None) except Exception as e: - logger.error("Error closing Gemini Live connection: %s", e) + logger.error("error=<%s> | error closing gemini live connection", e) raise def _build_live_config( diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 22be7edb7..7051f4447 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -118,7 +118,7 @@ def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-e self._current_role = None self._generation_stage = None - logger.debug("Nova Sonic bidirectional model initialized: %s", model_id) + logger.debug("model_id=<%s> | nova sonic model initialized", model_id) async def start( self, @@ -138,7 +138,7 @@ async def start( if self._active: raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") - logger.debug("Nova connection create - starting") + logger.debug("nova connection starting") try: # Initialize client if needed @@ -160,20 +160,20 @@ async def start( logger.error("Stream is None") raise ValueError("Stream cannot be None") - logger.debug("Nova Sonic connection initialized with connection_id: %s", self.connection_id) + logger.debug("connection_id=<%s> | nova sonic connection initialized", self.connection_id) # Send initialization events system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." init_events = self._build_initialization_events(system_prompt, tools or [], messages) - logger.debug("Nova Sonic initialization - sending %d events", len(init_events)) + logger.debug("event_count=<%d> | sending nova sonic initialization events", len(init_events)) await self._send_initialization_events(init_events) - logger.info("Nova Sonic connection established successfully") + logger.info("connection_id=<%s> | nova sonic connection established", self.connection_id) except Exception as e: self._active = False - logger.error("Nova connection create error: %s", str(e)) + logger.error("error=<%s> | nova connection create failed", str(e)) raise def _build_initialization_events( @@ -198,16 +198,20 @@ async def _send_initialization_events(self, events: list[str]) -> None: def _log_event_type(self, nova_event: dict[str, any]) -> None: """Log specific Nova Sonic event types for debugging.""" if "usageEvent" in nova_event: - logger.debug("Nova usage: %s", nova_event["usageEvent"]) + logger.debug("usage=<%s> | nova usage event received", nova_event["usageEvent"]) elif "textOutput" in nova_event: - logger.debug("Nova text output") + logger.debug("nova text output received") elif "toolUse" in nova_event: tool_use = nova_event["toolUse"] - logger.debug("Nova tool use: %s (id: %s)", tool_use["toolName"], tool_use["toolUseId"]) + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | nova tool use received", + tool_use["toolName"], + tool_use["toolUseId"], + ) elif "audioOutput" in nova_event: audio_content = nova_event["audioOutput"]["content"] audio_bytes = base64.b64decode(audio_content) - logger.debug("Nova audio output: %d bytes", len(audio_bytes)) + logger.debug("audio_bytes=<%d> | nova audio output received", len(audio_bytes)) async def receive(self) -> AsyncIterable[dict[str, any]]: """Receive Nova Sonic events and convert to provider-agnostic format.""" @@ -215,7 +219,7 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: logger.error("Stream is None") return - logger.debug("Nova events - starting event stream") + logger.debug("nova event stream starting") # Emit connection start event yield BidiConnectionStartEvent(connection_id=self.connection_id, model=self.model_id) @@ -240,7 +244,7 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: continue except Exception as e: - logger.error("Error receiving Nova Sonic event: %s", e) + logger.error("error=<%s> | error receiving nova sonic event", e) logger.error(traceback.format_exc()) yield BidiErrorEvent(error=e) finally: @@ -274,9 +278,9 @@ async def send( if tool_result: await self._send_tool_result(tool_result) else: - logger.warning(f"Unknown content type: {type(content)}") + logger.warning("content_type=<%s> | unknown content type", type(content).__name__) except Exception as e: - logger.error(f"Error sending content: {e}") + logger.error("error=<%s> | error sending content to nova sonic", e) raise # Propagate exception for debugging in experimental code async def _start_audio_connection(self) -> None: @@ -284,7 +288,7 @@ async def _start_audio_connection(self) -> None: if self.audio_connection_active: return - logger.debug("Nova audio connection start") + logger.debug("nova audio connection starting") audio_content_start = json.dumps( { @@ -331,7 +335,7 @@ async def _end_audio_input(self) -> None: if not self.audio_connection_active: return - logger.debug("Nova audio connection end") + logger.debug("nova audio connection ending") audio_content_end = json.dumps( {"event": {"contentEnd": {"promptName": self.connection_id, "contentName": self.audio_content_name}}} @@ -372,7 +376,7 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: """Internal: Send tool result using Nova Sonic toolResult format.""" tool_use_id = tool_result.get("toolUseId") - logger.debug("Nova tool result send: %s", tool_use_id) + logger.debug("tool_use_id=<%s> | sending nova tool result", tool_use_id) # Extract result content result_data = {} @@ -398,7 +402,7 @@ async def stop(self) -> None: if not self._active: return - logger.debug("Nova cleanup - starting connection close") + logger.debug("nova connection cleanup starting") self._active = False try: @@ -413,18 +417,18 @@ async def stop(self) -> None: try: await self._send_nova_event(event) except Exception as e: - logger.warning("Error during Nova Sonic cleanup: %s", e) + logger.warning("error=<%s> | error during nova sonic cleanup", e) # Close stream try: await self.stream.input_stream.close() except Exception as e: - logger.warning("Error closing Nova Sonic stream: %s", e) + logger.warning("error=<%s> | error closing nova sonic stream", e) except Exception as e: - logger.error("Nova cleanup error: %s", str(e)) + logger.error("error=<%s> | nova cleanup failed", str(e)) finally: - logger.debug("Nova connection closed") + logger.debug("nova connection closed") def _convert_nova_event(self, nova_event: dict[str, any]) -> BidiOutputEvent | None: """Convert Nova Sonic events to TypedEvent format.""" @@ -432,7 +436,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> BidiOutputEvent | N if "completionStart" in nova_event: completion_data = nova_event["completionStart"] self._current_completion_id = completion_data.get("completionId") - logger.debug("Nova completion started: %s", self._current_completion_id) + logger.debug("completion_id=<%s> | nova completion started", self._current_completion_id) return None # Handle completion end @@ -463,7 +467,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> BidiOutputEvent | N text_content = nova_event["textOutput"]["content"] # Check for Nova Sonic interruption pattern if '{ "interrupted" : true }' in text_content: - logger.debug("Nova interruption detected in text") + logger.debug("nova interruption detected in text output") return BidiInterruptionEvent(reason="user_speech") return BidiTranscriptStreamEvent( @@ -487,7 +491,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> BidiOutputEvent | N # Handle interruption if nova_event.get("stopReason") == "INTERRUPTED": - logger.debug("Nova interruption stop reason") + logger.debug("nova interruption detected via stop reason") return BidiInterruptionEvent(reason="user_speech") # Handle usage events - convert to multimodal usage format @@ -646,11 +650,10 @@ async def _send_nova_event(self, event: str) -> None: bytes_data = event.encode("utf-8") chunk = InvokeModelWithBidirectionalStreamInputChunk(value=BidirectionalInputPayloadPart(bytes_=bytes_data)) await self.stream.input_stream.send(chunk) - logger.debug("Successfully sent Nova Sonic event") + logger.debug("nova sonic event sent successfully") except Exception as e: - logger.error("Error sending Nova Sonic event: %s", e) - logger.error("Event was: %s", event) + logger.error("error=<%s>, event=<%s> | error sending nova sonic event", e, event[:100]) raise async def _initialize_client(self) -> None: @@ -665,11 +668,11 @@ async def _initialize_client(self) -> None: ) self.client = BedrockRuntimeClient(config=config) - logger.debug("Nova Sonic client initialized") + logger.debug("region=<%s> | nova sonic client initialized", self.region) except ImportError as e: - logger.error("Nova Sonic dependencies not available: %s", e) + logger.error("error=<%s> | nova sonic dependencies not available", e) raise except Exception as e: - logger.error("Error initializing Nova Sonic client: %s", e) + logger.error("error=<%s> | error initializing nova sonic client", e) raise diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 7f7ce2eb6..1564e5d68 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -109,7 +109,7 @@ def __init__( self._function_call_buffer = {} - logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) + logger.debug("model=<%s> | openai realtime model initialized", model) async def start( self, @@ -129,7 +129,7 @@ async def start( if self._active: raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") - logger.info("Creating OpenAI Realtime connection...") + logger.info("openai realtime connection starting") try: # Initialize connection state @@ -147,7 +147,7 @@ async def start( headers.append(("OpenAI-Project", self.project)) self.websocket = await websockets.connect(url, additional_headers=headers) - logger.info("WebSocket connected successfully") + logger.info("connection_id=<%s> | websocket connected successfully", self.connection_id) # Configure session session_config = self._build_session_config(system_prompt, tools) @@ -159,7 +159,7 @@ async def start( except Exception as e: self._active = False - logger.error("OpenAI connection error: %s", e) + logger.error("error=<%s> | openai connection failed", e) raise def _require_active(self) -> bool: @@ -224,7 +224,7 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] if key in supported_params: config[key] = value else: - logger.warning("Ignoring unsupported session parameter: %s", key) + logger.warning("parameter=<%s> | ignoring unsupported session parameter", key) return config @@ -289,7 +289,7 @@ async def receive(self) -> AsyncIterable[BidiOutputEvent]: yield event except Exception as e: - logger.error("Error receiving OpenAI Realtime event: %s", e) + logger.error("error=<%s> | error receiving openai realtime event", e) yield BidiErrorEvent(error=e) finally: # Emit connection close event @@ -359,7 +359,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[BidiOutput elif event_type == "conversation.item.input_audio_transcription.failed": error_info = openai_event.get("error", {}) - logger.warning("OpenAI transcription failed: %s", error_info.get("message", "Unknown error")) + logger.warning("error=<%s> | openai transcription failed", error_info.get("message", "unknown error")) return None # Function call processing @@ -387,7 +387,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[BidiOutput # Return ToolUseStreamEvent for consistency with standard agent return [ToolUseStreamEvent(delta={"toolUse": tool_use}, current_tool_use=tool_use)] except (json.JSONDecodeError, KeyError) as e: - logger.warning("Error parsing function arguments for %s: %s", call_id, e) + logger.warning("call_id=<%s>, error=<%s> | error parsing function arguments", call_id, e) del self._function_call_buffer[call_id] return None @@ -400,7 +400,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[BidiOutput elif event_type == "response.cancelled": response = openai_event.get("response", {}) response_id = response.get("id", "unknown") - logger.debug("OpenAI response cancelled: %s", response_id) + logger.debug("response_id=<%s> | openai response cancelled", response_id) return [BidiResponseCompleteEvent(response_id=response_id, stop_reason="interrupted")] # Turn complete and usage - response finished @@ -476,11 +476,11 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[BidiOutput elif event_type in ["conversation.item.retrieve", "conversation.item.added"]: item = openai_event.get("item", {}) action = "retrieved" if "retrieve" in event_type else "added" - logger.debug("OpenAI conversation item %s: %s", action, item.get("id")) + logger.debug("action=<%s>, item_id=<%s> | openai conversation item event", action, item.get("id")) return None elif event_type == "conversation.item.done": - logger.debug("OpenAI conversation item done: %s", openai_event.get("item", {}).get("id")) + logger.debug("item_id=<%s> | openai conversation item done", openai_event.get("item", {}).get("id")) return None # Response output events - combine similar events @@ -491,7 +491,11 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[BidiOutput "response.content_part.done", ]: item_data = openai_event.get("item") or openai_event.get("part") - logger.debug("OpenAI %s: %s", event_type, item_data.get("id") if item_data else "unknown") + logger.debug( + "event_type=<%s>, item_id=<%s> | openai output event", + event_type, + item_data.get("id") if item_data else "unknown", + ) # Track function call names from response.output_item.added if event_type == "response.output_item.added": @@ -517,7 +521,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[BidiOutput "session.created", "session.updated", ]: - logger.debug("OpenAI %s event", event_type) + logger.debug("event_type=<%s> | openai event received", event_type) return None elif event_type == "error": @@ -528,15 +532,15 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[BidiOutput if error_code == "response_cancel_not_active": # This happens when trying to cancel a response that's not active # It's safe to ignore as the session remains functional - logger.debug("OpenAI response cancel attempted when no response active (safe to ignore)") + logger.debug("openai response cancel attempted when no response active") return None # Log other errors - logger.error("OpenAI Realtime error: %s", error_data) + logger.error("error=<%s> | openai realtime error", error_data) return None else: - logger.debug("Unhandled OpenAI event type: %s", event_type) + logger.debug("event_type=<%s> | unhandled openai event type", event_type) return None async def send( @@ -567,9 +571,9 @@ async def send( if tool_result: await self._send_tool_result(tool_result) else: - logger.warning(f"Unknown content type: {type(content).__name__}") + logger.warning("content_type=<%s> | unknown content type", type(content).__name__) except Exception as e: - logger.error(f"Error sending content: {e}") + logger.error("error=<%s> | error sending content to openai", e) raise # Propagate exception for debugging in experimental code async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: @@ -591,7 +595,7 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: """Internal: Send tool result back to OpenAI.""" tool_use_id = tool_result.get("toolUseId") - logger.debug("OpenAI tool result send: %s", tool_use_id) + logger.debug("tool_use_id=<%s> | sending openai tool result", tool_use_id) # Extract result content result_data = {} @@ -613,22 +617,22 @@ async def stop(self) -> None: if not self._active: return - logger.debug("OpenAI Realtime cleanup - starting connection close") + logger.debug("openai realtime connection cleanup starting") self._active = False try: await self.websocket.close() except Exception as e: - logger.warning("Error closing OpenAI Realtime WebSocket: %s", e) + logger.warning("error=<%s> | error closing openai realtime websocket", e) - logger.debug("OpenAI Realtime connection closed") + logger.debug("openai realtime connection closed") async def _send_event(self, event: dict[str, any]) -> None: """Send event to OpenAI via WebSocket.""" try: message = json.dumps(event) await self.websocket.send(message) - logger.debug("Sent OpenAI event: %s", event.get("type")) + logger.debug("event_type=<%s> | openai event sent", event.get("type")) except Exception as e: - logger.error("Error sending OpenAI event: %s", e) + logger.error("error=<%s> | error sending openai event", e) raise diff --git a/src/strands/experimental/bidi/scripts/test_gemini_live.py b/src/strands/experimental/bidi/scripts/test_gemini_live.py index 09f00abbf..489bcd6ac 100644 --- a/src/strands/experimental/bidi/scripts/test_gemini_live.py +++ b/src/strands/experimental/bidi/scripts/test_gemini_live.py @@ -256,7 +256,7 @@ async def get_frames(context): await context["agent"].send(image_event) print("📸 Frame sent to model") except Exception as e: - logger.error(f"Error sending frame: {e}") + logger.error("error=<%s> | error sending frame", e) # Wait 1 second between frames (1 FPS) await asyncio.sleep(1.0) diff --git a/tests_integ/bidi/context.py b/tests_integ/bidi/context.py index 99760704b..adca6ee5b 100644 --- a/tests_integ/bidi/context.py +++ b/tests_integ/bidi/context.py @@ -147,7 +147,7 @@ async def say(self, text: str): chunk_event = self.audio_generator.create_audio_input_event(chunk) await self.input_queue.put({"type": "audio_chunk", "data": chunk_event}) - logger.debug(f"Queued {len(audio_data)} bytes of audio for: {text[:50]}...") + logger.debug("audio_bytes=<%d>, text_preview=<%s> | queued audio for text", len(audio_data), text[:50]) async def send(self, data: str | dict) -> None: """Send data directly to model (text, image, etc.). @@ -158,7 +158,7 @@ async def send(self, data: str | dict) -> None: - dict: Custom event (e.g., image, audio) """ await self.input_queue.put({"type": "direct", "data": data}) - logger.debug(f"Queued direct send: {type(data).__name__}") + logger.debug("data_type=<%s> | queued direct send", type(data).__name__) async def wait_for_response( self, @@ -189,14 +189,15 @@ async def wait_for_response( elapsed_since_event = time.monotonic() - self.last_event_time if elapsed_since_event >= silence_threshold: logger.debug( - f"Response complete: {len(current_events) - initial_event_count} events, " - f"{elapsed_since_event:.1f}s silence" + "event_count=<%d>, silence_duration=<%.1f> | response complete", + len(current_events) - initial_event_count, + elapsed_since_event, ) return await asyncio.sleep(WAIT_POLL_INTERVAL) - logger.warning(f"Response timeout after {timeout}s") + logger.warning("timeout=<%s> | response timeout", timeout) def get_events(self, event_type: str | None = None) -> list[dict]: """Get collected events, optionally filtered by type. @@ -305,7 +306,7 @@ async def _input_thread(self): - Sends direct data to model """ try: - logger.debug(f"Input thread starting, active={self.active}") + logger.debug("active=<%s> | input thread starting", self.active) while self.active: try: # Check for queued input (non-blocking with short timeout) @@ -324,7 +325,7 @@ async def _input_thread(self): if isinstance(input_item["data"], str) else type(input_item["data"]).__name__ ) - logger.debug(f"Sent direct: {data_repr}") + logger.debug("data=<%s> | sent direct data", data_repr) except asyncio.TimeoutError: # No input queued - send silence chunk to simulate continuous microphone input @@ -337,9 +338,9 @@ async def _input_thread(self): logger.debug("Input thread cancelled") raise # Re-raise to properly propagate cancellation except Exception as e: - logger.exception(f"Input thread error: {e}") + logger.exception("error=<%s> | input thread error", e) finally: - logger.debug(f"Input thread stopped, active={self.active}") + logger.debug("active=<%s> | input thread stopped", self.active) async def _event_collection_thread(self): """Continuously collect events from model.""" @@ -350,13 +351,13 @@ async def _event_collection_thread(self): # Thread-safe: put in queue instead of direct append await self._event_queue.put(event) - logger.debug(f"Event collected: {list(event.keys())}") + logger.debug("event_type=<%s> | event collected", event.get("type", "unknown")) except asyncio.CancelledError: logger.debug("Event collection thread cancelled") raise # Re-raise to properly propagate cancellation except Exception as e: - logger.error(f"Event collection thread error: {e}") + logger.error("error=<%s> | event collection thread error", e) def _generate_silence_chunk(self) -> dict: """Generate silence chunk for background audio. diff --git a/tests_integ/bidi/generators/audio.py b/tests_integ/bidi/generators/audio.py index 8f2a9929f..4598817fd 100644 --- a/tests_integ/bidi/generators/audio.py +++ b/tests_integ/bidi/generators/audio.py @@ -74,11 +74,11 @@ async def generate_audio( cache_path = self._get_cache_path(cache_key) if cache_path.exists(): - logger.debug(f"Using cached audio for: {text[:50]}...") + logger.debug("text_preview=<%s> | using cached audio", text[:50]) return cache_path.read_bytes() # Generate audio with Polly - logger.debug(f"Generating audio with Polly: {text[:50]}...") + logger.debug("text_preview=<%s> | generating audio with polly", text[:50]) try: response = self.polly_client.synthesize_speech( @@ -95,12 +95,12 @@ async def generate_audio( # Cache for future use if use_cache: cache_path.write_bytes(audio_data) - logger.debug(f"Cached audio: {cache_path}") + logger.debug("cache_path=<%s> | cached audio", cache_path) return audio_data except Exception as e: - logger.error(f"Polly audio generation failed: {e}") + logger.error("error=<%s> | polly audio generation failed", e) raise def create_audio_input_event( diff --git a/tests_integ/bidi/test_bidirectional_agent.py b/tests_integ/bidi/test_bidirectional_agent.py index d948d615a..0d3b41607 100644 --- a/tests_integ/bidi/test_bidirectional_agent.py +++ b/tests_integ/bidi/test_bidirectional_agent.py @@ -169,7 +169,7 @@ async def test_bidirectional_agent(agent_with_calculator, audio_generator, provi provider_name = provider_config["name"] silence_duration = provider_config["silence_duration"] - logger.info(f"Testing provider: {provider_name}") + logger.info("provider=<%s> | testing provider", provider_name) async with BidirectionalTestContext(agent_with_calculator, audio_generator) as ctx: # Turn 1: Simple greeting to test basic audio I/O @@ -183,8 +183,8 @@ async def test_bidirectional_agent(agent_with_calculator, audio_generator, provi # Validate turn 1 - just check we got a response assert len(text_outputs_turn1) > 0, f"[{provider_name}] No text output received in turn 1" - logger.info(f"[{provider_name}] ✓ Turn 1 complete: received response") - logger.info(f"[{provider_name}] Response: {text_outputs_turn1[0][:100]}...") + logger.info("provider=<%s> | turn 1 complete received response", provider_name) + logger.info("provider=<%s>, response=<%s> | turn 1 response", provider_name, text_outputs_turn1[0][:100]) # Turn 2: Follow-up to test multi-turn conversation await ctx.say("What's your name?") @@ -197,8 +197,8 @@ async def test_bidirectional_agent(agent_with_calculator, audio_generator, provi # Validate turn 2 - check we got more responses assert len(text_outputs_turn2) > len(text_outputs_turn1), f"[{provider_name}] No new text output in turn 2" - logger.info(f"[{provider_name}] ✓ Turn 2 complete: multi-turn conversation works") - logger.info(f"[{provider_name}] Total responses: {len(text_outputs_turn2)}") + logger.info("provider=<%s> | turn 2 complete multi-turn conversation works", provider_name) + logger.info("provider=<%s>, response_count=<%d> | total responses", provider_name, len(text_outputs_turn2)) # Validate full conversation # Validate audio outputs @@ -208,9 +208,13 @@ async def test_bidirectional_agent(agent_with_calculator, audio_generator, provi # Summary logger.info("=" * 60) - logger.info(f"[{provider_name}] ✓ Multi-turn conversation test PASSED") - logger.info(f" Provider: {provider_name}") - logger.info(f" Total events: {len(ctx.get_events())}") - logger.info(f" Text responses: {len(text_outputs_turn2)}") - logger.info(f" Audio chunks: {len(audio_outputs)} ({total_audio_bytes:,} bytes)") + logger.info("provider=<%s> | multi-turn conversation test passed", provider_name) + logger.info("provider=<%s> | test summary", provider_name) + logger.info("event_count=<%d> | total events", len(ctx.get_events())) + logger.info("text_response_count=<%d> | text responses", len(text_outputs_turn2)) + logger.info( + "audio_chunk_count=<%d>, audio_bytes=<%d> | audio chunks", + len(audio_outputs), + total_audio_bytes, + ) logger.info("=" * 60) From 0e6221a2e5c73f6a50f36f9967d3e7cdb23eb27f Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Mon, 17 Nov 2025 10:40:53 -0800 Subject: [PATCH 130/242] fix hatch-static-analysis in github workflow --- .github/workflows/test-lint.yml | 20 +++++++++++++++++++ pyproject.toml | 12 ++++++++++- .../experimental/bidi/scripts/test_bidi.py | 4 ---- .../bidi/scripts/test_bidi_novasonic.py | 5 ----- .../bidi/scripts/test_bidi_openai.py | 5 ----- .../bidi/scripts/test_gemini_live.py | 5 ----- 6 files changed, 31 insertions(+), 20 deletions(-) diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml index e38942b2c..4986acf1f 100644 --- a/.github/workflows/test-lint.yml +++ b/.github/workflows/test-lint.yml @@ -59,6 +59,20 @@ jobs: uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} + - name: Install system audio dependencies (Linux) + if: matrix.os-name == 'linux' + run: | + sudo apt-get update + sudo apt-get install -y portaudio19-dev libasound2-dev + - name: Install system audio dependencies (macOS) + if: matrix.os-name == 'macOS' + run: | + brew install portaudio + - name: Install system audio dependencies (Windows) + if: matrix.os-name == 'windows' + run: | + # Windows typically has audio libraries available by default + echo "Windows audio dependencies handled by PyAudio wheels" - name: Install dependencies run: | pip install --no-cache-dir hatch @@ -89,6 +103,11 @@ jobs: python-version: '3.10' cache: 'pip' + - name: Install system audio dependencies (Linux) + run: | + sudo apt-get update + sudo apt-get install -y portaudio19-dev libasound2-dev + - name: Install dependencies run: | pip install --no-cache-dir hatch @@ -97,3 +116,4 @@ jobs: id: lint run: hatch fmt --linter --check continue-on-error: false + diff --git a/pyproject.toml b/pyproject.toml index f5311c299..97b9cf5f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -129,7 +129,9 @@ source = "vcs" # Use git tags for versioning [tool.hatch.envs.hatch-static-analysis] installer = "uv" -features = ["all", "bidi-all"] +# Only install 'all' features, not 'bidi-all' which requires Python 3.12+ +# The bidi code will still be type-checked, but without installing its runtime dependencies +features = ["all"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.13.0,<0.14.0", @@ -221,6 +223,14 @@ warn_no_return = true warn_unreachable = true follow_untyped_imports = true ignore_missing_imports = false +# Ignore missing imports for optional bidi dependencies (not installed in lint environment) +[[tool.mypy.overrides]] +module = [ + "smithy_core.*", + "smithy_aws_core.*", + "aws_sdk_bedrock_runtime.*", +] +ignore_missing_imports = true [tool.ruff] diff --git a/src/strands/experimental/bidi/scripts/test_bidi.py b/src/strands/experimental/bidi/scripts/test_bidi.py index 85480bfaa..f7447871c 100644 --- a/src/strands/experimental/bidi/scripts/test_bidi.py +++ b/src/strands/experimental/bidi/scripts/test_bidi.py @@ -1,10 +1,6 @@ """Test BidirectionalAgent with simple developer experience.""" import asyncio -import sys -from pathlib import Path - -sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) from strands_tools import calculator diff --git a/src/strands/experimental/bidi/scripts/test_bidi_novasonic.py b/src/strands/experimental/bidi/scripts/test_bidi_novasonic.py index 1a95303ec..2ed62e455 100644 --- a/src/strands/experimental/bidi/scripts/test_bidi_novasonic.py +++ b/src/strands/experimental/bidi/scripts/test_bidi_novasonic.py @@ -6,11 +6,6 @@ import asyncio import base64 -import sys -from pathlib import Path - -# Add the src directory to Python path -sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) import os import time diff --git a/src/strands/experimental/bidi/scripts/test_bidi_openai.py b/src/strands/experimental/bidi/scripts/test_bidi_openai.py index 6c29d1dfc..807629feb 100644 --- a/src/strands/experimental/bidi/scripts/test_bidi_openai.py +++ b/src/strands/experimental/bidi/scripts/test_bidi_openai.py @@ -4,12 +4,7 @@ import asyncio import base64 import os -import sys import time -from pathlib import Path - -# Add the src directory to Python path -sys.path.insert(0, str(Path(__file__).parent / "src")) import pyaudio from strands_tools import calculator diff --git a/src/strands/experimental/bidi/scripts/test_gemini_live.py b/src/strands/experimental/bidi/scripts/test_gemini_live.py index 489bcd6ac..31dfc6af0 100644 --- a/src/strands/experimental/bidi/scripts/test_gemini_live.py +++ b/src/strands/experimental/bidi/scripts/test_gemini_live.py @@ -18,11 +18,6 @@ import io import logging import os -import sys -from pathlib import Path - -# Add the src directory to Python path -sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) import time try: From 46319e521b73e21f0f8e694b8bfd242a0b91752b Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 17 Nov 2025 15:30:01 -0500 Subject: [PATCH 131/242] nova - send event lock (#52) --- src/strands/agent/agent.py | 3 +- src/strands/experimental/bidi/agent/agent.py | 13 ++-- .../experimental/bidi/models/novasonic.py | 59 ++++++++++--------- tests/strands/agent/test_agent.py | 4 +- 4 files changed, 38 insertions(+), 41 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 4273d737b..d3d45013a 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -11,7 +11,6 @@ import json import logging -import random import warnings from typing import ( TYPE_CHECKING, @@ -57,7 +56,7 @@ from ..tools.registry import ToolRegistry from ..tools.structured_output._structured_output_context import StructuredOutputContext from ..tools.watcher import ToolWatcher -from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, ToolInterruptEvent, TypedEvent +from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent from ..types.agent import AgentInput from ..types.content import ContentBlock, Message, Messages, SystemContentBlock from ..types.exceptions import ContextWindowOverflowException diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 10de35f74..f38414fd5 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -412,16 +412,13 @@ async def run_inputs(): event = await input_() await self.send(event) - # TODO: Need to make tool result send in Nova provider atomic. - # Audio input events end up interleaving and leading to failures. - # Adding a sleep here as a temporary solution. - await asyncio.sleep(0.001) - async def run_outputs(): async for event in self.receive(): for output in outputs: await output(event) + await self.start() + for input_ in inputs: if hasattr(input_, "start"): await input_.start() @@ -430,14 +427,10 @@ async def run_outputs(): if hasattr(output, "start"): await output.start() - # Start agent after all IO is ready - await self.start() try: await asyncio.gather(run_inputs(), run_outputs(), return_exceptions=True) finally: - await self.stop() - for input_ in inputs: if hasattr(input_, "stop"): await input_.stop() @@ -446,6 +439,8 @@ async def run_outputs(): if hasattr(output, "stop"): await output.stop() + await self.stop() + def _validate_active_connection(self) -> None: """Validate that an active connection exists. diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 7051f4447..eeaa1d659 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -77,7 +77,6 @@ NOVA_TOOL_CONFIG = {"mediaType": "application/json"} # Timing constants -EVENT_DELAY = 0.1 RESPONSE_TIMEOUT = 1.0 @@ -118,6 +117,9 @@ def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-e self._current_role = None self._generation_stage = None + # Ensure certain events are sent in sequence when required + self._send_lock = asyncio.Lock() + logger.debug("model_id=<%s> | nova sonic model initialized", model_id) async def start( @@ -190,10 +192,8 @@ def _build_initialization_events( return events async def _send_initialization_events(self, events: list[str]) -> None: - """Send initialization events with required delays.""" - for _i, event in enumerate(events): - await self._send_nova_event(event) - await asyncio.sleep(EVENT_DELAY) + """Send initialization events.""" + await self._send_nova_event(events) def _log_event_type(self, nova_event: dict[str, any]) -> None: """Log specific Nova Sonic event types for debugging.""" @@ -305,7 +305,7 @@ async def _start_audio_connection(self) -> None: } ) - await self._send_nova_event(audio_content_start) + await self._send_nova_event([audio_content_start]) self.audio_connection_active = True async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: @@ -328,7 +328,7 @@ async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: } ) - await self._send_nova_event(audio_event) + await self._send_nova_event([audio_event]) async def _end_audio_input(self) -> None: """Internal: End current audio input connection to trigger Nova Sonic processing.""" @@ -341,7 +341,7 @@ async def _end_audio_input(self) -> None: {"event": {"contentEnd": {"promptName": self.connection_id, "contentName": self.audio_content_name}}} ) - await self._send_nova_event(audio_content_end) + await self._send_nova_event([audio_content_end]) self.audio_connection_active = False async def _send_text_content(self, text: str) -> None: @@ -352,9 +352,7 @@ async def _send_text_content(self, text: str) -> None: self._get_text_input_event(content_name, text), self._get_content_end_event(content_name), ] - - for event in events: - await self._send_nova_event(event) + await self._send_nova_event(events) async def _send_interrupt(self) -> None: """Internal: Send interruption signal to Nova Sonic.""" @@ -370,7 +368,7 @@ async def _send_interrupt(self) -> None: } } ) - await self._send_nova_event(interrupt_event) + await self._send_nova_event([interrupt_event]) async def _send_tool_result(self, tool_result: ToolResult) -> None: """Internal: Send tool result using Nova Sonic toolResult format.""" @@ -393,9 +391,7 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: self._get_tool_result_event(content_name, result_data), self._get_content_end_event(content_name), ] - - for event in events: - await self._send_nova_event(event) + await self._send_nova_event(events) async def stop(self) -> None: """Close Nova Sonic connection with proper cleanup sequence.""" @@ -412,12 +408,10 @@ async def stop(self) -> None: # Send cleanup events cleanup_events = [self._get_prompt_end_event(), self._get_connection_end_event()] - - for event in cleanup_events: - try: - await self._send_nova_event(event) - except Exception as e: - logger.warning("error=<%s> | error during nova sonic cleanup", e) + try: + await self._send_nova_event(cleanup_events) + except Exception as e: + logger.warning("error=<%s> | error during nova sonic cleanup", e) # Close stream try: @@ -643,14 +637,23 @@ def _get_connection_end_event(self) -> str: """Generate connection end event.""" return json.dumps({"event": {"connectionEnd": {}}}) - async def _send_nova_event(self, event: str) -> None: - """Send event JSON string to Nova Sonic stream.""" + async def _send_nova_event(self, events: list[str]) -> None: + """Send event JSON string to Nova Sonic stream. + + A lock is used to send events in sequence when required (e.g., tool result start, content, and end). + + Args: + events: Jsonified event. + """ try: - # Event is already a JSON string - bytes_data = event.encode("utf-8") - chunk = InvokeModelWithBidirectionalStreamInputChunk(value=BidirectionalInputPayloadPart(bytes_=bytes_data)) - await self.stream.input_stream.send(chunk) - logger.debug("nova sonic event sent successfully") + async with self._send_lock: + for event in events: + bytes_data = event.encode("utf-8") + chunk = InvokeModelWithBidirectionalStreamInputChunk( + value=BidirectionalInputPayloadPart(bytes_=bytes_data) + ) + await self.stream.input_stream.send(chunk) + logger.debug("nova sonic event sent successfully") except Exception as e: logger.error("error=<%s>, event=<%s> | error sending nova sonic event", e, event[:100]) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 3a0bc2dfb..550422cfe 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -2240,8 +2240,8 @@ def test_agent_backwards_compatibility_single_text_block(): # Should extract text for backwards compatibility assert agent.system_prompt == text - - + + @pytest.mark.parametrize( "content, expected", [ From 315ca240150aa4a1d1026be0f3eae1b3c1ec5db7 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 17 Nov 2025 15:51:58 -0500 Subject: [PATCH 132/242] agent - run - run input/output concurrently (#49) --- src/strands/experimental/bidi/agent/agent.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index f38414fd5..ac015e43a 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -399,23 +399,26 @@ async def run(self, inputs: list[BidiInput], outputs: list[BidiOutput]) -> None: Example: ```python - audio_io = BidiAudioIO(audio_config={"input_sample_rate": 16000}) + audio_io = BidiAudioIO(input_rate=16000) text_io = BidiTextIO() agent = BidiAgent(model=model, tools=[calculator]) await agent.run(inputs=[audio_io.input()], outputs=[audio_io.output(), text_io.output()]) ``` """ - async def run_inputs(): - while self.active: - for input_ in inputs: + async def run_inputs() -> None: + async def task(input_: BidiInput) -> None: + while self.active: event = await input_() await self.send(event) - async def run_outputs(): + tasks = [task(input_) for input_ in inputs] + await asyncio.gather(*tasks) + + async def run_outputs() -> None: async for event in self.receive(): - for output in outputs: - await output(event) + tasks = [output(event) for output in outputs] + await asyncio.gather(*tasks) await self.start() @@ -428,7 +431,7 @@ async def run_outputs(): await output.start() try: - await asyncio.gather(run_inputs(), run_outputs(), return_exceptions=True) + await asyncio.gather(run_inputs(), run_outputs()) finally: for input_ in inputs: From 4f3f55867049fb6096a42ad9520bce028a3888ac Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 18 Nov 2025 10:13:21 +0100 Subject: [PATCH 133/242] feat(bidi): Add agent state --- src/strands/experimental/bidi/agent/agent.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index ac015e43a..c58262910 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -22,6 +22,7 @@ from ....tools.executors import ConcurrentToolExecutor from ....tools.executors._executor import ToolExecutor from ....tools.registry import ToolRegistry +from ....agent.state import AgentState from ....tools.watcher import ToolWatcher from ....types.content import Message, Messages from ....types.tools import AgentTool, ToolResult, ToolUse @@ -58,6 +59,7 @@ def __init__( name: str | None = None, tool_executor: ToolExecutor | None = None, description: str | None = None, + state: AgentState | dict | None = None, **kwargs: Any, ): """Initialize bidirectional agent. @@ -73,10 +75,11 @@ def __init__( name: Name of the Agent. tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). description: Description of what the Agent does. + state: Stateful information for the agent. Can be either an AgentState object, or a json serializable dict. **kwargs: Additional configuration for future extensibility. Raises: - ValueError: If model configuration is invalid. + ValueError: If model configuration is invalid or state is invalid type. TypeError: If model type is unsupported. """ self.model = ( @@ -113,6 +116,17 @@ def __init__( # Initialize tool executor self.tool_executor = tool_executor or ConcurrentToolExecutor() + # Initialize agent state management + if state is not None: + if isinstance(state, dict): + self.state = AgentState(state) + elif isinstance(state, AgentState): + self.state = state + else: + raise ValueError("state must be an AgentState object or a dict") + else: + self.state = AgentState() + # Initialize other components self._tool_caller = _ToolCaller(self) From d31bc1b494692c9fe78eb94fcaef01ab235d109c Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 18 Nov 2025 13:18:53 +0100 Subject: [PATCH 134/242] Merge remote-tracking branch 'upstream/main' --- src/strands/_async.py | 4 +- src/strands/agent/agent.py | 107 +++-- src/strands/agent/interrupt.py | 59 --- src/strands/event_loop/event_loop.py | 12 +- src/strands/event_loop/streaming.py | 7 +- src/strands/hooks/registry.py | 67 ++- src/strands/interrupt.py | 94 ++++- src/strands/models/anthropic.py | 1 + src/strands/models/litellm.py | 123 +++++- src/strands/models/openai.py | 92 +++- src/strands/models/sagemaker.py | 8 +- src/strands/multiagent/graph.py | 10 +- src/strands/multiagent/swarm.py | 65 +-- src/strands/telemetry/tracer.py | 44 +- src/strands/tools/caller.py | 78 ++-- src/strands/tools/decorator.py | 136 +++++- src/strands/tools/executors/_executor.py | 10 +- src/strands/tools/mcp/mcp_agent_tool.py | 12 +- src/strands/tools/mcp/mcp_client.py | 71 +++- src/strands/types/session.py | 4 +- .../strands/agent/hooks/test_hook_registry.py | 21 +- tests/strands/agent/test_agent.py | 35 ++ tests/strands/agent/test_interrupt.py | 61 --- tests/strands/event_loop/test_event_loop.py | 8 +- .../test_event_loop_structured_output.py | 6 +- tests/strands/event_loop/test_streaming.py | 37 ++ .../experimental/hooks/test_hook_aliases.py | 7 +- tests/strands/hooks/test_registry.py | 32 +- tests/strands/models/test_litellm.py | 76 ++++ tests/strands/models/test_openai.py | 42 ++ tests/strands/multiagent/test_swarm.py | 27 ++ .../test_repository_session_manager.py | 4 +- tests/strands/telemetry/test_tracer.py | 96 ++++- tests/strands/test_interrupt.py | 108 ++++- tests/strands/tools/executors/conftest.py | 4 +- .../strands/tools/mcp/test_mcp_agent_tool.py | 29 +- tests/strands/tools/test_decorator.py | 398 +++++++++++++++++- tests/strands/types/test_interrupt.py | 5 +- tests/strands/types/test_session.py | 6 +- tests_integ/hooks/__init__.py | 0 tests_integ/hooks/multiagent/__init__.py | 0 tests_integ/hooks/multiagent/test_events.py | 122 ++++++ tests_integ/hooks/test_events.py | 138 ++++++ tests_integ/mcp/test_mcp_client.py | 67 +++ tests_integ/models/test_model_anthropic.py | 30 ++ tests_integ/models/test_model_litellm.py | 22 + tests_integ/models/test_model_openai.py | 26 ++ tests_integ/tools/__init__.py | 0 tests_integ/tools/test_thread_context.py | 47 +++ 49 files changed, 2053 insertions(+), 405 deletions(-) delete mode 100644 src/strands/agent/interrupt.py delete mode 100644 tests/strands/agent/test_interrupt.py create mode 100644 tests_integ/hooks/__init__.py create mode 100644 tests_integ/hooks/multiagent/__init__.py create mode 100644 tests_integ/hooks/multiagent/test_events.py create mode 100644 tests_integ/hooks/test_events.py create mode 100644 tests_integ/tools/__init__.py create mode 100644 tests_integ/tools/test_thread_context.py diff --git a/src/strands/_async.py b/src/strands/_async.py index 976487c37..141ca71b7 100644 --- a/src/strands/_async.py +++ b/src/strands/_async.py @@ -1,6 +1,7 @@ """Private async execution utilities.""" import asyncio +import contextvars from concurrent.futures import ThreadPoolExecutor from typing import Awaitable, Callable, TypeVar @@ -27,5 +28,6 @@ def execute() -> T: return asyncio.run(execute_async()) with ThreadPoolExecutor() as executor: - future = executor.submit(execute) + context = contextvars.copy_context() + future = executor.submit(context.run, execute) return future.result() diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index d3d45013a..434c769c6 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -45,6 +45,7 @@ HookRegistry, MessageAddedEvent, ) +from ..interrupt import _InterruptState from ..models.bedrock import BedrockModel from ..models.model import Model from ..session.session_manager import SessionManager @@ -60,7 +61,6 @@ from ..types.agent import AgentInput from ..types.content import ContentBlock, Message, Messages, SystemContentBlock from ..types.exceptions import ContextWindowOverflowException -from ..types.interrupt import InterruptResponseContent from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue from .agent_result import AgentResult @@ -68,7 +68,6 @@ ConversationManager, SlidingWindowConversationManager, ) -from .interrupt import InterruptState from .state import AgentState logger = logging.getLogger(__name__) @@ -179,8 +178,8 @@ def __init__( """ self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model self.messages = messages if messages is not None else [] - # initializing self.system_prompt for backwards compatibility - self.system_prompt, self._system_prompt_content = self._initialize_system_prompt(system_prompt) + # initializing self._system_prompt for backwards compatibility + self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(system_prompt) self._default_structured_output_model = structured_output_model self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) self.name = name or _DEFAULT_AGENT_NAME @@ -243,7 +242,7 @@ def __init__( self.hooks = HookRegistry() - self._interrupt_state = InterruptState() + self._interrupt_state = _InterruptState() # Initialize session management functionality self._session_manager = session_manager @@ -258,7 +257,36 @@ def __init__( self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self)) @property - def tool(self) -> _ToolCaller: + def system_prompt(self) -> str | None: + """Get the system prompt as a string for backwards compatibility. + + Returns the system prompt as a concatenated string when it contains text content, + or None if no text content is present. This maintains backwards compatibility + with existing code that expects system_prompt to be a string. + + Returns: + The system prompt as a string, or None if no text content exists. + """ + return self._system_prompt + + @system_prompt.setter + def system_prompt(self, value: str | list[SystemContentBlock] | None) -> None: + """Set the system prompt and update internal content representation. + + Accepts either a string or list of SystemContentBlock objects. + When set, both the backwards-compatible string representation and the internal + content block representation are updated to maintain consistency. + + Args: + value: System prompt as string, list of SystemContentBlock objects, or None. + - str: Simple text prompt (most common use case) + - list[SystemContentBlock]: Content blocks with features like caching + - None: Clear the system prompt + """ + self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(value) + + @property + def tool(self) -> ToolCaller: """Call tool as a function. Returns: @@ -424,7 +452,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu category=DeprecationWarning, stacklevel=2, ) - self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) + await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self)) with self.tracer.tracer.start_as_current_span( "execute_structured_output", kind=trace_api.SpanKind.CLIENT ) as structured_output_span: @@ -432,7 +460,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu if not self.messages and not prompt: raise ValueError("No conversation history or prompt provided") - temp_messages: Messages = self.messages + self._convert_prompt_to_messages(prompt) + temp_messages: Messages = self.messages + await self._convert_prompt_to_messages(prompt) structured_output_span.set_attributes( { @@ -465,7 +493,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu return event["output"] finally: - self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) + await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self)) def cleanup(self) -> None: """Clean up resources used by the agent. @@ -531,7 +559,7 @@ async def stream_async( yield event["data"] ``` """ - self._resume_interrupt(prompt) + self._interrupt_state.resume(prompt) merged_state = {} if kwargs: @@ -548,7 +576,7 @@ async def stream_async( callback_handler = kwargs.get("callback_handler", self.callback_handler) # Process input and get message to add (if any) - messages = self._convert_prompt_to_messages(prompt) + messages = await self._convert_prompt_to_messages(prompt) self.trace_span = self._start_agent_trace_span(messages) @@ -574,38 +602,6 @@ async def stream_async( self._end_agent_trace_span(error=e) raise - def _resume_interrupt(self, prompt: AgentInput) -> None: - """Configure the interrupt state if resuming from an interrupt event. - - Args: - prompt: User responses if resuming from interrupt. - - Raises: - TypeError: If in interrupt state but user did not provide responses. - """ - if not self._interrupt_state.activated: - return - - if not isinstance(prompt, list): - raise TypeError(f"prompt_type={type(prompt)} | must resume from interrupt with list of interruptResponse's") - - invalid_types = [ - content_type for content in prompt for content_type in content if content_type != "interruptResponse" - ] - if invalid_types: - raise TypeError( - f"content_types=<{invalid_types}> | must resume from interrupt with list of interruptResponse's" - ) - - for content in cast(list[InterruptResponseContent], prompt): - interrupt_id = content["interruptResponse"]["interruptId"] - interrupt_response = content["interruptResponse"]["response"] - - if interrupt_id not in self._interrupt_state.interrupts: - raise KeyError(f"interrupt_id=<{interrupt_id}> | no interrupt found") - - self._interrupt_state.interrupts[interrupt_id].response = interrupt_response - async def _run_loop( self, messages: Messages, @@ -622,13 +618,13 @@ async def _run_loop( Yields: Events from the event loop cycle. """ - self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) + await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self)) try: yield InitEventLoopEvent() for message in messages: - self._append_message(message) + await self._append_message(message) structured_output_context = StructuredOutputContext( structured_output_model or self._default_structured_output_model @@ -654,7 +650,7 @@ async def _run_loop( finally: self.conversation_manager.apply_management(self) - self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) + await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self)) async def _execute_event_loop_cycle( self, invocation_state: dict[str, Any], structured_output_context: StructuredOutputContext | None = None @@ -703,7 +699,7 @@ async def _execute_event_loop_cycle( if structured_output_context: structured_output_context.cleanup(self.tool_registry) - def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: + async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: if self._interrupt_state.activated: return [] @@ -718,7 +714,7 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: tool_use_ids = [ content["toolUse"]["toolUseId"] for content in self.messages[-1]["content"] if "toolUse" in content ] - self._append_message( + await self._append_message( { "role": "user", "content": generate_missing_tool_result_content(tool_use_ids), @@ -749,7 +745,7 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: raise ValueError("Input prompt must be of type: `str | list[Contentblock] | Messages | None`.") return messages - def _record_tool_execution( + async def _record_tool_execution( self, tool: ToolUse, tool_result: ToolResult, @@ -809,10 +805,10 @@ def _record_tool_execution( } # Add to message history - self._append_message(user_msg) - self._append_message(tool_use_msg) - self._append_message(tool_result_msg) - self._append_message(assistant_msg) + await self._append_message(user_msg) + await self._append_message(tool_use_msg) + await self._append_message(tool_result_msg) + await self._append_message(assistant_msg) def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: """Starts a trace span for the agent. @@ -828,6 +824,7 @@ def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: tools=self.tool_names, system_prompt=self.system_prompt, custom_trace_attributes=self.trace_attributes, + tools_config=self.tool_registry.get_all_tools_config(), ) def _end_agent_trace_span( @@ -897,10 +894,10 @@ def _initialize_system_prompt( else: return None, None - def _append_message(self, message: Message) -> None: + async def _append_message(self, message: Message) -> None: """Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent.""" self.messages.append(message) - self.hooks.invoke_callbacks(MessageAddedEvent(agent=self, message=message)) + await self.hooks.invoke_callbacks_async(MessageAddedEvent(agent=self, message=message)) def _redact_user_content(self, content: list[ContentBlock], redact_message: str) -> list[ContentBlock]: """Redact user content preserving toolResult blocks. diff --git a/src/strands/agent/interrupt.py b/src/strands/agent/interrupt.py deleted file mode 100644 index 3cec1541b..000000000 --- a/src/strands/agent/interrupt.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Track the state of interrupt events raised by the user for human-in-the-loop workflows.""" - -from dataclasses import asdict, dataclass, field -from typing import Any - -from ..interrupt import Interrupt - - -@dataclass -class InterruptState: - """Track the state of interrupt events raised by the user. - - Note, interrupt state is cleared after resuming. - - Attributes: - interrupts: Interrupts raised by the user. - context: Additional context associated with an interrupt event. - activated: True if agent is in an interrupt state, False otherwise. - """ - - interrupts: dict[str, Interrupt] = field(default_factory=dict) - context: dict[str, Any] = field(default_factory=dict) - activated: bool = False - - def activate(self, context: dict[str, Any] | None = None) -> None: - """Activate the interrupt state. - - Args: - context: Context associated with the interrupt event. - """ - self.context = context or {} - self.activated = True - - def deactivate(self) -> None: - """Deacitvate the interrupt state. - - Interrupts and context are cleared. - """ - self.interrupts = {} - self.context = {} - self.activated = False - - def to_dict(self) -> dict[str, Any]: - """Serialize to dict for session management.""" - return asdict(self) - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> "InterruptState": - """Initiailize interrupt state from serialized interrupt state. - - Interrupt state can be serialized with the `to_dict` method. - """ - return cls( - interrupts={ - interrupt_id: Interrupt(**interrupt_data) for interrupt_id, interrupt_data in data["interrupts"].items() - }, - context=data["context"], - activated=data["activated"], - ) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 66174c09f..562de24b8 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -227,7 +227,7 @@ async def event_loop_cycle( ) structured_output_context.set_forced_mode() logger.debug("Forcing structured output tool") - agent._append_message( + await agent._append_message( {"role": "user", "content": [{"text": "You must format the previous response as structured output."}]} ) @@ -322,7 +322,7 @@ async def _handle_model_execution( model_id=model_id, ) with trace_api.use_span(model_invoke_span): - agent.hooks.invoke_callbacks( + await agent.hooks.invoke_callbacks_async( BeforeModelCallEvent( agent=agent, ) @@ -347,7 +347,7 @@ async def _handle_model_execution( stop_reason, message, usage, metrics = event["stop"] invocation_state.setdefault("request_state", {}) - agent.hooks.invoke_callbacks( + await agent.hooks.invoke_callbacks_async( AfterModelCallEvent( agent=agent, stop_response=AfterModelCallEvent.ModelStopResponse( @@ -368,7 +368,7 @@ async def _handle_model_execution( if model_invoke_span: tracer.end_span_with_error(model_invoke_span, str(e), e) - agent.hooks.invoke_callbacks( + await agent.hooks.invoke_callbacks_async( AfterModelCallEvent( agent=agent, exception=e, @@ -402,7 +402,7 @@ async def _handle_model_execution( # Add the response message to the conversation agent.messages.append(message) - agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) + await agent.hooks.invoke_callbacks_async(MessageAddedEvent(agent=agent, message=message)) # Update metrics agent.event_loop_metrics.update_usage(usage) @@ -507,7 +507,7 @@ async def _handle_tool_execution( } agent.messages.append(tool_result_message) - agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message)) + await agent.hooks.invoke_callbacks_async(MessageAddedEvent(agent=agent, message=tool_result_message)) yield ToolResultMessageEvent(message=tool_result_message) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index c7b0b2caa..43836fe34 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -350,8 +350,11 @@ def extract_usage_metrics(event: MetadataEvent, time_to_first_byte_ms: int | Non Returns: The extracted usage metrics and latency. """ - usage = Usage(**event["usage"]) - metrics = Metrics(**event["metrics"]) + # MetadataEvent has total=False, making all fields optional, but Usage and Metrics types + # have Required fields. Provide defaults to handle cases where custom models don't + # provide usage/metrics (e.g., when latency info is unavailable). + usage = Usage(**{"inputTokens": 0, "outputTokens": 0, "totalTokens": 0, **event.get("usage", {})}) + metrics = Metrics(**{"latencyMs": 0, **event.get("metrics", {})}) if time_to_first_byte_ms: metrics["timeToFirstByteMs"] = time_to_first_byte_ms diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 564be85cb..1efc0bf5b 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -7,9 +7,10 @@ via hook provider objects. """ +import inspect import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar +from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar from ..interrupt import Interrupt, InterruptException @@ -122,10 +123,15 @@ class HookCallback(Protocol, Generic[TEvent]): ```python def my_callback(event: StartRequestEvent) -> None: print(f"Request started for agent: {event.agent.name}") + + # Or + + async def my_callback(event: StartRequestEvent) -> None: + # await an async operation ``` """ - def __call__(self, event: TEvent) -> None: + def __call__(self, event: TEvent) -> None | Awaitable[None]: """Handle a hook event. Args: @@ -164,6 +170,10 @@ def my_handler(event: StartRequestEvent): registry.add_callback(StartRequestEvent, my_handler) ``` """ + # Related issue: https://github.com/strands-agents/sdk-python/issues/330 + if event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback): + raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback") + callbacks = self._registered_callbacks.setdefault(event_type, []) callbacks.append(callback) @@ -189,6 +199,52 @@ def register_hooks(self, registry: HookRegistry): """ hook.register_hooks(self) + async def invoke_callbacks_async(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Interrupt]]: + """Invoke all registered callbacks for the given event. + + This method finds all callbacks registered for the event's type and + invokes them in the appropriate order. For events with should_reverse_callbacks=True, + callbacks are invoked in reverse registration order. Any exceptions raised by callback + functions will propagate to the caller. + + Additionally, this method aggregates interrupts raised by the user to instantiate human-in-the-loop workflows. + + Args: + event: The event to dispatch to registered callbacks. + + Returns: + The event dispatched to registered callbacks and any interrupts raised by the user. + + Raises: + ValueError: If interrupt name is used more than once. + + Example: + ```python + event = StartRequestEvent(agent=my_agent) + await registry.invoke_callbacks_async(event) + ``` + """ + interrupts: dict[str, Interrupt] = {} + + for callback in self.get_callbacks_for(event): + try: + if inspect.iscoroutinefunction(callback): + await callback(event) + else: + callback(event) + + except InterruptException as exception: + interrupt = exception.interrupt + if interrupt.name in interrupts: + message = f"interrupt_name=<{interrupt.name}> | interrupt name used more than once" + logger.error(message) + raise ValueError(message) from exception + + # Each callback is allowed to raise their own interrupt. + interrupts[interrupt.name] = interrupt + + return event, list(interrupts.values()) + def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Interrupt]]: """Invoke all registered callbacks for the given event. @@ -206,6 +262,7 @@ def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Inte The event dispatched to registered callbacks and any interrupts raised by the user. Raises: + RuntimeError: If at least one callback is async. ValueError: If interrupt name is used more than once. Example: @@ -214,9 +271,13 @@ def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Inte registry.invoke_callbacks(event) ``` """ + callbacks = list(self.get_callbacks_for(event)) interrupts: dict[str, Interrupt] = {} - for callback in self.get_callbacks_for(event): + if any(inspect.iscoroutinefunction(callback) for callback in callbacks): + raise RuntimeError(f"event=<{event}> | use invoke_callbacks_async to invoke async callback") + + for callback in callbacks: try: callback(event) except InterruptException as exception: diff --git a/src/strands/interrupt.py b/src/strands/interrupt.py index f0ed52389..919927e1a 100644 --- a/src/strands/interrupt.py +++ b/src/strands/interrupt.py @@ -1,7 +1,11 @@ """Human-in-the-loop interrupt system for agent workflows.""" -from dataclasses import asdict, dataclass -from typing import Any +from dataclasses import asdict, dataclass, field +from typing import TYPE_CHECKING, Any, cast + +if TYPE_CHECKING: + from .types.agent import AgentInput + from .types.interrupt import InterruptResponseContent @dataclass @@ -31,3 +35,89 @@ class InterruptException(Exception): def __init__(self, interrupt: Interrupt) -> None: """Set the interrupt.""" self.interrupt = interrupt + + +@dataclass +class _InterruptState: + """Track the state of interrupt events raised by the user. + + Note, interrupt state is cleared after resuming. + + Attributes: + interrupts: Interrupts raised by the user. + context: Additional context associated with an interrupt event. + activated: True if agent is in an interrupt state, False otherwise. + """ + + interrupts: dict[str, Interrupt] = field(default_factory=dict) + context: dict[str, Any] = field(default_factory=dict) + activated: bool = False + + def activate(self, context: dict[str, Any] | None = None) -> None: + """Activate the interrupt state. + + Args: + context: Context associated with the interrupt event. + """ + self.context = context or {} + self.activated = True + + def deactivate(self) -> None: + """Deacitvate the interrupt state. + + Interrupts and context are cleared. + """ + self.interrupts = {} + self.context = {} + self.activated = False + + def resume(self, prompt: "AgentInput") -> None: + """Configure the interrupt state if resuming from an interrupt event. + + Args: + prompt: User responses if resuming from interrupt. + + Raises: + TypeError: If in interrupt state but user did not provide responses. + """ + if not self.activated: + return + + if not isinstance(prompt, list): + raise TypeError(f"prompt_type={type(prompt)} | must resume from interrupt with list of interruptResponse's") + + invalid_types = [ + content_type for content in prompt for content_type in content if content_type != "interruptResponse" + ] + if invalid_types: + raise TypeError( + f"content_types=<{invalid_types}> | must resume from interrupt with list of interruptResponse's" + ) + + contents = cast(list["InterruptResponseContent"], prompt) + for content in contents: + interrupt_id = content["interruptResponse"]["interruptId"] + interrupt_response = content["interruptResponse"]["response"] + + if interrupt_id not in self.interrupts: + raise KeyError(f"interrupt_id=<{interrupt_id}> | no interrupt found") + + self.interrupts[interrupt_id].response = interrupt_response + + def to_dict(self) -> dict[str, Any]: + """Serialize to dict for session management.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "_InterruptState": + """Initiailize interrupt state from serialized interrupt state. + + Interrupt state can be serialized with the `to_dict` method. + """ + return cls( + interrupts={ + interrupt_id: Interrupt(**interrupt_data) for interrupt_id, interrupt_data in data["interrupts"].items() + }, + context=data["context"], + activated=data["activated"], + ) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 48351da19..68b234729 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -39,6 +39,7 @@ class AnthropicModel(Model): } OVERFLOW_MESSAGES = { + "prompt is too long:", "input is too long", "input length exceeds context window", "input and output tokens exceed your context limit", diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 7a8c0ae03..17f1bbb94 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -14,9 +14,10 @@ from typing_extensions import Unpack, override from ..tools import convert_pydantic_to_tool_spec -from ..types.content import ContentBlock, Messages +from ..types.content import ContentBlock, Messages, SystemContentBlock +from ..types.event_loop import Usage from ..types.exceptions import ContextWindowOverflowException -from ..types.streaming import StreamEvent +from ..types.streaming import MetadataEvent, StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import validate_config_keys from .openai import OpenAIModel @@ -81,11 +82,12 @@ def get_config(self) -> LiteLLMConfig: @override @classmethod - def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]: """Format a LiteLLM content block. Args: content: Message content. + **kwargs: Additional keyword arguments for future extensibility. Returns: LiteLLM formatted content block. @@ -131,6 +133,113 @@ def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> return chunks, data_type + @override + @classmethod + def _format_system_messages( + cls, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **kwargs: Any, + ) -> list[dict[str, Any]]: + """Format system messages for LiteLLM with cache point support. + + Args: + system_prompt: System prompt to provide context to the model. + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + List of formatted system messages. + """ + # Handle backward compatibility: if system_prompt is provided but system_prompt_content is None + if system_prompt and system_prompt_content is None: + system_prompt_content = [{"text": system_prompt}] + + system_content: list[dict[str, Any]] = [] + for block in system_prompt_content or []: + if "text" in block: + system_content.append({"type": "text", "text": block["text"]}) + elif "cachePoint" in block and block["cachePoint"].get("type") == "default": + # Apply cache control to the immediately preceding content block + # for LiteLLM/Anthropic compatibility + if system_content: + system_content[-1]["cache_control"] = {"type": "ephemeral"} + + # Create single system message with content array rather than mulitple system messages + return [{"role": "system", "content": system_content}] if system_content else [] + + @override + @classmethod + def format_request_messages( + cls, + messages: Messages, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **kwargs: Any, + ) -> list[dict[str, Any]]: + """Format a LiteLLM compatible messages array with cache point support. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model (for legacy compatibility). + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + A LiteLLM compatible messages array. + """ + formatted_messages = cls._format_system_messages(system_prompt, system_prompt_content=system_prompt_content) + formatted_messages.extend(cls._format_regular_messages(messages)) + + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + + @override + def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: + """Format a LiteLLM response event into a standardized message chunk. + + This method overrides OpenAI's format_chunk to handle the metadata case + with prompt caching support. All other chunk types use the parent implementation. + + Args: + event: A response event from the LiteLLM model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + """ + # Handle metadata case with prompt caching support + if event["chunk_type"] == "metadata": + usage_data: Usage = { + "inputTokens": event["data"].prompt_tokens, + "outputTokens": event["data"].completion_tokens, + "totalTokens": event["data"].total_tokens, + } + + # Only LiteLLM over Anthropic supports cache write tokens + # Waiting until a more general approach is available to set cacheWriteInputTokens + + if tokens_details := getattr(event["data"], "prompt_tokens_details", None): + if cached := getattr(tokens_details, "cached_tokens", None): + usage_data["cacheReadInputTokens"] = cached + if creation := getattr(tokens_details, "cache_creation_tokens", None): + usage_data["cacheWriteInputTokens"] = creation + + return StreamEvent( + metadata=MetadataEvent( + metrics={ + "latencyMs": 0, # TODO + }, + usage=usage_data, + ) + ) + # For all other cases, use the parent implementation + return super().format_chunk(event) + @override async def stream( self, @@ -139,6 +248,7 @@ async def stream( system_prompt: Optional[str] = None, *, tool_choice: ToolChoice | None = None, + system_prompt_content: Optional[list[SystemContentBlock]] = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the LiteLLM model. @@ -148,17 +258,22 @@ async def stream( tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. + system_prompt_content: System prompt content blocks to provide context to the model. **kwargs: Additional keyword arguments for future extensibility. Yields: Formatted message chunks from the model. """ logger.debug("formatting request") - request = self.format_request(messages, tool_specs, system_prompt, tool_choice) + request = self.format_request( + messages, tool_specs, system_prompt, tool_choice, system_prompt_content=system_prompt_content + ) logger.debug("request=<%s>", request) logger.debug("invoking model") try: + if kwargs.get("stream") is False: + raise ValueError("stream parameter cannot be explicitly set to False") response = await litellm.acompletion(**self.client_args, **request) except ContextWindowExceededError as e: logger.warning("litellm client raised context window overflow") diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 1efe641e6..435c82cab 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -14,7 +14,7 @@ from pydantic import BaseModel from typing_extensions import Unpack, override -from ..types.content import ContentBlock, Messages +from ..types.content import ContentBlock, Messages, SystemContentBlock from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse @@ -89,11 +89,12 @@ def get_config(self) -> OpenAIConfig: return cast(OpenAIModel.OpenAIConfig, self.config) @classmethod - def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]: """Format an OpenAI compatible content block. Args: content: Message content. + **kwargs: Additional keyword arguments for future extensibility. Returns: OpenAI compatible content block. @@ -131,11 +132,12 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") @classmethod - def format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: + def format_request_message_tool_call(cls, tool_use: ToolUse, **kwargs: Any) -> dict[str, Any]: """Format an OpenAI compatible tool call. Args: tool_use: Tool use requested by the model. + **kwargs: Additional keyword arguments for future extensibility. Returns: OpenAI compatible tool call. @@ -150,11 +152,12 @@ def format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: } @classmethod - def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) -> dict[str, Any]: """Format an OpenAI compatible tool message. Args: tool_result: Tool result collected from a tool execution. + **kwargs: Additional keyword arguments for future extensibility. Returns: OpenAI compatible tool message. @@ -198,18 +201,46 @@ def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str return {"tool_choice": "auto"} @classmethod - def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: - """Format an OpenAI compatible messages array. + def _format_system_messages( + cls, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **kwargs: Any, + ) -> list[dict[str, Any]]: + """Format system messages for OpenAI-compatible providers. Args: - messages: List of message objects to be processed by the model. system_prompt: System prompt to provide context to the model. + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. Returns: - An OpenAI compatible messages array. + List of formatted system messages. + """ + # Handle backward compatibility: if system_prompt is provided but system_prompt_content is None + if system_prompt and system_prompt_content is None: + system_prompt_content = [{"text": system_prompt}] + + # TODO: Handle caching blocks https://github.com/strands-agents/sdk-python/issues/1140 + return [ + {"role": "system", "content": content["text"]} + for content in system_prompt_content or [] + if "text" in content + ] + + @classmethod + def _format_regular_messages(cls, messages: Messages, **kwargs: Any) -> list[dict[str, Any]]: + """Format regular messages for OpenAI-compatible providers. + + Args: + messages: List of message objects to be processed by the model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + List of formatted messages. """ - formatted_messages: list[dict[str, Any]] - formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] + formatted_messages = [] for message in messages: contents = message["content"] @@ -242,14 +273,42 @@ def format_request_messages(cls, messages: Messages, system_prompt: Optional[str formatted_messages.append(formatted_message) formatted_messages.extend(formatted_tool_messages) + return formatted_messages + + @classmethod + def format_request_messages( + cls, + messages: Messages, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **kwargs: Any, + ) -> list[dict[str, Any]]: + """Format an OpenAI compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + An OpenAI compatible messages array. + """ + formatted_messages = cls._format_system_messages(system_prompt, system_prompt_content=system_prompt_content) + formatted_messages.extend(cls._format_regular_messages(messages)) + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] def format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, tool_choice: ToolChoice | None = None, + *, + system_prompt_content: list[SystemContentBlock] | None = None, + **kwargs: Any, ) -> dict[str, Any]: """Format an OpenAI compatible chat streaming request. @@ -258,6 +317,8 @@ def format_request( tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. Returns: An OpenAI compatible chat streaming request. @@ -267,7 +328,9 @@ def format_request( format. """ return { - "messages": self.format_request_messages(messages, system_prompt), + "messages": self.format_request_messages( + messages, system_prompt, system_prompt_content=system_prompt_content + ), "model": self.config["model_id"], "stream": True, "stream_options": {"include_usage": True}, @@ -286,11 +349,12 @@ def format_request( **cast(dict[str, Any], self.config.get("params", {})), } - def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: """Format an OpenAI response event into a standardized message chunk. Args: event: A response event from the OpenAI compatible model. + **kwargs: Additional keyword arguments for future extensibility. Returns: The formatted chunk. diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 25b3ca7ce..7f8b8ff51 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -202,6 +202,7 @@ def format_request( tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, tool_choice: ToolChoice | None = None, + **kwargs: Any, ) -> dict[str, Any]: """Format an Amazon SageMaker chat streaming request. @@ -211,6 +212,7 @@ def format_request( system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for interface consistency but is currently ignored for this model provider.** + **kwargs: Additional keyword arguments for future extensibility. Returns: An Amazon SageMaker chat streaming request. @@ -501,11 +503,12 @@ async def stream( @override @classmethod - def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) -> dict[str, Any]: """Format a SageMaker compatible tool message. Args: tool_result: Tool result collected from a tool execution. + **kwargs: Additional keyword arguments for future extensibility. Returns: SageMaker compatible tool message with content as a string. @@ -531,11 +534,12 @@ def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: @override @classmethod - def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]: """Format a content block. Args: content: Message content. + **kwargs: Additional keyword arguments for future extensibility. Returns: Formatted content block. diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index b421b70c1..9f28876bf 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -453,7 +453,7 @@ def __init__( self._resume_from_session = False self.id = id - self.hooks.invoke_callbacks(MultiAgentInitializedEvent(self)) + run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -516,7 +516,7 @@ async def stream_async( if invocation_state is None: invocation_state = {} - self.hooks.invoke_callbacks(BeforeMultiAgentInvocationEvent(self, invocation_state)) + await self.hooks.invoke_callbacks_async(BeforeMultiAgentInvocationEvent(self, invocation_state)) logger.debug("task=<%s> | starting graph execution", task) @@ -569,7 +569,7 @@ async def stream_async( raise finally: self.state.execution_time = round((time.time() - start_time) * 1000) - self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(self)) + await self.hooks.invoke_callbacks_async(AfterMultiAgentInvocationEvent(self)) self._resume_from_session = False self._resume_next_nodes.clear() @@ -776,7 +776,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: """Execute a single node and yield TypedEvent objects.""" - self.hooks.invoke_callbacks(BeforeNodeCallEvent(self, node.node_id, invocation_state)) + await self.hooks.invoke_callbacks_async(BeforeNodeCallEvent(self, node.node_id, invocation_state)) # Reset the node's state if reset_on_revisit is enabled, and it's being revisited if self.reset_on_revisit and node in self.state.completed_nodes: @@ -920,7 +920,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) raise finally: - self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node.node_id, invocation_state)) + await self.hooks.invoke_callbacks_async(AfterNodeCallEvent(self, node.node_id, invocation_state)) def _accumulate_metrics(self, node_result: NodeResult) -> None: """Accumulate metrics from a node result.""" diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index accd56463..3913cd837 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -156,6 +156,7 @@ class SwarmState: # Total metrics across all agents accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) execution_time: int = 0 # Total execution time in milliseconds + handoff_node: SwarmNode | None = None # The agent to execute next handoff_message: str | None = None # Message passed during agent handoff def should_continue( @@ -273,7 +274,7 @@ def __init__( self._setup_swarm(nodes) self._inject_swarm_tools() - self.hooks.invoke_callbacks(MultiAgentInitializedEvent(self)) + run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -336,7 +337,7 @@ async def stream_async( if invocation_state is None: invocation_state = {} - self.hooks.invoke_callbacks(BeforeMultiAgentInvocationEvent(self, invocation_state)) + await self.hooks.invoke_callbacks_async(BeforeMultiAgentInvocationEvent(self, invocation_state)) logger.debug("starting swarm execution") @@ -375,7 +376,7 @@ async def stream_async( raise finally: self.state.execution_time = round((time.time() - self.state.start_time) * 1000) - self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(self, invocation_state)) + await self.hooks.invoke_callbacks_async(AfterMultiAgentInvocationEvent(self, invocation_state)) self._resume_from_session = False # Yield final result after execution_time is set @@ -537,7 +538,7 @@ def handoff_to_agent(agent_name: str, message: str, context: dict[str, Any] | No # Execute handoff swarm_ref._handle_handoff(target_node, message, context) - return {"status": "success", "content": [{"text": f"Handed off to {agent_name}: {message}"}]} + return {"status": "success", "content": [{"text": f"Handing off to {agent_name}: {message}"}]} except Exception as e: return {"status": "error", "content": [{"text": f"Error in handoff: {str(e)}"}]} @@ -553,21 +554,19 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st ) return - # Update swarm state - previous_agent = cast(SwarmNode, self.state.current_node) - self.state.current_node = target_node + current_node = cast(SwarmNode, self.state.current_node) - # Store handoff message for the target agent + self.state.handoff_node = target_node self.state.handoff_message = message # Store handoff context as shared context if context: for key, value in context.items(): - self.shared_context.add_context(previous_agent, key, value) + self.shared_context.add_context(current_node, key, value) logger.debug( - "from_node=<%s>, to_node=<%s> | handed off from agent to agent", - previous_agent.node_id, + "from_node=<%s>, to_node=<%s> | handing off from agent to agent", + current_node.node_id, target_node.node_id, ) @@ -667,7 +666,6 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato logger.debug("reason=<%s> | stopping execution", reason) break - # Get current node current_node = self.state.current_node if not current_node or current_node.node_id not in self.nodes: logger.error("node=<%s> | node not found", current_node.node_id if current_node else "None") @@ -680,14 +678,11 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato len(self.state.node_history) + 1, ) - # Store the current node before execution to detect handoffs - previous_node = current_node - - # Execute node with timeout protection # TODO: Implement cancellation token to stop _execute_node from continuing try: - # Execute with timeout wrapper for async generator streaming - self.hooks.invoke_callbacks(BeforeNodeCallEvent(self, current_node.node_id, invocation_state)) + await self.hooks.invoke_callbacks_async( + BeforeNodeCallEvent(self, current_node.node_id, invocation_state) + ) node_stream = self._stream_with_timeout( self._execute_node(current_node, self.state.task, invocation_state), self.node_timeout, @@ -697,28 +692,33 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato yield event self.state.node_history.append(current_node) - - # After self.state add current node, swarm state finish updating, we persist here - self.hooks.invoke_callbacks(AfterNodeCallEvent(self, current_node.node_id, invocation_state)) + await self.hooks.invoke_callbacks_async( + AfterNodeCallEvent(self, current_node.node_id, invocation_state) + ) logger.debug("node=<%s> | node execution completed", current_node.node_id) - # Check if handoff occurred during execution - if self.state.current_node is not None and self.state.current_node != previous_node: - # Emit handoff event (single node transition in Swarm) + # Check if handoff requested during execution + if self.state.handoff_node: + previous_node = current_node + current_node = self.state.handoff_node + + self.state.handoff_node = None + self.state.current_node = current_node + handoff_event = MultiAgentHandoffEvent( from_node_ids=[previous_node.node_id], - to_node_ids=[self.state.current_node.node_id], + to_node_ids=[current_node.node_id], message=self.state.handoff_message or "Agent handoff occurred", ) yield handoff_event logger.debug( "from_node=<%s>, to_node=<%s> | handoff detected", previous_node.node_id, - self.state.current_node.node_id, + current_node.node_id, ) + else: - # No handoff occurred, mark swarm as complete logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id) self.state.completion_status = Status.COMPLETED break @@ -862,11 +862,12 @@ def _build_result(self) -> SwarmResult: def serialize_state(self) -> dict[str, Any]: """Serialize the current swarm state to a dictionary.""" status_str = self.state.completion_status.value - next_nodes = ( - [self.state.current_node.node_id] - if self.state.completion_status == Status.EXECUTING and self.state.current_node - else [] - ) + if self.state.handoff_node: + next_nodes = [self.state.handoff_node.node_id] + elif self.state.completion_status == Status.EXECUTING and self.state.current_node: + next_nodes = [self.state.current_node.node_id] + else: + next_nodes = [] return { "type": "swarm", diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 9cefc6911..c47a10c3f 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -79,11 +79,12 @@ class Tracer: When the OTEL_EXPORTER_OTLP_ENDPOINT environment variable is set, traces are sent to the OTLP endpoint. + + Both attributes are controlled by including "gen_ai_latest_experimental" or "gen_ai_tool_definitions", + respectively, in the OTEL_SEMCONV_STABILITY_OPT_IN environment variable. """ - def __init__( - self, - ) -> None: + def __init__(self) -> None: """Initialize the tracer.""" self.service_name = __name__ self.tracer_provider: Optional[trace_api.TracerProvider] = None @@ -92,17 +93,19 @@ def __init__( ThreadingInstrumentor().instrument() # Read OTEL_SEMCONV_STABILITY_OPT_IN environment variable - self.use_latest_genai_conventions = self._parse_semconv_opt_in() + opt_in_values = self._parse_semconv_opt_in() + ## To-do: should not set below attributes directly, use env var instead + self.use_latest_genai_conventions = "gen_ai_latest_experimental" in opt_in_values + self._include_tool_definitions = "gen_ai_tool_definitions" in opt_in_values - def _parse_semconv_opt_in(self) -> bool: + def _parse_semconv_opt_in(self) -> set[str]: """Parse the OTEL_SEMCONV_STABILITY_OPT_IN environment variable. Returns: - Set of opt-in values from the environment variable + A set of opt-in values from the environment variable. """ opt_in_env = os.getenv("OTEL_SEMCONV_STABILITY_OPT_IN", "") - - return "gen_ai_latest_experimental" in opt_in_env + return {value.strip() for value in opt_in_env.split(",")} def _start_span( self, @@ -551,6 +554,7 @@ def start_agent_span( model_id: Optional[str] = None, tools: Optional[list] = None, custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + tools_config: Optional[dict] = None, **kwargs: Any, ) -> Span: """Start a new span for an agent invocation. @@ -561,6 +565,7 @@ def start_agent_span( model_id: Optional model identifier. tools: Optional list of tools being used. custom_trace_attributes: Optional mapping of custom trace attributes to include in the span. + tools_config: Optional dictionary of tool configurations. **kwargs: Additional attributes to add to the span. Returns: @@ -577,8 +582,15 @@ def start_agent_span( attributes["gen_ai.request.model"] = model_id if tools: - tools_json = serialize(tools) - attributes["gen_ai.agent.tools"] = tools_json + attributes["gen_ai.agent.tools"] = serialize(tools) + + if self._include_tool_definitions and tools_config: + try: + tool_definitions = self._construct_tool_definitions(tools_config) + attributes["gen_ai.tool.definitions"] = serialize(tool_definitions) + except Exception: + # A failure in telemetry should not crash the agent + logger.warning("failed to attach tool metadata to agent span", exc_info=True) # Add custom trace attributes if provided if custom_trace_attributes: @@ -649,6 +661,18 @@ def end_agent_span( self._end_span(span, attributes, error) + def _construct_tool_definitions(self, tools_config: dict) -> list[dict[str, Any]]: + """Constructs a list of tool definitions from the provided tools_config.""" + return [ + { + "name": name, + "description": spec.get("description"), + "inputSchema": spec.get("inputSchema"), + "outputSchema": spec.get("outputSchema"), + } + for name, spec in tools_config.items() + ] + def start_multiagent_span( self, task: str | list[ContentBlock], diff --git a/src/strands/tools/caller.py b/src/strands/tools/caller.py index 9c1c116b5..4663b662f 100644 --- a/src/strands/tools/caller.py +++ b/src/strands/tools/caller.py @@ -1,11 +1,11 @@ """ToolCaller base class.""" -import asyncio import random -from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Optional +from .._async import run_async from ..tools.executors._executor import ToolExecutor +from ..types._events import ToolInterruptEvent from ..types.tools import ToolResult, ToolUse @@ -61,7 +61,12 @@ def caller( Raises: AttributeError: If the tool doesn't exist. + RuntimeError: If called during an interrupt or if interrupt is raised. """ + # Check if agent has interrupt state and if it's activated + if hasattr(self._agent, "_interrupt_state") and self._agent._interrupt_state.activated: + raise RuntimeError("cannot directly call tool during interrupt") + normalized_name = self._find_normalized_tool_name(name) # Create unique tool ID and set up the tool request @@ -73,10 +78,11 @@ def caller( } # Execute tool using shared execution pipeline - tool_result = self._execute_tool_sync(tool_use, kwargs) + tool_result = self._execute_tool_async(tool_use, kwargs, user_message_override, record_direct_tool_call) - # Handle tool call recording with agent-specific behavior - self._handle_tool_call_recording(tool_use, tool_result, user_message_override, record_direct_tool_call) + # Apply conversation management if agent supports it (traditional agents) + if hasattr(self._agent, "conversation_manager"): + self._agent.conversation_manager.apply_management(self._agent) return tool_result @@ -111,54 +117,48 @@ def _find_normalized_tool_name(self, name: str) -> str: raise AttributeError(f"Tool '{name}' not found") - def _execute_tool_sync(self, tool_use: ToolUse, invocation_state: dict[str, Any]) -> ToolResult: - """Execute tool synchronously using shared Strands pipeline. + def _execute_tool_async( + self, + tool_use: ToolUse, + invocation_state: dict[str, Any], + user_message_override: Optional[str], + record_direct_tool_call: Optional[bool], + ) -> ToolResult: + """Execute tool asynchronously using shared Strands pipeline. Args: tool_use: Tool execution request. invocation_state: Execution context. + user_message_override: Optional message override. + record_direct_tool_call: Optional recording override. Returns: Tool execution result. + + Raises: + RuntimeError: If interrupt is raised during tool execution. """ tool_results: list[ToolResult] = [] async def acall() -> ToolResult: async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): - _ = event - return tool_results[0] + # Check for interrupt events + if isinstance(event, ToolInterruptEvent): + if hasattr(self._agent, "_interrupt_state"): + self._agent._interrupt_state.deactivate() + raise RuntimeError("cannot raise interrupt in direct tool call") - def tcall() -> ToolResult: - return asyncio.run(acall()) + tool_result = tool_results[0] - with ThreadPoolExecutor() as executor: - future = executor.submit(tcall) - return future.result() + # Determine if we should record the tool call + should_record = ( + record_direct_tool_call if record_direct_tool_call is not None else self._agent.record_direct_tool_call + ) - def _handle_tool_call_recording( - self, - tool_use: ToolUse, - tool_result: ToolResult, - user_message_override: Optional[str], - record_direct_tool_call: Optional[bool], - ) -> None: - """Handle tool call recording with agent-specific behavior. - - Args: - tool_use: Tool execution information. - tool_result: Tool result. - user_message_override: Optional message override. - record_direct_tool_call: Optional recording override. - """ - # Determine if we should record the tool call - should_record = ( - record_direct_tool_call if record_direct_tool_call is not None else self._agent.record_direct_tool_call - ) + if should_record: + # Use agent's async recording method + await self._agent._record_tool_execution(tool_use, tool_result, user_message_override) - if should_record: - # Use agent's recording method - self._agent._record_tool_execution(tool_use, tool_result, user_message_override) + return tool_result - # Apply conversation management if agent supports it (traditional agents) - if hasattr(self._agent, "conversation_manager"): - self._agent.conversation_manager.apply_management(self._agent) + return run_async(acall) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 5c49f4b58..8dc933f51 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -45,6 +45,7 @@ def my_tool(param1: str, param2: int = 42) -> dict: import inspect import logging from typing import ( + Annotated, Any, Callable, Generic, @@ -54,12 +55,15 @@ def my_tool(param1: str, param2: int = 42) -> dict: TypeVar, Union, cast, + get_args, + get_origin, get_type_hints, overload, ) import docstring_parser from pydantic import BaseModel, Field, create_model +from pydantic.fields import FieldInfo from typing_extensions import override from ..interrupt import InterruptException @@ -105,15 +109,66 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) - # Parse the docstring with docstring_parser doc_str = inspect.getdoc(func) or "" self.doc = docstring_parser.parse(doc_str) - - # Get parameter descriptions from parsed docstring - self.param_descriptions = { + self.param_descriptions: dict[str, str] = { param.arg_name: param.description or f"Parameter {param.arg_name}" for param in self.doc.params } # Create a Pydantic model for validation self.input_model = self._create_input_model() + def _extract_annotated_metadata( + self, annotation: Any, param_name: str, param_default: Any + ) -> tuple[Any, FieldInfo]: + """Extracts type and a simple string description from an Annotated type hint. + + Returns: + A tuple of (actual_type, field_info), where field_info is a new, simple + Pydantic FieldInfo instance created from the extracted metadata. + """ + actual_type = annotation + description: str | None = None + + if get_origin(annotation) is Annotated: + args = get_args(annotation) + actual_type = args[0] + + # Look through metadata for a string description or a FieldInfo object + for meta in args[1:]: + if isinstance(meta, str): + description = meta + elif isinstance(meta, FieldInfo): + # --- Future Contributor Note --- + # We are explicitly blocking the use of `pydantic.Field` within `Annotated` + # because of the complexities of Pydantic v2's immutable Core Schema. + # + # Once a Pydantic model's schema is built, its `FieldInfo` objects are + # effectively frozen. Attempts to mutate a `FieldInfo` object after + # creation (e.g., by copying it and setting `.description` or `.default`) + # are unreliable because the underlying Core Schema does not see these changes. + # + # The correct way to support this would be to reliably extract all + # constraints (ge, le, pattern, etc.) from the original FieldInfo and + # rebuild a new one from scratch. However, these constraints are not + # stored as public attributes, making them difficult to inspect reliably. + # + # Deferring this complexity until there is clear demand and a robust + # pattern for inspecting FieldInfo constraints is established. + raise NotImplementedError( + "Using pydantic.Field within Annotated is not yet supported for tool decorators. " + "Please use a simple string for the description, or define constraints in the function's " + "docstring." + ) + + # Determine the final description with a clear priority order + # Priority: 1. Annotated string -> 2. Docstring -> 3. Fallback + final_description = description + if final_description is None: + final_description = self.param_descriptions.get(param_name) or f"Parameter {param_name}" + # Create FieldInfo object from scratch + final_field = Field(default=param_default, description=final_description) + + return actual_type, final_field + def _validate_signature(self) -> None: """Verify that ToolContext is used correctly in the function signature.""" for param in self.signature.parameters.values(): @@ -146,24 +201,73 @@ def _create_input_model(self) -> Type[BaseModel]: if self._is_special_parameter(name): continue - # Get parameter type and default - param_type = self.type_hints.get(name, Any) + # Use param.annotation directly to get the raw type hint. Using get_type_hints() + # can cause inconsistent behavior across Python versions for complex Annotated types. + param_type = param.annotation + if param_type is inspect.Parameter.empty: + param_type = Any default = ... if param.default is inspect.Parameter.empty else param.default - description = self.param_descriptions.get(name, f"Parameter {name}") - # Create Field with description and default - field_definitions[name] = (param_type, Field(default=default, description=description)) + actual_type, field_info = self._extract_annotated_metadata(param_type, name, default) + field_definitions[name] = (actual_type, field_info) - # Create model name based on function name model_name = f"{self.func.__name__.capitalize()}Tool" - # Create and return the model if field_definitions: return create_model(model_name, **field_definitions) else: - # Handle case with no parameters return create_model(model_name) + def _extract_description_from_docstring(self) -> str: + """Extract the docstring excluding only the Args section. + + This method uses the parsed docstring to extract everything except + the Args/Arguments/Parameters section, preserving Returns, Raises, + Examples, and other sections. + + Returns: + The description text, or the function name if no description is available. + """ + func_name = self.func.__name__ + + # Fallback: try to extract manually from raw docstring + raw_docstring = inspect.getdoc(self.func) + if raw_docstring: + lines = raw_docstring.strip().split("\n") + result_lines = [] + skip_args_section = False + + for line in lines: + stripped_line = line.strip() + + # Check if we're starting the Args section + if stripped_line.lower().startswith(("args:", "arguments:", "parameters:", "param:", "params:")): + skip_args_section = True + continue + + # Check if we're starting a new section (not Args) + elif ( + stripped_line.lower().startswith(("returns:", "return:", "yields:", "yield:")) + or stripped_line.lower().startswith(("raises:", "raise:", "except:", "exceptions:")) + or stripped_line.lower().startswith(("examples:", "example:", "note:", "notes:")) + or stripped_line.lower().startswith(("see also:", "seealso:", "references:", "ref:")) + ): + skip_args_section = False + result_lines.append(line) + continue + + # If we're not in the Args section, include the line + if not skip_args_section: + result_lines.append(line) + + # Join and clean up the description + description = "\n".join(result_lines).strip() + if description: + return description + + # Final fallback: use function name + return func_name + def extract_metadata(self) -> ToolSpec: """Extract metadata from the function to create a tool specification. @@ -173,7 +277,7 @@ def extract_metadata(self) -> ToolSpec: The specification includes: - name: The function name (or custom override) - - description: The function's docstring + - description: The function's docstring description (excluding Args) - inputSchema: A JSON schema describing the expected parameters Returns: @@ -181,12 +285,8 @@ def extract_metadata(self) -> ToolSpec: """ func_name = self.func.__name__ - # Extract function description from docstring, preserving paragraph breaks - description = inspect.getdoc(self.func) - if description: - description = description.strip() - else: - description = func_name + # Extract function description from parsed docstring, excluding Args section and beyond + description = self._extract_description_from_docstring() # Get schema directly from the Pydantic model input_schema = self.input_model.model_json_schema() diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index f9a482558..87c38990d 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -85,7 +85,7 @@ async def _stream( } ) - before_event, interrupts = agent.hooks.invoke_callbacks( + before_event, interrupts = await agent.hooks.invoke_callbacks_async( BeforeToolCallEvent( agent=agent, selected_tool=tool_func, @@ -109,7 +109,7 @@ async def _stream( "status": "error", "content": [{"text": cancel_message}], } - after_event, _ = agent.hooks.invoke_callbacks( + after_event, _ = await agent.hooks.invoke_callbacks_async( AfterToolCallEvent( agent=agent, tool_use=tool_use, @@ -147,7 +147,7 @@ async def _stream( "status": "error", "content": [{"text": f"Unknown tool: {tool_name}"}], } - after_event, _ = agent.hooks.invoke_callbacks( + after_event, _ = await agent.hooks.invoke_callbacks_async( AfterToolCallEvent( agent=agent, selected_tool=selected_tool, @@ -184,7 +184,7 @@ async def _stream( result = cast(ToolResult, event) - after_event, _ = agent.hooks.invoke_callbacks( + after_event, _ = await agent.hooks.invoke_callbacks_async( AfterToolCallEvent( agent=agent, selected_tool=selected_tool, @@ -204,7 +204,7 @@ async def _stream( "status": "error", "content": [{"text": f"Error: {str(e)}"}], } - after_event, _ = agent.hooks.invoke_callbacks( + after_event, _ = await agent.hooks.invoke_callbacks_async( AfterToolCallEvent( agent=agent, selected_tool=selected_tool, diff --git a/src/strands/tools/mcp/mcp_agent_tool.py b/src/strands/tools/mcp/mcp_agent_tool.py index af0c069a1..bedd93f24 100644 --- a/src/strands/tools/mcp/mcp_agent_tool.py +++ b/src/strands/tools/mcp/mcp_agent_tool.py @@ -6,6 +6,7 @@ """ import logging +from datetime import timedelta from typing import TYPE_CHECKING, Any from mcp.types import Tool as MCPTool @@ -28,7 +29,13 @@ class MCPAgentTool(AgentTool): seamlessly within the agent framework. """ - def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient", name_override: str | None = None) -> None: + def __init__( + self, + mcp_tool: MCPTool, + mcp_client: "MCPClient", + name_override: str | None = None, + timeout: timedelta | None = None, + ) -> None: """Initialize a new MCPAgentTool instance. Args: @@ -36,12 +43,14 @@ def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient", name_override: st mcp_client: The MCP server connection to use for tool invocation name_override: Optional name to use for the agent tool (for disambiguation) If None, uses the original MCP tool name + timeout: Optional timeout duration for tool execution """ super().__init__() logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name) self.mcp_tool = mcp_tool self.mcp_client = mcp_client self._agent_tool_name = name_override or mcp_tool.name + self.timeout = timeout @property def tool_name(self) -> str: @@ -105,5 +114,6 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw tool_use_id=tool_use["toolUseId"], name=self.mcp_tool.name, # Use original MCP name for server communication arguments=tool_use["input"], + read_timeout_seconds=self.timeout, ) yield ToolResultEvent(result) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 2fe006466..b16b9c2b4 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -119,10 +119,12 @@ def __init__( mcp_instrumentation() self._session_id = uuid.uuid4() self._log_debug_with_thread("initializing MCPClient connection") - # Main thread blocks until future completesock + # Main thread blocks until future completes self._init_future: futures.Future[None] = futures.Future() + # Set within the inner loop as it needs the asyncio loop + self._close_future: asyncio.futures.Future[None] | None = None + self._close_exception: None | Exception = None # Do not want to block other threads while close event is false - self._close_event = asyncio.Event() self._transport_callable = transport_callable self._background_thread: threading.Thread | None = None @@ -288,11 +290,12 @@ def stop( - _background_thread: Thread running the async event loop - _background_thread_session: MCP ClientSession (auto-closed by context manager) - _background_thread_event_loop: AsyncIO event loop in background thread - - _close_event: AsyncIO event to signal thread shutdown + - _close_future: AsyncIO future to signal thread shutdown + - _close_exception: Exception that caused the background thread shutdown; None if a normal shutdown occurred. - _init_future: Future for initialization synchronization Cleanup order: - 1. Signal close event to background thread (if session initialized) + 1. Signal close future to background thread (if session initialized) 2. Wait for background thread to complete 3. Reset all state for reuse @@ -303,13 +306,14 @@ def stop( """ self._log_debug_with_thread("exiting MCPClient context") - # Only try to signal close event if we have a background thread + # Only try to signal close future if we have a background thread if self._background_thread is not None: - # Signal close event if event loop exists + # Signal close future if event loop exists if self._background_thread_event_loop is not None: async def _set_close_event() -> None: - self._close_event.set() + if self._close_future and not self._close_future.done(): + self._close_future.set_result(None) # Not calling _invoke_on_background_thread since the session does not need to exist # we only need the thread and event loop to exist. @@ -317,11 +321,11 @@ async def _set_close_event() -> None: self._log_debug_with_thread("waiting for background thread to join") self._background_thread.join() + self._log_debug_with_thread("background thread is closed, MCPClient context exited") # Reset fields to allow instance reuse self._init_future = futures.Future() - self._close_event = asyncio.Event() self._background_thread = None self._background_thread_session = None self._background_thread_event_loop = None @@ -330,6 +334,11 @@ async def _set_close_event() -> None: self._tool_provider_started = False self._consumers = set() + if self._close_exception: + exception = self._close_exception + self._close_exception = None + raise RuntimeError("Connection to the MCP server was closed") from exception + def list_tools_sync( self, pagination_token: str | None = None, @@ -563,6 +572,10 @@ async def _async_background_thread(self) -> None: signals readiness to the main thread, and waits for a close signal. """ self._log_debug_with_thread("starting async background thread for MCP connection") + + # Initialized here so that it has the asyncio loop + self._close_future = asyncio.Future() + try: async with self._transport_callable() as (read_stream, write_stream, *_): self._log_debug_with_thread("transport connection established") @@ -583,8 +596,9 @@ async def _async_background_thread(self) -> None: self._log_debug_with_thread("waiting for close signal") # Keep background thread running until signaled to close. - # Thread is not blocked as this is an asyncio.Event not a threading.Event - await self._close_event.wait() + # Thread is not blocked as this a future + await self._close_future + self._log_debug_with_thread("close signal received") except Exception as e: # If we encounter an exception and the future is still running, @@ -592,6 +606,12 @@ async def _async_background_thread(self) -> None: if not self._init_future.done(): self._init_future.set_exception(e) else: + # _close_future is automatically cancelled by the framework which doesn't provide us with the useful + # exception, so instead we store the exception in a different field where stop() can read it + self._close_exception = e + if self._close_future and not self._close_future.done(): + self._close_future.set_result(None) + self._log_debug_with_thread( "encountered exception on background thread after initialization %s", str(e) ) @@ -601,7 +621,7 @@ def _background_task(self) -> None: This method creates a new event loop for the background thread, sets it as the current event loop, and runs the async_background_thread - coroutine until completion. In this case "until completion" means until the _close_event is set. + coroutine until completion. In this case "until completion" means until the _close_future is resolved. This allows for a long-running event loop. """ self._log_debug_with_thread("setting up background task event loop") @@ -699,9 +719,34 @@ def _log_debug_with_thread(self, msg: str, *args: Any, **kwargs: Any) -> None: ) def _invoke_on_background_thread(self, coro: Coroutine[Any, Any, T]) -> futures.Future[T]: - if self._background_thread_session is None or self._background_thread_event_loop is None: + # save a reference to this so that even if it's reset we have the original + close_future = self._close_future + + if ( + self._background_thread_session is None + or self._background_thread_event_loop is None + or close_future is None + ): raise MCPClientInitializationError("the client session was not initialized") - return asyncio.run_coroutine_threadsafe(coro=coro, loop=self._background_thread_event_loop) + + async def run_async() -> T: + # Fix for strands-agents/sdk-python/issues/995 - cancel all pending invocations if/when the session closes + invoke_event = asyncio.create_task(coro) + tasks: list[asyncio.Task | asyncio.Future] = [ + invoke_event, + close_future, + ] + + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + if done.pop() == close_future: + self._log_debug_with_thread("event loop for the server closed before the invoke completed") + raise RuntimeError("Connection to the MCP server was closed") + else: + return await invoke_event + + invoke_future = asyncio.run_coroutine_threadsafe(coro=run_async(), loop=self._background_thread_event_loop) + return invoke_future def _should_include_tool(self, tool: MCPAgentTool) -> bool: """Check if a tool should be included based on constructor filters.""" diff --git a/src/strands/types/session.py b/src/strands/types/session.py index 4e72a1468..8b78ab448 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -7,7 +7,7 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Optional -from ..agent.interrupt import InterruptState +from ..interrupt import _InterruptState from .content import Message if TYPE_CHECKING: @@ -148,7 +148,7 @@ def to_dict(self) -> dict[str, Any]: def initialize_internal_state(self, agent: "Agent") -> None: """Initialize internal state of agent.""" if "interrupt_state" in self._internal_state: - agent._interrupt_state = InterruptState.from_dict(self._internal_state["interrupt_state"]) + agent._interrupt_state = _InterruptState.from_dict(self._internal_state["interrupt_state"]) @dataclass diff --git a/tests/strands/agent/hooks/test_hook_registry.py b/tests/strands/agent/hooks/test_hook_registry.py index 680ded682..ad1415f22 100644 --- a/tests/strands/agent/hooks/test_hook_registry.py +++ b/tests/strands/agent/hooks/test_hook_registry.py @@ -113,29 +113,32 @@ def test_get_callbacks_for_after_event(hook_registry, after_event): assert callbacks[1] == callback1 # Reverse order -def test_invoke_callbacks(hook_registry, normal_event): - """Test that invoke_callbacks calls all registered callbacks for an event.""" +@pytest.mark.asyncio +async def test_invoke_callbacks_async(hook_registry, normal_event): + """Test that invoke_callbacks_async calls all registered callbacks for an event.""" callback1 = Mock() callback2 = Mock() hook_registry.add_callback(NormalTestEvent, callback1) hook_registry.add_callback(NormalTestEvent, callback2) - hook_registry.invoke_callbacks(normal_event) + await hook_registry.invoke_callbacks_async(normal_event) callback1.assert_called_once_with(normal_event) callback2.assert_called_once_with(normal_event) -def test_invoke_callbacks_no_registered_callbacks(hook_registry, normal_event): - """Test that invoke_callbacks doesn't fail when there are no registered callbacks.""" +@pytest.mark.asyncio +async def test_invoke_callbacks_async_no_registered_callbacks(hook_registry, normal_event): + """Test that invoke_callbacks_async doesn't fail when there are no registered callbacks.""" # No callbacks registered - hook_registry.invoke_callbacks(normal_event) + await hook_registry.invoke_callbacks_async(normal_event) # Test passes if no exception is raised -def test_invoke_callbacks_after_event(hook_registry, after_event): - """Test that invoke_callbacks calls callbacks in reverse order for after events.""" +@pytest.mark.asyncio +async def test_invoke_callbacks_async_after_event(hook_registry, after_event): + """Test that invoke_callbacks_async calls callbacks in reverse order for after events.""" call_order: List[str] = [] def callback1(_event): @@ -147,7 +150,7 @@ def callback2(_event): hook_registry.add_callback(AfterTestEvent, callback1) hook_registry.add_callback(AfterTestEvent, callback2) - hook_registry.invoke_callbacks(after_event) + await hook_registry.invoke_callbacks_async(after_event) assert call_order == ["callback2", "callback1"] # Reverse order diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 550422cfe..d04f57948 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1221,6 +1221,37 @@ async def test_stream_async_multi_modal_input(mock_model, agent, agenerator, ali assert tru_message == exp_message +def test_system_prompt_setter_string(): + """Test that setting system_prompt with string updates both internal fields.""" + agent = Agent(system_prompt="initial prompt") + + agent.system_prompt = "updated prompt" + + assert agent.system_prompt == "updated prompt" + assert agent._system_prompt_content == [{"text": "updated prompt"}] + + +def test_system_prompt_setter_list(): + """Test that setting system_prompt with list updates both internal fields.""" + agent = Agent() + + content_blocks = [{"text": "You are helpful"}, {"cache_control": {"type": "ephemeral"}}] + agent.system_prompt = content_blocks + + assert agent.system_prompt == "You are helpful" + assert agent._system_prompt_content == content_blocks + + +def test_system_prompt_setter_none(): + """Test that setting system_prompt to None clears both internal fields.""" + agent = Agent(system_prompt="initial prompt") + + agent.system_prompt = None + + assert agent.system_prompt is None + assert agent._system_prompt_content is None + + @pytest.mark.asyncio async def test_stream_async_passes_invocation_state(agent, mock_model, mock_event_loop_cycle, agenerator, alist): mock_model.mock_stream.side_effect = [ @@ -1360,6 +1391,7 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model tools=agent.tool_names, system_prompt=agent.system_prompt, custom_trace_attributes=agent.trace_attributes, + tools_config=unittest.mock.ANY, ) # Verify span was ended with the result @@ -1394,6 +1426,7 @@ async def test_event_loop(*args, **kwargs): tools=agent.tool_names, system_prompt=agent.system_prompt, custom_trace_attributes=agent.trace_attributes, + tools_config=unittest.mock.ANY, ) expected_response = AgentResult( @@ -1432,6 +1465,7 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod tools=agent.tool_names, system_prompt=agent.system_prompt, custom_trace_attributes=agent.trace_attributes, + tools_config=unittest.mock.ANY, ) # Verify span was ended with the exception @@ -1468,6 +1502,7 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr tools=agent.tool_names, system_prompt=agent.system_prompt, custom_trace_attributes=agent.trace_attributes, + tools_config=unittest.mock.ANY, ) # Verify span was ended with the exception diff --git a/tests/strands/agent/test_interrupt.py b/tests/strands/agent/test_interrupt.py deleted file mode 100644 index e248c29a6..000000000 --- a/tests/strands/agent/test_interrupt.py +++ /dev/null @@ -1,61 +0,0 @@ -import pytest - -from strands.agent.interrupt import InterruptState -from strands.interrupt import Interrupt - - -@pytest.fixture -def interrupt(): - return Interrupt(id="test_id", name="test_name", reason="test reason") - - -def test_interrupt_activate(): - interrupt_state = InterruptState() - - interrupt_state.activate(context={"test": "context"}) - - assert interrupt_state.activated - - tru_context = interrupt_state.context - exp_context = {"test": "context"} - assert tru_context == exp_context - - -def test_interrupt_deactivate(): - interrupt_state = InterruptState(context={"test": "context"}, activated=True) - - interrupt_state.deactivate() - - assert not interrupt_state.activated - - tru_context = interrupt_state.context - exp_context = {} - assert tru_context == exp_context - - -def test_interrupt_state_to_dict(interrupt): - interrupt_state = InterruptState(interrupts={"test_id": interrupt}, context={"test": "context"}, activated=True) - - tru_data = interrupt_state.to_dict() - exp_data = { - "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, - "context": {"test": "context"}, - "activated": True, - } - assert tru_data == exp_data - - -def test_interrupt_state_from_dict(): - data = { - "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, - "context": {"test": "context"}, - "activated": True, - } - - tru_state = InterruptState.from_dict(data) - exp_state = InterruptState( - interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, - context={"test": "context"}, - activated=True, - ) - assert tru_state == exp_state diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 72fe1b4bd..9335f91a8 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -1,12 +1,11 @@ import concurrent import unittest.mock -from unittest.mock import ANY, MagicMock, call, patch +from unittest.mock import ANY, AsyncMock, MagicMock, call, patch import pytest import strands import strands.telemetry -from strands.agent.interrupt import InterruptState from strands.hooks import ( AfterModelCallEvent, BeforeModelCallEvent, @@ -14,7 +13,7 @@ HookRegistry, MessageAddedEvent, ) -from strands.interrupt import Interrupt +from strands.interrupt import Interrupt, _InterruptState from strands.telemetry.metrics import EventLoopMetrics from strands.tools.executors import SequentialToolExecutor from strands.tools.registry import ToolRegistry @@ -143,7 +142,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock.event_loop_metrics = EventLoopMetrics() mock.hooks = hook_registry mock.tool_executor = tool_executor - mock._interrupt_state = InterruptState() + mock._interrupt_state = _InterruptState() return mock @@ -750,6 +749,7 @@ async def test_request_state_initialization(alist): # not setting this to False results in endless recursion mock_agent._interrupt_state.activated = False mock_agent.event_loop_metrics.start_cycle.return_value = (0, MagicMock()) + mock_agent.hooks.invoke_callbacks_async = AsyncMock() # Call without providing request_state stream = strands.event_loop.event_loop.event_loop_cycle( diff --git a/tests/strands/event_loop/test_event_loop_structured_output.py b/tests/strands/event_loop/test_event_loop_structured_output.py index 6d3e3a9b5..886da2f0b 100644 --- a/tests/strands/event_loop/test_event_loop_structured_output.py +++ b/tests/strands/event_loop/test_event_loop_structured_output.py @@ -1,6 +1,6 @@ """Tests for structured output integration in the event loop.""" -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest from pydantic import BaseModel @@ -38,10 +38,10 @@ def mock_agent(): agent.tool_registry = ToolRegistry() agent.event_loop_metrics = EventLoopMetrics() agent.hooks = Mock() - agent.hooks.invoke_callbacks = Mock() + agent.hooks.invoke_callbacks_async = AsyncMock() agent.trace_span = None agent.tool_executor = Mock() - agent._append_message = Mock() + agent._append_message = AsyncMock() # Set up _interrupt_state properly agent._interrupt_state = Mock() diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 714fbac27..3f5a6c998 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -421,6 +421,43 @@ def test_extract_usage_metrics_with_cache_tokens(): assert tru_usage == exp_usage and tru_metrics == exp_metrics +def test_extract_usage_metrics_without_metrics(): + """Test extract_usage_metrics when metrics field is missing.""" + event = { + "usage": {"inputTokens": 5, "outputTokens": 2, "totalTokens": 7}, + } + + tru_usage, tru_metrics = strands.event_loop.streaming.extract_usage_metrics(event) + exp_usage = {"inputTokens": 5, "outputTokens": 2, "totalTokens": 7} + exp_metrics = {"latencyMs": 0} + + assert tru_usage == exp_usage and tru_metrics == exp_metrics + + +def test_extract_usage_metrics_without_usage(): + """Test extract_usage_metrics when usage field is missing.""" + event = { + "metrics": {"latencyMs": 100}, + } + + tru_usage, tru_metrics = strands.event_loop.streaming.extract_usage_metrics(event) + exp_usage = {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0} + exp_metrics = {"latencyMs": 100} + + assert tru_usage == exp_usage and tru_metrics == exp_metrics + + +def test_extract_usage_metrics_empty_metadata(): + """Test extract_usage_metrics when both fields are missing.""" + event = {} + + tru_usage, tru_metrics = strands.event_loop.streaming.extract_usage_metrics(event) + exp_usage = {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0} + exp_metrics = {"latencyMs": 0} + + assert tru_usage == exp_usage and tru_metrics == exp_metrics + + @pytest.mark.parametrize( ("response", "exp_events"), [ diff --git a/tests/strands/experimental/hooks/test_hook_aliases.py b/tests/strands/experimental/hooks/test_hook_aliases.py index db9cd3783..6744aa00c 100644 --- a/tests/strands/experimental/hooks/test_hook_aliases.py +++ b/tests/strands/experimental/hooks/test_hook_aliases.py @@ -9,6 +9,8 @@ import sys from unittest.mock import Mock +import pytest + from strands.experimental.hooks import ( AfterModelInvocationEvent, AfterToolInvocationEvent, @@ -80,7 +82,8 @@ def test_after_model_call_event_type_equality(): assert isinstance(after_model_event, AfterModelCallEvent) -def test_experimental_aliases_in_hook_registry(): +@pytest.mark.asyncio +async def test_experimental_aliases_in_hook_registry(): """Verify that experimental aliases work with hook registry callbacks.""" hook_registry = HookRegistry() callback_called = False @@ -103,7 +106,7 @@ def experimental_callback(event: BeforeToolInvocationEvent): ) # Invoke callbacks - should work since alias points to same type - hook_registry.invoke_callbacks(test_event) + await hook_registry.invoke_callbacks_async(test_event) assert callback_called assert received_event is test_event diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py index 6918bd2ee..3daf41734 100644 --- a/tests/strands/hooks/test_registry.py +++ b/tests/strands/hooks/test_registry.py @@ -2,9 +2,8 @@ import pytest -from strands.agent.interrupt import InterruptState -from strands.hooks import BeforeToolCallEvent, HookRegistry -from strands.interrupt import Interrupt +from strands.hooks import AgentInitializedEvent, BeforeInvocationEvent, BeforeToolCallEvent, HookRegistry +from strands.interrupt import Interrupt, _InterruptState @pytest.fixture @@ -15,11 +14,19 @@ def registry(): @pytest.fixture def agent(): instance = unittest.mock.Mock() - instance._interrupt_state = InterruptState() + instance._interrupt_state = _InterruptState() return instance -def test_hook_registry_invoke_callbacks_interrupt(registry, agent): +def test_hook_registry_add_callback_agent_init_coroutine(registry): + callback = unittest.mock.AsyncMock() + + with pytest.raises(ValueError, match=r"AgentInitializedEvent can only be registered with a synchronous callback"): + registry.add_callback(AgentInitializedEvent, callback) + + +@pytest.mark.asyncio +async def test_hook_registry_invoke_callbacks_async_interrupt(registry, agent): event = BeforeToolCallEvent( agent=agent, selected_tool=None, @@ -35,7 +42,7 @@ def test_hook_registry_invoke_callbacks_interrupt(registry, agent): registry.add_callback(BeforeToolCallEvent, callback2) registry.add_callback(BeforeToolCallEvent, callback3) - _, tru_interrupts = registry.invoke_callbacks(event) + _, tru_interrupts = await registry.invoke_callbacks_async(event) exp_interrupts = [ Interrupt( id="v1:before_tool_call:test_tool_id:da3551f3-154b-5978-827e-50ac387877ee", @@ -55,7 +62,8 @@ def test_hook_registry_invoke_callbacks_interrupt(registry, agent): callback3.assert_called_once_with(event) -def test_hook_registry_invoke_callbacks_interrupt_name_clash(registry, agent): +@pytest.mark.asyncio +async def test_hook_registry_invoke_callbacks_async_interrupt_name_clash(registry, agent): event = BeforeToolCallEvent( agent=agent, selected_tool=None, @@ -70,4 +78,12 @@ def test_hook_registry_invoke_callbacks_interrupt_name_clash(registry, agent): registry.add_callback(BeforeToolCallEvent, callback2) with pytest.raises(ValueError, match="interrupt_name= | interrupt name used more than once"): - registry.invoke_callbacks(event) + await registry.invoke_callbacks_async(event) + + +def test_hook_registry_invoke_callbacks_coroutine(registry, agent): + callback = unittest.mock.AsyncMock() + registry.add_callback(BeforeInvocationEvent, callback) + + with pytest.raises(RuntimeError, match=r"use invoke_callbacks_async to invoke async callback"): + registry.invoke_callbacks(BeforeInvocationEvent(agent=agent)) diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 57a8593cd..aafee1d17 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -192,6 +192,8 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, mock_event_7 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_7)]) mock_event_8 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_8)]) mock_event_9 = unittest.mock.Mock() + mock_event_9.usage.prompt_tokens_details.cached_tokens = 10 + mock_event_9.usage.prompt_tokens_details.cache_creation_tokens = 10 litellm_acompletion.side_effect = unittest.mock.AsyncMock( return_value=agenerator( @@ -252,6 +254,8 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, { "metadata": { "usage": { + "cacheReadInputTokens": mock_event_9.usage.prompt_tokens_details.cached_tokens, + "cacheWriteInputTokens": mock_event_9.usage.prompt_tokens_details.cache_creation_tokens, "inputTokens": mock_event_9.usage.prompt_tokens, "outputTokens": mock_event_9.usage.completion_tokens, "totalTokens": mock_event_9.usage.total_tokens, @@ -402,3 +406,75 @@ async def test_context_window_maps_to_typed_exception(litellm_acompletion, model with pytest.raises(ContextWindowOverflowException): async for _ in model.stream([{"role": "user", "content": [{"text": "x"}]}]): pass + + +@pytest.mark.asyncio +async def test_stream_raises_error_when_stream_is_false(model): + """Test that stream raises ValueError when stream parameter is explicitly False.""" + messages = [{"role": "user", "content": [{"text": "test"}]}] + + with pytest.raises(ValueError, match="stream parameter cannot be explicitly set to False"): + async for _ in model.stream(messages, stream=False): + pass + + +def test_format_request_messages_with_system_prompt_content(): + """Test format_request_messages with system_prompt_content parameter.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [{"text": "You are a helpful assistant."}, {"cachePoint": {"type": "default"}}] + + result = LiteLLMModel.format_request_messages(messages, system_prompt_content=system_prompt_content) + + expected = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant.", "cache_control": {"type": "ephemeral"}} + ], + }, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected + + +def test_format_request_messages_backward_compatibility_system_prompt(): + """Test that system_prompt is converted to system_prompt_content when system_prompt_content is None.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt = "You are a helpful assistant." + + result = LiteLLMModel.format_request_messages(messages, system_prompt=system_prompt) + + expected = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected + + +def test_format_request_messages_cache_point_support(): + """Test that cache points are properly applied to preceding content blocks.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [ + {"text": "First instruction."}, + {"text": "Second instruction."}, + {"cachePoint": {"type": "default"}}, + {"text": "Third instruction."}, + ] + + result = LiteLLMModel.format_request_messages(messages, system_prompt_content=system_prompt_content) + + expected = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "First instruction."}, + {"type": "text", "text": "Second instruction.", "cache_control": {"type": "ephemeral"}}, + {"type": "text", "text": "Third instruction."}, + ], + }, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index cc30b7420..0de0c4ebc 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -944,3 +944,45 @@ async def test_structured_output_rate_limit_as_throttle(openai_client, model, me # Verify the exception message contains the original error assert "tokens per min" in str(exc_info.value) assert exc_info.value.__cause__ == mock_error + + +def test_format_request_messages_with_system_prompt_content(): + """Test format_request_messages with system_prompt_content parameter.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [{"text": "You are a helpful assistant."}] + + result = OpenAIModel.format_request_messages(messages, system_prompt_content=system_prompt_content) + + expected = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected + + +def test_format_request_messages_with_none_system_prompt_content(): + """Test format_request_messages with system_prompt_content parameter.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + result = OpenAIModel.format_request_messages(messages) + + expected = [{"role": "user", "content": [{"text": "Hello", "type": "text"}]}] + + assert result == expected + + +def test_format_request_messages_drops_cache_points(): + """Test that cache points are dropped in OpenAI format_request_messages.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [{"text": "You are a helpful assistant."}, {"cachePoint": {"type": "default"}}] + + result = OpenAIModel.format_request_messages(messages, system_prompt_content=system_prompt_content) + + # Cache points should be dropped, only text content included + expected = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index e8a6a5f79..008b2954d 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -1149,3 +1149,30 @@ async def test_swarm_persistence(mock_strands_tracer, mock_use_span): assert final_state["status"] == "completed" assert len(final_state["node_history"]) == 1 assert "test_agent" in final_state["node_results"] + + +@pytest.mark.asyncio +async def test_swarm_handle_handoff(): + first_agent = create_mock_agent("first") + second_agent = create_mock_agent("second") + + swarm = Swarm([first_agent, second_agent]) + + async def handoff_stream(*args, **kwargs): + yield {"agent_start": True} + + swarm._handle_handoff(swarm.nodes["second"], "test message", {}) + + assert swarm.state.current_node.node_id == "first" + assert swarm.state.handoff_node.node_id == "second" + + yield {"result": first_agent.return_value} + + first_agent.stream_async = Mock(side_effect=handoff_stream) + + result = await swarm.invoke_async("test") + assert result.status == Status.COMPLETED + + tru_node_order = [node.node_id for node in result.node_history] + exp_node_order = ["first", "second"] + assert tru_node_order == exp_node_order diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index ed0ec9072..451d0dd09 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -7,7 +7,7 @@ from strands.agent.agent import Agent from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager -from strands.agent.interrupt import InterruptState +from strands.interrupt import _InterruptState from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import ContentBlock from strands.types.exceptions import SessionException @@ -131,7 +131,7 @@ def test_initialize_restores_existing_agent(session_manager, agent): assert len(agent.messages) == 1 assert agent.messages[0]["role"] == "user" assert agent.messages[0]["content"][0]["text"] == "Hello" - assert agent._interrupt_state == InterruptState(interrupts={}, context={"test": "init"}, activated=False) + assert agent._interrupt_state == _InterruptState(interrupts={}, context={"test": "init"}, activated=False) def test_initialize_restores_existing_agent_with_summarizing_conversation_manager(session_manager): diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 05dbe387f..98cfb459f 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -163,11 +163,11 @@ def test_start_model_invoke_span(mock_tracer): assert span is not None -def test_start_model_invoke_span_latest_conventions(mock_tracer): +def test_start_model_invoke_span_latest_conventions(mock_tracer, monkeypatch): """Test starting a model invoke span with the latest semantic conventions.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True tracer.tracer = mock_tracer mock_span = mock.MagicMock() @@ -244,11 +244,11 @@ def test_end_model_invoke_span(mock_span): mock_span.end.assert_called_once() -def test_end_model_invoke_span_latest_conventions(mock_span): +def test_end_model_invoke_span_latest_conventions(mock_span, monkeypatch): """Test ending a model invoke span with the latest semantic conventions.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True message = {"role": "assistant", "content": [{"text": "Response"}]} usage = Usage(inputTokens=10, outputTokens=20, totalTokens=30) metrics = Metrics(latencyMs=20, timeToFirstByteMs=10) @@ -307,11 +307,11 @@ def test_start_tool_call_span(mock_tracer): assert span is not None -def test_start_tool_call_span_latest_conventions(mock_tracer): +def test_start_tool_call_span_latest_conventions(mock_tracer, monkeypatch): """Test starting a tool call span with the latest semantic conventions.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True tracer.tracer = mock_tracer mock_span = mock.MagicMock() @@ -396,11 +396,11 @@ def test_start_swarm_span_with_contentblock_task(mock_tracer): assert span is not None -def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer): +def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer, monkeypatch): """Test starting a swarm call span with task as list of contentBlock with latest semantic conventions.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True tracer.tracer = mock_tracer mock_span = mock.MagicMock() @@ -439,10 +439,10 @@ def test_end_swarm_span(mock_span): ) -def test_end_swarm_span_latest_conventions(mock_span): +def test_end_swarm_span_latest_conventions(mock_span, monkeypatch): """Test ending a tool call span with latest semantic conventions.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True swarm_final_reuslt = "foo bar bar" tracer.end_swarm_span(mock_span, swarm_final_reuslt) @@ -503,10 +503,10 @@ def test_end_tool_call_span(mock_span): mock_span.end.assert_called_once() -def test_end_tool_call_span_latest_conventions(mock_span): +def test_end_tool_call_span_latest_conventions(mock_span, monkeypatch): """Test ending a tool call span with the latest semantic conventions.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True tool_result = {"status": "success", "content": [{"text": "Tool result"}, {"json": {"foo": "bar"}}]} tracer.end_tool_call_span(mock_span, tool_result) @@ -558,11 +558,11 @@ def test_start_event_loop_cycle_span(mock_tracer): assert span is not None -def test_start_event_loop_cycle_span_latest_conventions(mock_tracer): +def test_start_event_loop_cycle_span_latest_conventions(mock_tracer, monkeypatch): """Test starting an event loop cycle span with the latest semantic conventions.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True tracer.tracer = mock_tracer mock_span = mock.MagicMock() @@ -609,10 +609,10 @@ def test_end_event_loop_cycle_span(mock_span): mock_span.end.assert_called_once() -def test_end_event_loop_cycle_span_latest_conventions(mock_span): +def test_end_event_loop_cycle_span_latest_conventions(mock_span, monkeypatch): """Test ending an event loop cycle span with the latest semantic conventions.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True message = {"role": "assistant", "content": [{"text": "Response"}]} tool_result_message = { "role": "assistant", @@ -679,11 +679,11 @@ def test_start_agent_span(mock_tracer): assert span is not None -def test_start_agent_span_latest_conventions(mock_tracer): +def test_start_agent_span_latest_conventions(mock_tracer, monkeypatch): """Test starting an agent span with the latest semantic conventions.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True tracer.tracer = mock_tracer mock_span = mock.MagicMock() @@ -749,10 +749,10 @@ def test_end_agent_span(mock_span): mock_span.end.assert_called_once() -def test_end_agent_span_latest_conventions(mock_span): +def test_end_agent_span_latest_conventions(mock_span, monkeypatch): """Test ending an agent span with the latest semantic conventions.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True # Mock AgentResult with metrics mock_metrics = mock.MagicMock() @@ -1324,3 +1324,59 @@ def test_start_event_loop_cycle_span_with_tool_result_message(mock_tracer): "gen_ai.tool.message", attributes={"content": json.dumps(messages[0]["content"])} ) assert span is not None + + +def test_start_agent_span_does_not_include_tool_definitions_by_default(): + """Verify that start_agent_span does not include tool definitions by default.""" + tracer = Tracer() + tracer._start_span = mock.MagicMock() + + tools_config = { + "my_tool": { + "name": "my_tool", + "description": "A test tool", + "inputSchema": {"json": {}}, + "outputSchema": {"json": {}}, + } + } + + tracer.start_agent_span(messages=[], agent_name="TestAgent", tools_config=tools_config) + + tracer._start_span.assert_called_once() + _, call_kwargs = tracer._start_span.call_args + attributes = call_kwargs.get("attributes", {}) + assert "gen_ai.tool.definitions" not in attributes + + +def test_start_agent_span_includes_tool_definitions_when_enabled(monkeypatch): + """Verify that start_agent_span includes tool definitions when enabled.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_tool_definitions") + tracer = Tracer() + tracer._start_span = mock.MagicMock() + + tools_config = { + "my_tool": { + "name": "my_tool", + "description": "A test tool", + "inputSchema": {"json": {"type": "object", "properties": {}}}, + "outputSchema": {"json": {"type": "object", "properties": {}}}, + } + } + + tracer.start_agent_span(messages=[], agent_name="TestAgent", tools_config=tools_config) + + tracer._start_span.assert_called_once() + _, call_kwargs = tracer._start_span.call_args + attributes = call_kwargs.get("attributes", {}) + + assert "gen_ai.tool.definitions" in attributes + expected_tool_details = [ + { + "name": "my_tool", + "description": "A test tool", + "inputSchema": {"json": {"type": "object", "properties": {}}}, + "outputSchema": {"json": {"type": "object", "properties": {}}}, + } + ] + expected_json = serialize(expected_tool_details) + assert attributes["gen_ai.tool.definitions"] == expected_json diff --git a/tests/strands/test_interrupt.py b/tests/strands/test_interrupt.py index 8ce972103..a45d524e4 100644 --- a/tests/strands/test_interrupt.py +++ b/tests/strands/test_interrupt.py @@ -1,6 +1,6 @@ import pytest -from strands.interrupt import Interrupt +from strands.interrupt import Interrupt, _InterruptState @pytest.fixture @@ -22,3 +22,109 @@ def test_interrupt_to_dict(interrupt): "response": {"response": "test"}, } assert tru_dict == exp_dict + + +def test_interrupt_state_activate(): + interrupt_state = _InterruptState() + + interrupt_state.activate(context={"test": "context"}) + + assert interrupt_state.activated + + tru_context = interrupt_state.context + exp_context = {"test": "context"} + assert tru_context == exp_context + + +def test_interrupt_state_deactivate(): + interrupt_state = _InterruptState(context={"test": "context"}, activated=True) + + interrupt_state.deactivate() + + assert not interrupt_state.activated + + tru_context = interrupt_state.context + exp_context = {} + assert tru_context == exp_context + + +def test_interrupt_state_to_dict(): + interrupt_state = _InterruptState( + interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, + context={"test": "context"}, + activated=True, + ) + + tru_data = interrupt_state.to_dict() + exp_data = { + "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, + "context": {"test": "context"}, + "activated": True, + } + assert tru_data == exp_data + + +def test_interrupt_state_from_dict(): + data = { + "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, + "context": {"test": "context"}, + "activated": True, + } + + tru_state = _InterruptState.from_dict(data) + exp_state = _InterruptState( + interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, + context={"test": "context"}, + activated=True, + ) + assert tru_state == exp_state + + +def test_interrupt_state_resume(): + interrupt_state = _InterruptState( + interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, + activated=True, + ) + + prompt = [ + { + "interruptResponse": { + "interruptId": "test_id", + "response": "test response", + } + } + ] + interrupt_state.resume(prompt) + + tru_response = interrupt_state.interrupts["test_id"].response + exp_response = "test response" + assert tru_response == exp_response + + +def test_interrupt_state_resumse_deactivated(): + interrupt_state = _InterruptState(activated=False) + interrupt_state.resume([]) + + +def test_interrupt_state_resume_invalid_prompt(): + interrupt_state = _InterruptState(activated=True) + + exp_message = r"prompt_type= \| must resume from interrupt with list of interruptResponse's" + with pytest.raises(TypeError, match=exp_message): + interrupt_state.resume("invalid") + + +def test_interrupt_state_resume_invalid_content(): + interrupt_state = _InterruptState(activated=True) + + exp_message = r"content_types=<\['text'\]> \| must resume from interrupt with list of interruptResponse's" + with pytest.raises(TypeError, match=exp_message): + interrupt_state.resume([{"text": "invalid"}]) + + +def test_interrupt_resume_invalid_id(): + interrupt_state = _InterruptState(activated=True) + + exp_message = r"interrupt_id= \| no interrupt found" + with pytest.raises(KeyError, match=exp_message): + interrupt_state.resume([{"interruptResponse": {"interruptId": "invalid", "response": None}}]) diff --git a/tests/strands/tools/executors/conftest.py b/tests/strands/tools/executors/conftest.py index d25cf14bd..4d299a539 100644 --- a/tests/strands/tools/executors/conftest.py +++ b/tests/strands/tools/executors/conftest.py @@ -4,8 +4,8 @@ import pytest import strands -from strands.agent.interrupt import InterruptState from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent, HookRegistry +from strands.interrupt import _InterruptState from strands.tools.registry import ToolRegistry from strands.types.tools import ToolContext @@ -104,7 +104,7 @@ def agent(tool_registry, hook_registry): mock_agent = unittest.mock.Mock() mock_agent.tool_registry = tool_registry mock_agent.hooks = hook_registry - mock_agent._interrupt_state = InterruptState() + mock_agent._interrupt_state = _InterruptState() return mock_agent diff --git a/tests/strands/tools/mcp/test_mcp_agent_tool.py b/tests/strands/tools/mcp/test_mcp_agent_tool.py index 442a9919b..81a2d9afb 100644 --- a/tests/strands/tools/mcp/test_mcp_agent_tool.py +++ b/tests/strands/tools/mcp/test_mcp_agent_tool.py @@ -1,3 +1,4 @@ +from datetime import timedelta from unittest.mock import MagicMock import pytest @@ -88,5 +89,31 @@ async def test_stream(mcp_agent_tool, mock_mcp_client, alist): assert tru_events == exp_events mock_mcp_client.call_tool_async.assert_called_once_with( - tool_use_id="test-123", name="test_tool", arguments={"param": "value"} + tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=None + ) + + +def test_timeout_initialization(mock_mcp_tool, mock_mcp_client): + timeout = timedelta(seconds=30) + agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client, timeout=timeout) + assert agent_tool.timeout == timeout + + +def test_timeout_default_none(mock_mcp_tool, mock_mcp_client): + agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client) + assert agent_tool.timeout is None + + +@pytest.mark.asyncio +async def test_stream_with_timeout(mock_mcp_tool, mock_mcp_client, alist): + timeout = timedelta(seconds=45) + agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client, timeout=timeout) + tool_use = {"toolUseId": "test-456", "name": "test_tool", "input": {"param": "value"}} + + tru_events = await alist(agent_tool.stream(tool_use, {})) + exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)] + + assert tru_events == exp_events + mock_mcp_client.call_tool_async.assert_called_once_with( + tool_use_id="test-456", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=timeout ) diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 25f9bc39e..a2a4c6213 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -3,15 +3,15 @@ """ from asyncio import Queue -from typing import Any, AsyncGenerator, Dict, Optional, Union +from typing import Annotated, Any, AsyncGenerator, Dict, List, Optional, Union from unittest.mock import MagicMock import pytest +from pydantic import Field import strands from strands import Agent -from strands.agent.interrupt import InterruptState -from strands.interrupt import Interrupt +from strands.interrupt import Interrupt, _InterruptState from strands.types._events import ToolInterruptEvent, ToolResultEvent, ToolStreamEvent from strands.types.tools import AgentTool, ToolContext, ToolUse @@ -151,7 +151,7 @@ async def test_stream_interrupt(alist): tool_use = {"toolUseId": "test_tool_id"} mock_agent = MagicMock() - mock_agent._interrupt_state = InterruptState() + mock_agent._interrupt_state = _InterruptState() invocation_state = {"agent": mock_agent} @@ -178,7 +178,7 @@ async def test_stream_interrupt_resume(alist): tool_use = {"toolUseId": "test_tool_id"} mock_agent = MagicMock() - mock_agent._interrupt_state = InterruptState(interrupts={interrupt.id: interrupt}) + mock_agent._interrupt_state = _InterruptState(interrupts={interrupt.id: interrupt}) invocation_state = {"agent": mock_agent} @@ -221,14 +221,7 @@ def test_tool(param1: str, param2: int) -> str: # Check basic spec properties assert spec["name"] == "test_tool" - assert ( - spec["description"] - == """Test tool function. - -Args: - param1: First parameter - param2: Second parameter""" - ) + assert spec["description"] == "Test tool function." # Check input schema schema = spec["inputSchema"]["json"] @@ -310,6 +303,174 @@ def test_tool(required: str, optional: Optional[int] = None) -> str: exp_events = [ ToolResultEvent({"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}) ] + assert tru_events == exp_events + + +@pytest.mark.asyncio +async def test_docstring_description_extraction(): + """Test that docstring descriptions are extracted correctly, excluding Args section.""" + + @strands.tool + def tool_with_full_docstring(param1: str, param2: int) -> str: + """This is the main description. + + This is more description text. + + Args: + param1: First parameter + param2: Second parameter + + Returns: + A string result + + Raises: + ValueError: If something goes wrong + """ + return f"{param1} {param2}" + + spec = tool_with_full_docstring.tool_spec + assert ( + spec["description"] + == """This is the main description. + +This is more description text. + +Returns: + A string result + +Raises: + ValueError: If something goes wrong""" + ) + + +def test_docstring_args_variations(): + """Test that various Args section formats are properly excluded.""" + + @strands.tool + def tool_with_args(param: str) -> str: + """Main description. + + Args: + param: Parameter description + """ + return param + + @strands.tool + def tool_with_arguments(param: str) -> str: + """Main description. + + Arguments: + param: Parameter description + """ + return param + + @strands.tool + def tool_with_parameters(param: str) -> str: + """Main description. + + Parameters: + param: Parameter description + """ + return param + + @strands.tool + def tool_with_params(param: str) -> str: + """Main description. + + Params: + param: Parameter description + """ + return param + + for tool in [tool_with_args, tool_with_arguments, tool_with_parameters, tool_with_params]: + spec = tool.tool_spec + assert spec["description"] == "Main description." + + +def test_docstring_no_args_section(): + """Test docstring extraction when there's no Args section.""" + + @strands.tool + def tool_no_args(param: str) -> str: + """This is the complete description. + + Returns: + A string result + """ + return param + + spec = tool_no_args.tool_spec + expected_desc = """This is the complete description. + +Returns: + A string result""" + assert spec["description"] == expected_desc + + +def test_docstring_only_args_section(): + """Test docstring extraction when there's only an Args section.""" + + @strands.tool + def tool_only_args(param: str) -> str: + """Args: + param: Parameter description + """ + return param + + spec = tool_only_args.tool_spec + # Should fall back to function name when no description remains + assert spec["description"] == "tool_only_args" + + +def test_docstring_empty(): + """Test docstring extraction when docstring is empty.""" + + @strands.tool + def tool_empty_docstring(param: str) -> str: + return param + + spec = tool_empty_docstring.tool_spec + # Should fall back to function name + assert spec["description"] == "tool_empty_docstring" + + +def test_docstring_preserves_other_sections(): + """Test that non-Args sections are preserved in the description.""" + + @strands.tool + def tool_multiple_sections(param: str) -> str: + """Main description here. + + Args: + param: This should be excluded + + Returns: + This should be included + + Raises: + ValueError: This should be included + + Examples: + This should be included + + Note: + This should be included + """ + return param + + spec = tool_multiple_sections.tool_spec + description = spec["description"] + + # Should include main description and other sections + assert "Main description here." in description + assert "Returns:" in description + assert "This should be included" in description + assert "Raises:" in description + assert "Examples:" in description + assert "Note:" in description + + # Should exclude Args section + assert "This should be excluded" not in description @pytest.mark.asyncio @@ -1450,3 +1611,214 @@ def test_function_tool_metadata_validate_signature_missing_context_config(): @strands.tool def my_tool(tool_context: ToolContext): pass + + +def test_tool_decorator_annotated_string_description(): + """Test tool decorator with Annotated type hints for descriptions.""" + + @strands.tool + def annotated_tool( + name: Annotated[str, "The user's full name"], + age: Annotated[int, "The user's age in years"], + city: str, # No annotation - should use docstring or generic + ) -> str: + """Tool with annotated parameters. + + Args: + city: The user's city (from docstring) + """ + return f"{name}, {age}, {city}" + + spec = annotated_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # Check that annotated descriptions are used + assert schema["properties"]["name"]["description"] == "The user's full name" + assert schema["properties"]["age"]["description"] == "The user's age in years" + + # Check that docstring is still used for non-annotated params + assert schema["properties"]["city"]["description"] == "The user's city (from docstring)" + + # Verify all are required + assert set(schema["required"]) == {"name", "age", "city"} + + +def test_tool_decorator_annotated_pydantic_field_constraints(): + """Test that using pydantic.Field in Annotated raises a NotImplementedError.""" + with pytest.raises(NotImplementedError, match="Using pydantic.Field within Annotated is not yet supported"): + + @strands.tool + def field_annotated_tool( + email: Annotated[str, Field(description="User's email address", pattern=r"^[\w\.-]+@[\w\.-]+\\.w+$")], + score: Annotated[int, Field(description="Score between 0-100", ge=0, le=100)] = 50, + ) -> str: + """Tool with Pydantic Field annotations.""" + return f"{email}: {score}" + + +def test_tool_decorator_annotated_overrides_docstring(): + """Test that Annotated descriptions override docstring descriptions.""" + + @strands.tool + def override_tool(param: Annotated[str, "Description from annotation"]) -> str: + """Tool with both annotation and docstring. + + Args: + param: Description from docstring (should be overridden) + """ + return param + + spec = override_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # Annotated description should win + assert schema["properties"]["param"]["description"] == "Description from annotation" + + +def test_tool_decorator_annotated_optional_type(): + """Test tool with Optional types in Annotated.""" + + @strands.tool + def optional_annotated_tool( + required: Annotated[str, "Required parameter"], optional: Annotated[Optional[str], "Optional parameter"] = None + ) -> str: + """Tool with optional annotated parameter.""" + return f"{required}, {optional}" + + spec = optional_annotated_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # Check descriptions + assert schema["properties"]["required"]["description"] == "Required parameter" + assert schema["properties"]["optional"]["description"] == "Optional parameter" + + # Check required list + assert "required" in schema["required"] + assert "optional" not in schema["required"] + + +def test_tool_decorator_annotated_complex_types(): + """Test tool with complex types in Annotated.""" + + @strands.tool + def complex_annotated_tool( + tags: Annotated[List[str], "List of tag strings"], config: Annotated[Dict[str, Any], "Configuration dictionary"] + ) -> str: + """Tool with complex annotated types.""" + return f"Tags: {len(tags)}, Config: {len(config)}" + + spec = complex_annotated_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # Check descriptions + assert schema["properties"]["tags"]["description"] == "List of tag strings" + assert schema["properties"]["config"]["description"] == "Configuration dictionary" + + # Check types are preserved + assert schema["properties"]["tags"]["type"] == "array" + assert schema["properties"]["config"]["type"] == "object" + + +def test_tool_decorator_annotated_mixed_styles(): + """Test that using pydantic.Field in a mixed-style annotation raises NotImplementedError.""" + with pytest.raises(NotImplementedError, match="Using pydantic.Field within Annotated is not yet supported"): + + @strands.tool + def mixed_tool( + plain: str, + annotated_str: Annotated[str, "String description"], + annotated_field: Annotated[int, Field(description="Field description", ge=0)], + docstring_only: int, + ) -> str: + """Tool with mixed parameter styles. + + Args: + plain: Plain parameter description + docstring_only: Docstring description for this param + """ + return "mixed" + + +@pytest.mark.asyncio +async def test_tool_decorator_annotated_execution(alist): + """Test that annotated tools execute correctly.""" + + @strands.tool + def execution_test(name: Annotated[str, "User name"], count: Annotated[int, "Number of times"] = 1) -> str: + """Test execution with annotations.""" + return f"Hello {name} " * count + + # Test tool use + tool_use = {"toolUseId": "test-id", "input": {"name": "Alice", "count": 2}} + stream = execution_test.stream(tool_use, {}) + + result = (await alist(stream))[-1] + assert result["tool_result"]["status"] == "success" + assert "Hello Alice Hello Alice" in result["tool_result"]["content"][0]["text"] + + # Test direct call + direct_result = execution_test("Bob", 3) + assert direct_result == "Hello Bob Hello Bob Hello Bob " + + +def test_tool_decorator_annotated_no_description_fallback(): + """Test that Annotated with a Field raises NotImplementedError.""" + with pytest.raises(NotImplementedError, match="Using pydantic.Field within Annotated is not yet supported"): + + @strands.tool + def no_desc_annotated( + param: Annotated[str, Field()], # Field without description + ) -> str: + """Tool with Annotated but no description. + + Args: + param: Docstring description + """ + return param + + +def test_tool_decorator_annotated_empty_string_description(): + """Test handling of empty string descriptions in Annotated.""" + + @strands.tool + def empty_desc_tool( + param: Annotated[str, ""], # Empty string description + ) -> str: + """Tool with empty annotation description. + + Args: + param: Docstring description + """ + return param + + spec = empty_desc_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # Empty string is still a valid description, should not fall back + assert schema["properties"]["param"]["description"] == "" + + +@pytest.mark.asyncio +async def test_tool_decorator_annotated_validation_error(alist): + """Test that validation works correctly with annotated parameters.""" + + @strands.tool + def validation_tool(age: Annotated[int, "User age"]) -> str: + """Tool for validation testing.""" + return f"Age: {age}" + + # Test with wrong type + tool_use = {"toolUseId": "test-id", "input": {"age": "not an int"}} + stream = validation_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + assert result["tool_result"]["status"] == "error" + + +def test_tool_decorator_annotated_field_with_inner_default(): + """Test that a default value in an Annotated Field raises NotImplementedError.""" + with pytest.raises(NotImplementedError, match="Using pydantic.Field within Annotated is not yet supported"): + + @strands.tool + def inner_default_tool(name: str, level: Annotated[int, Field(description="A level value", default=10)]) -> str: + return f"{name} is at level {level}" diff --git a/tests/strands/types/test_interrupt.py b/tests/strands/types/test_interrupt.py index ade0fa5e8..ad31384b6 100644 --- a/tests/strands/types/test_interrupt.py +++ b/tests/strands/types/test_interrupt.py @@ -2,8 +2,7 @@ import pytest -from strands.agent.interrupt import InterruptState -from strands.interrupt import Interrupt, InterruptException +from strands.interrupt import Interrupt, InterruptException, _InterruptState from strands.types.interrupt import _Interruptible @@ -20,7 +19,7 @@ def interrupt(): @pytest.fixture def agent(): instance = unittest.mock.Mock() - instance._interrupt_state = InterruptState() + instance._interrupt_state = _InterruptState() return instance diff --git a/tests/strands/types/test_session.py b/tests/strands/types/test_session.py index 26d4062e4..3e5360742 100644 --- a/tests/strands/types/test_session.py +++ b/tests/strands/types/test_session.py @@ -3,8 +3,8 @@ from uuid import uuid4 from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager -from strands.agent.interrupt import InterruptState from strands.agent.state import AgentState +from strands.interrupt import _InterruptState from strands.types.session import ( Session, SessionAgent, @@ -101,7 +101,7 @@ def test_session_agent_from_agent(): agent.agent_id = "a1" agent.conversation_manager = unittest.mock.Mock(get_state=lambda: {"test": "conversation"}) agent.state = AgentState({"test": "state"}) - agent._interrupt_state = InterruptState(interrupts={}, context={}, activated=False) + agent._interrupt_state = _InterruptState(interrupts={}, context={}, activated=False) tru_session_agent = SessionAgent.from_agent(agent) exp_session_agent = SessionAgent( @@ -127,5 +127,5 @@ def test_session_agent_initialize_internal_state(): session_agent.initialize_internal_state(agent) tru_interrupt_state = agent._interrupt_state - exp_interrupt_state = InterruptState(interrupts={}, context={"test": "init"}, activated=False) + exp_interrupt_state = _InterruptState(interrupts={}, context={"test": "init"}, activated=False) assert tru_interrupt_state == exp_interrupt_state diff --git a/tests_integ/hooks/__init__.py b/tests_integ/hooks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_integ/hooks/multiagent/__init__.py b/tests_integ/hooks/multiagent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_integ/hooks/multiagent/test_events.py b/tests_integ/hooks/multiagent/test_events.py new file mode 100644 index 000000000..e8039444f --- /dev/null +++ b/tests_integ/hooks/multiagent/test_events.py @@ -0,0 +1,122 @@ +import pytest + +from strands import Agent +from strands.experimental.hooks.multiagent import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) +from strands.hooks import HookProvider +from strands.multiagent import GraphBuilder, Swarm + + +@pytest.fixture +def callback_names(): + return [] + + +@pytest.fixture +def hook_provider(callback_names): + class TestHook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(AfterMultiAgentInvocationEvent, self.after_multi_agent_invocation) + registry.add_callback(AfterMultiAgentInvocationEvent, self.after_multi_agent_invocation_async) + registry.add_callback(AfterNodeCallEvent, self.after_node_call) + registry.add_callback(AfterNodeCallEvent, self.after_node_call_async) + registry.add_callback(BeforeMultiAgentInvocationEvent, self.before_multi_agent_invocation) + registry.add_callback(BeforeMultiAgentInvocationEvent, self.before_multi_agent_invocation_async) + registry.add_callback(BeforeNodeCallEvent, self.before_node_call) + registry.add_callback(BeforeNodeCallEvent, self.before_node_call_async) + registry.add_callback(MultiAgentInitializedEvent, self.multi_agent_initialized_event) + registry.add_callback(MultiAgentInitializedEvent, self.multi_agent_initialized_event_async) + + def after_multi_agent_invocation(self, _event): + callback_names.append("after_multi_agent_invocation") + + async def after_multi_agent_invocation_async(self, _event): + callback_names.append("after_multi_agent_invocation_async") + + def after_node_call(self, _event): + callback_names.append("after_node_call") + + async def after_node_call_async(self, _event): + callback_names.append("after_node_call_async") + + def before_multi_agent_invocation(self, _event): + callback_names.append("before_multi_agent_invocation") + + async def before_multi_agent_invocation_async(self, _event): + callback_names.append("before_multi_agent_invocation_async") + + def before_node_call(self, _event): + callback_names.append("before_node_call") + + async def before_node_call_async(self, _event): + callback_names.append("before_node_call_async") + + def multi_agent_initialized_event(self, _event): + callback_names.append("multi_agent_initialized_event") + + async def multi_agent_initialized_event_async(self, _event): + callback_names.append("multi_agent_initialized_event_async") + + return TestHook() + + +@pytest.fixture +def agent(): + return Agent() + + +@pytest.fixture +def graph(agent, hook_provider): + builder = GraphBuilder() + builder.add_node(agent, "agent") + builder.set_entry_point("agent") + builder.set_hook_providers([hook_provider]) + return builder.build() + + +@pytest.fixture +def swarm(agent, hook_provider): + return Swarm([agent], hooks=[hook_provider]) + + +def test_graph_events(graph, callback_names): + graph("Hello") + + tru_callback_names = callback_names + exp_callback_names = [ + "multi_agent_initialized_event", + "multi_agent_initialized_event_async", + "before_multi_agent_invocation", + "before_multi_agent_invocation_async", + "before_node_call", + "before_node_call_async", + "after_node_call_async", + "after_node_call", + "after_multi_agent_invocation_async", + "after_multi_agent_invocation", + ] + assert tru_callback_names == exp_callback_names + + +def test_swarm_events(swarm, callback_names): + swarm("Hello") + + tru_callback_names = callback_names + exp_callback_names = [ + "multi_agent_initialized_event", + "multi_agent_initialized_event_async", + "before_multi_agent_invocation", + "before_multi_agent_invocation_async", + "before_node_call", + "before_node_call_async", + "after_node_call_async", + "after_node_call", + "after_multi_agent_invocation_async", + "after_multi_agent_invocation", + ] + assert tru_callback_names == exp_callback_names diff --git a/tests_integ/hooks/test_events.py b/tests_integ/hooks/test_events.py new file mode 100644 index 000000000..25971ecb0 --- /dev/null +++ b/tests_integ/hooks/test_events.py @@ -0,0 +1,138 @@ +import pytest + +from strands import Agent, tool +from strands.hooks import ( + AfterInvocationEvent, + AfterModelCallEvent, + AfterToolCallEvent, + AgentInitializedEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, + BeforeToolCallEvent, + HookProvider, + MessageAddedEvent, +) + + +@pytest.fixture +def callback_names(): + return [] + + +@pytest.fixture +def hook_provider(callback_names): + class TestHook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(AfterInvocationEvent, self.after_invocation) + registry.add_callback(AfterInvocationEvent, self.after_invocation_async) + registry.add_callback(AfterModelCallEvent, self.after_model_call) + registry.add_callback(AfterModelCallEvent, self.after_model_call_async) + registry.add_callback(AfterToolCallEvent, self.after_tool_call) + registry.add_callback(AfterToolCallEvent, self.after_tool_call_async) + registry.add_callback(AgentInitializedEvent, self.agent_initialized) + registry.add_callback(BeforeInvocationEvent, self.before_invocation) + registry.add_callback(BeforeInvocationEvent, self.before_invocation_async) + registry.add_callback(BeforeModelCallEvent, self.before_model_call) + registry.add_callback(BeforeModelCallEvent, self.before_model_call_async) + registry.add_callback(BeforeToolCallEvent, self.before_tool_call) + registry.add_callback(BeforeToolCallEvent, self.before_tool_call_async) + registry.add_callback(MessageAddedEvent, self.message_added) + registry.add_callback(MessageAddedEvent, self.message_added_async) + + def after_invocation(self, _event): + callback_names.append("after_invocation") + + async def after_invocation_async(self, _event): + callback_names.append("after_invocation_async") + + def after_model_call(self, _event): + callback_names.append("after_model_call") + + async def after_model_call_async(self, _event): + callback_names.append("after_model_call_async") + + def after_tool_call(self, _event): + callback_names.append("after_tool_call") + + async def after_tool_call_async(self, _event): + callback_names.append("after_tool_call_async") + + def agent_initialized(self, _event): + callback_names.append("agent_initialized") + + async def agent_initialized_async(self, _event): + callback_names.append("agent_initialized_async") + + def before_invocation(self, _event): + callback_names.append("before_invocation") + + async def before_invocation_async(self, _event): + callback_names.append("before_invocation_async") + + def before_model_call(self, _event): + callback_names.append("before_model_call") + + async def before_model_call_async(self, _event): + callback_names.append("before_model_call_async") + + def before_tool_call(self, _event): + callback_names.append("before_tool_call") + + async def before_tool_call_async(self, _event): + callback_names.append("before_tool_call_async") + + def message_added(self, _event): + callback_names.append("message_added") + + async def message_added_async(self, _event): + callback_names.append("message_added_async") + + return TestHook() + + +@pytest.fixture +def time_tool(): + @tool(name="time_tool") + def tool_() -> str: + return "12:00" + + return tool_ + + +@pytest.fixture +def agent(hook_provider, time_tool): + return Agent(hooks=[hook_provider], tools=[time_tool]) + + +def test_events(agent, callback_names): + agent("What time is it?") + + tru_callback_names = callback_names + exp_callback_names = [ + "agent_initialized", + "before_invocation", + "before_invocation_async", + "message_added", + "message_added_async", + "before_model_call", + "before_model_call_async", + "after_model_call_async", + "after_model_call", + "message_added", + "message_added_async", + "before_tool_call", + "before_tool_call_async", + "after_tool_call_async", + "after_tool_call", + "message_added", + "message_added_async", + "before_model_call", + "before_model_call_async", + "after_model_call_async", + "after_model_call", + "message_added", + "message_added_async", + "after_invocation_async", + "after_invocation", + ] + assert tru_callback_names == exp_callback_names diff --git a/tests_integ/mcp/test_mcp_client.py b/tests_integ/mcp/test_mcp_client.py index 2c9bb73e1..35cfd7e86 100644 --- a/tests_integ/mcp/test_mcp_client.py +++ b/tests_integ/mcp/test_mcp_client.py @@ -420,3 +420,70 @@ def transport_callback() -> MCPTransport: result = await streamable_http_client.call_tool_async(tool_use_id="123", name="timeout_tool") assert result["status"] == "error" assert result["content"][0]["text"] == "Tool execution failed: Connection closed" + + +def start_5xx_proxy_for_tool_calls(target_url: str, proxy_port: int): + """Starts a proxy that throws a 5XX when a tool call is invoked""" + import aiohttp + from aiohttp import web + + async def proxy_handler(request): + url = f"{target_url}{request.path_qs}" + + async with aiohttp.ClientSession() as session: + data = await request.read() + + if "tools/call" in f"{data}": + return web.Response(status=500, text="Internal Server Error") + + async with session.request( + method=request.method, url=url, headers=request.headers, data=data, allow_redirects=False + ) as resp: + print(f"Got request to {url} {data}") + response = web.StreamResponse(status=resp.status, headers=resp.headers) + await response.prepare(request) + + async for chunk in resp.content.iter_chunked(8192): + await response.write(chunk) + + return response + + app = web.Application() + app.router.add_route("*", "/{path:.*}", proxy_handler) + + web.run_app(app, host="127.0.0.1", port=proxy_port) + + +@pytest.mark.asyncio +async def test_streamable_http_mcp_client_with_500_error(): + import asyncio + import multiprocessing + + server_thread = threading.Thread( + target=start_comprehensive_mcp_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True + ) + server_thread.start() + + proxy_process = multiprocessing.Process( + target=start_5xx_proxy_for_tool_calls, kwargs={"target_url": "http://127.0.0.1:8001", "proxy_port": 8002} + ) + proxy_process.start() + + try: + await asyncio.sleep(2) # wait for server to startup completely + + def transport_callback() -> MCPTransport: + return streamablehttp_client(url="http://127.0.0.1:8002/mcp") + + streamable_http_client = MCPClient(transport_callback) + with pytest.raises(RuntimeError, match="Connection to the MCP server was closed"): + with streamable_http_client: + result = await streamable_http_client.call_tool_async( + tool_use_id="123", name="calculator", arguments={"x": 3, "y": 4} + ) + finally: + proxy_process.terminate() + proxy_process.join() + + assert result["status"] == "error" + assert result["content"][0]["text"] == "Tool execution failed: Connection to the MCP server was closed" diff --git a/tests_integ/models/test_model_anthropic.py b/tests_integ/models/test_model_anthropic.py index 62a95d06d..9a0d19dff 100644 --- a/tests_integ/models/test_model_anthropic.py +++ b/tests_integ/models/test_model_anthropic.py @@ -5,7 +5,10 @@ import strands from strands import Agent +from strands.agent import NullConversationManager from strands.models.anthropic import AnthropicModel +from strands.types.content import ContentBlock, Message +from strands.types.exceptions import ContextWindowOverflowException """ These tests only run if we have the anthropic api key @@ -152,3 +155,30 @@ def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color): tru_color = agent.structured_output(type(yellow_color), content) exp_color = yellow_color assert tru_color == exp_color + + +@pytest.mark.asyncio +def test_input_and_max_tokens_exceed_context_limit(): + """Test that triggers 'input length and max_tokens exceed context limit' error.""" + + # Note that this test is written specifically in a style that allows us to swap out conversation_manager and + # verify behavior + + model = AnthropicModel( + model_id="claude-sonnet-4-20250514", + max_tokens=64000, + ) + + large_message = "This is a very long text. " * 10000 + + messages = [ + Message(role="user", content=[ContentBlock(text=large_message)]), + Message(role="assistant", content=[ContentBlock(text=large_message)]), + Message(role="user", content=[ContentBlock(text=large_message)]), + ] + + # NullConversationManager will propagate ContextWindowOverflowException directly instead of handling it + agent = Agent(model=model, conversation_manager=NullConversationManager()) + + with pytest.raises(ContextWindowOverflowException): + agent(messages) diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index b348c29f4..f177c08a4 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -211,3 +211,25 @@ def test_structured_output_unsupported_model(model, nested_weather): # Verify that the tool method was called and schema method was not mock_tool.assert_called_once() mock_schema.assert_not_called() + + +@pytest.mark.asyncio +async def test_cache_read_tokens_multi_turn(model): + """Integration test for cache read tokens in multi-turn conversation.""" + from strands.types.content import SystemContentBlock + + system_prompt_content: list[SystemContentBlock] = [ + # Caching only works when prompts are large + {"text": "You are a helpful assistant. Always be concise." * 200}, + {"cachePoint": {"type": "default"}}, + ] + + agent = Agent(model=model, system_prompt=system_prompt_content) + + # First turn - establishes cache + agent("Hello, what's 2+2?") + result = agent("What's 3+3?") + result.metrics.accumulated_usage["cacheReadInputTokens"] + + assert result.metrics.accumulated_usage["cacheReadInputTokens"] > 0 + assert result.metrics.accumulated_usage["cacheWriteInputTokens"] > 0 diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index 7beb3013c..feb591d1a 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -231,3 +231,29 @@ def test_content_blocks_handling(model): result = agent(content) assert "4" in result.message["content"][0]["text"] + + +def test_system_prompt_content_integration(model): + """Integration test for system_prompt_content parameter.""" + from strands.types.content import SystemContentBlock + + system_prompt_content: list[SystemContentBlock] = [ + {"text": "You are a helpful assistant that always responds with 'SYSTEM_TEST_RESPONSE'."} + ] + + agent = Agent(model=model, system_prompt=system_prompt_content) + result = agent("Hello") + + # The response should contain our specific system prompt instruction + assert "SYSTEM_TEST_RESPONSE" in result.message["content"][0]["text"] + + +def test_system_prompt_backward_compatibility_integration(model): + """Integration test for backward compatibility with system_prompt parameter.""" + system_prompt = "You are a helpful assistant that always responds with 'BACKWARD_COMPAT_TEST'." + + agent = Agent(model=model, system_prompt=system_prompt) + result = agent("Hello") + + # The response should contain our specific system prompt instruction + assert "BACKWARD_COMPAT_TEST" in result.message["content"][0]["text"] diff --git a/tests_integ/tools/__init__.py b/tests_integ/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_integ/tools/test_thread_context.py b/tests_integ/tools/test_thread_context.py new file mode 100644 index 000000000..b86c9b2c0 --- /dev/null +++ b/tests_integ/tools/test_thread_context.py @@ -0,0 +1,47 @@ +import contextvars + +import pytest + +from strands import Agent, tool + + +@pytest.fixture +def result(): + return {} + + +@pytest.fixture +def contextvar(): + return contextvars.ContextVar("agent") + + +@pytest.fixture +def context_tool(result, contextvar): + @tool(name="context_tool") + def tool_(): + result["context_value"] = contextvar.get("local_context") + + return tool_ + + +@pytest.fixture +def agent(context_tool): + return Agent(tools=[context_tool]) + + +def test_agent_invoke_context_sharing(result, contextvar, agent): + contextvar.set("shared_context") + agent("Execute context_tool") + + tru_context = result["context_value"] + exp_context = contextvar.get() + assert tru_context == exp_context + + +def test_tool_call_context_sharing(result, contextvar, agent): + contextvar.set("shared_context") + agent.tool.context_tool() + + tru_context = result["context_value"] + exp_context = contextvar.get() + assert tru_context == exp_context From c15ccf552c776fe86ea6ed268a26e5c235f9bf26 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 18 Nov 2025 14:26:29 +0100 Subject: [PATCH 135/242] fix(bidi): fix tests --- .../bidi/models/test_gemini_live.py | 2 +- .../bidi/models/test_novasonic.py | 53 +++++++++---------- .../bidi/models/test_openai_realtime.py | 12 ++--- tests_integ/bidi/context.py | 2 +- 4 files changed, 32 insertions(+), 37 deletions(-) diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index 8cf875598..48f8befc9 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -32,7 +32,7 @@ def mock_genai_client(): """Mock the Google GenAI client.""" with unittest.mock.patch( - "strands.experimental.bidirectional_streaming.models.gemini_live.genai.Client" + "strands.experimental.bidi.models.gemini_live.genai.Client" ) as mock_client_cls: mock_client = mock_client_cls.return_value mock_client.aio = unittest.mock.MagicMock() diff --git a/tests/strands/experimental/bidi/models/test_novasonic.py b/tests/strands/experimental/bidi/models/test_novasonic.py index 8e60b8fb5..e0459fd51 100644 --- a/tests/strands/experimental/bidi/models/test_novasonic.py +++ b/tests/strands/experimental/bidi/models/test_novasonic.py @@ -241,7 +241,7 @@ async def test_event_conversion(nova_model): # Audio is kept as base64 string assert result.get("audio") == audio_base64 assert result.get("format") == "pcm" - assert result.get("sample_rate") == 24000 + assert result.get("sample_rate") == 16000 # Test text output (now returns BidiTranscriptStreamEvent) nova_event = {"textOutput": {"content": "Hello, world!", "role": "ASSISTANT"}} @@ -293,12 +293,34 @@ async def test_event_conversion(nova_model): assert result.get("outputTokens") == 60 # Test content start tracks role and emits BidiResponseStartEvent - nova_event = {"contentStart": {"role": "USER"}} + # TEXT type contentStart (matches API spec) + nova_event = { + "contentStart": { + "role": "ASSISTANT", + "type": "TEXT", + "additionalModelFields": '{"generationStage":"FINAL"}', + "contentId": "content-123", + } + } result = nova_model._convert_nova_event(nova_event) assert result is not None assert isinstance(result, BidiResponseStartEvent) assert result.get("type") == "bidi_response_start" - assert nova_model._current_role == "USER" + assert nova_model._current_role == "ASSISTANT" + assert nova_model._generation_stage == "FINAL" + + # Test AUDIO type contentStart (no additionalModelFields) + nova_event = {"contentStart": {"role": "ASSISTANT", "type": "AUDIO", "contentId": "content-456"}} + result = nova_model._convert_nova_event(nova_event) + assert result is not None + assert isinstance(result, BidiResponseStartEvent) + assert nova_model._current_role == "ASSISTANT" + + # Test TOOL type contentStart + nova_event = {"contentStart": {"role": "TOOL", "type": "TOOL", "contentId": "content-789"}} + result = nova_model._convert_nova_event(nova_event) + assert result is not None + assert isinstance(result, BidiResponseStartEvent) # Audio Streaming Tests @@ -323,31 +345,6 @@ async def test_audio_connection_lifecycle(nova_model, mock_client, mock_stream): await nova_model.stop() -@pytest.mark.asyncio -async def test_silence_detection(nova_model, mock_client, mock_stream): - """Test that silence detection automatically ends audio input.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model.client = mock_client - nova_model.silence_threshold = 0.1 # Short threshold for testing - - await nova_model.start() - - # Send audio to start connection (base64 encoded) - audio_b64 = base64.b64encode(b"audio data").decode("utf-8") - audio_event = BidiAudioInputEvent(audio=audio_b64, format="pcm", sample_rate=16000, channels=1) - - await nova_model.send(audio_event) - assert nova_model.audio_connection_active - - # Wait for silence detection - await asyncio.sleep(0.2) - - # Audio connection should be ended - assert not nova_model.audio_connection_active - - await nova_model.stop() - - # Helper Method Tests diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py index badc52031..079516735 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -45,7 +45,7 @@ async def async_connect(*args, **kwargs): return mock_websocket with unittest.mock.patch( - "strands.experimental.bidirectional_streaming.models.openai.websockets.connect" + "strands.experimental.bidi.models.openai.websockets.connect" ) as mock_connect: mock_connect.side_effect = async_connect yield mock_connect, mock_websocket @@ -134,8 +134,6 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp assert model._active is True assert model.connection_id is not None assert model.websocket == mock_ws - assert model._event_queue is not None - assert model._response_task is not None mock_connect.assert_called_once() # Test close @@ -293,13 +291,13 @@ async def test_send_edge_cases(mock_websockets_connect, model): image=image_b64, mime_type="image/jpeg", ) - with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: + with unittest.mock.patch("strands.experimental.bidi.models.openai.logger") as mock_logger: await model.send(image_input) mock_logger.warning.assert_called_with("Image input not supported by OpenAI Realtime API") # Test unknown content type unknown_content = {"unknown_field": "value"} - with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: + with unittest.mock.patch("strands.experimental.bidi.models.openai.logger") as mock_logger: await model.send(unknown_content) assert mock_logger.warning.called @@ -366,7 +364,7 @@ async def test_event_conversion(mock_websockets_connect, model): assert converted[0].get("text") == "Hello from OpenAI" assert converted[0].get("role") == "assistant" assert converted[0].delta == {"text": "Hello from OpenAI"} - assert converted[0].is_final is True + assert converted[0].is_final is False # Delta events are not final # Test function call sequence item_added = { @@ -480,7 +478,7 @@ def test_helper_methods(model): assert text_event.get("text") == "Hello" assert text_event.get("role") == "user" assert text_event.delta == {"text": "Hello"} - assert text_event.is_final is True + assert text_event.is_final is True # Done events are final assert text_event.current_transcript == "Hello" # Test _create_voice_activity_event (now returns BidiInterruptionEvent for speech_started) diff --git a/tests_integ/bidi/context.py b/tests_integ/bidi/context.py index adca6ee5b..830857564 100644 --- a/tests_integ/bidi/context.py +++ b/tests_integ/bidi/context.py @@ -84,7 +84,7 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): """Stop context manager, cleanup threads, and end agent session.""" # End agent session FIRST - this will cause receive() to exit cleanly - if self.agent._agent_loop and self.agent._agent_loop.active: + if self.agent._loop and self.agent._loop.active: await self.agent.stop() logger.debug("Agent session stopped") From 037d184eebffb9de37a3b86af93a452d00bc9d7a Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 18 Nov 2025 14:29:22 +0100 Subject: [PATCH 136/242] formatter --- tests/strands/experimental/bidi/models/test_gemini_live.py | 4 +--- .../strands/experimental/bidi/models/test_openai_realtime.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index 48f8befc9..6a2c79ece 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -31,9 +31,7 @@ @pytest.fixture def mock_genai_client(): """Mock the Google GenAI client.""" - with unittest.mock.patch( - "strands.experimental.bidi.models.gemini_live.genai.Client" - ) as mock_client_cls: + with unittest.mock.patch("strands.experimental.bidi.models.gemini_live.genai.Client") as mock_client_cls: mock_client = mock_client_cls.return_value mock_client.aio = unittest.mock.MagicMock() diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py index 079516735..2ffcac7ae 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -44,9 +44,7 @@ def mock_websockets_connect(mock_websocket): async def async_connect(*args, **kwargs): return mock_websocket - with unittest.mock.patch( - "strands.experimental.bidi.models.openai.websockets.connect" - ) as mock_connect: + with unittest.mock.patch("strands.experimental.bidi.models.openai.websockets.connect") as mock_connect: mock_connect.side_effect = async_connect yield mock_connect, mock_websocket From e50fb62287c726ee8786acf3a1500e31b149d244 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 18 Nov 2025 14:34:21 +0100 Subject: [PATCH 137/242] fix toolcaller rename --- src/strands/agent/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 434c769c6..aa79a94be 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -286,7 +286,7 @@ def system_prompt(self, value: str | list[SystemContentBlock] | None) -> None: self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(value) @property - def tool(self) -> ToolCaller: + def tool(self) -> _ToolCaller: """Call tool as a function. Returns: From 16e539b2dbf455e38befb014bc78c0b2d81ae687 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 18 Nov 2025 15:32:50 +0100 Subject: [PATCH 138/242] fix import issues and event loop hanging --- src/strands/experimental/bidi/agent/agent.py | 1 + src/strands/experimental/bidi/agent/loop.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 4fe7756a1..e005b384d 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -20,6 +20,7 @@ from .... import _identifier from ....hooks import HookProvider, HookRegistry from ....tools.caller import _ToolCaller +from ..hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent from ....tools.executors import ConcurrentToolExecutor from ....tools.executors._executor import ToolExecutor from ....tools.registry import ToolRegistry diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index ee458ca82..e60f89214 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -88,13 +88,17 @@ async def stop(self) -> None: logger.debug("agent loop stopping") try: + # Cancel all tasks for task in self._tasks: task.cancel() + # Wait briefly for tasks to finish their current operations await asyncio.gather(*self._tasks, return_exceptions=True) + # Stop the model await self._agent.model.stop() + # Clean up the event queue if not self._event_queue.empty(): self._event_queue.get_nowait() self._event_queue.put_nowait(self._stop_event) From 444fc34d2dbc989a48657368902d7a4208fb1f50 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 18 Nov 2025 16:04:10 +0100 Subject: [PATCH 139/242] remove interruptable --- src/strands/experimental/bidi/hooks/events.py | 24 ++----------------- 1 file changed, 2 insertions(+), 22 deletions(-) diff --git a/src/strands/experimental/bidi/hooks/events.py b/src/strands/experimental/bidi/hooks/events.py index e5aa5b2bd..d4add3200 100644 --- a/src/strands/experimental/bidi/hooks/events.py +++ b/src/strands/experimental/bidi/hooks/events.py @@ -4,15 +4,11 @@ the lifecycle of a streaming session. """ -import uuid from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal, Optional -from typing_extensions import override - from ....hooks.registry import BaseHookEvent from ....types.content import Message -from ....types.interrupt import _Interruptible from ....types.tools import AgentTool, ToolResult, ToolUse if TYPE_CHECKING: @@ -96,7 +92,7 @@ class BidiMessageAddedEvent(BidiHookEvent): @dataclass -class BidiBeforeToolCallEvent(BidiHookEvent, _Interruptible): +class BidiBeforeToolCallEvent(BidiHookEvent): """Event triggered before BidiAgent executes a tool. This event is fired just before the BidiAgent executes a tool during a streaming @@ -109,30 +105,14 @@ class BidiBeforeToolCallEvent(BidiHookEvent, _Interruptible): to change which tool gets executed. This may be None if tool lookup failed. tool_use: The tool parameters that will be passed to selected_tool. invocation_state: Keyword arguments that will be passed to the tool. - cancel_tool: A user defined message that when set, will cancel the tool call. - The message will be placed into a tool result with an error status. If set to `True`, - Strands will cancel the tool call and use a default cancel message. """ selected_tool: Optional[AgentTool] tool_use: ToolUse invocation_state: dict[str, Any] - cancel_tool: bool | str = False def _can_write(self, name: str) -> bool: - return name in ["cancel_tool", "selected_tool", "tool_use"] - - @override - def _interrupt_id(self, name: str) -> str: - """Unique id for the interrupt. - - Args: - name: User defined name for the interrupt. - - Returns: - Interrupt id. - """ - return f"v1:bidi_before_tool_call:{self.tool_use['toolUseId']}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}" + return name in ["selected_tool", "tool_use"] @dataclass From 23b03854d47d2c664e2d3467cd4a24ce33eee9ca Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 18 Nov 2025 16:14:57 +0100 Subject: [PATCH 140/242] update hooks to use async invocation --- src/strands/experimental/bidi/agent/agent.py | 2 +- src/strands/experimental/bidi/agent/loop.py | 35 +++++++++---------- .../bidi/hooks/test_bidi_hook_events.py | 2 -- 3 files changed, 18 insertions(+), 21 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index e005b384d..dda0e895e 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -284,7 +284,7 @@ async def send(self, input_data: BidiAgentInput) -> None: user_message: Message = {"role": "user", "content": [{"text": input_data}]} self.messages.append(user_message) - self.hooks.invoke_callbacks(BidiMessageAddedEvent(agent=self, message=user_message)) + await self.hooks.invoke_callbacks_async(BidiMessageAddedEvent(agent=self, message=user_message)) logger.debug("text_length=<%d> | text sent to model", len(input_data)) # Create BidiTextInputEvent for send() diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index e60f89214..62136f2c1 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -68,7 +68,7 @@ async def start(self) -> None: self._tasks = set() # Emit before invocation event - self._agent.hooks.invoke_callbacks(BidiBeforeInvocationEvent(agent=self._agent)) + await self._agent.hooks.invoke_callbacks_async(BidiBeforeInvocationEvent(agent=self._agent)) await self._agent.model.start( system_prompt=self._agent.system_prompt, @@ -109,7 +109,7 @@ async def stop(self) -> None: self._event_queue = None finally: # Emit after invocation event (reverse order for cleanup) - self._agent.hooks.invoke_callbacks(BidiAfterInvocationEvent(agent=self._agent)) + await self._agent.hooks.invoke_callbacks_async(BidiAfterInvocationEvent(agent=self._agent)) async def receive(self) -> AsyncIterable[BidiOutputEvent]: """Receive model and tool call events.""" @@ -149,7 +149,9 @@ async def _run_model(self) -> None: if event["is_final"]: message: Message = {"role": event["role"], "content": [{"text": event["text"]}]} self._agent.messages.append(message) - self._agent.hooks.invoke_callbacks(BidiMessageAddedEvent(agent=self._agent, message=message)) + await self._agent.hooks.invoke_callbacks_async( + BidiMessageAddedEvent(agent=self._agent, message=message) + ) elif isinstance(event, ToolUseStreamEvent): tool_use = event["current_tool_use"] @@ -160,7 +162,7 @@ async def _run_model(self) -> None: elif isinstance(event, BidiInterruptionEvent): # Emit interruption hook event - self._agent.hooks.invoke_callbacks( + await self._agent.hooks.invoke_callbacks_async( BidiInterruptionHookEvent( agent=self._agent, reason=event["reason"], @@ -181,7 +183,7 @@ async def _run_tool(self, tool_use: ToolUse) -> None: tool = self._agent.tool_registry.registry[tool_use["name"]] # Emit before tool call event - self._agent.hooks.invoke_callbacks( + await self._agent.hooks.invoke_callbacks_async( BidiBeforeToolCallEvent( agent=self._agent, selected_tool=tool, @@ -204,20 +206,18 @@ async def _run_tool(self, tool_use: ToolUse) -> None: except Exception as e: result = {"toolUseId": tool_use["toolUseId"], "status": "error", "content": [{"text": f"Error: {str(e)}"}]} - finally: # Emit after tool call event (reverse order for cleanup) - if result: - self._agent.hooks.invoke_callbacks( - BidiAfterToolCallEvent( - agent=self._agent, - selected_tool=tool, - tool_use=tool_use, - invocation_state=invocation_state, - result=result, - exception=exception, - ) + await self._agent.hooks.invoke_callbacks_async( + BidiAfterToolCallEvent( + agent=self._agent, + selected_tool=tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=result, + exception=exception, ) + ) await self._agent.model.send(ToolResultEvent(result)) @@ -226,6 +226,5 @@ async def _run_tool(self, tool_use: ToolUse) -> None: "content": [{"toolResult": result}], } self._agent.messages.append(message) - self._agent.hooks.invoke_callbacks(BidiMessageAddedEvent(agent=self._agent, message=message)) + await self._agent.hooks.invoke_callbacks_async(BidiMessageAddedEvent(agent=self._agent, message=message)) await self._event_queue.put(ToolResultMessageEvent(message)) - diff --git a/tests/strands/experimental/bidi/hooks/test_bidi_hook_events.py b/tests/strands/experimental/bidi/hooks/test_bidi_hook_events.py index 70550ee56..bf3710066 100644 --- a/tests/strands/experimental/bidi/hooks/test_bidi_hook_events.py +++ b/tests/strands/experimental/bidi/hooks/test_bidi_hook_events.py @@ -138,7 +138,6 @@ def test_before_tool_call_event_can_write_properties(before_tool_event): new_tool_use = ToolUse(name="new_tool", toolUseId="456", input={}) before_tool_event.selected_tool = None # Should not raise before_tool_event.tool_use = new_tool_use # Should not raise - before_tool_event.cancel_tool = True # Should not raise def test_before_tool_call_event_cannot_write_properties(before_tool_event): @@ -167,4 +166,3 @@ def test_after_tool_call_event_cannot_write_properties(after_tool_event): after_tool_event.invocation_state = {} with pytest.raises(AttributeError, match="Property exception is not writable"): after_tool_event.exception = Exception("test") - From ae3394d9785903b53e91bc2818dac9ebadf37286 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 18 Nov 2025 07:38:28 -0800 Subject: [PATCH 141/242] fix mypy errors - agent, loop --- src/strands/experimental/bidi/agent/agent.py | 24 ++++---------------- src/strands/experimental/bidi/agent/loop.py | 20 ++++++++-------- 2 files changed, 14 insertions(+), 30 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index ac015e43a..fe58ae806 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -15,7 +15,7 @@ import asyncio import json import logging -from typing import Any, AsyncIterable +from typing import Any, AsyncIterable, cast from .... import _identifier from ....tools.caller import _ToolCaller @@ -116,8 +116,6 @@ def __init__( # Initialize other components self._tool_caller = _ToolCaller(self) - self._current_adapters = [] # Track adapters for cleanup - self._loop = _BidiAgentLoop(self) @property @@ -344,10 +342,10 @@ async def __aenter__(self) -> "BidiAgent": await self.start() return self - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: """Async context manager exit point. - Automatically ends the connection and cleans up resources including adapters + Automatically ends the connection and cleans up resources including when exiting the context, regardless of whether an exception occurred. Args: @@ -356,19 +354,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: exc_tb: Exception traceback if an exception occurred, None otherwise. """ try: - logger.debug("context_manager= | cleaning up adapters and connection") - - # Cleanup adapters if any are currently active - for adapter in self._current_adapters: - if hasattr(adapter, "cleanup"): - try: - adapter.stop() - logger.debug("adapter_type=<%s> | adapter cleaned up", type(adapter).__name__) - except Exception as adapter_error: - logger.warning("adapter_error=<%s> | error cleaning up adapter", adapter_error) - - # Clear current adapters - self._current_adapters = [] + logger.debug("context_manager= | cleaning up connection") # Cleanup agent connection await self.stop() @@ -388,7 +374,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: @property def active(self) -> bool: """True if agent loop started, False otherwise.""" - return self._loop.active + return cast(bool, self._loop.active) async def run(self, inputs: list[BidiInput], outputs: list[BidiOutput]) -> None: """Run the agent using provided IO channels for bidirectional communication. diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 5fac824c2..0423df9de 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -5,7 +5,7 @@ import asyncio import logging -from typing import TYPE_CHECKING, AsyncIterable, Awaitable +from typing import TYPE_CHECKING, Any, AsyncIterable, Awaitable from ....types._events import ToolResultEvent, ToolResultMessageEvent, ToolStreamEvent, ToolUseStreamEvent from ....types.content import Message @@ -32,6 +32,7 @@ class _BidiAgentLoop: _event_queue: asyncio.Queue _stop_event: object _tasks: set + _active: bool def __init__(self, agent: "BidiAgent") -> None: """Initialize members of the agent loop. @@ -42,7 +43,7 @@ def __init__(self, agent: "BidiAgent") -> None: agent: Bidirectional agent to loop over. """ self._agent = agent - self._active = False + self._active: bool = False async def start(self) -> None: """Start the agent loop. @@ -87,9 +88,6 @@ async def stop(self) -> None: self._event_queue.put_nowait(self._stop_event) self._active = False - self._tasks = None - self._stop_event = None - self._event_queue = None async def receive(self) -> AsyncIterable[BidiOutputEvent]: """Receive model and tool call events.""" @@ -110,7 +108,7 @@ def _create_task(self, coro: Awaitable[None]) -> None: Adds a clean up callback to run after task completes. """ - task = asyncio.create_task(coro) + task: asyncio.Task[None] = asyncio.create_task(coro) # type: ignore task.add_done_callback(lambda task: self._tasks.remove(task)) self._tasks.add(task) @@ -127,15 +125,15 @@ async def _run_model(self) -> None: if isinstance(event, BidiTranscriptStreamEvent): if event["is_final"]: - message: Message = {"role": event["role"], "content": [{"text": event["text"]}]} - self._agent.messages.append(message) + transcript_message: Message = {"role": event["role"], "content": [{"text": event["text"]}]} + self._agent.messages.append(transcript_message) elif isinstance(event, ToolUseStreamEvent): tool_use = event["current_tool_use"] self._create_task(self._run_tool(tool_use)) - message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} - self._agent.messages.append(message) + tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} + self._agent.messages.append(tool_message) async def _run_tool(self, tool_use: ToolUse) -> None: """Task for running tool requested by the model.""" @@ -145,7 +143,7 @@ async def _run_tool(self, tool_use: ToolUse) -> None: try: tool = self._agent.tool_registry.registry[tool_use["name"]] - invocation_state = {} + invocation_state: dict[str, Any] = {} async for event in tool.stream(tool_use, invocation_state): if isinstance(event, ToolResultEvent): From f3f59788378f8a3e8d020f5f580b300e083db8a3 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 18 Nov 2025 08:40:52 -0800 Subject: [PATCH 142/242] fix mypy errors in agent, loop --- src/strands/experimental/bidi/agent/agent.py | 8 ++++---- src/strands/experimental/bidi/agent/loop.py | 19 ++++++++++--------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index fe58ae806..f313616e6 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -15,7 +15,7 @@ import asyncio import json import logging -from typing import Any, AsyncIterable, cast +from typing import Any, AsyncIterable from .... import _identifier from ....tools.caller import _ToolCaller @@ -23,7 +23,7 @@ from ....tools.executors._executor import ToolExecutor from ....tools.registry import ToolRegistry from ....tools.watcher import ToolWatcher -from ....types.content import Message, Messages +from ....types.content import ContentBlock, Message, Messages from ....types.tools import AgentTool, ToolResult, ToolUse from ...tools import ToolProvider from ..models.bidi_model import BidiModel @@ -169,7 +169,7 @@ def _record_tool_execution( # Create user message describing the tool call input_parameters = json.dumps(filtered_input, default=lambda o: f"<>") - user_msg_content = [ + user_msg_content: list[ContentBlock] = [ {"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")} ] @@ -374,7 +374,7 @@ async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseExc @property def active(self) -> bool: """True if agent loop started, False otherwise.""" - return cast(bool, self._loop.active) + return self._loop.active async def run(self, inputs: list[BidiInput], outputs: list[BidiOutput]) -> None: """Run the agent using provided IO channels for bidirectional communication. diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 0423df9de..e0007915a 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -108,7 +108,7 @@ def _create_task(self, coro: Awaitable[None]) -> None: Adds a clean up callback to run after task completes. """ - task: asyncio.Task[None] = asyncio.create_task(coro) # type: ignore + task: asyncio.Task[None] = asyncio.create_task(coro) # type: ignore task.add_done_callback(lambda task: self._tasks.remove(task)) self._tasks.add(task) @@ -139,7 +139,7 @@ async def _run_tool(self, tool_use: ToolUse) -> None: """Task for running tool requested by the model.""" logger.debug("tool_name=<%s> | tool execution starting", tool_use["name"]) - result: ToolResult = None + result: ToolResult | None = None try: tool = self._agent.tool_registry.registry[tool_use["name"]] @@ -159,11 +159,12 @@ async def _run_tool(self, tool_use: ToolUse) -> None: except Exception as e: result = {"toolUseId": tool_use["toolUseId"], "status": "error", "content": [{"text": f"Error: {str(e)}"}]} - await self._agent.model.send(ToolResultEvent(result)) + if result is not None: + await self._agent.model.send(ToolResultEvent(result)) - message: Message = { - "role": "user", - "content": [{"toolResult": result}], - } - self._agent.messages.append(message) - await self._event_queue.put(ToolResultMessageEvent(message)) + message: Message = { + "role": "user", + "content": [{"toolResult": result}], + } + self._agent.messages.append(message) + await self._event_queue.put(ToolResultMessageEvent(message)) From 5cfbe4d094aafc556956ac4c81acbbf1747af9be Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 18 Nov 2025 08:43:22 -0800 Subject: [PATCH 143/242] fix mypy errors in agent, loop files --- src/strands/experimental/bidi/agent/agent.py | 2 +- src/strands/experimental/bidi/agent/loop.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index f313616e6..58ab035ad 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -283,7 +283,7 @@ async def send(self, input_data: BidiAgentInput) -> None: return # Handle plain dict - reconstruct TypedEvent for WebSocket integration - if isinstance(input_data, dict) and "type" in input_data: + if isinstance(input_data, dict) and "type" in input_data: # type: ignore event_type = input_data["type"] if event_type == "bidi_text_input": input_event = BidiTextInputEvent(text=input_data["text"], role=input_data["role"]) diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index e0007915a..37f5fd0c2 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -120,7 +120,7 @@ async def _run_model(self) -> None: """ logger.debug("model task starting") - async for event in self._agent.model.receive(): + async for event in self._agent.model.receive(): # type: ignore await self._event_queue.put(event) if isinstance(event, BidiTranscriptStreamEvent): From a20a012e2c6a545a8c0d11a3b91f8fa0313b9fd1 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 18 Nov 2025 08:47:19 -0800 Subject: [PATCH 144/242] fix model provider mypy errors - novasonic, openai --- src/strands/experimental/bidi/models/novasonic.py | 10 +++++----- src/strands/experimental/bidi/models/openai.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index eeaa1d659..5e7e85d6c 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -18,7 +18,7 @@ import logging import traceback import uuid -from typing import AsyncIterable +from typing import Any, AsyncIterable from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme @@ -195,7 +195,7 @@ async def _send_initialization_events(self, events: list[str]) -> None: """Send initialization events.""" await self._send_nova_event(events) - def _log_event_type(self, nova_event: dict[str, any]) -> None: + def _log_event_type(self, nova_event: dict[str, Any]) -> None: """Log specific Nova Sonic event types for debugging.""" if "usageEvent" in nova_event: logger.debug("usage=<%s> | nova usage event received", nova_event["usageEvent"]) @@ -213,7 +213,7 @@ def _log_event_type(self, nova_event: dict[str, any]) -> None: audio_bytes = base64.b64decode(audio_content) logger.debug("audio_bytes=<%d> | nova audio output received", len(audio_bytes)) - async def receive(self) -> AsyncIterable[dict[str, any]]: + async def receive(self) -> AsyncIterable[BidiOutputEvent]: """Receive Nova Sonic events and convert to provider-agnostic format.""" if not self.stream: logger.error("Stream is None") @@ -424,7 +424,7 @@ async def stop(self) -> None: finally: logger.debug("nova connection closed") - def _convert_nova_event(self, nova_event: dict[str, any]) -> BidiOutputEvent | None: + def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | None: """Convert Nova Sonic events to TypedEvent format.""" # Handle completion start - track completionId if "completionStart" in nova_event: @@ -611,7 +611,7 @@ def _get_text_input_event(self, content_name: str, text: str) -> str: {"event": {"textInput": {"promptName": self.connection_id, "contentName": content_name, "content": text}}} ) - def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> str: + def _get_tool_result_event(self, content_name: str, result: dict[str, Any]) -> str: """Generate tool result event.""" return json.dumps( { diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 1564e5d68..855e4ba75 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -8,7 +8,7 @@ import logging import os import uuid -from typing import AsyncIterable +from typing import Any, AsyncIterable import websockets @@ -75,7 +75,7 @@ def __init__( api_key: str | None = None, organization: str | None = None, project: str | None = None, - session_config: dict[str, any] | None = None, + session_config: dict[str, Any] | None = None, **kwargs, ) -> None: """Initialize OpenAI Realtime bidirectional model. @@ -296,7 +296,7 @@ async def receive(self) -> AsyncIterable[BidiOutputEvent]: yield BidiConnectionCloseEvent(connection_id=self.connection_id, reason="complete") self._active = False - def _convert_openai_event(self, openai_event: dict[str, any]) -> list[BidiOutputEvent] | None: + def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutputEvent] | None: """Convert OpenAI events to Strands TypedEvent format.""" event_type = openai_event.get("type") @@ -627,7 +627,7 @@ async def stop(self) -> None: logger.debug("openai realtime connection closed") - async def _send_event(self, event: dict[str, any]) -> None: + async def _send_event(self, event: dict[str, Any]) -> None: """Send event to OpenAI via WebSocket.""" try: message = json.dumps(event) From f7a8e3e8195146d2ce488c596b6bc79df14254fa Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 18 Nov 2025 08:51:21 -0800 Subject: [PATCH 145/242] temporarily exclude scripts since scripts will be removed --- pyproject.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 97b9cf5f9..d7de4e226 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -223,6 +223,11 @@ warn_no_return = true warn_unreachable = true follow_untyped_imports = true ignore_missing_imports = false + +exclude = [ + "src/strands/experimental/bidi/scripts/.*", +] + # Ignore missing imports for optional bidi dependencies (not installed in lint environment) [[tool.mypy.overrides]] module = [ From 6d1ebd0629fbc4895630bfa17561a6c9fb3a830e Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 18 Nov 2025 10:00:33 -0800 Subject: [PATCH 146/242] fix mypy errors - audio,text, bidi_model --- src/strands/experimental/bidi/io/audio.py | 9 +++++---- src/strands/experimental/bidi/io/text.py | 8 ++++++++ src/strands/experimental/bidi/models/bidi_model.py | 4 ++-- src/strands/experimental/bidi/types/events.py | 2 +- 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/strands/experimental/bidi/io/audio.py b/src/strands/experimental/bidi/io/audio.py index 9a798a537..0e62b293b 100644 --- a/src/strands/experimental/bidi/io/audio.py +++ b/src/strands/experimental/bidi/io/audio.py @@ -10,7 +10,7 @@ from collections import deque from typing import Any -import pyaudio +import pyaudio # type: ignore[import-untyped] from ..types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiInterruptionEvent, BidiOutputEvent from ..types.io import BidiInput, BidiOutput @@ -156,9 +156,10 @@ async def stop(self) -> None: self._stream.close() self._audio.terminate() - self._output_task = None - self._buffer = None - self._buffer_event = None + # Adding type ignore to adhere to mypy + self._output_task = None # type: ignore[assignment] + self._buffer = None # type: ignore[assignment] + self._buffer_event = None # type: ignore[assignment] self._stream = None self._audio = None logger.debug("audio output stream stopped") diff --git a/src/strands/experimental/bidi/io/text.py b/src/strands/experimental/bidi/io/text.py index 18b39819d..8ecbae149 100644 --- a/src/strands/experimental/bidi/io/text.py +++ b/src/strands/experimental/bidi/io/text.py @@ -11,6 +11,14 @@ class _BidiTextOutput(BidiOutput): """Handle text output from bidi agent.""" + async def start(self) -> None: + """Start text output.""" + pass + + async def stop(self) -> None: + """Stop text output.""" + pass + async def __call__(self, event: BidiOutputEvent) -> None: """Print text events to stdout.""" if isinstance(event, BidiInterruptionEvent): diff --git a/src/strands/experimental/bidi/models/bidi_model.py b/src/strands/experimental/bidi/models/bidi_model.py index e598498e1..04e4a69e4 100644 --- a/src/strands/experimental/bidi/models/bidi_model.py +++ b/src/strands/experimental/bidi/models/bidi_model.py @@ -13,7 +13,7 @@ """ import logging -from typing import AsyncIterable, Protocol +from typing import Any, AsyncIterable, Protocol from ....types._events import ToolResultEvent from ....types.content import Messages @@ -39,7 +39,7 @@ async def start( system_prompt: str | None = None, tools: list[ToolSpec] | None = None, messages: Messages | None = None, - **kwargs, + **kwargs: Any, ) -> None: """Establish a persistent streaming connection with the model. diff --git a/src/strands/experimental/bidi/types/events.py b/src/strands/experimental/bidi/types/events.py index 547919494..933759b15 100644 --- a/src/strands/experimental/bidi/types/events.py +++ b/src/strands/experimental/bidi/types/events.py @@ -84,7 +84,7 @@ class BidiAudioInputEvent(TypedEvent): def __init__( self, audio: str, - format: Literal["pcm", "wav", "opus", "mp3"], + format: Literal["pcm", "wav", "opus", "mp3"] | str, sample_rate: Literal[16000, 24000, 48000], channels: Literal[1, 2], ): From 2b30c5b651b4e686cf55105bfbca05f45858b499 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 18 Nov 2025 10:44:12 -0800 Subject: [PATCH 147/242] fix mypy errors - novasonic --- src/strands/experimental/bidi/io/audio.py | 2 +- .../experimental/bidi/models/novasonic.py | 61 +++++++++++-------- 2 files changed, 36 insertions(+), 27 deletions(-) diff --git a/src/strands/experimental/bidi/io/audio.py b/src/strands/experimental/bidi/io/audio.py index 0e62b293b..f21bb7beb 100644 --- a/src/strands/experimental/bidi/io/audio.py +++ b/src/strands/experimental/bidi/io/audio.py @@ -156,7 +156,7 @@ async def stop(self) -> None: self._stream.close() self._audio.terminate() - # Adding type ignore to adhere to mypy + # Adding type ignore to adhere to mypy self._output_task = None # type: ignore[assignment] self._buffer = None # type: ignore[assignment] self._buffer_event = None # type: ignore[assignment] diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 5e7e85d6c..c2434c78d 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -88,7 +88,7 @@ class BidiNovaSonicModel(BidiModel): tool execution patterns while providing the standard BidiModel interface. """ - def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **kwargs) -> None: + def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **kwargs: Any) -> None: """Initialize Nova Sonic bidirectional model. Args: @@ -99,23 +99,23 @@ def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-e # Model configuration self.model_id = model_id self.region = region - self.client = None + self.client: Any = None # Connection state (initialized in start()) - self.stream = None - self.connection_id = None + self.stream: Any = None + self.connection_id: str = "" self._active = False # Nova Sonic requires unique content names - self.audio_content_name = None + self.audio_content_name: str | None = None # Audio connection state self.audio_connection_active = False # Track API-provided identifiers - self._current_completion_id = None - self._current_role = None - self._generation_stage = None + self._current_completion_id: str | None = None + self._current_role: str | None = None + self._generation_stage: str | None = None # Ensure certain events are sent in sequence when required self._send_lock = asyncio.Lock() @@ -127,7 +127,7 @@ async def start( system_prompt: str | None = None, tools: list[ToolSpec] | None = None, messages: Messages | None = None, - **kwargs, + **kwargs: Any, ) -> None: """Establish bidirectional connection to Nova Sonic. @@ -179,7 +179,7 @@ async def start( raise def _build_initialization_events( - self, system_prompt: str, tools: list[ToolSpec], messages: Messages | None + self, system_prompt: str, tools: list[ToolSpec], messages: Messages | None = None ) -> list[str]: """Build the sequence of initialization events.""" events = [self._get_connection_start_event(), self._get_prompt_start_event(tools)] @@ -213,7 +213,7 @@ def _log_event_type(self, nova_event: dict[str, Any]) -> None: audio_bytes = base64.b64decode(audio_content) logger.debug("audio_bytes=<%d> | nova audio output received", len(audio_bytes)) - async def receive(self) -> AsyncIterable[BidiOutputEvent]: + async def receive(self) -> AsyncIterable[BidiOutputEvent]: # type: ignore[override] """Receive Nova Sonic events and convert to provider-agnostic format.""" if not self.stream: logger.error("Stream is None") @@ -225,7 +225,7 @@ async def receive(self) -> AsyncIterable[BidiOutputEvent]: yield BidiConnectionStartEvent(connection_id=self.connection_id, model=self.model_id) try: - while self._active: + while self._active and self.stream: try: output = await asyncio.wait_for(self.stream.await_output(), timeout=RESPONSE_TIMEOUT) result = await output[1].receive() @@ -277,8 +277,6 @@ async def send( tool_result = content.get("tool_result") if tool_result: await self._send_tool_result(tool_result) - else: - logger.warning("content_type=<%s> | unknown content type", type(content).__name__) except Exception as e: logger.error("error=<%s> | error sending content to nova sonic", e) raise # Propagate exception for debugging in experimental code @@ -373,6 +371,9 @@ async def _send_interrupt(self) -> None: async def _send_tool_result(self, tool_result: ToolResult) -> None: """Internal: Send tool result using Nova Sonic toolResult format.""" tool_use_id = tool_result.get("toolUseId") + if not tool_use_id: + logger.error("tool result missing toolUseId") + return logger.debug("tool_use_id=<%s> | sending nova tool result", tool_use_id) @@ -414,10 +415,11 @@ async def stop(self) -> None: logger.warning("error=<%s> | error during nova sonic cleanup", e) # Close stream - try: - await self.stream.input_stream.close() - except Exception as e: - logger.warning("error=<%s> | error closing nova sonic stream", e) + if self.stream: + try: + await self.stream.input_stream.close() + except Exception as e: + logger.warning("error=<%s> | error closing nova sonic stream", e) except Exception as e: logger.error("error=<%s> | nova cleanup failed", str(e)) @@ -453,7 +455,10 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N # Audio is already base64 string from Nova Sonic audio_content = nova_event["audioOutput"]["content"] return BidiAudioStreamEvent( - audio=audio_content, format="pcm", sample_rate=NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"], channels=1 + audio=audio_content, + format="pcm", + sample_rate=NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"], # type: ignore + channels=1, ) # Handle text output (transcripts) @@ -467,7 +472,7 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N return BidiTranscriptStreamEvent( delta={"text": text_content}, text=text_content, - role=self._current_role.lower() if self._current_role else "assistant", + role=self._current_role.lower() if self._current_role else "assistant", # type: ignore is_final=self._generation_stage == "FINAL", current_transcript=text_content, ) @@ -480,8 +485,8 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N "name": tool_use["toolName"], "input": json.loads(tool_use["content"]), } - # Return ToolUseStreamEvent for consistency with standard agent - return ToolUseStreamEvent(delta={"toolUse": tool_use_event}, current_tool_use=tool_use_event) + # Return ToolUseStreamEvent - cast to dict for type compatibility + return ToolUseStreamEvent(delta={"toolUse": tool_use_event}, current_tool_use=dict(tool_use_event)) # type: ignore[return-value] # Handle interruption if nova_event.get("stopReason") == "INTERRUPTED": @@ -517,7 +522,7 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N ) # Ignore other events (contentEnd, etc.) - return + return None # Nova Sonic event template methods def _get_connection_start_event(self) -> str: @@ -526,7 +531,7 @@ def _get_connection_start_event(self) -> str: def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: """Generate Nova Sonic prompt start event with tool configuration.""" - prompt_start_event = { + prompt_start_event: dict[str, Any] = { "event": { "promptStart": { "promptName": self.connection_id, @@ -543,9 +548,9 @@ def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: return json.dumps(prompt_start_event) - def _build_tool_configuration(self, tools: list[ToolSpec]) -> list[dict]: + def _build_tool_configuration(self, tools: list[ToolSpec]) -> list[dict[str, Any]]: """Build tool configuration from tool specs.""" - tool_config = [] + tool_config: list[dict[str, Any]] = [] for tool in tools: input_schema = ( {"json": json.dumps(tool["inputSchema"]["json"])} @@ -645,6 +650,10 @@ async def _send_nova_event(self, events: list[str]) -> None: Args: events: Jsonified event. """ + if not self.stream: + logger.error("cannot send event: stream is None") + return + try: async with self._send_lock: for event in events: From 61f61ee42be4ee2307250e89196fa1d3d7a89d38 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 18 Nov 2025 11:26:35 -0800 Subject: [PATCH 148/242] fix mypy errors - openai, caller, events --- .../experimental/bidi/models/novasonic.py | 2 +- .../experimental/bidi/models/openai.py | 40 ++++++++++--------- src/strands/experimental/bidi/types/events.py | 3 +- src/strands/tools/caller.py | 2 +- 4 files changed, 25 insertions(+), 22 deletions(-) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index c2434c78d..8e83a4947 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -486,7 +486,7 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N "input": json.loads(tool_use["content"]), } # Return ToolUseStreamEvent - cast to dict for type compatibility - return ToolUseStreamEvent(delta={"toolUse": tool_use_event}, current_tool_use=dict(tool_use_event)) # type: ignore[return-value] + return ToolUseStreamEvent(delta={"toolUse": tool_use_event}, current_tool_use=dict(tool_use_event)) # Handle interruption if nova_event.get("stopReason") == "INTERRUPTED": diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 855e4ba75..3cda4f738 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -11,6 +11,7 @@ from typing import Any, AsyncIterable import websockets +from websockets import ClientConnection from ....types._events import ToolResultEvent, ToolUseStreamEvent from ....types.content import Messages @@ -76,7 +77,7 @@ def __init__( organization: str | None = None, project: str | None = None, session_config: dict[str, Any] | None = None, - **kwargs, + **kwargs: Any, ) -> None: """Initialize OpenAI Realtime bidirectional model. @@ -103,11 +104,11 @@ def __init__( ) # Connection state (initialized in start()) - self.websocket = None - self.connection_id = None - self._active = False + self.websocket: ClientConnection + self.connection_id: str + self._active: bool = False - self._function_call_buffer = {} + self._function_call_buffer: dict[str, Any] = {} logger.debug("model=<%s> | openai realtime model initialized", model) @@ -116,7 +117,7 @@ async def start( system_prompt: str | None = None, tools: list[ToolSpec] | None = None, messages: Messages | None = None, - **kwargs, + **kwargs: Any, ) -> None: """Establish bidirectional connection to OpenAI Realtime API. @@ -182,7 +183,7 @@ def _create_text_event(self, text: str, role: str, is_final: bool = True) -> Bid return BidiTranscriptStreamEvent( delta={"text": text}, text=text, - role=normalized_role, + role=normalized_role, # type: ignore is_final=is_final, current_transcript=text if is_final else None, ) @@ -203,7 +204,7 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] config["instructions"] = system_prompt if tools: - config["tools"] = self._convert_tools_to_openai_format(tools) + config["tools"] = self._convert_tools_to_openai_format(tools) # type: ignore # Apply user-provided session configuration supported_params = { @@ -255,7 +256,7 @@ def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: async def _add_conversation_history(self, messages: Messages) -> None: """Add conversation history to the session.""" for message in messages: - conversation_item = { + conversation_item: dict[Any, Any] = { "type": "conversation.item.create", "item": {"type": "message", "role": message["role"], "content": []}, } @@ -272,7 +273,7 @@ async def _add_conversation_history(self, messages: Messages) -> None: await self._send_event(conversation_item) - async def receive(self) -> AsyncIterable[BidiOutputEvent]: + async def receive(self) -> AsyncIterable[BidiOutputEvent]: # type: ignore """Receive OpenAI events and convert to Strands TypedEvent format.""" # Emit connection start event yield BidiConnectionStartEvent(connection_id=self.connection_id, model=self.model) @@ -281,7 +282,7 @@ async def receive(self) -> AsyncIterable[BidiOutputEvent]: while self._active: async for message in self.websocket: if not self._active: - break + break # type: ignore openai_event = json.loads(message) @@ -311,7 +312,10 @@ def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutput # Audio is already base64 string from OpenAI return [ BidiAudioStreamEvent( - audio=openai_event["delta"], format="pcm", sample_rate=AUDIO_FORMAT["rate"], channels=1 + audio=openai_event["delta"], + format="pcm", + sample_rate=AUDIO_FORMAT["rate"], # type: ignore + channels=1, ) ] @@ -385,7 +389,7 @@ def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutput } del self._function_call_buffer[call_id] # Return ToolUseStreamEvent for consistency with standard agent - return [ToolUseStreamEvent(delta={"toolUse": tool_use}, current_tool_use=tool_use)] + return [ToolUseStreamEvent(delta={"toolUse": tool_use}, current_tool_use=dict(tool_use))] except (json.JSONDecodeError, KeyError) as e: logger.warning("call_id=<%s>, error=<%s> | error parsing function arguments", call_id, e) del self._function_call_buffer[call_id] @@ -419,11 +423,11 @@ def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutput } # Build list of events to return - events = [] + events: list[Any] = [] # Always add response complete event events.append( - BidiResponseCompleteEvent(response_id=response_id, stop_reason=stop_reason_map.get(status, "complete")) + BidiResponseCompleteEvent(response_id=response_id, stop_reason=stop_reason_map.get(status, "complete")) # type: ignore ) # Add usage event if available @@ -464,7 +468,7 @@ def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutput input_tokens=usage.get("input_tokens", 0), output_tokens=usage.get("output_tokens", 0), total_tokens=usage.get("total_tokens", 0), - modality_details=modality_details if modality_details else None, + modality_details=modality_details if modality_details else None, # type: ignore cache_read_input_tokens=cached_tokens if cached_tokens > 0 else None, ) ) @@ -570,8 +574,6 @@ async def send( tool_result = content.get("tool_result") if tool_result: await self._send_tool_result(tool_result) - else: - logger.warning("content_type=<%s> | unknown content type", type(content).__name__) except Exception as e: logger.error("error=<%s> | error sending content to openai", e) raise # Propagate exception for debugging in experimental code @@ -598,7 +600,7 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: logger.debug("tool_use_id=<%s> | sending openai tool result", tool_use_id) # Extract result content - result_data = {} + result_data: dict[Any, Any] | str = {} if "content" in tool_result: # Extract text from content blocks for block in tool_result["content"]: diff --git a/src/strands/experimental/bidi/types/events.py b/src/strands/experimental/bidi/types/events.py index 933759b15..8e6113ea3 100644 --- a/src/strands/experimental/bidi/types/events.py +++ b/src/strands/experimental/bidi/types/events.py @@ -21,7 +21,7 @@ from typing import Any, Dict, List, Literal, Optional, cast -from ....types._events import ModelStreamEvent, TypedEvent +from ....types._events import ModelStreamEvent, ToolUseStreamEvent, TypedEvent from ....types.streaming import ContentBlockDelta # Audio format constants @@ -561,4 +561,5 @@ def details(self) -> Optional[Dict[str, Any]]: | BidiUsageEvent | BidiConnectionCloseEvent | BidiErrorEvent + | ToolUseStreamEvent ) diff --git a/src/strands/tools/caller.py b/src/strands/tools/caller.py index 4663b662f..68357f266 100644 --- a/src/strands/tools/caller.py +++ b/src/strands/tools/caller.py @@ -113,7 +113,7 @@ def _find_normalized_tool_name(self, name: str) -> str: # Registry defends against similar names, so take first match if filtered_tools: - return filtered_tools[0] + return filtered_tools[0] # type: ignore raise AttributeError(f"Tool '{name}' not found") From f2207115f3322f7a12511cbf0ef1297f9c928915 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 18 Nov 2025 13:13:55 -0800 Subject: [PATCH 149/242] fix mypy errors - gemini_live --- .../experimental/bidi/models/gemini_live.py | 56 ++++++++++--------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index e4d83ab2e..5f7eb587f 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -15,11 +15,11 @@ import base64 import logging import uuid -from typing import Any, AsyncIterable, Dict, List, Optional +from typing import Any, AsyncIterable, Dict, List, Optional, cast from google import genai from google.genai import types as genai_types -from google.genai.types import LiveServerMessage +from google.genai.types import LiveConnectConfigOrDict, LiveServerMessage from ....types._events import ToolResultEvent, ToolUseStreamEvent from ....types.content import Messages @@ -59,9 +59,9 @@ class BidiGeminiLiveModel(BidiModel): def __init__( self, model_id: str = "gemini-2.5-flash-native-audio-preview-09-2025", - api_key: Optional[str] = None, - live_config: Optional[Dict[str, Any]] = None, - **kwargs, + api_key: str | None = None, + live_config: Dict[str, Any] | None = None, + **kwargs: Any, ): """Initialize Gemini Live API bidirectional model. @@ -89,7 +89,7 @@ def __init__( self.live_config = default_config # Create Gemini client with proper API version - client_kwargs = {} + client_kwargs: dict[str, Any] = {} if api_key: client_kwargs["api_key"] = api_key @@ -99,17 +99,17 @@ def __init__( self.client = genai.Client(**client_kwargs) # Connection state (initialized in start()) - self.live_session = None + self.live_session: Any self.live_session_context_manager = None - self.connection_id = None - self._active = False + self.connection_id | str + self._active: bool = False async def start( self, system_prompt: Optional[str] = None, tools: Optional[List[ToolSpec]] = None, messages: Optional[Messages] = None, - **kwargs, + **kwargs: Any, ) -> None: """Establish bidirectional connection with Gemini Live API. @@ -131,7 +131,9 @@ async def start( live_config = self._build_live_config(system_prompt, tools, **kwargs) # Create the context manager - self.live_session_context_manager = self.client.aio.live.connect(model=self.model_id, config=live_config) + self.live_session_context_manager = self.client.aio.live.connect( + model=self.model_id, config=cast(LiveConnectConfigOrDict, live_config) + ) # Enter the context manager self.live_session = await self.live_session_context_manager.__aenter__() @@ -167,9 +169,10 @@ async def _send_message_history(self, messages: Messages) -> None: # "assistant" role from Messages format maps to "model" in Gemini role = "model" if message["role"] == "assistant" else message["role"] content = genai_types.Content(role=role, parts=content_parts) - await self.live_session.send_client_content(turns=content) + if self.live_session: + await self.live_session.send_client_content(turns=content) - async def receive(self) -> AsyncIterable[BidiOutputEvent]: + async def receive(self) -> AsyncIterable[BidiOutputEvent]: # type: ignore """Receive Gemini Live API events and convert to provider-agnostic format.""" # Emit connection start event yield BidiConnectionStartEvent(connection_id=self.connection_id, model=self.model_id) @@ -180,7 +183,7 @@ async def receive(self) -> AsyncIterable[BidiOutputEvent]: try: async for message in self.live_session.receive(): if not self._active: - break + raise ValueError("connection is not active") # Convert to provider-agnostic format (always returns list) for event in self._convert_gemini_live_event(message): @@ -231,7 +234,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOut BidiTranscriptStreamEvent( delta={"text": transcription_text}, text=transcription_text, - role=role.lower() if isinstance(role, str) else "user", + role=role.lower() if isinstance(role, str) else "user", # type: ignore is_final=True, current_transcript=transcription_text, ) @@ -249,7 +252,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOut BidiTranscriptStreamEvent( delta={"text": transcription_text}, text=transcription_text, - role=role.lower() if isinstance(role, str) else "assistant", + role=role.lower() if isinstance(role, str) else "assistant", # type: ignore is_final=True, current_transcript=transcription_text, ) @@ -262,7 +265,10 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOut audio_b64 = base64.b64encode(message.data).decode("utf-8") return [ BidiAudioStreamEvent( - audio=audio_b64, format="pcm", sample_rate=GEMINI_OUTPUT_SAMPLE_RATE, channels=GEMINI_CHANNELS + audio=audio_b64, + format="pcm", + sample_rate=GEMINI_OUTPUT_SAMPLE_RATE, # type: ignore + channels=GEMINI_CHANNELS, # type: ignore ) ] @@ -294,15 +300,15 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOut tool_events = [] for func_call in message.tool_call.function_calls: tool_use_event: ToolUse = { - "toolUseId": func_call.id, - "name": func_call.name, + "toolUseId": func_call.id, # type: ignore + "name": func_call.name, # type: ignore "input": func_call.args or {}, } # Create ToolUseStreamEvent for consistency with standard agent tool_events.append( - ToolUseStreamEvent(delta={"toolUse": tool_use_event}, current_tool_use=tool_use_event) + ToolUseStreamEvent(delta={"toolUse": tool_use_event}, current_tool_use=dict(tool_use_event)) ) - return tool_events + return tool_events # type: ignore # Handle usage metadata if hasattr(message, "usage_metadata") and message.usage_metadata: @@ -342,7 +348,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOut input_tokens=usage.prompt_token_count or 0, output_tokens=usage.response_token_count or 0, total_tokens=usage.total_token_count or 0, - modality_details=modality_details if modality_details else None, + modality_details=modality_details if modality_details else None, # type: ignore cache_read_input_tokens=usage.cached_content_token_count if usage.cached_content_token_count else None, @@ -386,8 +392,6 @@ async def send( tool_result = content.get("tool_result") if tool_result: await self._send_tool_result(tool_result) - else: - logger.warning("content_type=<%s> | unknown content type", type(content).__name__) except Exception as e: logger.error("error=<%s> | error sending content to gemini live", e) raise # Propagate exception for debugging in experimental code @@ -481,7 +485,7 @@ async def stop(self) -> None: raise def _build_live_config( - self, system_prompt: Optional[str] = None, tools: Optional[List[ToolSpec]] = None, **kwargs + self, system_prompt: Optional[str] = None, tools: Optional[List[ToolSpec]] = None, **kwargs: Any ) -> Dict[str, Any]: """Build LiveConnectConfig for the official SDK. @@ -489,7 +493,7 @@ def _build_live_config( to configure any Gemini Live API parameter directly. """ # Start with user-provided live_config - config_dict = {} + config_dict: dict[str, Any] = {} if self.live_config: config_dict.update(self.live_config) From 1a7732e3ee8023396a41509f351886d043566145 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 19 Nov 2025 11:29:26 +0100 Subject: [PATCH 150/242] feat(bidi): Implement session manager --- src/strands/experimental/bidi/agent/agent.py | 13 +- .../session/repository_session_manager.py | 85 ++++++++++ src/strands/session/session_manager.py | 55 +++++++ src/strands/types/session.py | 37 +++++ .../test_repository_session_manager.py | 150 ++++++++++++++++++ 5 files changed, 339 insertions(+), 1 deletion(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index a922e99d7..caa22a126 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -15,9 +15,12 @@ import asyncio import json import logging -from typing import Any, AsyncIterable +from typing import TYPE_CHECKING, Any, AsyncIterable from .... import _identifier + +if TYPE_CHECKING: + from ....session.session_manager import SessionManager from ....hooks import HookProvider, HookRegistry from ....tools.caller import _ToolCaller from ..hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent @@ -63,6 +66,7 @@ def __init__( description: str | None = None, hooks: list[HookProvider] | None = None, state: AgentState | dict | None = None, + session_manager: "SessionManager | None" = None, **kwargs: Any, ): """Initialize bidirectional agent. @@ -80,6 +84,8 @@ def __init__( description: Description of what the Agent does. hooks: Optional list of hook providers to register for lifecycle events. state: Stateful information for the agent. Can be either an AgentState object, or a json serializable dict. + session_manager: Manager for handling agent sessions including conversation history and state. + If provided, enables session-based persistence and state management. **kwargs: Additional configuration for future extensibility. Raises: @@ -142,6 +148,11 @@ def __init__( for hook in hooks: self.hooks.add_hook(hook) + # Initialize session management functionality + self._session_manager = session_manager + if self._session_manager: + self.hooks.add_hook(self._session_manager) + self._loop = _BidiAgentLoop(self) # Emit initialization event diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index a042452d3..ad4733a35 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -18,6 +18,7 @@ if TYPE_CHECKING: from ..agent.agent import Agent + from ..experimental.bidi.agent.agent import BidiAgent from ..multiagent.base import MultiAgentBase logger = logging.getLogger(__name__) @@ -226,3 +227,87 @@ def initialize_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> Non else: logger.debug("session_id=<%s> | restoring multi-agent state", self.session_id) source.deserialize_state(state) + + def initialize_bidi_agent(self, agent: "BidiAgent", **kwargs: Any) -> None: + """Initialize a bidirectional agent with a session. + + Args: + agent: BidiAgent to initialize from the session + **kwargs: Additional keyword arguments for future extensibility. + """ + if agent.agent_id in self._latest_agent_message: + raise SessionException("The `agent_id` of an agent must be unique in a session.") + self._latest_agent_message[agent.agent_id] = None + + session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id) + + if session_agent is None: + logger.debug( + "agent_id=<%s> | session_id=<%s> | creating bidi agent", + agent.agent_id, + self.session_id, + ) + + session_agent = SessionAgent.from_bidi_agent(agent) + self.session_repository.create_agent(self.session_id, session_agent) + # Initialize messages with sequential indices + session_message = None + for i, message in enumerate(agent.messages): + session_message = SessionMessage.from_message(message, i) + self.session_repository.create_message(self.session_id, agent.agent_id, session_message) + self._latest_agent_message[agent.agent_id] = session_message + else: + logger.debug( + "agent_id=<%s> | session_id=<%s> | restoring bidi agent", + agent.agent_id, + self.session_id, + ) + agent.state = AgentState(session_agent.state) + + session_agent.initialize_bidi_internal_state(agent) + + # BidiAgent has no conversation_manager, so no prepend_messages or removed_message_count + session_messages = self.session_repository.list_messages( + session_id=self.session_id, + agent_id=agent.agent_id, + offset=0, + ) + if len(session_messages) > 0: + self._latest_agent_message[agent.agent_id] = session_messages[-1] + + # Restore the agents messages array + agent.messages = [session_message.to_message() for session_message in session_messages] + + # Fix broken session histories: https://github.com/strands-agents/sdk-python/issues/859 + agent.messages = self._fix_broken_tool_use(agent.messages) + + def append_bidi_message(self, message: Message, agent: "BidiAgent", **kwargs: Any) -> None: + """Append a message to the bidirectional agent's session. + + Args: + message: Message to add to the agent in the session + agent: BidiAgent to append the message to + **kwargs: Additional keyword arguments for future extensibility. + """ + # Calculate the next index (0 if this is the first message, otherwise increment the previous index) + latest_agent_message = self._latest_agent_message[agent.agent_id] + if latest_agent_message: + next_index = latest_agent_message.message_id + 1 + else: + next_index = 0 + + session_message = SessionMessage.from_message(message, next_index) + self._latest_agent_message[agent.agent_id] = session_message + self.session_repository.create_message(self.session_id, agent.agent_id, session_message) + + def sync_bidi_agent(self, agent: "BidiAgent", **kwargs: Any) -> None: + """Serialize and update the bidirectional agent into the session repository. + + Args: + agent: BidiAgent to sync to the session. + **kwargs: Additional keyword arguments for future extensibility. + """ + self.session_repository.update_agent( + self.session_id, + SessionAgent.from_bidi_agent(agent), + ) diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index fb9132828..4a9d32f31 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -47,6 +47,21 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: registry.add_callback(AfterNodeCallEvent, lambda event: self.sync_multi_agent(event.source)) registry.add_callback(AfterMultiAgentInvocationEvent, lambda event: self.sync_multi_agent(event.source)) + # Register BidiAgent hooks if the experimental module is available + try: + from ..experimental.bidi.hooks.events import ( + BidiAfterInvocationEvent, + BidiAgentInitializedEvent, + BidiMessageAddedEvent, + ) + + registry.add_callback(BidiAgentInitializedEvent, lambda event: self.initialize_bidi_agent(event.agent)) + registry.add_callback(BidiMessageAddedEvent, lambda event: self.append_bidi_message(event.message, event.agent)) + registry.add_callback(BidiMessageAddedEvent, lambda event: self.sync_bidi_agent(event.agent)) + registry.add_callback(BidiAfterInvocationEvent, lambda event: self.sync_bidi_agent(event.agent)) + except ImportError: + pass + @abstractmethod def redact_latest_message(self, redact_message: Message, agent: "Agent", **kwargs: Any) -> None: """Redact the message most recently appended to the agent in the session. @@ -114,3 +129,43 @@ def initialize_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> Non "(initialize_multi_agent). Provide an implementation or use a " "SessionManager with session_type=SessionType.MULTI_AGENT." ) + + def initialize_bidi_agent(self, agent: "BidiAgent", **kwargs: Any) -> None: + """Initialize a bidirectional agent with a session. + + Args: + agent: BidiAgent to initialize + **kwargs: Additional keyword arguments for future extensibility. + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not support bidirectional agent persistence " + "(initialize_bidi_agent). Provide an implementation or use a " + "SessionManager with bidirectional agent support." + ) + + def append_bidi_message(self, message: Message, agent: "BidiAgent", **kwargs: Any) -> None: + """Append a message to the bidirectional agent's session. + + Args: + message: Message to add to the agent in the session + agent: BidiAgent to append the message to + **kwargs: Additional keyword arguments for future extensibility. + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not support bidirectional agent persistence " + "(append_bidi_message). Provide an implementation or use a " + "SessionManager with bidirectional agent support." + ) + + def sync_bidi_agent(self, agent: "BidiAgent", **kwargs: Any) -> None: + """Serialize and sync the bidirectional agent with the session storage. + + Args: + agent: BidiAgent who should be synchronized with the session storage + **kwargs: Additional keyword arguments for future extensibility. + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not support bidirectional agent persistence " + "(sync_bidi_agent). Provide an implementation or use a " + "SessionManager with bidirectional agent support." + ) diff --git a/src/strands/types/session.py b/src/strands/types/session.py index 8b78ab448..c13d84df3 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from ..agent.agent import Agent + from ..experimental.bidi.agent.agent import BidiAgent class SessionType(str, Enum): @@ -136,6 +137,31 @@ def from_agent(cls, agent: "Agent") -> "SessionAgent": }, ) + @classmethod + def from_bidi_agent(cls, agent: "BidiAgent") -> "SessionAgent": + """Convert a BidiAgent to a SessionAgent. + + Args: + agent: BidiAgent to convert + + Returns: + SessionAgent with empty conversation_manager_state (BidiAgent doesn't use conversation manager) + """ + if agent.agent_id is None: + raise ValueError("agent_id needs to be defined.") + + # BidiAgent doesn't have _interrupt_state yet, so we use empty dict for internal state + internal_state = {} + if hasattr(agent, "_interrupt_state"): + internal_state["interrupt_state"] = agent._interrupt_state.to_dict() + + return cls( + agent_id=agent.agent_id, + conversation_manager_state={}, # BidiAgent has no conversation_manager + state=agent.state.get(), + _internal_state=internal_state, + ) + @classmethod def from_dict(cls, env: dict[str, Any]) -> "SessionAgent": """Initialize a SessionAgent from a dictionary, ignoring keys that are not class parameters.""" @@ -150,6 +176,17 @@ def initialize_internal_state(self, agent: "Agent") -> None: if "interrupt_state" in self._internal_state: agent._interrupt_state = _InterruptState.from_dict(self._internal_state["interrupt_state"]) + def initialize_bidi_internal_state(self, agent: "BidiAgent") -> None: + """Initialize internal state of BidiAgent. + + Args: + agent: BidiAgent to initialize internal state for + """ + # BidiAgent doesn't have _interrupt_state yet, so we skip interrupt state restoration + # When BidiAgent adds _interrupt_state support, this will automatically work + if "interrupt_state" in self._internal_state and hasattr(agent, "_interrupt_state"): + agent._interrupt_state = _InterruptState.from_dict(self._internal_state["interrupt_state"]) + @dataclass class Session: diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index 451d0dd09..b4b861857 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -413,3 +413,153 @@ def test_fix_broken_tool_use_does_not_change_valid_message(session_manager): # Should remain unchanged since toolUse is in last message assert fixed_messages == messages + + +# ============================================================================ +# BidiAgent Session Tests +# ============================================================================ + + +@pytest.fixture +def mock_bidi_agent(): + """Create a mock BidiAgent for testing.""" + from unittest.mock import Mock + from strands.agent.state import AgentState + + agent = Mock() + agent.agent_id = "bidi-agent-1" + agent.messages = [{"role": "user", "content": [{"text": "Hello from bidi!"}]}] + agent.state = AgentState({"key": "value"}) + # BidiAgent doesn't have _interrupt_state yet + return agent + + +def test_initialize_bidi_agent_creates_new(session_manager, mock_bidi_agent): + """Test initializing a new BidiAgent creates session data.""" + session_manager.initialize_bidi_agent(mock_bidi_agent) + + # Verify agent created in repository + agent_data = session_manager.session_repository.read_agent("test-session", "bidi-agent-1") + assert agent_data is not None + assert agent_data.agent_id == "bidi-agent-1" + assert agent_data.conversation_manager_state == {} # Empty for BidiAgent + assert agent_data.state == {"key": "value"} + + # Verify message created + messages = session_manager.session_repository.list_messages("test-session", "bidi-agent-1") + assert len(messages) == 1 + assert messages[0].message["role"] == "user" + + +def test_initialize_bidi_agent_restores_existing(session_manager, mock_bidi_agent): + """Test initializing BidiAgent restores from existing session.""" + from strands.types.session import SessionAgent, SessionMessage + + # Create existing session data + session_agent = SessionAgent( + agent_id="bidi-agent-1", + state={"restored": "state"}, + conversation_manager_state={}, # Empty for BidiAgent + ) + session_manager.session_repository.create_agent("test-session", session_agent) + + # Add messages + msg1 = SessionMessage.from_message({"role": "user", "content": [{"text": "Message 1"}]}, 0) + msg2 = SessionMessage.from_message({"role": "assistant", "content": [{"text": "Response 1"}]}, 1) + session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg1) + session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg2) + + # Initialize agent + session_manager.initialize_bidi_agent(mock_bidi_agent) + + # Verify state restored + assert mock_bidi_agent.state.get() == {"restored": "state"} + + # Verify messages restored + assert len(mock_bidi_agent.messages) == 2 + assert mock_bidi_agent.messages[0]["role"] == "user" + assert mock_bidi_agent.messages[1]["role"] == "assistant" + + +def test_append_bidi_message(session_manager, mock_bidi_agent): + """Test appending messages to BidiAgent session.""" + # Initialize agent first + session_manager.initialize_bidi_agent(mock_bidi_agent) + + # Append new message + new_message = {"role": "assistant", "content": [{"text": "Response"}]} + session_manager.append_bidi_message(new_message, mock_bidi_agent) + + # Verify message stored + messages = session_manager.session_repository.list_messages("test-session", "bidi-agent-1") + assert len(messages) == 2 # Initial + new + assert messages[1].message["role"] == "assistant" + + +def test_sync_bidi_agent(session_manager, mock_bidi_agent): + """Test syncing BidiAgent state to session.""" + from strands.agent.state import AgentState + + # Initialize agent + session_manager.initialize_bidi_agent(mock_bidi_agent) + + # Update agent state + mock_bidi_agent.state = AgentState({"updated": "state"}) + + # Sync agent + session_manager.sync_bidi_agent(mock_bidi_agent) + + # Verify state updated in repository + agent_data = session_manager.session_repository.read_agent("test-session", "bidi-agent-1") + assert agent_data.state == {"updated": "state"} + + +def test_bidi_agent_no_conversation_manager(session_manager, mock_bidi_agent): + """Test that BidiAgent session doesn't use conversation_manager.""" + session_manager.initialize_bidi_agent(mock_bidi_agent) + + # Verify conversation_manager_state is empty + agent_data = session_manager.session_repository.read_agent("test-session", "bidi-agent-1") + assert agent_data.conversation_manager_state == {} + + +def test_bidi_agent_unique_id_constraint(session_manager, mock_bidi_agent): + """Test that BidiAgent agent_id must be unique in session.""" + # Initialize first agent + session_manager.initialize_bidi_agent(mock_bidi_agent) + + # Try to initialize another agent with same ID + from unittest.mock import Mock + from strands.agent.state import AgentState + + agent2 = Mock() + agent2.agent_id = "bidi-agent-1" # Same ID + agent2.messages = [] + agent2.state = AgentState({}) + + with pytest.raises(SessionException, match="The `agent_id` of an agent must be unique in a session."): + session_manager.initialize_bidi_agent(agent2) + + +def test_bidi_agent_messages_with_offset_zero(session_manager, mock_bidi_agent): + """Test that BidiAgent uses offset=0 for message restoration (no conversation_manager).""" + from strands.types.session import SessionAgent, SessionMessage + + # Create session with messages + session_agent = SessionAgent( + agent_id="bidi-agent-1", + state={}, + conversation_manager_state={}, + ) + session_manager.session_repository.create_agent("test-session", session_agent) + + # Add 5 messages + for i in range(5): + msg = SessionMessage.from_message({"role": "user", "content": [{"text": f"Message {i}"}]}, i) + session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg) + + # Initialize agent + session_manager.initialize_bidi_agent(mock_bidi_agent) + + # Verify all messages restored (offset=0, no removed_message_count) + assert len(mock_bidi_agent.messages) == 5 From 610a2ddea8ef12e3efeceb9dd15c7b6368e79161 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 19 Nov 2025 11:41:59 +0100 Subject: [PATCH 151/242] feat(novasonic): Implement agent.messages injection on start --- .../experimental/bidi/models/novasonic.py | 57 ++++++++- .../bidi/models/test_novasonic.py | 114 ++++++++++++++++++ 2 files changed, 169 insertions(+), 2 deletions(-) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index eeaa1d659..3b4638692 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -186,8 +186,10 @@ def _build_initialization_events( events.extend(self._get_system_prompt_events(system_prompt)) - # Message history would be processed here if needed in the future - # Currently not implemented as it's not used in the existing test cases + # Add conversation history if provided + if messages: + events.extend(self._get_message_history_events(messages)) + logger.debug("message_count=<%d> | conversation history added to initialization", len(messages)) return events @@ -567,6 +569,57 @@ def _get_system_prompt_events(self, system_prompt: str) -> list[str]: self._get_content_end_event(content_name), ] + def _get_message_history_events(self, messages: Messages) -> list[str]: + """Generate conversation history events from agent messages. + + Converts agent message history to Nova Sonic format following the + contentStart/textInput/contentEnd pattern for each message. + + Args: + messages: List of conversation messages with role and content. + + Returns: + List of JSON event strings for Nova Sonic. + """ + events = [] + + for message in messages: + role = message["role"].upper() # Convert to ASSISTANT or USER + content_blocks = message.get("content", []) + + # Extract text content from content blocks + text_parts = [] + for block in content_blocks: + if "text" in block: + text_parts.append(block["text"]) + elif "toolUse" in block: + # Include tool use information in text format for context + tool_use = block["toolUse"] + text_parts.append(f"[Tool: {tool_use['name']}]") + elif "toolResult" in block: + # Include tool result information in text format for context + tool_result = block["toolResult"] + if "content" in tool_result: + for result_block in tool_result["content"]: + if "text" in result_block: + text_parts.append(f"[Tool Result: {result_block['text']}]") + + # Combine all text parts + if text_parts: + combined_text = "\n".join(text_parts) + content_name = str(uuid.uuid4()) + + # Add contentStart, textInput, and contentEnd events + events.extend( + [ + self._get_text_content_start_event(content_name, role), + self._get_text_input_event(content_name, combined_text), + self._get_content_end_event(content_name), + ] + ) + + return events + def _get_text_content_start_event(self, content_name: str, role: str = "USER") -> str: """Generate text content start event.""" return json.dumps( diff --git a/tests/strands/experimental/bidi/models/test_novasonic.py b/tests/strands/experimental/bidi/models/test_novasonic.py index e0459fd51..b2b2ab015 100644 --- a/tests/strands/experimental/bidi/models/test_novasonic.py +++ b/tests/strands/experimental/bidi/models/test_novasonic.py @@ -136,6 +136,49 @@ async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model await model2.stop() # Second call should also be safe +@pytest.mark.asyncio +async def test_connection_with_message_history(nova_model, mock_client, mock_stream): + """Test connection initialization with conversation history.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model.client = mock_client + + # Create message history + messages = [ + {"role": "user", "content": [{"text": "What's the weather?"}]}, + {"role": "assistant", "content": [{"text": "I'll check the weather for you."}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "tool-123", "name": "get_weather", "input": {}}}], + }, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "tool-123", "content": [{"text": "Sunny, 72°F"}]}}], + }, + {"role": "assistant", "content": [{"text": "It's sunny and 72 degrees."}]}, + ] + + # Start connection with message history + await nova_model.start(system_prompt="You are a helpful assistant", messages=messages) + + # Verify initialization events were sent + # Should include: sessionStart, promptStart, system prompt (3 events), + # and message history (5 messages * 3 events each = 15 events) + # Total: 1 + 1 + 3 + 15 = 20 events minimum + assert mock_stream.input_stream.send.call_count >= 18 + + # Verify the events contain proper role information + sent_events = [call.args[0].value.bytes_.decode("utf-8") for call in mock_stream.input_stream.send.call_args_list] + + # Check that USER and ASSISTANT roles are present in contentStart events + user_events = [e for e in sent_events if '"role": "USER"' in e] + assistant_events = [e for e in sent_events if '"role": "ASSISTANT"' in e] + + assert len(user_events) >= 2 # At least 2 user messages + assert len(assistant_events) >= 3 # At least 3 assistant messages + + await nova_model.stop() + + # Send Method Tests @@ -402,6 +445,77 @@ async def test_event_templates(nova_model): assert json.loads(event["event"]["toolResult"]["content"]) == result +@pytest.mark.asyncio +async def test_message_history_conversion(nova_model): + """Test conversion of agent messages to Nova Sonic history events.""" + nova_model.connection_id = "test-connection" + + # Test with various message types + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there!"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "tool-1", "name": "calculator", "input": {"expr": "2+2"}}}], + }, + {"role": "user", "content": [{"toolResult": {"toolUseId": "tool-1", "content": [{"text": "4"}]}}]}, + {"role": "assistant", "content": [{"text": "The answer is 4"}]}, + ] + + events = nova_model._get_message_history_events(messages) + + # Each message should generate 3 events: contentStart, textInput, contentEnd + assert len(events) == 15 # 5 messages * 3 events each + + # Parse and verify events + parsed_events = [json.loads(e) for e in events] + + # Check first message (user) + assert "contentStart" in parsed_events[0]["event"] + assert parsed_events[0]["event"]["contentStart"]["role"] == "USER" + assert "textInput" in parsed_events[1]["event"] + assert parsed_events[1]["event"]["textInput"]["content"] == "Hello" + assert "contentEnd" in parsed_events[2]["event"] + + # Check second message (assistant) + assert "contentStart" in parsed_events[3]["event"] + assert parsed_events[3]["event"]["contentStart"]["role"] == "ASSISTANT" + assert "textInput" in parsed_events[4]["event"] + assert parsed_events[4]["event"]["textInput"]["content"] == "Hi there!" + + # Check tool use message (should include tool name in text) + assert "textInput" in parsed_events[7]["event"] + assert "[Tool: calculator]" in parsed_events[7]["event"]["textInput"]["content"] + + # Check tool result message (should include result in text) + assert "textInput" in parsed_events[10]["event"] + assert "[Tool Result: 4]" in parsed_events[10]["event"]["textInput"]["content"] + + +@pytest.mark.asyncio +async def test_message_history_empty_and_edge_cases(nova_model): + """Test message history conversion with empty and edge cases.""" + nova_model.connection_id = "test-connection" + + # Test with empty messages + events = nova_model._get_message_history_events([]) + assert len(events) == 0 + + # Test with message containing no text content + messages = [{"role": "user", "content": []}] + events = nova_model._get_message_history_events(messages) + assert len(events) == 0 # No events generated for empty content + + # Test with multiple text blocks in one message + messages = [{"role": "user", "content": [{"text": "First part"}, {"text": "Second part"}]}] + events = nova_model._get_message_history_events(messages) + assert len(events) == 3 # contentStart, textInput, contentEnd + parsed = json.loads(events[1]) + content = parsed["event"]["textInput"]["content"] + assert "First part" in content + assert "Second part" in content + + # Error Handling Tests From 1c26f8d6f53a41afcf319f106ff0246d145575be Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 19 Nov 2025 12:04:32 +0100 Subject: [PATCH 152/242] feat: Invoke message added event on tool use --- src/strands/experimental/bidi/agent/loop.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 62136f2c1..18f9e521c 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -159,6 +159,9 @@ async def _run_model(self) -> None: message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} self._agent.messages.append(message) + await self._agent.hooks.invoke_callbacks_async( + BidiMessageAddedEvent(agent=self._agent, message=message) + ) elif isinstance(event, BidiInterruptionEvent): # Emit interruption hook event From 4cdaf31831d3c90d8625604ea65b5650dcb75fba Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 19 Nov 2025 12:05:15 +0100 Subject: [PATCH 153/242] fix(openai): fix agent history init --- .../experimental/bidi/models/openai.py | 108 +++++++++++++++--- .../bidi/models/test_openai_realtime.py | 60 ++++++++++ 2 files changed, 152 insertions(+), 16 deletions(-) diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 1564e5d68..eabdccc63 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -253,24 +253,100 @@ def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: return openai_tools async def _add_conversation_history(self, messages: Messages) -> None: - """Add conversation history to the session.""" + """Add conversation history to the session. + + Converts agent message history to OpenAI Realtime API format using + conversation.item.create events for each message. + + Note: OpenAI Realtime API has a 32-character limit on call_id, so we truncate + UUIDs consistently to ensure tool calls and their results match. + + Args: + messages: List of conversation messages with role and content. + """ + # Track tool call IDs to ensure consistency between calls and results + call_id_map = {} + + # First pass: collect all tool call IDs for message in messages: - conversation_item = { - "type": "conversation.item.create", - "item": {"type": "message", "role": message["role"], "content": []}, - } + for block in message.get("content", []): + if "toolUse" in block: + tool_use = block["toolUse"] + original_id = tool_use["toolUseId"] + call_id = original_id[:32] + call_id_map[original_id] = call_id + + # Second pass: send messages + for message in messages: + role = message["role"] + content_blocks = message.get("content", []) + + # Build content array for OpenAI format + openai_content = [] - content = message.get("content", "") - if isinstance(content, str): - conversation_item["item"]["content"].append({"type": "input_text", "text": content}) - elif isinstance(content, list): - for item in content: - if isinstance(item, dict) and item.get("type") == "text": - conversation_item["item"]["content"].append( - {"type": "input_text", "text": item.get("text", "")} - ) - - await self._send_event(conversation_item) + for block in content_blocks: + if "text" in block: + # Text content - use appropriate type based on role + # User messages use "input_text", assistant messages use "output_text" + if role == "user": + openai_content.append({"type": "input_text", "text": block["text"]}) + else: # assistant + openai_content.append({"type": "output_text", "text": block["text"]}) + elif "toolUse" in block: + # Tool use - create as function_call item + tool_use = block["toolUse"] + original_id = tool_use["toolUseId"] + # Use pre-mapped call_id + call_id = call_id_map[original_id] + + tool_item = { + "type": "conversation.item.create", + "item": { + "type": "function_call", + "call_id": call_id, + "name": tool_use["name"], + "arguments": json.dumps(tool_use["input"]), + }, + } + await self._send_event(tool_item) + continue # Tool use is sent separately, not in message content + elif "toolResult" in block: + # Tool result - create as function_call_output item + tool_result = block["toolResult"] + result_text = "" + if "content" in tool_result: + for result_block in tool_result["content"]: + if "text" in result_block: + result_text = result_block["text"] + break + + original_id = tool_result["toolUseId"] + # Use mapped call_id if available, otherwise skip orphaned result + if original_id not in call_id_map: + continue # Skip this tool result since we don't have the call + + call_id = call_id_map[original_id] + + result_item = { + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": call_id, + "output": result_text, + }, + } + await self._send_event(result_item) + continue # Tool result is sent separately, not in message content + + # Only create message item if there's text content + if openai_content: + conversation_item = { + "type": "conversation.item.create", + "item": {"type": "message", "role": role, "content": openai_content}, + } + await self._send_event(conversation_item) + + logger.debug("message_count=<%d> | conversation history added to openai session", len(messages)) async def receive(self) -> AsyncIterable[BidiOutputEvent]: """Receive OpenAI events and convert to Strands TypedEvent format.""" diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py index 2ffcac7ae..e22e62f82 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -182,6 +182,66 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp await model_org.stop() +@pytest.mark.asyncio +async def test_connection_with_message_history(mock_websockets_connect, model): + """Test connection initialization with conversation history including tool calls.""" + _, mock_ws = mock_websockets_connect + + # Create message history with various content types + messages = [ + {"role": "user", "content": [{"text": "What's the weather?"}]}, + {"role": "assistant", "content": [{"text": "I'll check the weather for you."}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "call-123", "name": "get_weather", "input": {"location": "Seattle"}}}], + }, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "call-123", "content": [{"text": "Sunny, 72°F"}]}}], + }, + {"role": "assistant", "content": [{"text": "It's sunny and 72 degrees."}]}, + ] + + # Start connection with message history + await model.start(messages=messages) + + # Get all sent events + calls = mock_ws.send.call_args_list + sent_events = [json.loads(call[0][0]) for call in calls] + + # Filter conversation.item.create events + item_creates = [e for e in sent_events if e.get("type") == "conversation.item.create"] + + # Should have 5 items: 2 messages, 1 function_call, 1 function_call_output, 1 message + assert len(item_creates) >= 5 + + # Verify message items + message_items = [e for e in item_creates if e.get("item", {}).get("type") == "message"] + assert len(message_items) >= 3 + + # Verify first user message + user_msg = message_items[0] + assert user_msg["item"]["role"] == "user" + assert user_msg["item"]["content"][0]["text"] == "What's the weather?" + + # Verify function call item + function_call_items = [e for e in item_creates if e.get("item", {}).get("type") == "function_call"] + assert len(function_call_items) >= 1 + func_call = function_call_items[0] + assert func_call["item"]["call_id"] == "call-123" + assert func_call["item"]["name"] == "get_weather" + assert json.loads(func_call["item"]["arguments"]) == {"location": "Seattle"} + + # Verify function call output item + function_output_items = [e for e in item_creates if e.get("item", {}).get("type") == "function_call_output"] + assert len(function_output_items) >= 1 + func_output = function_output_items[0] + assert func_output["item"]["call_id"] == "call-123" + assert "Sunny, 72°F" in func_output["item"]["output"] + + await model.stop() + + @pytest.mark.asyncio async def test_connection_edge_cases(mock_websockets_connect, api_key, model_name): """Test connection error handling and edge cases.""" From 5af9c9d55e41567ce413dd34063b14e0057d0851 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 19 Nov 2025 16:25:20 -0500 Subject: [PATCH 154/242] lint and mypy (#65) --- pyproject.toml | 3 ++- src/strands/experimental/bidi/agent/agent.py | 4 ++-- src/strands/experimental/bidi/agent/loop.py | 17 +++++++++-------- src/strands/experimental/bidi/io/audio.py | 10 +--------- tests_integ/bidi/test_bidi_hooks.py | 4 +--- 5 files changed, 15 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d7de4e226..6d977a236 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -234,9 +234,10 @@ module = [ "smithy_core.*", "smithy_aws_core.*", "aws_sdk_bedrock_runtime.*", + "pyaudio", ] ignore_missing_imports = true - +follow_imports = "skip" [tool.ruff] line-length = 120 diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index b1d12110c..7156b12be 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -18,17 +18,17 @@ from typing import Any, AsyncIterable from .... import _identifier +from ....agent.state import AgentState from ....hooks import HookProvider, HookRegistry from ....tools.caller import _ToolCaller -from ..hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent from ....tools.executors import ConcurrentToolExecutor from ....tools.executors._executor import ToolExecutor from ....tools.registry import ToolRegistry -from ....agent.state import AgentState from ....tools.watcher import ToolWatcher from ....types.content import ContentBlock, Message, Messages from ....types.tools import AgentTool, ToolResult, ToolUse from ...tools import ToolProvider +from ..hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent from ..models.bidi_model import BidiModel from ..models.novasonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index c0898f6fb..765727b45 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -7,19 +7,20 @@ import logging from typing import TYPE_CHECKING, Any, AsyncIterable, Awaitable +from ....types._events import ToolResultEvent, ToolResultMessageEvent, ToolStreamEvent, ToolUseStreamEvent +from ....types.content import Message +from ....types.tools import ToolResult, ToolUse from ..hooks.events import ( BidiAfterInvocationEvent, BidiAfterToolCallEvent, BidiBeforeInvocationEvent, BidiBeforeToolCallEvent, - BidiInterruptionEvent as BidiInterruptionHookEvent, BidiMessageAddedEvent, ) -from ..types.events import BidiAudioStreamEvent, BidiInterruptionEvent, BidiOutputEvent, BidiTranscriptStreamEvent -from ....types._events import ToolResultEvent, ToolResultMessageEvent, ToolStreamEvent, ToolUseStreamEvent -from ....types.content import Message -from ....types.tools import ToolResult, ToolUse -from ..types.events import BidiOutputEvent, BidiTranscriptStreamEvent +from ..hooks.events import ( + BidiInterruptionEvent as BidiInterruptionHookEvent, +) +from ..types.events import BidiInterruptionEvent, BidiOutputEvent, BidiTranscriptStreamEvent if TYPE_CHECKING: from .agent import BidiAgent @@ -173,10 +174,10 @@ async def _run_tool(self, tool_use: ToolUse) -> None: """Task for running tool requested by the model.""" logger.debug("tool_name=<%s> | tool execution starting", tool_use["name"]) - result: ToolResult = None + result: ToolResult exception: Exception | None = None tool = None - invocation_state = {} + invocation_state: dict[str, Any] = {} try: tool = self._agent.tool_registry.registry[tool_use["name"]] diff --git a/src/strands/experimental/bidi/io/audio.py b/src/strands/experimental/bidi/io/audio.py index f21bb7beb..2f129481f 100644 --- a/src/strands/experimental/bidi/io/audio.py +++ b/src/strands/experimental/bidi/io/audio.py @@ -10,7 +10,7 @@ from collections import deque from typing import Any -import pyaudio # type: ignore[import-untyped] +import pyaudio from ..types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiInterruptionEvent, BidiOutputEvent from ..types.io import BidiInput, BidiOutput @@ -73,8 +73,6 @@ async def stop(self) -> None: self._stream.close() self._audio.terminate() - self._stream = None - self._audio = None logger.debug("audio input stream stopped") async def __call__(self) -> BidiAudioInputEvent: @@ -156,12 +154,6 @@ async def stop(self) -> None: self._stream.close() self._audio.terminate() - # Adding type ignore to adhere to mypy - self._output_task = None # type: ignore[assignment] - self._buffer = None # type: ignore[assignment] - self._buffer_event = None # type: ignore[assignment] - self._stream = None - self._audio = None logger.debug("audio output stream stopped") async def __call__(self, event: BidiOutputEvent) -> None: diff --git a/tests_integ/bidi/test_bidi_hooks.py b/tests_integ/bidi/test_bidi_hooks.py index badfea384..f6cad162a 100644 --- a/tests_integ/bidi/test_bidi_hooks.py +++ b/tests_integ/bidi/test_bidi_hooks.py @@ -1,7 +1,5 @@ """Integration tests for BidiAgent hooks with real model providers.""" -import asyncio - import pytest from src.strands import tool @@ -162,7 +160,7 @@ async def test_hook_events_contain_agent_reference(self): await agent.stop() # All events should reference the same agent - for event_type, event in collector.events: + for _, event in collector.events: assert hasattr(event, "agent") assert event.agent == agent From 51c58351e0aaa732221290fd416e58e39b9b7126 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 20 Nov 2025 15:47:51 +0100 Subject: [PATCH 155/242] feat(bidi): Add tool executor --- src/strands/experimental/bidi/agent/agent.py | 12 +- src/strands/experimental/bidi/agent/loop.py | 66 ++----- .../experimental/bidi/hooks/__init__.py | 55 ------ src/strands/experimental/bidi/hooks/events.py | 173 ----------------- src/strands/experimental/hooks/__init__.py | 20 +- src/strands/experimental/hooks/events.py | 177 +++++++++++++++++- src/strands/tools/executors/_executor.py | 145 +++++++++----- src/strands/tools/executors/concurrent.py | 11 +- src/strands/tools/executors/sequential.py | 7 +- src/strands/types/tools.py | 4 +- .../experimental/bidi/hooks/__init__.py | 1 - .../bidi/hooks/test_bidi_hook_events.py | 168 ----------------- tests_integ/bidi/test_bidi_hooks.py | 58 +----- tests_integ/bidi/test_bidirectional_agent.py | 33 +++- 14 files changed, 364 insertions(+), 566 deletions(-) delete mode 100644 src/strands/experimental/bidi/hooks/__init__.py delete mode 100644 src/strands/experimental/bidi/hooks/events.py delete mode 100644 tests/strands/experimental/bidi/hooks/__init__.py delete mode 100644 tests/strands/experimental/bidi/hooks/test_bidi_hook_events.py diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 7156b12be..c766fc634 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -28,7 +28,7 @@ from ....types.content import ContentBlock, Message, Messages from ....types.tools import AgentTool, ToolResult, ToolUse from ...tools import ToolProvider -from ..hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent +from ...hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent from ..models.bidi_model import BidiModel from ..models.novasonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput @@ -59,10 +59,10 @@ def __init__( load_tools_from_directory: bool = False, agent_id: str | None = None, name: str | None = None, - tool_executor: ToolExecutor | None = None, description: str | None = None, hooks: list[HookProvider] | None = None, state: AgentState | dict | None = None, + tool_executor: ToolExecutor | None = None, **kwargs: Any, ): """Initialize bidirectional agent. @@ -76,10 +76,10 @@ def __init__( load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory. agent_id: Optional ID for the agent, useful for connection management and multi-agent scenarios. name: Name of the Agent. - tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). description: Description of what the Agent does. hooks: Optional list of hook providers to register for lifecycle events. state: Stateful information for the agent. Can be either an AgentState object, or a json serializable dict. + tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). **kwargs: Additional configuration for future extensibility. Raises: @@ -117,9 +117,6 @@ def __init__( if self.load_tools_from_directory: self.tool_watcher = ToolWatcher(tool_registry=self.tool_registry) - # Initialize tool executor - self.tool_executor = tool_executor or ConcurrentToolExecutor() - # Initialize agent state management if state is not None: if isinstance(state, dict): @@ -134,6 +131,9 @@ def __init__( # Initialize other components self._tool_caller = _ToolCaller(self) + # Initialize tool executor + self.tool_executor = tool_executor or ConcurrentToolExecutor() + # Initialize hooks registry self.hooks = HookRegistry() if hooks: diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 765727b45..5200e8d0a 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -10,14 +10,12 @@ from ....types._events import ToolResultEvent, ToolResultMessageEvent, ToolStreamEvent, ToolUseStreamEvent from ....types.content import Message from ....types.tools import ToolResult, ToolUse -from ..hooks.events import ( +from ...hooks.events import ( BidiAfterInvocationEvent, - BidiAfterToolCallEvent, BidiBeforeInvocationEvent, - BidiBeforeToolCallEvent, BidiMessageAddedEvent, ) -from ..hooks.events import ( +from ...hooks.events import ( BidiInterruptionEvent as BidiInterruptionHookEvent, ) from ..types.events import BidiInterruptionEvent, BidiOutputEvent, BidiTranscriptStreamEvent @@ -171,56 +169,30 @@ async def _run_model(self) -> None: ) async def _run_tool(self, tool_use: ToolUse) -> None: - """Task for running tool requested by the model.""" + """Task for running tool requested by the model using the tool executor.""" logger.debug("tool_name=<%s> | tool execution starting", tool_use["name"]) - result: ToolResult - exception: Exception | None = None - tool = None - invocation_state: dict[str, Any] = {} + tool_results: list[ToolResult] = [] + invocation_state: dict[str, Any] = {"agent": self._agent} - try: - tool = self._agent.tool_registry.registry[tool_use["name"]] - - # Emit before tool call event - await self._agent.hooks.invoke_callbacks_async( - BidiBeforeToolCallEvent( - agent=self._agent, - selected_tool=tool, - tool_use=tool_use, - invocation_state=invocation_state, - ) - ) - - async for event in tool.stream(tool_use, invocation_state): - if isinstance(event, ToolResultEvent): - await self._event_queue.put(event) - result = event.tool_result - break - - if isinstance(event, ToolStreamEvent): - await self._event_queue.put(event) - else: - await self._event_queue.put(ToolStreamEvent(tool_use, event)) - - except Exception as e: - result = {"toolUseId": tool_use["toolUseId"], "status": "error", "content": [{"text": f"Error: {str(e)}"}]} + # Use the tool executor to run the tool (no tracing/metrics for BidiAgent yet) + tool_events = self._agent.tool_executor._stream( + self._agent, + tool_use, + tool_results, + invocation_state, + structured_output_context=None, # BidiAgent doesn't support structured output yet + ) - finally: - # Emit after tool call event (reverse order for cleanup) - await self._agent.hooks.invoke_callbacks_async( - BidiAfterToolCallEvent( - agent=self._agent, - selected_tool=tool, - tool_use=tool_use, - invocation_state=invocation_state, - result=result, - exception=exception, - ) - ) + async for event in tool_events: + await self._event_queue.put(event) + if isinstance(event, ToolResultEvent): + result = event.tool_result + # Send tool result to model await self._agent.model.send(ToolResultEvent(result)) + # Add tool result message to conversation history message: Message = { "role": "user", "content": [{"toolResult": result}], diff --git a/src/strands/experimental/bidi/hooks/__init__.py b/src/strands/experimental/bidi/hooks/__init__.py deleted file mode 100644 index 6ed0e52cf..000000000 --- a/src/strands/experimental/bidi/hooks/__init__.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Typed hook system for BidiAgent. - -This module provides hook events specifically for BidiAgent, enabling -composable extension of bidirectional streaming agent functionality. - -Example Usage: - ```python - from strands.experimental.bidi.hooks import ( - BidiBeforeInvocationEvent, - BidiInterruptionEvent, - HookProvider, - HookRegistry - ) - - class BidiLoggingHooks(HookProvider): - def register_hooks(self, registry: HookRegistry) -> None: - registry.add_callback(BidiBeforeInvocationEvent, self.log_session_start) - registry.add_callback(BidiInterruptionEvent, self.log_interruption) - - def log_session_start(self, event: BidiBeforeInvocationEvent) -> None: - print(f"BidiAgent {event.agent.name} starting session") - - def log_interruption(self, event: BidiInterruptionEvent) -> None: - print(f"Interrupted: {event.reason}") - - # Use with BidiAgent - agent = BidiAgent(hooks=[BidiLoggingHooks()]) - ``` -""" - -from ....hooks import HookCallback, HookProvider, HookRegistry -from .events import ( - BidiAfterInvocationEvent, - BidiAfterToolCallEvent, - BidiAgentInitializedEvent, - BidiBeforeInvocationEvent, - BidiBeforeToolCallEvent, - BidiHookEvent, - BidiInterruptionEvent, - BidiMessageAddedEvent, -) - -__all__ = [ - "BidiAgentInitializedEvent", - "BidiBeforeInvocationEvent", - "BidiAfterInvocationEvent", - "BidiBeforeToolCallEvent", - "BidiAfterToolCallEvent", - "BidiMessageAddedEvent", - "BidiInterruptionEvent", - "BidiHookEvent", - "HookProvider", - "HookCallback", - "HookRegistry", -] diff --git a/src/strands/experimental/bidi/hooks/events.py b/src/strands/experimental/bidi/hooks/events.py deleted file mode 100644 index d4add3200..000000000 --- a/src/strands/experimental/bidi/hooks/events.py +++ /dev/null @@ -1,173 +0,0 @@ -"""Hook events for BidiAgent. - -This module defines the events that are emitted as BidiAgent runs through -the lifecycle of a streaming session. -""" - -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal, Optional - -from ....hooks.registry import BaseHookEvent -from ....types.content import Message -from ....types.tools import AgentTool, ToolResult, ToolUse - -if TYPE_CHECKING: - from ..agent.agent import BidiAgent - - -@dataclass -class BidiHookEvent(BaseHookEvent): - """Base class for BidiAgent hook events. - - Attributes: - agent: The BidiAgent instance that triggered this event. - """ - - agent: "BidiAgent" - - -@dataclass -class BidiAgentInitializedEvent(BidiHookEvent): - """Event triggered when a BidiAgent has finished initialization. - - This event is fired after the BidiAgent has been fully constructed and all - built-in components have been initialized. Hook providers can use this - event to perform setup tasks that require a fully initialized agent. - """ - - pass - - -@dataclass -class BidiBeforeInvocationEvent(BidiHookEvent): - """Event triggered when BidiAgent starts a streaming session. - - This event is fired before the BidiAgent begins a streaming session, - before any model connection or audio processing occurs. Hook providers can - use this event to perform session-level setup, logging, or validation. - - This event is triggered at the beginning of agent.start(). - """ - - pass - - -@dataclass -class BidiAfterInvocationEvent(BidiHookEvent): - """Event triggered when BidiAgent ends a streaming session. - - This event is fired after the BidiAgent has completed a streaming session, - regardless of whether it completed successfully or encountered an error. - Hook providers can use this event for cleanup, logging, or state persistence. - - Note: This event uses reverse callback ordering, meaning callbacks registered - later will be invoked first during cleanup. - - This event is triggered at the end of agent.stop(). - """ - - @property - def should_reverse_callbacks(self) -> bool: - """True to invoke callbacks in reverse order.""" - return True - - -@dataclass -class BidiMessageAddedEvent(BidiHookEvent): - """Event triggered when BidiAgent adds a message to the conversation. - - This event is fired whenever the BidiAgent adds a new message to its internal - message history, including user messages (from transcripts), assistant responses, - and tool results. Hook providers can use this event for logging, monitoring, or - implementing custom message processing logic. - - Note: This event is only triggered for messages added by the framework - itself, not for messages manually added by tools or external code. - - Attributes: - message: The message that was added to the conversation history. - """ - - message: Message - - -@dataclass -class BidiBeforeToolCallEvent(BidiHookEvent): - """Event triggered before BidiAgent executes a tool. - - This event is fired just before the BidiAgent executes a tool during a streaming - session, allowing hook providers to inspect, modify, or replace the tool that - will be executed. The selected_tool can be modified by hook callbacks to change - which tool gets executed. - - Attributes: - selected_tool: The tool that will be invoked. Can be modified by hooks - to change which tool gets executed. This may be None if tool lookup failed. - tool_use: The tool parameters that will be passed to selected_tool. - invocation_state: Keyword arguments that will be passed to the tool. - """ - - selected_tool: Optional[AgentTool] - tool_use: ToolUse - invocation_state: dict[str, Any] - - def _can_write(self, name: str) -> bool: - return name in ["selected_tool", "tool_use"] - - -@dataclass -class BidiAfterToolCallEvent(BidiHookEvent): - """Event triggered after BidiAgent executes a tool. - - This event is fired after the BidiAgent has finished executing a tool during - a streaming session, regardless of whether the execution was successful or - resulted in an error. Hook providers can use this event for cleanup, logging, - or post-processing. - - Note: This event uses reverse callback ordering, meaning callbacks registered - later will be invoked first during cleanup. - - Attributes: - selected_tool: The tool that was invoked. It may be None if tool lookup failed. - tool_use: The tool parameters that were passed to the tool invoked. - invocation_state: Keyword arguments that were passed to the tool. - result: The result of the tool invocation. Either a ToolResult on success - or an Exception if the tool execution failed. - exception: Exception if the tool execution failed, None if successful. - cancel_message: The cancellation message if the user cancelled the tool call. - """ - - selected_tool: Optional[AgentTool] - tool_use: ToolUse - invocation_state: dict[str, Any] - result: ToolResult - exception: Optional[Exception] = None - cancel_message: str | None = None - - def _can_write(self, name: str) -> bool: - return name == "result" - - @property - def should_reverse_callbacks(self) -> bool: - """True to invoke callbacks in reverse order.""" - return True - - -@dataclass -class BidiInterruptionEvent(BidiHookEvent): - """Event triggered when model generation is interrupted. - - This event is fired when the user interrupts the assistant (e.g., by speaking - during the assistant's response) or when an error causes interruption. This is - specific to bidirectional streaming and doesn't exist in standard agents. - - Hook providers can use this event to log interruptions, implement custom - interruption handling, or trigger cleanup logic. - - Attributes: - reason: The reason for the interruption ("user_speech" or "error"). - interrupted_response_id: Optional ID of the response that was interrupted. - """ - - reason: Literal["user_speech", "error"] - interrupted_response_id: Optional[str] = None diff --git a/src/strands/experimental/hooks/__init__.py b/src/strands/experimental/hooks/__init__.py index 098d4cf0d..7c3c2b269 100644 --- a/src/strands/experimental/hooks/__init__.py +++ b/src/strands/experimental/hooks/__init__.py @@ -1,10 +1,20 @@ -"""Experimental hook functionality that has not yet reached stability.""" +"""Experimental hook functionality that has not yet reached stability. + +BidiAgent hooks are also available here to avoid circular imports. +""" from .events import ( AfterModelInvocationEvent, AfterToolInvocationEvent, BeforeModelInvocationEvent, BeforeToolInvocationEvent, + BidiAfterInvocationEvent, + BidiAfterToolCallEvent, + BidiAgentInitializedEvent, + BidiBeforeInvocationEvent, + BidiBeforeToolCallEvent, + BidiInterruptionEvent, + BidiMessageAddedEvent, ) __all__ = [ @@ -12,4 +22,12 @@ "AfterToolInvocationEvent", "BeforeModelInvocationEvent", "AfterModelInvocationEvent", + # BidiAgent hooks + "BidiAgentInitializedEvent", + "BidiBeforeInvocationEvent", + "BidiAfterInvocationEvent", + "BidiMessageAddedEvent", + "BidiBeforeToolCallEvent", + "BidiAfterToolCallEvent", + "BidiInterruptionEvent", ] diff --git a/src/strands/experimental/hooks/events.py b/src/strands/experimental/hooks/events.py index d711dd7ed..485e8d201 100644 --- a/src/strands/experimental/hooks/events.py +++ b/src/strands/experimental/hooks/events.py @@ -1,12 +1,21 @@ """Experimental hook events emitted as part of invoking Agents. This module defines the events that are emitted as Agents run through the lifecycle of a request. + +BidiAgent hook events are also defined here to avoid circular imports. """ import warnings -from typing import TypeAlias +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal, Optional, TypeAlias from ...hooks.events import AfterModelCallEvent, AfterToolCallEvent, BeforeModelCallEvent, BeforeToolCallEvent +from ...hooks.registry import BaseHookEvent +from ...types.content import Message +from ...types.tools import AgentTool, ToolResult, ToolUse + +if TYPE_CHECKING: + from ..bidi.agent.agent import BidiAgent warnings.warn( "These events have been moved to production with updated names. Use BeforeModelCallEvent, " @@ -19,3 +28,169 @@ AfterToolInvocationEvent: TypeAlias = AfterToolCallEvent BeforeModelInvocationEvent: TypeAlias = BeforeModelCallEvent AfterModelInvocationEvent: TypeAlias = AfterModelCallEvent + + +# BidiAgent Hook Events +# These are defined here to avoid circular imports with the bidi package + + +@dataclass +class BidiHookEvent(BaseHookEvent): + """Base class for BidiAgent hook events. + + Attributes: + agent: The BidiAgent instance that triggered this event. + """ + + agent: "BidiAgent" + + +@dataclass +class BidiAgentInitializedEvent(BidiHookEvent): + """Event triggered when a BidiAgent has finished initialization. + + This event is fired after the BidiAgent has been fully constructed and all + built-in components have been initialized. Hook providers can use this + event to perform setup tasks that require a fully initialized agent. + """ + + pass + + +@dataclass +class BidiBeforeInvocationEvent(BidiHookEvent): + """Event triggered when BidiAgent starts a streaming session. + + This event is fired before the BidiAgent begins a streaming session, + before any model connection or audio processing occurs. Hook providers can + use this event to perform session-level setup, logging, or validation. + + This event is triggered at the beginning of agent.start(). + """ + + pass + + +@dataclass +class BidiAfterInvocationEvent(BidiHookEvent): + """Event triggered when BidiAgent ends a streaming session. + + This event is fired after the BidiAgent has completed a streaming session, + regardless of whether it completed successfully or encountered an error. + Hook providers can use this event for cleanup, logging, or state persistence. + + Note: This event uses reverse callback ordering, meaning callbacks registered + later will be invoked first during cleanup. + + This event is triggered at the end of agent.stop(). + """ + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True + + +@dataclass +class BidiMessageAddedEvent(BidiHookEvent): + """Event triggered when BidiAgent adds a message to the conversation. + + This event is fired whenever the BidiAgent adds a new message to its internal + message history, including user messages (from transcripts), assistant responses, + and tool results. Hook providers can use this event for logging, monitoring, or + implementing custom message processing logic. + + Note: This event is only triggered for messages added by the framework + itself, not for messages manually added by tools or external code. + + Attributes: + message: The message that was added to the conversation history. + """ + + message: Message + + +@dataclass +class BidiBeforeToolCallEvent(BidiHookEvent): + """Event triggered before BidiAgent executes a tool. + + This event is fired just before the BidiAgent executes a tool during a streaming + session, allowing hook providers to inspect, modify, or replace the tool that + will be executed. The selected_tool can be modified by hook callbacks to change + which tool gets executed. + + Attributes: + selected_tool: The tool that will be invoked. Can be modified by hooks + to change which tool gets executed. This may be None if tool lookup failed. + tool_use: The tool parameters that will be passed to selected_tool. + invocation_state: Keyword arguments that will be passed to the tool. + cancel_tool: A user defined message that when set, will cancel the tool call. + The message will be placed into a tool result with an error status. If set to `True`, Strands will cancel + the tool call and use a default cancel message. + """ + + selected_tool: Optional[AgentTool] + tool_use: ToolUse + invocation_state: dict[str, Any] + cancel_tool: bool | str = False + + def _can_write(self, name: str) -> bool: + return name in ["cancel_tool", "selected_tool", "tool_use"] + + +@dataclass +class BidiAfterToolCallEvent(BidiHookEvent): + """Event triggered after BidiAgent executes a tool. + + This event is fired after the BidiAgent has finished executing a tool during + a streaming session, regardless of whether the execution was successful or + resulted in an error. Hook providers can use this event for cleanup, logging, + or post-processing. + + Note: This event uses reverse callback ordering, meaning callbacks registered + later will be invoked first during cleanup. + + Attributes: + selected_tool: The tool that was invoked. It may be None if tool lookup failed. + tool_use: The tool parameters that were passed to the tool invoked. + invocation_state: Keyword arguments that were passed to the tool. + result: The result of the tool invocation. Either a ToolResult on success + or an Exception if the tool execution failed. + exception: Exception if the tool execution failed, None if successful. + cancel_message: The cancellation message if the user cancelled the tool call. + """ + + selected_tool: Optional[AgentTool] + tool_use: ToolUse + invocation_state: dict[str, Any] + result: ToolResult + exception: Optional[Exception] = None + cancel_message: str | None = None + + def _can_write(self, name: str) -> bool: + return name == "result" + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True + + +@dataclass +class BidiInterruptionEvent(BidiHookEvent): + """Event triggered when model generation is interrupted. + + This event is fired when the user interrupts the assistant (e.g., by speaking + during the assistant's response) or when an error causes interruption. This is + specific to bidirectional streaming and doesn't exist in standard agents. + + Hook providers can use this event to log interruptions, implement custom + interruption handling, or trigger cleanup logic. + + Attributes: + reason: The reason for the interruption ("user_speech" or "error"). + interrupted_response_id: Optional ID of the response that was interrupted. + """ + + reason: Literal["user_speech", "error"] + interrupted_response_id: Optional[str] = None diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 87c38990d..969be521b 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -7,10 +7,11 @@ import abc import logging import time -from typing import TYPE_CHECKING, Any, AsyncGenerator, cast +from typing import TYPE_CHECKING, Any, AsyncGenerator, Union, cast from opentelemetry import trace as trace_api +from ...experimental.hooks.events import BidiAfterToolCallEvent, BidiBeforeToolCallEvent from ...hooks import AfterToolCallEvent, BeforeToolCallEvent from ...telemetry.metrics import Trace from ...telemetry.tracer import get_tracer, serialize @@ -21,6 +22,7 @@ if TYPE_CHECKING: # pragma: no cover from ...agent import Agent + from ...experimental.bidi.agent.agent import BidiAgent logger = logging.getLogger(__name__) @@ -28,9 +30,83 @@ class ToolExecutor(abc.ABC): """Abstract base class for tool executors.""" + @staticmethod + def _is_bidi_agent(agent: Union["Agent", "BidiAgent"]) -> bool: + """Check if the agent is a BidiAgent instance. + + Uses isinstance() with runtime import to avoid circular imports. + """ + # Import at runtime to avoid circular dependency + from ...experimental.bidi.agent.agent import BidiAgent + + return isinstance(agent, BidiAgent) + + @staticmethod + async def _invoke_before_tool_call_hook( + agent: Union["Agent", "BidiAgent"], + tool_func: Any, + tool_use: ToolUse, + invocation_state: dict[str, Any], + ) -> tuple[Any, list]: + """Invoke the appropriate before tool call hook based on agent type.""" + if ToolExecutor._is_bidi_agent(agent): + return await agent.hooks.invoke_callbacks_async( + BidiBeforeToolCallEvent( + agent=agent, + selected_tool=tool_func, + tool_use=tool_use, + invocation_state=invocation_state, + ) + ) + else: + return await agent.hooks.invoke_callbacks_async( + BeforeToolCallEvent( + agent=agent, + selected_tool=tool_func, + tool_use=tool_use, + invocation_state=invocation_state, + ) + ) + + @staticmethod + async def _invoke_after_tool_call_hook( + agent: Union["Agent", "BidiAgent"], + selected_tool: Any, + tool_use: ToolUse, + invocation_state: dict[str, Any], + result: ToolResult, + exception: Exception | None = None, + cancel_message: str | None = None, + ) -> tuple[Any, list]: + """Invoke the appropriate after tool call hook based on agent type.""" + if ToolExecutor._is_bidi_agent(agent): + return await agent.hooks.invoke_callbacks_async( + BidiAfterToolCallEvent( + agent=agent, + selected_tool=selected_tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=result, + exception=exception, + cancel_message=cancel_message, + ) + ) + else: + return await agent.hooks.invoke_callbacks_async( + AfterToolCallEvent( + agent=agent, + selected_tool=selected_tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=result, + exception=exception, + cancel_message=cancel_message, + ) + ) + @staticmethod async def _stream( - agent: "Agent", + agent: "Agent | BidiAgent", tool_use: ToolUse, tool_results: list[ToolResult], invocation_state: dict[str, Any], @@ -48,7 +124,7 @@ async def _stream( - Interrupt handling for human-in-the-loop workflows Args: - agent: The agent for which the tool is being executed. + agent: The agent (Agent or BidiAgent) for which the tool is being executed. tool_use: Metadata and inputs for the tool to be executed. tool_results: List of tool results from each tool execution. invocation_state: Context for the tool invocation. @@ -85,13 +161,9 @@ async def _stream( } ) - before_event, interrupts = await agent.hooks.invoke_callbacks_async( - BeforeToolCallEvent( - agent=agent, - selected_tool=tool_func, - tool_use=tool_use, - invocation_state=invocation_state, - ) + # Invoke appropriate before tool call hook based on agent type + before_event, interrupts = await ToolExecutor._invoke_before_tool_call_hook( + agent, tool_func, tool_use, invocation_state ) if interrupts: @@ -109,15 +181,9 @@ async def _stream( "status": "error", "content": [{"text": cancel_message}], } - after_event, _ = await agent.hooks.invoke_callbacks_async( - AfterToolCallEvent( - agent=agent, - tool_use=tool_use, - invocation_state=invocation_state, - selected_tool=None, - result=cancel_result, - cancel_message=cancel_message, - ) + + after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + agent, None, tool_use, invocation_state, cancel_result, cancel_message=cancel_message ) yield ToolResultEvent(after_event.result) tool_results.append(after_event.result) @@ -147,14 +213,9 @@ async def _stream( "status": "error", "content": [{"text": f"Unknown tool: {tool_name}"}], } - after_event, _ = await agent.hooks.invoke_callbacks_async( - AfterToolCallEvent( - agent=agent, - selected_tool=selected_tool, - tool_use=tool_use, - invocation_state=invocation_state, - result=result, - ) + + after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + agent, selected_tool, tool_use, invocation_state, result ) yield ToolResultEvent(after_event.result) tool_results.append(after_event.result) @@ -184,14 +245,8 @@ async def _stream( result = cast(ToolResult, event) - after_event, _ = await agent.hooks.invoke_callbacks_async( - AfterToolCallEvent( - agent=agent, - selected_tool=selected_tool, - tool_use=tool_use, - invocation_state=invocation_state, - result=result, - ) + after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + agent, selected_tool, tool_use, invocation_state, result ) yield ToolResultEvent(after_event.result) @@ -204,22 +259,16 @@ async def _stream( "status": "error", "content": [{"text": f"Error: {str(e)}"}], } - after_event, _ = await agent.hooks.invoke_callbacks_async( - AfterToolCallEvent( - agent=agent, - selected_tool=selected_tool, - tool_use=tool_use, - invocation_state=invocation_state, - result=error_result, - exception=e, - ) + + after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + agent, selected_tool, tool_use, invocation_state, error_result, exception=e ) yield ToolResultEvent(after_event.result) tool_results.append(after_event.result) @staticmethod async def _stream_with_trace( - agent: "Agent", + agent: Union["Agent", "BidiAgent"], tool_use: ToolUse, tool_results: list[ToolResult], cycle_trace: Trace, @@ -231,7 +280,7 @@ async def _stream_with_trace( """Execute tool with tracing and metrics collection. Args: - agent: The agent for which the tool is being executed. + agent: The agent (Agent or BidiAgent) for which the tool is being executed. tool_use: Metadata and inputs for the tool to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. @@ -277,7 +326,7 @@ async def _stream_with_trace( # pragma: no cover def _execute( self, - agent: "Agent", + agent: Union["Agent", "BidiAgent"], tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, @@ -288,7 +337,7 @@ def _execute( """Execute the given tools according to this executor's strategy. Args: - agent: The agent for which tools are being executed. + agent: The agent (Agent or BidiAgent) for which tools are being executed. tool_uses: Metadata and inputs for the tools to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py index 216eee379..1a586c589 100644 --- a/src/strands/tools/executors/concurrent.py +++ b/src/strands/tools/executors/concurrent.py @@ -1,7 +1,7 @@ """Concurrent tool executor implementation.""" import asyncio -from typing import TYPE_CHECKING, Any, AsyncGenerator +from typing import TYPE_CHECKING, Any, AsyncGenerator, Union from typing_extensions import override @@ -12,6 +12,7 @@ if TYPE_CHECKING: # pragma: no cover from ...agent import Agent + from ...experimental.bidi.agent.agent import BidiAgent from ..structured_output._structured_output_context import StructuredOutputContext @@ -21,7 +22,7 @@ class ConcurrentToolExecutor(ToolExecutor): @override async def _execute( self, - agent: "Agent", + agent: Union["Agent", "BidiAgent"], tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, @@ -32,7 +33,7 @@ async def _execute( """Execute tools concurrently. Args: - agent: The agent for which tools are being executed. + agent: The agent (Agent or BidiAgent) for which tools are being executed. tool_uses: Metadata and inputs for the tools to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. @@ -78,7 +79,7 @@ async def _execute( async def _task( self, - agent: "Agent", + agent: Union["Agent", "BidiAgent"], tool_use: ToolUse, tool_results: list[ToolResult], cycle_trace: Trace, @@ -93,7 +94,7 @@ async def _task( """Execute a single tool and put results in the task queue. Args: - agent: The agent executing the tool. + agent: The agent (Agent or BidiAgent) executing the tool. tool_use: Tool use metadata and inputs. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py index f78e60872..e4ac0ecda 100644 --- a/src/strands/tools/executors/sequential.py +++ b/src/strands/tools/executors/sequential.py @@ -1,6 +1,6 @@ """Sequential tool executor implementation.""" -from typing import TYPE_CHECKING, Any, AsyncGenerator +from typing import TYPE_CHECKING, Any, AsyncGenerator, Union from typing_extensions import override @@ -11,6 +11,7 @@ if TYPE_CHECKING: # pragma: no cover from ...agent import Agent + from ...experimental.bidi.agent.agent import BidiAgent from ..structured_output._structured_output_context import StructuredOutputContext @@ -20,7 +21,7 @@ class SequentialToolExecutor(ToolExecutor): @override async def _execute( self, - agent: "Agent", + agent: Union["Agent", "BidiAgent"], tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, @@ -33,7 +34,7 @@ async def _execute( Breaks early if an interrupt is raised by the user. Args: - agent: The agent for which tools are being executed. + agent: The agent (Agent or BidiAgent) for which tools are being executed. tool_uses: Metadata and inputs for the tools to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 8343647b2..4091f48ea 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -136,7 +136,7 @@ class ToolContext(_Interruptible): Attributes: tool_use: The complete ToolUse object containing tool invocation details. - agent: The Agent instance executing this tool, providing access to conversation history, + agent: The Agent or BidiAgent instance executing this tool, providing access to conversation history, model configuration, and other agent state. invocation_state: Caller-provided kwargs that were passed to the agent when it was invoked (agent(), agent.invoke_async(), etc.). @@ -147,7 +147,7 @@ class ToolContext(_Interruptible): """ tool_use: ToolUse - agent: "Agent" + agent: Any # Agent or BidiAgent - using Any for backwards compatibility invocation_state: dict[str, Any] def _interrupt_id(self, name: str) -> str: diff --git a/tests/strands/experimental/bidi/hooks/__init__.py b/tests/strands/experimental/bidi/hooks/__init__.py deleted file mode 100644 index 20a078833..000000000 --- a/tests/strands/experimental/bidi/hooks/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for BidiAgent hooks.""" diff --git a/tests/strands/experimental/bidi/hooks/test_bidi_hook_events.py b/tests/strands/experimental/bidi/hooks/test_bidi_hook_events.py deleted file mode 100644 index bf3710066..000000000 --- a/tests/strands/experimental/bidi/hooks/test_bidi_hook_events.py +++ /dev/null @@ -1,168 +0,0 @@ -"""Unit tests for BidiAgent hook events.""" - -from unittest.mock import Mock - -import pytest - -from strands.experimental.bidi.hooks import ( - BidiAfterInvocationEvent, - BidiAfterToolCallEvent, - BidiAgentInitializedEvent, - BidiBeforeInvocationEvent, - BidiBeforeToolCallEvent, - BidiInterruptionEvent, - BidiMessageAddedEvent, -) -from strands.types.tools import ToolResult, ToolUse - - -@pytest.fixture -def agent(): - return Mock() - - -@pytest.fixture -def tool(): - tool = Mock() - tool.tool_name = "test_tool" - return tool - - -@pytest.fixture -def tool_use(): - return ToolUse(name="test_tool", toolUseId="123", input={"param": "value"}) - - -@pytest.fixture -def tool_invocation_state(): - return {"param": "value"} - - -@pytest.fixture -def tool_result(): - return ToolResult(content=[{"text": "result"}], status="success", toolUseId="123") - - -@pytest.fixture -def message(): - return {"role": "user", "content": [{"text": "Hello"}]} - - -@pytest.fixture -def initialized_event(agent): - return BidiAgentInitializedEvent(agent=agent) - - -@pytest.fixture -def before_invocation_event(agent): - return BidiBeforeInvocationEvent(agent=agent) - - -@pytest.fixture -def after_invocation_event(agent): - return BidiAfterInvocationEvent(agent=agent) - - -@pytest.fixture -def message_added_event(agent, message): - return BidiMessageAddedEvent(agent=agent, message=message) - - -@pytest.fixture -def before_tool_event(agent, tool, tool_use, tool_invocation_state): - return BidiBeforeToolCallEvent( - agent=agent, - selected_tool=tool, - tool_use=tool_use, - invocation_state=tool_invocation_state, - ) - - -@pytest.fixture -def after_tool_event(agent, tool, tool_use, tool_invocation_state, tool_result): - return BidiAfterToolCallEvent( - agent=agent, - selected_tool=tool, - tool_use=tool_use, - invocation_state=tool_invocation_state, - result=tool_result, - ) - - -@pytest.fixture -def interruption_event(agent): - return BidiInterruptionEvent(agent=agent, reason="user_speech") - - -def test_event_should_reverse_callbacks( - initialized_event, - before_invocation_event, - after_invocation_event, - message_added_event, - before_tool_event, - after_tool_event, - interruption_event, -): - """Verify which events use reverse callback ordering.""" - # note that we ignore E712 (explicit booleans) for consistency/readability purposes - - assert initialized_event.should_reverse_callbacks == False # noqa: E712 - assert message_added_event.should_reverse_callbacks == False # noqa: E712 - assert interruption_event.should_reverse_callbacks == False # noqa: E712 - - assert before_invocation_event.should_reverse_callbacks == False # noqa: E712 - assert after_invocation_event.should_reverse_callbacks == True # noqa: E712 - - assert before_tool_event.should_reverse_callbacks == False # noqa: E712 - assert after_tool_event.should_reverse_callbacks == True # noqa: E712 - - -def test_interruption_event_with_response_id(agent): - """Verify BidiInterruptionEvent can include response ID.""" - event = BidiInterruptionEvent(agent=agent, reason="error", interrupted_response_id="resp_123") - - assert event.reason == "error" - assert event.interrupted_response_id == "resp_123" - - -def test_message_added_event_cannot_write_properties(message_added_event): - """Verify BidiMessageAddedEvent properties are read-only.""" - with pytest.raises(AttributeError, match="Property agent is not writable"): - message_added_event.agent = Mock() - with pytest.raises(AttributeError, match="Property message is not writable"): - message_added_event.message = {} - - -def test_before_tool_call_event_can_write_properties(before_tool_event): - """Verify BidiBeforeToolCallEvent allows writing specific properties.""" - new_tool_use = ToolUse(name="new_tool", toolUseId="456", input={}) - before_tool_event.selected_tool = None # Should not raise - before_tool_event.tool_use = new_tool_use # Should not raise - - -def test_before_tool_call_event_cannot_write_properties(before_tool_event): - """Verify BidiBeforeToolCallEvent protects certain properties.""" - with pytest.raises(AttributeError, match="Property agent is not writable"): - before_tool_event.agent = Mock() - with pytest.raises(AttributeError, match="Property invocation_state is not writable"): - before_tool_event.invocation_state = {} - - -def test_after_tool_call_event_can_write_properties(after_tool_event): - """Verify BidiAfterToolCallEvent allows writing result property.""" - new_result = ToolResult(content=[{"text": "new result"}], status="success", toolUseId="456") - after_tool_event.result = new_result # Should not raise - - -def test_after_tool_call_event_cannot_write_properties(after_tool_event): - """Verify BidiAfterToolCallEvent protects certain properties.""" - with pytest.raises(AttributeError, match="Property agent is not writable"): - after_tool_event.agent = Mock() - with pytest.raises(AttributeError, match="Property selected_tool is not writable"): - after_tool_event.selected_tool = None - with pytest.raises(AttributeError, match="Property tool_use is not writable"): - after_tool_event.tool_use = ToolUse(name="new", toolUseId="456", input={}) - with pytest.raises(AttributeError, match="Property invocation_state is not writable"): - after_tool_event.invocation_state = {} - with pytest.raises(AttributeError, match="Property exception is not writable"): - after_tool_event.exception = Exception("test") diff --git a/tests_integ/bidi/test_bidi_hooks.py b/tests_integ/bidi/test_bidi_hooks.py index f6cad162a..cb7def664 100644 --- a/tests_integ/bidi/test_bidi_hooks.py +++ b/tests_integ/bidi/test_bidi_hooks.py @@ -2,63 +2,15 @@ import pytest -from src.strands import tool -from src.strands.experimental.bidi.agent.agent import BidiAgent -from src.strands.experimental.bidi.hooks import ( +from strands import tool +from strands.experimental.bidi.agent.agent import BidiAgent +from strands.experimental.hooks.events import ( BidiAfterInvocationEvent, - BidiAfterToolCallEvent, - BidiAgentInitializedEvent, BidiBeforeInvocationEvent, - BidiBeforeToolCallEvent, - BidiInterruptionEvent, - BidiMessageAddedEvent, - HookProvider, ) +from strands.hooks import HookProvider - -class HookEventCollector(HookProvider): - """Hook provider that collects all emitted events for testing.""" - - def __init__(self): - self.events = [] - - def register_hooks(self, registry): - registry.add_callback(BidiAgentInitializedEvent, self.on_initialized) - registry.add_callback(BidiBeforeInvocationEvent, self.on_before_invocation) - registry.add_callback(BidiAfterInvocationEvent, self.on_after_invocation) - registry.add_callback(BidiBeforeToolCallEvent, self.on_before_tool_call) - registry.add_callback(BidiAfterToolCallEvent, self.on_after_tool_call) - registry.add_callback(BidiMessageAddedEvent, self.on_message_added) - registry.add_callback(BidiInterruptionEvent, self.on_interruption) - - def on_initialized(self, event: BidiAgentInitializedEvent): - self.events.append(("initialized", event)) - - def on_before_invocation(self, event: BidiBeforeInvocationEvent): - self.events.append(("before_invocation", event)) - - def on_after_invocation(self, event: BidiAfterInvocationEvent): - self.events.append(("after_invocation", event)) - - def on_before_tool_call(self, event: BidiBeforeToolCallEvent): - self.events.append(("before_tool_call", event)) - - def on_after_tool_call(self, event: BidiAfterToolCallEvent): - self.events.append(("after_tool_call", event)) - - def on_message_added(self, event: BidiMessageAddedEvent): - self.events.append(("message_added", event)) - - def on_interruption(self, event: BidiInterruptionEvent): - self.events.append(("interruption", event)) - - def get_event_types(self): - """Get list of event type names in order.""" - return [event_type for event_type, _ in self.events] - - def get_events_by_type(self, event_type): - """Get all events of a specific type.""" - return [event for et, event in self.events if et == event_type] +from .hook_utils import HookEventCollector @pytest.mark.asyncio diff --git a/tests_integ/bidi/test_bidirectional_agent.py b/tests_integ/bidi/test_bidirectional_agent.py index 0d3b41607..ee6da01a5 100644 --- a/tests_integ/bidi/test_bidirectional_agent.py +++ b/tests_integ/bidi/test_bidirectional_agent.py @@ -19,6 +19,7 @@ from strands.experimental.bidi.models.openai import BidiOpenAIRealtimeModel from .context import BidirectionalTestContext +from .hook_utils import HookEventCollector logger = logging.getLogger(__name__) @@ -135,7 +136,13 @@ def provider_config(request): @pytest.fixture -def agent_with_calculator(provider_config): +def hook_collector(): + """Provide a hook event collector for tracking all events.""" + return HookEventCollector() + + +@pytest.fixture +def agent_with_calculator(provider_config, hook_collector): """Provide bidirectional agent with calculator tool for the given provider. Note: Session lifecycle (start/end) is handled by BidirectionalTestContext. @@ -148,11 +155,12 @@ def agent_with_calculator(provider_config): model=model, tools=[calculator], system_prompt="You are a helpful assistant with access to a calculator tool. Keep responses brief.", + hooks=[hook_collector], ) @pytest.mark.asyncio -async def test_bidirectional_agent(agent_with_calculator, audio_generator, provider_config): +async def test_bidirectional_agent(agent_with_calculator, audio_generator, provider_config, hook_collector): """Test multi-turn conversation with follow-up questions across providers. This test runs against all configured providers (Nova Sonic, OpenAI, etc.) @@ -162,7 +170,7 @@ async def test_bidirectional_agent(agent_with_calculator, audio_generator, provi - Session lifecycle (start/end via context manager) - Audio input streaming - Speech-to-text transcription - - Tool execution (calculator) + - Tool execution (calculator) with hook verification - Multi-turn conversation flow - Text-to-speech audio output """ @@ -206,6 +214,20 @@ async def test_bidirectional_agent(agent_with_calculator, audio_generator, provi assert len(audio_outputs) > 0, f"[{provider_name}] No audio output received" total_audio_bytes = sum(len(audio) for audio in audio_outputs) + # Verify tool execution hooks if tools were called + tool_calls = hook_collector.get_tool_calls() + if len(tool_calls) > 0: + logger.info("provider=<%s> | tool execution detected", provider_name) + # Verify hooks are properly paired + verified_tools = hook_collector.verify_tool_execution() + logger.info( + "provider=<%s>, tools_called=<%s> | tool execution hooks verified", + provider_name, + verified_tools, + ) + else: + logger.info("provider=<%s> | no tools were called during conversation", provider_name) + # Summary logger.info("=" * 60) logger.info("provider=<%s> | multi-turn conversation test passed", provider_name) @@ -217,4 +239,9 @@ async def test_bidirectional_agent(agent_with_calculator, audio_generator, provi len(audio_outputs), total_audio_bytes, ) + logger.info( + "tool_calls=<%d> | tool execution count", + len(tool_calls), + ) logger.info("=" * 60) + From 7a00e08286bd7afe24bc2b71cb7dace8962601ec Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 20 Nov 2025 15:48:11 +0100 Subject: [PATCH 156/242] fix: Initialize model param with none on init to fix tests --- src/strands/experimental/bidi/models/gemini_live.py | 4 ++-- src/strands/experimental/bidi/models/novasonic.py | 2 +- src/strands/experimental/bidi/models/openai.py | 6 ++++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 5f7eb587f..1b71ddb94 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -99,9 +99,9 @@ def __init__( self.client = genai.Client(**client_kwargs) # Connection state (initialized in start()) - self.live_session: Any + self.live_session: Any = None self.live_session_context_manager = None - self.connection_id | str + self.connection_id: str | None = None self._active: bool = False async def start( diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 8e83a4947..7ddc8f85b 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -103,7 +103,7 @@ def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-e # Connection state (initialized in start()) self.stream: Any = None - self.connection_id: str = "" + self.connection_id: str | None = None self._active = False # Nova Sonic requires unique content names diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 3cda4f738..a0dae9237 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -104,8 +104,8 @@ def __init__( ) # Connection state (initialized in start()) - self.websocket: ClientConnection - self.connection_id: str + self.websocket: ClientConnection | None = None + self.connection_id: str | None = None self._active: bool = False self._function_call_buffer: dict[str, Any] = {} @@ -574,6 +574,8 @@ async def send( tool_result = content.get("tool_result") if tool_result: await self._send_tool_result(tool_result) + else: + logger.warning("Unknown content type: %s", type(content).__name__) except Exception as e: logger.error("error=<%s> | error sending content to openai", e) raise # Propagate exception for debugging in experimental code From d250c4c17ff49ed078ba94801d2782d56b912f0b Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 20 Nov 2025 16:00:13 +0100 Subject: [PATCH 157/242] feat: Add invocation state to bidi agent --- src/strands/experimental/bidi/agent/agent.py | 37 +++++++++++++++++--- src/strands/experimental/bidi/agent/loop.py | 10 +++++- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index c766fc634..cedebd66c 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -140,6 +140,9 @@ def __init__( for hook in hooks: self.hooks.add_hook(hook) + # Initialize invocation state (will be set in start()) + self._invocation_state: dict[str, Any] = {} + self._loop = _BidiAgentLoop(self) # Emit initialization event @@ -256,14 +259,31 @@ def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: di properties = tool_spec["inputSchema"]["json"]["properties"] return {k: v for k, v in input_params.items() if k in properties} - async def start(self) -> None: + async def start(self, invocation_state: dict[str, Any] | None = None) -> None: """Start a persistent bidirectional conversation connection. Initializes the streaming connection and starts background tasks for processing model events, tool execution, and connection management. + + Args: + invocation_state: Optional context to pass to tools during execution. + This allows passing custom data (user_id, session_id, database connections, etc.) + that tools can access via their invocation_state parameter. + + Example: + ```python + await agent.start(invocation_state={ + "user_id": "user_123", + "session_id": "session_456", + "database": db_connection, + }) + ``` """ logger.debug("agent starting") + # Store invocation_state for use during tool execution + self._invocation_state = invocation_state or {} + await self._loop.start() async def send(self, input_data: BidiAgentInput) -> None: @@ -404,19 +424,28 @@ def active(self) -> bool: """True if agent loop started, False otherwise.""" return self._loop.active - async def run(self, inputs: list[BidiInput], outputs: list[BidiOutput]) -> None: + async def run( + self, inputs: list[BidiInput], outputs: list[BidiOutput], invocation_state: dict[str, Any] | None = None + ) -> None: """Run the agent using provided IO channels for bidirectional communication. Args: inputs: Input callables to read data from a source outputs: Output callables to receive events from the agent + invocation_state: Optional context to pass to tools during execution. + This allows passing custom data (user_id, session_id, database connections, etc.) + that tools can access via their invocation_state parameter. Example: ```python audio_io = BidiAudioIO(input_rate=16000) text_io = BidiTextIO() agent = BidiAgent(model=model, tools=[calculator]) - await agent.run(inputs=[audio_io.input()], outputs=[audio_io.output(), text_io.output()]) + await agent.run( + inputs=[audio_io.input()], + outputs=[audio_io.output(), text_io.output()], + invocation_state={"user_id": "user_123"} + ) ``` """ @@ -434,7 +463,7 @@ async def run_outputs() -> None: tasks = [output(event) for output in outputs] await asyncio.gather(*tasks) - await self.start() + await self.start(invocation_state=invocation_state) for input_ in inputs: if hasattr(input_, "start"): diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 5200e8d0a..36e25f160 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -173,7 +173,15 @@ async def _run_tool(self, tool_use: ToolUse) -> None: logger.debug("tool_name=<%s> | tool execution starting", tool_use["name"]) tool_results: list[ToolResult] = [] - invocation_state: dict[str, Any] = {"agent": self._agent} + + # Build invocation_state from stored state and current agent context + invocation_state: dict[str, Any] = { + **self._agent._invocation_state, # User-provided context + "agent": self._agent, # Always include agent reference + "model": self._agent.model, + "messages": self._agent.messages, + "system_prompt": self._agent.system_prompt, + } # Use the tool executor to run the tool (no tracing/metrics for BidiAgent yet) tool_events = self._agent.tool_executor._stream( From 516b43158fb795ba3f9e413d2e33a757b16d5fd2 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 20 Nov 2025 16:00:38 +0100 Subject: [PATCH 158/242] run formatter --- src/strands/experimental/bidi/agent/loop.py | 2 +- src/strands/tools/executors/_executor.py | 10 +++++----- tests_integ/bidi/test_bidirectional_agent.py | 1 - 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 36e25f160..935200070 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -173,7 +173,7 @@ async def _run_tool(self, tool_use: ToolUse) -> None: logger.debug("tool_name=<%s> | tool execution starting", tool_use["name"]) tool_results: list[ToolResult] = [] - + # Build invocation_state from stored state and current agent context invocation_state: dict[str, Any] = { **self._agent._invocation_state, # User-provided context diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 969be521b..562895038 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -33,12 +33,12 @@ class ToolExecutor(abc.ABC): @staticmethod def _is_bidi_agent(agent: Union["Agent", "BidiAgent"]) -> bool: """Check if the agent is a BidiAgent instance. - + Uses isinstance() with runtime import to avoid circular imports. """ # Import at runtime to avoid circular dependency from ...experimental.bidi.agent.agent import BidiAgent - + return isinstance(agent, BidiAgent) @staticmethod @@ -181,7 +181,7 @@ async def _stream( "status": "error", "content": [{"text": cancel_message}], } - + after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( agent, None, tool_use, invocation_state, cancel_result, cancel_message=cancel_message ) @@ -213,7 +213,7 @@ async def _stream( "status": "error", "content": [{"text": f"Unknown tool: {tool_name}"}], } - + after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( agent, selected_tool, tool_use, invocation_state, result ) @@ -259,7 +259,7 @@ async def _stream( "status": "error", "content": [{"text": f"Error: {str(e)}"}], } - + after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( agent, selected_tool, tool_use, invocation_state, error_result, exception=e ) diff --git a/tests_integ/bidi/test_bidirectional_agent.py b/tests_integ/bidi/test_bidirectional_agent.py index ee6da01a5..ebc92c852 100644 --- a/tests_integ/bidi/test_bidirectional_agent.py +++ b/tests_integ/bidi/test_bidirectional_agent.py @@ -244,4 +244,3 @@ async def test_bidirectional_agent(agent_with_calculator, audio_generator, provi len(tool_calls), ) logger.info("=" * 60) - From 7254c283da0ae5c0586ee95db87ae2e22fd25751 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 20 Nov 2025 16:03:46 +0100 Subject: [PATCH 159/242] add hook utils for testing --- tests_integ/bidi/hook_utils.py | 76 ++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 tests_integ/bidi/hook_utils.py diff --git a/tests_integ/bidi/hook_utils.py b/tests_integ/bidi/hook_utils.py new file mode 100644 index 000000000..ea51a029e --- /dev/null +++ b/tests_integ/bidi/hook_utils.py @@ -0,0 +1,76 @@ +"""Shared utilities for testing BidiAgent hooks.""" + +from strands.experimental.hooks.events import ( + BidiAfterInvocationEvent, + BidiAfterToolCallEvent, + BidiAgentInitializedEvent, + BidiBeforeInvocationEvent, + BidiBeforeToolCallEvent, + BidiInterruptionEvent, + BidiMessageAddedEvent, +) +from strands.hooks import HookProvider + + +class HookEventCollector(HookProvider): + """Hook provider that collects all emitted events for testing.""" + + def __init__(self): + self.events = [] + + def register_hooks(self, registry): + registry.add_callback(BidiAgentInitializedEvent, self.on_initialized) + registry.add_callback(BidiBeforeInvocationEvent, self.on_before_invocation) + registry.add_callback(BidiAfterInvocationEvent, self.on_after_invocation) + registry.add_callback(BidiBeforeToolCallEvent, self.on_before_tool_call) + registry.add_callback(BidiAfterToolCallEvent, self.on_after_tool_call) + registry.add_callback(BidiMessageAddedEvent, self.on_message_added) + registry.add_callback(BidiInterruptionEvent, self.on_interruption) + + def on_initialized(self, event: BidiAgentInitializedEvent): + self.events.append(("initialized", event)) + + def on_before_invocation(self, event: BidiBeforeInvocationEvent): + self.events.append(("before_invocation", event)) + + def on_after_invocation(self, event: BidiAfterInvocationEvent): + self.events.append(("after_invocation", event)) + + def on_before_tool_call(self, event: BidiBeforeToolCallEvent): + self.events.append(("before_tool_call", event)) + + def on_after_tool_call(self, event: BidiAfterToolCallEvent): + self.events.append(("after_tool_call", event)) + + def on_message_added(self, event: BidiMessageAddedEvent): + self.events.append(("message_added", event)) + + def on_interruption(self, event: BidiInterruptionEvent): + self.events.append(("interruption", event)) + + def get_event_types(self): + """Get list of event type names in order.""" + return [event_type for event_type, _ in self.events] + + def get_events_by_type(self, event_type): + """Get all events of a specific type.""" + return [event for et, event in self.events if et == event_type] + + def get_tool_calls(self): + """Get list of tool names that were called.""" + before_calls = self.get_events_by_type("before_tool_call") + return [event.tool_use["name"] for event in before_calls] + + def verify_tool_execution(self): + """Verify that tool execution hooks were properly paired.""" + before_calls = self.get_events_by_type("before_tool_call") + after_calls = self.get_events_by_type("after_tool_call") + + assert len(before_calls) == len(after_calls), "Before and after tool call hooks should be paired" + + before_tools = [event.tool_use["name"] for event in before_calls] + after_tools = [event.tool_use["name"] for event in after_calls] + + assert before_tools == after_tools, "Tool call order should match between before and after hooks" + + return before_tools From dd666999e9853e0fab0dc8d22e36eb8b2b39e995 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 20 Nov 2025 16:20:29 +0100 Subject: [PATCH 160/242] test: Restore BidiAgent hook event tests in new location - Moved tests from tests/strands/experimental/bidi/hooks/ to tests/strands/experimental/hooks/ - Updated imports to use new hook location (strands.experimental.hooks) - Added test for new cancel_tool attribute in BidiBeforeToolCallEvent - All 7 tests passing --- .../hooks/test_bidi_hook_events.py | 169 ++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 tests/strands/experimental/hooks/test_bidi_hook_events.py diff --git a/tests/strands/experimental/hooks/test_bidi_hook_events.py b/tests/strands/experimental/hooks/test_bidi_hook_events.py new file mode 100644 index 000000000..4d49243b2 --- /dev/null +++ b/tests/strands/experimental/hooks/test_bidi_hook_events.py @@ -0,0 +1,169 @@ +"""Unit tests for BidiAgent hook events.""" + +from unittest.mock import Mock + +import pytest + +from strands.experimental.hooks import ( + BidiAfterInvocationEvent, + BidiAfterToolCallEvent, + BidiAgentInitializedEvent, + BidiBeforeInvocationEvent, + BidiBeforeToolCallEvent, + BidiInterruptionEvent, + BidiMessageAddedEvent, +) +from strands.types.tools import ToolResult, ToolUse + + +@pytest.fixture +def agent(): + return Mock() + + +@pytest.fixture +def tool(): + tool = Mock() + tool.tool_name = "test_tool" + return tool + + +@pytest.fixture +def tool_use(): + return ToolUse(name="test_tool", toolUseId="123", input={"param": "value"}) + + +@pytest.fixture +def tool_invocation_state(): + return {"param": "value"} + + +@pytest.fixture +def tool_result(): + return ToolResult(content=[{"text": "result"}], status="success", toolUseId="123") + + +@pytest.fixture +def message(): + return {"role": "user", "content": [{"text": "Hello"}]} + + +@pytest.fixture +def initialized_event(agent): + return BidiAgentInitializedEvent(agent=agent) + + +@pytest.fixture +def before_invocation_event(agent): + return BidiBeforeInvocationEvent(agent=agent) + + +@pytest.fixture +def after_invocation_event(agent): + return BidiAfterInvocationEvent(agent=agent) + + +@pytest.fixture +def message_added_event(agent, message): + return BidiMessageAddedEvent(agent=agent, message=message) + + +@pytest.fixture +def before_tool_event(agent, tool, tool_use, tool_invocation_state): + return BidiBeforeToolCallEvent( + agent=agent, + selected_tool=tool, + tool_use=tool_use, + invocation_state=tool_invocation_state, + ) + + +@pytest.fixture +def after_tool_event(agent, tool, tool_use, tool_invocation_state, tool_result): + return BidiAfterToolCallEvent( + agent=agent, + selected_tool=tool, + tool_use=tool_use, + invocation_state=tool_invocation_state, + result=tool_result, + ) + + +@pytest.fixture +def interruption_event(agent): + return BidiInterruptionEvent(agent=agent, reason="user_speech") + + +def test_event_should_reverse_callbacks( + initialized_event, + before_invocation_event, + after_invocation_event, + message_added_event, + before_tool_event, + after_tool_event, + interruption_event, +): + """Verify which events use reverse callback ordering.""" + # note that we ignore E712 (explicit booleans) for consistency/readability purposes + + assert initialized_event.should_reverse_callbacks == False # noqa: E712 + assert message_added_event.should_reverse_callbacks == False # noqa: E712 + assert interruption_event.should_reverse_callbacks == False # noqa: E712 + + assert before_invocation_event.should_reverse_callbacks == False # noqa: E712 + assert after_invocation_event.should_reverse_callbacks == True # noqa: E712 + + assert before_tool_event.should_reverse_callbacks == False # noqa: E712 + assert after_tool_event.should_reverse_callbacks == True # noqa: E712 + + +def test_interruption_event_with_response_id(agent): + """Verify BidiInterruptionEvent can include response ID.""" + event = BidiInterruptionEvent(agent=agent, reason="error", interrupted_response_id="resp_123") + + assert event.reason == "error" + assert event.interrupted_response_id == "resp_123" + + +def test_message_added_event_cannot_write_properties(message_added_event): + """Verify BidiMessageAddedEvent properties are read-only.""" + with pytest.raises(AttributeError, match="Property agent is not writable"): + message_added_event.agent = Mock() + with pytest.raises(AttributeError, match="Property message is not writable"): + message_added_event.message = {} + + +def test_before_tool_call_event_can_write_properties(before_tool_event): + """Verify BidiBeforeToolCallEvent allows writing specific properties.""" + new_tool_use = ToolUse(name="new_tool", toolUseId="456", input={}) + before_tool_event.selected_tool = None # Should not raise + before_tool_event.tool_use = new_tool_use # Should not raise + before_tool_event.cancel_tool = "Cancelled by user" # Should not raise + + +def test_before_tool_call_event_cannot_write_properties(before_tool_event): + """Verify BidiBeforeToolCallEvent protects certain properties.""" + with pytest.raises(AttributeError, match="Property agent is not writable"): + before_tool_event.agent = Mock() + with pytest.raises(AttributeError, match="Property invocation_state is not writable"): + before_tool_event.invocation_state = {} + + +def test_after_tool_call_event_can_write_properties(after_tool_event): + """Verify BidiAfterToolCallEvent allows writing result property.""" + new_result = ToolResult(content=[{"text": "new result"}], status="success", toolUseId="456") + after_tool_event.result = new_result # Should not raise + + +def test_after_tool_call_event_cannot_write_properties(after_tool_event): + """Verify BidiAfterToolCallEvent protects certain properties.""" + with pytest.raises(AttributeError, match="Property agent is not writable"): + after_tool_event.agent = Mock() + with pytest.raises(AttributeError, match="Property selected_tool is not writable"): + after_tool_event.selected_tool = None + with pytest.raises(AttributeError, match="Property tool_use is not writable"): + after_tool_event.tool_use = ToolUse(name="new", toolUseId="456", input={}) + with pytest.raises(AttributeError, match="Property invocation_state is not writable"): + after_tool_event.invocation_state = {} + with pytest.raises(AttributeError, match="Property exception is not writable"): + after_tool_event.exception = Exception("test") From 8d9b7b98ada0f8ba4a9dabb3fd11411e67f96711 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 20 Nov 2025 16:27:15 +0100 Subject: [PATCH 161/242] refactor: Use type name check instead of runtime import for BidiAgent detection - Changed from isinstance() with runtime import to type(agent).__name__ check - Avoids circular import issues without runtime overhead - Simpler and more explicit than duck typing - All tests passing --- src/strands/tools/executors/_executor.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 562895038..467d22936 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -32,14 +32,12 @@ class ToolExecutor(abc.ABC): @staticmethod def _is_bidi_agent(agent: Union["Agent", "BidiAgent"]) -> bool: - """Check if the agent is a BidiAgent instance. + """Check if the agent is a BidiAgent by type name. - Uses isinstance() with runtime import to avoid circular imports. + Uses type name comparison to avoid circular imports while maintaining + type safety. This works because we control both Agent and BidiAgent types. """ - # Import at runtime to avoid circular dependency - from ...experimental.bidi.agent.agent import BidiAgent - - return isinstance(agent, BidiAgent) + return type(agent).__name__ == "BidiAgent" @staticmethod async def _invoke_before_tool_call_hook( From e1886048dd916988dd32367983f4923e711ea1d5 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 20 Nov 2025 11:15:50 -0500 Subject: [PATCH 162/242] cancellation - nova sonic (#64) --- src/strands/experimental/bidi/_async.py | 50 +++ .../experimental/bidi/models/bidi_model.py | 3 +- .../experimental/bidi/models/novasonic.py | 374 +++++++----------- src/strands/experimental/bidi/types/_async.py | 15 + .../bidi/models/test_novasonic.py | 244 +++++------- tests/strands/experimental/bidi/test_async.py | 57 +++ 6 files changed, 354 insertions(+), 389 deletions(-) create mode 100644 src/strands/experimental/bidi/_async.py create mode 100644 src/strands/experimental/bidi/types/_async.py create mode 100644 tests/strands/experimental/bidi/test_async.py diff --git a/src/strands/experimental/bidi/_async.py b/src/strands/experimental/bidi/_async.py new file mode 100644 index 000000000..a4a126c16 --- /dev/null +++ b/src/strands/experimental/bidi/_async.py @@ -0,0 +1,50 @@ +"""Utilities for async operations.""" + +from functools import wraps +from typing import Any, Awaitable, Callable, TypeVar, cast + +from .types._async import Startable + +F = TypeVar("F", bound=Callable[..., Awaitable[None]]) + + +def start(func: F) -> F: + """Call stop if pairing start call fails. + + Any resources that did successfully start will still have an opportunity to stop cleanly. + + Args: + func: Start function to wrap. + """ + + @wraps(func) + async def wrapper(self: Startable, *args: Any, **kwargs: Any) -> None: + try: + await func(self, *args, **kwargs) + except Exception: + await self.stop() + raise + + return cast(F, wrapper) + + +async def stop(*funcs: F) -> None: + """Call all stops in sequence and aggregate errors. + + A failure in one stop call will not block subsequent stop calls. + + Args: + funcs: Stop functions to call in sequence. + + Raises: + ExceptionGroup: If any stop function raises an exception. + """ + exceptions = [] + for func in funcs: + try: + await func() + except Exception as exception: + exceptions.append(exception) + + if exceptions: + raise ExceptionGroup("failed stop sequence", exceptions) # type: ignore # noqa: F821 diff --git a/src/strands/experimental/bidi/models/bidi_model.py b/src/strands/experimental/bidi/models/bidi_model.py index 04e4a69e4..856a2edcf 100644 --- a/src/strands/experimental/bidi/models/bidi_model.py +++ b/src/strands/experimental/bidi/models/bidi_model.py @@ -18,6 +18,7 @@ from ....types._events import ToolResultEvent from ....types.content import Messages from ....types.tools import ToolSpec +from ..types._async import Startable from ..types.events import ( BidiInputEvent, BidiOutputEvent, @@ -26,7 +27,7 @@ logger = logging.getLogger(__name__) -class BidiModel(Protocol): +class BidiModel(Startable, Protocol): """Protocol for bidirectional streaming models. This interface defines the contract for models that support persistent streaming diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 8e83a4947..ccaa90dc8 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -16,7 +16,6 @@ import base64 import json import logging -import traceback import uuid from typing import Any, AsyncIterable @@ -27,17 +26,17 @@ InvokeModelWithBidirectionalStreamInputChunk, ) from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver +from smithy_core.aio.eventstream import DuplexEventStream from ....types._events import ToolResultEvent, ToolUseStreamEvent from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse +from .._async import start, stop from ..types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, BidiConnectionCloseEvent, BidiConnectionStartEvent, - BidiErrorEvent, - BidiImageInputEvent, BidiInputEvent, BidiInterruptionEvent, BidiOutputEvent, @@ -76,9 +75,6 @@ NOVA_TEXT_CONFIG = {"mediaType": "text/plain"} NOVA_TOOL_CONFIG = {"mediaType": "application/json"} -# Timing constants -RESPONSE_TIMEOUT = 1.0 - class BidiNovaSonicModel(BidiModel): """Nova Sonic implementation for bidirectional streaming. @@ -86,8 +82,13 @@ class BidiNovaSonicModel(BidiModel): Combines model configuration and connection state in a single class. Manages Nova Sonic's complex event sequencing, audio format conversion, and tool execution patterns while providing the standard BidiModel interface. + + Attributes: + _stream: open bedrock stream to nova sonic. """ + _stream: DuplexEventStream + def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **kwargs: Any) -> None: """Initialize Nova Sonic bidirectional model. @@ -96,25 +97,15 @@ def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-e region: AWS region. **kwargs: Reserved for future parameters. """ - # Model configuration self.model_id = model_id self.region = region - self.client: Any = None - - # Connection state (initialized in start()) - self.stream: Any = None - self.connection_id: str = "" - self._active = False - - # Nova Sonic requires unique content names - self.audio_content_name: str | None = None - - # Audio connection state - self.audio_connection_active = False # Track API-provided identifiers + self._connection_id: str | None = None + self._audio_content_name: str | None = None self._current_completion_id: str | None = None - self._current_role: str | None = None + + # Indicates if model is done generating transcript self._generation_stage: str | None = None # Ensure certain events are sent in sequence when required @@ -122,6 +113,7 @@ def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-e logger.debug("model_id=<%s> | nova sonic model initialized", model_id) + @start async def start( self, system_prompt: str | None = None, @@ -136,65 +128,52 @@ async def start( tools: List of tools available to the model. messages: Conversation history to initialize with. **kwargs: Additional configuration options. + + Raises: + RuntimeError: If user calls start again without first stopping. """ - if self._active: - raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") + if self._connection_id: + raise RuntimeError("call stop before starting again") logger.debug("nova connection starting") - try: - # Initialize client if needed - if not self.client: - await self._initialize_client() - - # Initialize connection state - self.connection_id = str(uuid.uuid4()) - self._active = True - self.audio_content_name = str(uuid.uuid4()) - - # Start Nova Sonic bidirectional stream - self.stream = await self.client.invoke_model_with_bidirectional_stream( - InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) - ) - - # Validate stream - if not self.stream: - logger.error("Stream is None") - raise ValueError("Stream cannot be None") - - logger.debug("connection_id=<%s> | nova sonic connection initialized", self.connection_id) + self._connection_id = str(uuid.uuid4()) - # Send initialization events - system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." - init_events = self._build_initialization_events(system_prompt, tools or [], messages) - - logger.debug("event_count=<%d> | sending nova sonic initialization events", len(init_events)) - await self._send_initialization_events(init_events) + config = Config( + endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", + region=self.region, + aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), + auth_scheme_resolver=HTTPAuthSchemeResolver(), + auth_schemes={"aws.auth#sigv4": SigV4AuthScheme(service="bedrock")}, + ) + client = BedrockRuntimeClient(config=config) + self._stream = await client.invoke_model_with_bidirectional_stream( + InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) + ) + logger.debug("region=<%s> | nova sonic client initialized", self.region) - logger.info("connection_id=<%s> | nova sonic connection established", self.connection_id) + init_events = self._build_initialization_events(system_prompt, tools, messages) + logger.debug("event_count=<%d> | sending nova sonic initialization events", len(init_events)) + await self._send_nova_events(init_events) - except Exception as e: - self._active = False - logger.error("error=<%s> | nova connection create failed", str(e)) - raise + logger.info("connection_id=<%s> | nova sonic connection established", self._connection_id) def _build_initialization_events( - self, system_prompt: str, tools: list[ToolSpec], messages: Messages | None = None + self, system_prompt: str | None, tools: list[ToolSpec] | None, messages: Messages | None ) -> list[str]: """Build the sequence of initialization events.""" - events = [self._get_connection_start_event(), self._get_prompt_start_event(tools)] - - events.extend(self._get_system_prompt_events(system_prompt)) + tools = tools or [] + events = [ + self._get_connection_start_event(), + self._get_prompt_start_event(tools), + *self._get_system_prompt_events(system_prompt), + ] - # Message history would be processed here if needed in the future + # TODO: Message history would be processed here if needed in the future # Currently not implemented as it's not used in the existing test cases return events - async def _send_initialization_events(self, events: list[str]) -> None: - """Send initialization events.""" - await self._send_nova_event(events) - def _log_event_type(self, nova_event: dict[str, Any]) -> None: """Log specific Nova Sonic event types for debugging.""" if "usageEvent" in nova_event: @@ -214,86 +193,66 @@ def _log_event_type(self, nova_event: dict[str, Any]) -> None: logger.debug("audio_bytes=<%d> | nova audio output received", len(audio_bytes)) async def receive(self) -> AsyncIterable[BidiOutputEvent]: # type: ignore[override] - """Receive Nova Sonic events and convert to provider-agnostic format.""" - if not self.stream: - logger.error("Stream is None") - return + """Receive Nova Sonic events and convert to provider-agnostic format. - logger.debug("nova event stream starting") + Raises: + RuntimeError: If start has not been called. + """ + if not self._connection_id: + raise RuntimeError("must call start") - # Emit connection start event - yield BidiConnectionStartEvent(connection_id=self.connection_id, model=self.model_id) + logger.debug("nova event stream starting") + yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) try: - while self._active and self.stream: - try: - output = await asyncio.wait_for(self.stream.await_output(), timeout=RESPONSE_TIMEOUT) - result = await output[1].receive() - - response_data = result.value.bytes_.decode("utf-8") - json_data = json.loads(response_data) - nova_event = json_data["event"] - self._log_event_type(nova_event) - - # Convert to provider-agnostic format - provider_event = self._convert_nova_event(nova_event) - if provider_event: - yield provider_event - - except asyncio.TimeoutError: - continue - - except Exception as e: - logger.error("error=<%s> | error receiving nova sonic event", e) - logger.error(traceback.format_exc()) - yield BidiErrorEvent(error=e) + _, output = await self._stream.await_output() + while True: + event_data = await output.receive() + nova_event = json.loads(event_data.value.bytes_.decode("utf-8"))["event"] + self._log_event_type(nova_event) + + model_event = self._convert_nova_event(nova_event) + if model_event: + yield model_event finally: - # Emit connection close event - yield BidiConnectionCloseEvent(connection_id=self.connection_id, reason="complete") + yield BidiConnectionCloseEvent(connection_id=self._connection_id, reason="complete") - async def send( - self, - content: BidiInputEvent | ToolResultEvent, - ) -> None: + async def send(self, content: BidiInputEvent | ToolResultEvent) -> None: """Unified send method for all content types. Sends the given content to Nova Sonic. Dispatches to appropriate internal handler based on content type. Args: - content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent). - """ - if not self._active: - return + content: Input event. - try: - if isinstance(content, BidiTextInputEvent): - await self._send_text_content(content.text) - elif isinstance(content, BidiAudioInputEvent): - await self._send_audio_content(content) - elif isinstance(content, BidiImageInputEvent): - # BidiImageInputEvent - not supported by Nova Sonic - logger.warning("Image input not supported by Nova Sonic") - elif isinstance(content, ToolResultEvent): - tool_result = content.get("tool_result") - if tool_result: - await self._send_tool_result(tool_result) - except Exception as e: - logger.error("error=<%s> | error sending content to nova sonic", e) - raise # Propagate exception for debugging in experimental code + Raises: + ValueError: If content type not supported (e.g., image content). + """ + if not self._connection_id: + raise RuntimeError("must call start") + + if isinstance(content, BidiTextInputEvent): + await self._send_text_content(content.text) + elif isinstance(content, BidiAudioInputEvent): + await self._send_audio_content(content) + elif isinstance(content, ToolResultEvent): + tool_result = content.get("tool_result") + if tool_result: + await self._send_tool_result(tool_result) + else: + raise ValueError(f"content_type={type(content)} | content not supported by nova sonic") async def _start_audio_connection(self) -> None: """Internal: Start audio input connection (call once before sending audio chunks).""" - if self.audio_connection_active: - return - logger.debug("nova audio connection starting") + self._audio_content_name = str(uuid.uuid4()) audio_content_start = json.dumps( { "event": { "contentStart": { - "promptName": self.connection_id, - "contentName": self.audio_content_name, + "promptName": self._connection_id, + "contentName": self._audio_content_name, "type": "AUDIO", "interactive": True, "role": "USER", @@ -303,13 +262,12 @@ async def _start_audio_connection(self) -> None: } ) - await self._send_nova_event([audio_content_start]) - self.audio_connection_active = True + await self._send_nova_events([audio_content_start]) async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: """Internal: Send audio using Nova Sonic protocol-specific format.""" # Start audio connection if not already active - if not self.audio_connection_active: + if not self._audio_content_name: await self._start_audio_connection() # Audio is already base64 encoded in the event @@ -318,29 +276,29 @@ async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: { "event": { "audioInput": { - "promptName": self.connection_id, - "contentName": self.audio_content_name, + "promptName": self._connection_id, + "contentName": self._audio_content_name, "content": audio_input.audio, } } } ) - await self._send_nova_event([audio_event]) + await self._send_nova_events([audio_event]) async def _end_audio_input(self) -> None: """Internal: End current audio input connection to trigger Nova Sonic processing.""" - if not self.audio_connection_active: + if not self._audio_content_name: return logger.debug("nova audio connection ending") audio_content_end = json.dumps( - {"event": {"contentEnd": {"promptName": self.connection_id, "contentName": self.audio_content_name}}} + {"event": {"contentEnd": {"promptName": self._connection_id, "contentName": self._audio_content_name}}} ) - await self._send_nova_event([audio_content_end]) - self.audio_connection_active = False + await self._send_nova_events([audio_content_end]) + self._audio_content_name = None async def _send_text_content(self, text: str) -> None: """Internal: Send text content using Nova Sonic format.""" @@ -350,34 +308,15 @@ async def _send_text_content(self, text: str) -> None: self._get_text_input_event(content_name, text), self._get_content_end_event(content_name), ] - await self._send_nova_event(events) - - async def _send_interrupt(self) -> None: - """Internal: Send interruption signal to Nova Sonic.""" - # Nova Sonic handles interruption through special input events - interrupt_event = json.dumps( - { - "event": { - "audioInput": { - "promptName": self.connection_id, - "contentName": self.audio_content_name, - "stopReason": "INTERRUPTED", - } - } - } - ) - await self._send_nova_event([interrupt_event]) + await self._send_nova_events(events) async def _send_tool_result(self, tool_result: ToolResult) -> None: """Internal: Send tool result using Nova Sonic toolResult format.""" - tool_use_id = tool_result.get("toolUseId") - if not tool_use_id: - logger.error("tool result missing toolUseId") - return + tool_use_id = tool_result["toolUseId"] logger.debug("tool_use_id=<%s> | sending nova tool result", tool_use_id) - # Extract result content + # TODO: We need to extract all content and content types result_data = {} if "content" in tool_result: # Extract text from content blocks @@ -392,39 +331,32 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: self._get_tool_result_event(content_name, result_data), self._get_content_end_event(content_name), ] - await self._send_nova_event(events) + await self._send_nova_events(events) async def stop(self) -> None: """Close Nova Sonic connection with proper cleanup sequence.""" - if not self._active: - return - logger.debug("nova connection cleanup starting") - self._active = False - try: - # End audio connection if active - if self.audio_connection_active: - await self._end_audio_input() + async def stop_events() -> None: + if not self._connection_id: + return - # Send cleanup events + await self._end_audio_input() cleanup_events = [self._get_prompt_end_event(), self._get_connection_end_event()] - try: - await self._send_nova_event(cleanup_events) - except Exception as e: - logger.warning("error=<%s> | error during nova sonic cleanup", e) - - # Close stream - if self.stream: - try: - await self.stream.input_stream.close() - except Exception as e: - logger.warning("error=<%s> | error closing nova sonic stream", e) - - except Exception as e: - logger.error("error=<%s> | nova cleanup failed", str(e)) - finally: - logger.debug("nova connection closed") + await self._send_nova_events(cleanup_events) + + async def stop_stream() -> None: + if not self._connection_id or not self._stream: + return + + await self._stream.close() + + async def stop_connection() -> None: + self._connection_id = None + + await stop(stop_events, stop_stream, stop_connection) + + logger.debug("nova connection closed") def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | None: """Convert Nova Sonic events to TypedEvent format.""" @@ -472,7 +404,7 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N return BidiTranscriptStreamEvent( delta={"text": text_content}, text=text_content, - role=self._current_role.lower() if self._current_role else "assistant", # type: ignore + role="assistant", is_final=self._generation_stage == "FINAL", current_transcript=text_content, ) @@ -505,13 +437,9 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N total_tokens=usage_data.get("totalTokens", total_input + total_output), ) - # Handle content start events (track role and emit response start) + # Handle content start events (emit response start) if "contentStart" in nova_event: content_data = nova_event["contentStart"] - role = content_data.get("role", "unknown") - # Store role for subsequent text output events - self._current_role = role - if content_data["type"] == "TEXT": self._generation_stage = json.loads(content_data["additionalModelFields"])["generationStage"] @@ -521,10 +449,12 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N response_id=self._current_completion_id or str(uuid.uuid4()) # Fallback to UUID if missing ) - # Ignore other events (contentEnd, etc.) + if "contentEnd" in nova_event: + self._generation_stage = None + + # Ignore all other events return None - # Nova Sonic event template methods def _get_connection_start_event(self) -> str: """Generate Nova Sonic connection start event.""" return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": NOVA_INFERENCE_CONFIG}}}) @@ -534,7 +464,7 @@ def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: prompt_start_event: dict[str, Any] = { "event": { "promptStart": { - "promptName": self.connection_id, + "promptName": self._connection_id, "textOutputConfiguration": NOVA_TEXT_CONFIG, "audioOutputConfiguration": NOVA_AUDIO_OUTPUT_CONFIG, } @@ -563,12 +493,12 @@ def _build_tool_configuration(self, tools: list[ToolSpec]) -> list[dict[str, Any ) return tool_config - def _get_system_prompt_events(self, system_prompt: str) -> list[str]: + def _get_system_prompt_events(self, system_prompt: str | None) -> list[str]: """Generate system prompt events.""" content_name = str(uuid.uuid4()) return [ self._get_text_content_start_event(content_name, "SYSTEM"), - self._get_text_input_event(content_name, system_prompt), + self._get_text_input_event(content_name, system_prompt or ""), self._get_content_end_event(content_name), ] @@ -578,7 +508,7 @@ def _get_text_content_start_event(self, content_name: str, role: str = "USER") - { "event": { "contentStart": { - "promptName": self.connection_id, + "promptName": self._connection_id, "contentName": content_name, "type": "TEXT", "role": role, @@ -595,7 +525,7 @@ def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> { "event": { "contentStart": { - "promptName": self.connection_id, + "promptName": self._connection_id, "contentName": content_name, "interactive": False, "type": "TOOL", @@ -613,7 +543,7 @@ def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> def _get_text_input_event(self, content_name: str, text: str) -> str: """Generate text input event.""" return json.dumps( - {"event": {"textInput": {"promptName": self.connection_id, "contentName": content_name, "content": text}}} + {"event": {"textInput": {"promptName": self._connection_id, "contentName": content_name, "content": text}}} ) def _get_tool_result_event(self, content_name: str, result: dict[str, Any]) -> str: @@ -622,7 +552,7 @@ def _get_tool_result_event(self, content_name: str, result: dict[str, Any]) -> s { "event": { "toolResult": { - "promptName": self.connection_id, + "promptName": self._connection_id, "contentName": content_name, "content": json.dumps(result), } @@ -632,59 +562,29 @@ def _get_tool_result_event(self, content_name: str, result: dict[str, Any]) -> s def _get_content_end_event(self, content_name: str) -> str: """Generate content end event.""" - return json.dumps({"event": {"contentEnd": {"promptName": self.connection_id, "contentName": content_name}}}) + return json.dumps({"event": {"contentEnd": {"promptName": self._connection_id, "contentName": content_name}}}) def _get_prompt_end_event(self) -> str: """Generate prompt end event.""" - return json.dumps({"event": {"promptEnd": {"promptName": self.connection_id}}}) + return json.dumps({"event": {"promptEnd": {"promptName": self._connection_id}}}) def _get_connection_end_event(self) -> str: """Generate connection end event.""" return json.dumps({"event": {"connectionEnd": {}}}) - async def _send_nova_event(self, events: list[str]) -> None: + async def _send_nova_events(self, events: list[str]) -> None: """Send event JSON string to Nova Sonic stream. A lock is used to send events in sequence when required (e.g., tool result start, content, and end). Args: - events: Jsonified event. + events: Jsonified events. """ - if not self.stream: - logger.error("cannot send event: stream is None") - return - - try: - async with self._send_lock: - for event in events: - bytes_data = event.encode("utf-8") - chunk = InvokeModelWithBidirectionalStreamInputChunk( - value=BidirectionalInputPayloadPart(bytes_=bytes_data) - ) - await self.stream.input_stream.send(chunk) - logger.debug("nova sonic event sent successfully") - - except Exception as e: - logger.error("error=<%s>, event=<%s> | error sending nova sonic event", e, event[:100]) - raise - - async def _initialize_client(self) -> None: - """Initialize Nova Sonic client.""" - try: - config = Config( - endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", - region=self.region, - aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), - auth_scheme_resolver=HTTPAuthSchemeResolver(), - auth_schemes={"aws.auth#sigv4": SigV4AuthScheme(service="bedrock")}, - ) - - self.client = BedrockRuntimeClient(config=config) - logger.debug("region=<%s> | nova sonic client initialized", self.region) - - except ImportError as e: - logger.error("error=<%s> | nova sonic dependencies not available", e) - raise - except Exception as e: - logger.error("error=<%s> | error initializing nova sonic client", e) - raise + async with self._send_lock: + for event in events: + bytes_data = event.encode("utf-8") + chunk = InvokeModelWithBidirectionalStreamInputChunk( + value=BidirectionalInputPayloadPart(bytes_=bytes_data) + ) + await self._stream.input_stream.send(chunk) + logger.debug("nova sonic event sent successfully") diff --git a/src/strands/experimental/bidi/types/_async.py b/src/strands/experimental/bidi/types/_async.py new file mode 100644 index 000000000..0d4309aff --- /dev/null +++ b/src/strands/experimental/bidi/types/_async.py @@ -0,0 +1,15 @@ +"""Types for custom async constructs.""" + +from typing import Any, Awaitable, Protocol + + +class Startable(Protocol): + """A construct that must first be started before use.""" + + def start(self, *args: Any, **kwargs: Any) -> Awaitable[None]: + """Setup resources and start connections.""" + ... + + def stop(self) -> Awaitable[None]: + """Tear down resources and stop connections.""" + ... diff --git a/tests/strands/experimental/bidi/models/test_novasonic.py b/tests/strands/experimental/bidi/models/test_novasonic.py index e0459fd51..49c6ec8f7 100644 --- a/tests/strands/experimental/bidi/models/test_novasonic.py +++ b/tests/strands/experimental/bidi/models/test_novasonic.py @@ -56,19 +56,21 @@ def mock_stream(): @pytest.fixture def mock_client(mock_stream): """Mock Bedrock Runtime client.""" - client = AsyncMock() - client.invoke_model_with_bidirectional_stream = AsyncMock(return_value=mock_stream) - return client + with patch("strands.experimental.bidi.models.novasonic.BedrockRuntimeClient") as mock_cls: + mock_instance = AsyncMock() + mock_instance.invoke_model_with_bidirectional_stream = AsyncMock(return_value=mock_stream) + mock_cls.return_value = mock_instance + + yield mock_instance @pytest_asyncio.fixture -async def nova_model(model_id, region): +def nova_model(model_id, region, mock_client): """Create Nova Sonic model instance.""" + _ = mock_client + model = BidiNovaSonicModel(model_id=model_id, region=region) yield model - # Cleanup - if model._active: - await model.stop() # Initialization and Connection Tests @@ -81,150 +83,96 @@ async def test_model_initialization(model_id, region): assert model.model_id == model_id assert model.region == region - assert model.stream is None - assert not model._active - assert model.connection_id is None + assert model._connection_id is None @pytest.mark.asyncio async def test_connection_lifecycle(nova_model, mock_client, mock_stream): """Test complete connection lifecycle with various configurations.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model.client = mock_client - - # Test basic connection - await nova_model.start(system_prompt="Test system prompt") - assert nova_model._active - assert nova_model.stream == mock_stream - assert nova_model.connection_id is not None - assert mock_client.invoke_model_with_bidirectional_stream.called - - # Test close - await nova_model.stop() - assert not nova_model._active - assert mock_stream.input_stream.close.called - - # Test connection with tools - tools = [ - { - "name": "get_weather", - "description": "Get weather information", - "inputSchema": {"json": json.dumps({"type": "object", "properties": {}})}, - } - ] - await nova_model.start(system_prompt="You are helpful", tools=tools) - # Verify initialization events were sent (connectionStart, promptStart, system prompt) - assert mock_stream.input_stream.send.call_count >= 3 - await nova_model.stop() + # Test basic connection + await nova_model.start(system_prompt="Test system prompt") + assert nova_model._stream == mock_stream + assert nova_model._connection_id is not None + assert mock_client.invoke_model_with_bidirectional_stream.called -@pytest.mark.asyncio -async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model_id, region): - """Test connection error handling and edge cases.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model.client = mock_client + # Test close + await nova_model.stop() + assert mock_stream.close.called - # Test double connection - await nova_model.start() - with pytest.raises(RuntimeError, match="Connection already active"): - await nova_model.start() - await nova_model.stop() + # Test connection with tools + tools = [ + { + "name": "get_weather", + "description": "Get weather information", + "inputSchema": {"json": json.dumps({"type": "object", "properties": {}})}, + } + ] + await nova_model.start(system_prompt="You are helpful", tools=tools) + # Verify initialization events were sent (connectionStart, promptStart, system prompt) + assert mock_stream.input_stream.send.call_count >= 3 + await nova_model.stop() - # Test close when already closed - model2 = BidiNovaSonicModel(model_id=model_id, region=region) - await model2.stop() # Should not raise - await model2.stop() # Second call should also be safe + +@pytest.mark.asyncio +async def test_model_stop_alone(nova_model): + await nova_model.stop() # Should not raise # Send Method Tests @pytest.mark.asyncio -async def test_send_all_content_types(nova_model, mock_client, mock_stream): +async def test_send_all_content_types(nova_model, mock_stream): """Test sending all content types through unified send() method.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model.client = mock_client - - await nova_model.start() - - # Test text content - text_event = BidiTextInputEvent(text="Hello, Nova!", role="user") - await nova_model.send(text_event) - # Should send contentStart, textInput, and contentEnd - assert mock_stream.input_stream.send.call_count >= 3 - - # Test audio content (base64 encoded) - audio_b64 = base64.b64encode(b"audio data").decode("utf-8") - audio_event = BidiAudioInputEvent(audio=audio_b64, format="pcm", sample_rate=16000, channels=1) - await nova_model.send(audio_event) - # Should start audio connection and send audio - assert nova_model.audio_connection_active - assert mock_stream.input_stream.send.called - - # Test tool result - tool_result: ToolResult = { - "toolUseId": "tool-123", - "status": "success", - "content": [{"text": "Weather is sunny"}], - } - await nova_model.send(ToolResultEvent(tool_result)) - # Should send contentStart, toolResult, and contentEnd - assert mock_stream.input_stream.send.called + await nova_model.start() + + # Test text content + text_event = BidiTextInputEvent(text="Hello, Nova!", role="user") + await nova_model.send(text_event) + # Should send contentStart, textInput, and contentEnd + assert mock_stream.input_stream.send.call_count >= 3 + + # Test audio content (base64 encoded) + audio_b64 = base64.b64encode(b"audio data").decode("utf-8") + audio_event = BidiAudioInputEvent(audio=audio_b64, format="pcm", sample_rate=16000, channels=1) + await nova_model.send(audio_event) + # Should start audio connection and send audio + assert nova_model._audio_content_name + assert mock_stream.input_stream.send.called + + # Test tool result + tool_result: ToolResult = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Weather is sunny"}], + } + await nova_model.send(ToolResultEvent(tool_result)) + # Should send contentStart, toolResult, and contentEnd + assert mock_stream.input_stream.send.called - await nova_model.stop() + await nova_model.stop() @pytest.mark.asyncio -async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog): +async def test_send_edge_cases(nova_model): """Test send() edge cases and error handling.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model.client = mock_client - - # Test send when inactive - text_event = BidiTextInputEvent(text="Hello", role="user") - await nova_model.send(text_event) # Should not raise - - # Test image content (not supported, base64 encoded, no encoding parameter) - await nova_model.start() - image_b64 = base64.b64encode(b"image data").decode("utf-8") - image_event = BidiImageInputEvent( - image=image_b64, - mime_type="image/jpeg", - ) - await nova_model.send(image_event) - # Should log warning about unsupported image input - assert any("not supported" in record.message.lower() for record in caplog.records) - await nova_model.stop() - - -# Receive and Event Conversion Tests + # Test image content (not supported, base64 encoded, no encoding parameter) + await nova_model.start() + image_b64 = base64.b64encode(b"image data").decode("utf-8") + image_event = BidiImageInputEvent( + image=image_b64, + mime_type="image/jpeg", + ) + with pytest.raises(ValueError, match=r"content not supported by nova sonic"): + await nova_model.send(image_event) -@pytest.mark.asyncio -async def test_receive_lifecycle_events(nova_model, mock_client, mock_stream): - """Test that receive() emits connection start and end events.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model.client = mock_client - - # Setup mock to return no events and then stop - async def mock_wait_for(*args, **kwargs): - await asyncio.sleep(0.1) - nova_model._active = False - raise asyncio.TimeoutError() - - with patch("asyncio.wait_for", side_effect=mock_wait_for): - await nova_model.start() + await nova_model.stop() - events = [] - async for event in nova_model.receive(): - events.append(event) - # Should have session start and end (new TypedEvent format) - assert len(events) >= 2 - assert events[0].get("type") == "bidi_connection_start" - assert events[0].get("connection_id") == nova_model.connection_id - assert events[-1].get("type") == "bidi_connection_close" +# Receive and Event Conversion Tests @pytest.mark.asyncio @@ -306,7 +254,6 @@ async def test_event_conversion(nova_model): assert result is not None assert isinstance(result, BidiResponseStartEvent) assert result.get("type") == "bidi_response_start" - assert nova_model._current_role == "ASSISTANT" assert nova_model._generation_stage == "FINAL" # Test AUDIO type contentStart (no additionalModelFields) @@ -314,7 +261,6 @@ async def test_event_conversion(nova_model): result = nova_model._convert_nova_event(nova_event) assert result is not None assert isinstance(result, BidiResponseStartEvent) - assert nova_model._current_role == "ASSISTANT" # Test TOOL type contentStart nova_event = {"contentStart": {"role": "TOOL", "type": "TOOL", "contentId": "content-789"}} @@ -327,22 +273,20 @@ async def test_event_conversion(nova_model): @pytest.mark.asyncio -async def test_audio_connection_lifecycle(nova_model, mock_client, mock_stream): +async def test_audio_connection_lifecycle(nova_model): """Test audio connection start and end lifecycle.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model.client = mock_client - await nova_model.start() + await nova_model.start() - # Start audio connection - await nova_model._start_audio_connection() - assert nova_model.audio_connection_active + # Start audio connection + await nova_model._start_audio_connection() + assert nova_model._audio_content_name - # End audio connection - await nova_model._end_audio_input() - assert not nova_model.audio_connection_active + # End audio connection + await nova_model._end_audio_input() + assert not nova_model._audio_content_name - await nova_model.stop() + await nova_model.stop() # Helper Method Tests @@ -378,7 +322,7 @@ async def test_event_templates(nova_model): assert "inferenceConfiguration" in event["event"]["sessionStart"] # Test prompt start event - nova_model.connection_id = "test-connection" + nova_model._connection_id = "test-connection" event_json = nova_model._get_prompt_start_event([]) event = json.loads(event_json) assert "event" in event @@ -406,21 +350,19 @@ async def test_event_templates(nova_model): @pytest.mark.asyncio -async def test_error_handling(nova_model, mock_client, mock_stream): +async def test_error_handling(nova_model, mock_stream): """Test error handling in various scenarios.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model.client = mock_client - # Test response processor handles errors gracefully - async def mock_error(*args, **kwargs): - raise Exception("Test error") + # Test response processor handles errors gracefully + async def mock_error(*args, **kwargs): + raise Exception("Test error") - mock_stream.await_output.side_effect = mock_error + mock_stream.await_output.side_effect = mock_error - await nova_model.start() + await nova_model.start() - # Wait a bit for response processor to handle error - await asyncio.sleep(0.1) + # Wait a bit for response processor to handle error + await asyncio.sleep(0.1) - # Should still be able to close cleanly - await nova_model.stop() + # Should still be able to close cleanly + await nova_model.stop() diff --git a/tests/strands/experimental/bidi/test_async.py b/tests/strands/experimental/bidi/test_async.py new file mode 100644 index 000000000..10b3016df --- /dev/null +++ b/tests/strands/experimental/bidi/test_async.py @@ -0,0 +1,57 @@ +from unittest.mock import AsyncMock, Mock + +import pytest + +from strands.experimental.bidi._async import start, stop + + +@pytest.fixture +def mock_startable(): + return Mock(start=AsyncMock(), stop=AsyncMock()) + + +@pytest.mark.asyncio +async def test_start_exception(mock_startable): + mock_startable.start.side_effect = ValueError("start failed") + + with pytest.raises(ValueError, match=r"start failed"): + await start(mock_startable.start)(mock_startable) + + mock_startable.stop.assert_called_once() + + +@pytest.mark.asyncio +async def test_start_success(mock_startable): + await start(mock_startable.start)(mock_startable) + mock_startable.stop.assert_not_called() + + +@pytest.mark.asyncio +async def test_stop_exception(): + func1 = AsyncMock() + func2 = AsyncMock(side_effect=ValueError("stop 2 failed")) + func3 = AsyncMock() + + with pytest.raises(ExceptionGroup) as exc_info: # type: ignore # noqa: F821 + await stop(func1, func2, func3) + + func1.assert_called_once() + func2.assert_called_once() + func3.assert_called_once() + + assert len(exc_info.value.exceptions) == 1 + with pytest.raises(ValueError, match=r"stop 2 failed"): + raise exc_info.value.exceptions[0] + + +@pytest.mark.asyncio +async def test_stop_success(): + func1 = AsyncMock() + func2 = AsyncMock() + func3 = AsyncMock() + + await stop(func1, func2, func3) + + func1.assert_called_once() + func2.assert_called_once() + func3.assert_called_once() From 16db749e7c8f0de88abca5e21ae4177fda1c22a9 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 20 Nov 2025 13:06:23 -0500 Subject: [PATCH 163/242] Update credentials resolver logic in NovaSonic --- .../experimental/bidi/models/novasonic.py | 50 +++++++++++++++++-- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index ccaa90dc8..d678428df 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -19,13 +19,14 @@ import uuid from typing import Any, AsyncIterable +import boto3 from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme from aws_sdk_bedrock_runtime.models import ( BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk, ) -from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver +from smithy_aws_core.identity.static import StaticCredentialsResolver from smithy_core.aio.eventstream import DuplexEventStream from ....types._events import ToolResultEvent, ToolUseStreamEvent @@ -89,16 +90,32 @@ class BidiNovaSonicModel(BidiModel): _stream: DuplexEventStream - def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **kwargs: Any) -> None: + def __init__( + self, + model_id: str = "amazon.nova-sonic-v1:0", + boto_session: boto3.Session | None = None, + region: str | None = None, + **kwargs: Any, + ) -> None: """Initialize Nova Sonic bidirectional model. Args: model_id: Nova Sonic model identifier. - region: AWS region. + boto_session: Boto Session to use when calling the Nova Sonic Model. + region_name: AWS region to use for the Nova Sonic service. + Defaults to the AWS_REGION environment variable if set, or "us-east-1" if not set. **kwargs: Reserved for future parameters. """ + if region and boto_session: + raise ValueError("Cannot specify both `region_name` and `boto_session`.") + + # Create session and resolve region + self._session = boto_session or boto3.Session() + resolved_region = region or self._session.region_name or "us-east-1" + + # Model configuration self.model_id = model_id - self.region = region + self.region = resolved_region # Track API-provided identifiers self._connection_id: str | None = None @@ -139,13 +156,36 @@ async def start( self._connection_id = str(uuid.uuid4()) + # Get credentials from boto3 session (full credential chain) + credentials = self._session.get_credentials() + + if not credentials: + raise RuntimeError( + "No AWS credentials found. Configure credentials via environment variables, " + "credential files, IAM roles, or SSO." + ) + + # Use static resolver with credentials configured as properties + resolver = StaticCredentialsResolver() + + print("🔍 CREDENTIAL DEBUG: Creating Config with credentials as properties...") config = Config( endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", region=self.region, - aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), + aws_credentials_identity_resolver=resolver, auth_scheme_resolver=HTTPAuthSchemeResolver(), auth_schemes={"aws.auth#sigv4": SigV4AuthScheme(service="bedrock")}, + # Configure static credentials as properties + aws_access_key_id=credentials.access_key, + aws_secret_access_key=credentials.secret_key, + aws_session_token=credentials.token, ) + + print("🔍 CREDENTIAL DEBUG: Creating BedrockRuntimeClient...") + self.client = BedrockRuntimeClient(config=config) + print("✅ CREDENTIAL DEBUG: Nova Sonic client initialized successfully!") + logger.debug("region=<%s> | nova sonic client initialized", self.region) + client = BedrockRuntimeClient(config=config) self._stream = await client.invoke_model_with_bidirectional_stream( InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) From 7e6148021d7c25f35ca86a7194169ab3a9468988 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 20 Nov 2025 13:09:39 -0500 Subject: [PATCH 164/242] minor update --- src/strands/experimental/bidi/models/novasonic.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index d678428df..eef6110d3 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -102,8 +102,7 @@ def __init__( Args: model_id: Nova Sonic model identifier. boto_session: Boto Session to use when calling the Nova Sonic Model. - region_name: AWS region to use for the Nova Sonic service. - Defaults to the AWS_REGION environment variable if set, or "us-east-1" if not set. + region: AWS region **kwargs: Reserved for future parameters. """ if region and boto_session: From bf06138d319f9c9200c99675518312c7373533fb Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 20 Nov 2025 13:10:43 -0500 Subject: [PATCH 165/242] minor update --- src/strands/experimental/bidi/models/novasonic.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index eef6110d3..315918169 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -167,7 +167,6 @@ async def start( # Use static resolver with credentials configured as properties resolver = StaticCredentialsResolver() - print("🔍 CREDENTIAL DEBUG: Creating Config with credentials as properties...") config = Config( endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", region=self.region, @@ -180,9 +179,7 @@ async def start( aws_session_token=credentials.token, ) - print("🔍 CREDENTIAL DEBUG: Creating BedrockRuntimeClient...") self.client = BedrockRuntimeClient(config=config) - print("✅ CREDENTIAL DEBUG: Nova Sonic client initialized successfully!") logger.debug("region=<%s> | nova sonic client initialized", self.region) client = BedrockRuntimeClient(config=config) From 96a24944fce4314fe3329f891a0bd1f03be22e29 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 20 Nov 2025 13:51:30 -0500 Subject: [PATCH 166/242] minor update --- src/strands/experimental/bidi/models/novasonic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 315918169..e28593d18 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -159,8 +159,8 @@ async def start( credentials = self._session.get_credentials() if not credentials: - raise RuntimeError( - "No AWS credentials found. Configure credentials via environment variables, " + raise ValueError( + "no AWS credentials found. configure credentials via environment variables, " "credential files, IAM roles, or SSO." ) From 5d7086fcf92bee3e7bb75df9b67561b45235b069 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 21 Nov 2025 12:00:05 +0100 Subject: [PATCH 167/242] fix mypy errors --- src/strands/experimental/bidi/agent/agent.py | 2 +- src/strands/experimental/bidi/agent/loop.py | 2 +- .../experimental/bidi/models/gemini_live.py | 4 +-- .../experimental/bidi/models/openai.py | 6 ++-- src/strands/tools/executors/_executor.py | 28 +++++++++++++------ src/strands/types/tools.py | 5 +--- .../strands/agent/hooks/test_agent_events.py | 13 +++++++-- tests/strands/agent/test_agent.py | 1 + 8 files changed, 38 insertions(+), 23 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index cedebd66c..d3ff77632 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -27,8 +27,8 @@ from ....tools.watcher import ToolWatcher from ....types.content import ContentBlock, Message, Messages from ....types.tools import AgentTool, ToolResult, ToolUse -from ...tools import ToolProvider from ...hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent +from ...tools import ToolProvider from ..models.bidi_model import BidiModel from ..models.novasonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 935200070..79ffbde00 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -7,7 +7,7 @@ import logging from typing import TYPE_CHECKING, Any, AsyncIterable, Awaitable -from ....types._events import ToolResultEvent, ToolResultMessageEvent, ToolStreamEvent, ToolUseStreamEvent +from ....types._events import ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent from ....types.content import Message from ....types.tools import ToolResult, ToolUse from ...hooks.events import ( diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 1b71ddb94..5f7eb587f 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -99,9 +99,9 @@ def __init__( self.client = genai.Client(**client_kwargs) # Connection state (initialized in start()) - self.live_session: Any = None + self.live_session: Any self.live_session_context_manager = None - self.connection_id: str | None = None + self.connection_id | str self._active: bool = False async def start( diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index a0dae9237..3cda4f738 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -104,8 +104,8 @@ def __init__( ) # Connection state (initialized in start()) - self.websocket: ClientConnection | None = None - self.connection_id: str | None = None + self.websocket: ClientConnection + self.connection_id: str self._active: bool = False self._function_call_buffer: dict[str, Any] = {} @@ -574,8 +574,6 @@ async def send( tool_result = content.get("tool_result") if tool_result: await self._send_tool_result(tool_result) - else: - logger.warning("Unknown content type: %s", type(content).__name__) except Exception as e: logger.error("error=<%s> | error sending content to openai", e) raise # Propagate exception for debugging in experimental code diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 467d22936..140d842a0 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -32,12 +32,18 @@ class ToolExecutor(abc.ABC): @staticmethod def _is_bidi_agent(agent: Union["Agent", "BidiAgent"]) -> bool: - """Check if the agent is a BidiAgent by type name. + """Check if the agent is a BidiAgent using isinstance. - Uses type name comparison to avoid circular imports while maintaining - type safety. This works because we control both Agent and BidiAgent types. + Uses runtime import to avoid circular dependency at module load time. + This properly handles subclasses of BidiAgent. """ - return type(agent).__name__ == "BidiAgent" + try: + from ...experimental.bidi.agent.agent import BidiAgent + + return isinstance(agent, BidiAgent) + except ImportError: + # If BidiAgent is not available, it can't be a BidiAgent + return False @staticmethod async def _invoke_before_tool_call_hook( @@ -50,7 +56,7 @@ async def _invoke_before_tool_call_hook( if ToolExecutor._is_bidi_agent(agent): return await agent.hooks.invoke_callbacks_async( BidiBeforeToolCallEvent( - agent=agent, + agent=agent, # type: ignore[arg-type] selected_tool=tool_func, tool_use=tool_use, invocation_state=invocation_state, @@ -59,7 +65,7 @@ async def _invoke_before_tool_call_hook( else: return await agent.hooks.invoke_callbacks_async( BeforeToolCallEvent( - agent=agent, + agent=agent, # type: ignore[arg-type] selected_tool=tool_func, tool_use=tool_use, invocation_state=invocation_state, @@ -80,7 +86,7 @@ async def _invoke_after_tool_call_hook( if ToolExecutor._is_bidi_agent(agent): return await agent.hooks.invoke_callbacks_async( BidiAfterToolCallEvent( - agent=agent, + agent=agent, # type: ignore[arg-type] selected_tool=selected_tool, tool_use=tool_use, invocation_state=invocation_state, @@ -92,7 +98,7 @@ async def _invoke_after_tool_call_hook( else: return await agent.hooks.invoke_callbacks_async( AfterToolCallEvent( - agent=agent, + agent=agent, # type: ignore[arg-type] selected_tool=selected_tool, tool_use=tool_use, invocation_state=invocation_state, @@ -315,7 +321,11 @@ async def _stream_with_trace( tool_success = result.get("status") == "success" tool_duration = time.time() - tool_start_time message = Message(role="user", content=[{"toolResult": result}]) - agent.event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message) + # Only add tool usage metrics for regular Agent (not BidiAgent) + if not ToolExecutor._is_bidi_agent(agent): + agent.event_loop_metrics.add_tool_usage( # type: ignore[union-attr] + tool_use, tool_duration, tool_trace, tool_success, message + ) cycle_trace.add_child(tool_trace) tracer.end_tool_call_span(tool_call_span, result) diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 4091f48ea..8f4dba6b1 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -8,16 +8,13 @@ import uuid from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union +from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union from typing_extensions import NotRequired, TypedDict from .interrupt import _Interruptible from .media import DocumentContent, ImageContent -if TYPE_CHECKING: - from .. import Agent - JSONSchema = dict """Type alias for JSON Schema dictionaries.""" diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index 4fef595f8..56a5999c0 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -268,13 +268,15 @@ async def test_stream_e2e_success(alist): "tool_stream_event": { "data": {"tool_streaming": True}, "tool_use": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"}, - } + }, + "type": "tool_stream", }, { "tool_stream_event": { "data": "Final result", "tool_use": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"}, - } + }, + "type": "tool_stream", }, { "message": { @@ -573,6 +575,7 @@ class Person(BaseModel): }, {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 0}}}, { + "type": "tool_use_stream", "delta": {"toolUse": {"input": ""}}, "current_tool_use": { "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", @@ -582,6 +585,7 @@ class Person(BaseModel): }, {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"na'}}, "contentBlockIndex": 0}}}, { + "type": "tool_use_stream", "delta": {"toolUse": {"input": '{"na'}}, "current_tool_use": { "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", @@ -591,6 +595,7 @@ class Person(BaseModel): }, {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": 'me"'}}, "contentBlockIndex": 0}}}, { + "type": "tool_use_stream", "delta": {"toolUse": {"input": 'me"'}}, "current_tool_use": { "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", @@ -600,6 +605,7 @@ class Person(BaseModel): }, {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": ': "J'}}, "contentBlockIndex": 0}}}, { + "type": "tool_use_stream", "delta": {"toolUse": {"input": ': "J'}}, "current_tool_use": { "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", @@ -609,6 +615,7 @@ class Person(BaseModel): }, {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": 'ohn"'}}, "contentBlockIndex": 0}}}, { + "type": "tool_use_stream", "delta": {"toolUse": {"input": 'ohn"'}}, "current_tool_use": { "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", @@ -618,6 +625,7 @@ class Person(BaseModel): }, {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": ', "age": 3'}}, "contentBlockIndex": 0}}}, { + "type": "tool_use_stream", "delta": {"toolUse": {"input": ', "age": 3'}}, "current_tool_use": { "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", @@ -627,6 +635,7 @@ class Person(BaseModel): }, {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": "1}"}}, "contentBlockIndex": 0}}}, { + "type": "tool_use_stream", "delta": {"toolUse": {"input": "1}"}}, "current_tool_use": { "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index ea6b09b75..f133400a8 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -684,6 +684,7 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): unittest.mock.call(event={"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}), unittest.mock.call(event={"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}), unittest.mock.call( + type="tool_use_stream", agent=agent, current_tool_use={"toolUseId": "123", "name": "test", "input": {}}, delta={"toolUse": {"input": '{"value"}'}}, From f652cc8b5a0e3fce6c85140f7b7ab06840993e45 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 21 Nov 2025 12:10:43 +0100 Subject: [PATCH 168/242] simplify code based on comments --- src/strands/experimental/bidi/agent/loop.py | 17 +++--- src/strands/tools/executors/_executor.py | 66 +++++++-------------- 2 files changed, 31 insertions(+), 52 deletions(-) diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 79ffbde00..1297e9e62 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -7,7 +7,7 @@ import logging from typing import TYPE_CHECKING, Any, AsyncIterable, Awaitable -from ....types._events import ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent +from ....types._events import ToolInterruptEvent, ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent from ....types.content import Message from ....types.tools import ToolResult, ToolUse from ...hooks.events import ( @@ -174,33 +174,34 @@ async def _run_tool(self, tool_use: ToolUse) -> None: tool_results: list[ToolResult] = [] - # Build invocation_state from stored state and current agent context invocation_state: dict[str, Any] = { - **self._agent._invocation_state, # User-provided context - "agent": self._agent, # Always include agent reference + **self._agent._invocation_state, + "agent": self._agent, "model": self._agent.model, "messages": self._agent.messages, "system_prompt": self._agent.system_prompt, } - # Use the tool executor to run the tool (no tracing/metrics for BidiAgent yet) tool_events = self._agent.tool_executor._stream( self._agent, tool_use, tool_results, invocation_state, - structured_output_context=None, # BidiAgent doesn't support structured output yet + structured_output_context=None, ) async for event in tool_events: + if isinstance(event, ToolInterruptEvent): + raise RuntimeError( + "Tool interruption is not yet supported in BidiAgent. " + "ToolInterruptEvent received but cannot be handled in bidirectional streaming context." + ) await self._event_queue.put(event) if isinstance(event, ToolResultEvent): result = event.tool_result - # Send tool result to model await self._agent.model.send(ToolResultEvent(result)) - # Add tool result message to conversation history message: Message = { "role": "user", "content": [{"toolResult": result}], diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 140d842a0..9acba3372 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -17,6 +17,7 @@ from ...telemetry.tracer import get_tracer, serialize from ...types._events import ToolCancelEvent, ToolInterruptEvent, ToolResultEvent, ToolStreamEvent, TypedEvent from ...types.content import Message +from ...types.interrupt import Interrupt from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolResult, ToolUse from ..structured_output._structured_output_context import StructuredOutputContext @@ -51,26 +52,17 @@ async def _invoke_before_tool_call_hook( tool_func: Any, tool_use: ToolUse, invocation_state: dict[str, Any], - ) -> tuple[Any, list]: + ) -> tuple[Union[BeforeToolCallEvent, BidiBeforeToolCallEvent], list[Interrupt]]: """Invoke the appropriate before tool call hook based on agent type.""" - if ToolExecutor._is_bidi_agent(agent): - return await agent.hooks.invoke_callbacks_async( - BidiBeforeToolCallEvent( - agent=agent, # type: ignore[arg-type] - selected_tool=tool_func, - tool_use=tool_use, - invocation_state=invocation_state, - ) - ) - else: - return await agent.hooks.invoke_callbacks_async( - BeforeToolCallEvent( - agent=agent, # type: ignore[arg-type] - selected_tool=tool_func, - tool_use=tool_use, - invocation_state=invocation_state, - ) + event_cls = BidiBeforeToolCallEvent if ToolExecutor._is_bidi_agent(agent) else BeforeToolCallEvent + return await agent.hooks.invoke_callbacks_async( + event_cls( + agent=agent, # type: ignore[arg-type] + selected_tool=tool_func, + tool_use=tool_use, + invocation_state=invocation_state, ) + ) @staticmethod async def _invoke_after_tool_call_hook( @@ -81,32 +73,20 @@ async def _invoke_after_tool_call_hook( result: ToolResult, exception: Exception | None = None, cancel_message: str | None = None, - ) -> tuple[Any, list]: + ) -> tuple[Union[AfterToolCallEvent, BidiAfterToolCallEvent], list[Interrupt]]: """Invoke the appropriate after tool call hook based on agent type.""" - if ToolExecutor._is_bidi_agent(agent): - return await agent.hooks.invoke_callbacks_async( - BidiAfterToolCallEvent( - agent=agent, # type: ignore[arg-type] - selected_tool=selected_tool, - tool_use=tool_use, - invocation_state=invocation_state, - result=result, - exception=exception, - cancel_message=cancel_message, - ) - ) - else: - return await agent.hooks.invoke_callbacks_async( - AfterToolCallEvent( - agent=agent, # type: ignore[arg-type] - selected_tool=selected_tool, - tool_use=tool_use, - invocation_state=invocation_state, - result=result, - exception=exception, - cancel_message=cancel_message, - ) + event_cls = BidiAfterToolCallEvent if ToolExecutor._is_bidi_agent(agent) else AfterToolCallEvent + return await agent.hooks.invoke_callbacks_async( + event_cls( + agent=agent, # type: ignore[arg-type] + selected_tool=selected_tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=result, + exception=exception, + cancel_message=cancel_message, ) + ) @staticmethod async def _stream( @@ -165,7 +145,6 @@ async def _stream( } ) - # Invoke appropriate before tool call hook based on agent type before_event, interrupts = await ToolExecutor._invoke_before_tool_call_hook( agent, tool_func, tool_use, invocation_state ) @@ -321,7 +300,6 @@ async def _stream_with_trace( tool_success = result.get("status") == "success" tool_duration = time.time() - tool_start_time message = Message(role="user", content=[{"toolResult": result}]) - # Only add tool usage metrics for regular Agent (not BidiAgent) if not ToolExecutor._is_bidi_agent(agent): agent.event_loop_metrics.add_tool_usage( # type: ignore[union-attr] tool_use, tool_duration, tool_trace, tool_success, message From 888553237fa6a471e4d0bdfea35353f3af1bd71f Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 21 Nov 2025 12:16:46 +0100 Subject: [PATCH 169/242] remove unnecessary comment --- src/strands/experimental/bidi/agent/agent.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index d3ff77632..5a29e9ad3 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -280,8 +280,6 @@ async def start(self, invocation_state: dict[str, Any] | None = None) -> None: ``` """ logger.debug("agent starting") - - # Store invocation_state for use during tool execution self._invocation_state = invocation_state or {} await self._loop.start() From 4532bb075bcfaca7c8c922dbeccb21c47dce1862 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 21 Nov 2025 08:47:21 -0500 Subject: [PATCH 170/242] nova - transcript - set role (#74) --- src/strands/experimental/bidi/models/novasonic.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index e28593d18..9cf47b6b0 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -431,7 +431,8 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N # Handle text output (transcripts) elif "textOutput" in nova_event: - text_content = nova_event["textOutput"]["content"] + text_output = nova_event["textOutput"] + text_content = text_output["content"] # Check for Nova Sonic interruption pattern if '{ "interrupted" : true }' in text_content: logger.debug("nova interruption detected in text output") @@ -440,7 +441,7 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N return BidiTranscriptStreamEvent( delta={"text": text_content}, text=text_content, - role="assistant", + role=text_output["role"].lower(), is_final=self._generation_stage == "FINAL", current_transcript=text_content, ) From 2c421cce8b5d5f65f9d794e72f59e218230013af Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 21 Nov 2025 14:57:26 +0100 Subject: [PATCH 171/242] fix merge mypy issuee --- src/strands/experimental/bidi/models/openai.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index cde84715d..fefde06bb 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -266,14 +266,16 @@ async def _add_conversation_history(self, messages: Messages) -> None: messages: List of conversation messages with role and content. """ # Track tool call IDs to ensure consistency between calls and results - call_id_map = {} + call_id_map: dict[str, str] = {} # First pass: collect all tool call IDs for message in messages: - conversation_item: dict[Any, Any] = { - "type": "conversation.item.create", - "item": {"type": "message", "role": message["role"], "content": []}, - } + for block in message.get("content", []): + if "toolUse" in block: + tool_use = block["toolUse"] + original_id = tool_use["toolUseId"] + call_id = original_id[:32] + call_id_map[original_id] = call_id # Second pass: send messages for message in messages: From 61333687aa152ae7f5243a63f71bb953b98de164 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 21 Nov 2025 15:16:52 +0100 Subject: [PATCH 172/242] fix: fix openai sample rate output --- .../experimental/bidi/models/openai.py | 10 ++- .../bidi/models/test_openai_realtime.py | 87 +++++++++++++++++++ 2 files changed, 96 insertions(+), 1 deletion(-) diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index fefde06bb..e7fdb9962 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -386,11 +386,19 @@ def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutput # Audio output elif event_type == "response.output_audio.delta": # Audio is already base64 string from OpenAI + # Get sample rate from user's session config if provided, otherwise use default + sample_rate = ( + self.session_config.get("audio", {}) + .get("output", {}) + .get("format", {}) + .get("rate", AUDIO_FORMAT["rate"]) + ) + return [ BidiAudioStreamEvent( audio=openai_event["delta"], format="pcm", - sample_rate=AUDIO_FORMAT["rate"], # type: ignore + sample_rate=sample_rate, # type: ignore channels=1, ) ] diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py index f1a465293..0012b5649 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -566,3 +566,90 @@ async def test_send_event_helper(mock_websockets_connect, model): assert sent_message == test_event await model.stop() + + +@pytest.mark.asyncio +async def test_custom_audio_sample_rate(mock_websockets_connect, api_key): + """Test that custom audio sample rate from session_config is used in audio events.""" + _, mock_ws = mock_websockets_connect + + # Create model with custom sample rate + custom_sample_rate = 48000 + session_config = {"audio": {"output": {"format": {"rate": custom_sample_rate}}}} + model = BidiOpenAIRealtimeModel(api_key=api_key, session_config=session_config) + + await model.start() + + # Simulate receiving an audio delta event from OpenAI + openai_audio_event = {"type": "response.output_audio.delta", "delta": "base64audiodata"} + + # Convert the event + converted_events = model._convert_openai_event(openai_audio_event) + + # Verify the audio event uses the custom sample rate + assert converted_events is not None + assert len(converted_events) == 1 + audio_event = converted_events[0] + assert isinstance(audio_event, BidiAudioStreamEvent) + assert audio_event.sample_rate == custom_sample_rate + assert audio_event.format == "pcm" + assert audio_event.channels == 1 + + await model.stop() + + +@pytest.mark.asyncio +async def test_default_audio_sample_rate(mock_websockets_connect, api_key): + """Test that default audio sample rate is used when no custom config is provided.""" + _, mock_ws = mock_websockets_connect + + # Create model without custom audio config + model = BidiOpenAIRealtimeModel(api_key=api_key) + + await model.start() + + # Simulate receiving an audio delta event from OpenAI + openai_audio_event = {"type": "response.output_audio.delta", "delta": "base64audiodata"} + + # Convert the event + converted_events = model._convert_openai_event(openai_audio_event) + + # Verify the audio event uses the default sample rate (24000) + assert converted_events is not None + assert len(converted_events) == 1 + audio_event = converted_events[0] + assert isinstance(audio_event, BidiAudioStreamEvent) + assert audio_event.sample_rate == 24000 # Default from AUDIO_FORMAT + assert audio_event.format == "pcm" + assert audio_event.channels == 1 + + await model.stop() + + +@pytest.mark.asyncio +async def test_partial_audio_config(mock_websockets_connect, api_key): + """Test that partial audio config doesn't break and falls back to defaults.""" + _, mock_ws = mock_websockets_connect + + # Create model with partial audio config (missing format.rate) + session_config = {"audio": {"output": {"voice": "alloy"}}} + model = BidiOpenAIRealtimeModel(api_key=api_key, session_config=session_config) + + await model.start() + + # Simulate receiving an audio delta event from OpenAI + openai_audio_event = {"type": "response.output_audio.delta", "delta": "base64audiodata"} + + # Convert the event + converted_events = model._convert_openai_event(openai_audio_event) + + # Verify the audio event uses the default sample rate + assert converted_events is not None + assert len(converted_events) == 1 + audio_event = converted_events[0] + assert isinstance(audio_event, BidiAudioStreamEvent) + assert audio_event.sample_rate == 24000 # Falls back to default + assert audio_event.format == "pcm" + assert audio_event.channels == 1 + + await model.stop() From 1026cfc3b7a5cbfb94cf3df4ea8501ff86de72c4 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Fri, 21 Nov 2025 10:59:47 -0500 Subject: [PATCH 173/242] add audio_config for user's to configure their audio settings --- src/strands/experimental/bidi/agent/agent.py | 25 +++- src/strands/experimental/bidi/io/audio.py | 54 +++++++- .../experimental/bidi/models/gemini_live.py | 51 +++++++- .../experimental/bidi/models/novasonic.py | 59 ++++++++- .../experimental/bidi/models/openai.py | 43 +++++- src/strands/experimental/bidi/types/io.py | 24 +++- .../experimental/bidi/io/test_audio.py | 122 ++++++++++++++++++ .../bidi/models/test_gemini_live.py | 67 ++++++++++ .../bidi/models/test_novasonic.py | 50 +++++++ .../bidi/models/test_openai_realtime.py | 71 ++++++++++ 10 files changed, 544 insertions(+), 22 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 7156b12be..4669447d7 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -29,6 +29,7 @@ from ....types.tools import AgentTool, ToolResult, ToolUse from ...tools import ToolProvider from ..hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent +from ..io.audio import _BidiAudioInput, _BidiAudioOutput from ..models.bidi_model import BidiModel from ..models.novasonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput @@ -419,6 +420,16 @@ async def run(self, inputs: list[BidiInput], outputs: list[BidiOutput]) -> None: await agent.run(inputs=[audio_io.input()], outputs=[audio_io.output(), text_io.output()]) ``` """ + # Extract audio config from model if available + audio_config = getattr(self.model, "audio_config", None) + if audio_config: + logger.debug( + "audio_config | model provides: input_rate=%s, output_rate=%s, channels=%s, voice=%s", + audio_config.get("input_rate"), + audio_config.get("output_rate"), + audio_config.get("channels"), + audio_config.get("voice"), + ) async def run_inputs() -> None: async def task(input_: BidiInput) -> None: @@ -436,13 +447,23 @@ async def run_outputs() -> None: await self.start() + # Start inputs with audio config if applicable for input_ in inputs: if hasattr(input_, "start"): - await input_.start() + # Pass audio config to audio inputs + if audio_config and isinstance(input_, _BidiAudioInput): + await input_.start(audio_config=audio_config) + else: + await input_.start() + # Start outputs with audio config if applicable for output in outputs: if hasattr(output, "start"): - await output.start() + # Pass audio config to audio outputs + if audio_config and isinstance(output, _BidiAudioOutput): + await output.start(audio_config=audio_config) + else: + await output.start() try: await asyncio.gather(run_inputs(), run_outputs()) diff --git a/src/strands/experimental/bidi/io/audio.py b/src/strands/experimental/bidi/io/audio.py index 2f129481f..41a07be9e 100644 --- a/src/strands/experimental/bidi/io/audio.py +++ b/src/strands/experimental/bidi/io/audio.py @@ -13,7 +13,7 @@ import pyaudio from ..types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiInterruptionEvent, BidiOutputEvent -from ..types.io import BidiInput, BidiOutput +from ..types.io import AudioConfig, BidiInput, BidiOutput logger = logging.getLogger(__name__) @@ -24,10 +24,12 @@ class _BidiAudioInput(BidiInput): Attributes: _audio: PyAudio instance for audio system access. _stream: Audio input stream. + _user_config_set: Track which config values were explicitly set by user. """ _audio: pyaudio.PyAudio _stream: pyaudio.Stream + _user_config_set: set[str] _CHANNELS: int = 1 _DEVICE_INDEX: int | None = None @@ -37,15 +39,33 @@ class _BidiAudioInput(BidiInput): _RATE: int = 16000 def __init__(self, config: dict[str, Any]) -> None: - """Extract configs.""" + """Extract configs and track which were explicitly set by user.""" + # Track which config values were explicitly provided by user + self._user_config_set = set(config.keys()) + self._channels = config.get("input_channels", _BidiAudioInput._CHANNELS) self._device_index = config.get("input_device_index", _BidiAudioInput._DEVICE_INDEX) self._format = config.get("input_format", _BidiAudioInput._FORMAT) self._frames_per_buffer = config.get("input_frames_per_buffer", _BidiAudioInput._FRAMES_PER_BUFFER) self._rate = config.get("input_rate", _BidiAudioInput._RATE) - async def start(self) -> None: - """Start input stream.""" + async def start(self, audio_config: AudioConfig | None = None) -> None: + """Start input stream. + + Args: + audio_config: Optional audio configuration from model provider. + Only applied if user did not explicitly set the value + in the constructor. + """ + # Apply audio config overrides only if user didn't explicitly set them + if audio_config: + if "input_rate" in audio_config and "input_rate" not in self._user_config_set: + self._rate = audio_config["input_rate"] + logger.debug("audio_config | applying model input rate: %d Hz", self._rate) + if "channels" in audio_config and "input_channels" not in self._user_config_set: + self._channels = audio_config["channels"] + logger.debug("audio_config | applying model channels: %d", self._channels) + logger.debug( "rate=<%d>, channels=<%d>, device_index=<%s> | starting audio input stream", self._rate, @@ -96,6 +116,7 @@ class _BidiAudioOutput(BidiOutput): _buffer: Deque buffer for queuing audio data. _buffer_event: Event to signal when buffer has data. _output_task: Background task for processing audio output. + _user_config_set: Track which config values were explicitly set by user. """ _audio: pyaudio.PyAudio @@ -103,6 +124,7 @@ class _BidiAudioOutput(BidiOutput): _buffer: deque _buffer_event: asyncio.Event _output_task: asyncio.Task + _user_config_set: set[str] _BUFFER_SIZE: int | None = None _CHANNELS: int = 1 @@ -112,7 +134,10 @@ class _BidiAudioOutput(BidiOutput): _RATE: int = 16000 def __init__(self, config: dict[str, Any]) -> None: - """Extract configs.""" + """Extract configs and track which were explicitly set by user.""" + # Track which config values were explicitly provided by user + self._user_config_set = set(config.keys()) + self._buffer_size = config.get("output_buffer_size", _BidiAudioOutput._BUFFER_SIZE) self._channels = config.get("output_channels", _BidiAudioOutput._CHANNELS) self._device_index = config.get("output_device_index", _BidiAudioOutput._DEVICE_INDEX) @@ -120,8 +145,23 @@ def __init__(self, config: dict[str, Any]) -> None: self._frames_per_buffer = config.get("output_frames_per_buffer", _BidiAudioOutput._FRAMES_PER_BUFFER) self._rate = config.get("output_rate", _BidiAudioOutput._RATE) - async def start(self) -> None: - """Start output stream.""" + async def start(self, audio_config: AudioConfig | None = None) -> None: + """Start output stream. + + Args: + audio_config: Optional audio configuration from model provider. + Only applied if user did not explicitly set the value + in the constructor. + """ + # Apply audio config overrides only if user didn't explicitly set them + if audio_config: + if "output_rate" in audio_config and "output_rate" not in self._user_config_set: + self._rate = audio_config["output_rate"] + logger.debug("audio_config | applying model output rate: %d Hz", self._rate) + if "channels" in audio_config and "output_channels" not in self._user_config_set: + self._channels = audio_config["channels"] + logger.debug("audio_config | applying model channels: %d", self._channels) + logger.debug( "rate=<%d>, channels=<%d>, device_index=<%s>, buffer_size=<%s> | starting audio output stream", self._rate, diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 5f7eb587f..76dfdabd4 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -38,6 +38,7 @@ BidiTranscriptStreamEvent, BidiUsageEvent, ) +from ..types.io import AudioConfig from .bidi_model import BidiModel logger = logging.getLogger(__name__) @@ -61,6 +62,7 @@ def __init__( model_id: str = "gemini-2.5-flash-native-audio-preview-09-2025", api_key: str | None = None, live_config: Dict[str, Any] | None = None, + audio_config: AudioConfig | None = None, **kwargs: Any, ): """Initialize Gemini Live API bidirectional model. @@ -69,6 +71,8 @@ def __init__( model_id: Gemini Live model identifier. api_key: Google AI API key for authentication. live_config: Gemini Live API configuration parameters (e.g., response_modalities, speech_config). + audio_config: Optional audio configuration override. If not provided, + uses Gemini Live API's default configuration. **kwargs: Reserved for future parameters. """ # Model configuration @@ -104,6 +108,37 @@ def __init__( self.connection_id | str self._active: bool = False + # Extract voice from live_config if provided + default_voice = None + if self.live_config and "speech_config" in self.live_config: + speech_config = self.live_config["speech_config"] + if isinstance(speech_config, dict): + default_voice = speech_config.get("voice_config", {}).get("prebuilt_voice_config", {}).get("voice_name") + + # Build audio configuration - use provided values or defaults + config_dict: AudioConfig = { + "input_rate": audio_config.get("input_rate", GEMINI_INPUT_SAMPLE_RATE) + if audio_config + else GEMINI_INPUT_SAMPLE_RATE, + "output_rate": audio_config.get("output_rate", GEMINI_OUTPUT_SAMPLE_RATE) + if audio_config + else GEMINI_OUTPUT_SAMPLE_RATE, + "channels": audio_config.get("channels", GEMINI_CHANNELS) if audio_config else GEMINI_CHANNELS, + "format": audio_config.get("format", "pcm") if audio_config else "pcm", + } + + # Add voice if configured (either from user or live_config) + voice_value = audio_config.get("voice", default_voice) if audio_config else default_voice + if voice_value: + config_dict["voice"] = voice_value + + self.audio_config = config_dict + + if audio_config: + logger.debug("audio_config | merged user-provided config with defaults") + else: + logger.debug("audio_config | using default Gemini Live audio configuration") + async def start( self, system_prompt: Optional[str] = None, @@ -267,8 +302,8 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOut BidiAudioStreamEvent( audio=audio_b64, format="pcm", - sample_rate=GEMINI_OUTPUT_SAMPLE_RATE, # type: ignore - channels=GEMINI_CHANNELS, # type: ignore + sample_rate=self.audio_config["output_rate"], # type: ignore + channels=self.audio_config["channels"], # type: ignore ) ] @@ -406,8 +441,10 @@ async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: # Decode base64 audio to bytes for SDK audio_bytes = base64.b64decode(audio_input.audio) - # Create audio blob for the SDK - audio_blob = genai_types.Blob(data=audio_bytes, mime_type=f"audio/pcm;rate={GEMINI_INPUT_SAMPLE_RATE}") + # Create audio blob for the SDK using audio_config + audio_blob = genai_types.Blob( + data=audio_bytes, mime_type=f"audio/pcm;rate={self.audio_config['input_rate']}" + ) # Send real-time audio input - this automatically handles VAD and interruption await self.live_session.send_realtime_input(audio=audio_blob) @@ -508,6 +545,12 @@ def _build_live_config( if tools: config_dict["tools"] = self._format_tools_for_live_api(tools) + # Override voice with audio_config value if present (audio_config takes precedence) + if "voice" in self.audio_config: + config_dict.setdefault("speech_config", {}).setdefault("voice_config", {}).setdefault( + "prebuilt_voice_config", {} + )["voice_name"] = self.audio_config["voice"] + return config_dict def _format_tools_for_live_api(self, tool_specs: List[ToolSpec]) -> List[genai_types.Tool]: diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 9cf47b6b0..31e177c55 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -17,7 +17,7 @@ import json import logging import uuid -from typing import Any, AsyncIterable +from typing import Any, AsyncIterable, Literal import boto3 from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput @@ -47,6 +47,7 @@ BidiTranscriptStreamEvent, BidiUsageEvent, ) +from ..types.io import AudioConfig from .bidi_model import BidiModel logger = logging.getLogger(__name__) @@ -95,6 +96,7 @@ def __init__( model_id: str = "amazon.nova-sonic-v1:0", boto_session: boto3.Session | None = None, region: str | None = None, + audio_config: AudioConfig | None = None, **kwargs: Any, ) -> None: """Initialize Nova Sonic bidirectional model. @@ -103,6 +105,8 @@ def __init__( model_id: Nova Sonic model identifier. boto_session: Boto Session to use when calling the Nova Sonic Model. region: AWS region + audio_config: Optional audio configuration override. If not provided, + uses Nova Sonic's default configuration. **kwargs: Reserved for future parameters. """ if region and boto_session: @@ -129,6 +133,28 @@ def __init__( logger.debug("model_id=<%s> | nova sonic model initialized", model_id) + # Build audio configuration - use provided values or defaults + self.audio_config: AudioConfig = { + "input_rate": audio_config.get("input_rate", NOVA_AUDIO_INPUT_CONFIG["sampleRateHertz"]) + if audio_config + else NOVA_AUDIO_INPUT_CONFIG["sampleRateHertz"], # type: ignore[typeddict-item] + "output_rate": audio_config.get("output_rate", NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]) + if audio_config + else NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"], # type: ignore[typeddict-item] + "channels": audio_config.get("channels", NOVA_AUDIO_INPUT_CONFIG["channelCount"]) + if audio_config + else NOVA_AUDIO_INPUT_CONFIG["channelCount"], # type: ignore[typeddict-item] + "format": audio_config.get("format", "pcm") if audio_config else "pcm", + "voice": audio_config.get("voice", NOVA_AUDIO_OUTPUT_CONFIG["voiceId"]) + if audio_config + else NOVA_AUDIO_OUTPUT_CONFIG["voiceId"], # type: ignore[typeddict-item] + } + + if audio_config: + logger.debug("audio_config | merged user-provided config with defaults") + else: + logger.debug("audio_config | using default Nova Sonic audio configuration") + @start async def start( self, @@ -283,6 +309,16 @@ async def _start_audio_connection(self) -> None: logger.debug("nova audio connection starting") self._audio_content_name = str(uuid.uuid4()) + # Build audio input configuration from audio_config + audio_input_config = { + "mediaType": "audio/lpcm", + "sampleRateHertz": self.audio_config["input_rate"], + "sampleSizeBits": 16, + "channelCount": self.audio_config["channels"], + "audioType": "SPEECH", + "encoding": "base64", + } + audio_content_start = json.dumps( { "event": { @@ -292,7 +328,7 @@ async def _start_audio_connection(self) -> None: "type": "AUDIO", "interactive": True, "role": "USER", - "audioInputConfiguration": NOVA_AUDIO_INPUT_CONFIG, + "audioInputConfiguration": audio_input_config, } } } @@ -422,11 +458,13 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N if "audioOutput" in nova_event: # Audio is already base64 string from Nova Sonic audio_content = nova_event["audioOutput"]["content"] + # Channels from audio_config is guaranteed to be 1 or 2 + channels: Literal[1, 2] = self.audio_config["channels"] # type: ignore[assignment] return BidiAudioStreamEvent( audio=audio_content, format="pcm", - sample_rate=NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"], # type: ignore - channels=1, + sample_rate=self.audio_config["output_rate"], # type: ignore + channels=channels, ) # Handle text output (transcripts) @@ -498,12 +536,23 @@ def _get_connection_start_event(self) -> str: def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: """Generate Nova Sonic prompt start event with tool configuration.""" + # Build audio output configuration from audio_config + audio_output_config = { + "mediaType": "audio/lpcm", + "sampleRateHertz": self.audio_config["output_rate"], + "sampleSizeBits": 16, + "channelCount": self.audio_config["channels"], + "voiceId": self.audio_config.get("voice", "matthew"), + "encoding": "base64", + "audioType": "SPEECH", + } + prompt_start_event: dict[str, Any] = { "event": { "promptStart": { "promptName": self._connection_id, "textOutputConfiguration": NOVA_TEXT_CONFIG, - "audioOutputConfiguration": NOVA_AUDIO_OUTPUT_CONFIG, + "audioOutputConfiguration": audio_output_config, } } } diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 3cda4f738..2e486cdcd 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -8,7 +8,7 @@ import logging import os import uuid -from typing import Any, AsyncIterable +from typing import Any, AsyncIterable, Literal import websockets from websockets import ClientConnection @@ -32,6 +32,7 @@ BidiTranscriptStreamEvent, BidiUsageEvent, ) +from ..types.io import AudioConfig from .bidi_model import BidiModel logger = logging.getLogger(__name__) @@ -77,6 +78,7 @@ def __init__( organization: str | None = None, project: str | None = None, session_config: dict[str, Any] | None = None, + audio_config: AudioConfig | None = None, **kwargs: Any, ) -> None: """Initialize OpenAI Realtime bidirectional model. @@ -87,6 +89,8 @@ def __init__( organization: OpenAI organization ID for API requests. project: OpenAI project ID for API requests. session_config: Session configuration parameters (e.g., voice, turn_detection, modalities). + audio_config: Optional audio configuration override. If not provided, + uses OpenAI Realtime API's default configuration. **kwargs: Reserved for future parameters. """ # Model configuration @@ -112,6 +116,33 @@ def __init__( logger.debug("model=<%s> | openai realtime model initialized", model) + # Extract voice from session_config if provided, otherwise use default + default_voice = "alloy" + if self.session_config and "audio" in self.session_config: + audio_settings = self.session_config["audio"] + if isinstance(audio_settings, dict) and "output" in audio_settings: + output_settings = audio_settings["output"] + if isinstance(output_settings, dict): + default_voice = output_settings.get("voice", default_voice) + + # Build audio configuration - use provided values or defaults + self.audio_config: AudioConfig = { + "input_rate": audio_config.get("input_rate", AUDIO_FORMAT["rate"]) + if audio_config + else AUDIO_FORMAT["rate"], # type: ignore[typeddict-item] + "output_rate": audio_config.get("output_rate", AUDIO_FORMAT["rate"]) + if audio_config + else AUDIO_FORMAT["rate"], # type: ignore[typeddict-item] + "channels": audio_config.get("channels", 1) if audio_config else 1, + "format": audio_config.get("format", "pcm") if audio_config else "pcm", + "voice": audio_config.get("voice", default_voice) if audio_config else default_voice, + } + + if audio_config: + logger.debug("audio_config | merged user-provided config with defaults") + else: + logger.debug("audio_config | using default OpenAI Realtime audio configuration") + async def start( self, system_prompt: str | None = None, @@ -227,6 +258,10 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] else: logger.warning("parameter=<%s> | ignoring unsupported session parameter", key) + # Override voice with audio_config value if present (audio_config takes precedence) + if "voice" in self.audio_config: + config.setdefault("audio", {}).setdefault("output", {})["voice"] = self.audio_config["voice"] # type: ignore + return config def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: @@ -310,12 +345,14 @@ def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutput # Audio output elif event_type == "response.output_audio.delta": # Audio is already base64 string from OpenAI + # Channels from audio_config is guaranteed to be 1 or 2 + channels: Literal[1, 2] = self.audio_config["channels"] # type: ignore[assignment] return [ BidiAudioStreamEvent( audio=openai_event["delta"], format="pcm", - sample_rate=AUDIO_FORMAT["rate"], # type: ignore - channels=1, + sample_rate=self.audio_config["output_rate"], # type: ignore + channels=channels, ) ] diff --git a/src/strands/experimental/bidi/types/io.py b/src/strands/experimental/bidi/types/io.py index 10ae5db77..7a696e85b 100644 --- a/src/strands/experimental/bidi/types/io.py +++ b/src/strands/experimental/bidi/types/io.py @@ -5,11 +5,33 @@ by separating input and output concerns into independent callables. """ -from typing import Awaitable, Protocol +from typing import Awaitable, Literal, Protocol, TypedDict from ..types.events import BidiInputEvent, BidiOutputEvent +class AudioConfig(TypedDict, total=False): + """Audio configuration for bidirectional streaming. + + Defines standard audio parameters shared between model providers + and audio I/O implementations. All fields are optional to support + models that may not use audio or only need specific parameters. + + Attributes: + input_rate: Input sample rate in Hz (e.g., 16000, 24000, 48000) + output_rate: Output sample rate in Hz (e.g., 16000, 24000, 48000) + channels: Number of audio channels (1=mono, 2=stereo) + format: Audio encoding format + voice: Voice identifier for text-to-speech (e.g., "alloy", "matthew") + """ + + input_rate: int + output_rate: int + channels: int + format: Literal["pcm", "wav", "opus", "mp3"] + voice: str + + class BidiInput(Protocol): """Protocol for bidirectional input callables. diff --git a/tests/strands/experimental/bidi/io/test_audio.py b/tests/strands/experimental/bidi/io/test_audio.py index e5e710b98..9517ad108 100644 --- a/tests/strands/experimental/bidi/io/test_audio.py +++ b/tests/strands/experimental/bidi/io/test_audio.py @@ -79,3 +79,125 @@ def write(data): await audio_output.stop() speaker.write.assert_called_once_with(write_future.result()) + + +# Audio Configuration Tests + + +@pytest.mark.asyncio +async def test_audio_input_respects_user_config(py_audio): + """Test that user-provided config takes precedence over model config.""" + audio_io = BidiAudioIO(input_rate=48000, input_channels=2) + audio_input = audio_io.input() + + microphone = unittest.mock.Mock() + microphone.read.return_value = b"test-audio" + py_audio.open.return_value = microphone + + # Model provides different config + model_audio_config = {"input_rate": 16000, "channels": 1} + + await audio_input.start(audio_config=model_audio_config) + + # User config should be used + py_audio.open.assert_called_once() + call_kwargs = py_audio.open.call_args.kwargs + assert call_kwargs["rate"] == 48000 # User config + assert call_kwargs["channels"] == 2 # User config + + await audio_input.stop() + + +@pytest.mark.asyncio +async def test_audio_input_applies_model_config_when_user_not_set(py_audio): + """Test that model config is applied when user doesn't provide values.""" + audio_io = BidiAudioIO() # No user config + audio_input = audio_io.input() + + microphone = unittest.mock.Mock() + microphone.read.return_value = b"test-audio" + py_audio.open.return_value = microphone + + # Model provides config + model_audio_config = {"input_rate": 24000, "channels": 2} + + await audio_input.start(audio_config=model_audio_config) + + # Model config should be used + py_audio.open.assert_called_once() + call_kwargs = py_audio.open.call_args.kwargs + assert call_kwargs["rate"] == 24000 # Model config + assert call_kwargs["channels"] == 2 # Model config + + await audio_input.stop() + + +@pytest.mark.asyncio +async def test_audio_output_respects_user_config(py_audio): + """Test that user-provided config takes precedence over model config.""" + audio_io = BidiAudioIO(output_rate=48000, output_channels=2) + audio_output = audio_io.output() + + speaker = unittest.mock.Mock() + py_audio.open.return_value = speaker + + # Model provides different config + model_audio_config = {"output_rate": 16000, "channels": 1} + + await audio_output.start(audio_config=model_audio_config) + + # User config should be used + py_audio.open.assert_called_once() + call_kwargs = py_audio.open.call_args.kwargs + assert call_kwargs["rate"] == 48000 # User config + assert call_kwargs["channels"] == 2 # User config + + await audio_output.stop() + + +@pytest.mark.asyncio +async def test_audio_output_applies_model_config_when_user_not_set(py_audio): + """Test that model config is applied when user doesn't provide values.""" + audio_io = BidiAudioIO() # No user config + audio_output = audio_io.output() + + speaker = unittest.mock.Mock() + py_audio.open.return_value = speaker + + # Model provides config + model_audio_config = {"output_rate": 24000, "channels": 2} + + await audio_output.start(audio_config=model_audio_config) + + # Model config should be used + py_audio.open.assert_called_once() + call_kwargs = py_audio.open.call_args.kwargs + assert call_kwargs["rate"] == 24000 # Model config + assert call_kwargs["channels"] == 2 # Model config + + await audio_output.stop() + + +@pytest.mark.asyncio +async def test_audio_partial_user_config(py_audio): + """Test that partial user config works correctly.""" + # User only sets rate, not channels + audio_io = BidiAudioIO(input_rate=48000) + audio_input = audio_io.input() + + microphone = unittest.mock.Mock() + microphone.read.return_value = b"test-audio" + py_audio.open.return_value = microphone + + # Model provides both rate and channels + model_audio_config = {"input_rate": 16000, "channels": 2} + + await audio_input.start(audio_config=model_audio_config) + + # User rate should be used, model channels should be applied + py_audio.open.assert_called_once() + call_kwargs = py_audio.open.call_args.kwargs + assert call_kwargs["rate"] == 48000 # User config + assert call_kwargs["channels"] == 2 # Model config + + await audio_input.stop() diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index 6a2c79ece..4134e4a2b 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -454,6 +454,73 @@ async def test_event_conversion(mock_genai_client, model): await model.stop() +# Audio Configuration Tests + + +def test_audio_config_defaults(mock_genai_client, model_id, api_key): + """Test default audio configuration.""" + _ = mock_genai_client + + model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) + + assert model.audio_config["input_rate"] == 16000 + assert model.audio_config["output_rate"] == 24000 + assert model.audio_config["channels"] == 1 + assert model.audio_config["format"] == "pcm" + assert "voice" not in model.audio_config # No default voice + + +def test_audio_config_partial_override(mock_genai_client, model_id, api_key): + """Test partial audio configuration override.""" + _ = mock_genai_client + + audio_config = {"output_rate": 48000, "voice": "Puck"} + model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key, audio_config=audio_config) + + # Overridden values + assert model.audio_config["output_rate"] == 48000 + assert model.audio_config["voice"] == "Puck" + + # Default values preserved + assert model.audio_config["input_rate"] == 16000 + assert model.audio_config["channels"] == 1 + assert model.audio_config["format"] == "pcm" + + +def test_audio_config_full_override(mock_genai_client, model_id, api_key): + """Test full audio configuration override.""" + _ = mock_genai_client + + audio_config = { + "input_rate": 48000, + "output_rate": 48000, + "channels": 2, + "format": "pcm", + "voice": "Aoede", + } + model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key, audio_config=audio_config) + + assert model.audio_config["input_rate"] == 48000 + assert model.audio_config["output_rate"] == 48000 + assert model.audio_config["channels"] == 2 + assert model.audio_config["format"] == "pcm" + assert model.audio_config["voice"] == "Aoede" + + +def test_audio_config_voice_priority(mock_genai_client, model_id, api_key): + """Test that audio_config voice takes precedence over live_config voice.""" + _ = mock_genai_client + + live_config = {"speech_config": {"voice_config": {"prebuilt_voice_config": {"voice_name": "Puck"}}}} + audio_config = {"voice": "Aoede"} + + model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key, live_config=live_config, audio_config=audio_config) + + # Build config and verify audio_config voice takes precedence + config = model._build_live_config() + assert config["speech_config"]["voice_config"]["prebuilt_voice_config"]["voice_name"] == "Aoede" + + # Helper Method Tests diff --git a/tests/strands/experimental/bidi/models/test_novasonic.py b/tests/strands/experimental/bidi/models/test_novasonic.py index 49c6ec8f7..079850bfd 100644 --- a/tests/strands/experimental/bidi/models/test_novasonic.py +++ b/tests/strands/experimental/bidi/models/test_novasonic.py @@ -86,6 +86,56 @@ async def test_model_initialization(model_id, region): assert model._connection_id is None +# Audio Configuration Tests + + +@pytest.mark.asyncio +async def test_audio_config_defaults(model_id, region): + """Test default audio configuration.""" + model = BidiNovaSonicModel(model_id=model_id, region=region) + + assert model.audio_config["input_rate"] == 16000 + assert model.audio_config["output_rate"] == 16000 + assert model.audio_config["channels"] == 1 + assert model.audio_config["format"] == "pcm" + assert model.audio_config["voice"] == "matthew" + + +@pytest.mark.asyncio +async def test_audio_config_partial_override(model_id, region): + """Test partial audio configuration override.""" + audio_config = {"output_rate": 24000, "voice": "ruth"} + model = BidiNovaSonicModel(model_id=model_id, region=region, audio_config=audio_config) + + # Overridden values + assert model.audio_config["output_rate"] == 24000 + assert model.audio_config["voice"] == "ruth" + + # Default values preserved + assert model.audio_config["input_rate"] == 16000 + assert model.audio_config["channels"] == 1 + assert model.audio_config["format"] == "pcm" + + +@pytest.mark.asyncio +async def test_audio_config_full_override(model_id, region): + """Test full audio configuration override.""" + audio_config = { + "input_rate": 48000, + "output_rate": 48000, + "channels": 2, + "format": "pcm", + "voice": "stephen", + } + model = BidiNovaSonicModel(model_id=model_id, region=region, audio_config=audio_config) + + assert model.audio_config["input_rate"] == 48000 + assert model.audio_config["output_rate"] == 48000 + assert model.audio_config["channels"] == 2 + assert model.audio_config["format"] == "pcm" + assert model.audio_config["voice"] == "stephen" + + @pytest.mark.asyncio async def test_connection_lifecycle(nova_model, mock_client, mock_stream): """Test complete connection lifecycle with various configurations.""" diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py index 2ffcac7ae..242f769da 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -112,6 +112,77 @@ def test_model_initialization(api_key, model_name): assert model_env.api_key == "env-key" +# Audio Configuration Tests + + +def test_audio_config_defaults(api_key, model_name): + """Test default audio configuration.""" + model = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) + + assert model.audio_config["input_rate"] == 24000 + assert model.audio_config["output_rate"] == 24000 + assert model.audio_config["channels"] == 1 + assert model.audio_config["format"] == "pcm" + assert model.audio_config["voice"] == "alloy" + + +def test_audio_config_partial_override(api_key, model_name): + """Test partial audio configuration override.""" + audio_config = {"output_rate": 48000, "voice": "echo"} + model = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key, audio_config=audio_config) + + # Overridden values + assert model.audio_config["output_rate"] == 48000 + assert model.audio_config["voice"] == "echo" + + # Default values preserved + assert model.audio_config["input_rate"] == 24000 + assert model.audio_config["channels"] == 1 + assert model.audio_config["format"] == "pcm" + + +def test_audio_config_full_override(api_key, model_name): + """Test full audio configuration override.""" + audio_config = { + "input_rate": 48000, + "output_rate": 48000, + "channels": 2, + "format": "pcm", + "voice": "shimmer", + } + model = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key, audio_config=audio_config) + + assert model.audio_config["input_rate"] == 48000 + assert model.audio_config["output_rate"] == 48000 + assert model.audio_config["channels"] == 2 + assert model.audio_config["format"] == "pcm" + assert model.audio_config["voice"] == "shimmer" + + +def test_audio_config_voice_priority(api_key, model_name): + """Test that audio_config voice takes precedence over session_config voice.""" + session_config = {"audio": {"output": {"voice": "alloy"}}} + audio_config = {"voice": "nova"} + + model = BidiOpenAIRealtimeModel( + model=model_name, api_key=api_key, session_config=session_config, audio_config=audio_config + ) + + # Build config and verify audio_config voice takes precedence + config = model._build_session_config(None, None) + assert config["audio"]["output"]["voice"] == "nova" + + +def test_audio_config_extracts_voice_from_session_config(api_key, model_name): + """Test that voice is extracted from session_config when audio_config not provided.""" + session_config = {"audio": {"output": {"voice": "fable"}}} + + model = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key, session_config=session_config) + + # Should extract voice from session_config + assert model.audio_config["voice"] == "fable" + + def test_init_without_api_key_raises(): """Test that initialization without API key raises error.""" with unittest.mock.patch.dict("os.environ", {}, clear=True): From a308eee6b721f0fa84b7d3d398269b898d933487 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 21 Nov 2025 14:02:32 -0500 Subject: [PATCH 174/242] async - task pool (#71) --- src/strands/experimental/bidi/_async.py | 50 ---------------- .../experimental/bidi/_async/__init__.py | 29 ++++++++++ .../experimental/bidi/_async/_task_pool.py | 43 ++++++++++++++ .../experimental/bidi/models/bidi_model.py | 3 +- .../experimental/bidi/models/novasonic.py | 27 ++++----- src/strands/experimental/bidi/types/_async.py | 15 ----- .../experimental/bidi/_async/__init__.py | 0 .../experimental/bidi/_async/test__init__.py | 36 ++++++++++++ .../bidi/_async/test_task_pool.py | 54 ++++++++++++++++++ tests/strands/experimental/bidi/test_async.py | 57 ------------------- 10 files changed, 174 insertions(+), 140 deletions(-) delete mode 100644 src/strands/experimental/bidi/_async.py create mode 100644 src/strands/experimental/bidi/_async/__init__.py create mode 100644 src/strands/experimental/bidi/_async/_task_pool.py delete mode 100644 src/strands/experimental/bidi/types/_async.py create mode 100644 tests/strands/experimental/bidi/_async/__init__.py create mode 100644 tests/strands/experimental/bidi/_async/test__init__.py create mode 100644 tests/strands/experimental/bidi/_async/test_task_pool.py delete mode 100644 tests/strands/experimental/bidi/test_async.py diff --git a/src/strands/experimental/bidi/_async.py b/src/strands/experimental/bidi/_async.py deleted file mode 100644 index a4a126c16..000000000 --- a/src/strands/experimental/bidi/_async.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Utilities for async operations.""" - -from functools import wraps -from typing import Any, Awaitable, Callable, TypeVar, cast - -from .types._async import Startable - -F = TypeVar("F", bound=Callable[..., Awaitable[None]]) - - -def start(func: F) -> F: - """Call stop if pairing start call fails. - - Any resources that did successfully start will still have an opportunity to stop cleanly. - - Args: - func: Start function to wrap. - """ - - @wraps(func) - async def wrapper(self: Startable, *args: Any, **kwargs: Any) -> None: - try: - await func(self, *args, **kwargs) - except Exception: - await self.stop() - raise - - return cast(F, wrapper) - - -async def stop(*funcs: F) -> None: - """Call all stops in sequence and aggregate errors. - - A failure in one stop call will not block subsequent stop calls. - - Args: - funcs: Stop functions to call in sequence. - - Raises: - ExceptionGroup: If any stop function raises an exception. - """ - exceptions = [] - for func in funcs: - try: - await func() - except Exception as exception: - exceptions.append(exception) - - if exceptions: - raise ExceptionGroup("failed stop sequence", exceptions) # type: ignore # noqa: F821 diff --git a/src/strands/experimental/bidi/_async/__init__.py b/src/strands/experimental/bidi/_async/__init__.py new file mode 100644 index 000000000..2a2d5fb0e --- /dev/null +++ b/src/strands/experimental/bidi/_async/__init__.py @@ -0,0 +1,29 @@ +"""Utilities for async operations.""" + +from typing import Awaitable, Callable + +from ._task_pool import _TaskPool + +__all__ = ["_TaskPool"] + + +async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: + """Call all stops in sequence and aggregate errors. + + A failure in one stop call will not block subsequent stop calls. + + Args: + funcs: Stop functions to call in sequence. + + Raises: + ExceptionGroup: If any stop function raises an exception. + """ + exceptions = [] + for func in funcs: + try: + await func() + except Exception as exception: + exceptions.append(exception) + + if exceptions: + raise ExceptionGroup("failed stop sequence", exceptions) # type: ignore # noqa: F821 diff --git a/src/strands/experimental/bidi/_async/_task_pool.py b/src/strands/experimental/bidi/_async/_task_pool.py new file mode 100644 index 000000000..83146fd5f --- /dev/null +++ b/src/strands/experimental/bidi/_async/_task_pool.py @@ -0,0 +1,43 @@ +"""Manage pool of active async tasks. + +This is particularly useful for cancelling multiple tasks at once. +""" + +import asyncio +from typing import Any, Coroutine + + +class _TaskPool: + """Manage pool of active async tasks.""" + + def __init__(self) -> None: + """Setup task container.""" + self._tasks: set[asyncio.Task] = set() + + def __len__(self) -> int: + """Number of active tasks.""" + return len(self._tasks) + + def create(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task: + """Create async task. + + Adds a clean up callback to run after task completes. + + Returns: + The created task. + """ + task = asyncio.create_task(coro) + task.add_done_callback(lambda task: self._tasks.remove(task)) + + self._tasks.add(task) + return task + + async def cancel(self) -> None: + """Cancel all active tasks in pool.""" + for task in self._tasks: + task.cancel() + + try: + await asyncio.gather(*self._tasks) + except asyncio.CancelledError: + pass diff --git a/src/strands/experimental/bidi/models/bidi_model.py b/src/strands/experimental/bidi/models/bidi_model.py index 856a2edcf..04e4a69e4 100644 --- a/src/strands/experimental/bidi/models/bidi_model.py +++ b/src/strands/experimental/bidi/models/bidi_model.py @@ -18,7 +18,6 @@ from ....types._events import ToolResultEvent from ....types.content import Messages from ....types.tools import ToolSpec -from ..types._async import Startable from ..types.events import ( BidiInputEvent, BidiOutputEvent, @@ -27,7 +26,7 @@ logger = logging.getLogger(__name__) -class BidiModel(Startable, Protocol): +class BidiModel(Protocol): """Protocol for bidirectional streaming models. This interface defines the contract for models that support persistent streaming diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 9cf47b6b0..892cf4af5 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -32,11 +32,10 @@ from ....types._events import ToolResultEvent, ToolUseStreamEvent from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse -from .._async import start, stop +from .._async import stop_all from ..types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, - BidiConnectionCloseEvent, BidiConnectionStartEvent, BidiInputEvent, BidiInterruptionEvent, @@ -129,7 +128,6 @@ def __init__( logger.debug("model_id=<%s> | nova sonic model initialized", model_id) - @start async def start( self, system_prompt: str | None = None, @@ -240,18 +238,15 @@ async def receive(self) -> AsyncIterable[BidiOutputEvent]: # type: ignore[overr logger.debug("nova event stream starting") yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) - try: - _, output = await self._stream.await_output() - while True: - event_data = await output.receive() - nova_event = json.loads(event_data.value.bytes_.decode("utf-8"))["event"] - self._log_event_type(nova_event) + _, output = await self._stream.await_output() + while True: + event_data = await output.receive() + nova_event = json.loads(event_data.value.bytes_.decode("utf-8"))["event"] + self._log_event_type(nova_event) - model_event = self._convert_nova_event(nova_event) - if model_event: - yield model_event - finally: - yield BidiConnectionCloseEvent(connection_id=self._connection_id, reason="complete") + model_event = self._convert_nova_event(nova_event) + if model_event: + yield model_event async def send(self, content: BidiInputEvent | ToolResultEvent) -> None: """Unified send method for all content types. Sends the given content to Nova Sonic. @@ -382,7 +377,7 @@ async def stop_events() -> None: await self._send_nova_events(cleanup_events) async def stop_stream() -> None: - if not self._connection_id or not self._stream: + if not hasattr(self, "_stream"): return await self._stream.close() @@ -390,7 +385,7 @@ async def stop_stream() -> None: async def stop_connection() -> None: self._connection_id = None - await stop(stop_events, stop_stream, stop_connection) + await stop_all(stop_events, stop_stream, stop_connection) logger.debug("nova connection closed") diff --git a/src/strands/experimental/bidi/types/_async.py b/src/strands/experimental/bidi/types/_async.py deleted file mode 100644 index 0d4309aff..000000000 --- a/src/strands/experimental/bidi/types/_async.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Types for custom async constructs.""" - -from typing import Any, Awaitable, Protocol - - -class Startable(Protocol): - """A construct that must first be started before use.""" - - def start(self, *args: Any, **kwargs: Any) -> Awaitable[None]: - """Setup resources and start connections.""" - ... - - def stop(self) -> Awaitable[None]: - """Tear down resources and stop connections.""" - ... diff --git a/tests/strands/experimental/bidi/_async/__init__.py b/tests/strands/experimental/bidi/_async/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/bidi/_async/test__init__.py b/tests/strands/experimental/bidi/_async/test__init__.py new file mode 100644 index 000000000..ac4b1ab61 --- /dev/null +++ b/tests/strands/experimental/bidi/_async/test__init__.py @@ -0,0 +1,36 @@ +from unittest.mock import AsyncMock + +import pytest + +from strands.experimental.bidi._async import stop_all + + +@pytest.mark.asyncio +async def test_stop_exception(): + func1 = AsyncMock() + func2 = AsyncMock(side_effect=ValueError("stop 2 failed")) + func3 = AsyncMock() + + with pytest.raises(ExceptionGroup) as exc_info: # type: ignore # noqa: F821 + await stop_all(func1, func2, func3) + + func1.assert_called_once() + func2.assert_called_once() + func3.assert_called_once() + + assert len(exc_info.value.exceptions) == 1 + with pytest.raises(ValueError, match=r"stop 2 failed"): + raise exc_info.value.exceptions[0] + + +@pytest.mark.asyncio +async def test_stop_success(): + func1 = AsyncMock() + func2 = AsyncMock() + func3 = AsyncMock() + + await stop_all(func1, func2, func3) + + func1.assert_called_once() + func2.assert_called_once() + func3.assert_called_once() diff --git a/tests/strands/experimental/bidi/_async/test_task_pool.py b/tests/strands/experimental/bidi/_async/test_task_pool.py new file mode 100644 index 000000000..35f817954 --- /dev/null +++ b/tests/strands/experimental/bidi/_async/test_task_pool.py @@ -0,0 +1,54 @@ +import asyncio + +import pytest + +from strands.experimental.bidi._async._task_pool import _TaskPool + + +@pytest.fixture +def task_pool() -> _TaskPool: + return _TaskPool() + + +def test_len(task_pool): + tru_len = len(task_pool) + exp_len = 0 + assert tru_len == exp_len + + +@pytest.mark.asyncio +async def test_create(task_pool: _TaskPool) -> None: + event = asyncio.Event() + + async def coro(): + await event.wait() + + task = task_pool.create(coro()) + + tru_len = len(task_pool) + exp_len = 1 + assert tru_len == exp_len + + event.set() + await task + + tru_len = len(task_pool) + exp_len = 0 + assert tru_len == exp_len + + +@pytest.mark.asyncio +async def test_cancel(task_pool: _TaskPool) -> None: + event = asyncio.Event() + + async def coro(): + await event.wait() + + task = task_pool.create(coro()) + await task_pool.cancel() + + tru_len = len(task_pool) + exp_len = 0 + assert tru_len == exp_len + + assert task.done() diff --git a/tests/strands/experimental/bidi/test_async.py b/tests/strands/experimental/bidi/test_async.py deleted file mode 100644 index 10b3016df..000000000 --- a/tests/strands/experimental/bidi/test_async.py +++ /dev/null @@ -1,57 +0,0 @@ -from unittest.mock import AsyncMock, Mock - -import pytest - -from strands.experimental.bidi._async import start, stop - - -@pytest.fixture -def mock_startable(): - return Mock(start=AsyncMock(), stop=AsyncMock()) - - -@pytest.mark.asyncio -async def test_start_exception(mock_startable): - mock_startable.start.side_effect = ValueError("start failed") - - with pytest.raises(ValueError, match=r"start failed"): - await start(mock_startable.start)(mock_startable) - - mock_startable.stop.assert_called_once() - - -@pytest.mark.asyncio -async def test_start_success(mock_startable): - await start(mock_startable.start)(mock_startable) - mock_startable.stop.assert_not_called() - - -@pytest.mark.asyncio -async def test_stop_exception(): - func1 = AsyncMock() - func2 = AsyncMock(side_effect=ValueError("stop 2 failed")) - func3 = AsyncMock() - - with pytest.raises(ExceptionGroup) as exc_info: # type: ignore # noqa: F821 - await stop(func1, func2, func3) - - func1.assert_called_once() - func2.assert_called_once() - func3.assert_called_once() - - assert len(exc_info.value.exceptions) == 1 - with pytest.raises(ValueError, match=r"stop 2 failed"): - raise exc_info.value.exceptions[0] - - -@pytest.mark.asyncio -async def test_stop_success(): - func1 = AsyncMock() - func2 = AsyncMock() - func3 = AsyncMock() - - await stop(func1, func2, func3) - - func1.assert_called_once() - func2.assert_called_once() - func3.assert_called_once() From 1b5f9a61cdf769067a54a2dfe254644bbaf7f43d Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 21 Nov 2025 15:32:09 -0500 Subject: [PATCH 175/242] cancellation - openai (#73) --- .../experimental/bidi/models/novasonic.py | 8 +- .../experimental/bidi/models/openai.py | 151 +++++++----------- .../bidi/models/test_novasonic.py | 2 +- ...test_openai_realtime.py => test_openai.py} | 39 ++--- 4 files changed, 77 insertions(+), 123 deletions(-) rename tests/strands/experimental/bidi/models/{test_openai_realtime.py => test_openai.py} (92%) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 892cf4af5..d1821af6b 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -147,7 +147,7 @@ async def start( RuntimeError: If user calls start again without first stopping. """ if self._connection_id: - raise RuntimeError("call stop before starting again") + raise RuntimeError("model already started | call stop before starting again") logger.debug("nova connection starting") @@ -233,7 +233,7 @@ async def receive(self) -> AsyncIterable[BidiOutputEvent]: # type: ignore[overr RuntimeError: If start has not been called. """ if not self._connection_id: - raise RuntimeError("must call start") + raise RuntimeError("model not started | call start before receiving") logger.debug("nova event stream starting") yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) @@ -260,7 +260,7 @@ async def send(self, content: BidiInputEvent | ToolResultEvent) -> None: ValueError: If content type not supported (e.g., image content). """ if not self._connection_id: - raise RuntimeError("must call start") + raise RuntimeError("model not started | call start before sending") if isinstance(content, BidiTextInputEvent): await self._send_text_content(content.text) @@ -271,7 +271,7 @@ async def send(self, content: BidiInputEvent | ToolResultEvent) -> None: if tool_result: await self._send_tool_result(tool_result) else: - raise ValueError(f"content_type={type(content)} | content not supported by nova sonic") + raise ValueError(f"content_type={type(content)} | content not supported") async def _start_audio_connection(self) -> None: """Internal: Start audio input connection (call once before sending audio chunks).""" diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 3cda4f738..cc82bbea8 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -16,13 +16,11 @@ from ....types._events import ToolResultEvent, ToolUseStreamEvent from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse +from .._async import stop_all from ..types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, - BidiConnectionCloseEvent, BidiConnectionStartEvent, - BidiErrorEvent, - BidiImageInputEvent, BidiInputEvent, BidiInterruptionEvent, BidiOutputEvent, @@ -70,6 +68,8 @@ class BidiOpenAIRealtimeModel(BidiModel): function calling, and event conversion to Strands format. """ + _websocket: ClientConnection + def __init__( self, model: str = DEFAULT_MODEL, @@ -104,9 +104,7 @@ def __init__( ) # Connection state (initialized in start()) - self.websocket: ClientConnection - self.connection_id: str - self._active: bool = False + self._connection_id: str | None = None self._function_call_buffer: dict[str, Any] = {} @@ -127,45 +125,35 @@ async def start( messages: Conversation history to initialize with. **kwargs: Additional configuration options. """ - if self._active: - raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") + if self._connection_id: + raise RuntimeError("model already started | call stop before starting again") logger.info("openai realtime connection starting") - try: - # Initialize connection state - self.connection_id = str(uuid.uuid4()) - self._active = True - self._function_call_buffer = {} - - # Establish WebSocket connection - url = f"{OPENAI_REALTIME_URL}?model={self.model}" + # Initialize connection state + self._connection_id = str(uuid.uuid4()) - headers = [("Authorization", f"Bearer {self.api_key}")] - if self.organization: - headers.append(("OpenAI-Organization", self.organization)) - if self.project: - headers.append(("OpenAI-Project", self.project)) + self._function_call_buffer = {} - self.websocket = await websockets.connect(url, additional_headers=headers) - logger.info("connection_id=<%s> | websocket connected successfully", self.connection_id) + # Establish WebSocket connection + url = f"{OPENAI_REALTIME_URL}?model={self.model}" - # Configure session - session_config = self._build_session_config(system_prompt, tools) - await self._send_event({"type": "session.update", "session": session_config}) + headers = [("Authorization", f"Bearer {self.api_key}")] + if self.organization: + headers.append(("OpenAI-Organization", self.organization)) + if self.project: + headers.append(("OpenAI-Project", self.project)) - # Add conversation history if provided - if messages: - await self._add_conversation_history(messages) + self._websocket = await websockets.connect(url, additional_headers=headers) + logger.info("connection_id=<%s> | websocket connected successfully", self._connection_id) - except Exception as e: - self._active = False - logger.error("error=<%s> | openai connection failed", e) - raise + # Configure session + session_config = self._build_session_config(system_prompt, tools) + await self._send_event({"type": "session.update", "session": session_config}) - def _require_active(self) -> bool: - """Check if session is active.""" - return self._active + # Add conversation history if provided + if messages: + await self._add_conversation_history(messages) def _create_text_event(self, text: str, role: str, is_final: bool = True) -> BidiTranscriptStreamEvent: """Create standardized transcript event. @@ -275,27 +263,16 @@ async def _add_conversation_history(self, messages: Messages) -> None: async def receive(self) -> AsyncIterable[BidiOutputEvent]: # type: ignore """Receive OpenAI events and convert to Strands TypedEvent format.""" - # Emit connection start event - yield BidiConnectionStartEvent(connection_id=self.connection_id, model=self.model) - - try: - while self._active: - async for message in self.websocket: - if not self._active: - break # type: ignore + if not self._connection_id: + raise RuntimeError("model not started | call start before receiving") - openai_event = json.loads(message) + yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model) - for event in self._convert_openai_event(openai_event) or []: - yield event + async for message in self._websocket: + openai_event = json.loads(message) - except Exception as e: - logger.error("error=<%s> | error receiving openai realtime event", e) - yield BidiErrorEvent(error=e) - finally: - # Emit connection close event - yield BidiConnectionCloseEvent(connection_id=self.connection_id, reason="complete") - self._active = False + for event in self._convert_openai_event(openai_event) or []: + yield event def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutputEvent] | None: """Convert OpenAI events to Strands TypedEvent format.""" @@ -557,26 +534,24 @@ async def send( Args: content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent). + + Raises: + ValueError: If content type not supported (e.g., image content). """ - if not self._require_active(): - return - - try: - # Note: TypedEvent inherits from dict, so isinstance checks for TypedEvent must come first - if isinstance(content, BidiTextInputEvent): - await self._send_text_content(content.text) - elif isinstance(content, BidiAudioInputEvent): - await self._send_audio_content(content) - elif isinstance(content, BidiImageInputEvent): - # BidiImageInputEvent - not supported by OpenAI Realtime yet - logger.warning("Image input not supported by OpenAI Realtime API") - elif isinstance(content, ToolResultEvent): - tool_result = content.get("tool_result") - if tool_result: - await self._send_tool_result(tool_result) - except Exception as e: - logger.error("error=<%s> | error sending content to openai", e) - raise # Propagate exception for debugging in experimental code + if not self._connection_id: + raise RuntimeError("model not started | call start before sending") + + # Note: TypedEvent inherits from dict, so isinstance checks for TypedEvent must come first + if isinstance(content, BidiTextInputEvent): + await self._send_text_content(content.text) + elif isinstance(content, BidiAudioInputEvent): + await self._send_audio_content(content) + elif isinstance(content, ToolResultEvent): + tool_result = content.get("tool_result") + if tool_result: + await self._send_tool_result(tool_result) + else: + raise ValueError(f"content_type={type(content)} | content not supported") async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: """Internal: Send audio content to OpenAI for processing.""" @@ -599,7 +574,7 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: logger.debug("tool_use_id=<%s> | sending openai tool result", tool_use_id) - # Extract result content + # TODO: We need to extract all content and content types result_data: dict[Any, Any] | str = {} if "content" in tool_result: # Extract text from content blocks @@ -616,25 +591,23 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: async def stop(self) -> None: """Close session and cleanup resources.""" - if not self._active: - return - logger.debug("openai realtime connection cleanup starting") - self._active = False - try: - await self.websocket.close() - except Exception as e: - logger.warning("error=<%s> | error closing openai realtime websocket", e) + async def stop_websocket() -> None: + if not hasattr(self, "_websocket"): + return + + await self._websocket.close() + + async def stop_connection() -> None: + self._connection_id = None + + await stop_all(stop_websocket, stop_connection) logger.debug("openai realtime connection closed") async def _send_event(self, event: dict[str, Any]) -> None: """Send event to OpenAI via WebSocket.""" - try: - message = json.dumps(event) - await self.websocket.send(message) - logger.debug("event_type=<%s> | openai event sent", event.get("type")) - except Exception as e: - logger.error("error=<%s> | error sending openai event", e) - raise + message = json.dumps(event) + await self._websocket.send(message) + logger.debug("event_type=<%s> | openai event sent", event.get("type")) diff --git a/tests/strands/experimental/bidi/models/test_novasonic.py b/tests/strands/experimental/bidi/models/test_novasonic.py index 49c6ec8f7..3a6dd66d1 100644 --- a/tests/strands/experimental/bidi/models/test_novasonic.py +++ b/tests/strands/experimental/bidi/models/test_novasonic.py @@ -166,7 +166,7 @@ async def test_send_edge_cases(nova_model): mime_type="image/jpeg", ) - with pytest.raises(ValueError, match=r"content not supported by nova sonic"): + with pytest.raises(ValueError, match=r"content not supported"): await nova_model.send(image_event) await nova_model.stop() diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai.py similarity index 92% rename from tests/strands/experimental/bidi/models/test_openai_realtime.py rename to tests/strands/experimental/bidi/models/test_openai.py index 2ffcac7ae..c0c38d1b2 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai.py @@ -93,8 +93,6 @@ def test_model_initialization(api_key, model_name): model_default = BidiOpenAIRealtimeModel(api_key="test-key") assert model_default.model == "gpt-realtime" assert model_default.api_key == "test-key" - assert model_default._active is False - assert model_default.websocket is None # Test with custom model model_custom = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) @@ -129,14 +127,12 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp # Test basic connection await model.start() - assert model._active is True - assert model.connection_id is not None - assert model.websocket == mock_ws + assert model._connection_id is not None + assert model._websocket == mock_ws mock_connect.assert_called_once() # Test close await model.stop() - assert model._active is False mock_ws.close.assert_called_once() # Test connection with system prompt @@ -202,7 +198,7 @@ async def async_connect(*args, **kwargs): # Test double connection model2 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) await model2.start() - with pytest.raises(RuntimeError, match="Connection already active"): + with pytest.raises(RuntimeError, match=r"call stop before starting again"): await model2.start() await model2.stop() @@ -210,12 +206,12 @@ async def async_connect(*args, **kwargs): model3 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) await model3.stop() # Should not raise - # Test close error handling (should not raise, just log) + # Test close error model4 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) await model4.start() mock_ws.close.side_effect = Exception("Close failed") - await model4.stop() # Should not raise - assert model4._active is False + with pytest.raises(ExceptionGroup): # noqa: F821 + await model4.stop() # Send Method Tests @@ -279,7 +275,8 @@ async def test_send_edge_cases(mock_websockets_connect, model): # Test send when inactive text_input = BidiTextInputEvent(text="Hello", role="user") - await model.send(text_input) + with pytest.raises(RuntimeError, match=r"call start before sending"): + await model.send(text_input) mock_ws.send.assert_not_called() # Test image input (not supported, base64 encoded, no encoding parameter) @@ -289,15 +286,8 @@ async def test_send_edge_cases(mock_websockets_connect, model): image=image_b64, mime_type="image/jpeg", ) - with unittest.mock.patch("strands.experimental.bidi.models.openai.logger") as mock_logger: + with pytest.raises(ValueError, match=r"content not supported"): await model.send(image_input) - mock_logger.warning.assert_called_with("Image input not supported by OpenAI Realtime API") - - # Test unknown content type - unknown_content = {"unknown_field": "value"} - with unittest.mock.patch("strands.experimental.bidi.models.openai.logger") as mock_logger: - await model.send(unknown_content) - assert mock_logger.warning.called await model.stop() @@ -318,7 +308,7 @@ async def test_receive_lifecycle_events(mock_websockets_connect, model): # First event should be connection start (new TypedEvent format) assert first_event.get("type") == "bidi_connection_start" - assert first_event.get("connection_id") == model.connection_id + assert first_event.get("connection_id") == model._connection_id assert first_event.get("model") == model.model # Close to trigger session end @@ -332,9 +322,6 @@ async def test_receive_lifecycle_events(mock_websockets_connect, model): except StopAsyncIteration: pass - # Last event should be connection close (new TypedEvent format) - assert events[-1].get("type") == "bidi_connection_close" - @pytest.mark.asyncio async def test_event_conversion(mock_websockets_connect, model): @@ -463,12 +450,6 @@ def test_tool_conversion(model, tool_spec): def test_helper_methods(model): """Test various helper methods.""" - # Test _require_active - assert model._require_active() is False - model._active = True - assert model._require_active() is True - model._active = False - # Test _create_text_event (now returns BidiTranscriptStreamEvent) text_event = model._create_text_event("Hello", "user") assert isinstance(text_event, BidiTranscriptStreamEvent) From a3c7d5efe157a9c90dbb3be49fcc4e3553670f68 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 21 Nov 2025 15:53:02 -0500 Subject: [PATCH 176/242] cancellation - gemini (#75) --- .../experimental/bidi/models/gemini_live.py | 495 ++++++++---------- .../bidi/models/test_gemini_live.py | 39 +- 2 files changed, 233 insertions(+), 301 deletions(-) diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 5f7eb587f..fbf62aad5 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -11,7 +11,6 @@ - Native support for audio/text streaming and interruption """ -import asyncio import base64 import logging import uuid @@ -24,12 +23,11 @@ from ....types._events import ToolResultEvent, ToolUseStreamEvent from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse +from .._async import stop_all from ..types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, - BidiConnectionCloseEvent, BidiConnectionStartEvent, - BidiErrorEvent, BidiImageInputEvent, BidiInputEvent, BidiInterruptionEvent, @@ -96,13 +94,12 @@ def __init__( # Use v1alpha for Live API as it has better model support client_kwargs["http_options"] = {"api_version": "v1alpha"} - self.client = genai.Client(**client_kwargs) + self._client = genai.Client(**client_kwargs) # Connection state (initialized in start()) - self.live_session: Any - self.live_session_context_manager = None - self.connection_id | str - self._active: bool = False + self._live_session: Any = None + self._live_session_context_manager: Any = None + self._connection_id: str | None = None async def start( self, @@ -119,33 +116,23 @@ async def start( messages: Conversation history to initialize with. **kwargs: Additional configuration options. """ - if self._active: - raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") + if self._connection_id: + raise RuntimeError("model already started | call stop before starting again") - try: - # Initialize connection state - self.connection_id = str(uuid.uuid4()) - self._active = True + self._connection_id = str(uuid.uuid4()) - # Build live config - live_config = self._build_live_config(system_prompt, tools, **kwargs) + # Build live config + live_config = self._build_live_config(system_prompt, tools, **kwargs) - # Create the context manager - self.live_session_context_manager = self.client.aio.live.connect( - model=self.model_id, config=cast(LiveConnectConfigOrDict, live_config) - ) + # Create the context manager and session + self._live_session_context_manager = self._client.aio.live.connect( + model=self.model_id, config=cast(LiveConnectConfigOrDict, live_config) + ) + self._live_session = await self._live_session_context_manager.__aenter__() - # Enter the context manager - self.live_session = await self.live_session_context_manager.__aenter__() - - # Send initial message history if provided - if messages: - await self._send_message_history(messages) - - except Exception as e: - self._active = False - logger.error("error=<%s> | error connecting to gemini live", e) - raise + # Send initial message history if provided + if messages: + await self._send_message_history(messages) async def _send_message_history(self, messages: Messages) -> None: """Send conversation history to Gemini Live API. @@ -169,41 +156,20 @@ async def _send_message_history(self, messages: Messages) -> None: # "assistant" role from Messages format maps to "model" in Gemini role = "model" if message["role"] == "assistant" else message["role"] content = genai_types.Content(role=role, parts=content_parts) - if self.live_session: - await self.live_session.send_client_content(turns=content) + await self._live_session.send_client_content(turns=content) async def receive(self) -> AsyncIterable[BidiOutputEvent]: # type: ignore """Receive Gemini Live API events and convert to provider-agnostic format.""" - # Emit connection start event - yield BidiConnectionStartEvent(connection_id=self.connection_id, model=self.model_id) - - try: - # Wrap in while loop to restart after turn_complete (SDK limitation workaround) - while self._active: - try: - async for message in self.live_session.receive(): - if not self._active: - raise ValueError("connection is not active") - - # Convert to provider-agnostic format (always returns list) - for event in self._convert_gemini_live_event(message): - yield event - - # SDK exits receive loop after turn_complete - restart automatically - if self._active: - logger.debug("gemini receive loop restarting after turn completion") - - except Exception as e: - logger.error("error=<%s> | error in gemini receive iteration", e) - # Small delay before retrying to avoid tight error loops - await asyncio.sleep(0.1) - - except Exception as e: - logger.error("error=<%s> | fatal error in gemini receive loop", e) - yield BidiErrorEvent(error=e) - finally: - # Emit connection close event when exiting - yield BidiConnectionCloseEvent(connection_id=self.connection_id, reason="complete") + if not self._connection_id: + raise RuntimeError("model not started | call start before receiving") + + yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) + + # Wrap in while loop to restart after turn_complete (SDK limitation workaround) + while True: + async for message in self._live_session.receive(): + for event in self._convert_gemini_live_event(message): + yield event def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOutputEvent]: """Convert Gemini Live API events to provider-agnostic format. @@ -217,155 +183,145 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOut Returns: List of event dicts (empty list if no events to emit). """ - try: - # Handle interruption first (from server_content) - if message.server_content and message.server_content.interrupted: - return [BidiInterruptionEvent(reason="user_speech")] - - # Handle input transcription (user's speech) - emit as transcript event - if message.server_content and message.server_content.input_transcription: - input_transcript = message.server_content.input_transcription - # Check if the transcription object has text content - if hasattr(input_transcript, "text") and input_transcript.text: - transcription_text = input_transcript.text - role = getattr(input_transcript, "role", "user") - logger.debug("text_length=<%d> | gemini input transcription detected", len(transcription_text)) - return [ - BidiTranscriptStreamEvent( - delta={"text": transcription_text}, - text=transcription_text, - role=role.lower() if isinstance(role, str) else "user", # type: ignore - is_final=True, - current_transcript=transcription_text, - ) - ] - - # Handle output transcription (model's audio) - emit as transcript event - if message.server_content and message.server_content.output_transcription: - output_transcript = message.server_content.output_transcription - # Check if the transcription object has text content - if hasattr(output_transcript, "text") and output_transcript.text: - transcription_text = output_transcript.text - role = getattr(output_transcript, "role", "assistant") - logger.debug("text_length=<%d> | gemini output transcription detected", len(transcription_text)) - return [ - BidiTranscriptStreamEvent( - delta={"text": transcription_text}, - text=transcription_text, - role=role.lower() if isinstance(role, str) else "assistant", # type: ignore - is_final=True, - current_transcript=transcription_text, - ) - ] - - # Handle audio output using SDK's built-in data property - # Check this BEFORE text to avoid triggering warning on mixed content - if message.data: - # Convert bytes to base64 string for JSON serializability - audio_b64 = base64.b64encode(message.data).decode("utf-8") + # Handle interruption first (from server_content) + if message.server_content and message.server_content.interrupted: + return [BidiInterruptionEvent(reason="user_speech")] + + # Handle input transcription (user's speech) - emit as transcript event + if message.server_content and message.server_content.input_transcription: + input_transcript = message.server_content.input_transcription + # Check if the transcription object has text content + if hasattr(input_transcript, "text") and input_transcript.text: + transcription_text = input_transcript.text + logger.debug("text_length=<%d> | gemini input transcription detected", len(transcription_text)) return [ - BidiAudioStreamEvent( - audio=audio_b64, - format="pcm", - sample_rate=GEMINI_OUTPUT_SAMPLE_RATE, # type: ignore - channels=GEMINI_CHANNELS, # type: ignore + BidiTranscriptStreamEvent( + delta={"text": transcription_text}, + text=transcription_text, + role="user", + # TODO: https://github.com/googleapis/python-genai/issues/1504 + is_final=bool(input_transcript.finished), + current_transcript=transcription_text, ) ] - # Handle text output from model_turn (avoids warning by checking parts directly) - if message.server_content and message.server_content.model_turn: - model_turn = message.server_content.model_turn - if model_turn.parts: - # Concatenate all text parts (Gemini may send multiple parts) - text_parts = [] - for part in model_turn.parts: - # Check if part has text attribute and it's not empty - if hasattr(part, "text") and part.text: - text_parts.append(part.text) - - if text_parts: - full_text = " ".join(text_parts) - return [ - BidiTranscriptStreamEvent( - delta={"text": full_text}, - text=full_text, - role="assistant", - is_final=True, - current_transcript=full_text, - ) - ] - - # Handle tool calls - return list to support multiple tool calls - if message.tool_call and message.tool_call.function_calls: - tool_events = [] - for func_call in message.tool_call.function_calls: - tool_use_event: ToolUse = { - "toolUseId": func_call.id, # type: ignore - "name": func_call.name, # type: ignore - "input": func_call.args or {}, - } - # Create ToolUseStreamEvent for consistency with standard agent - tool_events.append( - ToolUseStreamEvent(delta={"toolUse": tool_use_event}, current_tool_use=dict(tool_use_event)) + # Handle output transcription (model's audio) - emit as transcript event + if message.server_content and message.server_content.output_transcription: + output_transcript = message.server_content.output_transcription + # Check if the transcription object has text content + if hasattr(output_transcript, "text") and output_transcript.text: + transcription_text = output_transcript.text + logger.debug("text_length=<%d> | gemini output transcription detected", len(transcription_text)) + return [ + BidiTranscriptStreamEvent( + delta={"text": transcription_text}, + text=transcription_text, + role="assistant", + # TODO: https://github.com/googleapis/python-genai/issues/1504 + is_final=bool(output_transcript.finished), + current_transcript=transcription_text, ) - return tool_events # type: ignore + ] - # Handle usage metadata - if hasattr(message, "usage_metadata") and message.usage_metadata: - usage = message.usage_metadata + # Handle audio output using SDK's built-in data property + # Check this BEFORE text to avoid triggering warning on mixed content + if message.data: + # Convert bytes to base64 string for JSON serializability + audio_b64 = base64.b64encode(message.data).decode("utf-8") + return [ + BidiAudioStreamEvent( + audio=audio_b64, + format="pcm", + sample_rate=GEMINI_OUTPUT_SAMPLE_RATE, # type: ignore + channels=GEMINI_CHANNELS, # type: ignore + ) + ] + + # Handle text output from model_turn (avoids warning by checking parts directly) + if message.server_content and message.server_content.model_turn: + model_turn = message.server_content.model_turn + if model_turn.parts: + # Concatenate all text parts (Gemini may send multiple parts) + text_parts = [] + for part in model_turn.parts: + # Check if part has text attribute and it's not empty + if hasattr(part, "text") and part.text: + text_parts.append(part.text) + + if text_parts: + full_text = " ".join(text_parts) + return [ + BidiTranscriptStreamEvent( + delta={"text": full_text}, + text=full_text, + role="assistant", + is_final=True, + current_transcript=full_text, + ) + ] - # Build modality details from token details - modality_details = [] + # Handle tool calls - return list to support multiple tool calls + if message.tool_call and message.tool_call.function_calls: + tool_events = [] + for func_call in message.tool_call.function_calls: + tool_use_event: ToolUse = { + "toolUseId": func_call.id, # type: ignore + "name": func_call.name, # type: ignore + "input": func_call.args or {}, + } + # Create ToolUseStreamEvent for consistency with standard agent + tool_events.append( + ToolUseStreamEvent(delta={"toolUse": tool_use_event}, current_tool_use=dict(tool_use_event)) + ) + return tool_events # type: ignore + + # Handle usage metadata + if hasattr(message, "usage_metadata") and message.usage_metadata: + usage = message.usage_metadata + + # Build modality details from token details + modality_details = [] + + # Process prompt tokens details + if usage.prompt_tokens_details: + for detail in usage.prompt_tokens_details: + if detail.modality and detail.token_count: + modality_details.append( + { + "modality": str(detail.modality).lower(), + "input_tokens": detail.token_count, + "output_tokens": 0, + } + ) - # Process prompt tokens details - if usage.prompt_tokens_details: - for detail in usage.prompt_tokens_details: - if detail.modality and detail.token_count: + # Process response tokens details + if usage.response_tokens_details: + for detail in usage.response_tokens_details: + if detail.modality and detail.token_count: + # Find or create modality entry + modality_str = str(detail.modality).lower() + existing = next((m for m in modality_details if m["modality"] == modality_str), None) + if existing: + existing["output_tokens"] = detail.token_count + else: modality_details.append( - { - "modality": str(detail.modality).lower(), - "input_tokens": detail.token_count, - "output_tokens": 0, - } + {"modality": modality_str, "input_tokens": 0, "output_tokens": detail.token_count} ) - # Process response tokens details - if usage.response_tokens_details: - for detail in usage.response_tokens_details: - if detail.modality and detail.token_count: - # Find or create modality entry - modality_str = str(detail.modality).lower() - existing = next((m for m in modality_details if m["modality"] == modality_str), None) - if existing: - existing["output_tokens"] = detail.token_count - else: - modality_details.append( - {"modality": modality_str, "input_tokens": 0, "output_tokens": detail.token_count} - ) - - return [ - BidiUsageEvent( - input_tokens=usage.prompt_token_count or 0, - output_tokens=usage.response_token_count or 0, - total_tokens=usage.total_token_count or 0, - modality_details=modality_details if modality_details else None, # type: ignore - cache_read_input_tokens=usage.cached_content_token_count - if usage.cached_content_token_count - else None, - ) - ] - - # Silently ignore setup_complete and generation_complete messages - return [] - - except Exception as e: - logger.error( - "error=<%s>, message_type=<%s> | error converting gemini live event", - e, - type(message).__name__, - ) - # Return ErrorEvent in list so caller can handle it - return [BidiErrorEvent(error=e)] + return [ + BidiUsageEvent( + input_tokens=usage.prompt_token_count or 0, + output_tokens=usage.response_token_count or 0, + total_tokens=usage.total_token_count or 0, + modality_details=modality_details if modality_details else None, # type: ignore + cache_read_input_tokens=usage.cached_content_token_count + if usage.cached_content_token_count + else None, + ) + ] + + # Silently ignore setup_complete and generation_complete messages + return [] async def send( self, @@ -377,24 +333,25 @@ async def send( Args: content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent). - """ - if not self._active: - return - try: - if isinstance(content, BidiTextInputEvent): - await self._send_text_content(content.text) - elif isinstance(content, BidiAudioInputEvent): - await self._send_audio_content(content) - elif isinstance(content, BidiImageInputEvent): - await self._send_image_content(content) - elif isinstance(content, ToolResultEvent): - tool_result = content.get("tool_result") - if tool_result: - await self._send_tool_result(tool_result) - except Exception as e: - logger.error("error=<%s> | error sending content to gemini live", e) - raise # Propagate exception for debugging in experimental code + Raises: + ValueError: If content type not supported (e.g., image content). + """ + if not self._connection_id: + raise RuntimeError("model not started | call start before sending") + + if isinstance(content, BidiTextInputEvent): + await self._send_text_content(content.text) + elif isinstance(content, BidiAudioInputEvent): + await self._send_audio_content(content) + elif isinstance(content, BidiImageInputEvent): + await self._send_image_content(content) + elif isinstance(content, ToolResultEvent): + tool_result = content.get("tool_result") + if tool_result: + await self._send_tool_result(tool_result) + else: + raise ValueError(f"content_type={type(content)} | content not supported") async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: """Internal: Send audio content using Gemini Live API. @@ -402,18 +359,14 @@ async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: Gemini Live expects continuous audio streaming via send_realtime_input. This automatically triggers VAD and can interrupt ongoing responses. """ - try: - # Decode base64 audio to bytes for SDK - audio_bytes = base64.b64decode(audio_input.audio) - - # Create audio blob for the SDK - audio_blob = genai_types.Blob(data=audio_bytes, mime_type=f"audio/pcm;rate={GEMINI_INPUT_SAMPLE_RATE}") + # Decode base64 audio to bytes for SDK + audio_bytes = base64.b64decode(audio_input.audio) - # Send real-time audio input - this automatically handles VAD and interruption - await self.live_session.send_realtime_input(audio=audio_blob) + # Create audio blob for the SDK + audio_blob = genai_types.Blob(data=audio_bytes, mime_type=f"audio/pcm;rate={GEMINI_INPUT_SAMPLE_RATE}") - except Exception as e: - logger.error("error=<%s> | error sending audio content to gemini live", e) + # Send real-time audio input - this automatically handles VAD and interruption + await self._live_session.send_realtime_input(audio=audio_blob) async def _send_image_content(self, image_input: BidiImageInputEvent) -> None: """Internal: Send image content using Gemini Live API. @@ -421,68 +374,56 @@ async def _send_image_content(self, image_input: BidiImageInputEvent) -> None: Sends image frames following the same pattern as the GitHub example. Images are sent as base64-encoded data with MIME type. """ - try: - # Image is already base64 encoded in the event - msg = {"mime_type": image_input.mime_type, "data": image_input.image} + # Image is already base64 encoded in the event + msg = {"mime_type": image_input.mime_type, "data": image_input.image} - # Send using the same method as the GitHub example - await self.live_session.send(input=msg) - - except Exception as e: - logger.error("error=<%s> | error sending image content to gemini live", e) + # Send using the same method as the GitHub example + await self._live_session.send(input=msg) async def _send_text_content(self, text: str) -> None: """Internal: Send text content using Gemini Live API.""" - try: - # Create content with text - content = genai_types.Content(role="user", parts=[genai_types.Part(text=text)]) - - # Send as client content - await self.live_session.send_client_content(turns=content) + # Create content with text + content = genai_types.Content(role="user", parts=[genai_types.Part(text=text)]) - except Exception as e: - logger.error("error=<%s> | error sending text content to gemini live", e) + # Send as client content + await self._live_session.send_client_content(turns=content) async def _send_tool_result(self, tool_result: ToolResult) -> None: """Internal: Send tool result using Gemini Live API.""" - try: - tool_use_id = tool_result.get("toolUseId") - - # Extract result content - result_data = {} - if "content" in tool_result: - # Extract text from content blocks - for block in tool_result["content"]: - if "text" in block: - result_data = {"result": block["text"]} - break - - # Create function response - func_response = genai_types.FunctionResponse( - id=tool_use_id, - name=tool_use_id, # Gemini uses name as identifier - response=result_data, - ) - - # Send tool response - await self.live_session.send_tool_response(function_responses=[func_response]) - except Exception as e: - logger.error("error=<%s> | error sending tool result to gemini live", e) + tool_use_id = tool_result.get("toolUseId") + + # TODO: We need to extract all content and content types + result_data = {} + if "content" in tool_result: + # Extract text from content blocks + for block in tool_result["content"]: + if "text" in block: + result_data = {"result": block["text"]} + break + + # Create function response + func_response = genai_types.FunctionResponse( + id=tool_use_id, + name=tool_use_id, # Gemini uses name as identifier + response=result_data, + ) + + # Send tool response + await self._live_session.send_tool_response(function_responses=[func_response]) async def stop(self) -> None: """Close Gemini Live API connection.""" - if not self._active: - return - self._active = False + async def stop_session() -> None: + if not self._live_session_context_manager: + return + + await self._live_session_context_manager.__aexit__(None, None, None) + + async def stop_connection() -> None: + self._connection_id = None - try: - # Exit the context manager properly - if self.live_session_context_manager: - await self.live_session_context_manager.__aexit__(None, None, None) - except Exception as e: - logger.error("error=<%s> | error closing gemini live connection", e) - raise + await stop_all(stop_session, stop_connection) def _build_live_config( self, system_prompt: Optional[str] = None, tools: Optional[List[ToolSpec]] = None, **kwargs: Any diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index 6a2c79ece..d32f88fe7 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -17,7 +17,6 @@ from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, - BidiConnectionCloseEvent, BidiConnectionStartEvent, BidiImageInputEvent, BidiInterruptionEvent, @@ -96,8 +95,7 @@ def test_model_initialization(mock_genai_client, model_id, api_key): model_default = BidiGeminiLiveModel() assert model_default.model_id == "gemini-2.5-flash-native-audio-preview-09-2025" assert model_default.api_key is None - assert model_default._active is False - assert model_default.live_session is None + assert model_default._live_session is None # Check default config includes transcription assert model_default.live_config["response_modalities"] == ["AUDIO"] assert "outputAudioTranscription" in model_default.live_config @@ -128,14 +126,12 @@ async def test_connection_lifecycle(mock_genai_client, model, system_prompt, too # Test basic connection await model.start() - assert model._active is True - assert model.connection_id is not None - assert model.live_session == mock_live_session + assert model._connection_id is not None + assert model._live_session == mock_live_session mock_client.aio.live.connect.assert_called_once() # Test close await model.stop() - assert model._active is False mock_live_session_cm.__aexit__.assert_called_once() # Test connection with system prompt @@ -167,7 +163,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): # Test connection error model1 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) mock_client.aio.live.connect.side_effect = Exception("Connection failed") - with pytest.raises(Exception, match="Connection failed"): + with pytest.raises(Exception, match=r"Connection failed"): await model1.start() # Reset mock for next tests @@ -176,7 +172,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): # Test double connection model2 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) await model2.start() - with pytest.raises(RuntimeError, match="Connection already active"): + with pytest.raises(RuntimeError, match="call stop before starting again"): await model2.start() await model2.stop() @@ -188,7 +184,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): model4 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) await model4.start() mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") - with pytest.raises(Exception, match="Close failed"): + with pytest.raises(ExceptionGroup): # noqa: F821 await model4.stop() @@ -249,13 +245,15 @@ async def test_send_edge_cases(mock_genai_client, model): # Test send when inactive text_input = BidiTextInputEvent(text="Hello", role="user") - await model.send(text_input) + with pytest.raises(RuntimeError, match=r"call start before sending"): + await model.send(text_input) mock_live_session.send_client_content.assert_not_called() # Test unknown content type await model.start() unknown_content = {"unknown_field": "value"} - await model.send(unknown_content) # Should not raise, just log warning + with pytest.raises(ValueError, match=r"content not supported"): + await model.send(unknown_content) await model.stop() @@ -271,21 +269,14 @@ async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): await model.start() - # Collect events - events = [] async for event in model.receive(): - events.append(event) - # Close after first event to trigger connection end - if len(events) == 1: - await model.stop() + _ = event + break # Verify connection start and end - assert len(events) >= 2 - assert isinstance(events[0], BidiConnectionStartEvent) - assert events[0].get("type") == "bidi_connection_start" - assert events[0].connection_id == model.connection_id - assert isinstance(events[-1], BidiConnectionCloseEvent) - assert events[-1].get("type") == "bidi_connection_close" + assert isinstance(event, BidiConnectionStartEvent) + assert event.get("type") == "bidi_connection_start" + assert event.connection_id == model._connection_id @pytest.mark.asyncio From 6298c86cdc2e028e376958a46b3836d216e45cc8 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Sat, 22 Nov 2025 13:21:40 -0500 Subject: [PATCH 177/242] cancellation - agent (#72) --- src/strands/experimental/bidi/agent/agent.py | 116 +++++++------------ src/strands/experimental/bidi/agent/loop.py | 20 +++- 2 files changed, 62 insertions(+), 74 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 5a29e9ad3..78fa9da05 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -29,6 +29,7 @@ from ....types.tools import AgentTool, ToolResult, ToolUse from ...hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent from ...tools import ToolProvider +from .._async import stop_all from ..models.bidi_model import BidiModel from ..models.novasonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput @@ -140,14 +141,13 @@ def __init__( for hook in hooks: self.hooks.add_hook(hook) - # Initialize invocation state (will be set in start()) - self._invocation_state: dict[str, Any] = {} - self._loop = _BidiAgentLoop(self) # Emit initialization event self.hooks.invoke_callbacks(BidiAgentInitializedEvent(agent=self)) + self._started = False + @property def tool(self) -> _ToolCaller: """Call tool as a function. @@ -279,10 +279,12 @@ async def start(self, invocation_state: dict[str, Any] | None = None) -> None: }) ``` """ - logger.debug("agent starting") - self._invocation_state = invocation_state or {} + if self._started: + raise RuntimeError("agent already started | call stop before starting again") - await self._loop.start() + logger.debug("agent starting") + await self._loop.start(invocation_state) + self._started = True async def send(self, input_data: BidiAgentInput) -> None: """Send input to the model (text, audio, image, or event dict). @@ -299,14 +301,16 @@ async def send(self, input_data: BidiAgentInput) -> None: - dict: Event dictionary (will be reconstructed to TypedEvent) Raises: - ValueError: If no active session or invalid input type. + RuntimeError: If start has not been called. + ValueError: If invalid input type. Example: await agent.send("Hello") await agent.send(BidiAudioInputEvent(audio="base64...", format="pcm", ...)) await agent.send({"type": "bidirectional_text_input", "text": "Hello", "role": "user"}) """ - self._validate_active_connection() + if not self._started: + raise RuntimeError("agent not started | call start before sending") # Handle string input if isinstance(input_data, str): @@ -359,12 +363,16 @@ async def send(self, input_data: BidiAgentInput) -> None: async def receive(self) -> AsyncIterable[BidiOutputEvent]: """Receive events from the model including audio, text, and tool calls. - Yields model output events processed by background tasks including audio output, - text responses, tool calls, and connection updates. - Yields: - Model and tool call events. + Model output events processed by background tasks including audio output, + text responses, tool calls, and connection updates. + + Raises: + RuntimeError: If start has not been called. """ + if not self._started: + raise RuntimeError("agent not started | call start before receiving") + async for event in self._loop.receive(): yield event @@ -374,53 +382,34 @@ async def stop(self) -> None: Terminates the streaming connection, cancels background tasks, and closes the connection to the model provider. """ + self._started = False await self._loop.stop() - async def __aenter__(self) -> "BidiAgent": + async def __aenter__(self, invocation_state: dict[str, Any] | None = None) -> "BidiAgent": """Async context manager entry point. Automatically starts the bidirectional connection when entering the context. + Args: + invocation_state: Optional context to pass to tools during execution. + This allows passing custom data (user_id, session_id, database connections, etc.) + that tools can access via their invocation_state parameter. + Returns: Self for use in the context. """ - logger.debug("context_manager= | starting connection") - await self.start() + logger.debug("context_manager= | starting agent") + await self.start(invocation_state) return self - async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: + async def __aexit__(self, *_: Any) -> None: """Async context manager exit point. Automatically ends the connection and cleans up resources including when exiting the context, regardless of whether an exception occurred. - - Args: - exc_type: Exception type if an exception occurred, None otherwise. - exc_val: Exception value if an exception occurred, None otherwise. - exc_tb: Exception traceback if an exception occurred, None otherwise. """ - try: - logger.debug("context_manager= | cleaning up connection") - - # Cleanup agent connection - await self.stop() - - except Exception as cleanup_error: - if exc_type is None: - # No original exception, re-raise cleanup error - logger.error("cleanup_error=<%s> | error during context manager cleanup", cleanup_error) - raise - else: - # Original exception exists, log cleanup error but don't suppress original - logger.error( - "cleanup_error=<%s> | error during context manager cleanup suppressed due to original exception", - cleanup_error, - ) - - @property - def active(self) -> bool: - """True if agent loop started, False otherwise.""" - return self._loop.active + logger.debug("context_manager= | stopping agent") + await self.stop() async def run( self, inputs: list[BidiInput], outputs: list[BidiOutput], invocation_state: dict[str, Any] | None = None @@ -449,7 +438,7 @@ async def run( async def run_inputs() -> None: async def task(input_: BidiInput) -> None: - while self.active: + while True: event = await input_() await self.send(event) @@ -461,35 +450,20 @@ async def run_outputs() -> None: tasks = [output(event) for output in outputs] await asyncio.gather(*tasks) - await self.start(invocation_state=invocation_state) - - for input_ in inputs: - if hasattr(input_, "start"): - await input_.start() - - for output in outputs: - if hasattr(output, "start"): - await output.start() - try: - await asyncio.gather(run_inputs(), run_outputs()) - - finally: - for input_ in inputs: - if hasattr(input_, "stop"): - await input_.stop() + await self.start(invocation_state) - for output in outputs: - if hasattr(output, "stop"): - await output.stop() + start_inputs = [input_.start for input_ in inputs if hasattr(input_, "start")] + start_outputs = [output.start for output in outputs if hasattr(output, "start")] + for start in [*start_inputs, *start_outputs]: + await start() - await self.stop() + async with asyncio.TaskGroup() as task_group: # type: ignore + task_group.create_task(run_inputs()) + task_group.create_task(run_outputs()) - def _validate_active_connection(self) -> None: - """Validate that an active connection exists. + finally: + stop_inputs = [input_.stop for input_ in inputs if hasattr(input_, "stop")] + stop_outputs = [output.stop for output in outputs if hasattr(output, "stop")] - Raises: - ValueError: If no active connection. - """ - if not self.active: - raise ValueError("No active conversation. Call start() first or use async context manager.") + await stop_all(*stop_inputs, *stop_outputs, self.stop) diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 1297e9e62..2795a1e6c 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -35,12 +35,16 @@ class _BidiAgentLoop: _stop_event: Sentinel to mark end of loop. _tasks: Track active async tasks created in loop. _active: Flag if agent loop is started. + _invocation_state: Optional context to pass to tools during execution. + This allows passing custom data (user_id, session_id, database connections, etc.) + that tools can access via their invocation_state parameter. """ _event_queue: asyncio.Queue _stop_event: object _tasks: set _active: bool + _invocation_state: dict[str, Any] def __init__(self, agent: "BidiAgent") -> None: """Initialize members of the agent loop. @@ -51,18 +55,26 @@ def __init__(self, agent: "BidiAgent") -> None: agent: Bidirectional agent to loop over. """ self._agent = agent - self._active: bool = False + self._active = False + self._invocation_state = {} - async def start(self) -> None: + async def start(self, invocation_state: dict[str, Any] | None = None) -> None: """Start the agent loop. The agent model is started as part of this call. + + Args: + invocation_state: Optional context to pass to tools during execution. + This allows passing custom data (user_id, session_id, database connections, etc.) + that tools can access via their invocation_state parameter. """ if self.active: return logger.debug("agent loop starting") + self._invocation_state = invocation_state or {} + self._event_queue = asyncio.Queue(maxsize=1) self._stop_event = object() self._tasks = set() @@ -87,6 +99,8 @@ async def stop(self) -> None: logger.debug("agent loop stopping") + self._invocation_state = {} + try: # Cancel all tasks for task in self._tasks: @@ -175,7 +189,7 @@ async def _run_tool(self, tool_use: ToolUse) -> None: tool_results: list[ToolResult] = [] invocation_state: dict[str, Any] = { - **self._agent._invocation_state, + **self._invocation_state, "agent": self._agent, "model": self._agent.model, "messages": self._agent.messages, From c4181ec490f59b236767e5036d25844f5d178b27 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Sun, 23 Nov 2025 09:06:59 -0500 Subject: [PATCH 178/242] cancellation - agent loop (#63) --- src/strands/experimental/bidi/agent/agent.py | 8 + src/strands/experimental/bidi/agent/loop.py | 207 +++++++++---------- 2 files changed, 104 insertions(+), 111 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 78fa9da05..933bd8e11 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -20,6 +20,7 @@ from .... import _identifier from ....agent.state import AgentState from ....hooks import HookProvider, HookRegistry +from ....interrupt import _InterruptState from ....tools.caller import _ToolCaller from ....tools.executors import ConcurrentToolExecutor from ....tools.executors._executor import ToolExecutor @@ -146,6 +147,9 @@ def __init__( # Emit initialization event self.hooks.invoke_callbacks(BidiAgentInitializedEvent(agent=self)) + # TODO: Determine if full support is required + self._interrupt_state = _InterruptState() + self._started = False @property @@ -270,6 +274,10 @@ async def start(self, invocation_state: dict[str, Any] | None = None) -> None: This allows passing custom data (user_id, session_id, database connections, etc.) that tools can access via their invocation_state parameter. + Raises: + RuntimeError: + If agent already started. + Example: ```python await agent.start(invocation_state={ diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 2795a1e6c..3cfcf83aa 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -5,7 +5,7 @@ import asyncio import logging -from typing import TYPE_CHECKING, Any, AsyncIterable, Awaitable +from typing import TYPE_CHECKING, Any, AsyncIterable from ....types._events import ToolInterruptEvent, ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent from ....types.content import Message @@ -18,6 +18,7 @@ from ...hooks.events import ( BidiInterruptionEvent as BidiInterruptionHookEvent, ) +from .._async import _TaskPool, stop_all from ..types.events import BidiInterruptionEvent, BidiOutputEvent, BidiTranscriptStreamEvent if TYPE_CHECKING: @@ -31,21 +32,14 @@ class _BidiAgentLoop: Attributes: _agent: BidiAgent instance to loop. + _started: Flag if agent loop has started. + _task_pool: Track active async tasks created in loop. _event_queue: Queue model and tool call events for receiver. - _stop_event: Sentinel to mark end of loop. - _tasks: Track active async tasks created in loop. - _active: Flag if agent loop is started. _invocation_state: Optional context to pass to tools during execution. This allows passing custom data (user_id, session_id, database connections, etc.) that tools can access via their invocation_state parameter. """ - _event_queue: asyncio.Queue - _stop_event: object - _tasks: set - _active: bool - _invocation_state: dict[str, Any] - def __init__(self, agent: "BidiAgent") -> None: """Initialize members of the agent loop. @@ -55,8 +49,10 @@ def __init__(self, agent: "BidiAgent") -> None: agent: Bidirectional agent to loop over. """ self._agent = agent - self._active = False - self._invocation_state = {} + self._started = False + self._task_pool = _TaskPool() + self._event_queue: asyncio.Queue + self._invocation_state: dict[str, Any] async def start(self, invocation_state: dict[str, Any] | None = None) -> None: """Start the agent loop. @@ -67,19 +63,15 @@ async def start(self, invocation_state: dict[str, Any] | None = None) -> None: invocation_state: Optional context to pass to tools during execution. This allows passing custom data (user_id, session_id, database connections, etc.) that tools can access via their invocation_state parameter. + + Raises: + RuntimeError: + If loop already started. """ - if self.active: - return + if self._started: + raise RuntimeError("loop already started | call stop before starting again") logger.debug("agent loop starting") - - self._invocation_state = invocation_state or {} - - self._event_queue = asyncio.Queue(maxsize=1) - self._stop_event = object() - self._tasks = set() - - # Emit before invocation event await self._agent.hooks.invoke_callbacks_async(BidiBeforeInvocationEvent(agent=self._agent)) await self._agent.model.start( @@ -88,65 +80,48 @@ async def start(self, invocation_state: dict[str, Any] | None = None) -> None: messages=self._agent.messages, ) - self._create_task(self._run_model()) + self._event_queue = asyncio.Queue(maxsize=1) + + self._task_pool = _TaskPool() + self._task_pool.create(self._run_model()) - self._active = True + self._invocation_state = invocation_state or {} + self._started = True async def stop(self) -> None: """Stop the agent loop.""" - if not self.active: - return - logger.debug("agent loop stopping") + self._started = False self._invocation_state = {} - try: - # Cancel all tasks - for task in self._tasks: - task.cancel() - - # Wait briefly for tasks to finish their current operations - await asyncio.gather(*self._tasks, return_exceptions=True) + async def stop_tasks() -> None: + await self._task_pool.cancel() - # Stop the model + async def stop_model() -> None: await self._agent.model.stop() - # Clean up the event queue - if not self._event_queue.empty(): - self._event_queue.get_nowait() - self._event_queue.put_nowait(self._stop_event) - - self._active = False - + try: + await stop_all(stop_tasks, stop_model) finally: - # Emit after invocation event (reverse order for cleanup) await self._agent.hooks.invoke_callbacks_async(BidiAfterInvocationEvent(agent=self._agent)) async def receive(self) -> AsyncIterable[BidiOutputEvent]: - """Receive model and tool call events.""" + """Receive model and tool call events. + + Raises: + RuntimeError: If start has not been called. + """ + if not self._started: + raise RuntimeError("loop not started | call start before receiving") + while True: event = await self._event_queue.get() - if event is self._stop_event: - break + if isinstance(event, Exception): + raise event yield event - @property - def active(self) -> bool: - """True if agent loop started, False otherwise.""" - return self._active - - def _create_task(self, coro: Awaitable[None]) -> None: - """Utilitly to create async task. - - Adds a clean up callback to run after task completes. - """ - task: asyncio.Task[None] = asyncio.create_task(coro) # type: ignore - task.add_done_callback(lambda task: self._tasks.remove(task)) - - self._tasks.add(task) - async def _run_model(self) -> None: """Task for running the model. @@ -154,33 +129,39 @@ async def _run_model(self) -> None: """ logger.debug("model task starting") - async for event in self._agent.model.receive(): # type: ignore - await self._event_queue.put(event) - - if isinstance(event, BidiTranscriptStreamEvent): - if event["is_final"]: - message: Message = {"role": event["role"], "content": [{"text": event["text"]}]} - self._agent.messages.append(message) + try: + async for event in self._agent.model.receive(): # type: ignore + await self._event_queue.put(event) + + if isinstance(event, BidiTranscriptStreamEvent): + if event["is_final"]: + message: Message = {"role": event["role"], "content": [{"text": event["text"]}]} + self._agent.messages.append(message) + await self._agent.hooks.invoke_callbacks_async( + BidiMessageAddedEvent(agent=self._agent, message=message) + ) + + elif isinstance(event, ToolUseStreamEvent): + tool_use = event["current_tool_use"] + self._task_pool.create(self._run_tool(tool_use)) + + tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} + self._agent.messages.append(tool_message) await self._agent.hooks.invoke_callbacks_async( - BidiMessageAddedEvent(agent=self._agent, message=message) + BidiMessageAddedEvent(agent=self._agent, message=tool_message) ) - elif isinstance(event, ToolUseStreamEvent): - tool_use = event["current_tool_use"] - self._create_task(self._run_tool(tool_use)) - - tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} - self._agent.messages.append(tool_message) - - elif isinstance(event, BidiInterruptionEvent): - # Emit interruption hook event - await self._agent.hooks.invoke_callbacks_async( - BidiInterruptionHookEvent( - agent=self._agent, - reason=event["reason"], - interrupted_response_id=event.get("interrupted_response_id"), + elif isinstance(event, BidiInterruptionEvent): + await self._agent.hooks.invoke_callbacks_async( + BidiInterruptionHookEvent( + agent=self._agent, + reason=event["reason"], + interrupted_response_id=event.get("interrupted_response_id"), + ) ) - ) + + except Exception as error: + await self._event_queue.put(error) async def _run_tool(self, tool_use: ToolUse) -> None: """Task for running tool requested by the model using the tool executor.""" @@ -196,30 +177,34 @@ async def _run_tool(self, tool_use: ToolUse) -> None: "system_prompt": self._agent.system_prompt, } - tool_events = self._agent.tool_executor._stream( - self._agent, - tool_use, - tool_results, - invocation_state, - structured_output_context=None, - ) - - async for event in tool_events: - if isinstance(event, ToolInterruptEvent): - raise RuntimeError( - "Tool interruption is not yet supported in BidiAgent. " - "ToolInterruptEvent received but cannot be handled in bidirectional streaming context." - ) - await self._event_queue.put(event) - if isinstance(event, ToolResultEvent): - result = event.tool_result - - await self._agent.model.send(ToolResultEvent(result)) - - message: Message = { - "role": "user", - "content": [{"toolResult": result}], - } - self._agent.messages.append(message) - await self._agent.hooks.invoke_callbacks_async(BidiMessageAddedEvent(agent=self._agent, message=message)) - await self._event_queue.put(ToolResultMessageEvent(message)) + try: + tool_events = self._agent.tool_executor._stream( + self._agent, + tool_use, + tool_results, + invocation_state, + structured_output_context=None, + ) + + async for event in tool_events: + if isinstance(event, ToolInterruptEvent): + self._agent._interrupt_state.deactivate() + interrupt_names = [interrupt.name for interrupt in event.interrupts] + raise RuntimeError(f"interrupts={interrupt_names} | tool interrupts are not supported in bidi") + + await self._event_queue.put(event) + if isinstance(event, ToolResultEvent): + result = event.tool_result + + await self._agent.model.send(ToolResultEvent(result)) + + message: Message = { + "role": "user", + "content": [{"toolResult": result}], + } + self._agent.messages.append(message) + await self._agent.hooks.invoke_callbacks_async(BidiMessageAddedEvent(agent=self._agent, message=message)) + await self._event_queue.put(ToolResultMessageEvent(message)) + + except Exception as error: + await self._event_queue.put(error) From c70d43cad42e20d988f66cb8a2860e687adf97a5 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Sun, 23 Nov 2025 10:00:13 -0500 Subject: [PATCH 179/242] pyproject - bidi configs (#77) --- pyproject.toml | 109 ++++++++++-------- .../scripts => scripts/bidi}/test_bidi.py | 0 .../bidi}/test_bidi_novasonic.py | 0 .../bidi}/test_bidi_openai.py | 0 .../bidi}/test_gemini_live.py | 0 src/strands/experimental/bidi/__init__.py | 5 + .../experimental/bidi/_async/__init__.py | 2 +- src/strands/experimental/bidi/agent/agent.py | 11 +- src/strands/experimental/bidi/agent/loop.py | 6 +- .../experimental/bidi/models/bidi_model.py | 2 +- .../experimental/bidi/models/gemini_live.py | 37 +++--- .../experimental/bidi/models/novasonic.py | 13 ++- .../experimental/bidi/models/openai.py | 25 ++-- .../experimental/bidi/types/__init__.py | 13 --- src/strands/experimental/bidi/types/events.py | 28 +++-- src/strands/tools/executors/_executor.py | 8 +- .../strands/agent/hooks/test_agent_events.py | 3 + tests/strands/event_loop/test_streaming.py | 18 ++- .../experimental/bidi/_async/test__init__.py | 2 +- .../bidi/models/test_gemini_live.py | 2 +- .../experimental/bidi/models/test_openai.py | 2 +- 21 files changed, 161 insertions(+), 125 deletions(-) rename {src/strands/experimental/bidi/scripts => scripts/bidi}/test_bidi.py (100%) rename {src/strands/experimental/bidi/scripts => scripts/bidi}/test_bidi_novasonic.py (100%) rename {src/strands/experimental/bidi/scripts => scripts/bidi}/test_bidi_openai.py (100%) rename {src/strands/experimental/bidi/scripts => scripts/bidi}/test_gemini_live.py (100%) diff --git a/pyproject.toml b/pyproject.toml index 6d977a236..eb61bf5b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,34 +54,6 @@ sagemaker = [ "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0", "openai>=1.68.0,<2.0.0", # SageMaker uses OpenAI-compatible interface ] -bidi-novasonic = [ - "pyaudio>=0.2.13", - "rx>=3.2.0", - "smithy-aws-core>=0.0.1; python_version>='3.12'", - "pytz", - "aws_sdk_bedrock_runtime; python_version>='3.12'", -] -bidi-openai = [ - "pyaudio>=0.2.13", - "websockets>=14.0,<16.0", -] -bidi-gemini = [ - "pyaudio>=0.2.13", - "google-genai>=1.32.0,<2.0.0", - "opencv-python>=4.8.0", - "pillow>=10.0.0", -] -bidi = [ - "pyaudio>=0.2.13", - "rx>=3.2.0", - "smithy-aws-core>=0.0.1; python_version>='3.12'", - "pytz", - "aws_sdk_bedrock_runtime; python_version>='3.12'", - "websockets>=14.0,<16.0", - "google-genai>=1.32.0,<2.0.0", - "opencv-python>=4.8.0", - "pillow>=10.0.0", -] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ "sphinx>=5.0.0,<9.0.0", @@ -97,8 +69,17 @@ a2a = [ "fastapi>=0.115.12,<1.0.0", "starlette>=0.46.2,<1.0.0", ] -bidi-all = ["strands-agents[bidi,bidi-openai,bidi-gemini,bidi-novasonic]"] + +bidi = [ + "aws_sdk_bedrock_runtime; python_version>='3.12'", + "pyaudio>=0.2.13,<1.0.0", + "smithy-aws-core>=0.0.1; python_version>='3.12'", +] +bidi-gemini = ["google-genai>=1.32.0,<2.0.0"] +bidi-openai = ["websockets>=15.0.0,<16.0.0"] + all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] +bidi-all = ["strands-agents[a2a,bidi,bidi-gemini,bidi-openai,docs,otel]"] dev = [ "commitizen>=4.4.0,<5.0.0", @@ -129,8 +110,6 @@ source = "vcs" # Use git tags for versioning [tool.hatch.envs.hatch-static-analysis] installer = "uv" -# Only install 'all' features, not 'bidi-all' which requires Python 3.12+ -# The bidi code will still be type-checked, but without installing its runtime dependencies features = ["all"] dependencies = [ "mypy>=1.15.0,<2.0.0", @@ -149,7 +128,7 @@ format-fix = [ ] lint-check = [ "ruff check", - "mypy -p src" + "mypy ./src" ] lint-fix = [ "ruff check --fix" @@ -158,7 +137,7 @@ lint-fix = [ [tool.hatch.envs.hatch-test] installer = "uv" -features = ["all", "bidi-all"] +features = ["all"] extra-args = ["-n", "auto", "-vv"] dependencies = [ "pytest>=8.0.0,<9.0.0", @@ -181,7 +160,7 @@ cov-report = [] [tool.hatch.envs.default] installer = "uv" dev-mode = true -features = ["all", "bidi-all"] +features = ["all"] dependencies = [ "commitizen>=4.4.0,<5.0.0", "hatch>=1.0.0,<2.0.0", @@ -223,25 +202,16 @@ warn_no_return = true warn_unreachable = true follow_untyped_imports = true ignore_missing_imports = false +exclude = ["src/strands/experimental/bidi"] -exclude = [ - "src/strands/experimental/bidi/scripts/.*", -] - -# Ignore missing imports for optional bidi dependencies (not installed in lint environment) [[tool.mypy.overrides]] -module = [ - "smithy_core.*", - "smithy_aws_core.*", - "aws_sdk_bedrock_runtime.*", - "pyaudio", -] -ignore_missing_imports = true +module = ["strands.experimental.bidi.*"] follow_imports = "skip" [tool.ruff] line-length = 120 include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests_integ/**/*.py"] +exclude = ["src/strands/experimental/bidi/**/*.py", "tests/strands/experimental/bidi/**/*.py", "tests_integ/bidi/**/*.py"] [tool.ruff.lint] select = [ @@ -264,6 +234,7 @@ convention = "google" [tool.pytest.ini_options] testpaths = ["tests"] asyncio_default_fixture_loop_scope = "function" +addopts = "--ignore=tests/strands/experimental/bidi" [tool.coverage.run] @@ -272,6 +243,7 @@ source = ["src"] context = "thread" parallel = true concurrency = ["thread", "multiprocessing"] +omit = ["src/strands/experimental/bidi/*"] [tool.coverage.report] show_missing = true @@ -301,3 +273,48 @@ style = [ ["text", ""], ["disabled", "fg:#858585 italic"] ] + +# ========================= +# Bidi development configs +# ========================= + +[tool.hatch.envs.bidi] +dev-mode = true +features = ["dev", "bidi-all"] +installer = "uv" + +[tool.hatch.envs.bidi.scripts] +prepare = [ + "hatch run bidi-lint:format-fix", + "hatch run bidi-lint:quality-fix", + "hatch run bidi-lint:type-check", + "hatch run bidi-test:test-cov", +] + +[tools.hatch.envs.bidi-lint] +template = "bidi" + +[tool.hatch.envs.bidi-lint.scripts] +format-check = "format-fix --check" +format-fix = "ruff format {args} --target-version py312 ./src/strands/experimental/bidi/**/*.py" +quality-check = "ruff check {args} --target-version py312 ./src/strands/experimental/bidi/**/*.py" +quality-fix = "quality-check --fix" +type-check = "mypy {args} --python-version 3.12 ./src/strands/experimental/bidi/**/*.py" + +[tool.hatch.envs.bidi-test] +template = "bidi" + +[tool.hatch.envs.bidi-test.scripts] +test = "pytest {args} tests/strands/experimental/bidi" +test-cov = """ +test \ + --cov=strands.experimental.bidi \ + --cov-config= \ + --cov-branch \ + --cov-report=term-missing \ + --cov-report=xml:build/coverage/bidi-coverage.xml \ + --cov-report=html:build/coverage/bidi-html +""" + +[[tool.hatch.envs.bidi-test.matrix]] +python = ["3.13", "3.12"] diff --git a/src/strands/experimental/bidi/scripts/test_bidi.py b/scripts/bidi/test_bidi.py similarity index 100% rename from src/strands/experimental/bidi/scripts/test_bidi.py rename to scripts/bidi/test_bidi.py diff --git a/src/strands/experimental/bidi/scripts/test_bidi_novasonic.py b/scripts/bidi/test_bidi_novasonic.py similarity index 100% rename from src/strands/experimental/bidi/scripts/test_bidi_novasonic.py rename to scripts/bidi/test_bidi_novasonic.py diff --git a/src/strands/experimental/bidi/scripts/test_bidi_openai.py b/scripts/bidi/test_bidi_openai.py similarity index 100% rename from src/strands/experimental/bidi/scripts/test_bidi_openai.py rename to scripts/bidi/test_bidi_openai.py diff --git a/src/strands/experimental/bidi/scripts/test_gemini_live.py b/scripts/bidi/test_gemini_live.py similarity index 100% rename from src/strands/experimental/bidi/scripts/test_gemini_live.py rename to scripts/bidi/test_gemini_live.py diff --git a/src/strands/experimental/bidi/__init__.py b/src/strands/experimental/bidi/__init__.py index 97de04684..712451af9 100644 --- a/src/strands/experimental/bidi/__init__.py +++ b/src/strands/experimental/bidi/__init__.py @@ -1,5 +1,10 @@ """Bidirectional streaming package.""" +import sys + +if sys.version_info < (3, 12): + raise ImportError("bidi only supported for >= Python 3.12") + # Main components - Primary user interface # Re-export standard agent events for tool handling from ...types._events import ( diff --git a/src/strands/experimental/bidi/_async/__init__.py b/src/strands/experimental/bidi/_async/__init__.py index 2a2d5fb0e..6cee3264d 100644 --- a/src/strands/experimental/bidi/_async/__init__.py +++ b/src/strands/experimental/bidi/_async/__init__.py @@ -26,4 +26,4 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: exceptions.append(exception) if exceptions: - raise ExceptionGroup("failed stop sequence", exceptions) # type: ignore # noqa: F821 + raise ExceptionGroup("failed stop sequence", exceptions) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 933bd8e11..3ca0d2ad3 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -15,7 +15,7 @@ import asyncio import json import logging -from typing import Any, AsyncIterable +from typing import Any, AsyncGenerator from .... import _identifier from ....agent.state import AgentState @@ -294,7 +294,7 @@ async def start(self, invocation_state: dict[str, Any] | None = None) -> None: await self._loop.start(invocation_state) self._started = True - async def send(self, input_data: BidiAgentInput) -> None: + async def send(self, input_data: BidiAgentInput | dict[str, Any]) -> None: """Send input to the model (text, audio, image, or event dict). Unified method for sending text, audio, and image input to the model during @@ -341,8 +341,9 @@ async def send(self, input_data: BidiAgentInput) -> None: return # Handle plain dict - reconstruct TypedEvent for WebSocket integration - if isinstance(input_data, dict) and "type" in input_data: # type: ignore + if isinstance(input_data, dict) and "type" in input_data: event_type = input_data["type"] + input_event: BidiInputEvent if event_type == "bidi_text_input": input_event = BidiTextInputEvent(text=input_data["text"], role=input_data["role"]) elif event_type == "bidi_audio_input": @@ -368,7 +369,7 @@ async def send(self, input_data: BidiAgentInput) -> None: f"or event dict with 'type' field, got: {type(input_data)}" ) - async def receive(self) -> AsyncIterable[BidiOutputEvent]: + async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: """Receive events from the model including audio, text, and tool calls. Yields: @@ -466,7 +467,7 @@ async def run_outputs() -> None: for start in [*start_inputs, *start_outputs]: await start() - async with asyncio.TaskGroup() as task_group: # type: ignore + async with asyncio.TaskGroup() as task_group: task_group.create_task(run_inputs()) task_group.create_task(run_outputs()) diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 3cfcf83aa..1003dea8c 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -5,7 +5,7 @@ import asyncio import logging -from typing import TYPE_CHECKING, Any, AsyncIterable +from typing import TYPE_CHECKING, Any, AsyncGenerator from ....types._events import ToolInterruptEvent, ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent from ....types.content import Message @@ -106,7 +106,7 @@ async def stop_model() -> None: finally: await self._agent.hooks.invoke_callbacks_async(BidiAfterInvocationEvent(agent=self._agent)) - async def receive(self) -> AsyncIterable[BidiOutputEvent]: + async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: """Receive model and tool call events. Raises: @@ -130,7 +130,7 @@ async def _run_model(self) -> None: logger.debug("model task starting") try: - async for event in self._agent.model.receive(): # type: ignore + async for event in self._agent.model.receive(): await self._event_queue.put(event) if isinstance(event, BidiTranscriptStreamEvent): diff --git a/src/strands/experimental/bidi/models/bidi_model.py b/src/strands/experimental/bidi/models/bidi_model.py index 04e4a69e4..ad91a81b0 100644 --- a/src/strands/experimental/bidi/models/bidi_model.py +++ b/src/strands/experimental/bidi/models/bidi_model.py @@ -64,7 +64,7 @@ async def stop(self) -> None: """ ... - async def receive(self) -> AsyncIterable[BidiOutputEvent]: + def receive(self) -> AsyncIterable[BidiOutputEvent]: """Receive streaming events from the model. Continuously yields events from the model as they arrive over the connection. diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index fbf62aad5..a2b0e7335 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -14,7 +14,7 @@ import base64 import logging import uuid -from typing import Any, AsyncIterable, Dict, List, Optional, cast +from typing import Any, AsyncGenerator, cast from google import genai from google.genai import types as genai_types @@ -35,6 +35,9 @@ BidiTextInputEvent, BidiTranscriptStreamEvent, BidiUsageEvent, + Channel, + ModalityUsage, + SampleRate, ) from .bidi_model import BidiModel @@ -58,7 +61,7 @@ def __init__( self, model_id: str = "gemini-2.5-flash-native-audio-preview-09-2025", api_key: str | None = None, - live_config: Dict[str, Any] | None = None, + live_config: dict[str, Any] | None = None, **kwargs: Any, ): """Initialize Gemini Live API bidirectional model. @@ -103,9 +106,9 @@ def __init__( async def start( self, - system_prompt: Optional[str] = None, - tools: Optional[List[ToolSpec]] = None, - messages: Optional[Messages] = None, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, **kwargs: Any, ) -> None: """Establish bidirectional connection with Gemini Live API. @@ -158,7 +161,7 @@ async def _send_message_history(self, messages: Messages) -> None: content = genai_types.Content(role=role, parts=content_parts) await self._live_session.send_client_content(turns=content) - async def receive(self) -> AsyncIterable[BidiOutputEvent]: # type: ignore + async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: """Receive Gemini Live API events and convert to provider-agnostic format.""" if not self._connection_id: raise RuntimeError("model not started | call start before receiving") @@ -171,7 +174,7 @@ async def receive(self) -> AsyncIterable[BidiOutputEvent]: # type: ignore for event in self._convert_gemini_live_event(message): yield event - def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOutputEvent]: + def _convert_gemini_live_event(self, message: LiveServerMessage) -> list[BidiOutputEvent]: """Convert Gemini Live API events to provider-agnostic format. Handles different types of content: @@ -232,8 +235,8 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOut BidiAudioStreamEvent( audio=audio_b64, format="pcm", - sample_rate=GEMINI_OUTPUT_SAMPLE_RATE, # type: ignore - channels=GEMINI_CHANNELS, # type: ignore + sample_rate=cast(SampleRate, GEMINI_OUTPUT_SAMPLE_RATE), + channels=cast(Channel, GEMINI_CHANNELS), ) ] @@ -262,18 +265,18 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOut # Handle tool calls - return list to support multiple tool calls if message.tool_call and message.tool_call.function_calls: - tool_events = [] + tool_events: list[BidiOutputEvent] = [] for func_call in message.tool_call.function_calls: tool_use_event: ToolUse = { - "toolUseId": func_call.id, # type: ignore - "name": func_call.name, # type: ignore + "toolUseId": cast(str, func_call.id), + "name": cast(str, func_call.name), "input": func_call.args or {}, } # Create ToolUseStreamEvent for consistency with standard agent tool_events.append( ToolUseStreamEvent(delta={"toolUse": tool_use_event}, current_tool_use=dict(tool_use_event)) ) - return tool_events # type: ignore + return tool_events # Handle usage metadata if hasattr(message, "usage_metadata") and message.usage_metadata: @@ -313,7 +316,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> List[BidiOut input_tokens=usage.prompt_token_count or 0, output_tokens=usage.response_token_count or 0, total_tokens=usage.total_token_count or 0, - modality_details=modality_details if modality_details else None, # type: ignore + modality_details=cast(list[ModalityUsage], modality_details) if modality_details else None, cache_read_input_tokens=usage.cached_content_token_count if usage.cached_content_token_count else None, @@ -426,8 +429,8 @@ async def stop_connection() -> None: await stop_all(stop_session, stop_connection) def _build_live_config( - self, system_prompt: Optional[str] = None, tools: Optional[List[ToolSpec]] = None, **kwargs: Any - ) -> Dict[str, Any]: + self, system_prompt: str | None = None, tools: list[ToolSpec] | None = None, **kwargs: Any + ) -> dict[str, Any]: """Build LiveConnectConfig for the official SDK. Simply passes through all config parameters from live_config, allowing users @@ -451,7 +454,7 @@ def _build_live_config( return config_dict - def _format_tools_for_live_api(self, tool_specs: List[ToolSpec]) -> List[genai_types.Tool]: + def _format_tools_for_live_api(self, tool_specs: list[ToolSpec]) -> list[genai_types.Tool]: """Format tool specs for Gemini Live API.""" if not tool_specs: return [] diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index d1821af6b..52ada2840 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -17,7 +17,7 @@ import json import logging import uuid -from typing import Any, AsyncIterable +from typing import Any, AsyncGenerator, cast import boto3 from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput @@ -28,6 +28,7 @@ ) from smithy_aws_core.identity.static import StaticCredentialsResolver from smithy_core.aio.eventstream import DuplexEventStream +from smithy_core.shapes import ShapeID from ....types._events import ToolResultEvent, ToolUseStreamEvent from ....types.content import Messages @@ -45,6 +46,7 @@ BidiTextInputEvent, BidiTranscriptStreamEvent, BidiUsageEvent, + SampleRate, ) from .bidi_model import BidiModel @@ -170,7 +172,7 @@ async def start( region=self.region, aws_credentials_identity_resolver=resolver, auth_scheme_resolver=HTTPAuthSchemeResolver(), - auth_schemes={"aws.auth#sigv4": SigV4AuthScheme(service="bedrock")}, + auth_schemes={ShapeID("aws.auth#sigv4"): SigV4AuthScheme(service="bedrock")}, # Configure static credentials as properties aws_access_key_id=credentials.access_key, aws_secret_access_key=credentials.secret_key, @@ -226,7 +228,7 @@ def _log_event_type(self, nova_event: dict[str, Any]) -> None: audio_bytes = base64.b64decode(audio_content) logger.debug("audio_bytes=<%d> | nova audio output received", len(audio_bytes)) - async def receive(self) -> AsyncIterable[BidiOutputEvent]: # type: ignore[override] + async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: """Receive Nova Sonic events and convert to provider-agnostic format. Raises: @@ -241,6 +243,9 @@ async def receive(self) -> AsyncIterable[BidiOutputEvent]: # type: ignore[overr _, output = await self._stream.await_output() while True: event_data = await output.receive() + if not event_data: + continue + nova_event = json.loads(event_data.value.bytes_.decode("utf-8"))["event"] self._log_event_type(nova_event) @@ -420,7 +425,7 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N return BidiAudioStreamEvent( audio=audio_content, format="pcm", - sample_rate=NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"], # type: ignore + sample_rate=cast(SampleRate, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]), channels=1, ) diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index cc82bbea8..8e9e78416 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -8,7 +8,7 @@ import logging import os import uuid -from typing import Any, AsyncIterable +from typing import Any, AsyncGenerator, cast import websockets from websockets import ClientConnection @@ -29,6 +29,10 @@ BidiTextInputEvent, BidiTranscriptStreamEvent, BidiUsageEvent, + ModalityUsage, + Role, + SampleRate, + StopReason, ) from .bidi_model import BidiModel @@ -171,7 +175,7 @@ def _create_text_event(self, text: str, role: str, is_final: bool = True) -> Bid return BidiTranscriptStreamEvent( delta={"text": text}, text=text, - role=normalized_role, # type: ignore + role=cast(Role, normalized_role), is_final=is_final, current_transcript=text if is_final else None, ) @@ -184,15 +188,15 @@ def _create_voice_activity_event(self, activity_type: str) -> BidiInterruptionEv # Other voice activity events are logged but don't create events return None - def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict: + def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict[str, Any]: """Build session configuration for OpenAI Realtime API.""" - config = DEFAULT_SESSION_CONFIG.copy() + config: dict[str, Any] = DEFAULT_SESSION_CONFIG.copy() if system_prompt: config["instructions"] = system_prompt if tools: - config["tools"] = self._convert_tools_to_openai_format(tools) # type: ignore + config["tools"] = self._convert_tools_to_openai_format(tools) # Apply user-provided session configuration supported_params = { @@ -261,7 +265,7 @@ async def _add_conversation_history(self, messages: Messages) -> None: await self._send_event(conversation_item) - async def receive(self) -> AsyncIterable[BidiOutputEvent]: # type: ignore + async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: """Receive OpenAI events and convert to Strands TypedEvent format.""" if not self._connection_id: raise RuntimeError("model not started | call start before receiving") @@ -291,7 +295,7 @@ def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutput BidiAudioStreamEvent( audio=openai_event["delta"], format="pcm", - sample_rate=AUDIO_FORMAT["rate"], # type: ignore + sample_rate=cast(SampleRate, AUDIO_FORMAT["rate"]), channels=1, ) ] @@ -404,7 +408,10 @@ def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutput # Always add response complete event events.append( - BidiResponseCompleteEvent(response_id=response_id, stop_reason=stop_reason_map.get(status, "complete")) # type: ignore + BidiResponseCompleteEvent( + response_id=response_id, + stop_reason=cast(StopReason, stop_reason_map.get(status, "complete")), + ), ) # Add usage event if available @@ -445,7 +452,7 @@ def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutput input_tokens=usage.get("input_tokens", 0), output_tokens=usage.get("output_tokens", 0), total_tokens=usage.get("total_tokens", 0), - modality_details=modality_details if modality_details else None, # type: ignore + modality_details=cast(list[ModalityUsage], modality_details) if modality_details else None, cache_read_input_tokens=cached_tokens if cached_tokens > 0 else None, ) ) diff --git a/src/strands/experimental/bidi/types/__init__.py b/src/strands/experimental/bidi/types/__init__.py index d8525e23a..1fa1d9048 100644 --- a/src/strands/experimental/bidi/types/__init__.py +++ b/src/strands/experimental/bidi/types/__init__.py @@ -2,12 +2,6 @@ from .agent import BidiAgentInput from .events import ( - DEFAULT_CHANNELS, - DEFAULT_FORMAT, - DEFAULT_SAMPLE_RATE, - SUPPORTED_AUDIO_FORMATS, - SUPPORTED_CHANNELS, - SUPPORTED_SAMPLE_RATES, BidiAudioInputEvent, BidiAudioStreamEvent, BidiConnectionCloseEvent, @@ -47,11 +41,4 @@ "ModalityUsage", "BidiErrorEvent", "BidiOutputEvent", - # Constants - "SUPPORTED_AUDIO_FORMATS", - "SUPPORTED_SAMPLE_RATES", - "SUPPORTED_CHANNELS", - "DEFAULT_SAMPLE_RATE", - "DEFAULT_CHANNELS", - "DEFAULT_FORMAT", ] diff --git a/src/strands/experimental/bidi/types/events.py b/src/strands/experimental/bidi/types/events.py index 8e6113ea3..70d0f8f3d 100644 --- a/src/strands/experimental/bidi/types/events.py +++ b/src/strands/experimental/bidi/types/events.py @@ -25,13 +25,11 @@ from ....types.streaming import ContentBlockDelta # Audio format constants -SUPPORTED_AUDIO_FORMATS = ["pcm", "wav", "opus", "mp3"] -SUPPORTED_SAMPLE_RATES = [16000, 24000, 48000] -SUPPORTED_CHANNELS = [1, 2] # 1=mono, 2=stereo -DEFAULT_SAMPLE_RATE = 16000 -DEFAULT_CHANNELS = 1 -DEFAULT_FORMAT = "pcm" - +AudioFormat = Literal["pcm", "wav", "opus", "mp3"] +SampleRate = Literal[16000, 24000, 48000] +Channel = Literal[1, 2] # 1=mono, 2=stereo +Role = Literal["user", "assistant"] +StopReason = Literal["complete", "interrupted", "tool_use", "error"] # ============================================================================ # Input Events (sent via agent.send()) @@ -84,9 +82,9 @@ class BidiAudioInputEvent(TypedEvent): def __init__( self, audio: str, - format: Literal["pcm", "wav", "opus", "mp3"] | str, - sample_rate: Literal[16000, 24000, 48000], - channels: Literal[1, 2], + format: AudioFormat | str, + sample_rate: SampleRate, + channels: Channel, ): """Initialize audio input event.""" super().__init__( @@ -219,9 +217,9 @@ class BidiAudioStreamEvent(TypedEvent): def __init__( self, audio: str, - format: Literal["pcm", "wav", "opus", "mp3"], - sample_rate: Literal[16000, 24000, 48000], - channels: Literal[1, 2], + format: AudioFormat, + sample_rate: SampleRate, + channels: Channel, ): """Initialize audio stream event.""" super().__init__( @@ -273,7 +271,7 @@ def __init__( self, delta: ContentBlockDelta, text: str, - role: Literal["user", "assistant"], + role: Role, is_final: bool, current_transcript: Optional[str] = None, ): @@ -349,7 +347,7 @@ class BidiResponseCompleteEvent(TypedEvent): def __init__( self, response_id: str, - stop_reason: Literal["complete", "interrupted", "tool_use", "error"], + stop_reason: StopReason, ): """Initialize response complete event.""" super().__init__( diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 9acba3372..a0ce7c811 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -57,7 +57,7 @@ async def _invoke_before_tool_call_hook( event_cls = BidiBeforeToolCallEvent if ToolExecutor._is_bidi_agent(agent) else BeforeToolCallEvent return await agent.hooks.invoke_callbacks_async( event_cls( - agent=agent, # type: ignore[arg-type] + agent=agent, selected_tool=tool_func, tool_use=tool_use, invocation_state=invocation_state, @@ -78,7 +78,7 @@ async def _invoke_after_tool_call_hook( event_cls = BidiAfterToolCallEvent if ToolExecutor._is_bidi_agent(agent) else AfterToolCallEvent return await agent.hooks.invoke_callbacks_async( event_cls( - agent=agent, # type: ignore[arg-type] + agent=agent, selected_tool=selected_tool, tool_use=tool_use, invocation_state=invocation_state, @@ -301,9 +301,7 @@ async def _stream_with_trace( tool_duration = time.time() - tool_start_time message = Message(role="user", content=[{"toolResult": result}]) if not ToolExecutor._is_bidi_agent(agent): - agent.event_loop_metrics.add_tool_usage( # type: ignore[union-attr] - tool_use, tool_duration, tool_trace, tool_success, message - ) + agent.event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message) cycle_trace.add_child(tool_trace) tracer.end_tool_call_span(tool_call_span, result) diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index 56a5999c0..7b189a5c6 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -138,6 +138,7 @@ async def test_stream_e2e_success(alist): "arg1": 1013, "current_tool_use": {"input": {}, "name": "normal_tool", "toolUseId": "123"}, "delta": {"toolUse": {"input": "{}"}}, + "type": "tool_use_stream", }, {"event": {"contentBlockStop": {}}}, {"event": {"messageStop": {"stopReason": "tool_use"}}}, @@ -195,6 +196,7 @@ async def test_stream_e2e_success(alist): "model": ANY, "system_prompt": None, "tool_config": tool_config, + "type": "tool_use_stream", }, {"event": {"contentBlockStop": {}}}, {"event": {"messageStop": {"stopReason": "tool_use"}}}, @@ -252,6 +254,7 @@ async def test_stream_e2e_success(alist): "model": ANY, "system_prompt": None, "tool_config": tool_config, + "type": "tool_use_stream", }, {"event": {"contentBlockStop": {}}}, {"event": {"messageStop": {"stopReason": "tool_use"}}}, diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 3f5a6c998..02be400b1 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -133,11 +133,12 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) @pytest.mark.parametrize( - ("event", "state", "exp_updated_state", "callback_args"), + ("event", "event_type", "state", "exp_updated_state", "callback_args"), [ # Tool Use - Existing input ( {"delta": {"toolUse": {"input": '"value"}'}}}, + {"type": "tool_use_stream"}, {"current_tool_use": {"input": '{"key": '}}, {"current_tool_use": {"input": '{"key": "value"}'}}, {"current_tool_use": {"input": '{"key": "value"}'}}, @@ -145,6 +146,7 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) # Tool Use - New input ( {"delta": {"toolUse": {"input": '{"key": '}}}, + {"type": "tool_use_stream"}, {"current_tool_use": {}}, {"current_tool_use": {"input": '{"key": '}}, {"current_tool_use": {"input": '{"key": '}}, @@ -152,6 +154,7 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) # Text ( {"delta": {"text": " world"}}, + {}, {"text": "hello"}, {"text": "hello world"}, {"data": " world"}, @@ -159,6 +162,7 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) # Reasoning - Text - Existing ( {"delta": {"reasoningContent": {"text": "king"}}}, + {}, {"reasoningText": "thin"}, {"reasoningText": "thinking"}, {"reasoningText": "king", "reasoning": True}, @@ -167,12 +171,14 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) ( {"delta": {"reasoningContent": {"text": "thin"}}}, {}, + {}, {"reasoningText": "thin"}, {"reasoningText": "thin", "reasoning": True}, ), # Reasoning - Signature - Existing ( {"delta": {"reasoningContent": {"signature": "ue"}}}, + {}, {"signature": "val"}, {"signature": "value"}, {"reasoning_signature": "ue", "reasoning": True}, @@ -181,6 +187,7 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) ( {"delta": {"reasoningContent": {"signature": "val"}}}, {}, + {}, {"signature": "val"}, {"reasoning_signature": "val", "reasoning": True}, ), @@ -188,12 +195,14 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) pytest.param( {"delta": {"reasoningContent": {"redactedContent": b"encoded"}}}, {}, + {}, {"redactedContent": b"encoded"}, {"reasoningRedactedContent": b"encoded", "reasoning": True}, ), # Reasoning - redactedContent - Existing pytest.param( {"delta": {"reasoningContent": {"redactedContent": b"data"}}}, + {}, {"redactedContent": b"encoded_"}, {"redactedContent": b"encoded_data"}, {"reasoningRedactedContent": b"data", "reasoning": True}, @@ -204,6 +213,7 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) {}, {}, {}, + {}, ), # Empty ( @@ -211,11 +221,12 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) {}, {}, {}, + {}, ), ], ) -def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_updated_state, callback_args): - exp_callback_event = {**callback_args, "delta": event["delta"]} if callback_args else {} +def test_handle_content_block_delta(event: ContentBlockDeltaEvent, event_type, state, exp_updated_state, callback_args): + exp_callback_event = {**event_type, **callback_args, "delta": event["delta"]} if callback_args else {} tru_updated_state, tru_callback_event = strands.event_loop.streaming.handle_content_block_delta(event, state) @@ -526,6 +537,7 @@ def test_extract_usage_metrics_empty_metadata(): "input": '{"key": "value"}', }, }, + "type": "tool_use_stream", }, { "event": { diff --git a/tests/strands/experimental/bidi/_async/test__init__.py b/tests/strands/experimental/bidi/_async/test__init__.py index ac4b1ab61..f8df25e14 100644 --- a/tests/strands/experimental/bidi/_async/test__init__.py +++ b/tests/strands/experimental/bidi/_async/test__init__.py @@ -11,7 +11,7 @@ async def test_stop_exception(): func2 = AsyncMock(side_effect=ValueError("stop 2 failed")) func3 = AsyncMock() - with pytest.raises(ExceptionGroup) as exc_info: # type: ignore # noqa: F821 + with pytest.raises(ExceptionGroup) as exc_info: await stop_all(func1, func2, func3) func1.assert_called_once() diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index d32f88fe7..5567af042 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -184,7 +184,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): model4 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) await model4.start() mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") - with pytest.raises(ExceptionGroup): # noqa: F821 + with pytest.raises(ExceptionGroup): await model4.stop() diff --git a/tests/strands/experimental/bidi/models/test_openai.py b/tests/strands/experimental/bidi/models/test_openai.py index c0c38d1b2..869ef3afd 100644 --- a/tests/strands/experimental/bidi/models/test_openai.py +++ b/tests/strands/experimental/bidi/models/test_openai.py @@ -210,7 +210,7 @@ async def async_connect(*args, **kwargs): model4 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) await model4.start() mock_ws.close.side_effect = Exception("Close failed") - with pytest.raises(ExceptionGroup): # noqa: F821 + with pytest.raises(ExceptionGroup): await model4.stop() From 7a690eef05b86abf93d2bae295ee47cae5d17c67 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 24 Nov 2025 10:45:51 +0100 Subject: [PATCH 180/242] fix mypy errors --- .../experimental/bidi/models/openai.py | 2 +- src/strands/session/session_manager.py | 2 +- .../bidi/models/test_novasonic.py | 77 +++++++++---------- .../bidi/models/test_openai_realtime.py | 7 +- 4 files changed, 41 insertions(+), 47 deletions(-) diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index e7fdb9962..4d854886c 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -398,7 +398,7 @@ def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutput BidiAudioStreamEvent( audio=openai_event["delta"], format="pcm", - sample_rate=sample_rate, # type: ignore + sample_rate=sample_rate, channels=1, ) ] diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index dfa9b147c..78625aa8e 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -50,7 +50,7 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: # Register BidiAgent hooks if the experimental module is available try: - from ..experimental.bidi.hooks.events import ( + from ..experimental.hooks.events import ( BidiAfterInvocationEvent, BidiAgentInitializedEvent, BidiMessageAddedEvent, diff --git a/tests/strands/experimental/bidi/models/test_novasonic.py b/tests/strands/experimental/bidi/models/test_novasonic.py index 4b976838d..90151c18a 100644 --- a/tests/strands/experimental/bidi/models/test_novasonic.py +++ b/tests/strands/experimental/bidi/models/test_novasonic.py @@ -122,46 +122,43 @@ async def test_model_stop_alone(nova_model): @pytest.mark.asyncio async def test_connection_with_message_history(nova_model, mock_client, mock_stream): """Test connection initialization with conversation history.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model.client = mock_client - - # Create message history - messages = [ - {"role": "user", "content": [{"text": "What's the weather?"}]}, - {"role": "assistant", "content": [{"text": "I'll check the weather for you."}]}, - { - "role": "assistant", - "content": [{"toolUse": {"toolUseId": "tool-123", "name": "get_weather", "input": {}}}], - }, - { - "role": "user", - "content": [{"toolResult": {"toolUseId": "tool-123", "content": [{"text": "Sunny, 72°F"}]}}], - }, - {"role": "assistant", "content": [{"text": "It's sunny and 72 degrees."}]}, - ] - - # Start connection with message history - await nova_model.start(system_prompt="You are a helpful assistant", messages=messages) - - # Verify initialization events were sent - # Should include: sessionStart, promptStart, system prompt (3 events), - # and message history (5 messages * 3 events each = 15 events) - # Total: 1 + 1 + 3 + 15 = 20 events minimum - assert mock_stream.input_stream.send.call_count >= 18 - - # Verify the events contain proper role information - sent_events = [ - call.args[0].value.bytes_.decode("utf-8") for call in mock_stream.input_stream.send.call_args_list - ] - - # Check that USER and ASSISTANT roles are present in contentStart events - user_events = [e for e in sent_events if '"role": "USER"' in e] - assistant_events = [e for e in sent_events if '"role": "ASSISTANT"' in e] - - assert len(user_events) >= 2 # At least 2 user messages - assert len(assistant_events) >= 3 # At least 3 assistant messages - - await nova_model.stop() + nova_model.client = mock_client + + # Create message history + messages = [ + {"role": "user", "content": [{"text": "What's the weather?"}]}, + {"role": "assistant", "content": [{"text": "I'll check the weather for you."}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "tool-123", "name": "get_weather", "input": {}}}], + }, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "tool-123", "content": [{"text": "Sunny, 72°F"}]}}], + }, + {"role": "assistant", "content": [{"text": "It's sunny and 72 degrees."}]}, + ] + + # Start connection with message history + await nova_model.start(system_prompt="You are a helpful assistant", messages=messages) + + # Verify initialization events were sent + # Should include: sessionStart, promptStart, system prompt (3 events), + # and message history (5 messages * 3 events each = 15 events) + # Total: 1 + 1 + 3 + 15 = 20 events minimum + assert mock_stream.input_stream.send.call_count >= 18 + + # Verify the events contain proper role information + sent_events = [call.args[0].value.bytes_.decode("utf-8") for call in mock_stream.input_stream.send.call_args_list] + + # Check that USER and ASSISTANT roles are present in contentStart events + user_events = [e for e in sent_events if '"role": "USER"' in e] + assistant_events = [e for e in sent_events if '"role": "ASSISTANT"' in e] + + assert len(user_events) >= 2 # At least 2 user messages + assert len(assistant_events) >= 3 # At least 3 assistant messages + + await nova_model.stop() # Send Method Tests diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py index 0012b5649..58045b609 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -94,7 +94,6 @@ def test_model_initialization(api_key, model_name): assert model_default.model == "gpt-realtime" assert model_default.api_key == "test-key" assert model_default._active is False - assert model_default.websocket is None # Test with custom model model_custom = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) @@ -355,11 +354,9 @@ async def test_send_edge_cases(mock_websockets_connect, model): await model.send(image_input) mock_logger.warning.assert_called_with("Image input not supported by OpenAI Realtime API") - # Test unknown content type + # Test unknown content type - should just be ignored without warning unknown_content = {"unknown_field": "value"} - with unittest.mock.patch("strands.experimental.bidi.models.openai.logger") as mock_logger: - await model.send(unknown_content) - assert mock_logger.warning.called + await model.send(unknown_content) await model.stop() From 46e14f02ca0a20c31d131b0d73f2c0ff2e712526 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 25 Nov 2025 12:25:56 +0100 Subject: [PATCH 181/242] fix tool result mapping --- .../experimental/bidi/models/novasonic.py | 28 +-- .../experimental/bidi/models/openai.py | 63 ++++-- .../bidi/models/test_novasonic.py | 154 ++++++++++++-- .../experimental/bidi/models/test_openai.py | 194 +++++++++++++++++- 4 files changed, 390 insertions(+), 49 deletions(-) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 4accc6369..c069da1c4 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -354,14 +354,15 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: logger.debug("tool_use_id=<%s> | sending nova tool result", tool_use_id) - # TODO: We need to extract all content and content types - result_data = {} - if "content" in tool_result: - # Extract text from content blocks - for block in tool_result["content"]: - if "text" in block: - result_data = {"result": block["text"]} - break + # Nova Sonic expects stringified JSON in toolResult.content + content = tool_result.get("content", []) + + # Optimize for single content item - unwrap the array + if len(content) == 1: + result_data: dict[str, Any] = content[0] + else: + # Multiple items - send as array + result_data = {"content": content} content_name = str(uuid.uuid4()) events = [ @@ -564,17 +565,6 @@ def _get_message_history_events(self, messages: Messages) -> list[str]: for block in content_blocks: if "text" in block: text_parts.append(block["text"]) - elif "toolUse" in block: - # Include tool use information in text format for context - tool_use = block["toolUse"] - text_parts.append(f"[Tool: {tool_use['name']}]") - elif "toolResult" in block: - # Include tool result information in text format for context - tool_result = block["toolResult"] - if "content" in tool_result: - for result_block in tool_result["content"]: - if "text" in result_block: - text_parts.append(f"[Tool Result: {result_block['text']}]") # Combine all text parts if text_parts: diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 451c90d12..088f58c57 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -306,12 +306,31 @@ async def _add_conversation_history(self, messages: Messages) -> None: elif "toolResult" in block: # Tool result - create as function_call_output item tool_result = block["toolResult"] - result_text = "" + + # Serialize the entire tool result content, preserving all data types + result_output = "" if "content" in tool_result: + # Collect all content blocks + content_parts = [] for result_block in tool_result["content"]: if "text" in result_block: - result_text = result_block["text"] - break + content_parts.append(result_block["text"]) + elif "json" in result_block: + # Preserve JSON content + json_content = result_block["json"] + content_parts.append( + json.dumps(json_content) if not isinstance(json_content, str) else json_content + ) + elif "image" in result_block: + logger.warning("image content in tool results not supported by openai realtime api") + elif "document" in result_block: + logger.warning("document content in tool results not supported by openai realtime api") + + # Combine all parts - if single part, use as-is; if multiple, combine + if len(content_parts) == 1: + result_output = content_parts[0] + elif content_parts: + result_output = "\n".join(content_parts) original_id = tool_result["toolUseId"] # Use mapped call_id if available, otherwise skip orphaned result @@ -325,7 +344,7 @@ async def _add_conversation_history(self, messages: Messages) -> None: "item": { "type": "function_call_output", "call_id": call_id, - "output": result_text, + "output": result_output, }, } await self._send_event(result_item) @@ -665,18 +684,36 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: logger.debug("tool_use_id=<%s> | sending openai tool result", tool_use_id) - # TODO: We need to extract all content and content types - result_data: dict[Any, Any] | str = {} + # Serialize the entire tool result content, preserving all data types + result_output = "" if "content" in tool_result: - # Extract text from content blocks + # Collect all content blocks + content_parts = [] for block in tool_result["content"]: if "text" in block: - result_data = block["text"] - break - - result_text = json.dumps(result_data) if not isinstance(result_data, str) else result_data - - item_data = {"type": "function_call_output", "call_id": tool_use_id, "output": result_text} + content_parts.append(block["text"]) + elif "json" in block: + # Preserve JSON content + json_content = block["json"] + content_parts.append( + json.dumps(json_content) if not isinstance(json_content, str) else json_content + ) + elif "image" in block: + raise ValueError( + f"tool_use_id=<{tool_use_id}> | Image content in tool results is not supported by OpenAI Realtime API" + ) + elif "document" in block: + raise ValueError( + f"tool_use_id=<{tool_use_id}> | Document content in tool results is not supported by OpenAI Realtime API" + ) + + # Combine all parts - if single part, use as-is; if multiple, combine + if len(content_parts) == 1: + result_output = content_parts[0] + elif content_parts: + result_output = "\n".join(content_parts) + + item_data = {"type": "function_call_output", "call_id": tool_use_id, "output": result_output} await self._send_event({"type": "conversation.item.create", "item": item_data}) await self._send_event({"type": "response.create"}) diff --git a/tests/strands/experimental/bidi/models/test_novasonic.py b/tests/strands/experimental/bidi/models/test_novasonic.py index 575160e8d..23d1465bc 100644 --- a/tests/strands/experimental/bidi/models/test_novasonic.py +++ b/tests/strands/experimental/bidi/models/test_novasonic.py @@ -144,9 +144,10 @@ async def test_connection_with_message_history(nova_model, mock_client, mock_str # Verify initialization events were sent # Should include: sessionStart, promptStart, system prompt (3 events), - # and message history (5 messages * 3 events each = 15 events) - # Total: 1 + 1 + 3 + 15 = 20 events minimum - assert mock_stream.input_stream.send.call_count >= 18 + # and message history (only text messages: 3 messages * 3 events each = 9 events) + # Tool use/result messages are now skipped in history + # Total: 1 + 1 + 3 + 9 = 14 events minimum + assert mock_stream.input_stream.send.call_count >= 14 # Verify the events contain proper role information sent_events = [call.args[0].value.bytes_.decode("utf-8") for call in mock_stream.input_stream.send.call_args_list] @@ -155,8 +156,9 @@ async def test_connection_with_message_history(nova_model, mock_client, mock_str user_events = [e for e in sent_events if '"role": "USER"' in e] assistant_events = [e for e in sent_events if '"role": "ASSISTANT"' in e] - assert len(user_events) >= 2 # At least 2 user messages - assert len(assistant_events) >= 3 # At least 3 assistant messages + # Only text messages are sent, so we expect 1 user message and 2 assistant messages + assert len(user_events) >= 1 + assert len(assistant_events) >= 2 await nova_model.stop() @@ -183,16 +185,25 @@ async def test_send_all_content_types(nova_model, mock_stream): assert nova_model._audio_content_name assert mock_stream.input_stream.send.called - # Test tool result - tool_result: ToolResult = { + # Test tool result with single content item (should be unwrapped) + tool_result_single: ToolResult = { "toolUseId": "tool-123", "status": "success", "content": [{"text": "Weather is sunny"}], } - await nova_model.send(ToolResultEvent(tool_result)) + await nova_model.send(ToolResultEvent(tool_result_single)) # Should send contentStart, toolResult, and contentEnd assert mock_stream.input_stream.send.called + # Test tool result with multiple content items (should send as array) + tool_result_multi: ToolResult = { + "toolUseId": "tool-456", + "status": "success", + "content": [{"text": "Part 1"}, {"json": {"data": "value"}}], + } + await nova_model.send(ToolResultEvent(tool_result_multi)) + assert mock_stream.input_stream.send.called + await nova_model.stop() @@ -407,8 +418,9 @@ async def test_message_history_conversion(nova_model): events = nova_model._get_message_history_events(messages) - # Each message should generate 3 events: contentStart, textInput, contentEnd - assert len(events) == 15 # 5 messages * 3 events each + # Only text messages generate events (3 messages * 3 events each = 9 events) + # Tool use/result messages are now skipped in history + assert len(events) == 9 # Parse and verify events parsed_events = [json.loads(e) for e in events] @@ -426,13 +438,11 @@ async def test_message_history_conversion(nova_model): assert "textInput" in parsed_events[4]["event"] assert parsed_events[4]["event"]["textInput"]["content"] == "Hi there!" - # Check tool use message (should include tool name in text) + # Check third message (assistant - last text message) + assert "contentStart" in parsed_events[6]["event"] + assert parsed_events[6]["event"]["contentStart"]["role"] == "ASSISTANT" assert "textInput" in parsed_events[7]["event"] - assert "[Tool: calculator]" in parsed_events[7]["event"]["textInput"]["content"] - - # Check tool result message (should include result in text) - assert "textInput" in parsed_events[10]["event"] - assert "[Tool Result: 4]" in parsed_events[10]["event"]["textInput"]["content"] + assert parsed_events[7]["event"]["textInput"]["content"] == "The answer is 4" @pytest.mark.asyncio @@ -479,3 +489,115 @@ async def mock_error(*args, **kwargs): # Should still be able to close cleanly await nova_model.stop() + + +# Tool Result Content Tests + + +@pytest.mark.asyncio +async def test_tool_result_single_content_unwrapped(nova_model, mock_stream): + """Test that single content item is unwrapped (optimization).""" + await nova_model.start() + + tool_result: ToolResult = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Single result"}], + } + + await nova_model.send(ToolResultEvent(tool_result)) + + # Verify events were sent + assert mock_stream.input_stream.send.called + calls = mock_stream.input_stream.send.call_args_list + + # Find the toolResult event + tool_result_events = [] + for call in calls: + event_json = call.args[0].value.bytes_.decode("utf-8") + event = json.loads(event_json) + if "toolResult" in event.get("event", {}): + tool_result_events.append(event) + + assert len(tool_result_events) > 0 + tool_result_event = tool_result_events[0]["event"]["toolResult"] + + # Single content should be unwrapped (not in array) + content = json.loads(tool_result_event["content"]) + assert content == {"text": "Single result"} + + await nova_model.stop() + + +@pytest.mark.asyncio +async def test_tool_result_multiple_content_as_array(nova_model, mock_stream): + """Test that multiple content items are sent as array.""" + await nova_model.start() + + tool_result: ToolResult = { + "toolUseId": "tool-456", + "status": "success", + "content": [{"text": "Part 1"}, {"json": {"data": "value"}}], + } + + await nova_model.send(ToolResultEvent(tool_result)) + + # Verify events were sent + assert mock_stream.input_stream.send.called + calls = mock_stream.input_stream.send.call_args_list + + # Find the toolResult event + tool_result_events = [] + for call in calls: + event_json = call.args[0].value.bytes_.decode("utf-8") + event = json.loads(event_json) + if "toolResult" in event.get("event", {}): + tool_result_events.append(event) + + assert len(tool_result_events) > 0 + tool_result_event = tool_result_events[0]["event"]["toolResult"] + + # Multiple content should be in array format + content = json.loads(tool_result_event["content"]) + assert "content" in content + assert isinstance(content["content"], list) + assert len(content["content"]) == 2 + assert content["content"][0] == {"text": "Part 1"} + assert content["content"][1] == {"json": {"data": "value"}} + + await nova_model.stop() + + +@pytest.mark.asyncio +async def test_tool_result_empty_content(nova_model, mock_stream): + """Test that empty content is handled gracefully.""" + await nova_model.start() + + tool_result: ToolResult = { + "toolUseId": "tool-789", + "status": "success", + "content": [], + } + + await nova_model.send(ToolResultEvent(tool_result)) + + # Verify events were sent + assert mock_stream.input_stream.send.called + calls = mock_stream.input_stream.send.call_args_list + + # Find the toolResult event + tool_result_events = [] + for call in calls: + event_json = call.args[0].value.bytes_.decode("utf-8") + event = json.loads(event_json) + if "toolResult" in event.get("event", {}): + tool_result_events.append(event) + + assert len(tool_result_events) > 0 + tool_result_event = tool_result_events[0]["event"]["toolResult"] + + # Empty content should result in empty array wrapped in content key + content = json.loads(tool_result_event["content"]) + assert content == {"content": []} + + await nova_model.stop() diff --git a/tests/strands/experimental/bidi/models/test_openai.py b/tests/strands/experimental/bidi/models/test_openai.py index cfd298b8d..7a0da3a47 100644 --- a/tests/strands/experimental/bidi/models/test_openai.py +++ b/tests/strands/experimental/bidi/models/test_openai.py @@ -312,7 +312,7 @@ async def test_send_all_content_types(mock_websockets_connect, model): # Audio should be passed through as base64 assert audio_append[0]["audio"] == audio_b64 - # Test tool result + # Test tool result with text content tool_result: ToolResult = { "toolUseId": "tool-123", "status": "success", @@ -326,6 +326,60 @@ async def test_send_all_content_types(mock_websockets_connect, model): item = item_create[-1].get("item", {}) assert item.get("type") == "function_call_output" assert item.get("call_id") == "tool-123" + assert item.get("output") == "Result: 42" + + # Test tool result with JSON content + tool_result_json: ToolResult = { + "toolUseId": "tool-456", + "status": "success", + "content": [{"json": {"result": 42, "status": "ok"}}], + } + await model.send(ToolResultEvent(tool_result_json)) + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + item_create = [m for m in messages if m.get("type") == "conversation.item.create"] + item = item_create[-1].get("item", {}) + assert item.get("type") == "function_call_output" + assert item.get("call_id") == "tool-456" + # JSON should be serialized + assert json.loads(item.get("output")) == {"result": 42, "status": "ok"} + + # Test tool result with multiple content blocks + tool_result_multi: ToolResult = { + "toolUseId": "tool-789", + "status": "success", + "content": [{"text": "Part 1"}, {"json": {"data": "value"}}, {"text": "Part 2"}], + } + await model.send(ToolResultEvent(tool_result_multi)) + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + item_create = [m for m in messages if m.get("type") == "conversation.item.create"] + item = item_create[-1].get("item", {}) + assert item.get("type") == "function_call_output" + assert item.get("call_id") == "tool-789" + # Multiple parts should be joined with newlines + output = item.get("output") + assert "Part 1" in output + assert '"data": "value"' in output or "'data': 'value'" in output + assert "Part 2" in output + + # Test tool result with image content (should raise error) + tool_result_image: ToolResult = { + "toolUseId": "tool-999", + "status": "success", + "content": [{"image": {"format": "jpeg", "source": {"bytes": b"image_data"}}}], + } + with pytest.raises(ValueError, match=r"Image content.*not supported"): + await model.send(ToolResultEvent(tool_result_image)) + + # Test tool result with document content (should raise error) + tool_result_doc: ToolResult = { + "toolUseId": "tool-888", + "status": "success", + "content": [{"document": {"format": "pdf", "source": {"bytes": b"doc_data"}}}], + } + with pytest.raises(ValueError, match=r"Document content.*not supported"): + await model.send(ToolResultEvent(tool_result_doc)) await model.stop() @@ -634,3 +688,141 @@ async def test_partial_audio_config(mock_websockets_connect, api_key): assert audio_event.channels == 1 await model.stop() + + +# Tool Result Content Tests + + +@pytest.mark.asyncio +async def test_tool_result_single_text_content(mock_websockets_connect, api_key): + """Test tool result with single text content block.""" + _, mock_ws = mock_websockets_connect + model = BidiOpenAIRealtimeModel(api_key=api_key) + await model.start() + + tool_result: ToolResult = { + "toolUseId": "call-123", + "status": "success", + "content": [{"text": "Simple text result"}], + } + + await model.send(ToolResultEvent(tool_result)) + + # Verify the sent event + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + item_create = [m for m in messages if m.get("type") == "conversation.item.create"] + + assert len(item_create) > 0 + item = item_create[-1].get("item", {}) + assert item.get("type") == "function_call_output" + assert item.get("call_id") == "call-123" + assert item.get("output") == "Simple text result" + + await model.stop() + + +@pytest.mark.asyncio +async def test_tool_result_single_json_content(mock_websockets_connect, api_key): + """Test tool result with single JSON content block.""" + _, mock_ws = mock_websockets_connect + model = BidiOpenAIRealtimeModel(api_key=api_key) + await model.start() + + tool_result: ToolResult = { + "toolUseId": "call-456", + "status": "success", + "content": [{"json": {"temperature": 72, "condition": "sunny"}}], + } + + await model.send(ToolResultEvent(tool_result)) + + # Verify the sent event + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + item_create = [m for m in messages if m.get("type") == "conversation.item.create"] + + item = item_create[-1].get("item", {}) + assert item.get("type") == "function_call_output" + assert item.get("call_id") == "call-456" + # JSON should be serialized as string + output = item.get("output") + parsed = json.loads(output) + assert parsed == {"temperature": 72, "condition": "sunny"} + + await model.stop() + + +@pytest.mark.asyncio +async def test_tool_result_multiple_content_blocks(mock_websockets_connect, api_key): + """Test tool result with multiple content blocks (text and json).""" + _, mock_ws = mock_websockets_connect + model = BidiOpenAIRealtimeModel(api_key=api_key) + await model.start() + + tool_result: ToolResult = { + "toolUseId": "call-789", + "status": "success", + "content": [ + {"text": "Weather data:"}, + {"json": {"temp": 72, "humidity": 65}}, + {"text": "Forecast: sunny"}, + ], + } + + await model.send(ToolResultEvent(tool_result)) + + # Verify the sent event + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + item_create = [m for m in messages if m.get("type") == "conversation.item.create"] + + item = item_create[-1].get("item", {}) + assert item.get("type") == "function_call_output" + assert item.get("call_id") == "call-789" + # Multiple parts should be joined with newlines + output = item.get("output") + assert "Weather data:" in output + assert "temp" in output + assert "humidity" in output + assert "Forecast: sunny" in output + + await model.stop() + + +@pytest.mark.asyncio +async def test_tool_result_image_content_raises_error(mock_websockets_connect, api_key): + """Test that tool result with image content raises ValueError.""" + _, mock_ws = mock_websockets_connect + model = BidiOpenAIRealtimeModel(api_key=api_key) + await model.start() + + tool_result: ToolResult = { + "toolUseId": "call-999", + "status": "success", + "content": [{"image": {"format": "jpeg", "source": {"bytes": b"fake_image_data"}}}], + } + + with pytest.raises(ValueError, match=r"Image content.*not supported.*OpenAI Realtime API"): + await model.send(ToolResultEvent(tool_result)) + + await model.stop() + + +@pytest.mark.asyncio +async def test_tool_result_document_content_raises_error(mock_websockets_connect, api_key): + """Test that tool result with document content raises ValueError.""" + _, mock_ws = mock_websockets_connect + model = BidiOpenAIRealtimeModel(api_key=api_key) + await model.start() + + tool_result: ToolResult = { + "toolUseId": "call-888", + "status": "success", + "content": [{"document": {"format": "pdf", "source": {"bytes": b"fake_pdf_data"}}}], + } + + with pytest.raises(ValueError, match=r"Document content.*not supported.*OpenAI Realtime API"): + await model.send(ToolResultEvent(tool_result)) + + await model.stop() From a1a3fc606f5e0cc6644c417561e603a47b96af75 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 25 Nov 2025 12:26:05 +0100 Subject: [PATCH 182/242] update bidi hook imports --- src/strands/session/session_manager.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index 78625aa8e..ba4356089 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -4,6 +4,11 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any +from ..experimental.hooks.events import ( + BidiAfterInvocationEvent, + BidiAgentInitializedEvent, + BidiMessageAddedEvent, +) from ..experimental.hooks.multiagent.events import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, @@ -48,22 +53,11 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: registry.add_callback(AfterNodeCallEvent, lambda event: self.sync_multi_agent(event.source)) registry.add_callback(AfterMultiAgentInvocationEvent, lambda event: self.sync_multi_agent(event.source)) - # Register BidiAgent hooks if the experimental module is available - try: - from ..experimental.hooks.events import ( - BidiAfterInvocationEvent, - BidiAgentInitializedEvent, - BidiMessageAddedEvent, - ) - - registry.add_callback(BidiAgentInitializedEvent, lambda event: self.initialize_bidi_agent(event.agent)) - registry.add_callback( - BidiMessageAddedEvent, lambda event: self.append_bidi_message(event.message, event.agent) - ) - registry.add_callback(BidiMessageAddedEvent, lambda event: self.sync_bidi_agent(event.agent)) - registry.add_callback(BidiAfterInvocationEvent, lambda event: self.sync_bidi_agent(event.agent)) - except ImportError: - pass + # Register BidiAgent hooks + registry.add_callback(BidiAgentInitializedEvent, lambda event: self.initialize_bidi_agent(event.agent)) + registry.add_callback(BidiMessageAddedEvent, lambda event: self.append_bidi_message(event.message, event.agent)) + registry.add_callback(BidiMessageAddedEvent, lambda event: self.sync_bidi_agent(event.agent)) + registry.add_callback(BidiAfterInvocationEvent, lambda event: self.sync_bidi_agent(event.agent)) @abstractmethod def redact_latest_message(self, redact_message: Message, agent: "Agent", **kwargs: Any) -> None: From 275f469f1dc6643d38f206655541e4c81f498fe9 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 25 Nov 2025 07:44:32 -0500 Subject: [PATCH 183/242] pass agent to io's start instead of just model audio config --- src/strands/experimental/bidi/agent/agent.py | 27 +++----------------- src/strands/experimental/bidi/io/audio.py | 23 ++++++++++------- src/strands/experimental/bidi/io/text.py | 12 +++++++-- src/strands/experimental/bidi/types/io.py | 21 +++++++++++---- 4 files changed, 44 insertions(+), 39 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 9f99f9f9a..0e74da392 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -31,7 +31,6 @@ from ...hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent from ...tools import ToolProvider from ..hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent -from ..io.audio import _BidiAudioInput, _BidiAudioOutput from .._async import stop_all from ..models.bidi_model import BidiModel from ..models.novasonic import BidiNovaSonicModel @@ -446,16 +445,6 @@ async def run( ) ``` """ - # Extract audio config from model if available - audio_config = getattr(self.model, "audio_config", None) - if audio_config: - logger.debug( - "audio_config | model provides: input_rate=%s, output_rate=%s, channels=%s, voice=%s", - audio_config.get("input_rate"), - audio_config.get("output_rate"), - audio_config.get("channels"), - audio_config.get("voice"), - ) async def run_inputs() -> None: async def task(input_: BidiInput) -> None: @@ -473,23 +462,15 @@ async def run_outputs() -> None: await self.start() - # Start inputs with audio config if applicable + # Start inputs, passing agent for configuration for input_ in inputs: if hasattr(input_, "start"): - # Pass audio config to audio inputs - if audio_config and isinstance(input_, _BidiAudioInput): - await input_.start(audio_config=audio_config) - else: - await input_.start() + await input_.start(self) - # Start outputs with audio config if applicable + # Start outputs, passing agent for configuration for output in outputs: if hasattr(output, "start"): - # Pass audio config to audio outputs - if audio_config and isinstance(output, _BidiAudioOutput): - await output.start(audio_config=audio_config) - else: - await output.start() + await output.start(self) try: await self.start(invocation_state) diff --git a/src/strands/experimental/bidi/io/audio.py b/src/strands/experimental/bidi/io/audio.py index 41a07be9e..98f64ca1e 100644 --- a/src/strands/experimental/bidi/io/audio.py +++ b/src/strands/experimental/bidi/io/audio.py @@ -8,13 +8,16 @@ import base64 import logging from collections import deque -from typing import Any +from typing import TYPE_CHECKING, Any import pyaudio from ..types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiInterruptionEvent, BidiOutputEvent from ..types.io import AudioConfig, BidiInput, BidiOutput +if TYPE_CHECKING: + from ..agent.agent import BidiAgent + logger = logging.getLogger(__name__) @@ -49,14 +52,15 @@ def __init__(self, config: dict[str, Any]) -> None: self._frames_per_buffer = config.get("input_frames_per_buffer", _BidiAudioInput._FRAMES_PER_BUFFER) self._rate = config.get("input_rate", _BidiAudioInput._RATE) - async def start(self, audio_config: AudioConfig | None = None) -> None: + async def start(self, agent: "BidiAgent") -> None: """Start input stream. Args: - audio_config: Optional audio configuration from model provider. - Only applied if user did not explicitly set the value - in the constructor. + agent: The BidiAgent instance, providing access to model configuration. """ + # Extract audio config from agent's model + audio_config = getattr(agent.model, "audio_config", None) + # Apply audio config overrides only if user didn't explicitly set them if audio_config: if "input_rate" in audio_config and "input_rate" not in self._user_config_set: @@ -145,14 +149,15 @@ def __init__(self, config: dict[str, Any]) -> None: self._frames_per_buffer = config.get("output_frames_per_buffer", _BidiAudioOutput._FRAMES_PER_BUFFER) self._rate = config.get("output_rate", _BidiAudioOutput._RATE) - async def start(self, audio_config: AudioConfig | None = None) -> None: + async def start(self, agent: "BidiAgent") -> None: """Start output stream. Args: - audio_config: Optional audio configuration from model provider. - Only applied if user did not explicitly set the value - in the constructor. + agent: The BidiAgent instance, providing access to model configuration. """ + # Extract audio config from agent's model + audio_config = getattr(agent.model, "audio_config", None) + # Apply audio config overrides only if user didn't explicitly set them if audio_config: if "output_rate" in audio_config and "output_rate" not in self._user_config_set: diff --git a/src/strands/experimental/bidi/io/text.py b/src/strands/experimental/bidi/io/text.py index 8ecbae149..ea4e04e99 100644 --- a/src/strands/experimental/bidi/io/text.py +++ b/src/strands/experimental/bidi/io/text.py @@ -1,18 +1,26 @@ """Handle text input and output from bidi agent.""" import logging +from typing import TYPE_CHECKING from ..types.events import BidiInterruptionEvent, BidiOutputEvent, BidiTranscriptStreamEvent from ..types.io import BidiOutput +if TYPE_CHECKING: + from ..agent.agent import BidiAgent + logger = logging.getLogger(__name__) class _BidiTextOutput(BidiOutput): """Handle text output from bidi agent.""" - async def start(self) -> None: - """Start text output.""" + async def start(self, agent: "BidiAgent") -> None: + """Start text output. + + Args: + agent: The BidiAgent instance, providing access to model configuration. + """ pass async def stop(self) -> None: diff --git a/src/strands/experimental/bidi/types/io.py b/src/strands/experimental/bidi/types/io.py index 7a696e85b..57b49a834 100644 --- a/src/strands/experimental/bidi/types/io.py +++ b/src/strands/experimental/bidi/types/io.py @@ -5,10 +5,13 @@ by separating input and output concerns into independent callables. """ -from typing import Awaitable, Literal, Protocol, TypedDict +from typing import TYPE_CHECKING, Awaitable, Literal, Protocol, TypedDict from ..types.events import BidiInputEvent, BidiOutputEvent +if TYPE_CHECKING: + from ..agent.agent import BidiAgent + class AudioConfig(TypedDict, total=False): """Audio configuration for bidirectional streaming. @@ -39,8 +42,12 @@ class BidiInput(Protocol): and return events to be sent to the agent. """ - async def start(self) -> None: - """Start input.""" + async def start(self, agent: "BidiAgent") -> None: + """Start input. + + Args: + agent: The BidiAgent instance, providing access to model configuration. + """ ... async def stop(self) -> None: @@ -63,8 +70,12 @@ class BidiOutput(Protocol): (play audio, display text, send over websocket, etc.). """ - async def start(self) -> None: - """Start output.""" + async def start(self, agent: "BidiAgent") -> None: + """Start output. + + Args: + agent: The BidiAgent instance, providing access to model configuration. + """ ... async def stop(self) -> None: From d18b0210d3f64cfc450def36d4482d7e71682cbd Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 25 Nov 2025 08:42:21 -0500 Subject: [PATCH 184/242] updated audio config to be used from model provider --- src/strands/experimental/bidi/agent/agent.py | 14 +- src/strands/experimental/bidi/io/audio.py | 129 ++++++------- .../experimental/bidi/models/bidi_model.py | 7 + .../experimental/bidi/models/gemini_live.py | 2 +- .../experimental/bidi/models/novasonic.py | 2 +- .../experimental/bidi/models/openai.py | 10 +- .../experimental/bidi/types/bidi_model.py | 34 ++++ src/strands/experimental/bidi/types/io.py | 24 +-- .../experimental/bidi/io/test_audio.py | 170 ++++++++++++------ 9 files changed, 230 insertions(+), 162 deletions(-) create mode 100644 src/strands/experimental/bidi/types/bidi_model.py diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 0e74da392..ccba6f93e 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -30,7 +30,6 @@ from ....types.tools import AgentTool, ToolResult, ToolUse from ...hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent from ...tools import ToolProvider -from ..hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent from .._async import stop_all from ..models.bidi_model import BidiModel from ..models.novasonic import BidiNovaSonicModel @@ -435,7 +434,9 @@ async def run( Example: ```python - audio_io = BidiAudioIO(input_rate=16000) + # Using model defaults: + model = BidiNovaSonicModel() + audio_io = BidiAudioIO() text_io = BidiTextIO() agent = BidiAgent(model=model, tools=[calculator]) await agent.run( @@ -443,6 +444,15 @@ async def run( outputs=[audio_io.output(), text_io.output()], invocation_state={"user_id": "user_123"} ) + + # Using custom audio config: + model = BidiNovaSonicModel(audio_config={"input_rate": 48000, "output_rate": 24000}) + audio_io = BidiAudioIO() + agent = BidiAgent(model=model, tools=[calculator]) + await agent.run( + inputs=[audio_io.input()], + outputs=[audio_io.output()], + ) ``` """ diff --git a/src/strands/experimental/bidi/io/audio.py b/src/strands/experimental/bidi/io/audio.py index 98f64ca1e..d8ea61496 100644 --- a/src/strands/experimental/bidi/io/audio.py +++ b/src/strands/experimental/bidi/io/audio.py @@ -2,6 +2,8 @@ Reads user audio from input device and sends agent audio to output device using PyAudio. If a user interrupts the agent, the output buffer is cleared to stop playback. + +Audio configuration is provided by the model via agent.model.audio_config. """ import asyncio @@ -13,7 +15,7 @@ import pyaudio from ..types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiInterruptionEvent, BidiOutputEvent -from ..types.io import AudioConfig, BidiInput, BidiOutput +from ..types.io import BidiInput, BidiOutput if TYPE_CHECKING: from ..agent.agent import BidiAgent @@ -27,30 +29,28 @@ class _BidiAudioInput(BidiInput): Attributes: _audio: PyAudio instance for audio system access. _stream: Audio input stream. - _user_config_set: Track which config values were explicitly set by user. """ _audio: pyaudio.PyAudio _stream: pyaudio.Stream - _user_config_set: set[str] - _CHANNELS: int = 1 + # Audio device constants _DEVICE_INDEX: int | None = None - _ENCODING: str = "pcm" - _FORMAT: int = pyaudio.paInt16 + _PYAUDIO_FORMAT: int = pyaudio.paInt16 _FRAMES_PER_BUFFER: int = 512 - _RATE: int = 16000 def __init__(self, config: dict[str, Any]) -> None: - """Extract configs and track which were explicitly set by user.""" - # Track which config values were explicitly provided by user - self._user_config_set = set(config.keys()) - - self._channels = config.get("input_channels", _BidiAudioInput._CHANNELS) - self._device_index = config.get("input_device_index", _BidiAudioInput._DEVICE_INDEX) - self._format = config.get("input_format", _BidiAudioInput._FORMAT) - self._frames_per_buffer = config.get("input_frames_per_buffer", _BidiAudioInput._FRAMES_PER_BUFFER) - self._rate = config.get("input_rate", _BidiAudioInput._RATE) + """Initialize audio input handler. + + Args: + config: Configuration dictionary with optional overrides: + - input_device_index: Specific input device to use + - input_frames_per_buffer: Number of frames per buffer + """ + # Initialize instance variables from config or class constants + self._device_index = config.get("input_device_index", self._DEVICE_INDEX) + self._pyaudio_format = self._PYAUDIO_FORMAT + self._frames_per_buffer = config.get("input_frames_per_buffer", self._FRAMES_PER_BUFFER) async def start(self, agent: "BidiAgent") -> None: """Start input stream. @@ -58,17 +58,10 @@ async def start(self, agent: "BidiAgent") -> None: Args: agent: The BidiAgent instance, providing access to model configuration. """ - # Extract audio config from agent's model - audio_config = getattr(agent.model, "audio_config", None) - - # Apply audio config overrides only if user didn't explicitly set them - if audio_config: - if "input_rate" in audio_config and "input_rate" not in self._user_config_set: - self._rate = audio_config["input_rate"] - logger.debug("audio_config | applying model input rate: %d Hz", self._rate) - if "channels" in audio_config and "input_channels" not in self._user_config_set: - self._channels = audio_config["channels"] - logger.debug("audio_config | applying model channels: %d", self._channels) + # Get audio parameters from model config + self._rate = agent.model.audio_config["input_rate"] + self._channels = agent.model.audio_config["channels"] + self._format = agent.model.audio_config.get("format", "pcm") # Encoding format for events logger.debug( "rate=<%d>, channels=<%d>, device_index=<%s> | starting audio input stream", @@ -79,7 +72,7 @@ async def start(self, agent: "BidiAgent") -> None: self._audio = pyaudio.PyAudio() self._stream = self._audio.open( channels=self._channels, - format=self._format, + format=self._pyaudio_format, frames_per_buffer=self._frames_per_buffer, input=True, input_device_index=self._device_index, @@ -106,7 +99,7 @@ async def __call__(self) -> BidiAudioInputEvent: return BidiAudioInputEvent( audio=base64.b64encode(audio_bytes).decode("utf-8"), channels=self._channels, - format=_BidiAudioInput._ENCODING, + format=self._format, sample_rate=self._rate, ) @@ -120,7 +113,6 @@ class _BidiAudioOutput(BidiOutput): _buffer: Deque buffer for queuing audio data. _buffer_event: Event to signal when buffer has data. _output_task: Background task for processing audio output. - _user_config_set: Track which config values were explicitly set by user. """ _audio: pyaudio.PyAudio @@ -128,26 +120,27 @@ class _BidiAudioOutput(BidiOutput): _buffer: deque _buffer_event: asyncio.Event _output_task: asyncio.Task - _user_config_set: set[str] + # Audio device constants _BUFFER_SIZE: int | None = None - _CHANNELS: int = 1 _DEVICE_INDEX: int | None = None - _FORMAT: int = pyaudio.paInt16 + _PYAUDIO_FORMAT: int = pyaudio.paInt16 _FRAMES_PER_BUFFER: int = 512 - _RATE: int = 16000 def __init__(self, config: dict[str, Any]) -> None: - """Extract configs and track which were explicitly set by user.""" - # Track which config values were explicitly provided by user - self._user_config_set = set(config.keys()) - - self._buffer_size = config.get("output_buffer_size", _BidiAudioOutput._BUFFER_SIZE) - self._channels = config.get("output_channels", _BidiAudioOutput._CHANNELS) - self._device_index = config.get("output_device_index", _BidiAudioOutput._DEVICE_INDEX) - self._format = config.get("output_format", _BidiAudioOutput._FORMAT) - self._frames_per_buffer = config.get("output_frames_per_buffer", _BidiAudioOutput._FRAMES_PER_BUFFER) - self._rate = config.get("output_rate", _BidiAudioOutput._RATE) + """Initialize audio output handler. + + Args: + config: Configuration dictionary with optional overrides: + - output_device_index: Specific output device to use + - output_frames_per_buffer: Number of frames per buffer + - output_buffer_size: Maximum buffer size (None = unlimited) + """ + # Initialize instance variables from config or class constants + self._buffer_size = config.get("output_buffer_size", self._BUFFER_SIZE) + self._device_index = config.get("output_device_index", self._DEVICE_INDEX) + self._pyaudio_format = self._PYAUDIO_FORMAT # Not configurable + self._frames_per_buffer = config.get("output_frames_per_buffer", self._FRAMES_PER_BUFFER) async def start(self, agent: "BidiAgent") -> None: """Start output stream. @@ -155,29 +148,19 @@ async def start(self, agent: "BidiAgent") -> None: Args: agent: The BidiAgent instance, providing access to model configuration. """ - # Extract audio config from agent's model - audio_config = getattr(agent.model, "audio_config", None) - - # Apply audio config overrides only if user didn't explicitly set them - if audio_config: - if "output_rate" in audio_config and "output_rate" not in self._user_config_set: - self._rate = audio_config["output_rate"] - logger.debug("audio_config | applying model output rate: %d Hz", self._rate) - if "channels" in audio_config and "output_channels" not in self._user_config_set: - self._channels = audio_config["channels"] - logger.debug("audio_config | applying model channels: %d", self._channels) + # Get audio parameters from model config + self._rate = agent.model.audio_config["output_rate"] + self._channels = agent.model.audio_config["channels"] logger.debug( - "rate=<%d>, channels=<%d>, device_index=<%s>, buffer_size=<%s> | starting audio output stream", + "rate=<%d>, channels=<%d> | starting audio output stream", self._rate, self._channels, - self._device_index, - self._buffer_size, ) self._audio = pyaudio.PyAudio() self._stream = self._audio.open( channels=self._channels, - format=self._format, + format=self._pyaudio_format, frames_per_buffer=self._frames_per_buffer, output=True, output_device_index=self._device_index, @@ -228,25 +211,19 @@ async def _output(self) -> None: class BidiAudioIO: - """Send and receive audio data from devices.""" + """Send and receive audio data from devices. + + Args: + **config: Optional device configuration: + - input_device_index (int): Specific input device (default: None = system default) + - output_device_index (int): Specific output device (default: None = system default) + - input_frames_per_buffer (int): Input buffer size (default: 512) + - output_frames_per_buffer (int): Output buffer size (default: 512) + - output_buffer_size (int | None): Max output queue size (default: None = unlimited) + """ def __init__(self, **config: Any) -> None: - """Initialize audio devices. - - Args: - **config: Dictionary containing audio configuration: - - input_channels (int): Input channels (default: 1) - - input_device_index (int): Specific input device (optional) - - input_format (int): Audio format (default: paInt16) - - input_frames_per_buffer (int): Frames per buffer (default: 512) - - input_rate (int): Input sample rate (default: 16000) - - output_buffer_size (int): Maximum output buffer size (default: None) - - output_channels (int): Output channels (default: 1) - - output_device_index (int): Specific output device (optional) - - output_format (int): Audio format (default: paInt16) - - output_frames_per_buffer (int): Frames per buffer (default: 512) - - output_rate (int): Output sample rate (default: 16000) - """ + """Initialize audio devices.""" self._config = config def input(self) -> _BidiAudioInput: diff --git a/src/strands/experimental/bidi/models/bidi_model.py b/src/strands/experimental/bidi/models/bidi_model.py index ad91a81b0..7ced098f8 100644 --- a/src/strands/experimental/bidi/models/bidi_model.py +++ b/src/strands/experimental/bidi/models/bidi_model.py @@ -22,6 +22,7 @@ BidiInputEvent, BidiOutputEvent, ) +from ..types.model import AudioConfig logger = logging.getLogger(__name__) @@ -32,8 +33,14 @@ class BidiModel(Protocol): This interface defines the contract for models that support persistent streaming connections with real-time audio and text communication. Implementations handle provider-specific protocols while exposing a standardized event-based API. + + All bidirectional models must provide an audio_config property that specifies + their audio processing requirements. This configuration is built by merging + user-provided values with model-specific defaults. """ + audio_config: AudioConfig + async def start( self, system_prompt: str | None = None, diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index e02bc8869..716b7a0fc 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -39,7 +39,7 @@ ModalityUsage, SampleRate, ) -from ..types.io import AudioConfig +from ..types.model import AudioConfig from .bidi_model import BidiModel logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 40dcb181b..8ebf89785 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -48,7 +48,7 @@ BidiUsageEvent, SampleRate, ) -from ..types.io import AudioConfig +from ..types.model import AudioConfig from .bidi_model import BidiModel logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 8344e695c..192b15b06 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -34,7 +34,7 @@ SampleRate, StopReason, ) -from ..types.io import AudioConfig +from ..types.model import AudioConfig from .bidi_model import BidiModel logger = logging.getLogger(__name__) @@ -77,7 +77,7 @@ class BidiOpenAIRealtimeModel(BidiModel): def __init__( self, - model: str = DEFAULT_MODEL, + model_id: str = DEFAULT_MODEL, api_key: str | None = None, organization: str | None = None, project: str | None = None, @@ -98,7 +98,7 @@ def __init__( **kwargs: Reserved for future parameters. """ # Model configuration - self.model = model + self.model_id = model_id self.api_key = api_key self.organization = organization self.project = project @@ -171,7 +171,7 @@ async def start( self._function_call_buffer = {} # Establish WebSocket connection - url = f"{OPENAI_REALTIME_URL}?model={self.model}" + url = f"{OPENAI_REALTIME_URL}?model={self.model_id}" headers = [("Authorization", f"Bearer {self.api_key}")] if self.organization: @@ -305,7 +305,7 @@ async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: if not self._connection_id: raise RuntimeError("model not started | call start before receiving") - yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model) + yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) async for message in self._websocket: openai_event = json.loads(message) diff --git a/src/strands/experimental/bidi/types/bidi_model.py b/src/strands/experimental/bidi/types/bidi_model.py new file mode 100644 index 000000000..32aaa0079 --- /dev/null +++ b/src/strands/experimental/bidi/types/bidi_model.py @@ -0,0 +1,34 @@ +"""Model-related type definitions for bidirectional streaming. + +Defines types and configurations that are central to model providers, +including audio configuration that models use to specify their audio +processing requirements. +""" + +from typing import Literal, TypedDict + + +class AudioConfig(TypedDict, total=False): + """Audio configuration for bidirectional streaming models. + + Defines standard audio parameters that model providers use to specify + their audio processing requirements. All fields are optional to support + models that may not use audio or only need specific parameters. + + Model providers build this configuration by merging user-provided values + with their own defaults. The resulting configuration is then used by + audio I/O implementations to configure hardware appropriately. + + Attributes: + input_rate: Input sample rate in Hz (e.g., 16000, 24000, 48000) + output_rate: Output sample rate in Hz (e.g., 16000, 24000, 48000) + channels: Number of audio channels (1=mono, 2=stereo) + format: Audio encoding format + voice: Voice identifier for text-to-speech (e.g., "alloy", "matthew") + """ + + input_rate: int + output_rate: int + channels: int + format: Literal["pcm", "wav", "opus", "mp3"] + voice: str diff --git a/src/strands/experimental/bidi/types/io.py b/src/strands/experimental/bidi/types/io.py index 57b49a834..93d8059f8 100644 --- a/src/strands/experimental/bidi/types/io.py +++ b/src/strands/experimental/bidi/types/io.py @@ -5,7 +5,7 @@ by separating input and output concerns into independent callables. """ -from typing import TYPE_CHECKING, Awaitable, Literal, Protocol, TypedDict +from typing import TYPE_CHECKING, Awaitable, Protocol from ..types.events import BidiInputEvent, BidiOutputEvent @@ -13,28 +13,6 @@ from ..agent.agent import BidiAgent -class AudioConfig(TypedDict, total=False): - """Audio configuration for bidirectional streaming. - - Defines standard audio parameters shared between model providers - and audio I/O implementations. All fields are optional to support - models that may not use audio or only need specific parameters. - - Attributes: - input_rate: Input sample rate in Hz (e.g., 16000, 24000, 48000) - output_rate: Output sample rate in Hz (e.g., 16000, 24000, 48000) - channels: Number of audio channels (1=mono, 2=stereo) - format: Audio encoding format - voice: Voice identifier for text-to-speech (e.g., "alloy", "matthew") - """ - - input_rate: int - output_rate: int - channels: int - format: Literal["pcm", "wav", "opus", "mp3"] - voice: str - - class BidiInput(Protocol): """Protocol for bidirectional input callables. diff --git a/tests/strands/experimental/bidi/io/test_audio.py b/tests/strands/experimental/bidi/io/test_audio.py index 9517ad108..d22ce5d39 100644 --- a/tests/strands/experimental/bidi/io/test_audio.py +++ b/tests/strands/experimental/bidi/io/test_audio.py @@ -19,6 +19,34 @@ def audio_io(): return BidiAudioIO() +@pytest.fixture +def mock_agent(): + """Create a mock agent with model that has default audio_config.""" + agent = unittest.mock.MagicMock() + agent.model.audio_config = { + "input_rate": 16000, + "output_rate": 16000, + "channels": 1, + "format": "pcm", + "voice": "matthew", + } + return agent + + +@pytest.fixture +def mock_agent_custom_config(): + """Create a mock agent with custom audio_config.""" + agent = unittest.mock.MagicMock() + agent.model.audio_config = { + "input_rate": 48000, + "output_rate": 24000, + "channels": 2, + "format": "pcm", + "voice": "alloy", + } + return agent + + @pytest.fixture def audio_input(audio_io): return audio_io.input() @@ -30,13 +58,14 @@ def audio_output(audio_io): @pytest.mark.asyncio -async def test_bidi_audio_io_input(py_audio, audio_input): +async def test_bidi_audio_io_input(py_audio, audio_input, mock_agent): + """Test basic audio input functionality.""" microphone = unittest.mock.Mock() microphone.read.return_value = b"test-audio" py_audio.open.return_value = microphone - await audio_input.start() + await audio_input.start(mock_agent) tru_event = await audio_input() await audio_input.stop() @@ -52,7 +81,8 @@ async def test_bidi_audio_io_input(py_audio, audio_input): @pytest.mark.asyncio -async def test_bidi_audio_io_output(py_audio, audio_output): +async def test_bidi_audio_io_output(py_audio, audio_output, mock_agent): + """Test basic audio output functionality.""" write_future = asyncio.Future() write_event = asyncio.Event() @@ -65,7 +95,7 @@ def write(data): py_audio.open.return_value = speaker - await audio_output.start() + await audio_output.start(mock_agent) audio_event = BidiAudioStreamEvent( audio=base64.b64encode(b"test-audio").decode("utf-8"), @@ -85,119 +115,151 @@ def write(data): @pytest.mark.asyncio -async def test_audio_input_respects_user_config(py_audio): - """Test that user-provided config takes precedence over model config.""" - audio_io = BidiAudioIO(input_rate=48000, input_channels=2) +async def test_audio_input_uses_model_config(py_audio, audio_io, mock_agent): + """Test that audio input uses model's audio_config.""" audio_input = audio_io.input() microphone = unittest.mock.Mock() microphone.read.return_value = b"test-audio" py_audio.open.return_value = microphone - # Model provides different config - model_audio_config = {"input_rate": 16000, "channels": 1} + await audio_input.start(mock_agent) - await audio_input.start(audio_config=model_audio_config) - - # User config should be used + # Model config should be used py_audio.open.assert_called_once() call_kwargs = py_audio.open.call_args.kwargs - assert call_kwargs["rate"] == 48000 # User config - assert call_kwargs["channels"] == 2 # User config + assert call_kwargs["rate"] == 16000 # From mock_agent.model.audio_config + assert call_kwargs["channels"] == 1 # From mock_agent.model.audio_config await audio_input.stop() @pytest.mark.asyncio -async def test_audio_input_applies_model_config_when_user_not_set(py_audio): - """Test that model config is applied when user doesn't provide values.""" - audio_io = BidiAudioIO() # No user config +async def test_audio_input_uses_custom_model_config(py_audio, audio_io, mock_agent_custom_config): + """Test that audio input uses custom model audio_config.""" audio_input = audio_io.input() microphone = unittest.mock.Mock() microphone.read.return_value = b"test-audio" py_audio.open.return_value = microphone - # Model provides config - model_audio_config = {"input_rate": 24000, "channels": 2} - - await audio_input.start(audio_config=model_audio_config) + await audio_input.start(mock_agent_custom_config) - # Model config should be used + # Custom model config should be used py_audio.open.assert_called_once() call_kwargs = py_audio.open.call_args.kwargs - assert call_kwargs["rate"] == 24000 # Model config - assert call_kwargs["channels"] == 2 # Model config + assert call_kwargs["rate"] == 48000 # From custom config + assert call_kwargs["channels"] == 2 # From custom config await audio_input.stop() @pytest.mark.asyncio -async def test_audio_output_respects_user_config(py_audio): - """Test that user-provided config takes precedence over model config.""" - audio_io = BidiAudioIO(output_rate=48000, output_channels=2) +async def test_audio_output_uses_model_config(py_audio, audio_io, mock_agent): + """Test that audio output uses model's audio_config.""" audio_output = audio_io.output() speaker = unittest.mock.Mock() py_audio.open.return_value = speaker - # Model provides different config - model_audio_config = {"output_rate": 16000, "channels": 1} - - await audio_output.start(audio_config=model_audio_config) + await audio_output.start(mock_agent) - # User config should be used + # Model config should be used py_audio.open.assert_called_once() call_kwargs = py_audio.open.call_args.kwargs - assert call_kwargs["rate"] == 48000 # User config - assert call_kwargs["channels"] == 2 # User config + assert call_kwargs["rate"] == 16000 # From mock_agent.model.audio_config + assert call_kwargs["channels"] == 1 # From mock_agent.model.audio_config await audio_output.stop() @pytest.mark.asyncio -async def test_audio_output_applies_model_config_when_user_not_set(py_audio): - """Test that model config is applied when user doesn't provide values.""" - audio_io = BidiAudioIO() # No user config +async def test_audio_output_uses_custom_model_config(py_audio, audio_io, mock_agent_custom_config): + """Test that audio output uses custom model audio_config.""" audio_output = audio_io.output() speaker = unittest.mock.Mock() py_audio.open.return_value = speaker - # Model provides config - model_audio_config = {"output_rate": 24000, "channels": 2} - - await audio_output.start(audio_config=model_audio_config) + await audio_output.start(mock_agent_custom_config) - # Model config should be used + # Custom model config should be used py_audio.open.assert_called_once() call_kwargs = py_audio.open.call_args.kwargs - assert call_kwargs["rate"] == 24000 # Model config - assert call_kwargs["channels"] == 2 # Model config + assert call_kwargs["rate"] == 24000 # From custom config + assert call_kwargs["channels"] == 2 # From custom config await audio_output.stop() +# Device Configuration Tests + + @pytest.mark.asyncio -async def test_audio_partial_user_config(py_audio): - """Test that partial user config works correctly.""" - # User only sets rate, not channels - audio_io = BidiAudioIO(input_rate=48000) +async def test_audio_input_respects_user_device_config(py_audio, mock_agent): + """Test that user-provided device config overrides defaults.""" + audio_io = BidiAudioIO(input_device_index=5, input_frames_per_buffer=1024) audio_input = audio_io.input() microphone = unittest.mock.Mock() microphone.read.return_value = b"test-audio" py_audio.open.return_value = microphone - # Model provides both rate and channels - model_audio_config = {"input_rate": 16000, "channels": 2} + await audio_input.start(mock_agent) + + # User device config should be used + py_audio.open.assert_called_once() + call_kwargs = py_audio.open.call_args.kwargs + assert call_kwargs["input_device_index"] == 5 # User config + assert call_kwargs["frames_per_buffer"] == 1024 # User config + # Model config still used for audio parameters + assert call_kwargs["rate"] == 16000 # From model + assert call_kwargs["channels"] == 1 # From model + + await audio_input.stop() + + +@pytest.mark.asyncio +async def test_audio_output_respects_user_device_config(py_audio, mock_agent): + """Test that user-provided device config overrides defaults.""" + audio_io = BidiAudioIO(output_device_index=3, output_frames_per_buffer=2048, output_buffer_size=50) + audio_output = audio_io.output() + + speaker = unittest.mock.Mock() + py_audio.open.return_value = speaker + + await audio_output.start(mock_agent) + + # User device config should be used + py_audio.open.assert_called_once() + call_kwargs = py_audio.open.call_args.kwargs + assert call_kwargs["output_device_index"] == 3 # User config + assert call_kwargs["frames_per_buffer"] == 2048 # User config + # Model config still used for audio parameters + assert call_kwargs["rate"] == 16000 # From model + assert call_kwargs["channels"] == 1 # From model + # Buffer size should be set + assert audio_output._buffer_size == 50 # User config + + await audio_output.stop() + + +@pytest.mark.asyncio +async def test_audio_io_uses_defaults_when_no_config(py_audio, mock_agent): + """Test that defaults are used when no config provided.""" + audio_io = BidiAudioIO() # No config + audio_input = audio_io.input() + + microphone = unittest.mock.Mock() + microphone.read.return_value = b"test-audio" + py_audio.open.return_value = microphone - await audio_input.start(audio_config=model_audio_config) + await audio_input.start(mock_agent) - # User rate should be used, model channels should be applied + # Defaults should be used py_audio.open.assert_called_once() call_kwargs = py_audio.open.call_args.kwargs - assert call_kwargs["rate"] == 48000 # User config - assert call_kwargs["channels"] == 2 # Model config + assert call_kwargs["input_device_index"] is None # Default + assert call_kwargs["frames_per_buffer"] == 512 # Default await audio_input.stop() From 802099b832000ac7cfce5ab2cb4108c42712676e Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 25 Nov 2025 08:44:45 -0500 Subject: [PATCH 185/242] remove code from bad rebase --- src/strands/experimental/bidi/agent/agent.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index ccba6f93e..a3a38017c 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -470,18 +470,6 @@ async def run_outputs() -> None: tasks = [output(event) for output in outputs] await asyncio.gather(*tasks) - await self.start() - - # Start inputs, passing agent for configuration - for input_ in inputs: - if hasattr(input_, "start"): - await input_.start(self) - - # Start outputs, passing agent for configuration - for output in outputs: - if hasattr(output, "start"): - await output.start(self) - try: await self.start(invocation_state) From ff0b89b5d5d0136fd180dacec0dca7f6236d28f4 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 25 Nov 2025 08:45:37 -0500 Subject: [PATCH 186/242] pass agent to io start method --- src/strands/experimental/bidi/agent/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index a3a38017c..fac9fe0bc 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -476,7 +476,7 @@ async def run_outputs() -> None: start_inputs = [input_.start for input_ in inputs if hasattr(input_, "start")] start_outputs = [output.start for output in outputs if hasattr(output, "start")] for start in [*start_inputs, *start_outputs]: - await start() + await start(self) async with asyncio.TaskGroup() as task_group: task_group.create_task(run_inputs()) From 7b2c30ee2aa1be8c69a265ec7886e41659574ab3 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 25 Nov 2025 08:47:30 -0500 Subject: [PATCH 187/242] update imports --- src/strands/experimental/bidi/models/bidi_model.py | 2 +- src/strands/experimental/bidi/models/gemini_live.py | 2 +- src/strands/experimental/bidi/models/novasonic.py | 2 +- src/strands/experimental/bidi/models/openai.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/strands/experimental/bidi/models/bidi_model.py b/src/strands/experimental/bidi/models/bidi_model.py index 7ced098f8..f1eae4208 100644 --- a/src/strands/experimental/bidi/models/bidi_model.py +++ b/src/strands/experimental/bidi/models/bidi_model.py @@ -22,7 +22,7 @@ BidiInputEvent, BidiOutputEvent, ) -from ..types.model import AudioConfig +from ..types.bidi_model import AudioConfig logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 716b7a0fc..3b562efed 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -39,7 +39,7 @@ ModalityUsage, SampleRate, ) -from ..types.model import AudioConfig +from ..types.bidi_model import AudioConfig from .bidi_model import BidiModel logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 8ebf89785..06ac99012 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -48,7 +48,7 @@ BidiUsageEvent, SampleRate, ) -from ..types.model import AudioConfig +from ..types.bidi_model import AudioConfig from .bidi_model import BidiModel logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 192b15b06..14aa221d5 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -34,7 +34,7 @@ SampleRate, StopReason, ) -from ..types.model import AudioConfig +from ..types.bidi_model import AudioConfig from .bidi_model import BidiModel logger = logging.getLogger(__name__) From 489bba049710478f4f21048a2ddeb77e789da287 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 25 Nov 2025 09:00:38 -0500 Subject: [PATCH 188/242] test_bidi.py - remove use of with context (#81) --- scripts/bidi/test_bidi.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/bidi/test_bidi.py b/scripts/bidi/test_bidi.py index f7447871c..2beb3ddd7 100644 --- a/scripts/bidi/test_bidi.py +++ b/scripts/bidi/test_bidi.py @@ -15,11 +15,11 @@ async def main(): audio_io = BidiAudioIO() text_io = BidiTextIO() model = BidiNovaSonicModel(region="us-east-1") + agent = BidiAgent(model=model, tools=[calculator]) - async with BidiAgent(model=model, tools=[calculator]) as agent: - print("New BidiAgent Experience") - print("Try asking: 'What is 25 times 8?' or 'Calculate the square root of 144'") - await agent.run(inputs=[audio_io.input()], outputs=[audio_io.output(), text_io.output()]) + print("New BidiAgent Experience") + print("Try asking: 'What is 25 times 8?' or 'Calculate the square root of 144'") + await agent.run(inputs=[audio_io.input()], outputs=[audio_io.output(), text_io.output()]) if __name__ == "__main__": From cd06d208d6f90509bd017a1ec2b8aeabe48c4762 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 25 Nov 2025 09:17:07 -0500 Subject: [PATCH 189/242] remove unused ToolCaller code from agent and delete caller.py --- src/strands/experimental/bidi/agent/agent.py | 102 ++---------- src/strands/tools/caller.py | 164 ------------------- 2 files changed, 12 insertions(+), 254 deletions(-) delete mode 100644 src/strands/tools/caller.py diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 3ca0d2ad3..238089a8c 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -13,7 +13,6 @@ """ import asyncio -import json import logging from typing import Any, AsyncGenerator @@ -21,13 +20,13 @@ from ....agent.state import AgentState from ....hooks import HookProvider, HookRegistry from ....interrupt import _InterruptState -from ....tools.caller import _ToolCaller +from ....tools._caller import _ToolCaller from ....tools.executors import ConcurrentToolExecutor from ....tools.executors._executor import ToolExecutor from ....tools.registry import ToolRegistry from ....tools.watcher import ToolWatcher -from ....types.content import ContentBlock, Message, Messages -from ....types.tools import AgentTool, ToolResult, ToolUse +from ....types.content import Message, Messages +from ....types.tools import AgentTool from ...hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent from ...tools import ToolProvider from .._async import stop_all @@ -131,7 +130,7 @@ def __init__( self.state = AgentState() # Initialize other components - self._tool_caller = _ToolCaller(self) + self.tool_caller = _ToolCaller(self) # Initialize tool executor self.tool_executor = tool_executor or ConcurrentToolExecutor() @@ -165,7 +164,7 @@ def tool(self) -> _ToolCaller: agent.tool.calculator(expression="2+2") ``` """ - return self._tool_caller + return self.tool_caller @property def tool_names(self) -> list[str]: @@ -177,91 +176,12 @@ def tool_names(self) -> list[str]: all_tools = self.tool_registry.get_all_tools_config() return list(all_tools.keys()) - def _record_tool_execution( - self, - tool: ToolUse, - tool_result: ToolResult, - user_message_override: str | None, - ) -> None: - """Record a tool execution in the message history. - - Creates a sequence of messages that represent the tool execution: - - 1. A user message describing the tool call - 2. An assistant message with the tool use - 3. A user message with the tool result - 4. An assistant message acknowledging the tool call - - Args: - tool: The tool call information. - tool_result: The result returned by the tool. - user_message_override: Optional custom message to include. - """ - # Filter tool input parameters to only include those defined in tool spec - filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"]) - - # Create user message describing the tool call - input_parameters = json.dumps(filtered_input, default=lambda o: f"<>") - - user_msg_content: list[ContentBlock] = [ - {"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")} - ] - - # Add override message if provided - if user_message_override: - user_msg_content.insert(0, {"text": f"{user_message_override}\n"}) - - # Create filtered tool use for message history - filtered_tool: ToolUse = { - "toolUseId": tool["toolUseId"], - "name": tool["name"], - "input": filtered_input, - } - - # Create the message sequence - user_msg: Message = { - "role": "user", - "content": user_msg_content, - } - tool_use_msg: Message = { - "role": "assistant", - "content": [{"toolUse": filtered_tool}], - } - tool_result_msg: Message = { - "role": "user", - "content": [{"toolResult": tool_result}], - } - assistant_msg: Message = { - "role": "assistant", - "content": [{"text": f"agent.tool.{tool['name']} was called."}], - } - - # Add to message history - self.messages.append(user_msg) - self.messages.append(tool_use_msg) - self.messages.append(tool_result_msg) - self.messages.append(assistant_msg) - - logger.debug("tool_name=<%s> | direct tool call recorded in message history", tool["name"]) - - def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]: - """Filter input parameters to only include those defined in the tool specification. - - Args: - tool_name: Name of the tool to get specification for - input_params: Original input parameters - - Returns: - Filtered parameters containing only those defined in tool spec - """ - all_tools_config = self.tool_registry.get_all_tools_config() - tool_spec = all_tools_config.get(tool_name) - if not tool_spec or "inputSchema" not in tool_spec: - return input_params.copy() - properties = tool_spec["inputSchema"]["json"]["properties"] - return {k: v for k, v in input_params.items() if k in properties} + async def _append_message(self, message: Message) -> None: + """Appends a message to the agent's list of messages and invokes the callbacks for the MessageAddedEvent.""" + self.messages.append(message) + await self.hooks.invoke_callbacks_async(BidiMessageAddedEvent(agent=self, message=message)) async def start(self, invocation_state: dict[str, Any] | None = None) -> None: """Start a persistent bidirectional conversation connection. @@ -460,7 +380,9 @@ async def run_outputs() -> None: await asyncio.gather(*tasks) try: - await self.start(invocation_state) + # Only start if not already started (e.g., when used with async context manager) + if not self._started: + await self.start(invocation_state) start_inputs = [input_.start for input_ in inputs if hasattr(input_, "start")] start_outputs = [output.start for output in outputs if hasattr(output, "start")] diff --git a/src/strands/tools/caller.py b/src/strands/tools/caller.py deleted file mode 100644 index 68357f266..000000000 --- a/src/strands/tools/caller.py +++ /dev/null @@ -1,164 +0,0 @@ -"""ToolCaller base class.""" - -import random -from typing import Any, Callable, Optional - -from .._async import run_async -from ..tools.executors._executor import ToolExecutor -from ..types._events import ToolInterruptEvent -from ..types.tools import ToolResult, ToolUse - - -class _ToolCaller: - """Provides common tool calling functionality for Agent classes. - - Can be used by both traditional Agent and BidirectionalAgent classes with - agent-specific customizations. - - Automatically detects agent type and applies appropriate behavior: - - Traditional agents: Uses conversation_manager.apply_management() - """ - - def __init__(self, agent: Any) -> None: - """Initialize base tool caller. - - Args: - agent: Agent instance that will process tool results. - """ - # WARNING: Do not add other member variables to avoid conflicts with tool names - self._agent = agent - - def __getattr__(self, name: str) -> Callable[..., Any]: - """Enable method-style tool calling interface. - - This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). - It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing'). - - Args: - name: The name of the attribute (tool) being accessed. - - Returns: - A function that when called will execute the named tool. - - Raises: - AttributeError: If no tool with the given name exists or if multiple tools match the given name. - """ - - def caller( - user_message_override: Optional[str] = None, - record_direct_tool_call: Optional[bool] = None, - **kwargs: Any, - ) -> Any: - """Call a tool directly by name. - - Args: - user_message_override: Optional custom message to record instead of default. - record_direct_tool_call: Whether to record direct tool calls in message history. - **kwargs: Keyword arguments to pass to the tool. - - Returns: - The result returned by the tool. - - Raises: - AttributeError: If the tool doesn't exist. - RuntimeError: If called during an interrupt or if interrupt is raised. - """ - # Check if agent has interrupt state and if it's activated - if hasattr(self._agent, "_interrupt_state") and self._agent._interrupt_state.activated: - raise RuntimeError("cannot directly call tool during interrupt") - - normalized_name = self._find_normalized_tool_name(name) - - # Create unique tool ID and set up the tool request - tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" - tool_use: ToolUse = { - "toolUseId": tool_id, - "name": normalized_name, - "input": kwargs.copy(), - } - - # Execute tool using shared execution pipeline - tool_result = self._execute_tool_async(tool_use, kwargs, user_message_override, record_direct_tool_call) - - # Apply conversation management if agent supports it (traditional agents) - if hasattr(self._agent, "conversation_manager"): - self._agent.conversation_manager.apply_management(self._agent) - - return tool_result - - return caller - - def _find_normalized_tool_name(self, name: str) -> str: - """Lookup the tool represented by name, replacing characters with underscores as necessary. - - Args: - name: Tool name to normalize. - - Returns: - Normalized tool name that exists in registry. - - Raises: - AttributeError: If tool not found. - """ - tool_registry = self._agent.tool_registry.registry - - if tool_registry.get(name, None): - return name - - # Handle underscore placeholder for characters that can't be python identifiers - if "_" in name: - filtered_tools = [ - tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name - ] - - # Registry defends against similar names, so take first match - if filtered_tools: - return filtered_tools[0] # type: ignore - - raise AttributeError(f"Tool '{name}' not found") - - def _execute_tool_async( - self, - tool_use: ToolUse, - invocation_state: dict[str, Any], - user_message_override: Optional[str], - record_direct_tool_call: Optional[bool], - ) -> ToolResult: - """Execute tool asynchronously using shared Strands pipeline. - - Args: - tool_use: Tool execution request. - invocation_state: Execution context. - user_message_override: Optional message override. - record_direct_tool_call: Optional recording override. - - Returns: - Tool execution result. - - Raises: - RuntimeError: If interrupt is raised during tool execution. - """ - tool_results: list[ToolResult] = [] - - async def acall() -> ToolResult: - async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): - # Check for interrupt events - if isinstance(event, ToolInterruptEvent): - if hasattr(self._agent, "_interrupt_state"): - self._agent._interrupt_state.deactivate() - raise RuntimeError("cannot raise interrupt in direct tool call") - - tool_result = tool_results[0] - - # Determine if we should record the tool call - should_record = ( - record_direct_tool_call if record_direct_tool_call is not None else self._agent.record_direct_tool_call - ) - - if should_record: - # Use agent's async recording method - await self._agent._record_tool_execution(tool_use, tool_result, user_message_override) - - return tool_result - - return run_async(acall) From b3a6e31e8b53970af5981fabbf1d3f81e833d2e0 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 25 Nov 2025 09:30:20 -0500 Subject: [PATCH 190/242] agent - add text input to messages (#78) --- src/strands/experimental/bidi/agent/agent.py | 68 ++++++------------- src/strands/experimental/bidi/agent/loop.py | 23 ++++++- src/strands/experimental/bidi/types/events.py | 12 ++-- 3 files changed, 49 insertions(+), 54 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 3ca0d2ad3..cb1d10fe0 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -28,7 +28,7 @@ from ....tools.watcher import ToolWatcher from ....types.content import ContentBlock, Message, Messages from ....types.tools import AgentTool, ToolResult, ToolUse -from ...hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent +from ...hooks.events import BidiAgentInitializedEvent from ...tools import ToolProvider from .._async import stop_all from ..models.bidi_model import BidiModel @@ -304,8 +304,7 @@ async def send(self, input_data: BidiAgentInput | dict[str, Any]) -> None: Args: input_data: Can be: - str: Text message from user - - BidiAudioInputEvent: Audio data with format/sample rate - - BidiImageInputEvent: Image data with MIME type + - BidiInputEvent: TypedEvent - dict: Event dictionary (will be reconstructed to TypedEvent) Raises: @@ -320,54 +319,29 @@ async def send(self, input_data: BidiAgentInput | dict[str, Any]) -> None: if not self._started: raise RuntimeError("agent not started | call start before sending") - # Handle string input + input_event: BidiInputEvent + if isinstance(input_data, str): - # Add user text message to history - user_message: Message = {"role": "user", "content": [{"text": input_data}]} - - self.messages.append(user_message) - await self.hooks.invoke_callbacks_async(BidiMessageAddedEvent(agent=self, message=user_message)) - - logger.debug("text_length=<%d> | text sent to model", len(input_data)) - # Create BidiTextInputEvent for send() - text_event = BidiTextInputEvent(text=input_data, role="user") - await self.model.send(text_event) - return - - # Handle BidiInputEvent instances - # Check this before dict since TypedEvent inherits from dict - if isinstance(input_data, BidiInputEvent): - await self.model.send(input_data) - return - - # Handle plain dict - reconstruct TypedEvent for WebSocket integration - if isinstance(input_data, dict) and "type" in input_data: - event_type = input_data["type"] - input_event: BidiInputEvent - if event_type == "bidi_text_input": - input_event = BidiTextInputEvent(text=input_data["text"], role=input_data["role"]) - elif event_type == "bidi_audio_input": - input_event = BidiAudioInputEvent( - audio=input_data["audio"], - format=input_data["format"], - sample_rate=input_data["sample_rate"], - channels=input_data["channels"], - ) - elif event_type == "bidi_image_input": - input_event = BidiImageInputEvent(image=input_data["image"], mime_type=input_data["mime_type"]) + input_event = BidiTextInputEvent(text=input_data) + + elif isinstance(input_data, BidiInputEvent): + input_event = input_data + + elif isinstance(input_data, dict) and "type" in input_data: + input_type = input_data["type"] + if input_type == "bidi_text_input": + input_event = BidiTextInputEvent(**input_data) + elif input_type == "bidi_audio_input": + input_event = BidiAudioInputEvent(**input_data) + elif input_type == "bidi_image_input": + input_event = BidiImageInputEvent(**input_data) else: - raise ValueError(f"Unknown event type: {event_type}") + raise ValueError(f"input_type=<{input_type}> | input type not supported") - # Send the reconstructed TypedEvent - await self.model.send(input_event) - return + else: + raise ValueError("invalid input | must be str, BidiInputEvent, or event dict") - # If we get here, input type is invalid - raise ValueError( - f"Input must be a string, BidiInputEvent " - f"(BidiTextInputEvent/BidiAudioInputEvent/BidiImageInputEvent), " - f"or event dict with 'type' field, got: {type(input_data)}" - ) + await self._loop.send(input_event) async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: """Receive events from the model including audio, text, and tool calls. diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 1003dea8c..4d00cd714 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -19,7 +19,13 @@ BidiInterruptionEvent as BidiInterruptionHookEvent, ) from .._async import _TaskPool, stop_all -from ..types.events import BidiInterruptionEvent, BidiOutputEvent, BidiTranscriptStreamEvent +from ..types.events import ( + BidiInputEvent, + BidiInterruptionEvent, + BidiOutputEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, +) if TYPE_CHECKING: from .agent import BidiAgent @@ -106,6 +112,21 @@ async def stop_model() -> None: finally: await self._agent.hooks.invoke_callbacks_async(BidiAfterInvocationEvent(agent=self._agent)) + async def send(self, event: BidiInputEvent) -> None: + """Send model event. + + Additional, add text input to messages array. + + Args: + event: BidiInputEvent. + """ + if isinstance(event, BidiTextInputEvent): + message: Message = {"role": "user", "content": [{"text": event.text}]} + self._agent.messages.append(message) + await self._agent.hooks.invoke_callbacks_async(BidiMessageAddedEvent(agent=self._agent, message=message)) + + await self._agent.model.send(event) + async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: """Receive model and tool call events. diff --git a/src/strands/experimental/bidi/types/events.py b/src/strands/experimental/bidi/types/events.py index 70d0f8f3d..e9f53d0e6 100644 --- a/src/strands/experimental/bidi/types/events.py +++ b/src/strands/experimental/bidi/types/events.py @@ -43,10 +43,10 @@ class BidiTextInputEvent(TypedEvent): Parameters: text: The text content to send to the model. - role: The role of the message sender (typically "user"). + role: The role of the message sender (default: "user"). """ - def __init__(self, text: str, role: str): + def __init__(self, text: str, role: Role = "user"): """Initialize text input event.""" super().__init__( { @@ -62,9 +62,9 @@ def text(self) -> str: return cast(str, self.get("text")) @property - def role(self) -> str: + def role(self) -> Role: """The role of the message sender.""" - return cast(str, self.get("role")) + return cast(Role, self["role"]) class BidiAudioInputEvent(TypedEvent): @@ -298,9 +298,9 @@ def text(self) -> str: return cast(str, self.get("text")) @property - def role(self) -> str: + def role(self) -> Role: """The role of the message sender.""" - return cast(str, self.get("role")) + return cast(Role, self["role"]) @property def is_final(self) -> bool: From f29f4f88ab0724c7977ccb7cb36fa0006246a293 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 25 Nov 2025 09:30:44 -0500 Subject: [PATCH 191/242] bidi text input (#79) --- src/strands/experimental/bidi/agent/agent.py | 12 ++--- src/strands/experimental/bidi/io/text.py | 35 ++++++++---- src/strands/experimental/bidi/types/io.py | 12 +++-- .../strands/experimental/bidi/io/test_text.py | 53 +++++++++++++++++++ 4 files changed, 92 insertions(+), 20 deletions(-) create mode 100644 tests/strands/experimental/bidi/io/test_text.py diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index cb1d10fe0..60cb444a5 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -436,9 +436,9 @@ async def run_outputs() -> None: try: await self.start(invocation_state) - start_inputs = [input_.start for input_ in inputs if hasattr(input_, "start")] - start_outputs = [output.start for output in outputs if hasattr(output, "start")] - for start in [*start_inputs, *start_outputs]: + input_starts = [input_.start for input_ in inputs if isinstance(input_, BidiInput)] + output_starts = [output.start for output in outputs if isinstance(output, BidiOutput)] + for start in [*input_starts, *output_starts]: await start() async with asyncio.TaskGroup() as task_group: @@ -446,7 +446,7 @@ async def run_outputs() -> None: task_group.create_task(run_outputs()) finally: - stop_inputs = [input_.stop for input_ in inputs if hasattr(input_, "stop")] - stop_outputs = [output.stop for output in outputs if hasattr(output, "stop")] + input_stops = [input_.stop for input_ in inputs if isinstance(input_, BidiInput)] + output_stops = [output.stop for output in outputs if isinstance(output, BidiOutput)] - await stop_all(*stop_inputs, *stop_outputs, self.stop) + await stop_all(*input_stops, *output_stops, self.stop) diff --git a/src/strands/experimental/bidi/io/text.py b/src/strands/experimental/bidi/io/text.py index 8ecbae149..715dc4452 100644 --- a/src/strands/experimental/bidi/io/text.py +++ b/src/strands/experimental/bidi/io/text.py @@ -1,23 +1,36 @@ """Handle text input and output from bidi agent.""" +import asyncio import logging +import sys -from ..types.events import BidiInterruptionEvent, BidiOutputEvent, BidiTranscriptStreamEvent -from ..types.io import BidiOutput +from ..types.events import BidiInterruptionEvent, BidiOutputEvent, BidiTextInputEvent, BidiTranscriptStreamEvent +from ..types.io import BidiInput, BidiOutput logger = logging.getLogger(__name__) -class _BidiTextOutput(BidiOutput): - """Handle text output from bidi agent.""" +class _BidiTextInput(BidiInput): + """Handle text input from user.""" + + def __init__(self) -> None: + """Setup async stream reader.""" + self._reader = asyncio.StreamReader() async def start(self) -> None: - """Start text output.""" - pass + """Connect reader to stdin.""" + loop = asyncio.get_running_loop() + protocol = asyncio.StreamReaderProtocol(self._reader) + await loop.connect_read_pipe(lambda: protocol, sys.stdin) + + async def __call__(self) -> BidiTextInputEvent: + """Read user input from stdin.""" + text = (await self._reader.readline()).decode().strip() + return BidiTextInputEvent(text, role="user") - async def stop(self) -> None: - """Stop text output.""" - pass + +class _BidiTextOutput(BidiOutput): + """Handle text output from bidi agent.""" async def __call__(self, event: BidiOutputEvent) -> None: """Print text events to stdout.""" @@ -46,6 +59,10 @@ async def __call__(self, event: BidiOutputEvent) -> None: class BidiTextIO: """Handle text input and output from bidi agent.""" + def input(self) -> _BidiTextInput: + """Return text processing BidiInput.""" + return _BidiTextInput() + def output(self) -> _BidiTextOutput: """Return text processing BidiOutput.""" return _BidiTextOutput() diff --git a/src/strands/experimental/bidi/types/io.py b/src/strands/experimental/bidi/types/io.py index 10ae5db77..35b695f1c 100644 --- a/src/strands/experimental/bidi/types/io.py +++ b/src/strands/experimental/bidi/types/io.py @@ -5,11 +5,12 @@ by separating input and output concerns into independent callables. """ -from typing import Awaitable, Protocol +from typing import Awaitable, Protocol, runtime_checkable from ..types.events import BidiInputEvent, BidiOutputEvent +@runtime_checkable class BidiInput(Protocol): """Protocol for bidirectional input callables. @@ -19,11 +20,11 @@ class BidiInput(Protocol): async def start(self) -> None: """Start input.""" - ... + return async def stop(self) -> None: """Stop input.""" - ... + return def __call__(self) -> Awaitable[BidiInputEvent]: """Read input data from the source. @@ -34,6 +35,7 @@ def __call__(self) -> Awaitable[BidiInputEvent]: ... +@runtime_checkable class BidiOutput(Protocol): """Protocol for bidirectional output callables. @@ -43,11 +45,11 @@ class BidiOutput(Protocol): async def start(self) -> None: """Start output.""" - ... + return async def stop(self) -> None: """Stop output.""" - ... + return def __call__(self, event: BidiOutputEvent) -> Awaitable[None]: """Process output events from the agent. diff --git a/tests/strands/experimental/bidi/io/test_text.py b/tests/strands/experimental/bidi/io/test_text.py new file mode 100644 index 000000000..9ecf22eaf --- /dev/null +++ b/tests/strands/experimental/bidi/io/test_text.py @@ -0,0 +1,53 @@ +import unittest.mock + +import pytest + +from strands.experimental.bidi.io import BidiTextIO +from strands.experimental.bidi.types.events import BidiInterruptionEvent, BidiTextInputEvent, BidiTranscriptStreamEvent + + +@pytest.fixture +def stream_reader(): + with unittest.mock.patch("strands.experimental.bidi.io.text.asyncio.StreamReader") as mock: + yield mock.return_value + + +@pytest.fixture +def text_io(): + return BidiTextIO() + + +@pytest.fixture +def text_input(text_io): + return text_io.input() + + +@pytest.fixture +def text_output(text_io): + return text_io.output() + + +@pytest.mark.asyncio +async def test_bidi_text_io_input(stream_reader, text_input): + stream_reader.readline = unittest.mock.AsyncMock() + stream_reader.readline.return_value = b"test value" + + tru_event = await text_input() + exp_event = BidiTextInputEvent(text="test value", role="user") + assert tru_event == exp_event + + +@pytest.mark.parametrize( + ("event", "exp_print"), + [ + (BidiInterruptionEvent(reason="user_speech"), "interrupted"), + (BidiTranscriptStreamEvent(text="test text", delta="", is_final=False, role="user"), "Preview: test text"), + (BidiTranscriptStreamEvent(text="test text", delta="", is_final=True, role="user"), "test text"), + ] +) +@pytest.mark.asyncio +async def test_bidi_text_io_output_interrupt(event, exp_print, text_output, capsys): + await text_output(event) + + tru_print = capsys.readouterr().out.strip() + assert tru_print == exp_print From 923d93ebc71c359df0b8db3aabbd4a5d5723adad Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 25 Nov 2025 10:24:01 -0500 Subject: [PATCH 192/242] updated agent and _caller --- src/strands/experimental/bidi/agent/agent.py | 9 --------- src/strands/tools/_caller.py | 6 +++++- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 238089a8c..0d1eb2e02 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -177,12 +177,6 @@ def tool_names(self) -> list[str]: return list(all_tools.keys()) - - async def _append_message(self, message: Message) -> None: - """Appends a message to the agent's list of messages and invokes the callbacks for the MessageAddedEvent.""" - self.messages.append(message) - await self.hooks.invoke_callbacks_async(BidiMessageAddedEvent(agent=self, message=message)) - async def start(self, invocation_state: dict[str, Any] | None = None) -> None: """Start a persistent bidirectional conversation connection. @@ -380,9 +374,6 @@ async def run_outputs() -> None: await asyncio.gather(*tasks) try: - # Only start if not already started (e.g., when used with async context manager) - if not self._started: - await self.start(invocation_state) start_inputs = [input_.start for input_ in inputs if hasattr(input_, "start")] start_outputs = [output.start for output in outputs if hasattr(output, "start")] diff --git a/src/strands/tools/_caller.py b/src/strands/tools/_caller.py index fc7a3efb9..94fdcfec4 100644 --- a/src/strands/tools/_caller.py +++ b/src/strands/tools/_caller.py @@ -104,7 +104,11 @@ async def acall() -> ToolResult: return tool_result tool_result = run_async(acall) - self._agent.conversation_manager.apply_management(self._agent) + + # Apply conversation management if agent supports it (traditional agents) + if hasattr(self._agent, "conversation_manager"): + self._agent.conversation_manager.apply_management(self._agent) + return tool_result return caller From dec6c8edfe4e97e1497ca460ddf5856f47307797 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 25 Nov 2025 10:30:31 -0500 Subject: [PATCH 193/242] udpated agent --- src/strands/experimental/bidi/agent/agent.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 0d1eb2e02..17393f023 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -130,7 +130,7 @@ def __init__( self.state = AgentState() # Initialize other components - self.tool_caller = _ToolCaller(self) + self._tool_caller = _ToolCaller(self) # Initialize tool executor self.tool_executor = tool_executor or ConcurrentToolExecutor() @@ -164,7 +164,7 @@ def tool(self) -> _ToolCaller: agent.tool.calculator(expression="2+2") ``` """ - return self.tool_caller + return self._tool_caller @property def tool_names(self) -> list[str]: @@ -374,6 +374,7 @@ async def run_outputs() -> None: await asyncio.gather(*tasks) try: + await self.start(invocation_state) start_inputs = [input_.start for input_ in inputs if hasattr(input_, "start")] start_outputs = [output.start for output in outputs if hasattr(output, "start")] From 0d7406cedea0ff68601f018044de6900fee5fbb3 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 25 Nov 2025 11:38:44 -0500 Subject: [PATCH 194/242] update configuration for audio --- src/strands/experimental/bidi/agent/agent.py | 4 +- src/strands/experimental/bidi/io/audio.py | 18 +++-- src/strands/experimental/bidi/io/text.py | 2 +- .../experimental/bidi/models/bidi_model.py | 5 +- .../experimental/bidi/models/gemini_live.py | 52 ++++++++------- .../experimental/bidi/models/novasonic.py | 59 +++++++++-------- .../experimental/bidi/models/openai.py | 57 ++++++++-------- src/strands/experimental/bidi/types/io.py | 4 +- .../bidi/models/test_gemini_live.py | 62 ++++++++--------- .../bidi/models/test_novasonic.py | 50 +++++++------- .../experimental/bidi/models/test_openai.py | 66 ++++++++++--------- 11 files changed, 197 insertions(+), 182 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index e1cde9fbe..16e5a4b90 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -334,7 +334,7 @@ async def run( ) # Using custom audio config: - model = BidiNovaSonicModel(audio_config={"input_rate": 48000, "output_rate": 24000}) + model = BidiNovaSonicModel(config={"audio": {"input_rate": 48000, "output_rate": 24000}}) audio_io = BidiAudioIO() agent = BidiAgent(model=model, tools=[calculator]) await agent.run( @@ -364,7 +364,7 @@ async def run_outputs() -> None: input_starts = [input_.start for input_ in inputs if isinstance(input_, BidiInput)] output_starts = [output.start for output in outputs if isinstance(output, BidiOutput)] for start in [*input_starts, *output_starts]: - await start() + await start(self) async with asyncio.TaskGroup() as task_group: task_group.create_task(run_inputs()) diff --git a/src/strands/experimental/bidi/io/audio.py b/src/strands/experimental/bidi/io/audio.py index d8ea61496..919a69134 100644 --- a/src/strands/experimental/bidi/io/audio.py +++ b/src/strands/experimental/bidi/io/audio.py @@ -45,11 +45,12 @@ def __init__(self, config: dict[str, Any]) -> None: Args: config: Configuration dictionary with optional overrides: - input_device_index: Specific input device to use + - input_pyaudio_format: PyAudio format (default: paInt16) - input_frames_per_buffer: Number of frames per buffer """ # Initialize instance variables from config or class constants self._device_index = config.get("input_device_index", self._DEVICE_INDEX) - self._pyaudio_format = self._PYAUDIO_FORMAT + self._pyaudio_format = config.get("input_pyaudio_format", self._PYAUDIO_FORMAT) self._frames_per_buffer = config.get("input_frames_per_buffer", self._FRAMES_PER_BUFFER) async def start(self, agent: "BidiAgent") -> None: @@ -59,9 +60,9 @@ async def start(self, agent: "BidiAgent") -> None: agent: The BidiAgent instance, providing access to model configuration. """ # Get audio parameters from model config - self._rate = agent.model.audio_config["input_rate"] - self._channels = agent.model.audio_config["channels"] - self._format = agent.model.audio_config.get("format", "pcm") # Encoding format for events + self._rate = agent.model.config["audio"]["input_rate"] + self._channels = agent.model.config["audio"]["channels"] + self._format = agent.model.config["audio"].get("format", "pcm") # Encoding format for events logger.debug( "rate=<%d>, channels=<%d>, device_index=<%s> | starting audio input stream", @@ -133,13 +134,14 @@ def __init__(self, config: dict[str, Any]) -> None: Args: config: Configuration dictionary with optional overrides: - output_device_index: Specific output device to use + - output_pyaudio_format: PyAudio format (default: paInt16) - output_frames_per_buffer: Number of frames per buffer - output_buffer_size: Maximum buffer size (None = unlimited) """ # Initialize instance variables from config or class constants self._buffer_size = config.get("output_buffer_size", self._BUFFER_SIZE) self._device_index = config.get("output_device_index", self._DEVICE_INDEX) - self._pyaudio_format = self._PYAUDIO_FORMAT # Not configurable + self._pyaudio_format = config.get("output_pyaudio_format", self._PYAUDIO_FORMAT) self._frames_per_buffer = config.get("output_frames_per_buffer", self._FRAMES_PER_BUFFER) async def start(self, agent: "BidiAgent") -> None: @@ -149,8 +151,8 @@ async def start(self, agent: "BidiAgent") -> None: agent: The BidiAgent instance, providing access to model configuration. """ # Get audio parameters from model config - self._rate = agent.model.audio_config["output_rate"] - self._channels = agent.model.audio_config["channels"] + self._rate = agent.model.config["audio"]["output_rate"] + self._channels = agent.model.config["audio"]["channels"] logger.debug( "rate=<%d>, channels=<%d> | starting audio output stream", @@ -217,6 +219,8 @@ class BidiAudioIO: **config: Optional device configuration: - input_device_index (int): Specific input device (default: None = system default) - output_device_index (int): Specific output device (default: None = system default) + - input_pyaudio_format (int): PyAudio format for input (default: pyaudio.paInt16) + - output_pyaudio_format (int): PyAudio format for output (default: pyaudio.paInt16) - input_frames_per_buffer (int): Input buffer size (default: 512) - output_frames_per_buffer (int): Output buffer size (default: 512) - output_buffer_size (int | None): Max output queue size (default: None = unlimited) diff --git a/src/strands/experimental/bidi/io/text.py b/src/strands/experimental/bidi/io/text.py index ff31fee0e..7eadcb341 100644 --- a/src/strands/experimental/bidi/io/text.py +++ b/src/strands/experimental/bidi/io/text.py @@ -21,7 +21,7 @@ def __init__(self) -> None: """Setup async stream reader.""" self._reader = asyncio.StreamReader() - async def start(self) -> None: + async def start(self, agent: "BidiAgent") -> None: """Connect reader to stdin.""" loop = asyncio.get_running_loop() protocol = asyncio.StreamReaderProtocol(self._reader) diff --git a/src/strands/experimental/bidi/models/bidi_model.py b/src/strands/experimental/bidi/models/bidi_model.py index f1eae4208..69197dff3 100644 --- a/src/strands/experimental/bidi/models/bidi_model.py +++ b/src/strands/experimental/bidi/models/bidi_model.py @@ -34,12 +34,9 @@ class BidiModel(Protocol): connections with real-time audio and text communication. Implementations handle provider-specific protocols while exposing a standardized event-based API. - All bidirectional models must provide an audio_config property that specifies - their audio processing requirements. This configuration is built by merging - user-provided values with model-specific defaults. """ - audio_config: AudioConfig + config: dict[str, Any] async def start( self, diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 3b562efed..b9f88b717 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -62,7 +62,7 @@ def __init__( self, model_id: str = "gemini-2.5-flash-native-audio-preview-09-2025", api_key: str | None = None, - audio_config: AudioConfig | None = None, + config: dict[str, Any] | None = None, live_config: dict[str, Any] | None = None, **kwargs: Any, ): @@ -71,9 +71,9 @@ def __init__( Args: model_id: Gemini Live model identifier. api_key: Google AI API key for authentication. + config: Optional configuration dictionary with structure {"audio": AudioConfig, ...}. + If not provided or if "audio" key is missing, uses Gemini Live API's default audio configuration. live_config: Gemini Live API configuration parameters (e.g., response_modalities, speech_config). - audio_config: Optional audio configuration override. If not provided, - uses Gemini Live API's default configuration. **kwargs: Reserved for future parameters. """ # Model configuration @@ -108,33 +108,35 @@ def __init__( self._live_session_context_manager: Any = None self._connection_id: str | None = None + # Extract audio config from config dict if provided + user_audio_config = config.get("audio", {}) if config else {} + # Extract voice from live_config if provided - default_voice = None + live_config_voice = None if self.live_config and "speech_config" in self.live_config: speech_config = self.live_config["speech_config"] if isinstance(speech_config, dict): - default_voice = speech_config.get("voice_config", {}).get("prebuilt_voice_config", {}).get("voice_name") - - # Build audio configuration - use provided values or defaults - config_dict: AudioConfig = { - "input_rate": audio_config.get("input_rate", GEMINI_INPUT_SAMPLE_RATE) - if audio_config - else GEMINI_INPUT_SAMPLE_RATE, - "output_rate": audio_config.get("output_rate", GEMINI_OUTPUT_SAMPLE_RATE) - if audio_config - else GEMINI_OUTPUT_SAMPLE_RATE, - "channels": audio_config.get("channels", GEMINI_CHANNELS) if audio_config else GEMINI_CHANNELS, - "format": audio_config.get("format", "pcm") if audio_config else "pcm", + live_config_voice = speech_config.get("voice_config", {}).get("prebuilt_voice_config", {}).get("voice_name") + + # Define default audio configuration + default_audio_config: AudioConfig = { + "input_rate": GEMINI_INPUT_SAMPLE_RATE, + "output_rate": GEMINI_OUTPUT_SAMPLE_RATE, + "channels": GEMINI_CHANNELS, + "format": "pcm", } - # Add voice if configured (either from user or live_config) - voice_value = audio_config.get("voice", default_voice) if audio_config else default_voice - if voice_value: - config_dict["voice"] = voice_value + # Add voice to defaults if configured in live_config + if live_config_voice: + default_audio_config["voice"] = live_config_voice + + # Merge user config with defaults (user values take precedence) + merged_audio_config = cast(AudioConfig, {**default_audio_config, **user_audio_config}) - self.audio_config = config_dict + # Store config with audio defaults always populated + self.config: dict[str, Any] = {"audio": merged_audio_config} - if audio_config: + if user_audio_config: logger.debug("audio_config | merged user-provided config with defaults") else: logger.debug("audio_config | using default Gemini Live audio configuration") @@ -487,11 +489,11 @@ def _build_live_config( if tools: config_dict["tools"] = self._format_tools_for_live_api(tools) - # Override voice with audio_config value if present (audio_config takes precedence) - if "voice" in self.audio_config: + # Override voice with config value if present (config takes precedence) + if "voice" in self.config["audio"]: config_dict.setdefault("speech_config", {}).setdefault("voice_config", {}).setdefault( "prebuilt_voice_config", {} - )["voice_name"] = self.audio_config["voice"] + )["voice_name"] = self.config["audio"]["voice"] return config_dict diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 06ac99012..818f627f9 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -17,7 +17,7 @@ import json import logging import uuid -from typing import Any, AsyncGenerator, cast +from typing import Any, AsyncGenerator, cast, Literal import boto3 from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput @@ -97,7 +97,7 @@ def __init__( model_id: str = "amazon.nova-sonic-v1:0", boto_session: boto3.Session | None = None, region: str | None = None, - audio_config: AudioConfig | None = None, + config: dict[str, Any] | None = None, **kwargs: Any, ) -> None: """Initialize Nova Sonic bidirectional model. @@ -106,8 +106,8 @@ def __init__( model_id: Nova Sonic model identifier. boto_session: Boto Session to use when calling the Nova Sonic Model. region: AWS region - audio_config: Optional audio configuration override. If not provided, - uses Nova Sonic's default configuration. + config: Optional configuration dictionary with structure {"audio": AudioConfig, ...}. + If not provided or if "audio" key is missing, uses Nova Sonic's default audio configuration. **kwargs: Reserved for future parameters. """ if region and boto_session: @@ -134,24 +134,25 @@ def __init__( logger.debug("model_id=<%s> | nova sonic model initialized", model_id) - # Build audio configuration - use provided values or defaults - self.audio_config: AudioConfig = { - "input_rate": audio_config.get("input_rate", NOVA_AUDIO_INPUT_CONFIG["sampleRateHertz"]) - if audio_config - else NOVA_AUDIO_INPUT_CONFIG["sampleRateHertz"], # type: ignore[typeddict-item] - "output_rate": audio_config.get("output_rate", NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]) - if audio_config - else NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"], # type: ignore[typeddict-item] - "channels": audio_config.get("channels", NOVA_AUDIO_INPUT_CONFIG["channelCount"]) - if audio_config - else NOVA_AUDIO_INPUT_CONFIG["channelCount"], # type: ignore[typeddict-item] - "format": audio_config.get("format", "pcm") if audio_config else "pcm", - "voice": audio_config.get("voice", NOVA_AUDIO_OUTPUT_CONFIG["voiceId"]) - if audio_config - else NOVA_AUDIO_OUTPUT_CONFIG["voiceId"], # type: ignore[typeddict-item] + # Extract audio config from config dict if provided + user_audio_config = config.get("audio", {}) if config else {} + + # Define default audio configuration + default_audio_config: AudioConfig = { + "input_rate": cast(int, NOVA_AUDIO_INPUT_CONFIG["sampleRateHertz"]), + "output_rate": cast(int, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]), + "channels": cast(int, NOVA_AUDIO_INPUT_CONFIG["channelCount"]), + "format": "pcm", + "voice": cast(str, NOVA_AUDIO_OUTPUT_CONFIG["voiceId"]), } - if audio_config: + # Merge user config with defaults (user values take precedence) + merged_audio_config = cast(AudioConfig, {**default_audio_config, **user_audio_config}) + + # Store config with audio defaults always populated + self.config: dict[str, Any] = {"audio": merged_audio_config} + + if user_audio_config: logger.debug("audio_config | merged user-provided config with defaults") else: logger.debug("audio_config | using default Nova Sonic audio configuration") @@ -309,12 +310,12 @@ async def _start_audio_connection(self) -> None: logger.debug("nova audio connection starting") self._audio_content_name = str(uuid.uuid4()) - # Build audio input configuration from audio_config + # Build audio input configuration from config audio_input_config = { "mediaType": "audio/lpcm", - "sampleRateHertz": self.audio_config["input_rate"], + "sampleRateHertz": self.config["audio"]["input_rate"], "sampleSizeBits": 16, - "channelCount": self.audio_config["channels"], + "channelCount": self.config["audio"]["channels"], "audioType": "SPEECH", "encoding": "base64", } @@ -458,8 +459,8 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N if "audioOutput" in nova_event: # Audio is already base64 string from Nova Sonic audio_content = nova_event["audioOutput"]["content"] - # Channels from audio_config is guaranteed to be 1 or 2 - channels: Literal[1, 2] = self.audio_config["channels"] # type: ignore[assignment] + # Channels from config is guaranteed to be 1 or 2 + channels = cast(Literal[1, 2], self.config["audio"]["channels"]) return BidiAudioStreamEvent( audio=audio_content, format="pcm", @@ -536,13 +537,13 @@ def _get_connection_start_event(self) -> str: def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: """Generate Nova Sonic prompt start event with tool configuration.""" - # Build audio output configuration from audio_config + # Build audio output configuration from config audio_output_config = { "mediaType": "audio/lpcm", - "sampleRateHertz": self.audio_config["output_rate"], + "sampleRateHertz": self.config["audio"]["output_rate"], "sampleSizeBits": 16, - "channelCount": self.audio_config["channels"], - "voiceId": self.audio_config.get("voice", "matthew"), + "channelCount": self.config["audio"]["channels"], + "voiceId": self.config["audio"].get("voice", "matthew"), "encoding": "base64", "audioType": "SPEECH", } diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 14aa221d5..6fc4f458d 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -8,7 +8,7 @@ import logging import os import uuid -from typing import Any, AsyncGenerator, cast +from typing import Any, AsyncGenerator, cast, Literal import websockets from websockets import ClientConnection @@ -82,7 +82,7 @@ def __init__( organization: str | None = None, project: str | None = None, session_config: dict[str, Any] | None = None, - audio_config: AudioConfig | None = None, + config: dict[str, Any] | None = None, **kwargs: Any, ) -> None: """Initialize OpenAI Realtime bidirectional model. @@ -93,8 +93,8 @@ def __init__( organization: OpenAI organization ID for API requests. project: OpenAI project ID for API requests. session_config: Session configuration parameters (e.g., voice, turn_detection, modalities). - audio_config: Optional audio configuration override. If not provided, - uses OpenAI Realtime API's default configuration. + config: Optional configuration dictionary with structure {"audio": AudioConfig, ...}. + If not provided or if "audio" key is missing, uses OpenAI Realtime API's default audio configuration. **kwargs: Reserved for future parameters. """ # Model configuration @@ -116,31 +116,36 @@ def __init__( self._function_call_buffer: dict[str, Any] = {} - logger.debug("model=<%s> | openai realtime model initialized", model) + logger.debug("model=<%s> | openai realtime model initialized", model_id) - # Extract voice from session_config if provided, otherwise use default - default_voice = "alloy" + # Extract audio config from config dict if provided + user_audio_config = config.get("audio", {}) if config else {} + + # Extract voice from session_config if provided + session_config_voice = "alloy" if self.session_config and "audio" in self.session_config: audio_settings = self.session_config["audio"] if isinstance(audio_settings, dict) and "output" in audio_settings: output_settings = audio_settings["output"] if isinstance(output_settings, dict): - default_voice = output_settings.get("voice", default_voice) - - # Build audio configuration - use provided values or defaults - self.audio_config: AudioConfig = { - "input_rate": audio_config.get("input_rate", AUDIO_FORMAT["rate"]) - if audio_config - else AUDIO_FORMAT["rate"], # type: ignore[typeddict-item] - "output_rate": audio_config.get("output_rate", AUDIO_FORMAT["rate"]) - if audio_config - else AUDIO_FORMAT["rate"], # type: ignore[typeddict-item] - "channels": audio_config.get("channels", 1) if audio_config else 1, - "format": audio_config.get("format", "pcm") if audio_config else "pcm", - "voice": audio_config.get("voice", default_voice) if audio_config else default_voice, + session_config_voice = output_settings.get("voice", "alloy") + + # Define default audio configuration + default_audio_config: AudioConfig = { + "input_rate": cast(int, AUDIO_FORMAT["rate"]), + "output_rate": cast(int, AUDIO_FORMAT["rate"]), + "channels": 1, + "format": "pcm", + "voice": session_config_voice, } - if audio_config: + # Merge user config with defaults (user values take precedence) + merged_audio_config = cast(AudioConfig, {**default_audio_config, **user_audio_config}) + + # Store config with audio defaults always populated + self.config: dict[str, Any] = {"audio": merged_audio_config} + + if user_audio_config: logger.debug("audio_config | merged user-provided config with defaults") else: logger.debug("audio_config | using default OpenAI Realtime audio configuration") @@ -250,9 +255,9 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] else: logger.warning("parameter=<%s> | ignoring unsupported session parameter", key) - # Override voice with audio_config value if present (audio_config takes precedence) - if "voice" in self.audio_config: - config.setdefault("audio", {}).setdefault("output", {})["voice"] = self.audio_config["voice"] # type: ignore + # Override voice with config value if present (config takes precedence) + if "voice" in self.config["audio"]: + config.setdefault("audio", {}).setdefault("output", {})["voice"] = self.config["audio"]["voice"] return config @@ -326,8 +331,8 @@ def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutput # Audio output elif event_type == "response.output_audio.delta": # Audio is already base64 string from OpenAI - # Channels from audio_config is guaranteed to be 1 or 2 - channels: Literal[1, 2] = self.audio_config["channels"] # type: ignore[assignment] + # Channels from config is guaranteed to be 1 or 2 + channels = cast(Literal[1, 2], self.config["audio"]["channels"]) return [ BidiAudioStreamEvent( audio=openai_event["delta"], diff --git a/src/strands/experimental/bidi/types/io.py b/src/strands/experimental/bidi/types/io.py index 0c4b44704..7125eb5ef 100644 --- a/src/strands/experimental/bidi/types/io.py +++ b/src/strands/experimental/bidi/types/io.py @@ -22,7 +22,7 @@ class BidiInput(Protocol): and return events to be sent to the agent. """ - async def start(self) -> None: + async def start(self, agent: "BidiAgent") -> None: """Start input.""" return @@ -47,7 +47,7 @@ class BidiOutput(Protocol): (play audio, display text, send over websocket, etc.). """ - async def start(self) -> None: + async def start(self, agent: "BidiAgent") -> None: """Start output.""" return diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index cf0c70dc0..576e8c3df 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -454,62 +454,64 @@ def test_audio_config_defaults(mock_genai_client, model_id, api_key): model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) - assert model.audio_config["input_rate"] == 16000 - assert model.audio_config["output_rate"] == 24000 - assert model.audio_config["channels"] == 1 - assert model.audio_config["format"] == "pcm" - assert "voice" not in model.audio_config # No default voice + assert model.config["audio"]["input_rate"] == 16000 + assert model.config["audio"]["output_rate"] == 24000 + assert model.config["audio"]["channels"] == 1 + assert model.config["audio"]["format"] == "pcm" + assert "voice" not in model.config["audio"] # No default voice def test_audio_config_partial_override(mock_genai_client, model_id, api_key): """Test partial audio configuration override.""" _ = mock_genai_client - audio_config = {"output_rate": 48000, "voice": "Puck"} - model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key, audio_config=audio_config) + config = {"audio": {"output_rate": 48000, "voice": "Puck"}} + model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key, config=config) # Overridden values - assert model.audio_config["output_rate"] == 48000 - assert model.audio_config["voice"] == "Puck" + assert model.config["audio"]["output_rate"] == 48000 + assert model.config["audio"]["voice"] == "Puck" # Default values preserved - assert model.audio_config["input_rate"] == 16000 - assert model.audio_config["channels"] == 1 - assert model.audio_config["format"] == "pcm" + assert model.config["audio"]["input_rate"] == 16000 + assert model.config["audio"]["channels"] == 1 + assert model.config["audio"]["format"] == "pcm" def test_audio_config_full_override(mock_genai_client, model_id, api_key): """Test full audio configuration override.""" _ = mock_genai_client - audio_config = { - "input_rate": 48000, - "output_rate": 48000, - "channels": 2, - "format": "pcm", - "voice": "Aoede", + config = { + "audio": { + "input_rate": 48000, + "output_rate": 48000, + "channels": 2, + "format": "pcm", + "voice": "Aoede", + } } - model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key, audio_config=audio_config) + model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key, config=config) - assert model.audio_config["input_rate"] == 48000 - assert model.audio_config["output_rate"] == 48000 - assert model.audio_config["channels"] == 2 - assert model.audio_config["format"] == "pcm" - assert model.audio_config["voice"] == "Aoede" + assert model.config["audio"]["input_rate"] == 48000 + assert model.config["audio"]["output_rate"] == 48000 + assert model.config["audio"]["channels"] == 2 + assert model.config["audio"]["format"] == "pcm" + assert model.config["audio"]["voice"] == "Aoede" def test_audio_config_voice_priority(mock_genai_client, model_id, api_key): - """Test that audio_config voice takes precedence over live_config voice.""" + """Test that config audio voice takes precedence over live_config voice.""" _ = mock_genai_client live_config = {"speech_config": {"voice_config": {"prebuilt_voice_config": {"voice_name": "Puck"}}}} - audio_config = {"voice": "Aoede"} + config = {"audio": {"voice": "Aoede"}} - model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key, live_config=live_config, audio_config=audio_config) + model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key, live_config=live_config, config=config) - # Build config and verify audio_config voice takes precedence - config = model._build_live_config() - assert config["speech_config"]["voice_config"]["prebuilt_voice_config"]["voice_name"] == "Aoede" + # Build config and verify config audio voice takes precedence + built_config = model._build_live_config() + assert built_config["speech_config"]["voice_config"]["prebuilt_voice_config"]["voice_name"] == "Aoede" # Helper Method Tests diff --git a/tests/strands/experimental/bidi/models/test_novasonic.py b/tests/strands/experimental/bidi/models/test_novasonic.py index 267327a0d..1c8149b75 100644 --- a/tests/strands/experimental/bidi/models/test_novasonic.py +++ b/tests/strands/experimental/bidi/models/test_novasonic.py @@ -94,46 +94,48 @@ async def test_audio_config_defaults(model_id, region): """Test default audio configuration.""" model = BidiNovaSonicModel(model_id=model_id, region=region) - assert model.audio_config["input_rate"] == 16000 - assert model.audio_config["output_rate"] == 16000 - assert model.audio_config["channels"] == 1 - assert model.audio_config["format"] == "pcm" - assert model.audio_config["voice"] == "matthew" + assert model.config["audio"]["input_rate"] == 16000 + assert model.config["audio"]["output_rate"] == 16000 + assert model.config["audio"]["channels"] == 1 + assert model.config["audio"]["format"] == "pcm" + assert model.config["audio"]["voice"] == "matthew" @pytest.mark.asyncio async def test_audio_config_partial_override(model_id, region): """Test partial audio configuration override.""" - audio_config = {"output_rate": 24000, "voice": "ruth"} - model = BidiNovaSonicModel(model_id=model_id, region=region, audio_config=audio_config) + config = {"audio": {"output_rate": 24000, "voice": "ruth"}} + model = BidiNovaSonicModel(model_id=model_id, region=region, config=config) # Overridden values - assert model.audio_config["output_rate"] == 24000 - assert model.audio_config["voice"] == "ruth" + assert model.config["audio"]["output_rate"] == 24000 + assert model.config["audio"]["voice"] == "ruth" # Default values preserved - assert model.audio_config["input_rate"] == 16000 - assert model.audio_config["channels"] == 1 - assert model.audio_config["format"] == "pcm" + assert model.config["audio"]["input_rate"] == 16000 + assert model.config["audio"]["channels"] == 1 + assert model.config["audio"]["format"] == "pcm" @pytest.mark.asyncio async def test_audio_config_full_override(model_id, region): """Test full audio configuration override.""" - audio_config = { - "input_rate": 48000, - "output_rate": 48000, - "channels": 2, - "format": "pcm", - "voice": "stephen", + config = { + "audio": { + "input_rate": 48000, + "output_rate": 48000, + "channels": 2, + "format": "pcm", + "voice": "stephen", + } } - model = BidiNovaSonicModel(model_id=model_id, region=region, audio_config=audio_config) + model = BidiNovaSonicModel(model_id=model_id, region=region, config=config) - assert model.audio_config["input_rate"] == 48000 - assert model.audio_config["output_rate"] == 48000 - assert model.audio_config["channels"] == 2 - assert model.audio_config["format"] == "pcm" - assert model.audio_config["voice"] == "stephen" + assert model.config["audio"]["input_rate"] == 48000 + assert model.config["audio"]["output_rate"] == 48000 + assert model.config["audio"]["channels"] == 2 + assert model.config["audio"]["format"] == "pcm" + assert model.config["audio"]["voice"] == "stephen" @pytest.mark.asyncio diff --git a/tests/strands/experimental/bidi/models/test_openai.py b/tests/strands/experimental/bidi/models/test_openai.py index 55cfac26d..ee3dd45c9 100644 --- a/tests/strands/experimental/bidi/models/test_openai.py +++ b/tests/strands/experimental/bidi/models/test_openai.py @@ -117,68 +117,70 @@ def test_audio_config_defaults(api_key, model_name): """Test default audio configuration.""" model = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) - assert model.audio_config["input_rate"] == 24000 - assert model.audio_config["output_rate"] == 24000 - assert model.audio_config["channels"] == 1 - assert model.audio_config["format"] == "pcm" - assert model.audio_config["voice"] == "alloy" + assert model.config["audio"]["input_rate"] == 24000 + assert model.config["audio"]["output_rate"] == 24000 + assert model.config["audio"]["channels"] == 1 + assert model.config["audio"]["format"] == "pcm" + assert model.config["audio"]["voice"] == "alloy" def test_audio_config_partial_override(api_key, model_name): """Test partial audio configuration override.""" - audio_config = {"output_rate": 48000, "voice": "echo"} - model = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key, audio_config=audio_config) + config = {"audio": {"output_rate": 48000, "voice": "echo"}} + model = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key, config=config) # Overridden values - assert model.audio_config["output_rate"] == 48000 - assert model.audio_config["voice"] == "echo" + assert model.config["audio"]["output_rate"] == 48000 + assert model.config["audio"]["voice"] == "echo" # Default values preserved - assert model.audio_config["input_rate"] == 24000 - assert model.audio_config["channels"] == 1 - assert model.audio_config["format"] == "pcm" + assert model.config["audio"]["input_rate"] == 24000 + assert model.config["audio"]["channels"] == 1 + assert model.config["audio"]["format"] == "pcm" def test_audio_config_full_override(api_key, model_name): """Test full audio configuration override.""" - audio_config = { - "input_rate": 48000, - "output_rate": 48000, - "channels": 2, - "format": "pcm", - "voice": "shimmer", + config = { + "audio": { + "input_rate": 48000, + "output_rate": 48000, + "channels": 2, + "format": "pcm", + "voice": "shimmer", + } } - model = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key, audio_config=audio_config) + model = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key, config=config) - assert model.audio_config["input_rate"] == 48000 - assert model.audio_config["output_rate"] == 48000 - assert model.audio_config["channels"] == 2 - assert model.audio_config["format"] == "pcm" - assert model.audio_config["voice"] == "shimmer" + assert model.config["audio"]["input_rate"] == 48000 + assert model.config["audio"]["output_rate"] == 48000 + assert model.config["audio"]["channels"] == 2 + assert model.config["audio"]["format"] == "pcm" + assert model.config["audio"]["voice"] == "shimmer" def test_audio_config_voice_priority(api_key, model_name): - """Test that audio_config voice takes precedence over session_config voice.""" + """Test that config audio voice takes precedence over session_config voice.""" session_config = {"audio": {"output": {"voice": "alloy"}}} - audio_config = {"voice": "nova"} + config = {"audio": {"voice": "nova"}} model = BidiOpenAIRealtimeModel( - model=model_name, api_key=api_key, session_config=session_config, audio_config=audio_config + model=model_name, api_key=api_key, session_config=session_config, config=config ) - # Build config and verify audio_config voice takes precedence - config = model._build_session_config(None, None) - assert config["audio"]["output"]["voice"] == "nova" + # Build config and verify config audio voice takes precedence + built_config = model._build_session_config(None, None) + assert built_config["audio"]["output"]["voice"] == "nova" def test_audio_config_extracts_voice_from_session_config(api_key, model_name): - """Test that voice is extracted from session_config when audio_config not provided.""" + """Test that voice is extracted from session_config when config audio not provided.""" session_config = {"audio": {"output": {"voice": "fable"}}} model = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key, session_config=session_config) # Should extract voice from session_config - assert model.audio_config["voice"] == "fable" + assert model.config["audio"]["voice"] == "fable" def test_init_without_api_key_raises(): From 8c2c3c95546abdc610d286822bc4ed0647ba9367 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 25 Nov 2025 11:54:47 -0500 Subject: [PATCH 195/242] updating doc strings --- src/strands/experimental/bidi/io/audio.py | 14 +++++++------- src/strands/experimental/bidi/models/bidi_model.py | 2 ++ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/strands/experimental/bidi/io/audio.py b/src/strands/experimental/bidi/io/audio.py index 919a69134..62120b4c4 100644 --- a/src/strands/experimental/bidi/io/audio.py +++ b/src/strands/experimental/bidi/io/audio.py @@ -213,9 +213,12 @@ async def _output(self) -> None: class BidiAudioIO: - """Send and receive audio data from devices. - - Args: + """Send and receive audio data from devices.""" + + def __init__(self, **config: Any) -> None: + """Initialize audio devices. + + Args: **config: Optional device configuration: - input_device_index (int): Specific input device (default: None = system default) - output_device_index (int): Specific output device (default: None = system default) @@ -224,10 +227,7 @@ class BidiAudioIO: - input_frames_per_buffer (int): Input buffer size (default: 512) - output_frames_per_buffer (int): Output buffer size (default: 512) - output_buffer_size (int | None): Max output queue size (default: None = unlimited) - """ - - def __init__(self, **config: Any) -> None: - """Initialize audio devices.""" + """ self._config = config def input(self) -> _BidiAudioInput: diff --git a/src/strands/experimental/bidi/models/bidi_model.py b/src/strands/experimental/bidi/models/bidi_model.py index 69197dff3..253a5d440 100644 --- a/src/strands/experimental/bidi/models/bidi_model.py +++ b/src/strands/experimental/bidi/models/bidi_model.py @@ -34,6 +34,8 @@ class BidiModel(Protocol): connections with real-time audio and text communication. Implementations handle provider-specific protocols while exposing a standardized event-based API. + Attributes: + config: Configuration dictionary with provider-specific settings. """ config: dict[str, Any] From a9c7129b5d1dd0a5183c66e89256c67674eac4c2 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 25 Nov 2025 12:02:29 -0500 Subject: [PATCH 196/242] fix indentation --- src/strands/experimental/bidi/io/audio.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/strands/experimental/bidi/io/audio.py b/src/strands/experimental/bidi/io/audio.py index 62120b4c4..744404882 100644 --- a/src/strands/experimental/bidi/io/audio.py +++ b/src/strands/experimental/bidi/io/audio.py @@ -219,14 +219,14 @@ def __init__(self, **config: Any) -> None: """Initialize audio devices. Args: - **config: Optional device configuration: - - input_device_index (int): Specific input device (default: None = system default) - - output_device_index (int): Specific output device (default: None = system default) - - input_pyaudio_format (int): PyAudio format for input (default: pyaudio.paInt16) - - output_pyaudio_format (int): PyAudio format for output (default: pyaudio.paInt16) - - input_frames_per_buffer (int): Input buffer size (default: 512) - - output_frames_per_buffer (int): Output buffer size (default: 512) - - output_buffer_size (int | None): Max output queue size (default: None = unlimited) + **config: Optional device configuration: + - input_device_index (int): Specific input device (default: None = system default) + - output_device_index (int): Specific output device (default: None = system default) + - input_pyaudio_format (int): PyAudio format for input (default: pyaudio.paInt16) + - output_pyaudio_format (int): PyAudio format for output (default: pyaudio.paInt16) + - input_frames_per_buffer (int): Input buffer size (default: 512) + - output_frames_per_buffer (int): Output buffer size (default: 512) + - output_buffer_size (int | None): Max output queue size (default: None = unlimited) """ self._config = config From fabaf5941794fbb375052eea4183997e51882608 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 26 Nov 2025 10:55:35 +0100 Subject: [PATCH 197/242] address comments --- src/strands/experimental/bidi/agent/agent.py | 2 -- .../experimental/bidi/models/openai.py | 23 +++++++++---------- .../test_repository_session_manager.py | 15 +----------- 3 files changed, 12 insertions(+), 28 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 7cb39ad9e..fea187b47 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -151,8 +151,6 @@ def __init__( self._session_manager = session_manager if self._session_manager: self.hooks.add_hook(self._session_manager) - # Initialize invocation state (will be set in start()) - self._invocation_state: dict[str, Any] = {} self._loop = _BidiAgentLoop(self) diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 7b18f3294..36c7b0356 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -346,6 +346,7 @@ async def _add_conversation_history(self, messages: Messages) -> None: elif "toolResult" in block: # Tool result - create as function_call_output item tool_result = block["toolResult"] + original_id = tool_result["toolUseId"] # Serialize the entire tool result content, preserving all data types result_output = "" @@ -361,18 +362,19 @@ async def _add_conversation_history(self, messages: Messages) -> None: content_parts.append( json.dumps(json_content) if not isinstance(json_content, str) else json_content ) - elif "image" in result_block: - logger.warning("image content in tool results not supported by openai realtime api") - elif "document" in result_block: - logger.warning("document content in tool results not supported by openai realtime api") + else: + # Generic warning for unsupported content types + logger.warning( + "tool_use_id=<%s>, content_types=<%s> | content type in tool results not supported by openai realtime api", + original_id, + list(result_block.keys()), + ) # Combine all parts - if single part, use as-is; if multiple, combine if len(content_parts) == 1: result_output = content_parts[0] elif content_parts: result_output = "\n".join(content_parts) - - original_id = tool_result["toolUseId"] # Use mapped call_id if available, otherwise skip orphaned result if original_id not in call_id_map: continue # Skip this tool result since we don't have the call @@ -740,13 +742,10 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: content_parts.append( json.dumps(json_content) if not isinstance(json_content, str) else json_content ) - elif "image" in block: - raise ValueError( - f"tool_use_id=<{tool_use_id}> | Image content in tool results is not supported by OpenAI Realtime API" - ) - elif "document" in block: + else: + # Generic error for unsupported content types raise ValueError( - f"tool_use_id=<{tool_use_id}> | Document content in tool results is not supported by OpenAI Realtime API" + f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | Content type not supported by OpenAI Realtime API" ) # Combine all parts - if single part, use as-is; if multiple, combine diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index 61e66c569..0b5623ae0 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -7,6 +7,7 @@ from strands.agent.agent import Agent from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager +from strands.agent.state import AgentState from strands.interrupt import _InterruptState from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import ContentBlock @@ -423,10 +424,6 @@ def test_fix_broken_tool_use_does_not_change_valid_message(session_manager): @pytest.fixture def mock_bidi_agent(): """Create a mock BidiAgent for testing.""" - from unittest.mock import Mock - - from strands.agent.state import AgentState - agent = Mock() agent.agent_id = "bidi-agent-1" agent.messages = [{"role": "user", "content": [{"text": "Hello from bidi!"}]}] @@ -454,8 +451,6 @@ def test_initialize_bidi_agent_creates_new(session_manager, mock_bidi_agent): def test_initialize_bidi_agent_restores_existing(session_manager, mock_bidi_agent): """Test initializing BidiAgent restores from existing session.""" - from strands.types.session import SessionAgent, SessionMessage - # Create existing session data session_agent = SessionAgent( agent_id="bidi-agent-1", @@ -499,8 +494,6 @@ def test_append_bidi_message(session_manager, mock_bidi_agent): def test_sync_bidi_agent(session_manager, mock_bidi_agent): """Test syncing BidiAgent state to session.""" - from strands.agent.state import AgentState - # Initialize agent session_manager.initialize_bidi_agent(mock_bidi_agent) @@ -530,10 +523,6 @@ def test_bidi_agent_unique_id_constraint(session_manager, mock_bidi_agent): session_manager.initialize_bidi_agent(mock_bidi_agent) # Try to initialize another agent with same ID - from unittest.mock import Mock - - from strands.agent.state import AgentState - agent2 = Mock() agent2.agent_id = "bidi-agent-1" # Same ID agent2.messages = [] @@ -545,8 +534,6 @@ def test_bidi_agent_unique_id_constraint(session_manager, mock_bidi_agent): def test_bidi_agent_messages_with_offset_zero(session_manager, mock_bidi_agent): """Test that BidiAgent uses offset=0 for message restoration (no conversation_manager).""" - from strands.types.session import SessionAgent, SessionMessage - # Create session with messages session_agent = SessionAgent( agent_id="bidi-agent-1", From 9d01277684b2f2571000030099d88872a6e6c44b Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 26 Nov 2025 13:47:43 +0100 Subject: [PATCH 198/242] improve tool result handling --- .../experimental/bidi/models/gemini_live.py | 15 +++--- .../experimental/bidi/models/novasonic.py | 12 ++++- .../experimental/bidi/models/openai.py | 49 +++++-------------- 3 files changed, 31 insertions(+), 45 deletions(-) diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index b9f88b717..769e50aa8 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -431,15 +431,14 @@ async def _send_text_content(self, text: str) -> None: async def _send_tool_result(self, tool_result: ToolResult) -> None: """Internal: Send tool result using Gemini Live API.""" tool_use_id = tool_result.get("toolUseId") + content = tool_result.get("content", []) - # TODO: We need to extract all content and content types - result_data = {} - if "content" in tool_result: - # Extract text from content blocks - for block in tool_result["content"]: - if "text" in block: - result_data = {"result": block["text"]} - break + # Optimize for single content item - unwrap the array + if len(content) == 1: + result_data: dict[str, Any] = content[0] + else: + # Multiple items - send as array + result_data = {"result": content} # Create function response func_response = genai_types.FunctionResponse( diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 428eab15e..13a78e08c 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -391,9 +391,19 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: logger.debug("tool_use_id=<%s> | sending nova tool result", tool_use_id) - # Nova Sonic expects stringified JSON in toolResult.content + # Validate content types and preserve structure content = tool_result.get("content", []) + # Validate all content types are supported + for block in content: + if "text" not in block and "json" not in block: + # Unsupported content type - log warning + logger.warning( + "tool_use_id=<%s>, content_types=<%s> | content type in tool results not supported by nova sonic", + tool_use_id, + list(block.keys()), + ) + # Optimize for single content item - unwrap the array if len(content) == 1: result_data: dict[str, Any] = content[0] diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 36c7b0356..18843552c 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -348,33 +348,22 @@ async def _add_conversation_history(self, messages: Messages) -> None: tool_result = block["toolResult"] original_id = tool_result["toolUseId"] - # Serialize the entire tool result content, preserving all data types + # Validate content types and serialize, preserving structure result_output = "" if "content" in tool_result: - # Collect all content blocks - content_parts = [] + # First validate all content types are supported for result_block in tool_result["content"]: - if "text" in result_block: - content_parts.append(result_block["text"]) - elif "json" in result_block: - # Preserve JSON content - json_content = result_block["json"] - content_parts.append( - json.dumps(json_content) if not isinstance(json_content, str) else json_content - ) - else: - # Generic warning for unsupported content types + if "text" not in result_block and "json" not in result_block: + # Unsupported content type - log warning and skip logger.warning( "tool_use_id=<%s>, content_types=<%s> | content type in tool results not supported by openai realtime api", original_id, list(result_block.keys()), ) - # Combine all parts - if single part, use as-is; if multiple, combine - if len(content_parts) == 1: - result_output = content_parts[0] - elif content_parts: - result_output = "\n".join(content_parts) + # Preserve structure by JSON-dumping the entire content array + result_output = json.dumps(tool_result["content"]) + # Use mapped call_id if available, otherwise skip orphaned result if original_id not in call_id_map: continue # Skip this tool result since we don't have the call @@ -728,31 +717,19 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: logger.debug("tool_use_id=<%s> | sending openai tool result", tool_use_id) - # Serialize the entire tool result content, preserving all data types + # Validate content types and serialize, preserving structure result_output = "" if "content" in tool_result: - # Collect all content blocks - content_parts = [] + # First validate all content types are supported for block in tool_result["content"]: - if "text" in block: - content_parts.append(block["text"]) - elif "json" in block: - # Preserve JSON content - json_content = block["json"] - content_parts.append( - json.dumps(json_content) if not isinstance(json_content, str) else json_content - ) - else: - # Generic error for unsupported content types + if "text" not in block and "json" not in block: + # Unsupported content type - raise error raise ValueError( f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | Content type not supported by OpenAI Realtime API" ) - # Combine all parts - if single part, use as-is; if multiple, combine - if len(content_parts) == 1: - result_output = content_parts[0] - elif content_parts: - result_output = "\n".join(content_parts) + # Preserve structure by JSON-dumping the entire content array + result_output = json.dumps(tool_result["content"]) item_data = {"type": "function_call_output", "call_id": tool_use_id, "output": result_output} await self._send_event({"type": "conversation.item.create", "item": item_data}) From b5d2bb1ca7a6df56c6c4612f68f6e329d34e0ed3 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 26 Nov 2025 13:47:50 +0100 Subject: [PATCH 199/242] fix integ tests --- tests_integ/bidi/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests_integ/bidi/context.py b/tests_integ/bidi/context.py index 830857564..f60379b60 100644 --- a/tests_integ/bidi/context.py +++ b/tests_integ/bidi/context.py @@ -84,7 +84,7 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): """Stop context manager, cleanup threads, and end agent session.""" # End agent session FIRST - this will cause receive() to exit cleanly - if self.agent._loop and self.agent._loop.active: + if self.agent._started: await self.agent.stop() logger.debug("Agent session stopped") From be51621a188904e275b4d470b6e3074a84f1556b Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 26 Nov 2025 15:53:40 +0100 Subject: [PATCH 200/242] make tool result content handling consistent --- .../experimental/bidi/models/gemini_live.py | 8 ++ .../experimental/bidi/models/novasonic.py | 8 +- .../experimental/bidi/models/openai.py | 8 +- .../bidi/models/test_gemini_live.py | 115 ++++++++++++++++++ .../bidi/models/test_novasonic.py | 38 ++++++ .../experimental/bidi/models/test_openai.py | 59 +++++---- 6 files changed, 199 insertions(+), 37 deletions(-) diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 769e50aa8..e14b79bce 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -433,6 +433,14 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: tool_use_id = tool_result.get("toolUseId") content = tool_result.get("content", []) + # Validate all content types are supported + for block in content: + if "text" not in block and "json" not in block: + # Unsupported content type - raise error + raise ValueError( + f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | Content type not supported by Gemini Live API" + ) + # Optimize for single content item - unwrap the array if len(content) == 1: result_data: dict[str, Any] = content[0] diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 13a78e08c..7105bde00 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -397,11 +397,9 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: # Validate all content types are supported for block in content: if "text" not in block and "json" not in block: - # Unsupported content type - log warning - logger.warning( - "tool_use_id=<%s>, content_types=<%s> | content type in tool results not supported by nova sonic", - tool_use_id, - list(block.keys()), + # Unsupported content type - raise error + raise ValueError( + f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | Content type not supported by Nova Sonic" ) # Optimize for single content item - unwrap the array diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 18843552c..00fac64f8 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -354,11 +354,9 @@ async def _add_conversation_history(self, messages: Messages) -> None: # First validate all content types are supported for result_block in tool_result["content"]: if "text" not in result_block and "json" not in result_block: - # Unsupported content type - log warning and skip - logger.warning( - "tool_use_id=<%s>, content_types=<%s> | content type in tool results not supported by openai realtime api", - original_id, - list(result_block.keys()), + # Unsupported content type - raise error + raise ValueError( + f"tool_use_id=<{original_id}>, content_types=<{list(result_block.keys())}> | Content type not supported by OpenAI Realtime API" ) # Preserve structure by JSON-dumping the entire content array diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index 576e8c3df..053f13dde 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -543,3 +543,118 @@ def test_tool_formatting(model, tool_spec): # Test empty list formatted_empty = model._format_tools_for_live_api([]) assert formatted_empty == [] + + + +# Tool Result Content Tests + + +@pytest.mark.asyncio +async def test_tool_result_single_content_unwrapped(mock_genai_client, model): + """Test that single content item is unwrapped (optimization).""" + _, mock_live_session, _ = mock_genai_client + await model.start() + + tool_result: ToolResult = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Single result"}], + } + + await model.send(ToolResultEvent(tool_result)) + + # Verify the tool response was sent + mock_live_session.send_tool_response.assert_called_once() + call_args = mock_live_session.send_tool_response.call_args + function_responses = call_args.kwargs.get("function_responses", []) + + assert len(function_responses) == 1 + func_response = function_responses[0] + assert func_response.id == "tool-123" + # Single content should be unwrapped (not in array) + assert func_response.response == {"text": "Single result"} + + await model.stop() + + +@pytest.mark.asyncio +async def test_tool_result_multiple_content_as_array(mock_genai_client, model): + """Test that multiple content items are sent as array.""" + _, mock_live_session, _ = mock_genai_client + await model.start() + + tool_result: ToolResult = { + "toolUseId": "tool-456", + "status": "success", + "content": [{"text": "Part 1"}, {"json": {"data": "value"}}], + } + + await model.send(ToolResultEvent(tool_result)) + + # Verify the tool response was sent + mock_live_session.send_tool_response.assert_called_once() + call_args = mock_live_session.send_tool_response.call_args + function_responses = call_args.kwargs.get("function_responses", []) + + assert len(function_responses) == 1 + func_response = function_responses[0] + assert func_response.id == "tool-456" + # Multiple content should be in array format + assert "result" in func_response.response + assert isinstance(func_response.response["result"], list) + assert len(func_response.response["result"]) == 2 + assert func_response.response["result"][0] == {"text": "Part 1"} + assert func_response.response["result"][1] == {"json": {"data": "value"}} + + await model.stop() + + +@pytest.mark.asyncio +async def test_tool_result_unsupported_content_type(mock_genai_client, model): + """Test that unsupported content types raise ValueError.""" + _, _, _ = mock_genai_client + await model.start() + + # Test with image content (unsupported) + tool_result_image: ToolResult = { + "toolUseId": "tool-999", + "status": "success", + "content": [{"image": {"format": "jpeg", "source": {"bytes": b"image_data"}}}], + } + + with pytest.raises(ValueError, match=r"Content type not supported by Gemini Live API"): + await model.send(ToolResultEvent(tool_result_image)) + + # Test with document content (unsupported) + tool_result_doc: ToolResult = { + "toolUseId": "tool-888", + "status": "success", + "content": [{"document": {"format": "pdf", "source": {"bytes": b"doc_data"}}}], + } + + with pytest.raises(ValueError, match=r"Content type not supported by Gemini Live API"): + await model.send(ToolResultEvent(tool_result_doc)) + + # Test with mixed content (one unsupported) + tool_result_mixed: ToolResult = { + "toolUseId": "tool-777", + "status": "success", + "content": [{"text": "Valid text"}, {"image": {"format": "jpeg", "source": {"bytes": b"image_data"}}}], + } + + with pytest.raises(ValueError, match=r"Content type not supported by Gemini Live API"): + await model.send(ToolResultEvent(tool_result_mixed)) + + await model.stop() + + +# Helper fixture for async generator +@pytest.fixture +def agenerator(): + """Helper to create async generators for testing.""" + + async def _agenerator(items): + for item in items: + yield item + + return _agenerator diff --git a/tests/strands/experimental/bidi/models/test_novasonic.py b/tests/strands/experimental/bidi/models/test_novasonic.py index c2bdd27a2..19aac19c2 100644 --- a/tests/strands/experimental/bidi/models/test_novasonic.py +++ b/tests/strands/experimental/bidi/models/test_novasonic.py @@ -653,3 +653,41 @@ async def test_tool_result_empty_content(nova_model, mock_stream): assert content == {"content": []} await nova_model.stop() + + +@pytest.mark.asyncio +async def test_tool_result_unsupported_content_type(nova_model): + """Test that unsupported content types raise ValueError.""" + await nova_model.start() + + # Test with image content (unsupported) + tool_result_image: ToolResult = { + "toolUseId": "tool-999", + "status": "success", + "content": [{"image": {"format": "jpeg", "source": {"bytes": b"image_data"}}}], + } + + with pytest.raises(ValueError, match=r"Content type not supported by Nova Sonic"): + await nova_model.send(ToolResultEvent(tool_result_image)) + + # Test with document content (unsupported) + tool_result_doc: ToolResult = { + "toolUseId": "tool-888", + "status": "success", + "content": [{"document": {"format": "pdf", "source": {"bytes": b"doc_data"}}}], + } + + with pytest.raises(ValueError, match=r"Content type not supported by Nova Sonic"): + await nova_model.send(ToolResultEvent(tool_result_doc)) + + # Test with mixed content (one unsupported) + tool_result_mixed: ToolResult = { + "toolUseId": "tool-777", + "status": "success", + "content": [{"text": "Valid text"}, {"image": {"format": "jpeg", "source": {"bytes": b"image_data"}}}], + } + + with pytest.raises(ValueError, match=r"Content type not supported by Nova Sonic"): + await nova_model.send(ToolResultEvent(tool_result_mixed)) + + await nova_model.stop() diff --git a/tests/strands/experimental/bidi/models/test_openai.py b/tests/strands/experimental/bidi/models/test_openai.py index 32a5f257a..fb2bfb798 100644 --- a/tests/strands/experimental/bidi/models/test_openai.py +++ b/tests/strands/experimental/bidi/models/test_openai.py @@ -91,12 +91,12 @@ def test_model_initialization(api_key, model_name): """Test model initialization with various configurations.""" # Test default config model_default = BidiOpenAIRealtimeModel(api_key="test-key") - assert model_default.model == "gpt-realtime" + assert model_default.model_id == "gpt-realtime" assert model_default.api_key == "test-key" # Test with custom model model_custom = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) - assert model_custom.model == model_name + assert model_custom.model_id == model_name assert model_custom.api_key == api_key # Test with organization and project @@ -308,7 +308,9 @@ async def test_connection_with_message_history(mock_websockets_connect, model): assert len(function_output_items) >= 1 func_output = function_output_items[0] assert func_output["item"]["call_id"] == "call-123" - assert "Sunny, 72°F" in func_output["item"]["output"] + # Content is now preserved as JSON array + output = json.loads(func_output["item"]["output"]) + assert output == [{"text": "Sunny, 72°F"}] await model.stop() @@ -399,7 +401,9 @@ async def test_send_all_content_types(mock_websockets_connect, model): item = item_create[-1].get("item", {}) assert item.get("type") == "function_call_output" assert item.get("call_id") == "tool-123" - assert item.get("output") == "Result: 42" + # Content is now preserved as JSON array + output = json.loads(item.get("output")) + assert output == [{"text": "Result: 42"}] # Test tool result with JSON content tool_result_json: ToolResult = { @@ -414,8 +418,9 @@ async def test_send_all_content_types(mock_websockets_connect, model): item = item_create[-1].get("item", {}) assert item.get("type") == "function_call_output" assert item.get("call_id") == "tool-456" - # JSON should be serialized - assert json.loads(item.get("output")) == {"result": 42, "status": "ok"} + # Content is now preserved as JSON array + output = json.loads(item.get("output")) + assert output == [{"json": {"result": 42, "status": "ok"}}] # Test tool result with multiple content blocks tool_result_multi: ToolResult = { @@ -430,11 +435,9 @@ async def test_send_all_content_types(mock_websockets_connect, model): item = item_create[-1].get("item", {}) assert item.get("type") == "function_call_output" assert item.get("call_id") == "tool-789" - # Multiple parts should be joined with newlines - output = item.get("output") - assert "Part 1" in output - assert '"data": "value"' in output or "'data': 'value'" in output - assert "Part 2" in output + # Content is now preserved as JSON array + output = json.loads(item.get("output")) + assert output == [{"text": "Part 1"}, {"json": {"data": "value"}}, {"text": "Part 2"}] # Test tool result with image content (should raise error) tool_result_image: ToolResult = { @@ -442,7 +445,7 @@ async def test_send_all_content_types(mock_websockets_connect, model): "status": "success", "content": [{"image": {"format": "jpeg", "source": {"bytes": b"image_data"}}}], } - with pytest.raises(ValueError, match=r"Image content.*not supported"): + with pytest.raises(ValueError, match=r"Content type not supported by OpenAI Realtime API"): await model.send(ToolResultEvent(tool_result_image)) # Test tool result with document content (should raise error) @@ -451,7 +454,7 @@ async def test_send_all_content_types(mock_websockets_connect, model): "status": "success", "content": [{"document": {"format": "pdf", "source": {"bytes": b"doc_data"}}}], } - with pytest.raises(ValueError, match=r"Document content.*not supported"): + with pytest.raises(ValueError, match=r"Content type not supported by OpenAI Realtime API"): await model.send(ToolResultEvent(tool_result_doc)) await model.stop() @@ -498,7 +501,7 @@ async def test_receive_lifecycle_events(mock_websockets_connect, model): # First event should be connection start (new TypedEvent format) assert first_event.get("type") == "bidi_connection_start" assert first_event.get("connection_id") == model._connection_id - assert first_event.get("model") == model.model + assert first_event.get("model") == model.model_id # Close to trigger session end await model.stop() @@ -790,7 +793,9 @@ async def test_tool_result_single_text_content(mock_websockets_connect, api_key) item = item_create[-1].get("item", {}) assert item.get("type") == "function_call_output" assert item.get("call_id") == "call-123" - assert item.get("output") == "Simple text result" + # Content is now preserved as JSON array + output = json.loads(item.get("output")) + assert output == [{"text": "Simple text result"}] await model.stop() @@ -818,10 +823,9 @@ async def test_tool_result_single_json_content(mock_websockets_connect, api_key) item = item_create[-1].get("item", {}) assert item.get("type") == "function_call_output" assert item.get("call_id") == "call-456" - # JSON should be serialized as string - output = item.get("output") - parsed = json.loads(output) - assert parsed == {"temperature": 72, "condition": "sunny"} + # Content is now preserved as JSON array + output = json.loads(item.get("output")) + assert output == [{"json": {"temperature": 72, "condition": "sunny"}}] await model.stop() @@ -853,12 +857,13 @@ async def test_tool_result_multiple_content_blocks(mock_websockets_connect, api_ item = item_create[-1].get("item", {}) assert item.get("type") == "function_call_output" assert item.get("call_id") == "call-789" - # Multiple parts should be joined with newlines - output = item.get("output") - assert "Weather data:" in output - assert "temp" in output - assert "humidity" in output - assert "Forecast: sunny" in output + # Content is now preserved as JSON array + output = json.loads(item.get("output")) + assert output == [ + {"text": "Weather data:"}, + {"json": {"temp": 72, "humidity": 65}}, + {"text": "Forecast: sunny"}, + ] await model.stop() @@ -876,7 +881,7 @@ async def test_tool_result_image_content_raises_error(mock_websockets_connect, a "content": [{"image": {"format": "jpeg", "source": {"bytes": b"fake_image_data"}}}], } - with pytest.raises(ValueError, match=r"Image content.*not supported.*OpenAI Realtime API"): + with pytest.raises(ValueError, match=r"Content type not supported by OpenAI Realtime API"): await model.send(ToolResultEvent(tool_result)) await model.stop() @@ -895,7 +900,7 @@ async def test_tool_result_document_content_raises_error(mock_websockets_connect "content": [{"document": {"format": "pdf", "source": {"bytes": b"fake_pdf_data"}}}], } - with pytest.raises(ValueError, match=r"Document content.*not supported.*OpenAI Realtime API"): + with pytest.raises(ValueError, match=r"Content type not supported by OpenAI Realtime API"): await model.send(ToolResultEvent(tool_result)) await model.stop() From f03b73b7e8697be4891caa98103b521e34dc7978 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 26 Nov 2025 13:45:23 -0500 Subject: [PATCH 201/242] audio io - fix hanging on stop (#80) --- scripts/bidi/test_bidi_novasonic.py | 4 +- scripts/bidi/test_bidi_openai.py | 2 +- scripts/bidi/test_gemini_live.py | 2 +- src/strands/experimental/bidi/agent/agent.py | 7 +- src/strands/experimental/bidi/io/audio.py | 255 +++++++------ src/strands/experimental/bidi/io/text.py | 2 +- .../experimental/bidi/models/bidi_model.py | 1 - .../experimental/bidi/models/gemini_live.py | 6 +- .../experimental/bidi/models/novasonic.py | 6 +- .../experimental/bidi/models/openai.py | 8 +- src/strands/experimental/bidi/types/io.py | 3 +- src/strands/tools/_caller.py | 3 +- .../experimental/bidi/io/test_audio.py | 342 +++++++----------- .../experimental/bidi/models/test_openai.py | 28 +- 14 files changed, 316 insertions(+), 353 deletions(-) diff --git a/scripts/bidi/test_bidi_novasonic.py b/scripts/bidi/test_bidi_novasonic.py index 2ed62e455..baa39226f 100644 --- a/scripts/bidi/test_bidi_novasonic.py +++ b/scripts/bidi/test_bidi_novasonic.py @@ -46,7 +46,7 @@ async def play(context): channels=1, format=pyaudio.paInt16, output=True, - rate=24000, + rate=16000, frames_per_buffer=1024, ) @@ -216,7 +216,7 @@ async def main(duration=180): "active": True, "audio_in": asyncio.Queue(), "audio_out": asyncio.Queue(), - "connection": agent._agent_loop, + "connection": agent._loop, "duration": duration, "start_time": time.time(), "interrupted": False, diff --git a/scripts/bidi/test_bidi_openai.py b/scripts/bidi/test_bidi_openai.py index 807629feb..50d2d2f55 100644 --- a/scripts/bidi/test_bidi_openai.py +++ b/scripts/bidi/test_bidi_openai.py @@ -233,7 +233,7 @@ async def main(): # Create OpenAI model model = BidiOpenAIRealtimeModel( - model="gpt-4o-realtime-preview", + model_id="gpt-4o-realtime-preview", api_key=api_key, session={ "output_modalities": ["audio"], diff --git a/scripts/bidi/test_gemini_live.py b/scripts/bidi/test_gemini_live.py index 31dfc6af0..656ca6dcd 100644 --- a/scripts/bidi/test_gemini_live.py +++ b/scripts/bidi/test_gemini_live.py @@ -316,7 +316,7 @@ async def main(duration=180): "active": True, "audio_in": asyncio.Queue(), "audio_out": asyncio.Queue(), - "connection": agent._agent_loop, + "connection": agent._loop, "duration": duration, "start_time": time.time(), "interrupted": False, diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 16e5a4b90..600330bb4 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -25,8 +25,8 @@ from ....tools.executors._executor import ToolExecutor from ....tools.registry import ToolRegistry from ....tools.watcher import ToolWatcher -from ....types.content import ContentBlock, Message, Messages -from ....types.tools import AgentTool, ToolResult, ToolUse +from ....types.content import Messages +from ....types.tools import AgentTool from ...hooks.events import BidiAgentInitializedEvent from ...tools import ToolProvider from .._async import stop_all @@ -176,7 +176,6 @@ def tool_names(self) -> list[str]: all_tools = self.tool_registry.get_all_tools_config() return list(all_tools.keys()) - async def start(self, invocation_state: dict[str, Any] | None = None) -> None: """Start a persistent bidirectional conversation connection. @@ -332,7 +331,7 @@ async def run( outputs=[audio_io.output(), text_io.output()], invocation_state={"user_id": "user_123"} ) - + # Using custom audio config: model = BidiNovaSonicModel(config={"audio": {"input_rate": 48000, "output_rate": 24000}}) audio_io = BidiAudioIO() diff --git a/src/strands/experimental/bidi/io/audio.py b/src/strands/experimental/bidi/io/audio.py index 744404882..b5404b749 100644 --- a/src/strands/experimental/bidi/io/audio.py +++ b/src/strands/experimental/bidi/io/audio.py @@ -3,13 +3,13 @@ Reads user audio from input device and sends agent audio to output device using PyAudio. If a user interrupts the agent, the output buffer is cleared to stop playback. -Audio configuration is provided by the model via agent.model.audio_config. +Audio configuration is provided by the model via agent.model.config["audio"]. """ import asyncio import base64 import logging -from collections import deque +import queue from typing import TYPE_CHECKING, Any import pyaudio @@ -23,35 +23,112 @@ logger = logging.getLogger(__name__) +class _BidiAudioBuffer: + """Buffer chunks of audio data between agent and PyAudio.""" + + _buffer: queue.Queue + _data: bytearray + + def __init__(self, size: int | None = None): + """Initialize buffer settings. + + Args: + size: Size of the buffer (default: unbounded). + """ + self._size = size or 0 + + def start(self) -> None: + """Setup buffer.""" + self._buffer = queue.Queue(self._size) + self._data = bytearray() + + def stop(self) -> None: + """Tear down buffer.""" + if hasattr(self, "_data"): + self._data.clear() + if hasattr(self, "_buffer"): + # Unblocking waited get calls by putting an empty chunk + # Note, Queue.shutdown exists but is a 3.13+ only feature + # We simulate shutdown with the below logic + self._buffer.put_nowait(b"") + self._buffer = queue.Queue(self._size) + + def put(self, chunk: bytes) -> None: + """Put data chunk into buffer. + + If full, removes the oldest chunk. + """ + if self._buffer.full(): + logger.debug("buffer is full | removing oldest chunk") + try: + self._buffer.get_nowait() + except queue.Empty: + logger.debug("buffer already empty") + pass + + self._buffer.put_nowait(chunk) + + def get(self, byte_count: int | None = None) -> bytes: + """Get the number of bytes specified from the buffer. + + Args: + byte_count: Number of bytes to get from buffer. + - If the number of bytes specified is not available, the return is padded with silence. + - If the number of bytes is not specified, get the first chunk put in the buffer. + + Returns: + Specified number of bytes. + """ + if not byte_count: + self._data.extend(self._buffer.get()) + byte_count = len(self._data) + + while len(self._data) < byte_count: + try: + self._data.extend(self._buffer.get_nowait()) + except queue.Empty: + break + + padding_bytes = b"\x00" * max(byte_count - len(self._data), 0) + self._data.extend(padding_bytes) + + data = self._data[:byte_count] + del self._data[:byte_count] + + return bytes(data) + + def clear(self) -> None: + """Clear the buffer.""" + while True: + try: + self._buffer.get_nowait() + except queue.Empty: + break + + class _BidiAudioInput(BidiInput): """Handle audio input from user. Attributes: _audio: PyAudio instance for audio system access. _stream: Audio input stream. + _buffer: Buffer for sharing audio data between agent and PyAudio. """ _audio: pyaudio.PyAudio _stream: pyaudio.Stream - # Audio device constants - _DEVICE_INDEX: int | None = None - _PYAUDIO_FORMAT: int = pyaudio.paInt16 - _FRAMES_PER_BUFFER: int = 512 + _BUFFER_SIZE = None + _DEVICE_INDEX = None + _FRAMES_PER_BUFFER = 512 def __init__(self, config: dict[str, Any]) -> None: - """Initialize audio input handler. - - Args: - config: Configuration dictionary with optional overrides: - - input_device_index: Specific input device to use - - input_pyaudio_format: PyAudio format (default: paInt16) - - input_frames_per_buffer: Number of frames per buffer - """ - # Initialize instance variables from config or class constants - self._device_index = config.get("input_device_index", self._DEVICE_INDEX) - self._pyaudio_format = config.get("input_pyaudio_format", self._PYAUDIO_FORMAT) - self._frames_per_buffer = config.get("input_frames_per_buffer", self._FRAMES_PER_BUFFER) + """Extract configs.""" + self._buffer_size = config.get("input_buffer_size", _BidiAudioInput._BUFFER_SIZE) + self._device_index = config.get("input_device_index", _BidiAudioInput._DEVICE_INDEX) + self._frames_per_buffer = config.get("input_frames_per_buffer", _BidiAudioInput._FRAMES_PER_BUFFER) + + self._buffer = _BidiAudioBuffer(self._buffer_size) async def start(self, agent: "BidiAgent") -> None: """Start input stream. @@ -59,51 +136,55 @@ async def start(self, agent: "BidiAgent") -> None: Args: agent: The BidiAgent instance, providing access to model configuration. """ - # Get audio parameters from model config - self._rate = agent.model.config["audio"]["input_rate"] + logger.debug("starting audio input stream") + self._channels = agent.model.config["audio"]["channels"] - self._format = agent.model.config["audio"].get("format", "pcm") # Encoding format for events + self._format = agent.model.config["audio"]["format"] + self._rate = agent.model.config["audio"]["input_rate"] - logger.debug( - "rate=<%d>, channels=<%d>, device_index=<%s> | starting audio input stream", - self._rate, - self._channels, - self._device_index, - ) + self._buffer.start() self._audio = pyaudio.PyAudio() self._stream = self._audio.open( channels=self._channels, - format=self._pyaudio_format, + format=pyaudio.paInt16, frames_per_buffer=self._frames_per_buffer, input=True, input_device_index=self._device_index, rate=self._rate, + stream_callback=self._callback, ) - logger.info("rate=<%d>, channels=<%d> | audio input stream started", self._rate, self._channels) + + logger.debug("audio input stream started") async def stop(self) -> None: """Stop input stream.""" logger.debug("stopping audio input stream") - # TODO: Provide time for streaming thread to exit cleanly to prevent conflicts with the Nova threads. - # See if we can remove after properly handling cancellation for agent. - await asyncio.sleep(0.1) - self._stream.close() - self._audio.terminate() + if hasattr(self, "_stream"): + self._stream.close() + if hasattr(self, "_audio"): + self._audio.terminate() + if hasattr(self, "_buffer"): + self._buffer.stop() logger.debug("audio input stream stopped") async def __call__(self) -> BidiAudioInputEvent: """Read audio from input stream.""" - audio_bytes = await asyncio.to_thread(self._stream.read, self._frames_per_buffer, exception_on_overflow=False) + data = await asyncio.to_thread(self._buffer.get) return BidiAudioInputEvent( - audio=base64.b64encode(audio_bytes).decode("utf-8"), + audio=base64.b64encode(data).decode("utf-8"), channels=self._channels, format=self._format, sample_rate=self._rate, ) + def _callback(self, in_data: bytes, *_: Any) -> tuple[None, Any]: + """Callback to receive audio data from PyAudio.""" + self._buffer.put(in_data) + return (None, pyaudio.paContinue) + class _BidiAudioOutput(BidiOutput): """Handle audio output from bidi agent. @@ -111,38 +192,23 @@ class _BidiAudioOutput(BidiOutput): Attributes: _audio: PyAudio instance for audio system access. _stream: Audio output stream. - _buffer: Deque buffer for queuing audio data. - _buffer_event: Event to signal when buffer has data. - _output_task: Background task for processing audio output. + _buffer: Buffer for sharing audio data between agent and PyAudio. """ _audio: pyaudio.PyAudio _stream: pyaudio.Stream - _buffer: deque - _buffer_event: asyncio.Event - _output_task: asyncio.Task - # Audio device constants - _BUFFER_SIZE: int | None = None - _DEVICE_INDEX: int | None = None - _PYAUDIO_FORMAT: int = pyaudio.paInt16 - _FRAMES_PER_BUFFER: int = 512 + _BUFFER_SIZE = None + _DEVICE_INDEX = None + _FRAMES_PER_BUFFER = 512 def __init__(self, config: dict[str, Any]) -> None: - """Initialize audio output handler. - - Args: - config: Configuration dictionary with optional overrides: - - output_device_index: Specific output device to use - - output_pyaudio_format: PyAudio format (default: paInt16) - - output_frames_per_buffer: Number of frames per buffer - - output_buffer_size: Maximum buffer size (None = unlimited) - """ - # Initialize instance variables from config or class constants - self._buffer_size = config.get("output_buffer_size", self._BUFFER_SIZE) - self._device_index = config.get("output_device_index", self._DEVICE_INDEX) - self._pyaudio_format = config.get("output_pyaudio_format", self._PYAUDIO_FORMAT) - self._frames_per_buffer = config.get("output_frames_per_buffer", self._FRAMES_PER_BUFFER) + """Extract configs.""" + self._buffer_size = config.get("output_buffer_size", _BidiAudioOutput._BUFFER_SIZE) + self._device_index = config.get("output_device_index", _BidiAudioOutput._DEVICE_INDEX) + self._frames_per_buffer = config.get("output_frames_per_buffer", _BidiAudioOutput._FRAMES_PER_BUFFER) + + self._buffer = _BidiAudioBuffer(self._buffer_size) async def start(self, agent: "BidiAgent") -> None: """Start output stream. @@ -150,66 +216,54 @@ async def start(self, agent: "BidiAgent") -> None: Args: agent: The BidiAgent instance, providing access to model configuration. """ - # Get audio parameters from model config - self._rate = agent.model.config["audio"]["output_rate"] + logger.debug("starting audio output stream") + self._channels = agent.model.config["audio"]["channels"] + self._rate = agent.model.config["audio"]["output_rate"] - logger.debug( - "rate=<%d>, channels=<%d> | starting audio output stream", - self._rate, - self._channels, - ) + self._buffer.start() self._audio = pyaudio.PyAudio() self._stream = self._audio.open( channels=self._channels, - format=self._pyaudio_format, + format=pyaudio.paInt16, frames_per_buffer=self._frames_per_buffer, output=True, output_device_index=self._device_index, rate=self._rate, + stream_callback=self._callback, ) - self._buffer = deque(maxlen=self._buffer_size) - self._buffer_event = asyncio.Event() - self._output_task = asyncio.create_task(self._output()) - logger.info("rate=<%d>, channels=<%d> | audio output stream started", self._rate, self._channels) + + logger.debug("audio output stream started") async def stop(self) -> None: """Stop output stream.""" logger.debug("stopping audio output stream") - self._buffer.clear() - self._buffer.append(None) - self._buffer_event.set() - await self._output_task - self._stream.close() - self._audio.terminate() + if hasattr(self, "_stream"): + self._stream.close() + if hasattr(self, "_audio"): + self._audio.terminate() + if hasattr(self, "_buffer"): + self._buffer.stop() logger.debug("audio output stream stopped") async def __call__(self, event: BidiOutputEvent) -> None: - """Handle audio events with direct stream writing.""" + """Send audio to output stream.""" if isinstance(event, BidiAudioStreamEvent): - audio_bytes = base64.b64decode(event["audio"]) - self._buffer.append(audio_bytes) - self._buffer_event.set() - logger.debug("audio_bytes=<%d> | audio chunk buffered for playback", len(audio_bytes)) + data = base64.b64decode(event["audio"]) + self._buffer.put(data) + logger.debug("audio_bytes=<%d> | audio chunk buffered for playback", len(data)) elif isinstance(event, BidiInterruptionEvent): logger.debug("reason=<%s> | clearing audio buffer due to interruption", event["reason"]) self._buffer.clear() - self._buffer_event.clear() - - async def _output(self) -> None: - while True: - await self._buffer_event.wait() - self._buffer_event.clear() - - while self._buffer: - audio_bytes = self._buffer.popleft() - if not audio_bytes: - return - await asyncio.to_thread(self._stream.write, audio_bytes) + def _callback(self, _in_data: None, frame_count: int, *_: Any) -> tuple[bytes, Any]: + """Callback to send audio data to PyAudio.""" + byte_count = frame_count * pyaudio.get_sample_size(pyaudio.paInt16) + data = self._buffer.get(byte_count) + return (data, pyaudio.paContinue) class BidiAudioIO: @@ -217,16 +271,15 @@ class BidiAudioIO: def __init__(self, **config: Any) -> None: """Initialize audio devices. - + Args: **config: Optional device configuration: + - input_buffer_size (int): Maximum input buffer size (default: None) - input_device_index (int): Specific input device (default: None = system default) - - output_device_index (int): Specific output device (default: None = system default) - - input_pyaudio_format (int): PyAudio format for input (default: pyaudio.paInt16) - - output_pyaudio_format (int): PyAudio format for output (default: pyaudio.paInt16) - input_frames_per_buffer (int): Input buffer size (default: 512) + - output_buffer_size (int): Maximum output buffer size (default: None) + - output_device_index (int): Specific output device (default: None = system default) - output_frames_per_buffer (int): Output buffer size (default: 512) - - output_buffer_size (int | None): Max output queue size (default: None = unlimited) """ self._config = config diff --git a/src/strands/experimental/bidi/io/text.py b/src/strands/experimental/bidi/io/text.py index 7eadcb341..e123de766 100644 --- a/src/strands/experimental/bidi/io/text.py +++ b/src/strands/experimental/bidi/io/text.py @@ -2,8 +2,8 @@ import asyncio import logging -from typing import TYPE_CHECKING import sys +from typing import TYPE_CHECKING from ..types.events import BidiInterruptionEvent, BidiOutputEvent, BidiTextInputEvent, BidiTranscriptStreamEvent from ..types.io import BidiInput, BidiOutput diff --git a/src/strands/experimental/bidi/models/bidi_model.py b/src/strands/experimental/bidi/models/bidi_model.py index 253a5d440..1fb765bd8 100644 --- a/src/strands/experimental/bidi/models/bidi_model.py +++ b/src/strands/experimental/bidi/models/bidi_model.py @@ -22,7 +22,6 @@ BidiInputEvent, BidiOutputEvent, ) -from ..types.bidi_model import AudioConfig logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index b9f88b717..88484275d 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -24,6 +24,7 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from .._async import stop_all +from ..types.bidi_model import AudioConfig from ..types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, @@ -39,7 +40,6 @@ ModalityUsage, SampleRate, ) -from ..types.bidi_model import AudioConfig from .bidi_model import BidiModel logger = logging.getLogger(__name__) @@ -116,7 +116,9 @@ def __init__( if self.live_config and "speech_config" in self.live_config: speech_config = self.live_config["speech_config"] if isinstance(speech_config, dict): - live_config_voice = speech_config.get("voice_config", {}).get("prebuilt_voice_config", {}).get("voice_name") + live_config_voice = ( + speech_config.get("voice_config", {}).get("prebuilt_voice_config", {}).get("voice_name") + ) # Define default audio configuration default_audio_config: AudioConfig = { diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 818f627f9..4c89b9cf3 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -17,7 +17,7 @@ import json import logging import uuid -from typing import Any, AsyncGenerator, cast, Literal +from typing import Any, AsyncGenerator, Literal, cast import boto3 from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput @@ -34,6 +34,7 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from .._async import stop_all +from ..types.bidi_model import AudioConfig from ..types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, @@ -48,7 +49,6 @@ BidiUsageEvent, SampleRate, ) -from ..types.bidi_model import AudioConfig from .bidi_model import BidiModel logger = logging.getLogger(__name__) @@ -465,7 +465,7 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N audio=audio_content, format="pcm", sample_rate=cast(SampleRate, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]), - channels=1, + channels=channels, ) # Handle text output (transcripts) diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 6fc4f458d..c70536174 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -8,7 +8,7 @@ import logging import os import uuid -from typing import Any, AsyncGenerator, cast, Literal +from typing import Any, AsyncGenerator, Literal, cast import websockets from websockets import ClientConnection @@ -17,6 +17,7 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from .._async import stop_all +from ..types.bidi_model import AudioConfig from ..types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, @@ -34,7 +35,6 @@ SampleRate, StopReason, ) -from ..types.bidi_model import AudioConfig from .bidi_model import BidiModel logger = logging.getLogger(__name__) @@ -88,7 +88,7 @@ def __init__( """Initialize OpenAI Realtime bidirectional model. Args: - model: OpenAI model identifier (default: gpt-realtime). + model_id: OpenAI model identifier (default: gpt-realtime). api_key: OpenAI API key for authentication. organization: OpenAI organization ID for API requests. project: OpenAI project ID for API requests. @@ -338,7 +338,7 @@ def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutput audio=openai_event["delta"], format="pcm", sample_rate=cast(SampleRate, AUDIO_FORMAT["rate"]), - channels=1, + channels=channels, ) ] diff --git a/src/strands/experimental/bidi/types/io.py b/src/strands/experimental/bidi/types/io.py index 7125eb5ef..bdb7d9c9d 100644 --- a/src/strands/experimental/bidi/types/io.py +++ b/src/strands/experimental/bidi/types/io.py @@ -5,8 +5,7 @@ by separating input and output concerns into independent callables. """ -from typing import TYPE_CHECKING, Awaitable, Protocol -from typing import Awaitable, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Awaitable, Protocol, runtime_checkable from ..types.events import BidiInputEvent, BidiOutputEvent diff --git a/src/strands/tools/_caller.py b/src/strands/tools/_caller.py index 94fdcfec4..3ab576947 100644 --- a/src/strands/tools/_caller.py +++ b/src/strands/tools/_caller.py @@ -19,12 +19,13 @@ if TYPE_CHECKING: from ..agent import Agent + from ..experimental.bidi.agent import BidiAgent class _ToolCaller: """Call tool as a function.""" - def __init__(self, agent: "Agent") -> None: + def __init__(self, agent: "Agent | BidiAgent") -> None: """Initialize instance. Args: diff --git a/tests/strands/experimental/bidi/io/test_audio.py b/tests/strands/experimental/bidi/io/test_audio.py index d22ce5d39..459faa78a 100644 --- a/tests/strands/experimental/bidi/io/test_audio.py +++ b/tests/strands/experimental/bidi/io/test_audio.py @@ -1,265 +1,175 @@ -import asyncio import base64 import unittest.mock +import pyaudio import pytest +import pytest_asyncio -from strands.experimental.bidi.io import BidiAudioIO -from strands.experimental.bidi.types.events import BidiAudioInputEvent, BidiAudioStreamEvent +from strands.experimental.bidi.io.audio import BidiAudioIO, _BidiAudioBuffer +from strands.experimental.bidi.types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiInterruptionEvent @pytest.fixture -def py_audio(): - with unittest.mock.patch("strands.experimental.bidi.io.audio.pyaudio") as mock: - yield mock.PyAudio() +def audio_buffer(): + buffer = _BidiAudioBuffer(size=1) + buffer.start() + yield buffer + buffer.stop() @pytest.fixture -def audio_io(): - return BidiAudioIO() - - -@pytest.fixture -def mock_agent(): - """Create a mock agent with model that has default audio_config.""" - agent = unittest.mock.MagicMock() - agent.model.audio_config = { - "input_rate": 16000, - "output_rate": 16000, - "channels": 1, - "format": "pcm", - "voice": "matthew", +def agent(): + mock = unittest.mock.MagicMock() + mock.model.config = { + "audio": { + "input_rate": 24000, + "output_rate": 16000, + "channels": 2, + "format": "test-format", + "voice": "test-voice", + }, } - return agent + return mock @pytest.fixture -def mock_agent_custom_config(): - """Create a mock agent with custom audio_config.""" - agent = unittest.mock.MagicMock() - agent.model.audio_config = { - "input_rate": 48000, - "output_rate": 24000, - "channels": 2, - "format": "pcm", - "voice": "alloy", - } - return agent +def py_audio(): + with unittest.mock.patch("strands.experimental.bidi.io.audio.pyaudio.PyAudio") as mock: + yield mock.return_value @pytest.fixture -def audio_input(audio_io): - return audio_io.input() - +def config(): + return { + "input_buffer_size": 1, + "input_device_index": 1, + "input_frames_per_buffer": 1024, + "output_buffer_size": 2, + "output_device_index": 2, + "output_frames_per_buffer": 2048, + } @pytest.fixture -def audio_output(audio_io): - return audio_io.output() - - -@pytest.mark.asyncio -async def test_bidi_audio_io_input(py_audio, audio_input, mock_agent): - """Test basic audio input functionality.""" - microphone = unittest.mock.Mock() - microphone.read.return_value = b"test-audio" - - py_audio.open.return_value = microphone - - await audio_input.start(mock_agent) - tru_event = await audio_input() - await audio_input.stop() - - exp_event = BidiAudioInputEvent( - audio=base64.b64encode(b"test-audio").decode("utf-8"), - channels=1, - format="pcm", - sample_rate=16000, - ) - assert tru_event == exp_event - - microphone.read.assert_called_once_with(512, exception_on_overflow=False) - - -@pytest.mark.asyncio -async def test_bidi_audio_io_output(py_audio, audio_output, mock_agent): - """Test basic audio output functionality.""" - write_future = asyncio.Future() - write_event = asyncio.Event() - - def write(data): - write_future.set_result(data) - write_event.set() - - speaker = unittest.mock.Mock() - speaker.write.side_effect = write - - py_audio.open.return_value = speaker - - await audio_output.start(mock_agent) - - audio_event = BidiAudioStreamEvent( - audio=base64.b64encode(b"test-audio").decode("utf-8"), - channels=1, - format="pcm", - sample_rate=1600, - ) - await audio_output(audio_event) - await write_event.wait() - - await audio_output.stop() - - speaker.write.assert_called_once_with(write_future.result()) - - -# Audio Configuration Tests - +def audio_io(py_audio, config): + _ = py_audio + return BidiAudioIO(**config) -@pytest.mark.asyncio -async def test_audio_input_uses_model_config(py_audio, audio_io, mock_agent): - """Test that audio input uses model's audio_config.""" - audio_input = audio_io.input() - microphone = unittest.mock.Mock() - microphone.read.return_value = b"test-audio" - py_audio.open.return_value = microphone +@pytest_asyncio.fixture +async def audio_input(audio_io, agent): + input_ = audio_io.input() + await input_.start(agent) + yield input_ + await input_.stop() - await audio_input.start(mock_agent) - # Model config should be used - py_audio.open.assert_called_once() - call_kwargs = py_audio.open.call_args.kwargs - assert call_kwargs["rate"] == 16000 # From mock_agent.model.audio_config - assert call_kwargs["channels"] == 1 # From mock_agent.model.audio_config +@pytest_asyncio.fixture +async def audio_output(audio_io, agent): + output = audio_io.output() + await output.start(agent) + yield output + await output.stop() - await audio_input.stop() +def test_bidi_audio_buffer_put(audio_buffer): + audio_buffer.put(b"test-chunk") -@pytest.mark.asyncio -async def test_audio_input_uses_custom_model_config(py_audio, audio_io, mock_agent_custom_config): - """Test that audio input uses custom model audio_config.""" - audio_input = audio_io.input() + tru_chunk = audio_buffer.get() + exp_chunk = b"test-chunk" + assert tru_chunk == exp_chunk - microphone = unittest.mock.Mock() - microphone.read.return_value = b"test-audio" - py_audio.open.return_value = microphone - await audio_input.start(mock_agent_custom_config) +def test_bidi_audio_buffer_put_full(audio_buffer): + audio_buffer.put(b"test-chunk-1") + audio_buffer.put(b"test-chunk-2") - # Custom model config should be used - py_audio.open.assert_called_once() - call_kwargs = py_audio.open.call_args.kwargs - assert call_kwargs["rate"] == 48000 # From custom config - assert call_kwargs["channels"] == 2 # From custom config + tru_chunk = audio_buffer.get() + exp_chunk = b"test-chunk-2" + assert tru_chunk == exp_chunk - await audio_input.stop() +def test_bidi_audio_buffer_get_padding(audio_buffer): + audio_buffer.put(b"test-chunk") -@pytest.mark.asyncio -async def test_audio_output_uses_model_config(py_audio, audio_io, mock_agent): - """Test that audio output uses model's audio_config.""" - audio_output = audio_io.output() + tru_chunk = audio_buffer.get(11) + exp_chunk = b"test-chunk\x00" + assert tru_chunk == exp_chunk - speaker = unittest.mock.Mock() - py_audio.open.return_value = speaker - await audio_output.start(mock_agent) +def test_bidi_audio_buffer_clear(audio_buffer): + audio_buffer.put(b"test-chunk") + audio_buffer.clear() - # Model config should be used - py_audio.open.assert_called_once() - call_kwargs = py_audio.open.call_args.kwargs - assert call_kwargs["rate"] == 16000 # From mock_agent.model.audio_config - assert call_kwargs["channels"] == 1 # From mock_agent.model.audio_config - - await audio_output.stop() + tru_byte = audio_buffer.get(1) + exp_byte = b"\x00" + assert tru_byte == exp_byte @pytest.mark.asyncio -async def test_audio_output_uses_custom_model_config(py_audio, audio_io, mock_agent_custom_config): - """Test that audio output uses custom model audio_config.""" - audio_output = audio_io.output() - - speaker = unittest.mock.Mock() - py_audio.open.return_value = speaker - - await audio_output.start(mock_agent_custom_config) - - # Custom model config should be used - py_audio.open.assert_called_once() - call_kwargs = py_audio.open.call_args.kwargs - assert call_kwargs["rate"] == 24000 # From custom config - assert call_kwargs["channels"] == 2 # From custom config - - await audio_output.stop() - - -# Device Configuration Tests +async def test_bidi_audio_io_input(audio_input): + audio_input._callback(b"test-audio") + tru_event = await audio_input() + exp_event = BidiAudioInputEvent( + audio=base64.b64encode(b"test-audio").decode("utf-8"), + channels=2, + format="test-format", + sample_rate=24000, + ) + assert tru_event == exp_event -@pytest.mark.asyncio -async def test_audio_input_respects_user_device_config(py_audio, mock_agent): - """Test that user-provided device config overrides defaults.""" - audio_io = BidiAudioIO(input_device_index=5, input_frames_per_buffer=1024) - audio_input = audio_io.input() - - microphone = unittest.mock.Mock() - microphone.read.return_value = b"test-audio" - py_audio.open.return_value = microphone - - await audio_input.start(mock_agent) - - # User device config should be used - py_audio.open.assert_called_once() - call_kwargs = py_audio.open.call_args.kwargs - assert call_kwargs["input_device_index"] == 5 # User config - assert call_kwargs["frames_per_buffer"] == 1024 # User config - # Model config still used for audio parameters - assert call_kwargs["rate"] == 16000 # From model - assert call_kwargs["channels"] == 1 # From model - await audio_input.stop() +def test_bidi_audio_io_input_configs(py_audio, audio_input): + py_audio.open.assert_called_once_with( + channels=2, + format=pyaudio.paInt16, + frames_per_buffer=1024, + input=True, + input_device_index=1, + rate=24000, + stream_callback=audio_input._callback, + ) @pytest.mark.asyncio -async def test_audio_output_respects_user_device_config(py_audio, mock_agent): - """Test that user-provided device config overrides defaults.""" - audio_io = BidiAudioIO(output_device_index=3, output_frames_per_buffer=2048, output_buffer_size=50) - audio_output = audio_io.output() - - speaker = unittest.mock.Mock() - py_audio.open.return_value = speaker - - await audio_output.start(mock_agent) - - # User device config should be used - py_audio.open.assert_called_once() - call_kwargs = py_audio.open.call_args.kwargs - assert call_kwargs["output_device_index"] == 3 # User config - assert call_kwargs["frames_per_buffer"] == 2048 # User config - # Model config still used for audio parameters - assert call_kwargs["rate"] == 16000 # From model - assert call_kwargs["channels"] == 1 # From model - # Buffer size should be set - assert audio_output._buffer_size == 50 # User config +async def test_bidi_audio_io_output(audio_output): + audio_event = BidiAudioStreamEvent( + audio=base64.b64encode(b"test-audio").decode("utf-8"), + channels=2, + format="test-format", + sample_rate=16000, + ) + await audio_output(audio_event) - await audio_output.stop() + tru_data, _ = audio_output._callback(None, frame_count=4) + exp_data = b"test-aud" + assert tru_data == exp_data @pytest.mark.asyncio -async def test_audio_io_uses_defaults_when_no_config(py_audio, mock_agent): - """Test that defaults are used when no config provided.""" - audio_io = BidiAudioIO() # No config - audio_input = audio_io.input() - - microphone = unittest.mock.Mock() - microphone.read.return_value = b"test-audio" - py_audio.open.return_value = microphone - - await audio_input.start(mock_agent) - - # Defaults should be used - py_audio.open.assert_called_once() - call_kwargs = py_audio.open.call_args.kwargs - assert call_kwargs["input_device_index"] is None # Default - assert call_kwargs["frames_per_buffer"] == 512 # Default - - await audio_input.stop() +async def test_bidi_audio_io_output_interrupt(audio_output): + audio_event = BidiAudioStreamEvent( + audio=base64.b64encode(b"test-audio").decode("utf-8"), + channels=2, + format="test-format", + sample_rate=16000, + ) + await audio_output(audio_event) + interrupt_event = BidiInterruptionEvent(reason="user_speech") + await audio_output(interrupt_event) + + tru_data, _ = audio_output._callback(None, frame_count=1) + exp_data = b"\x00\x00" + assert tru_data == exp_data + + +def test_bidi_audio_io_output_configs(py_audio, audio_output): + py_audio.open.assert_called_once_with( + channels=2, + format=pyaudio.paInt16, + frames_per_buffer=2048, + output=True, + output_device_index=2, + rate=16000, + stream_callback=audio_output._callback, + ) diff --git a/tests/strands/experimental/bidi/models/test_openai.py b/tests/strands/experimental/bidi/models/test_openai.py index ee3dd45c9..778c7dcbd 100644 --- a/tests/strands/experimental/bidi/models/test_openai.py +++ b/tests/strands/experimental/bidi/models/test_openai.py @@ -91,16 +91,16 @@ def test_model_initialization(api_key, model_name): """Test model initialization with various configurations.""" # Test default config model_default = BidiOpenAIRealtimeModel(api_key="test-key") - assert model_default.model == "gpt-realtime" + assert model_default.model_id == "gpt-realtime" assert model_default.api_key == "test-key" # Test with custom model - model_custom = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) - assert model_custom.model == model_name + model_custom = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key) + assert model_custom.model_id == model_name assert model_custom.api_key == api_key # Test with organization and project - model_org = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key, organization="org-123", project="proj-456") + model_org = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key, organization="org-123", project="proj-456") assert model_org.organization == "org-123" assert model_org.project == "proj-456" @@ -115,7 +115,7 @@ def test_model_initialization(api_key, model_name): def test_audio_config_defaults(api_key, model_name): """Test default audio configuration.""" - model = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) + model = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key) assert model.config["audio"]["input_rate"] == 24000 assert model.config["audio"]["output_rate"] == 24000 @@ -127,7 +127,7 @@ def test_audio_config_defaults(api_key, model_name): def test_audio_config_partial_override(api_key, model_name): """Test partial audio configuration override.""" config = {"audio": {"output_rate": 48000, "voice": "echo"}} - model = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key, config=config) + model = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key, config=config) # Overridden values assert model.config["audio"]["output_rate"] == 48000 @@ -150,7 +150,7 @@ def test_audio_config_full_override(api_key, model_name): "voice": "shimmer", } } - model = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key, config=config) + model = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key, config=config) assert model.config["audio"]["input_rate"] == 48000 assert model.config["audio"]["output_rate"] == 48000 @@ -165,7 +165,7 @@ def test_audio_config_voice_priority(api_key, model_name): config = {"audio": {"voice": "nova"}} model = BidiOpenAIRealtimeModel( - model=model_name, api_key=api_key, session_config=session_config, config=config + model_id=model_name, api_key=api_key, session_config=session_config, config=config ) # Build config and verify config audio voice takes precedence @@ -177,7 +177,7 @@ def test_audio_config_extracts_voice_from_session_config(api_key, model_name): """Test that voice is extracted from session_config when config audio not provided.""" session_config = {"audio": {"output": {"voice": "fable"}}} - model = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key, session_config=session_config) + model = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key, session_config=session_config) # Should extract voice from session_config assert model.config["audio"]["voice"] == "fable" @@ -257,7 +257,7 @@ async def test_connection_edge_cases(mock_websockets_connect, api_key, model_nam mock_connect, mock_ws = mock_websockets_connect # Test connection error - model1 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) + model1 = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key) mock_connect.side_effect = Exception("Connection failed") with pytest.raises(Exception, match="Connection failed"): await model1.start() @@ -269,18 +269,18 @@ async def async_connect(*args, **kwargs): mock_connect.side_effect = async_connect # Test double connection - model2 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) + model2 = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key) await model2.start() with pytest.raises(RuntimeError, match=r"call stop before starting again"): await model2.start() await model2.stop() # Test close when not connected - model3 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) + model3 = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key) await model3.stop() # Should not raise # Test close error - model4 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) + model4 = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key) await model4.start() mock_ws.close.side_effect = Exception("Close failed") with pytest.raises(ExceptionGroup): @@ -382,7 +382,7 @@ async def test_receive_lifecycle_events(mock_websockets_connect, model): # First event should be connection start (new TypedEvent format) assert first_event.get("type") == "bidi_connection_start" assert first_event.get("connection_id") == model._connection_id - assert first_event.get("model") == model.model + assert first_event.get("model") == model.model_id # Close to trigger session end await model.stop() From a3bc749200a3d8a5cf8e1eb153681d96d8aab27e Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 26 Nov 2025 17:44:31 -0500 Subject: [PATCH 202/242] model provider consistency improvmeents --- .../experimental/bidi/models/gemini_live.py | 58 +++++++------ .../experimental/bidi/models/novasonic.py | 5 +- .../experimental/bidi/models/openai.py | 86 +++++++++++-------- .../bidi/models/test_gemini_live.py | 82 +++++++++++++++--- .../bidi/models/test_novasonic.py | 44 ++++++++++ .../experimental/bidi/models/test_openai.py | 56 ++++++------ 6 files changed, 228 insertions(+), 103 deletions(-) diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index c1a3e29eb..16949c5fd 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -63,7 +63,7 @@ def __init__( model_id: str = "gemini-2.5-flash-native-audio-preview-09-2025", api_key: str | None = None, config: dict[str, Any] | None = None, - live_config: dict[str, Any] | None = None, + provider_config: dict[str, Any] | None = None, **kwargs: Any, ): """Initialize Gemini Live API bidirectional model. @@ -73,7 +73,7 @@ def __init__( api_key: Google AI API key for authentication. config: Optional configuration dictionary with structure {"audio": AudioConfig, ...}. If not provided or if "audio" key is missing, uses Gemini Live API's default audio configuration. - live_config: Gemini Live API configuration parameters (e.g., response_modalities, speech_config). + provider_config: Gemini Live API configuration parameters (e.g., response_modalities, speech_config). **kwargs: Reserved for future parameters. """ # Model configuration @@ -88,10 +88,10 @@ def __init__( } # Merge user config with defaults (user config takes precedence) - if live_config: - default_config.update(live_config) + if provider_config: + default_config.update(provider_config) - self.live_config = default_config + self.provider_config = default_config # Create Gemini client with proper API version client_kwargs: dict[str, Any] = {} @@ -111,14 +111,8 @@ def __init__( # Extract audio config from config dict if provided user_audio_config = config.get("audio", {}) if config else {} - # Extract voice from live_config if provided - live_config_voice = None - if self.live_config and "speech_config" in self.live_config: - speech_config = self.live_config["speech_config"] - if isinstance(speech_config, dict): - live_config_voice = ( - speech_config.get("voice_config", {}).get("prebuilt_voice_config", {}).get("voice_name") - ) + # Extract voice from provider_config if provided + provider_voice = self._extract_voice_from_provider_config() # Define default audio configuration default_audio_config: AudioConfig = { @@ -128,9 +122,9 @@ def __init__( "format": "pcm", } - # Add voice to defaults if configured in live_config - if live_config_voice: - default_audio_config["voice"] = live_config_voice + # Add voice to defaults if configured in provider_config + if provider_voice: + default_audio_config["voice"] = provider_voice # Merge user config with defaults (user values take precedence) merged_audio_config = cast(AudioConfig, {**default_audio_config, **user_audio_config}) @@ -143,6 +137,16 @@ def __init__( else: logger.debug("audio_config | using default Gemini Live audio configuration") + def _extract_voice_from_provider_config(self) -> str | None: + """Extract voice from provider-specific config.""" + if "speech_config" in self.provider_config: + speech_config = self.provider_config["speech_config"] + if isinstance(speech_config, dict): + return (speech_config.get("voice_config", {}) + .get("prebuilt_voice_config", {}) + .get("voice_name")) + return None + async def start( self, system_prompt: str | None = None, @@ -203,7 +207,7 @@ async def _send_message_history(self, messages: Messages) -> None: async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: """Receive Gemini Live API events and convert to provider-agnostic format.""" if not self._connection_id: - raise RuntimeError("model not started | call start before receiving") + raise RuntimeError("model not started | call start before sending/receiving") yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) @@ -274,8 +278,8 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> list[BidiOut BidiAudioStreamEvent( audio=audio_b64, format="pcm", - sample_rate=cast(SampleRate, GEMINI_OUTPUT_SAMPLE_RATE), - channels=cast(Channel, GEMINI_CHANNELS), + sample_rate=cast(SampleRate, self.config["audio"]["output_rate"]), + channels=cast(Channel, self.config["audio"]["channels"]), ) ] @@ -380,7 +384,7 @@ async def send( ValueError: If content type not supported (e.g., image content). """ if not self._connection_id: - raise RuntimeError("model not started | call start before sending") + raise RuntimeError("model not started | call start before sending/receiving") if isinstance(content, BidiTextInputEvent): await self._send_text_content(content.text) @@ -405,7 +409,8 @@ async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: audio_bytes = base64.b64decode(audio_input.audio) # Create audio blob for the SDK - audio_blob = genai_types.Blob(data=audio_bytes, mime_type=f"audio/pcm;rate={GEMINI_INPUT_SAMPLE_RATE}") + mime_type = f"audio/pcm;rate={self.config['audio']['input_rate']}" + audio_blob = genai_types.Blob(data=audio_bytes, mime_type=mime_type) # Send real-time audio input - this automatically handles VAD and interruption await self._live_session.send_realtime_input(audio=audio_blob) @@ -440,7 +445,8 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: if "text" not in block and "json" not in block: # Unsupported content type - raise error raise ValueError( - f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | Content type not supported by Gemini Live API" + f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | " + f"Content type not supported by Gemini Live API" ) # Optimize for single content item - unwrap the array @@ -479,13 +485,13 @@ def _build_live_config( ) -> dict[str, Any]: """Build LiveConnectConfig for the official SDK. - Simply passes through all config parameters from live_config, allowing users + Simply passes through all config parameters from provider_config, allowing users to configure any Gemini Live API parameter directly. """ - # Start with user-provided live_config + # Start with user-provided provider_config config_dict: dict[str, Any] = {} - if self.live_config: - config_dict.update(self.live_config) + if self.provider_config: + config_dict.update(self.provider_config) # Override with any kwargs from start() config_dict.update(kwargs) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 14c0d6f76..d5e238f8e 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -399,7 +399,8 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: if "text" not in block and "json" not in block: # Unsupported content type - raise error raise ValueError( - f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | Content type not supported by Nova Sonic" + f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | " + f"Content type not supported by Nova Sonic" ) # Optimize for single content item - unwrap the array @@ -475,7 +476,7 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N return BidiAudioStreamEvent( audio=audio_content, format="pcm", - sample_rate=cast(SampleRate, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]), + sample_rate=cast(SampleRate, self.config["audio"]["output_rate"]), channels=channels, ) diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index f64f65f64..d80cd76b0 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -42,8 +42,7 @@ # OpenAI Realtime API configuration OPENAI_REALTIME_URL = "wss://api.openai.com/v1/realtime" DEFAULT_MODEL = "gpt-realtime" - -AUDIO_FORMAT = {"type": "audio/pcm", "rate": 24000} +DEFAULT_SAMPLE_RATE = 24000 DEFAULT_SESSION_CONFIG = { "type": "realtime", @@ -51,7 +50,7 @@ "output_modalities": ["audio"], "audio": { "input": { - "format": AUDIO_FORMAT, + "format": {"type": "audio/pcm", "rate": DEFAULT_SAMPLE_RATE}, "transcription": {"model": "gpt-4o-transcribe"}, "turn_detection": { "type": "server_vad", @@ -60,7 +59,7 @@ "silence_duration_ms": 500, }, }, - "output": {"format": AUDIO_FORMAT, "voice": "alloy"}, + "output": {"format": {"type": "audio/pcm", "rate": DEFAULT_SAMPLE_RATE}, "voice": "alloy"}, }, } @@ -79,10 +78,8 @@ def __init__( self, model_id: str = DEFAULT_MODEL, api_key: str | None = None, - organization: str | None = None, - project: str | None = None, - session_config: dict[str, Any] | None = None, config: dict[str, Any] | None = None, + provider_config: dict[str, Any] | None = None, **kwargs: Any, ) -> None: """Initialize OpenAI Realtime bidirectional model. @@ -90,20 +87,22 @@ def __init__( Args: model_id: OpenAI model identifier (default: gpt-realtime). api_key: OpenAI API key for authentication. - organization: OpenAI organization ID for API requests. - project: OpenAI project ID for API requests. - session_config: Session configuration parameters (e.g., voice, turn_detection, modalities). + provider_config: Session configuration parameters (e.g., voice, turn_detection, modalities). config: Optional configuration dictionary with structure {"audio": AudioConfig, ...}. If not provided or if "audio" key is missing, uses OpenAI Realtime API's default audio configuration. **kwargs: Reserved for future parameters. + + Environment Variables: + OPENAI_API_KEY: API key (if not provided as parameter) + OPENAI_ORGANIZATION: Organization ID for billing/organization + OPENAI_PROJECT: Project ID for billing/organization """ # Model configuration self.model_id = model_id self.api_key = api_key - self.organization = organization - self.project = project - self.session_config = session_config or {} + self.provider_config = provider_config or {} + # Read from environment variables with same pattern as API key if not self.api_key: self.api_key = os.getenv("OPENAI_API_KEY") if not self.api_key: @@ -111,6 +110,10 @@ def __init__( "OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter." ) + # Read organization and project from environment (no parameters needed) + self.organization = os.getenv("OPENAI_ORGANIZATION") + self.project = os.getenv("OPENAI_PROJECT") + # Connection state (initialized in start()) self._connection_id: str | None = None @@ -121,22 +124,16 @@ def __init__( # Extract audio config from config dict if provided user_audio_config = config.get("audio", {}) if config else {} - # Extract voice from session_config if provided - session_config_voice = "alloy" - if self.session_config and "audio" in self.session_config: - audio_settings = self.session_config["audio"] - if isinstance(audio_settings, dict) and "output" in audio_settings: - output_settings = audio_settings["output"] - if isinstance(output_settings, dict): - session_config_voice = output_settings.get("voice", "alloy") + # Extract voice from provider_config if provided + provider_voice = self._extract_voice_from_provider_config() # Define default audio configuration default_audio_config: AudioConfig = { - "input_rate": cast(int, AUDIO_FORMAT["rate"]), - "output_rate": cast(int, AUDIO_FORMAT["rate"]), + "input_rate": DEFAULT_SAMPLE_RATE, + "output_rate": DEFAULT_SAMPLE_RATE, "channels": 1, "format": "pcm", - "voice": session_config_voice, + "voice": provider_voice or "alloy", } # Merge user config with defaults (user values take precedence) @@ -150,6 +147,16 @@ def __init__( else: logger.debug("audio_config | using default OpenAI Realtime audio configuration") + def _extract_voice_from_provider_config(self) -> str | None: + """Extract voice from provider-specific config.""" + if "audio" in self.provider_config: + audio_settings = self.provider_config["audio"] + if isinstance(audio_settings, dict) and "output" in audio_settings: + output_settings = audio_settings["output"] + if isinstance(output_settings, dict): + return output_settings.get("voice") + return None + async def start( self, system_prompt: str | None = None, @@ -249,15 +256,23 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] "turn_detection", } - for key, value in self.session_config.items(): + for key, value in self.provider_config.items(): if key in supported_params: config[key] = value else: logger.warning("parameter=<%s> | ignoring unsupported session parameter", key) - # Override voice with config value if present (config takes precedence) + # Override audio configuration with config values if present (config takes precedence) if "voice" in self.config["audio"]: config.setdefault("audio", {}).setdefault("output", {})["voice"] = self.config["audio"]["voice"] + + if "input_rate" in self.config["audio"]: + input_config = config.setdefault("audio", {}).setdefault("input", {}).setdefault("format", {}) + input_config["rate"] = self.config["audio"]["input_rate"] + + if "output_rate" in self.config["audio"]: + output_config = config.setdefault("audio", {}).setdefault("output", {}).setdefault("format", {}) + output_config["rate"] = self.config["audio"]["output_rate"] return config @@ -356,7 +371,8 @@ async def _add_conversation_history(self, messages: Messages) -> None: if "text" not in result_block and "json" not in result_block: # Unsupported content type - raise error raise ValueError( - f"tool_use_id=<{original_id}>, content_types=<{list(result_block.keys())}> | Content type not supported by OpenAI Realtime API" + f"tool_use_id=<{original_id}>, content_types=<{list(result_block.keys())}> | " + f"Content type not supported by OpenAI Realtime API" ) # Preserve structure by JSON-dumping the entire content array @@ -392,7 +408,7 @@ async def _add_conversation_history(self, messages: Messages) -> None: async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: """Receive OpenAI events and convert to Strands TypedEvent format.""" if not self._connection_id: - raise RuntimeError("model not started | call start before receiving") + raise RuntimeError("model not started | call start before sending/receiving") yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) @@ -415,13 +431,8 @@ def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutput # Audio output elif event_type == "response.output_audio.delta": # Audio is already base64 string from OpenAI - # Get sample rate from user's session config if provided, otherwise use default - sample_rate = ( - self.session_config.get("audio", {}) - .get("output", {}) - .get("format", {}) - .get("rate", AUDIO_FORMAT["rate"]) - ) + # Use the resolved output sample rate from our merged configuration + sample_rate = self.config["audio"]["output_rate"] # Channels from config is guaranteed to be 1 or 2 channels = cast(Literal[1, 2], self.config["audio"]["channels"]) @@ -429,7 +440,7 @@ def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutput BidiAudioStreamEvent( audio=openai_event["delta"], format="pcm", - sample_rate=cast(SampleRate, AUDIO_FORMAT["rate"]), + sample_rate=cast(SampleRate, sample_rate), channels=channels, ) ] @@ -723,7 +734,8 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: if "text" not in block and "json" not in block: # Unsupported content type - raise error raise ValueError( - f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | Content type not supported by OpenAI Realtime API" + f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | " + f"Content type not supported by OpenAI Realtime API" ) # Preserve structure by JSON-dumping the entire content array diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index 053f13dde..846f2b526 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -97,9 +97,9 @@ def test_model_initialization(mock_genai_client, model_id, api_key): assert model_default.api_key is None assert model_default._live_session is None # Check default config includes transcription - assert model_default.live_config["response_modalities"] == ["AUDIO"] - assert "outputAudioTranscription" in model_default.live_config - assert "inputAudioTranscription" in model_default.live_config + assert model_default.provider_config["response_modalities"] == ["AUDIO"] + assert "outputAudioTranscription" in model_default.provider_config + assert "inputAudioTranscription" in model_default.provider_config # Test with API key model_with_key = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) @@ -107,13 +107,13 @@ def test_model_initialization(mock_genai_client, model_id, api_key): assert model_with_key.api_key == api_key # Test with custom config (merges with defaults) - live_config = {"temperature": 0.7, "top_p": 0.9} - model_custom = BidiGeminiLiveModel(model_id=model_id, live_config=live_config) + provider_config = {"temperature": 0.7, "top_p": 0.9} + model_custom = BidiGeminiLiveModel(model_id=model_id, provider_config=provider_config) # Custom config should be merged with defaults - assert model_custom.live_config["temperature"] == 0.7 - assert model_custom.live_config["top_p"] == 0.9 + assert model_custom.provider_config["temperature"] == 0.7 + assert model_custom.provider_config["top_p"] == 0.9 # Defaults should still be present - assert "response_modalities" in model_custom.live_config + assert "response_modalities" in model_custom.provider_config # Connection Tests @@ -501,13 +501,13 @@ def test_audio_config_full_override(mock_genai_client, model_id, api_key): def test_audio_config_voice_priority(mock_genai_client, model_id, api_key): - """Test that config audio voice takes precedence over live_config voice.""" + """Test that config audio voice takes precedence over provider_config voice.""" _ = mock_genai_client - live_config = {"speech_config": {"voice_config": {"prebuilt_voice_config": {"voice_name": "Puck"}}}} + provider_config = {"speech_config": {"voice_config": {"prebuilt_voice_config": {"voice_name": "Puck"}}}} config = {"audio": {"voice": "Aoede"}} - model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key, live_config=live_config, config=config) + model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key, provider_config=provider_config, config=config) # Build config and verify config audio voice takes precedence built_config = model._build_live_config() @@ -549,6 +549,66 @@ def test_tool_formatting(model, tool_spec): # Tool Result Content Tests +@pytest.mark.asyncio +async def test_custom_audio_rates_in_events(mock_genai_client, model_id, api_key): + """Test that audio events use configured sample rates and channels.""" + _, _, _ = mock_genai_client + + # Create model with custom audio configuration + config = {"audio": {"output_rate": 48000, "channels": 2}} + model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key, config=config) + await model.start() + + # Test audio output event uses custom configuration + mock_audio = unittest.mock.Mock() + mock_audio.text = None + mock_audio.data = b"audio_data" + mock_audio.tool_call = None + mock_audio.server_content = None + + audio_events = model._convert_gemini_live_event(mock_audio) + assert len(audio_events) == 1 + audio_event = audio_events[0] + assert isinstance(audio_event, BidiAudioStreamEvent) + # Should use configured rates, not constants + assert audio_event.sample_rate == 48000 # Custom config + assert audio_event.channels == 2 # Custom config + assert audio_event.format == "pcm" + + await model.stop() + + +@pytest.mark.asyncio +async def test_default_audio_rates_in_events(mock_genai_client, model_id, api_key): + """Test that audio events use default sample rates when no custom config.""" + _, _, _ = mock_genai_client + + # Create model without custom audio configuration + model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) + await model.start() + + # Test audio output event uses defaults + mock_audio = unittest.mock.Mock() + mock_audio.text = None + mock_audio.data = b"audio_data" + mock_audio.tool_call = None + mock_audio.server_content = None + + audio_events = model._convert_gemini_live_event(mock_audio) + assert len(audio_events) == 1 + audio_event = audio_events[0] + assert isinstance(audio_event, BidiAudioStreamEvent) + # Should use default rates + assert audio_event.sample_rate == 24000 # Default output rate + assert audio_event.channels == 1 # Default channels + assert audio_event.format == "pcm" + + await model.stop() + + +# Tool Result Content Tests + + @pytest.mark.asyncio async def test_tool_result_single_content_unwrapped(mock_genai_client, model): """Test that single content item is unwrapped (optimization).""" diff --git a/tests/strands/experimental/bidi/models/test_novasonic.py b/tests/strands/experimental/bidi/models/test_novasonic.py index 19aac19c2..1172aae53 100644 --- a/tests/strands/experimental/bidi/models/test_novasonic.py +++ b/tests/strands/experimental/bidi/models/test_novasonic.py @@ -524,6 +524,50 @@ async def test_message_history_empty_and_edge_cases(nova_model): # Error Handling Tests +@pytest.mark.asyncio +async def test_custom_audio_rates_in_events(model_id, region): + """Test that audio events use configured sample rates.""" + # Create model with custom audio configuration + config = {"audio": {"output_rate": 48000, "channels": 2}} + model = BidiNovaSonicModel(model_id=model_id, region=region, config=config) + + # Test audio output event uses custom configuration + audio_bytes = b"test audio data" + audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") + nova_event = {"audioOutput": {"content": audio_base64}} + result = model._convert_nova_event(nova_event) + + assert result is not None + assert isinstance(result, BidiAudioStreamEvent) + # Should use configured rates, not constants + assert result.sample_rate == 48000 # Custom config + assert result.channels == 2 # Custom config + assert result.format == "pcm" + + +@pytest.mark.asyncio +async def test_default_audio_rates_in_events(model_id, region): + """Test that audio events use default sample rates when no custom config.""" + # Create model without custom audio configuration + model = BidiNovaSonicModel(model_id=model_id, region=region) + + # Test audio output event uses defaults + audio_bytes = b"test audio data" + audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") + nova_event = {"audioOutput": {"content": audio_base64}} + result = model._convert_nova_event(nova_event) + + assert result is not None + assert isinstance(result, BidiAudioStreamEvent) + # Should use default rates + assert result.sample_rate == 16000 # Default output rate + assert result.channels == 1 # Default channels + assert result.format == "pcm" + + +# Error Handling Tests + + @pytest.mark.asyncio async def test_error_handling(nova_model, mock_stream): """Test error handling in various scenarios.""" diff --git a/tests/strands/experimental/bidi/models/test_openai.py b/tests/strands/experimental/bidi/models/test_openai.py index ede0920a6..6a6baf011 100644 --- a/tests/strands/experimental/bidi/models/test_openai.py +++ b/tests/strands/experimental/bidi/models/test_openai.py @@ -99,10 +99,11 @@ def test_model_initialization(api_key, model_name): assert model_custom.model_id == model_name assert model_custom.api_key == api_key - # Test with organization and project - model_org = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key, organization="org-123", project="proj-456") - assert model_org.organization == "org-123" - assert model_org.project == "proj-456" + # Test with organization and project via environment variables + with unittest.mock.patch.dict("os.environ", {"OPENAI_ORGANIZATION": "org-123", "OPENAI_PROJECT": "proj-456"}): + model_env = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key) + assert model_env.organization == "org-123" + assert model_env.project == "proj-456" # Test with env API key with unittest.mock.patch.dict("os.environ", {"OPENAI_API_KEY": "env-key"}): @@ -160,12 +161,12 @@ def test_audio_config_full_override(api_key, model_name): def test_audio_config_voice_priority(api_key, model_name): - """Test that config audio voice takes precedence over session_config voice.""" - session_config = {"audio": {"output": {"voice": "alloy"}}} + """Test that config audio voice takes precedence over provider_config voice.""" + provider_config = {"audio": {"output": {"voice": "alloy"}}} config = {"audio": {"voice": "nova"}} model = BidiOpenAIRealtimeModel( - model_id=model_name, api_key=api_key, session_config=session_config, config=config + model_id=model_name, api_key=api_key, provider_config=provider_config, config=config ) # Build config and verify config audio voice takes precedence @@ -173,13 +174,13 @@ def test_audio_config_voice_priority(api_key, model_name): assert built_config["audio"]["output"]["voice"] == "nova" -def test_audio_config_extracts_voice_from_session_config(api_key, model_name): - """Test that voice is extracted from session_config when config audio not provided.""" - session_config = {"audio": {"output": {"voice": "fable"}}} +def test_audio_config_extracts_voice_from_provider_config(api_key, model_name): + """Test that voice is extracted from provider_config when config audio not provided.""" + provider_config = {"audio": {"output": {"voice": "fable"}}} - model = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key, session_config=session_config) + model = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key, provider_config=provider_config) - # Should extract voice from session_config + # Should extract voice from provider_config assert model.config["audio"]["voice"] == "fable" @@ -240,15 +241,16 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp assert len(item_creates) > 0 await model.stop() - # Test connection with organization header - model_org = BidiOpenAIRealtimeModel(api_key="test-key", organization="org-123") - await model_org.start() - call_kwargs = mock_connect.call_args.kwargs - headers = call_kwargs.get("additional_headers", []) - org_header = [h for h in headers if h[0] == "OpenAI-Organization"] - assert len(org_header) == 1 - assert org_header[0][1] == "org-123" - await model_org.stop() + # Test connection with organization header (via environment) + with unittest.mock.patch.dict("os.environ", {"OPENAI_ORGANIZATION": "org-123"}): + model_org = BidiOpenAIRealtimeModel(api_key="test-key") + await model_org.start() + call_kwargs = mock_connect.call_args.kwargs + headers = call_kwargs.get("additional_headers", []) + org_header = [h for h in headers if h[0] == "OpenAI-Organization"] + assert len(org_header) == 1 + assert org_header[0][1] == "org-123" + await model_org.stop() @pytest.mark.asyncio @@ -681,13 +683,13 @@ async def test_send_event_helper(mock_websockets_connect, model): @pytest.mark.asyncio async def test_custom_audio_sample_rate(mock_websockets_connect, api_key): - """Test that custom audio sample rate from session_config is used in audio events.""" + """Test that custom audio sample rate from provider_config is used in audio events.""" _, mock_ws = mock_websockets_connect # Create model with custom sample rate custom_sample_rate = 48000 - session_config = {"audio": {"output": {"format": {"rate": custom_sample_rate}}}} - model = BidiOpenAIRealtimeModel(api_key=api_key, session_config=session_config) + provider_config = {"audio": {"output": {"format": {"rate": custom_sample_rate}}}} + model = BidiOpenAIRealtimeModel(api_key=api_key, provider_config=provider_config) await model.start() @@ -730,7 +732,7 @@ async def test_default_audio_sample_rate(mock_websockets_connect, api_key): assert len(converted_events) == 1 audio_event = converted_events[0] assert isinstance(audio_event, BidiAudioStreamEvent) - assert audio_event.sample_rate == 24000 # Default from AUDIO_FORMAT + assert audio_event.sample_rate == 24000 # Default from DEFAULT_SAMPLE_RATE assert audio_event.format == "pcm" assert audio_event.channels == 1 @@ -743,8 +745,8 @@ async def test_partial_audio_config(mock_websockets_connect, api_key): _, mock_ws = mock_websockets_connect # Create model with partial audio config (missing format.rate) - session_config = {"audio": {"output": {"voice": "alloy"}}} - model = BidiOpenAIRealtimeModel(api_key=api_key, session_config=session_config) + provider_config = {"audio": {"output": {"voice": "alloy"}}} + model = BidiOpenAIRealtimeModel(api_key=api_key, provider_config=provider_config) await model.start() From e7f3ed5f2060c3d0de8771b23f2e7f149c818e8e Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 26 Nov 2025 18:21:57 -0500 Subject: [PATCH 203/242] remove runtime import of bidi (#86) --- .../experimental/bidi/models/novasonic.py | 4 +-- .../experimental/bidi/models/openai.py | 8 ++--- src/strands/tools/executors/_executor.py | 32 +++++++------------ src/strands/tools/executors/concurrent.py | 8 ++--- src/strands/tools/executors/sequential.py | 6 ++-- 5 files changed, 24 insertions(+), 34 deletions(-) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 14c0d6f76..b528ff5fd 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -393,7 +393,7 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: # Validate content types and preserve structure content = tool_result.get("content", []) - + # Validate all content types are supported for block in content: if "text" not in block and "json" not in block: @@ -401,7 +401,7 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: raise ValueError( f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | Content type not supported by Nova Sonic" ) - + # Optimize for single content item - unwrap the array if len(content) == 1: result_data: dict[str, Any] = content[0] diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index f64f65f64..48c42a3e2 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -347,7 +347,7 @@ async def _add_conversation_history(self, messages: Messages) -> None: # Tool result - create as function_call_output item tool_result = block["toolResult"] original_id = tool_result["toolUseId"] - + # Validate content types and serialize, preserving structure result_output = "" if "content" in tool_result: @@ -358,10 +358,10 @@ async def _add_conversation_history(self, messages: Messages) -> None: raise ValueError( f"tool_use_id=<{original_id}>, content_types=<{list(result_block.keys())}> | Content type not supported by OpenAI Realtime API" ) - + # Preserve structure by JSON-dumping the entire content array result_output = json.dumps(tool_result["content"]) - + # Use mapped call_id if available, otherwise skip orphaned result if original_id not in call_id_map: continue # Skip this tool result since we don't have the call @@ -725,7 +725,7 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: raise ValueError( f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | Content type not supported by OpenAI Realtime API" ) - + # Preserve structure by JSON-dumping the entire content array result_output = json.dumps(tool_result["content"]) diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index ce8d04661..8ef8ee673 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -7,7 +7,7 @@ import abc import logging import time -from typing import TYPE_CHECKING, Any, AsyncGenerator, Union, cast +from typing import TYPE_CHECKING, Any, AsyncGenerator, cast from opentelemetry import trace as trace_api @@ -23,7 +23,7 @@ if TYPE_CHECKING: # pragma: no cover from ...agent import Agent - from ...experimental.bidi.agent.agent import BidiAgent + from ...experimental.bidi import BidiAgent logger = logging.getLogger(__name__) @@ -32,27 +32,17 @@ class ToolExecutor(abc.ABC): """Abstract base class for tool executors.""" @staticmethod - def _is_bidi_agent(agent: Union["Agent", "BidiAgent"]) -> bool: - """Check if the agent is a BidiAgent using isinstance. - - Uses runtime import to avoid circular dependency at module load time. - This properly handles subclasses of BidiAgent. - """ - try: - from ...experimental.bidi.agent.agent import BidiAgent - - return isinstance(agent, BidiAgent) - except ImportError: - # If BidiAgent is not available, it can't be a BidiAgent - return False + def _is_bidi_agent(agent: "Agent | BidiAgent") -> bool: + """Check if the agent is a BidiAgent instance.""" + return agent.__class__.__name__ == "BidiAgent" @staticmethod async def _invoke_before_tool_call_hook( - agent: Union["Agent", "BidiAgent"], + agent: "Agent | BidiAgent", tool_func: Any, tool_use: ToolUse, invocation_state: dict[str, Any], - ) -> tuple[Union[BeforeToolCallEvent, BidiBeforeToolCallEvent], list[Interrupt]]: + ) -> tuple[BeforeToolCallEvent | BidiBeforeToolCallEvent, list[Interrupt]]: """Invoke the appropriate before tool call hook based on agent type.""" event_cls = BidiBeforeToolCallEvent if ToolExecutor._is_bidi_agent(agent) else BeforeToolCallEvent return await agent.hooks.invoke_callbacks_async( @@ -66,14 +56,14 @@ async def _invoke_before_tool_call_hook( @staticmethod async def _invoke_after_tool_call_hook( - agent: Union["Agent", "BidiAgent"], + agent: "Agent | BidiAgent", selected_tool: Any, tool_use: ToolUse, invocation_state: dict[str, Any], result: ToolResult, exception: Exception | None = None, cancel_message: str | None = None, - ) -> tuple[Union[AfterToolCallEvent, BidiAfterToolCallEvent], list[Interrupt]]: + ) -> tuple[AfterToolCallEvent | BidiAfterToolCallEvent, list[Interrupt]]: """Invoke the appropriate after tool call hook based on agent type.""" event_cls = BidiAfterToolCallEvent if ToolExecutor._is_bidi_agent(agent) else AfterToolCallEvent return await agent.hooks.invoke_callbacks_async( @@ -252,7 +242,7 @@ async def _stream( @staticmethod async def _stream_with_trace( - agent: Union["Agent", "BidiAgent"], + agent: "Agent | BidiAgent", tool_use: ToolUse, tool_results: list[ToolResult], cycle_trace: Trace, @@ -313,7 +303,7 @@ async def _stream_with_trace( # pragma: no cover def _execute( self, - agent: Union["Agent", "BidiAgent"], + agent: "Agent | BidiAgent", tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py index 1a586c589..da5c1ff10 100644 --- a/src/strands/tools/executors/concurrent.py +++ b/src/strands/tools/executors/concurrent.py @@ -1,7 +1,7 @@ """Concurrent tool executor implementation.""" import asyncio -from typing import TYPE_CHECKING, Any, AsyncGenerator, Union +from typing import TYPE_CHECKING, Any, AsyncGenerator from typing_extensions import override @@ -12,7 +12,7 @@ if TYPE_CHECKING: # pragma: no cover from ...agent import Agent - from ...experimental.bidi.agent.agent import BidiAgent + from ...experimental.bidi import BidiAgent from ..structured_output._structured_output_context import StructuredOutputContext @@ -22,7 +22,7 @@ class ConcurrentToolExecutor(ToolExecutor): @override async def _execute( self, - agent: Union["Agent", "BidiAgent"], + agent: "Agent | BidiAgent", tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, @@ -79,7 +79,7 @@ async def _execute( async def _task( self, - agent: Union["Agent", "BidiAgent"], + agent: "Agent | BidiAgent", tool_use: ToolUse, tool_results: list[ToolResult], cycle_trace: Trace, diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py index e4ac0ecda..6163fc195 100644 --- a/src/strands/tools/executors/sequential.py +++ b/src/strands/tools/executors/sequential.py @@ -1,6 +1,6 @@ """Sequential tool executor implementation.""" -from typing import TYPE_CHECKING, Any, AsyncGenerator, Union +from typing import TYPE_CHECKING, Any, AsyncGenerator from typing_extensions import override @@ -11,7 +11,7 @@ if TYPE_CHECKING: # pragma: no cover from ...agent import Agent - from ...experimental.bidi.agent.agent import BidiAgent + from ...experimental.bidi import BidiAgent from ..structured_output._structured_output_context import StructuredOutputContext @@ -21,7 +21,7 @@ class SequentialToolExecutor(ToolExecutor): @override async def _execute( self, - agent: Union["Agent", "BidiAgent"], + agent: "Agent | BidiAgent", tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, From 9c88f8721ecb7bf21573bc867a0e1a2d372bc953 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 26 Nov 2025 21:27:38 -0500 Subject: [PATCH 204/242] fix lint and type errors (#87) --- src/strands/experimental/bidi/models/gemini_live.py | 5 +++-- src/strands/experimental/bidi/models/novasonic.py | 5 +++-- src/strands/experimental/bidi/models/openai.py | 9 +++++---- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index c1a3e29eb..802551394 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -440,12 +440,13 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: if "text" not in block and "json" not in block: # Unsupported content type - raise error raise ValueError( - f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | Content type not supported by Gemini Live API" + f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}>" + " | Content type not supported by Gemini Live API" ) # Optimize for single content item - unwrap the array if len(content) == 1: - result_data: dict[str, Any] = content[0] + result_data = cast(dict[str, Any], content[0]) else: # Multiple items - send as array result_data = {"result": content} diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index b528ff5fd..6ee917022 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -399,12 +399,13 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: if "text" not in block and "json" not in block: # Unsupported content type - raise error raise ValueError( - f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | Content type not supported by Nova Sonic" + f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}>" + " | Content type not supported by Nova Sonic" ) # Optimize for single content item - unwrap the array if len(content) == 1: - result_data: dict[str, Any] = content[0] + result_data = cast(dict[str, Any], content[0]) else: # Multiple items - send as array result_data = {"content": content} diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 48c42a3e2..2bdf62e9c 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -32,7 +32,6 @@ BidiUsageEvent, ModalityUsage, Role, - SampleRate, StopReason, ) from .bidi_model import BidiModel @@ -356,7 +355,8 @@ async def _add_conversation_history(self, messages: Messages) -> None: if "text" not in result_block and "json" not in result_block: # Unsupported content type - raise error raise ValueError( - f"tool_use_id=<{original_id}>, content_types=<{list(result_block.keys())}> | Content type not supported by OpenAI Realtime API" + f"tool_use_id=<{original_id}>, content_types=<{list(result_block.keys())}>" + " | Content type not supported by OpenAI Realtime API" ) # Preserve structure by JSON-dumping the entire content array @@ -429,7 +429,7 @@ def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutput BidiAudioStreamEvent( audio=openai_event["delta"], format="pcm", - sample_rate=cast(SampleRate, AUDIO_FORMAT["rate"]), + sample_rate=sample_rate, channels=channels, ) ] @@ -723,7 +723,8 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None: if "text" not in block and "json" not in block: # Unsupported content type - raise error raise ValueError( - f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | Content type not supported by OpenAI Realtime API" + f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}>" + " | Content type not supported by OpenAI Realtime API" ) # Preserve structure by JSON-dumping the entire content array From 610a8ea10ab33911113651ccfb5e8dedb0a44644 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 26 Nov 2025 21:28:04 -0500 Subject: [PATCH 205/242] tool executor - add agent isinstance check (#88) --- src/strands/tools/executors/_executor.py | 17 +++++++++++------ tests/strands/event_loop/test_event_loop.py | 2 ++ tests/strands/tools/executors/conftest.py | 2 ++ 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 8ef8ee673..a4f9e7e1f 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -32,9 +32,14 @@ class ToolExecutor(abc.ABC): """Abstract base class for tool executors.""" @staticmethod - def _is_bidi_agent(agent: "Agent | BidiAgent") -> bool: - """Check if the agent is a BidiAgent instance.""" - return agent.__class__.__name__ == "BidiAgent" + def _is_agent(agent: "Agent | BidiAgent") -> bool: + """Check if the agent is an Agent instance, otherwise we assume BidiAgent. + + Note, we use a runtime import to avoid a circular dependency error. + """ + from ...agent import Agent + + return isinstance(agent, Agent) @staticmethod async def _invoke_before_tool_call_hook( @@ -44,7 +49,7 @@ async def _invoke_before_tool_call_hook( invocation_state: dict[str, Any], ) -> tuple[BeforeToolCallEvent | BidiBeforeToolCallEvent, list[Interrupt]]: """Invoke the appropriate before tool call hook based on agent type.""" - event_cls = BidiBeforeToolCallEvent if ToolExecutor._is_bidi_agent(agent) else BeforeToolCallEvent + event_cls = BeforeToolCallEvent if ToolExecutor._is_agent(agent) else BidiBeforeToolCallEvent return await agent.hooks.invoke_callbacks_async( event_cls( agent=agent, @@ -65,7 +70,7 @@ async def _invoke_after_tool_call_hook( cancel_message: str | None = None, ) -> tuple[AfterToolCallEvent | BidiAfterToolCallEvent, list[Interrupt]]: """Invoke the appropriate after tool call hook based on agent type.""" - event_cls = BidiAfterToolCallEvent if ToolExecutor._is_bidi_agent(agent) else AfterToolCallEvent + event_cls = AfterToolCallEvent if ToolExecutor._is_agent(agent) else BidiAfterToolCallEvent return await agent.hooks.invoke_callbacks_async( event_cls( agent=agent, @@ -293,7 +298,7 @@ async def _stream_with_trace( tool_success = result.get("status") == "success" tool_duration = time.time() - tool_start_time message = Message(role="user", content=[{"toolResult": result}]) - if not ToolExecutor._is_bidi_agent(agent): + if ToolExecutor._is_agent(agent): agent.event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message) cycle_trace.add_child(tool_trace) diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 0a323b30d..52980729c 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -6,6 +6,7 @@ import strands import strands.telemetry +from strands import Agent from strands.hooks import ( AfterModelCallEvent, BeforeModelCallEvent, @@ -133,6 +134,7 @@ def tool_executor(): @pytest.fixture def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_registry, tool_executor): mock = unittest.mock.Mock(name="agent") + mock.__class__ = Agent mock.config.cache_points = [] mock.model = model mock.system_prompt = system_prompt diff --git a/tests/strands/tools/executors/conftest.py b/tests/strands/tools/executors/conftest.py index 5984e33ab..ad92ba603 100644 --- a/tests/strands/tools/executors/conftest.py +++ b/tests/strands/tools/executors/conftest.py @@ -4,6 +4,7 @@ import pytest import strands +from strands import Agent from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent, HookRegistry from strands.interrupt import _InterruptState from strands.tools.registry import ToolRegistry @@ -102,6 +103,7 @@ def tool_registry(weather_tool, temperature_tool, exception_tool, thread_tool, i @pytest.fixture def agent(tool_registry, hook_registry): mock_agent = unittest.mock.Mock() + mock_agent.__class__ = Agent mock_agent.tool_registry = tool_registry mock_agent.hooks = hook_registry mock_agent._interrupt_state = _InterruptState() From 7f16aa3712c96fd7d30dc6852e0ee1b40d4aac30 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 26 Nov 2025 22:15:28 -0500 Subject: [PATCH 206/242] add unit test for bidi agent --- scripts/bidi/test_bidi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/bidi/test_bidi.py b/scripts/bidi/test_bidi.py index 2beb3ddd7..eb8e3d462 100644 --- a/scripts/bidi/test_bidi.py +++ b/scripts/bidi/test_bidi.py @@ -6,7 +6,7 @@ from strands.experimental.bidi.agent.agent import BidiAgent from strands.experimental.bidi.io import BidiAudioIO, BidiTextIO -from strands.experimental.bidi.models.novasonic import BidiNovaSonicModel +from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel async def main(): @@ -14,7 +14,7 @@ async def main(): # Nova Sonic model audio_io = BidiAudioIO() text_io = BidiTextIO() - model = BidiNovaSonicModel(region="us-east-1") + model = BidiGeminiLiveModel(region="us-east-1", config={"audio": {"voice": "ash"}}) agent = BidiAgent(model=model, tools=[calculator]) print("New BidiAgent Experience") From aa90bc3852ba049e28ff8b45f5801a04b8643ba0 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 26 Nov 2025 22:17:26 -0500 Subject: [PATCH 207/242] add unit test for bidi agent --- scripts/bidi/test_bidi.py | 4 +- .../experimental/bidi/agent/__init__.py | 1 + .../experimental/bidi/agent/test_agent.py | 412 ++++++++++++++++++ 3 files changed, 415 insertions(+), 2 deletions(-) create mode 100644 tests/strands/experimental/bidi/agent/__init__.py create mode 100644 tests/strands/experimental/bidi/agent/test_agent.py diff --git a/scripts/bidi/test_bidi.py b/scripts/bidi/test_bidi.py index eb8e3d462..2beb3ddd7 100644 --- a/scripts/bidi/test_bidi.py +++ b/scripts/bidi/test_bidi.py @@ -6,7 +6,7 @@ from strands.experimental.bidi.agent.agent import BidiAgent from strands.experimental.bidi.io import BidiAudioIO, BidiTextIO -from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel +from strands.experimental.bidi.models.novasonic import BidiNovaSonicModel async def main(): @@ -14,7 +14,7 @@ async def main(): # Nova Sonic model audio_io = BidiAudioIO() text_io = BidiTextIO() - model = BidiGeminiLiveModel(region="us-east-1", config={"audio": {"voice": "ash"}}) + model = BidiNovaSonicModel(region="us-east-1") agent = BidiAgent(model=model, tools=[calculator]) print("New BidiAgent Experience") diff --git a/tests/strands/experimental/bidi/agent/__init__.py b/tests/strands/experimental/bidi/agent/__init__.py new file mode 100644 index 000000000..c7a89939d --- /dev/null +++ b/tests/strands/experimental/bidi/agent/__init__.py @@ -0,0 +1 @@ +# Empty init file for bidi agent test package \ No newline at end of file diff --git a/tests/strands/experimental/bidi/agent/test_agent.py b/tests/strands/experimental/bidi/agent/test_agent.py new file mode 100644 index 000000000..88c5d42c8 --- /dev/null +++ b/tests/strands/experimental/bidi/agent/test_agent.py @@ -0,0 +1,412 @@ +"""Unit tests for BidiAgent. + +Tests the bidirectional streaming agent including: +- Agent initialization and configuration +- Lifecycle management (start/stop/context manager) +- Send/receive methods with different content types +- Tool integration and execution +- Event processing and conversion +- Error handling and edge cases +""" + +import unittest.mock +import asyncio +import pytest +from uuid import uuid4 + +from strands.experimental.bidi.agent.agent import BidiAgent +from strands.experimental.bidi.models.novasonic import BidiNovaSonicModel +from strands.experimental.bidi.types.events import ( + BidiTextInputEvent, + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiTranscriptStreamEvent, + BidiConnectionStartEvent, + BidiConnectionCloseEvent, +) +from strands.types._events import ToolResultEvent +from strands.types.tools import ToolResult, ToolSpec + + +# Mock model fixtures + + +class MockBidiModel: + """Mock bidirectional model for testing.""" + + def __init__(self, config=None, model_id="mock-model"): + self.config = config or {"audio": {"input_rate": 16000, "output_rate": 24000, "channels": 1}} + self.model_id = model_id + self._connection_id = None + self._started = False + self._events_to_yield = [] + + async def start(self, system_prompt=None, tools=None, messages=None, **kwargs): + if self._started: + raise RuntimeError("model already started | call stop before starting again") + self._connection_id = str(uuid4()) + self._started = True + + async def stop(self): + if self._started: + self._started = False + self._connection_id = None + + async def send(self, content): + if not self._started: + raise RuntimeError("model not started | call start before sending/receiving") + # Mock implementation - in real tests, this would trigger events + + async def receive(self): + """Async generator yielding mock events.""" + if not self._started: + raise RuntimeError("model not started | call start before sending/receiving") + + # Yield connection start event + yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) + + # Yield any configured events + for event in self._events_to_yield: + yield event + + # Yield connection end event + yield BidiConnectionCloseEvent(connection_id=self._connection_id, reason="complete") + + def set_events(self, events): + """Helper to set events this mock model will yield.""" + self._events_to_yield = events + + +@pytest.fixture +def mock_model(): + """Create a mock BidiModel instance.""" + return MockBidiModel() + + + + + +@pytest.fixture +def mock_tool_registry(): + """Mock tool registry with some basic tools.""" + registry = unittest.mock.Mock() + registry.get_all_tool_specs.return_value = [ + { + "name": "calculator", + "description": "Perform calculations", + "inputSchema": {"json": {"type": "object", "properties": {}}} + } + ] + registry.get_all_tools_config.return_value = {"calculator": {}} + return registry + + +@pytest.fixture +def mock_tool_caller(): + """Mock tool caller for testing tool execution.""" + caller = unittest.mock.AsyncMock() + caller.call_tool = unittest.mock.AsyncMock() + return caller + + +@pytest.fixture +def agent(mock_model, mock_tool_registry, mock_tool_caller): + """Create a BidiAgent instance for testing.""" + with unittest.mock.patch("strands.experimental.bidi.agent.agent.ToolRegistry") as mock_registry_class: + mock_registry_class.return_value = mock_tool_registry + + with unittest.mock.patch("strands.experimental.bidi.agent.agent._ToolCaller") as mock_caller_class: + mock_caller_class.return_value = mock_tool_caller + + # Don't pass tools to avoid real tool loading + agent = BidiAgent(model=mock_model) + return agent + + +# ============================================================================ +# Initialization Tests +# ============================================================================ + + +def test_agent_initialization(): + """Test agent initialization with various configurations.""" + # Test default initialization + mock_model = MockBidiModel() + agent = BidiAgent(model=mock_model) + + assert agent.model == mock_model + assert agent.system_prompt is None + assert not agent._started + assert agent.model._connection_id is None + + # Test with configuration + system_prompt = "You are a helpful assistant." + agent_with_config = BidiAgent( + model=mock_model, + system_prompt=system_prompt, + agent_id="test_agent" + ) + + assert agent_with_config.system_prompt == system_prompt + assert agent_with_config.agent_id == "test_agent" + + # Test with string model ID + model_id = "amazon.nova-sonic-v1:0" + agent_with_string = BidiAgent(model=model_id) + + assert isinstance(agent_with_string.model, BidiNovaSonicModel) + assert agent_with_string.model.model_id == model_id + + # Test model config access + config = agent.model.config + assert config["audio"]["input_rate"] == 16000 + assert config["audio"]["output_rate"] == 24000 + assert config["audio"]["channels"] == 1 + + +# ============================================================================ +# Lifecycle Tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_agent_lifecycle(agent): + """Test agent start/stop lifecycle and state management.""" + # Initial state + assert not agent._started + assert agent.model._connection_id is None + + # Start agent + await agent.start() + assert agent._started + assert agent.model._connection_id is not None + connection_id = agent.model._connection_id + + # Double start should error + with pytest.raises(RuntimeError, match="agent already started"): + await agent.start() + + # Stop agent + await agent.stop() + assert not agent._started + assert agent.model._connection_id is None + + # Multiple stops should be safe + await agent.stop() + await agent.stop() + + # Restart should work with new connection ID + await agent.start() + assert agent._started + assert agent.model._connection_id != connection_id + + +# ============================================================================ +# Send Method Tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_send_methods(agent): + """Test sending various input types through agent.send().""" + await agent.start() + + # Test text input with TypedEvent + text_input = BidiTextInputEvent(text="Hello", role="user") + await agent.send(text_input) + assert len(agent.messages) == 1 + assert agent.messages[0]["content"][0]["text"] == "Hello" + + # Test string input (shorthand) + await agent.send("World") + assert len(agent.messages) == 2 + assert agent.messages[1]["content"][0]["text"] == "World" + + # Test audio input (doesn't add to messages) + audio_input = BidiAudioInputEvent( + audio="dGVzdA==", # base64 "test" + format="pcm", + sample_rate=16000, + channels=1 + ) + await agent.send(audio_input) + assert len(agent.messages) == 2 # Still 2, audio doesn't add + + # Test concurrent sends + sends = [ + agent.send(BidiTextInputEvent(text=f"Message {i}", role="user")) + for i in range(3) + ] + await asyncio.gather(*sends) + assert len(agent.messages) == 5 # 2 + 3 new messages + + +# ============================================================================ +# Receive Method Tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_receive_methods(agent): + """Test receiving events from model.""" + # Configure mock model to yield events + events = [ + BidiAudioStreamEvent( + audio="dGVzdA==", + format="pcm", + sample_rate=24000, + channels=1 + ), + BidiTranscriptStreamEvent( + text="Hello world", + role="assistant", + is_final=True, + delta={"text": "Hello world"}, + current_transcript="Hello world" + ) + ] + agent.model.set_events(events) + + await agent.start() + + received_events = [] + async for event in agent.receive(): + received_events.append(event) + if len(received_events) >= 4: # Stop after getting expected events + break + + # Verify event types and order + assert len(received_events) >= 3 + assert isinstance(received_events[0], BidiConnectionStartEvent) + assert isinstance(received_events[1], BidiAudioStreamEvent) + assert isinstance(received_events[2], BidiTranscriptStreamEvent) + + # Test empty events + agent.model.set_events([]) + await agent.stop() + await agent.start() + + empty_events = [] + async for event in agent.receive(): + empty_events.append(event) + if len(empty_events) >= 2: + break + + assert len(empty_events) >= 1 + assert isinstance(empty_events[0], BidiConnectionStartEvent) + + +# ============================================================================ +# Tool Integration Tests +# ============================================================================ + + +def test_agent_tools(agent, mock_tool_registry): + """Test agent tool integration and properties.""" + # Test tool property access + assert hasattr(agent, 'tool') + assert agent.tool is not None + assert agent.tool == agent._tool_caller + + # Test tool names property + mock_tool_registry.get_all_tools_config.return_value = { + "calculator": {}, + "weather": {} + } + + tool_names = agent.tool_names + assert isinstance(tool_names, list) + assert len(tool_names) == 2 + assert "calculator" in tool_names + assert "weather" in tool_names + + +# ============================================================================ +# Error Handling Tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_error_handling(agent): + """Test error handling in various scenarios.""" + # Test send before start + with pytest.raises(RuntimeError, match="call start before"): + await agent.send(BidiTextInputEvent(text="Hello", role="user")) + + # Test receive before start + with pytest.raises(RuntimeError, match="call start before"): + async for event in agent.receive(): + pass + + # Test send after stop + await agent.start() + await agent.stop() + with pytest.raises(RuntimeError, match="call start before"): + await agent.send(BidiTextInputEvent(text="Hello", role="user")) + + # Test receive after stop + with pytest.raises(RuntimeError, match="call start before"): + async for event in agent.receive(): + pass + + +@pytest.mark.asyncio +async def test_model_error_propagation(): + """Test that model errors are properly propagated.""" + # Test model start error + mock_model = MockBidiModel() + mock_model.start = unittest.mock.AsyncMock(side_effect=Exception("Connection failed")) + error_agent = BidiAgent(model=mock_model) + + with pytest.raises(Exception, match="Connection failed"): + await error_agent.start() + + # Test model receive error + mock_model2 = MockBidiModel() + agent2 = BidiAgent(model=mock_model2) + await agent2.start() + + async def failing_receive(): + yield BidiConnectionStartEvent(connection_id="test", model="test-model") + raise Exception("Receive failed") + + agent2.model.receive = failing_receive + with pytest.raises(Exception, match="Receive failed"): + async for event in agent2.receive(): + pass + + +# ============================================================================ +# State Consistency Tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_agent_state_consistency(agent): + """Test that agent state remains consistent across operations.""" + # Initial state + assert not agent._started + assert agent.model._connection_id is None + + # Start + await agent.start() + assert agent._started + assert agent.model._connection_id is not None + connection_id = agent.model._connection_id + + # Send operations shouldn't change connection state + await agent.send(BidiTextInputEvent(text="Hello", role="user")) + assert agent._started + assert agent.model._connection_id == connection_id + + # Stop + await agent.stop() + assert not agent._started + assert agent.model._connection_id is None + + +# ============================================================================ +# Integration Tests +# ============================================================================ + + From 18a43045b745fc22e77294d489ff2b41f308baa0 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 26 Nov 2025 22:21:24 -0500 Subject: [PATCH 208/242] add unit test for bidi agent --- .../experimental/bidi/agent/__init__.py | 2 +- .../experimental/bidi/agent/test_agent.py | 73 +------------------ 2 files changed, 3 insertions(+), 72 deletions(-) diff --git a/tests/strands/experimental/bidi/agent/__init__.py b/tests/strands/experimental/bidi/agent/__init__.py index c7a89939d..3359c6565 100644 --- a/tests/strands/experimental/bidi/agent/__init__.py +++ b/tests/strands/experimental/bidi/agent/__init__.py @@ -1 +1 @@ -# Empty init file for bidi agent test package \ No newline at end of file +"""Bidirectional streaming agent tests.""" \ No newline at end of file diff --git a/tests/strands/experimental/bidi/agent/test_agent.py b/tests/strands/experimental/bidi/agent/test_agent.py index 88c5d42c8..79eb6b83b 100644 --- a/tests/strands/experimental/bidi/agent/test_agent.py +++ b/tests/strands/experimental/bidi/agent/test_agent.py @@ -1,13 +1,4 @@ -"""Unit tests for BidiAgent. - -Tests the bidirectional streaming agent including: -- Agent initialization and configuration -- Lifecycle management (start/stop/context manager) -- Send/receive methods with different content types -- Tool integration and execution -- Event processing and conversion -- Error handling and edge cases -""" +"""Unit tests for BidiAgent.""" import unittest.mock import asyncio @@ -24,12 +15,6 @@ BidiConnectionStartEvent, BidiConnectionCloseEvent, ) -from strands.types._events import ToolResultEvent -from strands.types.tools import ToolResult, ToolSpec - - -# Mock model fixtures - class MockBidiModel: """Mock bidirectional model for testing.""" @@ -76,16 +61,11 @@ def set_events(self, events): """Helper to set events this mock model will yield.""" self._events_to_yield = events - @pytest.fixture def mock_model(): """Create a mock BidiModel instance.""" return MockBidiModel() - - - - @pytest.fixture def mock_tool_registry(): """Mock tool registry with some basic tools.""" @@ -122,12 +102,6 @@ def agent(mock_model, mock_tool_registry, mock_tool_caller): agent = BidiAgent(model=mock_model) return agent - -# ============================================================================ -# Initialization Tests -# ============================================================================ - - def test_agent_initialization(): """Test agent initialization with various configurations.""" # Test default initialization @@ -163,12 +137,6 @@ def test_agent_initialization(): assert config["audio"]["output_rate"] == 24000 assert config["audio"]["channels"] == 1 - -# ============================================================================ -# Lifecycle Tests -# ============================================================================ - - @pytest.mark.asyncio async def test_agent_lifecycle(agent): """Test agent start/stop lifecycle and state management.""" @@ -200,12 +168,6 @@ async def test_agent_lifecycle(agent): assert agent._started assert agent.model._connection_id != connection_id - -# ============================================================================ -# Send Method Tests -# ============================================================================ - - @pytest.mark.asyncio async def test_send_methods(agent): """Test sending various input types through agent.send().""" @@ -240,12 +202,6 @@ async def test_send_methods(agent): await asyncio.gather(*sends) assert len(agent.messages) == 5 # 2 + 3 new messages - -# ============================================================================ -# Receive Method Tests -# ============================================================================ - - @pytest.mark.asyncio async def test_receive_methods(agent): """Test receiving events from model.""" @@ -295,12 +251,6 @@ async def test_receive_methods(agent): assert len(empty_events) >= 1 assert isinstance(empty_events[0], BidiConnectionStartEvent) - -# ============================================================================ -# Tool Integration Tests -# ============================================================================ - - def test_agent_tools(agent, mock_tool_registry): """Test agent tool integration and properties.""" # Test tool property access @@ -320,12 +270,6 @@ def test_agent_tools(agent, mock_tool_registry): assert "calculator" in tool_names assert "weather" in tool_names - -# ============================================================================ -# Error Handling Tests -# ============================================================================ - - @pytest.mark.asyncio async def test_error_handling(agent): """Test error handling in various scenarios.""" @@ -375,12 +319,6 @@ async def failing_receive(): async for event in agent2.receive(): pass - -# ============================================================================ -# State Consistency Tests -# ============================================================================ - - @pytest.mark.asyncio async def test_agent_state_consistency(agent): """Test that agent state remains consistent across operations.""" @@ -402,11 +340,4 @@ async def test_agent_state_consistency(agent): # Stop await agent.stop() assert not agent._started - assert agent.model._connection_id is None - - -# ============================================================================ -# Integration Tests -# ============================================================================ - - + assert agent.model._connection_id is None \ No newline at end of file From 43e3125e0ac9db40a5bb828a44858b1772bbbc7e Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 26 Nov 2025 22:23:26 -0500 Subject: [PATCH 209/242] remove scripts directory before merging with sdk-python/main --- scripts/bidi/test_bidi.py | 34 --- scripts/bidi/test_bidi_novasonic.py | 246 -------------------- scripts/bidi/test_bidi_openai.py | 308 ------------------------ scripts/bidi/test_gemini_live.py | 349 ---------------------------- 4 files changed, 937 deletions(-) delete mode 100644 scripts/bidi/test_bidi.py delete mode 100644 scripts/bidi/test_bidi_novasonic.py delete mode 100644 scripts/bidi/test_bidi_openai.py delete mode 100644 scripts/bidi/test_gemini_live.py diff --git a/scripts/bidi/test_bidi.py b/scripts/bidi/test_bidi.py deleted file mode 100644 index 2beb3ddd7..000000000 --- a/scripts/bidi/test_bidi.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Test BidirectionalAgent with simple developer experience.""" - -import asyncio - -from strands_tools import calculator - -from strands.experimental.bidi.agent.agent import BidiAgent -from strands.experimental.bidi.io import BidiAudioIO, BidiTextIO -from strands.experimental.bidi.models.novasonic import BidiNovaSonicModel - - -async def main(): - """Test the BidirectionalAgent API.""" - # Nova Sonic model - audio_io = BidiAudioIO() - text_io = BidiTextIO() - model = BidiNovaSonicModel(region="us-east-1") - agent = BidiAgent(model=model, tools=[calculator]) - - print("New BidiAgent Experience") - print("Try asking: 'What is 25 times 8?' or 'Calculate the square root of 144'") - await agent.run(inputs=[audio_io.input()], outputs=[audio_io.output(), text_io.output()]) - - -if __name__ == "__main__": - try: - asyncio.run(main()) - except KeyboardInterrupt: - print("\n⏹️ Conversation ended by user") - except Exception as e: - print(f"❌ Error: {e}") - import traceback - - traceback.print_exc() diff --git a/scripts/bidi/test_bidi_novasonic.py b/scripts/bidi/test_bidi_novasonic.py deleted file mode 100644 index baa39226f..000000000 --- a/scripts/bidi/test_bidi_novasonic.py +++ /dev/null @@ -1,246 +0,0 @@ -"""Test suite for bidirectional streaming with real-time audio interaction. - -Tests the complete bidirectional streaming system including audio input/output, -interruption handling, and concurrent tool execution using Nova Sonic. -""" - -import asyncio -import base64 -import os -import time - -import pyaudio -from strands_tools import calculator - -from strands.experimental.bidi.agent.agent import BidiAgent -from strands.experimental.bidi.models.novasonic import BidiNovaSonicModel - - -def test_direct_tools(): - """Test direct tool calling.""" - print("Testing direct tool calling...") - - # Check AWS credentials - if not all([os.getenv("AWS_ACCESS_KEY_ID"), os.getenv("AWS_SECRET_ACCESS_KEY")]): - print("AWS credentials not set - skipping test") - return - - try: - model = BidiNovaSonicModel() - agent = BidiAgent(model=model, tools=[calculator]) - - # Test calculator - result = agent.tool.calculator(expression="2 * 3") - content = result.get("content", [{}])[0].get("text", "") - print(f"Result: {content}") - print("Test completed") - - except Exception as e: - print(f"Test failed: {e}") - - -async def play(context): - """Play audio output with responsive interruption support.""" - audio = pyaudio.PyAudio() - speaker = audio.open( - channels=1, - format=pyaudio.paInt16, - output=True, - rate=16000, - frames_per_buffer=1024, - ) - - try: - while context["active"]: - try: - # Check for interruption first - if context.get("interrupted", False): - # Clear entire audio queue immediately - while not context["audio_out"].empty(): - try: - context["audio_out"].get_nowait() - except asyncio.QueueEmpty: - break - - context["interrupted"] = False - await asyncio.sleep(0.05) - continue - - # Get next audio data - audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) - - if audio_data and context["active"]: - chunk_size = 1024 - for i in range(0, len(audio_data), chunk_size): - # Check for interruption before each chunk - if context.get("interrupted", False) or not context["active"]: - break - - end = min(i + chunk_size, len(audio_data)) - chunk = audio_data[i:end] - speaker.write(chunk) - await asyncio.sleep(0.001) - - except asyncio.TimeoutError: - continue # No audio available - except asyncio.QueueEmpty: - await asyncio.sleep(0.01) - except asyncio.CancelledError: - break - - except asyncio.CancelledError: - pass - finally: - speaker.close() - audio.terminate() - - -async def record(context): - """Record audio input from microphone.""" - audio = pyaudio.PyAudio() - microphone = audio.open( - channels=1, - format=pyaudio.paInt16, - frames_per_buffer=1024, - input=True, - rate=16000, - ) - - try: - while context["active"]: - try: - audio_bytes = microphone.read(1024, exception_on_overflow=False) - context["audio_in"].put_nowait(audio_bytes) - await asyncio.sleep(0.01) - except asyncio.CancelledError: - break - except asyncio.CancelledError: - pass - finally: - microphone.close() - audio.terminate() - - -async def receive(agent, context): - """Receive and process events from agent.""" - try: - async for event in agent.receive(): - event_type = event.get("type", "unknown") - - # Handle audio stream events (bidi_audio_stream) - if event_type == "bidi_audio_stream": - if not context.get("interrupted", False): - # Decode base64 audio string to bytes for playback - audio_b64 = event["audio"] - audio_data = base64.b64decode(audio_b64) - context["audio_out"].put_nowait(audio_data) - - # Handle interruption events (bidi_interruption) - elif event_type == "bidi_interruption": - context["interrupted"] = True - - # Handle transcript events (bidi_transcript_stream) - elif event_type == "bidi_transcript_stream": - text_content = event.get("text", "") - role = event.get("role", "unknown") - - # Log transcript output - if role == "user": - print(f"User: {text_content}") - elif role == "assistant": - print(f"Assistant: {text_content}") - - # Handle response complete events (bidi_response_complete) - elif event_type == "bidi_response_complete": - # Reset interrupted state since the turn is complete - context["interrupted"] = False - - # Handle tool use events (tool_use_stream) - elif event_type == "tool_use_stream": - tool_use = event.get("current_tool_use", {}) - tool_name = tool_use.get("name", "unknown") - tool_input = tool_use.get("input", {}) - print(f"🔧 Tool called: {tool_name} with input: {tool_input}") - - # Handle tool result events (tool_result) - elif event_type == "tool_result": - tool_result = event.get("tool_result", {}) - tool_name = tool_result.get("name", "unknown") - result_content = tool_result.get("content", []) - result_text = "" - for block in result_content: - if isinstance(block, dict) and block.get("type") == "text": - result_text = block.get("text", "") - break - print(f"✅ Tool result from {tool_name}: {result_text}") - - except asyncio.CancelledError: - pass - - -async def send(agent, context): - """Send audio input to agent.""" - try: - while time.time() - context["start_time"] < context["duration"]: - try: - audio_bytes = context["audio_in"].get_nowait() - # Create audio event using TypedEvent - from strands.experimental.bidi.types.events import BidiAudioInputEvent - - audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") - audio_event = BidiAudioInputEvent(audio=audio_b64, format="pcm", sample_rate=16000, channels=1) - await agent.send(audio_event) - except asyncio.QueueEmpty: - await asyncio.sleep(0.01) # Restored to working timing - except asyncio.CancelledError: - break - - context["active"] = False - except asyncio.CancelledError: - pass - - -async def main(duration=180): - """Main function for bidirectional streaming test.""" - print("Starting bidirectional streaming test...") - print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") - - # Initialize model and agent - model = BidiNovaSonicModel(region="us-east-1") - agent = BidiAgent(model=model, tools=[calculator], system_prompt="You are a helpful assistant.") - - await agent.start() - - # Create shared context for all tasks - context = { - "active": True, - "audio_in": asyncio.Queue(), - "audio_out": asyncio.Queue(), - "connection": agent._loop, - "duration": duration, - "start_time": time.time(), - "interrupted": False, - } - - print("Speak into microphone. Press Ctrl+C to exit.") - - try: - # Run all tasks concurrently - await asyncio.gather( - play(context), record(context), receive(agent, context), send(agent, context), return_exceptions=True - ) - except KeyboardInterrupt: - print("\nInterrupted by user") - except asyncio.CancelledError: - print("\nTest cancelled") - finally: - print("Cleaning up...") - context["active"] = False - await agent.stop() - - -if __name__ == "__main__": - # Test direct tool calling first - test_direct_tools() - - asyncio.run(main()) diff --git a/scripts/bidi/test_bidi_openai.py b/scripts/bidi/test_bidi_openai.py deleted file mode 100644 index 50d2d2f55..000000000 --- a/scripts/bidi/test_bidi_openai.py +++ /dev/null @@ -1,308 +0,0 @@ -#!/usr/bin/env python3 -"""Test OpenAI Realtime API speech-to-speech interaction.""" - -import asyncio -import base64 -import os -import time - -import pyaudio -from strands_tools import calculator - -from strands.experimental.bidi.agent.agent import BidiAgent -from strands.experimental.bidi.models.openai import BidiOpenAIRealtimeModel - - -async def play(context): - """Handle audio playback with interruption support.""" - audio = pyaudio.PyAudio() - - try: - speaker = audio.open( - format=pyaudio.paInt16, - channels=1, - rate=24000, # OpenAI Realtime uses 24kHz - output=True, - frames_per_buffer=1024, - ) - - while context["active"]: - try: - # Check for interruption - if context.get("interrupted", False): - # Clear audio queue on interruption - while not context["audio_out"].empty(): - try: - context["audio_out"].get_nowait() - except asyncio.QueueEmpty: - break - - context["interrupted"] = False - await asyncio.sleep(0.05) - continue - - # Get audio data with timeout - try: - audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) - - if audio_data and context["active"]: - # Play in chunks to allow interruption - chunk_size = 1024 - for i in range(0, len(audio_data), chunk_size): - if context.get("interrupted", False) or not context["active"]: - break - - chunk = audio_data[i : i + chunk_size] - speaker.write(chunk) - await asyncio.sleep(0.001) # Brief pause for responsiveness - - except asyncio.TimeoutError: - continue - - except asyncio.CancelledError: - break - - except asyncio.CancelledError: - pass - except Exception as e: - print(f"Audio playback error: {e}") - finally: - try: - speaker.close() - except Exception: - pass - audio.terminate() - - -async def record(context): - """Handle microphone recording.""" - audio = pyaudio.PyAudio() - - try: - microphone = audio.open( - format=pyaudio.paInt16, - channels=1, - rate=24000, # Match OpenAI's expected input rate - input=True, - frames_per_buffer=1024, - ) - - while context["active"]: - try: - audio_bytes = microphone.read(1024, exception_on_overflow=False) - await context["audio_in"].put(audio_bytes) - await asyncio.sleep(0.01) - except asyncio.CancelledError: - break - - except asyncio.CancelledError: - pass - except Exception as e: - print(f"Microphone recording error: {e}") - finally: - try: - microphone.close() - except Exception: - pass - audio.terminate() - - -async def receive(agent, context): - """Handle events from the agent.""" - try: - async for event in agent.receive(): - if not context["active"]: - break - - # Get event type - event_type = event.get("type", "unknown") - - # Handle audio stream events (bidi_audio_stream) - if event_type == "bidi_audio_stream": - # Decode base64 audio string to bytes for playback - audio_b64 = event["audio"] - audio_data = base64.b64decode(audio_b64) - - if not context.get("interrupted", False): - await context["audio_out"].put(audio_data) - - # Handle transcript events (bidi_transcript_stream) - elif event_type == "bidi_transcript_stream": - source = event.get("role", "assistant") - text = event.get("text", "").strip() - - if text: - if source == "user": - print(f"🎤 User: {text}") - elif source == "assistant": - print(f"🔊 Assistant: {text}") - - # Handle interruption events (bidi_interruption) - elif event_type == "bidi_interruption": - context["interrupted"] = True - print("⚠️ Interruption detected") - - # Handle connection start events (bidi_connection_start) - elif event_type == "bidi_connection_start": - print(f"✓ Session started: {event.get('model', 'unknown')}") - - # Handle connection close events (bidi_connection_close) - elif event_type == "bidi_connection_close": - print(f"✓ Session ended: {event.get('reason', 'unknown')}") - context["active"] = False - break - - # Handle response complete events (bidi_response_complete) - elif event_type == "bidi_response_complete": - # Reset interrupted state since the turn is complete - context["interrupted"] = False - - # Handle tool use events (tool_use_stream) - elif event_type == "tool_use_stream": - tool_use = event.get("current_tool_use", {}) - tool_name = tool_use.get("name", "unknown") - tool_input = tool_use.get("input", {}) - print(f"🔧 Tool called: {tool_name} with input: {tool_input}") - - # Handle tool result events (tool_result) - elif event_type == "tool_result": - tool_result = event.get("tool_result", {}) - tool_name = tool_result.get("name", "unknown") - result_content = tool_result.get("content", []) - result_text = "" - for block in result_content: - if isinstance(block, dict) and block.get("type") == "text": - result_text = block.get("text", "") - break - print(f"✅ Tool result from {tool_name}: {result_text}") - - except asyncio.CancelledError: - pass - except Exception as e: - print(f"Receive handler error: {e}") - finally: - pass - - -async def send(agent, context): - """Send audio from microphone to agent.""" - try: - while context["active"]: - try: - audio_bytes = await asyncio.wait_for(context["audio_in"].get(), timeout=0.1) - - # Create audio event using TypedEvent - # Encode audio bytes to base64 string for JSON serializability - from strands.experimental.bidi.types.events import BidiAudioInputEvent - - audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") - audio_event = BidiAudioInputEvent(audio=audio_b64, format="pcm", sample_rate=24000, channels=1) - - await agent.send(audio_event) - - except asyncio.TimeoutError: - continue - except asyncio.CancelledError: - break - - except asyncio.CancelledError: - pass - except Exception as e: - print(f"Send handler error: {e}") - finally: - pass - - -async def main(): - """Main test function for OpenAI voice chat.""" - print("Starting OpenAI Realtime API test...") - - # Check API key - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - print("OPENAI_API_KEY environment variable not set") - return False - - # Check audio system - try: - audio = pyaudio.PyAudio() - audio.terminate() - except Exception as e: - print(f"Audio system error: {e}") - return False - - # Create OpenAI model - model = BidiOpenAIRealtimeModel( - model_id="gpt-4o-realtime-preview", - api_key=api_key, - session={ - "output_modalities": ["audio"], - "audio": { - "input": { - "format": {"type": "audio/pcm", "rate": 24000}, - "turn_detection": {"type": "server_vad", "threshold": 0.5, "silence_duration_ms": 700}, - }, - "output": {"format": {"type": "audio/pcm", "rate": 24000}, "voice": "alloy"}, - }, - }, - ) - - # Create agent - agent = BidiAgent( - model=model, - tools=[calculator], - system_prompt=( - "You are a helpful voice assistant. " - "Keep your responses brief and natural. " - "Say hello when you first connect." - ), - ) - - # Start the session - await agent.start() - - # Create shared context - context = { - "active": True, - "audio_in": asyncio.Queue(), - "audio_out": asyncio.Queue(), - "interrupted": False, - "start_time": time.time(), - } - - print("Speak into your microphone. Press Ctrl+C to stop.") - - try: - # Run all tasks concurrently - await asyncio.gather( - play(context), record(context), receive(agent, context), send(agent, context), return_exceptions=True - ) - - except KeyboardInterrupt: - print("\nInterrupted by user") - except asyncio.CancelledError: - print("\nTest cancelled") - except Exception as e: - print(f"\nError during voice chat: {e}") - finally: - print("Cleaning up...") - context["active"] = False - - try: - await agent.stop() - except Exception as e: - print(f"Cleanup error: {e}") - - return True - - -if __name__ == "__main__": - try: - asyncio.run(main()) - except KeyboardInterrupt: - print("\nTest interrupted by user") - except Exception as e: - print(f"Test error: {e}") - import traceback - - traceback.print_exc() diff --git a/scripts/bidi/test_gemini_live.py b/scripts/bidi/test_gemini_live.py deleted file mode 100644 index 656ca6dcd..000000000 --- a/scripts/bidi/test_gemini_live.py +++ /dev/null @@ -1,349 +0,0 @@ -"""Test suite for Gemini Live bidirectional streaming with camera support. - -Tests the Gemini Live API with real-time audio and video interaction including: -- Audio input/output streaming -- Camera frame capture and transmission -- Interruption handling -- Concurrent tool execution -- Transcript events - -Requirements: -- pip install opencv-python pillow pyaudio google-genai -- Camera access permissions -- GOOGLE_AI_API_KEY environment variable -""" - -import asyncio -import base64 -import io -import logging -import os -import time - -try: - import cv2 - import PIL.Image - - CAMERA_AVAILABLE = True -except ImportError as e: - print(f"Camera dependencies not available: {e}") - print("Install with: pip install opencv-python pillow") - CAMERA_AVAILABLE = False - -import pyaudio -from strands_tools import calculator - -from strands.experimental.bidi.agent.agent import BidiAgent -from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel - -# Configure logging - debug only for Gemini Live, info for everything else -logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") -gemini_logger = logging.getLogger("strands.experimental.bidirectional_streaming.models.gemini_live") -gemini_logger.setLevel(logging.WARNING) -logger = logging.getLogger(__name__) - - -async def play(context): - """Play audio output with responsive interruption support.""" - audio = pyaudio.PyAudio() - speaker = audio.open( - channels=1, - format=pyaudio.paInt16, - output=True, - rate=24000, - frames_per_buffer=1024, - ) - - try: - while context["active"]: - try: - # Check for interruption first - if context.get("interrupted", False): - # Clear entire audio queue immediately - while not context["audio_out"].empty(): - try: - context["audio_out"].get_nowait() - except asyncio.QueueEmpty: - break - - context["interrupted"] = False - await asyncio.sleep(0.05) - continue - - # Get next audio data - audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) - - if audio_data and context["active"]: - chunk_size = 1024 - for i in range(0, len(audio_data), chunk_size): - # Check for interruption before each chunk - if context.get("interrupted", False) or not context["active"]: - break - - end = min(i + chunk_size, len(audio_data)) - chunk = audio_data[i:end] - speaker.write(chunk) - await asyncio.sleep(0.001) - - except asyncio.TimeoutError: - continue # No audio available - except asyncio.QueueEmpty: - await asyncio.sleep(0.01) - except asyncio.CancelledError: - break - - except asyncio.CancelledError: - pass - finally: - speaker.close() - audio.terminate() - - -async def record(context): - """Record audio input from microphone.""" - audio = pyaudio.PyAudio() - - # List all available audio devices - print("Available audio devices:") - for i in range(audio.get_device_count()): - device_info = audio.get_device_info_by_index(i) - if device_info["maxInputChannels"] > 0: # Only show input devices - print(f" Device {i}: {device_info['name']} (inputs: {device_info['maxInputChannels']})") - - # Get default input device info - default_device = audio.get_default_input_device_info() - print(f"\nUsing default input device: {default_device['name']} (Device {default_device['index']})") - - microphone = audio.open( - channels=1, - format=pyaudio.paInt16, - frames_per_buffer=1024, - input=True, - rate=16000, - ) - - try: - while context["active"]: - try: - audio_bytes = microphone.read(1024, exception_on_overflow=False) - context["audio_in"].put_nowait(audio_bytes) - await asyncio.sleep(0.01) - except asyncio.CancelledError: - break - except asyncio.CancelledError: - pass - finally: - microphone.close() - audio.terminate() - - -async def receive(agent, context): - """Receive and process events from agent.""" - try: - async for event in agent.receive(): - event_type = event.get("type", "unknown") - - # Handle audio stream events (bidi_audio_stream) - if event_type == "bidi_audio_stream": - if not context.get("interrupted", False): - # Decode base64 audio string to bytes for playback - audio_b64 = event["audio"] - audio_data = base64.b64decode(audio_b64) - context["audio_out"].put_nowait(audio_data) - - # Handle interruption events (bidi_interruption) - elif event_type == "bidi_interruption": - context["interrupted"] = True - print("⚠️ Interruption detected") - - # Handle transcript events (bidi_transcript_stream) - elif event_type == "bidi_transcript_stream": - transcript_text = event.get("text", "") - transcript_role = event.get("role", "unknown") - - # Print transcripts with special formatting - if transcript_role == "user": - print(f"🎤 User: {transcript_text}") - elif transcript_role == "assistant": - print(f"🔊 Assistant: {transcript_text}") - - # Handle response complete events (bidi_response_complete) - elif event_type == "bidi_response_complete": - # Reset interrupted state since the response is complete - context["interrupted"] = False - - # Handle tool use events (tool_use_stream) - elif event_type == "tool_use_stream": - tool_use = event.get("current_tool_use", {}) - tool_name = tool_use.get("name", "unknown") - tool_input = tool_use.get("input", {}) - print(f"🔧 Tool called: {tool_name} with input: {tool_input}") - - # Handle tool result events (tool_result) - elif event_type == "tool_result": - tool_result = event.get("tool_result", {}) - tool_name = tool_result.get("name", "unknown") - result_content = tool_result.get("content", []) - # Extract text from content blocks - result_text = "" - for block in result_content: - if isinstance(block, dict) and block.get("type") == "text": - result_text = block.get("text", "") - break - print(f"✅ Tool result from {tool_name}: {result_text}") - - except asyncio.CancelledError: - pass - - -def _get_frame(cap): - """Capture and process a frame from camera.""" - if not CAMERA_AVAILABLE: - return None - - # Read the frame - ret, frame = cap.read() - # Check if the frame was read successfully - if not ret: - return None - # Convert BGR to RGB color space - # OpenCV captures in BGR but PIL expects RGB format - # This prevents the blue tint in the video feed - frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - img = PIL.Image.fromarray(frame_rgb) - img.thumbnail([1024, 1024]) - - image_io = io.BytesIO() - img.save(image_io, format="jpeg") - image_io.seek(0) - - mime_type = "image/jpeg" - image_bytes = image_io.read() - return {"mime_type": mime_type, "data": base64.b64encode(image_bytes).decode()} - - -async def get_frames(context): - """Capture frames from camera and send to agent.""" - if not CAMERA_AVAILABLE: - print("Camera not available - skipping video capture") - return - - # This takes about a second, and will block the whole program - # causing the audio pipeline to overflow if you don't to_thread it. - cap = await asyncio.to_thread(cv2.VideoCapture, 0) # 0 represents the default camera - - print("Camera initialized. Starting video capture...") - - try: - while context["active"] and time.time() - context["start_time"] < context["duration"]: - frame = await asyncio.to_thread(_get_frame, cap) - if frame is None: - break - - # Send frame to agent as image input - try: - from strands.experimental.bidi.types.events import BidiImageInputEvent - - image_event = BidiImageInputEvent( - image=frame["data"], # Already base64 encoded - mime_type=frame["mime_type"], - ) - await context["agent"].send(image_event) - print("📸 Frame sent to model") - except Exception as e: - logger.error("error=<%s> | error sending frame", e) - - # Wait 1 second between frames (1 FPS) - await asyncio.sleep(1.0) - - except asyncio.CancelledError: - pass - finally: - # Release the VideoCapture object - cap.release() - - -async def send(agent, context): - """Send audio input to agent.""" - try: - while time.time() - context["start_time"] < context["duration"]: - try: - audio_bytes = context["audio_in"].get_nowait() - # Create audio event using TypedEvent - from strands.experimental.bidi.types.events import BidiAudioInputEvent - - audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") - audio_event = BidiAudioInputEvent(audio=audio_b64, format="pcm", sample_rate=16000, channels=1) - await agent.send(audio_event) - except asyncio.QueueEmpty: - await asyncio.sleep(0.01) - except asyncio.CancelledError: - break - - context["active"] = False - except asyncio.CancelledError: - pass - - -async def main(duration=180): - """Main function for Gemini Live bidirectional streaming test with camera support.""" - print("Starting Gemini Live bidirectional streaming test with camera...") - print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") - print("Video: Camera frames sent at 1 FPS to model") - - # Get API key from environment variable - api_key = os.getenv("GOOGLE_AI_API_KEY") - - if not api_key: - print("ERROR: GOOGLE_AI_API_KEY environment variable not set") - print("Please set it with: export GOOGLE_AI_API_KEY=your_api_key") - return - - # Initialize Gemini Live model with proper configuration - logger.info("Initializing Gemini Live model with API key") - - # Use default model and config (includes transcription enabled by default) - model = BidiGeminiLiveModel(api_key=api_key) - logger.info("Gemini Live model initialized successfully") - print("Using Gemini Live model with default config (audio output + transcription enabled)") - - agent = BidiAgent(model=model, tools=[calculator], system_prompt="You are a helpful assistant.") - - await agent.start() - - # Create shared context for all tasks - context = { - "active": True, - "audio_in": asyncio.Queue(), - "audio_out": asyncio.Queue(), - "connection": agent._loop, - "duration": duration, - "start_time": time.time(), - "interrupted": False, - "agent": agent, # Add agent reference for camera task - } - - print("Speak into microphone and show things to camera. Press Ctrl+C to exit.") - - try: - # Run all tasks concurrently including camera - await asyncio.gather( - play(context), - record(context), - receive(agent, context), - send(agent, context), - get_frames(context), # Add camera task - return_exceptions=True, - ) - except KeyboardInterrupt: - print("\nInterrupted by user") - except asyncio.CancelledError: - print("\nTest cancelled") - finally: - print("Cleaning up...") - context["active"] = False - await agent.stop() - - -if __name__ == "__main__": - asyncio.run(main()) From 3c812b0431738d2b01a715a623ce30fe00921db4 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 26 Nov 2025 22:43:43 -0500 Subject: [PATCH 210/242] remove scripts directory before merging with sdk-python/main --- .../experimental/bidi/agent/test_agent.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/strands/experimental/bidi/agent/test_agent.py b/tests/strands/experimental/bidi/agent/test_agent.py index 79eb6b83b..9a56a3c63 100644 --- a/tests/strands/experimental/bidi/agent/test_agent.py +++ b/tests/strands/experimental/bidi/agent/test_agent.py @@ -102,7 +102,7 @@ def agent(mock_model, mock_tool_registry, mock_tool_caller): agent = BidiAgent(model=mock_model) return agent -def test_agent_initialization(): +def test_bidi_agent_init_with_various_configurations(): """Test agent initialization with various configurations.""" # Test default initialization mock_model = MockBidiModel() @@ -138,7 +138,7 @@ def test_agent_initialization(): assert config["audio"]["channels"] == 1 @pytest.mark.asyncio -async def test_agent_lifecycle(agent): +async def test_bidi_agent_start_stop_lifecycle(agent): """Test agent start/stop lifecycle and state management.""" # Initial state assert not agent._started @@ -169,7 +169,7 @@ async def test_agent_lifecycle(agent): assert agent.model._connection_id != connection_id @pytest.mark.asyncio -async def test_send_methods(agent): +async def test_bidi_agent_send_with_input_types(agent): """Test sending various input types through agent.send().""" await agent.start() @@ -203,7 +203,7 @@ async def test_send_methods(agent): assert len(agent.messages) == 5 # 2 + 3 new messages @pytest.mark.asyncio -async def test_receive_methods(agent): +async def test_bidi_agent_receive_events_from_model(agent): """Test receiving events from model.""" # Configure mock model to yield events events = [ @@ -251,7 +251,7 @@ async def test_receive_methods(agent): assert len(empty_events) >= 1 assert isinstance(empty_events[0], BidiConnectionStartEvent) -def test_agent_tools(agent, mock_tool_registry): +def test_bidi_agent_tool_integration(agent, mock_tool_registry): """Test agent tool integration and properties.""" # Test tool property access assert hasattr(agent, 'tool') @@ -271,7 +271,7 @@ def test_agent_tools(agent, mock_tool_registry): assert "weather" in tool_names @pytest.mark.asyncio -async def test_error_handling(agent): +async def test_bidi_agent_send_receive_error_before_start(agent): """Test error handling in various scenarios.""" # Test send before start with pytest.raises(RuntimeError, match="call start before"): @@ -295,7 +295,7 @@ async def test_error_handling(agent): @pytest.mark.asyncio -async def test_model_error_propagation(): +async def test_bidi_agent_start_receive_propagates_model_errors(): """Test that model errors are properly propagated.""" # Test model start error mock_model = MockBidiModel() @@ -320,7 +320,7 @@ async def failing_receive(): pass @pytest.mark.asyncio -async def test_agent_state_consistency(agent): +async def test_bidi_agent_state_consistency(agent): """Test that agent state remains consistent across operations.""" # Initial state assert not agent._started From 9c5835e1ec8780920aafff611ae95b0dabe77630 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 27 Nov 2025 00:06:14 -0500 Subject: [PATCH 211/242] bidi io text - switch to using prompt toolkit (#91) --- pyproject.toml | 1 + src/strands/experimental/bidi/io/text.py | 44 ++++++++++--------- .../strands/experimental/bidi/io/test_text.py | 11 +++-- 3 files changed, 30 insertions(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index eb61bf5b3..2a8b250fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ a2a = [ bidi = [ "aws_sdk_bedrock_runtime; python_version>='3.12'", + "prompt_toolkit>=3.0.0,<4.0.0", "pyaudio>=0.2.13,<1.0.0", "smithy-aws-core>=0.0.1; python_version>='3.12'", ] diff --git a/src/strands/experimental/bidi/io/text.py b/src/strands/experimental/bidi/io/text.py index e123de766..99056f38b 100644 --- a/src/strands/experimental/bidi/io/text.py +++ b/src/strands/experimental/bidi/io/text.py @@ -1,36 +1,28 @@ -"""Handle text input and output from bidi agent.""" +"""Handle text input and output to and from bidi agent.""" -import asyncio import logging -import sys -from typing import TYPE_CHECKING +from typing import Any + +from prompt_toolkit import PromptSession from ..types.events import BidiInterruptionEvent, BidiOutputEvent, BidiTextInputEvent, BidiTranscriptStreamEvent from ..types.io import BidiInput, BidiOutput -if TYPE_CHECKING: - from ..agent.agent import BidiAgent - logger = logging.getLogger(__name__) class _BidiTextInput(BidiInput): """Handle text input from user.""" - def __init__(self) -> None: - """Setup async stream reader.""" - self._reader = asyncio.StreamReader() - - async def start(self, agent: "BidiAgent") -> None: - """Connect reader to stdin.""" - loop = asyncio.get_running_loop() - protocol = asyncio.StreamReaderProtocol(self._reader) - await loop.connect_read_pipe(lambda: protocol, sys.stdin) + def __init__(self, config: dict[str, Any]) -> None: + """Extract configs and setup prompt session.""" + prompt = config.get("input_prompt", "") + self._session: PromptSession = PromptSession(prompt) async def __call__(self) -> BidiTextInputEvent: """Read user input from stdin.""" - text = (await self._reader.readline()).decode().strip() - return BidiTextInputEvent(text, role="user") + text = await self._session.prompt_async() + return BidiTextInputEvent(text.strip(), role="user") class _BidiTextOutput(BidiOutput): @@ -61,11 +53,23 @@ async def __call__(self, event: BidiOutputEvent) -> None: class BidiTextIO: - """Handle text input and output from bidi agent.""" + """Handle text input and output to and from bidi agent. + + Accepts input from stdin and outputs to stdout. + """ + + def __init__(self, **config: Any) -> None: + """Initialize I/O. + + Args: + **config: Optional I/O configurations. + - input_prompt (str): Input prompt to display on screen (default: blank) + """ + self._config = config def input(self) -> _BidiTextInput: """Return text processing BidiInput.""" - return _BidiTextInput() + return _BidiTextInput(self._config) def output(self) -> _BidiTextOutput: """Return text processing BidiOutput.""" diff --git a/tests/strands/experimental/bidi/io/test_text.py b/tests/strands/experimental/bidi/io/test_text.py index 9ecf22eaf..5507a8c0f 100644 --- a/tests/strands/experimental/bidi/io/test_text.py +++ b/tests/strands/experimental/bidi/io/test_text.py @@ -7,8 +7,8 @@ @pytest.fixture -def stream_reader(): - with unittest.mock.patch("strands.experimental.bidi.io.text.asyncio.StreamReader") as mock: +def prompt_session(): + with unittest.mock.patch("strands.experimental.bidi.io.text.PromptSession") as mock: yield mock.return_value @@ -28,9 +28,8 @@ def text_output(text_io): @pytest.mark.asyncio -async def test_bidi_text_io_input(stream_reader, text_input): - stream_reader.readline = unittest.mock.AsyncMock() - stream_reader.readline.return_value = b"test value" +async def test_bidi_text_io_input(prompt_session, text_input): + prompt_session.prompt_async = unittest.mock.AsyncMock(return_value="test value") tru_event = await text_input() exp_event = BidiTextInputEvent(text="test value", role="user") @@ -46,7 +45,7 @@ async def test_bidi_text_io_input(stream_reader, text_input): ] ) @pytest.mark.asyncio -async def test_bidi_text_io_output_interrupt(event, exp_print, text_output, capsys): +async def test_bidi_text_io_output(event, exp_print, text_output, capsys): await text_output(event) tru_print = capsys.readouterr().out.strip() From 651565e4b292d33567bebc2b62a8f012dfebf3c1 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 28 Nov 2025 10:16:14 -0500 Subject: [PATCH 212/242] types - events - minor clean ups (#92) --- .../experimental/bidi/models/gemini_live.py | 14 +- .../experimental/bidi/models/novasonic.py | 11 +- .../experimental/bidi/models/openai.py | 5 +- .../experimental/bidi/types/bidi_model.py | 12 +- src/strands/experimental/bidi/types/events.py | 137 ++++++++++-------- src/strands/experimental/hooks/events.py | 10 +- 6 files changed, 106 insertions(+), 83 deletions(-) diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 802551394..201773a9d 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -26,6 +26,8 @@ from .._async import stop_all from ..types.bidi_model import AudioConfig from ..types.events import ( + AudioChannel, + AudioSampleRate, BidiAudioInputEvent, BidiAudioStreamEvent, BidiConnectionStartEvent, @@ -36,18 +38,16 @@ BidiTextInputEvent, BidiTranscriptStreamEvent, BidiUsageEvent, - Channel, ModalityUsage, - SampleRate, ) from .bidi_model import BidiModel logger = logging.getLogger(__name__) # Audio format constants -GEMINI_INPUT_SAMPLE_RATE = 16000 -GEMINI_OUTPUT_SAMPLE_RATE = 24000 -GEMINI_CHANNELS = 1 +GEMINI_INPUT_SAMPLE_RATE: AudioSampleRate = 16000 +GEMINI_OUTPUT_SAMPLE_RATE: AudioSampleRate = 24000 +GEMINI_CHANNELS: AudioChannel = 1 class BidiGeminiLiveModel(BidiModel): @@ -274,8 +274,8 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> list[BidiOut BidiAudioStreamEvent( audio=audio_b64, format="pcm", - sample_rate=cast(SampleRate, GEMINI_OUTPUT_SAMPLE_RATE), - channels=cast(Channel, GEMINI_CHANNELS), + sample_rate=GEMINI_OUTPUT_SAMPLE_RATE, + channels=GEMINI_CHANNELS, ) ] diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 6ee917022..ed6d183d7 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -36,6 +36,8 @@ from .._async import stop_all from ..types.bidi_model import AudioConfig from ..types.events import ( + AudioChannel, + AudioSampleRate, BidiAudioInputEvent, BidiAudioStreamEvent, BidiConnectionStartEvent, @@ -47,7 +49,6 @@ BidiTextInputEvent, BidiTranscriptStreamEvent, BidiUsageEvent, - SampleRate, ) from .bidi_model import BidiModel @@ -139,9 +140,9 @@ def __init__( # Define default audio configuration default_audio_config: AudioConfig = { - "input_rate": cast(int, NOVA_AUDIO_INPUT_CONFIG["sampleRateHertz"]), - "output_rate": cast(int, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]), - "channels": cast(int, NOVA_AUDIO_INPUT_CONFIG["channelCount"]), + "input_rate": cast(AudioSampleRate, NOVA_AUDIO_INPUT_CONFIG["sampleRateHertz"]), + "output_rate": cast(AudioSampleRate, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]), + "channels": cast(AudioChannel, NOVA_AUDIO_INPUT_CONFIG["channelCount"]), "format": "pcm", "voice": cast(str, NOVA_AUDIO_OUTPUT_CONFIG["voiceId"]), } @@ -476,7 +477,7 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N return BidiAudioStreamEvent( audio=audio_content, format="pcm", - sample_rate=cast(SampleRate, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]), + sample_rate=cast(AudioSampleRate, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]), channels=channels, ) diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 2bdf62e9c..46a9de14f 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -19,6 +19,7 @@ from .._async import stop_all from ..types.bidi_model import AudioConfig from ..types.events import ( + AudioSampleRate, BidiAudioInputEvent, BidiAudioStreamEvent, BidiConnectionStartEvent, @@ -131,8 +132,8 @@ def __init__( # Define default audio configuration default_audio_config: AudioConfig = { - "input_rate": cast(int, AUDIO_FORMAT["rate"]), - "output_rate": cast(int, AUDIO_FORMAT["rate"]), + "input_rate": cast(AudioSampleRate, AUDIO_FORMAT["rate"]), + "output_rate": cast(AudioSampleRate, AUDIO_FORMAT["rate"]), "channels": 1, "format": "pcm", "voice": session_config_voice, diff --git a/src/strands/experimental/bidi/types/bidi_model.py b/src/strands/experimental/bidi/types/bidi_model.py index 32aaa0079..de41de1a9 100644 --- a/src/strands/experimental/bidi/types/bidi_model.py +++ b/src/strands/experimental/bidi/types/bidi_model.py @@ -5,7 +5,9 @@ processing requirements. """ -from typing import Literal, TypedDict +from typing import TypedDict + +from .events import AudioChannel, AudioFormat, AudioSampleRate class AudioConfig(TypedDict, total=False): @@ -27,8 +29,8 @@ class AudioConfig(TypedDict, total=False): voice: Voice identifier for text-to-speech (e.g., "alloy", "matthew") """ - input_rate: int - output_rate: int - channels: int - format: Literal["pcm", "wav", "opus", "mp3"] + input_rate: AudioSampleRate + output_rate: AudioSampleRate + channels: AudioChannel + format: AudioFormat voice: str diff --git a/src/strands/experimental/bidi/types/events.py b/src/strands/experimental/bidi/types/events.py index e9f53d0e6..95c0b5710 100644 --- a/src/strands/experimental/bidi/types/events.py +++ b/src/strands/experimental/bidi/types/events.py @@ -19,17 +19,36 @@ - Audio data stored as base64-encoded strings for JSON compatibility """ -from typing import Any, Dict, List, Literal, Optional, cast +from typing import Any, Literal, cast from ....types._events import ModelStreamEvent, ToolUseStreamEvent, TypedEvent from ....types.streaming import ContentBlockDelta -# Audio format constants +AudioChannel = Literal[1, 2] +"""Number of audio channels. +- Mono: 1 +- Stereo: 2 +""" AudioFormat = Literal["pcm", "wav", "opus", "mp3"] -SampleRate = Literal[16000, 24000, 48000] -Channel = Literal[1, 2] # 1=mono, 2=stereo +"""Audio encoding format.""" +AudioSampleRate = Literal[16000, 24000, 48000] +"""Audio sample rate in Hz.""" + Role = Literal["user", "assistant"] -StopReason = Literal["complete", "interrupted", "tool_use", "error"] +"""Role of a message sender. + +- "user": Messages from the user to the assistant. +- "assistant": Messages from the assistant to the user. +""" + +StopReason = Literal["complete", "error", "interrupted", "tool_use"] +"""Reason for the model ending its response generation. + +- "complete": Model completed its response. +- "error": Model encountered an error. +- "interrupted": Model was interrupted by the user. +- "tool_use": Model is requesting a tool use. +""" # ============================================================================ # Input Events (sent via agent.send()) @@ -59,7 +78,7 @@ def __init__(self, text: str, role: Role = "user"): @property def text(self) -> str: """The text content to send to the model.""" - return cast(str, self.get("text")) + return cast(str, self["text"]) @property def role(self) -> Role: @@ -83,8 +102,8 @@ def __init__( self, audio: str, format: AudioFormat | str, - sample_rate: SampleRate, - channels: Channel, + sample_rate: AudioSampleRate, + channels: AudioChannel, ): """Initialize audio input event.""" super().__init__( @@ -100,22 +119,22 @@ def __init__( @property def audio(self) -> str: """Base64-encoded audio string.""" - return cast(str, self.get("audio")) + return cast(str, self["audio"]) @property - def format(self) -> str: + def format(self) -> AudioFormat: """Audio encoding format.""" - return cast(str, self.get("format")) + return cast(AudioFormat, self["format"]) @property - def sample_rate(self) -> int: + def sample_rate(self) -> AudioSampleRate: """Number of audio samples per second in Hz.""" - return cast(int, self.get("sample_rate")) + return cast(AudioSampleRate, self["sample_rate"]) @property - def channels(self) -> int: + def channels(self) -> AudioChannel: """Number of audio channels (1=mono, 2=stereo).""" - return cast(int, self.get("channels")) + return cast(AudioChannel, self["channels"]) class BidiImageInputEvent(TypedEvent): @@ -145,12 +164,12 @@ def __init__( @property def image(self) -> str: """Base64-encoded image string.""" - return cast(str, self.get("image")) + return cast(str, self["image"]) @property def mime_type(self) -> str: """MIME type of the image (e.g., "image/jpeg", "image/png").""" - return cast(str, self.get("mime_type")) + return cast(str, self["mime_type"]) # ============================================================================ @@ -179,12 +198,12 @@ def __init__(self, connection_id: str, model: str): @property def connection_id(self) -> str: """Unique identifier for this streaming connection.""" - return cast(str, self.get("connection_id")) + return cast(str, self["connection_id"]) @property def model(self) -> str: """Model identifier (e.g., 'gpt-realtime', 'gemini-2.0-flash-live').""" - return cast(str, self.get("model")) + return cast(str, self["model"]) class BidiResponseStartEvent(TypedEvent): @@ -201,7 +220,7 @@ def __init__(self, response_id: str): @property def response_id(self) -> str: """Unique identifier for this response.""" - return cast(str, self.get("response_id")) + return cast(str, self["response_id"]) class BidiAudioStreamEvent(TypedEvent): @@ -218,8 +237,8 @@ def __init__( self, audio: str, format: AudioFormat, - sample_rate: SampleRate, - channels: Channel, + sample_rate: AudioSampleRate, + channels: AudioChannel, ): """Initialize audio stream event.""" super().__init__( @@ -235,22 +254,22 @@ def __init__( @property def audio(self) -> str: """Base64-encoded audio string.""" - return cast(str, self.get("audio")) + return cast(str, self["audio"]) @property - def format(self) -> str: + def format(self) -> AudioFormat: """Audio encoding format.""" - return cast(str, self.get("format")) + return cast(AudioFormat, self["format"]) @property - def sample_rate(self) -> int: + def sample_rate(self) -> AudioSampleRate: """Number of audio samples per second in Hz.""" - return cast(int, self.get("sample_rate")) + return cast(AudioSampleRate, self["sample_rate"]) @property - def channels(self) -> int: + def channels(self) -> AudioChannel: """Number of audio channels (1=mono, 2=stereo).""" - return cast(int, self.get("channels")) + return cast(AudioChannel, self["channels"]) class BidiTranscriptStreamEvent(ModelStreamEvent): @@ -273,7 +292,7 @@ def __init__( text: str, role: Role, is_final: bool, - current_transcript: Optional[str] = None, + current_transcript: str | None = None, ): """Initialize transcript stream event.""" super().__init__( @@ -290,12 +309,12 @@ def __init__( @property def delta(self) -> ContentBlockDelta: """The incremental transcript change.""" - return cast(ContentBlockDelta, self.get("delta")) + return cast(ContentBlockDelta, self["delta"]) @property def text(self) -> str: """The text content to send to the model.""" - return cast(str, self.get("text")) + return cast(str, self["text"]) @property def role(self) -> Role: @@ -305,12 +324,12 @@ def role(self) -> Role: @property def is_final(self) -> bool: """Whether this is the final/complete transcript.""" - return cast(bool, self.get("is_final")) + return cast(bool, self["is_final"]) @property - def current_transcript(self) -> Optional[str]: + def current_transcript(self) -> str | None: """The accumulated transcript text so far.""" - return cast(Optional[str], self.get("current_transcript")) + return cast(str | None, self.get("current_transcript")) class BidiInterruptionEvent(TypedEvent): @@ -333,7 +352,7 @@ def __init__(self, reason: Literal["user_speech", "error"]): @property def reason(self) -> str: """Why the interruption occurred.""" - return cast(str, self.get("reason")) + return cast(str, self["reason"]) class BidiResponseCompleteEvent(TypedEvent): @@ -361,12 +380,12 @@ def __init__( @property def response_id(self) -> str: """Unique identifier for this response.""" - return cast(str, self.get("response_id")) + return cast(str, self["response_id"]) @property - def stop_reason(self) -> str: + def stop_reason(self) -> StopReason: """Why the response ended.""" - return cast(str, self.get("stop_reason")) + return cast(StopReason, self["stop_reason"]) class ModalityUsage(dict): @@ -403,12 +422,12 @@ def __init__( input_tokens: int, output_tokens: int, total_tokens: int, - modality_details: Optional[List[ModalityUsage]] = None, - cache_read_input_tokens: Optional[int] = None, - cache_write_input_tokens: Optional[int] = None, + modality_details: list[ModalityUsage] | None = None, + cache_read_input_tokens: int | None = None, + cache_write_input_tokens: int | None = None, ): """Initialize usage event.""" - data: Dict[str, Any] = { + data: dict[str, Any] = { "type": "bidi_usage", "inputTokens": input_tokens, "outputTokens": output_tokens, @@ -425,32 +444,32 @@ def __init__( @property def input_tokens(self) -> int: """Total tokens used for all input modalities.""" - return cast(int, self.get("inputTokens")) + return cast(int, self["inputTokens"]) @property def output_tokens(self) -> int: """Total tokens used for all output modalities.""" - return cast(int, self.get("outputTokens")) + return cast(int, self["outputTokens"]) @property def total_tokens(self) -> int: """Sum of input and output tokens.""" - return cast(int, self.get("totalTokens")) + return cast(int, self["totalTokens"]) @property - def modality_details(self) -> List[ModalityUsage]: + def modality_details(self) -> list[ModalityUsage]: """Optional list of token usage per modality.""" - return cast(List[ModalityUsage], self.get("modality_details", [])) + return cast(list[ModalityUsage], self.get("modality_details", [])) @property - def cache_read_input_tokens(self) -> Optional[int]: + def cache_read_input_tokens(self) -> int | None: """Optional tokens read from cache.""" - return cast(Optional[int], self.get("cacheReadInputTokens")) + return cast(int | None, self.get("cacheReadInputTokens")) @property - def cache_write_input_tokens(self) -> Optional[int]: + def cache_write_input_tokens(self) -> int | None: """Optional tokens written to cache.""" - return cast(Optional[int], self.get("cacheWriteInputTokens")) + return cast(int | None, self.get("cacheWriteInputTokens")) class BidiConnectionCloseEvent(TypedEvent): @@ -478,12 +497,12 @@ def __init__( @property def connection_id(self) -> str: """Unique identifier for this streaming connection.""" - return cast(str, self.get("connection_id")) + return cast(str, self["connection_id"]) @property def reason(self) -> str: """Why the interruption occurred.""" - return cast(str, self.get("reason")) + return cast(str, self["reason"]) class BidiErrorEvent(TypedEvent): @@ -501,7 +520,7 @@ class BidiErrorEvent(TypedEvent): def __init__( self, error: Exception, - details: Optional[Dict[str, Any]] = None, + details: dict[str, Any] | None = None, ): """Initialize error event.""" # Store serializable data in dict (for JSON serialization) @@ -527,17 +546,17 @@ def error(self) -> Exception: @property def code(self) -> str: """Error code derived from exception class name.""" - return cast(str, self.get("code")) + return cast(str, self["code"]) @property def message(self) -> str: """Human-readable error message from the exception.""" - return cast(str, self.get("message")) + return cast(str, self["message"]) @property - def details(self) -> Optional[Dict[str, Any]]: + def details(self) -> dict[str, Any] | None: """Additional error context beyond the exception itself.""" - return cast(Optional[Dict[str, Any]], self.get("details")) + return cast(dict[str, Any] | None, self.get("details")) # ============================================================================ diff --git a/src/strands/experimental/hooks/events.py b/src/strands/experimental/hooks/events.py index 485e8d201..6403df02f 100644 --- a/src/strands/experimental/hooks/events.py +++ b/src/strands/experimental/hooks/events.py @@ -7,7 +7,7 @@ import warnings from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal, Optional, TypeAlias +from typing import TYPE_CHECKING, Any, Literal, TypeAlias from ...hooks.events import AfterModelCallEvent, AfterToolCallEvent, BeforeModelCallEvent, BeforeToolCallEvent from ...hooks.registry import BaseHookEvent @@ -129,7 +129,7 @@ class BidiBeforeToolCallEvent(BidiHookEvent): the tool call and use a default cancel message. """ - selected_tool: Optional[AgentTool] + selected_tool: AgentTool | None tool_use: ToolUse invocation_state: dict[str, Any] cancel_tool: bool | str = False @@ -160,11 +160,11 @@ class BidiAfterToolCallEvent(BidiHookEvent): cancel_message: The cancellation message if the user cancelled the tool call. """ - selected_tool: Optional[AgentTool] + selected_tool: AgentTool | None tool_use: ToolUse invocation_state: dict[str, Any] result: ToolResult - exception: Optional[Exception] = None + exception: Exception | None = None cancel_message: str | None = None def _can_write(self, name: str) -> bool: @@ -193,4 +193,4 @@ class BidiInterruptionEvent(BidiHookEvent): """ reason: Literal["user_speech", "error"] - interrupted_response_id: Optional[str] = None + interrupted_response_id: str | None = None From be4c10f2cb013ca9f29e67a95ee6c6e141d07f32 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 28 Nov 2025 10:50:06 -0500 Subject: [PATCH 213/242] model timeout - restart connection - nova sonic (#84) --- src/strands/experimental/bidi/agent/loop.py | 59 ++++++++++++++++++- .../experimental/bidi/models/__init__.py | 3 +- .../experimental/bidi/models/bidi_model.py | 11 ++++ .../experimental/bidi/models/novasonic.py | 17 +++++- .../experimental/bidi/types/__init__.py | 2 + src/strands/experimental/bidi/types/events.py | 28 ++++++++- src/strands/experimental/hooks/events.py | 24 ++++++++ .../experimental/bidi/agent/test_loop.py | 51 ++++++++++++++++ .../bidi/models/test_novasonic.py | 28 +++++++++ 9 files changed, 216 insertions(+), 7 deletions(-) create mode 100644 tests/strands/experimental/bidi/agent/test_loop.py diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 4d00cd714..76907b6fd 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -11,7 +11,9 @@ from ....types.content import Message from ....types.tools import ToolResult, ToolUse from ...hooks.events import ( + BidiAfterConnectionRestartEvent, BidiAfterInvocationEvent, + BidiBeforeConnectionRestartEvent, BidiBeforeInvocationEvent, BidiMessageAddedEvent, ) @@ -19,7 +21,9 @@ BidiInterruptionEvent as BidiInterruptionHookEvent, ) from .._async import _TaskPool, stop_all +from ..models import BidiModelTimeoutError from ..types.events import ( + BidiConnectionRestartEvent, BidiInputEvent, BidiInterruptionEvent, BidiOutputEvent, @@ -44,6 +48,8 @@ class _BidiAgentLoop: _invocation_state: Optional context to pass to tools during execution. This allows passing custom data (user_id, session_id, database connections, etc.) that tools can access via their invocation_state parameter. + _send_gate: Gate the sending of events to the model. + Blocks when agent is reseting the model connection after timeout. """ def __init__(self, agent: "BidiAgent") -> None: @@ -60,6 +66,8 @@ def __init__(self, agent: "BidiAgent") -> None: self._event_queue: asyncio.Queue self._invocation_state: dict[str, Any] + self._send_gate = asyncio.Event() + async def start(self, invocation_state: dict[str, Any] | None = None) -> None: """Start the agent loop. @@ -92,6 +100,7 @@ async def start(self, invocation_state: dict[str, Any] | None = None) -> None: self._task_pool.create(self._run_model()) self._invocation_state = invocation_state or {} + self._send_gate.set() self._started = True async def stop(self) -> None: @@ -99,6 +108,7 @@ async def stop(self) -> None: logger.debug("agent loop stopping") self._started = False + self._send_gate.clear() self._invocation_state = {} async def stop_tasks() -> None: @@ -112,14 +122,21 @@ async def stop_model() -> None: finally: await self._agent.hooks.invoke_callbacks_async(BidiAfterInvocationEvent(agent=self._agent)) - async def send(self, event: BidiInputEvent) -> None: + async def send(self, event: BidiInputEvent | ToolResultEvent) -> None: """Send model event. - Additional, add text input to messages array. + Additionally, add text input to messages array. Args: event: BidiInputEvent. """ + if not self._started: + raise RuntimeError("loop not started | call start before sending") + + if not self._send_gate.is_set(): + logger.debug("waiting for model send signal") + await self._send_gate.wait() + if isinstance(event, BidiTextInputEvent): message: Message = {"role": "user", "content": [{"text": event.text}]} self._agent.messages.append(message) @@ -138,11 +155,47 @@ async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: while True: event = await self._event_queue.get() + if isinstance(event, BidiModelTimeoutError): + logger.debug("model timeout error received") + yield BidiConnectionRestartEvent(event) + await self._restart_connection(event) + continue + if isinstance(event, Exception): raise event yield event + async def _restart_connection(self, timeout_error: BidiModelTimeoutError) -> None: + """Restart the model connection after timeout. + + Args: + timeout_error: Timeout error reported by the model. + """ + logger.debug("reseting model connection") + + self._send_gate.clear() + + await self._agent.hooks.invoke_callbacks_async(BidiBeforeConnectionRestartEvent(self._agent, timeout_error)) + + restart_exception = None + try: + await self._agent.model.stop() + await self._agent.model.start( + self._agent.system_prompt, + self._agent.tool_registry.get_all_tool_specs(), + self._agent.messages, + ) + self._task_pool.create(self._run_model()) + except Exception as exception: + restart_exception = exception + finally: + await self._agent.hooks.invoke_callbacks_async( + BidiAfterConnectionRestartEvent(self._agent, restart_exception) + ) + + self._send_gate.set() + async def _run_model(self) -> None: """Task for running the model. @@ -217,7 +270,7 @@ async def _run_tool(self, tool_use: ToolUse) -> None: if isinstance(event, ToolResultEvent): result = event.tool_result - await self._agent.model.send(ToolResultEvent(result)) + await self.send(ToolResultEvent(result)) message: Message = { "role": "user", diff --git a/src/strands/experimental/bidi/models/__init__.py b/src/strands/experimental/bidi/models/__init__.py index 13aaa9697..d1221df36 100644 --- a/src/strands/experimental/bidi/models/__init__.py +++ b/src/strands/experimental/bidi/models/__init__.py @@ -1,12 +1,13 @@ """Bidirectional model interfaces and implementations.""" -from .bidi_model import BidiModel +from .bidi_model import BidiModel, BidiModelTimeoutError from .gemini_live import BidiGeminiLiveModel from .novasonic import BidiNovaSonicModel from .openai import BidiOpenAIRealtimeModel __all__ = [ "BidiModel", + "BidiModelTimeoutError", "BidiGeminiLiveModel", "BidiNovaSonicModel", "BidiOpenAIRealtimeModel", diff --git a/src/strands/experimental/bidi/models/bidi_model.py b/src/strands/experimental/bidi/models/bidi_model.py index 1fb765bd8..bc2806e78 100644 --- a/src/strands/experimental/bidi/models/bidi_model.py +++ b/src/strands/experimental/bidi/models/bidi_model.py @@ -108,3 +108,14 @@ async def send( await model.send(ToolResultEvent(tool_result)) """ ... + + +class BidiModelTimeoutError(Exception): + """Model timeout error. + + Bidirectional models are often configured with a connection time limit. Nova sonic for example keeps the connection + open for 8 minutes max. Upon receiving a timeout, the agent loop is configured to restart the model connection so as + to create a seamless, uninterrupted experience for the user. + """ + + pass diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index ed6d183d7..a4e2952d0 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -25,6 +25,8 @@ from aws_sdk_bedrock_runtime.models import ( BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk, + ModelTimeoutException, + ValidationException, ) from smithy_aws_core.identity.static import StaticCredentialsResolver from smithy_core.aio.eventstream import DuplexEventStream @@ -50,7 +52,7 @@ BidiTranscriptStreamEvent, BidiUsageEvent, ) -from .bidi_model import BidiModel +from .bidi_model import BidiModel, BidiModelTimeoutError logger = logging.getLogger(__name__) @@ -272,7 +274,18 @@ async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: _, output = await self._stream.await_output() while True: - event_data = await output.receive() + try: + event_data = await output.receive() + + except ValidationException as error: + if "InternalErrorCode=531" in str(error): + # nova also times out if user is silent for 175 seconds + raise BidiModelTimeoutError(error) from error + raise + + except ModelTimeoutException as error: + raise BidiModelTimeoutError(error) from error + if not event_data: continue diff --git a/src/strands/experimental/bidi/types/__init__.py b/src/strands/experimental/bidi/types/__init__.py index 1fa1d9048..903a54508 100644 --- a/src/strands/experimental/bidi/types/__init__.py +++ b/src/strands/experimental/bidi/types/__init__.py @@ -5,6 +5,7 @@ BidiAudioInputEvent, BidiAudioStreamEvent, BidiConnectionCloseEvent, + BidiConnectionRestartEvent, BidiConnectionStartEvent, BidiErrorEvent, BidiImageInputEvent, @@ -31,6 +32,7 @@ "BidiInputEvent", # Output Events "BidiConnectionStartEvent", + "BidiConnectionRestartEvent", "BidiConnectionCloseEvent", "BidiResponseStartEvent", "BidiResponseCompleteEvent", diff --git a/src/strands/experimental/bidi/types/events.py b/src/strands/experimental/bidi/types/events.py index 95c0b5710..e6504c2a4 100644 --- a/src/strands/experimental/bidi/types/events.py +++ b/src/strands/experimental/bidi/types/events.py @@ -19,11 +19,14 @@ - Audio data stored as base64-encoded strings for JSON compatibility """ -from typing import Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, cast from ....types._events import ModelStreamEvent, ToolUseStreamEvent, TypedEvent from ....types.streaming import ContentBlockDelta +if TYPE_CHECKING: + from ..models.bidi_model import BidiModelTimeoutError + AudioChannel = Literal[1, 2] """Number of audio channels. - Mono: 1 @@ -206,6 +209,28 @@ def model(self) -> str: return cast(str, self["model"]) +class BidiConnectionRestartEvent(TypedEvent): + """Agent is restarting the model connection after timeout.""" + + def __init__(self, timeout_error: "BidiModelTimeoutError"): + """Initialize. + + Args: + timeout_error: Timeout error reported by the model. + """ + super().__init__( + { + "type": "bidi_connection_restart", + "timeout_error": timeout_error, + } + ) + + @property + def timeout_error(self) -> "BidiModelTimeoutError": + """Model timeout error.""" + return cast("BidiModelTimeoutError", self["timeout_error"]) + + class BidiResponseStartEvent(TypedEvent): """Model starts generating a response. @@ -570,6 +595,7 @@ def details(self) -> dict[str, Any] | None: BidiOutputEvent = ( BidiConnectionStartEvent + | BidiConnectionRestartEvent | BidiResponseStartEvent | BidiAudioStreamEvent | BidiTranscriptStreamEvent diff --git a/src/strands/experimental/hooks/events.py b/src/strands/experimental/hooks/events.py index 6403df02f..f486f5ec4 100644 --- a/src/strands/experimental/hooks/events.py +++ b/src/strands/experimental/hooks/events.py @@ -16,6 +16,7 @@ if TYPE_CHECKING: from ..bidi.agent.agent import BidiAgent + from ..bidi.models import BidiModelTimeoutError warnings.warn( "These events have been moved to production with updated names. Use BeforeModelCallEvent, " @@ -194,3 +195,26 @@ class BidiInterruptionEvent(BidiHookEvent): reason: Literal["user_speech", "error"] interrupted_response_id: str | None = None + + +@dataclass +class BidiBeforeConnectionRestartEvent(BidiHookEvent): + """Event emitted before agent attempts to restart model connection after timeout. + + Attributes: + timeout_error: Timeout error reported by the model. + """ + + timeout_error: "BidiModelTimeoutError" + + +@dataclass +class BidiAfterConnectionRestartEvent(BidiHookEvent): + """Event emitted after agent attempts to restart model connection after timeout. + + Attribtues: + exception: Populated if exception was raised during connection restart. + None value means the restart was successful. + """ + + exception: Exception | None = None diff --git a/tests/strands/experimental/bidi/agent/test_loop.py b/tests/strands/experimental/bidi/agent/test_loop.py new file mode 100644 index 000000000..1ec5712af --- /dev/null +++ b/tests/strands/experimental/bidi/agent/test_loop.py @@ -0,0 +1,51 @@ +import unittest.mock + +import pytest +import pytest_asyncio + +from strands.experimental.bidi.agent.loop import _BidiAgentLoop +from strands.experimental.bidi.models import BidiModelTimeoutError +from strands.experimental.bidi.types.events import BidiConnectionRestartEvent, BidiTextInputEvent + + +@pytest.fixture +def agent(): + mock = unittest.mock.Mock() + mock.hooks = unittest.mock.AsyncMock() + mock.model = unittest.mock.AsyncMock() + return mock + + +@pytest_asyncio.fixture +async def loop(agent): + return _BidiAgentLoop(agent) + + +@pytest.mark.asyncio +async def test_bidi_agent_loop_receive_restart_connection(loop, agent, agenerator): + timeout_error = BidiModelTimeoutError("test timeout") + text_event = BidiTextInputEvent(text="test after restart") + + agent.model.receive = unittest.mock.Mock(side_effect=[timeout_error, agenerator([text_event])]) + + await loop.start() + + tru_events = [] + async for event in loop.receive(): + tru_events.append(event) + if len(tru_events) >= 2: + break + + exp_events = [ + BidiConnectionRestartEvent(timeout_error), + text_event, + ] + assert tru_events == exp_events + + agent.model.stop.assert_called_once() + assert agent.model.start.call_count == 2 + agent.model.start.assert_any_call( + agent.system_prompt, + agent.tool_registry.get_all_tool_specs.return_value, + agent.messages, + ) diff --git a/tests/strands/experimental/bidi/models/test_novasonic.py b/tests/strands/experimental/bidi/models/test_novasonic.py index 19aac19c2..744dedf12 100644 --- a/tests/strands/experimental/bidi/models/test_novasonic.py +++ b/tests/strands/experimental/bidi/models/test_novasonic.py @@ -11,10 +11,12 @@ import pytest import pytest_asyncio +from aws_sdk_bedrock_runtime.models import ModelTimeoutException, ValidationException from strands.experimental.bidi.models.novasonic import ( BidiNovaSonicModel, ) +from strands.experimental.bidi.models.bidi_model import BidiModelTimeoutError from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, @@ -524,6 +526,32 @@ async def test_message_history_empty_and_edge_cases(nova_model): # Error Handling Tests +@pytest.mark.asyncio +async def test_bidi_nova_sonic_model_receive_timeout(nova_model, mock_stream): + mock_output = AsyncMock() + mock_output.receive.side_effect = ModelTimeoutException("Connection timeout") + mock_stream.await_output.return_value = (None, mock_output) + + await nova_model.start() + + with pytest.raises(BidiModelTimeoutError): + async for _ in nova_model.receive(): + pass + + +@pytest.mark.asyncio +async def test_bidi_nova_sonic_model_receive_timeout_validation(nova_model, mock_stream): + mock_output = AsyncMock() + mock_output.receive.side_effect = ValidationException("InternalErrorCode=531: Request timeout") + mock_stream.await_output.return_value = (None, mock_output) + + await nova_model.start() + + with pytest.raises(BidiModelTimeoutError): + async for _ in nova_model.receive(): + pass + + @pytest.mark.asyncio async def test_error_handling(nova_model, mock_stream): """Test error handling in various scenarios.""" From 2b3438a5ecab395040686b0666b296525706105e Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Fri, 28 Nov 2025 11:24:36 -0500 Subject: [PATCH 214/242] add feature to stop connection through voice --- src/strands/experimental/bidi/__init__.py | 5 ++++ src/strands/experimental/bidi/agent/agent.py | 23 ++++++++++++++++++- src/strands/experimental/bidi/agent/loop.py | 10 ++++++++ src/strands/experimental/bidi/io/text.py | 13 ++++++++++- .../experimental/bidi/tools/__init__.py | 5 ++++ .../bidi/tools/stop_connection.py | 19 +++++++++++++++ src/strands/experimental/bidi/types/events.py | 2 +- 7 files changed, 74 insertions(+), 3 deletions(-) create mode 100644 src/strands/experimental/bidi/tools/__init__.py create mode 100644 src/strands/experimental/bidi/tools/stop_connection.py diff --git a/src/strands/experimental/bidi/__init__.py b/src/strands/experimental/bidi/__init__.py index 712451af9..9549cb805 100644 --- a/src/strands/experimental/bidi/__init__.py +++ b/src/strands/experimental/bidi/__init__.py @@ -25,6 +25,9 @@ from .models.novasonic import BidiNovaSonicModel from .models.openai import BidiOpenAIRealtimeModel +# Built-in tools +from .tools import stop_connection + # Event types - For type hints and event handling from .types.events import ( BidiAudioInputEvent, @@ -53,6 +56,8 @@ "BidiGeminiLiveModel", "BidiNovaSonicModel", "BidiOpenAIRealtimeModel", + # Built-in tools + "stop_connection", # Input Event types "BidiTextInputEvent", "BidiAudioInputEvent", diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 1ec4a63ee..c11bb426c 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -33,7 +33,14 @@ from ..models.bidi_model import BidiModel from ..models.novasonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput -from ..types.events import BidiAudioInputEvent, BidiImageInputEvent, BidiInputEvent, BidiOutputEvent, BidiTextInputEvent +from ..types.events import ( + BidiAudioInputEvent, + BidiConnectionCloseEvent, + BidiImageInputEvent, + BidiInputEvent, + BidiOutputEvent, + BidiTextInputEvent, +) from ..types.io import BidiInput, BidiOutput from .loop import _BidiAgentLoop @@ -358,6 +365,9 @@ async def run_inputs() -> None: async def task(input_: BidiInput) -> None: while True: event = await input_() + if not self._started: + logger.debug("agent stopped, exiting input task") + return await self.send(event) tasks = [task(input_) for input_ in inputs] @@ -368,6 +378,17 @@ async def run_outputs() -> None: tasks = [output(event) for output in outputs] await asyncio.gather(*tasks) + if isinstance(event, BidiConnectionCloseEvent) and event.reason == "user_request": + logger.info( + "connection_id=<%s>, reason=<%s> | graceful shutdown initiated", + event.connection_id, + event.reason, + ) + # Set flag to signal shutdown + self._started = False + # Return - TaskGroup will cancel remaining tasks, finally block handles cleanup + return + try: await self.start(invocation_state) diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 4d00cd714..9f9c50792 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -20,6 +20,7 @@ ) from .._async import _TaskPool, stop_all from ..types.events import ( + BidiConnectionCloseEvent, BidiInputEvent, BidiInterruptionEvent, BidiOutputEvent, @@ -227,5 +228,14 @@ async def _run_tool(self, tool_use: ToolUse) -> None: await self._agent.hooks.invoke_callbacks_async(BidiMessageAddedEvent(agent=self._agent, message=message)) await self._event_queue.put(ToolResultMessageEvent(message)) + # Check if this was the stop_connection tool + if tool_use["name"] == "stop_connection": + logger.info("tool_name=<%s> | connection stop requested by tool", tool_use["name"]) + # Get connection_id from the model + connection_id = getattr(self._agent.model, "_connection_id", None) or "unknown" + await self._event_queue.put( + BidiConnectionCloseEvent(connection_id=connection_id, reason="user_request") + ) + except Exception as error: await self._event_queue.put(error) diff --git a/src/strands/experimental/bidi/io/text.py b/src/strands/experimental/bidi/io/text.py index e123de766..b7daad832 100644 --- a/src/strands/experimental/bidi/io/text.py +++ b/src/strands/experimental/bidi/io/text.py @@ -5,7 +5,13 @@ import sys from typing import TYPE_CHECKING -from ..types.events import BidiInterruptionEvent, BidiOutputEvent, BidiTextInputEvent, BidiTranscriptStreamEvent +from ..types.events import ( + BidiConnectionCloseEvent, + BidiInterruptionEvent, + BidiOutputEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, +) from ..types.io import BidiInput, BidiOutput if TYPE_CHECKING: @@ -42,6 +48,11 @@ async def __call__(self, event: BidiOutputEvent) -> None: logger.debug("reason=<%s> | text output interrupted", event["reason"]) print("interrupted") + elif isinstance(event, BidiConnectionCloseEvent): + if event.reason == "user_request": + logger.debug("connection_id=<%s> | user requested connection close", event.connection_id) + else: + logger.debug("connection_id=<%s>, reason=<%s> | connection closed", event.connection_id, event.reason) elif isinstance(event, BidiTranscriptStreamEvent): text = event["text"] is_final = event["is_final"] diff --git a/src/strands/experimental/bidi/tools/__init__.py b/src/strands/experimental/bidi/tools/__init__.py new file mode 100644 index 000000000..e821ada5d --- /dev/null +++ b/src/strands/experimental/bidi/tools/__init__.py @@ -0,0 +1,5 @@ +"""Built-in tools for bidirectional agents.""" + +from .stop_connection import stop_connection + +__all__ = ["stop_connection"] diff --git a/src/strands/experimental/bidi/tools/stop_connection.py b/src/strands/experimental/bidi/tools/stop_connection.py new file mode 100644 index 000000000..9fdd78975 --- /dev/null +++ b/src/strands/experimental/bidi/tools/stop_connection.py @@ -0,0 +1,19 @@ +"""Tool to gracefully stop a bidirectional connection.""" + +from ....tools.decorator import tool + + +@tool +def stop_connection() -> dict: + """Stop the bidirectional conversation gracefully. + + Use this tool when the user wants to end the conversation, such as when they say: + goodbye, bye, end conversation, stop, exit, quit, that's all, or I'm done. + + Returns: + Success message confirming the conversation will end + """ + return { + "status": "success", + "content": [{"text": "Ending conversation"}], + } diff --git a/src/strands/experimental/bidi/types/events.py b/src/strands/experimental/bidi/types/events.py index e9f53d0e6..c1ebf2a84 100644 --- a/src/strands/experimental/bidi/types/events.py +++ b/src/strands/experimental/bidi/types/events.py @@ -464,7 +464,7 @@ class BidiConnectionCloseEvent(TypedEvent): def __init__( self, connection_id: str, - reason: Literal["client_disconnect", "timeout", "error", "complete"], + reason: Literal["client_disconnect", "timeout", "error", "complete", "user_request"], ): """Initialize connection close event.""" super().__init__( From 2cf2c13c4a52b3dc9ba660d33157e8459b260533 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Fri, 28 Nov 2025 23:35:05 -0500 Subject: [PATCH 215/242] update stop_connection tool to stop_coversation --- src/strands/experimental/bidi/__init__.py | 4 ++-- src/strands/experimental/bidi/agent/agent.py | 19 ++++++------------- src/strands/experimental/bidi/agent/loop.py | 9 ++++----- src/strands/experimental/bidi/io/text.py | 3 +-- .../experimental/bidi/tools/__init__.py | 4 ++-- .../bidi/tools/stop_connection.py | 19 ------------------- 6 files changed, 15 insertions(+), 43 deletions(-) delete mode 100644 src/strands/experimental/bidi/tools/stop_connection.py diff --git a/src/strands/experimental/bidi/__init__.py b/src/strands/experimental/bidi/__init__.py index 9549cb805..7e2ad2bb3 100644 --- a/src/strands/experimental/bidi/__init__.py +++ b/src/strands/experimental/bidi/__init__.py @@ -26,7 +26,7 @@ from .models.openai import BidiOpenAIRealtimeModel # Built-in tools -from .tools import stop_connection +from .tools import stop_conversation # Event types - For type hints and event handling from .types.events import ( @@ -57,7 +57,7 @@ "BidiNovaSonicModel", "BidiOpenAIRealtimeModel", # Built-in tools - "stop_connection", + "stop_conversation", # Input Event types "BidiTextInputEvent", "BidiAudioInputEvent", diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index c11bb426c..11e533b08 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -365,18 +365,13 @@ async def run_inputs() -> None: async def task(input_: BidiInput) -> None: while True: event = await input_() - if not self._started: - logger.debug("agent stopped, exiting input task") - return await self.send(event) - tasks = [task(input_) for input_ in inputs] - await asyncio.gather(*tasks) + await asyncio.gather(*[task(input_) for input_ in inputs]) - async def run_outputs() -> None: + async def run_outputs(inputs_task: asyncio.Task) -> None: async for event in self.receive(): - tasks = [output(event) for output in outputs] - await asyncio.gather(*tasks) + await asyncio.gather(*[output(event) for output in outputs]) if isinstance(event, BidiConnectionCloseEvent) and event.reason == "user_request": logger.info( @@ -384,9 +379,7 @@ async def run_outputs() -> None: event.connection_id, event.reason, ) - # Set flag to signal shutdown - self._started = False - # Return - TaskGroup will cancel remaining tasks, finally block handles cleanup + inputs_task.cancel() return try: @@ -398,8 +391,8 @@ async def run_outputs() -> None: await start(self) async with asyncio.TaskGroup() as task_group: - task_group.create_task(run_inputs()) - task_group.create_task(run_outputs()) + inputs_task = task_group.create_task(run_inputs()) + task_group.create_task(run_outputs(inputs_task)) finally: input_stops = [input_.stop for input_ in inputs if isinstance(input_, BidiInput)] diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 2160d4ecd..119fd07b9 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -281,11 +281,10 @@ async def _run_tool(self, tool_use: ToolUse) -> None: await self._agent.hooks.invoke_callbacks_async(BidiMessageAddedEvent(agent=self._agent, message=message)) await self._event_queue.put(ToolResultMessageEvent(message)) - # Check if this was the stop_connection tool - if tool_use["name"] == "stop_connection": - logger.info("tool_name=<%s> | connection stop requested by tool", tool_use["name"]) - # Get connection_id from the model - connection_id = getattr(self._agent.model, "_connection_id", None) or "unknown" + # Check if this was the stop_conversation tool + if tool_use["name"] == "stop_conversation": + logger.info("tool_name=<%s> | conversation stop requested by tool", tool_use["name"]) + connection_id = getattr(self._agent.model, "_connection_id", "unknown") await self._event_queue.put( BidiConnectionCloseEvent(connection_id=connection_id, reason="user_request") ) diff --git a/src/strands/experimental/bidi/io/text.py b/src/strands/experimental/bidi/io/text.py index 31577ecf0..1fe906de0 100644 --- a/src/strands/experimental/bidi/io/text.py +++ b/src/strands/experimental/bidi/io/text.py @@ -42,9 +42,8 @@ async def __call__(self, event: BidiOutputEvent) -> None: elif isinstance(event, BidiConnectionCloseEvent): if event.reason == "user_request": + print("user requested connection close using the stop_conversation tool.") logger.debug("connection_id=<%s> | user requested connection close", event.connection_id) - else: - logger.debug("connection_id=<%s>, reason=<%s> | connection closed", event.connection_id, event.reason) elif isinstance(event, BidiTranscriptStreamEvent): text = event["text"] is_final = event["is_final"] diff --git a/src/strands/experimental/bidi/tools/__init__.py b/src/strands/experimental/bidi/tools/__init__.py index e821ada5d..c665dc65a 100644 --- a/src/strands/experimental/bidi/tools/__init__.py +++ b/src/strands/experimental/bidi/tools/__init__.py @@ -1,5 +1,5 @@ """Built-in tools for bidirectional agents.""" -from .stop_connection import stop_connection +from .stop_conversation import stop_conversation -__all__ = ["stop_connection"] +__all__ = ["stop_conversation"] diff --git a/src/strands/experimental/bidi/tools/stop_connection.py b/src/strands/experimental/bidi/tools/stop_connection.py deleted file mode 100644 index 9fdd78975..000000000 --- a/src/strands/experimental/bidi/tools/stop_connection.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Tool to gracefully stop a bidirectional connection.""" - -from ....tools.decorator import tool - - -@tool -def stop_connection() -> dict: - """Stop the bidirectional conversation gracefully. - - Use this tool when the user wants to end the conversation, such as when they say: - goodbye, bye, end conversation, stop, exit, quit, that's all, or I'm done. - - Returns: - Success message confirming the conversation will end - """ - return { - "status": "success", - "content": [{"text": "Ending conversation"}], - } From 4e68657c981c38a24099603466536ba21e69aee8 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Fri, 28 Nov 2025 23:36:33 -0500 Subject: [PATCH 216/242] update stop_connection tool to stop_coversation --- .../bidi/tools/stop_conversation.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 src/strands/experimental/bidi/tools/stop_conversation.py diff --git a/src/strands/experimental/bidi/tools/stop_conversation.py b/src/strands/experimental/bidi/tools/stop_conversation.py new file mode 100644 index 000000000..ad61fb1d0 --- /dev/null +++ b/src/strands/experimental/bidi/tools/stop_conversation.py @@ -0,0 +1,19 @@ +"""Tool to gracefully stop a bidirectional connection.""" + +from ....tools.decorator import tool + + +@tool +def stop_conversation() -> dict: + """Stop the bidirectional conversation gracefully. + + Use ONLY when user says "stop conversation" exactly. + Do NOT use for: "stop", "goodbye", "bye", "exit", "quit", "end" or other farewells or phrases. + + Returns: + Success message confirming the conversation will end + """ + return { + "status": "success", + "content": [{"text": "Ending conversation"}], + } From 625b25a8a4e6e7f1affe920502023edf333ee925 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sat, 29 Nov 2025 01:03:00 -0500 Subject: [PATCH 217/242] updated implementation based on comments --- src/strands/experimental/bidi/agent/agent.py | 4 +- .../experimental/bidi/models/gemini_live.py | 109 ++++++++------- .../experimental/bidi/models/novasonic.py | 70 ++++++---- .../experimental/bidi/models/openai.py | 130 ++++++++++-------- .../bidi/models/test_gemini_live.py | 26 ++-- .../bidi/models/test_novasonic.py | 20 +-- .../experimental/bidi/models/test_openai.py | 86 +++++++----- 7 files changed, 252 insertions(+), 193 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 1ec4a63ee..d4a4c7365 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -344,7 +344,9 @@ async def run( ) # Using custom audio config: - model = BidiNovaSonicModel(config={"audio": {"input_rate": 48000, "output_rate": 24000}}) + model = BidiNovaSonicModel( + provider_config={"audio": {"input_rate": 48000, "output_rate": 24000}} + ) audio_io = BidiAudioIO() agent = BidiAgent(model=model, tools=[calculator]) await agent.run( diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 16949c5fd..103270b72 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -61,91 +61,96 @@ class BidiGeminiLiveModel(BidiModel): def __init__( self, model_id: str = "gemini-2.5-flash-native-audio-preview-09-2025", - api_key: str | None = None, - config: dict[str, Any] | None = None, provider_config: dict[str, Any] | None = None, + client_config: dict[str, Any] | None = None, **kwargs: Any, ): """Initialize Gemini Live API bidirectional model. Args: - model_id: Gemini Live model identifier. - api_key: Google AI API key for authentication. - config: Optional configuration dictionary with structure {"audio": AudioConfig, ...}. - If not provided or if "audio" key is missing, uses Gemini Live API's default audio configuration. - provider_config: Gemini Live API configuration parameters (e.g., response_modalities, speech_config). + model_id: Model identifier (default: gemini-2.5-flash-native-audio-preview-09-2025) + provider_config: Model behavior (audio, response_modalities, speech_config, transcription) + client_config: Authentication (api_key, http_options) **kwargs: Reserved for future parameters. + """ - # Model configuration + # Store model ID self.model_id = model_id - self.api_key = api_key - - # Set default live_config with transcription enabled - default_config = { - "response_modalities": ["AUDIO"], - "outputAudioTranscription": {}, # Enable output transcription by default - "inputAudioTranscription": {}, # Enable input transcription by default - } - # Merge user config with defaults (user config takes precedence) - if provider_config: - default_config.update(provider_config) + # Resolve client config with defaults + self._client_config = self._resolve_client_config(client_config or {}) - self.provider_config = default_config + # Resolve provider config with defaults + self._provider_config = self._resolve_provider_config(provider_config or {}) - # Create Gemini client with proper API version - client_kwargs: dict[str, Any] = {} - if api_key: - client_kwargs["api_key"] = api_key + # Extract and store audio config for IO coordination + self.config: dict[str, Any] = {"audio": self._provider_config["audio"]} - # Use v1alpha for Live API as it has better model support - client_kwargs["http_options"] = {"api_version": "v1alpha"} + # Store API key for later use + self.api_key = self._client_config.get("api_key") - self._client = genai.Client(**client_kwargs) + # Create Gemini client + self._client = genai.Client(**self._client_config) # Connection state (initialized in start()) self._live_session: Any = None self._live_session_context_manager: Any = None self._connection_id: str | None = None - # Extract audio config from config dict if provided - user_audio_config = config.get("audio", {}) if config else {} + def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Resolve client config (sets default http_options if not provided).""" + resolved = config.copy() + + # Set default http_options if not provided + if "http_options" not in resolved: + resolved["http_options"] = {"api_version": "v1alpha"} + + return resolved - # Extract voice from provider_config if provided - provider_voice = self._extract_voice_from_provider_config() + def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Merge user config with defaults (user takes precedence).""" + # Extract voice from provider-specific speech_config.voice_config.prebuilt_voice_config.voice_name if present + provider_voice = None + if "speech_config" in config and isinstance(config["speech_config"], dict): + provider_voice = ( + config["speech_config"] + .get("voice_config", {}) + .get("prebuilt_voice_config", {}) + .get("voice_name") + ) # Define default audio configuration - default_audio_config: AudioConfig = { + default_audio: AudioConfig = { "input_rate": GEMINI_INPUT_SAMPLE_RATE, "output_rate": GEMINI_OUTPUT_SAMPLE_RATE, "channels": GEMINI_CHANNELS, "format": "pcm", } - # Add voice to defaults if configured in provider_config if provider_voice: - default_audio_config["voice"] = provider_voice + default_audio["voice"] = provider_voice - # Merge user config with defaults (user values take precedence) - merged_audio_config = cast(AudioConfig, {**default_audio_config, **user_audio_config}) + user_audio = config.get("audio", {}) + merged_audio = {**default_audio, **user_audio} - # Store config with audio defaults always populated - self.config: dict[str, Any] = {"audio": merged_audio_config} + default_provider_settings = { + "response_modalities": ["AUDIO"], + "outputAudioTranscription": {}, + "inputAudioTranscription": {}, + } - if user_audio_config: + resolved = { + **default_provider_settings, + **config, + "audio": merged_audio, # Audio always uses merged version + } + + if user_audio: logger.debug("audio_config | merged user-provided config with defaults") else: logger.debug("audio_config | using default Gemini Live audio configuration") - def _extract_voice_from_provider_config(self) -> str | None: - """Extract voice from provider-specific config.""" - if "speech_config" in self.provider_config: - speech_config = self.provider_config["speech_config"] - if isinstance(speech_config, dict): - return (speech_config.get("voice_config", {}) - .get("prebuilt_voice_config", {}) - .get("voice_name")) - return None + return resolved async def start( self, @@ -207,7 +212,7 @@ async def _send_message_history(self, messages: Messages) -> None: async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: """Receive Gemini Live API events and convert to provider-agnostic format.""" if not self._connection_id: - raise RuntimeError("model not started | call start before sending/receiving") + raise RuntimeError("model not started | call start before receiving") yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) @@ -488,10 +493,9 @@ def _build_live_config( Simply passes through all config parameters from provider_config, allowing users to configure any Gemini Live API parameter directly. """ - # Start with user-provided provider_config config_dict: dict[str, Any] = {} - if self.provider_config: - config_dict.update(self.provider_config) + if self._provider_config: + config_dict.update({k: v for k, v in self._provider_config.items() if k != "audio"}) # Override with any kwargs from start() config_dict.update(kwargs) @@ -504,7 +508,6 @@ def _build_live_config( if tools: config_dict["tools"] = self._format_tools_for_live_api(tools) - # Override voice with config value if present (config takes precedence) if "voice" in self.config["audio"]: config_dict.setdefault("speech_config", {}).setdefault("voice_config", {}).setdefault( "prebuilt_voice_config", {} diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index d5e238f8e..10b74d0d4 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -95,31 +95,33 @@ class BidiNovaSonicModel(BidiModel): def __init__( self, model_id: str = "amazon.nova-sonic-v1:0", - boto_session: boto3.Session | None = None, - region: str | None = None, - config: dict[str, Any] | None = None, + provider_config: dict[str, Any] | None = None, + client_config: dict[str, Any] | None = None, **kwargs: Any, ) -> None: """Initialize Nova Sonic bidirectional model. Args: - model_id: Nova Sonic model identifier. - boto_session: Boto Session to use when calling the Nova Sonic Model. - region: AWS region - config: Optional configuration dictionary with structure {"audio": AudioConfig, ...}. - If not provided or if "audio" key is missing, uses Nova Sonic's default audio configuration. + model_id: Model identifier (default: amazon.nova-sonic-v1:0) + provider_config: Model behavior (audio, inference settings) + client_config: AWS authentication (boto_session OR region, not both) **kwargs: Reserved for future parameters. """ - if region and boto_session: - raise ValueError("Cannot specify both `region_name` and `boto_session`.") + # Store model ID + self.model_id = model_id - # Create session and resolve region - self._session = boto_session or boto3.Session() - resolved_region = region or self._session.region_name or "us-east-1" + # Resolve client config with defaults + self._client_config = self._resolve_client_config(client_config or {}) - # Model configuration - self.model_id = model_id - self.region = resolved_region + # Resolve provider config with defaults + self._provider_config = self._resolve_provider_config(provider_config or {}) + + # Extract and store audio config for IO coordination + self.config: dict[str, Any] = {"audio": self._provider_config["audio"]} + + # Store session and region for later use + self._session = self._client_config["boto_session"] + self.region = self._client_config["region"] # Track API-provided identifiers self._connection_id: str | None = None @@ -134,11 +136,27 @@ def __init__( logger.debug("model_id=<%s> | nova sonic model initialized", model_id) - # Extract audio config from config dict if provided - user_audio_config = config.get("audio", {}) if config else {} + def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Resolve AWS client config (creates boto session if needed).""" + if "boto_session" in config and "region" in config: + raise ValueError("Cannot specify both 'boto_session' and 'region' in client_config") + + resolved = config.copy() + # Create boto session if not provided + if "boto_session" not in resolved: + resolved["boto_session"] = boto3.Session() + + # Resolve region from session or use default + if "region" not in resolved: + resolved["region"] = resolved["boto_session"].region_name or "us-east-1" + + return resolved + + def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Merge user config with defaults (user takes precedence).""" # Define default audio configuration - default_audio_config: AudioConfig = { + default_audio: AudioConfig = { "input_rate": cast(int, NOVA_AUDIO_INPUT_CONFIG["sampleRateHertz"]), "output_rate": cast(int, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]), "channels": cast(int, NOVA_AUDIO_INPUT_CONFIG["channelCount"]), @@ -146,17 +164,21 @@ def __init__( "voice": cast(str, NOVA_AUDIO_OUTPUT_CONFIG["voiceId"]), } - # Merge user config with defaults (user values take precedence) - merged_audio_config = cast(AudioConfig, {**default_audio_config, **user_audio_config}) + user_audio = config.get("audio", {}) + merged_audio = {**default_audio, **user_audio} - # Store config with audio defaults always populated - self.config: dict[str, Any] = {"audio": merged_audio_config} + resolved = { + "audio": merged_audio, + **{k: v for k, v in config.items() if k != "audio"}, + } - if user_audio_config: + if user_audio: logger.debug("audio_config | merged user-provided config with defaults") else: logger.debug("audio_config | using default Nova Sonic audio configuration") + return resolved + async def start( self, system_prompt: str | None = None, diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index d80cd76b0..f26e04403 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -77,42 +77,36 @@ class BidiOpenAIRealtimeModel(BidiModel): def __init__( self, model_id: str = DEFAULT_MODEL, - api_key: str | None = None, - config: dict[str, Any] | None = None, provider_config: dict[str, Any] | None = None, + client_config: dict[str, Any] | None = None, **kwargs: Any, ) -> None: """Initialize OpenAI Realtime bidirectional model. Args: - model_id: OpenAI model identifier (default: gpt-realtime). - api_key: OpenAI API key for authentication. - provider_config: Session configuration parameters (e.g., voice, turn_detection, modalities). - config: Optional configuration dictionary with structure {"audio": AudioConfig, ...}. - If not provided or if "audio" key is missing, uses OpenAI Realtime API's default audio configuration. + model_id: Model identifier (default: gpt-realtime) + provider_config: Model behavior (audio, instructions, turn_detection, etc.) + client_config: Authentication (api_key, organization, project) + Falls back to OPENAI_API_KEY, OPENAI_ORGANIZATION, OPENAI_PROJECT env vars **kwargs: Reserved for future parameters. - Environment Variables: - OPENAI_API_KEY: API key (if not provided as parameter) - OPENAI_ORGANIZATION: Organization ID for billing/organization - OPENAI_PROJECT: Project ID for billing/organization """ - # Model configuration + # Store model ID self.model_id = model_id - self.api_key = api_key - self.provider_config = provider_config or {} - - # Read from environment variables with same pattern as API key - if not self.api_key: - self.api_key = os.getenv("OPENAI_API_KEY") - if not self.api_key: - raise ValueError( - "OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter." - ) - # Read organization and project from environment (no parameters needed) - self.organization = os.getenv("OPENAI_ORGANIZATION") - self.project = os.getenv("OPENAI_PROJECT") + # Resolve client config with defaults and env vars + self._client_config = self._resolve_client_config(client_config or {}) + + # Resolve provider config with defaults + self._provider_config = self._resolve_provider_config(provider_config or {}) + + # Extract and store audio config for IO coordination + self.config: dict[str, Any] = {"audio": self._provider_config["audio"]} + + # Store client config values for later use + self.api_key = self._client_config["api_key"] + self.organization = self._client_config.get("organization") + self.project = self._client_config.get("project") # Connection state (initialized in start()) self._connection_id: str | None = None @@ -121,14 +115,40 @@ def __init__( logger.debug("model=<%s> | openai realtime model initialized", model_id) - # Extract audio config from config dict if provided - user_audio_config = config.get("audio", {}) if config else {} + def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Resolve client config with env var fallback (config takes precedence).""" + resolved = config.copy() + + if "api_key" not in resolved: + resolved["api_key"] = os.getenv("OPENAI_API_KEY") + + if not resolved.get("api_key"): + raise ValueError( + "OpenAI API key is required. Provide via client_config={'api_key': '...'} " + "or set OPENAI_API_KEY environment variable." + ) + if "organization" not in resolved: + env_org = os.getenv("OPENAI_ORGANIZATION") + if env_org: + resolved["organization"] = env_org + + if "project" not in resolved: + env_project = os.getenv("OPENAI_PROJECT") + if env_project: + resolved["project"] = env_project + + return resolved - # Extract voice from provider_config if provided - provider_voice = self._extract_voice_from_provider_config() + def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Merge user config with defaults (user takes precedence).""" + # Extract voice from provider-specific audio.output.voice if present + provider_voice = None + if "audio" in config and isinstance(config["audio"], dict): + if "output" in config["audio"] and isinstance(config["audio"]["output"], dict): + provider_voice = config["audio"]["output"].get("voice") # Define default audio configuration - default_audio_config: AudioConfig = { + default_audio: AudioConfig = { "input_rate": DEFAULT_SAMPLE_RATE, "output_rate": DEFAULT_SAMPLE_RATE, "channels": 1, @@ -136,26 +156,20 @@ def __init__( "voice": provider_voice or "alloy", } - # Merge user config with defaults (user values take precedence) - merged_audio_config = cast(AudioConfig, {**default_audio_config, **user_audio_config}) + user_audio = config.get("audio", {}) + merged_audio = {**default_audio, **user_audio} - # Store config with audio defaults always populated - self.config: dict[str, Any] = {"audio": merged_audio_config} + resolved = { + "audio": merged_audio, + **{k: v for k, v in config.items() if k != "audio"}, + } - if user_audio_config: + if user_audio: logger.debug("audio_config | merged user-provided config with defaults") else: logger.debug("audio_config | using default OpenAI Realtime audio configuration") - def _extract_voice_from_provider_config(self) -> str | None: - """Extract voice from provider-specific config.""" - if "audio" in self.provider_config: - audio_settings = self.provider_config["audio"] - if isinstance(audio_settings, dict) and "output" in audio_settings: - output_settings = audio_settings["output"] - if isinstance(output_settings, dict): - return output_settings.get("voice") - return None + return resolved async def start( self, @@ -247,7 +261,6 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] "output_modalities", "instructions", "voice", - "audio", "tools", "tool_choice", "input_audio_format", @@ -256,23 +269,28 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] "turn_detection", } - for key, value in self.provider_config.items(): - if key in supported_params: + for key, value in self._provider_config.items(): + if key == "audio": + continue + elif key in supported_params: config[key] = value else: logger.warning("parameter=<%s> | ignoring unsupported session parameter", key) - # Override audio configuration with config values if present (config takes precedence) - if "voice" in self.config["audio"]: - config.setdefault("audio", {}).setdefault("output", {})["voice"] = self.config["audio"]["voice"] + audio_config = self.config["audio"] - if "input_rate" in self.config["audio"]: - input_config = config.setdefault("audio", {}).setdefault("input", {}).setdefault("format", {}) - input_config["rate"] = self.config["audio"]["input_rate"] + if "voice" in audio_config: + config.setdefault("audio", {}).setdefault("output", {})["voice"] = audio_config["voice"] - if "output_rate" in self.config["audio"]: - output_config = config.setdefault("audio", {}).setdefault("output", {}).setdefault("format", {}) - output_config["rate"] = self.config["audio"]["output_rate"] + if "input_rate" in audio_config: + config.setdefault("audio", {}).setdefault("input", {}).setdefault("format", {})["rate"] = audio_config[ + "input_rate" + ] + + if "output_rate" in audio_config: + config.setdefault("audio", {}).setdefault("output", {}).setdefault("format", {})["rate"] = audio_config[ + "output_rate" + ] return config diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index 846f2b526..48c1d9e09 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -62,7 +62,7 @@ def api_key(): def model(mock_genai_client, model_id, api_key): """Create a BidiGeminiLiveModel instance.""" _ = mock_genai_client - return BidiGeminiLiveModel(model_id=model_id, api_key=api_key) + return BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) @pytest.fixture @@ -102,7 +102,7 @@ def test_model_initialization(mock_genai_client, model_id, api_key): assert "inputAudioTranscription" in model_default.provider_config # Test with API key - model_with_key = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) + model_with_key = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) assert model_with_key.model_id == model_id assert model_with_key.api_key == api_key @@ -161,7 +161,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): mock_client, _, mock_live_session_cm = mock_genai_client # Test connection error - model1 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) + model1 = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) mock_client.aio.live.connect.side_effect = Exception("Connection failed") with pytest.raises(Exception, match=r"Connection failed"): await model1.start() @@ -170,18 +170,18 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): mock_client.aio.live.connect.side_effect = None # Test double connection - model2 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) + model2 = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) await model2.start() with pytest.raises(RuntimeError, match="call stop before starting again"): await model2.start() await model2.stop() # Test close when not connected - model3 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) + model3 = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) await model3.stop() # Should not raise # Test close error handling - model4 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) + model4 = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) await model4.start() mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") with pytest.raises(ExceptionGroup): @@ -452,7 +452,7 @@ def test_audio_config_defaults(mock_genai_client, model_id, api_key): """Test default audio configuration.""" _ = mock_genai_client - model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) + model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) assert model.config["audio"]["input_rate"] == 16000 assert model.config["audio"]["output_rate"] == 24000 @@ -466,7 +466,7 @@ def test_audio_config_partial_override(mock_genai_client, model_id, api_key): _ = mock_genai_client config = {"audio": {"output_rate": 48000, "voice": "Puck"}} - model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key, config=config) + model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, config=config) # Overridden values assert model.config["audio"]["output_rate"] == 48000 @@ -491,7 +491,7 @@ def test_audio_config_full_override(mock_genai_client, model_id, api_key): "voice": "Aoede", } } - model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key, config=config) + model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, config=config) assert model.config["audio"]["input_rate"] == 48000 assert model.config["audio"]["output_rate"] == 48000 @@ -507,7 +507,9 @@ def test_audio_config_voice_priority(mock_genai_client, model_id, api_key): provider_config = {"speech_config": {"voice_config": {"prebuilt_voice_config": {"voice_name": "Puck"}}}} config = {"audio": {"voice": "Aoede"}} - model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key, provider_config=provider_config, config=config) + model = BidiGeminiLiveModel( + model_id=model_id, client_config={"api_key": api_key}, provider_config=provider_config, config=config + ) # Build config and verify config audio voice takes precedence built_config = model._build_live_config() @@ -556,7 +558,7 @@ async def test_custom_audio_rates_in_events(mock_genai_client, model_id, api_key # Create model with custom audio configuration config = {"audio": {"output_rate": 48000, "channels": 2}} - model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key, config=config) + model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, config=config) await model.start() # Test audio output event uses custom configuration @@ -584,7 +586,7 @@ async def test_default_audio_rates_in_events(mock_genai_client, model_id, api_ke _, _, _ = mock_genai_client # Create model without custom audio configuration - model = BidiGeminiLiveModel(model_id=model_id, api_key=api_key) + model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) await model.start() # Test audio output event uses defaults diff --git a/tests/strands/experimental/bidi/models/test_novasonic.py b/tests/strands/experimental/bidi/models/test_novasonic.py index 1172aae53..b83758335 100644 --- a/tests/strands/experimental/bidi/models/test_novasonic.py +++ b/tests/strands/experimental/bidi/models/test_novasonic.py @@ -69,7 +69,7 @@ def nova_model(model_id, region, mock_client): """Create Nova Sonic model instance.""" _ = mock_client - model = BidiNovaSonicModel(model_id=model_id, region=region) + model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}) yield model @@ -79,7 +79,7 @@ def nova_model(model_id, region, mock_client): @pytest.mark.asyncio async def test_model_initialization(model_id, region): """Test model initialization with configuration.""" - model = BidiNovaSonicModel(model_id=model_id, region=region) + model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}) assert model.model_id == model_id assert model.region == region @@ -92,7 +92,7 @@ async def test_model_initialization(model_id, region): @pytest.mark.asyncio async def test_audio_config_defaults(model_id, region): """Test default audio configuration.""" - model = BidiNovaSonicModel(model_id=model_id, region=region) + model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}) assert model.config["audio"]["input_rate"] == 16000 assert model.config["audio"]["output_rate"] == 16000 @@ -104,8 +104,8 @@ async def test_audio_config_defaults(model_id, region): @pytest.mark.asyncio async def test_audio_config_partial_override(model_id, region): """Test partial audio configuration override.""" - config = {"audio": {"output_rate": 24000, "voice": "ruth"}} - model = BidiNovaSonicModel(model_id=model_id, region=region, config=config) + provider_config = {"audio": {"output_rate": 24000, "voice": "ruth"}} + model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}, provider_config=provider_config) # Overridden values assert model.config["audio"]["output_rate"] == 24000 @@ -120,7 +120,7 @@ async def test_audio_config_partial_override(model_id, region): @pytest.mark.asyncio async def test_audio_config_full_override(model_id, region): """Test full audio configuration override.""" - config = { + provider_config = { "audio": { "input_rate": 48000, "output_rate": 48000, @@ -129,7 +129,7 @@ async def test_audio_config_full_override(model_id, region): "voice": "stephen", } } - model = BidiNovaSonicModel(model_id=model_id, region=region, config=config) + model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}, provider_config=provider_config) assert model.config["audio"]["input_rate"] == 48000 assert model.config["audio"]["output_rate"] == 48000 @@ -528,8 +528,8 @@ async def test_message_history_empty_and_edge_cases(nova_model): async def test_custom_audio_rates_in_events(model_id, region): """Test that audio events use configured sample rates.""" # Create model with custom audio configuration - config = {"audio": {"output_rate": 48000, "channels": 2}} - model = BidiNovaSonicModel(model_id=model_id, region=region, config=config) + provider_config = {"audio": {"output_rate": 48000, "channels": 2}} + model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}, provider_config=provider_config) # Test audio output event uses custom configuration audio_bytes = b"test audio data" @@ -549,7 +549,7 @@ async def test_custom_audio_rates_in_events(model_id, region): async def test_default_audio_rates_in_events(model_id, region): """Test that audio events use default sample rates when no custom config.""" # Create model without custom audio configuration - model = BidiNovaSonicModel(model_id=model_id, region=region) + model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}) # Test audio output event uses defaults audio_bytes = b"test audio data" diff --git a/tests/strands/experimental/bidi/models/test_openai.py b/tests/strands/experimental/bidi/models/test_openai.py index 6a6baf011..1f36d3fa4 100644 --- a/tests/strands/experimental/bidi/models/test_openai.py +++ b/tests/strands/experimental/bidi/models/test_openai.py @@ -62,7 +62,7 @@ def api_key(): @pytest.fixture def model(api_key, model_name): """Create an BidiOpenAIRealtimeModel instance.""" - return BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) + return BidiOpenAIRealtimeModel(model=model_name, client_config={"api_key": api_key}) @pytest.fixture @@ -87,28 +87,29 @@ def messages(): # Initialization Tests -def test_model_initialization(api_key, model_name): +def test_model_initialization(api_key, model_name, monkeypatch): """Test model initialization with various configurations.""" # Test default config - model_default = BidiOpenAIRealtimeModel(api_key="test-key") + model_default = BidiOpenAIRealtimeModel(client_config={"api_key": "test-key"}) assert model_default.model_id == "gpt-realtime" assert model_default.api_key == "test-key" # Test with custom model - model_custom = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key) + model_custom = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) assert model_custom.model_id == model_name assert model_custom.api_key == api_key # Test with organization and project via environment variables - with unittest.mock.patch.dict("os.environ", {"OPENAI_ORGANIZATION": "org-123", "OPENAI_PROJECT": "proj-456"}): - model_env = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key) - assert model_env.organization == "org-123" - assert model_env.project == "proj-456" + monkeypatch.setenv("OPENAI_ORGANIZATION", "org-123") + monkeypatch.setenv("OPENAI_PROJECT", "proj-456") + model_env = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) + assert model_env.organization == "org-123" + assert model_env.project == "proj-456" # Test with env API key - with unittest.mock.patch.dict("os.environ", {"OPENAI_API_KEY": "env-key"}): - model_env = BidiOpenAIRealtimeModel() - assert model_env.api_key == "env-key" + monkeypatch.setenv("OPENAI_API_KEY", "env-key") + model_env = BidiOpenAIRealtimeModel() + assert model_env.api_key == "env-key" # Audio Configuration Tests @@ -116,7 +117,7 @@ def test_model_initialization(api_key, model_name): def test_audio_config_defaults(api_key, model_name): """Test default audio configuration.""" - model = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key) + model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) assert model.config["audio"]["input_rate"] == 24000 assert model.config["audio"]["output_rate"] == 24000 @@ -128,7 +129,7 @@ def test_audio_config_defaults(api_key, model_name): def test_audio_config_partial_override(api_key, model_name): """Test partial audio configuration override.""" config = {"audio": {"output_rate": 48000, "voice": "echo"}} - model = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key, config=config) + model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, config=config) # Overridden values assert model.config["audio"]["output_rate"] == 48000 @@ -151,7 +152,7 @@ def test_audio_config_full_override(api_key, model_name): "voice": "shimmer", } } - model = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key, config=config) + model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, config=config) assert model.config["audio"]["input_rate"] == 48000 assert model.config["audio"]["output_rate"] == 48000 @@ -166,7 +167,7 @@ def test_audio_config_voice_priority(api_key, model_name): config = {"audio": {"voice": "nova"}} model = BidiOpenAIRealtimeModel( - model_id=model_name, api_key=api_key, provider_config=provider_config, config=config + model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config, config=config ) # Build config and verify config audio voice takes precedence @@ -178,17 +179,19 @@ def test_audio_config_extracts_voice_from_provider_config(api_key, model_name): """Test that voice is extracted from provider_config when config audio not provided.""" provider_config = {"audio": {"output": {"voice": "fable"}}} - model = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key, provider_config=provider_config) + model = BidiOpenAIRealtimeModel( + model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config + ) # Should extract voice from provider_config assert model.config["audio"]["voice"] == "fable" -def test_init_without_api_key_raises(): +def test_init_without_api_key_raises(monkeypatch): """Test that initialization without API key raises error.""" - with unittest.mock.patch.dict("os.environ", {}, clear=True): - with pytest.raises(ValueError, match="OpenAI API key is required"): - BidiOpenAIRealtimeModel() + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + with pytest.raises(ValueError, match="OpenAI API key is required"): + BidiOpenAIRealtimeModel() # Connection Tests @@ -242,15 +245,24 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp await model.stop() # Test connection with organization header (via environment) - with unittest.mock.patch.dict("os.environ", {"OPENAI_ORGANIZATION": "org-123"}): - model_org = BidiOpenAIRealtimeModel(api_key="test-key") - await model_org.start() - call_kwargs = mock_connect.call_args.kwargs - headers = call_kwargs.get("additional_headers", []) - org_header = [h for h in headers if h[0] == "OpenAI-Organization"] - assert len(org_header) == 1 - assert org_header[0][1] == "org-123" - await model_org.stop() + # Note: This test needs to be in a separate test function to use monkeypatch properly + # Skipping inline environment test here - see test_connection_with_org_header + + +@pytest.mark.asyncio +async def test_connection_with_org_header(mock_websockets_connect, monkeypatch): + """Test connection with organization header from environment.""" + mock_connect, mock_ws = mock_websockets_connect + + monkeypatch.setenv("OPENAI_ORGANIZATION", "org-123") + model_org = BidiOpenAIRealtimeModel(client_config={"api_key": "test-key"}) + await model_org.start() + call_kwargs = mock_connect.call_args.kwargs + headers = call_kwargs.get("additional_headers", []) + org_header = [h for h in headers if h[0] == "OpenAI-Organization"] + assert len(org_header) == 1 + assert org_header[0][1] == "org-123" + await model_org.stop() @pytest.mark.asyncio @@ -323,7 +335,7 @@ async def test_connection_edge_cases(mock_websockets_connect, api_key, model_nam mock_connect, mock_ws = mock_websockets_connect # Test connection error - model1 = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key) + model1 = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) mock_connect.side_effect = Exception("Connection failed") with pytest.raises(Exception, match="Connection failed"): await model1.start() @@ -335,18 +347,18 @@ async def async_connect(*args, **kwargs): mock_connect.side_effect = async_connect # Test double connection - model2 = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key) + model2 = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) await model2.start() with pytest.raises(RuntimeError, match=r"call stop before starting again"): await model2.start() await model2.stop() # Test close when not connected - model3 = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key) + model3 = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) await model3.stop() # Should not raise # Test close error - model4 = BidiOpenAIRealtimeModel(model_id=model_name, api_key=api_key) + model4 = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) await model4.start() mock_ws.close.side_effect = Exception("Close failed") with pytest.raises(ExceptionGroup): @@ -689,7 +701,7 @@ async def test_custom_audio_sample_rate(mock_websockets_connect, api_key): # Create model with custom sample rate custom_sample_rate = 48000 provider_config = {"audio": {"output": {"format": {"rate": custom_sample_rate}}}} - model = BidiOpenAIRealtimeModel(api_key=api_key, provider_config=provider_config) + model = BidiOpenAIRealtimeModel(client_config={"api_key": api_key}, provider_config=provider_config) await model.start() @@ -717,7 +729,7 @@ async def test_default_audio_sample_rate(mock_websockets_connect, api_key): _, mock_ws = mock_websockets_connect # Create model without custom audio config - model = BidiOpenAIRealtimeModel(api_key=api_key) + model = BidiOpenAIRealtimeModel(client_config={"api_key": api_key}) await model.start() @@ -746,7 +758,7 @@ async def test_partial_audio_config(mock_websockets_connect, api_key): # Create model with partial audio config (missing format.rate) provider_config = {"audio": {"output": {"voice": "alloy"}}} - model = BidiOpenAIRealtimeModel(api_key=api_key, provider_config=provider_config) + model = BidiOpenAIRealtimeModel(client_config={"api_key": api_key}, provider_config=provider_config) await model.start() @@ -775,7 +787,7 @@ async def test_partial_audio_config(mock_websockets_connect, api_key): async def test_tool_result_single_text_content(mock_websockets_connect, api_key): """Test tool result with single text content block.""" _, mock_ws = mock_websockets_connect - model = BidiOpenAIRealtimeModel(api_key=api_key) + model = BidiOpenAIRealtimeModel(client_config={"api_key": api_key}) await model.start() tool_result: ToolResult = { From 028483758f675bc7f35783faff25bf2fa4be02a7 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sat, 29 Nov 2025 01:08:21 -0500 Subject: [PATCH 218/242] update tool --- src/strands/experimental/bidi/tools/stop_conversation.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/strands/experimental/bidi/tools/stop_conversation.py b/src/strands/experimental/bidi/tools/stop_conversation.py index ad61fb1d0..9c7e1c6cd 100644 --- a/src/strands/experimental/bidi/tools/stop_conversation.py +++ b/src/strands/experimental/bidi/tools/stop_conversation.py @@ -4,7 +4,7 @@ @tool -def stop_conversation() -> dict: +def stop_conversation() -> str: """Stop the bidirectional conversation gracefully. Use ONLY when user says "stop conversation" exactly. @@ -13,7 +13,4 @@ def stop_conversation() -> dict: Returns: Success message confirming the conversation will end """ - return { - "status": "success", - "content": [{"text": "Ending conversation"}], - } + return "Ending conversation" From b47497824404b6481f407ad9b608a690165e6d2e Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Sat, 29 Nov 2025 01:16:06 -0500 Subject: [PATCH 219/242] loop - add tool use and tool result in sequence to history (#94) --- src/strands/experimental/bidi/agent/loop.py | 74 +++++++++++-------- .../experimental/bidi/agent/test_loop.py | 61 ++++++++++++++- 2 files changed, 101 insertions(+), 34 deletions(-) diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 76907b6fd..421c55e53 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -5,7 +5,7 @@ import asyncio import logging -from typing import TYPE_CHECKING, Any, AsyncGenerator +from typing import TYPE_CHECKING, Any, AsyncGenerator, cast from ....types._events import ToolInterruptEvent, ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent from ....types.content import Message @@ -50,6 +50,8 @@ class _BidiAgentLoop: that tools can access via their invocation_state parameter. _send_gate: Gate the sending of events to the model. Blocks when agent is reseting the model connection after timeout. + _message_lock: Lock to ensure that paired messages are added to history in sequence without interference. + For example, tool use and tool result messages must be added adjacent to each other. """ def __init__(self, agent: "BidiAgent") -> None: @@ -67,6 +69,7 @@ def __init__(self, agent: "BidiAgent") -> None: self._invocation_state: dict[str, Any] self._send_gate = asyncio.Event() + self._message_lock = asyncio.Lock() async def start(self, invocation_state: dict[str, Any] | None = None) -> None: """Start the agent loop. @@ -79,8 +82,7 @@ async def start(self, invocation_state: dict[str, Any] | None = None) -> None: that tools can access via their invocation_state parameter. Raises: - RuntimeError: - If loop already started. + RuntimeError: If loop already started. """ if self._started: raise RuntimeError("loop already started | call stop before starting again") @@ -128,7 +130,10 @@ async def send(self, event: BidiInputEvent | ToolResultEvent) -> None: Additionally, add text input to messages array. Args: - event: BidiInputEvent. + event: User input event or tool result. + + Raises: + RuntimeError: If start has not been called. """ if not self._started: raise RuntimeError("loop not started | call start before sending") @@ -139,14 +144,16 @@ async def send(self, event: BidiInputEvent | ToolResultEvent) -> None: if isinstance(event, BidiTextInputEvent): message: Message = {"role": "user", "content": [{"text": event.text}]} - self._agent.messages.append(message) - await self._agent.hooks.invoke_callbacks_async(BidiMessageAddedEvent(agent=self._agent, message=message)) + await self._add_messages(message) await self._agent.model.send(event) async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: """Receive model and tool call events. + Returns: + Model and tool call events. + Raises: RuntimeError: If start has not been called. """ @@ -210,21 +217,12 @@ async def _run_model(self) -> None: if isinstance(event, BidiTranscriptStreamEvent): if event["is_final"]: message: Message = {"role": event["role"], "content": [{"text": event["text"]}]} - self._agent.messages.append(message) - await self._agent.hooks.invoke_callbacks_async( - BidiMessageAddedEvent(agent=self._agent, message=message) - ) + await self._add_messages(message) elif isinstance(event, ToolUseStreamEvent): tool_use = event["current_tool_use"] self._task_pool.create(self._run_tool(tool_use)) - tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} - self._agent.messages.append(tool_message) - await self._agent.hooks.invoke_callbacks_async( - BidiMessageAddedEvent(agent=self._agent, message=tool_message) - ) - elif isinstance(event, BidiInterruptionEvent): await self._agent.hooks.invoke_callbacks_async( BidiInterruptionHookEvent( @@ -238,7 +236,11 @@ async def _run_model(self) -> None: await self._event_queue.put(error) async def _run_tool(self, tool_use: ToolUse) -> None: - """Task for running tool requested by the model using the tool executor.""" + """Task for running tool requested by the model using the tool executor. + + Args: + tool_use: Tool use request from model. + """ logger.debug("tool_name=<%s> | tool execution starting", tool_use["name"]) tool_results: list[ToolResult] = [] @@ -260,25 +262,35 @@ async def _run_tool(self, tool_use: ToolUse) -> None: structured_output_context=None, ) - async for event in tool_events: - if isinstance(event, ToolInterruptEvent): + async for tool_event in tool_events: + if isinstance(tool_event, ToolInterruptEvent): self._agent._interrupt_state.deactivate() - interrupt_names = [interrupt.name for interrupt in event.interrupts] + interrupt_names = [interrupt.name for interrupt in tool_event.interrupts] raise RuntimeError(f"interrupts={interrupt_names} | tool interrupts are not supported in bidi") - await self._event_queue.put(event) - if isinstance(event, ToolResultEvent): - result = event.tool_result + await self._event_queue.put(tool_event) + + tool_result_event = cast(ToolResultEvent, tool_event) - await self.send(ToolResultEvent(result)) + tool_use_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} + tool_result_message: Message = {"role": "user", "content": [{"toolResult": tool_result_event.tool_result}]} + await self._add_messages(tool_use_message, tool_result_message) - message: Message = { - "role": "user", - "content": [{"toolResult": result}], - } - self._agent.messages.append(message) - await self._agent.hooks.invoke_callbacks_async(BidiMessageAddedEvent(agent=self._agent, message=message)) - await self._event_queue.put(ToolResultMessageEvent(message)) + await self._event_queue.put(ToolResultMessageEvent(tool_result_message)) + await self.send(tool_result_event) except Exception as error: await self._event_queue.put(error) + + async def _add_messages(self, *messages: Message) -> None: + """Add messages to history in sequence without interference. + + Args: + *messages: List of messages to add into history. + """ + async with self._message_lock: + for message in messages: + self._agent.messages.append(message) + await self._agent.hooks.invoke_callbacks_async( + BidiMessageAddedEvent(agent=self._agent, message=message) + ) diff --git a/tests/strands/experimental/bidi/agent/test_loop.py b/tests/strands/experimental/bidi/agent/test_loop.py index 1ec5712af..68346ab19 100644 --- a/tests/strands/experimental/bidi/agent/test_loop.py +++ b/tests/strands/experimental/bidi/agent/test_loop.py @@ -3,16 +3,35 @@ import pytest import pytest_asyncio +from strands import tool from strands.experimental.bidi.agent.loop import _BidiAgentLoop from strands.experimental.bidi.models import BidiModelTimeoutError from strands.experimental.bidi.types.events import BidiConnectionRestartEvent, BidiTextInputEvent +from strands.hooks import HookRegistry +from strands.tools.executors import SequentialToolExecutor +from strands.tools.registry import ToolRegistry +from strands.types._events import ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent @pytest.fixture -def agent(): +def time_tool(): + @tool(name="time_tool") + async def func(): + return "12:00" + + return func + + +@pytest.fixture +def agent(time_tool): mock = unittest.mock.Mock() - mock.hooks = unittest.mock.AsyncMock() + mock.hooks = HookRegistry() + mock.messages = [] mock.model = unittest.mock.AsyncMock() + mock.tool_executor = SequentialToolExecutor() + mock.tool_registry = ToolRegistry() + mock.tool_registry.process_tools([time_tool]) + return mock @@ -46,6 +65,42 @@ async def test_bidi_agent_loop_receive_restart_connection(loop, agent, agenerato assert agent.model.start.call_count == 2 agent.model.start.assert_any_call( agent.system_prompt, - agent.tool_registry.get_all_tool_specs.return_value, + agent.tool_registry.get_all_tool_specs(), agent.messages, ) + + +@pytest.mark.asyncio +async def test_bidi_agent_loop_receive_tool_use(loop, agent, agenerator): + + tool_use = {"toolUseId": "t1", "name": "time_tool", "input": {}} + tool_result = {"toolUseId": "t1", "status": "success", "content": [{"text": "12:00"}]} + + tool_use_event = ToolUseStreamEvent(current_tool_use=tool_use, delta="") + tool_result_event = ToolResultEvent(tool_result) + + agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event])) + + await loop.start() + + tru_events = [] + async for event in loop.receive(): + tru_events.append(event) + if len(tru_events) >= 3: + break + + exp_events = [ + tool_use_event, + tool_result_event, + ToolResultMessageEvent({"role": "user", "content": [{"toolResult": tool_result}]}), + ] + assert tru_events == exp_events + + tru_messages = agent.messages + exp_messages = [ + {"role": "assistant", "content": [{"toolUse": tool_use}]}, + {"role": "user", "content": [{"toolResult": tool_result}]}, + ] + assert tru_messages == exp_messages + + agent.model.send.assert_called_with(tool_result_event) From 53e4b31b804a41ad1fb49dfe50297ea3dd61ced5 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sat, 29 Nov 2025 12:13:29 -0500 Subject: [PATCH 220/242] adjust logic based on comments --- src/strands/experimental/bidi/agent/agent.py | 7 +++--- src/strands/experimental/bidi/agent/loop.py | 23 +++++++++++++------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 11e533b08..7b11c87d9 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -374,13 +374,14 @@ async def run_outputs(inputs_task: asyncio.Task) -> None: await asyncio.gather(*[output(event) for output in outputs]) if isinstance(event, BidiConnectionCloseEvent) and event.reason == "user_request": - logger.info( + logger.debug( "connection_id=<%s>, reason=<%s> | graceful shutdown initiated", event.connection_id, event.reason, ) - inputs_task.cancel() - return + break + + inputs_task.cancel() try: await self.start(invocation_state) diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 119fd07b9..ab8782ec0 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -165,6 +165,11 @@ async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: if isinstance(event, Exception): raise event + # Check for graceful shutdown event + if isinstance(event, BidiConnectionCloseEvent) and event.reason == "user_request": + yield event + break + yield event async def _restart_connection(self, timeout_error: BidiModelTimeoutError) -> None: @@ -271,6 +276,16 @@ async def _run_tool(self, tool_use: ToolUse) -> None: if isinstance(event, ToolResultEvent): result = event.tool_result + # Check for stop_conversation BEFORE sending result + if tool_use["name"] == "stop_conversation": + logger.info("tool_name=<%s> | conversation stop requested, skipping result send", tool_use["name"]) + connection_id = getattr(self._agent.model, "_connection_id", "unknown") + await self._event_queue.put( + BidiConnectionCloseEvent(connection_id=connection_id, reason="user_request") + ) + return # Early exit - don't send result, don't add to messages + + # Normal flow for all other tools await self.send(ToolResultEvent(result)) message: Message = { @@ -281,13 +296,5 @@ async def _run_tool(self, tool_use: ToolUse) -> None: await self._agent.hooks.invoke_callbacks_async(BidiMessageAddedEvent(agent=self._agent, message=message)) await self._event_queue.put(ToolResultMessageEvent(message)) - # Check if this was the stop_conversation tool - if tool_use["name"] == "stop_conversation": - logger.info("tool_name=<%s> | conversation stop requested by tool", tool_use["name"]) - connection_id = getattr(self._agent.model, "_connection_id", "unknown") - await self._event_queue.put( - BidiConnectionCloseEvent(connection_id=connection_id, reason="user_request") - ) - except Exception as error: await self._event_queue.put(error) From e9c26abcbd3e5f653370783438e0aaa046418975 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sat, 29 Nov 2025 12:33:45 -0500 Subject: [PATCH 221/242] updated implementation --- src/strands/experimental/bidi/agent/agent.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 7b11c87d9..7d8326a85 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -373,14 +373,6 @@ async def run_outputs(inputs_task: asyncio.Task) -> None: async for event in self.receive(): await asyncio.gather(*[output(event) for output in outputs]) - if isinstance(event, BidiConnectionCloseEvent) and event.reason == "user_request": - logger.debug( - "connection_id=<%s>, reason=<%s> | graceful shutdown initiated", - event.connection_id, - event.reason, - ) - break - inputs_task.cancel() try: From 16677f161f96e2a2aafac038acf0ba47fc0eeec5 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sat, 29 Nov 2025 13:07:03 -0500 Subject: [PATCH 222/242] update implementation --- src/strands/experimental/bidi/agent/loop.py | 22 +++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index ead914aec..80245d9b2 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -276,16 +276,7 @@ async def _run_tool(self, tool_use: ToolUse) -> None: await self._event_queue.put(tool_event) - # Check for stop_conversation BEFORE sending result - if tool_use["name"] == "stop_conversation": - logger.info("tool_name=<%s> | conversation stop requested, skipping result send", tool_use["name"]) - connection_id = getattr(self._agent.model, "_connection_id", "unknown") - await self._event_queue.put( - BidiConnectionCloseEvent(connection_id=connection_id, reason="user_request") - ) - return # Early exit - don't send result, don't add to messages - - # Normal flow for all other tools + # Normal flow for all tools (including stop_conversation) tool_result_event = cast(ToolResultEvent, tool_event) tool_use_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} @@ -293,6 +284,17 @@ async def _run_tool(self, tool_use: ToolUse) -> None: await self._add_messages(tool_use_message, tool_result_message) await self._event_queue.put(ToolResultMessageEvent(tool_result_message)) + + # Check for stop_conversation before sending to model + if tool_use["name"] == "stop_conversation": + logger.info("tool_name=<%s> | conversation stop requested, skipping model send", tool_use["name"]) + connection_id = getattr(self._agent.model, "_connection_id", "unknown") + await self._event_queue.put( + BidiConnectionCloseEvent(connection_id=connection_id, reason="user_request") + ) + return # Skip the model send + + # Send result to model (all tools except stop_conversation) await self.send(tool_result_event) except Exception as error: From 4d14bb422d89cf23c8c05d68694f02b8971871db Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Sat, 29 Nov 2025 13:15:04 -0500 Subject: [PATCH 223/242] loop - restart connection - openai (#95) --- .../experimental/bidi/models/openai.py | 42 ++++++++++++-- .../experimental/bidi/models/test_openai.py | 57 ++++++++++++------- 2 files changed, 72 insertions(+), 27 deletions(-) diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 46a9de14f..3849ecaf7 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -4,9 +4,11 @@ with WebSocket connections, voice activity detection, and function calling. """ +import asyncio import json import logging import os +import time import uuid from typing import Any, AsyncGenerator, Literal, cast @@ -35,11 +37,21 @@ Role, StopReason, ) -from .bidi_model import BidiModel +from .bidi_model import BidiModel, BidiModelTimeoutError logger = logging.getLogger(__name__) +# Test idle_timeout_ms + # OpenAI Realtime API configuration +OPENAI_MAX_TIMEOUT_S = 3000 # 50 minutes +"""Max timeout before closing connection. + +OpenAI documents a 60 minute limit on realtime sessions +(https://platform.openai.com/docs/guides/realtime-conversations#session-lifecycle-events). However, OpenAI does not +emit any warnings when approaching the limit. As a workaround, we configure a max timeout client side to gracefully +handle the connection closure. We set the max to 50 minutes to provide enough buffer before hitting the real limit. +""" OPENAI_REALTIME_URL = "wss://api.openai.com/v1/realtime" DEFAULT_MODEL = "gpt-realtime" @@ -74,6 +86,7 @@ class BidiOpenAIRealtimeModel(BidiModel): """ _websocket: ClientConnection + _start_time: int def __init__( self, @@ -81,6 +94,7 @@ def __init__( api_key: str | None = None, organization: str | None = None, project: str | None = None, + timeout_s: int = OPENAI_MAX_TIMEOUT_S, session_config: dict[str, Any] | None = None, config: dict[str, Any] | None = None, **kwargs: Any, @@ -92,9 +106,11 @@ def __init__( api_key: OpenAI API key for authentication. organization: OpenAI organization ID for API requests. project: OpenAI project ID for API requests. + timeout_s: Connection timeout in seconds (max: 3000s). + Model will raise a BidiModelTimeoutError after hitting this limit. session_config: Session configuration parameters (e.g., voice, turn_detection, modalities). config: Optional configuration dictionary with structure {"audio": AudioConfig, ...}. - If not provided or if "audio" key is missing, uses OpenAI Realtime API's default audio configuration. + If not provided or if "audio" key is missing, uses OpenAI Realtime API's default audio configuration. **kwargs: Reserved for future parameters. """ # Model configuration @@ -102,6 +118,7 @@ def __init__( self.api_key = api_key self.organization = organization self.project = project + self.timeout_s = timeout_s self.session_config = session_config or {} if not self.api_key: @@ -111,6 +128,11 @@ def __init__( "OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter." ) + if self.timeout_s > OPENAI_MAX_TIMEOUT_S: + raise ValueError( + f"timeout_s=<{timeout_s}>, max_timeout_s=<{OPENAI_MAX_TIMEOUT_S}> | timeout exceeds max limit" + ) + # Connection state (initialized in start()) self._connection_id: str | None = None @@ -168,10 +190,11 @@ async def start( if self._connection_id: raise RuntimeError("model already started | call stop before starting again") - logger.info("openai realtime connection starting") + logger.debug("openai realtime connection starting") # Initialize connection state self._connection_id = str(uuid.uuid4()) + self._start_time = int(time.time()) self._function_call_buffer = {} @@ -185,7 +208,7 @@ async def start( headers.append(("OpenAI-Project", self.project)) self._websocket = await websockets.connect(url, additional_headers=headers) - logger.info("connection_id=<%s> | websocket connected successfully", self._connection_id) + logger.debug("connection_id=<%s> | websocket connected successfully", self._connection_id) # Configure session session_config = self._build_session_config(system_prompt, tools) @@ -397,7 +420,16 @@ async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) - async for message in self._websocket: + while True: + duration = time.time() - self._start_time + if duration >= self.timeout_s: + raise BidiModelTimeoutError(f"timeout_s=<{self.timeout_s}>") + + try: + message = await asyncio.wait_for(self._websocket.recv(), timeout=10) + except asyncio.TimeoutError: + continue + openai_event = json.loads(message) for event in self._convert_openai_event(openai_event) or []: diff --git a/tests/strands/experimental/bidi/models/test_openai.py b/tests/strands/experimental/bidi/models/test_openai.py index ede0920a6..cb221b917 100644 --- a/tests/strands/experimental/bidi/models/test_openai.py +++ b/tests/strands/experimental/bidi/models/test_openai.py @@ -14,10 +14,12 @@ import pytest +from strands.experimental.bidi.models.bidi_model import BidiModelTimeoutError from strands.experimental.bidi.models.openai import BidiOpenAIRealtimeModel from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, + BidiConnectionStartEvent, BidiImageInputEvent, BidiInterruptionEvent, BidiResponseCompleteEvent, @@ -60,8 +62,9 @@ def api_key(): @pytest.fixture -def model(api_key, model_name): +def model(mock_websockets_connect, api_key, model_name): """Create an BidiOpenAIRealtimeModel instance.""" + _ = mock_websockets_connect return BidiOpenAIRealtimeModel(model=model_name, api_key=api_key) @@ -488,37 +491,47 @@ async def test_send_edge_cases(mock_websockets_connect, model): @pytest.mark.asyncio -async def test_receive_lifecycle_events(mock_websockets_connect, model): - """Test that receive() emits connection start and end events.""" - _, _ = mock_websockets_connect +async def test_receive_lifecycle_events(mock_websocket, model): + audio_message = '{"type": "response.output_audio.delta", "delta": ""}' + mock_websocket.recv.return_value = audio_message await model.start() + model._connection_id = "c1" + + tru_events = [] + async for event in model.receive(): + tru_events.append(event) + if len(tru_events) >= 2: + break + + exp_events = [ + BidiConnectionStartEvent(connection_id="c1", model="gpt-realtime"), + BidiAudioStreamEvent( + audio="", + format="pcm", + sample_rate=24000, + channels=1, + ) + ] + assert tru_events == exp_events - # Get first event - receive_gen = model.receive() - first_event = await anext(receive_gen) - # First event should be connection start (new TypedEvent format) - assert first_event.get("type") == "bidi_connection_start" - assert first_event.get("connection_id") == model._connection_id - assert first_event.get("model") == model.model_id +@unittest.mock.patch("strands.experimental.bidi.models.openai.time.time") +@pytest.mark.asyncio +async def test_receive_timeout(mock_time, model): + mock_time.side_effect = [1, 2] + model.timeout_s = 1 - # Close to trigger session end - await model.stop() + await model.start() - # Collect remaining events - events = [first_event] - try: - async for event in receive_gen: - events.append(event) - except StopAsyncIteration: - pass + with pytest.raises(BidiModelTimeoutError): + async for _ in model.receive(): + pass @pytest.mark.asyncio -async def test_event_conversion(mock_websockets_connect, model): +async def test_event_conversion(model): """Test conversion of all OpenAI event types to standard format.""" - _, _ = mock_websockets_connect await model.start() # Test audio output (now returns list with BidiAudioStreamEvent) From 505de3e7a862c351ebe0a94c5228e688ccacc486 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Sat, 29 Nov 2025 18:05:23 -0500 Subject: [PATCH 224/242] run bidi:prepare (#97) --- src/strands/experimental/bidi/agent/agent.py | 1 - .../experimental/bidi/models/gemini_live.py | 18 +++------ .../experimental/bidi/models/novasonic.py | 19 ++++----- .../experimental/bidi/models/openai.py | 27 +++++++------ .../bidi/models/test_gemini_live.py | 40 ++++++------------- .../bidi/models/test_novasonic.py | 5 ++- .../experimental/bidi/models/test_openai.py | 28 ++++--------- 7 files changed, 49 insertions(+), 89 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 56c55e4a9..74b65ba10 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -35,7 +35,6 @@ from ..types.agent import BidiAgentInput from ..types.events import ( BidiAudioInputEvent, - BidiConnectionCloseEvent, BidiImageInputEvent, BidiInputEvent, BidiOutputEvent, diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 79030e03f..2e9a13b54 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -81,10 +81,7 @@ def __init__( self._client_config = self._resolve_client_config(client_config or {}) # Resolve provider config with defaults - self._provider_config = self._resolve_provider_config(provider_config or {}) - - # Extract and store audio config for IO coordination - self.config: dict[str, Any] = {"audio": self._provider_config["audio"]} + self.config = self._resolve_provider_config(provider_config or {}) # Store API key for later use self.api_key = self._client_config.get("api_key") @@ -113,10 +110,7 @@ def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: provider_voice = None if "speech_config" in config and isinstance(config["speech_config"], dict): provider_voice = ( - config["speech_config"] - .get("voice_config", {}) - .get("prebuilt_voice_config", {}) - .get("voice_name") + config["speech_config"].get("voice_config", {}).get("prebuilt_voice_config", {}).get("voice_name") ) # Define default audio configuration @@ -283,8 +277,8 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> list[BidiOut BidiAudioStreamEvent( audio=audio_b64, format="pcm", - sample_rate=cast(SampleRate, self.config["audio"]["output_rate"]), - channels=cast(Channel, self.config["audio"]["channels"]), + sample_rate=cast(AudioSampleRate, self.config["audio"]["output_rate"]), + channels=cast(AudioChannel, self.config["audio"]["channels"]), ) ] @@ -494,8 +488,8 @@ def _build_live_config( to configure any Gemini Live API parameter directly. """ config_dict: dict[str, Any] = {} - if self._provider_config: - config_dict.update({k: v for k, v in self._provider_config.items() if k != "audio"}) + if self.config: + config_dict.update({k: v for k, v in self.config.items() if k != "audio"}) # Override with any kwargs from start() config_dict.update(kwargs) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 2a16ee91e..24c932ab0 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -17,7 +17,7 @@ import json import logging import uuid -from typing import Any, AsyncGenerator, Literal, cast +from typing import Any, AsyncGenerator, cast import boto3 from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput @@ -117,10 +117,7 @@ def __init__( self._client_config = self._resolve_client_config(client_config or {}) # Resolve provider config with defaults - self._provider_config = self._resolve_provider_config(provider_config or {}) - - # Extract and store audio config for IO coordination - self.config: dict[str, Any] = {"audio": self._provider_config["audio"]} + self.config = self._resolve_provider_config(provider_config or {}) # Store session and region for later use self._session = self._client_config["boto_session"] @@ -167,15 +164,15 @@ def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: "voice": cast(str, NOVA_AUDIO_OUTPUT_CONFIG["voiceId"]), } - user_audio = config.get("audio", {}) - merged_audio = {**default_audio, **user_audio} + user_audio_config = config.get("audio", {}) + merged_audio = {**default_audio_config, **user_audio_config} resolved = { "audio": merged_audio, **{k: v for k, v in config.items() if k != "audio"}, } - if user_audio: + if user_audio_config: logger.debug("audio_config | merged user-provided config with defaults") else: logger.debug("audio_config | using default Nova Sonic audio configuration") @@ -507,13 +504,11 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N if "audioOutput" in nova_event: # Audio is already base64 string from Nova Sonic audio_content = nova_event["audioOutput"]["content"] - # Channels from config is guaranteed to be 1 or 2 - channels = cast(Literal[1, 2], self.config["audio"]["channels"]) return BidiAudioStreamEvent( audio=audio_content, format="pcm", - sample_rate=cast(AudioSampleRate, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]), - channels=channels, + sample_rate=cast(AudioSampleRate, self.config["audio"]["output_rate"]), + channels=cast(AudioChannel, self.config["audio"]["channels"]), ) # Handle text output (transcripts) diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 59fb55f5b..bfe3ad533 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -111,19 +111,17 @@ def __init__( self._client_config = self._resolve_client_config(client_config or {}) # Resolve provider config with defaults - self._provider_config = self._resolve_provider_config(provider_config or {}) - - # Extract and store audio config for IO coordination - self.config: dict[str, Any] = {"audio": self._provider_config["audio"]} + self.config = self._resolve_provider_config(provider_config or {}) # Store client config values for later use self.api_key = self._client_config["api_key"] self.organization = self._client_config.get("organization") self.project = self._client_config.get("project") + self.timeout_s = self._client_config["timeout_s"] if self.timeout_s > OPENAI_MAX_TIMEOUT_S: raise ValueError( - f"timeout_s=<{timeout_s}>, max_timeout_s=<{OPENAI_MAX_TIMEOUT_S}> | timeout exceeds max limit" + f"timeout_s=<{self.timeout_s}>, max_timeout_s=<{OPENAI_MAX_TIMEOUT_S}> | timeout exceeds max limit" ) # Connection state (initialized in start()) @@ -139,7 +137,7 @@ def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: if "api_key" not in resolved: resolved["api_key"] = os.getenv("OPENAI_API_KEY") - + if not resolved.get("api_key"): raise ValueError( "OpenAI API key is required. Provide via client_config={'api_key': '...'} " @@ -149,12 +147,15 @@ def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: env_org = os.getenv("OPENAI_ORGANIZATION") if env_org: resolved["organization"] = env_org - + if "project" not in resolved: env_project = os.getenv("OPENAI_PROJECT") if env_project: resolved["project"] = env_project + if "timeout_s" not in resolved: + resolved["timeout_s"] = OPENAI_MAX_TIMEOUT_S + return resolved def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: @@ -167,8 +168,8 @@ def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: # Define default audio configuration default_audio: AudioConfig = { - "input_rate": DEFAULT_SAMPLE_RATE, - "output_rate": DEFAULT_SAMPLE_RATE, + "input_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE), + "output_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE), "channels": 1, "format": "pcm", "voice": provider_voice or "alloy", @@ -288,7 +289,7 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] "turn_detection", } - for key, value in self._provider_config.items(): + for key, value in self.config.items(): if key == "audio": continue elif key in supported_params: @@ -297,15 +298,15 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] logger.warning("parameter=<%s> | ignoring unsupported session parameter", key) audio_config = self.config["audio"] - + if "voice" in audio_config: config.setdefault("audio", {}).setdefault("output", {})["voice"] = audio_config["voice"] - + if "input_rate" in audio_config: config.setdefault("audio", {}).setdefault("input", {}).setdefault("format", {})["rate"] = audio_config[ "input_rate" ] - + if "output_rate" in audio_config: config.setdefault("audio", {}).setdefault("output", {}).setdefault("format", {})["rate"] = audio_config[ "output_rate" diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index 48c1d9e09..dec83dbe3 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -97,9 +97,9 @@ def test_model_initialization(mock_genai_client, model_id, api_key): assert model_default.api_key is None assert model_default._live_session is None # Check default config includes transcription - assert model_default.provider_config["response_modalities"] == ["AUDIO"] - assert "outputAudioTranscription" in model_default.provider_config - assert "inputAudioTranscription" in model_default.provider_config + assert model_default.config["response_modalities"] == ["AUDIO"] + assert "outputAudioTranscription" in model_default.config + assert "inputAudioTranscription" in model_default.config # Test with API key model_with_key = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) @@ -110,10 +110,10 @@ def test_model_initialization(mock_genai_client, model_id, api_key): provider_config = {"temperature": 0.7, "top_p": 0.9} model_custom = BidiGeminiLiveModel(model_id=model_id, provider_config=provider_config) # Custom config should be merged with defaults - assert model_custom.provider_config["temperature"] == 0.7 - assert model_custom.provider_config["top_p"] == 0.9 + assert model_custom.config["temperature"] == 0.7 + assert model_custom.config["top_p"] == 0.9 # Defaults should still be present - assert "response_modalities" in model_custom.provider_config + assert "response_modalities" in model_custom.config # Connection Tests @@ -465,8 +465,8 @@ def test_audio_config_partial_override(mock_genai_client, model_id, api_key): """Test partial audio configuration override.""" _ = mock_genai_client - config = {"audio": {"output_rate": 48000, "voice": "Puck"}} - model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, config=config) + provider_config = {"audio": {"output_rate": 48000, "voice": "Puck"}} + model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, provider_config=provider_config) # Overridden values assert model.config["audio"]["output_rate"] == 48000 @@ -482,7 +482,7 @@ def test_audio_config_full_override(mock_genai_client, model_id, api_key): """Test full audio configuration override.""" _ = mock_genai_client - config = { + provider_config = { "audio": { "input_rate": 48000, "output_rate": 48000, @@ -491,7 +491,7 @@ def test_audio_config_full_override(mock_genai_client, model_id, api_key): "voice": "Aoede", } } - model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, config=config) + model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, provider_config=provider_config) assert model.config["audio"]["input_rate"] == 48000 assert model.config["audio"]["output_rate"] == 48000 @@ -500,22 +500,6 @@ def test_audio_config_full_override(mock_genai_client, model_id, api_key): assert model.config["audio"]["voice"] == "Aoede" -def test_audio_config_voice_priority(mock_genai_client, model_id, api_key): - """Test that config audio voice takes precedence over provider_config voice.""" - _ = mock_genai_client - - provider_config = {"speech_config": {"voice_config": {"prebuilt_voice_config": {"voice_name": "Puck"}}}} - config = {"audio": {"voice": "Aoede"}} - - model = BidiGeminiLiveModel( - model_id=model_id, client_config={"api_key": api_key}, provider_config=provider_config, config=config - ) - - # Build config and verify config audio voice takes precedence - built_config = model._build_live_config() - assert built_config["speech_config"]["voice_config"]["prebuilt_voice_config"]["voice_name"] == "Aoede" - - # Helper Method Tests @@ -557,8 +541,8 @@ async def test_custom_audio_rates_in_events(mock_genai_client, model_id, api_key _, _, _ = mock_genai_client # Create model with custom audio configuration - config = {"audio": {"output_rate": 48000, "channels": 2}} - model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, config=config) + provider_config = {"audio": {"output_rate": 48000, "channels": 2}} + model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, provider_config=provider_config) await model.start() # Test audio output event uses custom configuration diff --git a/tests/strands/experimental/bidi/models/test_novasonic.py b/tests/strands/experimental/bidi/models/test_novasonic.py index 3f4f6c2bc..39524e434 100644 --- a/tests/strands/experimental/bidi/models/test_novasonic.py +++ b/tests/strands/experimental/bidi/models/test_novasonic.py @@ -568,6 +568,7 @@ async def test_default_audio_rates_in_events(model_id, region): # Error Handling Tests +@pytest.mark.asyncio async def test_bidi_nova_sonic_model_receive_timeout(nova_model, mock_stream): mock_output = AsyncMock() mock_output.receive.side_effect = ModelTimeoutException("Connection timeout") @@ -575,7 +576,7 @@ async def test_bidi_nova_sonic_model_receive_timeout(nova_model, mock_stream): await nova_model.start() - with pytest.raises(BidiModelTimeoutError): + with pytest.raises(BidiModelTimeoutError, match=r"Connection timeout"): async for _ in nova_model.receive(): pass @@ -588,7 +589,7 @@ async def test_bidi_nova_sonic_model_receive_timeout_validation(nova_model, mock await nova_model.start() - with pytest.raises(BidiModelTimeoutError): + with pytest.raises(BidiModelTimeoutError, match=r"InternalErrorCode=531"): async for _ in nova_model.receive(): pass diff --git a/tests/strands/experimental/bidi/models/test_openai.py b/tests/strands/experimental/bidi/models/test_openai.py index ab1705cd9..85a1cc097 100644 --- a/tests/strands/experimental/bidi/models/test_openai.py +++ b/tests/strands/experimental/bidi/models/test_openai.py @@ -130,8 +130,8 @@ def test_audio_config_defaults(api_key, model_name): def test_audio_config_partial_override(api_key, model_name): """Test partial audio configuration override.""" - config = {"audio": {"output_rate": 48000, "voice": "echo"}} - model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, config=config) + provider_config = {"audio": {"output_rate": 48000, "voice": "echo"}} + model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config) # Overridden values assert model.config["audio"]["output_rate"] == 48000 @@ -145,7 +145,7 @@ def test_audio_config_partial_override(api_key, model_name): def test_audio_config_full_override(api_key, model_name): """Test full audio configuration override.""" - config = { + provider_config = { "audio": { "input_rate": 48000, "output_rate": 48000, @@ -154,7 +154,7 @@ def test_audio_config_full_override(api_key, model_name): "voice": "shimmer", } } - model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, config=config) + model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config) assert model.config["audio"]["input_rate"] == 48000 assert model.config["audio"]["output_rate"] == 48000 @@ -163,23 +163,9 @@ def test_audio_config_full_override(api_key, model_name): assert model.config["audio"]["voice"] == "shimmer" -def test_audio_config_voice_priority(api_key, model_name): - """Test that config audio voice takes precedence over provider_config voice.""" - provider_config = {"audio": {"output": {"voice": "alloy"}}} - config = {"audio": {"voice": "nova"}} - - model = BidiOpenAIRealtimeModel( - model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config, config=config - ) - - # Build config and verify config audio voice takes precedence - built_config = model._build_session_config(None, None) - assert built_config["audio"]["output"]["voice"] == "nova" - - def test_audio_config_extracts_voice_from_provider_config(api_key, model_name): """Test that voice is extracted from provider_config when config audio not provided.""" - provider_config = {"audio": {"output": {"voice": "fable"}}} + provider_config = {"audio": {"voice": "fable"}} model = BidiOpenAIRealtimeModel( model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config @@ -537,7 +523,7 @@ async def test_receive_timeout(mock_time, model): await model.start() - with pytest.raises(BidiModelTimeoutError): + with pytest.raises(BidiModelTimeoutError, match=r"timeout_s=<1>"): async for _ in model.receive(): pass @@ -712,7 +698,7 @@ async def test_custom_audio_sample_rate(mock_websockets_connect, api_key): # Create model with custom sample rate custom_sample_rate = 48000 - provider_config = {"audio": {"output": {"format": {"rate": custom_sample_rate}}}} + provider_config = {"audio": {"output_rate": custom_sample_rate}} model = BidiOpenAIRealtimeModel(client_config={"api_key": api_key}, provider_config=provider_config) await model.start() From 40aac26e69779ad2df62c4f61e024a09c623397b Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Sun, 30 Nov 2025 13:48:01 -0500 Subject: [PATCH 225/242] loop - restart connection - gemini (#96) --- src/strands/experimental/bidi/agent/loop.py | 1 + .../experimental/bidi/models/bidi_model.py | 11 ++++- .../experimental/bidi/models/gemini_live.py | 25 ++++++++--- .../experimental/bidi/models/novasonic.py | 6 +-- .../experimental/bidi/agent/test_loop.py | 5 ++- .../bidi/models/test_gemini_live.py | 45 +++++++++++++++++++ 6 files changed, 82 insertions(+), 11 deletions(-) diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 80245d9b2..13b7033a4 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -198,6 +198,7 @@ async def _restart_connection(self, timeout_error: BidiModelTimeoutError) -> Non self._agent.system_prompt, self._agent.tool_registry.get_all_tool_specs(), self._agent.messages, + **timeout_error.restart_config, ) self._task_pool.create(self._run_model()) except Exception as exception: diff --git a/src/strands/experimental/bidi/models/bidi_model.py b/src/strands/experimental/bidi/models/bidi_model.py index bc2806e78..0d0da63d2 100644 --- a/src/strands/experimental/bidi/models/bidi_model.py +++ b/src/strands/experimental/bidi/models/bidi_model.py @@ -118,4 +118,13 @@ class BidiModelTimeoutError(Exception): to create a seamless, uninterrupted experience for the user. """ - pass + def __init__(self, message: str, **restart_config: Any) -> None: + """Initialize error. + + Args: + message: Timeout message from model. + **restart_config: Configure restart specific behaviors in the call to model start. + """ + super().__init__(self, message) + + self.restart_config = restart_config diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 2e9a13b54..1f2b2d5cd 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -40,7 +40,7 @@ BidiUsageEvent, ModalityUsage, ) -from .bidi_model import BidiModel +from .bidi_model import BidiModel, BidiModelTimeoutError logger = logging.getLogger(__name__) @@ -92,6 +92,7 @@ def __init__( # Connection state (initialized in start()) self._live_session: Any = None self._live_session_context_manager: Any = None + self._live_session_handle: str | None = None self._connection_id: str | None = None def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: @@ -175,8 +176,8 @@ async def start( ) self._live_session = await self._live_session_context_manager.__aenter__() - # Send initial message history if provided - if messages: + # Gemini itself restores message history when resuming from session + if messages and "live_session_handle" not in kwargs: await self._send_message_history(messages) async def _send_message_history(self, messages: Messages) -> None: @@ -227,7 +228,22 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> list[BidiOut Returns: List of event dicts (empty list if no events to emit). + + Raises: + BidiModelTimeoutError: If gemini responds with go away message. """ + if message.go_away: + raise BidiModelTimeoutError( + message.go_away.model_dump_json(), live_session_handle=self._live_session_handle + ) + + if message.session_resumption_update: + resumption_update = message.session_resumption_update + if resumption_update.resumable and resumption_update.new_handle: + self._live_session_handle = resumption_update.new_handle + logger.debug("session_handle=<%s> | updating gemini session handle", self._live_session_handle) + return [] + # Handle interruption first (from server_content) if message.server_content and message.server_content.interrupted: return [BidiInterruptionEvent(reason="user_speech")] @@ -491,8 +507,7 @@ def _build_live_config( if self.config: config_dict.update({k: v for k, v in self.config.items() if k != "audio"}) - # Override with any kwargs from start() - config_dict.update(kwargs) + config_dict["session_resumption"] = {"handle": kwargs.get("live_session_handle")} # Add system instruction if provided if system_prompt: diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 24c932ab0..713afe028 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -297,13 +297,13 @@ async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: event_data = await output.receive() except ValidationException as error: - if "InternalErrorCode=531" in str(error): + if "InternalErrorCode=531" in error.message: # nova also times out if user is silent for 175 seconds - raise BidiModelTimeoutError(error) from error + raise BidiModelTimeoutError(error.message) from error raise except ModelTimeoutException as error: - raise BidiModelTimeoutError(error) from error + raise BidiModelTimeoutError(error.message) from error if not event_data: continue diff --git a/tests/strands/experimental/bidi/agent/test_loop.py b/tests/strands/experimental/bidi/agent/test_loop.py index 68346ab19..d19cada60 100644 --- a/tests/strands/experimental/bidi/agent/test_loop.py +++ b/tests/strands/experimental/bidi/agent/test_loop.py @@ -42,7 +42,7 @@ async def loop(agent): @pytest.mark.asyncio async def test_bidi_agent_loop_receive_restart_connection(loop, agent, agenerator): - timeout_error = BidiModelTimeoutError("test timeout") + timeout_error = BidiModelTimeoutError("test timeout", test_restart_config=1) text_event = BidiTextInputEvent(text="test after restart") agent.model.receive = unittest.mock.Mock(side_effect=[timeout_error, agenerator([text_event])]) @@ -63,10 +63,11 @@ async def test_bidi_agent_loop_receive_restart_connection(loop, agent, agenerato agent.model.stop.assert_called_once() assert agent.model.start.call_count == 2 - agent.model.start.assert_any_call( + agent.model.start.assert_called_with( agent.system_prompt, agent.tool_registry.get_all_tool_specs(), agent.messages, + test_restart_config=1, ) diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index dec83dbe3..a880bb223 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -13,6 +13,7 @@ import pytest from google.genai import types as genai_types +from strands.experimental.bidi.models.bidi_model import BidiModelTimeoutError from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, @@ -279,6 +280,34 @@ async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): assert event.connection_id == model._connection_id +@pytest.mark.asyncio +async def test_receive_timeout(mock_genai_client, model, agenerator): + mock_resumption_response = unittest.mock.Mock() + mock_resumption_response.go_away = None + mock_resumption_response.session_resumption_update = unittest.mock.Mock() + mock_resumption_response.session_resumption_update.resumable = True + mock_resumption_response.session_resumption_update.new_handle = "h1" + + mock_timeout_response = unittest.mock.Mock() + mock_timeout_response.go_away = unittest.mock.Mock() + mock_timeout_response.go_away.model_dump_json.return_value = "test timeout" + + _, mock_live_session, _ = mock_genai_client + mock_live_session.receive = unittest.mock.Mock( + return_value=agenerator([mock_resumption_response, mock_timeout_response]) + ) + + await model.start() + + with pytest.raises(BidiModelTimeoutError, match=r"test timeout"): + async for _ in model.receive(): + pass + + tru_handle = model._live_session_handle + exp_handle = "h1" + assert tru_handle == exp_handle + + @pytest.mark.asyncio async def test_event_conversion(mock_genai_client, model): """Test conversion of all Gemini Live event types to standard format.""" @@ -288,6 +317,8 @@ async def test_event_conversion(mock_genai_client, model): # Test text output (converted to transcript via model_turn.parts) mock_text = unittest.mock.Mock() mock_text.data = None + mock_text.go_away = None + mock_text.session_resumption_update = None mock_text.tool_call = None # Create proper server_content structure with model_turn @@ -319,6 +350,8 @@ async def test_event_conversion(mock_genai_client, model): # Test multiple text parts (should concatenate) mock_multi_text = unittest.mock.Mock() mock_multi_text.data = None + mock_multi_text.go_away = None + mock_multi_text.session_resumption_update = None mock_multi_text.tool_call = None mock_server_content_multi = unittest.mock.Mock() @@ -347,6 +380,8 @@ async def test_event_conversion(mock_genai_client, model): mock_audio = unittest.mock.Mock() mock_audio.text = None mock_audio.data = b"audio_data" + mock_audio.go_away = None + mock_audio.session_resumption_update = None mock_audio.tool_call = None mock_audio.server_content = None @@ -373,6 +408,8 @@ async def test_event_conversion(mock_genai_client, model): mock_tool = unittest.mock.Mock() mock_tool.text = None mock_tool.data = None + mock_tool.go_away = None + mock_tool.session_resumption_update = None mock_tool.tool_call = mock_tool_call mock_tool.server_content = None @@ -404,6 +441,8 @@ async def test_event_conversion(mock_genai_client, model): mock_tool_multi = unittest.mock.Mock() mock_tool_multi.text = None mock_tool_multi.data = None + mock_tool_multi.go_away = None + mock_tool_multi.session_resumption_update = None mock_tool_multi.tool_call = mock_tool_call_multi mock_tool_multi.server_content = None @@ -431,6 +470,8 @@ async def test_event_conversion(mock_genai_client, model): mock_interrupt = unittest.mock.Mock() mock_interrupt.text = None mock_interrupt.data = None + mock_interrupt.go_away = None + mock_interrupt.session_resumption_update = None mock_interrupt.tool_call = None mock_interrupt.server_content = mock_server_content @@ -549,6 +590,8 @@ async def test_custom_audio_rates_in_events(mock_genai_client, model_id, api_key mock_audio = unittest.mock.Mock() mock_audio.text = None mock_audio.data = b"audio_data" + mock_audio.go_away = None + mock_audio.session_resumption_update = None mock_audio.tool_call = None mock_audio.server_content = None @@ -577,6 +620,8 @@ async def test_default_audio_rates_in_events(mock_genai_client, model_id, api_ke mock_audio = unittest.mock.Mock() mock_audio.text = None mock_audio.data = b"audio_data" + mock_audio.go_away = None + mock_audio.session_resumption_update = None mock_audio.tool_call = None mock_audio.server_content = None From 642752abb690a79220b8ecc4a24f7ad371ac9bf8 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sun, 30 Nov 2025 16:08:33 -0500 Subject: [PATCH 226/242] rename modules --- src/strands/experimental/bidi/__init__.py | 6 +- src/strands/experimental/bidi/agent/agent.py | 4 +- .../experimental/bidi/models/__init__.py | 6 +- .../experimental/bidi/models/bidi_model.py | 130 --- .../experimental/bidi/models/gemini_live.py | 2 +- .../experimental/bidi/models/novasonic.py | 760 ---------------- .../experimental/bidi/models/openai.py | 816 ------------------ .../experimental/bidi/types/bidi_model.py | 36 - src/strands/experimental/bidi/types/events.py | 2 +- .../bidi/models/test_gemini_live.py | 2 +- .../bidi/models/test_novasonic.py | 4 +- .../experimental/bidi/models/test_openai.py | 4 +- 12 files changed, 15 insertions(+), 1757 deletions(-) delete mode 100644 src/strands/experimental/bidi/models/bidi_model.py delete mode 100644 src/strands/experimental/bidi/models/novasonic.py delete mode 100644 src/strands/experimental/bidi/models/openai.py delete mode 100644 src/strands/experimental/bidi/types/bidi_model.py diff --git a/src/strands/experimental/bidi/__init__.py b/src/strands/experimental/bidi/__init__.py index 7e2ad2bb3..13c5b51e1 100644 --- a/src/strands/experimental/bidi/__init__.py +++ b/src/strands/experimental/bidi/__init__.py @@ -18,12 +18,12 @@ from .io.audio import BidiAudioIO # Model interface (for custom implementations) -from .models.bidi_model import BidiModel +from .models.model import BidiModel # Model providers - What users need to create models from .models.gemini_live import BidiGeminiLiveModel -from .models.novasonic import BidiNovaSonicModel -from .models.openai import BidiOpenAIRealtimeModel +from .models.nova_sonic import BidiNovaSonicModel +from .models.openai_realtime import BidiOpenAIRealtimeModel # Built-in tools from .tools import stop_conversation diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 74b65ba10..68075d0b2 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -30,8 +30,8 @@ from ...hooks.events import BidiAgentInitializedEvent from ...tools import ToolProvider from .._async import stop_all -from ..models.bidi_model import BidiModel -from ..models.novasonic import BidiNovaSonicModel +from ..models.model import BidiModel +from ..models.nova_sonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput from ..types.events import ( BidiAudioInputEvent, diff --git a/src/strands/experimental/bidi/models/__init__.py b/src/strands/experimental/bidi/models/__init__.py index d1221df36..29a2229c5 100644 --- a/src/strands/experimental/bidi/models/__init__.py +++ b/src/strands/experimental/bidi/models/__init__.py @@ -1,9 +1,9 @@ """Bidirectional model interfaces and implementations.""" -from .bidi_model import BidiModel, BidiModelTimeoutError +from .model import BidiModel, BidiModelTimeoutError from .gemini_live import BidiGeminiLiveModel -from .novasonic import BidiNovaSonicModel -from .openai import BidiOpenAIRealtimeModel +from .nova_sonic import BidiNovaSonicModel +from .openai_realtime import BidiOpenAIRealtimeModel __all__ = [ "BidiModel", diff --git a/src/strands/experimental/bidi/models/bidi_model.py b/src/strands/experimental/bidi/models/bidi_model.py deleted file mode 100644 index 0d0da63d2..000000000 --- a/src/strands/experimental/bidi/models/bidi_model.py +++ /dev/null @@ -1,130 +0,0 @@ -"""Bidirectional streaming model interface. - -Defines the abstract interface for models that support real-time bidirectional -communication with persistent connections. Unlike traditional request-response -models, bidirectional models maintain an open connection for streaming audio, -text, and tool interactions. - -Features: -- Persistent connection management with connect/close lifecycle -- Real-time bidirectional communication (send and receive simultaneously) -- Provider-agnostic event normalization -- Support for audio, text, image, and tool result streaming -""" - -import logging -from typing import Any, AsyncIterable, Protocol - -from ....types._events import ToolResultEvent -from ....types.content import Messages -from ....types.tools import ToolSpec -from ..types.events import ( - BidiInputEvent, - BidiOutputEvent, -) - -logger = logging.getLogger(__name__) - - -class BidiModel(Protocol): - """Protocol for bidirectional streaming models. - - This interface defines the contract for models that support persistent streaming - connections with real-time audio and text communication. Implementations handle - provider-specific protocols while exposing a standardized event-based API. - - Attributes: - config: Configuration dictionary with provider-specific settings. - """ - - config: dict[str, Any] - - async def start( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs: Any, - ) -> None: - """Establish a persistent streaming connection with the model. - - Opens a bidirectional connection that remains active for real-time communication. - The connection supports concurrent sending and receiving of events until explicitly - closed. Must be called before any send() or receive() operations. - - Args: - system_prompt: System instructions to configure model behavior. - tools: Tool specifications that the model can invoke during the conversation. - messages: Initial conversation history to provide context. - **kwargs: Provider-specific configuration options. - """ - ... - - async def stop(self) -> None: - """Close the streaming connection and release resources. - - Terminates the active bidirectional connection and cleans up any associated - resources such as network connections, buffers, or background tasks. After - calling close(), the model instance cannot be used until start() is called again. - """ - ... - - def receive(self) -> AsyncIterable[BidiOutputEvent]: - """Receive streaming events from the model. - - Continuously yields events from the model as they arrive over the connection. - Events are normalized to a provider-agnostic format for uniform processing. - This method should be called in a loop or async task to process model responses. - - The stream continues until the connection is closed or an error occurs. - - Yields: - BidiOutputEvent: Standardized event objects containing audio output, - transcripts, tool calls, or control signals. - """ - ... - - async def send( - self, - content: BidiInputEvent | ToolResultEvent, - ) -> None: - """Send content to the model over the active connection. - - Transmits user input or tool results to the model during an active streaming - session. Supports multiple content types including text, audio, images, and - tool execution results. Can be called multiple times during a conversation. - - Args: - content: The content to send. Must be one of: - - BidiTextInputEvent: Text message from the user - - BidiAudioInputEvent: Audio data for speech input - - BidiImageInputEvent: Image data for visual understanding - - ToolResultEvent: Result from a tool execution - - Example: - await model.send(BidiTextInputEvent(text="Hello", role="user")) - await model.send(BidiAudioInputEvent(audio=bytes, format="pcm", sample_rate=16000, channels=1)) - await model.send(BidiImageInputEvent(image=bytes, mime_type="image/jpeg", encoding="raw")) - await model.send(ToolResultEvent(tool_result)) - """ - ... - - -class BidiModelTimeoutError(Exception): - """Model timeout error. - - Bidirectional models are often configured with a connection time limit. Nova sonic for example keeps the connection - open for 8 minutes max. Upon receiving a timeout, the agent loop is configured to restart the model connection so as - to create a seamless, uninterrupted experience for the user. - """ - - def __init__(self, message: str, **restart_config: Any) -> None: - """Initialize error. - - Args: - message: Timeout message from model. - **restart_config: Configure restart specific behaviors in the call to model start. - """ - super().__init__(self, message) - - self.restart_config = restart_config diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 1f2b2d5cd..efc1d1832 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -40,7 +40,7 @@ BidiUsageEvent, ModalityUsage, ) -from .bidi_model import BidiModel, BidiModelTimeoutError +from .model import BidiModel, BidiModelTimeoutError logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py deleted file mode 100644 index 713afe028..000000000 --- a/src/strands/experimental/bidi/models/novasonic.py +++ /dev/null @@ -1,760 +0,0 @@ -"""Nova Sonic bidirectional model provider for real-time streaming conversations. - -Implements the BidiModel interface for Amazon's Nova Sonic, handling the -complex event sequencing and audio processing required by Nova Sonic's -InvokeModelWithBidirectionalStream protocol. - -Nova Sonic specifics: -- Hierarchical event sequences: connectionStart → promptStart → content streaming -- Base64-encoded audio format with hex encoding -- Tool execution with content containers and identifier tracking -- 8-minute connection limits with proper cleanup sequences -- Interruption detection through stopReason events -""" - -import asyncio -import base64 -import json -import logging -import uuid -from typing import Any, AsyncGenerator, cast - -import boto3 -from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput -from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme -from aws_sdk_bedrock_runtime.models import ( - BidirectionalInputPayloadPart, - InvokeModelWithBidirectionalStreamInputChunk, - ModelTimeoutException, - ValidationException, -) -from smithy_aws_core.identity.static import StaticCredentialsResolver -from smithy_core.aio.eventstream import DuplexEventStream -from smithy_core.shapes import ShapeID - -from ....types._events import ToolResultEvent, ToolUseStreamEvent -from ....types.content import Messages -from ....types.tools import ToolResult, ToolSpec, ToolUse -from .._async import stop_all -from ..types.bidi_model import AudioConfig -from ..types.events import ( - AudioChannel, - AudioSampleRate, - BidiAudioInputEvent, - BidiAudioStreamEvent, - BidiConnectionStartEvent, - BidiInputEvent, - BidiInterruptionEvent, - BidiOutputEvent, - BidiResponseCompleteEvent, - BidiResponseStartEvent, - BidiTextInputEvent, - BidiTranscriptStreamEvent, - BidiUsageEvent, -) -from .bidi_model import BidiModel, BidiModelTimeoutError - -logger = logging.getLogger(__name__) - -# Nova Sonic configuration constants -NOVA_INFERENCE_CONFIG = {"maxTokens": 1024, "topP": 0.9, "temperature": 0.7} - -NOVA_AUDIO_INPUT_CONFIG = { - "mediaType": "audio/lpcm", - "sampleRateHertz": 16000, - "sampleSizeBits": 16, - "channelCount": 1, - "audioType": "SPEECH", - "encoding": "base64", -} - -NOVA_AUDIO_OUTPUT_CONFIG = { - "mediaType": "audio/lpcm", - "sampleRateHertz": 16000, - "sampleSizeBits": 16, - "channelCount": 1, - "voiceId": "matthew", - "encoding": "base64", - "audioType": "SPEECH", -} - -NOVA_TEXT_CONFIG = {"mediaType": "text/plain"} -NOVA_TOOL_CONFIG = {"mediaType": "application/json"} - - -class BidiNovaSonicModel(BidiModel): - """Nova Sonic implementation for bidirectional streaming. - - Combines model configuration and connection state in a single class. - Manages Nova Sonic's complex event sequencing, audio format conversion, and - tool execution patterns while providing the standard BidiModel interface. - - Attributes: - _stream: open bedrock stream to nova sonic. - """ - - _stream: DuplexEventStream - - def __init__( - self, - model_id: str = "amazon.nova-sonic-v1:0", - provider_config: dict[str, Any] | None = None, - client_config: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - """Initialize Nova Sonic bidirectional model. - - Args: - model_id: Model identifier (default: amazon.nova-sonic-v1:0) - provider_config: Model behavior (audio, inference settings) - client_config: AWS authentication (boto_session OR region, not both) - **kwargs: Reserved for future parameters. - """ - # Store model ID - self.model_id = model_id - - # Resolve client config with defaults - self._client_config = self._resolve_client_config(client_config or {}) - - # Resolve provider config with defaults - self.config = self._resolve_provider_config(provider_config or {}) - - # Store session and region for later use - self._session = self._client_config["boto_session"] - self.region = self._client_config["region"] - - # Track API-provided identifiers - self._connection_id: str | None = None - self._audio_content_name: str | None = None - self._current_completion_id: str | None = None - - # Indicates if model is done generating transcript - self._generation_stage: str | None = None - - # Ensure certain events are sent in sequence when required - self._send_lock = asyncio.Lock() - - logger.debug("model_id=<%s> | nova sonic model initialized", model_id) - - def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: - """Resolve AWS client config (creates boto session if needed).""" - if "boto_session" in config and "region" in config: - raise ValueError("Cannot specify both 'boto_session' and 'region' in client_config") - - resolved = config.copy() - - # Create boto session if not provided - if "boto_session" not in resolved: - resolved["boto_session"] = boto3.Session() - - # Resolve region from session or use default - if "region" not in resolved: - resolved["region"] = resolved["boto_session"].region_name or "us-east-1" - - return resolved - - def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: - """Merge user config with defaults (user takes precedence).""" - # Define default audio configuration - default_audio_config: AudioConfig = { - "input_rate": cast(AudioSampleRate, NOVA_AUDIO_INPUT_CONFIG["sampleRateHertz"]), - "output_rate": cast(AudioSampleRate, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]), - "channels": cast(AudioChannel, NOVA_AUDIO_INPUT_CONFIG["channelCount"]), - "format": "pcm", - "voice": cast(str, NOVA_AUDIO_OUTPUT_CONFIG["voiceId"]), - } - - user_audio_config = config.get("audio", {}) - merged_audio = {**default_audio_config, **user_audio_config} - - resolved = { - "audio": merged_audio, - **{k: v for k, v in config.items() if k != "audio"}, - } - - if user_audio_config: - logger.debug("audio_config | merged user-provided config with defaults") - else: - logger.debug("audio_config | using default Nova Sonic audio configuration") - - return resolved - - async def start( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs: Any, - ) -> None: - """Establish bidirectional connection to Nova Sonic. - - Args: - system_prompt: System instructions for the model. - tools: List of tools available to the model. - messages: Conversation history to initialize with. - **kwargs: Additional configuration options. - - Raises: - RuntimeError: If user calls start again without first stopping. - """ - if self._connection_id: - raise RuntimeError("model already started | call stop before starting again") - - logger.debug("nova connection starting") - - self._connection_id = str(uuid.uuid4()) - - # Get credentials from boto3 session (full credential chain) - credentials = self._session.get_credentials() - - if not credentials: - raise ValueError( - "no AWS credentials found. configure credentials via environment variables, " - "credential files, IAM roles, or SSO." - ) - - # Use static resolver with credentials configured as properties - resolver = StaticCredentialsResolver() - - config = Config( - endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", - region=self.region, - aws_credentials_identity_resolver=resolver, - auth_scheme_resolver=HTTPAuthSchemeResolver(), - auth_schemes={ShapeID("aws.auth#sigv4"): SigV4AuthScheme(service="bedrock")}, - # Configure static credentials as properties - aws_access_key_id=credentials.access_key, - aws_secret_access_key=credentials.secret_key, - aws_session_token=credentials.token, - ) - - self.client = BedrockRuntimeClient(config=config) - logger.debug("region=<%s> | nova sonic client initialized", self.region) - - client = BedrockRuntimeClient(config=config) - self._stream = await client.invoke_model_with_bidirectional_stream( - InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) - ) - logger.debug("region=<%s> | nova sonic client initialized", self.region) - - init_events = self._build_initialization_events(system_prompt, tools, messages) - logger.debug("event_count=<%d> | sending nova sonic initialization events", len(init_events)) - await self._send_nova_events(init_events) - - logger.info("connection_id=<%s> | nova sonic connection established", self._connection_id) - - def _build_initialization_events( - self, system_prompt: str | None, tools: list[ToolSpec] | None, messages: Messages | None - ) -> list[str]: - """Build the sequence of initialization events.""" - tools = tools or [] - events = [ - self._get_connection_start_event(), - self._get_prompt_start_event(tools), - *self._get_system_prompt_events(system_prompt), - ] - - # Add conversation history if provided - if messages: - events.extend(self._get_message_history_events(messages)) - logger.debug("message_count=<%d> | conversation history added to initialization", len(messages)) - - return events - - def _log_event_type(self, nova_event: dict[str, Any]) -> None: - """Log specific Nova Sonic event types for debugging.""" - if "usageEvent" in nova_event: - logger.debug("usage=<%s> | nova usage event received", nova_event["usageEvent"]) - elif "textOutput" in nova_event: - logger.debug("nova text output received") - elif "toolUse" in nova_event: - tool_use = nova_event["toolUse"] - logger.debug( - "tool_name=<%s>, tool_use_id=<%s> | nova tool use received", - tool_use["toolName"], - tool_use["toolUseId"], - ) - elif "audioOutput" in nova_event: - audio_content = nova_event["audioOutput"]["content"] - audio_bytes = base64.b64decode(audio_content) - logger.debug("audio_bytes=<%d> | nova audio output received", len(audio_bytes)) - - async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: - """Receive Nova Sonic events and convert to provider-agnostic format. - - Raises: - RuntimeError: If start has not been called. - """ - if not self._connection_id: - raise RuntimeError("model not started | call start before receiving") - - logger.debug("nova event stream starting") - yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) - - _, output = await self._stream.await_output() - while True: - try: - event_data = await output.receive() - - except ValidationException as error: - if "InternalErrorCode=531" in error.message: - # nova also times out if user is silent for 175 seconds - raise BidiModelTimeoutError(error.message) from error - raise - - except ModelTimeoutException as error: - raise BidiModelTimeoutError(error.message) from error - - if not event_data: - continue - - nova_event = json.loads(event_data.value.bytes_.decode("utf-8"))["event"] - self._log_event_type(nova_event) - - model_event = self._convert_nova_event(nova_event) - if model_event: - yield model_event - - async def send(self, content: BidiInputEvent | ToolResultEvent) -> None: - """Unified send method for all content types. Sends the given content to Nova Sonic. - - Dispatches to appropriate internal handler based on content type. - - Args: - content: Input event. - - Raises: - ValueError: If content type not supported (e.g., image content). - """ - if not self._connection_id: - raise RuntimeError("model not started | call start before sending") - - if isinstance(content, BidiTextInputEvent): - await self._send_text_content(content.text) - elif isinstance(content, BidiAudioInputEvent): - await self._send_audio_content(content) - elif isinstance(content, ToolResultEvent): - tool_result = content.get("tool_result") - if tool_result: - await self._send_tool_result(tool_result) - else: - raise ValueError(f"content_type={type(content)} | content not supported") - - async def _start_audio_connection(self) -> None: - """Internal: Start audio input connection (call once before sending audio chunks).""" - logger.debug("nova audio connection starting") - self._audio_content_name = str(uuid.uuid4()) - - # Build audio input configuration from config - audio_input_config = { - "mediaType": "audio/lpcm", - "sampleRateHertz": self.config["audio"]["input_rate"], - "sampleSizeBits": 16, - "channelCount": self.config["audio"]["channels"], - "audioType": "SPEECH", - "encoding": "base64", - } - - audio_content_start = json.dumps( - { - "event": { - "contentStart": { - "promptName": self._connection_id, - "contentName": self._audio_content_name, - "type": "AUDIO", - "interactive": True, - "role": "USER", - "audioInputConfiguration": audio_input_config, - } - } - } - ) - - await self._send_nova_events([audio_content_start]) - - async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: - """Internal: Send audio using Nova Sonic protocol-specific format.""" - # Start audio connection if not already active - if not self._audio_content_name: - await self._start_audio_connection() - - # Audio is already base64 encoded in the event - # Send audio input event - audio_event = json.dumps( - { - "event": { - "audioInput": { - "promptName": self._connection_id, - "contentName": self._audio_content_name, - "content": audio_input.audio, - } - } - } - ) - - await self._send_nova_events([audio_event]) - - async def _end_audio_input(self) -> None: - """Internal: End current audio input connection to trigger Nova Sonic processing.""" - if not self._audio_content_name: - return - - logger.debug("nova audio connection ending") - - audio_content_end = json.dumps( - {"event": {"contentEnd": {"promptName": self._connection_id, "contentName": self._audio_content_name}}} - ) - - await self._send_nova_events([audio_content_end]) - self._audio_content_name = None - - async def _send_text_content(self, text: str) -> None: - """Internal: Send text content using Nova Sonic format.""" - content_name = str(uuid.uuid4()) - events = [ - self._get_text_content_start_event(content_name), - self._get_text_input_event(content_name, text), - self._get_content_end_event(content_name), - ] - await self._send_nova_events(events) - - async def _send_tool_result(self, tool_result: ToolResult) -> None: - """Internal: Send tool result using Nova Sonic toolResult format.""" - tool_use_id = tool_result["toolUseId"] - - logger.debug("tool_use_id=<%s> | sending nova tool result", tool_use_id) - - # Validate content types and preserve structure - content = tool_result.get("content", []) - - # Validate all content types are supported - for block in content: - if "text" not in block and "json" not in block: - # Unsupported content type - raise error - raise ValueError( - f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | " - f"Content type not supported by Nova Sonic" - ) - - # Optimize for single content item - unwrap the array - if len(content) == 1: - result_data = cast(dict[str, Any], content[0]) - else: - # Multiple items - send as array - result_data = {"content": content} - - content_name = str(uuid.uuid4()) - events = [ - self._get_tool_content_start_event(content_name, tool_use_id), - self._get_tool_result_event(content_name, result_data), - self._get_content_end_event(content_name), - ] - await self._send_nova_events(events) - - async def stop(self) -> None: - """Close Nova Sonic connection with proper cleanup sequence.""" - logger.debug("nova connection cleanup starting") - - async def stop_events() -> None: - if not self._connection_id: - return - - await self._end_audio_input() - cleanup_events = [self._get_prompt_end_event(), self._get_connection_end_event()] - await self._send_nova_events(cleanup_events) - - async def stop_stream() -> None: - if not hasattr(self, "_stream"): - return - - await self._stream.close() - - async def stop_connection() -> None: - self._connection_id = None - - await stop_all(stop_events, stop_stream, stop_connection) - - logger.debug("nova connection closed") - - def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | None: - """Convert Nova Sonic events to TypedEvent format.""" - # Handle completion start - track completionId - if "completionStart" in nova_event: - completion_data = nova_event["completionStart"] - self._current_completion_id = completion_data.get("completionId") - logger.debug("completion_id=<%s> | nova completion started", self._current_completion_id) - return None - - # Handle completion end - if "completionEnd" in nova_event: - completion_data = nova_event["completionEnd"] - completion_id = completion_data.get("completionId", self._current_completion_id) - stop_reason = completion_data.get("stopReason", "END_TURN") - - event = BidiResponseCompleteEvent( - response_id=completion_id or str(uuid.uuid4()), # Fallback to UUID if missing - stop_reason="interrupted" if stop_reason == "INTERRUPTED" else "complete", - ) - - # Clear completion tracking - self._current_completion_id = None - return event - - # Handle audio output - if "audioOutput" in nova_event: - # Audio is already base64 string from Nova Sonic - audio_content = nova_event["audioOutput"]["content"] - return BidiAudioStreamEvent( - audio=audio_content, - format="pcm", - sample_rate=cast(AudioSampleRate, self.config["audio"]["output_rate"]), - channels=cast(AudioChannel, self.config["audio"]["channels"]), - ) - - # Handle text output (transcripts) - elif "textOutput" in nova_event: - text_output = nova_event["textOutput"] - text_content = text_output["content"] - # Check for Nova Sonic interruption pattern - if '{ "interrupted" : true }' in text_content: - logger.debug("nova interruption detected in text output") - return BidiInterruptionEvent(reason="user_speech") - - return BidiTranscriptStreamEvent( - delta={"text": text_content}, - text=text_content, - role=text_output["role"].lower(), - is_final=self._generation_stage == "FINAL", - current_transcript=text_content, - ) - - # Handle tool use - if "toolUse" in nova_event: - tool_use = nova_event["toolUse"] - tool_use_event: ToolUse = { - "toolUseId": tool_use["toolUseId"], - "name": tool_use["toolName"], - "input": json.loads(tool_use["content"]), - } - # Return ToolUseStreamEvent - cast to dict for type compatibility - return ToolUseStreamEvent(delta={"toolUse": tool_use_event}, current_tool_use=dict(tool_use_event)) - - # Handle interruption - if nova_event.get("stopReason") == "INTERRUPTED": - logger.debug("nova interruption detected via stop reason") - return BidiInterruptionEvent(reason="user_speech") - - # Handle usage events - convert to multimodal usage format - if "usageEvent" in nova_event: - usage_data = nova_event["usageEvent"] - total_input = usage_data.get("totalInputTokens", 0) - total_output = usage_data.get("totalOutputTokens", 0) - - return BidiUsageEvent( - input_tokens=total_input, - output_tokens=total_output, - total_tokens=usage_data.get("totalTokens", total_input + total_output), - ) - - # Handle content start events (emit response start) - if "contentStart" in nova_event: - content_data = nova_event["contentStart"] - if content_data["type"] == "TEXT": - self._generation_stage = json.loads(content_data["additionalModelFields"])["generationStage"] - - # Emit response start event using API-provided completionId - # completionId should already be tracked from completionStart event - return BidiResponseStartEvent( - response_id=self._current_completion_id or str(uuid.uuid4()) # Fallback to UUID if missing - ) - - if "contentEnd" in nova_event: - self._generation_stage = None - - # Ignore all other events - return None - - def _get_connection_start_event(self) -> str: - """Generate Nova Sonic connection start event.""" - return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": NOVA_INFERENCE_CONFIG}}}) - - def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: - """Generate Nova Sonic prompt start event with tool configuration.""" - # Build audio output configuration from config - audio_output_config = { - "mediaType": "audio/lpcm", - "sampleRateHertz": self.config["audio"]["output_rate"], - "sampleSizeBits": 16, - "channelCount": self.config["audio"]["channels"], - "voiceId": self.config["audio"].get("voice", "matthew"), - "encoding": "base64", - "audioType": "SPEECH", - } - - prompt_start_event: dict[str, Any] = { - "event": { - "promptStart": { - "promptName": self._connection_id, - "textOutputConfiguration": NOVA_TEXT_CONFIG, - "audioOutputConfiguration": audio_output_config, - } - } - } - - if tools: - tool_config = self._build_tool_configuration(tools) - prompt_start_event["event"]["promptStart"]["toolUseOutputConfiguration"] = NOVA_TOOL_CONFIG - prompt_start_event["event"]["promptStart"]["toolConfiguration"] = {"tools": tool_config} - - return json.dumps(prompt_start_event) - - def _build_tool_configuration(self, tools: list[ToolSpec]) -> list[dict[str, Any]]: - """Build tool configuration from tool specs.""" - tool_config: list[dict[str, Any]] = [] - for tool in tools: - input_schema = ( - {"json": json.dumps(tool["inputSchema"]["json"])} - if "json" in tool["inputSchema"] - else {"json": json.dumps(tool["inputSchema"])} - ) - - tool_config.append( - {"toolSpec": {"name": tool["name"], "description": tool["description"], "inputSchema": input_schema}} - ) - return tool_config - - def _get_system_prompt_events(self, system_prompt: str | None) -> list[str]: - """Generate system prompt events.""" - content_name = str(uuid.uuid4()) - return [ - self._get_text_content_start_event(content_name, "SYSTEM"), - self._get_text_input_event(content_name, system_prompt or ""), - self._get_content_end_event(content_name), - ] - - def _get_message_history_events(self, messages: Messages) -> list[str]: - """Generate conversation history events from agent messages. - - Converts agent message history to Nova Sonic format following the - contentStart/textInput/contentEnd pattern for each message. - - Args: - messages: List of conversation messages with role and content. - - Returns: - List of JSON event strings for Nova Sonic. - """ - events = [] - - for message in messages: - role = message["role"].upper() # Convert to ASSISTANT or USER - content_blocks = message.get("content", []) - - # Extract text content from content blocks - text_parts = [] - for block in content_blocks: - if "text" in block: - text_parts.append(block["text"]) - - # Combine all text parts - if text_parts: - combined_text = "\n".join(text_parts) - content_name = str(uuid.uuid4()) - - # Add contentStart, textInput, and contentEnd events - events.extend( - [ - self._get_text_content_start_event(content_name, role), - self._get_text_input_event(content_name, combined_text), - self._get_content_end_event(content_name), - ] - ) - - return events - - def _get_text_content_start_event(self, content_name: str, role: str = "USER") -> str: - """Generate text content start event.""" - return json.dumps( - { - "event": { - "contentStart": { - "promptName": self._connection_id, - "contentName": content_name, - "type": "TEXT", - "role": role, - "interactive": True, - "textInputConfiguration": NOVA_TEXT_CONFIG, - } - } - } - ) - - def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> str: - """Generate tool content start event.""" - return json.dumps( - { - "event": { - "contentStart": { - "promptName": self._connection_id, - "contentName": content_name, - "interactive": False, - "type": "TOOL", - "role": "TOOL", - "toolResultInputConfiguration": { - "toolUseId": tool_use_id, - "type": "TEXT", - "textInputConfiguration": NOVA_TEXT_CONFIG, - }, - } - } - } - ) - - def _get_text_input_event(self, content_name: str, text: str) -> str: - """Generate text input event.""" - return json.dumps( - {"event": {"textInput": {"promptName": self._connection_id, "contentName": content_name, "content": text}}} - ) - - def _get_tool_result_event(self, content_name: str, result: dict[str, Any]) -> str: - """Generate tool result event.""" - return json.dumps( - { - "event": { - "toolResult": { - "promptName": self._connection_id, - "contentName": content_name, - "content": json.dumps(result), - } - } - } - ) - - def _get_content_end_event(self, content_name: str) -> str: - """Generate content end event.""" - return json.dumps({"event": {"contentEnd": {"promptName": self._connection_id, "contentName": content_name}}}) - - def _get_prompt_end_event(self) -> str: - """Generate prompt end event.""" - return json.dumps({"event": {"promptEnd": {"promptName": self._connection_id}}}) - - def _get_connection_end_event(self) -> str: - """Generate connection end event.""" - return json.dumps({"event": {"connectionEnd": {}}}) - - async def _send_nova_events(self, events: list[str]) -> None: - """Send event JSON string to Nova Sonic stream. - - A lock is used to send events in sequence when required (e.g., tool result start, content, and end). - - Args: - events: Jsonified events. - """ - async with self._send_lock: - for event in events: - bytes_data = event.encode("utf-8") - chunk = InvokeModelWithBidirectionalStreamInputChunk( - value=BidirectionalInputPayloadPart(bytes_=bytes_data) - ) - await self._stream.input_stream.send(chunk) - logger.debug("nova sonic event sent successfully") diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py deleted file mode 100644 index bfe3ad533..000000000 --- a/src/strands/experimental/bidi/models/openai.py +++ /dev/null @@ -1,816 +0,0 @@ -"""OpenAI Realtime API provider for Strands bidirectional streaming. - -Provides real-time audio and text communication through OpenAI's Realtime API -with WebSocket connections, voice activity detection, and function calling. -""" - -import asyncio -import json -import logging -import os -import time -import uuid -from typing import Any, AsyncGenerator, Literal, cast - -import websockets -from websockets import ClientConnection - -from ....types._events import ToolResultEvent, ToolUseStreamEvent -from ....types.content import Messages -from ....types.tools import ToolResult, ToolSpec, ToolUse -from .._async import stop_all -from ..types.bidi_model import AudioConfig -from ..types.events import ( - AudioSampleRate, - BidiAudioInputEvent, - BidiAudioStreamEvent, - BidiConnectionStartEvent, - BidiInputEvent, - BidiInterruptionEvent, - BidiOutputEvent, - BidiResponseCompleteEvent, - BidiResponseStartEvent, - BidiTextInputEvent, - BidiTranscriptStreamEvent, - BidiUsageEvent, - ModalityUsage, - Role, - StopReason, -) -from .bidi_model import BidiModel, BidiModelTimeoutError - -logger = logging.getLogger(__name__) - -# Test idle_timeout_ms - -# OpenAI Realtime API configuration -OPENAI_MAX_TIMEOUT_S = 3000 # 50 minutes -"""Max timeout before closing connection. - -OpenAI documents a 60 minute limit on realtime sessions -(https://platform.openai.com/docs/guides/realtime-conversations#session-lifecycle-events). However, OpenAI does not -emit any warnings when approaching the limit. As a workaround, we configure a max timeout client side to gracefully -handle the connection closure. We set the max to 50 minutes to provide enough buffer before hitting the real limit. -""" -OPENAI_REALTIME_URL = "wss://api.openai.com/v1/realtime" -DEFAULT_MODEL = "gpt-realtime" -DEFAULT_SAMPLE_RATE = 24000 - -DEFAULT_SESSION_CONFIG = { - "type": "realtime", - "instructions": "You are a helpful assistant. Please speak in English and keep your responses clear and concise.", - "output_modalities": ["audio"], - "audio": { - "input": { - "format": {"type": "audio/pcm", "rate": DEFAULT_SAMPLE_RATE}, - "transcription": {"model": "gpt-4o-transcribe"}, - "turn_detection": { - "type": "server_vad", - "threshold": 0.5, - "prefix_padding_ms": 300, - "silence_duration_ms": 500, - }, - }, - "output": {"format": {"type": "audio/pcm", "rate": DEFAULT_SAMPLE_RATE}, "voice": "alloy"}, - }, -} - - -class BidiOpenAIRealtimeModel(BidiModel): - """OpenAI Realtime API implementation for bidirectional streaming. - - Combines model configuration and connection state in a single class. - Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, - function calling, and event conversion to Strands format. - """ - - _websocket: ClientConnection - _start_time: int - - def __init__( - self, - model_id: str = DEFAULT_MODEL, - provider_config: dict[str, Any] | None = None, - client_config: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - """Initialize OpenAI Realtime bidirectional model. - - Args: - model_id: Model identifier (default: gpt-realtime) - provider_config: Model behavior (audio, instructions, turn_detection, etc.) - client_config: Authentication (api_key, organization, project) - Falls back to OPENAI_API_KEY, OPENAI_ORGANIZATION, OPENAI_PROJECT env vars - **kwargs: Reserved for future parameters. - - """ - # Store model ID - self.model_id = model_id - - # Resolve client config with defaults and env vars - self._client_config = self._resolve_client_config(client_config or {}) - - # Resolve provider config with defaults - self.config = self._resolve_provider_config(provider_config or {}) - - # Store client config values for later use - self.api_key = self._client_config["api_key"] - self.organization = self._client_config.get("organization") - self.project = self._client_config.get("project") - self.timeout_s = self._client_config["timeout_s"] - - if self.timeout_s > OPENAI_MAX_TIMEOUT_S: - raise ValueError( - f"timeout_s=<{self.timeout_s}>, max_timeout_s=<{OPENAI_MAX_TIMEOUT_S}> | timeout exceeds max limit" - ) - - # Connection state (initialized in start()) - self._connection_id: str | None = None - - self._function_call_buffer: dict[str, Any] = {} - - logger.debug("model=<%s> | openai realtime model initialized", model_id) - - def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: - """Resolve client config with env var fallback (config takes precedence).""" - resolved = config.copy() - - if "api_key" not in resolved: - resolved["api_key"] = os.getenv("OPENAI_API_KEY") - - if not resolved.get("api_key"): - raise ValueError( - "OpenAI API key is required. Provide via client_config={'api_key': '...'} " - "or set OPENAI_API_KEY environment variable." - ) - if "organization" not in resolved: - env_org = os.getenv("OPENAI_ORGANIZATION") - if env_org: - resolved["organization"] = env_org - - if "project" not in resolved: - env_project = os.getenv("OPENAI_PROJECT") - if env_project: - resolved["project"] = env_project - - if "timeout_s" not in resolved: - resolved["timeout_s"] = OPENAI_MAX_TIMEOUT_S - - return resolved - - def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: - """Merge user config with defaults (user takes precedence).""" - # Extract voice from provider-specific audio.output.voice if present - provider_voice = None - if "audio" in config and isinstance(config["audio"], dict): - if "output" in config["audio"] and isinstance(config["audio"]["output"], dict): - provider_voice = config["audio"]["output"].get("voice") - - # Define default audio configuration - default_audio: AudioConfig = { - "input_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE), - "output_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE), - "channels": 1, - "format": "pcm", - "voice": provider_voice or "alloy", - } - - user_audio = config.get("audio", {}) - merged_audio = {**default_audio, **user_audio} - - resolved = { - "audio": merged_audio, - **{k: v for k, v in config.items() if k != "audio"}, - } - - if user_audio: - logger.debug("audio_config | merged user-provided config with defaults") - else: - logger.debug("audio_config | using default OpenAI Realtime audio configuration") - - return resolved - - async def start( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs: Any, - ) -> None: - """Establish bidirectional connection to OpenAI Realtime API. - - Args: - system_prompt: System instructions for the model. - tools: List of tools available to the model. - messages: Conversation history to initialize with. - **kwargs: Additional configuration options. - """ - if self._connection_id: - raise RuntimeError("model already started | call stop before starting again") - - logger.debug("openai realtime connection starting") - - # Initialize connection state - self._connection_id = str(uuid.uuid4()) - self._start_time = int(time.time()) - - self._function_call_buffer = {} - - # Establish WebSocket connection - url = f"{OPENAI_REALTIME_URL}?model={self.model_id}" - - headers = [("Authorization", f"Bearer {self.api_key}")] - if self.organization: - headers.append(("OpenAI-Organization", self.organization)) - if self.project: - headers.append(("OpenAI-Project", self.project)) - - self._websocket = await websockets.connect(url, additional_headers=headers) - logger.debug("connection_id=<%s> | websocket connected successfully", self._connection_id) - - # Configure session - session_config = self._build_session_config(system_prompt, tools) - await self._send_event({"type": "session.update", "session": session_config}) - - # Add conversation history if provided - if messages: - await self._add_conversation_history(messages) - - def _create_text_event(self, text: str, role: str, is_final: bool = True) -> BidiTranscriptStreamEvent: - """Create standardized transcript event. - - Args: - text: The transcript text - role: The role (will be normalized to lowercase) - is_final: Whether this is the final transcript - """ - # Normalize role to lowercase and ensure it's either "user" or "assistant" - normalized_role = role.lower() if isinstance(role, str) else "assistant" - if normalized_role not in ["user", "assistant"]: - normalized_role = "assistant" - - return BidiTranscriptStreamEvent( - delta={"text": text}, - text=text, - role=cast(Role, normalized_role), - is_final=is_final, - current_transcript=text if is_final else None, - ) - - def _create_voice_activity_event(self, activity_type: str) -> BidiInterruptionEvent | None: - """Create standardized interruption event for voice activity.""" - # Only speech_started triggers interruption - if activity_type == "speech_started": - return BidiInterruptionEvent(reason="user_speech") - # Other voice activity events are logged but don't create events - return None - - def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict[str, Any]: - """Build session configuration for OpenAI Realtime API.""" - config: dict[str, Any] = DEFAULT_SESSION_CONFIG.copy() - - if system_prompt: - config["instructions"] = system_prompt - - if tools: - config["tools"] = self._convert_tools_to_openai_format(tools) - - # Apply user-provided session configuration - supported_params = { - "type", - "output_modalities", - "instructions", - "voice", - "tools", - "tool_choice", - "input_audio_format", - "output_audio_format", - "input_audio_transcription", - "turn_detection", - } - - for key, value in self.config.items(): - if key == "audio": - continue - elif key in supported_params: - config[key] = value - else: - logger.warning("parameter=<%s> | ignoring unsupported session parameter", key) - - audio_config = self.config["audio"] - - if "voice" in audio_config: - config.setdefault("audio", {}).setdefault("output", {})["voice"] = audio_config["voice"] - - if "input_rate" in audio_config: - config.setdefault("audio", {}).setdefault("input", {}).setdefault("format", {})["rate"] = audio_config[ - "input_rate" - ] - - if "output_rate" in audio_config: - config.setdefault("audio", {}).setdefault("output", {}).setdefault("format", {})["rate"] = audio_config[ - "output_rate" - ] - - return config - - def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: - """Convert Strands tool specifications to OpenAI Realtime API format.""" - openai_tools = [] - - for tool in tools: - input_schema = tool["inputSchema"] - if "json" in input_schema: - schema = ( - json.loads(input_schema["json"]) if isinstance(input_schema["json"], str) else input_schema["json"] - ) - else: - schema = input_schema - - # OpenAI Realtime API expects flat structure, not nested under "function" - openai_tool = { - "type": "function", - "name": tool["name"], - "description": tool["description"], - "parameters": schema, - } - openai_tools.append(openai_tool) - - return openai_tools - - async def _add_conversation_history(self, messages: Messages) -> None: - """Add conversation history to the session. - - Converts agent message history to OpenAI Realtime API format using - conversation.item.create events for each message. - - Note: OpenAI Realtime API has a 32-character limit on call_id, so we truncate - UUIDs consistently to ensure tool calls and their results match. - - Args: - messages: List of conversation messages with role and content. - """ - # Track tool call IDs to ensure consistency between calls and results - call_id_map: dict[str, str] = {} - - # First pass: collect all tool call IDs - for message in messages: - for block in message.get("content", []): - if "toolUse" in block: - tool_use = block["toolUse"] - original_id = tool_use["toolUseId"] - call_id = original_id[:32] - call_id_map[original_id] = call_id - - # Second pass: send messages - for message in messages: - role = message["role"] - content_blocks = message.get("content", []) - - # Build content array for OpenAI format - openai_content = [] - - for block in content_blocks: - if "text" in block: - # Text content - use appropriate type based on role - # User messages use "input_text", assistant messages use "output_text" - if role == "user": - openai_content.append({"type": "input_text", "text": block["text"]}) - else: # assistant - openai_content.append({"type": "output_text", "text": block["text"]}) - elif "toolUse" in block: - # Tool use - create as function_call item - tool_use = block["toolUse"] - original_id = tool_use["toolUseId"] - # Use pre-mapped call_id - call_id = call_id_map[original_id] - - tool_item = { - "type": "conversation.item.create", - "item": { - "type": "function_call", - "call_id": call_id, - "name": tool_use["name"], - "arguments": json.dumps(tool_use["input"]), - }, - } - await self._send_event(tool_item) - continue # Tool use is sent separately, not in message content - elif "toolResult" in block: - # Tool result - create as function_call_output item - tool_result = block["toolResult"] - original_id = tool_result["toolUseId"] - - # Validate content types and serialize, preserving structure - result_output = "" - if "content" in tool_result: - # First validate all content types are supported - for result_block in tool_result["content"]: - if "text" not in result_block and "json" not in result_block: - # Unsupported content type - raise error - raise ValueError( - f"tool_use_id=<{original_id}>, content_types=<{list(result_block.keys())}> | " - f"Content type not supported by OpenAI Realtime API" - ) - - # Preserve structure by JSON-dumping the entire content array - result_output = json.dumps(tool_result["content"]) - - # Use mapped call_id if available, otherwise skip orphaned result - if original_id not in call_id_map: - continue # Skip this tool result since we don't have the call - - call_id = call_id_map[original_id] - - result_item = { - "type": "conversation.item.create", - "item": { - "type": "function_call_output", - "call_id": call_id, - "output": result_output, - }, - } - await self._send_event(result_item) - continue # Tool result is sent separately, not in message content - - # Only create message item if there's text content - if openai_content: - conversation_item = { - "type": "conversation.item.create", - "item": {"type": "message", "role": role, "content": openai_content}, - } - await self._send_event(conversation_item) - - logger.debug("message_count=<%d> | conversation history added to openai session", len(messages)) - - async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: - """Receive OpenAI events and convert to Strands TypedEvent format.""" - if not self._connection_id: - raise RuntimeError("model not started | call start before sending/receiving") - - yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) - - while True: - duration = time.time() - self._start_time - if duration >= self.timeout_s: - raise BidiModelTimeoutError(f"timeout_s=<{self.timeout_s}>") - - try: - message = await asyncio.wait_for(self._websocket.recv(), timeout=10) - except asyncio.TimeoutError: - continue - - openai_event = json.loads(message) - - for event in self._convert_openai_event(openai_event) or []: - yield event - - def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutputEvent] | None: - """Convert OpenAI events to Strands TypedEvent format.""" - event_type = openai_event.get("type") - - # Turn start - response begins - if event_type == "response.created": - response = openai_event.get("response", {}) - response_id = response.get("id", str(uuid.uuid4())) - return [BidiResponseStartEvent(response_id=response_id)] - - # Audio output - elif event_type == "response.output_audio.delta": - # Audio is already base64 string from OpenAI - # Use the resolved output sample rate from our merged configuration - sample_rate = self.config["audio"]["output_rate"] - - # Channels from config is guaranteed to be 1 or 2 - channels = cast(Literal[1, 2], self.config["audio"]["channels"]) - return [ - BidiAudioStreamEvent( - audio=openai_event["delta"], - format="pcm", - sample_rate=sample_rate, - channels=channels, - ) - ] - - # Assistant text output events - combine multiple similar events - elif event_type in ["response.output_text.delta", "response.output_audio_transcript.delta"]: - role = openai_event.get("role", "assistant") - return [ - self._create_text_event( - openai_event["delta"], role.lower() if isinstance(role, str) else "assistant", is_final=False - ) - ] - - elif event_type in ["response.output_audio_transcript.done"]: - role = openai_event.get("role", "assistant").lower() - return [self._create_text_event(openai_event["transcript"], role)] - - elif event_type in ["response.output_text.done"]: - role = openai_event.get("role", "assistant").lower() - return [self._create_text_event(openai_event["text"], role)] - - # User transcription events - combine multiple similar events - elif event_type in [ - "conversation.item.input_audio_transcription.delta", - "conversation.item.input_audio_transcription.completed", - ]: - text_key = "delta" if "delta" in event_type else "transcript" - text = openai_event.get(text_key, "") - role = openai_event.get("role", "user") - is_final = "completed" in event_type - return ( - [self._create_text_event(text, role.lower() if isinstance(role, str) else "user", is_final=is_final)] - if text.strip() - else None - ) - - elif event_type == "conversation.item.input_audio_transcription.segment": - segment_data = openai_event.get("segment", {}) - text = segment_data.get("text", "") - role = segment_data.get("role", "user") - return ( - [self._create_text_event(text, role.lower() if isinstance(role, str) else "user")] - if text.strip() - else None - ) - - elif event_type == "conversation.item.input_audio_transcription.failed": - error_info = openai_event.get("error", {}) - logger.warning("error=<%s> | openai transcription failed", error_info.get("message", "unknown error")) - return None - - # Function call processing - elif event_type == "response.function_call_arguments.delta": - call_id = openai_event.get("call_id") - delta = openai_event.get("delta", "") - if call_id: - if call_id not in self._function_call_buffer: - self._function_call_buffer[call_id] = {"call_id": call_id, "name": "", "arguments": delta} - else: - self._function_call_buffer[call_id]["arguments"] += delta - return None - - elif event_type == "response.function_call_arguments.done": - call_id = openai_event.get("call_id") - if call_id and call_id in self._function_call_buffer: - function_call = self._function_call_buffer[call_id] - try: - tool_use: ToolUse = { - "toolUseId": call_id, - "name": function_call["name"], - "input": json.loads(function_call["arguments"]) if function_call["arguments"] else {}, - } - del self._function_call_buffer[call_id] - # Return ToolUseStreamEvent for consistency with standard agent - return [ToolUseStreamEvent(delta={"toolUse": tool_use}, current_tool_use=dict(tool_use))] - except (json.JSONDecodeError, KeyError) as e: - logger.warning("call_id=<%s>, error=<%s> | error parsing function arguments", call_id, e) - del self._function_call_buffer[call_id] - return None - - # Voice activity detection - speech_started triggers interruption - elif event_type == "input_audio_buffer.speech_started": - # This is the primary interruption signal - handle it first - return [BidiInterruptionEvent(reason="user_speech")] - - # Response cancelled - handle interruption - elif event_type == "response.cancelled": - response = openai_event.get("response", {}) - response_id = response.get("id", "unknown") - logger.debug("response_id=<%s> | openai response cancelled", response_id) - return [BidiResponseCompleteEvent(response_id=response_id, stop_reason="interrupted")] - - # Turn complete and usage - response finished - elif event_type == "response.done": - response = openai_event.get("response", {}) - response_id = response.get("id", "unknown") - status = response.get("status", "completed") - usage = response.get("usage") - - # Map OpenAI status to our stop_reason - stop_reason_map = { - "completed": "complete", - "cancelled": "interrupted", - "failed": "error", - "incomplete": "interrupted", - } - - # Build list of events to return - events: list[Any] = [] - - # Always add response complete event - events.append( - BidiResponseCompleteEvent( - response_id=response_id, - stop_reason=cast(StopReason, stop_reason_map.get(status, "complete")), - ), - ) - - # Add usage event if available - if usage: - input_details = usage.get("input_token_details", {}) - output_details = usage.get("output_token_details", {}) - - # Build modality details - modality_details = [] - - # Text modality - text_input = input_details.get("text_tokens", 0) - text_output = output_details.get("text_tokens", 0) - if text_input > 0 or text_output > 0: - modality_details.append( - {"modality": "text", "input_tokens": text_input, "output_tokens": text_output} - ) - - # Audio modality - audio_input = input_details.get("audio_tokens", 0) - audio_output = output_details.get("audio_tokens", 0) - if audio_input > 0 or audio_output > 0: - modality_details.append( - {"modality": "audio", "input_tokens": audio_input, "output_tokens": audio_output} - ) - - # Image modality - image_input = input_details.get("image_tokens", 0) - if image_input > 0: - modality_details.append({"modality": "image", "input_tokens": image_input, "output_tokens": 0}) - - # Cached tokens - cached_tokens = input_details.get("cached_tokens", 0) - - # Add usage event - events.append( - BidiUsageEvent( - input_tokens=usage.get("input_tokens", 0), - output_tokens=usage.get("output_tokens", 0), - total_tokens=usage.get("total_tokens", 0), - modality_details=cast(list[ModalityUsage], modality_details) if modality_details else None, - cache_read_input_tokens=cached_tokens if cached_tokens > 0 else None, - ) - ) - - # Return list of events - return events - - # Lifecycle events (log only) - combine multiple similar events - elif event_type in ["conversation.item.retrieve", "conversation.item.added"]: - item = openai_event.get("item", {}) - action = "retrieved" if "retrieve" in event_type else "added" - logger.debug("action=<%s>, item_id=<%s> | openai conversation item event", action, item.get("id")) - return None - - elif event_type == "conversation.item.done": - logger.debug("item_id=<%s> | openai conversation item done", openai_event.get("item", {}).get("id")) - return None - - # Response output events - combine similar events - elif event_type in [ - "response.output_item.added", - "response.output_item.done", - "response.content_part.added", - "response.content_part.done", - ]: - item_data = openai_event.get("item") or openai_event.get("part") - logger.debug( - "event_type=<%s>, item_id=<%s> | openai output event", - event_type, - item_data.get("id") if item_data else "unknown", - ) - - # Track function call names from response.output_item.added - if event_type == "response.output_item.added": - item = openai_event.get("item", {}) - if item.get("type") == "function_call": - call_id = item.get("call_id") - function_name = item.get("name") - if call_id and function_name: - if call_id not in self._function_call_buffer: - self._function_call_buffer[call_id] = { - "call_id": call_id, - "name": function_name, - "arguments": "", - } - else: - self._function_call_buffer[call_id]["name"] = function_name - return None - - # Session/buffer events - combine simple log-only events - elif event_type in [ - "input_audio_buffer.committed", - "input_audio_buffer.cleared", - "session.created", - "session.updated", - ]: - logger.debug("event_type=<%s> | openai event received", event_type) - return None - - elif event_type == "error": - error_data = openai_event.get("error", {}) - error_code = error_data.get("code", "") - - # Suppress expected errors that don't affect session state - if error_code == "response_cancel_not_active": - # This happens when trying to cancel a response that's not active - # It's safe to ignore as the session remains functional - logger.debug("openai response cancel attempted when no response active") - return None - - # Log other errors - logger.error("error=<%s> | openai realtime error", error_data) - return None - - else: - logger.debug("event_type=<%s> | unhandled openai event type", event_type) - return None - - async def send( - self, - content: BidiInputEvent | ToolResultEvent, - ) -> None: - """Unified send method for all content types. Sends the given content to OpenAI. - - Dispatches to appropriate internal handler based on content type. - - Args: - content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent). - - Raises: - ValueError: If content type not supported (e.g., image content). - """ - if not self._connection_id: - raise RuntimeError("model not started | call start before sending") - - # Note: TypedEvent inherits from dict, so isinstance checks for TypedEvent must come first - if isinstance(content, BidiTextInputEvent): - await self._send_text_content(content.text) - elif isinstance(content, BidiAudioInputEvent): - await self._send_audio_content(content) - elif isinstance(content, ToolResultEvent): - tool_result = content.get("tool_result") - if tool_result: - await self._send_tool_result(tool_result) - else: - raise ValueError(f"content_type={type(content)} | content not supported") - - async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: - """Internal: Send audio content to OpenAI for processing.""" - # Audio is already base64 encoded in the event - await self._send_event({"type": "input_audio_buffer.append", "audio": audio_input.audio}) - - async def _send_text_content(self, text: str) -> None: - """Internal: Send text content to OpenAI for processing.""" - item_data = {"type": "message", "role": "user", "content": [{"type": "input_text", "text": text}]} - await self._send_event({"type": "conversation.item.create", "item": item_data}) - await self._send_event({"type": "response.create"}) - - async def _send_interrupt(self) -> None: - """Internal: Send interruption signal to OpenAI.""" - await self._send_event({"type": "response.cancel"}) - - async def _send_tool_result(self, tool_result: ToolResult) -> None: - """Internal: Send tool result back to OpenAI.""" - tool_use_id = tool_result.get("toolUseId") - - logger.debug("tool_use_id=<%s> | sending openai tool result", tool_use_id) - - # Validate content types and serialize, preserving structure - result_output = "" - if "content" in tool_result: - # First validate all content types are supported - for block in tool_result["content"]: - if "text" not in block and "json" not in block: - # Unsupported content type - raise error - raise ValueError( - f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | " - f"Content type not supported by OpenAI Realtime API" - ) - - # Preserve structure by JSON-dumping the entire content array - result_output = json.dumps(tool_result["content"]) - - item_data = {"type": "function_call_output", "call_id": tool_use_id, "output": result_output} - await self._send_event({"type": "conversation.item.create", "item": item_data}) - await self._send_event({"type": "response.create"}) - - async def stop(self) -> None: - """Close session and cleanup resources.""" - logger.debug("openai realtime connection cleanup starting") - - async def stop_websocket() -> None: - if not hasattr(self, "_websocket"): - return - - await self._websocket.close() - - async def stop_connection() -> None: - self._connection_id = None - - await stop_all(stop_websocket, stop_connection) - - logger.debug("openai realtime connection closed") - - async def _send_event(self, event: dict[str, Any]) -> None: - """Send event to OpenAI via WebSocket.""" - message = json.dumps(event) - await self._websocket.send(message) - logger.debug("event_type=<%s> | openai event sent", event.get("type")) diff --git a/src/strands/experimental/bidi/types/bidi_model.py b/src/strands/experimental/bidi/types/bidi_model.py deleted file mode 100644 index de41de1a9..000000000 --- a/src/strands/experimental/bidi/types/bidi_model.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Model-related type definitions for bidirectional streaming. - -Defines types and configurations that are central to model providers, -including audio configuration that models use to specify their audio -processing requirements. -""" - -from typing import TypedDict - -from .events import AudioChannel, AudioFormat, AudioSampleRate - - -class AudioConfig(TypedDict, total=False): - """Audio configuration for bidirectional streaming models. - - Defines standard audio parameters that model providers use to specify - their audio processing requirements. All fields are optional to support - models that may not use audio or only need specific parameters. - - Model providers build this configuration by merging user-provided values - with their own defaults. The resulting configuration is then used by - audio I/O implementations to configure hardware appropriately. - - Attributes: - input_rate: Input sample rate in Hz (e.g., 16000, 24000, 48000) - output_rate: Output sample rate in Hz (e.g., 16000, 24000, 48000) - channels: Number of audio channels (1=mono, 2=stereo) - format: Audio encoding format - voice: Voice identifier for text-to-speech (e.g., "alloy", "matthew") - """ - - input_rate: AudioSampleRate - output_rate: AudioSampleRate - channels: AudioChannel - format: AudioFormat - voice: str diff --git a/src/strands/experimental/bidi/types/events.py b/src/strands/experimental/bidi/types/events.py index 7ea2b6345..d9905c16b 100644 --- a/src/strands/experimental/bidi/types/events.py +++ b/src/strands/experimental/bidi/types/events.py @@ -25,7 +25,7 @@ from ....types.streaming import ContentBlockDelta if TYPE_CHECKING: - from ..models.bidi_model import BidiModelTimeoutError + from ..models.model import BidiModelTimeoutError AudioChannel = Literal[1, 2] """Number of audio channels. diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index a880bb223..c92211816 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -13,7 +13,7 @@ import pytest from google.genai import types as genai_types -from strands.experimental.bidi.models.bidi_model import BidiModelTimeoutError +from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, diff --git a/tests/strands/experimental/bidi/models/test_novasonic.py b/tests/strands/experimental/bidi/models/test_novasonic.py index 39524e434..7ec0c32a1 100644 --- a/tests/strands/experimental/bidi/models/test_novasonic.py +++ b/tests/strands/experimental/bidi/models/test_novasonic.py @@ -13,10 +13,10 @@ import pytest_asyncio from aws_sdk_bedrock_runtime.models import ModelTimeoutException, ValidationException -from strands.experimental.bidi.models.novasonic import ( +from strands.experimental.bidi.models.nova_sonic import ( BidiNovaSonicModel, ) -from strands.experimental.bidi.models.bidi_model import BidiModelTimeoutError +from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, diff --git a/tests/strands/experimental/bidi/models/test_openai.py b/tests/strands/experimental/bidi/models/test_openai.py index 85a1cc097..5b3d627fd 100644 --- a/tests/strands/experimental/bidi/models/test_openai.py +++ b/tests/strands/experimental/bidi/models/test_openai.py @@ -14,8 +14,8 @@ import pytest -from strands.experimental.bidi.models.bidi_model import BidiModelTimeoutError -from strands.experimental.bidi.models.openai import BidiOpenAIRealtimeModel +from strands.experimental.bidi.models.model import BidiModelTimeoutError +from strands.experimental.bidi.models.openai_realtime import BidiOpenAIRealtimeModel from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, From 0dd05fe99ecdfb9d8243686adefed6ce0b881780 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sun, 30 Nov 2025 16:09:27 -0500 Subject: [PATCH 227/242] rename files --- src/strands/experimental/bidi/models/model.py | 130 +++ .../experimental/bidi/models/nova_sonic.py | 760 ++++++++++++++++ .../bidi/models/openai_realtime.py | 816 ++++++++++++++++++ src/strands/experimental/bidi/types/model.py | 36 + 4 files changed, 1742 insertions(+) create mode 100644 src/strands/experimental/bidi/models/model.py create mode 100644 src/strands/experimental/bidi/models/nova_sonic.py create mode 100644 src/strands/experimental/bidi/models/openai_realtime.py create mode 100644 src/strands/experimental/bidi/types/model.py diff --git a/src/strands/experimental/bidi/models/model.py b/src/strands/experimental/bidi/models/model.py new file mode 100644 index 000000000..0d0da63d2 --- /dev/null +++ b/src/strands/experimental/bidi/models/model.py @@ -0,0 +1,130 @@ +"""Bidirectional streaming model interface. + +Defines the abstract interface for models that support real-time bidirectional +communication with persistent connections. Unlike traditional request-response +models, bidirectional models maintain an open connection for streaming audio, +text, and tool interactions. + +Features: +- Persistent connection management with connect/close lifecycle +- Real-time bidirectional communication (send and receive simultaneously) +- Provider-agnostic event normalization +- Support for audio, text, image, and tool result streaming +""" + +import logging +from typing import Any, AsyncIterable, Protocol + +from ....types._events import ToolResultEvent +from ....types.content import Messages +from ....types.tools import ToolSpec +from ..types.events import ( + BidiInputEvent, + BidiOutputEvent, +) + +logger = logging.getLogger(__name__) + + +class BidiModel(Protocol): + """Protocol for bidirectional streaming models. + + This interface defines the contract for models that support persistent streaming + connections with real-time audio and text communication. Implementations handle + provider-specific protocols while exposing a standardized event-based API. + + Attributes: + config: Configuration dictionary with provider-specific settings. + """ + + config: dict[str, Any] + + async def start( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs: Any, + ) -> None: + """Establish a persistent streaming connection with the model. + + Opens a bidirectional connection that remains active for real-time communication. + The connection supports concurrent sending and receiving of events until explicitly + closed. Must be called before any send() or receive() operations. + + Args: + system_prompt: System instructions to configure model behavior. + tools: Tool specifications that the model can invoke during the conversation. + messages: Initial conversation history to provide context. + **kwargs: Provider-specific configuration options. + """ + ... + + async def stop(self) -> None: + """Close the streaming connection and release resources. + + Terminates the active bidirectional connection and cleans up any associated + resources such as network connections, buffers, or background tasks. After + calling close(), the model instance cannot be used until start() is called again. + """ + ... + + def receive(self) -> AsyncIterable[BidiOutputEvent]: + """Receive streaming events from the model. + + Continuously yields events from the model as they arrive over the connection. + Events are normalized to a provider-agnostic format for uniform processing. + This method should be called in a loop or async task to process model responses. + + The stream continues until the connection is closed or an error occurs. + + Yields: + BidiOutputEvent: Standardized event objects containing audio output, + transcripts, tool calls, or control signals. + """ + ... + + async def send( + self, + content: BidiInputEvent | ToolResultEvent, + ) -> None: + """Send content to the model over the active connection. + + Transmits user input or tool results to the model during an active streaming + session. Supports multiple content types including text, audio, images, and + tool execution results. Can be called multiple times during a conversation. + + Args: + content: The content to send. Must be one of: + - BidiTextInputEvent: Text message from the user + - BidiAudioInputEvent: Audio data for speech input + - BidiImageInputEvent: Image data for visual understanding + - ToolResultEvent: Result from a tool execution + + Example: + await model.send(BidiTextInputEvent(text="Hello", role="user")) + await model.send(BidiAudioInputEvent(audio=bytes, format="pcm", sample_rate=16000, channels=1)) + await model.send(BidiImageInputEvent(image=bytes, mime_type="image/jpeg", encoding="raw")) + await model.send(ToolResultEvent(tool_result)) + """ + ... + + +class BidiModelTimeoutError(Exception): + """Model timeout error. + + Bidirectional models are often configured with a connection time limit. Nova sonic for example keeps the connection + open for 8 minutes max. Upon receiving a timeout, the agent loop is configured to restart the model connection so as + to create a seamless, uninterrupted experience for the user. + """ + + def __init__(self, message: str, **restart_config: Any) -> None: + """Initialize error. + + Args: + message: Timeout message from model. + **restart_config: Configure restart specific behaviors in the call to model start. + """ + super().__init__(self, message) + + self.restart_config = restart_config diff --git a/src/strands/experimental/bidi/models/nova_sonic.py b/src/strands/experimental/bidi/models/nova_sonic.py new file mode 100644 index 000000000..262b37240 --- /dev/null +++ b/src/strands/experimental/bidi/models/nova_sonic.py @@ -0,0 +1,760 @@ +"""Nova Sonic bidirectional model provider for real-time streaming conversations. + +Implements the BidiModel interface for Amazon's Nova Sonic, handling the +complex event sequencing and audio processing required by Nova Sonic's +InvokeModelWithBidirectionalStream protocol. + +Nova Sonic specifics: +- Hierarchical event sequences: connectionStart → promptStart → content streaming +- Base64-encoded audio format with hex encoding +- Tool execution with content containers and identifier tracking +- 8-minute connection limits with proper cleanup sequences +- Interruption detection through stopReason events +""" + +import asyncio +import base64 +import json +import logging +import uuid +from typing import Any, AsyncGenerator, cast + +import boto3 +from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput +from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme +from aws_sdk_bedrock_runtime.models import ( + BidirectionalInputPayloadPart, + InvokeModelWithBidirectionalStreamInputChunk, + ModelTimeoutException, + ValidationException, +) +from smithy_aws_core.identity.static import StaticCredentialsResolver +from smithy_core.aio.eventstream import DuplexEventStream +from smithy_core.shapes import ShapeID + +from ....types._events import ToolResultEvent, ToolUseStreamEvent +from ....types.content import Messages +from ....types.tools import ToolResult, ToolSpec, ToolUse +from .._async import stop_all +from ..types.bidi_model import AudioConfig +from ..types.events import ( + AudioChannel, + AudioSampleRate, + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionStartEvent, + BidiInputEvent, + BidiInterruptionEvent, + BidiOutputEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiUsageEvent, +) +from .model import BidiModel, BidiModelTimeoutError + +logger = logging.getLogger(__name__) + +# Nova Sonic configuration constants +NOVA_INFERENCE_CONFIG = {"maxTokens": 1024, "topP": 0.9, "temperature": 0.7} + +NOVA_AUDIO_INPUT_CONFIG = { + "mediaType": "audio/lpcm", + "sampleRateHertz": 16000, + "sampleSizeBits": 16, + "channelCount": 1, + "audioType": "SPEECH", + "encoding": "base64", +} + +NOVA_AUDIO_OUTPUT_CONFIG = { + "mediaType": "audio/lpcm", + "sampleRateHertz": 16000, + "sampleSizeBits": 16, + "channelCount": 1, + "voiceId": "matthew", + "encoding": "base64", + "audioType": "SPEECH", +} + +NOVA_TEXT_CONFIG = {"mediaType": "text/plain"} +NOVA_TOOL_CONFIG = {"mediaType": "application/json"} + + +class BidiNovaSonicModel(BidiModel): + """Nova Sonic implementation for bidirectional streaming. + + Combines model configuration and connection state in a single class. + Manages Nova Sonic's complex event sequencing, audio format conversion, and + tool execution patterns while providing the standard BidiModel interface. + + Attributes: + _stream: open bedrock stream to nova sonic. + """ + + _stream: DuplexEventStream + + def __init__( + self, + model_id: str = "amazon.nova-sonic-v1:0", + provider_config: dict[str, Any] | None = None, + client_config: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + """Initialize Nova Sonic bidirectional model. + + Args: + model_id: Model identifier (default: amazon.nova-sonic-v1:0) + provider_config: Model behavior (audio, inference settings) + client_config: AWS authentication (boto_session OR region, not both) + **kwargs: Reserved for future parameters. + """ + # Store model ID + self.model_id = model_id + + # Resolve client config with defaults + self._client_config = self._resolve_client_config(client_config or {}) + + # Resolve provider config with defaults + self.config = self._resolve_provider_config(provider_config or {}) + + # Store session and region for later use + self._session = self._client_config["boto_session"] + self.region = self._client_config["region"] + + # Track API-provided identifiers + self._connection_id: str | None = None + self._audio_content_name: str | None = None + self._current_completion_id: str | None = None + + # Indicates if model is done generating transcript + self._generation_stage: str | None = None + + # Ensure certain events are sent in sequence when required + self._send_lock = asyncio.Lock() + + logger.debug("model_id=<%s> | nova sonic model initialized", model_id) + + def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Resolve AWS client config (creates boto session if needed).""" + if "boto_session" in config and "region" in config: + raise ValueError("Cannot specify both 'boto_session' and 'region' in client_config") + + resolved = config.copy() + + # Create boto session if not provided + if "boto_session" not in resolved: + resolved["boto_session"] = boto3.Session() + + # Resolve region from session or use default + if "region" not in resolved: + resolved["region"] = resolved["boto_session"].region_name or "us-east-1" + + return resolved + + def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Merge user config with defaults (user takes precedence).""" + # Define default audio configuration + default_audio_config: AudioConfig = { + "input_rate": cast(AudioSampleRate, NOVA_AUDIO_INPUT_CONFIG["sampleRateHertz"]), + "output_rate": cast(AudioSampleRate, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]), + "channels": cast(AudioChannel, NOVA_AUDIO_INPUT_CONFIG["channelCount"]), + "format": "pcm", + "voice": cast(str, NOVA_AUDIO_OUTPUT_CONFIG["voiceId"]), + } + + user_audio_config = config.get("audio", {}) + merged_audio = {**default_audio_config, **user_audio_config} + + resolved = { + "audio": merged_audio, + **{k: v for k, v in config.items() if k != "audio"}, + } + + if user_audio_config: + logger.debug("audio_config | merged user-provided config with defaults") + else: + logger.debug("audio_config | using default Nova Sonic audio configuration") + + return resolved + + async def start( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs: Any, + ) -> None: + """Establish bidirectional connection to Nova Sonic. + + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Additional configuration options. + + Raises: + RuntimeError: If user calls start again without first stopping. + """ + if self._connection_id: + raise RuntimeError("model already started | call stop before starting again") + + logger.debug("nova connection starting") + + self._connection_id = str(uuid.uuid4()) + + # Get credentials from boto3 session (full credential chain) + credentials = self._session.get_credentials() + + if not credentials: + raise ValueError( + "no AWS credentials found. configure credentials via environment variables, " + "credential files, IAM roles, or SSO." + ) + + # Use static resolver with credentials configured as properties + resolver = StaticCredentialsResolver() + + config = Config( + endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", + region=self.region, + aws_credentials_identity_resolver=resolver, + auth_scheme_resolver=HTTPAuthSchemeResolver(), + auth_schemes={ShapeID("aws.auth#sigv4"): SigV4AuthScheme(service="bedrock")}, + # Configure static credentials as properties + aws_access_key_id=credentials.access_key, + aws_secret_access_key=credentials.secret_key, + aws_session_token=credentials.token, + ) + + self.client = BedrockRuntimeClient(config=config) + logger.debug("region=<%s> | nova sonic client initialized", self.region) + + client = BedrockRuntimeClient(config=config) + self._stream = await client.invoke_model_with_bidirectional_stream( + InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) + ) + logger.debug("region=<%s> | nova sonic client initialized", self.region) + + init_events = self._build_initialization_events(system_prompt, tools, messages) + logger.debug("event_count=<%d> | sending nova sonic initialization events", len(init_events)) + await self._send_nova_events(init_events) + + logger.info("connection_id=<%s> | nova sonic connection established", self._connection_id) + + def _build_initialization_events( + self, system_prompt: str | None, tools: list[ToolSpec] | None, messages: Messages | None + ) -> list[str]: + """Build the sequence of initialization events.""" + tools = tools or [] + events = [ + self._get_connection_start_event(), + self._get_prompt_start_event(tools), + *self._get_system_prompt_events(system_prompt), + ] + + # Add conversation history if provided + if messages: + events.extend(self._get_message_history_events(messages)) + logger.debug("message_count=<%d> | conversation history added to initialization", len(messages)) + + return events + + def _log_event_type(self, nova_event: dict[str, Any]) -> None: + """Log specific Nova Sonic event types for debugging.""" + if "usageEvent" in nova_event: + logger.debug("usage=<%s> | nova usage event received", nova_event["usageEvent"]) + elif "textOutput" in nova_event: + logger.debug("nova text output received") + elif "toolUse" in nova_event: + tool_use = nova_event["toolUse"] + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | nova tool use received", + tool_use["toolName"], + tool_use["toolUseId"], + ) + elif "audioOutput" in nova_event: + audio_content = nova_event["audioOutput"]["content"] + audio_bytes = base64.b64decode(audio_content) + logger.debug("audio_bytes=<%d> | nova audio output received", len(audio_bytes)) + + async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: + """Receive Nova Sonic events and convert to provider-agnostic format. + + Raises: + RuntimeError: If start has not been called. + """ + if not self._connection_id: + raise RuntimeError("model not started | call start before receiving") + + logger.debug("nova event stream starting") + yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) + + _, output = await self._stream.await_output() + while True: + try: + event_data = await output.receive() + + except ValidationException as error: + if "InternalErrorCode=531" in error.message: + # nova also times out if user is silent for 175 seconds + raise BidiModelTimeoutError(error.message) from error + raise + + except ModelTimeoutException as error: + raise BidiModelTimeoutError(error.message) from error + + if not event_data: + continue + + nova_event = json.loads(event_data.value.bytes_.decode("utf-8"))["event"] + self._log_event_type(nova_event) + + model_event = self._convert_nova_event(nova_event) + if model_event: + yield model_event + + async def send(self, content: BidiInputEvent | ToolResultEvent) -> None: + """Unified send method for all content types. Sends the given content to Nova Sonic. + + Dispatches to appropriate internal handler based on content type. + + Args: + content: Input event. + + Raises: + ValueError: If content type not supported (e.g., image content). + """ + if not self._connection_id: + raise RuntimeError("model not started | call start before sending") + + if isinstance(content, BidiTextInputEvent): + await self._send_text_content(content.text) + elif isinstance(content, BidiAudioInputEvent): + await self._send_audio_content(content) + elif isinstance(content, ToolResultEvent): + tool_result = content.get("tool_result") + if tool_result: + await self._send_tool_result(tool_result) + else: + raise ValueError(f"content_type={type(content)} | content not supported") + + async def _start_audio_connection(self) -> None: + """Internal: Start audio input connection (call once before sending audio chunks).""" + logger.debug("nova audio connection starting") + self._audio_content_name = str(uuid.uuid4()) + + # Build audio input configuration from config + audio_input_config = { + "mediaType": "audio/lpcm", + "sampleRateHertz": self.config["audio"]["input_rate"], + "sampleSizeBits": 16, + "channelCount": self.config["audio"]["channels"], + "audioType": "SPEECH", + "encoding": "base64", + } + + audio_content_start = json.dumps( + { + "event": { + "contentStart": { + "promptName": self._connection_id, + "contentName": self._audio_content_name, + "type": "AUDIO", + "interactive": True, + "role": "USER", + "audioInputConfiguration": audio_input_config, + } + } + } + ) + + await self._send_nova_events([audio_content_start]) + + async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: + """Internal: Send audio using Nova Sonic protocol-specific format.""" + # Start audio connection if not already active + if not self._audio_content_name: + await self._start_audio_connection() + + # Audio is already base64 encoded in the event + # Send audio input event + audio_event = json.dumps( + { + "event": { + "audioInput": { + "promptName": self._connection_id, + "contentName": self._audio_content_name, + "content": audio_input.audio, + } + } + } + ) + + await self._send_nova_events([audio_event]) + + async def _end_audio_input(self) -> None: + """Internal: End current audio input connection to trigger Nova Sonic processing.""" + if not self._audio_content_name: + return + + logger.debug("nova audio connection ending") + + audio_content_end = json.dumps( + {"event": {"contentEnd": {"promptName": self._connection_id, "contentName": self._audio_content_name}}} + ) + + await self._send_nova_events([audio_content_end]) + self._audio_content_name = None + + async def _send_text_content(self, text: str) -> None: + """Internal: Send text content using Nova Sonic format.""" + content_name = str(uuid.uuid4()) + events = [ + self._get_text_content_start_event(content_name), + self._get_text_input_event(content_name, text), + self._get_content_end_event(content_name), + ] + await self._send_nova_events(events) + + async def _send_tool_result(self, tool_result: ToolResult) -> None: + """Internal: Send tool result using Nova Sonic toolResult format.""" + tool_use_id = tool_result["toolUseId"] + + logger.debug("tool_use_id=<%s> | sending nova tool result", tool_use_id) + + # Validate content types and preserve structure + content = tool_result.get("content", []) + + # Validate all content types are supported + for block in content: + if "text" not in block and "json" not in block: + # Unsupported content type - raise error + raise ValueError( + f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | " + f"Content type not supported by Nova Sonic" + ) + + # Optimize for single content item - unwrap the array + if len(content) == 1: + result_data = cast(dict[str, Any], content[0]) + else: + # Multiple items - send as array + result_data = {"content": content} + + content_name = str(uuid.uuid4()) + events = [ + self._get_tool_content_start_event(content_name, tool_use_id), + self._get_tool_result_event(content_name, result_data), + self._get_content_end_event(content_name), + ] + await self._send_nova_events(events) + + async def stop(self) -> None: + """Close Nova Sonic connection with proper cleanup sequence.""" + logger.debug("nova connection cleanup starting") + + async def stop_events() -> None: + if not self._connection_id: + return + + await self._end_audio_input() + cleanup_events = [self._get_prompt_end_event(), self._get_connection_end_event()] + await self._send_nova_events(cleanup_events) + + async def stop_stream() -> None: + if not hasattr(self, "_stream"): + return + + await self._stream.close() + + async def stop_connection() -> None: + self._connection_id = None + + await stop_all(stop_events, stop_stream, stop_connection) + + logger.debug("nova connection closed") + + def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | None: + """Convert Nova Sonic events to TypedEvent format.""" + # Handle completion start - track completionId + if "completionStart" in nova_event: + completion_data = nova_event["completionStart"] + self._current_completion_id = completion_data.get("completionId") + logger.debug("completion_id=<%s> | nova completion started", self._current_completion_id) + return None + + # Handle completion end + if "completionEnd" in nova_event: + completion_data = nova_event["completionEnd"] + completion_id = completion_data.get("completionId", self._current_completion_id) + stop_reason = completion_data.get("stopReason", "END_TURN") + + event = BidiResponseCompleteEvent( + response_id=completion_id or str(uuid.uuid4()), # Fallback to UUID if missing + stop_reason="interrupted" if stop_reason == "INTERRUPTED" else "complete", + ) + + # Clear completion tracking + self._current_completion_id = None + return event + + # Handle audio output + if "audioOutput" in nova_event: + # Audio is already base64 string from Nova Sonic + audio_content = nova_event["audioOutput"]["content"] + return BidiAudioStreamEvent( + audio=audio_content, + format="pcm", + sample_rate=cast(AudioSampleRate, self.config["audio"]["output_rate"]), + channels=cast(AudioChannel, self.config["audio"]["channels"]), + ) + + # Handle text output (transcripts) + elif "textOutput" in nova_event: + text_output = nova_event["textOutput"] + text_content = text_output["content"] + # Check for Nova Sonic interruption pattern + if '{ "interrupted" : true }' in text_content: + logger.debug("nova interruption detected in text output") + return BidiInterruptionEvent(reason="user_speech") + + return BidiTranscriptStreamEvent( + delta={"text": text_content}, + text=text_content, + role=text_output["role"].lower(), + is_final=self._generation_stage == "FINAL", + current_transcript=text_content, + ) + + # Handle tool use + if "toolUse" in nova_event: + tool_use = nova_event["toolUse"] + tool_use_event: ToolUse = { + "toolUseId": tool_use["toolUseId"], + "name": tool_use["toolName"], + "input": json.loads(tool_use["content"]), + } + # Return ToolUseStreamEvent - cast to dict for type compatibility + return ToolUseStreamEvent(delta={"toolUse": tool_use_event}, current_tool_use=dict(tool_use_event)) + + # Handle interruption + if nova_event.get("stopReason") == "INTERRUPTED": + logger.debug("nova interruption detected via stop reason") + return BidiInterruptionEvent(reason="user_speech") + + # Handle usage events - convert to multimodal usage format + if "usageEvent" in nova_event: + usage_data = nova_event["usageEvent"] + total_input = usage_data.get("totalInputTokens", 0) + total_output = usage_data.get("totalOutputTokens", 0) + + return BidiUsageEvent( + input_tokens=total_input, + output_tokens=total_output, + total_tokens=usage_data.get("totalTokens", total_input + total_output), + ) + + # Handle content start events (emit response start) + if "contentStart" in nova_event: + content_data = nova_event["contentStart"] + if content_data["type"] == "TEXT": + self._generation_stage = json.loads(content_data["additionalModelFields"])["generationStage"] + + # Emit response start event using API-provided completionId + # completionId should already be tracked from completionStart event + return BidiResponseStartEvent( + response_id=self._current_completion_id or str(uuid.uuid4()) # Fallback to UUID if missing + ) + + if "contentEnd" in nova_event: + self._generation_stage = None + + # Ignore all other events + return None + + def _get_connection_start_event(self) -> str: + """Generate Nova Sonic connection start event.""" + return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": NOVA_INFERENCE_CONFIG}}}) + + def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: + """Generate Nova Sonic prompt start event with tool configuration.""" + # Build audio output configuration from config + audio_output_config = { + "mediaType": "audio/lpcm", + "sampleRateHertz": self.config["audio"]["output_rate"], + "sampleSizeBits": 16, + "channelCount": self.config["audio"]["channels"], + "voiceId": self.config["audio"].get("voice", "matthew"), + "encoding": "base64", + "audioType": "SPEECH", + } + + prompt_start_event: dict[str, Any] = { + "event": { + "promptStart": { + "promptName": self._connection_id, + "textOutputConfiguration": NOVA_TEXT_CONFIG, + "audioOutputConfiguration": audio_output_config, + } + } + } + + if tools: + tool_config = self._build_tool_configuration(tools) + prompt_start_event["event"]["promptStart"]["toolUseOutputConfiguration"] = NOVA_TOOL_CONFIG + prompt_start_event["event"]["promptStart"]["toolConfiguration"] = {"tools": tool_config} + + return json.dumps(prompt_start_event) + + def _build_tool_configuration(self, tools: list[ToolSpec]) -> list[dict[str, Any]]: + """Build tool configuration from tool specs.""" + tool_config: list[dict[str, Any]] = [] + for tool in tools: + input_schema = ( + {"json": json.dumps(tool["inputSchema"]["json"])} + if "json" in tool["inputSchema"] + else {"json": json.dumps(tool["inputSchema"])} + ) + + tool_config.append( + {"toolSpec": {"name": tool["name"], "description": tool["description"], "inputSchema": input_schema}} + ) + return tool_config + + def _get_system_prompt_events(self, system_prompt: str | None) -> list[str]: + """Generate system prompt events.""" + content_name = str(uuid.uuid4()) + return [ + self._get_text_content_start_event(content_name, "SYSTEM"), + self._get_text_input_event(content_name, system_prompt or ""), + self._get_content_end_event(content_name), + ] + + def _get_message_history_events(self, messages: Messages) -> list[str]: + """Generate conversation history events from agent messages. + + Converts agent message history to Nova Sonic format following the + contentStart/textInput/contentEnd pattern for each message. + + Args: + messages: List of conversation messages with role and content. + + Returns: + List of JSON event strings for Nova Sonic. + """ + events = [] + + for message in messages: + role = message["role"].upper() # Convert to ASSISTANT or USER + content_blocks = message.get("content", []) + + # Extract text content from content blocks + text_parts = [] + for block in content_blocks: + if "text" in block: + text_parts.append(block["text"]) + + # Combine all text parts + if text_parts: + combined_text = "\n".join(text_parts) + content_name = str(uuid.uuid4()) + + # Add contentStart, textInput, and contentEnd events + events.extend( + [ + self._get_text_content_start_event(content_name, role), + self._get_text_input_event(content_name, combined_text), + self._get_content_end_event(content_name), + ] + ) + + return events + + def _get_text_content_start_event(self, content_name: str, role: str = "USER") -> str: + """Generate text content start event.""" + return json.dumps( + { + "event": { + "contentStart": { + "promptName": self._connection_id, + "contentName": content_name, + "type": "TEXT", + "role": role, + "interactive": True, + "textInputConfiguration": NOVA_TEXT_CONFIG, + } + } + } + ) + + def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> str: + """Generate tool content start event.""" + return json.dumps( + { + "event": { + "contentStart": { + "promptName": self._connection_id, + "contentName": content_name, + "interactive": False, + "type": "TOOL", + "role": "TOOL", + "toolResultInputConfiguration": { + "toolUseId": tool_use_id, + "type": "TEXT", + "textInputConfiguration": NOVA_TEXT_CONFIG, + }, + } + } + } + ) + + def _get_text_input_event(self, content_name: str, text: str) -> str: + """Generate text input event.""" + return json.dumps( + {"event": {"textInput": {"promptName": self._connection_id, "contentName": content_name, "content": text}}} + ) + + def _get_tool_result_event(self, content_name: str, result: dict[str, Any]) -> str: + """Generate tool result event.""" + return json.dumps( + { + "event": { + "toolResult": { + "promptName": self._connection_id, + "contentName": content_name, + "content": json.dumps(result), + } + } + } + ) + + def _get_content_end_event(self, content_name: str) -> str: + """Generate content end event.""" + return json.dumps({"event": {"contentEnd": {"promptName": self._connection_id, "contentName": content_name}}}) + + def _get_prompt_end_event(self) -> str: + """Generate prompt end event.""" + return json.dumps({"event": {"promptEnd": {"promptName": self._connection_id}}}) + + def _get_connection_end_event(self) -> str: + """Generate connection end event.""" + return json.dumps({"event": {"connectionEnd": {}}}) + + async def _send_nova_events(self, events: list[str]) -> None: + """Send event JSON string to Nova Sonic stream. + + A lock is used to send events in sequence when required (e.g., tool result start, content, and end). + + Args: + events: Jsonified events. + """ + async with self._send_lock: + for event in events: + bytes_data = event.encode("utf-8") + chunk = InvokeModelWithBidirectionalStreamInputChunk( + value=BidirectionalInputPayloadPart(bytes_=bytes_data) + ) + await self._stream.input_stream.send(chunk) + logger.debug("nova sonic event sent successfully") diff --git a/src/strands/experimental/bidi/models/openai_realtime.py b/src/strands/experimental/bidi/models/openai_realtime.py new file mode 100644 index 000000000..79ef5f78c --- /dev/null +++ b/src/strands/experimental/bidi/models/openai_realtime.py @@ -0,0 +1,816 @@ +"""OpenAI Realtime API provider for Strands bidirectional streaming. + +Provides real-time audio and text communication through OpenAI's Realtime API +with WebSocket connections, voice activity detection, and function calling. +""" + +import asyncio +import json +import logging +import os +import time +import uuid +from typing import Any, AsyncGenerator, Literal, cast + +import websockets +from websockets import ClientConnection + +from ....types._events import ToolResultEvent, ToolUseStreamEvent +from ....types.content import Messages +from ....types.tools import ToolResult, ToolSpec, ToolUse +from .._async import stop_all +from ..types.bidi_model import AudioConfig +from ..types.events import ( + AudioSampleRate, + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionStartEvent, + BidiInputEvent, + BidiInterruptionEvent, + BidiOutputEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiUsageEvent, + ModalityUsage, + Role, + StopReason, +) +from .model import BidiModel, BidiModelTimeoutError + +logger = logging.getLogger(__name__) + +# Test idle_timeout_ms + +# OpenAI Realtime API configuration +OPENAI_MAX_TIMEOUT_S = 3000 # 50 minutes +"""Max timeout before closing connection. + +OpenAI documents a 60 minute limit on realtime sessions +(https://platform.openai.com/docs/guides/realtime-conversations#session-lifecycle-events). However, OpenAI does not +emit any warnings when approaching the limit. As a workaround, we configure a max timeout client side to gracefully +handle the connection closure. We set the max to 50 minutes to provide enough buffer before hitting the real limit. +""" +OPENAI_REALTIME_URL = "wss://api.openai.com/v1/realtime" +DEFAULT_MODEL = "gpt-realtime" +DEFAULT_SAMPLE_RATE = 24000 + +DEFAULT_SESSION_CONFIG = { + "type": "realtime", + "instructions": "You are a helpful assistant. Please speak in English and keep your responses clear and concise.", + "output_modalities": ["audio"], + "audio": { + "input": { + "format": {"type": "audio/pcm", "rate": DEFAULT_SAMPLE_RATE}, + "transcription": {"model": "gpt-4o-transcribe"}, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "prefix_padding_ms": 300, + "silence_duration_ms": 500, + }, + }, + "output": {"format": {"type": "audio/pcm", "rate": DEFAULT_SAMPLE_RATE}, "voice": "alloy"}, + }, +} + + +class BidiOpenAIRealtimeModel(BidiModel): + """OpenAI Realtime API implementation for bidirectional streaming. + + Combines model configuration and connection state in a single class. + Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, + function calling, and event conversion to Strands format. + """ + + _websocket: ClientConnection + _start_time: int + + def __init__( + self, + model_id: str = DEFAULT_MODEL, + provider_config: dict[str, Any] | None = None, + client_config: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + """Initialize OpenAI Realtime bidirectional model. + + Args: + model_id: Model identifier (default: gpt-realtime) + provider_config: Model behavior (audio, instructions, turn_detection, etc.) + client_config: Authentication (api_key, organization, project) + Falls back to OPENAI_API_KEY, OPENAI_ORGANIZATION, OPENAI_PROJECT env vars + **kwargs: Reserved for future parameters. + + """ + # Store model ID + self.model_id = model_id + + # Resolve client config with defaults and env vars + self._client_config = self._resolve_client_config(client_config or {}) + + # Resolve provider config with defaults + self.config = self._resolve_provider_config(provider_config or {}) + + # Store client config values for later use + self.api_key = self._client_config["api_key"] + self.organization = self._client_config.get("organization") + self.project = self._client_config.get("project") + self.timeout_s = self._client_config["timeout_s"] + + if self.timeout_s > OPENAI_MAX_TIMEOUT_S: + raise ValueError( + f"timeout_s=<{self.timeout_s}>, max_timeout_s=<{OPENAI_MAX_TIMEOUT_S}> | timeout exceeds max limit" + ) + + # Connection state (initialized in start()) + self._connection_id: str | None = None + + self._function_call_buffer: dict[str, Any] = {} + + logger.debug("model=<%s> | openai realtime model initialized", model_id) + + def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Resolve client config with env var fallback (config takes precedence).""" + resolved = config.copy() + + if "api_key" not in resolved: + resolved["api_key"] = os.getenv("OPENAI_API_KEY") + + if not resolved.get("api_key"): + raise ValueError( + "OpenAI API key is required. Provide via client_config={'api_key': '...'} " + "or set OPENAI_API_KEY environment variable." + ) + if "organization" not in resolved: + env_org = os.getenv("OPENAI_ORGANIZATION") + if env_org: + resolved["organization"] = env_org + + if "project" not in resolved: + env_project = os.getenv("OPENAI_PROJECT") + if env_project: + resolved["project"] = env_project + + if "timeout_s" not in resolved: + resolved["timeout_s"] = OPENAI_MAX_TIMEOUT_S + + return resolved + + def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Merge user config with defaults (user takes precedence).""" + # Extract voice from provider-specific audio.output.voice if present + provider_voice = None + if "audio" in config and isinstance(config["audio"], dict): + if "output" in config["audio"] and isinstance(config["audio"]["output"], dict): + provider_voice = config["audio"]["output"].get("voice") + + # Define default audio configuration + default_audio: AudioConfig = { + "input_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE), + "output_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE), + "channels": 1, + "format": "pcm", + "voice": provider_voice or "alloy", + } + + user_audio = config.get("audio", {}) + merged_audio = {**default_audio, **user_audio} + + resolved = { + "audio": merged_audio, + **{k: v for k, v in config.items() if k != "audio"}, + } + + if user_audio: + logger.debug("audio_config | merged user-provided config with defaults") + else: + logger.debug("audio_config | using default OpenAI Realtime audio configuration") + + return resolved + + async def start( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs: Any, + ) -> None: + """Establish bidirectional connection to OpenAI Realtime API. + + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Additional configuration options. + """ + if self._connection_id: + raise RuntimeError("model already started | call stop before starting again") + + logger.debug("openai realtime connection starting") + + # Initialize connection state + self._connection_id = str(uuid.uuid4()) + self._start_time = int(time.time()) + + self._function_call_buffer = {} + + # Establish WebSocket connection + url = f"{OPENAI_REALTIME_URL}?model={self.model_id}" + + headers = [("Authorization", f"Bearer {self.api_key}")] + if self.organization: + headers.append(("OpenAI-Organization", self.organization)) + if self.project: + headers.append(("OpenAI-Project", self.project)) + + self._websocket = await websockets.connect(url, additional_headers=headers) + logger.debug("connection_id=<%s> | websocket connected successfully", self._connection_id) + + # Configure session + session_config = self._build_session_config(system_prompt, tools) + await self._send_event({"type": "session.update", "session": session_config}) + + # Add conversation history if provided + if messages: + await self._add_conversation_history(messages) + + def _create_text_event(self, text: str, role: str, is_final: bool = True) -> BidiTranscriptStreamEvent: + """Create standardized transcript event. + + Args: + text: The transcript text + role: The role (will be normalized to lowercase) + is_final: Whether this is the final transcript + """ + # Normalize role to lowercase and ensure it's either "user" or "assistant" + normalized_role = role.lower() if isinstance(role, str) else "assistant" + if normalized_role not in ["user", "assistant"]: + normalized_role = "assistant" + + return BidiTranscriptStreamEvent( + delta={"text": text}, + text=text, + role=cast(Role, normalized_role), + is_final=is_final, + current_transcript=text if is_final else None, + ) + + def _create_voice_activity_event(self, activity_type: str) -> BidiInterruptionEvent | None: + """Create standardized interruption event for voice activity.""" + # Only speech_started triggers interruption + if activity_type == "speech_started": + return BidiInterruptionEvent(reason="user_speech") + # Other voice activity events are logged but don't create events + return None + + def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict[str, Any]: + """Build session configuration for OpenAI Realtime API.""" + config: dict[str, Any] = DEFAULT_SESSION_CONFIG.copy() + + if system_prompt: + config["instructions"] = system_prompt + + if tools: + config["tools"] = self._convert_tools_to_openai_format(tools) + + # Apply user-provided session configuration + supported_params = { + "type", + "output_modalities", + "instructions", + "voice", + "tools", + "tool_choice", + "input_audio_format", + "output_audio_format", + "input_audio_transcription", + "turn_detection", + } + + for key, value in self.config.items(): + if key == "audio": + continue + elif key in supported_params: + config[key] = value + else: + logger.warning("parameter=<%s> | ignoring unsupported session parameter", key) + + audio_config = self.config["audio"] + + if "voice" in audio_config: + config.setdefault("audio", {}).setdefault("output", {})["voice"] = audio_config["voice"] + + if "input_rate" in audio_config: + config.setdefault("audio", {}).setdefault("input", {}).setdefault("format", {})["rate"] = audio_config[ + "input_rate" + ] + + if "output_rate" in audio_config: + config.setdefault("audio", {}).setdefault("output", {}).setdefault("format", {})["rate"] = audio_config[ + "output_rate" + ] + + return config + + def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: + """Convert Strands tool specifications to OpenAI Realtime API format.""" + openai_tools = [] + + for tool in tools: + input_schema = tool["inputSchema"] + if "json" in input_schema: + schema = ( + json.loads(input_schema["json"]) if isinstance(input_schema["json"], str) else input_schema["json"] + ) + else: + schema = input_schema + + # OpenAI Realtime API expects flat structure, not nested under "function" + openai_tool = { + "type": "function", + "name": tool["name"], + "description": tool["description"], + "parameters": schema, + } + openai_tools.append(openai_tool) + + return openai_tools + + async def _add_conversation_history(self, messages: Messages) -> None: + """Add conversation history to the session. + + Converts agent message history to OpenAI Realtime API format using + conversation.item.create events for each message. + + Note: OpenAI Realtime API has a 32-character limit on call_id, so we truncate + UUIDs consistently to ensure tool calls and their results match. + + Args: + messages: List of conversation messages with role and content. + """ + # Track tool call IDs to ensure consistency between calls and results + call_id_map: dict[str, str] = {} + + # First pass: collect all tool call IDs + for message in messages: + for block in message.get("content", []): + if "toolUse" in block: + tool_use = block["toolUse"] + original_id = tool_use["toolUseId"] + call_id = original_id[:32] + call_id_map[original_id] = call_id + + # Second pass: send messages + for message in messages: + role = message["role"] + content_blocks = message.get("content", []) + + # Build content array for OpenAI format + openai_content = [] + + for block in content_blocks: + if "text" in block: + # Text content - use appropriate type based on role + # User messages use "input_text", assistant messages use "output_text" + if role == "user": + openai_content.append({"type": "input_text", "text": block["text"]}) + else: # assistant + openai_content.append({"type": "output_text", "text": block["text"]}) + elif "toolUse" in block: + # Tool use - create as function_call item + tool_use = block["toolUse"] + original_id = tool_use["toolUseId"] + # Use pre-mapped call_id + call_id = call_id_map[original_id] + + tool_item = { + "type": "conversation.item.create", + "item": { + "type": "function_call", + "call_id": call_id, + "name": tool_use["name"], + "arguments": json.dumps(tool_use["input"]), + }, + } + await self._send_event(tool_item) + continue # Tool use is sent separately, not in message content + elif "toolResult" in block: + # Tool result - create as function_call_output item + tool_result = block["toolResult"] + original_id = tool_result["toolUseId"] + + # Validate content types and serialize, preserving structure + result_output = "" + if "content" in tool_result: + # First validate all content types are supported + for result_block in tool_result["content"]: + if "text" not in result_block and "json" not in result_block: + # Unsupported content type - raise error + raise ValueError( + f"tool_use_id=<{original_id}>, content_types=<{list(result_block.keys())}> | " + f"Content type not supported by OpenAI Realtime API" + ) + + # Preserve structure by JSON-dumping the entire content array + result_output = json.dumps(tool_result["content"]) + + # Use mapped call_id if available, otherwise skip orphaned result + if original_id not in call_id_map: + continue # Skip this tool result since we don't have the call + + call_id = call_id_map[original_id] + + result_item = { + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": call_id, + "output": result_output, + }, + } + await self._send_event(result_item) + continue # Tool result is sent separately, not in message content + + # Only create message item if there's text content + if openai_content: + conversation_item = { + "type": "conversation.item.create", + "item": {"type": "message", "role": role, "content": openai_content}, + } + await self._send_event(conversation_item) + + logger.debug("message_count=<%d> | conversation history added to openai session", len(messages)) + + async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: + """Receive OpenAI events and convert to Strands TypedEvent format.""" + if not self._connection_id: + raise RuntimeError("model not started | call start before sending/receiving") + + yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) + + while True: + duration = time.time() - self._start_time + if duration >= self.timeout_s: + raise BidiModelTimeoutError(f"timeout_s=<{self.timeout_s}>") + + try: + message = await asyncio.wait_for(self._websocket.recv(), timeout=10) + except asyncio.TimeoutError: + continue + + openai_event = json.loads(message) + + for event in self._convert_openai_event(openai_event) or []: + yield event + + def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutputEvent] | None: + """Convert OpenAI events to Strands TypedEvent format.""" + event_type = openai_event.get("type") + + # Turn start - response begins + if event_type == "response.created": + response = openai_event.get("response", {}) + response_id = response.get("id", str(uuid.uuid4())) + return [BidiResponseStartEvent(response_id=response_id)] + + # Audio output + elif event_type == "response.output_audio.delta": + # Audio is already base64 string from OpenAI + # Use the resolved output sample rate from our merged configuration + sample_rate = self.config["audio"]["output_rate"] + + # Channels from config is guaranteed to be 1 or 2 + channels = cast(Literal[1, 2], self.config["audio"]["channels"]) + return [ + BidiAudioStreamEvent( + audio=openai_event["delta"], + format="pcm", + sample_rate=sample_rate, + channels=channels, + ) + ] + + # Assistant text output events - combine multiple similar events + elif event_type in ["response.output_text.delta", "response.output_audio_transcript.delta"]: + role = openai_event.get("role", "assistant") + return [ + self._create_text_event( + openai_event["delta"], role.lower() if isinstance(role, str) else "assistant", is_final=False + ) + ] + + elif event_type in ["response.output_audio_transcript.done"]: + role = openai_event.get("role", "assistant").lower() + return [self._create_text_event(openai_event["transcript"], role)] + + elif event_type in ["response.output_text.done"]: + role = openai_event.get("role", "assistant").lower() + return [self._create_text_event(openai_event["text"], role)] + + # User transcription events - combine multiple similar events + elif event_type in [ + "conversation.item.input_audio_transcription.delta", + "conversation.item.input_audio_transcription.completed", + ]: + text_key = "delta" if "delta" in event_type else "transcript" + text = openai_event.get(text_key, "") + role = openai_event.get("role", "user") + is_final = "completed" in event_type + return ( + [self._create_text_event(text, role.lower() if isinstance(role, str) else "user", is_final=is_final)] + if text.strip() + else None + ) + + elif event_type == "conversation.item.input_audio_transcription.segment": + segment_data = openai_event.get("segment", {}) + text = segment_data.get("text", "") + role = segment_data.get("role", "user") + return ( + [self._create_text_event(text, role.lower() if isinstance(role, str) else "user")] + if text.strip() + else None + ) + + elif event_type == "conversation.item.input_audio_transcription.failed": + error_info = openai_event.get("error", {}) + logger.warning("error=<%s> | openai transcription failed", error_info.get("message", "unknown error")) + return None + + # Function call processing + elif event_type == "response.function_call_arguments.delta": + call_id = openai_event.get("call_id") + delta = openai_event.get("delta", "") + if call_id: + if call_id not in self._function_call_buffer: + self._function_call_buffer[call_id] = {"call_id": call_id, "name": "", "arguments": delta} + else: + self._function_call_buffer[call_id]["arguments"] += delta + return None + + elif event_type == "response.function_call_arguments.done": + call_id = openai_event.get("call_id") + if call_id and call_id in self._function_call_buffer: + function_call = self._function_call_buffer[call_id] + try: + tool_use: ToolUse = { + "toolUseId": call_id, + "name": function_call["name"], + "input": json.loads(function_call["arguments"]) if function_call["arguments"] else {}, + } + del self._function_call_buffer[call_id] + # Return ToolUseStreamEvent for consistency with standard agent + return [ToolUseStreamEvent(delta={"toolUse": tool_use}, current_tool_use=dict(tool_use))] + except (json.JSONDecodeError, KeyError) as e: + logger.warning("call_id=<%s>, error=<%s> | error parsing function arguments", call_id, e) + del self._function_call_buffer[call_id] + return None + + # Voice activity detection - speech_started triggers interruption + elif event_type == "input_audio_buffer.speech_started": + # This is the primary interruption signal - handle it first + return [BidiInterruptionEvent(reason="user_speech")] + + # Response cancelled - handle interruption + elif event_type == "response.cancelled": + response = openai_event.get("response", {}) + response_id = response.get("id", "unknown") + logger.debug("response_id=<%s> | openai response cancelled", response_id) + return [BidiResponseCompleteEvent(response_id=response_id, stop_reason="interrupted")] + + # Turn complete and usage - response finished + elif event_type == "response.done": + response = openai_event.get("response", {}) + response_id = response.get("id", "unknown") + status = response.get("status", "completed") + usage = response.get("usage") + + # Map OpenAI status to our stop_reason + stop_reason_map = { + "completed": "complete", + "cancelled": "interrupted", + "failed": "error", + "incomplete": "interrupted", + } + + # Build list of events to return + events: list[Any] = [] + + # Always add response complete event + events.append( + BidiResponseCompleteEvent( + response_id=response_id, + stop_reason=cast(StopReason, stop_reason_map.get(status, "complete")), + ), + ) + + # Add usage event if available + if usage: + input_details = usage.get("input_token_details", {}) + output_details = usage.get("output_token_details", {}) + + # Build modality details + modality_details = [] + + # Text modality + text_input = input_details.get("text_tokens", 0) + text_output = output_details.get("text_tokens", 0) + if text_input > 0 or text_output > 0: + modality_details.append( + {"modality": "text", "input_tokens": text_input, "output_tokens": text_output} + ) + + # Audio modality + audio_input = input_details.get("audio_tokens", 0) + audio_output = output_details.get("audio_tokens", 0) + if audio_input > 0 or audio_output > 0: + modality_details.append( + {"modality": "audio", "input_tokens": audio_input, "output_tokens": audio_output} + ) + + # Image modality + image_input = input_details.get("image_tokens", 0) + if image_input > 0: + modality_details.append({"modality": "image", "input_tokens": image_input, "output_tokens": 0}) + + # Cached tokens + cached_tokens = input_details.get("cached_tokens", 0) + + # Add usage event + events.append( + BidiUsageEvent( + input_tokens=usage.get("input_tokens", 0), + output_tokens=usage.get("output_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + modality_details=cast(list[ModalityUsage], modality_details) if modality_details else None, + cache_read_input_tokens=cached_tokens if cached_tokens > 0 else None, + ) + ) + + # Return list of events + return events + + # Lifecycle events (log only) - combine multiple similar events + elif event_type in ["conversation.item.retrieve", "conversation.item.added"]: + item = openai_event.get("item", {}) + action = "retrieved" if "retrieve" in event_type else "added" + logger.debug("action=<%s>, item_id=<%s> | openai conversation item event", action, item.get("id")) + return None + + elif event_type == "conversation.item.done": + logger.debug("item_id=<%s> | openai conversation item done", openai_event.get("item", {}).get("id")) + return None + + # Response output events - combine similar events + elif event_type in [ + "response.output_item.added", + "response.output_item.done", + "response.content_part.added", + "response.content_part.done", + ]: + item_data = openai_event.get("item") or openai_event.get("part") + logger.debug( + "event_type=<%s>, item_id=<%s> | openai output event", + event_type, + item_data.get("id") if item_data else "unknown", + ) + + # Track function call names from response.output_item.added + if event_type == "response.output_item.added": + item = openai_event.get("item", {}) + if item.get("type") == "function_call": + call_id = item.get("call_id") + function_name = item.get("name") + if call_id and function_name: + if call_id not in self._function_call_buffer: + self._function_call_buffer[call_id] = { + "call_id": call_id, + "name": function_name, + "arguments": "", + } + else: + self._function_call_buffer[call_id]["name"] = function_name + return None + + # Session/buffer events - combine simple log-only events + elif event_type in [ + "input_audio_buffer.committed", + "input_audio_buffer.cleared", + "session.created", + "session.updated", + ]: + logger.debug("event_type=<%s> | openai event received", event_type) + return None + + elif event_type == "error": + error_data = openai_event.get("error", {}) + error_code = error_data.get("code", "") + + # Suppress expected errors that don't affect session state + if error_code == "response_cancel_not_active": + # This happens when trying to cancel a response that's not active + # It's safe to ignore as the session remains functional + logger.debug("openai response cancel attempted when no response active") + return None + + # Log other errors + logger.error("error=<%s> | openai realtime error", error_data) + return None + + else: + logger.debug("event_type=<%s> | unhandled openai event type", event_type) + return None + + async def send( + self, + content: BidiInputEvent | ToolResultEvent, + ) -> None: + """Unified send method for all content types. Sends the given content to OpenAI. + + Dispatches to appropriate internal handler based on content type. + + Args: + content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent). + + Raises: + ValueError: If content type not supported (e.g., image content). + """ + if not self._connection_id: + raise RuntimeError("model not started | call start before sending") + + # Note: TypedEvent inherits from dict, so isinstance checks for TypedEvent must come first + if isinstance(content, BidiTextInputEvent): + await self._send_text_content(content.text) + elif isinstance(content, BidiAudioInputEvent): + await self._send_audio_content(content) + elif isinstance(content, ToolResultEvent): + tool_result = content.get("tool_result") + if tool_result: + await self._send_tool_result(tool_result) + else: + raise ValueError(f"content_type={type(content)} | content not supported") + + async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: + """Internal: Send audio content to OpenAI for processing.""" + # Audio is already base64 encoded in the event + await self._send_event({"type": "input_audio_buffer.append", "audio": audio_input.audio}) + + async def _send_text_content(self, text: str) -> None: + """Internal: Send text content to OpenAI for processing.""" + item_data = {"type": "message", "role": "user", "content": [{"type": "input_text", "text": text}]} + await self._send_event({"type": "conversation.item.create", "item": item_data}) + await self._send_event({"type": "response.create"}) + + async def _send_interrupt(self) -> None: + """Internal: Send interruption signal to OpenAI.""" + await self._send_event({"type": "response.cancel"}) + + async def _send_tool_result(self, tool_result: ToolResult) -> None: + """Internal: Send tool result back to OpenAI.""" + tool_use_id = tool_result.get("toolUseId") + + logger.debug("tool_use_id=<%s> | sending openai tool result", tool_use_id) + + # Validate content types and serialize, preserving structure + result_output = "" + if "content" in tool_result: + # First validate all content types are supported + for block in tool_result["content"]: + if "text" not in block and "json" not in block: + # Unsupported content type - raise error + raise ValueError( + f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | " + f"Content type not supported by OpenAI Realtime API" + ) + + # Preserve structure by JSON-dumping the entire content array + result_output = json.dumps(tool_result["content"]) + + item_data = {"type": "function_call_output", "call_id": tool_use_id, "output": result_output} + await self._send_event({"type": "conversation.item.create", "item": item_data}) + await self._send_event({"type": "response.create"}) + + async def stop(self) -> None: + """Close session and cleanup resources.""" + logger.debug("openai realtime connection cleanup starting") + + async def stop_websocket() -> None: + if not hasattr(self, "_websocket"): + return + + await self._websocket.close() + + async def stop_connection() -> None: + self._connection_id = None + + await stop_all(stop_websocket, stop_connection) + + logger.debug("openai realtime connection closed") + + async def _send_event(self, event: dict[str, Any]) -> None: + """Send event to OpenAI via WebSocket.""" + message = json.dumps(event) + await self._websocket.send(message) + logger.debug("event_type=<%s> | openai event sent", event.get("type")) diff --git a/src/strands/experimental/bidi/types/model.py b/src/strands/experimental/bidi/types/model.py new file mode 100644 index 000000000..de41de1a9 --- /dev/null +++ b/src/strands/experimental/bidi/types/model.py @@ -0,0 +1,36 @@ +"""Model-related type definitions for bidirectional streaming. + +Defines types and configurations that are central to model providers, +including audio configuration that models use to specify their audio +processing requirements. +""" + +from typing import TypedDict + +from .events import AudioChannel, AudioFormat, AudioSampleRate + + +class AudioConfig(TypedDict, total=False): + """Audio configuration for bidirectional streaming models. + + Defines standard audio parameters that model providers use to specify + their audio processing requirements. All fields are optional to support + models that may not use audio or only need specific parameters. + + Model providers build this configuration by merging user-provided values + with their own defaults. The resulting configuration is then used by + audio I/O implementations to configure hardware appropriately. + + Attributes: + input_rate: Input sample rate in Hz (e.g., 16000, 24000, 48000) + output_rate: Output sample rate in Hz (e.g., 16000, 24000, 48000) + channels: Number of audio channels (1=mono, 2=stereo) + format: Audio encoding format + voice: Voice identifier for text-to-speech (e.g., "alloy", "matthew") + """ + + input_rate: AudioSampleRate + output_rate: AudioSampleRate + channels: AudioChannel + format: AudioFormat + voice: str From 78b3ebc4ba16c046fed0af4d185127ecb6ea693c Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Sun, 30 Nov 2025 16:14:57 -0500 Subject: [PATCH 228/242] fix formatting on docstrings (#98) --- src/strands/experimental/bidi/agent/agent.py | 2 ++ src/strands/experimental/bidi/io/audio.py | 2 ++ src/strands/experimental/bidi/io/text.py | 1 + src/strands/experimental/bidi/models/bidi_model.py | 4 ++++ src/strands/experimental/bidi/models/gemini_live.py | 2 ++ src/strands/experimental/bidi/models/novasonic.py | 1 + src/strands/experimental/bidi/models/openai.py | 4 ++-- src/strands/experimental/bidi/types/events.py | 6 +++++- src/strands/experimental/hooks/__init__.py | 5 +---- src/strands/experimental/hooks/events.py | 11 ++++------- tests/strands/experimental/bidi/models/test_openai.py | 8 ++++---- tests/strands/experimental/hooks/test_hook_aliases.py | 2 +- 12 files changed, 29 insertions(+), 19 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 74b65ba10..2bfbdb3fa 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -6,6 +6,7 @@ continuous responses including audio output. Key capabilities: + - Persistent conversation connections with concurrent processing - Real-time audio input/output streaming - Automatic interruption detection and tool execution @@ -233,6 +234,7 @@ async def send(self, input_data: BidiAgentInput | dict[str, Any]) -> None: Args: input_data: Can be: + - str: Text message from user - BidiInputEvent: TypedEvent - dict: Event dictionary (will be reconstructed to TypedEvent) diff --git a/src/strands/experimental/bidi/io/audio.py b/src/strands/experimental/bidi/io/audio.py index b5404b749..5eff829e9 100644 --- a/src/strands/experimental/bidi/io/audio.py +++ b/src/strands/experimental/bidi/io/audio.py @@ -73,6 +73,7 @@ def get(self, byte_count: int | None = None) -> bytes: Args: byte_count: Number of bytes to get from buffer. + - If the number of bytes specified is not available, the return is padded with silence. - If the number of bytes is not specified, get the first chunk put in the buffer. @@ -274,6 +275,7 @@ def __init__(self, **config: Any) -> None: Args: **config: Optional device configuration: + - input_buffer_size (int): Maximum input buffer size (default: None) - input_device_index (int): Specific input device (default: None = system default) - input_frames_per_buffer (int): Input buffer size (default: 512) diff --git a/src/strands/experimental/bidi/io/text.py b/src/strands/experimental/bidi/io/text.py index 1fe906de0..f575c5606 100644 --- a/src/strands/experimental/bidi/io/text.py +++ b/src/strands/experimental/bidi/io/text.py @@ -73,6 +73,7 @@ def __init__(self, **config: Any) -> None: Args: **config: Optional I/O configurations. + - input_prompt (str): Input prompt to display on screen (default: blank) """ self._config = config diff --git a/src/strands/experimental/bidi/models/bidi_model.py b/src/strands/experimental/bidi/models/bidi_model.py index 0d0da63d2..f5e34aa50 100644 --- a/src/strands/experimental/bidi/models/bidi_model.py +++ b/src/strands/experimental/bidi/models/bidi_model.py @@ -6,6 +6,7 @@ text, and tool interactions. Features: + - Persistent connection management with connect/close lifecycle - Real-time bidirectional communication (send and receive simultaneously) - Provider-agnostic event normalization @@ -96,16 +97,19 @@ async def send( Args: content: The content to send. Must be one of: + - BidiTextInputEvent: Text message from the user - BidiAudioInputEvent: Audio data for speech input - BidiImageInputEvent: Image data for visual understanding - ToolResultEvent: Result from a tool execution Example: + ``` await model.send(BidiTextInputEvent(text="Hello", role="user")) await model.send(BidiAudioInputEvent(audio=bytes, format="pcm", sample_rate=16000, channels=1)) await model.send(BidiImageInputEvent(image=bytes, mime_type="image/jpeg", encoding="raw")) await model.send(ToolResultEvent(tool_result)) + ``` """ ... diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 1f2b2d5cd..b8daff291 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -4,6 +4,7 @@ official Google GenAI SDK for simplified and robust WebSocket communication. Key improvements over custom WebSocket implementation: + - Uses official google-genai SDK with native Live API support - Simplified session management with client.aio.live.connect() - Built-in tool integration and event handling @@ -221,6 +222,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> list[BidiOut """Convert Gemini Live API events to provider-agnostic format. Handles different types of content: + - inputTranscription: User's speech transcribed to text - outputTranscription: Model's audio transcribed to text - modelTurn text: Text response from the model diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 713afe028..968c42358 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -5,6 +5,7 @@ InvokeModelWithBidirectionalStream protocol. Nova Sonic specifics: + - Hierarchical event sequences: connectionStart → promptStart → content streaming - Base64-encoded audio format with hex encoding - Tool execution with content containers and identifier tracking diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index bfe3ad533..af38ef706 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -48,8 +48,8 @@ """Max timeout before closing connection. OpenAI documents a 60 minute limit on realtime sessions -(https://platform.openai.com/docs/guides/realtime-conversations#session-lifecycle-events). However, OpenAI does not -emit any warnings when approaching the limit. As a workaround, we configure a max timeout client side to gracefully +([docs](https://platform.openai.com/docs/guides/realtime-conversations#session-lifecycle-events)). However, OpenAI does +not emit any warnings when approaching the limit. As a workaround, we configure a max timeout client side to gracefully handle the connection closure. We set the max to 50 minutes to provide enough buffer before hitting the real limit. """ OPENAI_REALTIME_URL = "wss://api.openai.com/v1/realtime" diff --git a/src/strands/experimental/bidi/types/events.py b/src/strands/experimental/bidi/types/events.py index 7ea2b6345..572ab56db 100644 --- a/src/strands/experimental/bidi/types/events.py +++ b/src/strands/experimental/bidi/types/events.py @@ -4,6 +4,7 @@ capabilities with real-time audio and persistent connection support. Key features: + - Audio input/output events with standardized formats - Interruption detection and handling - Connection lifecycle management @@ -12,6 +13,7 @@ - JSON-serializable events (audio/images stored as base64 strings) Audio format normalization: + - Supports PCM, WAV, Opus, and MP3 formats - Standardizes sample rates (16kHz, 24kHz, 48kHz) - Normalizes channel configurations (mono/stereo) @@ -29,6 +31,7 @@ AudioChannel = Literal[1, 2] """Number of audio channels. + - Mono: 1 - Stereo: 2 """ @@ -362,7 +365,6 @@ class BidiInterruptionEvent(TypedEvent): Parameters: reason: Why the interruption occurred. - response_id: ID of the response that was interrupted (may be None). """ def __init__(self, reason: Literal["user_speech", "error"]): @@ -592,6 +594,7 @@ def details(self) -> dict[str, Any] | None: # BidiInputEvent in send() methods for sending tool results back to the model. BidiInputEvent = BidiTextInputEvent | BidiAudioInputEvent | BidiImageInputEvent +"""Union of different bidi input event types.""" BidiOutputEvent = ( BidiConnectionStartEvent @@ -606,3 +609,4 @@ def details(self) -> dict[str, Any] | None: | BidiErrorEvent | ToolUseStreamEvent ) +"""Union of different bidi output event types.""" diff --git a/src/strands/experimental/hooks/__init__.py b/src/strands/experimental/hooks/__init__.py index 7c3c2b269..c76b57ea4 100644 --- a/src/strands/experimental/hooks/__init__.py +++ b/src/strands/experimental/hooks/__init__.py @@ -1,7 +1,4 @@ -"""Experimental hook functionality that has not yet reached stability. - -BidiAgent hooks are also available here to avoid circular imports. -""" +"""Experimental hook functionality that has not yet reached stability.""" from .events import ( AfterModelInvocationEvent, diff --git a/src/strands/experimental/hooks/events.py b/src/strands/experimental/hooks/events.py index f486f5ec4..8a8d80629 100644 --- a/src/strands/experimental/hooks/events.py +++ b/src/strands/experimental/hooks/events.py @@ -1,8 +1,6 @@ -"""Experimental hook events emitted as part of invoking Agents. +"""Experimental hook events emitted as part of invoking Agents and BidiAgents. -This module defines the events that are emitted as Agents run through the lifecycle of a request. - -BidiAgent hook events are also defined here to avoid circular imports. +This module defines the events that are emitted as Agents and BidiAgents run through the lifecycle of a request. """ import warnings @@ -19,8 +17,8 @@ from ..bidi.models import BidiModelTimeoutError warnings.warn( - "These events have been moved to production with updated names. Use BeforeModelCallEvent, " - "AfterModelCallEvent, BeforeToolCallEvent, and AfterToolCallEvent from strands.hooks instead.", + "BeforeModelCallEvent, AfterModelCallEvent, BeforeToolCallEvent, and AfterToolCallEvent are no longer experimental." + "Import from strands.hooks instead.", DeprecationWarning, stacklevel=2, ) @@ -32,7 +30,6 @@ # BidiAgent Hook Events -# These are defined here to avoid circular imports with the bidi package @dataclass diff --git a/tests/strands/experimental/bidi/models/test_openai.py b/tests/strands/experimental/bidi/models/test_openai.py index 85a1cc097..04381810e 100644 --- a/tests/strands/experimental/bidi/models/test_openai.py +++ b/tests/strands/experimental/bidi/models/test_openai.py @@ -816,7 +816,7 @@ async def test_tool_result_single_text_content(mock_websockets_connect, api_key) async def test_tool_result_single_json_content(mock_websockets_connect, api_key): """Test tool result with single JSON content block.""" _, mock_ws = mock_websockets_connect - model = BidiOpenAIRealtimeModel(api_key=api_key) + model = BidiOpenAIRealtimeModel(client_config={"api_key": api_key}) await model.start() tool_result: ToolResult = { @@ -846,7 +846,7 @@ async def test_tool_result_single_json_content(mock_websockets_connect, api_key) async def test_tool_result_multiple_content_blocks(mock_websockets_connect, api_key): """Test tool result with multiple content blocks (text and json).""" _, mock_ws = mock_websockets_connect - model = BidiOpenAIRealtimeModel(api_key=api_key) + model = BidiOpenAIRealtimeModel(client_config={"api_key": api_key}) await model.start() tool_result: ToolResult = { @@ -884,7 +884,7 @@ async def test_tool_result_multiple_content_blocks(mock_websockets_connect, api_ async def test_tool_result_image_content_raises_error(mock_websockets_connect, api_key): """Test that tool result with image content raises ValueError.""" _, mock_ws = mock_websockets_connect - model = BidiOpenAIRealtimeModel(api_key=api_key) + model = BidiOpenAIRealtimeModel(client_config={"api_key": api_key}) await model.start() tool_result: ToolResult = { @@ -903,7 +903,7 @@ async def test_tool_result_image_content_raises_error(mock_websockets_connect, a async def test_tool_result_document_content_raises_error(mock_websockets_connect, api_key): """Test that tool result with document content raises ValueError.""" _, mock_ws = mock_websockets_connect - model = BidiOpenAIRealtimeModel(api_key=api_key) + model = BidiOpenAIRealtimeModel(client_config={"api_key": api_key}) await model.start() tool_result: ToolResult = { diff --git a/tests/strands/experimental/hooks/test_hook_aliases.py b/tests/strands/experimental/hooks/test_hook_aliases.py index 6744aa00c..f4899f2ab 100644 --- a/tests/strands/experimental/hooks/test_hook_aliases.py +++ b/tests/strands/experimental/hooks/test_hook_aliases.py @@ -123,7 +123,7 @@ def test_deprecation_warning_on_import(captured_warnings): assert len(captured_warnings) == 1 assert issubclass(captured_warnings[0].category, DeprecationWarning) - assert "moved to production with updated names" in str(captured_warnings[0].message) + assert "are no longer experimental" in str(captured_warnings[0].message) def test_deprecation_warning_on_import_only_for_experimental(captured_warnings): From e472b929e98a2a3cd5dca9cb3e77cb144b29a8fd Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sun, 30 Nov 2025 18:06:10 -0500 Subject: [PATCH 229/242] addressed comments --- src/strands/experimental/bidi/__init__.py | 6 +++--- src/strands/experimental/bidi/models/__init__.py | 2 +- src/strands/experimental/bidi/models/gemini_live.py | 4 +--- src/strands/experimental/bidi/models/nova_sonic.py | 3 +-- src/strands/experimental/bidi/models/openai_realtime.py | 2 +- .../bidi/models/{test_novasonic.py => test_nova_sonic.py} | 0 .../bidi/models/{test_openai.py => test_openai_realtime.py} | 0 7 files changed, 7 insertions(+), 10 deletions(-) rename tests/strands/experimental/bidi/models/{test_novasonic.py => test_nova_sonic.py} (100%) rename tests/strands/experimental/bidi/models/{test_openai.py => test_openai_realtime.py} (100%) diff --git a/src/strands/experimental/bidi/__init__.py b/src/strands/experimental/bidi/__init__.py index 13c5b51e1..d274bfbcb 100644 --- a/src/strands/experimental/bidi/__init__.py +++ b/src/strands/experimental/bidi/__init__.py @@ -17,11 +17,11 @@ # IO channels - Hardware abstraction from .io.audio import BidiAudioIO -# Model interface (for custom implementations) -from .models.model import BidiModel - # Model providers - What users need to create models from .models.gemini_live import BidiGeminiLiveModel + +# Model interface (for custom implementations) +from .models.model import BidiModel from .models.nova_sonic import BidiNovaSonicModel from .models.openai_realtime import BidiOpenAIRealtimeModel diff --git a/src/strands/experimental/bidi/models/__init__.py b/src/strands/experimental/bidi/models/__init__.py index 29a2229c5..b56208c1e 100644 --- a/src/strands/experimental/bidi/models/__init__.py +++ b/src/strands/experimental/bidi/models/__init__.py @@ -1,7 +1,7 @@ """Bidirectional model interfaces and implementations.""" -from .model import BidiModel, BidiModelTimeoutError from .gemini_live import BidiGeminiLiveModel +from .model import BidiModel, BidiModelTimeoutError from .nova_sonic import BidiNovaSonicModel from .openai_realtime import BidiOpenAIRealtimeModel diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 4f7a9db44..a267211d1 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -4,7 +4,6 @@ official Google GenAI SDK for simplified and robust WebSocket communication. Key improvements over custom WebSocket implementation: - - Uses official google-genai SDK with native Live API support - Simplified session management with client.aio.live.connect() - Built-in tool integration and event handling @@ -25,7 +24,7 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from .._async import stop_all -from ..types.bidi_model import AudioConfig +from ..types.model import AudioConfig from ..types.events import ( AudioChannel, AudioSampleRate, @@ -222,7 +221,6 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> list[BidiOut """Convert Gemini Live API events to provider-agnostic format. Handles different types of content: - - inputTranscription: User's speech transcribed to text - outputTranscription: Model's audio transcribed to text - modelTurn text: Text response from the model diff --git a/src/strands/experimental/bidi/models/nova_sonic.py b/src/strands/experimental/bidi/models/nova_sonic.py index 04037a90f..9ccc3d58f 100644 --- a/src/strands/experimental/bidi/models/nova_sonic.py +++ b/src/strands/experimental/bidi/models/nova_sonic.py @@ -5,7 +5,6 @@ InvokeModelWithBidirectionalStream protocol. Nova Sonic specifics: - - Hierarchical event sequences: connectionStart → promptStart → content streaming - Base64-encoded audio format with hex encoding - Tool execution with content containers and identifier tracking @@ -37,7 +36,7 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from .._async import stop_all -from ..types.bidi_model import AudioConfig +from ..types.model import AudioConfig from ..types.events import ( AudioChannel, AudioSampleRate, diff --git a/src/strands/experimental/bidi/models/openai_realtime.py b/src/strands/experimental/bidi/models/openai_realtime.py index 9a4584365..39312c7d3 100644 --- a/src/strands/experimental/bidi/models/openai_realtime.py +++ b/src/strands/experimental/bidi/models/openai_realtime.py @@ -19,7 +19,7 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from .._async import stop_all -from ..types.bidi_model import AudioConfig +from ..types.model import AudioConfig from ..types.events import ( AudioSampleRate, BidiAudioInputEvent, diff --git a/tests/strands/experimental/bidi/models/test_novasonic.py b/tests/strands/experimental/bidi/models/test_nova_sonic.py similarity index 100% rename from tests/strands/experimental/bidi/models/test_novasonic.py rename to tests/strands/experimental/bidi/models/test_nova_sonic.py diff --git a/tests/strands/experimental/bidi/models/test_openai.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py similarity index 100% rename from tests/strands/experimental/bidi/models/test_openai.py rename to tests/strands/experimental/bidi/models/test_openai_realtime.py From 75a77751a2458e75afbb9f52b765c77698a022d2 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sun, 30 Nov 2025 18:23:23 -0500 Subject: [PATCH 230/242] addrsssed comments --- pyproject.toml | 2 ++ scripts/bidi/test_bidi_openai.py | 2 +- src/strands/experimental/bidi/__init__.py | 6 ------ src/strands/experimental/bidi/models/__init__.py | 4 ---- src/strands/experimental/bidi/models/gemini_live.py | 1 + src/strands/experimental/bidi/models/nova_sonic.py | 1 + 6 files changed, 5 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2a8b250fe..944a1b3a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,8 @@ bidi = [ "prompt_toolkit>=3.0.0,<4.0.0", "pyaudio>=0.2.13,<1.0.0", "smithy-aws-core>=0.0.1; python_version>='3.12'", + "google-genai>=1.32.0,<2.0.0", + "websockets>=15.0.0,<16.0.0", ] bidi-gemini = ["google-genai>=1.32.0,<2.0.0"] bidi-openai = ["websockets>=15.0.0,<16.0.0"] diff --git a/scripts/bidi/test_bidi_openai.py b/scripts/bidi/test_bidi_openai.py index 50d2d2f55..677c12981 100644 --- a/scripts/bidi/test_bidi_openai.py +++ b/scripts/bidi/test_bidi_openai.py @@ -10,7 +10,7 @@ from strands_tools import calculator from strands.experimental.bidi.agent.agent import BidiAgent -from strands.experimental.bidi.models.openai import BidiOpenAIRealtimeModel +from strands.experimental.bidi.models.openai_realtime import BidiOpenAIRealtimeModel async def play(context): diff --git a/src/strands/experimental/bidi/__init__.py b/src/strands/experimental/bidi/__init__.py index d274bfbcb..57986062e 100644 --- a/src/strands/experimental/bidi/__init__.py +++ b/src/strands/experimental/bidi/__init__.py @@ -17,13 +17,9 @@ # IO channels - Hardware abstraction from .io.audio import BidiAudioIO -# Model providers - What users need to create models -from .models.gemini_live import BidiGeminiLiveModel - # Model interface (for custom implementations) from .models.model import BidiModel from .models.nova_sonic import BidiNovaSonicModel -from .models.openai_realtime import BidiOpenAIRealtimeModel # Built-in tools from .tools import stop_conversation @@ -53,9 +49,7 @@ # IO channels "BidiAudioIO", # Model providers - "BidiGeminiLiveModel", "BidiNovaSonicModel", - "BidiOpenAIRealtimeModel", # Built-in tools "stop_conversation", # Input Event types diff --git a/src/strands/experimental/bidi/models/__init__.py b/src/strands/experimental/bidi/models/__init__.py index b56208c1e..cc62c9987 100644 --- a/src/strands/experimental/bidi/models/__init__.py +++ b/src/strands/experimental/bidi/models/__init__.py @@ -1,14 +1,10 @@ """Bidirectional model interfaces and implementations.""" -from .gemini_live import BidiGeminiLiveModel from .model import BidiModel, BidiModelTimeoutError from .nova_sonic import BidiNovaSonicModel -from .openai_realtime import BidiOpenAIRealtimeModel __all__ = [ "BidiModel", "BidiModelTimeoutError", - "BidiGeminiLiveModel", "BidiNovaSonicModel", - "BidiOpenAIRealtimeModel", ] diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index a267211d1..ca69b9453 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -4,6 +4,7 @@ official Google GenAI SDK for simplified and robust WebSocket communication. Key improvements over custom WebSocket implementation: + - Uses official google-genai SDK with native Live API support - Simplified session management with client.aio.live.connect() - Built-in tool integration and event handling diff --git a/src/strands/experimental/bidi/models/nova_sonic.py b/src/strands/experimental/bidi/models/nova_sonic.py index 9ccc3d58f..0cfa51181 100644 --- a/src/strands/experimental/bidi/models/nova_sonic.py +++ b/src/strands/experimental/bidi/models/nova_sonic.py @@ -5,6 +5,7 @@ InvokeModelWithBidirectionalStream protocol. Nova Sonic specifics: + - Hierarchical event sequences: connectionStart → promptStart → content streaming - Base64-encoded audio format with hex encoding - Tool execution with content containers and identifier tracking From 69d8f09134045f16e4088ca2d66d77a26e248bb0 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sun, 30 Nov 2025 18:30:09 -0500 Subject: [PATCH 231/242] address comments --- pyproject.toml | 2 -- src/strands/experimental/bidi/models/gemini_live.py | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 944a1b3a5..2a8b250fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,8 +75,6 @@ bidi = [ "prompt_toolkit>=3.0.0,<4.0.0", "pyaudio>=0.2.13,<1.0.0", "smithy-aws-core>=0.0.1; python_version>='3.12'", - "google-genai>=1.32.0,<2.0.0", - "websockets>=15.0.0,<16.0.0", ] bidi-gemini = ["google-genai>=1.32.0,<2.0.0"] bidi-openai = ["websockets>=15.0.0,<16.0.0"] diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index ca69b9453..dc3810520 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -222,6 +222,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> list[BidiOut """Convert Gemini Live API events to provider-agnostic format. Handles different types of content: + - inputTranscription: User's speech transcribed to text - outputTranscription: Model's audio transcribed to text - modelTurn text: Text response from the model From 98294945ac3bbab041e459fb24c861d2eda9070c Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sun, 30 Nov 2025 18:38:12 -0500 Subject: [PATCH 232/242] minor update --- src/strands/experimental/bidi/models/gemini_live.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index dc3810520..3af8d707f 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -222,7 +222,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> list[BidiOut """Convert Gemini Live API events to provider-agnostic format. Handles different types of content: - + - inputTranscription: User's speech transcribed to text - outputTranscription: Model's audio transcribed to text - modelTurn text: Text response from the model From 4c58c43bcf7e2b694c388ed845e6074313401cb8 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sun, 30 Nov 2025 19:11:44 -0500 Subject: [PATCH 233/242] minor update --- tests_integ/bidi/test_bidirectional_agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests_integ/bidi/test_bidirectional_agent.py b/tests_integ/bidi/test_bidirectional_agent.py index ebc92c852..61cf78723 100644 --- a/tests_integ/bidi/test_bidirectional_agent.py +++ b/tests_integ/bidi/test_bidirectional_agent.py @@ -15,8 +15,8 @@ from strands import tool from strands.experimental.bidi.agent.agent import BidiAgent from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel -from strands.experimental.bidi.models.novasonic import BidiNovaSonicModel -from strands.experimental.bidi.models.openai import BidiOpenAIRealtimeModel +from strands.experimental.bidi.models.nova_sonic import BidiNovaSonicModel +from strands.experimental.bidi.models.openai_realtime import BidiOpenAIRealtimeModel from .context import BidirectionalTestContext from .hook_utils import HookEventCollector From a46828de24bca605ac0e348a9c15f19ae05aa2b5 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 1 Dec 2025 20:35:41 -0500 Subject: [PATCH 234/242] isolate model inference configs (#100) --- .../experimental/bidi/models/gemini_live.py | 42 ++++++------------ .../experimental/bidi/models/nova_sonic.py | 31 ++++++------- .../bidi/models/openai_realtime.py | 43 +++++-------------- .../bidi/models/test_gemini_live.py | 14 +++--- .../bidi/models/test_nova_sonic.py | 2 +- .../bidi/models/test_openai_realtime.py | 4 +- 6 files changed, 46 insertions(+), 90 deletions(-) diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 3af8d707f..88d7f5a0c 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -25,7 +25,6 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from .._async import stop_all -from ..types.model import AudioConfig from ..types.events import ( AudioChannel, AudioSampleRate, @@ -41,6 +40,7 @@ BidiUsageEvent, ModalityUsage, ) +from ..types.model import AudioConfig from .model import BidiModel, BidiModelTimeoutError logger = logging.getLogger(__name__) @@ -70,7 +70,7 @@ def __init__( Args: model_id: Model identifier (default: gemini-2.5-flash-native-audio-preview-09-2025) - provider_config: Model behavior (audio, response_modalities, speech_config, transcription) + provider_config: Model behavior (audio, inference) client_config: Authentication (api_key, http_options) **kwargs: Reserved for future parameters. @@ -108,44 +108,28 @@ def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: """Merge user config with defaults (user takes precedence).""" - # Extract voice from provider-specific speech_config.voice_config.prebuilt_voice_config.voice_name if present - provider_voice = None - if "speech_config" in config and isinstance(config["speech_config"], dict): - provider_voice = ( - config["speech_config"].get("voice_config", {}).get("prebuilt_voice_config", {}).get("voice_name") - ) - - # Define default audio configuration default_audio: AudioConfig = { "input_rate": GEMINI_INPUT_SAMPLE_RATE, "output_rate": GEMINI_OUTPUT_SAMPLE_RATE, "channels": GEMINI_CHANNELS, "format": "pcm", } - - if provider_voice: - default_audio["voice"] = provider_voice - - user_audio = config.get("audio", {}) - merged_audio = {**default_audio, **user_audio} - - default_provider_settings = { + default_inference = { "response_modalities": ["AUDIO"], "outputAudioTranscription": {}, "inputAudioTranscription": {}, } resolved = { - **default_provider_settings, - **config, - "audio": merged_audio, # Audio always uses merged version + "audio": { + **default_audio, + **config.get("audio", {}), + }, + "inference": { + **default_inference, + **config.get("inference", {}), + }, } - - if user_audio: - logger.debug("audio_config | merged user-provided config with defaults") - else: - logger.debug("audio_config | using default Gemini Live audio configuration") - return resolved async def start( @@ -505,9 +489,7 @@ def _build_live_config( Simply passes through all config parameters from provider_config, allowing users to configure any Gemini Live API parameter directly. """ - config_dict: dict[str, Any] = {} - if self.config: - config_dict.update({k: v for k, v in self.config.items() if k != "audio"}) + config_dict: dict[str, Any] = self.config["inference"].copy() config_dict["session_resumption"] = {"handle": kwargs.get("live_session_handle")} diff --git a/src/strands/experimental/bidi/models/nova_sonic.py b/src/strands/experimental/bidi/models/nova_sonic.py index 0cfa51181..6a2477e22 100644 --- a/src/strands/experimental/bidi/models/nova_sonic.py +++ b/src/strands/experimental/bidi/models/nova_sonic.py @@ -37,7 +37,6 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from .._async import stop_all -from ..types.model import AudioConfig from ..types.events import ( AudioChannel, AudioSampleRate, @@ -53,12 +52,16 @@ BidiTranscriptStreamEvent, BidiUsageEvent, ) +from ..types.model import AudioConfig from .model import BidiModel, BidiModelTimeoutError logger = logging.getLogger(__name__) -# Nova Sonic configuration constants -NOVA_INFERENCE_CONFIG = {"maxTokens": 1024, "topP": 0.9, "temperature": 0.7} +_NOVA_INFERENCE_CONFIG_KEYS = { + "max_tokens": "maxTokens", + "temperature": "temperature", + "top_p": "topP", +} NOVA_AUDIO_INPUT_CONFIG = { "mediaType": "audio/lpcm", @@ -156,8 +159,7 @@ def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: """Merge user config with defaults (user takes precedence).""" - # Define default audio configuration - default_audio_config: AudioConfig = { + default_audio: AudioConfig = { "input_rate": cast(AudioSampleRate, NOVA_AUDIO_INPUT_CONFIG["sampleRateHertz"]), "output_rate": cast(AudioSampleRate, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]), "channels": cast(AudioChannel, NOVA_AUDIO_INPUT_CONFIG["channelCount"]), @@ -165,19 +167,13 @@ def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: "voice": cast(str, NOVA_AUDIO_OUTPUT_CONFIG["voiceId"]), } - user_audio_config = config.get("audio", {}) - merged_audio = {**default_audio_config, **user_audio_config} - resolved = { - "audio": merged_audio, - **{k: v for k, v in config.items() if k != "audio"}, + "audio": { + **default_audio, + **config.get("audio", {}), + }, + "inference": config.get("inference", {}), } - - if user_audio_config: - logger.debug("audio_config | merged user-provided config with defaults") - else: - logger.debug("audio_config | using default Nova Sonic audio configuration") - return resolved async def start( @@ -577,7 +573,8 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N def _get_connection_start_event(self) -> str: """Generate Nova Sonic connection start event.""" - return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": NOVA_INFERENCE_CONFIG}}}) + inference_config = {_NOVA_INFERENCE_CONFIG_KEYS[key]: value for key, value in self.config["inference"].items()} + return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": inference_config}}}) def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: """Generate Nova Sonic prompt start event with tool configuration.""" diff --git a/src/strands/experimental/bidi/models/openai_realtime.py b/src/strands/experimental/bidi/models/openai_realtime.py index 39312c7d3..9196a39d5 100644 --- a/src/strands/experimental/bidi/models/openai_realtime.py +++ b/src/strands/experimental/bidi/models/openai_realtime.py @@ -19,7 +19,6 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from .._async import stop_all -from ..types.model import AudioConfig from ..types.events import ( AudioSampleRate, BidiAudioInputEvent, @@ -37,6 +36,7 @@ Role, StopReason, ) +from ..types.model import AudioConfig from .model import BidiModel, BidiModelTimeoutError logger = logging.getLogger(__name__) @@ -160,34 +160,21 @@ def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: """Merge user config with defaults (user takes precedence).""" - # Extract voice from provider-specific audio.output.voice if present - provider_voice = None - if "audio" in config and isinstance(config["audio"], dict): - if "output" in config["audio"] and isinstance(config["audio"]["output"], dict): - provider_voice = config["audio"]["output"].get("voice") - - # Define default audio configuration default_audio: AudioConfig = { "input_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE), "output_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE), "channels": 1, "format": "pcm", - "voice": provider_voice or "alloy", + "voice": "alloy", } - user_audio = config.get("audio", {}) - merged_audio = {**default_audio, **user_audio} - resolved = { - "audio": merged_audio, - **{k: v for k, v in config.items() if k != "audio"}, + "audio": { + **default_audio, + **config.get("audio", {}), + }, + "inference": config.get("inference", {}), } - - if user_audio: - logger.debug("audio_config | merged user-provided config with defaults") - else: - logger.debug("audio_config | using default OpenAI Realtime audio configuration") - return resolved async def start( @@ -277,22 +264,12 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] # Apply user-provided session configuration supported_params = { - "type", + "max_output_tokens", "output_modalities", - "instructions", - "voice", - "tools", "tool_choice", - "input_audio_format", - "output_audio_format", - "input_audio_transcription", - "turn_detection", } - - for key, value in self.config.items(): - if key == "audio": - continue - elif key in supported_params: + for key, value in self.config["inference"].items(): + if key in supported_params: config[key] = value else: logger.warning("parameter=<%s> | ignoring unsupported session parameter", key) diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index c92211816..da516d4a0 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -98,9 +98,9 @@ def test_model_initialization(mock_genai_client, model_id, api_key): assert model_default.api_key is None assert model_default._live_session is None # Check default config includes transcription - assert model_default.config["response_modalities"] == ["AUDIO"] - assert "outputAudioTranscription" in model_default.config - assert "inputAudioTranscription" in model_default.config + assert model_default.config["inference"]["response_modalities"] == ["AUDIO"] + assert "outputAudioTranscription" in model_default.config["inference"] + assert "inputAudioTranscription" in model_default.config["inference"] # Test with API key model_with_key = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) @@ -108,13 +108,13 @@ def test_model_initialization(mock_genai_client, model_id, api_key): assert model_with_key.api_key == api_key # Test with custom config (merges with defaults) - provider_config = {"temperature": 0.7, "top_p": 0.9} + provider_config = {"inference": {"temperature": 0.7, "top_p": 0.9}} model_custom = BidiGeminiLiveModel(model_id=model_id, provider_config=provider_config) # Custom config should be merged with defaults - assert model_custom.config["temperature"] == 0.7 - assert model_custom.config["top_p"] == 0.9 + assert model_custom.config["inference"]["temperature"] == 0.7 + assert model_custom.config["inference"]["top_p"] == 0.9 # Defaults should still be present - assert "response_modalities" in model_custom.config + assert "response_modalities" in model_custom.config["inference"] # Connection Tests diff --git a/tests/strands/experimental/bidi/models/test_nova_sonic.py b/tests/strands/experimental/bidi/models/test_nova_sonic.py index 7ec0c32a1..04f8043be 100644 --- a/tests/strands/experimental/bidi/models/test_nova_sonic.py +++ b/tests/strands/experimental/bidi/models/test_nova_sonic.py @@ -58,7 +58,7 @@ def mock_stream(): @pytest.fixture def mock_client(mock_stream): """Mock Bedrock Runtime client.""" - with patch("strands.experimental.bidi.models.novasonic.BedrockRuntimeClient") as mock_cls: + with patch("strands.experimental.bidi.models.nova_sonic.BedrockRuntimeClient") as mock_cls: mock_instance = AsyncMock() mock_instance.invoke_model_with_bidirectional_stream = AsyncMock(return_value=mock_stream) mock_cls.return_value = mock_instance diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py index 805144446..5c9c0900d 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -46,7 +46,7 @@ def mock_websockets_connect(mock_websocket): async def async_connect(*args, **kwargs): return mock_websocket - with unittest.mock.patch("strands.experimental.bidi.models.openai.websockets.connect") as mock_connect: + with unittest.mock.patch("strands.experimental.bidi.models.openai_realtime.websockets.connect") as mock_connect: mock_connect.side_effect = async_connect yield mock_connect, mock_websocket @@ -515,7 +515,7 @@ async def test_receive_lifecycle_events(mock_websocket, model): assert tru_events == exp_events -@unittest.mock.patch("strands.experimental.bidi.models.openai.time.time") +@unittest.mock.patch("strands.experimental.bidi.models.openai_realtime.time.time") @pytest.mark.asyncio async def test_receive_timeout(mock_time, model): mock_time.side_effect = [1, 2] From 78423eb5843a7211c7080569a893ea06c289b06a Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 1 Dec 2025 21:32:46 -0500 Subject: [PATCH 235/242] fix agent send dict to event construction (#101) --- src/strands/experimental/bidi/agent/agent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 84e0e1e4f..360dfe707 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -261,6 +261,7 @@ async def send(self, input_data: BidiAgentInput | dict[str, Any]) -> None: elif isinstance(input_data, dict) and "type" in input_data: input_type = input_data["type"] + input_data = {key: value for key, value in input_data.items() if key != "type"} if input_type == "bidi_text_input": input_event = BidiTextInputEvent(**input_data) elif input_type == "bidi_audio_input": From 75b91aa13e10ba19c9f4d0f2fc71b25c118bbea0 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 2 Dec 2025 10:49:15 -0500 Subject: [PATCH 236/242] add bidi to README --- README.md | 78 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/README.md b/README.md index 3ff0ec2e4..b3887de10 100644 --- a/README.md +++ b/README.md @@ -184,6 +184,84 @@ Built-in providers: Custom providers can be implemented using [Custom Providers](https://strandsagents.com/latest/user-guide/concepts/model-providers/custom_model_provider/) +### Bidirectional Streaming + +> **⚠️ Experimental Feature**: Bidirectional streaming is currently in experimental status. APIs may change in future releases as we refine the feature based on user feedback and evolving model capabilities. + +Build real-time voice and audio conversations with persistent streaming connections. Unlike traditional request-response patterns, bidirectional streaming maintains long-running conversations where users can interrupt, provide continuous input, and receive real-time audio responses. + +**Key Features:** +- Real-time audio input/output streaming +- Automatic interruption detection +- Concurrent tool execution during conversations +- Support for text, audio, and image inputs +- Provider-agnostic event system + +**Supported Model Providers:** +- Amazon Nova Sonic (`amazon.nova-sonic-v1:0`) +- Google Gemini Live (`gemini-2.5-flash-native-audio-preview-09-2025`) +- OpenAI Realtime API (`gpt-realtime`) + +**Quick Example:** + +```python +from strands.experimental.bidi import BidiAgent +from strands.experimental.bidi.models import BidiNovaSonicModel +from strands.experimental.bidi.io import BidiAudioIO, BidiTextIO +from strands_tools import calculator + +# Create bidirectional agent with audio model +model = BidiNovaSonicModel() +agent = BidiAgent(model=model, tools=[calculator]) + +# Setup audio and text I/O +audio_io = BidiAudioIO() +text_io = BidiTextIO() + +# Run with real-time audio streaming +await agent.run( + inputs=[audio_io.input()], + outputs=[audio_io.output(), text_io.output()] +) +``` + +**Configuration Options:** + +```python +# Configure audio settings +model = BidiNovaSonicModel( + provider_config={ + "audio": { + "input_rate": 16000, + "output_rate": 16000, + "voice": "matthew" + }, + "inference": { + "max_tokens": 2048, + "temperature": 0.7 + } + } +) + +# Configure I/O devices +audio_io = BidiAudioIO( + input_device_index=0, # Specific microphone + output_device_index=1, # Specific speaker + input_buffer_size=10, + output_buffer_size=10 +) +``` + +**Event Types:** + +The bidirectional streaming system uses a rich event model: + +- **Input Events**: `BidiTextInputEvent`, `BidiAudioInputEvent`, `BidiImageInputEvent` +- **Output Events**: `BidiAudioStreamEvent`, `BidiTranscriptStreamEvent`, `BidiInterruptionEvent`, `BidiUsageEvent`, `ToolUseStreamEvent` +- **Lifecycle Events**: `BidiConnectionStartEvent`, `BidiResponseStartEvent`, `BidiResponseCompleteEvent`, `BidiConnectionCloseEvent` + +All events are strongly typed and JSON-serializable for easy integration with web applications and logging systems. + ### Example tools Strands offers an optional strands-agents-tools package with pre-built tools for quick experimentation: From 185db15972c6998760f0b7de8fac687e939bf398 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 2 Dec 2025 11:37:27 -0500 Subject: [PATCH 237/242] address comments --- README.md | 67 ++++++++++++++++++++++++++----------------------------- 1 file changed, 32 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index b3887de10..5038b8661 100644 --- a/README.md +++ b/README.md @@ -184,11 +184,22 @@ Built-in providers: Custom providers can be implemented using [Custom Providers](https://strandsagents.com/latest/user-guide/concepts/model-providers/custom_model_provider/) -### Bidirectional Streaming +### Example tools -> **⚠️ Experimental Feature**: Bidirectional streaming is currently in experimental status. APIs may change in future releases as we refine the feature based on user feedback and evolving model capabilities. +Strands offers an optional strands-agents-tools package with pre-built tools for quick experimentation: + +```python +from strands import Agent +from strands_tools import calculator +agent = Agent(tools=[calculator]) +agent("What is the square root of 1764") +``` + +It's also available on GitHub via [strands-agents/tools](https://github.com/strands-agents/tools). -Build real-time voice and audio conversations with persistent streaming connections. Unlike traditional request-response patterns, bidirectional streaming maintains long-running conversations where users can interrupt, provide continuous input, and receive real-time audio responses. +### [Bidirectional Streaming](https://strandsagents.com/latest/documentation/docs/user-guide/concepts/experimental/bidirectional-streaming/quickstart) + +> **⚠️ Experimental Feature**: Bidirectional streaming is currently in experimental status. APIs may change in future releases as we refine the feature based on user feedback and evolving model capabilities. **Key Features:** - Real-time audio input/output streaming @@ -205,24 +216,31 @@ Build real-time voice and audio conversations with persistent streaming connecti **Quick Example:** ```python +import asyncio from strands.experimental.bidi import BidiAgent from strands.experimental.bidi.models import BidiNovaSonicModel from strands.experimental.bidi.io import BidiAudioIO, BidiTextIO +from strands.experimental.bidi.tools import stop_conversation from strands_tools import calculator -# Create bidirectional agent with audio model -model = BidiNovaSonicModel() -agent = BidiAgent(model=model, tools=[calculator]) +async def main(): + # Create bidirectional agent with audio model + model = BidiNovaSonicModel() + agent = BidiAgent(model=model, tools=[calculator, stop_conversation]) -# Setup audio and text I/O -audio_io = BidiAudioIO() -text_io = BidiTextIO() + # Setup audio and text I/O + audio_io = BidiAudioIO() + text_io = BidiTextIO() -# Run with real-time audio streaming -await agent.run( - inputs=[audio_io.input()], - outputs=[audio_io.output(), text_io.output()] -) + # Run with real-time audio streaming + # Say "stop conversation" to gracefully end the conversation + await agent.run( + inputs=[audio_io.input()], + outputs=[audio_io.output(), text_io.output()] + ) + +if __name__ == "__main__": + asyncio.run(main()) ``` **Configuration Options:** @@ -252,29 +270,8 @@ audio_io = BidiAudioIO( ) ``` -**Event Types:** - -The bidirectional streaming system uses a rich event model: - -- **Input Events**: `BidiTextInputEvent`, `BidiAudioInputEvent`, `BidiImageInputEvent` -- **Output Events**: `BidiAudioStreamEvent`, `BidiTranscriptStreamEvent`, `BidiInterruptionEvent`, `BidiUsageEvent`, `ToolUseStreamEvent` -- **Lifecycle Events**: `BidiConnectionStartEvent`, `BidiResponseStartEvent`, `BidiResponseCompleteEvent`, `BidiConnectionCloseEvent` - All events are strongly typed and JSON-serializable for easy integration with web applications and logging systems. -### Example tools - -Strands offers an optional strands-agents-tools package with pre-built tools for quick experimentation: - -```python -from strands import Agent -from strands_tools import calculator -agent = Agent(tools=[calculator]) -agent("What is the square root of 1764") -``` - -It's also available on GitHub via [strands-agents/tools](https://github.com/strands-agents/tools). - ## Documentation For detailed guidance & examples, explore our documentation: From f9f0e2d223570c1c82c2f7b925a90fbb0448137f Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 2 Dec 2025 11:38:18 -0500 Subject: [PATCH 238/242] address comments --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 5038b8661..d241a5b54 100644 --- a/README.md +++ b/README.md @@ -270,8 +270,6 @@ audio_io = BidiAudioIO( ) ``` -All events are strongly typed and JSON-serializable for easy integration with web applications and logging systems. - ## Documentation For detailed guidance & examples, explore our documentation: From 6cbac512c1384dc002bd3edcb7a028cb1eebf580 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 2 Dec 2025 11:41:41 -0500 Subject: [PATCH 239/242] address comments --- README.md | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index d241a5b54..7af8af333 100644 --- a/README.md +++ b/README.md @@ -197,16 +197,11 @@ agent("What is the square root of 1764") It's also available on GitHub via [strands-agents/tools](https://github.com/strands-agents/tools). -### [Bidirectional Streaming](https://strandsagents.com/latest/documentation/docs/user-guide/concepts/experimental/bidirectional-streaming/quickstart) +### Bidirectional Streaming > **⚠️ Experimental Feature**: Bidirectional streaming is currently in experimental status. APIs may change in future releases as we refine the feature based on user feedback and evolving model capabilities. -**Key Features:** -- Real-time audio input/output streaming -- Automatic interruption detection -- Concurrent tool execution during conversations -- Support for text, audio, and image inputs -- Provider-agnostic event system +Build real-time voice and audio conversations with persistent streaming connections. Unlike traditional request-response patterns, bidirectional streaming maintains long-running conversations where users can interrupt, provide continuous input, and receive real-time audio responses. Get started with your first BidiAgent by following the [Quickstart]((https://strandsagents.com/latest/documentation/docs/user-guide/concepts/experimental/bidirectional-streaming/quickstart)) guide. **Supported Model Providers:** - Amazon Nova Sonic (`amazon.nova-sonic-v1:0`) From 873da65a6e9c80666750de0780b48a7eaf3252cb Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 2 Dec 2025 11:44:56 -0500 Subject: [PATCH 240/242] address comments --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7af8af333..e7d1b2a7e 100644 --- a/README.md +++ b/README.md @@ -201,7 +201,7 @@ It's also available on GitHub via [strands-agents/tools](https://github.com/stra > **⚠️ Experimental Feature**: Bidirectional streaming is currently in experimental status. APIs may change in future releases as we refine the feature based on user feedback and evolving model capabilities. -Build real-time voice and audio conversations with persistent streaming connections. Unlike traditional request-response patterns, bidirectional streaming maintains long-running conversations where users can interrupt, provide continuous input, and receive real-time audio responses. Get started with your first BidiAgent by following the [Quickstart]((https://strandsagents.com/latest/documentation/docs/user-guide/concepts/experimental/bidirectional-streaming/quickstart)) guide. +Build real-time voice and audio conversations with persistent streaming connections. Unlike traditional request-response patterns, bidirectional streaming maintains long-running conversations where users can interrupt, provide continuous input, and receive real-time audio responses. Get started with your first BidiAgent by following the [Quickstart](https://strandsagents.com/latest/documentation/docs/user-guide/concepts/experimental/bidirectional-streaming/quickstart) guide. **Supported Model Providers:** - Amazon Nova Sonic (`amazon.nova-sonic-v1:0`) From 4ce00d8c9266df533f6041dd0dcd849a9f568767 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 2 Dec 2025 12:33:04 -0500 Subject: [PATCH 241/242] fix minor linting and integ test failure errors --- pyproject.toml | 2 +- src/strands/types/_events.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2a8b250fe..f5738a68b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -235,7 +235,7 @@ convention = "google" [tool.pytest.ini_options] testpaths = ["tests"] asyncio_default_fixture_loop_scope = "function" -addopts = "--ignore=tests/strands/experimental/bidi" +addopts = "--ignore=tests/strands/experimental/bidi --ignore=tests_integ/bidi" [tool.coverage.run] diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 1337237ba..efe0894ea 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -286,7 +286,7 @@ def __init__(self, tool_result: ToolResult) -> None: @property def tool_use_id(self) -> str: """The toolUseId associated with this result.""" - return cast(str, cast(ToolResult, self.get("tool_result")).get("toolUseId")) + return cast(ToolResult, self.get("tool_result")).get("toolUseId") @property def tool_result(self) -> ToolResult: @@ -314,7 +314,7 @@ def __init__(self, tool_use: ToolUse, tool_stream_data: Any) -> None: @property def tool_use_id(self) -> str: """The toolUseId associated with this stream.""" - return cast(str, cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use")).get("toolUseId")) + return cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use")).get("toolUseId") class ToolCancelEvent(TypedEvent): @@ -332,7 +332,7 @@ def __init__(self, tool_use: ToolUse, message: str) -> None: @property def tool_use_id(self) -> str: """The id of the tool cancelled.""" - return cast(str, cast(ToolUse, cast(dict, self.get("tool_cancel_event")).get("tool_use")).get("toolUseId")) + return cast(ToolUse, cast(dict, self.get("tool_cancel_event")).get("tool_use")).get("toolUseId") @property def message(self) -> str: @@ -350,7 +350,7 @@ def __init__(self, tool_use: ToolUse, interrupts: list[Interrupt]) -> None: @property def tool_use_id(self) -> str: """The id of the tool interrupted.""" - return cast(str, cast(ToolUse, cast(dict, self.get("tool_interrupt_event")).get("tool_use")).get("toolUseId")) + return cast(ToolUse, cast(dict, self.get("tool_interrupt_event")).get("tool_use")).get("toolUseId") @property def interrupts(self) -> list[Interrupt]: From 2498674b38d3f7dae6b81039d5591bfdd98d1643 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 3 Dec 2025 09:59:44 -0500 Subject: [PATCH 242/242] pyproject - static analysis - python 3.10 --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index f5738a68b..2c2a6b260 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,6 +118,7 @@ dependencies = [ # Include required package dependencies for mypy "strands-agents @ {root:uri}", ] +python = "3.10" # Define static-analysis scripts so we can include mypy as part of the linting check [tool.hatch.envs.hatch-static-analysis.scripts]