diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java index 121e1d9b867..a4900071db9 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java @@ -22,6 +22,7 @@ import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.Tool; +import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.ToolDefinition; @@ -110,4 +111,10 @@ public String call(String functionInput) { .block(); } + @Override + public String call(String toolArguments, ToolContext toolContext) { + // ToolContext is not supported by the MCP tools + return this.call(toolArguments); + } + } diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java index ffbe09303d6..80cc6f8d718 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java @@ -23,6 +23,7 @@ import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.Tool; +import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.ToolDefinition; @@ -111,4 +112,10 @@ public String call(String functionInput) { return ModelOptionsUtils.toJsonString(response.content()); } + @Override + public String call(String toolArguments, ToolContext toolContext) { + // ToolContext is not supported by the MCP tools + return this.call(toolArguments); + } + } diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java index 7131385258b..70f2da83f02 100644 --- a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java @@ -16,6 +16,8 @@ package org.springframework.ai.mcp; +import java.util.Map; + import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -25,6 +27,8 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.chat.model.ToolContext; + import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -71,4 +75,20 @@ void callShouldHandleJsonInputAndOutput() { assertThat(response).isNotNull(); } + @Test + void callShoulIngroeToolContext() { + // Arrange + when(tool.name()).thenReturn("testTool"); + CallToolResult callResult = mock(CallToolResult.class); + when(mcpClient.callTool(any(CallToolRequest.class))).thenReturn(callResult); + + SyncMcpToolCallback callback = new SyncMcpToolCallback(mcpClient, tool); + + // Act + String response = callback.call("{\"param\":\"value\"}", new ToolContext(Map.of("foo", "bar"))); + + // Assert + assertThat(response).isNotNull(); + } + }