generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 483
feat(a2a): add A2AAgent class as an implementation of the agent interface for remote A2A protocol based agents #1174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
awsarron
wants to merge
5
commits into
strands-agents:main
Choose a base branch
from
awsarron:feat-a2a-agent
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
a7b4036
feat(a2a): add A2AAgent class as an implementation of the agent inter…
awsarron 5034f18
add integ test for a2a
awsarron 64520b7
move conversion functions to a utility module + rename _get_agent_car…
awsarron 90124cf
fix a2a integ test
awsarron 182f9a4
add A2AStreamEvent + handle partial responses in stream_async + modif…
awsarron File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,253 @@ | ||
| """A2A Agent client for Strands Agents. | ||
|
|
||
| This module provides the A2AAgent class, which acts as a client wrapper for remote A2A agents, | ||
| allowing them to be used standalone or as part of multi-agent patterns. | ||
|
|
||
| A2AAgent can be used to get the Agent Card and interact with the agent. | ||
| """ | ||
|
|
||
| import logging | ||
| from typing import Any, AsyncIterator | ||
|
|
||
| import httpx | ||
| from a2a.client import A2ACardResolver, Client, ClientConfig, ClientFactory | ||
| from a2a.types import AgentCard, Message, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent | ||
|
|
||
| from .._async import run_async | ||
| from ..multiagent.a2a.converters import convert_input_to_message, convert_response_to_agent_result | ||
| from ..types._events import AgentResultEvent | ||
| from ..types.a2a import A2AResponse, A2AStreamEvent | ||
| from ..types.agent import AgentInput | ||
| from .agent_result import AgentResult | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| DEFAULT_TIMEOUT = 300 | ||
|
|
||
|
|
||
| class A2AAgent: | ||
| """Client wrapper for remote A2A agents.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| endpoint: str, | ||
| *, | ||
| name: str | None = None, | ||
| description: str = "", | ||
| timeout: int = DEFAULT_TIMEOUT, | ||
| a2a_client_factory: ClientFactory | None = None, | ||
| ): | ||
| """Initialize A2A agent. | ||
|
|
||
| Args: | ||
| endpoint: The base URL of the remote A2A agent. | ||
| name: Agent name. If not provided, will be populated from agent card. | ||
| description: Agent description. If empty, will be populated from agent card. | ||
| timeout: Timeout for HTTP operations in seconds (defaults to 300). | ||
| a2a_client_factory: Optional pre-configured A2A ClientFactory. If provided, | ||
| it will be used to create the A2A client after discovering the agent card. | ||
| """ | ||
| self.endpoint = endpoint | ||
| self.name = name | ||
| self.description = description | ||
| self.timeout = timeout | ||
| self._httpx_client: httpx.AsyncClient | None = None | ||
| self._owns_client = a2a_client_factory is None | ||
| self._agent_card: AgentCard | None = None | ||
| self._a2a_client: Client | None = None | ||
| self._a2a_client_factory: ClientFactory | None = a2a_client_factory | ||
|
|
||
| def _get_httpx_client(self) -> httpx.AsyncClient: | ||
| """Get or create the httpx client for this agent. | ||
|
|
||
| Returns: | ||
| Configured httpx.AsyncClient instance. | ||
| """ | ||
| if self._httpx_client is None: | ||
| self._httpx_client = httpx.AsyncClient(timeout=self.timeout) | ||
| return self._httpx_client | ||
|
|
||
| async def _get_agent_card(self) -> AgentCard: | ||
| """Discover and cache the agent card from the remote endpoint. | ||
|
|
||
| Returns: | ||
| The discovered AgentCard. | ||
| """ | ||
| if self._agent_card is not None: | ||
| return self._agent_card | ||
|
|
||
| httpx_client = self._get_httpx_client() | ||
| resolver = A2ACardResolver(httpx_client=httpx_client, base_url=self.endpoint) | ||
| self._agent_card = await resolver.get_agent_card() | ||
|
|
||
| # Populate name from card if not set | ||
| if self.name is None and self._agent_card.name: | ||
| self.name = self._agent_card.name | ||
|
|
||
| # Populate description from card if not set | ||
| if not self.description and self._agent_card.description: | ||
| self.description = self._agent_card.description | ||
|
|
||
| logger.info("agent=<%s>, endpoint=<%s> | discovered agent card", self.name, self.endpoint) | ||
| return self._agent_card | ||
|
|
||
| async def _get_a2a_client(self) -> Client: | ||
| """Get or create the A2A client for this agent. | ||
|
|
||
| Returns: | ||
| Configured A2A client instance. | ||
| """ | ||
| if self._a2a_client is None: | ||
| agent_card = await self._get_agent_card() | ||
|
|
||
| if self._a2a_client_factory is not None: | ||
| # Use provided factory | ||
| factory = self._a2a_client_factory | ||
| else: | ||
| # Create default factory | ||
| httpx_client = self._get_httpx_client() | ||
| config = ClientConfig(httpx_client=httpx_client, streaming=False) | ||
| factory = ClientFactory(config) | ||
|
|
||
| self._a2a_client = factory.create(agent_card) | ||
| return self._a2a_client | ||
|
|
||
| async def _send_message(self, prompt: AgentInput) -> AsyncIterator[A2AResponse]: | ||
| """Send message to A2A agent. | ||
|
|
||
| Args: | ||
| prompt: Input to send to the agent. | ||
|
|
||
| Returns: | ||
| Async iterator of A2A events. | ||
|
|
||
| Raises: | ||
| ValueError: If prompt is None. | ||
| """ | ||
| if prompt is None: | ||
| raise ValueError("prompt is required for A2AAgent") | ||
|
|
||
| client = await self._get_a2a_client() | ||
| message = convert_input_to_message(prompt) | ||
|
|
||
| logger.info("agent=<%s>, endpoint=<%s> | sending message", self.name, self.endpoint) | ||
| return client.send_message(message) | ||
|
|
||
| def _is_complete_event(self, event: A2AResponse) -> bool: | ||
| """Check if an A2A event represents a complete response. | ||
|
|
||
| Args: | ||
| event: A2A event. | ||
|
|
||
| Returns: | ||
| True if the event represents a complete response. | ||
| """ | ||
| # Direct Message is always complete | ||
| if isinstance(event, Message): | ||
| return True | ||
|
|
||
| # Handle tuple responses (Task, UpdateEvent | None) | ||
| if isinstance(event, tuple) and len(event) == 2: | ||
| task, update_event = event | ||
|
|
||
| # Initial task response (no update event) | ||
| if update_event is None: | ||
| return True | ||
|
|
||
| # Artifact update with last_chunk flag | ||
| if isinstance(update_event, TaskArtifactUpdateEvent): | ||
| if hasattr(update_event, "last_chunk") and update_event.last_chunk is not None: | ||
| return update_event.last_chunk | ||
| return False | ||
|
|
||
| # Status update with completed state | ||
| if isinstance(update_event, TaskStatusUpdateEvent): | ||
| if update_event.status and hasattr(update_event.status, "state"): | ||
| return update_event.status.state == TaskState.completed | ||
|
|
||
| return False | ||
|
|
||
| async def invoke_async( | ||
| self, | ||
| prompt: AgentInput = None, | ||
| **kwargs: Any, | ||
| ) -> AgentResult: | ||
| """Asynchronously invoke the remote A2A agent. | ||
|
|
||
| Args: | ||
| prompt: Input to the agent (string, message list, or content blocks). | ||
| **kwargs: Additional arguments (ignored). | ||
|
|
||
| Returns: | ||
| AgentResult containing the agent's response. | ||
|
|
||
| Raises: | ||
| ValueError: If prompt is None. | ||
| RuntimeError: If no response received from agent. | ||
| """ | ||
| async for event in await self._send_message(prompt): | ||
| return convert_response_to_agent_result(event) | ||
|
|
||
| raise RuntimeError("No response received from A2A agent") | ||
|
|
||
| def __call__( | ||
| self, | ||
| prompt: AgentInput = None, | ||
| **kwargs: Any, | ||
| ) -> AgentResult: | ||
| """Synchronously invoke the remote A2A agent. | ||
|
|
||
| Args: | ||
| prompt: Input to the agent (string, message list, or content blocks). | ||
| **kwargs: Additional arguments (ignored). | ||
|
|
||
| Returns: | ||
| AgentResult containing the agent's response. | ||
|
|
||
| Raises: | ||
| ValueError: If prompt is None. | ||
| RuntimeError: If no response received from agent. | ||
| """ | ||
| return run_async(lambda: self.invoke_async(prompt, **kwargs)) | ||
|
|
||
| async def stream_async( | ||
| self, | ||
| prompt: AgentInput = None, | ||
| **kwargs: Any, | ||
| ) -> AsyncIterator[Any]: | ||
| """Stream agent execution asynchronously. | ||
|
|
||
| Args: | ||
| prompt: Input to the agent (string, message list, or content blocks). | ||
| **kwargs: Additional arguments (ignored). | ||
|
|
||
| Yields: | ||
| A2A events and a final AgentResult event. | ||
|
|
||
| Raises: | ||
| ValueError: If prompt is None. | ||
| """ | ||
| last_event = None | ||
| last_complete_event = None | ||
|
|
||
| async for event in await self._send_message(prompt): | ||
| last_event = event | ||
| if self._is_complete_event(event): | ||
| last_complete_event = event | ||
| yield A2AStreamEvent(event) | ||
|
|
||
| # Use the last complete event if available, otherwise fall back to last event | ||
| final_event = last_complete_event if last_complete_event is not None else last_event | ||
|
|
||
| if final_event is not None: | ||
| result = convert_response_to_agent_result(final_event) | ||
| yield AgentResultEvent(result) | ||
|
|
||
| def __del__(self) -> None: | ||
| """Clean up resources when agent is garbage collected.""" | ||
| if self._owns_client and self._httpx_client is not None: | ||
| try: | ||
| client = self._httpx_client | ||
| run_async(lambda: client.aclose()) | ||
| except Exception: | ||
| pass # Best effort cleanup, ignore errors in __del__ | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| """Conversion functions between Strands and A2A types.""" | ||
|
|
||
| from typing import cast | ||
| from uuid import uuid4 | ||
|
|
||
| from a2a.types import Message as A2AMessage | ||
| from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart | ||
|
|
||
| from ...agent.agent_result import AgentResult | ||
| from ...telemetry.metrics import EventLoopMetrics | ||
| from ...types.a2a import A2AResponse | ||
| from ...types.agent import AgentInput | ||
| from ...types.content import ContentBlock, Message | ||
|
|
||
|
|
||
| def convert_input_to_message(prompt: AgentInput) -> A2AMessage: | ||
| """Convert AgentInput to A2A Message. | ||
|
|
||
| Args: | ||
| prompt: Input in various formats (string, message list, or content blocks). | ||
|
|
||
| Returns: | ||
| A2AMessage ready to send to the remote agent. | ||
|
|
||
| Raises: | ||
| ValueError: If prompt format is unsupported. | ||
| """ | ||
| message_id = uuid4().hex | ||
|
|
||
| if isinstance(prompt, str): | ||
| return A2AMessage( | ||
| kind="message", | ||
| role=Role.user, | ||
| parts=[Part(TextPart(kind="text", text=prompt))], | ||
| message_id=message_id, | ||
| ) | ||
|
|
||
| if isinstance(prompt, list) and prompt and (isinstance(prompt[0], dict)): | ||
| if "role" in prompt[0]: | ||
| for msg in reversed(prompt): | ||
| if msg.get("role") == "user": | ||
| content = cast(list[ContentBlock], msg.get("content", [])) | ||
| parts = convert_content_blocks_to_parts(content) | ||
| return A2AMessage( | ||
| kind="message", | ||
| role=Role.user, | ||
| parts=parts, | ||
| message_id=message_id, | ||
| ) | ||
| else: | ||
| parts = convert_content_blocks_to_parts(cast(list[ContentBlock], prompt)) | ||
| return A2AMessage( | ||
| kind="message", | ||
| role=Role.user, | ||
| parts=parts, | ||
| message_id=message_id, | ||
| ) | ||
|
|
||
| raise ValueError(f"Unsupported input type: {type(prompt)}") | ||
|
|
||
|
|
||
| def convert_content_blocks_to_parts(content_blocks: list[ContentBlock]) -> list[Part]: | ||
| """Convert Strands ContentBlocks to A2A Parts. | ||
|
|
||
| Args: | ||
| content_blocks: List of Strands content blocks. | ||
|
|
||
| Returns: | ||
| List of A2A Part objects. | ||
| """ | ||
| parts = [] | ||
| for block in content_blocks: | ||
| if "text" in block: | ||
| parts.append(Part(TextPart(kind="text", text=block["text"]))) | ||
| return parts | ||
|
|
||
|
|
||
| def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: | ||
| """Convert A2A response to AgentResult. | ||
|
|
||
| Args: | ||
| response: A2A response (either A2AMessage or tuple of task and update event). | ||
|
|
||
| Returns: | ||
| AgentResult with extracted content and metadata. | ||
| """ | ||
| content: list[ContentBlock] = [] | ||
|
|
||
| if isinstance(response, tuple) and len(response) == 2: | ||
| task, update_event = response | ||
|
|
||
| # Handle artifact updates | ||
| if isinstance(update_event, TaskArtifactUpdateEvent): | ||
| if update_event.artifact and hasattr(update_event.artifact, "parts"): | ||
| for part in update_event.artifact.parts: | ||
| if hasattr(part, "root") and hasattr(part.root, "text"): | ||
| content.append({"text": part.root.text}) | ||
| # Handle status updates with messages | ||
| elif isinstance(update_event, TaskStatusUpdateEvent): | ||
| if update_event.status and hasattr(update_event.status, "message") and update_event.status.message: | ||
| for part in update_event.status.message.parts: | ||
| if hasattr(part, "root") and hasattr(part.root, "text"): | ||
| content.append({"text": part.root.text}) | ||
| # Handle initial task or task without update event | ||
| elif update_event is None and task and hasattr(task, "artifacts") and task.artifacts is not None: | ||
| for artifact in task.artifacts: | ||
| if hasattr(artifact, "parts"): | ||
| for part in artifact.parts: | ||
| if hasattr(part, "root") and hasattr(part.root, "text"): | ||
| content.append({"text": part.root.text}) | ||
| elif isinstance(response, A2AMessage): | ||
| for part in response.parts: | ||
| if hasattr(part, "root") and hasattr(part.root, "text"): | ||
| content.append({"text": part.root.text}) | ||
|
|
||
| message: Message = { | ||
| "role": "assistant", | ||
| "content": content, | ||
| } | ||
|
|
||
| return AgentResult( | ||
| stop_reason="end_turn", | ||
| message=message, | ||
| metrics=EventLoopMetrics(), | ||
| state={}, | ||
| ) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.