From 92fb71df190a3683cf1b787f1c98567fa238b670 Mon Sep 17 00:00:00 2001 From: Anirudh Konduru Date: Fri, 14 Nov 2025 09:51:47 -0800 Subject: [PATCH] feat: allow setting a timeout when creating MCPAgentTool --- src/strands/tools/mcp/mcp_agent_tool.py | 12 +++++++- .../strands/tools/mcp/test_mcp_agent_tool.py | 29 ++++++++++++++++++- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/src/strands/tools/mcp/mcp_agent_tool.py b/src/strands/tools/mcp/mcp_agent_tool.py index af0c069a1..bedd93f24 100644 --- a/src/strands/tools/mcp/mcp_agent_tool.py +++ b/src/strands/tools/mcp/mcp_agent_tool.py @@ -6,6 +6,7 @@ """ import logging +from datetime import timedelta from typing import TYPE_CHECKING, Any from mcp.types import Tool as MCPTool @@ -28,7 +29,13 @@ class MCPAgentTool(AgentTool): seamlessly within the agent framework. """ - def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient", name_override: str | None = None) -> None: + def __init__( + self, + mcp_tool: MCPTool, + mcp_client: "MCPClient", + name_override: str | None = None, + timeout: timedelta | None = None, + ) -> None: """Initialize a new MCPAgentTool instance. Args: @@ -36,12 +43,14 @@ def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient", name_override: st mcp_client: The MCP server connection to use for tool invocation name_override: Optional name to use for the agent tool (for disambiguation) If None, uses the original MCP tool name + timeout: Optional timeout duration for tool execution """ super().__init__() logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name) self.mcp_tool = mcp_tool self.mcp_client = mcp_client self._agent_tool_name = name_override or mcp_tool.name + self.timeout = timeout @property def tool_name(self) -> str: @@ -105,5 +114,6 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw tool_use_id=tool_use["toolUseId"], name=self.mcp_tool.name, # Use original MCP name for server communication arguments=tool_use["input"], + read_timeout_seconds=self.timeout, ) yield ToolResultEvent(result) diff --git a/tests/strands/tools/mcp/test_mcp_agent_tool.py b/tests/strands/tools/mcp/test_mcp_agent_tool.py index 442a9919b..81a2d9afb 100644 --- a/tests/strands/tools/mcp/test_mcp_agent_tool.py +++ b/tests/strands/tools/mcp/test_mcp_agent_tool.py @@ -1,3 +1,4 @@ +from datetime import timedelta from unittest.mock import MagicMock import pytest @@ -88,5 +89,31 @@ async def test_stream(mcp_agent_tool, mock_mcp_client, alist): assert tru_events == exp_events mock_mcp_client.call_tool_async.assert_called_once_with( - tool_use_id="test-123", name="test_tool", arguments={"param": "value"} + tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=None + ) + + +def test_timeout_initialization(mock_mcp_tool, mock_mcp_client): + timeout = timedelta(seconds=30) + agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client, timeout=timeout) + assert agent_tool.timeout == timeout + + +def test_timeout_default_none(mock_mcp_tool, mock_mcp_client): + agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client) + assert agent_tool.timeout is None + + +@pytest.mark.asyncio +async def test_stream_with_timeout(mock_mcp_tool, mock_mcp_client, alist): + timeout = timedelta(seconds=45) + agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client, timeout=timeout) + tool_use = {"toolUseId": "test-456", "name": "test_tool", "input": {"param": "value"}} + + tru_events = await alist(agent_tool.stream(tool_use, {})) + exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)] + + assert tru_events == exp_events + mock_mcp_client.call_tool_async.assert_called_once_with( + tool_use_id="test-456", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=timeout )