Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,3 +844,47 @@ def _append_message(self, message: Message) -> None:
"""Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent."""
self.messages.append(message)
self.hooks.invoke_callbacks(MessageAddedEvent(agent=self, message=message))

def with_session_manager(
self, session_manager: SessionManager, request_metadata: dict[str, Any] | None = None
) -> "Agent":
"""Create a new agent instance with session management enabled.

This method creates a copy of the current agent instance and adds session
management capabilities while preserving all other configuration. The original
agent must not already have a session manager or any messages.

Args:
session_manager: The session manager to add to the new agent instance.
request_metadata: Optional metadata to add to the new agent's state.

Returns:
A new Agent instance with the same configuration plus session management.

Raises:
ValueError: If the current agent already has a session manager or messages.
"""
import copy

if self._session_manager is not None:
raise ValueError("Agent must not have a session manager")

if self.messages:
raise ValueError("Agent must not have messages")

# Create a deep copy of the current agent
new_agent = copy.deepcopy(self)

# Reset the session manager and messages
new_agent._session_manager = session_manager

# Add request metadata to the new agent's state
if request_metadata:
new_agent.state.set("request_metadata", request_metadata)

# Re-register the new session manager hook
# Since we can't easily remove the old session manager hook, we'll just add the new one
# The new session manager will register its own hooks
new_agent.hooks.add_hook(session_manager)

return new_agent
77 changes: 63 additions & 14 deletions src/strands/multiagent/a2a/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import json
import logging
import mimetypes
from typing import Any, Literal
import uuid
from typing import Any, Callable, Literal

from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.events import EventQueue
Expand All @@ -22,6 +23,8 @@

from ...agent.agent import Agent as SAAgent
from ...agent.agent import AgentResult as SAAgentResult
from ...session.file_session_manager import FileSessionManager
from ...session.session_manager import SessionManager
from ...types.content import ContentBlock
from ...types.media import (
DocumentContent,
Expand All @@ -48,13 +51,19 @@ class StrandsA2AExecutor(AgentExecutor):
# Handle special cases where format differs from extension
FORMAT_MAPPINGS = {"jpg": "jpeg", "htm": "html", "3gp": "three_gp", "3gpp": "three_gp", "3g2": "three_gp"}

def __init__(self, agent: SAAgent):
def __init__(self, agent: SAAgent, session_manager_factory: Callable[[str], SessionManager] | None = None):
"""Initialize a StrandsA2AExecutor.

Args:
agent: The Strands Agent instance to adapt to the A2A protocol.
session_manager_factory: A callable that takes a session_id (str) and returns a SessionManager.
"""
self.agent = agent
if session_manager_factory is None:
logger.warning("No session_manager_factory provided. Using FileSessionManager as default.")
self.session_manager_factory = self._default_session_manager_factory
else:
self.session_manager_factory = session_manager_factory # type: ignore[assignment]

async def execute(
self,
Expand All @@ -63,15 +72,34 @@ async def execute(
) -> None:
"""Execute a request using the Strands Agent and send the response as A2A events.

This method executes the user's input using the Strands Agent in streaming mode
and converts the agent's response to A2A events.
This method processes an A2A request by converting the incoming message parts
to Strands ContentBlocks, executing the agent with proper session management,
and streaming the response back as A2A events.

The method handles various content types including:
- Text content
- Image files (with bytes or URI)
- Video files (with bytes or URI)
- Document files (with bytes or URI)
- Structured data (JSON)

Args:
context: The A2A request context, containing the user's input and task metadata.
event_queue: The A2A event queue used to send response events back to the client.
context: The A2A request context containing:
- User input message parts
- Task and session metadata
- Context ID for session management
event_queue: The A2A event queue for sending response events back to the client.

Raises:
ServerError: If an error occurs during agent execution
ServerError: If an error occurs during:
- Message part conversion
- Agent execution
- Event streaming
ValueError: If the context ID is missing or invalid

Note:
This method creates a new agent instance with a session manager for each
request to ensure proper session isolation and state management.
"""
task = context.current_task
if not task:
Expand Down Expand Up @@ -103,8 +131,14 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater
else:
raise ValueError("No content blocks available")

context_id = context.context_id if context.context_id else str(uuid.uuid4())

agent = self.agent.with_session_manager(
session_manager=self.session_manager_factory(session_id=context_id), request_metadata=context.metadata
)

try:
async for event in self.agent.stream_async(content_blocks):
async for event in agent.stream_async(content_blocks):
await self._handle_streaming_event(event, updater)
except Exception:
logger.exception("Error in streaming execution")
Expand Down Expand Up @@ -155,17 +189,21 @@ async def _handle_agent_result(self, result: SAAgentResult | None, updater: Task
async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
"""Cancel an ongoing execution.

This method is called when a request cancellation is requested. Currently,
cancellation is not supported by the Strands Agent executor, so this method
always raises an UnsupportedOperationError.
This method is called when a request cancellation is requested by the client.
Currently, cancellation is not supported by the Strands Agent executor, as
the underlying agent execution cannot be interrupted once started.

Args:
context: The A2A request context.
event_queue: The A2A event queue.
context: The A2A request context containing the cancellation request.
event_queue: The A2A event queue (unused in this implementation).

Raises:
ServerError: Always raised with an UnsupportedOperationError, as cancellation
is not currently supported.
is not currently supported by the Strands Agent executor.

Note:
Future versions may support cancellation by implementing proper task
interruption mechanisms in the underlying agent execution.
"""
logger.warning("Cancellation requested but not supported")
raise ServerError(error=UnsupportedOperationError())
Expand Down Expand Up @@ -197,6 +235,17 @@ def _get_file_type_from_mime_type(self, mime_type: str | None) -> Literal["docum
else:
return "unknown"

def _default_session_manager_factory(self, session_id: str) -> SessionManager:
"""Default session manager factory using FileSessionManager.

Args:
session_id(str): The session ID for the session manager.

Returns:
SessionManager: A FileSessionManager instance for the given session ID.
"""
return FileSessionManager(session_id=session_id)

def _get_file_format_from_mime_type(self, mime_type: str | None, file_type: str) -> str:
"""Extract file format from MIME type using Python's mimetypes library.

Expand Down
12 changes: 10 additions & 2 deletions src/strands/multiagent/a2a/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import logging
from typing import Any, Literal
from typing import Any, Callable, Literal
from urllib.parse import urlparse

import uvicorn
Expand All @@ -18,6 +18,7 @@
from starlette.applications import Starlette

from ...agent.agent import Agent as SAAgent
from ...session.session_manager import SessionManager
from .executor import StrandsA2AExecutor

logger = logging.getLogger(__name__)
Expand All @@ -30,6 +31,7 @@ def __init__(
self,
agent: SAAgent,
*,
session_manager_factory: Callable[[str], SessionManager] | None = None,
# AgentCard
host: str = "127.0.0.1",
port: int = 9000,
Expand All @@ -47,6 +49,9 @@ def __init__(

Args:
agent: The Strands Agent to wrap with A2A compatibility.
session_manager_factory: A callable that takes a session_id (str) and returns a SessionManager.
This factory will be used to create session managers for each agent context.
If None, defaults to using FileSessionManager with a warning.
host: The hostname or IP address to bind the A2A server to. Defaults to "127.0.0.1".
port: The port to bind the A2A server to. Defaults to 9000.
http_url: The public HTTP URL where this agent will be accessible. If provided,
Expand Down Expand Up @@ -90,7 +95,10 @@ def __init__(
self.description = self.strands_agent.description
self.capabilities = AgentCapabilities(streaming=True)
self.request_handler = DefaultRequestHandler(
agent_executor=StrandsA2AExecutor(self.strands_agent),
agent_executor=StrandsA2AExecutor(
agent=self.strands_agent,
session_manager_factory=session_manager_factory,
),
task_store=task_store or InMemoryTaskStore(),
queue_manager=queue_manager,
push_config_store=push_config_store,
Expand Down
16 changes: 15 additions & 1 deletion tests/strands/multiagent/a2a/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Common fixtures for A2A module tests."""

from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, PropertyMock

import pytest
from a2a.server.agent_execution import RequestContext
Expand Down Expand Up @@ -31,6 +31,19 @@ def mock_strands_agent():
mock_tool_registry.get_all_tools_config.return_value = {}
agent.tool_registry = mock_tool_registry

# Setup with_session_manager to return a copy of the agent
def mock_with_session_manager(session_manager=None, request_metadata=None):
"""Create a copy of the agent with session manager."""
agent_copy = MagicMock(spec=SAAgent)
agent_copy.name = agent.name
agent_copy.description = agent.description
agent_copy.invoke_async = agent.invoke_async
agent_copy.stream_async = agent.stream_async
agent_copy.tool_registry = agent.tool_registry
return agent_copy

agent.with_session_manager = MagicMock(side_effect=mock_with_session_manager)

return agent


Expand All @@ -39,6 +52,7 @@ def mock_request_context():
"""Create a mock RequestContext for testing."""
context = MagicMock(spec=RequestContext)
context.get_user_input.return_value = "Test input"
type(context).context_id = PropertyMock(return_value="test-context-id")
return context


Expand Down