Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/strands/tools/mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import anyio
from mcp import ClientSession, ListToolsResult
from mcp.client.session import ElicitationFnT
from mcp.types import BlobResourceContents, GetPromptResult, ListPromptsResult, TextResourceContents
from mcp.types import CallToolResult as MCPCallToolResult
from mcp.types import EmbeddedResource as MCPEmbeddedResource
Expand Down Expand Up @@ -98,19 +99,22 @@ def __init__(
startup_timeout: int = 30,
tool_filters: ToolFilters | None = None,
prefix: str | None = None,
):
elicitation_callback: Optional[ElicitationFnT] = None,
) -> None:
"""Initialize a new MCP Server connection.
Args:
transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple
startup_timeout: Timeout after which MCP server initialization should be cancelled
transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple.
startup_timeout: Timeout after which MCP server initialization should be cancelled.
Defaults to 30.
tool_filters: Optional filters to apply to tools.
prefix: Optional prefix for tool names.
elicitation_callback: Optional callback function to handle elicitation requests from the MCP server.
"""
self._startup_timeout = startup_timeout
self._tool_filters = tool_filters
self._prefix = prefix
self._elicitation_callback = elicitation_callback

mcp_instrumentation()
self._session_id = uuid.uuid4()
Expand Down Expand Up @@ -563,7 +567,10 @@ async def _async_background_thread(self) -> None:
async with self._transport_callable() as (read_stream, write_stream, *_):
self._log_debug_with_thread("transport connection established")
async with ClientSession(
read_stream, write_stream, message_handler=self._handle_error_message
read_stream,
write_stream,
message_handler=self._handle_error_message,
elicitation_callback=self._elicitation_callback,
) as session:
self._log_debug_with_thread("initializing MCP session")
await session.initialize()
Expand Down
41 changes: 41 additions & 0 deletions tests_integ/mcp/elicitation_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""MCP server for testing elicitation.

- Docs: https://modelcontextprotocol.io/specification/draft/client/elicitation
"""

from mcp.server import FastMCP
from mcp.types import ElicitRequest, ElicitRequestParams, ElicitResult


def server() -> None:
"""Simulate approval through MCP elicitation."""
server_ = FastMCP()

@server_.tool(description="Tool to request approval")
async def approval_tool() -> str:
"""Simulated approval tool.

Returns:
The elicitation result from the user.
"""
request = ElicitRequest(
params=ElicitRequestParams(
message="Do you approve",
requestedSchema={
"type": "object",
"properties": {
"message": {"type": "string", "description": "request message"},
},
"required": ["message"],
},
),
)
result = await server_.get_context().session.send_request(request, ElicitResult)

return result.model_dump_json()

server_.run(transport="stdio")


if __name__ == "__main__":
server()
40 changes: 40 additions & 0 deletions tests_integ/mcp/test_mcp_elicitation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import json

import pytest
from mcp import StdioServerParameters, stdio_client
from mcp.types import ElicitResult

from strands import Agent
from strands.tools.mcp import MCPClient


@pytest.fixture
def callback():
async def callback_(_, params):
return ElicitResult(action="accept", content={"message": params.message})

return callback_


@pytest.fixture
def client(callback):
return MCPClient(
lambda: stdio_client(
StdioServerParameters(command="python", args=["tests_integ/mcp/elicitation_server.py"]),
),
elicitation_callback=callback,
)


def test_mcp_elicitation(client):
with client:
tools = client.list_tools_sync()
agent = Agent(tools=tools)

agent("Can you get approval")

tool_result = agent.messages[-2]

tru_result = json.loads(tool_result["content"][0]["toolResult"]["content"][0]["text"])
exp_result = {"meta": None, "action": "accept", "content": {"message": "Do you approve"}}
assert tru_result == exp_result