From 21065d111737b99b601af5955d98b1d74b52ec51 Mon Sep 17 00:00:00 2001 From: Julian Grueber Date: Sat, 11 Oct 2025 14:15:21 +0200 Subject: [PATCH 1/2] feat: copy agent with new session manager --- src/strands/agent/agent.py | 44 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 8607a2601..f8460d982 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -844,3 +844,47 @@ def _append_message(self, message: Message) -> None: """Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent.""" self.messages.append(message) self.hooks.invoke_callbacks(MessageAddedEvent(agent=self, message=message)) + + def with_session_manager( + self, session_manager: SessionManager, request_metadata: dict[str, Any] | None = None + ) -> "Agent": + """Create a new agent instance with session management enabled. + + This method creates a copy of the current agent instance and adds session + management capabilities while preserving all other configuration. The original + agent must not already have a session manager or any messages. + + Args: + session_manager: The session manager to add to the new agent instance. + request_metadata: Optional metadata to add to the new agent's state. + + Returns: + A new Agent instance with the same configuration plus session management. + + Raises: + ValueError: If the current agent already has a session manager or messages. + """ + import copy + + if self._session_manager is not None: + raise ValueError("Agent must not have a session manager") + + if self.messages: + raise ValueError("Agent must not have messages") + + # Create a deep copy of the current agent + new_agent = copy.deepcopy(self) + + # Reset the session manager and messages + new_agent._session_manager = session_manager + + # Add request metadata to the new agent's state + if request_metadata: + new_agent.state.set("request_metadata", request_metadata) + + # Re-register the new session manager hook + # Since we can't easily remove the old session manager hook, we'll just add the new one + # The new session manager will register its own hooks + new_agent.hooks.add_hook(session_manager) + + return new_agent From 2a6178aa8fa8c70767f4443e860933ebc2f6b611 Mon Sep 17 00:00:00 2001 From: Julian Grueber Date: Sat, 11 Oct 2025 14:16:38 +0200 Subject: [PATCH 2/2] fix: a2a server --- src/strands/multiagent/a2a/executor.py | 77 +++++++++++++++++++----- src/strands/multiagent/a2a/server.py | 12 +++- tests/strands/multiagent/a2a/conftest.py | 16 ++++- 3 files changed, 88 insertions(+), 17 deletions(-) diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index 74ecc6531..48bd15b9b 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -11,7 +11,8 @@ import json import logging import mimetypes -from typing import Any, Literal +import uuid +from typing import Any, Callable, Literal from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.events import EventQueue @@ -22,6 +23,8 @@ from ...agent.agent import Agent as SAAgent from ...agent.agent import AgentResult as SAAgentResult +from ...session.file_session_manager import FileSessionManager +from ...session.session_manager import SessionManager from ...types.content import ContentBlock from ...types.media import ( DocumentContent, @@ -48,13 +51,19 @@ class StrandsA2AExecutor(AgentExecutor): # Handle special cases where format differs from extension FORMAT_MAPPINGS = {"jpg": "jpeg", "htm": "html", "3gp": "three_gp", "3gpp": "three_gp", "3g2": "three_gp"} - def __init__(self, agent: SAAgent): + def __init__(self, agent: SAAgent, session_manager_factory: Callable[[str], SessionManager] | None = None): """Initialize a StrandsA2AExecutor. Args: agent: The Strands Agent instance to adapt to the A2A protocol. + session_manager_factory: A callable that takes a session_id (str) and returns a SessionManager. """ self.agent = agent + if session_manager_factory is None: + logger.warning("No session_manager_factory provided. Using FileSessionManager as default.") + self.session_manager_factory = self._default_session_manager_factory + else: + self.session_manager_factory = session_manager_factory # type: ignore[assignment] async def execute( self, @@ -63,15 +72,34 @@ async def execute( ) -> None: """Execute a request using the Strands Agent and send the response as A2A events. - This method executes the user's input using the Strands Agent in streaming mode - and converts the agent's response to A2A events. + This method processes an A2A request by converting the incoming message parts + to Strands ContentBlocks, executing the agent with proper session management, + and streaming the response back as A2A events. + + The method handles various content types including: + - Text content + - Image files (with bytes or URI) + - Video files (with bytes or URI) + - Document files (with bytes or URI) + - Structured data (JSON) Args: - context: The A2A request context, containing the user's input and task metadata. - event_queue: The A2A event queue used to send response events back to the client. + context: The A2A request context containing: + - User input message parts + - Task and session metadata + - Context ID for session management + event_queue: The A2A event queue for sending response events back to the client. Raises: - ServerError: If an error occurs during agent execution + ServerError: If an error occurs during: + - Message part conversion + - Agent execution + - Event streaming + ValueError: If the context ID is missing or invalid + + Note: + This method creates a new agent instance with a session manager for each + request to ensure proper session isolation and state management. """ task = context.current_task if not task: @@ -103,8 +131,14 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater else: raise ValueError("No content blocks available") + context_id = context.context_id if context.context_id else str(uuid.uuid4()) + + agent = self.agent.with_session_manager( + session_manager=self.session_manager_factory(session_id=context_id), request_metadata=context.metadata + ) + try: - async for event in self.agent.stream_async(content_blocks): + async for event in agent.stream_async(content_blocks): await self._handle_streaming_event(event, updater) except Exception: logger.exception("Error in streaming execution") @@ -155,17 +189,21 @@ async def _handle_agent_result(self, result: SAAgentResult | None, updater: Task async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: """Cancel an ongoing execution. - This method is called when a request cancellation is requested. Currently, - cancellation is not supported by the Strands Agent executor, so this method - always raises an UnsupportedOperationError. + This method is called when a request cancellation is requested by the client. + Currently, cancellation is not supported by the Strands Agent executor, as + the underlying agent execution cannot be interrupted once started. Args: - context: The A2A request context. - event_queue: The A2A event queue. + context: The A2A request context containing the cancellation request. + event_queue: The A2A event queue (unused in this implementation). Raises: ServerError: Always raised with an UnsupportedOperationError, as cancellation - is not currently supported. + is not currently supported by the Strands Agent executor. + + Note: + Future versions may support cancellation by implementing proper task + interruption mechanisms in the underlying agent execution. """ logger.warning("Cancellation requested but not supported") raise ServerError(error=UnsupportedOperationError()) @@ -197,6 +235,17 @@ def _get_file_type_from_mime_type(self, mime_type: str | None) -> Literal["docum else: return "unknown" + def _default_session_manager_factory(self, session_id: str) -> SessionManager: + """Default session manager factory using FileSessionManager. + + Args: + session_id(str): The session ID for the session manager. + + Returns: + SessionManager: A FileSessionManager instance for the given session ID. + """ + return FileSessionManager(session_id=session_id) + def _get_file_format_from_mime_type(self, mime_type: str | None, file_type: str) -> str: """Extract file format from MIME type using Python's mimetypes library. diff --git a/src/strands/multiagent/a2a/server.py b/src/strands/multiagent/a2a/server.py index bbfbc824d..bf97e94bd 100644 --- a/src/strands/multiagent/a2a/server.py +++ b/src/strands/multiagent/a2a/server.py @@ -5,7 +5,7 @@ """ import logging -from typing import Any, Literal +from typing import Any, Callable, Literal from urllib.parse import urlparse import uvicorn @@ -18,6 +18,7 @@ from starlette.applications import Starlette from ...agent.agent import Agent as SAAgent +from ...session.session_manager import SessionManager from .executor import StrandsA2AExecutor logger = logging.getLogger(__name__) @@ -30,6 +31,7 @@ def __init__( self, agent: SAAgent, *, + session_manager_factory: Callable[[str], SessionManager] | None = None, # AgentCard host: str = "127.0.0.1", port: int = 9000, @@ -47,6 +49,9 @@ def __init__( Args: agent: The Strands Agent to wrap with A2A compatibility. + session_manager_factory: A callable that takes a session_id (str) and returns a SessionManager. + This factory will be used to create session managers for each agent context. + If None, defaults to using FileSessionManager with a warning. host: The hostname or IP address to bind the A2A server to. Defaults to "127.0.0.1". port: The port to bind the A2A server to. Defaults to 9000. http_url: The public HTTP URL where this agent will be accessible. If provided, @@ -90,7 +95,10 @@ def __init__( self.description = self.strands_agent.description self.capabilities = AgentCapabilities(streaming=True) self.request_handler = DefaultRequestHandler( - agent_executor=StrandsA2AExecutor(self.strands_agent), + agent_executor=StrandsA2AExecutor( + agent=self.strands_agent, + session_manager_factory=session_manager_factory, + ), task_store=task_store or InMemoryTaskStore(), queue_manager=queue_manager, push_config_store=push_config_store, diff --git a/tests/strands/multiagent/a2a/conftest.py b/tests/strands/multiagent/a2a/conftest.py index e0061a025..72a2b89b5 100644 --- a/tests/strands/multiagent/a2a/conftest.py +++ b/tests/strands/multiagent/a2a/conftest.py @@ -1,6 +1,6 @@ """Common fixtures for A2A module tests.""" -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, PropertyMock import pytest from a2a.server.agent_execution import RequestContext @@ -31,6 +31,19 @@ def mock_strands_agent(): mock_tool_registry.get_all_tools_config.return_value = {} agent.tool_registry = mock_tool_registry + # Setup with_session_manager to return a copy of the agent + def mock_with_session_manager(session_manager=None, request_metadata=None): + """Create a copy of the agent with session manager.""" + agent_copy = MagicMock(spec=SAAgent) + agent_copy.name = agent.name + agent_copy.description = agent.description + agent_copy.invoke_async = agent.invoke_async + agent_copy.stream_async = agent.stream_async + agent_copy.tool_registry = agent.tool_registry + return agent_copy + + agent.with_session_manager = MagicMock(side_effect=mock_with_session_manager) + return agent @@ -39,6 +52,7 @@ def mock_request_context(): """Create a mock RequestContext for testing.""" context = MagicMock(spec=RequestContext) context.get_user_input.return_value = "Test input" + type(context).context_id = PropertyMock(return_value="test-context-id") return context