From a7b40366d2f5230b532565a29c345164cb88699a Mon Sep 17 00:00:00 2001 From: Arron Bailiss Date: Thu, 13 Nov 2025 00:21:02 -0500 Subject: [PATCH 1/5] feat(a2a): add A2AAgent class as an implementation of the agent interface for remote A2A protocol based agents --- src/strands/agent/__init__.py | 8 +- src/strands/agent/a2a_agent.py | 283 +++++++++++++++++++++++++ tests/strands/agent/test_a2a_agent.py | 287 ++++++++++++++++++++++++++ 3 files changed, 571 insertions(+), 7 deletions(-) create mode 100644 src/strands/agent/a2a_agent.py create mode 100644 tests/strands/agent/test_a2a_agent.py diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index 6618d3328..8497a67d2 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -1,10 +1,4 @@ -"""This package provides the core Agent interface and supporting components for building AI agents with the SDK. - -It includes: - -- Agent: The main interface for interacting with AI models and tools -- ConversationManager: Classes for managing conversation history and context windows -""" +"""This package provides the core Agent interface and supporting components for building AI agents with the SDK.""" from .agent import Agent from .agent_result import AgentResult diff --git a/src/strands/agent/a2a_agent.py b/src/strands/agent/a2a_agent.py new file mode 100644 index 000000000..bd58c162b --- /dev/null +++ b/src/strands/agent/a2a_agent.py @@ -0,0 +1,283 @@ +"""A2A Agent client for Strands Agents. + +This module provides the A2AAgent class, which acts as a client wrapper for remote A2A agents, +allowing them to be used in graphs, swarms, and other multi-agent patterns. +""" + +import logging +from typing import Any, AsyncIterator, cast +from uuid import uuid4 + +import httpx +from a2a.client import A2ACardResolver, ClientConfig, ClientFactory +from a2a.types import AgentCard, Part, Role, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart +from a2a.types import Message as A2AMessage + +from .._async import run_async +from ..telemetry.metrics import EventLoopMetrics +from ..types.agent import AgentInput +from ..types.content import ContentBlock, Message +from .agent_result import AgentResult + +logger = logging.getLogger(__name__) + +DEFAULT_TIMEOUT = 300 + + +class A2AAgent: + """Client wrapper for remote A2A agents. + + Implements the AgentBase protocol to enable remote A2A agents to be used + in graphs, swarms, and other multi-agent patterns. + """ + + def __init__( + self, + endpoint: str, + timeout: int = DEFAULT_TIMEOUT, + httpx_client_args: dict[str, Any] | None = None, + ): + """Initialize A2A agent client. + + Args: + endpoint: The base URL of the remote A2A agent + timeout: Timeout for HTTP operations in seconds (defaults to 300) + httpx_client_args: Optional dictionary of arguments to pass to httpx.AsyncClient + constructor. Allows custom auth, headers, proxies, etc. + Example: {"headers": {"Authorization": "Bearer token"}} + """ + self.endpoint = endpoint + self.timeout = timeout + self._httpx_client_args: dict[str, Any] = httpx_client_args or {} + + if "timeout" not in self._httpx_client_args: + self._httpx_client_args["timeout"] = self.timeout + + self._agent_card: AgentCard | None = None + + def _get_httpx_client(self) -> httpx.AsyncClient: + """Get a fresh httpx client for the current operation. + + Returns: + Configured httpx.AsyncClient instance. + """ + return httpx.AsyncClient(**self._httpx_client_args) + + def _get_client_factory(self, streaming: bool = False) -> ClientFactory: + """Get a ClientFactory for the current operation. + + Args: + streaming: Whether to enable streaming mode. + + Returns: + Configured ClientFactory instance. + """ + httpx_client = self._get_httpx_client() + config = ClientConfig( + httpx_client=httpx_client, + streaming=streaming, + ) + return ClientFactory(config) + + async def _discover_agent_card(self) -> AgentCard: + """Discover and cache the agent card from the remote endpoint. + + Returns: + The discovered AgentCard. + """ + if self._agent_card is not None: + return self._agent_card + + httpx_client = self._get_httpx_client() + resolver = A2ACardResolver(httpx_client=httpx_client, base_url=self.endpoint) + self._agent_card = await resolver.get_agent_card() + logger.info("endpoint=<%s> | discovered agent card", self.endpoint) + return self._agent_card + + def _convert_input_to_message(self, prompt: AgentInput) -> A2AMessage: + """Convert AgentInput to A2A Message. + + Args: + prompt: Input in various formats (string, message list, or content blocks). + + Returns: + A2AMessage ready to send to the remote agent. + + Raises: + ValueError: If prompt format is unsupported. + """ + message_id = uuid4().hex + + if isinstance(prompt, str): + return A2AMessage( + kind="message", + role=Role.user, + parts=[Part(TextPart(kind="text", text=prompt))], + message_id=message_id, + ) + + if isinstance(prompt, list) and prompt and (isinstance(prompt[0], dict)): + if "role" in prompt[0]: + # Message list - extract last user message + for msg in reversed(prompt): + if msg.get("role") == "user": + content = cast(list[ContentBlock], msg.get("content", [])) + parts = self._convert_content_blocks_to_parts(content) + return A2AMessage( + kind="message", + role=Role.user, + parts=parts, + message_id=message_id, + ) + else: + # ContentBlock list + parts = self._convert_content_blocks_to_parts(cast(list[ContentBlock], prompt)) + return A2AMessage( + kind="message", + role=Role.user, + parts=parts, + message_id=message_id, + ) + + raise ValueError(f"Unsupported input type: {type(prompt)}") + + def _convert_content_blocks_to_parts(self, content_blocks: list[ContentBlock]) -> list[Part]: + """Convert Strands ContentBlocks to A2A Parts. + + Args: + content_blocks: List of Strands content blocks. + + Returns: + List of A2A Part objects. + """ + parts = [] + for block in content_blocks: + if "text" in block: + parts.append(Part(TextPart(kind="text", text=block["text"]))) + return parts + + def _convert_response_to_agent_result(self, response: Any) -> AgentResult: + """Convert A2A response to AgentResult. + + Args: + response: A2A response (either A2AMessage or tuple of task and update event). + + Returns: + AgentResult with extracted content and metadata. + """ + content: list[ContentBlock] = [] + + if isinstance(response, tuple) and len(response) == 2: + task, update_event = response + if update_event is None and task and hasattr(task, "artifacts"): + # Non-streaming response: extract from task artifacts + for artifact in task.artifacts: + if hasattr(artifact, "parts"): + for part in artifact.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + elif isinstance(response, A2AMessage): + # Direct message response + for part in response.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + + message: Message = { + "role": "assistant", + "content": content, + } + + return AgentResult( + stop_reason="end_turn", + message=message, + metrics=EventLoopMetrics(), + state={}, + ) + + async def _send_message( + self, prompt: AgentInput, streaming: bool + ) -> AsyncIterator[tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | A2AMessage]: + """Send message to A2A agent. + + Args: + prompt: Input to send to the agent. + streaming: Whether to use streaming mode. + + Returns: + Async iterator of A2A events. + + Raises: + ValueError: If prompt is None. + """ + if prompt is None: + raise ValueError("prompt is required for A2AAgent") + + agent_card = await self._discover_agent_card() + client = self._get_client_factory(streaming=streaming).create(agent_card) + message = self._convert_input_to_message(prompt) + + logger.info("endpoint=<%s> | %s message", self.endpoint, "streaming" if streaming else "sending") + return client.send_message(message) + + async def invoke_async( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AgentResult: + """Asynchronously invoke the remote A2A agent. + + Args: + prompt: Input to the agent (string, message list, or content blocks). + **kwargs: Additional arguments (ignored). + + Returns: + AgentResult containing the agent's response. + + Raises: + ValueError: If prompt is None. + RuntimeError: If no response received from agent. + """ + async for event in await self._send_message(prompt, streaming=False): + return self._convert_response_to_agent_result(event) + + raise RuntimeError("No response received from A2A agent") + + def __call__( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AgentResult: + """Synchronously invoke the remote A2A agent. + + Args: + prompt: Input to the agent (string, message list, or content blocks). + **kwargs: Additional arguments (ignored). + + Returns: + AgentResult containing the agent's response. + + Raises: + ValueError: If prompt is None. + RuntimeError: If no response received from agent. + """ + return run_async(lambda: self.invoke_async(prompt, **kwargs)) + + async def stream_async( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AsyncIterator[Any]: + """Stream agent execution asynchronously. + + Args: + prompt: Input to the agent (string, message list, or content blocks). + **kwargs: Additional arguments (ignored). + + Yields: + A2A events wrapped in dictionaries with an 'a2a_event' key. + + Raises: + ValueError: If prompt is None. + """ + async for event in await self._send_message(prompt, streaming=True): + yield {"a2a_event": event} diff --git a/tests/strands/agent/test_a2a_agent.py b/tests/strands/agent/test_a2a_agent.py new file mode 100644 index 000000000..fd94fc4a6 --- /dev/null +++ b/tests/strands/agent/test_a2a_agent.py @@ -0,0 +1,287 @@ +"""Tests for A2AAgent class.""" + +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest +from a2a.types import AgentCard, Part, Role, TextPart +from a2a.types import Message as A2AMessage + +from strands.agent.a2a_agent import A2AAgent +from strands.agent.agent_result import AgentResult + + +@pytest.fixture +def mock_agent_card(): + """Mock AgentCard for testing.""" + return AgentCard( + name="test-agent", + description="Test agent", + url="http://localhost:8000", + version="1.0.0", + capabilities={}, + default_input_modes=["text/plain"], + default_output_modes=["text/plain"], + skills=[], + ) + + +@pytest.fixture +def a2a_agent(): + """Create A2AAgent instance for testing.""" + return A2AAgent(endpoint="http://localhost:8000") + + +def test_init_with_defaults(): + """Test initialization with default parameters.""" + agent = A2AAgent(endpoint="http://localhost:8000") + assert agent.endpoint == "http://localhost:8000" + assert agent.timeout == 300 + assert agent._agent_card is None + + +def test_init_with_custom_timeout(): + """Test initialization with custom timeout.""" + agent = A2AAgent(endpoint="http://localhost:8000", timeout=600) + assert agent.timeout == 600 + assert agent._httpx_client_args["timeout"] == 600 + + +def test_init_with_httpx_client_args(): + """Test initialization with custom httpx client arguments.""" + agent = A2AAgent( + endpoint="http://localhost:8000", + httpx_client_args={"headers": {"Authorization": "Bearer token"}}, + ) + assert "headers" in agent._httpx_client_args + assert agent._httpx_client_args["headers"]["Authorization"] == "Bearer token" + + +@pytest.mark.asyncio +async def test_discover_agent_card(a2a_agent, mock_agent_card): + """Test agent card discovery.""" + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_agent_card) + mock_resolver_class.return_value = mock_resolver + + card = await a2a_agent._discover_agent_card() + + assert card == mock_agent_card + assert a2a_agent._agent_card == mock_agent_card + + +@pytest.mark.asyncio +async def test_discover_agent_card_cached(a2a_agent, mock_agent_card): + """Test that agent card is cached after first discovery.""" + a2a_agent._agent_card = mock_agent_card + + card = await a2a_agent._discover_agent_card() + + assert card == mock_agent_card + + +def test_convert_string_input(a2a_agent): + """Test converting string input to A2A message.""" + message = a2a_agent._convert_input_to_message("Hello") + + assert isinstance(message, A2AMessage) + assert message.role == Role.user + assert len(message.parts) == 1 + assert message.parts[0].root.text == "Hello" + + +def test_convert_message_list_input(a2a_agent): + """Test converting message list input to A2A message.""" + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + message = a2a_agent._convert_input_to_message(messages) + + assert isinstance(message, A2AMessage) + assert message.role == Role.user + assert len(message.parts) == 1 + + +def test_convert_content_blocks_input(a2a_agent): + """Test converting content blocks input to A2A message.""" + content_blocks = [{"text": "Hello"}, {"text": "World"}] + + message = a2a_agent._convert_input_to_message(content_blocks) + + assert isinstance(message, A2AMessage) + assert len(message.parts) == 2 + + +def test_convert_unsupported_input(a2a_agent): + """Test that unsupported input types raise ValueError.""" + with pytest.raises(ValueError, match="Unsupported input type"): + a2a_agent._convert_input_to_message(123) + + +def test_convert_content_blocks_to_parts(a2a_agent): + """Test converting content blocks to A2A parts.""" + content_blocks = [{"text": "Hello"}, {"text": "World"}] + + parts = a2a_agent._convert_content_blocks_to_parts(content_blocks) + + assert len(parts) == 2 + assert parts[0].root.text == "Hello" + assert parts[1].root.text == "World" + + +def test_convert_a2a_message_response(a2a_agent): + """Test converting A2A message response to AgentResult.""" + a2a_message = A2AMessage( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) + + result = a2a_agent._convert_response_to_agent_result(a2a_message) + + assert isinstance(result, AgentResult) + assert result.message["role"] == "assistant" + assert len(result.message["content"]) == 1 + assert result.message["content"][0]["text"] == "Response" + + +def test_convert_task_response(a2a_agent): + """Test converting task response to AgentResult.""" + mock_task = MagicMock() + mock_artifact = MagicMock() + mock_part = MagicMock() + mock_part.root.text = "Task response" + mock_artifact.parts = [mock_part] + mock_task.artifacts = [mock_artifact] + + result = a2a_agent._convert_response_to_agent_result((mock_task, None)) + + assert isinstance(result, AgentResult) + assert len(result.message["content"]) == 1 + assert result.message["content"][0]["text"] == "Task response" + + +def test_convert_multiple_parts_response(a2a_agent): + """Test converting response with multiple parts to separate content blocks.""" + a2a_message = A2AMessage( + message_id=uuid4().hex, + role=Role.agent, + parts=[ + Part(TextPart(kind="text", text="First")), + Part(TextPart(kind="text", text="Second")), + ], + ) + + result = a2a_agent._convert_response_to_agent_result(a2a_message) + + assert len(result.message["content"]) == 2 + assert result.message["content"][0]["text"] == "First" + assert result.message["content"][1]["text"] == "Second" + + +@pytest.mark.asyncio +async def test_invoke_async_success(a2a_agent, mock_agent_card): + """Test successful async invocation.""" + mock_response = A2AMessage( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) + + async def mock_send_message(*args, **kwargs): + yield mock_response + + with patch.object(a2a_agent, "_discover_agent_card", return_value=mock_agent_card): + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_client = AsyncMock() + mock_client.send_message = mock_send_message + mock_factory = MagicMock() + mock_factory.create.return_value = mock_client + mock_factory_class.return_value = mock_factory + + result = await a2a_agent.invoke_async("Hello") + + assert isinstance(result, AgentResult) + assert result.message["content"][0]["text"] == "Response" + + +@pytest.mark.asyncio +async def test_invoke_async_no_prompt(a2a_agent): + """Test that invoke_async raises ValueError when prompt is None.""" + with pytest.raises(ValueError, match="prompt is required"): + await a2a_agent.invoke_async(None) + + +@pytest.mark.asyncio +async def test_invoke_async_no_response(a2a_agent, mock_agent_card): + """Test that invoke_async raises RuntimeError when no response received.""" + + async def mock_send_message(*args, **kwargs): + return + yield # Make it an async generator + + with patch.object(a2a_agent, "_discover_agent_card", return_value=mock_agent_card): + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_client = AsyncMock() + mock_client.send_message = mock_send_message + mock_factory = MagicMock() + mock_factory.create.return_value = mock_client + mock_factory_class.return_value = mock_factory + + with pytest.raises(RuntimeError, match="No response received"): + await a2a_agent.invoke_async("Hello") + + +def test_call_sync(a2a_agent): + """Test synchronous call method.""" + mock_result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=MagicMock(), + state={}, + ) + + with patch("strands.agent.a2a_agent.run_async") as mock_run_async: + mock_run_async.return_value = mock_result + + result = a2a_agent("Hello") + + assert result == mock_result + mock_run_async.assert_called_once() + + +@pytest.mark.asyncio +async def test_stream_async_success(a2a_agent, mock_agent_card): + """Test successful async streaming.""" + mock_event1 = MagicMock() + mock_event2 = MagicMock() + + async def mock_send_message(*args, **kwargs): + yield mock_event1 + yield mock_event2 + + with patch.object(a2a_agent, "_discover_agent_card", return_value=mock_agent_card): + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_client = AsyncMock() + mock_client.send_message = mock_send_message + mock_factory = MagicMock() + mock_factory.create.return_value = mock_client + mock_factory_class.return_value = mock_factory + + events = [] + async for event in a2a_agent.stream_async("Hello"): + events.append(event) + + assert len(events) == 2 + assert events[0]["a2a_event"] == mock_event1 + assert events[1]["a2a_event"] == mock_event2 + + +@pytest.mark.asyncio +async def test_stream_async_no_prompt(a2a_agent): + """Test that stream_async raises ValueError when prompt is None.""" + with pytest.raises(ValueError, match="prompt is required"): + async for _ in a2a_agent.stream_async(None): + pass From 5034f184ccf6128ac8b44c4567caea769ce57d47 Mon Sep 17 00:00:00 2001 From: Arron Bailiss Date: Thu, 13 Nov 2025 09:57:33 -0500 Subject: [PATCH 2/5] add integ test for a2a --- tests_integ/a2a/__init__.py | 0 tests_integ/a2a/a2a_server.py | 15 +++++++++++++++ tests_integ/a2a/test_multiagent_a2a.py | 26 ++++++++++++++++++++++++++ 3 files changed, 41 insertions(+) create mode 100644 tests_integ/a2a/__init__.py create mode 100644 tests_integ/a2a/a2a_server.py create mode 100644 tests_integ/a2a/test_multiagent_a2a.py diff --git a/tests_integ/a2a/__init__.py b/tests_integ/a2a/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_integ/a2a/a2a_server.py b/tests_integ/a2a/a2a_server.py new file mode 100644 index 000000000..047edc3ba --- /dev/null +++ b/tests_integ/a2a/a2a_server.py @@ -0,0 +1,15 @@ +from strands import Agent +from strands.multiagent.a2a import A2AServer + +# Create an agent and serve it over A2A +agent = Agent( + name="Test agent", + description="Test description here", + callback_handler=None, +) +a2a_server = A2AServer( + agent=agent, + host="localhost", + port=9000, +) +a2a_server.serve() diff --git a/tests_integ/a2a/test_multiagent_a2a.py b/tests_integ/a2a/test_multiagent_a2a.py new file mode 100644 index 000000000..ed7c43cd1 --- /dev/null +++ b/tests_integ/a2a/test_multiagent_a2a.py @@ -0,0 +1,26 @@ +import atexit +import os +import subprocess +import time + +from strands.agent.a2a_agent import A2AAgent + +# Start our A2A server +server_path = os.path.join(os.path.dirname(__file__), "a2a_server.py") +server = subprocess.Popen(["python", server_path]) + + +def cleanup(): + server.terminate() + + +atexit.register(cleanup) +time.sleep(5) # Wait for A2A server to start + +# Connect to our A2A server +a2a_agent = A2AAgent(endpoint="http://localhost:9000") + +# Invoke our A2A server +result = a2a_agent("Hello there!") + +# TODO: assertions on result From 64520b7e5f17664ce8efd5334651c7b0580c4f65 Mon Sep 17 00:00:00 2001 From: Arron Bailiss Date: Thu, 13 Nov 2025 09:58:34 -0500 Subject: [PATCH 3/5] move conversion functions to a utility module + rename _get_agent_card function + reduce use of Any type --- src/strands/agent/a2a_agent.py | 121 ++---------------- src/strands/multiagent/a2a/converters.py | 113 ++++++++++++++++ tests/strands/agent/test_a2a_agent.py | 114 +---------------- .../strands/multiagent/a2a/test_converters.py | 115 +++++++++++++++++ 4 files changed, 243 insertions(+), 220 deletions(-) create mode 100644 src/strands/multiagent/a2a/converters.py create mode 100644 tests/strands/multiagent/a2a/test_converters.py diff --git a/src/strands/agent/a2a_agent.py b/src/strands/agent/a2a_agent.py index bd58c162b..38bee41f9 100644 --- a/src/strands/agent/a2a_agent.py +++ b/src/strands/agent/a2a_agent.py @@ -5,18 +5,15 @@ """ import logging -from typing import Any, AsyncIterator, cast -from uuid import uuid4 +from typing import Any, AsyncIterator import httpx from a2a.client import A2ACardResolver, ClientConfig, ClientFactory -from a2a.types import AgentCard, Part, Role, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart -from a2a.types import Message as A2AMessage +from a2a.types import AgentCard from .._async import run_async -from ..telemetry.metrics import EventLoopMetrics +from ..multiagent.a2a.converters import convert_input_to_message, convert_response_to_agent_result from ..types.agent import AgentInput -from ..types.content import ContentBlock, Message from .agent_result import AgentResult logger = logging.getLogger(__name__) @@ -79,7 +76,7 @@ def _get_client_factory(self, streaming: bool = False) -> ClientFactory: ) return ClientFactory(config) - async def _discover_agent_card(self) -> AgentCard: + async def _get_agent_card(self) -> AgentCard: """Discover and cache the agent card from the remote endpoint. Returns: @@ -94,109 +91,7 @@ async def _discover_agent_card(self) -> AgentCard: logger.info("endpoint=<%s> | discovered agent card", self.endpoint) return self._agent_card - def _convert_input_to_message(self, prompt: AgentInput) -> A2AMessage: - """Convert AgentInput to A2A Message. - - Args: - prompt: Input in various formats (string, message list, or content blocks). - - Returns: - A2AMessage ready to send to the remote agent. - - Raises: - ValueError: If prompt format is unsupported. - """ - message_id = uuid4().hex - - if isinstance(prompt, str): - return A2AMessage( - kind="message", - role=Role.user, - parts=[Part(TextPart(kind="text", text=prompt))], - message_id=message_id, - ) - - if isinstance(prompt, list) and prompt and (isinstance(prompt[0], dict)): - if "role" in prompt[0]: - # Message list - extract last user message - for msg in reversed(prompt): - if msg.get("role") == "user": - content = cast(list[ContentBlock], msg.get("content", [])) - parts = self._convert_content_blocks_to_parts(content) - return A2AMessage( - kind="message", - role=Role.user, - parts=parts, - message_id=message_id, - ) - else: - # ContentBlock list - parts = self._convert_content_blocks_to_parts(cast(list[ContentBlock], prompt)) - return A2AMessage( - kind="message", - role=Role.user, - parts=parts, - message_id=message_id, - ) - - raise ValueError(f"Unsupported input type: {type(prompt)}") - - def _convert_content_blocks_to_parts(self, content_blocks: list[ContentBlock]) -> list[Part]: - """Convert Strands ContentBlocks to A2A Parts. - - Args: - content_blocks: List of Strands content blocks. - - Returns: - List of A2A Part objects. - """ - parts = [] - for block in content_blocks: - if "text" in block: - parts.append(Part(TextPart(kind="text", text=block["text"]))) - return parts - - def _convert_response_to_agent_result(self, response: Any) -> AgentResult: - """Convert A2A response to AgentResult. - - Args: - response: A2A response (either A2AMessage or tuple of task and update event). - - Returns: - AgentResult with extracted content and metadata. - """ - content: list[ContentBlock] = [] - - if isinstance(response, tuple) and len(response) == 2: - task, update_event = response - if update_event is None and task and hasattr(task, "artifacts"): - # Non-streaming response: extract from task artifacts - for artifact in task.artifacts: - if hasattr(artifact, "parts"): - for part in artifact.parts: - if hasattr(part, "root") and hasattr(part.root, "text"): - content.append({"text": part.root.text}) - elif isinstance(response, A2AMessage): - # Direct message response - for part in response.parts: - if hasattr(part, "root") and hasattr(part.root, "text"): - content.append({"text": part.root.text}) - - message: Message = { - "role": "assistant", - "content": content, - } - - return AgentResult( - stop_reason="end_turn", - message=message, - metrics=EventLoopMetrics(), - state={}, - ) - - async def _send_message( - self, prompt: AgentInput, streaming: bool - ) -> AsyncIterator[tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | A2AMessage]: + async def _send_message(self, prompt: AgentInput, streaming: bool) -> AsyncIterator[Any]: """Send message to A2A agent. Args: @@ -212,9 +107,9 @@ async def _send_message( if prompt is None: raise ValueError("prompt is required for A2AAgent") - agent_card = await self._discover_agent_card() + agent_card = await self._get_agent_card() client = self._get_client_factory(streaming=streaming).create(agent_card) - message = self._convert_input_to_message(prompt) + message = convert_input_to_message(prompt) logger.info("endpoint=<%s> | %s message", self.endpoint, "streaming" if streaming else "sending") return client.send_message(message) @@ -238,7 +133,7 @@ async def invoke_async( RuntimeError: If no response received from agent. """ async for event in await self._send_message(prompt, streaming=False): - return self._convert_response_to_agent_result(event) + return convert_response_to_agent_result(event) raise RuntimeError("No response received from A2A agent") diff --git a/src/strands/multiagent/a2a/converters.py b/src/strands/multiagent/a2a/converters.py new file mode 100644 index 000000000..420e4532c --- /dev/null +++ b/src/strands/multiagent/a2a/converters.py @@ -0,0 +1,113 @@ +"""Conversion functions between Strands and A2A types.""" + +from typing import TypeAlias, cast +from uuid import uuid4 + +from a2a.types import Message as A2AMessage +from a2a.types import Part, Role, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart + +from ...agent.agent_result import AgentResult +from ...telemetry.metrics import EventLoopMetrics +from ...types.agent import AgentInput +from ...types.content import ContentBlock, Message + +A2AResponse: TypeAlias = tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | A2AMessage + + +def convert_input_to_message(prompt: AgentInput) -> A2AMessage: + """Convert AgentInput to A2A Message. + + Args: + prompt: Input in various formats (string, message list, or content blocks). + + Returns: + A2AMessage ready to send to the remote agent. + + Raises: + ValueError: If prompt format is unsupported. + """ + message_id = uuid4().hex + + if isinstance(prompt, str): + return A2AMessage( + kind="message", + role=Role.user, + parts=[Part(TextPart(kind="text", text=prompt))], + message_id=message_id, + ) + + if isinstance(prompt, list) and prompt and (isinstance(prompt[0], dict)): + if "role" in prompt[0]: + for msg in reversed(prompt): + if msg.get("role") == "user": + content = cast(list[ContentBlock], msg.get("content", [])) + parts = convert_content_blocks_to_parts(content) + return A2AMessage( + kind="message", + role=Role.user, + parts=parts, + message_id=message_id, + ) + else: + parts = convert_content_blocks_to_parts(cast(list[ContentBlock], prompt)) + return A2AMessage( + kind="message", + role=Role.user, + parts=parts, + message_id=message_id, + ) + + raise ValueError(f"Unsupported input type: {type(prompt)}") + + +def convert_content_blocks_to_parts(content_blocks: list[ContentBlock]) -> list[Part]: + """Convert Strands ContentBlocks to A2A Parts. + + Args: + content_blocks: List of Strands content blocks. + + Returns: + List of A2A Part objects. + """ + parts = [] + for block in content_blocks: + if "text" in block: + parts.append(Part(TextPart(kind="text", text=block["text"]))) + return parts + + +def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: + """Convert A2A response to AgentResult. + + Args: + response: A2A response (either A2AMessage or tuple of task and update event). + + Returns: + AgentResult with extracted content and metadata. + """ + content: list[ContentBlock] = [] + + if isinstance(response, tuple) and len(response) == 2: + task, update_event = response + if update_event is None and task and hasattr(task, "artifacts") and task.artifacts is not None: + for artifact in task.artifacts: + if hasattr(artifact, "parts"): + for part in artifact.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + elif isinstance(response, A2AMessage): + for part in response.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + + message: Message = { + "role": "assistant", + "content": content, + } + + return AgentResult( + stop_reason="end_turn", + message=message, + metrics=EventLoopMetrics(), + state={}, + ) diff --git a/tests/strands/agent/test_a2a_agent.py b/tests/strands/agent/test_a2a_agent.py index fd94fc4a6..ab0a495f2 100644 --- a/tests/strands/agent/test_a2a_agent.py +++ b/tests/strands/agent/test_a2a_agent.py @@ -58,129 +58,29 @@ def test_init_with_httpx_client_args(): @pytest.mark.asyncio -async def test_discover_agent_card(a2a_agent, mock_agent_card): +async def test_get_agent_card(a2a_agent, mock_agent_card): """Test agent card discovery.""" with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: mock_resolver = AsyncMock() mock_resolver.get_agent_card = AsyncMock(return_value=mock_agent_card) mock_resolver_class.return_value = mock_resolver - card = await a2a_agent._discover_agent_card() + card = await a2a_agent._get_agent_card() assert card == mock_agent_card assert a2a_agent._agent_card == mock_agent_card @pytest.mark.asyncio -async def test_discover_agent_card_cached(a2a_agent, mock_agent_card): +async def test_get_agent_card_cached(a2a_agent, mock_agent_card): """Test that agent card is cached after first discovery.""" a2a_agent._agent_card = mock_agent_card - card = await a2a_agent._discover_agent_card() + card = await a2a_agent._get_agent_card() assert card == mock_agent_card -def test_convert_string_input(a2a_agent): - """Test converting string input to A2A message.""" - message = a2a_agent._convert_input_to_message("Hello") - - assert isinstance(message, A2AMessage) - assert message.role == Role.user - assert len(message.parts) == 1 - assert message.parts[0].root.text == "Hello" - - -def test_convert_message_list_input(a2a_agent): - """Test converting message list input to A2A message.""" - messages = [ - {"role": "user", "content": [{"text": "Hello"}]}, - ] - - message = a2a_agent._convert_input_to_message(messages) - - assert isinstance(message, A2AMessage) - assert message.role == Role.user - assert len(message.parts) == 1 - - -def test_convert_content_blocks_input(a2a_agent): - """Test converting content blocks input to A2A message.""" - content_blocks = [{"text": "Hello"}, {"text": "World"}] - - message = a2a_agent._convert_input_to_message(content_blocks) - - assert isinstance(message, A2AMessage) - assert len(message.parts) == 2 - - -def test_convert_unsupported_input(a2a_agent): - """Test that unsupported input types raise ValueError.""" - with pytest.raises(ValueError, match="Unsupported input type"): - a2a_agent._convert_input_to_message(123) - - -def test_convert_content_blocks_to_parts(a2a_agent): - """Test converting content blocks to A2A parts.""" - content_blocks = [{"text": "Hello"}, {"text": "World"}] - - parts = a2a_agent._convert_content_blocks_to_parts(content_blocks) - - assert len(parts) == 2 - assert parts[0].root.text == "Hello" - assert parts[1].root.text == "World" - - -def test_convert_a2a_message_response(a2a_agent): - """Test converting A2A message response to AgentResult.""" - a2a_message = A2AMessage( - message_id=uuid4().hex, - role=Role.agent, - parts=[Part(TextPart(kind="text", text="Response"))], - ) - - result = a2a_agent._convert_response_to_agent_result(a2a_message) - - assert isinstance(result, AgentResult) - assert result.message["role"] == "assistant" - assert len(result.message["content"]) == 1 - assert result.message["content"][0]["text"] == "Response" - - -def test_convert_task_response(a2a_agent): - """Test converting task response to AgentResult.""" - mock_task = MagicMock() - mock_artifact = MagicMock() - mock_part = MagicMock() - mock_part.root.text = "Task response" - mock_artifact.parts = [mock_part] - mock_task.artifacts = [mock_artifact] - - result = a2a_agent._convert_response_to_agent_result((mock_task, None)) - - assert isinstance(result, AgentResult) - assert len(result.message["content"]) == 1 - assert result.message["content"][0]["text"] == "Task response" - - -def test_convert_multiple_parts_response(a2a_agent): - """Test converting response with multiple parts to separate content blocks.""" - a2a_message = A2AMessage( - message_id=uuid4().hex, - role=Role.agent, - parts=[ - Part(TextPart(kind="text", text="First")), - Part(TextPart(kind="text", text="Second")), - ], - ) - - result = a2a_agent._convert_response_to_agent_result(a2a_message) - - assert len(result.message["content"]) == 2 - assert result.message["content"][0]["text"] == "First" - assert result.message["content"][1]["text"] == "Second" - - @pytest.mark.asyncio async def test_invoke_async_success(a2a_agent, mock_agent_card): """Test successful async invocation.""" @@ -193,7 +93,7 @@ async def test_invoke_async_success(a2a_agent, mock_agent_card): async def mock_send_message(*args, **kwargs): yield mock_response - with patch.object(a2a_agent, "_discover_agent_card", return_value=mock_agent_card): + with patch.object(a2a_agent, "_get_agent_card", return_value=mock_agent_card): with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: mock_client = AsyncMock() mock_client.send_message = mock_send_message @@ -222,7 +122,7 @@ async def mock_send_message(*args, **kwargs): return yield # Make it an async generator - with patch.object(a2a_agent, "_discover_agent_card", return_value=mock_agent_card): + with patch.object(a2a_agent, "_get_agent_card", return_value=mock_agent_card): with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: mock_client = AsyncMock() mock_client.send_message = mock_send_message @@ -262,7 +162,7 @@ async def mock_send_message(*args, **kwargs): yield mock_event1 yield mock_event2 - with patch.object(a2a_agent, "_discover_agent_card", return_value=mock_agent_card): + with patch.object(a2a_agent, "_get_agent_card", return_value=mock_agent_card): with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: mock_client = AsyncMock() mock_client.send_message = mock_send_message diff --git a/tests/strands/multiagent/a2a/test_converters.py b/tests/strands/multiagent/a2a/test_converters.py new file mode 100644 index 000000000..0d443e4ed --- /dev/null +++ b/tests/strands/multiagent/a2a/test_converters.py @@ -0,0 +1,115 @@ +"""Tests for A2A converter functions.""" + +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest +from a2a.types import Message as A2AMessage +from a2a.types import Part, Role, TextPart + +from strands.agent.agent_result import AgentResult +from strands.multiagent.a2a.converters import ( + convert_content_blocks_to_parts, + convert_input_to_message, + convert_response_to_agent_result, +) + + +def test_convert_string_input(): + """Test converting string input to A2A message.""" + message = convert_input_to_message("Hello") + + assert isinstance(message, A2AMessage) + assert message.role == Role.user + assert len(message.parts) == 1 + assert message.parts[0].root.text == "Hello" + + +def test_convert_message_list_input(): + """Test converting message list input to A2A message.""" + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + message = convert_input_to_message(messages) + + assert isinstance(message, A2AMessage) + assert message.role == Role.user + assert len(message.parts) == 1 + + +def test_convert_content_blocks_input(): + """Test converting content blocks input to A2A message.""" + content_blocks = [{"text": "Hello"}, {"text": "World"}] + + message = convert_input_to_message(content_blocks) + + assert isinstance(message, A2AMessage) + assert len(message.parts) == 2 + + +def test_convert_unsupported_input(): + """Test that unsupported input types raise ValueError.""" + with pytest.raises(ValueError, match="Unsupported input type"): + convert_input_to_message(123) + + +def test_convert_content_blocks_to_parts(): + """Test converting content blocks to A2A parts.""" + content_blocks = [{"text": "Hello"}, {"text": "World"}] + + parts = convert_content_blocks_to_parts(content_blocks) + + assert len(parts) == 2 + assert parts[0].root.text == "Hello" + assert parts[1].root.text == "World" + + +def test_convert_a2a_message_response(): + """Test converting A2A message response to AgentResult.""" + a2a_message = A2AMessage( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) + + result = convert_response_to_agent_result(a2a_message) + + assert isinstance(result, AgentResult) + assert result.message["role"] == "assistant" + assert len(result.message["content"]) == 1 + assert result.message["content"][0]["text"] == "Response" + + +def test_convert_task_response(): + """Test converting task response to AgentResult.""" + mock_task = MagicMock() + mock_artifact = MagicMock() + mock_part = MagicMock() + mock_part.root.text = "Task response" + mock_artifact.parts = [mock_part] + mock_task.artifacts = [mock_artifact] + + result = convert_response_to_agent_result((mock_task, None)) + + assert isinstance(result, AgentResult) + assert len(result.message["content"]) == 1 + assert result.message["content"][0]["text"] == "Task response" + + +def test_convert_multiple_parts_response(): + """Test converting response with multiple parts to separate content blocks.""" + a2a_message = A2AMessage( + message_id=uuid4().hex, + role=Role.agent, + parts=[ + Part(TextPart(kind="text", text="First")), + Part(TextPart(kind="text", text="Second")), + ], + ) + + result = convert_response_to_agent_result(a2a_message) + + assert len(result.message["content"]) == 2 + assert result.message["content"][0]["text"] == "First" + assert result.message["content"][1]["text"] == "Second" From 90124cf62754a7c781fb1e98d93eed76ef00c3fd Mon Sep 17 00:00:00 2001 From: Arron Bailiss Date: Thu, 13 Nov 2025 12:56:17 -0500 Subject: [PATCH 4/5] fix a2a integ test --- tests_integ/a2a/test_multiagent_a2a.py | 27 +++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/tests_integ/a2a/test_multiagent_a2a.py b/tests_integ/a2a/test_multiagent_a2a.py index ed7c43cd1..5d8d13fc5 100644 --- a/tests_integ/a2a/test_multiagent_a2a.py +++ b/tests_integ/a2a/test_multiagent_a2a.py @@ -5,22 +5,23 @@ from strands.agent.a2a_agent import A2AAgent -# Start our A2A server -server_path = os.path.join(os.path.dirname(__file__), "a2a_server.py") -server = subprocess.Popen(["python", server_path]) +def test_a2a_agent(): + # Start our A2A server + server_path = os.path.join(os.path.dirname(__file__), "a2a_server.py") + server = subprocess.Popen(["python", server_path]) -def cleanup(): - server.terminate() + def cleanup(): + server.terminate() + atexit.register(cleanup) + time.sleep(5) # Wait for A2A server to start -atexit.register(cleanup) -time.sleep(5) # Wait for A2A server to start + # Connect to our A2A server + a2a_agent = A2AAgent(endpoint="http://localhost:9000") -# Connect to our A2A server -a2a_agent = A2AAgent(endpoint="http://localhost:9000") + # Invoke our A2A server + result = a2a_agent("Hello there!") -# Invoke our A2A server -result = a2a_agent("Hello there!") - -# TODO: assertions on result + # Ensure that it was successful + assert result.stop_reason == "end_turn" From 182f9a4f22446a6b02828636db10144efba2ec27 Mon Sep 17 00:00:00 2001 From: Arron Bailiss Date: Tue, 18 Nov 2025 14:03:54 -0500 Subject: [PATCH 5/5] add A2AStreamEvent + handle partial responses in stream_async + modify constructor parameters --- src/strands/agent/a2a_agent.py | 171 ++++++++++++++++------- src/strands/multiagent/a2a/converters.py | 23 ++- src/strands/types/a2a.py | 26 ++++ tests/strands/agent/test_a2a_agent.py | 80 ++++++++--- 4 files changed, 229 insertions(+), 71 deletions(-) create mode 100644 src/strands/types/a2a.py diff --git a/src/strands/agent/a2a_agent.py b/src/strands/agent/a2a_agent.py index 38bee41f9..1f2e015a6 100644 --- a/src/strands/agent/a2a_agent.py +++ b/src/strands/agent/a2a_agent.py @@ -1,18 +1,22 @@ """A2A Agent client for Strands Agents. This module provides the A2AAgent class, which acts as a client wrapper for remote A2A agents, -allowing them to be used in graphs, swarms, and other multi-agent patterns. +allowing them to be used standalone or as part of multi-agent patterns. + +A2AAgent can be used to get the Agent Card and interact with the agent. """ import logging from typing import Any, AsyncIterator import httpx -from a2a.client import A2ACardResolver, ClientConfig, ClientFactory -from a2a.types import AgentCard +from a2a.client import A2ACardResolver, Client, ClientConfig, ClientFactory +from a2a.types import AgentCard, Message, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent from .._async import run_async from ..multiagent.a2a.converters import convert_input_to_message, convert_response_to_agent_result +from ..types._events import AgentResultEvent +from ..types.a2a import A2AResponse, A2AStreamEvent from ..types.agent import AgentInput from .agent_result import AgentResult @@ -22,59 +26,46 @@ class A2AAgent: - """Client wrapper for remote A2A agents. - - Implements the AgentBase protocol to enable remote A2A agents to be used - in graphs, swarms, and other multi-agent patterns. - """ + """Client wrapper for remote A2A agents.""" def __init__( self, endpoint: str, + *, + name: str | None = None, + description: str = "", timeout: int = DEFAULT_TIMEOUT, - httpx_client_args: dict[str, Any] | None = None, + a2a_client_factory: ClientFactory | None = None, ): - """Initialize A2A agent client. + """Initialize A2A agent. Args: - endpoint: The base URL of the remote A2A agent - timeout: Timeout for HTTP operations in seconds (defaults to 300) - httpx_client_args: Optional dictionary of arguments to pass to httpx.AsyncClient - constructor. Allows custom auth, headers, proxies, etc. - Example: {"headers": {"Authorization": "Bearer token"}} + endpoint: The base URL of the remote A2A agent. + name: Agent name. If not provided, will be populated from agent card. + description: Agent description. If empty, will be populated from agent card. + timeout: Timeout for HTTP operations in seconds (defaults to 300). + a2a_client_factory: Optional pre-configured A2A ClientFactory. If provided, + it will be used to create the A2A client after discovering the agent card. """ self.endpoint = endpoint + self.name = name + self.description = description self.timeout = timeout - self._httpx_client_args: dict[str, Any] = httpx_client_args or {} - - if "timeout" not in self._httpx_client_args: - self._httpx_client_args["timeout"] = self.timeout - + self._httpx_client: httpx.AsyncClient | None = None + self._owns_client = a2a_client_factory is None self._agent_card: AgentCard | None = None + self._a2a_client: Client | None = None + self._a2a_client_factory: ClientFactory | None = a2a_client_factory def _get_httpx_client(self) -> httpx.AsyncClient: - """Get a fresh httpx client for the current operation. + """Get or create the httpx client for this agent. Returns: Configured httpx.AsyncClient instance. """ - return httpx.AsyncClient(**self._httpx_client_args) - - def _get_client_factory(self, streaming: bool = False) -> ClientFactory: - """Get a ClientFactory for the current operation. - - Args: - streaming: Whether to enable streaming mode. - - Returns: - Configured ClientFactory instance. - """ - httpx_client = self._get_httpx_client() - config = ClientConfig( - httpx_client=httpx_client, - streaming=streaming, - ) - return ClientFactory(config) + if self._httpx_client is None: + self._httpx_client = httpx.AsyncClient(timeout=self.timeout) + return self._httpx_client async def _get_agent_card(self) -> AgentCard: """Discover and cache the agent card from the remote endpoint. @@ -88,15 +79,44 @@ async def _get_agent_card(self) -> AgentCard: httpx_client = self._get_httpx_client() resolver = A2ACardResolver(httpx_client=httpx_client, base_url=self.endpoint) self._agent_card = await resolver.get_agent_card() - logger.info("endpoint=<%s> | discovered agent card", self.endpoint) + + # Populate name from card if not set + if self.name is None and self._agent_card.name: + self.name = self._agent_card.name + + # Populate description from card if not set + if not self.description and self._agent_card.description: + self.description = self._agent_card.description + + logger.info("agent=<%s>, endpoint=<%s> | discovered agent card", self.name, self.endpoint) return self._agent_card - async def _send_message(self, prompt: AgentInput, streaming: bool) -> AsyncIterator[Any]: + async def _get_a2a_client(self) -> Client: + """Get or create the A2A client for this agent. + + Returns: + Configured A2A client instance. + """ + if self._a2a_client is None: + agent_card = await self._get_agent_card() + + if self._a2a_client_factory is not None: + # Use provided factory + factory = self._a2a_client_factory + else: + # Create default factory + httpx_client = self._get_httpx_client() + config = ClientConfig(httpx_client=httpx_client, streaming=False) + factory = ClientFactory(config) + + self._a2a_client = factory.create(agent_card) + return self._a2a_client + + async def _send_message(self, prompt: AgentInput) -> AsyncIterator[A2AResponse]: """Send message to A2A agent. Args: prompt: Input to send to the agent. - streaming: Whether to use streaming mode. Returns: Async iterator of A2A events. @@ -107,13 +127,46 @@ async def _send_message(self, prompt: AgentInput, streaming: bool) -> AsyncItera if prompt is None: raise ValueError("prompt is required for A2AAgent") - agent_card = await self._get_agent_card() - client = self._get_client_factory(streaming=streaming).create(agent_card) + client = await self._get_a2a_client() message = convert_input_to_message(prompt) - logger.info("endpoint=<%s> | %s message", self.endpoint, "streaming" if streaming else "sending") + logger.info("agent=<%s>, endpoint=<%s> | sending message", self.name, self.endpoint) return client.send_message(message) + def _is_complete_event(self, event: A2AResponse) -> bool: + """Check if an A2A event represents a complete response. + + Args: + event: A2A event. + + Returns: + True if the event represents a complete response. + """ + # Direct Message is always complete + if isinstance(event, Message): + return True + + # Handle tuple responses (Task, UpdateEvent | None) + if isinstance(event, tuple) and len(event) == 2: + task, update_event = event + + # Initial task response (no update event) + if update_event is None: + return True + + # Artifact update with last_chunk flag + if isinstance(update_event, TaskArtifactUpdateEvent): + if hasattr(update_event, "last_chunk") and update_event.last_chunk is not None: + return update_event.last_chunk + return False + + # Status update with completed state + if isinstance(update_event, TaskStatusUpdateEvent): + if update_event.status and hasattr(update_event.status, "state"): + return update_event.status.state == TaskState.completed + + return False + async def invoke_async( self, prompt: AgentInput = None, @@ -132,7 +185,7 @@ async def invoke_async( ValueError: If prompt is None. RuntimeError: If no response received from agent. """ - async for event in await self._send_message(prompt, streaming=False): + async for event in await self._send_message(prompt): return convert_response_to_agent_result(event) raise RuntimeError("No response received from A2A agent") @@ -169,10 +222,32 @@ async def stream_async( **kwargs: Additional arguments (ignored). Yields: - A2A events wrapped in dictionaries with an 'a2a_event' key. + A2A events and a final AgentResult event. Raises: ValueError: If prompt is None. """ - async for event in await self._send_message(prompt, streaming=True): - yield {"a2a_event": event} + last_event = None + last_complete_event = None + + async for event in await self._send_message(prompt): + last_event = event + if self._is_complete_event(event): + last_complete_event = event + yield A2AStreamEvent(event) + + # Use the last complete event if available, otherwise fall back to last event + final_event = last_complete_event if last_complete_event is not None else last_event + + if final_event is not None: + result = convert_response_to_agent_result(final_event) + yield AgentResultEvent(result) + + def __del__(self) -> None: + """Clean up resources when agent is garbage collected.""" + if self._owns_client and self._httpx_client is not None: + try: + client = self._httpx_client + run_async(lambda: client.aclose()) + except Exception: + pass # Best effort cleanup, ignore errors in __del__ diff --git a/src/strands/multiagent/a2a/converters.py b/src/strands/multiagent/a2a/converters.py index 420e4532c..efd04cddc 100644 --- a/src/strands/multiagent/a2a/converters.py +++ b/src/strands/multiagent/a2a/converters.py @@ -1,18 +1,17 @@ """Conversion functions between Strands and A2A types.""" -from typing import TypeAlias, cast +from typing import cast from uuid import uuid4 from a2a.types import Message as A2AMessage -from a2a.types import Part, Role, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart +from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart from ...agent.agent_result import AgentResult from ...telemetry.metrics import EventLoopMetrics +from ...types.a2a import A2AResponse from ...types.agent import AgentInput from ...types.content import ContentBlock, Message -A2AResponse: TypeAlias = tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | A2AMessage - def convert_input_to_message(prompt: AgentInput) -> A2AMessage: """Convert AgentInput to A2A Message. @@ -89,7 +88,21 @@ def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: if isinstance(response, tuple) and len(response) == 2: task, update_event = response - if update_event is None and task and hasattr(task, "artifacts") and task.artifacts is not None: + + # Handle artifact updates + if isinstance(update_event, TaskArtifactUpdateEvent): + if update_event.artifact and hasattr(update_event.artifact, "parts"): + for part in update_event.artifact.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + # Handle status updates with messages + elif isinstance(update_event, TaskStatusUpdateEvent): + if update_event.status and hasattr(update_event.status, "message") and update_event.status.message: + for part in update_event.status.message.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + # Handle initial task or task without update event + elif update_event is None and task and hasattr(task, "artifacts") and task.artifacts is not None: for artifact in task.artifacts: if hasattr(artifact, "parts"): for part in artifact.parts: diff --git a/src/strands/types/a2a.py b/src/strands/types/a2a.py new file mode 100644 index 000000000..3a7bb091b --- /dev/null +++ b/src/strands/types/a2a.py @@ -0,0 +1,26 @@ +"""Additional A2A types.""" + +from typing import TypeAlias + +from a2a.types import Message, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent + +from ._events import TypedEvent + +A2AResponse: TypeAlias = tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | Message + + +class A2AStreamEvent(TypedEvent): + """Event that wraps streamed A2A types.""" + + def __init__(self, a2a_event: A2AResponse) -> None: + """Initialize with A2A event. + + Args: + a2a_event: The original A2A event (Task tuple or Message) + """ + super().__init__( + { + "type": "a2a_stream", + "event": a2a_event, # Nest A2A event to avoid field conflicts + } + ) diff --git a/tests/strands/agent/test_a2a_agent.py b/tests/strands/agent/test_a2a_agent.py index ab0a495f2..d595787b7 100644 --- a/tests/strands/agent/test_a2a_agent.py +++ b/tests/strands/agent/test_a2a_agent.py @@ -4,8 +4,7 @@ from uuid import uuid4 import pytest -from a2a.types import AgentCard, Part, Role, TextPart -from a2a.types import Message as A2AMessage +from a2a.types import AgentCard, Message, Part, Role, TextPart from strands.agent.a2a_agent import A2AAgent from strands.agent.agent_result import AgentResult @@ -38,23 +37,29 @@ def test_init_with_defaults(): assert agent.endpoint == "http://localhost:8000" assert agent.timeout == 300 assert agent._agent_card is None + assert agent.name is None + assert agent.description == "" + + +def test_init_with_name_and_description(): + """Test initialization with custom name and description.""" + agent = A2AAgent(endpoint="http://localhost:8000", name="my-agent", description="My custom agent") + assert agent.name == "my-agent" + assert agent.description == "My custom agent" def test_init_with_custom_timeout(): """Test initialization with custom timeout.""" agent = A2AAgent(endpoint="http://localhost:8000", timeout=600) assert agent.timeout == 600 - assert agent._httpx_client_args["timeout"] == 600 -def test_init_with_httpx_client_args(): - """Test initialization with custom httpx client arguments.""" - agent = A2AAgent( - endpoint="http://localhost:8000", - httpx_client_args={"headers": {"Authorization": "Bearer token"}}, - ) - assert "headers" in agent._httpx_client_args - assert agent._httpx_client_args["headers"]["Authorization"] == "Bearer token" +def test_init_with_external_a2a_client_factory(): + """Test initialization with external A2A client factory.""" + external_factory = MagicMock() + agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=external_factory) + assert agent._a2a_client_factory is external_factory + assert not agent._owns_client @pytest.mark.asyncio @@ -81,10 +86,42 @@ async def test_get_agent_card_cached(a2a_agent, mock_agent_card): assert card == mock_agent_card +@pytest.mark.asyncio +async def test_get_agent_card_populates_name_and_description(mock_agent_card): + """Test that agent card populates name and description if not set.""" + agent = A2AAgent(endpoint="http://localhost:8000") + + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_agent_card) + mock_resolver_class.return_value = mock_resolver + + await agent._get_agent_card() + + assert agent.name == mock_agent_card.name + assert agent.description == mock_agent_card.description + + +@pytest.mark.asyncio +async def test_get_agent_card_preserves_custom_name_and_description(mock_agent_card): + """Test that custom name and description are not overridden by agent card.""" + agent = A2AAgent(endpoint="http://localhost:8000", name="custom-name", description="Custom description") + + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_agent_card) + mock_resolver_class.return_value = mock_resolver + + await agent._get_agent_card() + + assert agent.name == "custom-name" + assert agent.description == "Custom description" + + @pytest.mark.asyncio async def test_invoke_async_success(a2a_agent, mock_agent_card): """Test successful async invocation.""" - mock_response = A2AMessage( + mock_response = Message( message_id=uuid4().hex, role=Role.agent, parts=[Part(TextPart(kind="text", text="Response"))], @@ -155,12 +192,14 @@ def test_call_sync(a2a_agent): @pytest.mark.asyncio async def test_stream_async_success(a2a_agent, mock_agent_card): """Test successful async streaming.""" - mock_event1 = MagicMock() - mock_event2 = MagicMock() + mock_response = Message( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) async def mock_send_message(*args, **kwargs): - yield mock_event1 - yield mock_event2 + yield mock_response with patch.object(a2a_agent, "_get_agent_card", return_value=mock_agent_card): with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: @@ -175,8 +214,13 @@ async def mock_send_message(*args, **kwargs): events.append(event) assert len(events) == 2 - assert events[0]["a2a_event"] == mock_event1 - assert events[1]["a2a_event"] == mock_event2 + # First event is A2A stream event + assert events[0]["type"] == "a2a_stream" + assert events[0]["event"] == mock_response + # Final event is AgentResult + assert "result" in events[1] + assert isinstance(events[1]["result"], AgentResult) + assert events[1]["result"].message["content"][0]["text"] == "Response" @pytest.mark.asyncio