From c4e25207af01a39eeb1dbe21a7d874eb8046d602 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 21 Oct 2025 08:39:38 -0400 Subject: [PATCH] mcp elicitation --- src/strands/tools/mcp/mcp_client.py | 15 ++++++--- tests_integ/mcp/elicitation_server.py | 41 +++++++++++++++++++++++++ tests_integ/mcp/test_mcp_elicitation.py | 40 ++++++++++++++++++++++++ 3 files changed, 92 insertions(+), 4 deletions(-) create mode 100644 tests_integ/mcp/elicitation_server.py create mode 100644 tests_integ/mcp/test_mcp_elicitation.py diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 61f3d9185..2fe006466 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -20,6 +20,7 @@ import anyio from mcp import ClientSession, ListToolsResult +from mcp.client.session import ElicitationFnT from mcp.types import BlobResourceContents, GetPromptResult, ListPromptsResult, TextResourceContents from mcp.types import CallToolResult as MCPCallToolResult from mcp.types import EmbeddedResource as MCPEmbeddedResource @@ -98,19 +99,22 @@ def __init__( startup_timeout: int = 30, tool_filters: ToolFilters | None = None, prefix: str | None = None, - ): + elicitation_callback: Optional[ElicitationFnT] = None, + ) -> None: """Initialize a new MCP Server connection. Args: - transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple - startup_timeout: Timeout after which MCP server initialization should be cancelled + transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple. + startup_timeout: Timeout after which MCP server initialization should be cancelled. Defaults to 30. tool_filters: Optional filters to apply to tools. prefix: Optional prefix for tool names. + elicitation_callback: Optional callback function to handle elicitation requests from the MCP server. """ self._startup_timeout = startup_timeout self._tool_filters = tool_filters self._prefix = prefix + self._elicitation_callback = elicitation_callback mcp_instrumentation() self._session_id = uuid.uuid4() @@ -563,7 +567,10 @@ async def _async_background_thread(self) -> None: async with self._transport_callable() as (read_stream, write_stream, *_): self._log_debug_with_thread("transport connection established") async with ClientSession( - read_stream, write_stream, message_handler=self._handle_error_message + read_stream, + write_stream, + message_handler=self._handle_error_message, + elicitation_callback=self._elicitation_callback, ) as session: self._log_debug_with_thread("initializing MCP session") await session.initialize() diff --git a/tests_integ/mcp/elicitation_server.py b/tests_integ/mcp/elicitation_server.py new file mode 100644 index 000000000..337f29fa1 --- /dev/null +++ b/tests_integ/mcp/elicitation_server.py @@ -0,0 +1,41 @@ +"""MCP server for testing elicitation. + +- Docs: https://modelcontextprotocol.io/specification/draft/client/elicitation +""" + +from mcp.server import FastMCP +from mcp.types import ElicitRequest, ElicitRequestParams, ElicitResult + + +def server() -> None: + """Simulate approval through MCP elicitation.""" + server_ = FastMCP() + + @server_.tool(description="Tool to request approval") + async def approval_tool() -> str: + """Simulated approval tool. + + Returns: + The elicitation result from the user. + """ + request = ElicitRequest( + params=ElicitRequestParams( + message="Do you approve", + requestedSchema={ + "type": "object", + "properties": { + "message": {"type": "string", "description": "request message"}, + }, + "required": ["message"], + }, + ), + ) + result = await server_.get_context().session.send_request(request, ElicitResult) + + return result.model_dump_json() + + server_.run(transport="stdio") + + +if __name__ == "__main__": + server() diff --git a/tests_integ/mcp/test_mcp_elicitation.py b/tests_integ/mcp/test_mcp_elicitation.py new file mode 100644 index 000000000..4e5a224c1 --- /dev/null +++ b/tests_integ/mcp/test_mcp_elicitation.py @@ -0,0 +1,40 @@ +import json + +import pytest +from mcp import StdioServerParameters, stdio_client +from mcp.types import ElicitResult + +from strands import Agent +from strands.tools.mcp import MCPClient + + +@pytest.fixture +def callback(): + async def callback_(_, params): + return ElicitResult(action="accept", content={"message": params.message}) + + return callback_ + + +@pytest.fixture +def client(callback): + return MCPClient( + lambda: stdio_client( + StdioServerParameters(command="python", args=["tests_integ/mcp/elicitation_server.py"]), + ), + elicitation_callback=callback, + ) + + +def test_mcp_elicitation(client): + with client: + tools = client.list_tools_sync() + agent = Agent(tools=tools) + + agent("Can you get approval") + + tool_result = agent.messages[-2] + + tru_result = json.loads(tool_result["content"][0]["toolResult"]["content"][0]["text"]) + exp_result = {"meta": None, "action": "accept", "content": {"message": "Do you approve"}} + assert tru_result == exp_result