Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions src/strands/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
"""This package provides the core Agent interface and supporting components for building AI agents with the SDK.

It includes:

- Agent: The main interface for interacting with AI models and tools
- ConversationManager: Classes for managing conversation history and context windows
"""
"""This package provides the core Agent interface and supporting components for building AI agents with the SDK."""

from .agent import Agent
from .agent_result import AgentResult
Expand Down
253 changes: 253 additions & 0 deletions src/strands/agent/a2a_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
"""A2A Agent client for Strands Agents.

This module provides the A2AAgent class, which acts as a client wrapper for remote A2A agents,
allowing them to be used standalone or as part of multi-agent patterns.

A2AAgent can be used to get the Agent Card and interact with the agent.
"""

import logging
from typing import Any, AsyncIterator

import httpx
from a2a.client import A2ACardResolver, Client, ClientConfig, ClientFactory
from a2a.types import AgentCard, Message, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent

from .._async import run_async
from ..multiagent.a2a.converters import convert_input_to_message, convert_response_to_agent_result
from ..types._events import AgentResultEvent
from ..types.a2a import A2AResponse, A2AStreamEvent
from ..types.agent import AgentInput
from .agent_result import AgentResult

logger = logging.getLogger(__name__)

DEFAULT_TIMEOUT = 300


class A2AAgent:
"""Client wrapper for remote A2A agents."""

def __init__(
self,
endpoint: str,
*,
name: str | None = None,
description: str = "",
timeout: int = DEFAULT_TIMEOUT,
a2a_client_factory: ClientFactory | None = None,
):
"""Initialize A2A agent.

Args:
endpoint: The base URL of the remote A2A agent.
name: Agent name. If not provided, will be populated from agent card.
description: Agent description. If empty, will be populated from agent card.
timeout: Timeout for HTTP operations in seconds (defaults to 300).
a2a_client_factory: Optional pre-configured A2A ClientFactory. If provided,
it will be used to create the A2A client after discovering the agent card.
"""
self.endpoint = endpoint
self.name = name
self.description = description
self.timeout = timeout
self._httpx_client: httpx.AsyncClient | None = None
self._owns_client = a2a_client_factory is None
self._agent_card: AgentCard | None = None
self._a2a_client: Client | None = None
self._a2a_client_factory: ClientFactory | None = a2a_client_factory

def _get_httpx_client(self) -> httpx.AsyncClient:
"""Get or create the httpx client for this agent.

Returns:
Configured httpx.AsyncClient instance.
"""
if self._httpx_client is None:
self._httpx_client = httpx.AsyncClient(timeout=self.timeout)
return self._httpx_client

async def _get_agent_card(self) -> AgentCard:
"""Discover and cache the agent card from the remote endpoint.

Returns:
The discovered AgentCard.
"""
if self._agent_card is not None:
return self._agent_card

httpx_client = self._get_httpx_client()
resolver = A2ACardResolver(httpx_client=httpx_client, base_url=self.endpoint)
self._agent_card = await resolver.get_agent_card()

# Populate name from card if not set
if self.name is None and self._agent_card.name:
self.name = self._agent_card.name

# Populate description from card if not set
if not self.description and self._agent_card.description:
self.description = self._agent_card.description

logger.info("agent=<%s>, endpoint=<%s> | discovered agent card", self.name, self.endpoint)
return self._agent_card

async def _get_a2a_client(self) -> Client:
"""Get or create the A2A client for this agent.

Returns:
Configured A2A client instance.
"""
if self._a2a_client is None:
agent_card = await self._get_agent_card()

if self._a2a_client_factory is not None:
# Use provided factory
factory = self._a2a_client_factory
else:
# Create default factory
httpx_client = self._get_httpx_client()
config = ClientConfig(httpx_client=httpx_client, streaming=False)
factory = ClientFactory(config)

self._a2a_client = factory.create(agent_card)
return self._a2a_client

async def _send_message(self, prompt: AgentInput) -> AsyncIterator[A2AResponse]:
"""Send message to A2A agent.

Args:
prompt: Input to send to the agent.

Returns:
Async iterator of A2A events.

Raises:
ValueError: If prompt is None.
"""
if prompt is None:
raise ValueError("prompt is required for A2AAgent")

client = await self._get_a2a_client()
message = convert_input_to_message(prompt)

logger.info("agent=<%s>, endpoint=<%s> | sending message", self.name, self.endpoint)
return client.send_message(message)

def _is_complete_event(self, event: A2AResponse) -> bool:
"""Check if an A2A event represents a complete response.

Args:
event: A2A event.

Returns:
True if the event represents a complete response.
"""
# Direct Message is always complete
if isinstance(event, Message):
return True

# Handle tuple responses (Task, UpdateEvent | None)
if isinstance(event, tuple) and len(event) == 2:
task, update_event = event

# Initial task response (no update event)
if update_event is None:
return True

# Artifact update with last_chunk flag
if isinstance(update_event, TaskArtifactUpdateEvent):
if hasattr(update_event, "last_chunk") and update_event.last_chunk is not None:
return update_event.last_chunk
return False

# Status update with completed state
if isinstance(update_event, TaskStatusUpdateEvent):
if update_event.status and hasattr(update_event.status, "state"):
return update_event.status.state == TaskState.completed

return False

async def invoke_async(
self,
prompt: AgentInput = None,
**kwargs: Any,
) -> AgentResult:
"""Asynchronously invoke the remote A2A agent.

Args:
prompt: Input to the agent (string, message list, or content blocks).
**kwargs: Additional arguments (ignored).

Returns:
AgentResult containing the agent's response.

Raises:
ValueError: If prompt is None.
RuntimeError: If no response received from agent.
"""
async for event in await self._send_message(prompt):
return convert_response_to_agent_result(event)

raise RuntimeError("No response received from A2A agent")

def __call__(
self,
prompt: AgentInput = None,
**kwargs: Any,
) -> AgentResult:
"""Synchronously invoke the remote A2A agent.

Args:
prompt: Input to the agent (string, message list, or content blocks).
**kwargs: Additional arguments (ignored).

Returns:
AgentResult containing the agent's response.

Raises:
ValueError: If prompt is None.
RuntimeError: If no response received from agent.
"""
return run_async(lambda: self.invoke_async(prompt, **kwargs))

async def stream_async(
self,
prompt: AgentInput = None,
**kwargs: Any,
) -> AsyncIterator[Any]:
"""Stream agent execution asynchronously.

Args:
prompt: Input to the agent (string, message list, or content blocks).
**kwargs: Additional arguments (ignored).

Yields:
A2A events and a final AgentResult event.

Raises:
ValueError: If prompt is None.
"""
last_event = None
last_complete_event = None

async for event in await self._send_message(prompt):
last_event = event
if self._is_complete_event(event):
last_complete_event = event
yield A2AStreamEvent(event)

# Use the last complete event if available, otherwise fall back to last event
final_event = last_complete_event if last_complete_event is not None else last_event

if final_event is not None:
result = convert_response_to_agent_result(final_event)
yield AgentResultEvent(result)

def __del__(self) -> None:
"""Clean up resources when agent is garbage collected."""
if self._owns_client and self._httpx_client is not None:
try:
client = self._httpx_client
run_async(lambda: client.aclose())
except Exception:
pass # Best effort cleanup, ignore errors in __del__
126 changes: 126 additions & 0 deletions src/strands/multiagent/a2a/converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Conversion functions between Strands and A2A types."""

from typing import cast
from uuid import uuid4

from a2a.types import Message as A2AMessage
from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart

from ...agent.agent_result import AgentResult
from ...telemetry.metrics import EventLoopMetrics
from ...types.a2a import A2AResponse
from ...types.agent import AgentInput
from ...types.content import ContentBlock, Message


def convert_input_to_message(prompt: AgentInput) -> A2AMessage:
"""Convert AgentInput to A2A Message.

Args:
prompt: Input in various formats (string, message list, or content blocks).

Returns:
A2AMessage ready to send to the remote agent.

Raises:
ValueError: If prompt format is unsupported.
"""
message_id = uuid4().hex

if isinstance(prompt, str):
return A2AMessage(
kind="message",
role=Role.user,
parts=[Part(TextPart(kind="text", text=prompt))],
message_id=message_id,
)

if isinstance(prompt, list) and prompt and (isinstance(prompt[0], dict)):
if "role" in prompt[0]:
for msg in reversed(prompt):
if msg.get("role") == "user":
content = cast(list[ContentBlock], msg.get("content", []))
parts = convert_content_blocks_to_parts(content)
return A2AMessage(
kind="message",
role=Role.user,
parts=parts,
message_id=message_id,
)
else:
parts = convert_content_blocks_to_parts(cast(list[ContentBlock], prompt))
return A2AMessage(
kind="message",
role=Role.user,
parts=parts,
message_id=message_id,
)

raise ValueError(f"Unsupported input type: {type(prompt)}")


def convert_content_blocks_to_parts(content_blocks: list[ContentBlock]) -> list[Part]:
"""Convert Strands ContentBlocks to A2A Parts.

Args:
content_blocks: List of Strands content blocks.

Returns:
List of A2A Part objects.
"""
parts = []
for block in content_blocks:
if "text" in block:
parts.append(Part(TextPart(kind="text", text=block["text"])))
return parts


def convert_response_to_agent_result(response: A2AResponse) -> AgentResult:
"""Convert A2A response to AgentResult.

Args:
response: A2A response (either A2AMessage or tuple of task and update event).

Returns:
AgentResult with extracted content and metadata.
"""
content: list[ContentBlock] = []

if isinstance(response, tuple) and len(response) == 2:
task, update_event = response

# Handle artifact updates
if isinstance(update_event, TaskArtifactUpdateEvent):
if update_event.artifact and hasattr(update_event.artifact, "parts"):
for part in update_event.artifact.parts:
if hasattr(part, "root") and hasattr(part.root, "text"):
content.append({"text": part.root.text})
# Handle status updates with messages
elif isinstance(update_event, TaskStatusUpdateEvent):
if update_event.status and hasattr(update_event.status, "message") and update_event.status.message:
for part in update_event.status.message.parts:
if hasattr(part, "root") and hasattr(part.root, "text"):
content.append({"text": part.root.text})
# Handle initial task or task without update event
elif update_event is None and task and hasattr(task, "artifacts") and task.artifacts is not None:
for artifact in task.artifacts:
if hasattr(artifact, "parts"):
for part in artifact.parts:
if hasattr(part, "root") and hasattr(part.root, "text"):
content.append({"text": part.root.text})
elif isinstance(response, A2AMessage):
for part in response.parts:
if hasattr(part, "root") and hasattr(part.root, "text"):
content.append({"text": part.root.text})

message: Message = {
"role": "assistant",
"content": content,
}

return AgentResult(
stop_reason="end_turn",
message=message,
metrics=EventLoopMetrics(),
state={},
)
Loading
Loading