diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index 6618d3328..8497a67d2 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -1,10 +1,4 @@ -"""This package provides the core Agent interface and supporting components for building AI agents with the SDK. - -It includes: - -- Agent: The main interface for interacting with AI models and tools -- ConversationManager: Classes for managing conversation history and context windows -""" +"""This package provides the core Agent interface and supporting components for building AI agents with the SDK.""" from .agent import Agent from .agent_result import AgentResult diff --git a/src/strands/agent/a2a_agent.py b/src/strands/agent/a2a_agent.py new file mode 100644 index 000000000..1f2e015a6 --- /dev/null +++ b/src/strands/agent/a2a_agent.py @@ -0,0 +1,253 @@ +"""A2A Agent client for Strands Agents. + +This module provides the A2AAgent class, which acts as a client wrapper for remote A2A agents, +allowing them to be used standalone or as part of multi-agent patterns. + +A2AAgent can be used to get the Agent Card and interact with the agent. +""" + +import logging +from typing import Any, AsyncIterator + +import httpx +from a2a.client import A2ACardResolver, Client, ClientConfig, ClientFactory +from a2a.types import AgentCard, Message, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent + +from .._async import run_async +from ..multiagent.a2a.converters import convert_input_to_message, convert_response_to_agent_result +from ..types._events import AgentResultEvent +from ..types.a2a import A2AResponse, A2AStreamEvent +from ..types.agent import AgentInput +from .agent_result import AgentResult + +logger = logging.getLogger(__name__) + +DEFAULT_TIMEOUT = 300 + + +class A2AAgent: + """Client wrapper for remote A2A agents.""" + + def __init__( + self, + endpoint: str, + *, + name: str | None = None, + description: str = "", + timeout: int = DEFAULT_TIMEOUT, + a2a_client_factory: ClientFactory | None = None, + ): + """Initialize A2A agent. + + Args: + endpoint: The base URL of the remote A2A agent. + name: Agent name. If not provided, will be populated from agent card. + description: Agent description. If empty, will be populated from agent card. + timeout: Timeout for HTTP operations in seconds (defaults to 300). + a2a_client_factory: Optional pre-configured A2A ClientFactory. If provided, + it will be used to create the A2A client after discovering the agent card. + """ + self.endpoint = endpoint + self.name = name + self.description = description + self.timeout = timeout + self._httpx_client: httpx.AsyncClient | None = None + self._owns_client = a2a_client_factory is None + self._agent_card: AgentCard | None = None + self._a2a_client: Client | None = None + self._a2a_client_factory: ClientFactory | None = a2a_client_factory + + def _get_httpx_client(self) -> httpx.AsyncClient: + """Get or create the httpx client for this agent. + + Returns: + Configured httpx.AsyncClient instance. + """ + if self._httpx_client is None: + self._httpx_client = httpx.AsyncClient(timeout=self.timeout) + return self._httpx_client + + async def _get_agent_card(self) -> AgentCard: + """Discover and cache the agent card from the remote endpoint. + + Returns: + The discovered AgentCard. + """ + if self._agent_card is not None: + return self._agent_card + + httpx_client = self._get_httpx_client() + resolver = A2ACardResolver(httpx_client=httpx_client, base_url=self.endpoint) + self._agent_card = await resolver.get_agent_card() + + # Populate name from card if not set + if self.name is None and self._agent_card.name: + self.name = self._agent_card.name + + # Populate description from card if not set + if not self.description and self._agent_card.description: + self.description = self._agent_card.description + + logger.info("agent=<%s>, endpoint=<%s> | discovered agent card", self.name, self.endpoint) + return self._agent_card + + async def _get_a2a_client(self) -> Client: + """Get or create the A2A client for this agent. + + Returns: + Configured A2A client instance. + """ + if self._a2a_client is None: + agent_card = await self._get_agent_card() + + if self._a2a_client_factory is not None: + # Use provided factory + factory = self._a2a_client_factory + else: + # Create default factory + httpx_client = self._get_httpx_client() + config = ClientConfig(httpx_client=httpx_client, streaming=False) + factory = ClientFactory(config) + + self._a2a_client = factory.create(agent_card) + return self._a2a_client + + async def _send_message(self, prompt: AgentInput) -> AsyncIterator[A2AResponse]: + """Send message to A2A agent. + + Args: + prompt: Input to send to the agent. + + Returns: + Async iterator of A2A events. + + Raises: + ValueError: If prompt is None. + """ + if prompt is None: + raise ValueError("prompt is required for A2AAgent") + + client = await self._get_a2a_client() + message = convert_input_to_message(prompt) + + logger.info("agent=<%s>, endpoint=<%s> | sending message", self.name, self.endpoint) + return client.send_message(message) + + def _is_complete_event(self, event: A2AResponse) -> bool: + """Check if an A2A event represents a complete response. + + Args: + event: A2A event. + + Returns: + True if the event represents a complete response. + """ + # Direct Message is always complete + if isinstance(event, Message): + return True + + # Handle tuple responses (Task, UpdateEvent | None) + if isinstance(event, tuple) and len(event) == 2: + task, update_event = event + + # Initial task response (no update event) + if update_event is None: + return True + + # Artifact update with last_chunk flag + if isinstance(update_event, TaskArtifactUpdateEvent): + if hasattr(update_event, "last_chunk") and update_event.last_chunk is not None: + return update_event.last_chunk + return False + + # Status update with completed state + if isinstance(update_event, TaskStatusUpdateEvent): + if update_event.status and hasattr(update_event.status, "state"): + return update_event.status.state == TaskState.completed + + return False + + async def invoke_async( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AgentResult: + """Asynchronously invoke the remote A2A agent. + + Args: + prompt: Input to the agent (string, message list, or content blocks). + **kwargs: Additional arguments (ignored). + + Returns: + AgentResult containing the agent's response. + + Raises: + ValueError: If prompt is None. + RuntimeError: If no response received from agent. + """ + async for event in await self._send_message(prompt): + return convert_response_to_agent_result(event) + + raise RuntimeError("No response received from A2A agent") + + def __call__( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AgentResult: + """Synchronously invoke the remote A2A agent. + + Args: + prompt: Input to the agent (string, message list, or content blocks). + **kwargs: Additional arguments (ignored). + + Returns: + AgentResult containing the agent's response. + + Raises: + ValueError: If prompt is None. + RuntimeError: If no response received from agent. + """ + return run_async(lambda: self.invoke_async(prompt, **kwargs)) + + async def stream_async( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AsyncIterator[Any]: + """Stream agent execution asynchronously. + + Args: + prompt: Input to the agent (string, message list, or content blocks). + **kwargs: Additional arguments (ignored). + + Yields: + A2A events and a final AgentResult event. + + Raises: + ValueError: If prompt is None. + """ + last_event = None + last_complete_event = None + + async for event in await self._send_message(prompt): + last_event = event + if self._is_complete_event(event): + last_complete_event = event + yield A2AStreamEvent(event) + + # Use the last complete event if available, otherwise fall back to last event + final_event = last_complete_event if last_complete_event is not None else last_event + + if final_event is not None: + result = convert_response_to_agent_result(final_event) + yield AgentResultEvent(result) + + def __del__(self) -> None: + """Clean up resources when agent is garbage collected.""" + if self._owns_client and self._httpx_client is not None: + try: + client = self._httpx_client + run_async(lambda: client.aclose()) + except Exception: + pass # Best effort cleanup, ignore errors in __del__ diff --git a/src/strands/multiagent/a2a/converters.py b/src/strands/multiagent/a2a/converters.py new file mode 100644 index 000000000..efd04cddc --- /dev/null +++ b/src/strands/multiagent/a2a/converters.py @@ -0,0 +1,126 @@ +"""Conversion functions between Strands and A2A types.""" + +from typing import cast +from uuid import uuid4 + +from a2a.types import Message as A2AMessage +from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart + +from ...agent.agent_result import AgentResult +from ...telemetry.metrics import EventLoopMetrics +from ...types.a2a import A2AResponse +from ...types.agent import AgentInput +from ...types.content import ContentBlock, Message + + +def convert_input_to_message(prompt: AgentInput) -> A2AMessage: + """Convert AgentInput to A2A Message. + + Args: + prompt: Input in various formats (string, message list, or content blocks). + + Returns: + A2AMessage ready to send to the remote agent. + + Raises: + ValueError: If prompt format is unsupported. + """ + message_id = uuid4().hex + + if isinstance(prompt, str): + return A2AMessage( + kind="message", + role=Role.user, + parts=[Part(TextPart(kind="text", text=prompt))], + message_id=message_id, + ) + + if isinstance(prompt, list) and prompt and (isinstance(prompt[0], dict)): + if "role" in prompt[0]: + for msg in reversed(prompt): + if msg.get("role") == "user": + content = cast(list[ContentBlock], msg.get("content", [])) + parts = convert_content_blocks_to_parts(content) + return A2AMessage( + kind="message", + role=Role.user, + parts=parts, + message_id=message_id, + ) + else: + parts = convert_content_blocks_to_parts(cast(list[ContentBlock], prompt)) + return A2AMessage( + kind="message", + role=Role.user, + parts=parts, + message_id=message_id, + ) + + raise ValueError(f"Unsupported input type: {type(prompt)}") + + +def convert_content_blocks_to_parts(content_blocks: list[ContentBlock]) -> list[Part]: + """Convert Strands ContentBlocks to A2A Parts. + + Args: + content_blocks: List of Strands content blocks. + + Returns: + List of A2A Part objects. + """ + parts = [] + for block in content_blocks: + if "text" in block: + parts.append(Part(TextPart(kind="text", text=block["text"]))) + return parts + + +def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: + """Convert A2A response to AgentResult. + + Args: + response: A2A response (either A2AMessage or tuple of task and update event). + + Returns: + AgentResult with extracted content and metadata. + """ + content: list[ContentBlock] = [] + + if isinstance(response, tuple) and len(response) == 2: + task, update_event = response + + # Handle artifact updates + if isinstance(update_event, TaskArtifactUpdateEvent): + if update_event.artifact and hasattr(update_event.artifact, "parts"): + for part in update_event.artifact.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + # Handle status updates with messages + elif isinstance(update_event, TaskStatusUpdateEvent): + if update_event.status and hasattr(update_event.status, "message") and update_event.status.message: + for part in update_event.status.message.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + # Handle initial task or task without update event + elif update_event is None and task and hasattr(task, "artifacts") and task.artifacts is not None: + for artifact in task.artifacts: + if hasattr(artifact, "parts"): + for part in artifact.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + elif isinstance(response, A2AMessage): + for part in response.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + + message: Message = { + "role": "assistant", + "content": content, + } + + return AgentResult( + stop_reason="end_turn", + message=message, + metrics=EventLoopMetrics(), + state={}, + ) diff --git a/src/strands/types/a2a.py b/src/strands/types/a2a.py new file mode 100644 index 000000000..3a7bb091b --- /dev/null +++ b/src/strands/types/a2a.py @@ -0,0 +1,26 @@ +"""Additional A2A types.""" + +from typing import TypeAlias + +from a2a.types import Message, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent + +from ._events import TypedEvent + +A2AResponse: TypeAlias = tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | Message + + +class A2AStreamEvent(TypedEvent): + """Event that wraps streamed A2A types.""" + + def __init__(self, a2a_event: A2AResponse) -> None: + """Initialize with A2A event. + + Args: + a2a_event: The original A2A event (Task tuple or Message) + """ + super().__init__( + { + "type": "a2a_stream", + "event": a2a_event, # Nest A2A event to avoid field conflicts + } + ) diff --git a/tests/strands/agent/test_a2a_agent.py b/tests/strands/agent/test_a2a_agent.py new file mode 100644 index 000000000..d595787b7 --- /dev/null +++ b/tests/strands/agent/test_a2a_agent.py @@ -0,0 +1,231 @@ +"""Tests for A2AAgent class.""" + +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest +from a2a.types import AgentCard, Message, Part, Role, TextPart + +from strands.agent.a2a_agent import A2AAgent +from strands.agent.agent_result import AgentResult + + +@pytest.fixture +def mock_agent_card(): + """Mock AgentCard for testing.""" + return AgentCard( + name="test-agent", + description="Test agent", + url="http://localhost:8000", + version="1.0.0", + capabilities={}, + default_input_modes=["text/plain"], + default_output_modes=["text/plain"], + skills=[], + ) + + +@pytest.fixture +def a2a_agent(): + """Create A2AAgent instance for testing.""" + return A2AAgent(endpoint="http://localhost:8000") + + +def test_init_with_defaults(): + """Test initialization with default parameters.""" + agent = A2AAgent(endpoint="http://localhost:8000") + assert agent.endpoint == "http://localhost:8000" + assert agent.timeout == 300 + assert agent._agent_card is None + assert agent.name is None + assert agent.description == "" + + +def test_init_with_name_and_description(): + """Test initialization with custom name and description.""" + agent = A2AAgent(endpoint="http://localhost:8000", name="my-agent", description="My custom agent") + assert agent.name == "my-agent" + assert agent.description == "My custom agent" + + +def test_init_with_custom_timeout(): + """Test initialization with custom timeout.""" + agent = A2AAgent(endpoint="http://localhost:8000", timeout=600) + assert agent.timeout == 600 + + +def test_init_with_external_a2a_client_factory(): + """Test initialization with external A2A client factory.""" + external_factory = MagicMock() + agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=external_factory) + assert agent._a2a_client_factory is external_factory + assert not agent._owns_client + + +@pytest.mark.asyncio +async def test_get_agent_card(a2a_agent, mock_agent_card): + """Test agent card discovery.""" + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_agent_card) + mock_resolver_class.return_value = mock_resolver + + card = await a2a_agent._get_agent_card() + + assert card == mock_agent_card + assert a2a_agent._agent_card == mock_agent_card + + +@pytest.mark.asyncio +async def test_get_agent_card_cached(a2a_agent, mock_agent_card): + """Test that agent card is cached after first discovery.""" + a2a_agent._agent_card = mock_agent_card + + card = await a2a_agent._get_agent_card() + + assert card == mock_agent_card + + +@pytest.mark.asyncio +async def test_get_agent_card_populates_name_and_description(mock_agent_card): + """Test that agent card populates name and description if not set.""" + agent = A2AAgent(endpoint="http://localhost:8000") + + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_agent_card) + mock_resolver_class.return_value = mock_resolver + + await agent._get_agent_card() + + assert agent.name == mock_agent_card.name + assert agent.description == mock_agent_card.description + + +@pytest.mark.asyncio +async def test_get_agent_card_preserves_custom_name_and_description(mock_agent_card): + """Test that custom name and description are not overridden by agent card.""" + agent = A2AAgent(endpoint="http://localhost:8000", name="custom-name", description="Custom description") + + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_agent_card) + mock_resolver_class.return_value = mock_resolver + + await agent._get_agent_card() + + assert agent.name == "custom-name" + assert agent.description == "Custom description" + + +@pytest.mark.asyncio +async def test_invoke_async_success(a2a_agent, mock_agent_card): + """Test successful async invocation.""" + mock_response = Message( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) + + async def mock_send_message(*args, **kwargs): + yield mock_response + + with patch.object(a2a_agent, "_get_agent_card", return_value=mock_agent_card): + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_client = AsyncMock() + mock_client.send_message = mock_send_message + mock_factory = MagicMock() + mock_factory.create.return_value = mock_client + mock_factory_class.return_value = mock_factory + + result = await a2a_agent.invoke_async("Hello") + + assert isinstance(result, AgentResult) + assert result.message["content"][0]["text"] == "Response" + + +@pytest.mark.asyncio +async def test_invoke_async_no_prompt(a2a_agent): + """Test that invoke_async raises ValueError when prompt is None.""" + with pytest.raises(ValueError, match="prompt is required"): + await a2a_agent.invoke_async(None) + + +@pytest.mark.asyncio +async def test_invoke_async_no_response(a2a_agent, mock_agent_card): + """Test that invoke_async raises RuntimeError when no response received.""" + + async def mock_send_message(*args, **kwargs): + return + yield # Make it an async generator + + with patch.object(a2a_agent, "_get_agent_card", return_value=mock_agent_card): + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_client = AsyncMock() + mock_client.send_message = mock_send_message + mock_factory = MagicMock() + mock_factory.create.return_value = mock_client + mock_factory_class.return_value = mock_factory + + with pytest.raises(RuntimeError, match="No response received"): + await a2a_agent.invoke_async("Hello") + + +def test_call_sync(a2a_agent): + """Test synchronous call method.""" + mock_result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=MagicMock(), + state={}, + ) + + with patch("strands.agent.a2a_agent.run_async") as mock_run_async: + mock_run_async.return_value = mock_result + + result = a2a_agent("Hello") + + assert result == mock_result + mock_run_async.assert_called_once() + + +@pytest.mark.asyncio +async def test_stream_async_success(a2a_agent, mock_agent_card): + """Test successful async streaming.""" + mock_response = Message( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) + + async def mock_send_message(*args, **kwargs): + yield mock_response + + with patch.object(a2a_agent, "_get_agent_card", return_value=mock_agent_card): + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_client = AsyncMock() + mock_client.send_message = mock_send_message + mock_factory = MagicMock() + mock_factory.create.return_value = mock_client + mock_factory_class.return_value = mock_factory + + events = [] + async for event in a2a_agent.stream_async("Hello"): + events.append(event) + + assert len(events) == 2 + # First event is A2A stream event + assert events[0]["type"] == "a2a_stream" + assert events[0]["event"] == mock_response + # Final event is AgentResult + assert "result" in events[1] + assert isinstance(events[1]["result"], AgentResult) + assert events[1]["result"].message["content"][0]["text"] == "Response" + + +@pytest.mark.asyncio +async def test_stream_async_no_prompt(a2a_agent): + """Test that stream_async raises ValueError when prompt is None.""" + with pytest.raises(ValueError, match="prompt is required"): + async for _ in a2a_agent.stream_async(None): + pass diff --git a/tests/strands/multiagent/a2a/test_converters.py b/tests/strands/multiagent/a2a/test_converters.py new file mode 100644 index 000000000..0d443e4ed --- /dev/null +++ b/tests/strands/multiagent/a2a/test_converters.py @@ -0,0 +1,115 @@ +"""Tests for A2A converter functions.""" + +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest +from a2a.types import Message as A2AMessage +from a2a.types import Part, Role, TextPart + +from strands.agent.agent_result import AgentResult +from strands.multiagent.a2a.converters import ( + convert_content_blocks_to_parts, + convert_input_to_message, + convert_response_to_agent_result, +) + + +def test_convert_string_input(): + """Test converting string input to A2A message.""" + message = convert_input_to_message("Hello") + + assert isinstance(message, A2AMessage) + assert message.role == Role.user + assert len(message.parts) == 1 + assert message.parts[0].root.text == "Hello" + + +def test_convert_message_list_input(): + """Test converting message list input to A2A message.""" + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + message = convert_input_to_message(messages) + + assert isinstance(message, A2AMessage) + assert message.role == Role.user + assert len(message.parts) == 1 + + +def test_convert_content_blocks_input(): + """Test converting content blocks input to A2A message.""" + content_blocks = [{"text": "Hello"}, {"text": "World"}] + + message = convert_input_to_message(content_blocks) + + assert isinstance(message, A2AMessage) + assert len(message.parts) == 2 + + +def test_convert_unsupported_input(): + """Test that unsupported input types raise ValueError.""" + with pytest.raises(ValueError, match="Unsupported input type"): + convert_input_to_message(123) + + +def test_convert_content_blocks_to_parts(): + """Test converting content blocks to A2A parts.""" + content_blocks = [{"text": "Hello"}, {"text": "World"}] + + parts = convert_content_blocks_to_parts(content_blocks) + + assert len(parts) == 2 + assert parts[0].root.text == "Hello" + assert parts[1].root.text == "World" + + +def test_convert_a2a_message_response(): + """Test converting A2A message response to AgentResult.""" + a2a_message = A2AMessage( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) + + result = convert_response_to_agent_result(a2a_message) + + assert isinstance(result, AgentResult) + assert result.message["role"] == "assistant" + assert len(result.message["content"]) == 1 + assert result.message["content"][0]["text"] == "Response" + + +def test_convert_task_response(): + """Test converting task response to AgentResult.""" + mock_task = MagicMock() + mock_artifact = MagicMock() + mock_part = MagicMock() + mock_part.root.text = "Task response" + mock_artifact.parts = [mock_part] + mock_task.artifacts = [mock_artifact] + + result = convert_response_to_agent_result((mock_task, None)) + + assert isinstance(result, AgentResult) + assert len(result.message["content"]) == 1 + assert result.message["content"][0]["text"] == "Task response" + + +def test_convert_multiple_parts_response(): + """Test converting response with multiple parts to separate content blocks.""" + a2a_message = A2AMessage( + message_id=uuid4().hex, + role=Role.agent, + parts=[ + Part(TextPart(kind="text", text="First")), + Part(TextPart(kind="text", text="Second")), + ], + ) + + result = convert_response_to_agent_result(a2a_message) + + assert len(result.message["content"]) == 2 + assert result.message["content"][0]["text"] == "First" + assert result.message["content"][1]["text"] == "Second" diff --git a/tests_integ/a2a/__init__.py b/tests_integ/a2a/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_integ/a2a/a2a_server.py b/tests_integ/a2a/a2a_server.py new file mode 100644 index 000000000..047edc3ba --- /dev/null +++ b/tests_integ/a2a/a2a_server.py @@ -0,0 +1,15 @@ +from strands import Agent +from strands.multiagent.a2a import A2AServer + +# Create an agent and serve it over A2A +agent = Agent( + name="Test agent", + description="Test description here", + callback_handler=None, +) +a2a_server = A2AServer( + agent=agent, + host="localhost", + port=9000, +) +a2a_server.serve() diff --git a/tests_integ/a2a/test_multiagent_a2a.py b/tests_integ/a2a/test_multiagent_a2a.py new file mode 100644 index 000000000..5d8d13fc5 --- /dev/null +++ b/tests_integ/a2a/test_multiagent_a2a.py @@ -0,0 +1,27 @@ +import atexit +import os +import subprocess +import time + +from strands.agent.a2a_agent import A2AAgent + + +def test_a2a_agent(): + # Start our A2A server + server_path = os.path.join(os.path.dirname(__file__), "a2a_server.py") + server = subprocess.Popen(["python", server_path]) + + def cleanup(): + server.terminate() + + atexit.register(cleanup) + time.sleep(5) # Wait for A2A server to start + + # Connect to our A2A server + a2a_agent = A2AAgent(endpoint="http://localhost:9000") + + # Invoke our A2A server + result = a2a_agent("Hello there!") + + # Ensure that it was successful + assert result.stop_reason == "end_turn"