Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,18 @@
matchIfMissing = true)
public class McpClientAutoConfiguration {

/**
* Create a dynamic client name based on the client name and the name of the server
* connection.
* @param clientName the client name as defined by the configuration
* @param serverConnectionName the name of the server connection being used by the
* client
* @return the connected client name
*/
private String connectedClientName(String clientName, String serverConnectionName) {
return clientName + " - " + serverConnectionName;
}

/**
* Creates a list of {@link McpSyncClient} instances based on the available
* transports.
Expand Down Expand Up @@ -144,7 +156,8 @@ public List<McpSyncClient> mcpSyncClients(McpSyncClientConfigurer mcpSyncClientC
if (!CollectionUtils.isEmpty(namedTransports)) {
for (NamedClientMcpTransport namedTransport : namedTransports) {

McpSchema.Implementation clientInfo = new McpSchema.Implementation(commonProperties.getName(),
McpSchema.Implementation clientInfo = new McpSchema.Implementation(
this.connectedClientName(commonProperties.getName(), namedTransport.name()),
commonProperties.getVersion());

McpClient.SyncSpec syncSpec = McpClient.sync(namedTransport.transport())
Expand Down Expand Up @@ -256,7 +269,8 @@ public List<McpAsyncClient> mcpAsyncClients(McpAsyncClientConfigurer mcpSyncClie
if (!CollectionUtils.isEmpty(namedTransports)) {
for (NamedClientMcpTransport namedTransport : namedTransports) {

McpSchema.Implementation clientInfo = new McpSchema.Implementation(commonProperties.getName(),
McpSchema.Implementation clientInfo = new McpSchema.Implementation(
this.connectedClientName(commonProperties.getName(), namedTransport.name()),
commonProperties.getVersion());

McpClient.AsyncSpec syncSpec = McpClient.async(namedTransport.transport())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

/**
* A named MCP client transport. Usually created by the transport auto-configurations, but
* you can also create them manually. Expose the list castom NamedClientMcpTransport
* as @Bean.
* you can also create them manually.
*
* @param name the name of the transport. Usually the name of the server connection.
* @param transport the MCP client transport.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package org.springframework.ai.autoconfigure.mcp.client.properties;

import java.time.Duration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -30,7 +29,6 @@

import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;

/**
* Configuration properties for the Model Context Protocol (MCP) stdio client.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.ai.mcp;

import java.util.Map;
import java.util.UUID;

import io.modelcontextprotocol.client.McpAsyncClient;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
Expand Down Expand Up @@ -85,7 +86,7 @@ public AsyncMcpToolCallback(McpAsyncClient mcpClient, Tool tool) {
@Override
public ToolDefinition getToolDefinition() {
return ToolDefinition.builder()
.name(this.tool.name())
.name(this.asyncMcpClient.getClientInfo().name() + "-" + this.tool.name())
.description(this.tool.description())
.inputSchema(ModelOptionsUtils.toJsonString(this.tool.inputSchema()))
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.ai.mcp;

import java.util.Map;
import java.util.UUID;

import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
Expand Down Expand Up @@ -70,6 +71,7 @@ public class SyncMcpToolCallback implements ToolCallback {
public SyncMcpToolCallback(McpSyncClient mcpClient, Tool tool) {
this.mcpClient = mcpClient;
this.tool = tool;

}

/**
Expand All @@ -86,7 +88,7 @@ public SyncMcpToolCallback(McpSyncClient mcpClient, Tool tool) {
@Override
public ToolDefinition getToolDefinition() {
return ToolDefinition.builder()
.name(this.tool.name())
.name(mcpClient.getClientInfo().name() + "-" + this.tool.name())
.description(this.tool.description())
.inputSchema(ModelOptionsUtils.toJsonString(this.tool.inputSchema()))
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.mockito.junit.jupiter.MockitoExtension;

import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.spec.McpSchema.Implementation;
import io.modelcontextprotocol.spec.McpSchema.ListToolsResult;
import io.modelcontextprotocol.spec.McpSchema.Tool;

Expand All @@ -40,23 +41,24 @@ class SyncMcpToolCallbackProviderTests {

@Test
void getToolCallbacksShouldReturnEmptyArrayWhenNoTools() {
// Arrange

ListToolsResult listToolsResult = mock(ListToolsResult.class);
when(listToolsResult.tools()).thenReturn(List.of());
when(mcpClient.listTools()).thenReturn(listToolsResult);

SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(mcpClient);

// Act
var callbacks = provider.getToolCallbacks();

// Assert
assertThat(callbacks).isEmpty();
}

@Test
void getToolCallbacksShouldReturnCallbacksForEachTool() {
// Arrange

var clientInfo = new Implementation("testClient", "1.0.0");
when(mcpClient.getClientInfo()).thenReturn(clientInfo);

Tool tool1 = mock(Tool.class);
when(tool1.name()).thenReturn("tool1");

Expand All @@ -69,16 +71,16 @@ void getToolCallbacksShouldReturnCallbacksForEachTool() {

SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(mcpClient);

// Act
var callbacks = provider.getToolCallbacks();

// Assert
assertThat(callbacks).hasSize(2);
}

@Test
void getToolCallbacksShouldThrowExceptionForDuplicateToolNames() {
// Arrange
var clientInfo = new Implementation("testClient", "1.0.0");
when(mcpClient.getClientInfo()).thenReturn(clientInfo);

Tool tool1 = mock(Tool.class);
when(tool1.name()).thenReturn("sameName");

Expand All @@ -91,9 +93,40 @@ void getToolCallbacksShouldThrowExceptionForDuplicateToolNames() {

SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(mcpClient);

// Act & Assert
assertThatThrownBy(() -> provider.getToolCallbacks()).isInstanceOf(IllegalStateException.class)
.hasMessageContaining("Multiple tools with the same name");
}

@Test
void getSameNameToolsButDifferntClientInfoNamesShouldProduceDifferentToolCallbackNames() {

Tool tool1 = mock(Tool.class);
when(tool1.name()).thenReturn("sameName");

Tool tool2 = mock(Tool.class);
when(tool2.name()).thenReturn("sameName");

McpSyncClient mcpClient1 = mock(McpSyncClient.class);
ListToolsResult listToolsResult1 = mock(ListToolsResult.class);
when(listToolsResult1.tools()).thenReturn(List.of(tool1));
when(mcpClient1.listTools()).thenReturn(listToolsResult1);

var clientInfo1 = new Implementation("testClient1", "1.0.0");
when(mcpClient1.getClientInfo()).thenReturn(clientInfo1);

McpSyncClient mcpClient2 = mock(McpSyncClient.class);
ListToolsResult listToolsResult2 = mock(ListToolsResult.class);
when(listToolsResult2.tools()).thenReturn(List.of(tool2));
when(mcpClient2.listTools()).thenReturn(listToolsResult2);

var clientInfo2 = new Implementation("testClient2", "1.0.0");
when(mcpClient2.getClientInfo()).thenReturn(clientInfo2);

SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(mcpClient1, mcpClient2);

var callbacks = provider.getToolCallbacks();

assertThat(callbacks).hasSize(2);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
import io.modelcontextprotocol.spec.McpSchema.Implementation;
import io.modelcontextprotocol.spec.McpSchema.Tool;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand All @@ -45,30 +46,31 @@ class SyncMcpToolCallbackTests {

@Test
void getToolDefinitionShouldReturnCorrectDefinition() {
// Arrange

var clientInfo = new Implementation("testClient", "1.0.0");
when(mcpClient.getClientInfo()).thenReturn(clientInfo);
when(tool.name()).thenReturn("testTool");
when(tool.description()).thenReturn("Test tool description");

SyncMcpToolCallback callback = new SyncMcpToolCallback(mcpClient, tool);

// Act
var toolDefinition = callback.getToolDefinition();

// Assert
assertThat(toolDefinition.name()).isEqualTo("testTool");
assertThat(toolDefinition.name()).isEqualTo(clientInfo.name() + "-testTool");
assertThat(toolDefinition.description()).isEqualTo("Test tool description");
}

@Test
void callShouldHandleJsonInputAndOutput() {
// Arrange

when(mcpClient.getClientInfo()).thenReturn(new Implementation("testClient", "1.0.0"));

when(tool.name()).thenReturn("testTool");
CallToolResult callResult = mock(CallToolResult.class);
when(mcpClient.callTool(any(CallToolRequest.class))).thenReturn(callResult);

SyncMcpToolCallback callback = new SyncMcpToolCallback(mcpClient, tool);

// Act
String response = callback.call("{\"param\":\"value\"}");

// Assert
Expand All @@ -77,17 +79,16 @@ void callShouldHandleJsonInputAndOutput() {

@Test
void callShoulIngroeToolContext() {
// Arrange
when(mcpClient.getClientInfo()).thenReturn(new Implementation("testClient", "1.0.0"));

when(tool.name()).thenReturn("testTool");
CallToolResult callResult = mock(CallToolResult.class);
when(mcpClient.callTool(any(CallToolRequest.class))).thenReturn(callResult);

SyncMcpToolCallback callback = new SyncMcpToolCallback(mcpClient, tool);

// Act
String response = callback.call("{\"param\":\"value\"}", new ToolContext(Map.of("foo", "bar")));

// Assert
assertThat(response).isNotNull();
}

Expand Down
Loading