Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/strands/tools/mcp/mcp_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import logging
from datetime import timedelta
from typing import TYPE_CHECKING, Any

from mcp.types import Tool as MCPTool
Expand All @@ -28,20 +29,28 @@ class MCPAgentTool(AgentTool):
seamlessly within the agent framework.
"""

def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient", name_override: str | None = None) -> None:
def __init__(
self,
mcp_tool: MCPTool,
mcp_client: "MCPClient",
name_override: str | None = None,
timeout: timedelta | None = None,
) -> None:
"""Initialize a new MCPAgentTool instance.
Args:
mcp_tool: The MCP tool to adapt
mcp_client: The MCP server connection to use for tool invocation
name_override: Optional name to use for the agent tool (for disambiguation)
If None, uses the original MCP tool name
timeout: Optional timeout duration for tool execution
"""
super().__init__()
logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name)
self.mcp_tool = mcp_tool
self.mcp_client = mcp_client
self._agent_tool_name = name_override or mcp_tool.name
self.timeout = timeout

@property
def tool_name(self) -> str:
Expand Down Expand Up @@ -105,5 +114,6 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
tool_use_id=tool_use["toolUseId"],
name=self.mcp_tool.name, # Use original MCP name for server communication
arguments=tool_use["input"],
read_timeout_seconds=self.timeout,
)
yield ToolResultEvent(result)
29 changes: 28 additions & 1 deletion tests/strands/tools/mcp/test_mcp_agent_tool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import timedelta
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -88,5 +89,31 @@ async def test_stream(mcp_agent_tool, mock_mcp_client, alist):

assert tru_events == exp_events
mock_mcp_client.call_tool_async.assert_called_once_with(
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=None
)


def test_timeout_initialization(mock_mcp_tool, mock_mcp_client):
timeout = timedelta(seconds=30)
agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client, timeout=timeout)
assert agent_tool.timeout == timeout


def test_timeout_default_none(mock_mcp_tool, mock_mcp_client):
agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client)
assert agent_tool.timeout is None


@pytest.mark.asyncio
async def test_stream_with_timeout(mock_mcp_tool, mock_mcp_client, alist):
timeout = timedelta(seconds=45)
agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client, timeout=timeout)
tool_use = {"toolUseId": "test-456", "name": "test_tool", "input": {"param": "value"}}

tru_events = await alist(agent_tool.stream(tool_use, {}))
exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)]

assert tru_events == exp_events
mock_mcp_client.call_tool_async.assert_called_once_with(
tool_use_id="test-456", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=timeout
)
Loading